feat!: changes api of time series method to be less dumb, breaks existing calls

This commit is contained in:
Deepak Mallubhotla 2024-04-20 22:16:31 -05:00
parent b4b25974c9
commit 450f0d5730
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
3 changed files with 56 additions and 32 deletions

View File

@ -60,7 +60,11 @@ def read_files(dipoles_file, dots_file):
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
)
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
_logger.info(dipoles)
dots = tantri.cli.file_importer.read_dots_json_file(dots_file)
for dipole in dipoles:
_logger.info(dipole)
for dot in dots:
_logger.info(dot)
@cli.command()

View File

@ -23,7 +23,7 @@ class DotPosition:
@dataclass
class Dipole:
class WrappedDipole:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
@ -94,10 +94,29 @@ class Dipole:
self.state *= -1
def get_wrapped_dipoles(
dipole_tos: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
) -> typing.Sequence[WrappedDipole]:
return [
WrappedDipole(
p=dipole_to.p,
s=dipole_to.s,
w=dipole_to.w,
dot_positions=dots,
measurement_type=measurement_type,
)
for dipole_to in dipole_tos
]
class DipoleTimeSeries:
def __init__(
self,
dipoles: typing.Sequence[Dipole],
dipoles: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
dt: float,
rng_to_use: typing.Optional[numpy.random.Generator] = None,
):
@ -107,7 +126,7 @@ class DipoleTimeSeries:
else:
self.rng = rng_to_use
self.dipoles = dipoles
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
self.state = 0
self.dt = dt
@ -125,7 +144,7 @@ class DipoleTimeSeries:
__all__ = [
"Dipole",
"WrappedDipole",
"DipoleTimeSeries",
"DipoleTO",
]

View File

@ -1,4 +1,9 @@
from tantri.dipoles import Dipole, DipoleTimeSeries, DotPosition, DipoleMeasurementType
from tantri.dipoles import (
DipoleTO,
DipoleTimeSeries,
DotPosition,
DipoleMeasurementType,
)
import numpy
@ -7,36 +12,34 @@ def test_dipoles_1(snapshot):
rng = numpy.random.default_rng(1234)
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
d1 = Dipole(
d1 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([5, 3, 2]),
15,
dots,
DipoleMeasurementType.ELECTRIC_POTENTIAL,
)
d2 = Dipole(
d2 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([-2, 3, 2]),
0.1,
dots,
DipoleMeasurementType.ELECTRIC_POTENTIAL,
)
d3 = Dipole(
d3 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([2, 1, 2]),
6,
dots,
DipoleMeasurementType.ELECTRIC_POTENTIAL,
)
d4 = Dipole(
d4 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([-5, -5, 2]),
50,
dots,
DipoleMeasurementType.ELECTRIC_POTENTIAL,
)
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng)
ts_gen = DipoleTimeSeries(
[d1, d2, d3, d4],
dots,
DipoleMeasurementType.ELECTRIC_POTENTIAL,
0.01,
rng_to_use=rng,
)
time_series = [ts_gen.transition() for i in range(50)]
assert time_series == snapshot
@ -47,36 +50,34 @@ def test_dipoles_1_field(snapshot):
rng = numpy.random.default_rng(1234)
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
d1 = Dipole(
d1 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([5, 3, 2]),
15,
dots,
DipoleMeasurementType.X_ELECTRIC_FIELD,
)
d2 = Dipole(
d2 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([-2, 3, 2]),
0.1,
dots,
DipoleMeasurementType.X_ELECTRIC_FIELD,
)
d3 = Dipole(
d3 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([2, 1, 2]),
6,
dots,
DipoleMeasurementType.X_ELECTRIC_FIELD,
)
d4 = Dipole(
d4 = DipoleTO(
numpy.array([0, 0, 10]),
numpy.array([-5, -5, 2]),
50,
dots,
DipoleMeasurementType.X_ELECTRIC_FIELD,
)
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng)
ts_gen = DipoleTimeSeries(
[d1, d2, d3, d4],
dots,
DipoleMeasurementType.X_ELECTRIC_FIELD,
0.01,
rng_to_use=rng,
)
time_series = [ts_gen.transition() for i in range(50)]
assert time_series == snapshot