feat!: copying over from sd4

This commit is contained in:
Deepak Mallubhotla 2024-12-30 04:10:17 +00:00
parent e408920bde
commit 5a0b294150
Signed by: deepak
GPG Key ID: 47831B15427F5A55
8 changed files with 373 additions and 118 deletions

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field
import typing
import tantri.dipoles.types
import pathlib
from enum import Enum
from enum import Enum, IntEnum
import logging
_logger = logging.getLogger(__name__)
@ -14,21 +14,30 @@ class MeasurementTypeEnum(Enum):
POTENTIAL = "electric-potential"
X_ELECTRIC_FIELD = "x-electric-field"
class SkipToStage(IntEnum):
# shouldn't need this lol
STAGE_01 = 0
STAGE_02 = 1
STAGE_03 = 2
STAGE_04 = 3
# Copy over some random constants to see if they're ever reused
@dataclass(frozen=True)
class GeneralConfig:
dots_json_name = "dots.json"
indexes_json_name = "indexes.json"
out_dir_name = "out"
log_pattern = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
dots_json_name: str = "dots.json"
indexes_json_name: str = "indexes.json"
out_dir_name: str = "out"
log_pattern: str = "%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
root_directory: pathlib.Path = pathlib.Path.cwd()
mega_merged_name = "mega_merged_coalesced.csv"
mega_merged_inferenced_name = "mega_merged_coalesced_inferenced.csv"
mega_merged_name: str = "mega_merged_coalesced.csv"
mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv"
skip_to_stage: typing.Optional[int] = None
@dataclass(frozen=True)
@ -51,12 +60,18 @@ class GenerationConfig:
tantri.dipoles.types.Orientation.XY,
]
)
# TODO: what's replica here?
num_replicas: int = 3
# the above three can be overrided with manually specified configurations
override_dipole_configs: typing.Optional[typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]] = None
tantri_configs: typing.Sequence[TantriConfig] = field(
default_factory=lambda: [TantriConfig()]
)
num_bin_time_series: int = 25
bin_log_width: float = 0.25
@ -70,6 +85,8 @@ class DeepdogConfig:
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
target_success: int = 1000
max_monte_carlo_cycles_steps: int = 20
# Whetehr to use a log log cost function
use_log_noise: bool = False
@dataclass(frozen=True)
@ -80,7 +97,7 @@ class Config:
def absify(self, filename: str) -> pathlib.Path:
ret = (self.general_config.root_directory / filename).resolve()
_logger.debug(f"Absifying {filename=}, getting {ret}")
_logger.debug(f"Absifying {filename=}, geting {ret}")
return ret
def get_out_dir_path(self) -> pathlib.Path:
@ -115,8 +132,8 @@ class ReducedModelParams:
x_max: float = 20
y_min: float = -10
y_max: float = 10
z_min: float = 0
z_max: float = 5
z_min: float = 5
z_max: float = 6.5
w_log_min: float = -5
w_log_max: float = 1
count: int = 1

View File

