feat: adds dipole generation code
This commit is contained in:
parent
ad8892aa9c
commit
6ef005fbe1
@ -1,38 +1,90 @@
|
||||
import numpy
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Optional
|
||||
from dataclasses import dataclass
|
||||
from tantri.dipoles.types import DipoleTO
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
# stuff for generating random dipoles from parameters
|
||||
|
||||
|
||||
class Orientation(Enum):
|
||||
XY = 1
|
||||
Z = 2
|
||||
RANDOM = 3
|
||||
|
||||
|
||||
# A description of the parameters needed to generate random dipoles
|
||||
@dataclass
|
||||
class DipoleGenerationConfig:
|
||||
# assumed len 3
|
||||
p: numpy.ndarray
|
||||
s: numpy.ndarray
|
||||
# note no actual checks anywhere that these are sensibly defined with min less than max etc.
|
||||
x_min: float
|
||||
x_max: float
|
||||
y_min: float
|
||||
y_max: float
|
||||
z_min: float
|
||||
z_max: float
|
||||
|
||||
# should be 1/tau up to some pis
|
||||
w: float
|
||||
mag: float
|
||||
|
||||
# these are log_10 of actual value
|
||||
w_log_min: float
|
||||
w_log_max: float
|
||||
|
||||
orientation: Orientation
|
||||
|
||||
dipole_count: int
|
||||
generation_seed: int
|
||||
|
||||
|
||||
def row_to_dipole(input_dict: dict) -> DipoleTO:
|
||||
p = input_dict["p"]
|
||||
if len(p) != 3:
|
||||
raise ValueError(
|
||||
f"p parameter in input_dict [{input_dict}] does not have length 3"
|
||||
def make_dipoles(
|
||||
config: DipoleGenerationConfig,
|
||||
rng_override: Optional[numpy.random.Generator] = None,
|
||||
) -> Sequence[DipoleTO]:
|
||||
|
||||
if rng_override is None:
|
||||
_logger.info(
|
||||
f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation"
|
||||
)
|
||||
s = input_dict["s"]
|
||||
if len(s) != 3:
|
||||
raise ValueError(
|
||||
f"s parameter in input_dict [{input_dict}] does not have length 3"
|
||||
rng = numpy.random.default_rng(config.generation_seed)
|
||||
else:
|
||||
rng = rng_override
|
||||
|
||||
dipoles = []
|
||||
|
||||
for i in range(config.dipole_count):
|
||||
sx = rng.uniform(config.x_min, config.x_max)
|
||||
sy = rng.uniform(config.y_min, config.y_max)
|
||||
sz = rng.uniform(config.z_min, config.z_max)
|
||||
|
||||
# orientation
|
||||
# 0, 1, 2
|
||||
# xy, z, random
|
||||
|
||||
if config.orientation is Orientation.RANDOM:
|
||||
theta = numpy.arccos(2 * rng.random() - 1)
|
||||
phi = 2 * numpy.pi * rng.random()
|
||||
elif config.orientation is Orientation.Z:
|
||||
theta = 0
|
||||
phi = 0
|
||||
elif config.orientation is Orientation.XY:
|
||||
theta = numpy.pi / 2
|
||||
phi = 2 * numpy.pi * rng.random()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"this shouldn't have happened, orientation index: {config}"
|
||||
)
|
||||
|
||||
px = config.mag * numpy.cos(phi) * numpy.sin(theta)
|
||||
py = config.mag * numpy.sin(phi) * numpy.sin(theta)
|
||||
pz = config.mag * numpy.cos(theta)
|
||||
|
||||
w = 10 ** rng.uniform(config.w_log_min, config.w_log_max)
|
||||
|
||||
dipoles.append(
|
||||
DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w)
|
||||
)
|
||||
w = input_dict["w"]
|
||||
|
||||
return DipoleTO(p, s, w)
|
||||
|
||||
|
||||
def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]:
|
||||
return [row_to_dipole(input_dict) for input_dict in dot_dict_array]
|
||||
return dipoles
|
||||
|
37
tests/dipoles/generation/__snapshots__/test_dipole_generation.ambr
Executable file
37
tests/dipoles/generation/__snapshots__/test_dipole_generation.ambr
Executable file
@ -0,0 +1,37 @@
|
||||
# serializer version: 1
|
||||
# name: test_generation_simple_random
|
||||
list([
|
||||
DipoleTO(p=array([-36.97888725, 79.75551306, -47.66151523]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.0022608601493359997),
|
||||
DipoleTO(p=array([-82.1374974 , 31.92138656, -47.27003914]), s=array([-5.16467413, -1.81466071, 7.85631698]), w=0.06754798966620681),
|
||||
DipoleTO(p=array([ -8.46760467, -94.37111521, 31.97486959]), s=array([7.27242593, 3.63757671, 6.69952525]), w=0.004658659556177238),
|
||||
DipoleTO(p=array([-44.1712149 , -81.84888003, 36.73778182]), s=array([-6.55867631, 3.70414972, 4.24055463]), w=0.06808539217046054),
|
||||
DipoleTO(p=array([99.76796192, 1.96376112, 6.5190043 ]), s=array([-8.79725374, 4.77769274, 5.75580651]), w=0.005672850212356041),
|
||||
])
|
||||
# ---
|
||||
# name: test_generation_simple_xy
|
||||
list([
|
||||
DipoleTO(p=array([-7.33995988e+00, 9.97302611e+01, 6.12323400e-15]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.009063400581631654),
|
||||
DipoleTO(p=array([ 9.74638419e+01, -2.23785504e+01, 6.12323400e-15]), s=array([-7.63817534, -2.58233707, 5.27413572]), w=0.00617944416992951),
|
||||
DipoleTO(p=array([ 6.55436572e+01, -7.55250223e+01, 6.12323400e-15]), s=array([-1.17987756, 1.09870809, 7.45448519]), w=0.10583856430824175),
|
||||
DipoleTO(p=array([4.70336990e+01, 8.82486893e+01, 6.12323400e-15]), s=array([3.19748696, 2.35757698, 4.89101463]), w=0.4085497202597084),
|
||||
DipoleTO(p=array([-7.66420441e+01, -6.42339246e+01, 6.12323400e-15]), s=array([-8.79722684, 1.83688909, 6.68495208]), w=0.0015149975770328856),
|
||||
])
|
||||
# ---
|
||||
# name: test_generation_simple_xy_override_rng
|
||||
list([
|
||||
DipoleTO(p=array([-7.33995988e+00, 9.97302611e+01, 6.12323400e-15]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.009063400581631654),
|
||||
DipoleTO(p=array([ 9.74638419e+01, -2.23785504e+01, 6.12323400e-15]), s=array([-7.63817534, -2.58233707, 5.27413572]), w=0.00617944416992951),
|
||||
DipoleTO(p=array([ 6.55436572e+01, -7.55250223e+01, 6.12323400e-15]), s=array([-1.17987756, 1.09870809, 7.45448519]), w=0.10583856430824175),
|
||||
DipoleTO(p=array([4.70336990e+01, 8.82486893e+01, 6.12323400e-15]), s=array([3.19748696, 2.35757698, 4.89101463]), w=0.4085497202597084),
|
||||
DipoleTO(p=array([-7.66420441e+01, -6.42339246e+01, 6.12323400e-15]), s=array([-8.79722684, 1.83688909, 6.68495208]), w=0.0015149975770328856),
|
||||
])
|
||||
# ---
|
||||
# name: test_generation_simple_z
|
||||
list([
|
||||
DipoleTO(p=array([ 0., 0., 100.]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.006096453583832874),
|
||||
DipoleTO(p=array([ 0., 0., 100.]), s=array([-3.61805883, -3.81908767, 4.96706517]), w=0.009028212784568558),
|
||||
DipoleTO(p=array([ 0., 0., 100.]), s=array([ 9.2815849 , -2.36350196, 5.76402449]), w=0.06754798966620681),
|
||||
DipoleTO(p=array([ 0., 0., 100.]), s=array([7.27242593, 3.63757671, 6.69952525]), w=0.09541640373507952),
|
||||
DipoleTO(p=array([ 0., 0., 100.]), s=array([ 4.71515397, -2.77246342, 4.68826474]), w=0.4085497202597084),
|
||||
])
|
||||
# ---
|
84
tests/dipoles/generation/test_dipole_generation.py
Executable file
84
tests/dipoles/generation/test_dipole_generation.py
Executable file
@ -0,0 +1,84 @@
|
||||
from tantri.dipoles.generation import DipoleGenerationConfig, make_dipoles, Orientation
|
||||
import numpy
|
||||
|
||||
def test_generation_simple_xy_override_rng(snapshot):
|
||||
|
||||
rng = numpy.random.default_rng(1234)
|
||||
config = DipoleGenerationConfig(
|
||||
x_min=-10,
|
||||
x_max=10,
|
||||
y_min=-5,
|
||||
y_max=5,
|
||||
z_min=4,
|
||||
z_max=8,
|
||||
mag=100,
|
||||
w_log_min=-3,
|
||||
w_log_max=0,
|
||||
orientation=Orientation.XY,
|
||||
dipole_count=5,
|
||||
generation_seed=999999, # should not be used
|
||||
)
|
||||
|
||||
dipoles = make_dipoles(config, rng)
|
||||
assert dipoles == snapshot
|
||||
|
||||
def test_generation_simple_xy(snapshot):
|
||||
|
||||
config = DipoleGenerationConfig(
|
||||
x_min=-10,
|
||||
x_max=10,
|
||||
y_min=-5,
|
||||
y_max=5,
|
||||
z_min=4,
|
||||
z_max=8,
|
||||
mag=100,
|
||||
w_log_min=-3,
|
||||
w_log_max=0,
|
||||
orientation=Orientation.XY,
|
||||
dipole_count=5,
|
||||
generation_seed=1234,
|
||||
)
|
||||
|
||||
dipoles = make_dipoles(config)
|
||||
assert dipoles == snapshot
|
||||
|
||||
def test_generation_simple_z(snapshot):
|
||||
|
||||
config = DipoleGenerationConfig(
|
||||
x_min=-10,
|
||||
x_max=10,
|
||||
y_min=-5,
|
||||
y_max=5,
|
||||
z_min=4,
|
||||
z_max=8,
|
||||
mag=100,
|
||||
w_log_min=-3,
|
||||
w_log_max=0,
|
||||
orientation=Orientation.Z,
|
||||
dipole_count=5,
|
||||
generation_seed=1234,
|
||||
)
|
||||
|
||||
dipoles = make_dipoles(config)
|
||||
assert dipoles == snapshot
|
||||
|
||||
def test_generation_simple_random(snapshot):
|
||||
|
||||
config = DipoleGenerationConfig(
|
||||
x_min=-10,
|
||||
x_max=10,
|
||||
y_min=-5,
|
||||
y_max=5,
|
||||
z_min=4,
|
||||
z_max=8,
|
||||
mag=100,
|
||||
w_log_min=-3,
|
||||
w_log_max=0,
|
||||
orientation=Orientation.RANDOM,
|
||||
dipole_count=5,
|
||||
generation_seed=1234,
|
||||
)
|
||||
|
||||
dipoles = make_dipoles(config)
|
||||
assert dipoles == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user