Adds some gradient descent stuff

This commit is contained in:
Deepak Mallubhotla 2021-08-02 08:57:07 -05:00
parent 472ff166cc
commit 026519697b
Signed by: deepak
GPG Key ID: 64BF53A3369104E7
7 changed files with 290 additions and 130 deletions

View File

@ -1,2 +1,3 @@
[flake8]
ignore = W191
ignore = W191, E501
max-line-length = 120

View File

@ -0,0 +1,25 @@
import numpy
def find_sols(cost_array, jacobian, step_size=0.001, max_iterations=5, initial=(0, 0), desired_cost=1e-6, step_size_tries=10):
desired_cost_squared = desired_cost**2
current = numpy.array(initial)
iterations = 0
curr_cost = numpy.array(cost_array(current))
while iterations < max_iterations:
curr_cost = numpy.array(cost_array(current))
if curr_cost.dot(curr_cost) < desired_cost_squared:
return ("Finished early", current)
gradient = .5 * numpy.matmul(numpy.transpose(jacobian(current)), curr_cost)
next = current - step_size * gradient
next_cost = numpy.array(cost_array(next))
tries = step_size_tries
current_step = step_size
while tries > 0 and next_cost.dot(next_cost) > curr_cost.dot(curr_cost):
current_step = current_step / 10
next = current - current_step / 10 * gradient
next_cost = numpy.array(cost_array(next))
tries -= 1
current = next
iterations += 1
return ("Ran out of iterations", current)

View File

@ -0,0 +1,9 @@
import numpy # type: ignore
def circle_cost(x: float, y: float) -> float:
return 4 - x**2 - y**2
def circle_gradient(x: float, y: float) -> numpy.ndarray:
return numpy.array([-2 * x, -2 * y])

238
poetry.lock generated
View File

