feat: adds ability to read from config file

This commit is contained in:
2025-02-22 01:53:03 -06:00
parent 952e41cde6
commit 4729543c4e
12 changed files with 298 additions and 5 deletions

View File

@@ -25,4 +25,8 @@ costs
# Procedure
- `run.sh`
- `clean.sh` to clean things
- `clean.sh` to clean things
# style
Note we prefer tabs to spaces, but our autoformatter can handle that for us.

View File

@@ -17,6 +17,17 @@ checknix:
echo "Using poetry as runner, no nix detected."
fi
# update all test snapshots, use if snapshots are out of date
update-snapshots:
#!/usr/bin/env bash
set -euxo pipefail
if [[ "${DO_NIX_CUSTOM:=0}" -eq 1 ]]; then
pytest --snapshot-update
else
poetry run pytest --snapshot-update
fi
# run all tests
test: fmt
#!/usr/bin/env bash

View File

@@ -71,7 +71,7 @@ class GenerationConfig:
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
] = None
tantri_configs: typing.Sequence[TantriConfig] = field(
tantri_configs: typing.List[TantriConfig] = field(
default_factory=lambda: [TantriConfig()]
)

View File

@@ -0,0 +1,95 @@
import tomli
import pathlib
from kalpaa.config import (
Config,
# GenerationConfig,
GeneralConfig,
# DeepdogConfig,
MeasurementTypeEnum,
)
import tantri.dipoles.types
import dacite
import numpy
import logging
_logger = logging.getLogger(__name__)
_common_dacite_config = dacite.Config(
strict=True,
type_hooks={numpy.ndarray: numpy.array},
cast=[MeasurementTypeEnum, pathlib.Path, tantri.dipoles.types.Orientation],
)
def read_general_config_dict(general_config_dict: dict) -> GeneralConfig:
"""
Converts a dictionary to a GeneralConfig object
:param general_config_dict: dictionary containing general config values
:return: GeneralConfig object
"""
general_config = dacite.from_dict(
data_class=GeneralConfig,
data=general_config_dict,
config=_common_dacite_config,
)
return general_config
def read_config_dict(file_path: pathlib.Path) -> dict:
"""
Reads a TOML file and returns the contents as a dictionary
:param file_path: path to the TOML file
:return: dictionary containing the config values
"""
_logger.debug(f"Reading config from {file_path=}")
with open(file_path, "rb") as toml_file:
config_dict = tomli.load(toml_file)
return config_dict
def serialize_config(config_dict: dict) -> Config:
"""
Converts a dictionary to a Config object
Makes assumptions about structure of the config_dict, so validation should happen here too if needed.
:param config_dict: dictionary containing config values
:return: Config object
"""
# generation_config = GenerationConfig(**config_dict["generation_config"])
# general_config_dict = config_dict["general_config"]
# general_config = GeneralConfig(
# root_directory=general_config_dict["root_directory"],
# out_dir_name=general_config_dict["out_dir_name"],
# dots_json_name=general_config_dict["dots_json_name"],
# mega_merged_name=general_config_dict["mega_merged_name"],
# mega_merged_inferenced_name=general_config_dict["mega_merged_inferenced_name"],
# skip_to_stage=general_config_dict["skip_to_stage"],
# measurement_type=MeasurementTypeEnum(general_config_dict["measurement_type"]),
# indexes_json_name=general_config_dict["indexes_json_name"],
# log_pattern=general_config_dict["log_pattern"],
# )
# deepdog_config = DeepdogConfig(**config_dict["deepdog_config"])
# config = Config(
# generation_config=generation_config,
# general_config=general_config,
# deepdog_config=deepdog_config,
# )
config = dacite.from_dict(
data_class=Config,
data=config_dict,
config=_common_dacite_config,
)
_logger.warning(config)
return config
def read_config(file_path: pathlib.Path) -> Config:
config_dict = read_config_dict(file_path)
return serialize_config(config_dict)

16
poetry.lock generated
View File

@@ -158,6 +158,20 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1
[package.extras]
toml = ["tomli"]
[[package]]
name = "dacite"
version = "1.9.2"
description = "Simple creation of data classes from dictionaries."
optional = false
python-versions = ">=3.7"
files = [
{file = "dacite-1.9.2-py3-none-any.whl", hash = "sha256:053f7c3f5128ca2e9aceb66892b1a3c8936d02c686e707bee96e19deef4bc4a0"},
{file = "dacite-1.9.2.tar.gz", hash = "sha256:6ccc3b299727c7aa17582f0021f6ae14d5de47c7227932c47fec4cdfefd26f09"},
]
[package.extras]
dev = ["black", "coveralls", "mypy", "pre-commit", "pylint", "pytest (>=5)", "pytest-benchmark", "pytest-cov"]
[[package]]
name = "deepdog"
version = "1.5.0"
@@ -600,4 +614,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<3.10"
content-hash = "b1e6dc0ba0d669dc7beba8f0018c2a9ff4a93e6093e179a561c8419b30ad35c5"
content-hash = "3c7c7ffbb6d6e8456ff887554c046a068772f637335810afed36f23556fd8966"

View File

@@ -8,11 +8,11 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<3.10"
pdme = "^1.5.0"
deepdog = "^1.5.0"
tantri = "^1.3.0"
tomli = "^2.0.1"
dacite = "^1.9.2"
[tool.poetry.group.dev.dependencies]
black = "^24.8.0"
flake8 = "^4.0.1"

0
tests/config/__init__.py Normal file
View File

View File

@@ -0,0 +1,10 @@
# serializer version: 1
# name: test_parse_config_all_fields_toml
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True))
# ---
# name: test_parse_config_few_fields_toml
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True))
# ---
# name: test_parse_config_toml
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False))
# ---

