kalpa/kalpaa/stages/stage04.py
Deepak Mallubhotla a731a81c6a
Some checks failed
gitea-physics/kalpa/pipeline/head There was a failure building this commit
feat: many disparate updates for modernising
2025-02-21 15:58:12 -06:00

267 lines
7.0 KiB
Python

import typing
import logging
import argparse
import csv
import kalpaa
import kalpaa.common
import kalpaa.inference_coalesce
import tantri.dipoles.types
_logger = logging.getLogger(__name__)
ORIENTATION_DICT = {
"XY": "fixedxy",
"RANDOM": "free",
"Z": "fixedz",
}
# could be detected but why not just hardcode
MERGED_OUT_FIELDNAMES = [
"subdir_name",
"actual_orientation",
"actual_avg_filled",
"generation_replica_index",
"is_row_actual",
# old fields
"dot_name",
"target_cost",
"xmin",
"xmax",
"ymin",
"ymax",
"zmin",
"zmax",
"orientation",
"avg_filled",
"log_magnitude",
"calculations_coalesced",
"success",
"count",
"prob",
]
INFERENCED_OUT_FIELDNAMES = [
"subdir_name",
"actual_orientation",
"actual_avg_filled",
"generation_replica_index",
"is_row_actual",
# old fields
"dot_name",
"target_cost",
"xmin",
"xmax",
"ymin",
"ymax",
"zmin",
"zmax",
"orientation",
"avg_filled",
"log_magnitude",
"calculations_coalesced",
"success",
"count",
"prob",
"coalesced_prob",
]
def is_actual(row, actual_normal_orientation, actual_count):
_logger.debug("Check orientations")
row_or = row["orientation"]
_logger.debug(f"row: {row_or}, actual: {actual_normal_orientation}")
is_or = row_or == actual_normal_orientation
_logger.debug("Check counts")
row_count = row["avg_filled"]
_logger.debug(f"row: {row_count}, actual: {actual_count}")
is_count = int(row_count) == int(actual_count)
_logger.debug("Check magnitude")
row_logmag = row["log_magnitude"]
# TODO hardcoding
is_mag = int(row_logmag) == 2
_logger.debug(f"{is_or=} and {is_count=}and {is_mag=}")
if is_or and is_count and is_mag:
_logger.debug("Returning 1")
return 1
else:
_logger.debug("Returning 0")
return 0
class Stage04Runner:
def __init__(self, config: kalpaa.Config):
self.config = config
_logger.info(f"Initialising Stage04 runner with {config=}")
self.indexifier = self.config.indexifier()
def read_merged_coalesced_csv(self, orientation, count, replica) -> typing.Sequence:
subdir_name = f"{orientation.lower()}-{count}-{replica}"
subdir_path = self.config.get_out_dir_path() / subdir_name
csv_path = (
subdir_path
/ kalpaa.common.sorted_bayesruns_name()
/ kalpaa.common.merged_coalesced_name()
)
_logger.debug(f"Reading {csv_path=}")
with csv_path.open(mode="r", newline="") as csvfile:
reader = csv.DictReader(csvfile)
out_list = []
for row in reader:
normal_orientation = ORIENTATION_DICT[orientation]
row["subdir_name"] = subdir_name
row["actual_orientation"] = ORIENTATION_DICT[orientation]
row["actual_avg_filled"] = count
row["generation_replica_index"] = replica
row["is_row_actual"] = is_actual(row, normal_orientation, count)
out_list.append(row)
return out_list
def read_merged_coalesced_csv_override(self, override_name: str) -> typing.Sequence:
subdir_name = override_name
subdir_path = self.config.get_out_dir_path() / subdir_name
csv_path = (
subdir_path
/ kalpaa.common.sorted_bayesruns_name()
/ kalpaa.common.merged_coalesced_name()
)
_logger.debug(f"Reading {csv_path=}")
with csv_path.open(mode="r", newline="") as csvfile:
reader = csv.DictReader(csvfile)
out_list = []
for row in reader:
# We can't put any of the actual info in because it's totally arbitrary, but that's fine!
# normal_orientation = ORIENTATION_DICT[orientation]
row["subdir_name"] = subdir_name
# row["actual_orientation"] = ORIENTATION_DICT[orientation]
# row["actual_avg_filled"] = count
# row["generation_replica_index"] = replica
# row["is_row_actual"] = is_actual(row, normal_orientation, count)
out_list.append(row)
return out_list
def run(self):
megamerged_path = (
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
)
# normal merged
with megamerged_path.open(mode="w", newline="") as outfile:
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
writer.writeheader()
if self.config.generation_config.override_dipole_configs is None:
for count in self.config.generation_config.counts:
for orientation in self.config.generation_config.orientations:
for replica in range(
self.config.generation_config.num_replicas
):
_logger.info(f"Reading {count=} {orientation=} {replica=}")
rows = self.read_merged_coalesced_csv(
orientation, count, replica
)
for row in rows:
writer.writerow(row)
else:
_logger.debug(
f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}"
)
for (
override_name
) in self.config.generation_config.override_dipole_configs.keys():
_logger.info(f"Working for subdir {override_name}")
rows = self.read_merged_coalesced_csv_override(override_name)
for row in rows:
writer.writerow(row)
# merge with inference
if self.config.generation_config.override_dipole_configs is None:
with megamerged_path.open(mode="r", newline="") as infile:
# Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad:
# megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES)
megamerged_reader = csv.DictReader(infile)
rows = [row for row in megamerged_reader]
_logger.debug(rows[0])
coalescer = kalpaa.inference_coalesce.Coalescer(
rows, num_replicas=self.config.generation_config.num_replicas
)
_logger.info(coalescer.actual_dict.keys())
# coalescer.coalesce_generations(("fixedxy", "1"), "dot1")
coalesced = coalescer.coalesce_all()
inferenced_path = (
self.config.get_out_dir_path()
/ self.config.general_config.mega_merged_inferenced_name
)
with inferenced_path.open(mode="w", newline="") as outfile:
writer = csv.DictWriter(
outfile, fieldnames=INFERENCED_OUT_FIELDNAMES
)
writer.writeheader()
for val in coalesced.values():
for dots in val.values():
for generation in dots.values():
for row in generation.values():
writer.writerow(row)
else:
_logger.info(
"skipping inference metamerge, overridden dipole config specified"
)
def parse_args():
parser = argparse.ArgumentParser(
"put files in directory", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--log-file",
type=str,
help="A filename for logging to, if not provided will only log to stderr",
default=None,
)
args = parser.parse_args()
return args
def main():
args = parse_args()
tantri_configs = [
kalpaa.TantriConfig(31415, 100, 5, 100000),
kalpaa.TantriConfig(314, 100, 0.00005, 100000),
]
generation_config = kalpaa.GenerationConfig(
tantri_configs=tantri_configs,
counts=[1],
num_replicas=3,
orientations=[tantri.dipoles.types.Orientation.Z],
)
config = kalpaa.Config(generation_config=generation_config)
kalpaa.common.set_up_logging(config, args.log_file)
_logger.info("Generating our data, for the following iterations")
_logger.info(config)
runner = Stage04Runner(config)
runner.run()
if __name__ == "__main__":
main()