diff --git a/kalpa/config.py b/kalpa/config.py index 740b6df..1ed44d0 100755 --- a/kalpa/config.py +++ b/kalpa/config.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field import typing import tantri.dipoles.types import pathlib -from enum import Enum +from enum import Enum, IntEnum import logging _logger = logging.getLogger(__name__) @@ -14,21 +14,30 @@ class MeasurementTypeEnum(Enum): POTENTIAL = "electric-potential" X_ELECTRIC_FIELD = "x-electric-field" +class SkipToStage(IntEnum): + # shouldn't need this lol + STAGE_01 = 0 + STAGE_02 = 1 + STAGE_03 = 2 + STAGE_04 = 3 + # Copy over some random constants to see if they're ever reused @dataclass(frozen=True) class GeneralConfig: - dots_json_name = "dots.json" - indexes_json_name = "indexes.json" - out_dir_name = "out" - log_pattern = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s" + dots_json_name: str = "dots.json" + indexes_json_name: str = "indexes.json" + out_dir_name: str = "out" + log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s" measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD root_directory: pathlib.Path = pathlib.Path.cwd() - mega_merged_name = "mega_merged_coalesced.csv" - mega_merged_inferenced_name = "mega_merged_coalesced_inferenced.csv" + mega_merged_name: str = "mega_merged_coalesced.csv" + mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv" + + skip_to_stage: typing.Optional[int] = None @dataclass(frozen=True) @@ -51,12 +60,18 @@ class GenerationConfig: tantri.dipoles.types.Orientation.XY, ] ) + # TODO: what's replica here? num_replicas: int = 3 + # the above three can be overrided with manually specified configurations + override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None + tantri_configs: typing.Sequence[TantriConfig] = field( default_factory=lambda: [TantriConfig()] ) + + num_bin_time_series: int = 25 bin_log_width: float = 0.25 @@ -70,6 +85,8 @@ class DeepdogConfig: costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1]) target_success: int = 1000 max_monte_carlo_cycles_steps: int = 20 + # Whetehr to use a log log cost function + use_log_noise: bool = False @dataclass(frozen=True) @@ -80,7 +97,7 @@ class Config: def absify(self, filename: str) -> pathlib.Path: ret = (self.general_config.root_directory / filename).resolve() - _logger.debug(f"Absifying {filename=}, getting {ret}") + _logger.debug(f"Absifying {filename=}, geting {ret}") return ret def get_out_dir_path(self) -> pathlib.Path: @@ -115,8 +132,8 @@ class ReducedModelParams: x_max: float = 20 y_min: float = -10 y_max: float = 10 - z_min: float = 0 - z_max: float = 5 + z_min: float = 5 + z_max: float = 6.5 w_log_min: float = -5 w_log_max: float = 1 count: int = 1 diff --git a/kalpa/inference_coalesce/coalescer.py b/kalpa/inference_coalesce/coalescer.py index f781a5b..87bbc99 100755 --- a/kalpa/inference_coalesce/coalescer.py +++ b/kalpa/inference_coalesce/coalescer.py @@ -66,45 +66,37 @@ class Coalescer: _logger.debug(f"subdict keys: {subdict.keys()}") # TODO hardcoding 3 generations - if self.num_replicas != 3: - raise ValueError( - f"num replicas was {self.num_replicas}, but we've hard coded 3" - ) + # if self.num_replicas != 3: + # raise ValueError( + # f"num replicas was {self.num_replicas}, but we've hard coded 3" + # ) # generations_keys = ["0", "1", "2"] + _logger.info(f"Going through generation {0}") + # 0th gen is easiest for model_key, val in subdict["0"].items(): val["coalesced_prob"] = val["prob"] - weight1 = sum( - [ - float(subdict["0"][key]["coalesced_prob"]) - * float(subdict["1"][key]["prob"]) - for key in subdict["1"].keys() - ] - ) - _logger.debug(weight1) - for model_key, val in subdict["1"].items(): - val["coalesced_prob"] = ( - float(val["prob"]) - * float(subdict["0"][model_key]["coalesced_prob"]) - / weight1 - ) + if self.num_replicas > 1: + for gen in range(1, self.num_replicas): + _logger.info(f"Going through generation {gen}") + + generation_weight = sum( + [ + float(subdict[str(gen - 1)][key]["coalesced_prob"]) + * float(subdict[str(gen)][key]["prob"]) + for key in subdict[str(gen)].keys() + ] + ) + _logger.debug(generation_weight) + for model_key, val in subdict[str(gen)].items(): + val["coalesced_prob"] = ( + float(val["prob"]) + * float(subdict[str(gen-1)][model_key]["coalesced_prob"]) + / generation_weight + ) - weight2 = sum( - [ - float(subdict["1"][key]["coalesced_prob"]) - * float(subdict["2"][key]["prob"]) - for key in subdict["2"].keys() - ] - ) - _logger.debug(weight2) - for model_key, val in subdict["2"].items(): - val["coalesced_prob"] = ( - float(val["prob"]) - * float(subdict["1"][model_key]["coalesced_prob"]) - / weight2 - ) def coalesce_all(self): for actual_key in self.actual_dict.keys(): diff --git a/kalpa/read_bin_csv.py b/kalpa/read_bin_csv.py index 203806a..97bd663 100755 --- a/kalpa/read_bin_csv.py +++ b/kalpa/read_bin_csv.py @@ -58,6 +58,7 @@ class StDevUsingCostFunction: dot_inputs_array, actual_measurement_array, actual_stdev_array, + log_noise: bool = False, ): _logger.info(f"Cost function with measurement type of {measurement_type}") self.measurement_type = measurement_type @@ -67,6 +68,16 @@ class StDevUsingCostFunction: self.actual_stdev_array = actual_stdev_array self.actual_stdev_array2 = actual_stdev_array**2 + self.use_log_noise = log_noise + self.log_actual = numpy.log(self.actual_measurement_array) + self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2 + # if self.use_log_noise: + # _logger.debug("remove these debugs later") + # _logger.debug(self.actual_measurement_array) + # _logger.debug(self.actual_stdev_array) + # _logger.debug(self.log_actual) + # _logger.debug(self.log_denom2) + def __call__(self, dipoles_to_test): if self.measurement_type == X_ELECTRIC_FIELD: vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses( @@ -76,7 +87,12 @@ class StDevUsingCostFunction: vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses( self.dot_inputs_array, dipoles_to_test ) - diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2 + + if self.use_log_noise: + diffs = ((numpy.log(vals) - self.log_actual) **2) / self.log_denom2 + else: + diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2 + return numpy.sqrt(diffs.mean(axis=-1)) @@ -202,6 +218,7 @@ class BinnedData: csv_dict: typing.Dict[str, typing.Any] measurement_type: str + # we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script. def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]: if dot_name not in self.dots_dict: raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}") @@ -255,6 +272,7 @@ class BinnedData: def _stdev_cost_function( self, measurements: typing.Sequence[Measurement], + use_log_noise: bool = False, ): meas_array = numpy.array([m.dot_measurement.v for m in measurements]) stdev_array = numpy.array([m.stdev for m in measurements]) @@ -266,7 +284,7 @@ class BinnedData: _logger.debug(f"Obtained {input_array=}") return StDevUsingCostFunction( - self.measurement_type, input_array, meas_array, stdev_array + self.measurement_type, input_array, meas_array, stdev_array, use_log_noise ) def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float): @@ -277,10 +295,13 @@ class BinnedData: ) def stdev_cost_function_filter( - self, dot_names: typing.Sequence[str], target_cost: float + self, + dot_names: typing.Sequence[str], + target_cost: float, + use_log_noise: bool = False, ): measurements = self.measurements(dot_names) - cost_function = self._stdev_cost_function(measurements) + cost_function = self._stdev_cost_function(measurements, use_log_noise) return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter( cost_function, target_cost ) diff --git a/kalpa/stages/__init__.py b/kalpa/stages/__init__.py index e42333a..556e73d 100755 --- a/kalpa/stages/__init__.py +++ b/kalpa/stages/__init__.py @@ -1,3 +1,4 @@ +import pathlib import logging import kalpa.stages.stage01 @@ -6,9 +7,13 @@ import kalpa.stages.stage03 import kalpa.stages.stage04 import kalpa.common import tantri.dipoles.types +import kalpa.config import argparse +# try not to use this out side of main or when defining config stuff pls +import numpy + _logger = logging.getLogger(__name__) @@ -18,21 +23,41 @@ class Runner: _logger.info(f"Initialising runner with {config=}") def run(self): - _logger.info("*** Beginning Stage 01 ***") - stage01 = kalpa.stages.stage01.Stage01Runner(self.config) - stage01.run() - _logger.info("*** Beginning Stage 02 ***") - stage02 = kalpa.stages.stage02.Stage02Runner(self.config) - stage02.run() + if self.config.general_config.skip_to_stage is not None: - _logger.info("*** Beginning Stage 03 ***") - stage03 = kalpa.stages.stage03.Stage03Runner(self.config) - stage03.run() + stage01 = kalpa.stages.stage01.Stage01Runner(self.config) + stage02 = kalpa.stages.stage02.Stage02Runner(self.config) + stage03 = kalpa.stages.stage03.Stage03Runner(self.config) + stage04 = kalpa.stages.stage04.Stage04Runner(self.config) - _logger.info("*** Beginning Stage 04 ***") - stage04 = kalpa.stages.stage04.Stage04Runner(self.config) - stage04.run() + stages = [stage01, stage02, stage03, stage04] + + start = int(self.config.general_config.skip_to_stage) + _logger.info(f"Received instruction to start at stage {start + 1}") + for i, stage in enumerate(stages[start: 4]): + _logger.info(f"*** Running stage {i + start + 1}") + stage.run() + + + else: + # standard run, can keep old + + _logger.info("*** Beginning Stage 01 ***") + stage01 = kalpa.stages.stage01.Stage01Runner(self.config) + stage01.run() + + _logger.info("*** Beginning Stage 02 ***") + stage02 = kalpa.stages.stage02.Stage02Runner(self.config) + stage02.run() + + _logger.info("*** Beginning Stage 03 ***") + stage03 = kalpa.stages.stage03.Stage03Runner(self.config) + stage03.run() + + _logger.info("*** Beginning Stage 04 ***") + stage04 = kalpa.stages.stage04.Stage04Runner(self.config) + stage04.run() def parse_args(): @@ -42,9 +67,16 @@ def parse_args(): ) parser.add_argument( - "--log-file", + "--override-root", type=str, - help="A filename for logging to, if not provided will only log to stderr", + help="If provided, override the root dir.", + default=None, + ) + + parser.add_argument( + "-s", "--skip-to-stage", + type=int, + help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.", default=None, ) args = parser.parse_args() @@ -55,23 +87,85 @@ def main(): args = parse_args() tantri_configs = [ - kalpa.TantriConfig(12345, 50, 0.5, 100000), + kalpa.TantriConfig(123456, 50, 0.5, 100000), # kalpa.TantriConfig(1234, 50, 0.0005, 10000), ] + + + override_config = { + # "test1": [ + # tantri.dipoles.types.DipoleTO( + # numpy.array([0, 0, 100]), + # numpy.array([-2, -2, 2.9]), + # 0.0005 + # ) + # ], + "two_dipole_connors_geom": [ + tantri.dipoles.types.DipoleTO( + numpy.array([0, 0, 100]), + numpy.array([-2, -2, 5.75]), + 0.0005 + ), + tantri.dipoles.types.DipoleTO( + numpy.array([0, 0, 100]), + numpy.array([6, 2, 5.75]), + 0.05 + ) + ], + "two_dipole_connors_geom_omegaswap": [ + tantri.dipoles.types.DipoleTO( + numpy.array([0, 0, 100]), + numpy.array([-2, -2, 5.75]), + 0.05 + ), + tantri.dipoles.types.DipoleTO( + numpy.array([0, 0, 100]), + numpy.array([6, 2, 5.75]), + 0.0005 + ) + ] + } + generation_config = kalpa.GenerationConfig( tantri_configs=tantri_configs, - counts=[1], - num_replicas=3, - orientations=[tantri.dipoles.types.Orientation.XY], + counts=[3, 31], + num_replicas=5, + + # let's test this out + # override_dipole_configs=override_config, + + orientations=[tantri.dipoles.types.Orientation.Z], num_bin_time_series=25, ) + + if args.override_root is None: + _logger.info("root dir not given") + root = pathlib.Path("plots0") + else: + root = pathlib.Path(args.override_root) + + if args.skip_to_stage is not None: + if args.skip_to_stage not in [1, 2, 3, 4]: + raise ValueError(f"There is no stage {args.skip_to_stage}") + else: + skip = kalpa.config.SkipToStage(args.skip_to_stage - 1) + else: + skip = None + + general_config = kalpa.GeneralConfig( - measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL + measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL, + out_dir_name=str(root / "out"), + skip_to_stage=skip ) + # kalpa.GeneralConfig + deepdog_config = kalpa.DeepdogConfig( - costs_to_try=[10, 2, 1, 0.1], + costs_to_try=[2, 1], max_monte_carlo_cycles_steps=20, + target_success=200, + use_log_noise=True, ) config = kalpa.Config( @@ -80,7 +174,7 @@ def main(): deepdog_config=deepdog_config, ) - kalpa.common.set_up_logging(config, args.log_file) + kalpa.common.set_up_logging(config, str(root / f"logs/{root}.log")) _logger.info(f"Got {config=}") runner = Runner(config) diff --git a/kalpa/stages/stage01.py b/kalpa/stages/stage01.py index 6651d74..118bf4d 100755 --- a/kalpa/stages/stage01.py +++ b/kalpa/stages/stage01.py @@ -7,7 +7,10 @@ import logging import kalpa import kalpa.common import tantri.cli +import tantri.cli.input_files +import tantri.cli.input_files.write_dipoles import tantri.dipoles.types +import typing _logger = logging.getLogger(__name__) @@ -113,7 +116,9 @@ class Stage01Runner: for tantri_index, tantri_config in enumerate( self.config.generation_config.tantri_configs ): - output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index) + output_csv = directory / kalpa.common.tantri_full_output_name( + tantri_index + ) binned_csv = directory / kalpa.common.tantri_binned_output_name( tantri_index ) @@ -131,16 +136,95 @@ class Stage01Runner: True, ) + # This whole method is duplication and is ripe for refactor, but that's fine! + # deliberately bad to get it done. + # here we're going to be manually specifying dipoles as we have from our config + def generate_override_dipole( + self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO] + ): + """ + create a directory, populate it with stuff. + + seed: still a seed integer to use + override_name: the name of this dipole configuration, from config file + override_dipoles: dipoles to override + + """ + + _logger.info( + f"Writing override config {override_name} with dipoles: [{override_dipoles}]" + ) + out = self.config.get_out_dir_path() + directory = out / f"{override_name}" + directory.mkdir(parents=True, exist_ok=True) + + _logger.debug("generated override directory") + + # config_json = directory / "generation_config.json" + dipoles_json = directory / "dipoles.json" + + # with open(config_json, "w") as conf_file: + # params = kalpa.ReducedModelParams( + # count=count, orientation=tantri.dipoles.types.Orientation(orientation) + # ) + # _logger.debug(f"Got params {params=}") + # json.dump(params.config_dict(seed), conf_file) + # # json.dump(kalpa.common.model_config_dict(count, orientation, seed), conf_file) + + # the original logic looked like this: + # tantri.cli._generate_dipoles(config_json, dipoles_json, (seed, replica, 1)) + # We're replicating the bit that wrote the dipoles here, but that's a refactor opportunity + with dipoles_json.open("w") as dipole_out: + dipole_out.write( + json.dumps( + [dip.as_dict() for dip in override_dipoles], + cls=tantri.cli.input_files.write_dipoles.NumpyEncoder + ) + ) + + _logger.info(f"Wrote to dipoles file {dipoles_json}") + + # tantri.cli._write_apsd(dipoles_json, DOTS, X_ELECTRIC_FIELD, DELTA_T, NUM_ITERATIONS, NUM_BIN_TS, (index, replica, 2), output_csv, binned_csv, BIN_WIDTH_LOG, True) + for tantri_index, tantri_config in enumerate( + self.config.generation_config.tantri_configs + ): + output_csv = directory / kalpa.common.tantri_full_output_name( + tantri_index + ) + binned_csv = directory / kalpa.common.tantri_binned_output_name( + tantri_index + ) + tantri.cli._write_apsd( + dipoles_json, + self.config.general_config.dots_json_name, + self.config.general_config.measurement_type.value, + tantri_config.delta_t, + tantri_config.num_iterations, + self.config.generation_config.num_bin_time_series, + (seed, 2), + output_csv, + binned_csv, + self.config.generation_config.bin_log_width, + True, + ) + def run(self): seed_index = 0 - for count in self.config.generation_config.counts: - for orientation in self.config.generation_config.orientations: - for replica in range(self.config.generation_config.num_replicas): - _logger.info( - f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}" - ) - self.generate_single_subdir(seed_index, count, orientation, replica) - seed_index += 1 + if self.config.generation_config.override_dipole_configs is None: + # should be by default + _logger.debug("no override needed!") + for count in self.config.generation_config.counts: + for orientation in self.config.generation_config.orientations: + for replica in range(self.config.generation_config.num_replicas): + _logger.info( + f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}" + ) + self.generate_single_subdir(seed_index, count, orientation, replica) + seed_index += 1 + else: + _logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}") + for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items(): + self.generate_override_dipole(seed_index, override_name, override_dipoles) def parse_args(): diff --git a/kalpa/stages/stage02.py b/kalpa/stages/stage02.py index 0d1d4ff..ef849d8 100755 --- a/kalpa/stages/stage02.py +++ b/kalpa/stages/stage02.py @@ -178,12 +178,16 @@ class Stage02Runner: # TODO find way to store this as a global config file occupancies_dict = { - 1: (1000, 1000), - 10: (1000, 100), - 16: (10000, 10), - 31: (1000, 100), - 56: (1000, 100), - 100: (100, 100), + 1: (500, 1000), + 2: (250, 2000), + 3: (250, 2000), + 5: (100, 5000), + 10: (50, 10000), + 16: (50, 10000), + 17: (50, 10000), + 31: (50, 10000), + 56: (25, 20000), + 100: (5, 100000), } mccount, mccountcycles = occupancies_dict[avg_filled] @@ -212,7 +216,7 @@ class Stage02Runner: _logger.info(f"{deepdog_config=}") stdev_cost_function_filters = [ - b.stdev_cost_function_filter(dot_names, cost) for b in binned_datas + b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas ] _logger.debug(f"{stdev_cost_function_filters=}") diff --git a/kalpa/stages/stage04.py b/kalpa/stages/stage04.py index b94571a..a5daefd 100755 --- a/kalpa/stages/stage04.py +++ b/kalpa/stages/stage04.py @@ -121,6 +121,31 @@ class Stage04Runner: out_list.append(row) return out_list + def read_merged_coalesced_csv_override(self, override_name: str) -> typing.Sequence: + subdir_name = override_name + subdir_path = self.config.get_out_dir_path() / subdir_name + csv_path = ( + subdir_path + / kalpa.common.sorted_bayesruns_name() + / kalpa.common.merged_coalesced_name() + ) + _logger.debug(f"Reading {csv_path=}") + with csv_path.open(mode="r", newline="") as csvfile: + reader = csv.DictReader(csvfile) + out_list = [] + for row in reader: + # We can't put any of the actual info in because it's totally arbitrary, but that's fine! + + # normal_orientation = ORIENTATION_DICT[orientation] + row["subdir_name"] = subdir_name + # row["actual_orientation"] = ORIENTATION_DICT[orientation] + # row["actual_avg_filled"] = count + # row["generation_replica_index"] = replica + # row["is_row_actual"] = is_actual(row, normal_orientation, count) + out_list.append(row) + return out_list + + def run(self): megamerged_path = ( self.config.get_out_dir_path() / self.config.general_config.mega_merged_name @@ -130,46 +155,64 @@ class Stage04Runner: with megamerged_path.open(mode="w", newline="") as outfile: writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES) writer.writeheader() - for count in self.config.generation_config.counts: - for orientation in self.config.generation_config.orientations: - for replica in range(self.config.generation_config.num_replicas): - _logger.info(f"Reading {count=} {orientation=} {replica=}") - rows = self.read_merged_coalesced_csv( - orientation, count, replica - ) - for row in rows: - writer.writerow(row) + + + if self.config.generation_config.override_dipole_configs is None: + + for count in self.config.generation_config.counts: + for orientation in self.config.generation_config.orientations: + for replica in range(self.config.generation_config.num_replicas): + _logger.info(f"Reading {count=} {orientation=} {replica=}") + rows = self.read_merged_coalesced_csv( + orientation, count, replica + ) + for row in rows: + writer.writerow(row) + + else: + _logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}") + for override_name in self.config.generation_config.override_dipole_configs.keys(): + _logger.info(f"Working for subdir {override_name}") + rows = self.read_merged_coalesced_csv_override( + override_name + ) + for row in rows: + writer.writerow(row) + # merge with inference - with megamerged_path.open(mode="r", newline="") as infile: - # Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad: - # megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES) - megamerged_reader = csv.DictReader(infile) - rows = [row for row in megamerged_reader] - _logger.debug(rows[0]) - coalescer = kalpa.inference_coalesce.Coalescer( - rows, num_replicas=self.config.generation_config.num_replicas - ) - _logger.info(coalescer.actual_dict.keys()) + if self.config.generation_config.override_dipole_configs is None: - # coalescer.coalesce_generations(("fixedxy", "1"), "dot1") + with megamerged_path.open(mode="r", newline="") as infile: + # Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad: + # megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES) + megamerged_reader = csv.DictReader(infile) + rows = [row for row in megamerged_reader] + _logger.debug(rows[0]) + coalescer = kalpa.inference_coalesce.Coalescer( + rows, num_replicas=self.config.generation_config.num_replicas + ) + _logger.info(coalescer.actual_dict.keys()) - coalesced = coalescer.coalesce_all() + # coalescer.coalesce_generations(("fixedxy", "1"), "dot1") - inferenced_path = ( - self.config.get_out_dir_path() - / self.config.general_config.mega_merged_inferenced_name - ) - with inferenced_path.open(mode="w", newline="") as outfile: - writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES) - writer.writeheader() - for val in coalesced.values(): - for dots in val.values(): - for generation in dots.values(): - for row in generation.values(): - writer.writerow(row) + coalesced = coalescer.coalesce_all() + inferenced_path = ( + self.config.get_out_dir_path() + / self.config.general_config.mega_merged_inferenced_name + ) + with inferenced_path.open(mode="w", newline="") as outfile: + writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES) + writer.writeheader() + for val in coalesced.values(): + for dots in val.values(): + for generation in dots.values(): + for row in generation.values(): + writer.writerow(row) + else: + _logger.info("skipping inference metamerge, overridden dipole config specified") def parse_args(): diff --git a/pyproject.toml b/pyproject.toml index 5708478..4b631e7 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ python = ">=3.8.1,<3.10" pdme = "^1.5.0" -deepdog = "^1.4.0" -tantri = "^1.2.0" +deepdog = "^1.5.0" +tantri = "^1.3.0" tomli = "^2.0.1" [tool.poetry.group.dev.dependencies] black = "^24.8.0"