fmt: formatting changes

This commit is contained in:
Deepak Mallubhotla 2024-12-30 04:16:31 +00:00
parent 4496a6b4da
commit 0d8213c500
Signed by: deepak
GPG Key ID: 47831B15427F5A55
7 changed files with 73 additions and 61 deletions

View File

@ -14,6 +14,7 @@ class MeasurementTypeEnum(Enum):
POTENTIAL = "electric-potential" POTENTIAL = "electric-potential"
X_ELECTRIC_FIELD = "x-electric-field" X_ELECTRIC_FIELD = "x-electric-field"
class SkipToStage(IntEnum): class SkipToStage(IntEnum):
# shouldn't need this lol # shouldn't need this lol
STAGE_01 = 0 STAGE_01 = 0
@ -30,7 +31,9 @@ class GeneralConfig:
dots_json_name: str = "dots.json" dots_json_name: str = "dots.json"
indexes_json_name: str = "indexes.json" indexes_json_name: str = "indexes.json"
out_dir_name: str = "out" out_dir_name: str = "out"
log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s" log_pattern: str = (
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
)
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
root_directory: pathlib.Path = pathlib.Path.cwd() root_directory: pathlib.Path = pathlib.Path.cwd()
@ -64,14 +67,14 @@ class GenerationConfig:
num_replicas: int = 3 num_replicas: int = 3
# the above three can be overrided with manually specified configurations # the above three can be overrided with manually specified configurations
override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None override_dipole_configs: typing.Optional[
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
] = None
tantri_configs: typing.Sequence[TantriConfig] = field( tantri_configs: typing.Sequence[TantriConfig] = field(
default_factory=lambda: [TantriConfig()] default_factory=lambda: [TantriConfig()]
) )
num_bin_time_series: int = 25 num_bin_time_series: int = 25
bin_log_width: float = 0.25 bin_log_width: float = 0.25

View File

@ -79,7 +79,7 @@ class Coalescer:
val["coalesced_prob"] = val["prob"] val["coalesced_prob"] = val["prob"]
if self.num_replicas > 1: if self.num_replicas > 1:
for gen in range(1, self.num_replicas): for gen in range(1, self.num_replicas):
_logger.info(f"Going through generation {gen}") _logger.info(f"Going through generation {gen}")
generation_weight = sum( generation_weight = sum(
@ -93,11 +93,10 @@ class Coalescer:
for model_key, val in subdict[str(gen)].items(): for model_key, val in subdict[str(gen)].items():
val["coalesced_prob"] = ( val["coalesced_prob"] = (
float(val["prob"]) float(val["prob"])
* float(subdict[str(gen-1)][model_key]["coalesced_prob"]) * float(subdict[str(gen - 1)][model_key]["coalesced_prob"])
/ generation_weight / generation_weight
) )
def coalesce_all(self): def coalesce_all(self):
for actual_key in self.actual_dict.keys(): for actual_key in self.actual_dict.keys():
for dot_key in self.actual_dict[actual_key].keys(): for dot_key in self.actual_dict[actual_key].keys():

View File

@ -70,7 +70,10 @@ class StDevUsingCostFunction:
self.use_log_noise = log_noise self.use_log_noise = log_noise
self.log_actual = numpy.log(self.actual_measurement_array) self.log_actual = numpy.log(self.actual_measurement_array)
self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2 self.log_denom2 = (
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
- numpy.log(self.actual_measurement_array)
) ** 2
# if self.use_log_noise: # if self.use_log_noise:
# _logger.debug("remove these debugs later") # _logger.debug("remove these debugs later")
# _logger.debug(self.actual_measurement_array) # _logger.debug(self.actual_measurement_array)
@ -89,9 +92,11 @@ class StDevUsingCostFunction:
) )
if self.use_log_noise: if self.use_log_noise:
diffs = ((numpy.log(vals) - self.log_actual) **2) / self.log_denom2 diffs = ((numpy.log(vals) - self.log_actual) ** 2) / self.log_denom2
else: else:
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2 diffs = (
(vals - self.actual_measurement_array) ** 2
) / self.actual_stdev_array2
return numpy.sqrt(diffs.mean(axis=-1)) return numpy.sqrt(diffs.mean(axis=-1))

View File

