feat!: copying over from sd4
This commit is contained in:
parent
e408920bde
commit
5a0b294150
@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
|||||||
import typing
|
import typing
|
||||||
import tantri.dipoles.types
|
import tantri.dipoles.types
|
||||||
import pathlib
|
import pathlib
|
||||||
from enum import Enum
|
from enum import Enum, IntEnum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
@ -14,21 +14,30 @@ class MeasurementTypeEnum(Enum):
|
|||||||
POTENTIAL = "electric-potential"
|
POTENTIAL = "electric-potential"
|
||||||
X_ELECTRIC_FIELD = "x-electric-field"
|
X_ELECTRIC_FIELD = "x-electric-field"
|
||||||
|
|
||||||
|
class SkipToStage(IntEnum):
|
||||||
|
# shouldn't need this lol
|
||||||
|
STAGE_01 = 0
|
||||||
|
STAGE_02 = 1
|
||||||
|
STAGE_03 = 2
|
||||||
|
STAGE_04 = 3
|
||||||
|
|
||||||
|
|
||||||
# Copy over some random constants to see if they're ever reused
|
# Copy over some random constants to see if they're ever reused
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GeneralConfig:
|
class GeneralConfig:
|
||||||
dots_json_name = "dots.json"
|
dots_json_name: str = "dots.json"
|
||||||
indexes_json_name = "indexes.json"
|
indexes_json_name: str = "indexes.json"
|
||||||
out_dir_name = "out"
|
out_dir_name: str = "out"
|
||||||
log_pattern = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||||
|
|
||||||
mega_merged_name = "mega_merged_coalesced.csv"
|
mega_merged_name: str = "mega_merged_coalesced.csv"
|
||||||
mega_merged_inferenced_name = "mega_merged_coalesced_inferenced.csv"
|
mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv"
|
||||||
|
|
||||||
|
skip_to_stage: typing.Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -51,12 +60,18 @@ class GenerationConfig:
|
|||||||
tantri.dipoles.types.Orientation.XY,
|
tantri.dipoles.types.Orientation.XY,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
# TODO: what's replica here?
|
||||||
num_replicas: int = 3
|
num_replicas: int = 3
|
||||||
|
|
||||||
|
# the above three can be overrided with manually specified configurations
|
||||||
|
override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None
|
||||||
|
|
||||||
tantri_configs: typing.Sequence[TantriConfig] = field(
|
tantri_configs: typing.Sequence[TantriConfig] = field(
|
||||||
default_factory=lambda: [TantriConfig()]
|
default_factory=lambda: [TantriConfig()]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
num_bin_time_series: int = 25
|
num_bin_time_series: int = 25
|
||||||
bin_log_width: float = 0.25
|
bin_log_width: float = 0.25
|
||||||
|
|
||||||
@ -70,6 +85,8 @@ class DeepdogConfig:
|
|||||||
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
|
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
|
||||||
target_success: int = 1000
|
target_success: int = 1000
|
||||||
max_monte_carlo_cycles_steps: int = 20
|
max_monte_carlo_cycles_steps: int = 20
|
||||||
|
# Whetehr to use a log log cost function
|
||||||
|
use_log_noise: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -80,7 +97,7 @@ class Config:
|
|||||||
|
|
||||||
def absify(self, filename: str) -> pathlib.Path:
|
def absify(self, filename: str) -> pathlib.Path:
|
||||||
ret = (self.general_config.root_directory / filename).resolve()
|
ret = (self.general_config.root_directory / filename).resolve()
|
||||||
_logger.debug(f"Absifying {filename=}, getting {ret}")
|
_logger.debug(f"Absifying {filename=}, geting {ret}")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_out_dir_path(self) -> pathlib.Path:
|
def get_out_dir_path(self) -> pathlib.Path:
|
||||||
@ -115,8 +132,8 @@ class ReducedModelParams:
|
|||||||
x_max: float = 20
|
x_max: float = 20
|
||||||
y_min: float = -10
|
y_min: float = -10
|
||||||
y_max: float = 10
|
y_max: float = 10
|
||||||
z_min: float = 0
|
z_min: float = 5
|
||||||
z_max: float = 5
|
z_max: float = 6.5
|
||||||
w_log_min: float = -5
|
w_log_min: float = -5
|
||||||
w_log_max: float = 1
|
w_log_max: float = 1
|
||||||
count: int = 1
|
count: int = 1
|
||||||
|
@ -66,45 +66,37 @@ class Coalescer:
|
|||||||
_logger.debug(f"subdict keys: {subdict.keys()}")
|
_logger.debug(f"subdict keys: {subdict.keys()}")
|
||||||
|
|
||||||
# TODO hardcoding 3 generations
|
# TODO hardcoding 3 generations
|
||||||
if self.num_replicas != 3:
|
# if self.num_replicas != 3:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
f"num replicas was {self.num_replicas}, but we've hard coded 3"
|
# f"num replicas was {self.num_replicas}, but we've hard coded 3"
|
||||||
)
|
# )
|
||||||
# generations_keys = ["0", "1", "2"]
|
# generations_keys = ["0", "1", "2"]
|
||||||
|
|
||||||
|
_logger.info(f"Going through generation {0}")
|
||||||
|
|
||||||
# 0th gen is easiest
|
# 0th gen is easiest
|
||||||
for model_key, val in subdict["0"].items():
|
for model_key, val in subdict["0"].items():
|
||||||
val["coalesced_prob"] = val["prob"]
|
val["coalesced_prob"] = val["prob"]
|
||||||
|
|
||||||
weight1 = sum(
|
if self.num_replicas > 1:
|
||||||
[
|
for gen in range(1, self.num_replicas):
|
||||||
float(subdict["0"][key]["coalesced_prob"])
|
_logger.info(f"Going through generation {gen}")
|
||||||
* float(subdict["1"][key]["prob"])
|
|
||||||
for key in subdict["1"].keys()
|
generation_weight = sum(
|
||||||
]
|
[
|
||||||
)
|
float(subdict[str(gen - 1)][key]["coalesced_prob"])
|
||||||
_logger.debug(weight1)
|
* float(subdict[str(gen)][key]["prob"])
|
||||||
for model_key, val in subdict["1"].items():
|
for key in subdict[str(gen)].keys()
|
||||||
val["coalesced_prob"] = (
|
]
|
||||||
float(val["prob"])
|
)
|
||||||
* float(subdict["0"][model_key]["coalesced_prob"])
|
_logger.debug(generation_weight)
|
||||||
/ weight1
|
for model_key, val in subdict[str(gen)].items():
|
||||||
)
|
val["coalesced_prob"] = (
|
||||||
|
float(val["prob"])
|
||||||
|
* float(subdict[str(gen-1)][model_key]["coalesced_prob"])
|
||||||
|
/ generation_weight
|
||||||
|
)
|
||||||
|
|
||||||
weight2 = sum(
|
|
||||||
[
|
|
||||||
float(subdict["1"][key]["coalesced_prob"])
|
|
||||||
* float(subdict["2"][key]["prob"])
|
|
||||||
for key in subdict["2"].keys()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_logger.debug(weight2)
|
|
||||||
for model_key, val in subdict["2"].items():
|
|
||||||
val["coalesced_prob"] = (
|
|
||||||
float(val["prob"])
|
|
||||||
* float(subdict["1"][model_key]["coalesced_prob"])
|
|
||||||
/ weight2
|
|
||||||
)
|
|
||||||
|
|
||||||
def coalesce_all(self):
|
def coalesce_all(self):
|
||||||
for actual_key in self.actual_dict.keys():
|
for actual_key in self.actual_dict.keys():
|
||||||
|
@ -58,6 +58,7 @@ class StDevUsingCostFunction:
|
|||||||
dot_inputs_array,
|
dot_inputs_array,
|
||||||
actual_measurement_array,
|
actual_measurement_array,
|
||||||
actual_stdev_array,
|
actual_stdev_array,
|
||||||
|
log_noise: bool = False,
|
||||||
):
|
):
|
||||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||||
self.measurement_type = measurement_type
|
self.measurement_type = measurement_type
|
||||||
@ -67,6 +68,16 @@ class StDevUsingCostFunction:
|
|||||||
self.actual_stdev_array = actual_stdev_array
|
self.actual_stdev_array = actual_stdev_array
|
||||||
self.actual_stdev_array2 = actual_stdev_array**2
|
self.actual_stdev_array2 = actual_stdev_array**2
|
||||||
|
|
||||||
|
self.use_log_noise = log_noise
|
||||||
|
self.log_actual = numpy.log(self.actual_measurement_array)
|
||||||
|
self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2
|
||||||
|
# if self.use_log_noise:
|
||||||
|
# _logger.debug("remove these debugs later")
|
||||||
|
# _logger.debug(self.actual_measurement_array)
|
||||||
|
# _logger.debug(self.actual_stdev_array)
|
||||||
|
# _logger.debug(self.log_actual)
|
||||||
|
# _logger.debug(self.log_denom2)
|
||||||
|
|
||||||
def __call__(self, dipoles_to_test):
|
def __call__(self, dipoles_to_test):
|
||||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||||
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
||||||
@ -76,7 +87,12 @@ class StDevUsingCostFunction:
|
|||||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||||
self.dot_inputs_array, dipoles_to_test
|
self.dot_inputs_array, dipoles_to_test
|
||||||
)
|
)
|
||||||
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2
|
|
||||||
|
if self.use_log_noise:
|
||||||
|
diffs = ((numpy.log(vals) - self.log_actual) **2) / self.log_denom2
|
||||||
|
else:
|
||||||
|
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2
|
||||||
|
|
||||||
return numpy.sqrt(diffs.mean(axis=-1))
|
return numpy.sqrt(diffs.mean(axis=-1))
|
||||||
|
|
||||||
|
|
||||||
@ -202,6 +218,7 @@ class BinnedData:
|
|||||||
csv_dict: typing.Dict[str, typing.Any]
|
csv_dict: typing.Dict[str, typing.Any]
|
||||||
measurement_type: str
|
measurement_type: str
|
||||||
|
|
||||||
|
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
||||||
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]:
|
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]:
|
||||||
if dot_name not in self.dots_dict:
|
if dot_name not in self.dots_dict:
|
||||||
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
||||||
@ -255,6 +272,7 @@ class BinnedData:
|
|||||||
def _stdev_cost_function(
|
def _stdev_cost_function(
|
||||||
self,
|
self,
|
||||||
measurements: typing.Sequence[Measurement],
|
measurements: typing.Sequence[Measurement],
|
||||||
|
use_log_noise: bool = False,
|
||||||
):
|
):
|
||||||
meas_array = numpy.array([m.dot_measurement.v for m in measurements])
|
meas_array = numpy.array([m.dot_measurement.v for m in measurements])
|
||||||
stdev_array = numpy.array([m.stdev for m in measurements])
|
stdev_array = numpy.array([m.stdev for m in measurements])
|
||||||
@ -266,7 +284,7 @@ class BinnedData:
|
|||||||
_logger.debug(f"Obtained {input_array=}")
|
_logger.debug(f"Obtained {input_array=}")
|
||||||
|
|
||||||
return StDevUsingCostFunction(
|
return StDevUsingCostFunction(
|
||||||
self.measurement_type, input_array, meas_array, stdev_array
|
self.measurement_type, input_array, meas_array, stdev_array, use_log_noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
||||||
@ -277,10 +295,13 @@ class BinnedData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def stdev_cost_function_filter(
|
def stdev_cost_function_filter(
|
||||||
self, dot_names: typing.Sequence[str], target_cost: float
|
self,
|
||||||
|
dot_names: typing.Sequence[str],
|
||||||
|
target_cost: float,
|
||||||
|
use_log_noise: bool = False,
|
||||||
):
|
):
|
||||||
measurements = self.measurements(dot_names)
|
measurements = self.measurements(dot_names)
|
||||||
cost_function = self._stdev_cost_function(measurements)
|
cost_function = self._stdev_cost_function(measurements, use_log_noise)
|
||||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||||
cost_function, target_cost
|
cost_function, target_cost
|
||||||
)
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pathlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import kalpa.stages.stage01
|
import kalpa.stages.stage01
|
||||||
@ -6,9 +7,13 @@ import kalpa.stages.stage03
|
|||||||
import kalpa.stages.stage04
|
import kalpa.stages.stage04
|
||||||
import kalpa.common
|
import kalpa.common
|
||||||
import tantri.dipoles.types
|
import tantri.dipoles.types
|
||||||
|
import kalpa.config
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
# try not to use this out side of main or when defining config stuff pls
|
||||||
|
import numpy
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -18,21 +23,41 @@ class Runner:
|
|||||||
_logger.info(f"Initialising runner with {config=}")
|
_logger.info(f"Initialising runner with {config=}")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
_logger.info("*** Beginning Stage 01 ***")
|
|
||||||
stage01 = kalpa.stages.stage01.Stage01Runner(self.config)
|
|
||||||
stage01.run()
|
|
||||||
|
|
||||||
_logger.info("*** Beginning Stage 02 ***")
|
if self.config.general_config.skip_to_stage is not None:
|
||||||
stage02 = kalpa.stages.stage02.Stage02Runner(self.config)
|
|
||||||
stage02.run()
|
|
||||||
|
|
||||||
_logger.info("*** Beginning Stage 03 ***")
|
stage01 = kalpa.stages.stage01.Stage01Runner(self.config)
|
||||||
stage03 = kalpa.stages.stage03.Stage03Runner(self.config)
|
stage02 = kalpa.stages.stage02.Stage02Runner(self.config)
|
||||||
stage03.run()
|
stage03 = kalpa.stages.stage03.Stage03Runner(self.config)
|
||||||
|
stage04 = kalpa.stages.stage04.Stage04Runner(self.config)
|
||||||
|
|
||||||
_logger.info("*** Beginning Stage 04 ***")
|
stages = [stage01, stage02, stage03, stage04]
|
||||||
stage04 = kalpa.stages.stage04.Stage04Runner(self.config)
|
|
||||||
stage04.run()
|
start = int(self.config.general_config.skip_to_stage)
|
||||||
|
_logger.info(f"Received instruction to start at stage {start + 1}")
|
||||||
|
for i, stage in enumerate(stages[start: 4]):
|
||||||
|
_logger.info(f"*** Running stage {i + start + 1}")
|
||||||
|
stage.run()
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
# standard run, can keep old
|
||||||
|
|
||||||
|
_logger.info("*** Beginning Stage 01 ***")
|
||||||
|
stage01 = kalpa.stages.stage01.Stage01Runner(self.config)
|
||||||
|
stage01.run()
|
||||||
|
|
||||||
|
_logger.info("*** Beginning Stage 02 ***")
|
||||||
|
stage02 = kalpa.stages.stage02.Stage02Runner(self.config)
|
||||||
|
stage02.run()
|
||||||
|
|
||||||
|
_logger.info("*** Beginning Stage 03 ***")
|
||||||
|
stage03 = kalpa.stages.stage03.Stage03Runner(self.config)
|
||||||
|
stage03.run()
|
||||||
|
|
||||||
|
_logger.info("*** Beginning Stage 04 ***")
|
||||||
|
stage04 = kalpa.stages.stage04.Stage04Runner(self.config)
|
||||||
|
stage04.run()
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -42,9 +67,16 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-file",
|
"--override-root",
|
||||||
type=str,
|
type=str,
|
||||||
help="A filename for logging to, if not provided will only log to stderr",
|
help="If provided, override the root dir.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-s", "--skip-to-stage",
|
||||||
|
type=int,
|
||||||
|
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -55,23 +87,85 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
tantri_configs = [
|
tantri_configs = [
|
||||||
kalpa.TantriConfig(12345, 50, 0.5, 100000),
|
kalpa.TantriConfig(123456, 50, 0.5, 100000),
|
||||||
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
|
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
override_config = {
|
||||||
|
# "test1": [
|
||||||
|
# tantri.dipoles.types.DipoleTO(
|
||||||
|
# numpy.array([0, 0, 100]),
|
||||||
|
# numpy.array([-2, -2, 2.9]),
|
||||||
|
# 0.0005
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
"two_dipole_connors_geom": [
|
||||||
|
tantri.dipoles.types.DipoleTO(
|
||||||
|
numpy.array([0, 0, 100]),
|
||||||
|
numpy.array([-2, -2, 5.75]),
|
||||||
|
0.0005
|
||||||
|
),
|
||||||
|
tantri.dipoles.types.DipoleTO(
|
||||||
|
numpy.array([0, 0, 100]),
|
||||||
|
numpy.array([6, 2, 5.75]),
|
||||||
|
0.05
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"two_dipole_connors_geom_omegaswap": [
|
||||||
|
tantri.dipoles.types.DipoleTO(
|
||||||
|
numpy.array([0, 0, 100]),
|
||||||
|
numpy.array([-2, -2, 5.75]),
|
||||||
|
0.05
|
||||||
|
),
|
||||||
|
tantri.dipoles.types.DipoleTO(
|
||||||
|
numpy.array([0, 0, 100]),
|
||||||
|
numpy.array([6, 2, 5.75]),
|
||||||
|
0.0005
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
generation_config = kalpa.GenerationConfig(
|
generation_config = kalpa.GenerationConfig(
|
||||||
tantri_configs=tantri_configs,
|
tantri_configs=tantri_configs,
|
||||||
counts=[1],
|
counts=[3, 31],
|
||||||
num_replicas=3,
|
num_replicas=5,
|
||||||
orientations=[tantri.dipoles.types.Orientation.XY],
|
|
||||||
|
# let's test this out
|
||||||
|
# override_dipole_configs=override_config,
|
||||||
|
|
||||||
|
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||||
num_bin_time_series=25,
|
num_bin_time_series=25,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.override_root is None:
|
||||||
|
_logger.info("root dir not given")
|
||||||
|
root = pathlib.Path("plots0")
|
||||||
|
else:
|
||||||
|
root = pathlib.Path(args.override_root)
|
||||||
|
|
||||||
|
if args.skip_to_stage is not None:
|
||||||
|
if args.skip_to_stage not in [1, 2, 3, 4]:
|
||||||
|
raise ValueError(f"There is no stage {args.skip_to_stage}")
|
||||||
|
else:
|
||||||
|
skip = kalpa.config.SkipToStage(args.skip_to_stage - 1)
|
||||||
|
else:
|
||||||
|
skip = None
|
||||||
|
|
||||||
|
|
||||||
general_config = kalpa.GeneralConfig(
|
general_config = kalpa.GeneralConfig(
|
||||||
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL
|
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
|
||||||
|
out_dir_name=str(root / "out"),
|
||||||
|
skip_to_stage=skip
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# kalpa.GeneralConfig
|
||||||
|
|
||||||
deepdog_config = kalpa.DeepdogConfig(
|
deepdog_config = kalpa.DeepdogConfig(
|
||||||
costs_to_try=[10, 2, 1, 0.1],
|
costs_to_try=[2, 1],
|
||||||
max_monte_carlo_cycles_steps=20,
|
max_monte_carlo_cycles_steps=20,
|
||||||
|
target_success=200,
|
||||||
|
use_log_noise=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = kalpa.Config(
|
config = kalpa.Config(
|
||||||
@ -80,7 +174,7 @@ def main():
|
|||||||
deepdog_config=deepdog_config,
|
deepdog_config=deepdog_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
kalpa.common.set_up_logging(config, args.log_file)
|
kalpa.common.set_up_logging(config, str(root / f"logs/{root}.log"))
|
||||||
|
|
||||||
_logger.info(f"Got {config=}")
|
_logger.info(f"Got {config=}")
|
||||||
runner = Runner(config)
|
runner = Runner(config)
|
||||||
|
@ -7,7 +7,10 @@ import logging
|
|||||||
import kalpa
|
import kalpa
|
||||||
import kalpa.common
|
import kalpa.common
|
||||||
import tantri.cli
|
import tantri.cli
|
||||||
|
import tantri.cli.input_files
|
||||||
|
import tantri.cli.input_files.write_dipoles
|
||||||
import tantri.dipoles.types
|
import tantri.dipoles.types
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
@ -113,7 +116,9 @@ class Stage01Runner:
|
|||||||
for tantri_index, tantri_config in enumerate(
|
for tantri_index, tantri_config in enumerate(
|
||||||
self.config.generation_config.tantri_configs
|
self.config.generation_config.tantri_configs
|
||||||
):
|
):
|
||||||
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
|
output_csv = directory / kalpa.common.tantri_full_output_name(
|
||||||
|
tantri_index
|
||||||
|
)
|
||||||
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||||
tantri_index
|
tantri_index
|
||||||
)
|
)
|
||||||
@ -131,16 +136,95 @@ class Stage01Runner:
|
|||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This whole method is duplication and is ripe for refactor, but that's fine!
|
||||||
|
# deliberately bad to get it done.
|
||||||
|
# here we're going to be manually specifying dipoles as we have from our config
|
||||||
|
def generate_override_dipole(
|
||||||
|
self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
create a directory, populate it with stuff.
|
||||||
|
|
||||||
|
seed: still a seed integer to use
|
||||||
|
override_name: the name of this dipole configuration, from config file
|
||||||
|
override_dipoles: dipoles to override
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
f"Writing override config {override_name} with dipoles: [{override_dipoles}]"
|
||||||
|
)
|
||||||
|
out = self.config.get_out_dir_path()
|
||||||
|
directory = out / f"{override_name}"
|
||||||
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
_logger.debug("generated override directory")
|
||||||
|
|
||||||
|
# config_json = directory / "generation_config.json"
|
||||||
|
dipoles_json = directory / "dipoles.json"
|
||||||
|
|
||||||
|
# with open(config_json, "w") as conf_file:
|
||||||
|
# params = kalpa.ReducedModelParams(
|
||||||
|
# count=count, orientation=tantri.dipoles.types.Orientation(orientation)
|
||||||
|
# )
|
||||||
|
# _logger.debug(f"Got params {params=}")
|
||||||
|
# json.dump(params.config_dict(seed), conf_file)
|
||||||
|
# # json.dump(kalpa.common.model_config_dict(count, orientation, seed), conf_file)
|
||||||
|
|
||||||
|
# the original logic looked like this:
|
||||||
|
# tantri.cli._generate_dipoles(config_json, dipoles_json, (seed, replica, 1))
|
||||||
|
# We're replicating the bit that wrote the dipoles here, but that's a refactor opportunity
|
||||||
|
with dipoles_json.open("w") as dipole_out:
|
||||||
|
dipole_out.write(
|
||||||
|
json.dumps(
|
||||||
|
[dip.as_dict() for dip in override_dipoles],
|
||||||
|
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_logger.info(f"Wrote to dipoles file {dipoles_json}")
|
||||||
|
|
||||||
|
# tantri.cli._write_apsd(dipoles_json, DOTS, X_ELECTRIC_FIELD, DELTA_T, NUM_ITERATIONS, NUM_BIN_TS, (index, replica, 2), output_csv, binned_csv, BIN_WIDTH_LOG, True)
|
||||||
|
for tantri_index, tantri_config in enumerate(
|
||||||
|
self.config.generation_config.tantri_configs
|
||||||
|
):
|
||||||
|
output_csv = directory / kalpa.common.tantri_full_output_name(
|
||||||
|
tantri_index
|
||||||
|
)
|
||||||
|
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||||
|
tantri_index
|
||||||
|
)
|
||||||
|
tantri.cli._write_apsd(
|
||||||
|
dipoles_json,
|
||||||
|
self.config.general_config.dots_json_name,
|
||||||
|
self.config.general_config.measurement_type.value,
|
||||||
|
tantri_config.delta_t,
|
||||||
|
tantri_config.num_iterations,
|
||||||
|
self.config.generation_config.num_bin_time_series,
|
||||||
|
(seed, 2),
|
||||||
|
output_csv,
|
||||||
|
binned_csv,
|
||||||
|
self.config.generation_config.bin_log_width,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
seed_index = 0
|
seed_index = 0
|
||||||
for count in self.config.generation_config.counts:
|
if self.config.generation_config.override_dipole_configs is None:
|
||||||
for orientation in self.config.generation_config.orientations:
|
# should be by default
|
||||||
for replica in range(self.config.generation_config.num_replicas):
|
_logger.debug("no override needed!")
|
||||||
_logger.info(
|
for count in self.config.generation_config.counts:
|
||||||
f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}"
|
for orientation in self.config.generation_config.orientations:
|
||||||
)
|
for replica in range(self.config.generation_config.num_replicas):
|
||||||
self.generate_single_subdir(seed_index, count, orientation, replica)
|
_logger.info(
|
||||||
seed_index += 1
|
f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}"
|
||||||
|
)
|
||||||
|
self.generate_single_subdir(seed_index, count, orientation, replica)
|
||||||
|
seed_index += 1
|
||||||
|
else:
|
||||||
|
_logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}")
|
||||||
|
for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items():
|
||||||
|
self.generate_override_dipole(seed_index, override_name, override_dipoles)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -178,12 +178,16 @@ class Stage02Runner:
|
|||||||
|
|
||||||
# TODO find way to store this as a global config file
|
# TODO find way to store this as a global config file
|
||||||
occupancies_dict = {
|
occupancies_dict = {
|
||||||
1: (1000, 1000),
|
1: (500, 1000),
|
||||||
10: (1000, 100),
|
2: (250, 2000),
|
||||||
16: (10000, 10),
|
3: (250, 2000),
|
||||||
31: (1000, 100),
|
5: (100, 5000),
|
||||||
56: (1000, 100),
|
10: (50, 10000),
|
||||||
100: (100, 100),
|
16: (50, 10000),
|
||||||
|
17: (50, 10000),
|
||||||
|
31: (50, 10000),
|
||||||
|
56: (25, 20000),
|
||||||
|
100: (5, 100000),
|
||||||
}
|
}
|
||||||
|
|
||||||
mccount, mccountcycles = occupancies_dict[avg_filled]
|
mccount, mccountcycles = occupancies_dict[avg_filled]
|
||||||
@ -212,7 +216,7 @@ class Stage02Runner:
|
|||||||
_logger.info(f"{deepdog_config=}")
|
_logger.info(f"{deepdog_config=}")
|
||||||
|
|
||||||
stdev_cost_function_filters = [
|
stdev_cost_function_filters = [
|
||||||
b.stdev_cost_function_filter(dot_names, cost) for b in binned_datas
|
b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas
|
||||||
]
|
]
|
||||||
|
|
||||||
_logger.debug(f"{stdev_cost_function_filters=}")
|
_logger.debug(f"{stdev_cost_function_filters=}")
|
||||||
|
@ -121,6 +121,31 @@ class Stage04Runner:
|
|||||||
out_list.append(row)
|
out_list.append(row)
|
||||||
return out_list
|
return out_list
|
||||||
|
|
||||||
|
def read_merged_coalesced_csv_override(self, override_name: str) -> typing.Sequence:
|
||||||
|
subdir_name = override_name
|
||||||
|
subdir_path = self.config.get_out_dir_path() / subdir_name
|
||||||
|
csv_path = (
|
||||||
|
subdir_path
|
||||||
|
/ kalpa.common.sorted_bayesruns_name()
|
||||||
|
/ kalpa.common.merged_coalesced_name()
|
||||||
|
)
|
||||||
|
_logger.debug(f"Reading {csv_path=}")
|
||||||
|
with csv_path.open(mode="r", newline="") as csvfile:
|
||||||
|
reader = csv.DictReader(csvfile)
|
||||||
|
out_list = []
|
||||||
|
for row in reader:
|
||||||
|
# We can't put any of the actual info in because it's totally arbitrary, but that's fine!
|
||||||
|
|
||||||
|
# normal_orientation = ORIENTATION_DICT[orientation]
|
||||||
|
row["subdir_name"] = subdir_name
|
||||||
|
# row["actual_orientation"] = ORIENTATION_DICT[orientation]
|
||||||
|
# row["actual_avg_filled"] = count
|
||||||
|
# row["generation_replica_index"] = replica
|
||||||
|
# row["is_row_actual"] = is_actual(row, normal_orientation, count)
|
||||||
|
out_list.append(row)
|
||||||
|
return out_list
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
megamerged_path = (
|
megamerged_path = (
|
||||||
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
|
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
|
||||||
@ -130,46 +155,64 @@ class Stage04Runner:
|
|||||||
with megamerged_path.open(mode="w", newline="") as outfile:
|
with megamerged_path.open(mode="w", newline="") as outfile:
|
||||||
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
for count in self.config.generation_config.counts:
|
|
||||||
for orientation in self.config.generation_config.orientations:
|
|
||||||
for replica in range(self.config.generation_config.num_replicas):
|
if self.config.generation_config.override_dipole_configs is None:
|
||||||
_logger.info(f"Reading {count=} {orientation=} {replica=}")
|
|
||||||
rows = self.read_merged_coalesced_csv(
|
for count in self.config.generation_config.counts:
|
||||||
orientation, count, replica
|
for orientation in self.config.generation_config.orientations:
|
||||||
)
|
for replica in range(self.config.generation_config.num_replicas):
|
||||||
for row in rows:
|
_logger.info(f"Reading {count=} {orientation=} {replica=}")
|
||||||
writer.writerow(row)
|
rows = self.read_merged_coalesced_csv(
|
||||||
|
orientation, count, replica
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
else:
|
||||||
|
_logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}")
|
||||||
|
for override_name in self.config.generation_config.override_dipole_configs.keys():
|
||||||
|
_logger.info(f"Working for subdir {override_name}")
|
||||||
|
rows = self.read_merged_coalesced_csv_override(
|
||||||
|
override_name
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
|
||||||
# merge with inference
|
# merge with inference
|
||||||
|
|
||||||
with megamerged_path.open(mode="r", newline="") as infile:
|
if self.config.generation_config.override_dipole_configs is None:
|
||||||
# Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad:
|
|
||||||
# megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES)
|
|
||||||
megamerged_reader = csv.DictReader(infile)
|
|
||||||
rows = [row for row in megamerged_reader]
|
|
||||||
_logger.debug(rows[0])
|
|
||||||
coalescer = kalpa.inference_coalesce.Coalescer(
|
|
||||||
rows, num_replicas=self.config.generation_config.num_replicas
|
|
||||||
)
|
|
||||||
_logger.info(coalescer.actual_dict.keys())
|
|
||||||
|
|
||||||
# coalescer.coalesce_generations(("fixedxy", "1"), "dot1")
|
with megamerged_path.open(mode="r", newline="") as infile:
|
||||||
|
# Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad:
|
||||||
|
# megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES)
|
||||||
|
megamerged_reader = csv.DictReader(infile)
|
||||||
|
rows = [row for row in megamerged_reader]
|
||||||
|
_logger.debug(rows[0])
|
||||||
|
coalescer = kalpa.inference_coalesce.Coalescer(
|
||||||
|
rows, num_replicas=self.config.generation_config.num_replicas
|
||||||
|
)
|
||||||
|
_logger.info(coalescer.actual_dict.keys())
|
||||||
|
|
||||||
coalesced = coalescer.coalesce_all()
|
# coalescer.coalesce_generations(("fixedxy", "1"), "dot1")
|
||||||
|
|
||||||
inferenced_path = (
|
coalesced = coalescer.coalesce_all()
|
||||||
self.config.get_out_dir_path()
|
|
||||||
/ self.config.general_config.mega_merged_inferenced_name
|
|
||||||
)
|
|
||||||
with inferenced_path.open(mode="w", newline="") as outfile:
|
|
||||||
writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES)
|
|
||||||
writer.writeheader()
|
|
||||||
for val in coalesced.values():
|
|
||||||
for dots in val.values():
|
|
||||||
for generation in dots.values():
|
|
||||||
for row in generation.values():
|
|
||||||
writer.writerow(row)
|
|
||||||
|
|
||||||
|
inferenced_path = (
|
||||||
|
self.config.get_out_dir_path()
|
||||||
|
/ self.config.general_config.mega_merged_inferenced_name
|
||||||
|
)
|
||||||
|
with inferenced_path.open(mode="w", newline="") as outfile:
|
||||||
|
writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES)
|
||||||
|
writer.writeheader()
|
||||||
|
for val in coalesced.values():
|
||||||
|
for dots in val.values():
|
||||||
|
for generation in dots.values():
|
||||||
|
for row in generation.values():
|
||||||
|
writer.writerow(row)
|
||||||
|
else:
|
||||||
|
_logger.info("skipping inference metamerge, overridden dipole config specified")
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
||||||
|
@ -10,8 +10,8 @@ python = ">=3.8.1,<3.10"
|
|||||||
pdme = "^1.5.0"
|
pdme = "^1.5.0"
|
||||||
|
|
||||||
|
|
||||||
deepdog = "^1.4.0"
|
deepdog = "^1.5.0"
|
||||||
tantri = "^1.2.0"
|
tantri = "^1.3.0"
|
||||||
tomli = "^2.0.1"
|
tomli = "^2.0.1"
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
black = "^24.8.0"
|
black = "^24.8.0"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user