fmt: formatting changes
This commit is contained in:
parent
4496a6b4da
commit
0d8213c500
@ -14,6 +14,7 @@ class MeasurementTypeEnum(Enum):
|
|||||||
POTENTIAL = "electric-potential"
|
POTENTIAL = "electric-potential"
|
||||||
X_ELECTRIC_FIELD = "x-electric-field"
|
X_ELECTRIC_FIELD = "x-electric-field"
|
||||||
|
|
||||||
|
|
||||||
class SkipToStage(IntEnum):
|
class SkipToStage(IntEnum):
|
||||||
# shouldn't need this lol
|
# shouldn't need this lol
|
||||||
STAGE_01 = 0
|
STAGE_01 = 0
|
||||||
@ -30,7 +31,9 @@ class GeneralConfig:
|
|||||||
dots_json_name: str = "dots.json"
|
dots_json_name: str = "dots.json"
|
||||||
indexes_json_name: str = "indexes.json"
|
indexes_json_name: str = "indexes.json"
|
||||||
out_dir_name: str = "out"
|
out_dir_name: str = "out"
|
||||||
log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
log_pattern: str = (
|
||||||
|
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||||
|
)
|
||||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||||
|
|
||||||
@ -64,14 +67,14 @@ class GenerationConfig:
|
|||||||
num_replicas: int = 3
|
num_replicas: int = 3
|
||||||
|
|
||||||
# the above three can be overrided with manually specified configurations
|
# the above three can be overrided with manually specified configurations
|
||||||
override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None
|
override_dipole_configs: typing.Optional[
|
||||||
|
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||||
|
] = None
|
||||||
|
|
||||||
tantri_configs: typing.Sequence[TantriConfig] = field(
|
tantri_configs: typing.Sequence[TantriConfig] = field(
|
||||||
default_factory=lambda: [TantriConfig()]
|
default_factory=lambda: [TantriConfig()]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
num_bin_time_series: int = 25
|
num_bin_time_series: int = 25
|
||||||
bin_log_width: float = 0.25
|
bin_log_width: float = 0.25
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class Coalescer:
|
|||||||
val["coalesced_prob"] = val["prob"]
|
val["coalesced_prob"] = val["prob"]
|
||||||
|
|
||||||
if self.num_replicas > 1:
|
if self.num_replicas > 1:
|
||||||
for gen in range(1, self.num_replicas):
|
for gen in range(1, self.num_replicas):
|
||||||
_logger.info(f"Going through generation {gen}")
|
_logger.info(f"Going through generation {gen}")
|
||||||
|
|
||||||
generation_weight = sum(
|
generation_weight = sum(
|
||||||
@ -93,11 +93,10 @@ class Coalescer:
|
|||||||
for model_key, val in subdict[str(gen)].items():
|
for model_key, val in subdict[str(gen)].items():
|
||||||
val["coalesced_prob"] = (
|
val["coalesced_prob"] = (
|
||||||
float(val["prob"])
|
float(val["prob"])
|
||||||
* float(subdict[str(gen-1)][model_key]["coalesced_prob"])
|
* float(subdict[str(gen - 1)][model_key]["coalesced_prob"])
|
||||||
/ generation_weight
|
/ generation_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def coalesce_all(self):
|
def coalesce_all(self):
|
||||||
for actual_key in self.actual_dict.keys():
|
for actual_key in self.actual_dict.keys():
|
||||||
for dot_key in self.actual_dict[actual_key].keys():
|
for dot_key in self.actual_dict[actual_key].keys():
|
||||||
|
@ -70,7 +70,10 @@ class StDevUsingCostFunction:
|
|||||||
|
|
||||||
self.use_log_noise = log_noise
|
self.use_log_noise = log_noise
|
||||||
self.log_actual = numpy.log(self.actual_measurement_array)
|
self.log_actual = numpy.log(self.actual_measurement_array)
|
||||||
self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2
|
self.log_denom2 = (
|
||||||
|
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
|
||||||
|
- numpy.log(self.actual_measurement_array)
|
||||||
|
) ** 2
|
||||||
# if self.use_log_noise:
|
# if self.use_log_noise:
|
||||||
# _logger.debug("remove these debugs later")
|
# _logger.debug("remove these debugs later")
|
||||||
# _logger.debug(self.actual_measurement_array)
|
# _logger.debug(self.actual_measurement_array)
|
||||||
@ -89,9 +92,11 @@ class StDevUsingCostFunction:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_log_noise:
|
if self.use_log_noise:
|
||||||
diffs = ((numpy.log(vals) - self.log_actual) **2) / self.log_denom2
|
diffs = ((numpy.log(vals) - self.log_actual) ** 2) / self.log_denom2
|
||||||
else:
|
else:
|
||||||
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2
|
diffs = (
|
||||||
|
(vals - self.actual_measurement_array) ** 2
|
||||||
|
) / self.actual_stdev_array2
|
||||||
|
|
||||||
return numpy.sqrt(diffs.mean(axis=-1))
|
return numpy.sqrt(diffs.mean(axis=-1))
|
||||||
|
|
||||||
|
@ -35,11 +35,10 @@ class Runner:
|
|||||||
|
|
||||||
start = int(self.config.general_config.skip_to_stage)
|
start = int(self.config.general_config.skip_to_stage)
|
||||||
_logger.info(f"Received instruction to start at stage {start + 1}")
|
_logger.info(f"Received instruction to start at stage {start + 1}")
|
||||||
for i, stage in enumerate(stages[start: 4]):
|
for i, stage in enumerate(stages[start:4]):
|
||||||
_logger.info(f"*** Running stage {i + start + 1}")
|
_logger.info(f"*** Running stage {i + start + 1}")
|
||||||
stage.run()
|
stage.run()
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# standard run, can keep old
|
# standard run, can keep old
|
||||||
|
|
||||||
@ -74,7 +73,8 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-s", "--skip-to-stage",
|
"-s",
|
||||||
|
"--skip-to-stage",
|
||||||
type=int,
|
type=int,
|
||||||
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
|
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
|
||||||
default=None,
|
default=None,
|
||||||
@ -91,7 +91,6 @@ def main():
|
|||||||
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
|
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
override_config = {
|
override_config = {
|
||||||
# "test1": [
|
# "test1": [
|
||||||
# tantri.dipoles.types.DipoleTO(
|
# tantri.dipoles.types.DipoleTO(
|
||||||
@ -102,38 +101,28 @@ def main():
|
|||||||
# ],
|
# ],
|
||||||
"two_dipole_connors_geom": [
|
"two_dipole_connors_geom": [
|
||||||
tantri.dipoles.types.DipoleTO(
|
tantri.dipoles.types.DipoleTO(
|
||||||
numpy.array([0, 0, 100]),
|
numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.0005
|
||||||
numpy.array([-2, -2, 5.75]),
|
|
||||||
0.0005
|
|
||||||
),
|
),
|
||||||
tantri.dipoles.types.DipoleTO(
|
tantri.dipoles.types.DipoleTO(
|
||||||
numpy.array([0, 0, 100]),
|
numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.05
|
||||||
numpy.array([6, 2, 5.75]),
|
),
|
||||||
0.05
|
|
||||||
)
|
|
||||||
],
|
],
|
||||||
"two_dipole_connors_geom_omegaswap": [
|
"two_dipole_connors_geom_omegaswap": [
|
||||||
tantri.dipoles.types.DipoleTO(
|
tantri.dipoles.types.DipoleTO(
|
||||||
numpy.array([0, 0, 100]),
|
numpy.array([0, 0, 100]), numpy.array([-2, -2, 5.75]), 0.05
|
||||||
numpy.array([-2, -2, 5.75]),
|
|
||||||
0.05
|
|
||||||
),
|
),
|
||||||
tantri.dipoles.types.DipoleTO(
|
tantri.dipoles.types.DipoleTO(
|
||||||
numpy.array([0, 0, 100]),
|
numpy.array([0, 0, 100]), numpy.array([6, 2, 5.75]), 0.0005
|
||||||
numpy.array([6, 2, 5.75]),
|
),
|
||||||
0.0005
|
],
|
||||||
)
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_config = kalpa.GenerationConfig(
|
generation_config = kalpa.GenerationConfig(
|
||||||
tantri_configs=tantri_configs,
|
tantri_configs=tantri_configs,
|
||||||
counts=[3, 31],
|
counts=[3, 31],
|
||||||
num_replicas=5,
|
num_replicas=5,
|
||||||
|
|
||||||
# let's test this out
|
# let's test this out
|
||||||
# override_dipole_configs=override_config,
|
# override_dipole_configs=override_config,
|
||||||
|
|
||||||
orientations=[tantri.dipoles.types.Orientation.Z],
|
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||||
num_bin_time_series=25,
|
num_bin_time_series=25,
|
||||||
)
|
)
|
||||||
@ -150,13 +139,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
skip = kalpa.config.SkipToStage(args.skip_to_stage - 1)
|
skip = kalpa.config.SkipToStage(args.skip_to_stage - 1)
|
||||||
else:
|
else:
|
||||||
skip = None
|
skip = None
|
||||||
|
|
||||||
|
|
||||||
general_config = kalpa.GeneralConfig(
|
general_config = kalpa.GeneralConfig(
|
||||||
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
|
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
|
||||||
out_dir_name=str(root / "out"),
|
out_dir_name=str(root / "out"),
|
||||||
skip_to_stage=skip
|
skip_to_stage=skip,
|
||||||
)
|
)
|
||||||
|
|
||||||
# kalpa.GeneralConfig
|
# kalpa.GeneralConfig
|
||||||
|
@ -116,9 +116,7 @@ class Stage01Runner:
|
|||||||
for tantri_index, tantri_config in enumerate(
|
for tantri_index, tantri_config in enumerate(
|
||||||
self.config.generation_config.tantri_configs
|
self.config.generation_config.tantri_configs
|
||||||
):
|
):
|
||||||
output_csv = directory / kalpa.common.tantri_full_output_name(
|
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
|
||||||
tantri_index
|
|
||||||
)
|
|
||||||
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||||
tantri_index
|
tantri_index
|
||||||
)
|
)
|
||||||
@ -140,7 +138,10 @@ class Stage01Runner:
|
|||||||
# deliberately bad to get it done.
|
# deliberately bad to get it done.
|
||||||
# here we're going to be manually specifying dipoles as we have from our config
|
# here we're going to be manually specifying dipoles as we have from our config
|
||||||
def generate_override_dipole(
|
def generate_override_dipole(
|
||||||
self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO]
|
self,
|
||||||
|
seed: int,
|
||||||
|
override_name: str,
|
||||||
|
override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
create a directory, populate it with stuff.
|
create a directory, populate it with stuff.
|
||||||
@ -157,9 +158,9 @@ class Stage01Runner:
|
|||||||
out = self.config.get_out_dir_path()
|
out = self.config.get_out_dir_path()
|
||||||
directory = out / f"{override_name}"
|
directory = out / f"{override_name}"
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
_logger.debug("generated override directory")
|
_logger.debug("generated override directory")
|
||||||
|
|
||||||
# config_json = directory / "generation_config.json"
|
# config_json = directory / "generation_config.json"
|
||||||
dipoles_json = directory / "dipoles.json"
|
dipoles_json = directory / "dipoles.json"
|
||||||
|
|
||||||
@ -178,7 +179,7 @@ class Stage01Runner:
|
|||||||
dipole_out.write(
|
dipole_out.write(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
[dip.as_dict() for dip in override_dipoles],
|
[dip.as_dict() for dip in override_dipoles],
|
||||||
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder
|
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -188,9 +189,7 @@ class Stage01Runner:
|
|||||||
for tantri_index, tantri_config in enumerate(
|
for tantri_index, tantri_config in enumerate(
|
||||||
self.config.generation_config.tantri_configs
|
self.config.generation_config.tantri_configs
|
||||||
):
|
):
|
||||||
output_csv = directory / kalpa.common.tantri_full_output_name(
|
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
|
||||||
tantri_index
|
|
||||||
)
|
|
||||||
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
binned_csv = directory / kalpa.common.tantri_binned_output_name(
|
||||||
tantri_index
|
tantri_index
|
||||||
)
|
)
|
||||||
@ -219,12 +218,21 @@ class Stage01Runner:
|
|||||||
_logger.info(
|
_logger.info(
|
||||||
f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}"
|
f"Generating for {seed_index=}: [{count=}, {orientation=}, {replica=}"
|
||||||
)
|
)
|
||||||
self.generate_single_subdir(seed_index, count, orientation, replica)
|
self.generate_single_subdir(
|
||||||
|
seed_index, count, orientation, replica
|
||||||
|
)
|
||||||
seed_index += 1
|
seed_index += 1
|
||||||
else:
|
else:
|
||||||
_logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}")
|
_logger.debug(
|
||||||
for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items():
|
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
||||||
self.generate_override_dipole(seed_index, override_name, override_dipoles)
|
)
|
||||||
|
for (
|
||||||
|
override_name,
|
||||||
|
override_dipoles,
|
||||||
|
) in self.config.generation_config.override_dipole_configs.items():
|
||||||
|
self.generate_override_dipole(
|
||||||
|
seed_index, override_name, override_dipoles
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -216,7 +216,10 @@ class Stage02Runner:
|
|||||||
_logger.info(f"{deepdog_config=}")
|
_logger.info(f"{deepdog_config=}")
|
||||||
|
|
||||||
stdev_cost_function_filters = [
|
stdev_cost_function_filters = [
|
||||||
b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas
|
b.stdev_cost_function_filter(
|
||||||
|
dot_names, cost, self.config.deepdog_config.use_log_noise
|
||||||
|
)
|
||||||
|
for b in binned_datas
|
||||||
]
|
]
|
||||||
|
|
||||||
_logger.debug(f"{stdev_cost_function_filters=}")
|
_logger.debug(f"{stdev_cost_function_filters=}")
|
||||||
|
@ -145,7 +145,6 @@ class Stage04Runner:
|
|||||||
out_list.append(row)
|
out_list.append(row)
|
||||||
return out_list
|
return out_list
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
megamerged_path = (
|
megamerged_path = (
|
||||||
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
|
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
|
||||||
@ -156,30 +155,32 @@ class Stage04Runner:
|
|||||||
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
||||||
|
|
||||||
if self.config.generation_config.override_dipole_configs is None:
|
if self.config.generation_config.override_dipole_configs is None:
|
||||||
|
|
||||||
for count in self.config.generation_config.counts:
|
for count in self.config.generation_config.counts:
|
||||||
for orientation in self.config.generation_config.orientations:
|
for orientation in self.config.generation_config.orientations:
|
||||||
for replica in range(self.config.generation_config.num_replicas):
|
for replica in range(
|
||||||
|
self.config.generation_config.num_replicas
|
||||||
|
):
|
||||||
_logger.info(f"Reading {count=} {orientation=} {replica=}")
|
_logger.info(f"Reading {count=} {orientation=} {replica=}")
|
||||||
rows = self.read_merged_coalesced_csv(
|
rows = self.read_merged_coalesced_csv(
|
||||||
orientation, count, replica
|
orientation, count, replica
|
||||||
)
|
)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}")
|
_logger.debug(
|
||||||
for override_name in self.config.generation_config.override_dipole_configs.keys():
|
f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}"
|
||||||
|
)
|
||||||
|
for (
|
||||||
|
override_name
|
||||||
|
) in self.config.generation_config.override_dipole_configs.keys():
|
||||||
_logger.info(f"Working for subdir {override_name}")
|
_logger.info(f"Working for subdir {override_name}")
|
||||||
rows = self.read_merged_coalesced_csv_override(
|
rows = self.read_merged_coalesced_csv_override(override_name)
|
||||||
override_name
|
|
||||||
)
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
|
|
||||||
# merge with inference
|
# merge with inference
|
||||||
|
|
||||||
if self.config.generation_config.override_dipole_configs is None:
|
if self.config.generation_config.override_dipole_configs is None:
|
||||||
@ -204,7 +205,9 @@ class Stage04Runner:
|
|||||||
/ self.config.general_config.mega_merged_inferenced_name
|
/ self.config.general_config.mega_merged_inferenced_name
|
||||||
)
|
)
|
||||||
with inferenced_path.open(mode="w", newline="") as outfile:
|
with inferenced_path.open(mode="w", newline="") as outfile:
|
||||||
writer = csv.DictWriter(outfile, fieldnames=INFERENCED_OUT_FIELDNAMES)
|
writer = csv.DictWriter(
|
||||||
|
outfile, fieldnames=INFERENCED_OUT_FIELDNAMES
|
||||||
|
)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
for val in coalesced.values():
|
for val in coalesced.values():
|
||||||
for dots in val.values():
|
for dots in val.values():
|
||||||
@ -212,7 +215,10 @@ class Stage04Runner:
|
|||||||
for row in generation.values():
|
for row in generation.values():
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
else:
|
else:
|
||||||
_logger.info("skipping inference metamerge, overridden dipole config specified")
|
_logger.info(
|
||||||
|
"skipping inference metamerge, overridden dipole config specified"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user