refactor: refactors to use measurement enums internally
This commit is contained in:
parent
42dddcae02
commit
e83706b2b1
@ -16,6 +16,8 @@ import deepdog.direct_monte_carlo.cost_function_filter
|
||||
|
||||
# import tantri.cli
|
||||
|
||||
from kalpaa.config import MeasurementTypeEnum
|
||||
|
||||
import pdme
|
||||
import pdme.util.fast_v_calc
|
||||
import pdme.measurement
|
||||
@ -27,6 +29,15 @@ X_ELECTRIC_FIELD = "Ex"
|
||||
POTENTIAL = "V"
|
||||
|
||||
|
||||
def short_string_to_measurement_type(short_string: str) -> MeasurementTypeEnum:
|
||||
if short_string == X_ELECTRIC_FIELD:
|
||||
return MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
elif short_string == POTENTIAL:
|
||||
return MeasurementTypeEnum.POTENTIAL
|
||||
else:
|
||||
raise ValueError(f"Could not find {short_string=}")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Measurement:
|
||||
dot_measurement: pdme.measurement.DotMeasurement
|
||||
@ -35,22 +46,70 @@ class Measurement:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MeasurementGroup:
|
||||
measurements: typing.Sequence[Measurement]
|
||||
_measurements: typing.Sequence[Measurement]
|
||||
_measurement_type: MeasurementTypeEnum
|
||||
|
||||
def add(self, other: MeasurementGroup) -> MeasurementGroup:
|
||||
|
||||
# this is probably not conformant to the ideal contract for typing.Sequence
|
||||
new_measurements = [*self.measurements, *other.measurements]
|
||||
if other._measurement_type != self._measurement_type:
|
||||
raise ValueError(
|
||||
f"Cannot add {other._measurement_type=} to {self._measurement_type=}, as they have different measurement types"
|
||||
)
|
||||
|
||||
return MeasurementGroup(new_measurements)
|
||||
# this is probably not conformant to the ideal contract for typing.Sequence
|
||||
new_measurements = [*self._measurements, *other._measurements]
|
||||
|
||||
return MeasurementGroup(new_measurements, self._measurement_type)
|
||||
|
||||
def _meas_array(self) -> numpy.ndarray:
|
||||
return numpy.array([m.dot_measurement.v for m in self._measurements])
|
||||
|
||||
def _input_array(self) -> numpy.ndarray:
|
||||
return pdme.measurement.input_types.dot_inputs_to_array(
|
||||
[(m.dot_measurement.r, m.dot_measurement.f) for m in self._measurements]
|
||||
)
|
||||
|
||||
def _stdev_array(self) -> numpy.ndarray:
|
||||
return numpy.array([m.stdev for m in self._measurements])
|
||||
|
||||
def cost_function(self):
|
||||
meas_array = self._meas_array()
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
input_array = self._input_array()
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return CostFunction(self._measurement_type, input_array, meas_array)
|
||||
|
||||
def stdev_cost_function(
|
||||
self,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
stdev_array = self._stdev_array()
|
||||
|
||||
meas_array = self._meas_array()
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
input_array = self._input_array()
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return StDevUsingCostFunction(
|
||||
self._measurement_type,
|
||||
input_array,
|
||||
meas_array,
|
||||
stdev_array,
|
||||
log_noise=use_log_noise,
|
||||
)
|
||||
|
||||
|
||||
class CostFunction:
|
||||
def __init__(
|
||||
self,
|
||||
measurement_type,
|
||||
dot_inputs_array,
|
||||
actual_measurement_array,
|
||||
measurement_type: MeasurementTypeEnum,
|
||||
dot_inputs_array: numpy.ndarray,
|
||||
actual_measurement_array: numpy.ndarray,
|
||||
use_pair_measurement: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -71,11 +130,11 @@ class CostFunction:
|
||||
raise NotImplementedError("Pair measurements are not yet supported")
|
||||
|
||||
def __call__(self, dipoles_to_test):
|
||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
||||
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
elif self.measurement_type == POTENTIAL:
|
||||
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
@ -88,10 +147,10 @@ class CostFunction:
|
||||
class StDevUsingCostFunction:
|
||||
def __init__(
|
||||
self,
|
||||
measurement_type,
|
||||
dot_inputs_array,
|
||||
actual_measurement_array,
|
||||
actual_stdev_array,
|
||||
measurement_type: MeasurementTypeEnum,
|
||||
dot_inputs_array: numpy.ndarray,
|
||||
actual_measurement_array: numpy.ndarray,
|
||||
actual_stdev_array: numpy.ndarray,
|
||||
log_noise: bool = False,
|
||||
use_pair_measurement: bool = False,
|
||||
):
|
||||
@ -122,11 +181,11 @@ class StDevUsingCostFunction:
|
||||
self.use_pair_measurement = use_pair_measurement
|
||||
|
||||
def __call__(self, dipoles_to_test):
|
||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
||||
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
elif self.measurement_type == POTENTIAL:
|
||||
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
@ -170,7 +229,7 @@ PAIR_MEASUREMENT_BINNED_HEADER_REGEX = r"\s*CPSD_(?P<cpsd_type>correlation|phase
|
||||
@dataclasses.dataclass
|
||||
class ParsedBinHeader:
|
||||
original_field: str
|
||||
measurement_type: str
|
||||
measurement_type: MeasurementTypeEnum
|
||||
summary_stat: str
|
||||
dot_name: str
|
||||
# only used for pair measurements
|
||||
@ -192,7 +251,9 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||
match_groups = match.groupdict()
|
||||
return ParsedBinHeader(
|
||||
original_field=field,
|
||||
measurement_type=match_groups["measurement_type"],
|
||||
measurement_type=short_string_to_measurement_type(
|
||||
match_groups["measurement_type"]
|
||||
),
|
||||
dot_name=match_groups["dot_name"],
|
||||
summary_stat=match_groups["summary_stat"],
|
||||
)
|
||||
@ -205,7 +266,9 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||
)
|
||||
return ParsedBinHeader(
|
||||
original_field=field,
|
||||
measurement_type=groups["measurement_type"],
|
||||
measurement_type=short_string_to_measurement_type(
|
||||
groups["measurement_type"]
|
||||
),
|
||||
dot_name=groups["dot_name"],
|
||||
dot_name2=groups["dot_name2"],
|
||||
cpsd_type=cpsd_type,
|
||||
@ -218,7 +281,13 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||
|
||||
def read_bin_csv(
|
||||
csv_file: pathlib.Path,
|
||||
) -> typing.Tuple[str, typing.Dict[str, typing.Any]]:
|
||||
) -> typing.Tuple[MeasurementTypeEnum, typing.Dict[str, typing.Any]]:
|
||||
"""
|
||||
Read a binned csv file and return the measurement type and the binned data.
|
||||
|
||||
:param csv_file: The csv file to read.
|
||||
:return: A tuple of the measurement type and the binned data.
|
||||
"""
|
||||
|
||||
measurement_type = None
|
||||
_logger.info(f"Assuming measurement type is {measurement_type} for now")
|
||||
@ -296,7 +365,7 @@ def read_bin_csv(
|
||||
class BinnedData:
|
||||
dots_dict: typing.Dict
|
||||
csv_dict: typing.Dict[str, typing.Any]
|
||||
measurement_type: str
|
||||
measurement_type: MeasurementTypeEnum
|
||||
|
||||
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
||||
def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup:
|
||||
@ -317,7 +386,8 @@ class BinnedData:
|
||||
stdev=stdev,
|
||||
)
|
||||
for f, v, stdev in zip(freqs, vs, stdevs)
|
||||
]
|
||||
],
|
||||
_measurement_type=self.measurement_type,
|
||||
)
|
||||
|
||||
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
|
||||
@ -332,49 +402,47 @@ class BinnedData:
|
||||
|
||||
def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup:
|
||||
_logger.debug(f"Constructing measurements for dots {dot_names=}")
|
||||
ret = MeasurementGroup([])
|
||||
ret = MeasurementGroup([], self.measurement_type)
|
||||
_logger.debug
|
||||
for dot_name in dot_names:
|
||||
ret = ret.add(self._dot_to_measurements(dot_name))
|
||||
return ret
|
||||
|
||||
def _cost_function(self, mg: MeasurementGroup):
|
||||
dot_measurements = [m.dot_measurement for m in mg.measurements]
|
||||
meas_array = numpy.array([m.v for m in dot_measurements])
|
||||
# def _cost_function(self, mg: MeasurementGroup):
|
||||
# meas_array = mg.meas_array()
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
# _logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
|
||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
# input_array = mg.input_array()
|
||||
# _logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return CostFunction(self.measurement_type, input_array, meas_array)
|
||||
# return CostFunction(self.measurement_type, input_array, meas_array)
|
||||
|
||||
def _stdev_cost_function(
|
||||
self,
|
||||
mg: MeasurementGroup,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
meas_array = numpy.array([m.dot_measurement.v for m in mg.measurements])
|
||||
stdev_array = numpy.array([m.stdev for m in mg.measurements])
|
||||
# def _stdev_cost_function(
|
||||
# self,
|
||||
# mg: MeasurementGroup,
|
||||
# use_log_noise: bool = False,
|
||||
# ):
|
||||
# stdev_array = mg.stdev_array()
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
# meas_array = mg.meas_array()
|
||||
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
|
||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
# _logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
return StDevUsingCostFunction(
|
||||
self.measurement_type,
|
||||
input_array,
|
||||
meas_array,
|
||||
stdev_array,
|
||||
log_noise=use_log_noise,
|
||||
)
|
||||
# input_array = mg.input_array()
|
||||
# _logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
# return StDevUsingCostFunction(
|
||||
# self.measurement_type,
|
||||
# input_array,
|
||||
# meas_array,
|
||||
# stdev_array,
|
||||
# log_noise=use_log_noise,
|
||||
# )
|
||||
|
||||
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
||||
measurements = self.measurements(dot_names)
|
||||
cost_function = self._cost_function(measurements)
|
||||
cost_function = measurements.cost_function()
|
||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, target_cost
|
||||
)
|
||||
@ -386,7 +454,7 @@ class BinnedData:
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
measurements = self.measurements(dot_names)
|
||||
cost_function = self._stdev_cost_function(measurements, use_log_noise)
|
||||
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
|
||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, target_cost
|
||||
)
|
||||
|
@ -189,7 +189,7 @@
|
||||
0.5,
|
||||
]),
|
||||
}),
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
})
|
||||
# ---
|
||||
# name: test_binned_data_dot_measurement_costs
|
||||
@ -222,7 +222,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot1_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -230,7 +230,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot1_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
@ -238,7 +238,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -246,7 +246,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
@ -254,7 +254,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'line',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_line_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -262,7 +262,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'line',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_line_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
@ -270,7 +270,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle1_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -278,7 +278,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle1_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
@ -286,7 +286,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -294,7 +294,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
@ -302,7 +302,7 @@
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'uprise1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_uprise1_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -311,7 +311,7 @@
|
||||
'cpsd_type': 'correlation',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_correlation_V_dot1_dot2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -319,7 +319,7 @@
|
||||
'cpsd_type': 'correlation',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_correlation_V_dot1_dot2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
@ -327,7 +327,7 @@
|
||||
'cpsd_type': 'phase',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_phase_V_dot1_dot2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
@ -335,7 +335,7 @@
|
||||
'cpsd_type': 'phase',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': 'V',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_phase_V_dot1_dot2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
|
@ -95,14 +95,16 @@ def test_binned_data_dot_measurement_costs(snapshot):
|
||||
measurements_ex = binned_ex.measurements(["dot1"])
|
||||
_logger.warning(measurements_ex)
|
||||
|
||||
v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True)
|
||||
ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True)
|
||||
|
||||
v_linear_noise_stdev_cost_func = binned_v._stdev_cost_function(
|
||||
measurements_v, False
|
||||
v_log_noise_stdev_cost_func = measurements_v.stdev_cost_function(use_log_noise=True)
|
||||
ex_log_noise_stdev_cost_func = measurements_ex.stdev_cost_function(
|
||||
use_log_noise=True
|
||||
)
|
||||
ex_linear_noise_stdev_cost_func = binned_ex._stdev_cost_function(
|
||||
measurements_ex, False
|
||||
|
||||
v_linear_noise_stdev_cost_func = measurements_v.stdev_cost_function(
|
||||
use_log_noise=False
|
||||
)
|
||||
ex_linear_noise_stdev_cost_func = measurements_ex.stdev_cost_function(
|
||||
use_log_noise=False
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user