All checks were successful
gitea-physics/kalpa/pipeline/head This commit looks good
270 lines
7.2 KiB
Python
270 lines
7.2 KiB
Python
import typing
|
|
import logging
|
|
import argparse
|
|
import csv
|
|
import kalpaa
|
|
import kalpaa.common
|
|
import kalpaa.inference_coalesce
|
|
import tantri.dipoles.types
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
ORIENTATION_DICT = {
|
|
"XY": "fixedxy",
|
|
"RANDOM": "free",
|
|
"Z": "fixedz",
|
|
}
|
|
|
|
# could be detected but why not just hardcode
|
|
MERGED_OUT_FIELDNAMES = [
|
|
"subdir_name",
|
|
"actual_orientation",
|
|
"actual_avg_filled",
|
|
"generation_replica_index",
|
|
"is_row_actual",
|
|
# old fields
|
|
"dot_name",
|
|
"target_cost",
|
|
"xmin",
|
|
"xmax",
|
|
"ymin",
|
|
"ymax",
|
|
"zmin",
|
|
"zmax",
|
|
"orientation",
|
|
"avg_filled",
|
|
"log_magnitude",
|
|
"calculations_coalesced",
|
|
"success",
|
|
"count",
|
|
"prob",
|
|
]
|
|
|
|
INFERENCED_OUT_FIELDNAMES = [
|
|
"subdir_name",
|
|
"actual_orientation",
|
|
"actual_avg_filled",
|
|
"generation_replica_index",
|
|
"is_row_actual",
|
|
# old fields
|
|
"dot_name",
|
|
"target_cost",
|
|
"xmin",
|
|
"xmax",
|
|
"ymin",
|
|
"ymax",
|
|
"zmin",
|
|
"zmax",
|
|
"orientation",
|
|
"avg_filled",
|
|
"log_magnitude",
|
|
"calculations_coalesced",
|
|
"success",
|
|
"count",
|
|
"prob",
|
|
"coalesced_prob",
|
|
]
|
|
|
|
|
|
def is_actual(row, actual_normal_orientation, actual_count):
|
|
_logger.debug("Check orientations")
|
|
row_or = row["orientation"]
|
|
_logger.debug(f"row: {row_or}, actual: {actual_normal_orientation}")
|
|
is_or = row_or == actual_normal_orientation
|
|
|
|
_logger.debug("Check counts")
|
|
row_count = row["avg_filled"]
|
|
_logger.debug(f"row: {row_count}, actual: {actual_count}")
|
|
is_count = int(row_count) == int(actual_count)
|
|
|
|
_logger.debug("Check magnitude")
|
|
row_logmag = row["log_magnitude"]
|
|
# TODO hardcoding
|
|
is_mag = int(row_logmag) == 2
|
|
|
|
_logger.debug(f"{is_or=} and {is_count=}and {is_mag=}")
|
|
if is_or and is_count and is_mag:
|
|
_logger.debug("Returning 1")
|
|
return 1
|
|
else:
|
|
_logger.debug("Returning 0")
|
|
return 0
|
|
|
|
|
|
class Stage04Runner:
|
|
def __init__(self, config: kalpaa.Config):
|
|
self.config = config
|
|
_logger.info(f"Initialising Stage04 runner with {config=}")
|
|
|
|
self.indexifier = self.config.indexifier()
|
|
|
|
def read_merged_coalesced_csv(self, orientation, count, replica) -> typing.Sequence:
|
|
subdir_name = f"{orientation.lower()}-{count}-{replica}"
|
|
subdir_path = self.config.get_out_dir_path() / subdir_name
|
|
csv_path = (
|
|
subdir_path
|
|
/ kalpaa.common.sorted_bayesruns_name()
|
|
/ kalpaa.common.merged_coalesced_name()
|
|
)
|
|
_logger.debug(f"Reading {csv_path=}")
|
|
with csv_path.open(mode="r", newline="") as csvfile:
|
|
reader = csv.DictReader(csvfile)
|
|
out_list = []
|
|
for row in reader:
|
|
normal_orientation = ORIENTATION_DICT[orientation]
|
|
row["subdir_name"] = subdir_name
|
|
row["actual_orientation"] = ORIENTATION_DICT[orientation]
|
|
row["actual_avg_filled"] = count
|
|
row["generation_replica_index"] = replica
|
|
row["is_row_actual"] = is_actual(row, normal_orientation, count)
|
|
out_list.append(row)
|
|
return out_list
|
|
|
|
def read_merged_coalesced_csv_override(self, override_name: str) -> typing.Sequence:
|
|
subdir_name = override_name
|
|
subdir_path = self.config.get_out_dir_path() / subdir_name
|
|
csv_path = (
|
|
subdir_path
|
|
/ kalpaa.common.sorted_bayesruns_name()
|
|
/ kalpaa.common.merged_coalesced_name()
|
|
)
|
|
_logger.debug(f"Reading {csv_path=}")
|
|
with csv_path.open(mode="r", newline="") as csvfile:
|
|
reader = csv.DictReader(csvfile)
|
|
out_list = []
|
|
for row in reader:
|
|
# We can't put any of the actual info in because it's totally arbitrary, but that's fine!
|
|
|
|
# normal_orientation = ORIENTATION_DICT[orientation]
|
|
row["subdir_name"] = subdir_name
|
|
# row["actual_orientation"] = ORIENTATION_DICT[orientation]
|
|
# row["actual_avg_filled"] = count
|
|
# row["generation_replica_index"] = replica
|
|
# row["is_row_actual"] = is_actual(row, normal_orientation, count)
|
|
out_list.append(row)
|
|
return out_list
|
|
|
|
def run(self):
|
|
megamerged_path = (
|
|
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
|
|
)
|
|
|
|
# normal merged
|
|
with megamerged_path.open(mode="w", newline="") as outfile:
|
|
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
|
writer.writeheader()
|
|
|
|
if self.config.generation_config.override_dipole_configs is not None:
|
|
override_names = self.config.generation_config.override_dipole_configs.keys()
|
|
elif self.config.generation_config.override_measurement_filesets is not None:
|
|
override_names = self.config.generation_config.override_measurement_filesets.keys()
|
|
else:
|
|
override_names = None
|
|
|
|
if (override_names is not None):
|
|
_logger.debug(
|
|
f"We had overridden dipole config, using override {override_names}"
|
|
)
|
|
for override_name in override_names:
|
|
_logger.info(f"Working for subdir {override_name}")
|
|
rows = self.read_merged_coalesced_csv_override(override_name)
|
|
for row in rows:
|
|
writer.writerow(row)
|
|
else:
|
|
for count in self.config.generation_config.counts:
|
|
for orientation in self.config.generation_config.orientations:
|
|
for replica in range(
|
|
self.config.generation_config.num_replicas
|
|
):
|
|
_logger.info(f"Reading {count=} {orientation=} {replica=}")
|
|
rows = self.read_merged_coalesced_csv(
|
|
orientation, count, replica
|
|
)
|
|
for row in rows:
|
|
writer.writerow(row)
|
|
|
|
# merge with inference
|
|
|
|
if override_names is None:
|
|
|
|
with megamerged_path.open(mode="r", newline="") as infile:
|
|
# Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad:
|
|
# megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES)
|
|
megamerged_reader = csv.DictReader(infile)
|
|
rows = [row for row in megamerged_reader]
|
|
_logger.debug(rows[0])
|
|
coalescer = kalpaa.inference_coalesce.Coalescer(
|
|
rows, num_replicas=self.config.generation_config.num_replicas
|
|
)
|
|
_logger.info(coalescer.actual_dict.keys())
|
|
|
|
# coalescer.coalesce_generations(("fixedxy", "1"), "dot1")
|
|
|
|
coalesced = coalescer.coalesce_all()
|
|
|
|
inferenced_path = (
|
|
self.config.get_out_dir_path()
|
|
/ self.config.general_config.mega_merged_inferenced_name
|
|
)
|
|
with inferenced_path.open(mode="w", newline="") as outfile:
|
|
writer = csv.DictWriter(
|
|
outfile, fieldnames=INFERENCED_OUT_FIELDNAMES
|
|
)
|
|
writer.writeheader()
|
|
for val in coalesced.values():
|
|
for dots in val.values():
|
|
for generation in dots.values():
|
|
for row in generation.values():
|
|
writer.writerow(row)
|
|
else:
|
|
_logger.info(
|
|
"skipping inference metamerge, overridden dipole config specified"
|
|
)
|
|
|
|
|
|
def parse_args():
|
|
|
|
parser = argparse.ArgumentParser(
|
|
"put files in directory", formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--log-file",
|
|
type=str,
|
|
help="A filename for logging to, if not provided will only log to stderr",
|
|
default=None,
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
tantri_configs = [
|
|
kalpaa.TantriConfig(31415, 100, 5, 100000),
|
|
kalpaa.TantriConfig(314, 100, 0.00005, 100000),
|
|
]
|
|
generation_config = kalpaa.GenerationConfig(
|
|
tantri_configs=tantri_configs,
|
|
counts=[1],
|
|
num_replicas=3,
|
|
orientations=[tantri.dipoles.types.Orientation.Z],
|
|
)
|
|
|
|
config = kalpaa.Config(generation_config=generation_config)
|
|
|
|
kalpaa.common.set_up_logging(config, args.log_file)
|
|
|
|
_logger.info("Generating our data, for the following iterations")
|
|
|
|
_logger.info(config)
|
|
|
|
runner = Stage04Runner(config)
|
|
runner.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|