feat!: changes api of time series method to be less dumb, breaks existing calls
This commit is contained in:
parent
b4b25974c9
commit
450f0d5730
@ -60,7 +60,11 @@ def read_files(dipoles_file, dots_file):
|
||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||
)
|
||||
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
|
||||
_logger.info(dipoles)
|
||||
dots = tantri.cli.file_importer.read_dots_json_file(dots_file)
|
||||
for dipole in dipoles:
|
||||
_logger.info(dipole)
|
||||
for dot in dots:
|
||||
_logger.info(dot)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
@ -23,7 +23,7 @@ class DotPosition:
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dipole:
|
||||
class WrappedDipole:
|
||||
# assumed len 3
|
||||
p: numpy.ndarray
|
||||
s: numpy.ndarray
|
||||
@ -94,10 +94,29 @@ class Dipole:
|
||||
self.state *= -1
|
||||
|
||||
|
||||
def get_wrapped_dipoles(
|
||||
dipole_tos: typing.Sequence[DipoleTO],
|
||||
dots: typing.Sequence[DotPosition],
|
||||
measurement_type: DipoleMeasurementType,
|
||||
) -> typing.Sequence[WrappedDipole]:
|
||||
return [
|
||||
WrappedDipole(
|
||||
p=dipole_to.p,
|
||||
s=dipole_to.s,
|
||||
w=dipole_to.w,
|
||||
dot_positions=dots,
|
||||
measurement_type=measurement_type,
|
||||
)
|
||||
for dipole_to in dipole_tos
|
||||
]
|
||||
|
||||
|
||||
class DipoleTimeSeries:
|
||||
def __init__(
|
||||
self,
|
||||
dipoles: typing.Sequence[Dipole],
|
||||
dipoles: typing.Sequence[DipoleTO],
|
||||
dots: typing.Sequence[DotPosition],
|
||||
measurement_type: DipoleMeasurementType,
|
||||
dt: float,
|
||||
rng_to_use: typing.Optional[numpy.random.Generator] = None,
|
||||
):
|
||||
@ -107,7 +126,7 @@ class DipoleTimeSeries:
|
||||
else:
|
||||
self.rng = rng_to_use
|
||||
|
||||
self.dipoles = dipoles
|
||||
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
|
||||
self.state = 0
|
||||
self.dt = dt
|
||||
|
||||
@ -125,7 +144,7 @@ class DipoleTimeSeries:
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Dipole",
|
||||
"WrappedDipole",
|
||||
"DipoleTimeSeries",
|
||||
"DipoleTO",
|
||||
]
|
||||
|
@ -1,4 +1,9 @@
|
||||
from tantri.dipoles import Dipole, DipoleTimeSeries, DotPosition, DipoleMeasurementType
|
||||
from tantri.dipoles import (
|
||||
DipoleTO,
|
||||
DipoleTimeSeries,
|
||||
DotPosition,
|
||||
DipoleMeasurementType,
|
||||
)
|
||||
import numpy
|
||||
|
||||
|
||||
@ -7,36 +12,34 @@ def test_dipoles_1(snapshot):
|
||||
rng = numpy.random.default_rng(1234)
|
||||
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
|
||||
|
||||
d1 = Dipole(
|
||||
d1 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([5, 3, 2]),
|
||||
15,
|
||||
dots,
|
||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
||||
)
|
||||
d2 = Dipole(
|
||||
d2 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([-2, 3, 2]),
|
||||
0.1,
|
||||
dots,
|
||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
||||
)
|
||||
d3 = Dipole(
|
||||
d3 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([2, 1, 2]),
|
||||
6,
|
||||
dots,
|
||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
||||
)
|
||||
d4 = Dipole(
|
||||
d4 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([-5, -5, 2]),
|
||||
50,
|
||||
dots,
|
||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
||||
)
|
||||
|
||||
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng)
|
||||
ts_gen = DipoleTimeSeries(
|
||||
[d1, d2, d3, d4],
|
||||
dots,
|
||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
||||
0.01,
|
||||
rng_to_use=rng,
|
||||
)
|
||||
time_series = [ts_gen.transition() for i in range(50)]
|
||||
|
||||
assert time_series == snapshot
|
||||
@ -47,36 +50,34 @@ def test_dipoles_1_field(snapshot):
|
||||
rng = numpy.random.default_rng(1234)
|
||||
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
|
||||
|
||||
d1 = Dipole(
|
||||
d1 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([5, 3, 2]),
|
||||
15,
|
||||
dots,
|
||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
||||
)
|
||||
d2 = Dipole(
|
||||
d2 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([-2, 3, 2]),
|
||||
0.1,
|
||||
dots,
|
||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
||||
)
|
||||
d3 = Dipole(
|
||||
d3 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([2, 1, 2]),
|
||||
6,
|
||||
dots,
|
||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
||||
)
|
||||
d4 = Dipole(
|
||||
d4 = DipoleTO(
|
||||
numpy.array([0, 0, 10]),
|
||||
numpy.array([-5, -5, 2]),
|
||||
50,
|
||||
dots,
|
||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
||||
)
|
||||
|
||||
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng)
|
||||
ts_gen = DipoleTimeSeries(
|
||||
[d1, d2, d3, d4],
|
||||
dots,
|
||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
||||
0.01,
|
||||
rng_to_use=rng,
|
||||
)
|
||||
time_series = [ts_gen.transition() for i in range(50)]
|
||||
|
||||
assert time_series == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user