refactor: adds ability to hold pair measurements and abstract that out
This commit is contained in:
parent
e83706b2b1
commit
785746049a
@ -40,14 +40,31 @@ def short_string_to_measurement_type(short_string: str) -> MeasurementTypeEnum:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Measurement:
|
||||
dot_measurement: pdme.measurement.DotMeasurement
|
||||
dot_measurement: typing.Optional[pdme.measurement.DotMeasurement]
|
||||
stdev: float
|
||||
dot_pair_measurement: typing.Optional[pdme.measurement.DotPairMeasurement] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MeasurementGroup:
|
||||
_measurements: typing.Sequence[Measurement]
|
||||
_measurement_type: MeasurementTypeEnum
|
||||
_using_pairs: bool = dataclasses.field(init=False, default=False)
|
||||
|
||||
def validate(self):
|
||||
if not self._measurements:
|
||||
raise ValueError("Cannot have an empty measurement group")
|
||||
using_pairs = any(
|
||||
m.dot_pair_measurement is not None for m in self._measurements
|
||||
)
|
||||
using_singles = any(m.dot_measurement is not None for m in self._measurements)
|
||||
if using_pairs and using_singles:
|
||||
raise ValueError(
|
||||
"Cannot mix single and pair measurements in a single measurement group"
|
||||
)
|
||||
if not using_pairs and not using_singles:
|
||||
raise ValueError("Cannot have a measurement group with no measurements")
|
||||
self._using_pairs = using_pairs
|
||||
|
||||
def add(self, other: MeasurementGroup) -> MeasurementGroup:
|
||||
|
||||
@ -62,17 +79,50 @@ class MeasurementGroup:
|
||||
return MeasurementGroup(new_measurements, self._measurement_type)
|
||||
|
||||
def _meas_array(self) -> numpy.ndarray:
|
||||
return numpy.array([m.dot_measurement.v for m in self._measurements])
|
||||
if self._using_pairs:
|
||||
return numpy.array(
|
||||
[
|
||||
m.dot_pair_measurement.v
|
||||
for m in self._measurements
|
||||
if m.dot_pair_measurement is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
return numpy.array(
|
||||
[
|
||||
m.dot_measurement.v
|
||||
for m in self._measurements
|
||||
if m.dot_measurement is not None
|
||||
]
|
||||
)
|
||||
|
||||
def _input_array(self) -> numpy.ndarray:
|
||||
return pdme.measurement.input_types.dot_inputs_to_array(
|
||||
[(m.dot_measurement.r, m.dot_measurement.f) for m in self._measurements]
|
||||
)
|
||||
if self._using_pairs:
|
||||
return pdme.measurement.input_types.dot_pair_inputs_to_array(
|
||||
[
|
||||
(
|
||||
m.dot_pair_measurement.r1,
|
||||
m.dot_pair_measurement.r2,
|
||||
m.dot_pair_measurement.f,
|
||||
)
|
||||
for m in self._measurements
|
||||
if m.dot_pair_measurement is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
return pdme.measurement.input_types.dot_inputs_to_array(
|
||||
[
|
||||
(m.dot_measurement.r, m.dot_measurement.f)
|
||||
for m in self._measurements
|
||||
if m.dot_measurement is not None
|
||||
]
|
||||
)
|
||||
|
||||
def _stdev_array(self) -> numpy.ndarray:
|
||||
return numpy.array([m.stdev for m in self._measurements])
|
||||
|
||||
def cost_function(self):
|
||||
self.validate()
|
||||
meas_array = self._meas_array()
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
@ -86,6 +136,7 @@ class MeasurementGroup:
|
||||
self,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
self.validate()
|
||||
stdev_array = self._stdev_array()
|
||||
|
||||
meas_array = self._meas_array()
|
||||
@ -279,9 +330,15 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CSV_BinnedData:
|
||||
measurement_type: MeasurementTypeEnum
|
||||
single_dot_dict: typing.Dict[str, typing.Any]
|
||||
|
||||
|
||||
def read_bin_csv(
|
||||
csv_file: pathlib.Path,
|
||||
) -> typing.Tuple[MeasurementTypeEnum, typing.Dict[str, typing.Any]]:
|
||||
) -> CSV_BinnedData:
|
||||
"""
|
||||
Read a binned csv file and return the measurement type and the binned data.
|
||||
|
||||
@ -353,7 +410,10 @@ def read_bin_csv(
|
||||
raise ValueError(
|
||||
f"For some reason {measurement_type=} is None? We want to know our measurement type."
|
||||
)
|
||||
return measurement_type, aggregated_dict
|
||||
|
||||
return CSV_BinnedData(
|
||||
measurement_type=measurement_type, single_dot_dict=aggregated_dict
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.error(
|
||||
f"Had a bad time reading the binned data {csv_file}, sorry.", exc_info=e
|
||||
@ -462,7 +522,9 @@ class BinnedData:
|
||||
|
||||
def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> BinnedData:
|
||||
dots = read_dots_json(json_file)
|
||||
measurement_type, binned = read_bin_csv(csv_file)
|
||||
csv_data = read_bin_csv(csv_file)
|
||||
return BinnedData(
|
||||
measurement_type=measurement_type, dots_dict=dots, csv_dict=binned
|
||||
measurement_type=csv_data.measurement_type,
|
||||
dots_dict=dots,
|
||||
csv_dict=csv_data.single_dot_dict,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user