Slight refactor to pull solutions into model
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import numpy
|
||||
import scipy.optimize
|
||||
from typing import Callable, Sequence
|
||||
from pdme.measurement import DotMeasurement
|
||||
import logging
|
||||
@@ -72,3 +73,14 @@ class Model():
|
||||
return numpy.array([self.jac_for_dot(dot, pts) for dot in dots])
|
||||
|
||||
return jac_to_return
|
||||
|
||||
def solve(self, dots: Sequence[DotMeasurement], initial_pt: numpy.ndarray = None, bounds=(-numpy.inf, numpy.inf)) -> scipy.optimize.OptimizeResult:
|
||||
if initial_pt is None:
|
||||
initial = numpy.tile(.1, self.n() * self.point_length())
|
||||
else:
|
||||
if len(initial_pt) != self.point_length():
|
||||
raise ValueError(f"The initial point {initial_pt} does not have the model's expected length: {self.point_length()}")
|
||||
initial = numpy.tile(initial_pt, self.n())
|
||||
|
||||
result = scipy.optimize.least_squares(self.costs(dots), initial, jac=self.jac(dots), ftol=1e-15, gtol=3e-16, bounds=bounds)
|
||||
return result
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from pdme.solver.solver import sol
|
||||
|
||||
__all__ = ["sol"]
|
||||
@@ -1,17 +0,0 @@
|
||||
import numpy
|
||||
import scipy.optimize
|
||||
from pdme.model import Model
|
||||
from pdme.measurement import DotMeasurement
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def sol(model: Model, dots: Sequence[DotMeasurement], initial_pt=None, bounds=(-numpy.inf, numpy.inf)):
|
||||
if initial_pt is None:
|
||||
initial = numpy.tile(.1, model.n() * model.point_length())
|
||||
else:
|
||||
if len(initial_pt) != model.point_length():
|
||||
raise ValueError(f"The initial point {initial_pt} does not have the model's expected length: {model.point_length()}")
|
||||
initial = numpy.tile(initial_pt, model.n())
|
||||
|
||||
result = scipy.optimize.least_squares(model.costs(dots), initial, jac=model.jac(dots), ftol=1e-15, gtol=3e-16, bounds=bounds)
|
||||
return result
|
||||
Reference in New Issue
Block a user