refactor: adds ability to hold pair measurements and abstract that out

This commit is contained in:
Deepak Mallubhotla 2025-02-23 18:35:53 -06:00
parent e83706b2b1
commit 785746049a
Signed by: deepak
GPG Key ID: BEBAEBF28083E022

View File

@ -40,14 +40,31 @@ def short_string_to_measurement_type(short_string: str) -> MeasurementTypeEnum:
@dataclasses.dataclass
class Measurement:
dot_measurement: pdme.measurement.DotMeasurement
dot_measurement: typing.Optional[pdme.measurement.DotMeasurement]
stdev: float
dot_pair_measurement: typing.Optional[pdme.measurement.DotPairMeasurement] = None
@dataclasses.dataclass
class MeasurementGroup:
_measurements: typing.Sequence[Measurement]
_measurement_type: MeasurementTypeEnum
_using_pairs: bool = dataclasses.field(init=False, default=False)
def validate(self):
if not self._measurements:
raise ValueError("Cannot have an empty measurement group")
using_pairs = any(
m.dot_pair_measurement is not None for m in self._measurements
)
using_singles = any(m.dot_measurement is not None for m in self._measurements)
if using_pairs and using_singles:
raise ValueError(
"Cannot mix single and pair measurements in a single measurement group"
)
if not using_pairs and not using_singles:
raise ValueError("Cannot have a measurement group with no measurements")
self._using_pairs = using_pairs
def add(self, other: MeasurementGroup) -> MeasurementGroup:
@ -62,17 +79,50 @@ class MeasurementGroup:
return MeasurementGroup(new_measurements, self._measurement_type)
def _meas_array(self) -> numpy.ndarray:
return numpy.array([m.dot_measurement.v for m in self._measurements])
if self._using_pairs:
return numpy.array(
[
m.dot_pair_measurement.v
for m in self._measurements
if m.dot_pair_measurement is not None
]
)
else:
return numpy.array(
[
m.dot_measurement.v
for m in self._measurements
if m.dot_measurement is not None
]
)
def _input_array(self) -> numpy.ndarray:
return pdme.measurement.input_types.dot_inputs_to_array(
[(m.dot_measurement.r, m.dot_measurement.f) for m in self._measurements]
)
if self._using_pairs:
return pdme.measurement.input_types.dot_pair_inputs_to_array(
[
(
m.dot_pair_measurement.r1,
m.dot_pair_measurement.r2,
m.dot_pair_measurement.f,
)
for m in self._measurements
if m.dot_pair_measurement is not None
]
)
else:
return pdme.measurement.input_types.dot_inputs_to_array(
[
(m.dot_measurement.r, m.dot_measurement.f)
for m in self._measurements
if m.dot_measurement is not None
]
)
def _stdev_array(self) -> numpy.ndarray:
return numpy.array([m.stdev for m in self._measurements])
def cost_function(self):
self.validate()
meas_array = self._meas_array()
_logger.debug(f"Obtained {meas_array=}")
@ -86,6 +136,7 @@ class MeasurementGroup:
self,
use_log_noise: bool = False,
):
self.validate()
stdev_array = self._stdev_array()
meas_array = self._meas_array()
@ -279,9 +330,15 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
return None
@dataclasses.dataclass
class CSV_BinnedData:
measurement_type: MeasurementTypeEnum
single_dot_dict: typing.Dict[str, typing.Any]
def read_bin_csv(
csv_file: pathlib.Path,
) -> typing.Tuple[MeasurementTypeEnum, typing.Dict[str, typing.Any]]:
) -> CSV_BinnedData:
"""
Read a binned csv file and return the measurement type and the binned data.
@ -353,7 +410,10 @@ def read_bin_csv(
raise ValueError(
f"For some reason {measurement_type=} is None? We want to know our measurement type."
)
return measurement_type, aggregated_dict
return CSV_BinnedData(
measurement_type=measurement_type, single_dot_dict=aggregated_dict
)
except Exception as e:
_logger.error(
f"Had a bad time reading the binned data {csv_file}, sorry.", exc_info=e
@ -462,7 +522,9 @@ class BinnedData:
def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> BinnedData:
dots = read_dots_json(json_file)
measurement_type, binned = read_bin_csv(csv_file)
csv_data = read_bin_csv(csv_file)
return BinnedData(
measurement_type=measurement_type, dots_dict=dots, csv_dict=binned
measurement_type=csv_data.measurement_type,
dots_dict=dots,
csv_dict=csv_data.single_dot_dict,
)