feat: adds initial implementation of time series generation
Some checks failed
gitea-physics/tantri/pipeline/head There was a failure building this commit
Some checks failed
gitea-physics/tantri/pipeline/head There was a failure building this commit
This commit is contained in:
parent
9895c9d397
commit
7bde246a21
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from tantri.meta import __version__
|
from tantri.meta import __version__
|
||||||
|
from tantri.simple_telegraph_time_series import SimpleTelegraphTimeSeries
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
@ -7,3 +8,9 @@ def get_version():
|
|||||||
|
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SimpleTelegraphTimeSeries",
|
||||||
|
"__version__",
|
||||||
|
"get_version",
|
||||||
|
]
|
||||||
|
126
tantri/dipoles/__init__.py
Executable file
126
tantri/dipoles/__init__.py
Executable file
@ -0,0 +1,126 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy
|
||||||
|
import numpy.random
|
||||||
|
import typing
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DipoleMeasurementType(Enum):
|
||||||
|
ELECTRIC_POTENTIAL = 1
|
||||||
|
X_ELECTRIC_FIELD = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DotPosition:
|
||||||
|
# assume len 3
|
||||||
|
r: numpy.ndarray
|
||||||
|
label: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Dipole:
|
||||||
|
# assumed len 3
|
||||||
|
p: numpy.ndarray
|
||||||
|
s: numpy.ndarray
|
||||||
|
|
||||||
|
# should be 1/tau up to some pis
|
||||||
|
w: float
|
||||||
|
|
||||||
|
# For caching purposes tell each dipole where the dots are
|
||||||
|
dot_positions: typing.Sequence[DotPosition]
|
||||||
|
|
||||||
|
measurement_type: DipoleMeasurementType
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Coerce the inputs into numpy arrays.
|
||||||
|
"""
|
||||||
|
self.p = numpy.array(self.p)
|
||||||
|
self.s = numpy.array(self.s)
|
||||||
|
|
||||||
|
self.state = 1
|
||||||
|
self.cache = {}
|
||||||
|
for pos in self.dot_positions:
|
||||||
|
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
|
||||||
|
self.cache[pos.label] = self.potential(pos)
|
||||||
|
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
|
||||||
|
self.cache[pos.label] = self.e_field_x(pos)
|
||||||
|
|
||||||
|
def potential(self, dot: DotPosition) -> float:
|
||||||
|
# let's assume single dot at origin for now
|
||||||
|
r_diff = self.s - dot.r
|
||||||
|
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
|
||||||
|
|
||||||
|
def e_field_x(self, dot: DotPosition) -> float:
|
||||||
|
# let's assume single dot at origin for now
|
||||||
|
r_diff = self.s - dot.r
|
||||||
|
norm = numpy.linalg.norm(r_diff)
|
||||||
|
|
||||||
|
return (
|
||||||
|
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def transition(
|
||||||
|
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
|
||||||
|
) -> typing.Dict[str, float]:
|
||||||
|
rng: numpy.random.Generator
|
||||||
|
if rng_to_use is None:
|
||||||
|
rng = numpy.random.default_rng()
|
||||||
|
else:
|
||||||
|
rng = rng_to_use
|
||||||
|
# if on average flipping often, then just return 0, basically this dipole has been all used up.
|
||||||
|
# Facilitates going for different types of noise at very low freq?
|
||||||
|
if dt * 10 >= 1 / self.w:
|
||||||
|
# _logger.warning(
|
||||||
|
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
|
||||||
|
# )
|
||||||
|
self.state = rng.integers(0, 1, endpoint=True)
|
||||||
|
else:
|
||||||
|
prob = dt * self.w
|
||||||
|
if rng.random() < prob:
|
||||||
|
# _logger.debug("flip!")
|
||||||
|
self.flip_state()
|
||||||
|
return {k: self.state * v for k, v in self.cache.items()}
|
||||||
|
|
||||||
|
def flip_state(self):
|
||||||
|
self.state *= -1
|
||||||
|
|
||||||
|
|
||||||
|
class DipoleTimeSeries:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dipoles: typing.Sequence[Dipole],
|
||||||
|
dt: float,
|
||||||
|
rng_to_use: typing.Optional[numpy.random.Generator] = None,
|
||||||
|
):
|
||||||
|
self.rng: numpy.random.Generator
|
||||||
|
if rng_to_use is None:
|
||||||
|
self.rng = numpy.random.default_rng()
|
||||||
|
else:
|
||||||
|
self.rng = rng_to_use
|
||||||
|
|
||||||
|
self.dipoles = dipoles
|
||||||
|
self.state = 0
|
||||||
|
self.dt = dt
|
||||||
|
|
||||||
|
def transition(self) -> typing.Dict[str, float]:
|
||||||
|
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for transition in new_vals:
|
||||||
|
for k, v in transition.items():
|
||||||
|
if k not in ret:
|
||||||
|
ret[k] = v
|
||||||
|
else:
|
||||||
|
ret[k] += v
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Dipole",
|
||||||
|
"DipoleTimeSeries",
|
||||||
|
]
|
25
tantri/simple_telegraph_time_series.py
Executable file
25
tantri/simple_telegraph_time_series.py
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
import numpy.random
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTelegraphTimeSeries:
|
||||||
|
def __init__(
|
||||||
|
self, rng_to_use: numpy.random.Generator, transition_probability: float
|
||||||
|
):
|
||||||
|
if transition_probability > 1 or transition_probability < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Transition probabliity must be between 0 and 1, got {transition_probability} instead"
|
||||||
|
)
|
||||||
|
self.transition_probability = transition_probability
|
||||||
|
|
||||||
|
self.rng: numpy.random.Generator
|
||||||
|
if rng_to_use is None:
|
||||||
|
self.rng = numpy.random.default_rng()
|
||||||
|
else:
|
||||||
|
self.rng = rng_to_use
|
||||||
|
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
def transition(self) -> float:
|
||||||
|
if self.rng.random() < self.transition_probability:
|
||||||
|
self.state = 1 - self.state
|
||||||
|
return self.state
|
1029
tests/__snapshots__/test_simple_telegraph_time_series.ambr
Executable file
1029
tests/__snapshots__/test_simple_telegraph_time_series.ambr
Executable file
File diff suppressed because it is too large
Load Diff
25
tests/dipoles/__snapshots__/test_dipoles.ambr
Executable file
25
tests/dipoles/__snapshots__/test_dipoles.ambr
Executable file
@ -0,0 +1,25 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_dipoles_1
|
||||||
|
list([
|
||||||
|
1.1618574890412874,
|
||||||
|
1.1618574890412874,
|
||||||
|
1.1618574890412874,
|
||||||
|
1.1618574890412874,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.31962399244019385,
|
||||||
|
-0.4204260394683082,
|
||||||
|
])
|
||||||
|
# ---
|
14
tests/dipoles/test_dipoles.py
Executable file
14
tests/dipoles/test_dipoles.py
Executable file
@ -0,0 +1,14 @@
|
|||||||
|
from tantri.dipoles import Dipole, DipoleTimeSeries
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
def test_dipoles_1(snapshot):
|
||||||
|
d1 = Dipole(numpy.array([0, 0, 10]), numpy.array([5, 3, 2]), 15)
|
||||||
|
d2 = Dipole(numpy.array([0, 0, 10]), numpy.array([-2, 3, 2]), 0.1)
|
||||||
|
d3 = Dipole(numpy.array([0, 0, 10]), numpy.array([2, 1, 2]), 6)
|
||||||
|
d4 = Dipole(numpy.array([0, 0, 10]), numpy.array([-5, -5, 2]), 50)
|
||||||
|
|
||||||
|
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.001)
|
||||||
|
time_series = [ts_gen.transition() for i in range(20)]
|
||||||
|
|
||||||
|
assert time_series == snapshot
|
18
tests/test_simple_telegraph_time_series.py
Executable file
18
tests/test_simple_telegraph_time_series.py
Executable file
@ -0,0 +1,18 @@
|
|||||||
|
import tantri
|
||||||
|
import numpy.random
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_telegraph_basic_generation(snapshot):
|
||||||
|
rng = numpy.random.default_rng(1234)
|
||||||
|
ts_generator = tantri.SimpleTelegraphTimeSeries(rng, 0.8)
|
||||||
|
time_series = [ts_generator.transition() for i in range(20)]
|
||||||
|
|
||||||
|
assert time_series == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_telegraph_longer_generation(snapshot):
|
||||||
|
rng = numpy.random.default_rng(1234)
|
||||||
|
ts_generator = tantri.SimpleTelegraphTimeSeries(rng, 0.01)
|
||||||
|
time_series = [ts_generator.transition() for i in range(1000)]
|
||||||
|
|
||||||
|
assert time_series == snapshot
|
Loading…
x
Reference in New Issue
Block a user