@ -66,45 +66,37 @@ class Coalescer:
_logger.debug(f"subdict keys: {subdict.keys()}")
# TODO hardcoding 3 generations
if self.num_replicas != 3:
raise ValueError(
f"num replicas was {self.num_replicas}, but we've hard coded 3"
)
# if self.num_replicas != 3:
# raise ValueError(
# f"num replicas was {self.num_replicas}, but we've hard coded 3"
# )
# generations_keys = ["0", "1", "2"]
_logger.info(f"Going through generation {0}")
# 0th gen is easiest
for model_key, val in subdict["0"].items():
val["coalesced_prob"] = val["prob"]
weight1 = sum(
if self.num_replicas > 1:
for gen in range(1, self.num_replicas):
_logger.info(f"Going through generation {gen}")
generation_weight = sum(
[
float(subdict["0"][key]["coalesced_prob"])
* float(subdict["1"][key]["prob"])
for key in subdict["1"].keys()
float(subdict[str(gen - 1)][key]["coalesced_prob"])
* float(subdict[str(gen)][key]["prob"])
for key in subdict[str(gen)].keys()
]
)
_logger.debug(weight1)
for model_key, val in subdict["1"].items():
_logger.debug(generation_weight)
for model_key, val in subdict[str(gen)].items():
val["coalesced_prob"] = (
float(val["prob"])
* float(subdict["0"][model_key]["coalesced_prob"])
/ weight1
* float(subdict[str(gen-1)][model_key]["coalesced_prob"])
/ generation_weight
)
weight2 = sum(
[
float(subdict["1"][key]["coalesced_prob"])
* float(subdict["2"][key]["prob"])
for key in subdict["2"].keys()
]
)
_logger.debug(weight2)
for model_key, val in subdict["2"].items():
val["coalesced_prob"] = (
float(val["prob"])
* float(subdict["1"][model_key]["coalesced_prob"])
/ weight2
)
def coalesce_all(self):
for actual_key in self.actual_dict.keys():

View File

@ -58,6 +58,7 @@ class StDevUsingCostFunction:
dot_inputs_array,
actual_measurement_array,
actual_stdev_array,
log_noise: bool = False,
):
_logger.info(f"Cost function with measurement type of {measurement_type}")
self.measurement_type = measurement_type
@ -67,6 +68,16 @@ class StDevUsingCostFunction:
self.actual_stdev_array = actual_stdev_array
self.actual_stdev_array2 = actual_stdev_array**2
self.use_log_noise = log_noise
self.log_actual = numpy.log(self.actual_measurement_array)
self.log_denom2 = (numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array))**2
# if self.use_log_noise:
# _logger.debug("remove these debugs later")
# _logger.debug(self.actual_measurement_array)
# _logger.debug(self.actual_stdev_array)
# _logger.debug(self.log_actual)
# _logger.debug(self.log_denom2)
def __call__(self, dipoles_to_test):
if self.measurement_type == X_ELECTRIC_FIELD:
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
@ -76,7 +87,12 @@ class StDevUsingCostFunction:
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
self.dot_inputs_array, dipoles_to_test
)
if self.use_log_noise:
diffs = ((numpy.log(vals) - self.log_actual) **2) / self.log_denom2
else:
diffs = ((vals - self.actual_measurement_array) ** 2) / self.actual_stdev_array2
return numpy.sqrt(diffs.mean(axis=-1))
@ -202,6 +218,7 @@ class BinnedData:
csv_dict: typing.Dict[str, typing.Any]
measurement_type: str
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]:
if dot_name not in self.dots_dict:
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
@ -255,6 +272,7 @@ class BinnedData:
def _stdev_cost_function(
self,
measurements: typing.Sequence[Measurement],
use_log_noise: bool = False,
):
meas_array = numpy.array([m.dot_measurement.v for m in measurements])
stdev_array = numpy.array([m.stdev for m in measurements])
@ -266,7 +284,7 @@ class BinnedData:
_logger.debug(f"Obtained {input_array=}")
return StDevUsingCostFunction(
self.measurement_type, input_array, meas_array, stdev_array
self.measurement_type, input_array, meas_array, stdev_array, use_log_noise
)
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
@ -277,10 +295,13 @@ class BinnedData:
)
def stdev_cost_function_filter(
self, dot_names: typing.Sequence[str], target_cost: float
self,
dot_names: typing.Sequence[str],
target_cost: float,
use_log_noise: bool = False,
):
measurements = self.measurements(dot_names)
cost_function = self._stdev_cost_function(measurements)
cost_function = self._stdev_cost_function(measurements, use_log_noise)
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
cost_function, target_cost
)

View File

