kalpa/kalpaa/config/config_reader.py

96 lines
2.6 KiB
Python

import tomli
import pathlib
from kalpaa.config import (
Config,
# GenerationConfig,
GeneralConfig,
# DeepdogConfig,
MeasurementTypeEnum,
)
import tantri.dipoles.types
import dacite
import numpy
import logging
_logger = logging.getLogger(__name__)
_common_dacite_config = dacite.Config(
strict=True,
type_hooks={numpy.ndarray: numpy.array},
cast=[MeasurementTypeEnum, pathlib.Path, tantri.dipoles.types.Orientation],
)
def read_general_config_dict(general_config_dict: dict) -> GeneralConfig:
"""
Converts a dictionary to a GeneralConfig object
:param general_config_dict: dictionary containing general config values
:return: GeneralConfig object
"""
general_config = dacite.from_dict(
data_class=GeneralConfig,
data=general_config_dict,
config=_common_dacite_config,
)
return general_config
def read_config_dict(file_path: pathlib.Path) -> dict:
"""
Reads a TOML file and returns the contents as a dictionary
:param file_path: path to the TOML file
:return: dictionary containing the config values
"""
_logger.debug(f"Reading config from {file_path=}")
with open(file_path, "rb") as toml_file:
config_dict = tomli.load(toml_file)
return config_dict
def serialize_config(config_dict: dict) -> Config:
"""
Converts a dictionary to a Config object
Makes assumptions about structure of the config_dict, so validation should happen here too if needed.
:param config_dict: dictionary containing config values
:return: Config object
"""
# generation_config = GenerationConfig(**config_dict["generation_config"])
# general_config_dict = config_dict["general_config"]
# general_config = GeneralConfig(
# root_directory=general_config_dict["root_directory"],
# out_dir_name=general_config_dict["out_dir_name"],
# dots_json_name=general_config_dict["dots_json_name"],
# mega_merged_name=general_config_dict["mega_merged_name"],
# mega_merged_inferenced_name=general_config_dict["mega_merged_inferenced_name"],
# skip_to_stage=general_config_dict["skip_to_stage"],
# measurement_type=MeasurementTypeEnum(general_config_dict["measurement_type"]),
# indexes_json_name=general_config_dict["indexes_json_name"],
# log_pattern=general_config_dict["log_pattern"],
# )
# deepdog_config = DeepdogConfig(**config_dict["deepdog_config"])
# config = Config(
# generation_config=generation_config,
# general_config=general_config,
# deepdog_config=deepdog_config,
# )
config = dacite.from_dict(
data_class=Config,
data=config_dict,
config=_common_dacite_config,
)
_logger.warning(config)
return config
def read_config(file_path: pathlib.Path) -> Config:
config_dict = read_config_dict(file_path)
return serialize_config(config_dict)