refactor: refactors to use measurement enums internally

This commit is contained in:
Deepak Mallubhotla 2025-02-23 17:58:21 -06:00
parent 42dddcae02
commit e83706b2b1
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
3 changed files with 143 additions and 73 deletions

View File

@ -16,6 +16,8 @@ import deepdog.direct_monte_carlo.cost_function_filter
# import tantri.cli
from kalpaa.config import MeasurementTypeEnum
import pdme
import pdme.util.fast_v_calc
import pdme.measurement
@ -27,6 +29,15 @@ X_ELECTRIC_FIELD = "Ex"
POTENTIAL = "V"
def short_string_to_measurement_type(short_string: str) -> MeasurementTypeEnum:
if short_string == X_ELECTRIC_FIELD:
return MeasurementTypeEnum.X_ELECTRIC_FIELD
elif short_string == POTENTIAL:
return MeasurementTypeEnum.POTENTIAL
else:
raise ValueError(f"Could not find {short_string=}")
@dataclasses.dataclass
class Measurement:
dot_measurement: pdme.measurement.DotMeasurement
@ -35,22 +46,70 @@ class Measurement:
@dataclasses.dataclass
class MeasurementGroup:
measurements: typing.Sequence[Measurement]
_measurements: typing.Sequence[Measurement]
_measurement_type: MeasurementTypeEnum
def add(self, other: MeasurementGroup) -> MeasurementGroup:
# this is probably not conformant to the ideal contract for typing.Sequence
new_measurements = [*self.measurements, *other.measurements]
if other._measurement_type != self._measurement_type:
raise ValueError(
f"Cannot add {other._measurement_type=} to {self._measurement_type=}, as they have different measurement types"
)
return MeasurementGroup(new_measurements)
# this is probably not conformant to the ideal contract for typing.Sequence
new_measurements = [*self._measurements, *other._measurements]
return MeasurementGroup(new_measurements, self._measurement_type)
def _meas_array(self) -> numpy.ndarray:
return numpy.array([m.dot_measurement.v for m in self._measurements])
def _input_array(self) -> numpy.ndarray:
return pdme.measurement.input_types.dot_inputs_to_array(
[(m.dot_measurement.r, m.dot_measurement.f) for m in self._measurements]
)
def _stdev_array(self) -> numpy.ndarray:
return numpy.array([m.stdev for m in self._measurements])
def cost_function(self):
meas_array = self._meas_array()
_logger.debug(f"Obtained {meas_array=}")
input_array = self._input_array()
_logger.debug(f"Obtained {input_array=}")
return CostFunction(self._measurement_type, input_array, meas_array)
def stdev_cost_function(
self,
use_log_noise: bool = False,
):
stdev_array = self._stdev_array()
meas_array = self._meas_array()
_logger.debug(f"Obtained {meas_array=}")
input_array = self._input_array()
_logger.debug(f"Obtained {input_array=}")
return StDevUsingCostFunction(
self._measurement_type,
input_array,
meas_array,
stdev_array,
log_noise=use_log_noise,
)
class CostFunction:
def __init__(
self,
measurement_type,
dot_inputs_array,
actual_measurement_array,
measurement_type: MeasurementTypeEnum,
dot_inputs_array: numpy.ndarray,
actual_measurement_array: numpy.ndarray,
use_pair_measurement: bool = False,
):
"""
@ -71,11 +130,11 @@ class CostFunction:
raise NotImplementedError("Pair measurements are not yet supported")
def __call__(self, dipoles_to_test):
if self.measurement_type == X_ELECTRIC_FIELD:
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
self.dot_inputs_array, dipoles_to_test
)
elif self.measurement_type == POTENTIAL:
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
self.dot_inputs_array, dipoles_to_test
)
@ -88,10 +147,10 @@ class CostFunction:
class StDevUsingCostFunction:
def __init__(
self,
measurement_type,
dot_inputs_array,
actual_measurement_array,
actual_stdev_array,
measurement_type: MeasurementTypeEnum,
dot_inputs_array: numpy.ndarray,
actual_measurement_array: numpy.ndarray,
actual_stdev_array: numpy.ndarray,
log_noise: bool = False,
use_pair_measurement: bool = False,
):
@ -122,11 +181,11 @@ class StDevUsingCostFunction:
self.use_pair_measurement = use_pair_measurement
def __call__(self, dipoles_to_test):
if self.measurement_type == X_ELECTRIC_FIELD:
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
self.dot_inputs_array, dipoles_to_test
)
elif self.measurement_type == POTENTIAL:
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
self.dot_inputs_array, dipoles_to_test
)
@ -170,7 +229,7 @@ PAIR_MEASUREMENT_BINNED_HEADER_REGEX = r"\s*CPSD_(?P<cpsd_type>correlation|phase
@dataclasses.dataclass
class ParsedBinHeader:
original_field: str
measurement_type: str
measurement_type: MeasurementTypeEnum
summary_stat: str
dot_name: str
# only used for pair measurements
@ -192,7 +251,9 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
match_groups = match.groupdict()
return ParsedBinHeader(
original_field=field,
measurement_type=match_groups["measurement_type"],
measurement_type=short_string_to_measurement_type(
match_groups["measurement_type"]
),
dot_name=match_groups["dot_name"],
summary_stat=match_groups["summary_stat"],
)
@ -205,7 +266,9 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
)
return ParsedBinHeader(
original_field=field,
measurement_type=groups["measurement_type"],
measurement_type=short_string_to_measurement_type(
groups["measurement_type"]
),
dot_name=groups["dot_name"],
dot_name2=groups["dot_name2"],
cpsd_type=cpsd_type,
@ -218,7 +281,13 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
def read_bin_csv(
csv_file: pathlib.Path,
) -> typing.Tuple[str, typing.Dict[str, typing.Any]]:
) -> typing.Tuple[MeasurementTypeEnum, typing.Dict[str, typing.Any]]:
"""
Read a binned csv file and return the measurement type and the binned data.
:param csv_file: The csv file to read.
:return: A tuple of the measurement type and the binned data.
"""
measurement_type = None
_logger.info(f"Assuming measurement type is {measurement_type} for now")
@ -296,7 +365,7 @@ def read_bin_csv(
class BinnedData:
dots_dict: typing.Dict
csv_dict: typing.Dict[str, typing.Any]
measurement_type: str
measurement_type: MeasurementTypeEnum
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup:
@ -317,7 +386,8 @@ class BinnedData:
stdev=stdev,
)
for f, v, stdev in zip(freqs, vs, stdevs)
]
],
_measurement_type=self.measurement_type,
)
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
@ -332,49 +402,47 @@ class BinnedData:
def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup:
_logger.debug(f"Constructing measurements for dots {dot_names=}")
ret = MeasurementGroup([])
ret = MeasurementGroup([], self.measurement_type)
_logger.debug
for dot_name in dot_names:
ret = ret.add(self._dot_to_measurements(dot_name))
return ret
def _cost_function(self, mg: MeasurementGroup):
dot_measurements = [m.dot_measurement for m in mg.measurements]
meas_array = numpy.array([m.v for m in dot_measurements])
# def _cost_function(self, mg: MeasurementGroup):
# meas_array = mg.meas_array()
_logger.debug(f"Obtained {meas_array=}")
# _logger.debug(f"Obtained {meas_array=}")
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
_logger.debug(f"Obtained {input_array=}")
# input_array = mg.input_array()
# _logger.debug(f"Obtained {input_array=}")
return CostFunction(self.measurement_type, input_array, meas_array)
# return CostFunction(self.measurement_type, input_array, meas_array)
def _stdev_cost_function(
self,
mg: MeasurementGroup,
use_log_noise: bool = False,
):
meas_array = numpy.array([m.dot_measurement.v for m in mg.measurements])
stdev_array = numpy.array([m.stdev for m in mg.measurements])
# def _stdev_cost_function(
# self,
# mg: MeasurementGroup,
# use_log_noise: bool = False,
# ):
# stdev_array = mg.stdev_array()
_logger.debug(f"Obtained {meas_array=}")
# meas_array = mg.meas_array()
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in mg.measurements]
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
_logger.debug(f"Obtained {input_array=}")
# _logger.debug(f"Obtained {meas_array=}")
return StDevUsingCostFunction(
self.measurement_type,
input_array,
meas_array,
stdev_array,
log_noise=use_log_noise,
)
# input_array = mg.input_array()
# _logger.debug(f"Obtained {input_array=}")
# return StDevUsingCostFunction(
# self.measurement_type,
# input_array,
# meas_array,
# stdev_array,
# log_noise=use_log_noise,
# )
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
measurements = self.measurements(dot_names)
cost_function = self._cost_function(measurements)
cost_function = measurements.cost_function()
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
cost_function, target_cost
)
@ -386,7 +454,7 @@ class BinnedData:
use_log_noise: bool = False,
):
measurements = self.measurements(dot_names)
cost_function = self._stdev_cost_function(measurements, use_log_noise)
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
cost_function, target_cost
)

View File

@ -189,7 +189,7 @@
0.5,
]),
}),
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
})
# ---
# name: test_binned_data_dot_measurement_costs
@ -222,7 +222,7 @@
'cpsd_type': None,
'dot_name': 'dot1',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_dot1_mean',
'summary_stat': 'mean',
}),
@ -230,7 +230,7 @@
'cpsd_type': None,
'dot_name': 'dot1',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_dot1_stdev',
'summary_stat': 'stdev',
}),
@ -238,7 +238,7 @@
'cpsd_type': None,
'dot_name': 'dot2',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_dot2_mean',
'summary_stat': 'mean',
}),
@ -246,7 +246,7 @@
'cpsd_type': None,
'dot_name': 'dot2',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_dot2_stdev',
'summary_stat': 'stdev',
}),
@ -254,7 +254,7 @@
'cpsd_type': None,
'dot_name': 'line',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_line_mean',
'summary_stat': 'mean',
}),
@ -262,7 +262,7 @@
'cpsd_type': None,
'dot_name': 'line',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_line_stdev',
'summary_stat': 'stdev',
}),
@ -270,7 +270,7 @@
'cpsd_type': None,
'dot_name': 'triangle1',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_triangle1_mean',
'summary_stat': 'mean',
}),
@ -278,7 +278,7 @@
'cpsd_type': None,
'dot_name': 'triangle1',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_triangle1_stdev',
'summary_stat': 'stdev',
}),
@ -286,7 +286,7 @@
'cpsd_type': None,
'dot_name': 'triangle2',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_triangle2_mean',
'summary_stat': 'mean',
}),
@ -294,7 +294,7 @@
'cpsd_type': None,
'dot_name': 'triangle2',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_triangle2_stdev',
'summary_stat': 'stdev',
}),
@ -302,7 +302,7 @@
'cpsd_type': None,
'dot_name': 'uprise1',
'dot_name2': None,
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'APSD_V_uprise1_mean',
'summary_stat': 'mean',
}),
@ -311,7 +311,7 @@
'cpsd_type': 'correlation',
'dot_name': 'dot1',
'dot_name2': 'dot2',
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'CPSD_correlation_V_dot1_dot2_mean',
'summary_stat': 'mean',
}),
@ -319,7 +319,7 @@
'cpsd_type': 'correlation',
'dot_name': 'dot1',
'dot_name2': 'dot2',
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'CPSD_correlation_V_dot1_dot2_stdev',
'summary_stat': 'stdev',
}),
@ -327,7 +327,7 @@
'cpsd_type': 'phase',
'dot_name': 'dot1',
'dot_name2': 'dot2',
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'CPSD_phase_V_dot1_dot2_mean',
'summary_stat': 'mean',
}),
@ -335,7 +335,7 @@
'cpsd_type': 'phase',
'dot_name': 'dot1',
'dot_name2': 'dot2',
'measurement_type': 'V',
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
'original_field': 'CPSD_phase_V_dot1_dot2_stdev',
'summary_stat': 'stdev',
}),

View File

@ -95,14 +95,16 @@ def test_binned_data_dot_measurement_costs(snapshot):
measurements_ex = binned_ex.measurements(["dot1"])
_logger.warning(measurements_ex)
v_log_noise_stdev_cost_func = binned_v._stdev_cost_function(measurements_v, True)
ex_log_noise_stdev_cost_func = binned_ex._stdev_cost_function(measurements_ex, True)
v_linear_noise_stdev_cost_func = binned_v._stdev_cost_function(
measurements_v, False
v_log_noise_stdev_cost_func = measurements_v.stdev_cost_function(use_log_noise=True)
ex_log_noise_stdev_cost_func = measurements_ex.stdev_cost_function(
use_log_noise=True
)
ex_linear_noise_stdev_cost_func = binned_ex._stdev_cost_function(
measurements_ex, False
v_linear_noise_stdev_cost_func = measurements_v.stdev_cost_function(
use_log_noise=False
)
ex_linear_noise_stdev_cost_func = measurements_ex.stdev_cost_function(
use_log_noise=False
)
result_dict = {