feat!: changes api of time series method to be less dumb, breaks existing calls
This commit is contained in:
parent
b4b25974c9
commit
450f0d5730
@ -60,7 +60,11 @@ def read_files(dipoles_file, dots_file):
|
|||||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||||
)
|
)
|
||||||
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
|
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
|
||||||
_logger.info(dipoles)
|
dots = tantri.cli.file_importer.read_dots_json_file(dots_file)
|
||||||
|
for dipole in dipoles:
|
||||||
|
_logger.info(dipole)
|
||||||
|
for dot in dots:
|
||||||
|
_logger.info(dot)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
@ -23,7 +23,7 @@ class DotPosition:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Dipole:
|
class WrappedDipole:
|
||||||
# assumed len 3
|
# assumed len 3
|
||||||
p: numpy.ndarray
|
p: numpy.ndarray
|
||||||
s: numpy.ndarray
|
s: numpy.ndarray
|
||||||
@ -94,10 +94,29 @@ class Dipole:
|
|||||||
self.state *= -1
|
self.state *= -1
|
||||||
|
|
||||||
|
|
||||||
|
def get_wrapped_dipoles(
|
||||||
|
dipole_tos: typing.Sequence[DipoleTO],
|
||||||
|
dots: typing.Sequence[DotPosition],
|
||||||
|
measurement_type: DipoleMeasurementType,
|
||||||
|
) -> typing.Sequence[WrappedDipole]:
|
||||||
|
return [
|
||||||
|
WrappedDipole(
|
||||||
|
p=dipole_to.p,
|
||||||
|
s=dipole_to.s,
|
||||||
|
w=dipole_to.w,
|
||||||
|
dot_positions=dots,
|
||||||
|
measurement_type=measurement_type,
|
||||||
|
)
|
||||||
|
for dipole_to in dipole_tos
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DipoleTimeSeries:
|
class DipoleTimeSeries:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dipoles: typing.Sequence[Dipole],
|
dipoles: typing.Sequence[DipoleTO],
|
||||||
|
dots: typing.Sequence[DotPosition],
|
||||||
|
measurement_type: DipoleMeasurementType,
|
||||||
dt: float,
|
dt: float,
|
||||||
rng_to_use: typing.Optional[numpy.random.Generator] = None,
|
rng_to_use: typing.Optional[numpy.random.Generator] = None,
|
||||||
):
|
):
|
||||||
@ -107,7 +126,7 @@ class DipoleTimeSeries:
|
|||||||
else:
|
else:
|
||||||
self.rng = rng_to_use
|
self.rng = rng_to_use
|
||||||
|
|
||||||
self.dipoles = dipoles
|
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
|
||||||
self.state = 0
|
self.state = 0
|
||||||
self.dt = dt
|
self.dt = dt
|
||||||
|
|
||||||
@ -125,7 +144,7 @@ class DipoleTimeSeries:
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Dipole",
|
"WrappedDipole",
|
||||||
"DipoleTimeSeries",
|
"DipoleTimeSeries",
|
||||||
"DipoleTO",
|
"DipoleTO",
|
||||||
]
|
]
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
from tantri.dipoles import Dipole, DipoleTimeSeries, DotPosition, DipoleMeasurementType
|
from tantri.dipoles import (
|
||||||
|
DipoleTO,
|
||||||
|
DipoleTimeSeries,
|
||||||
|
DotPosition,
|
||||||
|
DipoleMeasurementType,
|
||||||
|
)
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
@ -7,36 +12,34 @@ def test_dipoles_1(snapshot):
|
|||||||
rng = numpy.random.default_rng(1234)
|
rng = numpy.random.default_rng(1234)
|
||||||
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
|
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
|
||||||
|
|
||||||
d1 = Dipole(
|
d1 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([5, 3, 2]),
|
numpy.array([5, 3, 2]),
|
||||||
15,
|
15,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
|
||||||
)
|
)
|
||||||
d2 = Dipole(
|
d2 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([-2, 3, 2]),
|
numpy.array([-2, 3, 2]),
|
||||||
0.1,
|
0.1,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
|
||||||
)
|
)
|
||||||
d3 = Dipole(
|
d3 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([2, 1, 2]),
|
numpy.array([2, 1, 2]),
|
||||||
6,
|
6,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
|
||||||
)
|
)
|
||||||
d4 = Dipole(
|
d4 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([-5, -5, 2]),
|
numpy.array([-5, -5, 2]),
|
||||||
50,
|
50,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng)
|
ts_gen = DipoleTimeSeries(
|
||||||
|
[d1, d2, d3, d4],
|
||||||
|
dots,
|
||||||
|
DipoleMeasurementType.ELECTRIC_POTENTIAL,
|
||||||
|
0.01,
|
||||||
|
rng_to_use=rng,
|
||||||
|
)
|
||||||
time_series = [ts_gen.transition() for i in range(50)]
|
time_series = [ts_gen.transition() for i in range(50)]
|
||||||
|
|
||||||
assert time_series == snapshot
|
assert time_series == snapshot
|
||||||
@ -47,36 +50,34 @@ def test_dipoles_1_field(snapshot):
|
|||||||
rng = numpy.random.default_rng(1234)
|
rng = numpy.random.default_rng(1234)
|
||||||
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
|
dots = [DotPosition(numpy.array([0, 0, 0]), "dot1")]
|
||||||
|
|
||||||
d1 = Dipole(
|
d1 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([5, 3, 2]),
|
numpy.array([5, 3, 2]),
|
||||||
15,
|
15,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
|
||||||
)
|
)
|
||||||
d2 = Dipole(
|
d2 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([-2, 3, 2]),
|
numpy.array([-2, 3, 2]),
|
||||||
0.1,
|
0.1,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
|
||||||
)
|
)
|
||||||
d3 = Dipole(
|
d3 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([2, 1, 2]),
|
numpy.array([2, 1, 2]),
|
||||||
6,
|
6,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
|
||||||
)
|
)
|
||||||
d4 = Dipole(
|
d4 = DipoleTO(
|
||||||
numpy.array([0, 0, 10]),
|
numpy.array([0, 0, 10]),
|
||||||
numpy.array([-5, -5, 2]),
|
numpy.array([-5, -5, 2]),
|
||||||
50,
|
50,
|
||||||
dots,
|
|
||||||
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.01, rng_to_use=rng)
|
ts_gen = DipoleTimeSeries(
|
||||||
|
[d1, d2, d3, d4],
|
||||||
|
dots,
|
||||||
|
DipoleMeasurementType.X_ELECTRIC_FIELD,
|
||||||
|
0.01,
|
||||||
|
rng_to_use=rng,
|
||||||
|
)
|
||||||
time_series = [ts_gen.transition() for i in range(50)]
|
time_series = [ts_gen.transition() for i in range(50)]
|
||||||
|
|
||||||
assert time_series == snapshot
|
assert time_series == snapshot
|
||||||
|
Loading…
x
Reference in New Issue
Block a user