feat: adds ability to manually specify measurement files

This commit is contained in:
Deepak Mallubhotla 2025-02-23 22:49:23 -06:00
parent 6d24f96b0f
commit 3fe55dbb67
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
9 changed files with 433 additions and 31 deletions

View File

@ -8,6 +8,7 @@ from kalpaa.config.config import (
DeepdogConfig,
Config,
ReducedModelParams,
OVERRIDE_MEASUREMENT_DIR_NAME,
)
from kalpaa.config.config_reader import (
read_config_dict,
@ -30,4 +31,5 @@ __all__ = [
"serialize_config",
"read_config",
"read_general_config_dict",
"OVERRIDE_MEASUREMENT_DIR_NAME",
]

View File

@ -127,7 +127,9 @@ class GenerationConfig:
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
] = None
override_measurements: typing.Optional[typing.Sequence[str]] = None
override_measurement_filesets: typing.Optional[
typing.Mapping[str, typing.Sequence[str]]
] = None
tantri_configs: typing.List[TantriConfig] = field(
default_factory=lambda: [TantriConfig()]
@ -170,6 +172,9 @@ class Config:
def get_dots_json_path(self) -> pathlib.Path:
return self.absify(self.general_config.dots_json_name)
def get_override_dir_path(self) -> pathlib.Path:
return self.absify(OVERRIDE_MEASUREMENT_DIR_NAME)
def indexifier(self) -> deepdog.indexify.Indexifier:
with self.absify(self.general_config.indexes_json_name).open(
"r"

View File

@ -662,11 +662,13 @@ class BinnedData:
def stdev_cost_function_filter(
self,
dot_names: typing.Sequence[str],
dot_names_or_pairs: typing.Union[
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
],
target_cost: float,
use_log_noise: bool = False,
):
measurements = self.measurements(dot_names)
measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs)
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
cost_function, target_cost

View File

@ -135,10 +135,22 @@ def main():
_logger.info(f"Copying {file} to {root}")
(root / file).write_text((pathlib.Path.cwd() / file).read_text())
# if config.general_config is not None:
# _logger.info(
# f"Overriding measurements with {config.general_config.override_measurements}"
# )
if config.generation_config.override_measurement_filesets is not None:
_logger.info(
f"Overriding measurements with {config.generation_config.override_measurement_filesets}"
)
override_directory = root / kalpaa.config.OVERRIDE_MEASUREMENT_DIR_NAME
override_directory.mkdir(exist_ok=True, parents=True)
for (
key,
files,
) in config.generation_config.override_measurement_filesets.items():
_logger.info(f"Copying for {key=}, {files} to {override_directory}")
for file in files:
fileset_dir = override_directory / key
fileset_dir.mkdir(exist_ok=True, parents=True)
_logger.info(f"Copying {file} to {override_directory}")
(fileset_dir / file).write_text((pathlib.Path.cwd() / file).read_text())
overridden_config = dataclasses.replace(
config,

View File

@ -200,7 +200,27 @@ class Stage01Runner:
def run(self):
seed_index = 0
if self.config.generation_config.override_dipole_configs is None:
if self.config.generation_config.override_measurement_filesets is not None:
for (
override_name,
) in self.config.generation_config.override_measurement_filesets.keys():
# don't need to do anything with the files, just create the out dir
out = self.config.get_out_dir_path()
directory = out / f"{override_name}"
directory.mkdir(parents=True, exist_ok=True)
elif self.config.generation_config.override_dipole_configs is not None:
_logger.debug(
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
)
for (
override_name,
override_dipoles,
) in self.config.generation_config.override_dipole_configs.items():
self.generate_override_dipole(
seed_index, override_name, override_dipoles
)
else:
# should be by default
_logger.debug("no override needed!")
for count in self.config.generation_config.counts:
@ -213,17 +233,6 @@ class Stage01Runner:
seed_index, count, orientation, replica
)
seed_index += 1
else:
_logger.debug(
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
)
for (
override_name,
override_dipoles,
) in self.config.generation_config.override_dipole_configs.items():
self.generate_override_dipole(
seed_index, override_name, override_dipoles
)
def parse_args():

View File

@ -90,7 +90,9 @@ class Stage02Runner:
else:
return ["dot1", "dot2", current_dot]
def run_in_subdir(self, subdir: pathlib.Path):
def run_in_subdir(
self, subdir: pathlib.Path, override_key: typing.Optional[str] = None
):
with kalpaa.common.new_cd(subdir):
_logger.debug(f"Running inside {subdir=}")
@ -123,6 +125,7 @@ class Stage02Runner:
dot_name: str,
trial_name: str,
seed_index: int,
override_name: typing.Optional[str] = None,
):
# _logger.info(f"Got job index {job_index}")
# NOTE This guy runs inside subdirs, obviously. In something like <kalpa>/out/z-10-2/dipoles
@ -137,6 +140,29 @@ class Stage02Runner:
f"Have {self.config.generation_config.tantri_configs} as our tantri_configs"
)
num_tantri_configs = len(self.config.generation_config.tantri_configs)
if override_name is not None:
if self.config.generation_config.override_measurement_filesets is None:
raise ValueError(
"override_name provided but no override_measurement_filesets, shouldn't be possible to get here"
)
_logger.info(f"Time to read override measurement fileset {override_name}")
override_dir = self.config.get_override_dir_path()
override_measurements = (
self.config.generation_config.override_measurement_filesets[
override_name
]
)
_logger.info(f"Finding files {override_measurements} in {override_dir}")
binned_datas = [
kalpaa.read_dots_and_binned(
self.config.get_dots_json_path(),
override_dir / measurement,
)
for measurement in override_measurements
]
else:
binned_datas = [
kalpaa.read_dots_and_binned(
self.config.get_dots_json_path(),
@ -221,6 +247,16 @@ class Stage02Runner:
_logger.info(results)
def run(self):
if self.config.generation_config.override_measurement_filesets is not None:
_logger.info("Using override configuration.")
for (
override_name
) in self.config.generation_config.override_measurement_filesets.keys():
subdir = self.config.get_out_dir_path() / override_name
dipoles_dir = subdir / "dipoles"
dipoles_dir.mkdir(exist_ok=True, parents=False)
self.run_in_subdir(dipoles_dir, override_key=override_name)
"""Going to iterate over every folder in out_dir, and execute the subdir stuff inside dirs like <kalpa>/out/z-10-2/dipoles"""
out_dir_path = self.config.get_out_dir_path()
subdirs = [child for child in out_dir_path.iterdir() if child.is_dir]

View File

@ -0,0 +1,301 @@
# serializer version: 1
# name: test_parse_config_all_fields_toml_as_dict
dict({
'deepdog_config': dict({
'costs_to_try': list([
20,
2,
0.2,
]),
'max_monte_carlo_cycles_steps': 20,
'target_success': 2000,
'use_log_noise': True,
}),
'default_model_param_config': dict({
'w_log_max': 1,
'w_log_min': -5,
'x_max': 20,
'x_min': -20,
'y_max': 10,
'y_min': -10,
'z_max': 6.5,
'z_min': 5,
}),
'general_config': dict({
'dots_json_name': 'test_dots.json',
'indexes_json_name': 'test_indexes.json',
'log_pattern': '%(asctime)s | %(message)s',
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
'mega_merged_name': 'test_mega_merged.csv',
'out_dir_name': 'out',
'root_directory': PosixPath('test_root'),
'skip_to_stage': 1,
}),
'generation_config': dict({
'bin_log_width': 0.25,
'counts': list([
1,
5,
10,
]),
'num_bin_time_series': 25,
'num_replicas': 3,
'orientations': list([
<Orientation.RANDOM: 'RANDOM'>,
<Orientation.Z: 'Z'>,
<Orientation.XY: 'XY'>,
]),
'override_dipole_configs': dict({
'scenario1': list([
dict({
'p': array([3, 5, 7]),
's': array([2, 4, 6]),
'w': 10,
}),
dict({
'p': array([30, 50, 70]),
's': array([20, 40, 60]),
'w': 10.55,
}),
]),
}),
'override_measurement_filesets': None,
'tantri_configs': list([
dict({
'delta_t': 0.01,
'index_seed_starter': 15151,
'num_iterations': 100,
'num_seeds': 5,
}),
dict({
'delta_t': 1,
'index_seed_starter': 1234,
'num_iterations': 200,
'num_seeds': 100,
}),
]),
}),
})
# ---
# name: test_parse_config_few_fields_toml_as_dict
dict({
'deepdog_config': dict({
'costs_to_try': list([
5,
2,
1,
0.5,
0.2,
]),
'max_monte_carlo_cycles_steps': 20,
'target_success': 2000,
'use_log_noise': True,
}),
'default_model_param_config': dict({
'w_log_max': 1,
'w_log_min': -5,
'x_max': 20,
'x_min': -20,
'y_max': 10,
'y_min': -10,
'z_max': 6.5,
'z_min': 5,
}),
'general_config': dict({
'dots_json_name': 'dots.json',
'indexes_json_name': 'indexes.json',
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
'mega_merged_name': 'mega_merged_coalesced.csv',
'out_dir_name': 'out',
'root_directory': PosixPath('test_root1'),
'skip_to_stage': None,
}),
'generation_config': dict({
'bin_log_width': 0.25,
'counts': list([
1,
5,
10,
]),
'num_bin_time_series': 25,
'num_replicas': 2,
'orientations': list([
<Orientation.RANDOM: 'RANDOM'>,
<Orientation.Z: 'Z'>,
<Orientation.XY: 'XY'>,
]),
'override_dipole_configs': dict({
'scenario1': list([
dict({
'p': array([3, 5, 7]),
's': array([2, 4, 6]),
'w': 10,
}),
dict({
'p': array([30, 50, 70]),
's': array([20, 40, 60]),
'w': 10.55,
}),
]),
}),
'override_measurement_filesets': None,
'tantri_configs': list([
dict({
'delta_t': 0.01,
'index_seed_starter': 15151,
'num_iterations': 100,
'num_seeds': 5,
}),
dict({
'delta_t': 1,
'index_seed_starter': 1234,
'num_iterations': 200,
'num_seeds': 100,
}),
]),
}),
})
# ---
# name: test_parse_config_geom_params_toml_as_dict
dict({
'deepdog_config': dict({
'costs_to_try': list([
5,
2,
1,
0.5,
0.2,
]),
'max_monte_carlo_cycles_steps': 20,
'target_success': 2000,
'use_log_noise': True,
}),
'default_model_param_config': dict({
'w_log_max': 1.5,
'w_log_min': -3,
'x_max': 20,
'x_min': -20,
'y_max': 10,
'y_min': -10,
'z_max': 2,
'z_min': 0,
}),
'general_config': dict({
'dots_json_name': 'dots.json',
'indexes_json_name': 'indexes.json',
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
'mega_merged_name': 'mega_merged_coalesced.csv',
'out_dir_name': 'out',
'root_directory': PosixPath('test_root1'),
'skip_to_stage': None,
}),
'generation_config': dict({
'bin_log_width': 0.25,
'counts': list([
1,
5,
10,
]),
'num_bin_time_series': 25,
'num_replicas': 2,
'orientations': list([
<Orientation.RANDOM: 'RANDOM'>,
<Orientation.Z: 'Z'>,
<Orientation.XY: 'XY'>,
]),
'override_dipole_configs': dict({
'scenario1': list([
dict({
'p': array([3, 5, 7]),
's': array([2, 4, 6]),
'w': 10,
}),
dict({
'p': array([30, 50, 70]),
's': array([20, 40, 60]),
'w': 10.55,
}),
]),
}),
'override_measurement_filesets': None,
'tantri_configs': list([
dict({
'delta_t': 0.01,
'index_seed_starter': 15151,
'num_iterations': 100,
'num_seeds': 5,
}),
dict({
'delta_t': 1,
'index_seed_starter': 1234,
'num_iterations': 200,
'num_seeds': 100,
}),
]),
}),
})
# ---
# name: test_parse_config_toml_as_dict
dict({
'deepdog_config': dict({
'costs_to_try': list([
10,
1,
0.1,
]),
'max_monte_carlo_cycles_steps': 20,
'target_success': 1000,
'use_log_noise': False,
}),
'default_model_param_config': dict({
'w_log_max': 1,
'w_log_min': -5,
'x_max': 20,
'x_min': -20,
'y_max': 10,
'y_min': -10,
'z_max': 6.5,
'z_min': 5,
}),
'general_config': dict({
'dots_json_name': 'test_dots.json',
'indexes_json_name': 'test_indexes.json',
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
'mega_merged_name': 'test_mega_merged.csv',
'out_dir_name': 'test_out',
'root_directory': PosixPath('test_root'),
'skip_to_stage': 1,
}),
'generation_config': dict({
'bin_log_width': 0.25,
'counts': list([
1,
10,
]),
'num_bin_time_series': 25,
'num_replicas': 3,
'orientations': list([
<Orientation.RANDOM: 'RANDOM'>,
<Orientation.Z: 'Z'>,
<Orientation.XY: 'XY'>,
]),
'override_dipole_configs': None,
'override_measurement_filesets': None,
'tantri_configs': list([
dict({
'delta_t': 0.05,
'index_seed_starter': 31415,
'num_iterations': 100000,
'num_seeds': 100,
}),
]),
}),
})
# ---

View File

@ -1,13 +1,13 @@
# serializer version: 1
# name: test_parse_config_all_fields_toml
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
# ---
# name: test_parse_config_few_fields_toml
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
# ---
# name: test_parse_config_geom_params_toml
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5))
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5))
# ---
# name: test_parse_config_toml
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
# ---

View File

@ -0,0 +1,35 @@
import pathlib
from dataclasses import asdict
import kalpaa.config.config_reader
TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files"
def test_parse_config_toml_as_dict(snapshot):
test_config_file = TEST_DATA_DIR / "test_config.toml"
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
assert asdict(actual_config) == snapshot
def test_parse_config_all_fields_toml_as_dict(snapshot):
test_config_file = TEST_DATA_DIR / "test_config_all_fields.toml"
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
assert asdict(actual_config) == snapshot
def test_parse_config_few_fields_toml_as_dict(snapshot):
test_config_file = TEST_DATA_DIR / "test_config_few_fields.toml"
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
assert asdict(actual_config) == snapshot
def test_parse_config_geom_params_toml_as_dict(snapshot):
test_config_file = TEST_DATA_DIR / "test_config_geom_params.toml"
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
assert asdict(actual_config) == snapshot