feat: Now can run through config file, with smarter label
This commit is contained in:
22
kalpaa.toml
Normal file
22
kalpaa.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[general_config]
|
||||
root_directory = "out"
|
||||
measurement_type = "electric-potential"
|
||||
|
||||
[generation_config]
|
||||
counts = [1, 5, 10]
|
||||
num_replicas = 2
|
||||
tantri_configs = [
|
||||
{index_seed_starter = 15151, num_seeds = 5, delta_t = 0.01, num_iterations = 100},
|
||||
{index_seed_starter = 1234, num_seeds = 100, delta_t = 1, num_iterations = 200}
|
||||
]
|
||||
|
||||
[generation_config.override_dipole_configs]
|
||||
scenario1 = [
|
||||
{p = [3, 5, 7], s = [2, 4, 6], w = 10},
|
||||
{p = [30, 50, 70], s = [20, 40, 60], w = 10.55},
|
||||
]
|
||||
|
||||
[deepdog_config]
|
||||
costs_to_try = [5, 2, 1, 0.5, 0.2]
|
||||
target_success = 2000
|
||||
use_log_noise = true
|
||||
@@ -1,163 +1,31 @@
|
||||
import json
|
||||
import deepdog.indexify
|
||||
from dataclasses import dataclass, field
|
||||
import typing
|
||||
import tantri.dipoles.types
|
||||
import pathlib
|
||||
from enum import Enum, IntEnum
|
||||
import logging
|
||||
from kalpaa.config.config import (
|
||||
MeasurementTypeEnum,
|
||||
SkipToStage,
|
||||
GeneralConfig,
|
||||
TantriConfig,
|
||||
GenerationConfig,
|
||||
DeepdogConfig,
|
||||
Config,
|
||||
ReducedModelParams,
|
||||
)
|
||||
from kalpaa.config.config_reader import (
|
||||
read_config_dict,
|
||||
serialize_config,
|
||||
read_config,
|
||||
read_general_config_dict,
|
||||
)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MeasurementTypeEnum(Enum):
|
||||
POTENTIAL = "electric-potential"
|
||||
X_ELECTRIC_FIELD = "x-electric-field"
|
||||
|
||||
|
||||
class SkipToStage(IntEnum):
|
||||
# shouldn't need this lol
|
||||
STAGE_01 = 0
|
||||
STAGE_02 = 1
|
||||
STAGE_03 = 2
|
||||
STAGE_04 = 3
|
||||
|
||||
|
||||
# Copy over some random constants to see if they're ever reused
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneralConfig:
|
||||
dots_json_name: str = "dots.json"
|
||||
indexes_json_name: str = "indexes.json"
|
||||
out_dir_name: str = "out"
|
||||
log_pattern: str = (
|
||||
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
)
|
||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||
|
||||
mega_merged_name: str = "mega_merged_coalesced.csv"
|
||||
mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv"
|
||||
|
||||
skip_to_stage: typing.Optional[int] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TantriConfig:
|
||||
index_seed_starter: int = 31415
|
||||
num_seeds: int = 100
|
||||
delta_t: float = 0.05
|
||||
num_iterations: int = 100000
|
||||
# sample_rate = 10
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationConfig:
|
||||
# Interact with indexes.json, probably should be a subset
|
||||
counts: typing.Sequence[int] = field(default_factory=lambda: [1, 10])
|
||||
orientations: typing.Sequence[tantri.dipoles.types.Orientation] = field(
|
||||
default_factory=lambda: [
|
||||
tantri.dipoles.types.Orientation.RANDOM,
|
||||
tantri.dipoles.types.Orientation.Z,
|
||||
tantri.dipoles.types.Orientation.XY,
|
||||
]
|
||||
)
|
||||
# TODO: what's replica here?
|
||||
num_replicas: int = 3
|
||||
|
||||
# the above three can be overrided with manually specified configurations
|
||||
override_dipole_configs: typing.Optional[
|
||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||
] = None
|
||||
|
||||
tantri_configs: typing.List[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
)
|
||||
|
||||
num_bin_time_series: int = 25
|
||||
bin_log_width: float = 0.25
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeepdogConfig:
|
||||
"""
|
||||
Class that holds all of the computational parameters
|
||||
"""
|
||||
|
||||
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
|
||||
target_success: int = 1000
|
||||
max_monte_carlo_cycles_steps: int = 20
|
||||
# Whether to use a log log cost function
|
||||
use_log_noise: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
generation_config: GenerationConfig = GenerationConfig()
|
||||
general_config: GeneralConfig = GeneralConfig()
|
||||
deepdog_config: DeepdogConfig = DeepdogConfig()
|
||||
|
||||
def absify(self, filename: str) -> pathlib.Path:
|
||||
ret = (self.general_config.root_directory / filename).resolve()
|
||||
_logger.debug(f"Absifying {filename=}, geting {ret}")
|
||||
return ret
|
||||
|
||||
def get_out_dir_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.out_dir_name)
|
||||
|
||||
def get_dots_json_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.dots_json_name)
|
||||
|
||||
def indexifier(self) -> deepdog.indexify.Indexifier:
|
||||
with self.absify(self.general_config.indexes_json_name).open(
|
||||
"r"
|
||||
) as indexify_json_file:
|
||||
indexify_spec = json.load(indexify_json_file)
|
||||
indexify_data = indexify_spec["indexes"]
|
||||
if "seed_spec" in indexify_spec:
|
||||
seed_spec = indexify_spec["seed_spec"]
|
||||
indexify_data[seed_spec["field_name"]] = list(
|
||||
range(seed_spec["num_seeds"])
|
||||
)
|
||||
|
||||
_logger.info(f"loading indexifier with data {indexify_data=}")
|
||||
return deepdog.indexify.Indexifier(indexify_data)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReducedModelParams:
|
||||
"""
|
||||
Units usually in 10s of nm for distance, s or Hz as needed for time units, log units are log base 10 of Hz or s values.
|
||||
"""
|
||||
|
||||
x_min: float = -20
|
||||
x_max: float = 20
|
||||
y_min: float = -10
|
||||
y_max: float = 10
|
||||
z_min: float = 5
|
||||
z_max: float = 6.5
|
||||
w_log_min: float = -5
|
||||
w_log_max: float = 1
|
||||
count: int = 1
|
||||
log_magnitude: float = 2
|
||||
orientation: tantri.dipoles.types.Orientation = (
|
||||
tantri.dipoles.types.Orientation.RANDOM
|
||||
)
|
||||
|
||||
def config_dict(self, seed: int) -> typing.Dict[str, typing.Any]:
|
||||
output_dict = {
|
||||
"x_min": self.x_min,
|
||||
"x_max": self.x_max,
|
||||
"y_min": self.y_min,
|
||||
"y_max": self.y_max,
|
||||
"z_min": self.z_min,
|
||||
"z_max": self.z_max,
|
||||
"mag": 10**self.log_magnitude,
|
||||
"w_log_min": self.w_log_min,
|
||||
"w_log_max": self.w_log_max,
|
||||
"orientation": self.orientation,
|
||||
"dipole_count": self.count,
|
||||
"generation_seed": seed,
|
||||
}
|
||||
return output_dict
|
||||
__all__ = [
|
||||
"MeasurementTypeEnum",
|
||||
"SkipToStage",
|
||||
"GeneralConfig",
|
||||
"TantriConfig",
|
||||
"GenerationConfig",
|
||||
"DeepdogConfig",
|
||||
"Config",
|
||||
"ReducedModelParams",
|
||||
"read_config_dict",
|
||||
"serialize_config",
|
||||
"read_config",
|
||||
"read_general_config_dict",
|
||||
]
|
||||
|
||||
163
kalpaa/config/config.py
Normal file
163
kalpaa/config/config.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import json
|
||||
import deepdog.indexify
|
||||
from dataclasses import dataclass, field
|
||||
import typing
|
||||
import tantri.dipoles.types
|
||||
import pathlib
|
||||
from enum import Enum, IntEnum
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MeasurementTypeEnum(Enum):
|
||||
POTENTIAL = "electric-potential"
|
||||
X_ELECTRIC_FIELD = "x-electric-field"
|
||||
|
||||
|
||||
class SkipToStage(IntEnum):
|
||||
# shouldn't need this lol
|
||||
STAGE_01 = 0
|
||||
STAGE_02 = 1
|
||||
STAGE_03 = 2
|
||||
STAGE_04 = 3
|
||||
|
||||
|
||||
# Copy over some random constants to see if they're ever reused
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneralConfig:
|
||||
dots_json_name: str = "dots.json"
|
||||
indexes_json_name: str = "indexes.json"
|
||||
out_dir_name: str = "out"
|
||||
log_pattern: str = (
|
||||
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
)
|
||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||
|
||||
mega_merged_name: str = "mega_merged_coalesced.csv"
|
||||
mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv"
|
||||
|
||||
skip_to_stage: typing.Optional[int] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TantriConfig:
|
||||
index_seed_starter: int = 31415
|
||||
num_seeds: int = 100
|
||||
delta_t: float = 0.05
|
||||
num_iterations: int = 100000
|
||||
# sample_rate = 10
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationConfig:
|
||||
# Interact with indexes.json, probably should be a subset
|
||||
counts: typing.Sequence[int] = field(default_factory=lambda: [1, 10])
|
||||
orientations: typing.Sequence[tantri.dipoles.types.Orientation] = field(
|
||||
default_factory=lambda: [
|
||||
tantri.dipoles.types.Orientation.RANDOM,
|
||||
tantri.dipoles.types.Orientation.Z,
|
||||
tantri.dipoles.types.Orientation.XY,
|
||||
]
|
||||
)
|
||||
# TODO: what's replica here?
|
||||
num_replicas: int = 3
|
||||
|
||||
# the above three can be overrided with manually specified configurations
|
||||
override_dipole_configs: typing.Optional[
|
||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||
] = None
|
||||
|
||||
tantri_configs: typing.List[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
)
|
||||
|
||||
num_bin_time_series: int = 25
|
||||
bin_log_width: float = 0.25
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeepdogConfig:
|
||||
"""
|
||||
Class that holds all of the computational parameters
|
||||
"""
|
||||
|
||||
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
|
||||
target_success: int = 1000
|
||||
max_monte_carlo_cycles_steps: int = 20
|
||||
# Whether to use a log log cost function
|
||||
use_log_noise: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
generation_config: GenerationConfig = GenerationConfig()
|
||||
general_config: GeneralConfig = GeneralConfig()
|
||||
deepdog_config: DeepdogConfig = DeepdogConfig()
|
||||
|
||||
def absify(self, filename: str) -> pathlib.Path:
|
||||
ret = (self.general_config.root_directory / filename).resolve()
|
||||
_logger.debug(f"Absifying {filename=}, for root directory {self.general_config.root_directory}, geting {ret}")
|
||||
return ret
|
||||
|
||||
def get_out_dir_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.out_dir_name)
|
||||
|
||||
def get_dots_json_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.dots_json_name)
|
||||
|
||||
def indexifier(self) -> deepdog.indexify.Indexifier:
|
||||
with self.absify(self.general_config.indexes_json_name).open(
|
||||
"r"
|
||||
) as indexify_json_file:
|
||||
indexify_spec = json.load(indexify_json_file)
|
||||
indexify_data = indexify_spec["indexes"]
|
||||
if "seed_spec" in indexify_spec:
|
||||
seed_spec = indexify_spec["seed_spec"]
|
||||
indexify_data[seed_spec["field_name"]] = list(
|
||||
range(seed_spec["num_seeds"])
|
||||
)
|
||||
|
||||
_logger.info(f"loading indexifier with data {indexify_data=}")
|
||||
return deepdog.indexify.Indexifier(indexify_data)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReducedModelParams:
|
||||
"""
|
||||
Units usually in 10s of nm for distance, s or Hz as needed for time units, log units are log base 10 of Hz or s values.
|
||||
"""
|
||||
|
||||
x_min: float = -20
|
||||
x_max: float = 20
|
||||
y_min: float = -10
|
||||
y_max: float = 10
|
||||
z_min: float = 5
|
||||
z_max: float = 6.5
|
||||
w_log_min: float = -5
|
||||
w_log_max: float = 1
|
||||
count: int = 1
|
||||
log_magnitude: float = 2
|
||||
orientation: tantri.dipoles.types.Orientation = (
|
||||
tantri.dipoles.types.Orientation.RANDOM
|
||||
)
|
||||
|
||||
def config_dict(self, seed: int) -> typing.Dict[str, typing.Any]:
|
||||
output_dict = {
|
||||
"x_min": self.x_min,
|
||||
"x_max": self.x_max,
|
||||
"y_min": self.y_min,
|
||||
"y_max": self.y_max,
|
||||
"z_min": self.z_min,
|
||||
"z_max": self.z_max,
|
||||
"mag": 10**self.log_magnitude,
|
||||
"w_log_min": self.w_log_min,
|
||||
"w_log_max": self.w_log_max,
|
||||
"orientation": self.orientation,
|
||||
"dipole_count": self.count,
|
||||
"generation_seed": seed,
|
||||
}
|
||||
return output_dict
|
||||
@@ -1,12 +1,13 @@
|
||||
import pathlib
|
||||
import logging
|
||||
|
||||
import dataclasses
|
||||
|
||||
import kalpaa.stages.stage01
|
||||
import kalpaa.stages.stage02
|
||||
import kalpaa.stages.stage03
|
||||
import kalpaa.stages.stage04
|
||||
import kalpaa.common
|
||||
import tantri.dipoles.types
|
||||
import kalpaa.config
|
||||
|
||||
import argparse
|
||||
@@ -72,6 +73,22 @@ def parse_args():
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--directory-label",
|
||||
type=str,
|
||||
help="Label for directory to put files in within root",
|
||||
default="output1",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
type=str,
|
||||
help="kalpaa.toml file to use for configuration",
|
||||
default="kalpaa.toml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--skip-to-stage",
|
||||
@@ -86,52 +103,15 @@ def parse_args():
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
tantri_configs = [
|
||||
kalpaa.TantriConfig(123456, 50, 0.5, 100000),
|
||||
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
|
||||
]
|
||||
|
||||
# override_config = {
|
||||
# # "test1": [
|
||||
# # tantri.dipoles.types.DipoleTO(
|
||||
# # numpy.array([0, 0, 100]),
|
||||
# # numpy.array([-2, -2, 2.9]),
|
||||
# # 0.0005
|
||||
# # )
|
||||
# # ],
|
||||
# "two_dipole_connors_geom": [
|
||||
# tantri.dipoles.types.DipoleTO(
|
||||
# numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.0005
|
||||
# ),
|
||||
# tantri.dipoles.types.DipoleTO(
|
||||
# numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.05
|
||||
# ),
|
||||
# ],
|
||||
# "two_dipole_connors_geom_omegaswap": [
|
||||
# tantri.dipoles.types.DipoleTO(
|
||||
# numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.05
|
||||
# ),
|
||||
# tantri.dipoles.types.DipoleTO(
|
||||
# numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.0005
|
||||
# ),
|
||||
# ],
|
||||
# }
|
||||
|
||||
generation_config = kalpaa.GenerationConfig(
|
||||
tantri_configs=tantri_configs,
|
||||
counts=[3, 31],
|
||||
num_replicas=5,
|
||||
# let's test this out
|
||||
# override_dipole_configs=override_config,
|
||||
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||
num_bin_time_series=25,
|
||||
)
|
||||
config = kalpaa.config.read_config(pathlib.Path(args.config_file))
|
||||
label = args.directory_label
|
||||
|
||||
if args.override_root is None:
|
||||
_logger.info("root dir not given")
|
||||
root = pathlib.Path("plots0")
|
||||
# root = pathlib.Path("hardcodedoutplace")
|
||||
root = config.general_config.root_directory / label
|
||||
else:
|
||||
root = pathlib.Path(args.override_root)
|
||||
root = pathlib.Path(args.override_root) / label
|
||||
|
||||
if args.skip_to_stage is not None:
|
||||
if args.skip_to_stage not in [1, 2, 3, 4]:
|
||||
@@ -141,31 +121,20 @@ def main():
|
||||
else:
|
||||
skip = None
|
||||
|
||||
general_config = kalpaa.GeneralConfig(
|
||||
measurement_type=kalpaa.MeasurementTypeEnum.POTENTIAL,
|
||||
out_dir_name=str(root / "out"),
|
||||
skip_to_stage=skip,
|
||||
)
|
||||
_logger.info(skip)
|
||||
|
||||
# kalpa.GeneralConfig
|
||||
kalpaa.common.set_up_logging(config, str(root / f"logs/kalpaa.log"))
|
||||
|
||||
deepdog_config = kalpaa.DeepdogConfig(
|
||||
costs_to_try=[2, 1],
|
||||
max_monte_carlo_cycles_steps=20,
|
||||
target_success=200,
|
||||
use_log_noise=True,
|
||||
)
|
||||
_logger.info(f"Root dir is {root}, copying over {config.general_config.indexes_json_name}, {config.general_config.dots_json_name} and {args.config_file}")
|
||||
for file in [config.general_config.indexes_json_name, config.general_config.dots_json_name, args.config_file]:
|
||||
_logger.info(f"Copying {file} to {root}")
|
||||
(root / file).write_text((pathlib.Path.cwd() / file).read_text())
|
||||
|
||||
config = kalpaa.Config(
|
||||
generation_config=generation_config,
|
||||
general_config=general_config,
|
||||
deepdog_config=deepdog_config,
|
||||
)
|
||||
|
||||
kalpaa.common.set_up_logging(config, str(root / f"logs/{root}.log"))
|
||||
overridden_config = dataclasses.replace(config, general_config=dataclasses.replace(config.general_config, root_directory=root.resolve(), skip_to_stage=skip))
|
||||
|
||||
_logger.info(f"Got {config=}")
|
||||
runner = Runner(config)
|
||||
runner = Runner(overridden_config)
|
||||
runner.run()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user