From a7508b8906923a579dfa05edc22963d2f0102be9 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 27 Mar 2022 17:23:01 -0500 Subject: [PATCH] feat!: adds pair inputs to array function --- pdme/measurement/__init__.py | 4 +- pdme/measurement/input_types.py | 22 +++++++++ pdme/measurement/oscillating_dipole.py | 17 +------ .../measurement/test_measurement_to_arrays.py | 45 +++++++++++++++++++ 4 files changed, 72 insertions(+), 16 deletions(-) create mode 100644 pdme/measurement/input_types.py create mode 100644 tests/measurement/test_measurement_to_arrays.py diff --git a/pdme/measurement/__init__.py b/pdme/measurement/__init__.py index 800b9b5..c15aca6 100644 --- a/pdme/measurement/__init__.py +++ b/pdme/measurement/__init__.py @@ -1,5 +1,7 @@ from pdme.measurement.dot_measure import DotMeasurement, DotRangeMeasurement from pdme.measurement.dot_pair_measure import DotPairMeasurement, DotPairRangeMeasurement from pdme.measurement.oscillating_dipole import OscillatingDipole, OscillatingDipoleArrangement +from pdme.measurement.input_types import DotInput, DotPairInput -__all__ = ['DotMeasurement', 'DotRangeMeasurement', 'DotPairMeasurement', 'DotPairRangeMeasurement', 'OscillatingDipole', 'OscillatingDipoleArrangement'] + +__all__ = ['DotMeasurement', 'DotRangeMeasurement', 'DotPairMeasurement', 'DotPairRangeMeasurement', 'OscillatingDipole', 'OscillatingDipoleArrangement', 'DotInput', 'DotPairInput'] diff --git a/pdme/measurement/input_types.py b/pdme/measurement/input_types.py new file mode 100644 index 0000000..ad8e121 --- /dev/null +++ b/pdme/measurement/input_types.py @@ -0,0 +1,22 @@ +import numpy.typing +from typing import Tuple, Sequence, Union +from pdme.measurement.dot_measure import DotRangeMeasurement +from pdme.measurement.dot_pair_measure import DotPairRangeMeasurement + + +DotInput = Tuple[numpy.typing.ArrayLike, float] +DotPairInput = Tuple[numpy.typing.ArrayLike, numpy.typing.ArrayLike, float] + + +def dot_inputs_to_array(dot_inputs: Sequence[DotInput]) -> numpy.ndarray: + return numpy.array([numpy.append(numpy.array(input[0]), input[1]) for input in dot_inputs]) + + +def dot_pair_inputs_to_array(pair_inputs: Sequence[DotPairInput]) -> numpy.ndarray: + return numpy.array([[numpy.append(numpy.array(input[0]), input[2]), numpy.append(numpy.array(input[1]), input[2])] for input in pair_inputs]) + + +def dot_range_measurements_low_high_arrays(dot_range_measurements: Union[Sequence[DotRangeMeasurement], Sequence[DotPairRangeMeasurement]]) -> Tuple[numpy.ndarray, numpy.ndarray]: + lows = [measurement.v_low for measurement in dot_range_measurements] + highs = [measurement.v_high for measurement in dot_range_measurements] + return (numpy.array(lows), numpy.array(highs)) diff --git a/pdme/measurement/oscillating_dipole.py b/pdme/measurement/oscillating_dipole.py index 8aea27b..8edaefe 100644 --- a/pdme/measurement/oscillating_dipole.py +++ b/pdme/measurement/oscillating_dipole.py @@ -1,13 +1,10 @@ from dataclasses import dataclass import numpy import numpy.typing -from typing import Sequence, List, Tuple +from typing import Sequence, List from pdme.measurement.dot_measure import DotMeasurement, DotRangeMeasurement from pdme.measurement.dot_pair_measure import DotPairMeasurement, DotPairRangeMeasurement - - -DotInput = Tuple[numpy.typing.ArrayLike, float] -DotPairInput = Tuple[numpy.typing.ArrayLike, numpy.typing.ArrayLike, float] +from pdme.measurement.input_types import DotInput, DotPairInput @dataclass @@ -59,16 +56,6 @@ class OscillatingDipole(): return self._alpha(r1) * self._alpha(r2) * self._b(f) -def dot_inputs_to_array(dot_inputs: Sequence[DotInput]) -> numpy.ndarray: - return numpy.array([numpy.append(numpy.array(input[0]), input[1]) for input in dot_inputs]) - - -def dot_range_measurements_low_high_arrays(dot_range_measurements: Sequence[DotRangeMeasurement]) -> Tuple[numpy.ndarray, numpy.ndarray]: - lows = [measurement.v_low for measurement in dot_range_measurements] - highs = [measurement.v_high for measurement in dot_range_measurements] - return (numpy.array(lows), numpy.array(highs)) - - class OscillatingDipoleArrangement(): ''' A collection of oscillating dipoles, which we are interested in being able to characterise. diff --git a/tests/measurement/test_measurement_to_arrays.py b/tests/measurement/test_measurement_to_arrays.py new file mode 100644 index 0000000..b7c7ecf --- /dev/null +++ b/tests/measurement/test_measurement_to_arrays.py @@ -0,0 +1,45 @@ +import numpy +import pdme.measurement.input_types +from pdme.measurement.dot_measure import DotRangeMeasurement +from pdme.measurement.dot_pair_measure import DotPairRangeMeasurement + + +def test_inputs_to_array(): + i1 = ([1, 2, 3], 5) + i2 = ([-1, 4, -2], 9) + + actual = pdme.measurement.input_types.dot_inputs_to_array([i1, i2]) + expected = numpy.array([[1, 2, 3, 5], [-1, 4, -2, 9]]) + + numpy.testing.assert_allclose(actual, expected, err_msg="Didn't convert to array properly") + + +def test_pair_inputs_to_array(): + i1 = ([1, 2, 3], [-1, 4, -2], 5) + i2 = ([-1, 4, -2], [6, 7, 8], 9) + + actual = pdme.measurement.input_types.dot_pair_inputs_to_array([i1, i2]) + expected = numpy.array([ + [[1, 2, 3, 5], [-1, 4, -2, 5]], + [[-1, 4, -2, 9], [6, 7, 8, 9]], + ]) + + numpy.testing.assert_allclose(actual, expected, err_msg="Didn't convert to array properly") + + +def test_ranges_to_array(): + m1 = DotRangeMeasurement(1, 2, 100, 1000) + m2 = DotRangeMeasurement(0.5, 3, 100, 1000) + + actual_lows, actual_highs = pdme.measurement.input_types.dot_range_measurements_low_high_arrays([m1, m2]) + numpy.testing.assert_allclose(actual_lows, [1, 0.5], err_msg="Lows were wrong") + numpy.testing.assert_allclose(actual_highs, [2, 3], err_msg="Highs were wrong") + + +def test_pair_ranges_to_array(): + m1 = DotPairRangeMeasurement(1, 2, 100, 1000, 10000) + m2 = DotPairRangeMeasurement(0.5, 3, 100, 1000, 10000) + + actual_lows, actual_highs = pdme.measurement.input_types.dot_range_measurements_low_high_arrays([m1, m2]) + numpy.testing.assert_allclose(actual_lows, [1, 0.5], err_msg="Lows were wrong") + numpy.testing.assert_allclose(actual_highs, [2, 3], err_msg="Highs were wrong")