688 lines
21 KiB
Python
688 lines
21 KiB
Python
from __future__ import annotations
|
|
|
|
# useful for measurementgroup which is a type that has itself in method signatures, avoids having to manually specify the typehint as a string
|
|
|
|
import re
|
|
import numpy
|
|
import dataclasses
|
|
import typing
|
|
import json
|
|
import pathlib
|
|
import logging
|
|
import csv
|
|
import deepdog.direct_monte_carlo.dmc_filters
|
|
import deepdog.direct_monte_carlo.compose_filter
|
|
import deepdog.direct_monte_carlo.cost_function_filter
|
|
import pdme.util.fast_nonlocal_spectrum
|
|
|
|
# import tantri.cli
|
|
|
|
from kalpaa.config import MeasurementTypeEnum
|
|
import kalpaa.common.angles
|
|
|
|
import pdme
|
|
import pdme.util.fast_v_calc
|
|
import pdme.measurement
|
|
import pdme.measurement.input_types
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
X_ELECTRIC_FIELD = "Ex"
|
|
POTENTIAL = "V"
|
|
|
|
|
|
def short_string_to_measurement_type(short_string: str) -> MeasurementTypeEnum:
|
|
if short_string == X_ELECTRIC_FIELD:
|
|
return MeasurementTypeEnum.X_ELECTRIC_FIELD
|
|
elif short_string == POTENTIAL:
|
|
return MeasurementTypeEnum.POTENTIAL
|
|
else:
|
|
raise ValueError(f"Could not find {short_string=}")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Measurement:
|
|
dot_measurement: typing.Optional[pdme.measurement.DotMeasurement]
|
|
stdev: float
|
|
dot_pair_measurement: typing.Optional[pdme.measurement.DotPairMeasurement] = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MeasurementGroup:
|
|
_measurements: typing.Sequence[Measurement]
|
|
_measurement_type: MeasurementTypeEnum
|
|
_using_pairs: bool = dataclasses.field(init=False, default=False)
|
|
|
|
def validate(self):
|
|
if not self._measurements:
|
|
raise ValueError("Cannot have an empty measurement group")
|
|
using_pairs = any(
|
|
m.dot_pair_measurement is not None for m in self._measurements
|
|
)
|
|
using_singles = any(m.dot_measurement is not None for m in self._measurements)
|
|
if using_pairs and using_singles:
|
|
raise ValueError(
|
|
"Cannot mix single and pair measurements in a single measurement group"
|
|
)
|
|
if not using_pairs and not using_singles:
|
|
raise ValueError("Cannot have a measurement group with no measurements")
|
|
self._using_pairs = using_pairs
|
|
|
|
def add(self, other: MeasurementGroup) -> MeasurementGroup:
|
|
|
|
if other._measurement_type != self._measurement_type:
|
|
raise ValueError(
|
|
f"Cannot add {other._measurement_type=} to {self._measurement_type=}, as they have different measurement types"
|
|
)
|
|
|
|
# this is probably not conformant to the ideal contract for typing.Sequence
|
|
new_measurements = [*self._measurements, *other._measurements]
|
|
|
|
return MeasurementGroup(new_measurements, self._measurement_type)
|
|
|
|
def _meas_array(self) -> numpy.ndarray:
|
|
if self._using_pairs:
|
|
return numpy.array(
|
|
[
|
|
m.dot_pair_measurement.v
|
|
for m in self._measurements
|
|
if m.dot_pair_measurement is not None
|
|
]
|
|
)
|
|
else:
|
|
return numpy.array(
|
|
[
|
|
m.dot_measurement.v
|
|
for m in self._measurements
|
|
if m.dot_measurement is not None
|
|
]
|
|
)
|
|
|
|
def _input_array(self) -> numpy.ndarray:
|
|
if self._using_pairs:
|
|
return pdme.measurement.input_types.dot_pair_inputs_to_array(
|
|
[
|
|
(
|
|
m.dot_pair_measurement.r1,
|
|
m.dot_pair_measurement.r2,
|
|
m.dot_pair_measurement.f,
|
|
)
|
|
for m in self._measurements
|
|
if m.dot_pair_measurement is not None
|
|
]
|
|
)
|
|
else:
|
|
return pdme.measurement.input_types.dot_inputs_to_array(
|
|
[
|
|
(m.dot_measurement.r, m.dot_measurement.f)
|
|
for m in self._measurements
|
|
if m.dot_measurement is not None
|
|
]
|
|
)
|
|
|
|
def _stdev_array(self) -> numpy.ndarray:
|
|
return numpy.array([m.stdev for m in self._measurements])
|
|
|
|
def cost_function(self):
|
|
self.validate()
|
|
meas_array = self._meas_array()
|
|
|
|
_logger.debug(f"Obtained {meas_array=}")
|
|
|
|
input_array = self._input_array()
|
|
_logger.debug(f"Obtained {input_array=}")
|
|
|
|
return CostFunction(self._measurement_type, input_array, meas_array)
|
|
|
|
def stdev_cost_function(
|
|
self,
|
|
use_log_noise: bool = False,
|
|
):
|
|
self.validate()
|
|
stdev_array = self._stdev_array()
|
|
|
|
meas_array = self._meas_array()
|
|
|
|
_logger.debug(f"Obtained {meas_array=}")
|
|
|
|
input_array = self._input_array()
|
|
_logger.debug(f"Obtained {input_array=}")
|
|
|
|
return StDevUsingCostFunction(
|
|
self._measurement_type,
|
|
input_array,
|
|
meas_array,
|
|
stdev_array,
|
|
log_noise=use_log_noise,
|
|
use_pair_measurement=self._using_pairs,
|
|
)
|
|
|
|
|
|
class CostFunction:
|
|
def __init__(
|
|
self,
|
|
measurement_type: MeasurementTypeEnum,
|
|
dot_inputs_array: numpy.ndarray,
|
|
actual_measurement_array: numpy.ndarray,
|
|
use_pair_measurement: bool = False,
|
|
):
|
|
"""
|
|
Construct a cost function that uses the measurements.
|
|
|
|
:param measurement_type: The type of measurement we're using.
|
|
:param dot_inputs_array: The array of dot inputs.
|
|
:param actual_measurement_array: The actual measurements.
|
|
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
|
"""
|
|
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
|
self.measurement_type = measurement_type
|
|
self.dot_inputs_array = dot_inputs_array
|
|
self.actual_measurement_array = actual_measurement_array
|
|
self.actual_measurement_array2 = actual_measurement_array**2
|
|
self.use_pair_measurement = use_pair_measurement
|
|
if self.use_pair_measurement:
|
|
raise NotImplementedError("Pair measurements are not yet supported")
|
|
|
|
def __call__(self, dipoles_to_test):
|
|
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
|
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
|
self.dot_inputs_array, dipoles_to_test
|
|
)
|
|
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
|
self.dot_inputs_array, dipoles_to_test
|
|
)
|
|
diffs = (
|
|
vals - self.actual_measurement_array
|
|
) ** 2 / self.actual_measurement_array2
|
|
return numpy.sqrt(diffs.mean(axis=-1))
|
|
|
|
|
|
class StDevUsingCostFunction:
|
|
def __init__(
|
|
self,
|
|
measurement_type: MeasurementTypeEnum,
|
|
dot_inputs_array: numpy.ndarray,
|
|
actual_measurement_array: numpy.ndarray,
|
|
actual_stdev_array: numpy.ndarray,
|
|
log_noise: bool = False,
|
|
use_pair_measurement: bool = False,
|
|
):
|
|
"""
|
|
Construct a cost function that uses the standard deviation of the measurements.
|
|
|
|
:param measurement_type: The type of measurement we're using.
|
|
:param dot_inputs_array: The array of dot inputs. (may be actually inputses for pair measurements)
|
|
:param actual_measurement_array: The actual measurements.
|
|
:param actual_stdev_array: The actual standard deviations.
|
|
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
|
:param log_noise: Whether to use log noise. (default false but we should probably use it)
|
|
"""
|
|
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
|
self.measurement_type = measurement_type
|
|
self.dot_inputs_array = dot_inputs_array
|
|
self.actual_measurement_array = actual_measurement_array
|
|
self.actual_measurement_array2 = actual_measurement_array**2
|
|
self.actual_stdev_array = actual_stdev_array
|
|
self.actual_stdev_array2 = actual_stdev_array**2
|
|
|
|
self.use_log_noise = log_noise
|
|
self.log_actual = numpy.log(self.actual_measurement_array)
|
|
self.log_denom2 = (
|
|
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
|
|
- numpy.log(self.actual_measurement_array)
|
|
) ** 2
|
|
self.use_pair_measurement = use_pair_measurement
|
|
|
|
def __call__(self, dipoles_to_test):
|
|
if self.use_pair_measurement:
|
|
# We're going to just use phase data, rather than correlation data for now.
|
|
# We'll probably need to do some re-architecting later to get the phase vs correlation flag to propagate here
|
|
# if self.use_log_noise:
|
|
# _logger.info("No log noise for phase data, which is wrapped but linear")
|
|
|
|
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
|
vals = pdme.util.fast_nonlocal_spectrum.fast_s_spin_qubit_tarucha_nonlocal_dipoleses(
|
|
self.dot_inputs_array, dipoles_to_test
|
|
)
|
|
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
|
vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal_dipoleses(
|
|
self.dot_inputs_array, dipoles_to_test
|
|
)
|
|
|
|
# _logger.debug(f"Got {vals=}")
|
|
|
|
sign_vals = pdme.util.fast_nonlocal_spectrum.signarg(vals)
|
|
|
|
# _logger.debug(f"Got {sign_vals=}")
|
|
diffs = (
|
|
kalpaa.common.angles.shortest_angular_distance(
|
|
sign_vals, self.actual_measurement_array
|
|
)
|
|
** 2
|
|
)
|
|
# _logger.debug(f"Got {diffs=}")
|
|
scaled_diffs = diffs / self.actual_stdev_array2
|
|
# _logger.debug(f"Got {scaled_diffs=}")
|
|
return numpy.sqrt(scaled_diffs.mean(axis=-1))
|
|
|
|
else:
|
|
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
|
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
|
self.dot_inputs_array, dipoles_to_test
|
|
)
|
|
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
|
self.dot_inputs_array, dipoles_to_test
|
|
)
|
|
|
|
if self.use_log_noise:
|
|
diffs = ((numpy.log(vals) - self.log_actual) ** 2) / self.log_denom2
|
|
else:
|
|
diffs = (
|
|
(vals - self.actual_measurement_array) ** 2
|
|
) / self.actual_stdev_array2
|
|
|
|
return numpy.sqrt(diffs.mean(axis=-1))
|
|
|
|
|
|
# the key for frequencies in what we return
|
|
RETURNED_FREQUENCIES_KEY = "frequencies"
|
|
|
|
|
|
def read_dots_json(json_file: pathlib.Path) -> typing.Dict:
|
|
try:
|
|
with open(json_file, "r") as file:
|
|
return _reshape_dots_dict(json.load(file))
|
|
except Exception as e:
|
|
_logger.error(
|
|
f"Had a bad time reading the dots file {json_file}, sorry.", exc_info=e
|
|
)
|
|
raise e
|
|
|
|
|
|
def _reshape_dots_dict(dots_dict: typing.Sequence[typing.Dict]) -> typing.Dict:
|
|
ret = {}
|
|
for dot in dots_dict:
|
|
ret[dot["label"]] = dot["r"]
|
|
return ret
|
|
|
|
|
|
BINNED_HEADER_REGEX = r"\s*APSD_(?P<measurement_type>\w+)_(?P<dot_name>\w+)_(?P<summary_stat>mean|stdev)\s*"
|
|
PAIR_MEASUREMENT_BINNED_HEADER_REGEX = r"\s*CPSD_(?P<cpsd_type>correlation|phase)_(?P<measurement_type>\w+)_(?P<dot_name>\w+)_(?P<dot_name2>\w+)_(?P<summary_stat>mean|stdev)\s*"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ParsedBinHeader:
|
|
original_field: str
|
|
measurement_type: MeasurementTypeEnum
|
|
summary_stat: str
|
|
dot_name: str
|
|
# only used for pair measurements
|
|
dot_name2: typing.Optional[str] = None
|
|
cpsd_type: typing.Optional[typing.Literal["correlation", "phase"]] = None
|
|
|
|
@property
|
|
def pair(self) -> bool:
|
|
return self.dot_name2 is not None
|
|
|
|
|
|
def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
|
"""
|
|
Parse a binned header field into a ParsedBinHeader object.
|
|
|
|
Return None if the field does not match the expected format (and thus no match).
|
|
"""
|
|
if (match := re.match(BINNED_HEADER_REGEX, field)) is not None:
|
|
match_groups = match.groupdict()
|
|
return ParsedBinHeader(
|
|
original_field=field,
|
|
measurement_type=short_string_to_measurement_type(
|
|
match_groups["measurement_type"]
|
|
),
|
|
dot_name=match_groups["dot_name"],
|
|
summary_stat=match_groups["summary_stat"],
|
|
)
|
|
elif (
|
|
pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field)
|
|
) is not None:
|
|
groups = pair_match.groupdict()
|
|
cpsd_type = typing.cast(
|
|
typing.Literal["correlation", "phase"], groups["cpsd_type"]
|
|
)
|
|
return ParsedBinHeader(
|
|
original_field=field,
|
|
measurement_type=short_string_to_measurement_type(
|
|
groups["measurement_type"]
|
|
),
|
|
dot_name=groups["dot_name"],
|
|
dot_name2=groups["dot_name2"],
|
|
cpsd_type=cpsd_type,
|
|
summary_stat=groups["summary_stat"],
|
|
)
|
|
else:
|
|
_logger.debug(f"Could not parse {field=}")
|
|
return None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CSV_BinnedData:
|
|
measurement_type: MeasurementTypeEnum
|
|
single_dot_dict: typing.Dict[str, typing.Any]
|
|
pair_dot_dict: typing.Dict[typing.Tuple[str, str], typing.Any]
|
|
freqs: typing.Sequence[float]
|
|
|
|
|
|
def read_bin_csv(
|
|
csv_file: pathlib.Path,
|
|
) -> CSV_BinnedData:
|
|
"""
|
|
Read a binned csv file and return the measurement type and the binned data.
|
|
|
|
:param csv_file: The csv file to read.
|
|
:return: A tuple of the measurement type and the binned data.
|
|
"""
|
|
|
|
measurement_type = None
|
|
_logger.info(f"Assuming measurement type is {measurement_type} for now")
|
|
try:
|
|
with open(csv_file, "r", newline="") as file:
|
|
reader = csv.DictReader(file)
|
|
fields = reader.fieldnames
|
|
|
|
if fields is None:
|
|
raise ValueError(
|
|
f"Really wanted our fields for file {file=} to be non-None, but they're None"
|
|
)
|
|
freq_field = fields[0]
|
|
|
|
remaining_fields = fields[1:]
|
|
_logger.debug(f"Going to read frequencies from {freq_field=}")
|
|
|
|
parsed_headers = {}
|
|
freq_list = []
|
|
aggregated_dict: typing.Dict[str, typing.Any] = {
|
|
RETURNED_FREQUENCIES_KEY: []
|
|
}
|
|
pair_aggregated_dict: typing.Dict[typing.Tuple[str, str], typing.Any] = {}
|
|
|
|
for field in remaining_fields:
|
|
parsed_header = _parse_bin_header(field)
|
|
if parsed_header is None:
|
|
_logger.warning(f"Could not parse {field=}")
|
|
continue
|
|
parsed_headers[field] = parsed_header
|
|
|
|
# Get our dictionary structures set up by initialising empty dictionaries for each new field as we go
|
|
if parsed_header.pair:
|
|
if parsed_header.dot_name2 is None:
|
|
raise ValueError(
|
|
f"Pair measurement {field=} has no dot_name2, but it should"
|
|
)
|
|
dot_names = (parsed_header.dot_name, parsed_header.dot_name2)
|
|
if dot_names not in pair_aggregated_dict:
|
|
pair_aggregated_dict[dot_names] = {}
|
|
|
|
if (
|
|
parsed_header.summary_stat
|
|
not in pair_aggregated_dict[dot_names]
|
|
):
|
|
pair_aggregated_dict[dot_names][parsed_header.summary_stat] = []
|
|
|
|
else:
|
|
if parsed_header.dot_name not in aggregated_dict:
|
|
aggregated_dict[parsed_header.dot_name] = {}
|
|
|
|
if (
|
|
parsed_header.summary_stat
|
|
not in aggregated_dict[parsed_header.dot_name]
|
|
):
|
|
aggregated_dict[parsed_header.dot_name][
|
|
parsed_header.summary_stat
|
|
] = []
|
|
|
|
# Realistically we'll always have the same measurement type, but this warning may help us catch out cases where this didn't happen correctly
|
|
# We should only need to set it once, so the fact we keep checking is more about catching errors than anything else
|
|
if measurement_type is not None:
|
|
if measurement_type != parsed_header.measurement_type:
|
|
_logger.warning(
|
|
f"Attempted to set already set measurement type {measurement_type}. Allowing the switch to {parsed_header.measurement_type}, but it's problematic"
|
|
)
|
|
measurement_type = parsed_header.measurement_type
|
|
|
|
_logger.debug("finished parsing headers")
|
|
_logger.debug("throwing away the measurement type for now")
|
|
|
|
for row in reader:
|
|
# _logger.debug(f"Got {row=}")
|
|
freq_list.append(float(row[freq_field].strip()))
|
|
# don't need to set, but keep for legacy
|
|
aggregated_dict[RETURNED_FREQUENCIES_KEY].append(
|
|
float(row[freq_field].strip())
|
|
)
|
|
for field, parsed_header in parsed_headers.items():
|
|
if parsed_header.pair:
|
|
if parsed_header.dot_name2 is None:
|
|
raise ValueError(
|
|
f"Pair measurement {field=} has no dot_name2, but it should"
|
|
)
|
|
value = float(row[field].strip())
|
|
dot_names = (parsed_header.dot_name, parsed_header.dot_name2)
|
|
pair_aggregated_dict[dot_names][
|
|
parsed_header.summary_stat
|
|
].append(value)
|
|
else:
|
|
value = float(row[field].strip())
|
|
aggregated_dict[parsed_header.dot_name][
|
|
parsed_header.summary_stat
|
|
].append(value)
|
|
|
|
if measurement_type is None:
|
|
raise ValueError(
|
|
f"For some reason {measurement_type=} is None? We want to know our measurement type."
|
|
)
|
|
|
|
return CSV_BinnedData(
|
|
measurement_type=measurement_type,
|
|
single_dot_dict=aggregated_dict,
|
|
freqs=freq_list,
|
|
pair_dot_dict=pair_aggregated_dict,
|
|
)
|
|
except Exception as e:
|
|
_logger.error(
|
|
f"Had a bad time reading the binned data {csv_file}, sorry.", exc_info=e
|
|
)
|
|
raise e
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class BinnedData:
|
|
dots_dict: typing.Dict
|
|
csv_dict: typing.Dict[str, typing.Any]
|
|
measurement_type: MeasurementTypeEnum
|
|
pair_dict: typing.Dict[typing.Tuple[str, str], typing.Any]
|
|
freq_list: typing.Sequence[float]
|
|
|
|
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
|
def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup:
|
|
if dot_name not in self.dots_dict:
|
|
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
|
if dot_name not in self.csv_dict:
|
|
raise KeyError(f"Could not find {dot_name=} in {self.csv_dict=}")
|
|
|
|
dot_r = self.dots_dict[dot_name]
|
|
freqs = self.freq_list
|
|
vs = self.csv_dict[dot_name]["mean"]
|
|
stdevs = self.csv_dict[dot_name]["stdev"]
|
|
|
|
return MeasurementGroup(
|
|
[
|
|
Measurement(
|
|
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
|
|
stdev=stdev,
|
|
)
|
|
for f, v, stdev in zip(freqs, vs, stdevs)
|
|
],
|
|
_measurement_type=self.measurement_type,
|
|
)
|
|
|
|
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
|
|
if dot_name not in self.dots_dict:
|
|
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
|
if dot_name not in self.csv_dict:
|
|
raise KeyError(f"Could not find {dot_name=} in {self.csv_dict=}")
|
|
|
|
stdevs = self.csv_dict[dot_name]["stdev"]
|
|
|
|
return stdevs
|
|
|
|
def _pair_to_measurements(
|
|
self, dot_pair_name: typing.Tuple[str, str]
|
|
) -> MeasurementGroup:
|
|
if dot_pair_name not in self.pair_dict:
|
|
raise KeyError(f"Could not find {dot_pair_name=} in {self.pair_dict=}")
|
|
|
|
dot_name1, dot_name2 = dot_pair_name
|
|
if dot_name1 not in self.dots_dict:
|
|
raise KeyError(f"Could not find {dot_name1=} in {self.dots_dict=}")
|
|
if dot_name2 not in self.dots_dict:
|
|
raise KeyError(f"Could not find {dot_name2=} in {self.dots_dict=}")
|
|
|
|
dot_r1 = self.dots_dict[dot_name1]
|
|
dot_r2 = self.dots_dict[dot_name2]
|
|
freqs = self.freq_list
|
|
vs = self.pair_dict[dot_pair_name]["mean"]
|
|
stdevs = self.pair_dict[dot_pair_name]["stdev"]
|
|
|
|
return MeasurementGroup(
|
|
[
|
|
Measurement(
|
|
dot_measurement=None,
|
|
dot_pair_measurement=pdme.measurement.DotPairMeasurement(
|
|
f=f, v=v, r1=dot_r1, r2=dot_r2
|
|
),
|
|
stdev=stdev,
|
|
)
|
|
for f, v, stdev in zip(freqs, vs, stdevs)
|
|
],
|
|
_measurement_type=self.measurement_type,
|
|
)
|
|
|
|
def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup:
|
|
_logger.debug(f"Constructing measurements for dots {dot_names=}")
|
|
ret = MeasurementGroup([], self.measurement_type)
|
|
_logger.debug
|
|
for dot_name in dot_names:
|
|
ret = ret.add(self._dot_to_measurements(dot_name))
|
|
return ret
|
|
|
|
def pair_measurements(
|
|
self, dot_pair_names: typing.Sequence[typing.Tuple[str, str]]
|
|
) -> MeasurementGroup:
|
|
_logger.debug(f"Constructing measurements for dot pairs {dot_pair_names=}")
|
|
ret = MeasurementGroup([], self.measurement_type)
|
|
_logger.debug
|
|
for dot_pair_name in dot_pair_names:
|
|
ret = ret.add(self._pair_to_measurements(dot_pair_name))
|
|
return ret
|
|
|
|
# def _cost_function(self, mg: MeasurementGroup):
|
|
# meas_array = mg.meas_array()
|
|
|
|
# _logger.debug(f"Obtained {meas_array=}")
|
|
|
|
# input_array = mg.input_array()
|
|
# _logger.debug(f"Obtained {input_array=}")
|
|
|
|
# return CostFunction(self.measurement_type, input_array, meas_array)
|
|
|
|
# def _stdev_cost_function(
|
|
# self,
|
|
# mg: MeasurementGroup,
|
|
# use_log_noise: bool = False,
|
|
# ):
|
|
# stdev_array = mg.stdev_array()
|
|
|
|
# meas_array = mg.meas_array()
|
|
|
|
# _logger.debug(f"Obtained {meas_array=}")
|
|
|
|
# input_array = mg.input_array()
|
|
# _logger.debug(f"Obtained {input_array=}")
|
|
|
|
# return StDevUsingCostFunction(
|
|
# self.measurement_type,
|
|
# input_array,
|
|
# meas_array,
|
|
# stdev_array,
|
|
# log_noise=use_log_noise,
|
|
# )
|
|
|
|
def _get_measurement_from_dot_name_or_pair(
|
|
self,
|
|
dot_names_or_pairs: typing.Union[
|
|
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
|
],
|
|
) -> MeasurementGroup:
|
|
"""
|
|
check if dot_names_or_pairs is a list of strings or a list of tuples of strings, then return the appropriate measurement group
|
|
"""
|
|
if isinstance(dot_names_or_pairs[0], str):
|
|
_logger.debug("first item was a string, assuming we're specifying strings")
|
|
# we expect all strings, fail if otherwise
|
|
names = []
|
|
for dn in dot_names_or_pairs:
|
|
if not isinstance(dn, str):
|
|
raise ValueError(f"Expected all strings in {dot_names_or_pairs=}")
|
|
names.append(dn)
|
|
_logger.debug(f"Constructing measurements for dots {names=}")
|
|
return self.measurements(names)
|
|
else:
|
|
_logger.debug("trying out pairs")
|
|
pairs = []
|
|
for dn in dot_names_or_pairs:
|
|
if not isinstance(dn, tuple):
|
|
raise ValueError(f"Expected all tuples in {dot_names_or_pairs=}")
|
|
pairs.append(dn)
|
|
_logger.debug(f"Constructing measurements for dot pairs {pairs=}")
|
|
return self.pair_measurements(pairs)
|
|
|
|
def cost_function_filter(
|
|
self,
|
|
dot_names_or_pairs: typing.Union[
|
|
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
|
],
|
|
target_cost: float,
|
|
):
|
|
measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs)
|
|
cost_function = measurements.cost_function()
|
|
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
|
cost_function, target_cost
|
|
)
|
|
|
|
def stdev_cost_function_filter(
|
|
self,
|
|
dot_names_or_pairs: typing.Union[
|
|
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
|
],
|
|
target_cost: float,
|
|
use_log_noise: bool = False,
|
|
):
|
|
measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs)
|
|
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
|
|
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
|
cost_function, target_cost
|
|
)
|
|
|
|
|
|
def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> BinnedData:
|
|
dots = read_dots_json(json_file)
|
|
csv_data = read_bin_csv(csv_file)
|
|
return BinnedData(
|
|
measurement_type=csv_data.measurement_type,
|
|
dots_dict=dots,
|
|
csv_dict=csv_data.single_dot_dict,
|
|
freq_list=csv_data.freqs,
|
|
pair_dict=csv_data.pair_dot_dict,
|
|
)
|