View File

@@ -0,0 +1,22 @@
[general_config]
dots_json_name = "test_dots.json"
indexes_json_name = "test_indexes.json"
out_dir_name = "test_out"
measurement_type = "x-electric-field"
root_directory = "test_root"
mega_merged_name = "test_mega_merged.csv"
mega_merged_inferenced_name = "test_mega_merged_inferenced.csv"
skip_to_stage = 1
[generation_config]
counts = [1, 10]
orientations = ["RANDOM", "Z", "XY"]
num_replicas = 3
num_bin_time_series = 25
bin_log_width = 0.25
[deepdog_config]
costs_to_try = [10, 1, 0.1]
target_success = 1000
max_monte_carlo_cycles_steps = 20
use_log_noise = false

View File

@@ -0,0 +1,33 @@
[general_config]
dots_json_name = "test_dots.json"
indexes_json_name = "test_indexes.json"
out_dir_name = "out"
measurement_type = "x-electric-field"
root_directory = "test_root"
mega_merged_name = "test_mega_merged.csv"
mega_merged_inferenced_name = "test_mega_merged_inferenced.csv"
skip_to_stage = 1
log_pattern = "%(asctime)s | %(message)s"
[generation_config]
counts = [1, 5, 10]
orientations = ["RANDOM", "Z", "XY"]
num_replicas = 3
num_bin_time_series = 25
bin_log_width = 0.25
tantri_configs = [
{index_seed_starter = 15151, num_seeds = 5, delta_t = 0.01, num_iterations = 100},
{index_seed_starter = 1234, num_seeds = 100, delta_t = 1, num_iterations = 200}
]
[generation_config.override_dipole_configs]
scenario1 = [
{p = [3, 5, 7], s = [2, 4, 6], w = 10},
{p = [30, 50, 70], s = [20, 40, 60], w = 10.55},
]
[deepdog_config]
costs_to_try = [20, 2, 0.2]
target_success = 2000
max_monte_carlo_cycles_steps = 20
use_log_noise = true

View File

@@ -0,0 +1,22 @@
[general_config]
root_directory = "test_root1"
measurement_type = "electric-potential"
[generation_config]
counts = [1, 5, 10]
num_replicas = 2
tantri_configs = [
{index_seed_starter = 15151, num_seeds = 5, delta_t = 0.01, num_iterations = 100},
{index_seed_starter = 1234, num_seeds = 100, delta_t = 1, num_iterations = 200}
]
[generation_config.override_dipole_configs]
scenario1 = [
{p = [3, 5, 7], s = [2, 4, 6], w = 10},
{p = [30, 50, 70], s = [20, 40, 60], w = 10.55},
]
[deepdog_config]
costs_to_try = [5, 2, 1, 0.5, 0.2]
target_success = 2000
use_log_noise = true

View File

@@ -0,0 +1,82 @@
import pathlib
from kalpaa.config import MeasurementTypeEnum
import kalpaa.config.config_reader
TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files"
def test_parse_general_config_dict():
general_config_dict = {
"dots_json_name": "test_dots.json",
"indexes_json_name": "test_indexes.json",
"out_dir_name": "test_out",
"log_pattern": "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s",
"measurement_type": "x-electric-field",
"root_directory": "test_root",
"mega_merged_name": "test_mega_merged.csv",
"mega_merged_inferenced_name": "test_mega_merged_inferenced.csv",
"skip_to_stage": 1,
}
general_config = kalpaa.config.config_reader.read_general_config_dict(
general_config_dict
)
expected_general_config = kalpaa.config.GeneralConfig(
dots_json_name="test_dots.json",
indexes_json_name="test_indexes.json",
out_dir_name="test_out",
log_pattern="%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s",
measurement_type=MeasurementTypeEnum.X_ELECTRIC_FIELD,
root_directory=pathlib.Path("test_root"),
mega_merged_name="test_mega_merged.csv",
mega_merged_inferenced_name="test_mega_merged_inferenced.csv",
skip_to_stage=1,
)
assert general_config == expected_general_config
def test_parse_empty_general_config_dict():
general_config_dict = {}
general_config = kalpaa.config.config_reader.read_general_config_dict(
general_config_dict
)
expected_general_config = kalpaa.config.GeneralConfig(
dots_json_name="dots.json",
indexes_json_name="indexes.json",
out_dir_name="out",
log_pattern="%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s",
measurement_type=MeasurementTypeEnum.X_ELECTRIC_FIELD,
root_directory=pathlib.Path.cwd(),
mega_merged_name="mega_merged_coalesced.csv",
mega_merged_inferenced_name="mega_merged_coalesced_inferenced.csv",
skip_to_stage=None,
)
assert general_config == expected_general_config
def test_parse_config_toml(snapshot):
test_config_file = TEST_DATA_DIR / "test_config.toml"
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
assert actual_config == snapshot
def test_parse_config_all_fields_toml(snapshot):
test_config_file = TEST_DATA_DIR / "test_config_all_fields.toml"
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
assert actual_config == snapshot
def test_parse_config_few_fields_toml(snapshot):
test_config_file = TEST_DATA_DIR / "test_config_few_fields.toml"
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
assert actual_config == snapshot