@ -35,11 +35,10 @@ class Runner:
start = int(self.config.general_config.skip_to_stage) start = int(self.config.general_config.skip_to_stage)
_logger.info(f"Received instruction to start at stage {start + 1}") _logger.info(f"Received instruction to start at stage {start + 1}")
for i, stage in enumerate(stages[start: 4]): for i, stage in enumerate(stages[start:4]):
_logger.info(f"*** Running stage {i + start + 1}") _logger.info(f"*** Running stage {i + start + 1}")
stage.run() stage.run()
else: else:
# standard run, can keep old # standard run, can keep old
@ -74,7 +73,8 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"-s", "--skip-to-stage", "-s",
"--skip-to-stage",
type=int, type=int,
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.", help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
default=None, default=None,
@ -91,7 +91,6 @@ def main():
# kalpa.TantriConfig(1234, 50, 0.0005, 10000), # kalpa.TantriConfig(1234, 50, 0.0005, 10000),
] ]
override_config = { override_config = {
# "test1": [ # "test1": [
# tantri.dipoles.types.DipoleTO( # tantri.dipoles.types.DipoleTO(
@ -102,38 +101,28 @@ def main():
# ], # ],
"two_dipole_connors_geom": [ "two_dipole_connors_geom": [
tantri.dipoles.types.DipoleTO( tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]), numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.0005
numpy.array([-2, -2, 5.75]),
0.0005
), ),
tantri.dipoles.types.DipoleTO( tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]), numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.05
numpy.array([6, 2, 5.75]), ),
0.05
)
], ],
"two_dipole_connors_geom_omegaswap": [ "two_dipole_connors_geom_omegaswap": [
tantri.dipoles.types.DipoleTO( tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]), numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.05
numpy.array([-2, -2, 5.75]),
0.05
), ),
tantri.dipoles.types.DipoleTO( tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]), numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.0005
numpy.array([6, 2, 5.75]), ),
0.0005 ],
)
]
} }
generation_config = kalpa.GenerationConfig( generation_config = kalpa.GenerationConfig(
tantri_configs=tantri_configs, tantri_configs=tantri_configs,
counts=[3, 31], counts=[3, 31],
num_replicas=5, num_replicas=5,
# let's test this out # let's test this out
# override_dipole_configs=override_config, # override_dipole_configs=override_config,
orientations=[tantri.dipoles.types.Orientation.Z], orientations=[tantri.dipoles.types.Orientation.Z],
num_bin_time_series=25, num_bin_time_series=25,
) )
@ -150,13 +139,12 @@ def main():
else: else:
skip = kalpa.config.SkipToStage(args.skip_to_stage - 1) skip = kalpa.config.SkipToStage(args.skip_to_stage - 1)
else: else:
skip = None skip = None
general_config = kalpa.GeneralConfig( general_config = kalpa.GeneralConfig(
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL, measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
out_dir_name=str(root / "out"), out_dir_name=str(root / "out"),
skip_to_stage=skip skip_to_stage=skip,
) )
# kalpa.GeneralConfig # kalpa.GeneralConfig

View File

@ -116,9 +116,7 @@ class Stage01Runner:
for tantri_index, tantri_config in enumerate( for tantri_index, tantri_config in enumerate(
self.config.generation_config.tantri_configs self.config.generation_config.tantri_configs
): ):
output_csv = directory / kalpa.common.tantri_full_output_name( output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
tantri_index
)
binned_csv = directory / kalpa.common.tantri_binned_output_name( binned_csv = directory / kalpa.common.tantri_binned_output_name(
tantri_index tantri_index
) )
@ -140,7 +138,10 @@ class Stage01Runner:
# deliberately bad to get it done. # deliberately bad to get it done.
# here we're going to be manually specifying dipoles as we have from our config # here we're going to be manually specifying dipoles as we have from our config
def generate_override_dipole( def generate_override_dipole(
self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO] self,
seed: int,
override_name: str,
override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO],
): ):
""" """
create a directory, populate it with stuff. create a directory, populate it with stuff.
@ -157,9 +158,9 @@ class Stage01Runner:
out = self.config.get_out_dir_path() out = self.config.get_out_dir_path()
directory = out / f"{override_name}" directory = out / f"{override_name}"
directory.mkdir(parents=True, exist_ok=True) directory.mkdir(parents=True, exist_ok=True)
_logger.debug("generated override directory") _logger.debug("generated override directory")
# config_json = directory / "generation_config.json" # config_json = directory / "generation_config.json"
dipoles_json = directory / "dipoles.json" dipoles_json = directory / "dipoles.json"
@ -178,7 +179,7 @@ class Stage01Runner:
dipole_out.write( dipole_out.write(
json.dumps( json.dumps(
[dip.as_dict() for dip in override_dipoles], [dip.as_dict() for dip in override_dipoles],
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder cls=tantri.cli.input_files.write_dipoles.NumpyEncoder,
) )
) )
@ -188,9 +189,7 @@ class Stage01Runner:
for tantri_index, tantri_config in enumerate( for tantri_index, tantri_config in enumerate(
self.config.generation_config.tantri_configs self.config.generation_config.tantri_configs
): ):
output_csv = directory / kalpa.common.tantri_full_output_name( output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
tantri_index
)
binned_csv = directory / kalpa.common.tantri_binned_output_name( binned_csv = directory / kalpa.common.tantri_binned_output_name(
tantri_index tantri_index
) )
@ -219,12 +218,21 @@ class Stage01Runner:
_logger.info( _logger.info(
f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}" f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}"
) )
self.generate_single_subdir(seed_index, count, orientation, replica) self.generate_single_subdir(
seed_index, count, orientation, replica
)
seed_index += 1 seed_index += 1
else: else:
_logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}") _logger.debug(
for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items(): f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
self.generate_override_dipole(seed_index, override_name, override_dipoles) )
for (
override_name,
override_dipoles,
) in self.config.generation_config.override_dipole_configs.items():
self.generate_override_dipole(
seed_index, override_name, override_dipoles
)
def parse_args(): def parse_args():

