From 450f0d573025f50bf03cd326f0a74080f2253a06 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sat, 20 Apr 2024 22:16:31 -0500 Subject: [PATCH] feat!: changes api of time series method to be less dumb, breaks existing calls --- tantri/cli/__init__.py | 6 +++- tantri/dipoles/__init__.py | 27 ++++++++++++++--- tests/dipoles/test_dipoles.py | 55 ++++++++++++++++++----------------- 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/tantri/cli/__init__.py b/tantri/cli/__init__.py index 3f7ec7a..658e50c 100755 --- a/tantri/cli/__init__.py +++ b/tantri/cli/__init__.py @@ -60,7 +60,11 @@ def read_files(dipoles_file, dots_file): f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]" ) dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file) - _logger.info(dipoles) + dots = tantri.cli.file_importer.read_dots_json_file(dots_file) + for dipole in dipoles: + _logger.info(dipole) + for dot in dots: + _logger.info(dot) @cli.command() diff --git a/tantri/dipoles/__init__.py b/tantri/dipoles/__init__.py index a22c074..e17696a 100755 --- a/tantri/dipoles/__init__.py +++ b/tantri/dipoles/__init__.py @@ -23,7 +23,7 @@ class DotPosition: @dataclass -class Dipole: +class WrappedDipole: # assumed len 3 p: numpy.ndarray s: numpy.ndarray @@ -94,10 +94,29 @@ class Dipole: self.state *= -1 +def get_wrapped_dipoles( + dipole_tos: typing.Sequence[DipoleTO], + dots: typing.Sequence[DotPosition], + measurement_type: DipoleMeasurementType, +) -> typing.Sequence[WrappedDipole]: + return [ + WrappedDipole( + p=dipole_to.p, + s=dipole_to.s, + w=dipole_to.w, + dot_positions=dots, + measurement_type=measurement_type, + ) + for dipole_to in dipole_tos + ] + + class DipoleTimeSeries: def __init__( self, - dipoles: typing.Sequence[Dipole], + dipoles: typing.Sequence[DipoleTO], + dots: typing.Sequence[DotPosition], + measurement_type: DipoleMeasurementType, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None, ): @@ -107,7 +126,7 @@ class DipoleTimeSeries: else: self.rng = rng_to_use - self.dipoles = dipoles + self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type) self.state = 0 self.dt = dt @@ -125,7 +144,7 @@ class DipoleTimeSeries: __all__ = [ - "Dipole", + "WrappedDipole", "DipoleTimeSeries", "DipoleTO", ] diff --git a/tests/dipoles/test_dipoles.py b/tests/dipoles/test_dipoles.py index f7af87c..0757622 100755 --- a/tests/dipoles/test_dipoles.py +++ b/tests/dipoles/test_dipoles.py @@ -1,4 +1,9 @@ -from tantri.dipoles import Dipole, DipoleTimeSeries, DotPosition, DipoleMeasurementType +from tantri.dipoles import ( + DipoleTO, + DipoleTimeSeries, + DotPosition, + DipoleMeasurementType, +) import numpy @@ -7,36 +12,34 @@ def test_dipoles_1(snapshot): rng = numpy.random.default_rng(1234) dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")] - d1 = Dipole( + d1 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([5, 3, 2]), 15, - dots, - DipoleMeasurementType.ELECTRIC_POTENTIAL, ) - d2 = Dipole( + d2 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([-2, 3, 2]), 0.1, - dots, - DipoleMeasurementType.ELECTRIC_POTENTIAL, ) - d3 = Dipole( + d3 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([2, 1, 2]), 6, - dots, - DipoleMeasurementType.ELECTRIC_POTENTIAL, ) - d4 = Dipole( + d4 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([-5, -5, 2]), 50, - dots, - DipoleMeasurementType.ELECTRIC_POTENTIAL, ) - ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng) + ts_gen = DipoleTimeSeries( + [d1, d2, d3, d4], + dots, + DipoleMeasurementType.ELECTRIC_POTENTIAL, + 0.01, + rng_to_use=rng, + ) time_series = [ts_gen.transition() for i in range(50)] assert time_series == snapshot @@ -47,36 +50,34 @@ def test_dipoles_1_field(snapshot): rng = numpy.random.default_rng(1234) dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")] - d1 = Dipole( + d1 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([5, 3, 2]), 15, - dots, - DipoleMeasurementType.X_ELECTRIC_FIELD, ) - d2 = Dipole( + d2 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([-2, 3, 2]), 0.1, - dots, - DipoleMeasurementType.X_ELECTRIC_FIELD, ) - d3 = Dipole( + d3 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([2, 1, 2]), 6, - dots, - DipoleMeasurementType.X_ELECTRIC_FIELD, ) - d4 = Dipole( + d4 = DipoleTO( numpy.array([0, 0, 10]), numpy.array([-5, -5, 2]), 50, - dots, - DipoleMeasurementType.X_ELECTRIC_FIELD, ) - ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng) + ts_gen = DipoleTimeSeries( + [d1, d2, d3, d4], + dots, + DipoleMeasurementType.X_ELECTRIC_FIELD, + 0.01, + rng_to_use=rng, + ) time_series = [ts_gen.transition() for i in range(50)] assert time_series == snapshot