Adds sol method to the model and removes skipped test
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
from typing import Callable, Sequence
|
from typing import Callable, Sequence
|
||||||
import numpy
|
import numpy
|
||||||
|
import scipy.optimize
|
||||||
from pathfinder.model.dot import DotMeasurement
|
from pathfinder.model.dot import DotMeasurement
|
||||||
|
|
||||||
|
|
||||||
@@ -35,3 +35,7 @@ class DotDipoleModel():
|
|||||||
return numpy.array([dot.jac(pts) for dot in self.dots])
|
return numpy.array([dot.jac(pts) for dot in self.dots])
|
||||||
|
|
||||||
return jac_to_return
|
return jac_to_return
|
||||||
|
|
||||||
|
def sol(self):
|
||||||
|
initial = numpy.zeros(6 * self.n)
|
||||||
|
return scipy.optimize.least_squares(self.costs(), initial, jac=self.jac(), ftol=1e-15, gtol=3e-16)
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import numpy
|
import numpy
|
||||||
import pathfinder.model as model
|
import pathfinder.model as model
|
||||||
import scipy.optimize
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def test_dotdipolemodel_repr():
|
def test_dotdipolemodel_repr():
|
||||||
@@ -49,8 +47,7 @@ def print_result(msg, result):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="old")
|
def test_dot_dipole_model_solution():
|
||||||
def test_dot_dipole_model_jac():
|
|
||||||
v1 = -0.05547767706400186526225414
|
v1 = -0.05547767706400186526225414
|
||||||
v2 = -0.06018573388098888319642888
|
v2 = -0.06018573388098888319642888
|
||||||
v3 = -0.06364032191901859480476888
|
v3 = -0.06364032191901859480476888
|
||||||
@@ -59,6 +56,7 @@ def test_dot_dipole_model_jac():
|
|||||||
v6 = -0.05735489606460216
|
v6 = -0.05735489606460216
|
||||||
v7 = -0.07237320672886623
|
v7 = -0.07237320672886623
|
||||||
v8 = -0.1082531754730548
|
v8 = -0.1082531754730548
|
||||||
|
v9 = -0.04471694936155558
|
||||||
|
|
||||||
c1 = model.DotMeasurement(v1, [0, 0, 1])
|
c1 = model.DotMeasurement(v1, [0, 0, 1])
|
||||||
c2 = model.DotMeasurement(v2, [0, 0, 2])
|
c2 = model.DotMeasurement(v2, [0, 0, 2])
|
||||||
@@ -68,61 +66,11 @@ def test_dot_dipole_model_jac():
|
|||||||
c6 = model.DotMeasurement(v6, [0, 0, 6])
|
c6 = model.DotMeasurement(v6, [0, 0, 6])
|
||||||
c7 = model.DotMeasurement(v7, [1, 1, 7])
|
c7 = model.DotMeasurement(v7, [1, 1, 7])
|
||||||
c8 = model.DotMeasurement(v8, [1, 2, 3])
|
c8 = model.DotMeasurement(v8, [1, 2, 3])
|
||||||
|
c9 = model.DotMeasurement(v9, [0, -1, 0])
|
||||||
|
|
||||||
mod = model.DotDipoleModel([c1, c2, c3, c4, c5, c6], 1)
|
expected_result = numpy.array([1, 3, 5, 5, 6, 7])
|
||||||
res = scipy.optimize.least_squares(mod.costs(), numpy.array([1, 2, 3, 4, 5, 6]), jac=mod.jac(), ftol=1e-12)
|
|
||||||
print_result("6 dots, least sq", res)
|
|
||||||
mod2 = model.DotDipoleModel([c1, c2, c3, c4, c5, c6, c7, c8], 1)
|
|
||||||
res2 = scipy.optimize.least_squares(mod2.costs(), numpy.array([0, 0, 0, 0, 0, 0]), jac=mod2.jac(), ftol=1e-12, gtol=3e-16)
|
|
||||||
print_result("7 dots, least squares", res2)
|
|
||||||
print(mod2.costs()(res2.x))
|
|
||||||
print(mod2.costs()(numpy.array([1, 3, 5, 5, 6, 7])))
|
|
||||||
|
|
||||||
|
mod = model.DotDipoleModel([c1, c2, c3, c4, c5, c6, c7, c8, c9], 1)
|
||||||
@pytest.mark.skip(reason="bad test")
|
res = mod.sol()
|
||||||
def test_dot_2dipoles_model_jac():
|
assert res.success, "The solution for a single dipole should have succeeded."
|
||||||
dots_12andone = [
|
numpy.testing.assert_allclose(res.x, expected_result, err_msg="Dipole wasn't as expected.")
|
||||||
model.DotMeasurement(-0.1978319326584865, [-4, -1, 0]),
|
|
||||||
model.DotMeasurement(-0.1273171638293727, [-4, -5, 0]),
|
|
||||||
model.DotMeasurement(-0.05545025617224288, [-4, -9, 0]),
|
|
||||||
model.DotMeasurement(-0.4960209997369774, [-1, -1, 0]),
|
|
||||||
model.DotMeasurement(-0.1763373289754278, [-1, -5, 0]),
|
|
||||||
model.DotMeasurement(-0.04946346672578462, [-1, -9, 0]),
|
|
||||||
model.DotMeasurement(-0.5633156098386561, [1, -1, 0]),
|
|
||||||
model.DotMeasurement(-0.113765134433174, [1, -5, 0]),
|
|
||||||
model.DotMeasurement(-0.0294893572499722, [1, -9, 0]),
|
|
||||||
model.DotMeasurement(0.7794941616360612, [4, -1, 0]),
|
|
||||||
model.DotMeasurement(0.1110683086477768, [4, -5, 0]),
|
|
||||||
model.DotMeasurement(0.01183272220840589, [4, -9, 0]),
|
|
||||||
model.DotMeasurement(-0.1096485462119833, [1, 1, 1]),
|
|
||||||
model.DotMeasurement(-0.3925851888077783, [1, 1, -1])
|
|
||||||
]
|
|
||||||
# dots = [
|
|
||||||
# model.Dot(-0.1978319326584865, [-4, -1, 0]),
|
|
||||||
# model.Dot(-0.1273171638293727, [-4, -5, 0]),
|
|
||||||
# model.Dot(-0.05545025617224288, [-4, -9, 0]),
|
|
||||||
# model.Dot(-0.4960209997369774, [-1, -1, 0]),
|
|
||||||
# model.Dot(-0.1763373289754278, [-1, -5, 0]),
|
|
||||||
# model.Dot(-0.04946346672578462, [-1, -9, 0]),
|
|
||||||
# model.Dot(-0.5633156098386561, [1, -1, 0]),
|
|
||||||
# model.Dot(-0.113765134433174, [1, -5, 0]),
|
|
||||||
# model.Dot(-0.0294893572499722, [1, -9, 0]),
|
|
||||||
# model.Dot(0.7794941616360612, [4, -1, 0]),
|
|
||||||
# model.Dot(0.1110683086477768, [4, -5, 0]),
|
|
||||||
# model.Dot(0.01183272220840589, [4, -9, 0]),
|
|
||||||
# model.Dot(-0.0261092728062841, [-4, -13, 0]),
|
|
||||||
# model.Dot(-0.02035559593904894, [-1, -13, 0]),
|
|
||||||
# model.Dot(-0.01296894715810522, [1, -13, 0]),
|
|
||||||
# model.Dot(0.0003306001626435171, [4, -13, 0]),
|
|
||||||
# model.Dot(0.2612817068810759, [7, -1, 0]),
|
|
||||||
# model.Dot(0.1203841445911355, [7, -5, 0]),
|
|
||||||
# model.Dot(0.03425933872931543, [7, -9, 0]),
|
|
||||||
# model.Dot(0.01068547688208644, [7, -13, 0]),
|
|
||||||
# ]
|
|
||||||
|
|
||||||
mod = model.DotDipoleModel(dots_12andone, 2)
|
|
||||||
res = scipy.optimize.least_squares(mod.costs(), numpy.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), jac=mod.jac(), ftol=1e-12)
|
|
||||||
print_result("6 dots, least sq", res)
|
|
||||||
print(mod.costs()(res.x))
|
|
||||||
for val in res.x:
|
|
||||||
print(val)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user