View File

@ -216,7 +216,10 @@ class Stage02Runner:
_logger.info(f"{deepdog_config=}") _logger.info(f"{deepdog_config=}")
stdev_cost_function_filters = [ stdev_cost_function_filters = [
b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas b.stdev_cost_function_filter(
dot_names, cost, self.config.deepdog_config.use_log_noise
)
for b in binned_datas
] ]
_logger.debug(f"{stdev_cost_function_filters=}") _logger.debug(f"{stdev_cost_function_filters=}")

View File

@ -145,7 +145,6 @@ class Stage04Runner:
out_list.append(row) out_list.append(row)
return out_list return out_list
def run(self): def run(self):
megamerged_path = ( megamerged_path = (
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
@ -156,30 +155,32 @@ class Stage04Runner:
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES) writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
writer.writeheader() writer.writeheader()
if self.config.generation_config.override_dipole_configs is None: if self.config.generation_config.override_dipole_configs is None:
for count in self.config.generation_config.counts: for count in self.config.generation_config.counts:
for orientation in self.config.generation_config.orientations: for orientation in self.config.generation_config.orientations:
for replica in range(self.config.generation_config.num_replicas): for replica in range(
self.config.generation_config.num_replicas
):
_logger.info(f"Reading {count=} {orientation=} {replica=}") _logger.info(f"Reading {count=} {orientation=} {replica=}")
rows = self.read_merged_coalesced_csv( rows = self.read_merged_coalesced_csv(
orientation, count, replica orientation, count, replica
) )
for row in rows: for row in rows:
writer.writerow(row) writer.writerow(row)
else: else:
_logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}") _logger.debug(
for override_name in self.config.generation_config.override_dipole_configs.keys(): f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}"
)
for (
override_name
) in self.config.generation_config.override_dipole_configs.keys():
_logger.info(f"Working for subdir {override_name}") _logger.info(f"Working for subdir {override_name}")
rows = self.read_merged_coalesced_csv_override( rows = self.read_merged_coalesced_csv_override(override_name)
override_name
)
for row in rows: for row in rows:
writer.writerow(row) writer.writerow(row)
# merge with inference # merge with inference
if self.config.generation_config.override_dipole_configs is None: if self.config.generation_config.override_dipole_configs is None:
@ -204,7 +205,9 @@ class Stage04Runner:
/ self.config.general_config.mega_merged_inferenced_name / self.config.general_config.mega_merged_inferenced_name
) )
with inferenced_path.open(mode="w", newline="") as outfile: with inferenced_path.open(mode="w", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES) writer = csv.DictWriter(
outfile, fieldnames=INFERENCED_OUT_FIELDNAMES
)
writer.writeheader() writer.writeheader()
for val in coalesced.values(): for val in coalesced.values():
for dots in val.values(): for dots in val.values():
@ -212,7 +215,10 @@ class Stage04Runner:
for row in generation.values(): for row in generation.values():
writer.writerow(row) writer.writerow(row)
else: else:
_logger.info("skipping inference metamerge, overridden dipole config specified") _logger.info(
"skipping inference metamerge, overridden dipole config specified"
)
def parse_args(): def parse_args():