Includes tests with scipy methods

This commit is contained in:
Deepak Mallubhotla 2021-08-20 07:52:26 -05:00
parent c51c477579
commit 56e88759fe
Signed by: deepak
GPG Key ID: 64BF53A3369104E7
5 changed files with 227 additions and 4 deletions

View File

@ -0,0 +1,16 @@
class DipoleModel():
'''
Model object represents a physical dipole finding problem.
Parameters
----------
n : int
The number of dipoles expected.
m: int
The number of dots used to sample the potential.
'''
def __init__(self, n, m):
self.n = n
serf.m = m

42
poetry.lock generated
View File

@ -191,6 +191,17 @@ toml = "*"
[package.extras] [package.extras]
testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"] testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
[[package]]
name = "scipy"
version = "1.5.4"
description = "SciPy: Scientific Library for Python"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
numpy = ">=1.14.5"
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.10.2" version = "0.10.2"
@ -217,8 +228,8 @@ python-versions = "*"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.8" python-versions = "^3.8,<3.10"
content-hash = "1ede17bb5e095a5778c8840f952675b181c00d9c9434f11a16652886a77af0a2" content-hash = "bde9b5d449e7257dc8c24675658295cf82950d7ec381d873e936ad7cc4bcf6d8"
[metadata.files] [metadata.files]
atomicwrites = [ atomicwrites = [
@ -381,6 +392,33 @@ pytest-cov = [
{file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"}, {file = "pytest-cov-2.12.1.tar.gz", hash = "sha256:261ceeb8c227b726249b376b8526b600f38667ee314f910353fa318caa01f4d7"},
{file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"}, {file = "pytest_cov-2.12.1-py2.py3-none-any.whl", hash = "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a"},
] ]
scipy = [
{file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"},
{file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"},
{file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"},
{file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"},
{file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"},
{file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"},
{file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"},
{file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"},
{file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"},
{file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"},
{file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"},
{file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"},
{file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"},
{file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"},
{file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"},
{file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"},
{file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"},
{file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"},
{file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"},
{file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"},
{file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"},
{file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"},
{file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"},
{file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"},
{file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"},
]
toml = [ toml = [
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},

View File

@ -5,8 +5,9 @@ description = ""
authors = ["Deepak <dmallubhotla+github@gmail.com>"] authors = ["Deepak <dmallubhotla+github@gmail.com>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8" python = "^3.8,<3.10"
numpy = "^1.21.1" numpy = "^1.21.1"
scipy = "~1.5"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = ">=6" pytest = ">=6"

View File

@ -108,7 +108,10 @@ def test_actual_dipole_finding():
jac_row(5)(pt), jac_row(5)(pt),
]) ])
_, _, result = pathfinder.gradient_descent.find_sols(costs, jac, step_size=1, max_iterations=10000, initial=(1, 2, 3, 4, 5, 6), desired_cost=1e-6, step_size_tries=30) _, iterations, result = pathfinder.gradient_descent.find_sols(costs, jac, step_size=1, max_iterations=10000, initial=(1, 2, 3, 4, 5, 6), desired_cost=1e-6, step_size_tries=30)
print(f"\n\n{iterations} iterations")
print(result)
print("\n")
numpy.testing.assert_allclose( numpy.testing.assert_allclose(
result, (1, 3, 5, 5, 6, 7), result, (1, 3, 5, 5, 6, 7),
rtol=5e-2, err_msg='the result was off', verbose=True rtol=5e-2, err_msg='the result was off', verbose=True

View File

@ -0,0 +1,165 @@
import numpy
import scipy.optimize
import pytest
def circ_cost(radius, center=(0, 0)):
def cf(pt):
pt2 = numpy.array(pt) - numpy.array(center)
return (radius**2 - pt2.dot(pt2))
return cf
def test_circ_cost():
cost = circ_cost(5)
actual = cost([3, 4])
expected = 0
assert actual == expected
cost = circ_cost(13, [12, 5])
actual = cost([0, 0])
expected = 0
assert actual == expected
def test_find_sols():
c1 = circ_cost(5)
c2 = circ_cost(13, [8, -8])
def costs(pt):
return numpy.array(
[c1(pt), c2(pt)]
)
def jac(pt):
x, y = pt
return numpy.array([[-2 * x, -2 * y], [-2 * (x - 8), -2 * (y + 8)]])
print(scipy.optimize.minimize(lambda x: costs(x).dot(costs(x)), numpy.array([1, 2])))
#
# message, iterations, result = pathfinder.gradient_descent.find_sols(costs, jac, step_size=0.01, max_iterations=5000, initial=(2, 10), desired_cost=1e-6)
# numpy.testing.assert_almost_equal(
# result, (3, 4),
# decimal=7, err_msg='the result was off', verbose=True
# )
def dipole_cost(vn, xn_raw):
xn = numpy.array(xn_raw)
def dc(pt):
p = pt[0:3]
s = pt[3:6]
diff = xn - s
return (vn * (numpy.linalg.norm(diff)**3)) - p.dot(diff)
return dc
def test_actual_dipole_finding():
def c0(pt):
p = pt[0:3]
return (p.dot(p) - 35)
v1 = -0.05547767706400186526225414
v2 = -0.06018573388098888319642888
v3 = -0.06364032191901859480476888
v4 = -0.06488383879243851188402150
v5 = -0.06297148063759813929659130
v6 = -0.05735489606460216
v7 = -0.07237320672886623
# the 0 here is a red herring for index purposes later
vns = [0, v1, v2, v3, v4, v5]
# the 0 here is a red herring
xns = [numpy.array([0, 0, n]) for n in range(0, 6)]
# the 0 here is a red herring for index purposes later
vns2 = [0, v1, v2, v3, v4, v5, v6, v7]
# the 0 here is a red herring
xns2 = [numpy.array([0, 0, n]) for n in range(0, 7)]
xns2.append([1, 1, 7])
c1 = dipole_cost(v1, [0, 0, 1])
c2 = dipole_cost(v2, [0, 0, 2])
c3 = dipole_cost(v3, [0, 0, 3])
c4 = dipole_cost(v4, [0, 0, 4])
c5 = dipole_cost(v5, [0, 0, 5])
c6 = dipole_cost(v6, [0, 0, 6])
c6 = dipole_cost(v6, [0, 0, 6])
c7 = dipole_cost(v7, [1, 1, 7])
def costs(pt):
return numpy.array(
[c0(pt), c1(pt), c2(pt), c3(pt), c4(pt), c5(pt)]
)
def costs2(pt):
return numpy.array(
[c0(pt), c1(pt), c2(pt), c3(pt), c4(pt), c5(pt), c6(pt), c7(pt)]
)
def jac_row(n):
def jr(pt):
p = pt[0:3]
s = pt[3:6]
vn = vns2[n]
xn = xns2[n]
diff = xn - s
return [
-diff[0], -diff[1], -diff[2],
p[0] - vn * 3 * numpy.linalg.norm(diff) * (diff)[0],
p[1] - vn * 3 * numpy.linalg.norm(diff) * (diff)[1],
p[2] - vn * 3 * numpy.linalg.norm(diff) * (diff)[2]
]
return jr
def jac(pt):
return numpy.array([
[2 * pt[0], 2 * pt[1], 2 * pt[2], 0, 0, 0],
jac_row(1)(pt),
jac_row(2)(pt),
jac_row(3)(pt),
jac_row(4)(pt),
jac_row(5)(pt),
])
def jac2(pt):
return numpy.array([
[2 * pt[0], 2 * pt[1], 2 * pt[2], 0, 0, 0],
jac_row(1)(pt),
jac_row(2)(pt),
jac_row(3)(pt),
jac_row(4)(pt),
jac_row(5)(pt),
jac_row(6)(pt),
jac_row(7)(pt),
])
def print_result(msg, result):
print(msg)
print(f"\tResult: {result.x}")
print(f"\tSuccess: {result.success}. {result.message}")
try:
print(f"\tFunc evals: {result.nfev}")
except AttributeError as e:
pass
try:
print(f"\tJacb evals: {result.njev}")
except AttributeError as e:
pass
print("Minimising the squared costs")
print(scipy.optimize.minimize(lambda x: costs(x).dot(costs(x)), numpy.array([1, 2, 3, 4, 5, 6])))
# print(scipy.optimize.broyden1(costs, numpy.array([1, 2, 3, 4, 5, 6])))
# print(scipy.optimize.newton_krylov(costs, numpy.array([1, 2, 3, 4, 5, 6])))
# print(scipy.optimize.anderson(costs, numpy.array([1, 2, 3, 4, 5, 6])))
print_result("Using root", scipy.optimize.root(costs, numpy.array([1, 2, 3, 4, 5, 6])))
print_result("Using root with jacobian", scipy.optimize.root(costs, numpy.array([1, 2, 3, 4, 5, 6]), jac=jac, tol=1e-12))
print_result("Using least squares", scipy.optimize.least_squares(costs, numpy.array([1, 2, 3, 4, 5, 6]), gtol=1e-12))
print_result("Using least squares, with jacobian", scipy.optimize.least_squares(costs, numpy.array([1, 2, 3, 4, 5, 6]), jac=jac, ftol=3e-16, gtol=3e-16, xtol=3e-16))
print_result("Using least squares, with jacobian, lm", scipy.optimize.least_squares(costs, numpy.array([1, 2, 3, 4, 5, 6]), jac=jac, ftol=3e-16, gtol=3e-16, xtol=3e-16, method="lm"))
print_result("Using least squares extra dot", scipy.optimize.least_squares(costs2, numpy.array([1, 2, 3, 4, 5, 6])))
print_result("Using least squares extra dot, with jacobian", scipy.optimize.least_squares(costs2, numpy.array([1, 2, 3, 4, 5, 6]), jac=jac2, ftol=1e-12))
print(scipy.optimize.least_squares(costs2, numpy.array([1, 2, 3, 4, 5, 6]), jac=jac2, ftol=1e-12).x[0])