Adds slight refactor for n as something model can provide
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
This commit is contained in:
parent
3a03204d53
commit
8fee0a27d2
@ -12,6 +12,9 @@ class Model():
|
||||
def point_length(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def n(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def v_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -29,10 +29,10 @@ class FixedZPlaneModel(Model):
|
||||
self.xmax = xmax
|
||||
self.ymin = ymin
|
||||
self.ymax = ymax
|
||||
self.n = n
|
||||
self._n = n
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'FixedZPlaneModel({self.z}, {self.xmin}, {self.xmax}, {self.ymin}, {self.ymax}, {self.n})'
|
||||
return f'FixedZPlaneModel({self.z}, {self.xmin}, {self.xmax}, {self.ymin}, {self.ymax}, {self.n()})'
|
||||
|
||||
def point_length(self) -> int:
|
||||
'''
|
||||
@ -41,6 +41,9 @@ class FixedZPlaneModel(Model):
|
||||
'''
|
||||
return 4
|
||||
|
||||
def n(self) -> int:
|
||||
return self._n
|
||||
|
||||
def v_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> float:
|
||||
p = numpy.array([0, 0, pt[0]])
|
||||
s = numpy.array([pt[1], pt[2], self.z])
|
||||
|
@ -1,9 +1,17 @@
|
||||
import numpy
|
||||
import scipy.optimize
|
||||
from pdme.model import Model
|
||||
from pdme.measurement import DotMeasurement
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def sol(self, initial_dipole=(0.1, 0.1, 0.1), initial_position=(.1, .1, .1), initial_frequency=1, use_root=True):
|
||||
initial = numpy.tile(numpy.concatenate((initial_dipole, initial_position, initial_frequency), axis=None), self.n)
|
||||
def sol(model: Model, dots: Sequence[DotMeasurement], initial_pt=None):
|
||||
if initial_pt is None:
|
||||
initial = numpy.tile(.1, model.n() * model.point_length())
|
||||
else:
|
||||
if len(initial_pt) != model.point_length():
|
||||
raise ValueError(f"The initial point {initial_pt} does not have the model's expected length: {model.point_length()}")
|
||||
initial = numpy.tile(initial_pt, model.n())
|
||||
|
||||
result = scipy.optimize.least_squares(self.costs(), initial, jac=self.jac(), ftol=1e-15, gtol=3e-16)
|
||||
result = scipy.optimize.least_squares(model.costs(dots), initial, jac=model.jac(dots), ftol=1e-15, gtol=3e-16)
|
||||
return result
|
||||
|
@ -9,6 +9,12 @@ def test_model_interface_not_implemented_point_length():
|
||||
model.point_length()
|
||||
|
||||
|
||||
def test_model_interface_not_implemented_point_n():
|
||||
model = Model()
|
||||
with pytest.raises(NotImplementedError):
|
||||
model.n()
|
||||
|
||||
|
||||
def test_model_interface_not_implemented_cost():
|
||||
model = Model()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user