@ -1,3 +1,4 @@
import pathlib
import logging
import kalpa.stages.stage01
@ -6,9 +7,13 @@ import kalpa.stages.stage03
import kalpa.stages.stage04
import kalpa.common
import tantri.dipoles.types
import kalpa.config
import argparse
# try not to use this out side of main or when defining config stuff pls
import numpy
_logger = logging.getLogger(__name__)
@ -18,6 +23,26 @@ class Runner:
_logger.info(f"Initialising runner with {config=}")
def run(self):
if self.config.general_config.skip_to_stage is not None:
stage01 = kalpa.stages.stage01.Stage01Runner(self.config)
stage02 = kalpa.stages.stage02.Stage02Runner(self.config)
stage03 = kalpa.stages.stage03.Stage03Runner(self.config)
stage04 = kalpa.stages.stage04.Stage04Runner(self.config)
stages = [stage01, stage02, stage03, stage04]
start = int(self.config.general_config.skip_to_stage)
_logger.info(f"Received instruction to start at stage {start + 1}")
for i, stage in enumerate(stages[start: 4]):
_logger.info(f"*** Running stage {i + start + 1}")
stage.run()
else:
# standard run, can keep old
_logger.info("*** Beginning Stage 01 ***")
stage01 = kalpa.stages.stage01.Stage01Runner(self.config)
stage01.run()
@ -42,9 +67,16 @@ def parse_args():
)
parser.add_argument(
"--log-file",
"--override-root",
type=str,
help="A filename for logging to, if not provided will only log to stderr",
help="If provided, override the root dir.",
default=None,
)
parser.add_argument(
"-s", "--skip-to-stage",
type=int,
help="Skip to stage, if provided. 1 means stages 1-4 will run, 4 means only stage 4 will run.",
default=None,
)
args = parser.parse_args()
@ -55,23 +87,85 @@ def main():
args = parse_args()
tantri_configs = [
kalpa.TantriConfig(12345, 50, 0.5, 100000),
kalpa.TantriConfig(123456, 50, 0.5, 100000),
# kalpa.TantriConfig(1234, 50, 0.0005, 10000),
]
override_config = {
# "test1": [
# tantri.dipoles.types.DipoleTO(
# numpy.array([0, 0, 100]),
# numpy.array([-2, -2, 2.9]),
# 0.0005
# )
# ],
"two_dipole_connors_geom": [
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([-2, -2, 5.75]),
0.0005
),
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([6, 2, 5.75]),
0.05
)
],
"two_dipole_connors_geom_omegaswap": [
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([-2, -2, 5.75]),
0.05
),
tantri.dipoles.types.DipoleTO(
numpy.array([0, 0, 100]),
numpy.array([6, 2, 5.75]),
0.0005
)
]
}
generation_config = kalpa.GenerationConfig(
tantri_configs=tantri_configs,
counts=[1],
num_replicas=3,
orientations=[tantri.dipoles.types.Orientation.XY],
counts=[3, 31],
num_replicas=5,
# let's test this out
# override_dipole_configs=override_config,
orientations=[tantri.dipoles.types.Orientation.Z],
num_bin_time_series=25,
)
if args.override_root is None:
_logger.info("root dir not given")
root = pathlib.Path("plots0")
else:
root = pathlib.Path(args.override_root)
if args.skip_to_stage is not None:
if args.skip_to_stage not in [1, 2, 3, 4]:
raise ValueError(f"There is no stage {args.skip_to_stage}")
else:
skip = kalpa.config.SkipToStage(args.skip_to_stage - 1)
else:
skip = None
general_config = kalpa.GeneralConfig(
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL
measurement_type=kalpa.MeasurementTypeEnum.POTENTIAL,
out_dir_name=str(root / "out"),
skip_to_stage=skip
)
# kalpa.GeneralConfig
deepdog_config = kalpa.DeepdogConfig(
costs_to_try=[10, 2, 1, 0.1],
costs_to_try=[2, 1],
max_monte_carlo_cycles_steps=20,
target_success=200,
use_log_noise=True,
)
config = kalpa.Config(
@ -80,7 +174,7 @@ def main():
deepdog_config=deepdog_config,
)
kalpa.common.set_up_logging(config, args.log_file)
kalpa.common.set_up_logging(config, str(root / f"logs/{root}.log"))
_logger.info(f"Got {config=}")
runner = Runner(config)

