feat: adds ability to manually specify measurement files
This commit is contained in:
parent
6d24f96b0f
commit
3fe55dbb67
@ -8,6 +8,7 @@ from kalpaa.config.config import (
|
||||
DeepdogConfig,
|
||||
Config,
|
||||
ReducedModelParams,
|
||||
OVERRIDE_MEASUREMENT_DIR_NAME,
|
||||
)
|
||||
from kalpaa.config.config_reader import (
|
||||
read_config_dict,
|
||||
@ -30,4 +31,5 @@ __all__ = [
|
||||
"serialize_config",
|
||||
"read_config",
|
||||
"read_general_config_dict",
|
||||
"OVERRIDE_MEASUREMENT_DIR_NAME",
|
||||
]
|
||||
|
@ -127,7 +127,9 @@ class GenerationConfig:
|
||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||
] = None
|
||||
|
||||
override_measurements: typing.Optional[typing.Sequence[str]] = None
|
||||
override_measurement_filesets: typing.Optional[
|
||||
typing.Mapping[str, typing.Sequence[str]]
|
||||
] = None
|
||||
|
||||
tantri_configs: typing.List[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
@ -170,6 +172,9 @@ class Config:
|
||||
def get_dots_json_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.dots_json_name)
|
||||
|
||||
def get_override_dir_path(self) -> pathlib.Path:
|
||||
return self.absify(OVERRIDE_MEASUREMENT_DIR_NAME)
|
||||
|
||||
def indexifier(self) -> deepdog.indexify.Indexifier:
|
||||
with self.absify(self.general_config.indexes_json_name).open(
|
||||
"r"
|
||||
|
@ -662,11 +662,13 @@ class BinnedData:
|
||||
|
||||
def stdev_cost_function_filter(
|
||||
self,
|
||||
dot_names: typing.Sequence[str],
|
||||
dot_names_or_pairs: typing.Union[
|
||||
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
||||
],
|
||||
target_cost: float,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
measurements = self.measurements(dot_names)
|
||||
measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs)
|
||||
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
|
||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, target_cost
|
||||
|
@ -135,10 +135,22 @@ def main():
|
||||
_logger.info(f"Copying {file} to {root}")
|
||||
(root / file).write_text((pathlib.Path.cwd() / file).read_text())
|
||||
|
||||
# if config.general_config is not None:
|
||||
# _logger.info(
|
||||
# f"Overriding measurements with {config.general_config.override_measurements}"
|
||||
# )
|
||||
if config.generation_config.override_measurement_filesets is not None:
|
||||
_logger.info(
|
||||
f"Overriding measurements with {config.generation_config.override_measurement_filesets}"
|
||||
)
|
||||
override_directory = root / kalpaa.config.OVERRIDE_MEASUREMENT_DIR_NAME
|
||||
override_directory.mkdir(exist_ok=True, parents=True)
|
||||
for (
|
||||
key,
|
||||
files,
|
||||
) in config.generation_config.override_measurement_filesets.items():
|
||||
_logger.info(f"Copying for {key=}, {files} to {override_directory}")
|
||||
for file in files:
|
||||
fileset_dir = override_directory / key
|
||||
fileset_dir.mkdir(exist_ok=True, parents=True)
|
||||
_logger.info(f"Copying {file} to {override_directory}")
|
||||
(fileset_dir / file).write_text((pathlib.Path.cwd() / file).read_text())
|
||||
|
||||
overridden_config = dataclasses.replace(
|
||||
config,
|
||||
|
@ -200,7 +200,27 @@ class Stage01Runner:
|
||||
|
||||
def run(self):
|
||||
seed_index = 0
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
if self.config.generation_config.override_measurement_filesets is not None:
|
||||
for (
|
||||
override_name,
|
||||
) in self.config.generation_config.override_measurement_filesets.keys():
|
||||
# don't need to do anything with the files, just create the out dir
|
||||
out = self.config.get_out_dir_path()
|
||||
directory = out / f"{override_name}"
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
elif self.config.generation_config.override_dipole_configs is not None:
|
||||
_logger.debug(
|
||||
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
||||
)
|
||||
for (
|
||||
override_name,
|
||||
override_dipoles,
|
||||
) in self.config.generation_config.override_dipole_configs.items():
|
||||
self.generate_override_dipole(
|
||||
seed_index, override_name, override_dipoles
|
||||
)
|
||||
else:
|
||||
# should be by default
|
||||
_logger.debug("no override needed!")
|
||||
for count in self.config.generation_config.counts:
|
||||
@ -213,17 +233,6 @@ class Stage01Runner:
|
||||
seed_index, count, orientation, replica
|
||||
)
|
||||
seed_index += 1
|
||||
else:
|
||||
_logger.debug(
|
||||
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
||||
)
|
||||
for (
|
||||
override_name,
|
||||
override_dipoles,
|
||||
) in self.config.generation_config.override_dipole_configs.items():
|
||||
self.generate_override_dipole(
|
||||
seed_index, override_name, override_dipoles
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -90,7 +90,9 @@ class Stage02Runner:
|
||||
else:
|
||||
return ["dot1", "dot2", current_dot]
|
||||
|
||||
def run_in_subdir(self, subdir: pathlib.Path):
|
||||
def run_in_subdir(
|
||||
self, subdir: pathlib.Path, override_key: typing.Optional[str] = None
|
||||
):
|
||||
with kalpaa.common.new_cd(subdir):
|
||||
_logger.debug(f"Running inside {subdir=}")
|
||||
|
||||
@ -123,6 +125,7 @@ class Stage02Runner:
|
||||
dot_name: str,
|
||||
trial_name: str,
|
||||
seed_index: int,
|
||||
override_name: typing.Optional[str] = None,
|
||||
):
|
||||
# _logger.info(f"Got job index {job_index}")
|
||||
# NOTE This guy runs inside subdirs, obviously. In something like <kalpa>/out/z-10-2/dipoles
|
||||
@ -137,6 +140,29 @@ class Stage02Runner:
|
||||
f"Have {self.config.generation_config.tantri_configs} as our tantri_configs"
|
||||
)
|
||||
num_tantri_configs = len(self.config.generation_config.tantri_configs)
|
||||
|
||||
if override_name is not None:
|
||||
if self.config.generation_config.override_measurement_filesets is None:
|
||||
raise ValueError(
|
||||
"override_name provided but no override_measurement_filesets, shouldn't be possible to get here"
|
||||
)
|
||||
_logger.info(f"Time to read override measurement fileset {override_name}")
|
||||
override_dir = self.config.get_override_dir_path()
|
||||
override_measurements = (
|
||||
self.config.generation_config.override_measurement_filesets[
|
||||
override_name
|
||||
]
|
||||
)
|
||||
_logger.info(f"Finding files {override_measurements} in {override_dir}")
|
||||
binned_datas = [
|
||||
kalpaa.read_dots_and_binned(
|
||||
self.config.get_dots_json_path(),
|
||||
override_dir / measurement,
|
||||
)
|
||||
for measurement in override_measurements
|
||||
]
|
||||
else:
|
||||
|
||||
binned_datas = [
|
||||
kalpaa.read_dots_and_binned(
|
||||
self.config.get_dots_json_path(),
|
||||
@ -221,6 +247,16 @@ class Stage02Runner:
|
||||
_logger.info(results)
|
||||
|
||||
def run(self):
|
||||
if self.config.generation_config.override_measurement_filesets is not None:
|
||||
_logger.info("Using override configuration.")
|
||||
for (
|
||||
override_name
|
||||
) in self.config.generation_config.override_measurement_filesets.keys():
|
||||
subdir = self.config.get_out_dir_path() / override_name
|
||||
dipoles_dir = subdir / "dipoles"
|
||||
dipoles_dir.mkdir(exist_ok=True, parents=False)
|
||||
self.run_in_subdir(dipoles_dir, override_key=override_name)
|
||||
|
||||
"""Going to iterate over every folder in out_dir, and execute the subdir stuff inside dirs like <kalpa>/out/z-10-2/dipoles"""
|
||||
out_dir_path = self.config.get_out_dir_path()
|
||||
subdirs = [child for child in out_dir_path.iterdir() if child.is_dir]
|
||||
|
301
tests/config/__snapshots__/test_toml_as_dict_snaps.ambr
Normal file
301
tests/config/__snapshots__/test_toml_as_dict_snaps.ambr
Normal file
@ -0,0 +1,301 @@
|
||||
# serializer version: 1
|
||||
# name: test_parse_config_all_fields_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
20,
|
||||
2,
|
||||
0.2,
|
||||
]),
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 2000,
|
||||
'use_log_noise': True,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1,
|
||||
'w_log_min': -5,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 6.5,
|
||||
'z_min': 5,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'dots_json_name': 'test_dots.json',
|
||||
'indexes_json_name': 'test_indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
|
||||
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
|
||||
'mega_merged_name': 'test_mega_merged.csv',
|
||||
'out_dir_name': 'out',
|
||||
'root_directory': PosixPath('test_root'),
|
||||
'skip_to_stage': 1,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 3,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': dict({
|
||||
'scenario1': list([
|
||||
dict({
|
||||
'p': array([3, 5, 7]),
|
||||
's': array([2, 4, 6]),
|
||||
'w': 10,
|
||||
}),
|
||||
dict({
|
||||
'p': array([30, 50, 70]),
|
||||
's': array([20, 40, 60]),
|
||||
'w': 10.55,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.01,
|
||||
'index_seed_starter': 15151,
|
||||
'num_iterations': 100,
|
||||
'num_seeds': 5,
|
||||
}),
|
||||
dict({
|
||||
'delta_t': 1,
|
||||
'index_seed_starter': 1234,
|
||||
'num_iterations': 200,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_parse_config_few_fields_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
5,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
0.2,
|
||||
]),
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 2000,
|
||||
'use_log_noise': True,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1,
|
||||
'w_log_min': -5,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 6.5,
|
||||
'z_min': 5,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'dots_json_name': 'dots.json',
|
||||
'indexes_json_name': 'indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
|
||||
'mega_merged_name': 'mega_merged_coalesced.csv',
|
||||
'out_dir_name': 'out',
|
||||
'root_directory': PosixPath('test_root1'),
|
||||
'skip_to_stage': None,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 2,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': dict({
|
||||
'scenario1': list([
|
||||
dict({
|
||||
'p': array([3, 5, 7]),
|
||||
's': array([2, 4, 6]),
|
||||
'w': 10,
|
||||
}),
|
||||
dict({
|
||||
'p': array([30, 50, 70]),
|
||||
's': array([20, 40, 60]),
|
||||
'w': 10.55,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.01,
|
||||
'index_seed_starter': 15151,
|
||||
'num_iterations': 100,
|
||||
'num_seeds': 5,
|
||||
}),
|
||||
dict({
|
||||
'delta_t': 1,
|
||||
'index_seed_starter': 1234,
|
||||
'num_iterations': 200,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_parse_config_geom_params_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
5,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
0.2,
|
||||
]),
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 2000,
|
||||
'use_log_noise': True,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1.5,
|
||||
'w_log_min': -3,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 2,
|
||||
'z_min': 0,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'dots_json_name': 'dots.json',
|
||||
'indexes_json_name': 'indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
|
||||
'mega_merged_name': 'mega_merged_coalesced.csv',
|
||||
'out_dir_name': 'out',
|
||||
'root_directory': PosixPath('test_root1'),
|
||||
'skip_to_stage': None,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 2,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': dict({
|
||||
'scenario1': list([
|
||||
dict({
|
||||
'p': array([3, 5, 7]),
|
||||
's': array([2, 4, 6]),
|
||||
'w': 10,
|
||||
}),
|
||||
dict({
|
||||
'p': array([30, 50, 70]),
|
||||
's': array([20, 40, 60]),
|
||||
'w': 10.55,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.01,
|
||||
'index_seed_starter': 15151,
|
||||
'num_iterations': 100,
|
||||
'num_seeds': 5,
|
||||
}),
|
||||
dict({
|
||||
'delta_t': 1,
|
||||
'index_seed_starter': 1234,
|
||||
'num_iterations': 200,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_parse_config_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
10,
|
||||
1,
|
||||
0.1,
|
||||
]),
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 1000,
|
||||
'use_log_noise': False,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1,
|
||||
'w_log_min': -5,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 6.5,
|
||||
'z_min': 5,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'dots_json_name': 'test_dots.json',
|
||||
'indexes_json_name': 'test_indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
|
||||
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
|
||||
'mega_merged_name': 'test_mega_merged.csv',
|
||||
'out_dir_name': 'test_out',
|
||||
'root_directory': PosixPath('test_root'),
|
||||
'skip_to_stage': 1,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 3,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': None,
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.05,
|
||||
'index_seed_starter': 31415,
|
||||
'num_iterations': 100000,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
@ -1,13 +1,13 @@
|
||||
# serializer version: 1
|
||||
# name: test_parse_config_all_fields_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
# ---
|
||||
# name: test_parse_config_few_fields_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
# ---
|
||||
# name: test_parse_config_geom_params_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5))
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5))
|
||||
# ---
|
||||
# name: test_parse_config_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
# ---
|
||||
|
35
tests/config/test_toml_as_dict_snaps.py
Normal file
35
tests/config/test_toml_as_dict_snaps.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pathlib
|
||||
from dataclasses import asdict
|
||||
|
||||
import kalpaa.config.config_reader
|
||||
|
||||
|
||||
TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files"
|
||||
|
||||
|
||||
def test_parse_config_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
||||
|
||||
|
||||
def test_parse_config_all_fields_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_all_fields.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
||||
|
||||
|
||||
def test_parse_config_few_fields_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_few_fields.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
||||
|
||||
|
||||
def test_parse_config_geom_params_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_geom_params.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
Loading…
x
Reference in New Issue
Block a user