Adds discretisation class
This commit is contained in:
parent
71e0a81107
commit
f2f76cd8d8
@ -3,7 +3,7 @@ import numpy.random
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Tuple
|
||||
import scipy.optimize
|
||||
from pdme.model.model import Model
|
||||
from pdme.model.model import Model, Discretisation
|
||||
from pdme.measurement import DotMeasurement, OscillatingDipoleArrangement, OscillatingDipole
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class FixedDipoleModel(Model):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FixedDipoleDiscretisation():
|
||||
class FixedDipoleDiscretisation(Discretisation):
|
||||
'''
|
||||
Representation of a discretisation of a FixedDipoleDiscretisation.
|
||||
Also captures a rough maximum value of dipole.
|
||||
@ -101,7 +101,7 @@ class FixedDipoleDiscretisation():
|
||||
self.y_step = (self.model.ymax - self.model.ymin) / self.num_y
|
||||
self.z_step = (self.model.zmax - self.model.zmin) / self.num_z
|
||||
|
||||
def bounds(self, index: Tuple[float, float, float]) -> Tuple:
|
||||
def bounds(self, index: Tuple[float, ...]) -> Tuple:
|
||||
xi, yi, zi = index
|
||||
|
||||
# For this model, a point is (sx, sx, sy, w).
|
||||
@ -121,7 +121,7 @@ class FixedDipoleDiscretisation():
|
||||
# see https://github.com/numpy/numpy/issues/20706 for why this is a mypy problem.
|
||||
return numpy.ndindex((self.num_x, self.num_y, self.num_z)) # type:ignore
|
||||
|
||||
def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, float, float]) -> scipy.optimize.OptimizeResult:
|
||||
def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, ...]) -> scipy.optimize.OptimizeResult:
|
||||
bounds = self.bounds(index)
|
||||
sx_mean = (bounds[0][0] + bounds[1][0]) / 2
|
||||
sy_mean = (bounds[0][1] + bounds[1][1]) / 2
|
||||
|
@ -3,7 +3,7 @@ import numpy.random
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Tuple
|
||||
import scipy.optimize
|
||||
from pdme.model.model import Model
|
||||
from pdme.model.model import Model, Discretisation
|
||||
from pdme.measurement import DotMeasurement, OscillatingDipole, OscillatingDipoleArrangement
|
||||
|
||||
|
||||
@ -106,7 +106,7 @@ class FixedMagnitudeModel(Model):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FixedMagnitudeDiscretisation():
|
||||
class FixedMagnitudeDiscretisation(Discretisation):
|
||||
'''
|
||||
Representation of a discretisation of a FixedMagnitudeDiscretisation.
|
||||
Also captures a rough maximum value of dipole.
|
||||
@ -141,7 +141,7 @@ class FixedMagnitudeDiscretisation():
|
||||
self.h_step = 2 / self.num_ptheta
|
||||
self.phi_step = 2 * numpy.pi / self.num_pphi
|
||||
|
||||
def bounds(self, index: Tuple[float, float, float, float, float]) -> Tuple:
|
||||
def bounds(self, index: Tuple[float, ...]) -> Tuple:
|
||||
pthetai, pphii, xi, yi, zi = index
|
||||
|
||||
# For this model, a point is (p_theta, p_phi, sx, sx, sy, w).
|
||||
@ -163,7 +163,7 @@ class FixedMagnitudeDiscretisation():
|
||||
# see https://github.com/numpy/numpy/issues/20706 for why this is a mypy problem.
|
||||
return numpy.ndindex((self.num_ptheta, self.num_pphi, self.num_x, self.num_y, self.num_z)) # type:ignore
|
||||
|
||||
def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, float, float, float, float]) -> scipy.optimize.OptimizeResult:
|
||||
def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, ...]) -> scipy.optimize.OptimizeResult:
|
||||
bounds = self.bounds(index)
|
||||
ptheta_mean = (bounds[0][0] + bounds[1][0]) / 2
|
||||
pphi_mean = (bounds[0][1] + bounds[1][1]) / 2
|
||||
|
@ -1,6 +1,6 @@
|
||||
import numpy
|
||||
import scipy.optimize
|
||||
from typing import Callable, Sequence
|
||||
from typing import Callable, Sequence, Tuple
|
||||
from pdme.measurement import DotMeasurement, OscillatingDipoleArrangement
|
||||
import pdme.util
|
||||
import logging
|
||||
@ -89,3 +89,14 @@ class Model():
|
||||
result = scipy.optimize.least_squares(self.costs(dots), initial, jac=self.jac(dots), ftol=1e-15, gtol=3e-16, xtol=None, bounds=bounds)
|
||||
result.normalised_x = pdme.util.normalise_point_list(result.x, self.point_length())
|
||||
return result
|
||||
|
||||
|
||||
class Discretisation():
|
||||
def bounds(self, index: Tuple[float, ...]) -> Tuple:
|
||||
raise NotImplementedError
|
||||
|
||||
def all_indices(self) -> numpy.ndindex:
|
||||
raise NotImplementedError
|
||||
|
||||
def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple) -> scipy.optimize.OptimizeResult:
|
||||
raise NotImplementedError
|
||||
|
@ -2,7 +2,7 @@ import numpy
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Tuple
|
||||
import scipy.optimize
|
||||
from pdme.model.model import Model
|
||||
from pdme.model.model import Model, Discretisation
|
||||
from pdme.measurement import DotMeasurement
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ class UnrestrictedModel(Model):
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnrestrictedDiscretisation():
|
||||
class UnrestrictedDiscretisation(Discretisation):
|
||||
'''
|
||||
Representation of a discretisation of a UnrestrictedModel.
|
||||
Also captures a rough maximum value of dipole.
|
||||
@ -113,7 +113,7 @@ class UnrestrictedDiscretisation():
|
||||
self.py_step = 2 * self.max_p / self.num_py
|
||||
self.pz_step = 2 * self.max_p / self.num_pz
|
||||
|
||||
def bounds(self, index: Tuple[float, float, float, float, float, float]) -> Tuple:
|
||||
def bounds(self, index: Tuple[float, ...]) -> Tuple:
|
||||
pxi, pyi, pzi, xi, yi, zi = index
|
||||
|
||||
# For this model, a point is (px, py, pz, sx, sx, sy, w).
|
||||
@ -135,7 +135,7 @@ class UnrestrictedDiscretisation():
|
||||
# see https://github.com/numpy/numpy/issues/20706 for why this is a mypy problem.
|
||||
return numpy.ndindex((self.num_px, self.num_py, self.num_pz, self.num_x, self.num_y, self.num_z)) # type:ignore
|
||||
|
||||
def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, float, float, float, float, float]) -> scipy.optimize.OptimizeResult:
|
||||
def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, ...]) -> scipy.optimize.OptimizeResult:
|
||||
bounds = self.bounds(index)
|
||||
px_mean = (bounds[0][0] + bounds[1][0]) / 2
|
||||
py_mean = (bounds[0][1] + bounds[1][1]) / 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user