feat!: copying over from sd4
This commit is contained in:
parent
e408920bde
commit
5a0b294150
@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
||||
import typing
|
||||
import tantri.dipoles.types
|
||||
import pathlib
|
||||
from enum import Enum
|
||||
from enum import Enum, IntEnum
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -14,21 +14,30 @@ class MeasurementTypeEnum(Enum):
|
||||
POTENTIAL = "electric-potential"
|
||||
X_ELECTRIC_FIELD = "x-electric-field"
|
||||
|
||||
class SkipToStage(IntEnum):
|
||||
# shouldn't need this lol
|
||||
STAGE_01 = 0
|
||||
STAGE_02 = 1
|
||||
STAGE_03 = 2
|
||||
STAGE_04 = 3
|
||||
|
||||
|
||||
# Copy over some random constants to see if they're ever reused
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneralConfig:
|
||||
dots_json_name = "dots.json"
|
||||
indexes_json_name = "indexes.json"
|
||||
out_dir_name = "out"
|
||||
log_pattern = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
dots_json_name: str = "dots.json"
|
||||
indexes_json_name: str = "indexes.json"
|
||||
out_dir_name: str = "out"
|
||||
log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||
|
||||
mega_merged_name = "mega_merged_coalesced.csv"
|
||||
mega_merged_inferenced_name = "mega_merged_coalesced_inferenced.csv"
|
||||
mega_merged_name: str = "mega_merged_coalesced.csv"
|
||||
mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv"
|
||||
|
||||
skip_to_stage: typing.Optional[int] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -51,12 +60,18 @@ class GenerationConfig:
|
||||
tantri.dipoles.types.Orientation.XY,
|
||||
]
|
||||
)
|
||||
# TODO: what's replica here?
|
||||
num_replicas: int = 3
|
||||
|
||||
# the above three can be overrided with manually specified configurations
|
||||
override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None
|
||||
|
||||
tantri_configs: typing.Sequence[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
)
|
||||
|
||||
|
||||
|
||||
num_bin_time_series: int = 25
|
||||
bin_log_width: float = 0.25
|
||||
|
||||
@ -70,6 +85,8 @@ class DeepdogConfig:
|
||||
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
|
||||
target_success: int = 1000
|
||||
max_monte_carlo_cycles_steps: int = 20
|
||||
# Whetehr to use a log log cost function
|
||||
use_log_noise: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -80,7 +97,7 @@ class Config:
|
||||
|
||||
def absify(self, filename: str) -> pathlib.Path:
|
||||
ret = (self.general_config.root_directory / filename).resolve()
|
||||
_logger.debug(f"Absifying {filename=}, getting {ret}")
|
||||
_logger.debug(f"Absifying {filename=}, geting {ret}")
|
||||
return ret
|
||||
|
||||
def get_out_dir_path(self) -> pathlib.Path:
|
||||
@ -115,8 +132,8 @@ class ReducedModelParams:
|
||||
x_max: float = 20
|
||||
y_min: float = -10
|
||||
y_max: float = 10
|
||||
z_min: float = 0
|
||||
z_max: float = 5
|
||||
z_min: float = 5
|
||||
z_max: float = 6.5
|
||||
w_log_min: float = -5
|
||||
w_log_max: float = 1
|
||||
count: int = 1
|
||||
|
@ -66,45 +66,37 @@ class Coalescer:
|
||||
_logger.debug(f"subdict keys: {subdict.keys()}")
|
||||
|
||||
# TODO hardcoding 3 generations
|
||||
if self.num_replicas != 3:
|
||||
raise ValueError(
|
||||
f"num replicas was {self.num_replicas}, but we've hard coded 3"
|
||||
)
|
||||
# if self.num_replicas != 3:
|
||||
# raise ValueError(
|
||||
# f"num replicas was {self.num_replicas}, but we've hard coded 3"
|
||||
# )
|
||||
# generations_keys = ["0", "1", "2"]
|
||||
|
||||
_logger.info(f"Going through generation {0}")
|
||||
|
||||
# 0th gen is easiest
|
||||
for model_key, val in subdict["0"].items():
|
||||
val["coalesced_prob"] = val["prob"]
|
||||
|
||||
weight1 = sum(
|
||||
if self.num_replicas > 1:
|
||||
for gen in range(1, self.num_replicas):
|
||||
_logger.info(f"Going through generation {gen}")
|
||||
|
||||
generation_weight = sum(
|
||||
[
|
||||
float(subdict["0"][key]["coalesced_prob"])
|
||||
* float(subdict["1"][key]["prob"])
|
||||
for key in subdict["1"].keys()
|
||||
float(subdict[str(gen - 1)][key]["coalesced_prob"])
|
||||
* float(subdict[str(gen)][key]["prob"])
|
||||
for key in subdict[str(gen)].keys()
|
||||
]
|
||||
)
|
||||
_logger.debug(weight1)
|
||||
for model_key, val in subdict["1"].items():
|
||||
_logger.debug(generation_weight)
|
||||
for model_key, val in subdict[str(gen)].items():
|
||||
val["coalesced_prob"] = (
|
||||
float(val["prob"])
|
||||
* float(subdict["0"][model_key]["coalesced_prob"])
|
||||
/ weight1
|
||||
* float(subdict[str(gen-1)][model_key]["coalesced_prob"])
|
||||
/ generation_weight
|
||||
)
|
||||
|
||||
weight2 = sum(
|
||||
[
|
||||
float(subdict["1"][key]["coalesced_prob"])
|
||||
* float(subdict["2"][key]["prob"])
|
||||
for key in subdict["2"].keys()
|
||||
]
|
||||
)
|
||||
_logger.debug(weight2)
|
||||
for model_key, val in subdict["2"].items():
|
||||
val["coalesced_prob"] = (
|
||||
float(val["prob"])
|
||||
* float(subdict["1"][model_key]["coalesced_prob"])
|
||||
/ weight2
|
||||
)
|
||||
|
||||
def coalesce_all(self):
|
||||
for actual_key in self.actual_dict.keys():
|
||||
|
@ -58,6 +58,7 @@ class StDevUsingCostFunction:
|
||||
dot_inputs_array,
|
||||
actual_measurement_array,
|
||||
actual_stdev_array,
|
||||
log_noise: bool = False,
|
||||
):
|
||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||
self.measurement_type = measurement_type
|
||||
@ -67,6 +68,16 @@ class StDevUsingCostFunction:
|
||||
self.actual_stdev_array = actual_stdev_array
|
||||
self.actual_stdev_array2 = actual_stdev_array**2
|
||||
|
||||
self.use_log_noise = log_noise
|
||||
self.log_actual = numpy.log(self.actual_measurement_array)
|
||||
self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2
|
||||
# if self.use_log_noise:
|
||||
# _logger.debug("remove these debugs later")
|
||||
# _logger.debug(self.actual_measurement_array)
|
||||
# _logger.debug(self.actual_stdev_array)
|
||||
# _logger.debug(self.log_actual)
|
||||
# _logger.debug(self.log_denom2)
|
||||
|
||||
def __call__(self, dipoles_to_test):
|
||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
||||
@ -76,7 +87,12 @@ class StDevUsingCostFunction:
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
|
||||
if self.use_log_noise:
|
||||
diffs = ((numpy.log(vals) - self.log_actual) **2) / self.log_denom2
|
||||
else:
|
||||
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2
|
||||
|
||||
return numpy.sqrt(diffs.mean(axis=-1))
|
||||
|
||||
|
||||
@ -202,6 +218,7 @@ class BinnedData:
|
||||
csv_dict: typing.Dict[str, typing.Any]
|
||||
measurement_type: str
|
||||
|
||||
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
||||
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]:
|
||||
if dot_name not in self.dots_dict:
|
||||
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
||||
@ -255,6 +272,7 @@ class BinnedData:
|
||||
def _stdev_cost_function(
|
||||
self,
|
||||
measurements: typing.Sequence[Measurement],
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
meas_array = numpy.array([m.dot_measurement.v for m in measurements])
|
||||
stdev_array = numpy.array([m.stdev for m in measurements])
|
||||
@ -266,7 +284,7 @@ class BinnedData:
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return StDevUsingCostFunction(
|
||||
self.measurement_type, input_array, meas_array, stdev_array
|
||||
self.measurement_type, input_array, meas_array, stdev_array, use_log_noise
|
||||
)
|
||||
|
||||
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
||||
@ -277,10 +295,13 @@ class BinnedData:
|
||||
)
|
||||
|
||||
def stdev_cost_function_filter(
|
||||
self, dot_names: typing.Sequence[str], target_cost: float
|
||||
self,
|
||||
dot_names: typing.Sequence[str],
|
||||
target_cost: float,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
measurements = self.measurements(dot_names)
|
||||
cost_function = self._stdev_cost_function(measurements)
|
||||
cost_function = self._stdev_cost_function(measurements, use_log_noise)
|
||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, target_cost
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import pathlib
|
||||
import logging
|
||||
|
||||
import kalpa.stages.stage01
|
||||
@ -6,9 +7,13 @@ import kalpa.stages.stage03
|
||||
import kalpa.stages.stage04
|
||||
import kalpa.common
|
||||
import tantri.dipoles.types
|
||||
import kalpa.config
|
||||
|
||||
import argparse
|
||||
|
||||
# try not to use this out side of main or when defining config stuff pls
|
||||
import numpy
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -18,6 +23,26 @@ class Runner:
|
||||
_logger.info(f"Initialising runner with {config=}")
|
||||
|
||||
def run(self):
|
||||
|
||||
if self.config.general_config.skip_to_stage is not None:
|
||||
|
||||
stage01 = kalpa.stages.stage01.Stage01Runner(self.config)
|
||||
stage02 = kalpa.stages.stage02.Stage02Runner(self.config)
|
||||
stage03 = kalpa.stages.stage03.Stage03Runner(self.config)
|
||||
stage04 = kalpa.stages.stage04.Stage04Runner(self.config)
|
||||
|
||||
stages = [stage01, stage02, stage03, stage04]
|
||||
|
||||
start = int(self.config.general_config.skip_to_stage)
|
||||
_logger.info(f"Received instruction to start at stage {start + 1}")
|
||||
for i, stage in enumerate(stages[start: 4]):
|
||||
_logger.info(f"*** Running stage {i + start + 1}")
|
||||
stage.run()
|
||||
|
||||
|
||||
else:
|
||||
# standard run, can keep old
|
||||
|
||||
_logger.info("*** Beginning Stage 01 ***")
|
||||
stage01 = kalpa.stages.stage01.Stage01Runner(self.config)
|
||||
stage01.run()
|
||||
@ -42,9 +67,16 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
"--override-root",
|
||||
type=str,
|
||||
help="A filename for logging to, if not provided will only log to stderr",
|
||||
help="If provided, override the root dir.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s", "--skip-to-stage",
|
||||
type=int,
|
||||
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
@ -55,23 +87,85 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
tantri_configs = [
|
||||
kalpa.TantriConfig(12345, 50, 0.5, 100000),
|
||||
kalpa.TantriConfig(123456, 50, 0.5, 100000),
|
||||
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
|
||||
]
|
||||
|
||||
|
||||
override_config = {
|
||||
# "test1": [
|
||||
# tantri.dipoles.types.DipoleTO(
|
||||
# numpy.array([0, 0, 100]),
|
||||
# numpy.array([-2, -2, 2.9]),
|
||||
# 0.0005
|
||||
# )
|
||||
# ],
|
||||
"two_dipole_connors_geom": [
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([-2, -2, 5.75]),
|
||||
0.0005
|
||||
),
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([6, 2, 5.75]),
|
||||
0.05
|
||||
)
|
||||
],
|
||||
"two_dipole_connors_geom_omegaswap": [
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([-2, -2, 5.75]),
|
||||
0.05
|
||||
),
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([6, 2, 5.75]),
|
||||
0.0005
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
generation_config = kalpa.GenerationConfig(
|
||||
tantri_configs=tantri_configs,
|
||||
counts=[1],
|
||||
num_replicas=3,
|
||||
orientations=[tantri.dipoles.types.Orientation.XY],
|
||||
counts=[3, 31],
|
||||
num_replicas=5,
|
||||
|
||||
# let's test this out
|
||||
# override_dipole_configs=override_config,
|
||||
|
||||
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||
num_bin_time_series=25,
|
||||
)
|
||||
|
||||
if args.override_root is None:
|
||||
_logger.info("root dir not given")
|
||||
root = pathlib.Path("plots0")
|
||||
else:
|
||||
root = pathlib.Path(args.override_root)
|
||||
|
||||
if args.skip_to_stage is not None:
|
||||
if args.skip_to_stage not in [1, 2, 3, 4]:
|
||||
raise ValueError(f"There is no stage {args.skip_to_stage}")
|
||||
else:
|
||||
skip = kalpa.config.SkipToStage(args.skip_to_stage - 1)
|
||||
else:
|
||||
skip = None
|
||||
|
||||
|
||||
general_config = kalpa.GeneralConfig(
|
||||
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL
|
||||
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
|
||||
out_dir_name=str(root / "out"),
|
||||
skip_to_stage=skip
|
||||
)
|
||||
|
||||
# kalpa.GeneralConfig
|
||||
|
||||
deepdog_config = kalpa.DeepdogConfig(
|
||||
costs_to_try=[10, 2, 1, 0.1],
|
||||
costs_to_try=[2, 1],
|
||||
max_monte_carlo_cycles_steps=20,
|
||||
target_success=200,
|
||||
use_log_noise=True,
|
||||
)
|
||||
|
||||
config = kalpa.Config(
|
||||
@ -80,7 +174,7 @@ def main():
|
||||
deepdog_config=deepdog_config,
|
||||
)
|
||||
|
||||
kalpa.common.set_up_logging(config, args.log_file)
|
||||
kalpa.common.set_up_logging(config, str(root / f"logs/{root}.log"))
|
||||
|
||||
_logger.info(f"Got {config=}")
|
||||
runner = Runner(config)
|
||||
|
@ -7,7 +7,10 @@ import logging
|
||||
import kalpa
|
||||
import kalpa.common
|
||||
import tantri.cli
|
||||
import tantri.cli.input_files
|
||||
import tantri.cli.input_files.write_dipoles
|
||||
import tantri.dipoles.types
|
||||
import typing
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -113,7 +116,9 @@ class Stage01Runner:
|
||||
for tantri_index, tantri_config in enumerate(
|
||||
self.config.generation_config.tantri_configs
|
||||
):
|
||||
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
|
||||
output_csv = directory / kalpa.common.tantri_full_output_name(
|
||||
tantri_index
|
||||
)
|
||||
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||
tantri_index
|
||||
)
|
||||
@ -131,8 +136,83 @@ class Stage01Runner:
|
||||
True,
|
||||
)
|
||||
|
||||
# This whole method is duplication and is ripe for refactor, but that's fine!
|
||||
# deliberately bad to get it done.
|
||||
# here we're going to be manually specifying dipoles as we have from our config
|
||||
def generate_override_dipole(
|
||||
self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO]
|
||||
):
|
||||
"""
|
||||
create a directory, populate it with stuff.
|
||||
|
||||
seed: still a seed integer to use
|
||||
override_name: the name of this dipole configuration, from config file
|
||||
override_dipoles: dipoles to override
|
||||
|
||||
"""
|
||||
|
||||
_logger.info(
|
||||
f"Writing override config {override_name} with dipoles: [{override_dipoles}]"
|
||||
)
|
||||
out = self.config.get_out_dir_path()
|
||||
directory = out / f"{override_name}"
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_logger.debug("generated override directory")
|
||||
|
||||
# config_json = directory / "generation_config.json"
|
||||
dipoles_json = directory / "dipoles.json"
|
||||
|
||||
# with open(config_json, "w") as conf_file:
|
||||
# params = kalpa.ReducedModelParams(
|
||||
# count=count, orientation=tantri.dipoles.types.Orientation(orientation)
|
||||
# )
|
||||
# _logger.debug(f"Got params {params=}")
|
||||
# json.dump(params.config_dict(seed), conf_file)
|
||||
# # json.dump(kalpa.common.model_config_dict(count, orientation, seed), conf_file)
|
||||
|
||||
# the original logic looked like this:
|
||||
# tantri.cli._generate_dipoles(config_json, dipoles_json, (seed, replica, 1))
|
||||
# We're replicating the bit that wrote the dipoles here, but that's a refactor opportunity
|
||||
with dipoles_json.open("w") as dipole_out:
|
||||
dipole_out.write(
|
||||
json.dumps(
|
||||
[dip.as_dict() for dip in override_dipoles],
|
||||
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder
|
||||
)
|
||||
)
|
||||
|
||||
_logger.info(f"Wrote to dipoles file {dipoles_json}")
|
||||
|
||||
# tantri.cli._write_apsd(dipoles_json, DOTS, X_ELECTRIC_FIELD, DELTA_T, NUM_ITERATIONS, NUM_BIN_TS, (index, replica, 2), output_csv, binned_csv, BIN_WIDTH_LOG, True)
|
||||
for tantri_index, tantri_config in enumerate(
|
||||
self.config.generation_config.tantri_configs
|
||||
):
|
||||
output_csv = directory / kalpa.common.tantri_full_output_name(
|
||||
tantri_index
|
||||
)
|
||||
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||
tantri_index
|
||||
)
|
||||
tantri.cli._write_apsd(
|
||||
dipoles_json,
|
||||
self.config.general_config.dots_json_name,
|
||||
self.config.general_config.measurement_type.value,
|
||||
tantri_config.delta_t,
|
||||
tantri_config.num_iterations,
|
||||
self.config.generation_config.num_bin_time_series,
|
||||
(seed, 2),
|
||||
output_csv,
|
||||
binned_csv,
|
||||
self.config.generation_config.bin_log_width,
|
||||
True,
|
||||
)
|
||||
|
||||
def run(self):
|
||||
seed_index = 0
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
# should be by default
|
||||
_logger.debug("no override needed!")
|
||||
for count in self.config.generation_config.counts:
|
||||
for orientation in self.config.generation_config.orientations:
|
||||
for replica in range(self.config.generation_config.num_replicas):
|
||||
@ -141,6 +221,10 @@ class Stage01Runner:
|
||||
)
|
||||
self.generate_single_subdir(seed_index, count, orientation, replica)
|
||||
seed_index += 1
|
||||
else:
|
||||
_logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}")
|
||||
for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items():
|
||||
self.generate_override_dipole(seed_index, override_name, override_dipoles)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -178,12 +178,16 @@ class Stage02Runner:
|
||||
|
||||
# TODO find way to store this as a global config file
|
||||
occupancies_dict = {
|
||||
1: (1000, 1000),
|
||||
10: (1000, 100),
|
||||
16: (10000, 10),
|
||||
31: (1000, 100),
|
||||
56: (1000, 100),
|
||||
100: (100, 100),
|
||||
1: (500, 1000),
|
||||
2: (250, 2000),
|
||||
3: (250, 2000),
|
||||
5: (100, 5000),
|
||||
10: (50, 10000),
|
||||
16: (50, 10000),
|
||||
17: (50, 10000),
|
||||
31: (50, 10000),
|
||||
56: (25, 20000),
|
||||
100: (5, 100000),
|
||||
}
|
||||
|
||||
mccount, mccountcycles = occupancies_dict[avg_filled]
|
||||
@ -212,7 +216,7 @@ class Stage02Runner:
|
||||
_logger.info(f"{deepdog_config=}")
|
||||
|
||||
stdev_cost_function_filters = [
|
||||
b.stdev_cost_function_filter(dot_names, cost) for b in binned_datas
|
||||
b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas
|
||||
]
|
||||
|
||||
_logger.debug(f"{stdev_cost_function_filters=}")
|
||||
|
@ -121,6 +121,31 @@ class Stage04Runner:
|
||||
out_list.append(row)
|
||||
return out_list
|
||||
|
||||
def read_merged_coalesced_csv_override(self, override_name: str) -> typing.Sequence:
|
||||
subdir_name = override_name
|
||||
subdir_path = self.config.get_out_dir_path() / subdir_name
|
||||
csv_path = (
|
||||
subdir_path
|
||||
/ kalpa.common.sorted_bayesruns_name()
|
||||
/ kalpa.common.merged_coalesced_name()
|
||||
)
|
||||
_logger.debug(f"Reading {csv_path=}")
|
||||
with csv_path.open(mode="r", newline="") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
out_list = []
|
||||
for row in reader:
|
||||
# We can't put any of the actual info in because it's totally arbitrary, but that's fine!
|
||||
|
||||
# normal_orientation = ORIENTATION_DICT[orientation]
|
||||
row["subdir_name"] = subdir_name
|
||||
# row["actual_orientation"] = ORIENTATION_DICT[orientation]
|
||||
# row["actual_avg_filled"] = count
|
||||
# row["generation_replica_index"] = replica
|
||||
# row["is_row_actual"] = is_actual(row, normal_orientation, count)
|
||||
out_list.append(row)
|
||||
return out_list
|
||||
|
||||
|
||||
def run(self):
|
||||
megamerged_path = (
|
||||
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
|
||||
@ -130,6 +155,10 @@ class Stage04Runner:
|
||||
with megamerged_path.open(mode="w", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
||||
writer.writeheader()
|
||||
|
||||
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
|
||||
for count in self.config.generation_config.counts:
|
||||
for orientation in self.config.generation_config.orientations:
|
||||
for replica in range(self.config.generation_config.num_replicas):
|
||||
@ -140,8 +169,21 @@ class Stage04Runner:
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
else:
|
||||
_logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}")
|
||||
for override_name in self.config.generation_config.override_dipole_configs.keys():
|
||||
_logger.info(f"Working for subdir {override_name}")
|
||||
rows = self.read_merged_coalesced_csv_override(
|
||||
override_name
|
||||
)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
# merge with inference
|
||||
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
|
||||
with megamerged_path.open(mode="r", newline="") as infile:
|
||||
# Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad:
|
||||
# megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES)
|
||||
@ -169,7 +211,8 @@ class Stage04Runner:
|
||||
for generation in dots.values():
|
||||
for row in generation.values():
|
||||
writer.writerow(row)
|
||||
|
||||
else:
|
||||
_logger.info("skipping inference metamerge, overridden dipole config specified")
|
||||
|
||||
def parse_args():
|
||||
|
||||
|
@ -10,8 +10,8 @@ python = ">=3.8.1,<3.10"
|
||||
pdme = "^1.5.0"
|
||||
|
||||
|
||||
deepdog = "^1.4.0"
|
||||
tantri = "^1.2.0"
|
||||
deepdog = "^1.5.0"
|
||||
tantri = "^1.3.0"
|
||||
tomli = "^2.0.1"
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.8.0"
|
||||
|
Loading…
x
Reference in New Issue
Block a user