feat: adds ability to manually specify measurement files
This commit is contained in:
parent
6d24f96b0f
commit
3fe55dbb67
@ -8,6 +8,7 @@ from kalpaa.config.config import (
|
|||||||
DeepdogConfig,
|
DeepdogConfig,
|
||||||
Config,
|
Config,
|
||||||
ReducedModelParams,
|
ReducedModelParams,
|
||||||
|
OVERRIDE_MEASUREMENT_DIR_NAME,
|
||||||
)
|
)
|
||||||
from kalpaa.config.config_reader import (
|
from kalpaa.config.config_reader import (
|
||||||
read_config_dict,
|
read_config_dict,
|
||||||
@ -30,4 +31,5 @@ __all__ = [
|
|||||||
"serialize_config",
|
"serialize_config",
|
||||||
"read_config",
|
"read_config",
|
||||||
"read_general_config_dict",
|
"read_general_config_dict",
|
||||||
|
"OVERRIDE_MEASUREMENT_DIR_NAME",
|
||||||
]
|
]
|
||||||
|
@ -127,7 +127,9 @@ class GenerationConfig:
|
|||||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
override_measurements: typing.Optional[typing.Sequence[str]] = None
|
override_measurement_filesets: typing.Optional[
|
||||||
|
typing.Mapping[str, typing.Sequence[str]]
|
||||||
|
] = None
|
||||||
|
|
||||||
tantri_configs: typing.List[TantriConfig] = field(
|
tantri_configs: typing.List[TantriConfig] = field(
|
||||||
default_factory=lambda: [TantriConfig()]
|
default_factory=lambda: [TantriConfig()]
|
||||||
@ -170,6 +172,9 @@ class Config:
|
|||||||
def get_dots_json_path(self) -> pathlib.Path:
|
def get_dots_json_path(self) -> pathlib.Path:
|
||||||
return self.absify(self.general_config.dots_json_name)
|
return self.absify(self.general_config.dots_json_name)
|
||||||
|
|
||||||
|
def get_override_dir_path(self) -> pathlib.Path:
|
||||||
|
return self.absify(OVERRIDE_MEASUREMENT_DIR_NAME)
|
||||||
|
|
||||||
def indexifier(self) -> deepdog.indexify.Indexifier:
|
def indexifier(self) -> deepdog.indexify.Indexifier:
|
||||||
with self.absify(self.general_config.indexes_json_name).open(
|
with self.absify(self.general_config.indexes_json_name).open(
|
||||||
"r"
|
"r"
|
||||||
|
@ -662,11 +662,13 @@ class BinnedData:
|
|||||||
|
|
||||||
def stdev_cost_function_filter(
|
def stdev_cost_function_filter(
|
||||||
self,
|
self,
|
||||||
dot_names: typing.Sequence[str],
|
dot_names_or_pairs: typing.Union[
|
||||||
|
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
||||||
|
],
|
||||||
target_cost: float,
|
target_cost: float,
|
||||||
use_log_noise: bool = False,
|
use_log_noise: bool = False,
|
||||||
):
|
):
|
||||||
measurements = self.measurements(dot_names)
|
measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs)
|
||||||
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
|
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
|
||||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||||
cost_function, target_cost
|
cost_function, target_cost
|
||||||
|
@ -135,10 +135,22 @@ def main():
|
|||||||
_logger.info(f"Copying {file} to {root}")
|
_logger.info(f"Copying {file} to {root}")
|
||||||
(root / file).write_text((pathlib.Path.cwd() / file).read_text())
|
(root / file).write_text((pathlib.Path.cwd() / file).read_text())
|
||||||
|
|
||||||
# if config.general_config is not None:
|
if config.generation_config.override_measurement_filesets is not None:
|
||||||
# _logger.info(
|
_logger.info(
|
||||||
# f"Overriding measurements with {config.general_config.override_measurements}"
|
f"Overriding measurements with {config.generation_config.override_measurement_filesets}"
|
||||||
# )
|
)
|
||||||
|
override_directory = root / kalpaa.config.OVERRIDE_MEASUREMENT_DIR_NAME
|
||||||
|
override_directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
for (
|
||||||
|
key,
|
||||||
|
files,
|
||||||
|
) in config.generation_config.override_measurement_filesets.items():
|
||||||
|
_logger.info(f"Copying for {key=}, {files} to {override_directory}")
|
||||||
|
for file in files:
|
||||||
|
fileset_dir = override_directory / key
|
||||||
|
fileset_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
_logger.info(f"Copying {file} to {override_directory}")
|
||||||
|
(fileset_dir / file).write_text((pathlib.Path.cwd() / file).read_text())
|
||||||
|
|
||||||
overridden_config = dataclasses.replace(
|
overridden_config = dataclasses.replace(
|
||||||
config,
|
config,
|
||||||
|
@ -200,7 +200,27 @@ class Stage01Runner:
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
seed_index = 0
|
seed_index = 0
|
||||||
if self.config.generation_config.override_dipole_configs is None:
|
if self.config.generation_config.override_measurement_filesets is not None:
|
||||||
|
for (
|
||||||
|
override_name,
|
||||||
|
) in self.config.generation_config.override_measurement_filesets.keys():
|
||||||
|
# don't need to do anything with the files, just create the out dir
|
||||||
|
out = self.config.get_out_dir_path()
|
||||||
|
directory = out / f"{override_name}"
|
||||||
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
elif self.config.generation_config.override_dipole_configs is not None:
|
||||||
|
_logger.debug(
|
||||||
|
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
||||||
|
)
|
||||||
|
for (
|
||||||
|
override_name,
|
||||||
|
override_dipoles,
|
||||||
|
) in self.config.generation_config.override_dipole_configs.items():
|
||||||
|
self.generate_override_dipole(
|
||||||
|
seed_index, override_name, override_dipoles
|
||||||
|
)
|
||||||
|
else:
|
||||||
# should be by default
|
# should be by default
|
||||||
_logger.debug("no override needed!")
|
_logger.debug("no override needed!")
|
||||||
for count in self.config.generation_config.counts:
|
for count in self.config.generation_config.counts:
|
||||||
@ -213,17 +233,6 @@ class Stage01Runner:
|
|||||||
seed_index, count, orientation, replica
|
seed_index, count, orientation, replica
|
||||||
)
|
)
|
||||||
seed_index += 1
|
seed_index += 1
|
||||||
else:
|
|
||||||
_logger.debug(
|
|
||||||
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
|
||||||
)
|
|
||||||
for (
|
|
||||||
override_name,
|
|
||||||
override_dipoles,
|
|
||||||
) in self.config.generation_config.override_dipole_configs.items():
|
|
||||||
self.generate_override_dipole(
|
|
||||||
seed_index, override_name, override_dipoles
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -90,7 +90,9 @@ class Stage02Runner:
|
|||||||
else:
|
else:
|
||||||
return ["dot1", "dot2", current_dot]
|
return ["dot1", "dot2", current_dot]
|
||||||
|
|
||||||
def run_in_subdir(self, subdir: pathlib.Path):
|
def run_in_subdir(
|
||||||
|
self, subdir: pathlib.Path, override_key: typing.Optional[str] = None
|
||||||
|
):
|
||||||
with kalpaa.common.new_cd(subdir):
|
with kalpaa.common.new_cd(subdir):
|
||||||
_logger.debug(f"Running inside {subdir=}")
|
_logger.debug(f"Running inside {subdir=}")
|
||||||
|
|
||||||
@ -123,6 +125,7 @@ class Stage02Runner:
|
|||||||
dot_name: str,
|
dot_name: str,
|
||||||
trial_name: str,
|
trial_name: str,
|
||||||
seed_index: int,
|
seed_index: int,
|
||||||
|
override_name: typing.Optional[str] = None,
|
||||||
):
|
):
|
||||||
# _logger.info(f"Got job index {job_index}")
|
# _logger.info(f"Got job index {job_index}")
|
||||||
# NOTE This guy runs inside subdirs, obviously. In something like <kalpa>/out/z-10-2/dipoles
|
# NOTE This guy runs inside subdirs, obviously. In something like <kalpa>/out/z-10-2/dipoles
|
||||||
@ -137,14 +140,37 @@ class Stage02Runner:
|
|||||||
f"Have {self.config.generation_config.tantri_configs} as our tantri_configs"
|
f"Have {self.config.generation_config.tantri_configs} as our tantri_configs"
|
||||||
)
|
)
|
||||||
num_tantri_configs = len(self.config.generation_config.tantri_configs)
|
num_tantri_configs = len(self.config.generation_config.tantri_configs)
|
||||||
binned_datas = [
|
|
||||||
kalpaa.read_dots_and_binned(
|
if override_name is not None:
|
||||||
self.config.get_dots_json_path(),
|
if self.config.generation_config.override_measurement_filesets is None:
|
||||||
pathlib.Path("..")
|
raise ValueError(
|
||||||
/ kalpaa.common.tantri_binned_output_name(tantri_index),
|
"override_name provided but no override_measurement_filesets, shouldn't be possible to get here"
|
||||||
|
)
|
||||||
|
_logger.info(f"Time to read override measurement fileset {override_name}")
|
||||||
|
override_dir = self.config.get_override_dir_path()
|
||||||
|
override_measurements = (
|
||||||
|
self.config.generation_config.override_measurement_filesets[
|
||||||
|
override_name
|
||||||
|
]
|
||||||
)
|
)
|
||||||
for tantri_index in range(num_tantri_configs)
|
_logger.info(f"Finding files {override_measurements} in {override_dir}")
|
||||||
]
|
binned_datas = [
|
||||||
|
kalpaa.read_dots_and_binned(
|
||||||
|
self.config.get_dots_json_path(),
|
||||||
|
override_dir / measurement,
|
||||||
|
)
|
||||||
|
for measurement in override_measurements
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
|
||||||
|
binned_datas = [
|
||||||
|
kalpaa.read_dots_and_binned(
|
||||||
|
self.config.get_dots_json_path(),
|
||||||
|
pathlib.Path("..")
|
||||||
|
/ kalpaa.common.tantri_binned_output_name(tantri_index),
|
||||||
|
)
|
||||||
|
for tantri_index in range(num_tantri_configs)
|
||||||
|
]
|
||||||
|
|
||||||
dot_names = self._dots_to_include(dot_name)
|
dot_names = self._dots_to_include(dot_name)
|
||||||
_logger.debug(f"Got dot names {dot_names}")
|
_logger.debug(f"Got dot names {dot_names}")
|
||||||
@ -221,6 +247,16 @@ class Stage02Runner:
|
|||||||
_logger.info(results)
|
_logger.info(results)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
if self.config.generation_config.override_measurement_filesets is not None:
|
||||||
|
_logger.info("Using override configuration.")
|
||||||
|
for (
|
||||||
|
override_name
|
||||||
|
) in self.config.generation_config.override_measurement_filesets.keys():
|
||||||
|
subdir = self.config.get_out_dir_path() / override_name
|
||||||
|
dipoles_dir = subdir / "dipoles"
|
||||||
|
dipoles_dir.mkdir(exist_ok=True, parents=False)
|
||||||
|
self.run_in_subdir(dipoles_dir, override_key=override_name)
|
||||||
|
|
||||||
"""Going to iterate over every folder in out_dir, and execute the subdir stuff inside dirs like <kalpa>/out/z-10-2/dipoles"""
|
"""Going to iterate over every folder in out_dir, and execute the subdir stuff inside dirs like <kalpa>/out/z-10-2/dipoles"""
|
||||||
out_dir_path = self.config.get_out_dir_path()
|
out_dir_path = self.config.get_out_dir_path()
|
||||||
subdirs = [child for child in out_dir_path.iterdir() if child.is_dir]
|
subdirs = [child for child in out_dir_path.iterdir() if child.is_dir]
|
||||||
|
301
tests/config/__snapshots__/test_toml_as_dict_snaps.ambr
Normal file
301
tests/config/__snapshots__/test_toml_as_dict_snaps.ambr
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_parse_config_all_fields_toml_as_dict
|
||||||
|
dict({
|
||||||
|
'deepdog_config': dict({
|
||||||
|
'costs_to_try': list([
|
||||||
|
20,
|
||||||
|
2,
|
||||||
|
0.2,
|
||||||
|
]),
|
||||||
|
'max_monte_carlo_cycles_steps': 20,
|
||||||
|
'target_success': 2000,
|
||||||
|
'use_log_noise': True,
|
||||||
|
}),
|
||||||
|
'default_model_param_config': dict({
|
||||||
|
'w_log_max': 1,
|
||||||
|
'w_log_min': -5,
|
||||||
|
'x_max': 20,
|
||||||
|
'x_min': -20,
|
||||||
|
'y_max': 10,
|
||||||
|
'y_min': -10,
|
||||||
|
'z_max': 6.5,
|
||||||
|
'z_min': 5,
|
||||||
|
}),
|
||||||
|
'general_config': dict({
|
||||||
|
'dots_json_name': 'test_dots.json',
|
||||||
|
'indexes_json_name': 'test_indexes.json',
|
||||||
|
'log_pattern': '%(asctime)s | %(message)s',
|
||||||
|
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
|
||||||
|
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
|
||||||
|
'mega_merged_name': 'test_mega_merged.csv',
|
||||||
|
'out_dir_name': 'out',
|
||||||
|
'root_directory': PosixPath('test_root'),
|
||||||
|
'skip_to_stage': 1,
|
||||||
|
}),
|
||||||
|
'generation_config': dict({
|
||||||
|
'bin_log_width': 0.25,
|
||||||
|
'counts': list([
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
]),
|
||||||
|
'num_bin_time_series': 25,
|
||||||
|
'num_replicas': 3,
|
||||||
|
'orientations': list([
|
||||||
|
<Orientation.RANDOM: 'RANDOM'>,
|
||||||
|
<Orientation.Z: 'Z'>,
|
||||||
|
<Orientation.XY: 'XY'>,
|
||||||
|
]),
|
||||||
|
'override_dipole_configs': dict({
|
||||||
|
'scenario1': list([
|
||||||
|
dict({
|
||||||
|
'p': array([3, 5, 7]),
|
||||||
|
's': array([2, 4, 6]),
|
||||||
|
'w': 10,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'p': array([30, 50, 70]),
|
||||||
|
's': array([20, 40, 60]),
|
||||||
|
'w': 10.55,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'override_measurement_filesets': None,
|
||||||
|
'tantri_configs': list([
|
||||||
|
dict({
|
||||||
|
'delta_t': 0.01,
|
||||||
|
'index_seed_starter': 15151,
|
||||||
|
'num_iterations': 100,
|
||||||
|
'num_seeds': 5,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'delta_t': 1,
|
||||||
|
'index_seed_starter': 1234,
|
||||||
|
'num_iterations': 200,
|
||||||
|
'num_seeds': 100,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_parse_config_few_fields_toml_as_dict
|
||||||
|
dict({
|
||||||
|
'deepdog_config': dict({
|
||||||
|
'costs_to_try': list([
|
||||||
|
5,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
0.5,
|
||||||
|
0.2,
|
||||||
|
]),
|
||||||
|
'max_monte_carlo_cycles_steps': 20,
|
||||||
|
'target_success': 2000,
|
||||||
|
'use_log_noise': True,
|
||||||
|
}),
|
||||||
|
'default_model_param_config': dict({
|
||||||
|
'w_log_max': 1,
|
||||||
|
'w_log_min': -5,
|
||||||
|
'x_max': 20,
|
||||||
|
'x_min': -20,
|
||||||
|
'y_max': 10,
|
||||||
|
'y_min': -10,
|
||||||
|
'z_max': 6.5,
|
||||||
|
'z_min': 5,
|
||||||
|
}),
|
||||||
|
'general_config': dict({
|
||||||
|
'dots_json_name': 'dots.json',
|
||||||
|
'indexes_json_name': 'indexes.json',
|
||||||
|
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||||
|
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||||
|
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
|
||||||
|
'mega_merged_name': 'mega_merged_coalesced.csv',
|
||||||
|
'out_dir_name': 'out',
|
||||||
|
'root_directory': PosixPath('test_root1'),
|
||||||
|
'skip_to_stage': None,
|
||||||
|
}),
|
||||||
|
'generation_config': dict({
|
||||||
|
'bin_log_width': 0.25,
|
||||||
|
'counts': list([
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
]),
|
||||||
|
'num_bin_time_series': 25,
|
||||||
|
'num_replicas': 2,
|
||||||
|
'orientations': list([
|
||||||
|
<Orientation.RANDOM: 'RANDOM'>,
|
||||||
|
<Orientation.Z: 'Z'>,
|
||||||
|
<Orientation.XY: 'XY'>,
|
||||||
|
]),
|
||||||
|
'override_dipole_configs': dict({
|
||||||
|
'scenario1': list([
|
||||||
|
dict({
|
||||||
|
'p': array([3, 5, 7]),
|
||||||
|
's': array([2, 4, 6]),
|
||||||
|
'w': 10,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'p': array([30, 50, 70]),
|
||||||
|
's': array([20, 40, 60]),
|
||||||
|
'w': 10.55,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'override_measurement_filesets': None,
|
||||||
|
'tantri_configs': list([
|
||||||
|
dict({
|
||||||
|
'delta_t': 0.01,
|
||||||
|
'index_seed_starter': 15151,
|
||||||
|
'num_iterations': 100,
|
||||||
|
'num_seeds': 5,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'delta_t': 1,
|
||||||
|
'index_seed_starter': 1234,
|
||||||
|
'num_iterations': 200,
|
||||||
|
'num_seeds': 100,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_parse_config_geom_params_toml_as_dict
|
||||||
|
dict({
|
||||||
|
'deepdog_config': dict({
|
||||||
|
'costs_to_try': list([
|
||||||
|
5,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
0.5,
|
||||||
|
0.2,
|
||||||
|
]),
|
||||||
|
'max_monte_carlo_cycles_steps': 20,
|
||||||
|
'target_success': 2000,
|
||||||
|
'use_log_noise': True,
|
||||||
|
}),
|
||||||
|
'default_model_param_config': dict({
|
||||||
|
'w_log_max': 1.5,
|
||||||
|
'w_log_min': -3,
|
||||||
|
'x_max': 20,
|
||||||
|
'x_min': -20,
|
||||||
|
'y_max': 10,
|
||||||
|
'y_min': -10,
|
||||||
|
'z_max': 2,
|
||||||
|
'z_min': 0,
|
||||||
|
}),
|
||||||
|
'general_config': dict({
|
||||||
|
'dots_json_name': 'dots.json',
|
||||||
|
'indexes_json_name': 'indexes.json',
|
||||||
|
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||||
|
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||||
|
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
|
||||||
|
'mega_merged_name': 'mega_merged_coalesced.csv',
|
||||||
|
'out_dir_name': 'out',
|
||||||
|
'root_directory': PosixPath('test_root1'),
|
||||||
|
'skip_to_stage': None,
|
||||||
|
}),
|
||||||
|
'generation_config': dict({
|
||||||
|
'bin_log_width': 0.25,
|
||||||
|
'counts': list([
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
]),
|
||||||
|
'num_bin_time_series': 25,
|
||||||
|
'num_replicas': 2,
|
||||||
|
'orientations': list([
|
||||||
|
<Orientation.RANDOM: 'RANDOM'>,
|
||||||
|
<Orientation.Z: 'Z'>,
|
||||||
|
<Orientation.XY: 'XY'>,
|
||||||
|
]),
|
||||||
|
'override_dipole_configs': dict({
|
||||||
|
'scenario1': list([
|
||||||
|
dict({
|
||||||
|
'p': array([3, 5, 7]),
|
||||||
|
's': array([2, 4, 6]),
|
||||||
|
'w': 10,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'p': array([30, 50, 70]),
|
||||||
|
's': array([20, 40, 60]),
|
||||||
|
'w': 10.55,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'override_measurement_filesets': None,
|
||||||
|
'tantri_configs': list([
|
||||||
|
dict({
|
||||||
|
'delta_t': 0.01,
|
||||||
|
'index_seed_starter': 15151,
|
||||||
|
'num_iterations': 100,
|
||||||
|
'num_seeds': 5,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'delta_t': 1,
|
||||||
|
'index_seed_starter': 1234,
|
||||||
|
'num_iterations': 200,
|
||||||
|
'num_seeds': 100,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_parse_config_toml_as_dict
|
||||||
|
dict({
|
||||||
|
'deepdog_config': dict({
|
||||||
|
'costs_to_try': list([
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
0.1,
|
||||||
|
]),
|
||||||
|
'max_monte_carlo_cycles_steps': 20,
|
||||||
|
'target_success': 1000,
|
||||||
|
'use_log_noise': False,
|
||||||
|
}),
|
||||||
|
'default_model_param_config': dict({
|
||||||
|
'w_log_max': 1,
|
||||||
|
'w_log_min': -5,
|
||||||
|
'x_max': 20,
|
||||||
|
'x_min': -20,
|
||||||
|
'y_max': 10,
|
||||||
|
'y_min': -10,
|
||||||
|
'z_max': 6.5,
|
||||||
|
'z_min': 5,
|
||||||
|
}),
|
||||||
|
'general_config': dict({
|
||||||
|
'dots_json_name': 'test_dots.json',
|
||||||
|
'indexes_json_name': 'test_indexes.json',
|
||||||
|
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||||
|
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
|
||||||
|
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
|
||||||
|
'mega_merged_name': 'test_mega_merged.csv',
|
||||||
|
'out_dir_name': 'test_out',
|
||||||
|
'root_directory': PosixPath('test_root'),
|
||||||
|
'skip_to_stage': 1,
|
||||||
|
}),
|
||||||
|
'generation_config': dict({
|
||||||
|
'bin_log_width': 0.25,
|
||||||
|
'counts': list([
|
||||||
|
1,
|
||||||
|
10,
|
||||||
|
]),
|
||||||
|
'num_bin_time_series': 25,
|
||||||
|
'num_replicas': 3,
|
||||||
|
'orientations': list([
|
||||||
|
<Orientation.RANDOM: 'RANDOM'>,
|
||||||
|
<Orientation.Z: 'Z'>,
|
||||||
|
<Orientation.XY: 'XY'>,
|
||||||
|
]),
|
||||||
|
'override_dipole_configs': None,
|
||||||
|
'override_measurement_filesets': None,
|
||||||
|
'tantri_configs': list([
|
||||||
|
dict({
|
||||||
|
'delta_t': 0.05,
|
||||||
|
'index_seed_starter': 31415,
|
||||||
|
'num_iterations': 100000,
|
||||||
|
'num_seeds': 100,
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
@ -1,13 +1,13 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_parse_config_all_fields_toml
|
# name: test_parse_config_all_fields_toml
|
||||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||||
# ---
|
# ---
|
||||||
# name: test_parse_config_few_fields_toml
|
# name: test_parse_config_few_fields_toml
|
||||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||||
# ---
|
# ---
|
||||||
# name: test_parse_config_geom_params_toml
|
# name: test_parse_config_geom_params_toml
|
||||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5))
|
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5))
|
||||||
# ---
|
# ---
|
||||||
# name: test_parse_config_toml
|
# name: test_parse_config_toml
|
||||||
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, override_measurements=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||||
# ---
|
# ---
|
||||||
|
35
tests/config/test_toml_as_dict_snaps.py
Normal file
35
tests/config/test_toml_as_dict_snaps.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import pathlib
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
import kalpaa.config.config_reader
|
||||||
|
|
||||||
|
|
||||||
|
TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_config_toml_as_dict(snapshot):
|
||||||
|
test_config_file = TEST_DATA_DIR / "test_config.toml"
|
||||||
|
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||||
|
|
||||||
|
assert asdict(actual_config) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_config_all_fields_toml_as_dict(snapshot):
|
||||||
|
test_config_file = TEST_DATA_DIR / "test_config_all_fields.toml"
|
||||||
|
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||||
|
|
||||||
|
assert asdict(actual_config) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_config_few_fields_toml_as_dict(snapshot):
|
||||||
|
test_config_file = TEST_DATA_DIR / "test_config_few_fields.toml"
|
||||||
|
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||||
|
|
||||||
|
assert asdict(actual_config) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_config_geom_params_toml_as_dict(snapshot):
|
||||||
|
test_config_file = TEST_DATA_DIR / "test_config_geom_params.toml"
|
||||||
|
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||||
|
|
||||||
|
assert asdict(actual_config) == snapshot
|
Loading…
x
Reference in New Issue
Block a user