ref: refactoring types and stuff

This commit is contained in:
Deepak Mallubhotla 2024-05-01 23:43:54 -05:00
parent f55bbe66ee
commit ca371762ae
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
10 changed files with 306 additions and 253 deletions

View File

@ -7,6 +7,7 @@ import tantri.cli.file_importer
import tantri.dipoles
import json
import tantri.dipoles.generation
import tantri.dipoles.types
import pathlib
@ -205,7 +206,7 @@ def _generate_dipoles(generation_config, output_file, override_rng_seed):
with open(generation_config, "r") as config_file:
data = json.load(config_file)
config = tantri.dipoles.generation.DipoleGenerationConfig(**data)
config = tantri.dipoles.types.DipoleGenerationConfig(**data)
override_rng = None
if override_rng_seed is not None:

View File

@ -1,150 +1,12 @@
from dataclasses import dataclass
import numpy
import numpy.random
import typing
from enum import Enum
from tantri.dipoles.types import DipoleTO
import logging
_logger = logging.getLogger(__name__)
class DipoleMeasurementType(Enum):
ELECTRIC_POTENTIAL = 1
X_ELECTRIC_FIELD = 2
@dataclass(frozen=True)
class DotPosition:
# assume len 3
r: numpy.ndarray
label: str
@dataclass
class WrappedDipole:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# should be 1/tau up to some pis
w: float
# For caching purposes tell each dipole where the dots are
# TODO: This can be done better by only passing into the time series the non-repeated p s and w,
# TODO: and then creating a new wrapper type to include all the cached stuff.
# TODO: Realistically, the dot positions and measurement type data should live in the time series.
dot_positions: typing.Sequence[DotPosition]
measurement_type: DipoleMeasurementType
def __post_init__(self) -> None:
"""
Coerce the inputs into numpy arrays.
"""
self.p = numpy.array(self.p)
self.s = numpy.array(self.s)
self.state = 1
self.cache = {}
for pos in self.dot_positions:
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
self.cache[pos.label] = self.potential(pos)
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
self.cache[pos.label] = self.e_field_x(pos)
def potential(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
def e_field_x(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
norm = numpy.linalg.norm(r_diff)
return (
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
)[0]
def transition(
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
) -> typing.Dict[str, float]:
rng: numpy.random.Generator
if rng_to_use is None:
rng = numpy.random.default_rng()
else:
rng = rng_to_use
# if on average flipping often, then just return 0, basically this dipole has been all used up.
# Facilitates going for different types of noise at very low freq?
if dt * 10 >= 1 / self.w:
# _logger.warning(
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
# )
self.state = rng.integers(0, 1, endpoint=True)
else:
prob = dt * self.w
if rng.random() < prob:
# _logger.debug("flip!")
self.flip_state()
return {k: self.state * v for k, v in self.cache.items()}
def flip_state(self):
self.state *= -1
def get_wrapped_dipoles(
dipole_tos: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
) -> typing.Sequence[WrappedDipole]:
return [
WrappedDipole(
p=dipole_to.p,
s=dipole_to.s,
w=dipole_to.w,
dot_positions=dots,
measurement_type=measurement_type,
)
for dipole_to in dipole_tos
]
class DipoleTimeSeries:
def __init__(
self,
dipoles: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
dt: float,
rng_to_use: typing.Optional[numpy.random.Generator] = None,
):
self.rng: numpy.random.Generator
if rng_to_use is None:
self.rng = numpy.random.default_rng()
else:
self.rng = rng_to_use
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
self.state = 0
self.dt = dt
def transition(self) -> typing.Dict[str, float]:
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
ret = {}
for transition in new_vals:
for k, v in transition.items():
if k not in ret:
ret[k] = v
else:
ret[k] += v
return ret
from tantri.dipoles.types import DipoleTO, DotPosition, DipoleMeasurementType
from tantri.dipoles.time_series import DipoleTimeSeries, WrappedDipole
from tantri.dipoles.generation import make_dipoles
__all__ = [
"WrappedDipole",
"DipoleTimeSeries",
"DipoleTO",
"DotPosition",
"DipoleMeasurementType",
"make_dipoles",
]

View File

@ -1,107 +1,5 @@
import numpy
from typing import Sequence, Optional
from dataclasses import dataclass, asdict
from tantri.dipoles.types import DipoleTO
from enum import Enum
import logging
from tantri.dipoles.generation.generate_dipole_config import make_dipoles
# stuff for generating random dipoles from parameters
_logger = logging.getLogger(__name__)
class Orientation(str, Enum):
# Enum for orientation, making string for json serialisation purposes
#
# Note that htis might not be infinitely extensible?
# https://stackoverflow.com/questions/75040733/is-there-a-way-to-use-strenum-in-earlier-python-versions
XY = "XY"
Z = "Z"
RANDOM = "RANDOM"
# A description of the parameters needed to generate random dipoles
@dataclass
class DipoleGenerationConfig:
# note no actual checks anywhere that these are sensibly defined with min less than max etc.
x_min: float
x_max: float
y_min: float
y_max: float
z_min: float
z_max: float
mag: float
# these are log_10 of actual value
w_log_min: float
w_log_max: float
orientation: Orientation
dipole_count: int
generation_seed: int
def __post_init__(self):
# This allows us to transparently set this with a string, while providing early warning of a type error
self.orientation = Orientation(self.orientation)
def as_dict(self) -> dict:
return_dict = asdict(self)
return_dict["orientation"] = return_dict["orientation"].value
return return_dict
def make_dipoles(
config: DipoleGenerationConfig,
rng_override: Optional[numpy.random.Generator] = None,
) -> Sequence[DipoleTO]:
if rng_override is None:
_logger.info(
f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation"
)
rng = numpy.random.default_rng(config.generation_seed)
else:
_logger.info("Using overridden rng, of unknown seed")
rng = rng_override
dipoles = []
for i in range(config.dipole_count):
sx = rng.uniform(config.x_min, config.x_max)
sy = rng.uniform(config.y_min, config.y_max)
sz = rng.uniform(config.z_min, config.z_max)
# orientation
# 0, 1, 2
# xy, z, random
if config.orientation is Orientation.RANDOM:
theta = numpy.arccos(2 * rng.random() - 1)
phi = 2 * numpy.pi * rng.random()
elif config.orientation is Orientation.Z:
theta = 0
phi = 0
elif config.orientation is Orientation.XY:
theta = numpy.pi / 2
phi = 2 * numpy.pi * rng.random()
else:
raise ValueError(
f"this shouldn't have happened, orientation index: {config}"
)
px = config.mag * numpy.cos(phi) * numpy.sin(theta)
py = config.mag * numpy.sin(phi) * numpy.sin(theta)
pz = config.mag * numpy.cos(theta)
w = 10 ** rng.uniform(config.w_log_min, config.w_log_max)
dipoles.append(
DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w)
)
return dipoles
__all__ = [
"make_dipoles",
]

View File

@ -0,0 +1,61 @@
import numpy
from typing import Sequence, Optional
from tantri.dipoles.types import DipoleTO, DipoleGenerationConfig, Orientation
import logging
# stuff for generating random dipoles from parameters
_logger = logging.getLogger(__name__)
def make_dipoles(
config: DipoleGenerationConfig,
rng_override: Optional[numpy.random.Generator] = None,
) -> Sequence[DipoleTO]:
if rng_override is None:
_logger.info(
f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation"
)
rng = numpy.random.default_rng(config.generation_seed)
else:
_logger.info("Using overridden rng, of unknown seed")
rng = rng_override
dipoles = []
for i in range(config.dipole_count):
sx = rng.uniform(config.x_min, config.x_max)
sy = rng.uniform(config.y_min, config.y_max)
sz = rng.uniform(config.z_min, config.z_max)
# orientation
# 0, 1, 2
# xy, z, random
if config.orientation is Orientation.RANDOM:
theta = numpy.arccos(2 * rng.random() - 1)
phi = 2 * numpy.pi * rng.random()
elif config.orientation is Orientation.Z:
theta = 0
phi = 0
elif config.orientation is Orientation.XY:
theta = numpy.pi / 2
phi = 2 * numpy.pi * rng.random()
else:
raise ValueError(
f"this shouldn't have happened, orientation index: {config}"
)
px = config.mag * numpy.cos(phi) * numpy.sin(theta)
py = config.mag * numpy.sin(phi) * numpy.sin(theta)
pz = config.mag * numpy.cos(theta)
w = 10 ** rng.uniform(config.w_log_min, config.w_log_max)
dipoles.append(
DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w)
)
return dipoles

