ref: refactoring types and stuff
This commit is contained in:
parent
f55bbe66ee
commit
ca371762ae
@ -7,6 +7,7 @@ import tantri.cli.file_importer
|
|||||||
import tantri.dipoles
|
import tantri.dipoles
|
||||||
import json
|
import json
|
||||||
import tantri.dipoles.generation
|
import tantri.dipoles.generation
|
||||||
|
import tantri.dipoles.types
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
|
|
||||||
@ -205,7 +206,7 @@ def _generate_dipoles(generation_config, output_file, override_rng_seed):
|
|||||||
|
|
||||||
with open(generation_config, "r") as config_file:
|
with open(generation_config, "r") as config_file:
|
||||||
data = json.load(config_file)
|
data = json.load(config_file)
|
||||||
config = tantri.dipoles.generation.DipoleGenerationConfig(**data)
|
config = tantri.dipoles.types.DipoleGenerationConfig(**data)
|
||||||
|
|
||||||
override_rng = None
|
override_rng = None
|
||||||
if override_rng_seed is not None:
|
if override_rng_seed is not None:
|
||||||
|
@ -1,150 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from tantri.dipoles.types import DipoleTO, DotPosition, DipoleMeasurementType
|
||||||
import numpy
|
from tantri.dipoles.time_series import DipoleTimeSeries, WrappedDipole
|
||||||
import numpy.random
|
from tantri.dipoles.generation import make_dipoles
|
||||||
import typing
|
|
||||||
from enum import Enum
|
|
||||||
from tantri.dipoles.types import DipoleTO
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DipoleMeasurementType(Enum):
|
|
||||||
ELECTRIC_POTENTIAL = 1
|
|
||||||
X_ELECTRIC_FIELD = 2
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class DotPosition:
|
|
||||||
# assume len 3
|
|
||||||
r: numpy.ndarray
|
|
||||||
label: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WrappedDipole:
|
|
||||||
# assumed len 3
|
|
||||||
p: numpy.ndarray
|
|
||||||
s: numpy.ndarray
|
|
||||||
|
|
||||||
# should be 1/tau up to some pis
|
|
||||||
w: float
|
|
||||||
|
|
||||||
# For caching purposes tell each dipole where the dots are
|
|
||||||
# TODO: This can be done better by only passing into the time series the non-repeated p s and w,
|
|
||||||
# TODO: and then creating a new wrapper type to include all the cached stuff.
|
|
||||||
# TODO: Realistically, the dot positions and measurement type data should live in the time series.
|
|
||||||
dot_positions: typing.Sequence[DotPosition]
|
|
||||||
|
|
||||||
measurement_type: DipoleMeasurementType
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
"""
|
|
||||||
Coerce the inputs into numpy arrays.
|
|
||||||
"""
|
|
||||||
self.p = numpy.array(self.p)
|
|
||||||
self.s = numpy.array(self.s)
|
|
||||||
|
|
||||||
self.state = 1
|
|
||||||
self.cache = {}
|
|
||||||
for pos in self.dot_positions:
|
|
||||||
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
|
|
||||||
self.cache[pos.label] = self.potential(pos)
|
|
||||||
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
|
|
||||||
self.cache[pos.label] = self.e_field_x(pos)
|
|
||||||
|
|
||||||
def potential(self, dot: DotPosition) -> float:
|
|
||||||
# let's assume single dot at origin for now
|
|
||||||
r_diff = self.s - dot.r
|
|
||||||
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
|
|
||||||
|
|
||||||
def e_field_x(self, dot: DotPosition) -> float:
|
|
||||||
# let's assume single dot at origin for now
|
|
||||||
r_diff = self.s - dot.r
|
|
||||||
norm = numpy.linalg.norm(r_diff)
|
|
||||||
|
|
||||||
return (
|
|
||||||
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
def transition(
|
|
||||||
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
|
|
||||||
) -> typing.Dict[str, float]:
|
|
||||||
rng: numpy.random.Generator
|
|
||||||
if rng_to_use is None:
|
|
||||||
rng = numpy.random.default_rng()
|
|
||||||
else:
|
|
||||||
rng = rng_to_use
|
|
||||||
# if on average flipping often, then just return 0, basically this dipole has been all used up.
|
|
||||||
# Facilitates going for different types of noise at very low freq?
|
|
||||||
if dt * 10 >= 1 / self.w:
|
|
||||||
# _logger.warning(
|
|
||||||
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
|
|
||||||
# )
|
|
||||||
self.state = rng.integers(0, 1, endpoint=True)
|
|
||||||
else:
|
|
||||||
prob = dt * self.w
|
|
||||||
if rng.random() < prob:
|
|
||||||
# _logger.debug("flip!")
|
|
||||||
self.flip_state()
|
|
||||||
return {k: self.state * v for k, v in self.cache.items()}
|
|
||||||
|
|
||||||
def flip_state(self):
|
|
||||||
self.state *= -1
|
|
||||||
|
|
||||||
|
|
||||||
def get_wrapped_dipoles(
|
|
||||||
dipole_tos: typing.Sequence[DipoleTO],
|
|
||||||
dots: typing.Sequence[DotPosition],
|
|
||||||
measurement_type: DipoleMeasurementType,
|
|
||||||
) -> typing.Sequence[WrappedDipole]:
|
|
||||||
return [
|
|
||||||
WrappedDipole(
|
|
||||||
p=dipole_to.p,
|
|
||||||
s=dipole_to.s,
|
|
||||||
w=dipole_to.w,
|
|
||||||
dot_positions=dots,
|
|
||||||
measurement_type=measurement_type,
|
|
||||||
)
|
|
||||||
for dipole_to in dipole_tos
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class DipoleTimeSeries:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dipoles: typing.Sequence[DipoleTO],
|
|
||||||
dots: typing.Sequence[DotPosition],
|
|
||||||
measurement_type: DipoleMeasurementType,
|
|
||||||
dt: float,
|
|
||||||
rng_to_use: typing.Optional[numpy.random.Generator] = None,
|
|
||||||
):
|
|
||||||
self.rng: numpy.random.Generator
|
|
||||||
if rng_to_use is None:
|
|
||||||
self.rng = numpy.random.default_rng()
|
|
||||||
else:
|
|
||||||
self.rng = rng_to_use
|
|
||||||
|
|
||||||
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
|
|
||||||
self.state = 0
|
|
||||||
self.dt = dt
|
|
||||||
|
|
||||||
def transition(self) -> typing.Dict[str, float]:
|
|
||||||
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
|
|
||||||
|
|
||||||
ret = {}
|
|
||||||
for transition in new_vals:
|
|
||||||
for k, v in transition.items():
|
|
||||||
if k not in ret:
|
|
||||||
ret[k] = v
|
|
||||||
else:
|
|
||||||
ret[k] += v
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WrappedDipole",
|
"WrappedDipole",
|
||||||
"DipoleTimeSeries",
|
"DipoleTimeSeries",
|
||||||
"DipoleTO",
|
"DipoleTO",
|
||||||
|
"DotPosition",
|
||||||
|
"DipoleMeasurementType",
|
||||||
|
"make_dipoles",
|
||||||
]
|
]
|
||||||
|
@ -1,107 +1,5 @@
|
|||||||
import numpy
|
from tantri.dipoles.generation.generate_dipole_config import make_dipoles
|
||||||
from typing import Sequence, Optional
|
|
||||||
from dataclasses import dataclass, asdict
|
|
||||||
from tantri.dipoles.types import DipoleTO
|
|
||||||
from enum import Enum
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
# stuff for generating random dipoles from parameters
|
"make_dipoles",
|
||||||
|
]
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Orientation(str, Enum):
|
|
||||||
# Enum for orientation, making string for json serialisation purposes
|
|
||||||
#
|
|
||||||
# Note that htis might not be infinitely extensible?
|
|
||||||
# https://stackoverflow.com/questions/75040733/is-there-a-way-to-use-strenum-in-earlier-python-versions
|
|
||||||
XY = "XY"
|
|
||||||
Z = "Z"
|
|
||||||
RANDOM = "RANDOM"
|
|
||||||
|
|
||||||
|
|
||||||
# A description of the parameters needed to generate random dipoles
|
|
||||||
@dataclass
|
|
||||||
class DipoleGenerationConfig:
|
|
||||||
# note no actual checks anywhere that these are sensibly defined with min less than max etc.
|
|
||||||
x_min: float
|
|
||||||
x_max: float
|
|
||||||
y_min: float
|
|
||||||
y_max: float
|
|
||||||
z_min: float
|
|
||||||
z_max: float
|
|
||||||
|
|
||||||
mag: float
|
|
||||||
|
|
||||||
# these are log_10 of actual value
|
|
||||||
w_log_min: float
|
|
||||||
w_log_max: float
|
|
||||||
|
|
||||||
orientation: Orientation
|
|
||||||
|
|
||||||
dipole_count: int
|
|
||||||
generation_seed: int
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# This allows us to transparently set this with a string, while providing early warning of a type error
|
|
||||||
self.orientation = Orientation(self.orientation)
|
|
||||||
|
|
||||||
def as_dict(self) -> dict:
|
|
||||||
return_dict = asdict(self)
|
|
||||||
|
|
||||||
return_dict["orientation"] = return_dict["orientation"].value
|
|
||||||
|
|
||||||
return return_dict
|
|
||||||
|
|
||||||
|
|
||||||
def make_dipoles(
|
|
||||||
config: DipoleGenerationConfig,
|
|
||||||
rng_override: Optional[numpy.random.Generator] = None,
|
|
||||||
) -> Sequence[DipoleTO]:
|
|
||||||
|
|
||||||
if rng_override is None:
|
|
||||||
_logger.info(
|
|
||||||
f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation"
|
|
||||||
)
|
|
||||||
rng = numpy.random.default_rng(config.generation_seed)
|
|
||||||
else:
|
|
||||||
_logger.info("Using overridden rng, of unknown seed")
|
|
||||||
rng = rng_override
|
|
||||||
|
|
||||||
dipoles = []
|
|
||||||
|
|
||||||
for i in range(config.dipole_count):
|
|
||||||
sx = rng.uniform(config.x_min, config.x_max)
|
|
||||||
sy = rng.uniform(config.y_min, config.y_max)
|
|
||||||
sz = rng.uniform(config.z_min, config.z_max)
|
|
||||||
|
|
||||||
# orientation
|
|
||||||
# 0, 1, 2
|
|
||||||
# xy, z, random
|
|
||||||
|
|
||||||
if config.orientation is Orientation.RANDOM:
|
|
||||||
theta = numpy.arccos(2 * rng.random() - 1)
|
|
||||||
phi = 2 * numpy.pi * rng.random()
|
|
||||||
elif config.orientation is Orientation.Z:
|
|
||||||
theta = 0
|
|
||||||
phi = 0
|
|
||||||
elif config.orientation is Orientation.XY:
|
|
||||||
theta = numpy.pi / 2
|
|
||||||
phi = 2 * numpy.pi * rng.random()
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"this shouldn't have happened, orientation index: {config}"
|
|
||||||
)
|
|
||||||
|
|
||||||
px = config.mag * numpy.cos(phi) * numpy.sin(theta)
|
|
||||||
py = config.mag * numpy.sin(phi) * numpy.sin(theta)
|
|
||||||
pz = config.mag * numpy.cos(theta)
|
|
||||||
|
|
||||||
w = 10 ** rng.uniform(config.w_log_min, config.w_log_max)
|
|
||||||
|
|
||||||
dipoles.append(
|
|
||||||
DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w)
|
|
||||||
)
|
|
||||||
|
|
||||||
return dipoles
|
|
||||||
|
61
tantri/dipoles/generation/generate_dipole_config.py
Normal file
61
tantri/dipoles/generation/generate_dipole_config.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import numpy
|
||||||
|
from typing import Sequence, Optional
|
||||||
|
from tantri.dipoles.types import DipoleTO, DipoleGenerationConfig, Orientation
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
# stuff for generating random dipoles from parameters
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def make_dipoles(
|
||||||
|
config: DipoleGenerationConfig,
|
||||||
|
rng_override: Optional[numpy.random.Generator] = None,
|
||||||
|
) -> Sequence[DipoleTO]:
|
||||||
|
|
||||||
|
if rng_override is None:
|
||||||
|
_logger.info(
|
||||||
|
f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation"
|
||||||
|
)
|
||||||
|
rng = numpy.random.default_rng(config.generation_seed)
|
||||||
|
else:
|
||||||
|
_logger.info("Using overridden rng, of unknown seed")
|
||||||
|
rng = rng_override
|
||||||
|
|
||||||
|
dipoles = []
|
||||||
|
|
||||||
|
for i in range(config.dipole_count):
|
||||||
|
sx = rng.uniform(config.x_min, config.x_max)
|
||||||
|
sy = rng.uniform(config.y_min, config.y_max)
|
||||||
|
sz = rng.uniform(config.z_min, config.z_max)
|
||||||
|
|
||||||
|
# orientation
|
||||||
|
# 0, 1, 2
|
||||||
|
# xy, z, random
|
||||||
|
|
||||||
|
if config.orientation is Orientation.RANDOM:
|
||||||
|
theta = numpy.arccos(2 * rng.random() - 1)
|
||||||
|
phi = 2 * numpy.pi * rng.random()
|
||||||
|
elif config.orientation is Orientation.Z:
|
||||||
|
theta = 0
|
||||||
|
phi = 0
|
||||||
|
elif config.orientation is Orientation.XY:
|
||||||
|
theta = numpy.pi / 2
|
||||||
|
phi = 2 * numpy.pi * rng.random()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"this shouldn't have happened, orientation index: {config}"
|
||||||
|
)
|
||||||
|
|
||||||
|
px = config.mag * numpy.cos(phi) * numpy.sin(theta)
|
||||||
|
py = config.mag * numpy.sin(phi) * numpy.sin(theta)
|
||||||
|
pz = config.mag * numpy.cos(theta)
|
||||||
|
|
||||||
|
w = 10 ** rng.uniform(config.w_log_min, config.w_log_max)
|
||||||
|
|
||||||
|
dipoles.append(
|
||||||
|
DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w)
|
||||||
|
)
|
||||||
|
|
||||||
|
return dipoles
|
30
tantri/dipoles/supersample.py
Normal file
30
tantri/dipoles/supersample.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# import math
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# how many times faster than the max frequency we want to be, bigger is more accurate but 10 is probably fine
|
||||||
|
DESIRED_THRESHOLD = 10
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SuperSample:
|
||||||
|
super_dt: float
|
||||||
|
super_sample_ratio: int
|
||||||
|
|
||||||
|
|
||||||
|
def get_supersample(max_frequency: float, dt: float) -> SuperSample:
|
||||||
|
# now we want to sample at least 10x faster than max_frequency, otherwise we're going to skew our statistics
|
||||||
|
# note that this is why if performance mattered we'd be optimising this to pre-gen our flip times with poisson statistics.
|
||||||
|
# so we want (1/dt) > 10 * max_freq
|
||||||
|
if DESIRED_THRESHOLD * dt * max_frequency < 1:
|
||||||
|
# can return unchanged
|
||||||
|
_logger.debug("no supersampling needed")
|
||||||
|
return SuperSample(super_dt=dt, super_sample_ratio=1)
|
||||||
|
else:
|
||||||
|
# else we want a such that a / dt > 10 * max_freq, or a > 10 * dt * max_freq, a = math.ceil(10 * dt * max_freq)
|
||||||
|
# a = math.ceil(DESIRED_THRESHOLD * dt * max_frequency)
|
||||||
|
# return SuperSample(super_dt=dt / a, super_sample_ratio=a)
|
||||||
|
return SuperSample(super_dt=dt, super_sample_ratio=1)
|
134
tantri/dipoles/time_series.py
Normal file
134
tantri/dipoles/time_series.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy
|
||||||
|
import numpy.random
|
||||||
|
import typing
|
||||||
|
from tantri.dipoles.types import DipoleTO, DotPosition, DipoleMeasurementType
|
||||||
|
import tantri.dipoles.supersample
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WrappedDipole:
|
||||||
|
# assumed len 3
|
||||||
|
p: numpy.ndarray
|
||||||
|
s: numpy.ndarray
|
||||||
|
|
||||||
|
# should be 1/tau up to some pis
|
||||||
|
w: float
|
||||||
|
|
||||||
|
# For caching purposes tell each dipole where the dots are
|
||||||
|
# TODO: This can be done better by only passing into the time series the non-repeated p s and w,
|
||||||
|
# TODO: and then creating a new wrapper type to include all the cached stuff.
|
||||||
|
# TODO: Realistically, the dot positions and measurement type data should live in the time series.
|
||||||
|
dot_positions: typing.Sequence[DotPosition]
|
||||||
|
|
||||||
|
measurement_type: DipoleMeasurementType
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Coerce the inputs into numpy arrays.
|
||||||
|
"""
|
||||||
|
self.p = numpy.array(self.p)
|
||||||
|
self.s = numpy.array(self.s)
|
||||||
|
|
||||||
|
self.state = 1
|
||||||
|
self.cache = {}
|
||||||
|
for pos in self.dot_positions:
|
||||||
|
if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL:
|
||||||
|
self.cache[pos.label] = self.potential(pos)
|
||||||
|
elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD:
|
||||||
|
self.cache[pos.label] = self.e_field_x(pos)
|
||||||
|
|
||||||
|
def potential(self, dot: DotPosition) -> float:
|
||||||
|
# let's assume single dot at origin for now
|
||||||
|
r_diff = self.s - dot.r
|
||||||
|
return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3)
|
||||||
|
|
||||||
|
def e_field_x(self, dot: DotPosition) -> float:
|
||||||
|
# let's assume single dot at origin for now
|
||||||
|
r_diff = self.s - dot.r
|
||||||
|
norm = numpy.linalg.norm(r_diff)
|
||||||
|
|
||||||
|
return (
|
||||||
|
((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3)
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def transition(
|
||||||
|
self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None
|
||||||
|
) -> typing.Dict[str, float]:
|
||||||
|
rng: numpy.random.Generator
|
||||||
|
if rng_to_use is None:
|
||||||
|
rng = numpy.random.default_rng()
|
||||||
|
else:
|
||||||
|
rng = rng_to_use
|
||||||
|
# if on average flipping often, then just return 0, basically this dipole has been all used up.
|
||||||
|
# Facilitates going for different types of noise at very low freq?
|
||||||
|
if dt * 10 >= 1 / self.w:
|
||||||
|
# _logger.warning(
|
||||||
|
# f"delta t {dt} is too long compared to dipole frequency {self.w}"
|
||||||
|
# )
|
||||||
|
self.state = rng.integers(0, 1, endpoint=True)
|
||||||
|
else:
|
||||||
|
prob = dt * self.w
|
||||||
|
if rng.random() < prob:
|
||||||
|
# _logger.debug("flip!")
|
||||||
|
self.flip_state()
|
||||||
|
return {k: self.state * v for k, v in self.cache.items()}
|
||||||
|
|
||||||
|
def flip_state(self):
|
||||||
|
self.state *= -1
|
||||||
|
|
||||||
|
|
||||||
|
def get_wrapped_dipoles(
|
||||||
|
dipole_tos: typing.Sequence[DipoleTO],
|
||||||
|
dots: typing.Sequence[DotPosition],
|
||||||
|
measurement_type: DipoleMeasurementType,
|
||||||
|
) -> typing.Sequence[WrappedDipole]:
|
||||||
|
return [
|
||||||
|
WrappedDipole(
|
||||||
|
p=dipole_to.p,
|
||||||
|
s=dipole_to.s,
|
||||||
|
w=dipole_to.w,
|
||||||
|
dot_positions=dots,
|
||||||
|
measurement_type=measurement_type,
|
||||||
|
)
|
||||||
|
for dipole_to in dipole_tos
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DipoleTimeSeries:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dipoles: typing.Sequence[DipoleTO],
|
||||||
|
dots: typing.Sequence[DotPosition],
|
||||||
|
measurement_type: DipoleMeasurementType,
|
||||||
|
dt: float,
|
||||||
|
rng_to_use: typing.Optional[numpy.random.Generator] = None,
|
||||||
|
):
|
||||||
|
self.rng: numpy.random.Generator
|
||||||
|
if rng_to_use is None:
|
||||||
|
self.rng = numpy.random.default_rng()
|
||||||
|
else:
|
||||||
|
self.rng = rng_to_use
|
||||||
|
|
||||||
|
self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type)
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
|
# we may need to supersample, because of how dumb this process is.
|
||||||
|
# let's find our highest frequency
|
||||||
|
max_frequency = max(d.w for d in self.dipoles)
|
||||||
|
|
||||||
|
super_sample = tantri.dipoles.supersample.get_supersample(max_frequency, dt)
|
||||||
|
self.dt = super_sample.super_dt
|
||||||
|
self.super_sample_ratio = super_sample.super_sample_ratio
|
||||||
|
|
||||||
|
def transition(self) -> typing.Dict[str, float]:
|
||||||
|
new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles]
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for transition in new_vals:
|
||||||
|
for k, v in transition.items():
|
||||||
|
if k not in ret:
|
||||||
|
ret[k] = v
|
||||||
|
else:
|
||||||
|
ret[k] += v
|
||||||
|
return ret
|
@ -1,5 +1,6 @@
|
|||||||
import numpy
|
import numpy
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
|
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
|
||||||
@ -15,3 +16,59 @@ class DipoleTO:
|
|||||||
|
|
||||||
def as_dict(self) -> dict:
|
def as_dict(self) -> dict:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
|
||||||
|
class Orientation(str, Enum):
|
||||||
|
# Enum for orientation, making string for json serialisation purposes
|
||||||
|
#
|
||||||
|
# Note that htis might not be infinitely extensible?
|
||||||
|
# https://stackoverflow.com/questions/75040733/is-there-a-way-to-use-strenum-in-earlier-python-versions
|
||||||
|
XY = "XY"
|
||||||
|
Z = "Z"
|
||||||
|
RANDOM = "RANDOM"
|
||||||
|
|
||||||
|
|
||||||
|
# A description of the parameters needed to generate random dipoles
|
||||||
|
@dataclass
|
||||||
|
class DipoleGenerationConfig:
|
||||||
|
# note no actual checks anywhere that these are sensibly defined with min less than max etc.
|
||||||
|
x_min: float
|
||||||
|
x_max: float
|
||||||
|
y_min: float
|
||||||
|
y_max: float
|
||||||
|
z_min: float
|
||||||
|
z_max: float
|
||||||
|
|
||||||
|
mag: float
|
||||||
|
|
||||||
|
# these are log_10 of actual value
|
||||||
|
w_log_min: float
|
||||||
|
w_log_max: float
|
||||||
|
|
||||||
|
orientation: Orientation
|
||||||
|
|
||||||
|
dipole_count: int
|
||||||
|
generation_seed: int
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# This allows us to transparently set this with a string, while providing early warning of a type error
|
||||||
|
self.orientation = Orientation(self.orientation)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return_dict = asdict(self)
|
||||||
|
|
||||||
|
return_dict["orientation"] = return_dict["orientation"].value
|
||||||
|
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DipoleMeasurementType(Enum):
|
||||||
|
ELECTRIC_POTENTIAL = 1
|
||||||
|
X_ELECTRIC_FIELD = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DotPosition:
|
||||||
|
# assume len 3
|
||||||
|
r: numpy.ndarray
|
||||||
|
label: str
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from tantri.dipoles.generation import DipoleGenerationConfig, make_dipoles, Orientation
|
from tantri.dipoles.types import DipoleGenerationConfig, Orientation
|
||||||
|
from tantri.dipoles.generation import make_dipoles
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from tantri.dipoles.generation import DipoleGenerationConfig, Orientation
|
from tantri.dipoles.types import DipoleGenerationConfig, Orientation
|
||||||
|
|
||||||
|
|
||||||
def test_serialise_generation_config_to_json(snapshot):
|
def test_serialise_generation_config_to_json(snapshot):
|
||||||
|
9
tests/dipoles/test_supersample.py
Normal file
9
tests/dipoles/test_supersample.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from tantri.dipoles.supersample import get_supersample, SuperSample
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_supersample():
|
||||||
|
# let's say we go really really slow
|
||||||
|
dt = 1
|
||||||
|
max_frequency = 0.0001
|
||||||
|
|
||||||
|
assert get_supersample(max_frequency, dt) == SuperSample(1, 1)
|
Loading…
x
Reference in New Issue
Block a user