Adds slight refactor for n as something model can provide
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good

This commit is contained in:
Deepak Mallubhotla 2022-01-02 16:47:04 -06:00
parent 3a03204d53
commit 8fee0a27d2
Signed by: deepak
GPG Key ID: 64BF53A3369104E7
4 changed files with 25 additions and 5 deletions

View File

@ -12,6 +12,9 @@ class Model():
def point_length(self) -> int:
raise NotImplementedError
def n(self) -> int:
raise NotImplementedError
def v_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> float:
raise NotImplementedError

View File

@ -29,10 +29,10 @@ class FixedZPlaneModel(Model):
self.xmax = xmax
self.ymin = ymin
self.ymax = ymax
self.n = n
self._n = n
def __repr__(self) -> str:
return f'FixedZPlaneModel({self.z}, {self.xmin}, {self.xmax}, {self.ymin}, {self.ymax}, {self.n})'
return f'FixedZPlaneModel({self.z}, {self.xmin}, {self.xmax}, {self.ymin}, {self.ymax}, {self.n()})'
def point_length(self) -> int:
'''
@ -41,6 +41,9 @@ class FixedZPlaneModel(Model):
'''
return 4
def n(self) -> int:
return self._n
def v_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> float:
p = numpy.array([0, 0, pt[0]])
s = numpy.array([pt[1], pt[2], self.z])

View File

@ -1,9 +1,17 @@
import numpy
import scipy.optimize
from pdme.model import Model
from pdme.measurement import DotMeasurement
from typing import Sequence
def sol(self, initial_dipole=(0.1, 0.1, 0.1), initial_position=(.1, .1, .1), initial_frequency=1, use_root=True):
initial = numpy.tile(numpy.concatenate((initial_dipole, initial_position, initial_frequency), axis=None), self.n)
def sol(model: Model, dots: Sequence[DotMeasurement], initial_pt=None):
if initial_pt is None:
initial = numpy.tile(.1, model.n() * model.point_length())
else:
if len(initial_pt) != model.point_length():
raise ValueError(f"The initial point {initial_pt} does not have the model's expected length: {model.point_length()}")
initial = numpy.tile(initial_pt, model.n())
result = scipy.optimize.least_squares(self.costs(), initial, jac=self.jac(), ftol=1e-15, gtol=3e-16)
result = scipy.optimize.least_squares(model.costs(dots), initial, jac=model.jac(dots), ftol=1e-15, gtol=3e-16)
return result

View File

@ -9,6 +9,12 @@ def test_model_interface_not_implemented_point_length():
model.point_length()
def test_model_interface_not_implemented_point_n():
model = Model()
with pytest.raises(NotImplementedError):
model.n()
def test_model_interface_not_implemented_cost():
model = Model()