@ -1,107 +1,80 @@
[[package]]
category = "dev"
description = "Atomic file writes."
marker = "sys_platform == \"win32\""
name = "atomicwrites"
version = "1.4.0"
description = "Atomic file writes."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "1.4.0"
[[package]]
category = "dev"
description = "Classes Without Boilerplate"
name = "attrs"
version = "21.2.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "21.2.0"
[package.extras]
dev = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"]
dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"]
docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
tests = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"]
tests_no_zope = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"]
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"]
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"]
[[package]]
category = "dev"
description = "Cross-platform colored terminal text."
marker = "sys_platform == \"win32\""
name = "colorama"
version = "0.4.4"
description = "Cross-platform colored terminal text."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "0.4.4"
[[package]]
category = "dev"
description = "Code coverage measurement for Python"
name = "coverage"
version = "5.5"
description = "Code coverage measurement for Python"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
version = "5.5"
[package.extras]
toml = ["toml"]
[[package]]
category = "dev"
description = "the modular source code checker: pep8 pyflakes and co"
name = "flake8"
version = "3.9.2"
description = "the modular source code checker: pep8 pyflakes and co"
category = "dev"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
version = "3.9.2"
[package.dependencies]
mccabe = ">=0.6.0,<0.7.0"
pycodestyle = ">=2.7.0,<2.8.0"
pyflakes = ">=2.3.0,<2.4.0"
[package.dependencies.importlib-metadata]
python = "<3.8"
version = "*"
[[package]]
category = "dev"
description = "Read metadata from Python packages"
marker = "python_version < \"3.8\""
name = "importlib-metadata"
optional = false
python-versions = ">=3.6"
version = "4.6.1"
[package.dependencies]
zipp = ">=0.5"
[package.dependencies.typing-extensions]
python = "<3.8"
version = ">=3.6.4"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
perf = ["ipython"]
testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"]
[[package]]
category = "dev"
description = "iniconfig: brain-dead simple config-ini parsing"
name = "iniconfig"
optional = false
python-versions = "*"
version = "1.1.1"
[[package]]
description = "iniconfig: brain-dead simple config-ini parsing"
category = "dev"
description = "McCabe checker, plugin for flake8"
name = "mccabe"
optional = false
python-versions = "*"
version = "0.6.1"
[[package]]
name = "mccabe"
version = "0.6.1"
description = "McCabe checker, plugin for flake8"
category = "dev"
description = "Optional static typing for Python"
optional = false
python-versions = "*"
[[package]]
name = "mypy"
version = "0.790"
description = "Optional static typing for Python"
category = "dev"
optional = false
python-versions = ">=3.5"
version = "0.790"
[package.dependencies]
mypy-extensions = ">=0.4.3,<0.5.0"
@ -112,104 +85,103 @@ typing-extensions = ">=3.7.4"
dmypy = ["psutil (>=4.0)"]
[[package]]
category = "dev"
description = "Experimental type system extensions for programs checked with the mypy typechecker."
name = "mypy-extensions"
version = "0.4.3"
description = "Experimental type system extensions for programs checked with the mypy typechecker."
category = "dev"
optional = false
python-versions = "*"
version = "0.4.3"
[[package]]
category = "dev"
description = "Core utilities for Python packages"
name = "numpy"
version = "1.21.1"
description = "NumPy is the fundamental package for array computing with Python."
category = "main"
optional = false
python-versions = ">=3.7"
[[package]]
name = "packaging"
version = "21.0"
description = "Core utilities for Python packages"
category = "dev"
optional = false
python-versions = ">=3.6"
version = "21.0"
[package.dependencies]
pyparsing = ">=2.0.2"
[[package]]
category = "dev"
description = "plugin and hook calling mechanisms for python"
name = "pluggy"
version = "0.13.1"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "0.13.1"
[package.dependencies]
[package.dependencies.importlib-metadata]
python = "<3.8"
version = ">=0.12"
[package.extras]
dev = ["pre-commit", "tox"]
[[package]]
category = "dev"
description = "library with cross-python path, ini-parsing, io, code, log facilities"
name = "py"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "1.10.0"
description = "library with cross-python path, ini-parsing, io, code, log facilities"
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
category = "dev"
description = "Python style guide checker"
name = "pycodestyle"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "2.7.0"
[[package]]
description = "Python style guide checker"
category = "dev"
description = "passive checker of Python programs"
name = "pyflakes"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "2.3.1"
[[package]]
name = "pyflakes"
version = "2.3.1"
description = "passive checker of Python programs"
category = "dev"
description = "Python parsing module"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "pyparsing"
version = "2.4.7"
description = "Python parsing module"
category = "dev"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
version = "2.4.7"
[[package]]
category = "dev"
description = "pytest: simple powerful testing with Python"
name = "pytest"
version = "6.2.4"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.6"
version = "6.2.4"
[package.dependencies]
atomicwrites = ">=1.0"
atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
attrs = ">=19.2.0"
colorama = "*"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<1.0.0a1"
py = ">=1.8.2"
toml = "*"
[package.dependencies.importlib-metadata]
python = "<3.8"
version = ">=0.12"
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
[[package]]
category = "dev"
description = "Pytest plugin for measuring coverage."
name = "pytest-cov"
version = "2.12.1"
description = "Pytest plugin for measuring coverage."
category = "dev"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "2.12.1"
[package.dependencies]
coverage = ">=5.2.1"
@ -220,45 +192,33 @@ toml = "*"
testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
[[package]]
category = "dev"
description = "Python Library for Tom's Obvious, Minimal Language"
name = "toml"
version = "0.10.2"
description = "Python Library for Tom's Obvious, Minimal Language"
category = "dev"
optional = false
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
version = "0.10.2"
[[package]]
category = "dev"
description = "a fork of Python 2 and 3 ast modules with type comment support"
name = "typed-ast"
optional = false
python-versions = "*"
version = "1.4.3"
[[package]]
description = "a fork of Python 2 and 3 ast modules with type comment support"
category = "dev"
description = "Backported and Experimental Type Hints for Python 3.5+"
name = "typing-extensions"
optional = false
python-versions = "*"
version = "3.10.0.0"
[[package]]
name = "typing-extensions"
version = "3.10.0.0"
description = "Backported and Experimental Type Hints for Python 3.5+"
category = "dev"
description = "Backport of pathlib-compatible object wrapper for zip files"
marker = "python_version < \"3.8\""
name = "zipp"
optional = false
python-versions = ">=3.6"
version = "3.5.0"
[package.extras]
docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"]
python-versions = "*"
[metadata]
content-hash = "b3de1ca8a251edc8dae2c93e878deb3131846055ac2a6d6b34e43ea20a8f1ea2"
python-versions = "^3.7"
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "1ede17bb5e095a5778c8840f952675b181c00d9c9434f11a16652886a77af0a2"
[metadata.files]
atomicwrites = [
@ -331,10 +291,6 @@ flake8 = [
{file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"},
{file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"},
]
importlib-metadata = [
{file = "importlib_metadata-4.6.1-py3-none-any.whl", hash = "sha256:9f55f560e116f8643ecf2922d9cd3e1c7e8d52e683178fecd9d08f6aa357e11e"},
{file = "importlib_metadata-4.6.1.tar.gz", hash = "sha256:079ada16b7fc30dfbb5d13399a5113110dab1aa7c2bc62f66af75f0b717c8cac"},
]
iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
@ -363,6 +319,36 @@ mypy-extensions = [
{file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
{file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
]
numpy = [
{file = "numpy-1.21.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38e8648f9449a549a7dfe8d8755a5979b45b3538520d1e735637ef28e8c2dc50"},
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fd7d7409fa643a91d0a05c7554dd68aa9c9bb16e186f6ccfe40d6e003156e33a"},
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a75b4498b1e93d8b700282dc8e655b8bd559c0904b3910b144646dbbbc03e062"},
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1412aa0aec3e00bc23fbb8664d76552b4efde98fb71f60737c83efbac24112f1"},
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e46ceaff65609b5399163de5893d8f2a82d3c77d5e56d976c8b5fb01faa6b671"},
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c6a2324085dd52f96498419ba95b5777e40b6bcbc20088fddb9e8cbb58885e8e"},
{file = "numpy-1.21.1-cp37-cp37m-win32.whl", hash = "sha256:73101b2a1fef16602696d133db402a7e7586654682244344b8329cdcbbb82172"},
{file = "numpy-1.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7a708a79c9a9d26904d1cca8d383bf869edf6f8e7650d85dbc77b041e8c5a0f8"},
{file = "numpy-1.21.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95b995d0c413f5d0428b3f880e8fe1660ff9396dcd1f9eedbc311f37b5652e16"},
{file = "numpy-1.21.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:635e6bd31c9fb3d475c8f44a089569070d10a9ef18ed13738b03049280281267"},
{file = "numpy-1.21.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a3d5fb89bfe21be2ef47c0614b9c9c707b7362386c9a3ff1feae63e0267ccb6"},
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a326af80e86d0e9ce92bcc1e65c8ff88297de4fa14ee936cb2293d414c9ec63"},
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:791492091744b0fe390a6ce85cc1bf5149968ac7d5f0477288f78c89b385d9af"},
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0318c465786c1f63ac05d7c4dbcecd4d2d7e13f0959b01b534ea1e92202235c5"},
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a513bd9c1551894ee3d31369f9b07460ef223694098cf27d399513415855b68"},
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:91c6f5fc58df1e0a3cc0c3a717bb3308ff850abdaa6d2d802573ee2b11f674a8"},
{file = "numpy-1.21.1-cp38-cp38-win32.whl", hash = "sha256:978010b68e17150db8765355d1ccdd450f9fc916824e8c4e35ee620590e234cd"},
{file = "numpy-1.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:9749a40a5b22333467f02fe11edc98f022133ee1bfa8ab99bda5e5437b831214"},
{file = "numpy-1.21.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d7a4aeac3b94af92a9373d6e77b37691b86411f9745190d2c351f410ab3a791f"},
{file = "numpy-1.21.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9e7912a56108aba9b31df688a4c4f5cb0d9d3787386b87d504762b6754fbb1b"},
{file = "numpy-1.21.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25b40b98ebdd272bc3020935427a4530b7d60dfbe1ab9381a39147834e985eac"},
{file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a92c5aea763d14ba9d6475803fc7904bda7decc2a0a68153f587ad82941fec1"},
{file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a0f648eb28bae4bcb204e6fd14603de2908de982e761a2fc78efe0f19e96e1"},
{file = "numpy-1.21.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01f28075a92eede918b965e86e8f0ba7b7797a95aa8d35e1cc8821f5fc3ad6a"},
{file = "numpy-1.21.1-cp39-cp39-win32.whl", hash = "sha256:88c0b89ad1cc24a5efbb99ff9ab5db0f9a86e9cc50240177a571fbe9c2860ac2"},
{file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"},
{file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"},
{file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"},
]
packaging = [
{file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"},
{file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"},
@ -436,7 +422,3 @@ typing-extensions = [
{file = "typing_extensions-3.10.0.0-py3-none-any.whl", hash = "sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84"},
{file = "typing_extensions-3.10.0.0.tar.gz", hash = "sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342"},
]
zipp = [
{file = "zipp-3.5.0-py3-none-any.whl", hash = "sha256:957cfda87797e389580cb8b9e3870841ca991e2125350677b2ca83a0e99390a3"},
{file = "zipp-3.5.0.tar.gz", hash = "sha256:f5812b1e007e48cff63449a5e9f4e7ebea716b4111f9c4f9a645f91d579bf0c4"},
]

View File

@ -5,7 +5,8 @@ description = ""
authors = ["Deepak <dmallubhotla+github@gmail.com>"]
[tool.poetry.dependencies]
python = "^3.7"
python = "^3.8"
numpy = "^1.21.1"
[tool.poetry.dev-dependencies]
pytest = ">=6"
@ -16,3 +17,6 @@ mypy = "^0.790"
[build-system]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"
[tool.pytest.ini_options]
testpaths = ["tests"]

29
scripts/patch.sh Normal file
View File

@ -0,0 +1,29 @@
#!/usr/bin/env bash
set -Eeuo pipefail
if [ -z "$(git status --porcelain)" ]; then
# Working directory clean
branch_name=$(git symbolic-ref -q HEAD)
branch_name=${branch_name##refs/heads/}
branch_name=${branch_name:-HEAD}
poetry version patch
version=`sed 's/version = "\([0-9]*.[0-9]*.[0-9]*\)"/\1/p' -n <pyproject.toml`
read -p "Create commit for version $version? " -n 1 -r
echo # (optional) move to a new line
if [[ $REPLY =~ ^[Yy]$ ]]
then
# do dangerous stuff
echo "Creating a new patch"
git add pyproject.toml
git commit -m "Created version $version"
git tag -a "$version" -m "patch.sh created version $version"
git push --tags
else
echo "Surrendering, clean up by reverting pyproject.toml..."
exit 2
fi
else
echo "Can't create patch version, working tree unclean..."
exit 1
fi

View File

@ -0,0 +1,110 @@
import numpy
import pathfinder.gradient_descent
def circ_cost(radius, center=(0, 0)):
def cf(pt):
pt2 = numpy.array(pt) - numpy.array(center)
return (radius**2 - pt2.dot(pt2))
return cf
def test_circ_cost():
cost = circ_cost(5)
actual = cost([3, 4])
expected = 0
assert actual == expected
cost = circ_cost(13, [12, 5])
actual = cost([0, 0])
expected = 0
assert actual == expected
def test_find_sols():
c1 = circ_cost(5)
c2 = circ_cost(13, [8, -8])
def costs(pt):
return numpy.array(
[c1(pt), c2(pt)]
)
def jac(pt):
x, y = pt
return numpy.array([[-2 * x, -2 * y], [-2 * (x - 8), -2 * (y + 8)]])
print(costs([3, 4]))
result = pathfinder.gradient_descent.find_sols(costs, jac, step_size=0.01, max_iterations=5000, initial=(2, 10), desired_cost=1e-6)
print(result)
def dipole_cost(vn, xn_raw):
xn = numpy.array(xn_raw)
def dc(pt):
p = pt[0:3]
s = pt[3:6]
diff = xn - s
return (vn * (numpy.linalg.norm(diff)**3)) - p.dot(diff)
return dc
def test_actual_dipole_finding():
def c0(pt):
p = pt[0:3]
return (p.dot(p) - 35)
v1 = -0.0554777
v2 = -0.0601857
v3 = -0.0636403
v4 = -0.0648838
v5 = -0.0629715
# the 0 here is a red herring for index purposes later
vns = [0, v1, v2, v3, v4, v5]
# the 0 here is a red herring
xns = [numpy.array([0, 0, n]) for n in range(0, 6)]
c1 = dipole_cost(v1, [0, 0, 1])
c2 = dipole_cost(v2, [0, 0, 2])
c3 = dipole_cost(v3, [0, 0, 3])
c4 = dipole_cost(v4, [0, 0, 4])
c5 = dipole_cost(v5, [0, 0, 5])
def costs(pt):
return numpy.array(
[c0(pt), c1(pt), c2(pt), c3(pt), c4(pt), c5(pt)]
)
def jac_row(n):
def jr(pt):
p = pt[0:3]
s = pt[3:6]
vn = vns[n]
xn = xns[n]
diff = xn - s
return [
-diff[0], -diff[1], -diff[2],
p[0] - vn * 3 * numpy.linalg.norm(diff) * (diff)[0],
p[1] - vn * 3 * numpy.linalg.norm(diff) * (diff)[1],
p[2] - vn * 3 * numpy.linalg.norm(diff) * (diff)[2]
]
return jr
def jac(pt):
return numpy.array([
[2 * pt[0], 2 * pt[1], 2 * pt[2], 0, 0, 0],
jac_row(1)(pt),
jac_row(2)(pt),
jac_row(3)(pt),
jac_row(4)(pt),
jac_row(5)(pt),
])
result = pathfinder.gradient_descent.find_sols(costs, jac, step_size=0.01, max_iterations=10000, initial=(1, 2, 3, 4, 5, 6), desired_cost=1e-6, step_size_tries=25)
print(result)