fmt: formatting changes

This commit is contained in:
Deepak Mallubhotla 2024-12-30 04:16:31 +00:00
parent 4496a6b4da
commit 0d8213c500
Signed by: deepak
GPG Key ID: 47831B15427F5A55
7 changed files with 73 additions and 61 deletions

View File

@ -14,6 +14,7 @@ class MeasurementTypeEnum(Enum):
POTENTIAL = "electric-potential"
X_ELECTRIC_FIELD = "x-electric-field"
class SkipToStage(IntEnum):
# shouldn't need this lol
STAGE_01 = 0
@ -30,7 +31,9 @@ class GeneralConfig:
dots_json_name: str = "dots.json"
indexes_json_name: str = "indexes.json"
out_dir_name: str = "out"
log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
log_pattern: str = (
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
)
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
root_directory: pathlib.Path = pathlib.Path.cwd()
@ -64,14 +67,14 @@ class GenerationConfig:
num_replicas: int = 3
# the above three can be overrided with manually specified configurations
override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None
override_dipole_configs: typing.Optional[
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
] = None
tantri_configs: typing.Sequence[TantriConfig] = field(
default_factory=lambda: [TantriConfig()]
)
num_bin_time_series: int = 25
bin_log_width: float = 0.25

View File

@ -93,11 +93,10 @@ class Coalescer:
for model_key, val in subdict[str(gen)].items():
val["coalesced_prob"] = (
float(val["prob"])
* float(subdict[str(gen-1)][model_key]["coalesced_prob"])
* float(subdict[str(gen - 1)][model_key]["coalesced_prob"])
/ generation_weight
)
def coalesce_all(self):
for actual_key in self.actual_dict.keys():
for dot_key in self.actual_dict[actual_key].keys():

View File

@ -70,7 +70,10 @@ class StDevUsingCostFunction:
self.use_log_noise = log_noise
self.log_actual = numpy.log(self.actual_measurement_array)
self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2
self.log_denom2 = (
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
- numpy.log(self.actual_measurement_array)
) ** 2
# if self.use_log_noise:
# _logger.debug("remove these debugs later")
# _logger.debug(self.actual_measurement_array)
@ -89,9 +92,11 @@ class StDevUsingCostFunction:
)
if self.use_log_noise:
diffs = ((numpy.log(vals) - self.log_actual) **2) / self.log_denom2
diffs = ((numpy.log(vals) - self.log_actual) ** 2) / self.log_denom2
else:
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2
diffs = (
(vals - self.actual_measurement_array) ** 2
) / self.actual_stdev_array2
return numpy.sqrt(diffs.mean(axis=-1))

View File

@ -35,11 +35,10 @@ class Runner:
start = int(self.config.general_config.skip_to_stage)
_logger.info(f"Received instruction to start at stage {start + 1}")
for i, stage in enumerate(stages[start: 4]):
for i, stage in enumerate(stages[start:4]):
_logger.info(f"*** Running stage {i + start + 1}")
stage.run()
else:
# standard run, can keep old
@ -74,7 +73,8 @@ def parse_args():
)
parser.add_argument(
"-s", "--skip-to-stage",
"-s",
"--skip-to-stage",
type=int,
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
default=None,
@ -91,7 +91,6 @@ def main():
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
]
override_config = {
# "test1": [
# tantri.dipoles.types.DipoleTO(
@ -102,38 +101,28 @@ def main():
# ],
"two_dipole_connors_geom": [
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([-2, -2, 5.75]),
0.0005
numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.0005
),
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([6, 2, 5.75]),
0.05
)
numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.05
),
],
"two_dipole_connors_geom_omegaswap": [
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([-2, -2, 5.75]),
0.05
numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.05
),
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([6, 2, 5.75]),
0.0005
)
]
numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.0005
),
],
}
generation_config = kalpa.GenerationConfig(
tantri_configs=tantri_configs,
counts=[3, 31],
num_replicas=5,
# let's test this out
# override_dipole_configs=override_config,
orientations=[tantri.dipoles.types.Orientation.Z],
num_bin_time_series=25,
)
@ -150,13 +139,12 @@ def main():
else:
skip = kalpa.config.SkipToStage(args.skip_to_stage - 1)
else:
skip = None
skip = None
general_config = kalpa.GeneralConfig(
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
out_dir_name=str(root / "out"),
skip_to_stage=skip
skip_to_stage=skip,
)
# kalpa.GeneralConfig

