feat: adds initial implementation of time series generation
Some checks failed
gitea-physics/tantri/pipeline/head There was a failure building this commit
Some checks failed
gitea-physics/tantri/pipeline/head There was a failure building this commit
This commit is contained in:
parent
9895c9d397
commit
7bde246a21
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from tantri.meta import __version__
|
||||
from tantri.simple_telegraph_time_series import SimpleTelegraphTimeSeries
|
||||
|
||||
|
||||
def get_version():
|
||||
@ -7,3 +8,9 @@ def get_version():
|
||||
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
||||
__all__ = [
|
||||
"SimpleTelegraphTimeSeries",
|
||||
"__version__",
|
||||
"get_version",
|
||||
]
|
||||
|
126
tantri/dipoles/__init__.py
Executable file
126
tantri/dipoles/__init__.py
Executable file
@ -0,0 +1,126 @@
|
||||
from dataclasses import dataclass
|
||||
import numpy
|
||||
import numpy.random
|
||||
import typing
|
||||
from enum import Enum
|
||||
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DipoleMeasurementType(Enum):
|
||||
ELECTRIC_POTENTIAL = 1
|
||||
X_ELECTRIC_FIELD = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DotPosition:
|
||||
# assume len 3
|
||||
r: numpy.ndarray
|
||||
label: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dipole:
|
||||
# assumed len 3
|
||||
p: numpy.ndarray
|
||||
s: numpy.ndarray
|
||||
|
||||
# should be 1/tau up to some pis
|
||||
w: float
|
||||
|
||||
# For caching purposes tell each dipole where the dots are
|
||||
dot_positions: typing.Sequence[DotPosition]
|
||||
|
||||
measurement_type: DipoleMeasurementType
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""
|
||||
Coerce the inputs into numpy arrays.
|
||||
"""
|
||||
self.p = numpy.array(self.p)
|
||||
self.s = numpy.array(self.s)
|
||||
|
||||
self.state = 1
|
||||
self.cache = {}
|
||||
for pos in self.dot_positions:
|
||||
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
|
||||
self.cache[pos.label] = self.potential(pos)
|
||||
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
|
||||
self.cache[pos.label] = self.e_field_x(pos)
|
||||
|
||||
def potential(self, dot: DotPosition) -> float:
|
||||
# let's assume single dot at origin for now
|
||||
r_diff = self.s - dot.r
|
||||
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
|
||||
|
||||
def e_field_x(self, dot: DotPosition) -> float:
|
||||
# let's assume single dot at origin for now
|
||||
r_diff = self.s - dot.r
|
||||
norm = numpy.linalg.norm(r_diff)
|
||||
|
||||
return (
|
||||
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
|
||||
)[0]
|
||||
|
||||
def transition(
|
||||
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
|
||||
) -> typing.Dict[str, float]:
|
||||
rng: numpy.random.Generator
|
||||
if rng_to_use is None:
|
||||
rng = numpy.random.default_rng()
|
||||
else:
|
||||
rng = rng_to_use
|
||||
# if on average flipping often, then just return 0, basically this dipole has been all used up.
|
||||
# Facilitates going for different types of noise at very low freq?
|
||||
if dt * 10 >= 1 / self.w:
|
||||
# _logger.warning(
|
||||
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
|
||||
# )
|
||||
self.state = rng.integers(0, 1, endpoint=True)
|
||||
else:
|
||||
prob = dt * self.w
|
||||
if rng.random() < prob:
|
||||
# _logger.debug("flip!")
|
||||
self.flip_state()
|
||||
return {k: self.state * v for k, v in self.cache.items()}
|
||||
|
||||
def flip_state(self):
|
||||
self.state *= -1
|
||||
|
||||
|
||||
class DipoleTimeSeries:
|
||||
def __init__(
|
||||
self,
|
||||
dipoles: typing.Sequence[Dipole],
|
||||
dt: float,
|
||||
rng_to_use: typing.Optional[numpy.random.Generator] = None,
|
||||
):
|
||||
self.rng: numpy.random.Generator
|
||||
if rng_to_use is None:
|
||||
self.rng = numpy.random.default_rng()
|
||||
else:
|
||||
self.rng = rng_to_use
|
||||
|
||||
self.dipoles = dipoles
|
||||
self.state = 0
|
||||
self.dt = dt
|
||||
|
||||
def transition(self) -> typing.Dict[str, float]:
|
||||
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
|
||||
|
||||
ret = {}
|
||||
for transition in new_vals:
|
||||
for k, v in transition.items():
|
||||
if k not in ret:
|
||||
ret[k] = v
|
||||
else:
|
||||
ret[k] += v
|
||||
return ret
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Dipole",
|
||||
"DipoleTimeSeries",
|
||||
]
|
25
tantri/simple_telegraph_time_series.py
Executable file
25
tantri/simple_telegraph_time_series.py
Executable file
@ -0,0 +1,25 @@
|
||||
import numpy.random
|
||||
|
||||
|
||||
class SimpleTelegraphTimeSeries:
|
||||
def __init__(
|
||||
self, rng_to_use: numpy.random.Generator, transition_probability: float
|
||||
):
|
||||
if transition_probability > 1 or transition_probability < 0:
|
||||
raise ValueError(
|
||||
f"Transition probabliity must be between 0 and 1, got {transition_probability} instead"
|
||||
)
|
||||
self.transition_probability = transition_probability
|
||||
|
||||
self.rng: numpy.random.Generator
|
||||
if rng_to_use is None:
|
||||
self.rng = numpy.random.default_rng()
|
||||
else:
|
||||
self.rng = rng_to_use
|
||||
|
||||
self.state = 0
|
||||
|
||||
def transition(self) -> float:
|
||||
if self.rng.random() < self.transition_probability:
|
||||
self.state = 1 - self.state
|
||||
return self.state
|
1029
tests/__snapshots__/test_simple_telegraph_time_series.ambr
Executable file
1029
tests/__snapshots__/test_simple_telegraph_time_series.ambr
Executable file
File diff suppressed because it is too large
Load Diff
25
tests/dipoles/__snapshots__/test_dipoles.ambr
Executable file
25
tests/dipoles/__snapshots__/test_dipoles.ambr
Executable file
@ -0,0 +1,25 @@
|
||||
# serializer version: 1
|
||||
# name: test_dipoles_1
|
||||
list([
|
||||
1.1618574890412874,
|
||||
1.1618574890412874,
|
||||
1.1618574890412874,
|
||||
1.1618574890412874,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.31962399244019385,
|
||||
-0.4204260394683082,
|
||||
])
|
||||
# ---
|
14
tests/dipoles/test_dipoles.py
Executable file
14
tests/dipoles/test_dipoles.py
Executable file
@ -0,0 +1,14 @@
|
||||
from tantri.dipoles import Dipole, DipoleTimeSeries
|
||||
import numpy
|
||||
|
||||
|
||||
def test_dipoles_1(snapshot):
|
||||
d1 = Dipole(numpy.array([0, 0, 10]), numpy.array([5, 3, 2]), 15)
|
||||
d2 = Dipole(numpy.array([0, 0, 10]), numpy.array([-2, 3, 2]), 0.1)
|
||||
d3 = Dipole(numpy.array([0, 0, 10]), numpy.array([2, 1, 2]), 6)
|
||||
d4 = Dipole(numpy.array([0, 0, 10]), numpy.array([-5, -5, 2]), 50)
|
||||
|
||||
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.001)
|
||||
time_series = [ts_gen.transition() for i in range(20)]
|
||||
|
||||
assert time_series == snapshot
|
18
tests/test_simple_telegraph_time_series.py
Executable file
18
tests/test_simple_telegraph_time_series.py
Executable file
@ -0,0 +1,18 @@
|
||||
import tantri
|
||||
import numpy.random
|
||||
|
||||
|
||||
def test_simple_telegraph_basic_generation(snapshot):
|
||||
rng = numpy.random.default_rng(1234)
|
||||
ts_generator = tantri.SimpleTelegraphTimeSeries(rng, 0.8)
|
||||
time_series = [ts_generator.transition() for i in range(20)]
|
||||
|
||||
assert time_series == snapshot
|
||||
|
||||
|
||||
def test_simple_telegraph_longer_generation(snapshot):
|
||||
rng = numpy.random.default_rng(1234)
|
||||
ts_generator = tantri.SimpleTelegraphTimeSeries(rng, 0.01)
|
||||
time_series = [ts_generator.transition() for i in range(1000)]
|
||||
|
||||
assert time_series == snapshot
|
Loading…
x
Reference in New Issue
Block a user