From f2f76cd8d80ad7de8a8732527629063827265541 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 23 Jan 2022 16:24:53 -0600 Subject: [PATCH] Adds discretisation class --- pdme/model/fixed_dipole_model.py | 8 ++++---- pdme/model/fixed_magnitude_model.py | 8 ++++---- pdme/model/model.py | 13 ++++++++++++- pdme/model/unrestricted_model.py | 8 ++++---- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pdme/model/fixed_dipole_model.py b/pdme/model/fixed_dipole_model.py index cb66190..0f2b6f1 100644 --- a/pdme/model/fixed_dipole_model.py +++ b/pdme/model/fixed_dipole_model.py @@ -3,7 +3,7 @@ import numpy.random from dataclasses import dataclass from typing import Sequence, Tuple import scipy.optimize -from pdme.model.model import Model +from pdme.model.model import Model, Discretisation from pdme.measurement import DotMeasurement, OscillatingDipoleArrangement, OscillatingDipole @@ -74,7 +74,7 @@ class FixedDipoleModel(Model): @dataclass -class FixedDipoleDiscretisation(): +class FixedDipoleDiscretisation(Discretisation): ''' Representation of a discretisation of a FixedDipoleDiscretisation. Also captures a rough maximum value of dipole. @@ -101,7 +101,7 @@ class FixedDipoleDiscretisation(): self.y_step = (self.model.ymax - self.model.ymin) / self.num_y self.z_step = (self.model.zmax - self.model.zmin) / self.num_z - def bounds(self, index: Tuple[float, float, float]) -> Tuple: + def bounds(self, index: Tuple[float, ...]) -> Tuple: xi, yi, zi = index # For this model, a point is (sx, sx, sy, w). @@ -121,7 +121,7 @@ class FixedDipoleDiscretisation(): # see https://github.com/numpy/numpy/issues/20706 for why this is a mypy problem. return numpy.ndindex((self.num_x, self.num_y, self.num_z)) # type:ignore - def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, float, float]) -> scipy.optimize.OptimizeResult: + def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, ...]) -> scipy.optimize.OptimizeResult: bounds = self.bounds(index) sx_mean = (bounds[0][0] + bounds[1][0]) / 2 sy_mean = (bounds[0][1] + bounds[1][1]) / 2 diff --git a/pdme/model/fixed_magnitude_model.py b/pdme/model/fixed_magnitude_model.py index beb3352..3783039 100644 --- a/pdme/model/fixed_magnitude_model.py +++ b/pdme/model/fixed_magnitude_model.py @@ -3,7 +3,7 @@ import numpy.random from dataclasses import dataclass from typing import Sequence, Tuple import scipy.optimize -from pdme.model.model import Model +from pdme.model.model import Model, Discretisation from pdme.measurement import DotMeasurement, OscillatingDipole, OscillatingDipoleArrangement @@ -106,7 +106,7 @@ class FixedMagnitudeModel(Model): @dataclass -class FixedMagnitudeDiscretisation(): +class FixedMagnitudeDiscretisation(Discretisation): ''' Representation of a discretisation of a FixedMagnitudeDiscretisation. Also captures a rough maximum value of dipole. @@ -141,7 +141,7 @@ class FixedMagnitudeDiscretisation(): self.h_step = 2 / self.num_ptheta self.phi_step = 2 * numpy.pi / self.num_pphi - def bounds(self, index: Tuple[float, float, float, float, float]) -> Tuple: + def bounds(self, index: Tuple[float, ...]) -> Tuple: pthetai, pphii, xi, yi, zi = index # For this model, a point is (p_theta, p_phi, sx, sx, sy, w). @@ -163,7 +163,7 @@ class FixedMagnitudeDiscretisation(): # see https://github.com/numpy/numpy/issues/20706 for why this is a mypy problem. return numpy.ndindex((self.num_ptheta, self.num_pphi, self.num_x, self.num_y, self.num_z)) # type:ignore - def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, float, float, float, float]) -> scipy.optimize.OptimizeResult: + def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, ...]) -> scipy.optimize.OptimizeResult: bounds = self.bounds(index) ptheta_mean = (bounds[0][0] + bounds[1][0]) / 2 pphi_mean = (bounds[0][1] + bounds[1][1]) / 2 diff --git a/pdme/model/model.py b/pdme/model/model.py index 3616246..fe0f788 100644 --- a/pdme/model/model.py +++ b/pdme/model/model.py @@ -1,6 +1,6 @@ import numpy import scipy.optimize -from typing import Callable, Sequence +from typing import Callable, Sequence, Tuple from pdme.measurement import DotMeasurement, OscillatingDipoleArrangement import pdme.util import logging @@ -89,3 +89,14 @@ class Model(): result = scipy.optimize.least_squares(self.costs(dots), initial, jac=self.jac(dots), ftol=1e-15, gtol=3e-16, xtol=None, bounds=bounds) result.normalised_x = pdme.util.normalise_point_list(result.x, self.point_length()) return result + + +class Discretisation(): + def bounds(self, index: Tuple[float, ...]) -> Tuple: + raise NotImplementedError + + def all_indices(self) -> numpy.ndindex: + raise NotImplementedError + + def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple) -> scipy.optimize.OptimizeResult: + raise NotImplementedError diff --git a/pdme/model/unrestricted_model.py b/pdme/model/unrestricted_model.py index e2dc0d1..5d34caf 100644 --- a/pdme/model/unrestricted_model.py +++ b/pdme/model/unrestricted_model.py @@ -2,7 +2,7 @@ import numpy from dataclasses import dataclass from typing import Sequence, Tuple import scipy.optimize -from pdme.model.model import Model +from pdme.model.model import Model, Discretisation from pdme.measurement import DotMeasurement @@ -71,7 +71,7 @@ class UnrestrictedModel(Model): @dataclass -class UnrestrictedDiscretisation(): +class UnrestrictedDiscretisation(Discretisation): ''' Representation of a discretisation of a UnrestrictedModel. Also captures a rough maximum value of dipole. @@ -113,7 +113,7 @@ class UnrestrictedDiscretisation(): self.py_step = 2 * self.max_p / self.num_py self.pz_step = 2 * self.max_p / self.num_pz - def bounds(self, index: Tuple[float, float, float, float, float, float]) -> Tuple: + def bounds(self, index: Tuple[float, ...]) -> Tuple: pxi, pyi, pzi, xi, yi, zi = index # For this model, a point is (px, py, pz, sx, sx, sy, w). @@ -135,7 +135,7 @@ class UnrestrictedDiscretisation(): # see https://github.com/numpy/numpy/issues/20706 for why this is a mypy problem. return numpy.ndindex((self.num_px, self.num_py, self.num_pz, self.num_x, self.num_y, self.num_z)) # type:ignore - def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, float, float, float, float, float]) -> scipy.optimize.OptimizeResult: + def solve_for_index(self, dots: Sequence[DotMeasurement], index: Tuple[float, ...]) -> scipy.optimize.OptimizeResult: bounds = self.bounds(index) px_mean = (bounds[0][0] + bounds[1][0]) / 2 py_mean = (bounds[0][1] + bounds[1][1]) / 2