fmt: formatting changes
This commit is contained in:
parent
4496a6b4da
commit
0d8213c500
@ -14,6 +14,7 @@ class MeasurementTypeEnum(Enum):
|
||||
POTENTIAL = "electric-potential"
|
||||
X_ELECTRIC_FIELD = "x-electric-field"
|
||||
|
||||
|
||||
class SkipToStage(IntEnum):
|
||||
# shouldn't need this lol
|
||||
STAGE_01 = 0
|
||||
@ -30,7 +31,9 @@ class GeneralConfig:
|
||||
dots_json_name: str = "dots.json"
|
||||
indexes_json_name: str = "indexes.json"
|
||||
out_dir_name: str = "out"
|
||||
log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
log_pattern: str = (
|
||||
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
)
|
||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||
|
||||
@ -64,14 +67,14 @@ class GenerationConfig:
|
||||
num_replicas: int = 3
|
||||
|
||||
# the above three can be overrided with manually specified configurations
|
||||
override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None
|
||||
override_dipole_configs: typing.Optional[
|
||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||
] = None
|
||||
|
||||
tantri_configs: typing.Sequence[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
)
|
||||
|
||||
|
||||
|
||||
num_bin_time_series: int = 25
|
||||
bin_log_width: float = 0.25
|
||||
|
||||
|
@ -97,7 +97,6 @@ class Coalescer:
|
||||
/ generation_weight
|
||||
)
|
||||
|
||||
|
||||
def coalesce_all(self):
|
||||
for actual_key in self.actual_dict.keys():
|
||||
for dot_key in self.actual_dict[actual_key].keys():
|
||||
|
@ -70,7 +70,10 @@ class StDevUsingCostFunction:
|
||||
|
||||
self.use_log_noise = log_noise
|
||||
self.log_actual = numpy.log(self.actual_measurement_array)
|
||||
self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2
|
||||
self.log_denom2 = (
|
||||
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
|
||||
- numpy.log(self.actual_measurement_array)
|
||||
) ** 2
|
||||
# if self.use_log_noise:
|
||||
# _logger.debug("remove these debugs later")
|
||||
# _logger.debug(self.actual_measurement_array)
|
||||
@ -91,7 +94,9 @@ class StDevUsingCostFunction:
|
||||
if self.use_log_noise:
|
||||
diffs = ((numpy.log(vals) - self.log_actual) ** 2) / self.log_denom2
|
||||
else:
|
||||
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2
|
||||
diffs = (
|
||||
(vals - self.actual_measurement_array) ** 2
|
||||
) / self.actual_stdev_array2
|
||||
|
||||
return numpy.sqrt(diffs.mean(axis=-1))
|
||||
|
||||
|
@ -39,7 +39,6 @@ class Runner:
|
||||
_logger.info(f"*** Running stage {i + start + 1}")
|
||||
stage.run()
|
||||
|
||||
|
||||
else:
|
||||
# standard run, can keep old
|
||||
|
||||
@ -74,7 +73,8 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s", "--skip-to-stage",
|
||||
"-s",
|
||||
"--skip-to-stage",
|
||||
type=int,
|
||||
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
|
||||
default=None,
|
||||
@ -91,7 +91,6 @@ def main():
|
||||
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
|
||||
]
|
||||
|
||||
|
||||
override_config = {
|
||||
# "test1": [
|
||||
# tantri.dipoles.types.DipoleTO(
|
||||
@ -102,38 +101,28 @@ def main():
|
||||
# ],
|
||||
"two_dipole_connors_geom": [
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([-2, -2, 5.75]),
|
||||
0.0005
|
||||
numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.0005
|
||||
),
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([6, 2, 5.75]),
|
||||
0.05
|
||||
)
|
||||
numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.05
|
||||
),
|
||||
],
|
||||
"two_dipole_connors_geom_omegaswap": [
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([-2, -2, 5.75]),
|
||||
0.05
|
||||
numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.05
|
||||
),
|
||||
tantri.dipoles.types.DipoleTO(
|
||||
numpy.array([0, 0, 100]),
|
||||
numpy.array([6, 2, 5.75]),
|
||||
0.0005
|
||||
)
|
||||
]
|
||||
numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.0005
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
generation_config = kalpa.GenerationConfig(
|
||||
tantri_configs=tantri_configs,
|
||||
counts=[3, 31],
|
||||
num_replicas=5,
|
||||
|
||||
# let's test this out
|
||||
# override_dipole_configs=override_config,
|
||||
|
||||
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||
num_bin_time_series=25,
|
||||
)
|
||||
@ -152,11 +141,10 @@ def main():
|
||||
else:
|
||||
skip = None
|
||||
|
||||
|
||||
general_config = kalpa.GeneralConfig(
|
||||
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
|
||||
out_dir_name=str(root / "out"),
|
||||
skip_to_stage=skip
|
||||
skip_to_stage=skip,
|
||||
)
|
||||
|
||||
# kalpa.GeneralConfig
|
||||
|
@ -116,9 +116,7 @@ class Stage01Runner:
|
||||
for tantri_index, tantri_config in enumerate(
|
||||
self.config.generation_config.tantri_configs
|
||||
):
|
||||
output_csv = directory / kalpa.common.tantri_full_output_name(
|
||||
tantri_index
|
||||
)
|
||||
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
|
||||
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||
tantri_index
|
||||
)
|
||||
@ -140,7 +138,10 @@ class Stage01Runner:
|
||||
# deliberately bad to get it done.
|
||||
# here we're going to be manually specifying dipoles as we have from our config
|
||||
def generate_override_dipole(
|
||||
self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO]
|
||||
self,
|
||||
seed: int,
|
||||
override_name: str,
|
||||
override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO],
|
||||
):
|
||||
"""
|
||||
create a directory, populate it with stuff.
|
||||
@ -178,7 +179,7 @@ class Stage01Runner:
|
||||
dipole_out.write(
|
||||
json.dumps(
|
||||
[dip.as_dict() for dip in override_dipoles],
|
||||
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder
|
||||
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder,
|
||||
)
|
||||
)
|
||||
|
||||
@ -188,9 +189,7 @@ class Stage01Runner:
|
||||
for tantri_index, tantri_config in enumerate(
|
||||
self.config.generation_config.tantri_configs
|
||||
):
|
||||
output_csv = directory / kalpa.common.tantri_full_output_name(
|
||||
tantri_index
|
||||
)
|
||||
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
|
||||
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||
tantri_index
|
||||
)
|
||||
@ -219,12 +218,21 @@ class Stage01Runner:
|
||||
_logger.info(
|
||||
f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}"
|
||||
)
|
||||
self.generate_single_subdir(seed_index, count, orientation, replica)
|
||||
self.generate_single_subdir(
|
||||
seed_index, count, orientation, replica
|
||||
)
|
||||
seed_index += 1
|
||||
else:
|
||||
_logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}")
|
||||
for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items():
|
||||
self.generate_override_dipole(seed_index, override_name, override_dipoles)
|
||||
_logger.debug(
|
||||
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
||||
)
|
||||
for (
|
||||
override_name,
|
||||
override_dipoles,
|
||||
) in self.config.generation_config.override_dipole_configs.items():
|
||||
self.generate_override_dipole(
|
||||
seed_index, override_name, override_dipoles
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -216,7 +216,10 @@ class Stage02Runner:
|
||||
_logger.info(f"{deepdog_config=}")
|
||||
|
||||
stdev_cost_function_filters = [
|
||||
b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas
|
||||
b.stdev_cost_function_filter(
|
||||
dot_names, cost, self.config.deepdog_config.use_log_noise
|
||||
)
|
||||
for b in binned_datas
|
||||
]
|
||||
|
||||
_logger.debug(f"{stdev_cost_function_filters=}")
|
||||
|
@ -145,7 +145,6 @@ class Stage04Runner:
|
||||
out_list.append(row)
|
||||
return out_list
|
||||
|
||||
|
||||
def run(self):
|
||||
megamerged_path = (
|
||||
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
|
||||
@ -156,12 +155,13 @@ class Stage04Runner:
|
||||
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
||||
writer.writeheader()
|
||||
|
||||
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
|
||||
for count in self.config.generation_config.counts:
|
||||
for orientation in self.config.generation_config.orientations:
|
||||
for replica in range(self.config.generation_config.num_replicas):
|
||||
for replica in range(
|
||||
self.config.generation_config.num_replicas
|
||||
):
|
||||
_logger.info(f"Reading {count=} {orientation=} {replica=}")
|
||||
rows = self.read_merged_coalesced_csv(
|
||||
orientation, count, replica
|
||||
@ -170,16 +170,17 @@ class Stage04Runner:
|
||||
writer.writerow(row)
|
||||
|
||||
else:
|
||||
_logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}")
|
||||
for override_name in self.config.generation_config.override_dipole_configs.keys():
|
||||
_logger.info(f"Working for subdir {override_name}")
|
||||
rows = self.read_merged_coalesced_csv_override(
|
||||
override_name
|
||||
_logger.debug(
|
||||
f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}"
|
||||
)
|
||||
for (
|
||||
override_name
|
||||
) in self.config.generation_config.override_dipole_configs.keys():
|
||||
_logger.info(f"Working for subdir {override_name}")
|
||||
rows = self.read_merged_coalesced_csv_override(override_name)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
# merge with inference
|
||||
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
@ -204,7 +205,9 @@ class Stage04Runner:
|
||||
/ self.config.general_config.mega_merged_inferenced_name
|
||||
)
|
||||
with inferenced_path.open(mode="w", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES)
|
||||
writer = csv.DictWriter(
|
||||
outfile, fieldnames=INFERENCED_OUT_FIELDNAMES
|
||||
)
|
||||
writer.writeheader()
|
||||
for val in coalesced.values():
|
||||
for dots in val.values():
|
||||
@ -212,7 +215,10 @@ class Stage04Runner:
|
||||
for row in generation.values():
|
||||
writer.writerow(row)
|
||||
else:
|
||||
_logger.info("skipping inference metamerge, overridden dipole config specified")
|
||||
_logger.info(
|
||||
"skipping inference metamerge, overridden dipole config specified"
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user