View File

@ -7,7 +7,10 @@ import logging
import kalpa
import kalpa.common
import tantri.cli
import tantri.cli.input_files
import tantri.cli.input_files.write_dipoles
import tantri.dipoles.types
import typing
_logger = logging.getLogger(__name__)
@ -113,7 +116,9 @@ class Stage01Runner:
for tantri_index, tantri_config in enumerate(
self.config.generation_config.tantri_configs
):
output_csv = directory / kalpa.common.tantri_full_output_name(tantri_index)
output_csv = directory / kalpa.common.tantri_full_output_name(
tantri_index
)
binned_csv = directory / kalpa.common.tantri_binned_output_name(
tantri_index
)
@ -131,8 +136,83 @@ class Stage01Runner:
True,
)
# This whole method is duplication and is ripe for refactor, but that's fine!
# deliberately bad to get it done.
# here we're going to be manually specifying dipoles as we have from our config
def generate_override_dipole(
self, seed: int, override_name: str, override_dipoles: typing.Sequence[tantri.dipoles.types.DipoleTO]
):
"""
create a directory, populate it with stuff.
seed: still a seed integer to use
override_name: the name of this dipole configuration, from config file
override_dipoles: dipoles to override
"""
_logger.info(
f"Writing override config {override_name} with dipoles: [{override_dipoles}]"
)
out = self.config.get_out_dir_path()
directory = out / f"{override_name}"
directory.mkdir(parents=True, exist_ok=True)
_logger.debug("generated override directory")
# config_json = directory / "generation_config.json"
dipoles_json = directory / "dipoles.json"
# with open(config_json, "w") as conf_file:
# params = kalpa.ReducedModelParams(
# count=count, orientation=tantri.dipoles.types.Orientation(orientation)
# )
# _logger.debug(f"Got params {params=}")
# json.dump(params.config_dict(seed), conf_file)
# # json.dump(kalpa.common.model_config_dict(count, orientation, seed), conf_file)
# the original logic looked like this:
# tantri.cli._generate_dipoles(config_json, dipoles_json, (seed, replica, 1))
# We're replicating the bit that wrote the dipoles here, but that's a refactor opportunity
with dipoles_json.open("w") as dipole_out:
dipole_out.write(
json.dumps(
[dip.as_dict() for dip in override_dipoles],
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder
)
)
_logger.info(f"Wrote to dipoles file {dipoles_json}")
# tantri.cli._write_apsd(dipoles_json, DOTS, X_ELECTRIC_FIELD, DELTA_T, NUM_ITERATIONS, NUM_BIN_TS, (index, replica, 2), output_csv, binned_csv, BIN_WIDTH_LOG, True)
for tantri_index, tantri_config in enumerate(
self.config.generation_config.tantri_configs
):
output_csv = directory / kalpa.common.tantri_full_output_name(
tantri_index
)
binned_csv = directory / kalpa.common.tantri_binned_output_name(
tantri_index
)
tantri.cli._write_apsd(
dipoles_json,
self.config.general_config.dots_json_name,
self.config.general_config.measurement_type.value,
tantri_config.delta_t,
tantri_config.num_iterations,
self.config.generation_config.num_bin_time_series,
(seed, 2),
output_csv,
binned_csv,
self.config.generation_config.bin_log_width,
True,
)
def run(self):
seed_index = 0
if self.config.generation_config.override_dipole_configs is None:
# should be by default
_logger.debug("no override needed!")
for count in self.config.generation_config.counts:
for orientation in self.config.generation_config.orientations:
for replica in range(self.config.generation_config.num_replicas):
@ -141,6 +221,10 @@ class Stage01Runner:
)
self.generate_single_subdir(seed_index, count, orientation, replica)
seed_index += 1
else:
_logger.debug(f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}")
for override_name, override_dipoles in self.config.generation_config.override_dipole_configs.items():
self.generate_override_dipole(seed_index, override_name, override_dipoles)
def parse_args():

