feat: adds initial implementation of time series generation
Some checks failed
gitea-physics/tantri/pipeline/head There was a failure building this commit

This commit is contained in:
Deepak Mallubhotla 2024-04-20 13:52:40 -05:00
parent 9895c9d397
commit 7bde246a21
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
7 changed files with 1244 additions and 0 deletions

View File

@ -1,5 +1,6 @@
import logging
from tantri.meta import __version__
from tantri.simple_telegraph_time_series import SimpleTelegraphTimeSeries
def get_version():
@ -7,3 +8,9 @@ def get_version():
logging.getLogger(__name__).addHandler(logging.NullHandler())
__all__ = [
"SimpleTelegraphTimeSeries",
"__version__",
"get_version",
]

126
tantri/dipoles/__init__.py Executable file
View File

@ -0,0 +1,126 @@
from dataclasses import dataclass
import numpy
import numpy.random
import typing
from enum import Enum
import logging
_logger = logging.getLogger(__name__)
class DipoleMeasurementType(Enum):
ELECTRIC_POTENTIAL = 1
X_ELECTRIC_FIELD = 2
@dataclass(frozen=True)
class DotPosition:
# assume len 3
r: numpy.ndarray
label: str
@dataclass
class Dipole:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# should be 1/tau up to some pis
w: float
# For caching purposes tell each dipole where the dots are
dot_positions: typing.Sequence[DotPosition]
measurement_type: DipoleMeasurementType
def __post_init__(self) -> None:
"""
Coerce the inputs into numpy arrays.
"""
self.p = numpy.array(self.p)
self.s = numpy.array(self.s)
self.state = 1
self.cache = {}
for pos in self.dot_positions:
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
self.cache[pos.label] = self.potential(pos)
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
self.cache[pos.label] = self.e_field_x(pos)
def potential(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
def e_field_x(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
norm = numpy.linalg.norm(r_diff)
return (
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
)[0]
def transition(
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
) -> typing.Dict[str, float]:
rng: numpy.random.Generator
if rng_to_use is None:
rng = numpy.random.default_rng()
else:
rng = rng_to_use
# if on average flipping often, then just return 0, basically this dipole has been all used up.
# Facilitates going for different types of noise at very low freq?
if dt * 10 >= 1 / self.w:
# _logger.warning(
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
# )
self.state = rng.integers(0, 1, endpoint=True)
else:
prob = dt * self.w
if rng.random() < prob:
# _logger.debug("flip!")
self.flip_state()
return {k: self.state * v for k, v in self.cache.items()}
def flip_state(self):
self.state *= -1
class DipoleTimeSeries:
def __init__(
self,
dipoles: typing.Sequence[Dipole],
dt: float,
rng_to_use: typing.Optional[numpy.random.Generator] = None,
):
self.rng: numpy.random.Generator
if rng_to_use is None:
self.rng = numpy.random.default_rng()
else:
self.rng = rng_to_use
self.dipoles = dipoles
self.state = 0
self.dt = dt
def transition(self) -> typing.Dict[str, float]:
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
ret = {}
for transition in new_vals:
for k, v in transition.items():
if k not in ret:
ret[k] = v
else:
ret[k] += v
return ret
__all__ = [
"Dipole",
"DipoleTimeSeries",
]

View File

@ -0,0 +1,25 @@
import numpy.random
class SimpleTelegraphTimeSeries:
def __init__(
self, rng_to_use: numpy.random.Generator, transition_probability: float
):
if transition_probability > 1 or transition_probability < 0:
raise ValueError(
f"Transition probabliity must be between 0 and 1, got {transition_probability} instead"
)
self.transition_probability = transition_probability
self.rng: numpy.random.Generator
if rng_to_use is None:
self.rng = numpy.random.default_rng()
else:
self.rng = rng_to_use
self.state = 0
def transition(self) -> float:
if self.rng.random() < self.transition_probability:
self.state = 1 - self.state
return self.state

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,25 @@
# serializer version: 1
# name: test_dipoles_1
list([
1.1618574890412874,
1.1618574890412874,
1.1618574890412874,
1.1618574890412874,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.31962399244019385,
-0.4204260394683082,
])
# ---

14
tests/dipoles/test_dipoles.py Executable file
View File

@ -0,0 +1,14 @@
from tantri.dipoles import Dipole, DipoleTimeSeries
import numpy
def test_dipoles_1(snapshot):
d1 = Dipole(numpy.array([0, 0, 10]), numpy.array([5, 3, 2]), 15)
d2 = Dipole(numpy.array([0, 0, 10]), numpy.array([-2, 3, 2]), 0.1)
d3 = Dipole(numpy.array([0, 0, 10]), numpy.array([2, 1, 2]), 6)
d4 = Dipole(numpy.array([0, 0, 10]), numpy.array([-5, -5, 2]), 50)
ts_gen = DipoleTimeSeries([d1, d2, d3, d4], 0.001)
time_series = [ts_gen.transition() for i in range(20)]
assert time_series == snapshot

View File

@ -0,0 +1,18 @@
import tantri
import numpy.random
def test_simple_telegraph_basic_generation(snapshot):
rng = numpy.random.default_rng(1234)
ts_generator = tantri.SimpleTelegraphTimeSeries(rng, 0.8)
time_series = [ts_generator.transition() for i in range(20)]
assert time_series == snapshot
def test_simple_telegraph_longer_generation(snapshot):
rng = numpy.random.default_rng(1234)
ts_generator = tantri.SimpleTelegraphTimeSeries(rng, 0.01)
time_series = [ts_generator.transition() for i in range(1000)]
assert time_series == snapshot