feat: adds new type for measurement groups to facilitate future refactors, technically breaking but not for what public interface should be
All checks were successful
gitea-physics/kalpa/pipeline/head This commit looks good
All checks were successful
gitea-physics/kalpa/pipeline/head This commit looks good
This commit is contained in:
parent
1e840a8c32
commit
42dddcae02
@ -1,3 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# useful for measurementgroup which is a type that has itself in method signatures, avoids having to manually specify the typehint as a string
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import numpy
|
import numpy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
@ -29,13 +33,42 @@ class Measurement:
|
|||||||
stdev: float
|
stdev: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MeasurementGroup:
|
||||||
|
measurements: typing.Sequence[Measurement]
|
||||||
|
|
||||||
|
def add(self, other: MeasurementGroup) -> MeasurementGroup:
|
||||||
|
|
||||||
|
# this is probably not conformant to the ideal contract for typing.Sequence
|
||||||
|
new_measurements = [*self.measurements, *other.measurements]
|
||||||
|
|
||||||
|
return MeasurementGroup(new_measurements)
|
||||||
|
|
||||||
|
|
||||||
class CostFunction:
|
class CostFunction:
|
||||||
def __init__(self, measurement_type, dot_inputs_array, actual_measurement_array):
|
def __init__(
|
||||||
|
self,
|
||||||
|
measurement_type,
|
||||||
|
dot_inputs_array,
|
||||||
|
actual_measurement_array,
|
||||||
|
use_pair_measurement: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a cost function that uses the measurements.
|
||||||
|
|
||||||
|
:param measurement_type: The type of measurement we're using.
|
||||||
|
:param dot_inputs_array: The array of dot inputs.
|
||||||
|
:param actual_measurement_array: The actual measurements.
|
||||||
|
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
||||||
|
"""
|
||||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||||
self.measurement_type = measurement_type
|
self.measurement_type = measurement_type
|
||||||
self.dot_inputs_array = dot_inputs_array
|
self.dot_inputs_array = dot_inputs_array
|
||||||
self.actual_measurement_array = actual_measurement_array
|
self.actual_measurement_array = actual_measurement_array
|
||||||
self.actual_measurement_array2 = actual_measurement_array**2
|
self.actual_measurement_array2 = actual_measurement_array**2
|
||||||
|
self.use_pair_measurement = use_pair_measurement
|
||||||
|
if self.use_pair_measurement:
|
||||||
|
raise NotImplementedError("Pair measurements are not yet supported")
|
||||||
|
|
||||||
def __call__(self, dipoles_to_test):
|
def __call__(self, dipoles_to_test):
|
||||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||||
@ -60,7 +93,18 @@ class StDevUsingCostFunction:
|
|||||||
actual_measurement_array,
|
actual_measurement_array,
|
||||||
actual_stdev_array,
|
actual_stdev_array,
|
||||||
log_noise: bool = False,
|
log_noise: bool = False,
|
||||||
|
use_pair_measurement: bool = False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Construct a cost function that uses the standard deviation of the measurements.
|
||||||
|
|
||||||
|
:param measurement_type: The type of measurement we're using.
|
||||||
|
:param dot_inputs_array: The array of dot inputs. (may be actually inputses for pair measurements)
|
||||||
|
:param actual_measurement_array: The actual measurements.
|
||||||
|
:param actual_stdev_array: The actual standard deviations.
|
||||||
|
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
||||||
|
:param log_noise: Whether to use log noise. (default false but we should probably use it)
|
||||||
|
"""
|
||||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||||
self.measurement_type = measurement_type
|
self.measurement_type = measurement_type
|
||||||
self.dot_inputs_array = dot_inputs_array
|
self.dot_inputs_array = dot_inputs_array
|
||||||
@ -75,6 +119,7 @@ class StDevUsingCostFunction:
|
|||||||
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
|
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
|
||||||
- numpy.log(self.actual_measurement_array)
|
- numpy.log(self.actual_measurement_array)
|
||||||
) ** 2
|
) ** 2
|
||||||
|
self.use_pair_measurement = use_pair_measurement
|
||||||
|
|
||||||
def __call__(self, dipoles_to_test):
|
def __call__(self, dipoles_to_test):
|
||||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||||
@ -130,7 +175,11 @@ class ParsedBinHeader:
|
|||||||
dot_name: str
|
dot_name: str
|
||||||
# only used for pair measurements
|
# only used for pair measurements
|
||||||
dot_name2: typing.Optional[str] = None
|
dot_name2: typing.Optional[str] = None
|
||||||
cpsd_type: typing.Optional[str] = None
|
cpsd_type: typing.Optional[typing.Literal["correlation", "phase"]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pair(self) -> bool:
|
||||||
|
return self.dot_name2 is not None
|
||||||
|
|
||||||
|
|
||||||
def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||||
@ -151,12 +200,15 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
|||||||
pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field)
|
pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field)
|
||||||
) is not None:
|
) is not None:
|
||||||
groups = pair_match.groupdict()
|
groups = pair_match.groupdict()
|
||||||
|
cpsd_type = typing.cast(
|
||||||
|
typing.Literal["correlation", "phase"], groups["cpsd_type"]
|
||||||
|
)
|
||||||
return ParsedBinHeader(
|
return ParsedBinHeader(
|
||||||
original_field=field,
|
original_field=field,
|
||||||
measurement_type=groups["measurement_type"],
|
measurement_type=groups["measurement_type"],
|
||||||
dot_name=groups["dot_name"],
|
dot_name=groups["dot_name"],
|
||||||
dot_name2=groups["dot_name2"],
|
dot_name2=groups["dot_name2"],
|
||||||
cpsd_type=groups["cpsd_type"],
|
cpsd_type=cpsd_type,
|
||||||
summary_stat=groups["summary_stat"],
|
summary_stat=groups["summary_stat"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -247,7 +299,7 @@ class BinnedData:
|
|||||||
measurement_type: str
|
measurement_type: str
|
||||||
|
|
||||||
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
||||||
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]:
|
def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup:
|
||||||
if dot_name not in self.dots_dict:
|
if dot_name not in self.dots_dict:
|
||||||
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
||||||
if dot_name not in self.csv_dict:
|
if dot_name not in self.csv_dict:
|
||||||
@ -258,13 +310,15 @@ class BinnedData:
|
|||||||
vs = self.csv_dict[dot_name]["mean"]
|
vs = self.csv_dict[dot_name]["mean"]
|
||||||
stdevs = self.csv_dict[dot_name]["stdev"]
|
stdevs = self.csv_dict[dot_name]["stdev"]
|
||||||
|
|
||||||
return [
|
return MeasurementGroup(
|
||||||
Measurement(
|
[
|
||||||
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
|
Measurement(
|
||||||
stdev=stdev,
|
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
|
||||||
)
|
stdev=stdev,
|
||||||
for f, v, stdev in zip(freqs, vs, stdevs)
|
)
|
||||||
]
|
for f, v, stdev in zip(freqs, vs, stdevs)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
|
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
|
||||||
if dot_name not in self.dots_dict:
|
if dot_name not in self.dots_dict:
|
||||||
@ -276,22 +330,21 @@ class BinnedData:
|
|||||||
|
|
||||||
return stdevs
|
return stdevs
|
||||||
|
|
||||||
def measurements(
|
def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup:
|
||||||
self, dot_names: typing.Sequence[str]
|
|
||||||
) -> typing.Sequence[Measurement]:
|
|
||||||
_logger.debug(f"Constructing measurements for dots {dot_names=}")
|
_logger.debug(f"Constructing measurements for dots {dot_names=}")
|
||||||
ret: typing.List[Measurement] = []
|
ret = MeasurementGroup([])
|
||||||
|
_logger.debug
|
||||||
for dot_name in dot_names:
|
for dot_name in dot_names:
|
||||||
ret.extend(self._dot_to_measurement(dot_name))
|
ret = ret.add(self._dot_to_measurements(dot_name))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _cost_function(self, measurements: typing.Sequence[Measurement]):
|
def _cost_function(self, mg: MeasurementGroup):
|
||||||
dot_measurements = [m.dot_measurement for m in measurements]
|
dot_measurements = [m.dot_measurement for m in mg.measurements]
|
||||||
meas_array = numpy.array([m.v for m in dot_measurements])
|
meas_array = numpy.array([m.v for m in dot_measurements])
|
||||||
|
|
||||||
_logger.debug(f"Obtained {meas_array=}")
|
_logger.debug(f"Obtained {meas_array=}")
|
||||||
|
|
||||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements]
|
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
|
||||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||||
_logger.debug(f"Obtained {input_array=}")
|
_logger.debug(f"Obtained {input_array=}")
|
||||||
|
|
||||||
@ -299,20 +352,24 @@ class BinnedData:
|
|||||||
|
|
||||||
def _stdev_cost_function(
|
def _stdev_cost_function(
|
||||||
self,
|
self,
|
||||||
measurements: typing.Sequence[Measurement],
|
mg: MeasurementGroup,
|
||||||
use_log_noise: bool = False,
|
use_log_noise: bool = False,
|
||||||
):
|
):
|
||||||
meas_array = numpy.array([m.dot_measurement.v for m in measurements])
|
meas_array = numpy.array([m.dot_measurement.v for m in mg.measurements])
|
||||||
stdev_array = numpy.array([m.stdev for m in measurements])
|
stdev_array = numpy.array([m.stdev for m in mg.measurements])
|
||||||
|
|
||||||
_logger.debug(f"Obtained {meas_array=}")
|
_logger.debug(f"Obtained {meas_array=}")
|
||||||
|
|
||||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements]
|
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
|
||||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||||
_logger.debug(f"Obtained {input_array=}")
|
_logger.debug(f"Obtained {input_array=}")
|
||||||
|
|
||||||
return StDevUsingCostFunction(
|
return StDevUsingCostFunction(
|
||||||
self.measurement_type, input_array, meas_array, stdev_array, use_log_noise
|
self.measurement_type,
|
||||||
|
input_array,
|
||||||
|
meas_array,
|
||||||
|
stdev_array,
|
||||||
|
log_noise=use_log_noise,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
||||||
@ -341,18 +398,3 @@ def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> Bin
|
|||||||
return BinnedData(
|
return BinnedData(
|
||||||
measurement_type=measurement_type, dots_dict=dots, csv_dict=binned
|
measurement_type=measurement_type, dots_dict=dots, csv_dict=binned
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
|
|
||||||
print(read_dots_json(pathlib.Path("dots.json")))
|
|
||||||
# print(read_bin_csv(pathlib.Path("binned-0.01-10000-50-12345.csv")))
|
|
||||||
binned_data = read_dots_and_binned(
|
|
||||||
pathlib.Path("dots.json"), pathlib.Path("binned-0.01-10000-50-12345.csv")
|
|
||||||
)
|
|
||||||
_logger.info(binned_data)
|
|
||||||
for entry in binned_data.measurements(["uprise1", "dot1"]):
|
|
||||||
_logger.info(entry)
|
|
||||||
filter = binned_data.cost_function_filter(["uprise1", "dot1"], 0.5)
|
|
||||||
_logger.info(filter)
|
|
||||||
|
@ -93,7 +93,7 @@ def test_binned_data_dot_measurement_costs(snapshot):
|
|||||||
|
|
||||||
binned_ex = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, ex_csv_file)
|
binned_ex = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, ex_csv_file)
|
||||||
measurements_ex = binned_ex.measurements(["dot1"])
|
measurements_ex = binned_ex.measurements(["dot1"])
|
||||||
# _logger.warning(measurements)
|
_logger.warning(measurements_ex)
|
||||||
|
|
||||||
v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True)
|
v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True)
|
||||||
ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True)
|
ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user