feat: adds ability to read from config file
This commit is contained in:
@@ -25,4 +25,8 @@ costs
|
||||
|
||||
# Procedure
|
||||
- `run.sh`
|
||||
- `clean.sh` to clean things
|
||||
- `clean.sh` to clean things
|
||||
|
||||
# style
|
||||
|
||||
Note we prefer tabs to spaces, but our autoformatter can handle that for us.
|
||||
|
||||
11
justfile
11
justfile
@@ -17,6 +17,17 @@ checknix:
|
||||
echo "Using poetry as runner, no nix detected."
|
||||
fi
|
||||
|
||||
# update all test snapshots, use if snapshots are out of date
|
||||
update-snapshots:
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
if [[ "${DO_NIX_CUSTOM:=0}" -eq 1 ]]; then
|
||||
pytest --snapshot-update
|
||||
else
|
||||
poetry run pytest --snapshot-update
|
||||
fi
|
||||
|
||||
|
||||
# run all tests
|
||||
test: fmt
|
||||
#!/usr/bin/env bash
|
||||
|
||||
@@ -71,7 +71,7 @@ class GenerationConfig:
|
||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||
] = None
|
||||
|
||||
tantri_configs: typing.Sequence[TantriConfig] = field(
|
||||
tantri_configs: typing.List[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
)
|
||||
|
||||
|
||||
95
kalpaa/config/config_reader.py
Normal file
95
kalpaa/config/config_reader.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import tomli
|
||||
import pathlib
|
||||
from kalpaa.config import (
|
||||
Config,
|
||||
# GenerationConfig,
|
||||
GeneralConfig,
|
||||
# DeepdogConfig,
|
||||
MeasurementTypeEnum,
|
||||
)
|
||||
import tantri.dipoles.types
|
||||
import dacite
|
||||
import numpy
|
||||
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
_common_dacite_config = dacite.Config(
|
||||
strict=True,
|
||||
type_hooks={numpy.ndarray: numpy.array},
|
||||
cast=[MeasurementTypeEnum, pathlib.Path, tantri.dipoles.types.Orientation],
|
||||
)
|
||||
|
||||
|
||||
def read_general_config_dict(general_config_dict: dict) -> GeneralConfig:
|
||||
"""
|
||||
Converts a dictionary to a GeneralConfig object
|
||||
|
||||
:param general_config_dict: dictionary containing general config values
|
||||
:return: GeneralConfig object
|
||||
"""
|
||||
general_config = dacite.from_dict(
|
||||
data_class=GeneralConfig,
|
||||
data=general_config_dict,
|
||||
config=_common_dacite_config,
|
||||
)
|
||||
return general_config
|
||||
|
||||
|
||||
def read_config_dict(file_path: pathlib.Path) -> dict:
|
||||
"""
|
||||
Reads a TOML file and returns the contents as a dictionary
|
||||
|
||||
:param file_path: path to the TOML file
|
||||
:return: dictionary containing the config values
|
||||
"""
|
||||
_logger.debug(f"Reading config from {file_path=}")
|
||||
with open(file_path, "rb") as toml_file:
|
||||
config_dict = tomli.load(toml_file)
|
||||
return config_dict
|
||||
|
||||
|
||||
def serialize_config(config_dict: dict) -> Config:
|
||||
"""
|
||||
Converts a dictionary to a Config object
|
||||
|
||||
Makes assumptions about structure of the config_dict, so validation should happen here too if needed.
|
||||
|
||||
:param config_dict: dictionary containing config values
|
||||
:return: Config object
|
||||
"""
|
||||
# generation_config = GenerationConfig(**config_dict["generation_config"])
|
||||
|
||||
# general_config_dict = config_dict["general_config"]
|
||||
# general_config = GeneralConfig(
|
||||
# root_directory=general_config_dict["root_directory"],
|
||||
# out_dir_name=general_config_dict["out_dir_name"],
|
||||
# dots_json_name=general_config_dict["dots_json_name"],
|
||||
# mega_merged_name=general_config_dict["mega_merged_name"],
|
||||
# mega_merged_inferenced_name=general_config_dict["mega_merged_inferenced_name"],
|
||||
# skip_to_stage=general_config_dict["skip_to_stage"],
|
||||
# measurement_type=MeasurementTypeEnum(general_config_dict["measurement_type"]),
|
||||
# indexes_json_name=general_config_dict["indexes_json_name"],
|
||||
# log_pattern=general_config_dict["log_pattern"],
|
||||
# )
|
||||
|
||||
# deepdog_config = DeepdogConfig(**config_dict["deepdog_config"])
|
||||
# config = Config(
|
||||
# generation_config=generation_config,
|
||||
# general_config=general_config,
|
||||
# deepdog_config=deepdog_config,
|
||||
# )
|
||||
config = dacite.from_dict(
|
||||
data_class=Config,
|
||||
data=config_dict,
|
||||
config=_common_dacite_config,
|
||||
)
|
||||
_logger.warning(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def read_config(file_path: pathlib.Path) -> Config:
|
||||
config_dict = read_config_dict(file_path)
|
||||
return serialize_config(config_dict)
|
||||
16
poetry.lock
generated
16
poetry.lock
generated
@@ -158,6 +158,20 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1
|
||||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "dacite"
|
||||
version = "1.9.2"
|
||||
description = "Simple creation of data classes from dictionaries."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0"},
|
||||
{file = "dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "coveralls", "mypy", "pre-commit", "pylint", "pytest (>=5)", "pytest-benchmark", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "deepdog"
|
||||
version = "1.5.0"
|
||||
@@ -600,4 +614,4 @@ files = [
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<3.10"
|
||||
content-hash = "b1e6dc0ba0d669dc7beba8f0018c2a9ff4a93e6093e179a561c8419b30ad35c5"
|
||||
content-hash = "3c7c7ffbb6d6e8456ff887554c046a068772f637335810afed36f23556fd8966"
|
||||
|
||||
@@ -8,11 +8,11 @@ readme = "README.md"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<3.10"
|
||||
pdme = "^1.5.0"
|
||||
|
||||
|
||||
deepdog = "^1.5.0"
|
||||
tantri = "^1.3.0"
|
||||
tomli = "^2.0.1"
|
||||
dacite = "^1.9.2"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.8.0"
|
||||
flake8 = "^4.0.1"
|
||||
|
||||
0
tests/config/__init__.py
Normal file
0
tests/config/__init__.py
Normal file
10
tests/config/__snapshots__/test_toml_reader.ambr
Normal file
10
tests/config/__snapshots__/test_toml_reader.ambr
Normal file
@@ -0,0 +1,10 @@
|
||||
# serializer version: 1
|
||||
# name: test_parse_config_all_fields_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True))
|
||||
# ---
|
||||
# name: test_parse_config_few_fields_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True))
|
||||
# ---
|
||||
# name: test_parse_config_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False))
|
||||
# ---
|
||||
22
tests/config/test_files/test_config.toml
Normal file
22
tests/config/test_files/test_config.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[general_config]
|
||||
dots_json_name = "test_dots.json"
|
||||
indexes_json_name = "test_indexes.json"
|
||||
out_dir_name = "test_out"
|
||||
measurement_type = "x-electric-field"
|
||||
root_directory = "test_root"
|
||||
mega_merged_name = "test_mega_merged.csv"
|
||||
mega_merged_inferenced_name = "test_mega_merged_inferenced.csv"
|
||||
skip_to_stage = 1
|
||||
|
||||
[generation_config]
|
||||
counts = [1, 10]
|
||||
orientations = ["RANDOM", "Z", "XY"]
|
||||
num_replicas = 3
|
||||
num_bin_time_series = 25
|
||||
bin_log_width = 0.25
|
||||
|
||||
[deepdog_config]
|
||||
costs_to_try = [10, 1, 0.1]
|
||||
target_success = 1000
|
||||
max_monte_carlo_cycles_steps = 20
|
||||
use_log_noise = false
|
||||
33
tests/config/test_files/test_config_all_fields.toml
Normal file
33
tests/config/test_files/test_config_all_fields.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[general_config]
|
||||
dots_json_name = "test_dots.json"
|
||||
indexes_json_name = "test_indexes.json"
|
||||
out_dir_name = "out"
|
||||
measurement_type = "x-electric-field"
|
||||
root_directory = "test_root"
|
||||
mega_merged_name = "test_mega_merged.csv"
|
||||
mega_merged_inferenced_name = "test_mega_merged_inferenced.csv"
|
||||
skip_to_stage = 1
|
||||
log_pattern = "%(asctime)s | %(message)s"
|
||||
|
||||
[generation_config]
|
||||
counts = [1, 5, 10]
|
||||
orientations = ["RANDOM", "Z", "XY"]
|
||||
num_replicas = 3
|
||||
num_bin_time_series = 25
|
||||
bin_log_width = 0.25
|
||||
tantri_configs = [
|
||||
{index_seed_starter = 15151, num_seeds = 5, delta_t = 0.01, num_iterations = 100},
|
||||
{index_seed_starter = 1234, num_seeds = 100, delta_t = 1, num_iterations = 200}
|
||||
]
|
||||
|
||||
[generation_config.override_dipole_configs]
|
||||
scenario1 = [
|
||||
{p = [3, 5, 7], s = [2, 4, 6], w = 10},
|
||||
{p = [30, 50, 70], s = [20, 40, 60], w = 10.55},
|
||||
]
|
||||
|
||||
[deepdog_config]
|
||||
costs_to_try = [20, 2, 0.2]
|
||||
target_success = 2000
|
||||
max_monte_carlo_cycles_steps = 20
|
||||
use_log_noise = true
|
||||
22
tests/config/test_files/test_config_few_fields.toml
Normal file
22
tests/config/test_files/test_config_few_fields.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[general_config]
|
||||
root_directory = "test_root1"
|
||||
measurement_type = "electric-potential"
|
||||
|
||||
[generation_config]
|
||||
counts = [1, 5, 10]
|
||||
num_replicas = 2
|
||||
tantri_configs = [
|
||||
{index_seed_starter = 15151, num_seeds = 5, delta_t = 0.01, num_iterations = 100},
|
||||
{index_seed_starter = 1234, num_seeds = 100, delta_t = 1, num_iterations = 200}
|
||||
]
|
||||
|
||||
[generation_config.override_dipole_configs]
|
||||
scenario1 = [
|
||||
{p = [3, 5, 7], s = [2, 4, 6], w = 10},
|
||||
{p = [30, 50, 70], s = [20, 40, 60], w = 10.55},
|
||||
]
|
||||
|
||||
[deepdog_config]
|
||||
costs_to_try = [5, 2, 1, 0.5, 0.2]
|
||||
target_success = 2000
|
||||
use_log_noise = true
|
||||
82
tests/config/test_toml_reader.py
Normal file
82
tests/config/test_toml_reader.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pathlib
|
||||
|
||||
from kalpaa.config import MeasurementTypeEnum
|
||||
import kalpaa.config.config_reader
|
||||
|
||||
|
||||
TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files"
|
||||
|
||||
|
||||
def test_parse_general_config_dict():
|
||||
general_config_dict = {
|
||||
"dots_json_name": "test_dots.json",
|
||||
"indexes_json_name": "test_indexes.json",
|
||||
"out_dir_name": "test_out",
|
||||
"log_pattern": "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s",
|
||||
"measurement_type": "x-electric-field",
|
||||
"root_directory": "test_root",
|
||||
"mega_merged_name": "test_mega_merged.csv",
|
||||
"mega_merged_inferenced_name": "test_mega_merged_inferenced.csv",
|
||||
"skip_to_stage": 1,
|
||||
}
|
||||
|
||||
general_config = kalpaa.config.config_reader.read_general_config_dict(
|
||||
general_config_dict
|
||||
)
|
||||
|
||||
expected_general_config = kalpaa.config.GeneralConfig(
|
||||
dots_json_name="test_dots.json",
|
||||
indexes_json_name="test_indexes.json",
|
||||
out_dir_name="test_out",
|
||||
log_pattern="%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s",
|
||||
measurement_type=MeasurementTypeEnum.X_ELECTRIC_FIELD,
|
||||
root_directory=pathlib.Path("test_root"),
|
||||
mega_merged_name="test_mega_merged.csv",
|
||||
mega_merged_inferenced_name="test_mega_merged_inferenced.csv",
|
||||
skip_to_stage=1,
|
||||
)
|
||||
|
||||
assert general_config == expected_general_config
|
||||
|
||||
|
||||
def test_parse_empty_general_config_dict():
|
||||
general_config_dict = {}
|
||||
|
||||
general_config = kalpaa.config.config_reader.read_general_config_dict(
|
||||
general_config_dict
|
||||
)
|
||||
|
||||
expected_general_config = kalpaa.config.GeneralConfig(
|
||||
dots_json_name="dots.json",
|
||||
indexes_json_name="indexes.json",
|
||||
out_dir_name="out",
|
||||
log_pattern="%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s",
|
||||
measurement_type=MeasurementTypeEnum.X_ELECTRIC_FIELD,
|
||||
root_directory=pathlib.Path.cwd(),
|
||||
mega_merged_name="mega_merged_coalesced.csv",
|
||||
mega_merged_inferenced_name="mega_merged_coalesced_inferenced.csv",
|
||||
skip_to_stage=None,
|
||||
)
|
||||
|
||||
assert general_config == expected_general_config
|
||||
|
||||
|
||||
def test_parse_config_toml(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert actual_config == snapshot
|
||||
|
||||
|
||||
def test_parse_config_all_fields_toml(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_all_fields.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert actual_config == snapshot
|
||||
|
||||
|
||||
def test_parse_config_few_fields_toml(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_few_fields.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert actual_config == snapshot
|
||||
Reference in New Issue
Block a user