feat: adds new type for measurement groups to facilitate future refactors, technically breaking but not for what public interface should be
All checks were successful
gitea-physics/kalpa/pipeline/head This commit looks good
All checks were successful
gitea-physics/kalpa/pipeline/head This commit looks good
This commit is contained in:
parent
1e840a8c32
commit
42dddcae02
@ -1,3 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# useful for measurementgroup which is a type that has itself in method signatures, avoids having to manually specify the typehint as a string
|
||||
|
||||
import re
|
||||
import numpy
|
||||
import dataclasses
|
||||
@ -29,13 +33,42 @@ class Measurement:
|
||||
stdev: float
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MeasurementGroup:
|
||||
measurements: typing.Sequence[Measurement]
|
||||
|
||||
def add(self, other: MeasurementGroup) -> MeasurementGroup:
|
||||
|
||||
# this is probably not conformant to the ideal contract for typing.Sequence
|
||||
new_measurements = [*self.measurements, *other.measurements]
|
||||
|
||||
return MeasurementGroup(new_measurements)
|
||||
|
||||
|
||||
class CostFunction:
|
||||
def __init__(self, measurement_type, dot_inputs_array, actual_measurement_array):
|
||||
def __init__(
|
||||
self,
|
||||
measurement_type,
|
||||
dot_inputs_array,
|
||||
actual_measurement_array,
|
||||
use_pair_measurement: bool = False,
|
||||
):
|
||||
"""
|
||||
Construct a cost function that uses the measurements.
|
||||
|
||||
:param measurement_type: The type of measurement we're using.
|
||||
:param dot_inputs_array: The array of dot inputs.
|
||||
:param actual_measurement_array: The actual measurements.
|
||||
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
||||
"""
|
||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||
self.measurement_type = measurement_type
|
||||
self.dot_inputs_array = dot_inputs_array
|
||||
self.actual_measurement_array = actual_measurement_array
|
||||
self.actual_measurement_array2 = actual_measurement_array**2
|
||||
self.use_pair_measurement = use_pair_measurement
|
||||
if self.use_pair_measurement:
|
||||
raise NotImplementedError("Pair measurements are not yet supported")
|
||||
|
||||
def __call__(self, dipoles_to_test):
|
||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||
@ -60,7 +93,18 @@ class StDevUsingCostFunction:
|
||||
actual_measurement_array,
|
||||
actual_stdev_array,
|
||||
log_noise: bool = False,
|
||||
use_pair_measurement: bool = False,
|
||||
):
|
||||
"""
|
||||
Construct a cost function that uses the standard deviation of the measurements.
|
||||
|
||||
:param measurement_type: The type of measurement we're using.
|
||||
:param dot_inputs_array: The array of dot inputs. (may be actually inputses for pair measurements)
|
||||
:param actual_measurement_array: The actual measurements.
|
||||
:param actual_stdev_array: The actual standard deviations.
|
||||
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
||||
:param log_noise: Whether to use log noise. (default false but we should probably use it)
|
||||
"""
|
||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||
self.measurement_type = measurement_type
|
||||
self.dot_inputs_array = dot_inputs_array
|
||||
@ -75,6 +119,7 @@ class StDevUsingCostFunction:
|
||||
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
|
||||
- numpy.log(self.actual_measurement_array)
|
||||
) ** 2
|
||||
self.use_pair_measurement = use_pair_measurement
|
||||
|
||||
def __call__(self, dipoles_to_test):
|
||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||
@ -130,7 +175,11 @@ class ParsedBinHeader:
|
||||
dot_name: str
|
||||
# only used for pair measurements
|
||||
dot_name2: typing.Optional[str] = None
|
||||
cpsd_type: typing.Optional[str] = None
|
||||
cpsd_type: typing.Optional[typing.Literal["correlation", "phase"]] = None
|
||||
|
||||
@property
|
||||
def pair(self) -> bool:
|
||||
return self.dot_name2 is not None
|
||||
|
||||
|
||||
def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||
@ -151,12 +200,15 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||
pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field)
|
||||
) is not None:
|
||||
groups = pair_match.groupdict()
|
||||
cpsd_type = typing.cast(
|
||||
typing.Literal["correlation", "phase"], groups["cpsd_type"]
|
||||
)
|
||||
return ParsedBinHeader(
|
||||
original_field=field,
|
||||
measurement_type=groups["measurement_type"],
|
||||
dot_name=groups["dot_name"],
|
||||
dot_name2=groups["dot_name2"],
|
||||
cpsd_type=groups["cpsd_type"],
|
||||
cpsd_type=cpsd_type,
|
||||
summary_stat=groups["summary_stat"],
|
||||
)
|
||||
else:
|
||||
@ -247,7 +299,7 @@ class BinnedData:
|
||||
measurement_type: str
|
||||
|
||||
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
||||
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]:
|
||||
def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup:
|
||||
if dot_name not in self.dots_dict:
|
||||
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
||||
if dot_name not in self.csv_dict:
|
||||
@ -258,13 +310,15 @@ class BinnedData:
|
||||
vs = self.csv_dict[dot_name]["mean"]
|
||||
stdevs = self.csv_dict[dot_name]["stdev"]
|
||||
|
||||
return [
|
||||
Measurement(
|
||||
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
|
||||
stdev=stdev,
|
||||
)
|
||||
for f, v, stdev in zip(freqs, vs, stdevs)
|
||||
]
|
||||
return MeasurementGroup(
|
||||
[
|
||||
Measurement(
|
||||
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
|
||||
stdev=stdev,
|
||||
)
|
||||
for f, v, stdev in zip(freqs, vs, stdevs)
|
||||
]
|
||||
)
|
||||
|
||||
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
|
||||
if dot_name not in self.dots_dict:
|
||||
@ -276,22 +330,21 @@ class BinnedData:
|
||||
|
||||
return stdevs
|
||||
|
||||
def measurements(
|
||||
self, dot_names: typing.Sequence[str]
|
||||
) -> typing.Sequence[Measurement]:
|
||||
def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup:
|
||||
_logger.debug(f"Constructing measurements for dots {dot_names=}")
|
||||
ret: typing.List[Measurement] = []
|
||||
ret = MeasurementGroup([])
|
||||
_logger.debug
|
||||
for dot_name in dot_names:
|
||||
ret.extend(self._dot_to_measurement(dot_name))
|
||||
ret = ret.add(self._dot_to_measurements(dot_name))
|
||||
return ret
|
||||
|
||||
def _cost_function(self, measurements: typing.Sequence[Measurement]):
|
||||
dot_measurements = [m.dot_measurement for m in measurements]
|
||||
def _cost_function(self, mg: MeasurementGroup):
|
||||
dot_measurements = [m.dot_measurement for m in mg.measurements]
|
||||
meas_array = numpy.array([m.v for m in dot_measurements])
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements]
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
|
||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
@ -299,20 +352,24 @@ class BinnedData:
|
||||
|
||||
def _stdev_cost_function(
|
||||
self,
|
||||
measurements: typing.Sequence[Measurement],
|
||||
mg: MeasurementGroup,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
meas_array = numpy.array([m.dot_measurement.v for m in measurements])
|
||||
stdev_array = numpy.array([m.stdev for m in measurements])
|
||||
meas_array = numpy.array([m.dot_measurement.v for m in mg.measurements])
|
||||
stdev_array = numpy.array([m.stdev for m in mg.measurements])
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements]
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
|
||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return StDevUsingCostFunction(
|
||||
self.measurement_type, input_array, meas_array, stdev_array, use_log_noise
|
||||
self.measurement_type,
|
||||
input_array,
|
||||
meas_array,
|
||||
stdev_array,
|
||||
log_noise=use_log_noise,
|
||||
)
|
||||
|
||||
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
||||
@ -341,18 +398,3 @@ def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> Bin
|
||||
return BinnedData(
|
||||
measurement_type=measurement_type, dots_dict=dots, csv_dict=binned
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
print(read_dots_json(pathlib.Path("dots.json")))
|
||||
# print(read_bin_csv(pathlib.Path("binned-0.01-10000-50-12345.csv")))
|
||||
binned_data = read_dots_and_binned(
|
||||
pathlib.Path("dots.json"), pathlib.Path("binned-0.01-10000-50-12345.csv")
|
||||
)
|
||||
_logger.info(binned_data)
|
||||
for entry in binned_data.measurements(["uprise1", "dot1"]):
|
||||
_logger.info(entry)
|
||||
filter = binned_data.cost_function_filter(["uprise1", "dot1"], 0.5)
|
||||
_logger.info(filter)
|
||||
|
@ -93,7 +93,7 @@ def test_binned_data_dot_measurement_costs(snapshot):
|
||||
|
||||
binned_ex = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, ex_csv_file)
|
||||
measurements_ex = binned_ex.measurements(["dot1"])
|
||||
# _logger.warning(measurements)
|
||||
_logger.warning(measurements_ex)
|
||||
|
||||
v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True)
|
||||
ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user