From 31ba558364a9be25a52cf216204f99c23240bc14 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 23 Feb 2025 22:54:00 -0600 Subject: [PATCH] fix: better inference with overrides --- kalpaa/stages/stage04.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/kalpaa/stages/stage04.py b/kalpaa/stages/stage04.py index 8cff660..619f293 100644 --- a/kalpaa/stages/stage04.py +++ b/kalpaa/stages/stage04.py @@ -155,8 +155,23 @@ class Stage04Runner: writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES) writer.writeheader() - if self.config.generation_config.override_dipole_configs is None: + if self.config.generation_config.override_dipole_configs is not None: + override_names = self.config.generation_config.override_dipole_configs.keys() + elif self.config.generation_config.override_measurement_filesets is not None: + override_names = self.config.generation_config.override_measurement_filesets.keys() + else: + override_names = None + if (override_names is not None): + _logger.debug( + f"We had overridden dipole config, using override {override_names}" + ) + for override_name in override_names: + _logger.info(f"Working for subdir {override_name}") + rows = self.read_merged_coalesced_csv_override(override_name) + for row in rows: + writer.writerow(row) + else: for count in self.config.generation_config.counts: for orientation in self.config.generation_config.orientations: for replica in range( @@ -169,21 +184,9 @@ class Stage04Runner: for row in rows: writer.writerow(row) - else: - _logger.debug( - f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}" - ) - for ( - override_name - ) in self.config.generation_config.override_dipole_configs.keys(): - _logger.info(f"Working for subdir {override_name}") - rows = self.read_merged_coalesced_csv_override(override_name) - for row in rows: - writer.writerow(row) - # merge with inference - if self.config.generation_config.override_dipole_configs is None: + if override_names is None: with megamerged_path.open(mode="r", newline="") as infile: # Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad: