Adds unrestricted model and adds util for normalising
Some checks failed
gitea-physics/pdme/pipeline/head There was a failure building this commit
Some checks failed
gitea-physics/pdme/pipeline/head There was a failure building this commit
This commit is contained in:
parent
5869691634
commit
b8bbdf29f4
@ -1,4 +1,5 @@
|
||||
from pdme.model.model import Model
|
||||
from pdme.model.fixed_z_plane_model import FixedZPlaneModel
|
||||
from pdme.model.unrestricted_model import UnrestrictedModel
|
||||
|
||||
__all__ = ["Model", "FixedZPlaneModel"]
|
||||
__all__ = ["Model", "FixedZPlaneModel", "UnrestrictedModel"]
|
||||
|
@ -2,6 +2,7 @@ import numpy
|
||||
import scipy.optimize
|
||||
from typing import Callable, Sequence
|
||||
from pdme.measurement import DotMeasurement
|
||||
import pdme.util
|
||||
import logging
|
||||
|
||||
|
||||
@ -83,4 +84,5 @@ class Model():
|
||||
initial = numpy.tile(initial_pt, self.n())
|
||||
|
||||
result = scipy.optimize.least_squares(self.costs(dots), initial, jac=self.jac(dots), ftol=1e-15, gtol=3e-16, bounds=bounds)
|
||||
result.normalised_x = pdme.util.normalise_point_list(result.x, self.point_length())
|
||||
return result
|
||||
|
60
pdme/model/unrestricted_model.py
Normal file
60
pdme/model/unrestricted_model.py
Normal file
@ -0,0 +1,60 @@
|
||||
import numpy
|
||||
from pdme.model.model import Model
|
||||
from pdme.measurement import DotMeasurement
|
||||
|
||||
|
||||
class UnrestrictedModel(Model):
|
||||
'''
|
||||
Model of oscillating dipoles with no restrictions.
|
||||
Additionally, each dipole is assumed to be orientated in the plus or minus z direction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n : int
|
||||
The number of dipoles to assume.
|
||||
'''
|
||||
def __init__(self, n: int) -> None:
|
||||
self._n = n
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'UnrestrictedModel({self.n()})'
|
||||
|
||||
def point_length(self) -> int:
|
||||
'''
|
||||
Dipole is unconstrained in this model.
|
||||
All seven degrees of freedom: (px, py, pz, sx, sy, sz, w).
|
||||
'''
|
||||
return 7
|
||||
|
||||
def n(self) -> int:
|
||||
return self._n
|
||||
|
||||
def v_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> float:
|
||||
p = pt[0:3]
|
||||
s = pt[3:6]
|
||||
w = pt[6]
|
||||
|
||||
diff = dot.r - s
|
||||
alpha = p.dot(diff) / (numpy.linalg.norm(diff)**3)
|
||||
b = (1 / numpy.pi) * (w / (w**2 + dot.f**2))
|
||||
return alpha**2 * b
|
||||
|
||||
def jac_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> numpy.ndarray:
|
||||
p = pt[0:3]
|
||||
s = pt[3:6]
|
||||
w = pt[6]
|
||||
|
||||
diff = dot.r - s
|
||||
alpha = p.dot(diff) / (numpy.linalg.norm(diff)**3)
|
||||
b = (1 / numpy.pi) * (w / (w**2 + dot.f**2))
|
||||
|
||||
p_divs = 2 * alpha * diff / (numpy.linalg.norm(diff)**3) * b
|
||||
|
||||
r_divs = (-p / (numpy.linalg.norm(diff)**3) + 3 * p.dot(diff) * diff / (numpy.linalg.norm(diff)**5)) * 2 * alpha * b
|
||||
|
||||
f2 = dot.f**2
|
||||
w2 = w**2
|
||||
|
||||
w_div = alpha**2 * (1 / numpy.pi) * ((f2 - w2) / ((f2 + w2)**2))
|
||||
|
||||
return numpy.concatenate((p_divs, r_divs, w_div), axis=None)
|
3
pdme/util/__init__.py
Normal file
3
pdme/util/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pdme.util.normal_form import normalise_point_list
|
||||
|
||||
__all__ = ["normalise_point_list"]
|
20
pdme/util/normal_form.py
Normal file
20
pdme/util/normal_form.py
Normal file
@ -0,0 +1,20 @@
|
||||
import numpy
|
||||
import operator
|
||||
|
||||
|
||||
# flips px, py, pz
|
||||
SIGN_ARRAY = numpy.array((-1, -1, -1, 1, 1, 1, 1))
|
||||
|
||||
|
||||
def flip_chunk_to_positive_px(pt: numpy.ndarray) -> numpy.ndarray:
|
||||
if pt[0] > 0:
|
||||
return pt
|
||||
else:
|
||||
return SIGN_ARRAY * pt
|
||||
|
||||
|
||||
def normalise_point_list(pts: numpy.ndarray, pt_length) -> numpy.ndarray:
|
||||
chunked_pts = [flip_chunk_to_positive_px(pts[i: i + pt_length]) for i in range(0, len(pts), pt_length)]
|
||||
range_to_length = list(range(pt_length))
|
||||
rotated_range = range_to_length[pt_length - 1:] + range_to_length[0:pt_length - 1]
|
||||
return numpy.concatenate(sorted(chunked_pts, key=lambda x: tuple(round(val, 3) for val in operator.itemgetter(*rotated_range)(x))), axis=None)
|
24
tests/model/test_unrestricted_basic_solve.py
Normal file
24
tests/model/test_unrestricted_basic_solve.py
Normal file
@ -0,0 +1,24 @@
|
||||
from pdme.model import UnrestrictedModel
|
||||
from pdme.measurement import OscillatingDipole, OscillatingDipoleArrangement
|
||||
import logging
|
||||
import numpy
|
||||
import itertools
|
||||
|
||||
|
||||
def test_unrestricted_model_solve_basic():
|
||||
# Initialise our dipole arrangement and create dot measurements along a square.
|
||||
dipoles = OscillatingDipoleArrangement([OscillatingDipole((.2, 0, 2), (1, 2, 4), 1)])
|
||||
dot_inputs = list(itertools.chain.from_iterable(
|
||||
(([1, 2, 0.01], f), ([1, 1, -0.2], f), ([1.5, 2, 0.01], f), ([1.5, 1, -0.2], f), ([2, 1, 0], f), ([2, 2, 0], f), ([0, 2, -.1], f), ([0, 1, 0.04], f), ([2, 0, 0], f), ([1, 0, 0], f)) for f in numpy.arange(1, 10, 2)
|
||||
))
|
||||
dots = dipoles.get_dot_measurements(dot_inputs)
|
||||
|
||||
model = UnrestrictedModel(1)
|
||||
|
||||
# from the dipole, these are the unspecified variables in ((0, 0, 2), (1, 2, 4), 1)
|
||||
expected_solution = [0.2, 0, 2, 1, 2, 4, 1]
|
||||
|
||||
result = model.solve(dots)
|
||||
logging.info(result)
|
||||
assert result.success
|
||||
numpy.testing.assert_allclose(result.normalised_x, expected_solution, err_msg="Even well specified problem solution was wrong.", rtol=1e-6, atol=1e-11)
|
38
tests/model/test_unrestricted_model.py
Normal file
38
tests/model/test_unrestricted_model.py
Normal file
@ -0,0 +1,38 @@
|
||||
from pdme.model import UnrestrictedModel
|
||||
from pdme.measurement import DotMeasurement
|
||||
import logging
|
||||
import numpy
|
||||
|
||||
|
||||
def test_unrestricted_plane_model_repr():
|
||||
model = UnrestrictedModel(6)
|
||||
assert repr(model) == "UnrestrictedModel(6)"
|
||||
|
||||
|
||||
def test_unrestricted_model_cost_and_jac_single():
|
||||
model = UnrestrictedModel(1)
|
||||
measured_v = 0.000191292 # from dipole with p=(0, 0, 2) at (1, 2, 4) with w = 1
|
||||
dot = DotMeasurement(measured_v, (1, 2, 0), 5)
|
||||
pt = [0, 0, 2, 2, 2, 4, 2]
|
||||
|
||||
cost_function = model.costs([dot])
|
||||
|
||||
expected_cost = [0.0000946746]
|
||||
actual_cost = cost_function(pt)
|
||||
|
||||
numpy.testing.assert_allclose(actual_cost, expected_cost, err_msg="Cost wasn't as expected.", rtol=1e-6, atol=1e-11)
|
||||
|
||||
jac_function = model.jac([dot])
|
||||
|
||||
expected_jac = [
|
||||
[
|
||||
0.00007149165379592005, 0, 0.0002859666151836802,
|
||||
-0.0001009293935942401, 0, -0.0002607342667851202,
|
||||
0.0001035396365320221
|
||||
]
|
||||
]
|
||||
actual_jac = jac_function(pt)
|
||||
|
||||
logging.warning(actual_jac)
|
||||
|
||||
numpy.testing.assert_allclose(actual_jac, expected_jac, err_msg="Jac wasn't as expected.", rtol=1e-6, atol=1e-11)
|
Loading…
x
Reference in New Issue
Block a user