feat: adds dipole generation code

This commit is contained in:
Deepak Mallubhotla 2024-04-20 19:52:56 -05:00
parent ad8892aa9c
commit 6ef005fbe1
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
3 changed files with 194 additions and 21 deletions

View File

@ -1,38 +1,90 @@
import numpy
from typing import Sequence
from typing import Sequence, Optional
from dataclasses import dataclass
from tantri.dipoles.types import DipoleTO
from enum import Enum
import logging
_logger = logging.getLogger(__name__)
# stuff for generating random dipoles from parameters
class Orientation(Enum):
XY = 1
Z = 2
RANDOM = 3
# A description of the parameters needed to generate random dipoles
@dataclass
class DipoleGenerationConfig:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# note no actual checks anywhere that these are sensibly defined with min less than max etc.
x_min: float
x_max: float
y_min: float
y_max: float
z_min: float
z_max: float
# should be 1/tau up to some pis
w: float
mag: float
# these are log_10 of actual value
w_log_min: float
w_log_max: float
orientation: Orientation
dipole_count: int
generation_seed: int
def row_to_dipole(input_dict: dict) -> DipoleTO:
p = input_dict["p"]
if len(p) != 3:
raise ValueError(
f"p parameter in input_dict [{input_dict}] does not have length 3"
def make_dipoles(
config: DipoleGenerationConfig,
rng_override: Optional[numpy.random.Generator] = None,
) -> Sequence[DipoleTO]:
if rng_override is None:
_logger.info(
f"Using the seed [{config.generation_seed}] provided by configuration for dipole generation"
)
s = input_dict["s"]
if len(s) != 3:
raise ValueError(
f"s parameter in input_dict [{input_dict}] does not have length 3"
rng = numpy.random.default_rng(config.generation_seed)
else:
rng = rng_override
dipoles = []
for i in range(config.dipole_count):
sx = rng.uniform(config.x_min, config.x_max)
sy = rng.uniform(config.y_min, config.y_max)
sz = rng.uniform(config.z_min, config.z_max)
# orientation
# 0, 1, 2
# xy, z, random
if config.orientation is Orientation.RANDOM:
theta = numpy.arccos(2 * rng.random() - 1)
phi = 2 * numpy.pi * rng.random()
elif config.orientation is Orientation.Z:
theta = 0
phi = 0
elif config.orientation is Orientation.XY:
theta = numpy.pi / 2
phi = 2 * numpy.pi * rng.random()
else:
raise ValueError(
f"this shouldn't have happened, orientation index: {config}"
)
px = config.mag * numpy.cos(phi) * numpy.sin(theta)
py = config.mag * numpy.sin(phi) * numpy.sin(theta)
pz = config.mag * numpy.cos(theta)
w = 10 ** rng.uniform(config.w_log_min, config.w_log_max)
dipoles.append(
DipoleTO(numpy.array([px, py, pz]), numpy.array([sx, sy, sz]), w)
)
w = input_dict["w"]
return DipoleTO(p, s, w)
def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]:
return [row_to_dipole(input_dict) for input_dict in dot_dict_array]
return dipoles

View File

@ -0,0 +1,37 @@
# serializer version: 1
# name: test_generation_simple_random
list([
DipoleTO(p=array([-36.97888725, 79.75551306, -47.66151523]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.0022608601493359997),
DipoleTO(p=array([-82.1374974 , 31.92138656, -47.27003914]), s=array([-5.16467413, -1.81466071, 7.85631698]), w=0.06754798966620681),
DipoleTO(p=array([ -8.46760467, -94.37111521, 31.97486959]), s=array([7.27242593, 3.63757671, 6.69952525]), w=0.004658659556177238),
DipoleTO(p=array([-44.1712149 , -81.84888003, 36.73778182]), s=array([-6.55867631, 3.70414972, 4.24055463]), w=0.06808539217046054),
DipoleTO(p=array([99.76796192, 1.96376112, 6.5190043 ]), s=array([-8.79725374, 4.77769274, 5.75580651]), w=0.005672850212356041),
])
# ---
# name: test_generation_simple_xy
list([
DipoleTO(p=array([-7.33995988e+00, 9.97302611e+01, 6.12323400e-15]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.009063400581631654),
DipoleTO(p=array([ 9.74638419e+01, -2.23785504e+01, 6.12323400e-15]), s=array([-7.63817534, -2.58233707, 5.27413572]), w=0.00617944416992951),
DipoleTO(p=array([ 6.55436572e+01, -7.55250223e+01, 6.12323400e-15]), s=array([-1.17987756, 1.09870809, 7.45448519]), w=0.10583856430824175),
DipoleTO(p=array([4.70336990e+01, 8.82486893e+01, 6.12323400e-15]), s=array([3.19748696, 2.35757698, 4.89101463]), w=0.4085497202597084),
DipoleTO(p=array([-7.66420441e+01, -6.42339246e+01, 6.12323400e-15]), s=array([-8.79722684, 1.83688909, 6.68495208]), w=0.0015149975770328856),
])
# ---
# name: test_generation_simple_xy_override_rng
list([
DipoleTO(p=array([-7.33995988e+00, 9.97302611e+01, 6.12323400e-15]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.009063400581631654),
DipoleTO(p=array([ 9.74638419e+01, -2.23785504e+01, 6.12323400e-15]), s=array([-7.63817534, -2.58233707, 5.27413572]), w=0.00617944416992951),
DipoleTO(p=array([ 6.55436572e+01, -7.55250223e+01, 6.12323400e-15]), s=array([-1.17987756, 1.09870809, 7.45448519]), w=0.10583856430824175),
DipoleTO(p=array([4.70336990e+01, 8.82486893e+01, 6.12323400e-15]), s=array([3.19748696, 2.35757698, 4.89101463]), w=0.4085497202597084),
DipoleTO(p=array([-7.66420441e+01, -6.42339246e+01, 6.12323400e-15]), s=array([-8.79722684, 1.83688909, 6.68495208]), w=0.0015149975770328856),
])
# ---
# name: test_generation_simple_z
list([
DipoleTO(p=array([ 0., 0., 100.]), s=array([ 9.53399533, -1.19804265, 7.69298494]), w=0.006096453583832874),
DipoleTO(p=array([ 0., 0., 100.]), s=array([-3.61805883, -3.81908767, 4.96706517]), w=0.009028212784568558),
DipoleTO(p=array([ 0., 0., 100.]), s=array([ 9.2815849 , -2.36350196, 5.76402449]), w=0.06754798966620681),
DipoleTO(p=array([ 0., 0., 100.]), s=array([7.27242593, 3.63757671, 6.69952525]), w=0.09541640373507952),
DipoleTO(p=array([ 0., 0., 100.]), s=array([ 4.71515397, -2.77246342, 4.68826474]), w=0.4085497202597084),
])
# ---

View File

@ -0,0 +1,84 @@
from tantri.dipoles.generation import DipoleGenerationConfig, make_dipoles, Orientation
import numpy
def test_generation_simple_xy_override_rng(snapshot):
rng = numpy.random.default_rng(1234)
config = DipoleGenerationConfig(
x_min=-10,
x_max=10,
y_min=-5,
y_max=5,
z_min=4,
z_max=8,
mag=100,
w_log_min=-3,
w_log_max=0,
orientation=Orientation.XY,
dipole_count=5,
generation_seed=999999, # should not be used
)
dipoles = make_dipoles(config, rng)
assert dipoles == snapshot
def test_generation_simple_xy(snapshot):
config = DipoleGenerationConfig(
x_min=-10,
x_max=10,
y_min=-5,
y_max=5,
z_min=4,
z_max=8,
mag=100,
w_log_min=-3,
w_log_max=0,
orientation=Orientation.XY,
dipole_count=5,
generation_seed=1234,
)
dipoles = make_dipoles(config)
assert dipoles == snapshot
def test_generation_simple_z(snapshot):
config = DipoleGenerationConfig(
x_min=-10,
x_max=10,
y_min=-5,
y_max=5,
z_min=4,
z_max=8,
mag=100,
w_log_min=-3,
w_log_max=0,
orientation=Orientation.Z,
dipole_count=5,
generation_seed=1234,
)
dipoles = make_dipoles(config)
assert dipoles == snapshot
def test_generation_simple_random(snapshot):
config = DipoleGenerationConfig(
x_min=-10,
x_max=10,
y_min=-5,
y_max=5,
z_min=4,
z_max=8,
mag=100,
w_log_min=-3,
w_log_max=0,
orientation=Orientation.RANDOM,
dipole_count=5,
generation_seed=1234,
)
dipoles = make_dipoles(config)
assert dipoles == snapshot