Adds slight refactor for n as something model can provide
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
This commit is contained in:
@@ -12,6 +12,9 @@ class Model():
|
||||
def point_length(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def n(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def v_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ class FixedZPlaneModel(Model):
|
||||
self.xmax = xmax
|
||||
self.ymin = ymin
|
||||
self.ymax = ymax
|
||||
self.n = n
|
||||
self._n = n
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'FixedZPlaneModel({self.z}, {self.xmin}, {self.xmax}, {self.ymin}, {self.ymax}, {self.n})'
|
||||
return f'FixedZPlaneModel({self.z}, {self.xmin}, {self.xmax}, {self.ymin}, {self.ymax}, {self.n()})'
|
||||
|
||||
def point_length(self) -> int:
|
||||
'''
|
||||
@@ -41,6 +41,9 @@ class FixedZPlaneModel(Model):
|
||||
'''
|
||||
return 4
|
||||
|
||||
def n(self) -> int:
|
||||
return self._n
|
||||
|
||||
def v_for_point_at_dot(self, dot: DotMeasurement, pt: numpy.ndarray) -> float:
|
||||
p = numpy.array([0, 0, pt[0]])
|
||||
s = numpy.array([pt[1], pt[2], self.z])
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import numpy
|
||||
import scipy.optimize
|
||||
from pdme.model import Model
|
||||
from pdme.measurement import DotMeasurement
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def sol(self, initial_dipole=(0.1, 0.1, 0.1), initial_position=(.1, .1, .1), initial_frequency=1, use_root=True):
|
||||
initial = numpy.tile(numpy.concatenate((initial_dipole, initial_position, initial_frequency), axis=None), self.n)
|
||||
def sol(model: Model, dots: Sequence[DotMeasurement], initial_pt=None):
|
||||
if initial_pt is None:
|
||||
initial = numpy.tile(.1, model.n() * model.point_length())
|
||||
else:
|
||||
if len(initial_pt) != model.point_length():
|
||||
raise ValueError(f"The initial point {initial_pt} does not have the model's expected length: {model.point_length()}")
|
||||
initial = numpy.tile(initial_pt, model.n())
|
||||
|
||||
result = scipy.optimize.least_squares(self.costs(), initial, jac=self.jac(), ftol=1e-15, gtol=3e-16)
|
||||
result = scipy.optimize.least_squares(model.costs(dots), initial, jac=model.jac(dots), ftol=1e-15, gtol=3e-16)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user