Deepak Mallubhotla a731a81c6a
Some checks failed
gitea-physics/kalpa/pipeline/head There was a failure building this commit
feat: many disparate updates for modernising
2025-02-21 15:58:12 -06:00

105 lines
2.8 KiB
Python

import logging
_logger = logging.getLogger(__name__)
class Keys:
def __init__(self, row):
self.row = row
def actual_key(self):
return (self.row["actual_orientation"], self.row["actual_avg_filled"])
def dot_cost_key(self):
return (self.row["dot_name"], self.row["target_cost"])
def model_key(self):
return (
self.row["orientation"],
self.row["avg_filled"],
self.row["log_magnitude"],
)
def replica_key(self):
return self.row["generation_replica_index"]
def all_keys(self):
return (
self.actual_key(),
self.dot_cost_key(),
self.replica_key(),
self.model_key(),
)
class Coalescer:
def __init__(self, rows, num_replicas: int):
self.rows = rows
# sort into actuals, then dots, then probs
self.actual_dict: dict = {}
for row in self.rows:
keys = Keys(row).all_keys()
_logger.debug(keys)
if keys[0] not in self.actual_dict:
_logger.debug(f"Creating layer 0 for {keys[0]}")
self.actual_dict[keys[0]] = {}
if keys[1] not in self.actual_dict[keys[0]]:
_logger.debug(f"Creating layer 1 for {keys[0]}, {keys[1]}")
self.actual_dict[keys[0]][keys[1]] = {}
if keys[2] not in self.actual_dict[keys[0]][keys[1]]:
_logger.debug(f"Creating layer 2 for {keys[0]}, {keys[1]}, {keys[2]}")
self.actual_dict[keys[0]][keys[1]][keys[2]] = {}
_logger.debug(
f"Adding to {self.actual_dict[keys[0]][keys[1]][keys[2]]} for {keys[3]}"
)
self.actual_dict[keys[0]][keys[1]][keys[2]][keys[3]] = row
self.num_replicas = num_replicas
def coalesce_generations(self, actual_key, dot_key):
_logger.debug(self.actual_dict.keys())
_logger.debug(self.actual_dict[actual_key].keys())
subdict = self.actual_dict[actual_key][dot_key]
_logger.debug(f"subdict keys: {subdict.keys()}")
# TODO hardcoding 3 generations
# if self.num_replicas != 3:
# raise ValueError(
# f"num replicas was {self.num_replicas}, but we've hard coded 3"
# )
# generations_keys = ["0", "1", "2"]
_logger.info(f"Going through generation {0}")
# 0th gen is easiest
for model_key, val in subdict["0"].items():
val["coalesced_prob"] = val["prob"]
if self.num_replicas > 1:
for gen in range(1, self.num_replicas):
_logger.info(f"Going through generation {gen}")
generation_weight = sum(
[
float(subdict[str(gen - 1)][key]["coalesced_prob"])
* float(subdict[str(gen)][key]["prob"])
for key in subdict[str(gen)].keys()
]
)
_logger.debug(generation_weight)
for model_key, val in subdict[str(gen)].items():
val["coalesced_prob"] = (
float(val["prob"])
* float(subdict[str(gen - 1)][model_key]["coalesced_prob"])
/ generation_weight
)
def coalesce_all(self):
for actual_key in self.actual_dict.keys():
for dot_key in self.actual_dict[actual_key].keys():
self.coalesce_generations(actual_key, dot_key)
return self.actual_dict