feat: adds new type for measurement groups to facilitate future refactors, technically breaking but not for what public interface should be
All checks were successful
gitea-physics/kalpa/pipeline/head This commit looks good

This commit is contained in:
Deepak Mallubhotla 2025-02-23 17:06:17 -06:00
parent 1e840a8c32
commit 42dddcae02
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
2 changed files with 82 additions and 40 deletions

View File

@ -1,3 +1,7 @@
from __future__ import annotations
# useful for measurementgroup which is a type that has itself in method signatures, avoids having to manually specify the typehint as a string
import re import re
import numpy import numpy
import dataclasses import dataclasses
@ -29,13 +33,42 @@ class Measurement:
stdev: float stdev: float
@dataclasses.dataclass
class MeasurementGroup:
measurements: typing.Sequence[Measurement]
def add(self, other: MeasurementGroup) -> MeasurementGroup:
# this is probably not conformant to the ideal contract for typing.Sequence
new_measurements = [*self.measurements, *other.measurements]
return MeasurementGroup(new_measurements)
class CostFunction: class CostFunction:
def __init__(self, measurement_type, dot_inputs_array, actual_measurement_array): def __init__(
self,
measurement_type,
dot_inputs_array,
actual_measurement_array,
use_pair_measurement: bool = False,
):
"""
Construct a cost function that uses the measurements.
:param measurement_type: The type of measurement we're using.
:param dot_inputs_array: The array of dot inputs.
:param actual_measurement_array: The actual measurements.
:param use_pair_measurement: Whether to use pair measurements. (default false)
"""
_logger.info(f"Cost function with measurement type of {measurement_type}") _logger.info(f"Cost function with measurement type of {measurement_type}")
self.measurement_type = measurement_type self.measurement_type = measurement_type
self.dot_inputs_array = dot_inputs_array self.dot_inputs_array = dot_inputs_array
self.actual_measurement_array = actual_measurement_array self.actual_measurement_array = actual_measurement_array
self.actual_measurement_array2 = actual_measurement_array**2 self.actual_measurement_array2 = actual_measurement_array**2
self.use_pair_measurement = use_pair_measurement
if self.use_pair_measurement:
raise NotImplementedError("Pair measurements are not yet supported")
def __call__(self, dipoles_to_test): def __call__(self, dipoles_to_test):
if self.measurement_type == X_ELECTRIC_FIELD: if self.measurement_type == X_ELECTRIC_FIELD:
@ -60,7 +93,18 @@ class StDevUsingCostFunction:
actual_measurement_array, actual_measurement_array,
actual_stdev_array, actual_stdev_array,
log_noise: bool = False, log_noise: bool = False,
use_pair_measurement: bool = False,
): ):
"""
Construct a cost function that uses the standard deviation of the measurements.
:param measurement_type: The type of measurement we're using.
:param dot_inputs_array: The array of dot inputs. (may be actually inputses for pair measurements)
:param actual_measurement_array: The actual measurements.
:param actual_stdev_array: The actual standard deviations.
:param use_pair_measurement: Whether to use pair measurements. (default false)
:param log_noise: Whether to use log noise. (default false but we should probably use it)
"""
_logger.info(f"Cost function with measurement type of {measurement_type}") _logger.info(f"Cost function with measurement type of {measurement_type}")
self.measurement_type = measurement_type self.measurement_type = measurement_type
self.dot_inputs_array = dot_inputs_array self.dot_inputs_array = dot_inputs_array
@ -75,6 +119,7 @@ class StDevUsingCostFunction:
numpy.log(self.actual_stdev_array + self.actual_measurement_array) numpy.log(self.actual_stdev_array + self.actual_measurement_array)
- numpy.log(self.actual_measurement_array) - numpy.log(self.actual_measurement_array)
) ** 2 ) ** 2
self.use_pair_measurement = use_pair_measurement
def __call__(self, dipoles_to_test): def __call__(self, dipoles_to_test):
if self.measurement_type == X_ELECTRIC_FIELD: if self.measurement_type == X_ELECTRIC_FIELD:
@ -130,7 +175,11 @@ class ParsedBinHeader:
dot_name: str dot_name: str
# only used for pair measurements # only used for pair measurements
dot_name2: typing.Optional[str] = None dot_name2: typing.Optional[str] = None
cpsd_type: typing.Optional[str] = None cpsd_type: typing.Optional[typing.Literal["correlation", "phase"]] = None
@property
def pair(self) -> bool:
return self.dot_name2 is not None
def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]: def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
@ -151,12 +200,15 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field) pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field)
) is not None: ) is not None:
groups = pair_match.groupdict() groups = pair_match.groupdict()
cpsd_type = typing.cast(
typing.Literal["correlation", "phase"], groups["cpsd_type"]
)
return ParsedBinHeader( return ParsedBinHeader(
original_field=field, original_field=field,
measurement_type=groups["measurement_type"], measurement_type=groups["measurement_type"],
dot_name=groups["dot_name"], dot_name=groups["dot_name"],
dot_name2=groups["dot_name2"], dot_name2=groups["dot_name2"],
cpsd_type=groups["cpsd_type"], cpsd_type=cpsd_type,
summary_stat=groups["summary_stat"], summary_stat=groups["summary_stat"],
) )
else: else:
@ -247,7 +299,7 @@ class BinnedData:
measurement_type: str measurement_type: str
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script. # we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]: def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup:
if dot_name not in self.dots_dict: if dot_name not in self.dots_dict:
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}") raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
if dot_name not in self.csv_dict: if dot_name not in self.csv_dict:
@ -258,13 +310,15 @@ class BinnedData:
vs = self.csv_dict[dot_name]["mean"] vs = self.csv_dict[dot_name]["mean"]
stdevs = self.csv_dict[dot_name]["stdev"] stdevs = self.csv_dict[dot_name]["stdev"]
return [ return MeasurementGroup(
[
Measurement( Measurement(
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r), dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
stdev=stdev, stdev=stdev,
) )
for f, v, stdev in zip(freqs, vs, stdevs) for f, v, stdev in zip(freqs, vs, stdevs)
] ]
)
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]: def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
if dot_name not in self.dots_dict: if dot_name not in self.dots_dict:
@ -276,22 +330,21 @@ class BinnedData:
return stdevs return stdevs
def measurements( def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup:
self, dot_names: typing.Sequence[str]
) -> typing.Sequence[Measurement]:
_logger.debug(f"Constructing measurements for dots {dot_names=}") _logger.debug(f"Constructing measurements for dots {dot_names=}")
ret: typing.List[Measurement] = [] ret = MeasurementGroup([])
_logger.debug
for dot_name in dot_names: for dot_name in dot_names:
ret.extend(self._dot_to_measurement(dot_name)) ret = ret.add(self._dot_to_measurements(dot_name))
return ret return ret
def _cost_function(self, measurements: typing.Sequence[Measurement]): def _cost_function(self, mg: MeasurementGroup):
dot_measurements = [m.dot_measurement for m in measurements] dot_measurements = [m.dot_measurement for m in mg.measurements]
meas_array = numpy.array([m.v for m in dot_measurements]) meas_array = numpy.array([m.v for m in dot_measurements])
_logger.debug(f"Obtained {meas_array=}") _logger.debug(f"Obtained {meas_array=}")
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements] inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs) input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
_logger.debug(f"Obtained {input_array=}") _logger.debug(f"Obtained {input_array=}")
@ -299,20 +352,24 @@ class BinnedData:
def _stdev_cost_function( def _stdev_cost_function(
self, self,
measurements: typing.Sequence[Measurement], mg: MeasurementGroup,
use_log_noise: bool = False, use_log_noise: bool = False,
): ):
meas_array = numpy.array([m.dot_measurement.v for m in measurements]) meas_array = numpy.array([m.dot_measurement.v for m in mg.measurements])
stdev_array = numpy.array([m.stdev for m in measurements]) stdev_array = numpy.array([m.stdev for m in mg.measurements])
_logger.debug(f"Obtained {meas_array=}") _logger.debug(f"Obtained {meas_array=}")
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements] inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs) input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
_logger.debug(f"Obtained {input_array=}") _logger.debug(f"Obtained {input_array=}")
return StDevUsingCostFunction( return StDevUsingCostFunction(
self.measurement_type, input_array, meas_array, stdev_array, use_log_noise self.measurement_type,
input_array,
meas_array,
stdev_array,
log_noise=use_log_noise,
) )
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float): def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
@ -341,18 +398,3 @@ def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> Bin
return BinnedData( return BinnedData(
measurement_type=measurement_type, dots_dict=dots, csv_dict=binned measurement_type=measurement_type, dots_dict=dots, csv_dict=binned
) )
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
print(read_dots_json(pathlib.Path("dots.json")))
# print(read_bin_csv(pathlib.Path("binned-0.01-10000-50-12345.csv")))
binned_data = read_dots_and_binned(
pathlib.Path("dots.json"), pathlib.Path("binned-0.01-10000-50-12345.csv")
)
_logger.info(binned_data)
for entry in binned_data.measurements(["uprise1", "dot1"]):
_logger.info(entry)
filter = binned_data.cost_function_filter(["uprise1", "dot1"], 0.5)
_logger.info(filter)

View File

@ -93,7 +93,7 @@ def test_binned_data_dot_measurement_costs(snapshot):
binned_ex = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, ex_csv_file) binned_ex = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, ex_csv_file)
measurements_ex = binned_ex.measurements(["dot1"]) measurements_ex = binned_ex.measurements(["dot1"])
# _logger.warning(measurements) _logger.warning(measurements_ex)
v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True) v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True)
ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True) ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True)