View File

@ -0,0 +1,30 @@
import dataclasses
import logging
# import math
_logger = logging.getLogger(__name__)
# how many times faster than the max frequency we want to be, bigger is more accurate but 10 is probably fine
DESIRED_THRESHOLD = 10
@dataclasses.dataclass
class SuperSample:
super_dt: float
super_sample_ratio: int
def get_supersample(max_frequency: float, dt: float) -> SuperSample:
# now we want to sample at least 10x faster than max_frequency, otherwise we're going to skew our statistics
# note that this is why if performance mattered we'd be optimising this to pre-gen our flip times with poisson statistics.
# so we want (1/dt) > 10 * max_freq
if DESIRED_THRESHOLD * dt * max_frequency < 1:
# can return unchanged
_logger.debug("no supersampling needed")
return SuperSample(super_dt=dt, super_sample_ratio=1)
else:
# else we want a such that a / dt > 10 * max_freq, or a > 10 * dt * max_freq, a = math.ceil(10 * dt * max_freq)
# a = math.ceil(DESIRED_THRESHOLD * dt * max_frequency)
# return SuperSample(super_dt=dt / a, super_sample_ratio=a)
return SuperSample(super_dt=dt, super_sample_ratio=1)

View File

@ -0,0 +1,134 @@
from dataclasses import dataclass
import numpy
import numpy.random
import typing
from tantri.dipoles.types import DipoleTO, DotPosition, DipoleMeasurementType
import tantri.dipoles.supersample
@dataclass
class WrappedDipole:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# should be 1/tau up to some pis
w: float
# For caching purposes tell each dipole where the dots are
# TODO: This can be done better by only passing into the time series the non-repeated p s and w,
# TODO: and then creating a new wrapper type to include all the cached stuff.
# TODO: Realistically, the dot positions and measurement type data should live in the time series.
dot_positions: typing.Sequence[DotPosition]
measurement_type: DipoleMeasurementType
def __post_init__(self) -> None:
"""
Coerce the inputs into numpy arrays.
"""
self.p = numpy.array(self.p)
self.s = numpy.array(self.s)
self.state = 1
self.cache = {}
for pos in self.dot_positions:
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
self.cache[pos.label] = self.potential(pos)
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
self.cache[pos.label] = self.e_field_x(pos)
def potential(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
def e_field_x(self, dot: DotPosition) -> float:
# let's assume single dot at origin for now
r_diff = self.s - dot.r
norm = numpy.linalg.norm(r_diff)
return (
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
)[0]
def transition(
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
) -> typing.Dict[str, float]:
rng: numpy.random.Generator
if rng_to_use is None:
rng = numpy.random.default_rng()
else:
rng = rng_to_use
# if on average flipping often, then just return 0, basically this dipole has been all used up.
# Facilitates going for different types of noise at very low freq?
if dt * 10 >= 1 / self.w:
# _logger.warning(
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
# )
self.state = rng.integers(0, 1, endpoint=True)
else:
prob = dt * self.w
if rng.random() < prob:
# _logger.debug("flip!")
self.flip_state()
return {k: self.state * v for k, v in self.cache.items()}
def flip_state(self):
self.state *= -1
def get_wrapped_dipoles(
dipole_tos: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
) -> typing.Sequence[WrappedDipole]:
return [
WrappedDipole(
p=dipole_to.p,
s=dipole_to.s,
w=dipole_to.w,
dot_positions=dots,
measurement_type=measurement_type,
)
for dipole_to in dipole_tos
]
class DipoleTimeSeries:
def __init__(
self,
dipoles: typing.Sequence[DipoleTO],
dots: typing.Sequence[DotPosition],
measurement_type: DipoleMeasurementType,
dt: float,
rng_to_use: typing.Optional[numpy.random.Generator] = None,
):
self.rng: numpy.random.Generator
if rng_to_use is None:
self.rng = numpy.random.default_rng()
else:
self.rng = rng_to_use
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
self.state = 0
# we may need to supersample, because of how dumb this process is.
# let's find our highest frequency
max_frequency = max(d.w for d in self.dipoles)
super_sample = tantri.dipoles.supersample.get_supersample(max_frequency, dt)
self.dt = super_sample.super_dt
self.super_sample_ratio = super_sample.super_sample_ratio
def transition(self) -> typing.Dict[str, float]:
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
ret = {}
for transition in new_vals:
for k, v in transition.items():
if k not in ret:
ret[k] = v
else:
ret[k] += v
return ret

View File

@ -1,5 +1,6 @@
import numpy
from dataclasses import dataclass, asdict
from enum import Enum
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
@ -15,3 +16,59 @@ class DipoleTO:
def as_dict(self) -> dict:
return asdict(self)
class Orientation(str, Enum):
# Enum for orientation, making string for json serialisation purposes
#
# Note that htis might not be infinitely extensible?
# https://stackoverflow.com/questions/75040733/is-there-a-way-to-use-strenum-in-earlier-python-versions
XY = "XY"
Z = "Z"
RANDOM = "RANDOM"
# A description of the parameters needed to generate random dipoles
@dataclass
class DipoleGenerationConfig:
# note no actual checks anywhere that these are sensibly defined with min less than max etc.
x_min: float
x_max: float
y_min: float
y_max: float
z_min: float
z_max: float
mag: float
# these are log_10 of actual value
w_log_min: float
w_log_max: float
orientation: Orientation
dipole_count: int
generation_seed: int
def __post_init__(self):
# This allows us to transparently set this with a string, while providing early warning of a type error
self.orientation = Orientation(self.orientation)
def as_dict(self) -> dict:
return_dict = asdict(self)
return_dict["orientation"] = return_dict["orientation"].value
return return_dict
class DipoleMeasurementType(Enum):
ELECTRIC_POTENTIAL = 1
X_ELECTRIC_FIELD = 2
@dataclass(frozen=True)
class DotPosition:
# assume len 3
r: numpy.ndarray
label: str

View File

@ -1,4 +1,5 @@
from tantri.dipoles.generation import DipoleGenerationConfig, make_dipoles, Orientation
from tantri.dipoles.types import DipoleGenerationConfig, Orientation
from tantri.dipoles.generation import make_dipoles
import numpy

View File

@ -1,5 +1,5 @@
import json
from tantri.dipoles.generation import DipoleGenerationConfig, Orientation
from tantri.dipoles.types import DipoleGenerationConfig, Orientation
def test_serialise_generation_config_to_json(snapshot):

View File

@ -0,0 +1,9 @@
from tantri.dipoles.supersample import get_supersample, SuperSample
def test_raw_supersample():
# let's say we go really really slow
dt = 1
max_frequency = 0.0001
assert get_supersample(max_frequency, dt) == SuperSample(1, 1)