View File

@ -116,9 +116,7 @@ class Stage01Runner:
for tantri_index, tantri_config in enumerate(
self.config.generation_config.tantri_configs
):
output_csv = directory / kalpa.common.tantri_full_output_name(
tantri_index
)
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
binned_csv = directory / kalpa.common.tantri_binned_output_name(
tantri_index
)
@ -140,7 +138,10 @@ class Stage01Runner:
# deliberately bad to get it done.
# here we're going to be manually specifying dipoles as we have from our config
def generate_override_dipole(
self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO]
self,
seed: int,
override_name: str,
override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO],
):
"""
create a directory, populate it with stuff.
@ -178,7 +179,7 @@ class Stage01Runner:
dipole_out.write(
json.dumps(
[dip.as_dict() for dip in override_dipoles],
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder,
)
)
@ -188,9 +189,7 @@ class Stage01Runner:
for tantri_index, tantri_config in enumerate(
self.config.generation_config.tantri_configs
):
output_csv = directory / kalpa.common.tantri_full_output_name(
tantri_index
)
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
binned_csv = directory / kalpa.common.tantri_binned_output_name(
tantri_index
)
@ -219,12 +218,21 @@ class Stage01Runner:
_logger.info(
f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}"
)
self.generate_single_subdir(seed_index, count, orientation, replica)
self.generate_single_subdir(
seed_index, count, orientation, replica
)
seed_index += 1
else:
_logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}")
for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items():
self.generate_override_dipole(seed_index, override_name, override_dipoles)
_logger.debug(
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
)
for (
override_name,
override_dipoles,
) in self.config.generation_config.override_dipole_configs.items():
self.generate_override_dipole(
seed_index, override_name, override_dipoles
)
def parse_args():

View File

@ -216,7 +216,10 @@ class Stage02Runner:
_logger.info(f"{deepdog_config=}")
stdev_cost_function_filters = [
b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas
b.stdev_cost_function_filter(
dot_names, cost, self.config.deepdog_config.use_log_noise
)
for b in binned_datas
]
_logger.debug(f"{stdev_cost_function_filters=}")

View File

@ -145,7 +145,6 @@ class Stage04Runner:
out_list.append(row)
return out_list
def run(self):
megamerged_path = (
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
@ -156,12 +155,13 @@ class Stage04Runner:
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
writer.writeheader()
if self.config.generation_config.override_dipole_configs is None:
for count in self.config.generation_config.counts:
for orientation in self.config.generation_config.orientations:
for replica in range(self.config.generation_config.num_replicas):
for replica in range(
self.config.generation_config.num_replicas
):
_logger.info(f"Reading {count=} {orientation=} {replica=}")
rows = self.read_merged_coalesced_csv(
orientation, count, replica
@ -170,16 +170,17 @@ class Stage04Runner:
writer.writerow(row)
else:
_logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}")
for override_name in self.config.generation_config.override_dipole_configs.keys():
_logger.debug(
f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}"
)
for (
override_name
) in self.config.generation_config.override_dipole_configs.keys():
_logger.info(f"Working for subdir {override_name}")
rows = self.read_merged_coalesced_csv_override(
override_name
)
rows = self.read_merged_coalesced_csv_override(override_name)
for row in rows:
writer.writerow(row)
# merge with inference
if self.config.generation_config.override_dipole_configs is None:
@ -204,7 +205,9 @@ class Stage04Runner:
/ self.config.general_config.mega_merged_inferenced_name
)
with inferenced_path.open(mode="w", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES)
writer = csv.DictWriter(
outfile, fieldnames=INFERENCED_OUT_FIELDNAMES
)
writer.writeheader()
for val in coalesced.values():
for dots in val.values():
@ -212,7 +215,10 @@ class Stage04Runner:
for row in generation.values():
writer.writerow(row)
else:
_logger.info("skipping inference metamerge, overridden dipole config specified")
_logger.info(
"skipping inference metamerge, overridden dipole config specified"
)
def parse_args():