View File

@ -178,12 +178,16 @@ class Stage02Runner:
# TODO find way to store this as a global config file
occupancies_dict = {
1: (1000, 1000),
10: (1000, 100),
16: (10000, 10),
31: (1000, 100),
56: (1000, 100),
100: (100, 100),
1: (500, 1000),
2: (250, 2000),
3: (250, 2000),
5: (100, 5000),
10: (50, 10000),
16: (50, 10000),
17: (50, 10000),
31: (50, 10000),
56: (25, 20000),
100: (5, 100000),
}
mccount, mccountcycles = occupancies_dict[avg_filled]
@ -212,7 +216,7 @@ class Stage02Runner:
_logger.info(f"{deepdog_config=}")
stdev_cost_function_filters = [
b.stdev_cost_function_filter(dot_names, cost) for b in binned_datas
b.stdev_cost_function_filter(dot_names, cost, self.config.deepdog_config.use_log_noise) for b in binned_datas
]
_logger.debug(f"{stdev_cost_function_filters=}")

View File

@ -121,6 +121,31 @@ class Stage04Runner:
out_list.append(row)
return out_list
def read_merged_coalesced_csv_override(self, override_name: str) -> typing.Sequence:
subdir_name = override_name
subdir_path = self.config.get_out_dir_path() / subdir_name
csv_path = (
subdir_path
/ kalpa.common.sorted_bayesruns_name()
/ kalpa.common.merged_coalesced_name()
)
_logger.debug(f"Reading {csv_path=}")
with csv_path.open(mode="r", newline="") as csvfile:
reader = csv.DictReader(csvfile)
out_list = []
for row in reader:
# We can't put any of the actual info in because it's totally arbitrary, but that's fine!
# normal_orientation = ORIENTATION_DICT[orientation]
row["subdir_name"] = subdir_name
# row["actual_orientation"] = ORIENTATION_DICT[orientation]
# row["actual_avg_filled"] = count
# row["generation_replica_index"] = replica
# row["is_row_actual"] = is_actual(row, normal_orientation, count)
out_list.append(row)
return out_list
def run(self):
megamerged_path = (
self.config.get_out_dir_path() / self.config.general_config.mega_merged_name
@ -130,6 +155,10 @@ class Stage04Runner:
with megamerged_path.open(mode="w", newline="") as outfile:
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
writer.writeheader()
if self.config.generation_config.override_dipole_configs is None:
for count in self.config.generation_config.counts:
for orientation in self.config.generation_config.orientations:
for replica in range(self.config.generation_config.num_replicas):
@ -140,8 +169,21 @@ class Stage04Runner:
for row in rows:
writer.writerow(row)
else:
_logger.debug(f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}")
for override_name in self.config.generation_config.override_dipole_configs.keys():
_logger.info(f"Working for subdir {override_name}")
rows = self.read_merged_coalesced_csv_override(
override_name
)
for row in rows:
writer.writerow(row)
# merge with inference
if self.config.generation_config.override_dipole_configs is None:
with megamerged_path.open(mode="r", newline="") as infile:
# Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad:
# megamerged_reader = csv.DictReader(infile, fieldnames=MERGED_OUT_FIELDNAMES)
@ -169,7 +211,8 @@ class Stage04Runner:
for generation in dots.values():
for row in generation.values():
writer.writerow(row)
else:
_logger.info("skipping inference metamerge, overridden dipole config specified")
def parse_args():

View File

@ -10,8 +10,8 @@ python = ">=3.8.1,<3.10"
pdme = "^1.5.0"
deepdog = "^1.4.0"
tantri = "^1.2.0"
deepdog = "^1.5.0"
tantri = "^1.3.0"
tomli = "^2.0.1"
[tool.poetry.group.dev.dependencies]
black = "^24.8.0"