96 lines
2.6 KiB
Python
96 lines
2.6 KiB
Python
import tomli
|
|
import pathlib
|
|
from kalpaa.config import (
|
|
Config,
|
|
# GenerationConfig,
|
|
GeneralConfig,
|
|
# DeepdogConfig,
|
|
MeasurementTypeEnum,
|
|
)
|
|
import tantri.dipoles.types
|
|
import dacite
|
|
import numpy
|
|
|
|
import logging
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
_common_dacite_config = dacite.Config(
|
|
strict=True,
|
|
type_hooks={numpy.ndarray: numpy.array},
|
|
cast=[MeasurementTypeEnum, pathlib.Path, tantri.dipoles.types.Orientation],
|
|
)
|
|
|
|
|
|
def read_general_config_dict(general_config_dict: dict) -> GeneralConfig:
|
|
"""
|
|
Converts a dictionary to a GeneralConfig object
|
|
|
|
:param general_config_dict: dictionary containing general config values
|
|
:return: GeneralConfig object
|
|
"""
|
|
general_config = dacite.from_dict(
|
|
data_class=GeneralConfig,
|
|
data=general_config_dict,
|
|
config=_common_dacite_config,
|
|
)
|
|
return general_config
|
|
|
|
|
|
def read_config_dict(file_path: pathlib.Path) -> dict:
|
|
"""
|
|
Reads a TOML file and returns the contents as a dictionary
|
|
|
|
:param file_path: path to the TOML file
|
|
:return: dictionary containing the config values
|
|
"""
|
|
_logger.debug(f"Reading config from {file_path=}")
|
|
with open(file_path, "rb") as toml_file:
|
|
config_dict = tomli.load(toml_file)
|
|
return config_dict
|
|
|
|
|
|
def serialize_config(config_dict: dict) -> Config:
|
|
"""
|
|
Converts a dictionary to a Config object
|
|
|
|
Makes assumptions about structure of the config_dict, so validation should happen here too if needed.
|
|
|
|
:param config_dict: dictionary containing config values
|
|
:return: Config object
|
|
"""
|
|
# generation_config = GenerationConfig(**config_dict["generation_config"])
|
|
|
|
# general_config_dict = config_dict["general_config"]
|
|
# general_config = GeneralConfig(
|
|
# root_directory=general_config_dict["root_directory"],
|
|
# out_dir_name=general_config_dict["out_dir_name"],
|
|
# dots_json_name=general_config_dict["dots_json_name"],
|
|
# mega_merged_name=general_config_dict["mega_merged_name"],
|
|
# mega_merged_inferenced_name=general_config_dict["mega_merged_inferenced_name"],
|
|
# skip_to_stage=general_config_dict["skip_to_stage"],
|
|
# measurement_type=MeasurementTypeEnum(general_config_dict["measurement_type"]),
|
|
# indexes_json_name=general_config_dict["indexes_json_name"],
|
|
# log_pattern=general_config_dict["log_pattern"],
|
|
# )
|
|
|
|
# deepdog_config = DeepdogConfig(**config_dict["deepdog_config"])
|
|
# config = Config(
|
|
# generation_config=generation_config,
|
|
# general_config=general_config,
|
|
# deepdog_config=deepdog_config,
|
|
# )
|
|
config = dacite.from_dict(
|
|
data_class=Config,
|
|
data=config_dict,
|
|
config=_common_dacite_config,
|
|
)
|
|
_logger.warning(config)
|
|
|
|
return config
|
|
|
|
|
|
def read_config(file_path: pathlib.Path) -> Config:
|
|
config_dict = read_config_dict(file_path)
|
|
return serialize_config(config_dict)
|