tantri/tantri/dipoles/time_series.py

135 lines
3.8 KiB
Python

from dataclasses import dataclass
import numpy
import numpy.random
import typing
from tantri.dipoles.types import DipoleTO, DotPosition, DipoleMeasurementType
import tantri.dipoles.supersample
@dataclass
class WrappedDipole:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# should be 1/tau up to some pis
w: float
# For caching purposes tell each dipole where the dots are
# TODO: This can be done better by only passing into the time series the non-repeated p s and w,
# TODO: and then creating a new wrapper type to include all the cached stuff.
# TODO: Realistically, the dot positions and measurement type data should live in the time series.
dot_positions: typing.Sequence[DotPosition]
measurement_type: DipoleMeasurementType
def __post_init__(self) -> None:
"""
Coerce the inputs into numpy arrays.
"""
self.p = numpy.array(self.p)
self.s = numpy.array(self.s)
self.state = 1
self.cache = {}
for pos in self.dot_positions:
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
self.cache[pos.label] = self.potential(pos)
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
self.cache[pos.label] = self.e_field_x(pos)
def potential(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
def e_field_x(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
norm = numpy.linalg.norm(r_diff)
return (
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
)[0]
def transition(
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
) -> typing.Dict[str, float]:
rng: numpy.random.Generator
if rng_to_use is None:
rng = numpy.random.default_rng()
else:
rng = rng_to_use
# if on average flipping often, then just return 0, basically this dipole has been all used up.
# Facilitates going for different types of noise at very low freq?
if dt * 10 >= 1 / self.w:
# _logger.warning(
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
# )
self.state = rng.integers(0, 1, endpoint=True)
else:
prob = dt * self.w
if rng.random() < prob:
# _logger.debug("flip!")
self.flip_state()
return {k: self.state * v for k, v in self.cache.items()}
def flip_state(self):
self.state *= -1
def get_wrapped_dipoles(
dipole_tos: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
) -> typing.Sequence[WrappedDipole]:
return [
WrappedDipole(
p=dipole_to.p,
s=dipole_to.s,
w=dipole_to.w,
dot_positions=dots,
measurement_type=measurement_type,
)
for dipole_to in dipole_tos
]
class DipoleTimeSeries:
def __init__(
self,
dipoles: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
dt: float,
rng_to_use: typing.Optional[numpy.random.Generator] = None,
):
self.rng: numpy.random.Generator
if rng_to_use is None:
self.rng = numpy.random.default_rng()
else:
self.rng = rng_to_use
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
self.state = 0
# we may need to supersample, because of how dumb this process is.
# let's find our highest frequency
max_frequency = max(d.w for d in self.dipoles)
super_sample = tantri.dipoles.supersample.get_supersample(max_frequency, dt)
self.dt = super_sample.super_dt
self.super_sample_ratio = super_sample.super_sample_ratio
def transition(self) -> typing.Dict[str, float]:
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
ret = {}
for transition in new_vals:
for k, v in transition.items():
if k not in ret:
ret[k] = v
else:
ret[k] += v
return ret