diff --git a/tantri/cli/__init__.py b/tantri/cli/__init__.py index 2ae7993..a767597 100755 --- a/tantri/cli/__init__.py +++ b/tantri/cli/__init__.py @@ -7,6 +7,7 @@ import tantri.cli.file_importer import tantri.dipoles import json import tantri.dipoles.generation +import tantri.dipoles.types import pathlib @@ -205,7 +206,7 @@ def _generate_dipoles(generation_config, output_file, override_rng_seed): with open(generation_config, "r") as config_file: data = json.load(config_file) - config = tantri.dipoles.generation.DipoleGenerationConfig(**data) + config = tantri.dipoles.types.DipoleGenerationConfig(**data) override_rng = None if override_rng_seed is not None: diff --git a/tantri/dipoles/__init__.py b/tantri/dipoles/__init__.py index e17696a..190b0eb 100755 --- a/tantri/dipoles/__init__.py +++ b/tantri/dipoles/__init__.py @@ -1,150 +1,12 @@ -from dataclasses import dataclass -import numpy -import numpy.random -import typing -from enum import Enum -from tantri.dipoles.types import DipoleTO - -import logging - -_logger = logging.getLogger(__name__) - - -class DipoleMeasurementType(Enum): - ELECTRIC_POTENTIAL = 1 - X_ELECTRIC_FIELD = 2 - - -@dataclass(frozen=True) -class DotPosition: - # assume len 3 - r: numpy.ndarray - label: str - - -@dataclass -class WrappedDipole: - # assumed len 3 - p: numpy.ndarray - s: numpy.ndarray - - # should be 1/tau up to some pis - w: float - - # For caching purposes tell each dipole where the dots are - # TODO: This can be done better by only passing into the time series the non-repeated p s and w, - # TODO: and then creating a new wrapper type to include all the cached stuff. - # TODO: Realistically, the dot positions and measurement type data should live in the time series. - dot_positions: typing.Sequence[DotPosition] - - measurement_type: DipoleMeasurementType - - def __post_init__(self) -> None: - """ - Coerce the inputs into numpy arrays. - """ - self.p = numpy.array(self.p) - self.s = numpy.array(self.s) - - self.state = 1 - self.cache = {} - for pos in self.dot_positions: - if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL: - self.cache[pos.label] = self.potential(pos) - elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD: - self.cache[pos.label] = self.e_field_x(pos) - - def potential(self, dot: DotPosition) -> float: - # let's assume single dot at origin for now - r_diff = self.s - dot.r - return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3) - - def e_field_x(self, dot: DotPosition) -> float: - # let's assume single dot at origin for now - r_diff = self.s - dot.r - norm = numpy.linalg.norm(r_diff) - - return ( - ((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3) - )[0] - - def transition( - self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None - ) -> typing.Dict[str, float]: - rng: numpy.random.Generator - if rng_to_use is None: - rng = numpy.random.default_rng() - else: - rng = rng_to_use - # if on average flipping often, then just return 0, basically this dipole has been all used up. - # Facilitates going for different types of noise at very low freq? - if dt * 10 >= 1 / self.w: - # _logger.warning( - # f"delta t {dt} is too long compared to dipole frequency {self.w}" - # ) - self.state = rng.integers(0, 1, endpoint=True) - else: - prob = dt * self.w - if rng.random() < prob: - # _logger.debug("flip!") - self.flip_state() - return {k: self.state * v for k, v in self.cache.items()} - - def flip_state(self): - self.state *= -1 - - -def get_wrapped_dipoles( - dipole_tos: typing.Sequence[DipoleTO], - dots: typing.Sequence[DotPosition], - measurement_type: DipoleMeasurementType, -) -> typing.Sequence[WrappedDipole]: - return [ - WrappedDipole( - p=dipole_to.p, - s=dipole_to.s, - w=dipole_to.w, - dot_positions=dots, - measurement_type=measurement_type, - ) - for dipole_to in dipole_tos - ] - - -class DipoleTimeSeries: - def __init__( - self, - dipoles: typing.Sequence[DipoleTO], - dots: typing.Sequence[DotPosition], - measurement_type: DipoleMeasurementType, - dt: float, - rng_to_use: typing.Optional[numpy.random.Generator] = None, - ): - self.rng: numpy.random.Generator - if rng_to_use is None: - self.rng = numpy.random.default_rng() - else: - self.rng = rng_to_use - - self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type) - self.state = 0 - self.dt = dt - - def transition(self) -> typing.Dict[str, float]: - new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles] - - ret = {} - for transition in new_vals: - for k, v in transition.items(): - if k not in ret: - ret[k] = v - else: - ret[k] += v - return ret - +from tantri.dipoles.types import DipoleTO, DotPosition, DipoleMeasurementType +from tantri.dipoles.time_series import DipoleTimeSeries, WrappedDipole +from tantri.dipoles.generation import make_dipoles __all__ = [ "WrappedDipole", "DipoleTimeSeries", "DipoleTO", + "DotPosition", + "DipoleMeasurementType", + "make_dipoles", ] diff --git a/tantri/dipoles/generation/__init__.py b/tantri/dipoles/generation/__init__.py index 76ecb4e..0bd5c82 100755 --- a/tantri/dipoles/generation/__init__.py +++ b/tantri/dipoles/generation/__init__.py @@ -1,107 +1,5 @@ -import numpy -from typing import Sequence, Optional -from dataclasses import dataclass, asdict -from tantri.dipoles.types import DipoleTO -from enum import Enum -import logging +from tantri.dipoles.generation.generate_dipole_config import make_dipoles - -# stuff for generating random dipoles from parameters - -_logger = logging.getLogger(__name__) - - -class Orientation(str, Enum): - # Enum for orientation, making string for json serialisation purposes - # - # Note that htis might not be infinitely extensible? - # https://stackoverflow.com/questions/75040733/is-there-a-way-to-use-strenum-in-earlier-python-versions - XY = "XY" - Z = "Z" - RANDOM = "RANDOM" - - -# A description of the parameters needed to generate random dipoles -@dataclass -class DipoleGenerationConfig: - # note no actual checks anywhere that these are sensibly defined with min less than max etc. - x_min: float - x_max: float - y_min: float - y_max: float - z_min: float - z_max: float - - mag: float - - # these are log_10 of actual value - w_log_min: float - w_log_max: float - - orientation: Orientation - - dipole_count: int - generation_seed: int - - def __post_init__(self): - # This allows us to transparently set this with a string, while providing early warning of a type error - self.orientation = Orientation(self.orientation) - - def as_dict(self) -> dict: - return_dict = asdict(self) - - return_dict["orientation"] = return_dict["orientation"].value - - return return_dict - - -def make_dipoles( - config: DipoleGenerationConfig, - rng_override: Optional[numpy.random.Generator] = None, -) -> Sequence[DipoleTO]: - - if rng_override is None: - _logger.info( - f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation" - ) - rng = numpy.random.default_rng(config.generation_seed) - else: - _logger.info("Using overridden rng, of unknown seed") - rng = rng_override - - dipoles = [] - - for i in range(config.dipole_count): - sx = rng.uniform(config.x_min, config.x_max) - sy = rng.uniform(config.y_min, config.y_max) - sz = rng.uniform(config.z_min, config.z_max) - - # orientation - # 0, 1, 2 - # xy, z, random - - if config.orientation is Orientation.RANDOM: - theta = numpy.arccos(2 * rng.random() - 1) - phi = 2 * numpy.pi * rng.random() - elif config.orientation is Orientation.Z: - theta = 0 - phi = 0 - elif config.orientation is Orientation.XY: - theta = numpy.pi / 2 - phi = 2 * numpy.pi * rng.random() - else: - raise ValueError( - f"this shouldn't have happened, orientation index: {config}" - ) - - px = config.mag * numpy.cos(phi) * numpy.sin(theta) - py = config.mag * numpy.sin(phi) * numpy.sin(theta) - pz = config.mag * numpy.cos(theta) - - w = 10 ** rng.uniform(config.w_log_min, config.w_log_max) - - dipoles.append( - DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w) - ) - - return dipoles +__all__ = [ + "make_dipoles", +] diff --git a/tantri/dipoles/generation/generate_dipole_config.py b/tantri/dipoles/generation/generate_dipole_config.py new file mode 100644 index 0000000..1e2403c --- /dev/null +++ b/tantri/dipoles/generation/generate_dipole_config.py @@ -0,0 +1,61 @@ +import numpy +from typing import Sequence, Optional +from tantri.dipoles.types import DipoleTO, DipoleGenerationConfig, Orientation +import logging + + +# stuff for generating random dipoles from parameters + +_logger = logging.getLogger(__name__) + + +def make_dipoles( + config: DipoleGenerationConfig, + rng_override: Optional[numpy.random.Generator] = None, +) -> Sequence[DipoleTO]: + + if rng_override is None: + _logger.info( + f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation" + ) + rng = numpy.random.default_rng(config.generation_seed) + else: + _logger.info("Using overridden rng, of unknown seed") + rng = rng_override + + dipoles = [] + + for i in range(config.dipole_count): + sx = rng.uniform(config.x_min, config.x_max) + sy = rng.uniform(config.y_min, config.y_max) + sz = rng.uniform(config.z_min, config.z_max) + + # orientation + # 0, 1, 2 + # xy, z, random + + if config.orientation is Orientation.RANDOM: + theta = numpy.arccos(2 * rng.random() - 1) + phi = 2 * numpy.pi * rng.random() + elif config.orientation is Orientation.Z: + theta = 0 + phi = 0 + elif config.orientation is Orientation.XY: + theta = numpy.pi / 2 + phi = 2 * numpy.pi * rng.random() + else: + raise ValueError( + f"this shouldn't have happened, orientation index: {config}" + ) + + px = config.mag * numpy.cos(phi) * numpy.sin(theta) + py = config.mag * numpy.sin(phi) * numpy.sin(theta) + pz = config.mag * numpy.cos(theta) + + w = 10 ** rng.uniform(config.w_log_min, config.w_log_max) + + dipoles.append( + DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w) + ) + + return dipoles diff --git a/tantri/dipoles/supersample.py b/tantri/dipoles/supersample.py new file mode 100644 index 0000000..2f12ea6 --- /dev/null +++ b/tantri/dipoles/supersample.py @@ -0,0 +1,30 @@ +import dataclasses +import logging + +# import math + +_logger = logging.getLogger(__name__) + +# how many times faster than the max frequency we want to be, bigger is more accurate but 10 is probably fine +DESIRED_THRESHOLD = 10 + + +@dataclasses.dataclass +class SuperSample: + super_dt: float + super_sample_ratio: int + + +def get_supersample(max_frequency: float, dt: float) -> SuperSample: + # now we want to sample at least 10x faster than max_frequency, otherwise we're going to skew our statistics + # note that this is why if performance mattered we'd be optimising this to pre-gen our flip times with poisson statistics. + # so we want (1/dt) > 10 * max_freq + if DESIRED_THRESHOLD * dt * max_frequency < 1: + # can return unchanged + _logger.debug("no supersampling needed") + return SuperSample(super_dt=dt, super_sample_ratio=1) + else: + # else we want a such that a / dt > 10 * max_freq, or a > 10 * dt * max_freq, a = math.ceil(10 * dt * max_freq) + # a = math.ceil(DESIRED_THRESHOLD * dt * max_frequency) + # return SuperSample(super_dt=dt / a, super_sample_ratio=a) + return SuperSample(super_dt=dt, super_sample_ratio=1) diff --git a/tantri/dipoles/time_series.py b/tantri/dipoles/time_series.py new file mode 100644 index 0000000..b8a6c1d --- /dev/null +++ b/tantri/dipoles/time_series.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass +import numpy +import numpy.random +import typing +from tantri.dipoles.types import DipoleTO, DotPosition, DipoleMeasurementType +import tantri.dipoles.supersample + + +@dataclass +class WrappedDipole: + # assumed len 3 + p: numpy.ndarray + s: numpy.ndarray + + # should be 1/tau up to some pis + w: float + + # For caching purposes tell each dipole where the dots are + # TODO: This can be done better by only passing into the time series the non-repeated p s and w, + # TODO: and then creating a new wrapper type to include all the cached stuff. + # TODO: Realistically, the dot positions and measurement type data should live in the time series. + dot_positions: typing.Sequence[DotPosition] + + measurement_type: DipoleMeasurementType + + def __post_init__(self) -> None: + """ + Coerce the inputs into numpy arrays. + """ + self.p = numpy.array(self.p) + self.s = numpy.array(self.s) + + self.state = 1 + self.cache = {} + for pos in self.dot_positions: + if self.measurement_type is DipoleMeasurementType.ELECTRIC_POTENTIAL: + self.cache[pos.label] = self.potential(pos) + elif self.measurement_type is DipoleMeasurementType.X_ELECTRIC_FIELD: + self.cache[pos.label] = self.e_field_x(pos) + + def potential(self, dot: DotPosition) -> float: + # let's assume single dot at origin for now + r_diff = self.s - dot.r + return self.p.dot(r_diff) / (numpy.linalg.norm(r_diff) ** 3) + + def e_field_x(self, dot: DotPosition) -> float: + # let's assume single dot at origin for now + r_diff = self.s - dot.r + norm = numpy.linalg.norm(r_diff) + + return ( + ((3 * self.p.dot(r_diff) * r_diff / (norm**2)) - self.p) / (norm**3) + )[0] + + def transition( + self, dt: float, rng_to_use: typing.Optional[numpy.random.Generator] = None + ) -> typing.Dict[str, float]: + rng: numpy.random.Generator + if rng_to_use is None: + rng = numpy.random.default_rng() + else: + rng = rng_to_use + # if on average flipping often, then just return 0, basically this dipole has been all used up. + # Facilitates going for different types of noise at very low freq? + if dt * 10 >= 1 / self.w: + # _logger.warning( + # f"delta t {dt} is too long compared to dipole frequency {self.w}" + # ) + self.state = rng.integers(0, 1, endpoint=True) + else: + prob = dt * self.w + if rng.random() < prob: + # _logger.debug("flip!") + self.flip_state() + return {k: self.state * v for k, v in self.cache.items()} + + def flip_state(self): + self.state *= -1 + + +def get_wrapped_dipoles( + dipole_tos: typing.Sequence[DipoleTO], + dots: typing.Sequence[DotPosition], + measurement_type: DipoleMeasurementType, +) -> typing.Sequence[WrappedDipole]: + return [ + WrappedDipole( + p=dipole_to.p, + s=dipole_to.s, + w=dipole_to.w, + dot_positions=dots, + measurement_type=measurement_type, + ) + for dipole_to in dipole_tos + ] + + +class DipoleTimeSeries: + def __init__( + self, + dipoles: typing.Sequence[DipoleTO], + dots: typing.Sequence[DotPosition], + measurement_type: DipoleMeasurementType, + dt: float, + rng_to_use: typing.Optional[numpy.random.Generator] = None, + ): + self.rng: numpy.random.Generator + if rng_to_use is None: + self.rng = numpy.random.default_rng() + else: + self.rng = rng_to_use + + self.dipoles = get_wrapped_dipoles(dipoles, dots, measurement_type) + self.state = 0 + + # we may need to supersample, because of how dumb this process is. + # let's find our highest frequency + max_frequency = max(d.w for d in self.dipoles) + + super_sample = tantri.dipoles.supersample.get_supersample(max_frequency, dt) + self.dt = super_sample.super_dt + self.super_sample_ratio = super_sample.super_sample_ratio + + def transition(self) -> typing.Dict[str, float]: + new_vals = [dipole.transition(self.dt, self.rng) for dipole in self.dipoles] + + ret = {} + for transition in new_vals: + for k, v in transition.items(): + if k not in ret: + ret[k] = v + else: + ret[k] += v + return ret diff --git a/tantri/dipoles/types.py b/tantri/dipoles/types.py index d2d9d0a..841cd0a 100755 --- a/tantri/dipoles/types.py +++ b/tantri/dipoles/types.py @@ -1,5 +1,6 @@ import numpy from dataclasses import dataclass, asdict +from enum import Enum # Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing @@ -15,3 +16,59 @@ class DipoleTO: def as_dict(self) -> dict: return asdict(self) + + +class Orientation(str, Enum): + # Enum for orientation, making string for json serialisation purposes + # + # Note that htis might not be infinitely extensible? + # https://stackoverflow.com/questions/75040733/is-there-a-way-to-use-strenum-in-earlier-python-versions + XY = "XY" + Z = "Z" + RANDOM = "RANDOM" + + +# A description of the parameters needed to generate random dipoles +@dataclass +class DipoleGenerationConfig: + # note no actual checks anywhere that these are sensibly defined with min less than max etc. + x_min: float + x_max: float + y_min: float + y_max: float + z_min: float + z_max: float + + mag: float + + # these are log_10 of actual value + w_log_min: float + w_log_max: float + + orientation: Orientation + + dipole_count: int + generation_seed: int + + def __post_init__(self): + # This allows us to transparently set this with a string, while providing early warning of a type error + self.orientation = Orientation(self.orientation) + + def as_dict(self) -> dict: + return_dict = asdict(self) + + return_dict["orientation"] = return_dict["orientation"].value + + return return_dict + + +class DipoleMeasurementType(Enum): + ELECTRIC_POTENTIAL = 1 + X_ELECTRIC_FIELD = 2 + + +@dataclass(frozen=True) +class DotPosition: + # assume len 3 + r: numpy.ndarray + label: str diff --git a/tests/dipoles/generation/test_dipole_generation.py b/tests/dipoles/generation/test_dipole_generation.py index cc0d03c..c007c3b 100755 --- a/tests/dipoles/generation/test_dipole_generation.py +++ b/tests/dipoles/generation/test_dipole_generation.py @@ -1,4 +1,5 @@ -from tantri.dipoles.generation import DipoleGenerationConfig, make_dipoles, Orientation +from tantri.dipoles.types import DipoleGenerationConfig, Orientation +from tantri.dipoles.generation import make_dipoles import numpy diff --git a/tests/dipoles/generation/test_serialization_dipole_generation.py b/tests/dipoles/generation/test_serialization_dipole_generation.py index 8b1d359..11c8749 100755 --- a/tests/dipoles/generation/test_serialization_dipole_generation.py +++ b/tests/dipoles/generation/test_serialization_dipole_generation.py @@ -1,5 +1,5 @@ import json -from tantri.dipoles.generation import DipoleGenerationConfig, Orientation +from tantri.dipoles.types import DipoleGenerationConfig, Orientation def test_serialise_generation_config_to_json(snapshot): diff --git a/tests/dipoles/test_supersample.py b/tests/dipoles/test_supersample.py new file mode 100644 index 0000000..dbea378 --- /dev/null +++ b/tests/dipoles/test_supersample.py @@ -0,0 +1,9 @@ +from tantri.dipoles.supersample import get_supersample, SuperSample + + +def test_raw_supersample(): + # let's say we go really really slow + dt = 1 + max_frequency = 0.0001 + + assert get_supersample(max_frequency, dt) == SuperSample(1, 1)