From 42dddcae0214fc576120531e3a15a7f9995dd126 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 23 Feb 2025 17:06:17 -0600 Subject: [PATCH] feat: adds new type for measurement groups to facilitate future refactors, technically breaking but not for what public interface should be --- kalpaa/read_bin_csv.py | 120 ++++++++++++++++-------- tests/read_bin_csv/test_read_bin_csv.py | 2 +- 2 files changed, 82 insertions(+), 40 deletions(-) diff --git a/kalpaa/read_bin_csv.py b/kalpaa/read_bin_csv.py index ed766a4..e41b2a4 100644 --- a/kalpaa/read_bin_csv.py +++ b/kalpaa/read_bin_csv.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +# useful for measurementgroup which is a type that has itself in method signatures, avoids having to manually specify the typehint as a string + import re import numpy import dataclasses @@ -29,13 +33,42 @@ class Measurement: stdev: float +@dataclasses.dataclass +class MeasurementGroup: + measurements: typing.Sequence[Measurement] + + def add(self, other: MeasurementGroup) -> MeasurementGroup: + + # this is probably not conformant to the ideal contract for typing.Sequence + new_measurements = [*self.measurements, *other.measurements] + + return MeasurementGroup(new_measurements) + + class CostFunction: - def __init__(self, measurement_type, dot_inputs_array, actual_measurement_array): + def __init__( + self, + measurement_type, + dot_inputs_array, + actual_measurement_array, + use_pair_measurement: bool = False, + ): + """ + Construct a cost function that uses the measurements. + + :param measurement_type: The type of measurement we're using. + :param dot_inputs_array: The array of dot inputs. + :param actual_measurement_array: The actual measurements. + :param use_pair_measurement: Whether to use pair measurements. (default false) + """ _logger.info(f"Cost function with measurement type of {measurement_type}") self.measurement_type = measurement_type self.dot_inputs_array = dot_inputs_array self.actual_measurement_array = actual_measurement_array self.actual_measurement_array2 = actual_measurement_array**2 + self.use_pair_measurement = use_pair_measurement + if self.use_pair_measurement: + raise NotImplementedError("Pair measurements are not yet supported") def __call__(self, dipoles_to_test): if self.measurement_type == X_ELECTRIC_FIELD: @@ -60,7 +93,18 @@ class StDevUsingCostFunction: actual_measurement_array, actual_stdev_array, log_noise: bool = False, + use_pair_measurement: bool = False, ): + """ + Construct a cost function that uses the standard deviation of the measurements. + + :param measurement_type: The type of measurement we're using. + :param dot_inputs_array: The array of dot inputs. (may be actually inputses for pair measurements) + :param actual_measurement_array: The actual measurements. + :param actual_stdev_array: The actual standard deviations. + :param use_pair_measurement: Whether to use pair measurements. (default false) + :param log_noise: Whether to use log noise. (default false but we should probably use it) + """ _logger.info(f"Cost function with measurement type of {measurement_type}") self.measurement_type = measurement_type self.dot_inputs_array = dot_inputs_array @@ -75,6 +119,7 @@ class StDevUsingCostFunction: numpy.log(self.actual_stdev_array + self.actual_measurement_array) - numpy.log(self.actual_measurement_array) ) ** 2 + self.use_pair_measurement = use_pair_measurement def __call__(self, dipoles_to_test): if self.measurement_type == X_ELECTRIC_FIELD: @@ -130,7 +175,11 @@ class ParsedBinHeader: dot_name: str # only used for pair measurements dot_name2: typing.Optional[str] = None - cpsd_type: typing.Optional[str] = None + cpsd_type: typing.Optional[typing.Literal["correlation", "phase"]] = None + + @property + def pair(self) -> bool: + return self.dot_name2 is not None def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]: @@ -151,12 +200,15 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]: pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field) ) is not None: groups = pair_match.groupdict() + cpsd_type = typing.cast( + typing.Literal["correlation", "phase"], groups["cpsd_type"] + ) return ParsedBinHeader( original_field=field, measurement_type=groups["measurement_type"], dot_name=groups["dot_name"], dot_name2=groups["dot_name2"], - cpsd_type=groups["cpsd_type"], + cpsd_type=cpsd_type, summary_stat=groups["summary_stat"], ) else: @@ -247,7 +299,7 @@ class BinnedData: measurement_type: str # we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script. - def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]: + def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup: if dot_name not in self.dots_dict: raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}") if dot_name not in self.csv_dict: @@ -258,13 +310,15 @@ class BinnedData: vs = self.csv_dict[dot_name]["mean"] stdevs = self.csv_dict[dot_name]["stdev"] - return [ - Measurement( - dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r), - stdev=stdev, - ) - for f, v, stdev in zip(freqs, vs, stdevs) - ] + return MeasurementGroup( + [ + Measurement( + dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r), + stdev=stdev, + ) + for f, v, stdev in zip(freqs, vs, stdevs) + ] + ) def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]: if dot_name not in self.dots_dict: @@ -276,22 +330,21 @@ class BinnedData: return stdevs - def measurements( - self, dot_names: typing.Sequence[str] - ) -> typing.Sequence[Measurement]: + def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup: _logger.debug(f"Constructing measurements for dots {dot_names=}") - ret: typing.List[Measurement] = [] + ret = MeasurementGroup([]) + _logger.debug for dot_name in dot_names: - ret.extend(self._dot_to_measurement(dot_name)) + ret = ret.add(self._dot_to_measurements(dot_name)) return ret - def _cost_function(self, measurements: typing.Sequence[Measurement]): - dot_measurements = [m.dot_measurement for m in measurements] + def _cost_function(self, mg: MeasurementGroup): + dot_measurements = [m.dot_measurement for m in mg.measurements] meas_array = numpy.array([m.v for m in dot_measurements]) _logger.debug(f"Obtained {meas_array=}") - inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements] + inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements] input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs) _logger.debug(f"Obtained {input_array=}") @@ -299,20 +352,24 @@ class BinnedData: def _stdev_cost_function( self, - measurements: typing.Sequence[Measurement], + mg: MeasurementGroup, use_log_noise: bool = False, ): - meas_array = numpy.array([m.dot_measurement.v for m in measurements]) - stdev_array = numpy.array([m.stdev for m in measurements]) + meas_array = numpy.array([m.dot_measurement.v for m in mg.measurements]) + stdev_array = numpy.array([m.stdev for m in mg.measurements]) _logger.debug(f"Obtained {meas_array=}") - inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements] + inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements] input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs) _logger.debug(f"Obtained {input_array=}") return StDevUsingCostFunction( - self.measurement_type, input_array, meas_array, stdev_array, use_log_noise + self.measurement_type, + input_array, + meas_array, + stdev_array, + log_noise=use_log_noise, ) def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float): @@ -341,18 +398,3 @@ def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> Bin return BinnedData( measurement_type=measurement_type, dots_dict=dots, csv_dict=binned ) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - - print(read_dots_json(pathlib.Path("dots.json"))) - # print(read_bin_csv(pathlib.Path("binned-0.01-10000-50-12345.csv"))) - binned_data = read_dots_and_binned( - pathlib.Path("dots.json"), pathlib.Path("binned-0.01-10000-50-12345.csv") - ) - _logger.info(binned_data) - for entry in binned_data.measurements(["uprise1", "dot1"]): - _logger.info(entry) - filter = binned_data.cost_function_filter(["uprise1", "dot1"], 0.5) - _logger.info(filter) diff --git a/tests/read_bin_csv/test_read_bin_csv.py b/tests/read_bin_csv/test_read_bin_csv.py index 641c490..ff74859 100644 --- a/tests/read_bin_csv/test_read_bin_csv.py +++ b/tests/read_bin_csv/test_read_bin_csv.py @@ -93,7 +93,7 @@ def test_binned_data_dot_measurement_costs(snapshot): binned_ex = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, ex_csv_file) measurements_ex = binned_ex.measurements(["dot1"]) - # _logger.warning(measurements) + _logger.warning(measurements_ex) v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True) ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True)