diff --git a/kalpaa/config/__init__.py b/kalpaa/config/__init__.py index 073e91d..f3d6744 100644 --- a/kalpaa/config/__init__.py +++ b/kalpaa/config/__init__.py @@ -8,6 +8,7 @@ from kalpaa.config.config import ( DeepdogConfig, Config, ReducedModelParams, + OVERRIDE_MEASUREMENT_DIR_NAME, ) from kalpaa.config.config_reader import ( read_config_dict, @@ -30,4 +31,5 @@ __all__ = [ "serialize_config", "read_config", "read_general_config_dict", + "OVERRIDE_MEASUREMENT_DIR_NAME", ] diff --git a/kalpaa/config/config.py b/kalpaa/config/config.py index 58c9b5a..07cf5b8 100644 --- a/kalpaa/config/config.py +++ b/kalpaa/config/config.py @@ -127,7 +127,9 @@ class GenerationConfig: typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]] ] = None - override_measurements: typing.Optional[typing.Sequence[str]] = None + override_measurement_filesets: typing.Optional[ + typing.Mapping[str, typing.Sequence[str]] + ] = None tantri_configs: typing.List[TantriConfig] = field( default_factory=lambda: [TantriConfig()] @@ -170,6 +172,9 @@ class Config: def get_dots_json_path(self) -> pathlib.Path: return self.absify(self.general_config.dots_json_name) + def get_override_dir_path(self) -> pathlib.Path: + return self.absify(OVERRIDE_MEASUREMENT_DIR_NAME) + def indexifier(self) -> deepdog.indexify.Indexifier: with self.absify(self.general_config.indexes_json_name).open( "r" diff --git a/kalpaa/read_bin_csv.py b/kalpaa/read_bin_csv.py index bcb2f84..00c75fa 100644 --- a/kalpaa/read_bin_csv.py +++ b/kalpaa/read_bin_csv.py @@ -662,11 +662,13 @@ class BinnedData: def stdev_cost_function_filter( self, - dot_names: typing.Sequence[str], + dot_names_or_pairs: typing.Union[ + typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]] + ], target_cost: float, use_log_noise: bool = False, ): - measurements = self.measurements(dot_names) + measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs) cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise) return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter( cost_function, target_cost diff --git a/kalpaa/stages/__init__.py b/kalpaa/stages/__init__.py index 0c4ee6f..e87e017 100644 --- a/kalpaa/stages/__init__.py +++ b/kalpaa/stages/__init__.py @@ -135,10 +135,22 @@ def main(): _logger.info(f"Copying {file} to {root}") (root / file).write_text((pathlib.Path.cwd() / file).read_text()) - # if config.general_config is not None: - # _logger.info( - # f"Overriding measurements with {config.general_config.override_measurements}" - # ) + if config.generation_config.override_measurement_filesets is not None: + _logger.info( + f"Overriding measurements with {config.generation_config.override_measurement_filesets}" + ) + override_directory = root / kalpaa.config.OVERRIDE_MEASUREMENT_DIR_NAME + override_directory.mkdir(exist_ok=True, parents=True) + for ( + key, + files, + ) in config.generation_config.override_measurement_filesets.items(): + _logger.info(f"Copying for {key=}, {files} to {override_directory}") + for file in files: + fileset_dir = override_directory / key + fileset_dir.mkdir(exist_ok=True, parents=True) + _logger.info(f"Copying {file} to {override_directory}") + (fileset_dir / file).write_text((pathlib.Path.cwd() / file).read_text()) overridden_config = dataclasses.replace( config, diff --git a/kalpaa/stages/stage01.py b/kalpaa/stages/stage01.py index 6bcd983..c68aa6b 100644 --- a/kalpaa/stages/stage01.py +++ b/kalpaa/stages/stage01.py @@ -200,7 +200,27 @@ class Stage01Runner: def run(self): seed_index = 0 - if self.config.generation_config.override_dipole_configs is None: + if self.config.generation_config.override_measurement_filesets is not None: + for ( + override_name, + ) in self.config.generation_config.override_measurement_filesets.keys(): + # don't need to do anything with the files, just create the out dir + out = self.config.get_out_dir_path() + directory = out / f"{override_name}" + directory.mkdir(parents=True, exist_ok=True) + + elif self.config.generation_config.override_dipole_configs is not None: + _logger.debug( + f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}" + ) + for ( + override_name, + override_dipoles, + ) in self.config.generation_config.override_dipole_configs.items(): + self.generate_override_dipole( + seed_index, override_name, override_dipoles + ) + else: # should be by default _logger.debug("no override needed!") for count in self.config.generation_config.counts: @@ -213,17 +233,6 @@ class Stage01Runner: seed_index, count, orientation, replica ) seed_index += 1 - else: - _logger.debug( - f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}" - ) - for ( - override_name, - override_dipoles, - ) in self.config.generation_config.override_dipole_configs.items(): - self.generate_override_dipole( - seed_index, override_name, override_dipoles - ) def parse_args(): diff --git a/kalpaa/stages/stage02.py b/kalpaa/stages/stage02.py index b3ff259..8ce2dc0 100644 --- a/kalpaa/stages/stage02.py +++ b/kalpaa/stages/stage02.py @@ -90,7 +90,9 @@ class Stage02Runner: else: return ["dot1", "dot2", current_dot] - def run_in_subdir(self, subdir: pathlib.Path): + def run_in_subdir( + self, subdir: pathlib.Path, override_key: typing.Optional[str] = None + ): with kalpaa.common.new_cd(subdir): _logger.debug(f"Running inside {subdir=}") @@ -123,6 +125,7 @@ class Stage02Runner: dot_name: str, trial_name: str, seed_index: int, + override_name: typing.Optional[str] = None, ): # _logger.info(f"Got job index {job_index}") # NOTE This guy runs inside subdirs, obviously. In something like /out/z-10-2/dipoles @@ -137,14 +140,37 @@ class Stage02Runner: f"Have {self.config.generation_config.tantri_configs} as our tantri_configs" ) num_tantri_configs = len(self.config.generation_config.tantri_configs) - binned_datas = [ - kalpaa.read_dots_and_binned( - self.config.get_dots_json_path(), - pathlib.Path("..") - / kalpaa.common.tantri_binned_output_name(tantri_index), + + if override_name is not None: + if self.config.generation_config.override_measurement_filesets is None: + raise ValueError( + "override_name provided but no override_measurement_filesets, shouldn't be possible to get here" + ) + _logger.info(f"Time to read override measurement fileset {override_name}") + override_dir = self.config.get_override_dir_path() + override_measurements = ( + self.config.generation_config.override_measurement_filesets[ + override_name + ] ) - for tantri_index in range(num_tantri_configs) - ] + _logger.info(f"Finding files {override_measurements} in {override_dir}") + binned_datas = [ + kalpaa.read_dots_and_binned( + self.config.get_dots_json_path(), + override_dir / measurement, + ) + for measurement in override_measurements + ] + else: + + binned_datas = [ + kalpaa.read_dots_and_binned( + self.config.get_dots_json_path(), + pathlib.Path("..") + / kalpaa.common.tantri_binned_output_name(tantri_index), + ) + for tantri_index in range(num_tantri_configs) + ] dot_names = self._dots_to_include(dot_name) _logger.debug(f"Got dot names {dot_names}") @@ -221,6 +247,16 @@ class Stage02Runner: _logger.info(results) def run(self): + if self.config.generation_config.override_measurement_filesets is not None: + _logger.info("Using override configuration.") + for ( + override_name + ) in self.config.generation_config.override_measurement_filesets.keys(): + subdir = self.config.get_out_dir_path() / override_name + dipoles_dir = subdir / "dipoles" + dipoles_dir.mkdir(exist_ok=True, parents=False) + self.run_in_subdir(dipoles_dir, override_key=override_name) + """Going to iterate over every folder in out_dir, and execute the subdir stuff inside dirs like /out/z-10-2/dipoles""" out_dir_path = self.config.get_out_dir_path() subdirs = [child for child in out_dir_path.iterdir() if child.is_dir] diff --git a/tests/config/__snapshots__/test_toml_as_dict_snaps.ambr b/tests/config/__snapshots__/test_toml_as_dict_snaps.ambr new file mode 100644 index 0000000..56c79c2 --- /dev/null +++ b/tests/config/__snapshots__/test_toml_as_dict_snaps.ambr @@ -0,0 +1,301 @@ +# serializer version: 1 +# name: test_parse_config_all_fields_toml_as_dict + dict({ + 'deepdog_config': dict({ + 'costs_to_try': list([ + 20, + 2, + 0.2, + ]), + 'max_monte_carlo_cycles_steps': 20, + 'target_success': 2000, + 'use_log_noise': True, + }), + 'default_model_param_config': dict({ + 'w_log_max': 1, + 'w_log_min': -5, + 'x_max': 20, + 'x_min': -20, + 'y_max': 10, + 'y_min': -10, + 'z_max': 6.5, + 'z_min': 5, + }), + 'general_config': dict({ + 'dots_json_name': 'test_dots.json', + 'indexes_json_name': 'test_indexes.json', + 'log_pattern': '%(asctime)s | %(message)s', + 'measurement_type': , + 'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv', + 'mega_merged_name': 'test_mega_merged.csv', + 'out_dir_name': 'out', + 'root_directory': PosixPath('test_root'), + 'skip_to_stage': 1, + }), + 'generation_config': dict({ + 'bin_log_width': 0.25, + 'counts': list([ + 1, + 5, + 10, + ]), + 'num_bin_time_series': 25, + 'num_replicas': 3, + 'orientations': list([ + , + , + , + ]), + 'override_dipole_configs': dict({ + 'scenario1': list([ + dict({ + 'p': array([3, 5, 7]), + 's': array([2, 4, 6]), + 'w': 10, + }), + dict({ + 'p': array([30, 50, 70]), + 's': array([20, 40, 60]), + 'w': 10.55, + }), + ]), + }), + 'override_measurement_filesets': None, + 'tantri_configs': list([ + dict({ + 'delta_t': 0.01, + 'index_seed_starter': 15151, + 'num_iterations': 100, + 'num_seeds': 5, + }), + dict({ + 'delta_t': 1, + 'index_seed_starter': 1234, + 'num_iterations': 200, + 'num_seeds': 100, + }), + ]), + }), + }) +# --- +# name: test_parse_config_few_fields_toml_as_dict + dict({ + 'deepdog_config': dict({ + 'costs_to_try': list([ + 5, + 2, + 1, + 0.5, + 0.2, + ]), + 'max_monte_carlo_cycles_steps': 20, + 'target_success': 2000, + 'use_log_noise': True, + }), + 'default_model_param_config': dict({ + 'w_log_max': 1, + 'w_log_min': -5, + 'x_max': 20, + 'x_min': -20, + 'y_max': 10, + 'y_min': -10, + 'z_max': 6.5, + 'z_min': 5, + }), + 'general_config': dict({ + 'dots_json_name': 'dots.json', + 'indexes_json_name': 'indexes.json', + 'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', + 'measurement_type': , + 'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv', + 'mega_merged_name': 'mega_merged_coalesced.csv', + 'out_dir_name': 'out', + 'root_directory': PosixPath('test_root1'), + 'skip_to_stage': None, + }), + 'generation_config': dict({ + 'bin_log_width': 0.25, + 'counts': list([ + 1, + 5, + 10, + ]), + 'num_bin_time_series': 25, + 'num_replicas': 2, + 'orientations': list([ + , + , + , + ]), + 'override_dipole_configs': dict({ + 'scenario1': list([ + dict({ + 'p': array([3, 5, 7]), + 's': array([2, 4, 6]), + 'w': 10, + }), + dict({ + 'p': array([30, 50, 70]), + 's': array([20, 40, 60]), + 'w': 10.55, + }), + ]), + }), + 'override_measurement_filesets': None, + 'tantri_configs': list([ + dict({ + 'delta_t': 0.01, + 'index_seed_starter': 15151, + 'num_iterations': 100, + 'num_seeds': 5, + }), + dict({ + 'delta_t': 1, + 'index_seed_starter': 1234, + 'num_iterations': 200, + 'num_seeds': 100, + }), + ]), + }), + }) +# --- +# name: test_parse_config_geom_params_toml_as_dict + dict({ + 'deepdog_config': dict({ + 'costs_to_try': list([ + 5, + 2, + 1, + 0.5, + 0.2, + ]), + 'max_monte_carlo_cycles_steps': 20, + 'target_success': 2000, + 'use_log_noise': True, + }), + 'default_model_param_config': dict({ + 'w_log_max': 1.5, + 'w_log_min': -3, + 'x_max': 20, + 'x_min': -20, + 'y_max': 10, + 'y_min': -10, + 'z_max': 2, + 'z_min': 0, + }), + 'general_config': dict({ + 'dots_json_name': 'dots.json', + 'indexes_json_name': 'indexes.json', + 'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', + 'measurement_type': , + 'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv', + 'mega_merged_name': 'mega_merged_coalesced.csv', + 'out_dir_name': 'out', + 'root_directory': PosixPath('test_root1'), + 'skip_to_stage': None, + }), + 'generation_config': dict({ + 'bin_log_width': 0.25, + 'counts': list([ + 1, + 5, + 10, + ]), + 'num_bin_time_series': 25, + 'num_replicas': 2, + 'orientations': list([ + , + , + , + ]), + 'override_dipole_configs': dict({ + 'scenario1': list([ + dict({ + 'p': array([3, 5, 7]), + 's': array([2, 4, 6]), + 'w': 10, + }), + dict({ + 'p': array([30, 50, 70]), + 's': array([20, 40, 60]), + 'w': 10.55, + }), + ]), + }), + 'override_measurement_filesets': None, + 'tantri_configs': list([ + dict({ + 'delta_t': 0.01, + 'index_seed_starter': 15151, + 'num_iterations': 100, + 'num_seeds': 5, + }), + dict({ + 'delta_t': 1, + 'index_seed_starter': 1234, + 'num_iterations': 200, + 'num_seeds': 100, + }), + ]), + }), + }) +# --- +# name: test_parse_config_toml_as_dict + dict({ + 'deepdog_config': dict({ + 'costs_to_try': list([ + 10, + 1, + 0.1, + ]), + 'max_monte_carlo_cycles_steps': 20, + 'target_success': 1000, + 'use_log_noise': False, + }), + 'default_model_param_config': dict({ + 'w_log_max': 1, + 'w_log_min': -5, + 'x_max': 20, + 'x_min': -20, + 'y_max': 10, + 'y_min': -10, + 'z_max': 6.5, + 'z_min': 5, + }), + 'general_config': dict({ + 'dots_json_name': 'test_dots.json', + 'indexes_json_name': 'test_indexes.json', + 'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', + 'measurement_type': , + 'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv', + 'mega_merged_name': 'test_mega_merged.csv', + 'out_dir_name': 'test_out', + 'root_directory': PosixPath('test_root'), + 'skip_to_stage': 1, + }), + 'generation_config': dict({ + 'bin_log_width': 0.25, + 'counts': list([ + 1, + 10, + ]), + 'num_bin_time_series': 25, + 'num_replicas': 3, + 'orientations': list([ + , + , + , + ]), + 'override_dipole_configs': None, + 'override_measurement_filesets': None, + 'tantri_configs': list([ + dict({ + 'delta_t': 0.05, + 'index_seed_starter': 31415, + 'num_iterations': 100000, + 'num_seeds': 100, + }), + ]), + }), + }) +# --- diff --git a/tests/config/__snapshots__/test_toml_reader.ambr b/tests/config/__snapshots__/test_toml_reader.ambr index b443583..8678458 100644 --- a/tests/config/__snapshots__/test_toml_reader.ambr +++ b/tests/config/__snapshots__/test_toml_reader.ambr @@ -1,13 +1,13 @@ # serializer version: 1 # name: test_parse_config_all_fields_toml - Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[, , ], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1)) + Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[, , ], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1)) # --- # name: test_parse_config_few_fields_toml - Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[, , ], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1)) + Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[, , ], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1)) # --- # name: test_parse_config_geom_params_toml - Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[, , ], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5)) + Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[, , ], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5)) # --- # name: test_parse_config_toml - Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[, , ], num_replicas=3, override_dipole_configs=None, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1)) + Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[, , ], num_replicas=3, override_dipole_configs=None, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1)) # --- diff --git a/tests/config/test_toml_as_dict_snaps.py b/tests/config/test_toml_as_dict_snaps.py new file mode 100644 index 0000000..a060e91 --- /dev/null +++ b/tests/config/test_toml_as_dict_snaps.py @@ -0,0 +1,35 @@ +import pathlib +from dataclasses import asdict + +import kalpaa.config.config_reader + + +TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files" + + +def test_parse_config_toml_as_dict(snapshot): + test_config_file = TEST_DATA_DIR / "test_config.toml" + actual_config = kalpaa.config.config_reader.read_config(test_config_file) + + assert asdict(actual_config) == snapshot + + +def test_parse_config_all_fields_toml_as_dict(snapshot): + test_config_file = TEST_DATA_DIR / "test_config_all_fields.toml" + actual_config = kalpaa.config.config_reader.read_config(test_config_file) + + assert asdict(actual_config) == snapshot + + +def test_parse_config_few_fields_toml_as_dict(snapshot): + test_config_file = TEST_DATA_DIR / "test_config_few_fields.toml" + actual_config = kalpaa.config.config_reader.read_config(test_config_file) + + assert asdict(actual_config) == snapshot + + +def test_parse_config_geom_params_toml_as_dict(snapshot): + test_config_file = TEST_DATA_DIR / "test_config_geom_params.toml" + actual_config = kalpaa.config.config_reader.read_config(test_config_file) + + assert asdict(actual_config) == snapshot