Adds some gradient descent stuff
This commit is contained in:
parent
472ff166cc
commit
026519697b
25
pathfinder/gradient_descent.py
Normal file
25
pathfinder/gradient_descent.py
Normal file
@ -0,0 +1,25 @@
|
||||
import numpy
|
||||
|
||||
|
||||
def find_sols(cost_array, jacobian, step_size=0.001, max_iterations=5, initial=(0, 0), desired_cost=1e-6, step_size_tries=10):
|
||||
desired_cost_squared = desired_cost**2
|
||||
current = numpy.array(initial)
|
||||
iterations = 0
|
||||
curr_cost = numpy.array(cost_array(current))
|
||||
while iterations < max_iterations:
|
||||
curr_cost = numpy.array(cost_array(current))
|
||||
if curr_cost.dot(curr_cost) < desired_cost_squared:
|
||||
return ("Finished early", current)
|
||||
gradient = .5 * numpy.matmul(numpy.transpose(jacobian(current)), curr_cost)
|
||||
next = current - step_size * gradient
|
||||
next_cost = numpy.array(cost_array(next))
|
||||
tries = step_size_tries
|
||||
current_step = step_size
|
||||
while tries > 0 and next_cost.dot(next_cost) > curr_cost.dot(curr_cost):
|
||||
current_step = current_step / 10
|
||||
next = current - current_step / 10 * gradient
|
||||
next_cost = numpy.array(cost_array(next))
|
||||
tries -= 1
|
||||
current = next
|
||||
iterations += 1
|
||||
return ("Ran out of iterations", current)
|
@ -0,0 +1,9 @@
|
||||
import numpy # type: ignore
|
||||
|
||||
|
||||
def circle_cost(x: float, y: float) -> float:
|
||||
return 4 - x**2 - y**2
|
||||
|
||||
|
||||
def circle_gradient(x: float, y: float) -> numpy.ndarray:
|
||||
return numpy.array([-2 * x, -2 * y])
|
238
poetry.lock
generated
238
poetry.lock
generated
@ -1,107 +1,80 @@
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Atomic file writes."
|
||||
marker = "sys_platform == \"win32\""
|
||||
name = "atomicwrites"
|
||||
version = "1.4.0"
|
||||
description = "Atomic file writes."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
version = "1.4.0"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Classes Without Boilerplate"
|
||||
name = "attrs"
|
||||
version = "21.2.0"
|
||||
description = "Classes Without Boilerplate"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
version = "21.2.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"]
|
||||
dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "furo", "sphinx", "sphinx-notfound-page", "pre-commit"]
|
||||
docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
|
||||
tests = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"]
|
||||
tests_no_zope = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"]
|
||||
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface"]
|
||||
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins"]
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Cross-platform colored terminal text."
|
||||
marker = "sys_platform == \"win32\""
|
||||
name = "colorama"
|
||||
version = "0.4.4"
|
||||
description = "Cross-platform colored terminal text."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
version = "0.4.4"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Code coverage measurement for Python"
|
||||
name = "coverage"
|
||||
version = "5.5"
|
||||
description = "Code coverage measurement for Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4"
|
||||
version = "5.5"
|
||||
|
||||
[package.extras]
|
||||
toml = ["toml"]
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "the modular source code checker: pep8 pyflakes and co"
|
||||
name = "flake8"
|
||||
version = "3.9.2"
|
||||
description = "the modular source code checker: pep8 pyflakes and co"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
||||
version = "3.9.2"
|
||||
|
||||
[package.dependencies]
|
||||
mccabe = ">=0.6.0,<0.7.0"
|
||||
pycodestyle = ">=2.7.0,<2.8.0"
|
||||
pyflakes = ">=2.3.0,<2.4.0"
|
||||
|
||||
[package.dependencies.importlib-metadata]
|
||||
python = "<3.8"
|
||||
version = "*"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Read metadata from Python packages"
|
||||
marker = "python_version < \"3.8\""
|
||||
name = "importlib-metadata"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
version = "4.6.1"
|
||||
|
||||
[package.dependencies]
|
||||
zipp = ">=0.5"
|
||||
|
||||
[package.dependencies.typing-extensions]
|
||||
python = "<3.8"
|
||||
version = ">=3.6.4"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
|
||||
perf = ["ipython"]
|
||||
testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"]
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "iniconfig: brain-dead simple config-ini parsing"
|
||||
name = "iniconfig"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "1.1.1"
|
||||
|
||||
[[package]]
|
||||
description = "iniconfig: brain-dead simple config-ini parsing"
|
||||
category = "dev"
|
||||
description = "McCabe checker, plugin for flake8"
|
||||
name = "mccabe"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "0.6.1"
|
||||
|
||||
[[package]]
|
||||
name = "mccabe"
|
||||
version = "0.6.1"
|
||||
description = "McCabe checker, plugin for flake8"
|
||||
category = "dev"
|
||||
description = "Optional static typing for Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "mypy"
|
||||
version = "0.790"
|
||||
description = "Optional static typing for Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
version = "0.790"
|
||||
|
||||
[package.dependencies]
|
||||
mypy-extensions = ">=0.4.3,<0.5.0"
|
||||
@ -112,104 +85,103 @@ typing-extensions = ">=3.7.4"
|
||||
dmypy = ["psutil (>=4.0)"]
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Experimental type system extensions for programs checked with the mypy typechecker."
|
||||
name = "mypy-extensions"
|
||||
version = "0.4.3"
|
||||
description = "Experimental type system extensions for programs checked with the mypy typechecker."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "0.4.3"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Core utilities for Python packages"
|
||||
name = "numpy"
|
||||
version = "1.21.1"
|
||||
description = "NumPy is the fundamental package for array computing with Python."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "21.0"
|
||||
description = "Core utilities for Python packages"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
version = "21.0"
|
||||
|
||||
[package.dependencies]
|
||||
pyparsing = ">=2.0.2"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
name = "pluggy"
|
||||
version = "0.13.1"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
version = "0.13.1"
|
||||
|
||||
[package.dependencies]
|
||||
[package.dependencies.importlib-metadata]
|
||||
python = "<3.8"
|
||||
version = ">=0.12"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "tox"]
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "library with cross-python path, ini-parsing, io, code, log facilities"
|
||||
name = "py"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
version = "1.10.0"
|
||||
description = "library with cross-python path, ini-parsing, io, code, log facilities"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Python style guide checker"
|
||||
name = "pycodestyle"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
version = "2.7.0"
|
||||
|
||||
[[package]]
|
||||
description = "Python style guide checker"
|
||||
category = "dev"
|
||||
description = "passive checker of Python programs"
|
||||
name = "pyflakes"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
version = "2.3.1"
|
||||
|
||||
[[package]]
|
||||
name = "pyflakes"
|
||||
version = "2.3.1"
|
||||
description = "passive checker of Python programs"
|
||||
category = "dev"
|
||||
description = "Python parsing module"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
|
||||
[[package]]
|
||||
name = "pyparsing"
|
||||
version = "2.4.7"
|
||||
description = "Python parsing module"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
version = "2.4.7"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
name = "pytest"
|
||||
version = "6.2.4"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
version = "6.2.4"
|
||||
|
||||
[package.dependencies]
|
||||
atomicwrites = ">=1.0"
|
||||
atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
|
||||
attrs = ">=19.2.0"
|
||||
colorama = "*"
|
||||
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=0.12,<1.0.0a1"
|
||||
py = ">=1.8.2"
|
||||
toml = "*"
|
||||
|
||||
[package.dependencies.importlib-metadata]
|
||||
python = "<3.8"
|
||||
version = ">=0.12"
|
||||
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
name = "pytest-cov"
|
||||
version = "2.12.1"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
version = "2.12.1"
|
||||
|
||||
[package.dependencies]
|
||||
coverage = ">=5.2.1"
|
||||
@ -220,45 +192,33 @@ toml = "*"
|
||||
testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "Python Library for Tom's Obvious, Minimal Language"
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
description = "Python Library for Tom's Obvious, Minimal Language"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
version = "0.10.2"
|
||||
|
||||
[[package]]
|
||||
category = "dev"
|
||||
description = "a fork of Python 2 and 3 ast modules with type comment support"
|
||||
name = "typed-ast"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "1.4.3"
|
||||
|
||||
[[package]]
|
||||
description = "a fork of Python 2 and 3 ast modules with type comment support"
|
||||
category = "dev"
|
||||
description = "Backported and Experimental Type Hints for Python 3.5+"
|
||||
name = "typing-extensions"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "3.10.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "3.10.0.0"
|
||||
description = "Backported and Experimental Type Hints for Python 3.5+"
|
||||
category = "dev"
|
||||
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||
marker = "python_version < \"3.8\""
|
||||
name = "zipp"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
version = "3.5.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"]
|
||||
testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-cov", "pytest-enabler (>=1.0.1)", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"]
|
||||
python-versions = "*"
|
||||
|
||||
[metadata]
|
||||
content-hash = "b3de1ca8a251edc8dae2c93e878deb3131846055ac2a6d6b34e43ea20a8f1ea2"
|
||||
python-versions = "^3.7"
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "1ede17bb5e095a5778c8840f952675b181c00d9c9434f11a16652886a77af0a2"
|
||||
|
||||
[metadata.files]
|
||||
atomicwrites = [
|
||||
@ -331,10 +291,6 @@ flake8 = [
|
||||
{file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"},
|
||||
{file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"},
|
||||
]
|
||||
importlib-metadata = [
|
||||
{file = "importlib_metadata-4.6.1-py3-none-any.whl", hash = "sha256:9f55f560e116f8643ecf2922d9cd3e1c7e8d52e683178fecd9d08f6aa357e11e"},
|
||||
{file = "importlib_metadata-4.6.1.tar.gz", hash = "sha256:079ada16b7fc30dfbb5d13399a5113110dab1aa7c2bc62f66af75f0b717c8cac"},
|
||||
]
|
||||
iniconfig = [
|
||||
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
|
||||
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
|
||||
@ -363,6 +319,36 @@ mypy-extensions = [
|
||||
{file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
|
||||
{file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
|
||||
]
|
||||
numpy = [
|
||||
{file = "numpy-1.21.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38e8648f9449a549a7dfe8d8755a5979b45b3538520d1e735637ef28e8c2dc50"},
|
||||
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fd7d7409fa643a91d0a05c7554dd68aa9c9bb16e186f6ccfe40d6e003156e33a"},
|
||||
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a75b4498b1e93d8b700282dc8e655b8bd559c0904b3910b144646dbbbc03e062"},
|
||||
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1412aa0aec3e00bc23fbb8664d76552b4efde98fb71f60737c83efbac24112f1"},
|
||||
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e46ceaff65609b5399163de5893d8f2a82d3c77d5e56d976c8b5fb01faa6b671"},
|
||||
{file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c6a2324085dd52f96498419ba95b5777e40b6bcbc20088fddb9e8cbb58885e8e"},
|
||||
{file = "numpy-1.21.1-cp37-cp37m-win32.whl", hash = "sha256:73101b2a1fef16602696d133db402a7e7586654682244344b8329cdcbbb82172"},
|
||||
{file = "numpy-1.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7a708a79c9a9d26904d1cca8d383bf869edf6f8e7650d85dbc77b041e8c5a0f8"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95b995d0c413f5d0428b3f880e8fe1660ff9396dcd1f9eedbc311f37b5652e16"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:635e6bd31c9fb3d475c8f44a089569070d10a9ef18ed13738b03049280281267"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a3d5fb89bfe21be2ef47c0614b9c9c707b7362386c9a3ff1feae63e0267ccb6"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a326af80e86d0e9ce92bcc1e65c8ff88297de4fa14ee936cb2293d414c9ec63"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:791492091744b0fe390a6ce85cc1bf5149968ac7d5f0477288f78c89b385d9af"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0318c465786c1f63ac05d7c4dbcecd4d2d7e13f0959b01b534ea1e92202235c5"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a513bd9c1551894ee3d31369f9b07460ef223694098cf27d399513415855b68"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:91c6f5fc58df1e0a3cc0c3a717bb3308ff850abdaa6d2d802573ee2b11f674a8"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-win32.whl", hash = "sha256:978010b68e17150db8765355d1ccdd450f9fc916824e8c4e35ee620590e234cd"},
|
||||
{file = "numpy-1.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:9749a40a5b22333467f02fe11edc98f022133ee1bfa8ab99bda5e5437b831214"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d7a4aeac3b94af92a9373d6e77b37691b86411f9745190d2c351f410ab3a791f"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9e7912a56108aba9b31df688a4c4f5cb0d9d3787386b87d504762b6754fbb1b"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25b40b98ebdd272bc3020935427a4530b7d60dfbe1ab9381a39147834e985eac"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a92c5aea763d14ba9d6475803fc7904bda7decc2a0a68153f587ad82941fec1"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a0f648eb28bae4bcb204e6fd14603de2908de982e761a2fc78efe0f19e96e1"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01f28075a92eede918b965e86e8f0ba7b7797a95aa8d35e1cc8821f5fc3ad6a"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-win32.whl", hash = "sha256:88c0b89ad1cc24a5efbb99ff9ab5db0f9a86e9cc50240177a571fbe9c2860ac2"},
|
||||
{file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"},
|
||||
{file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"},
|
||||
{file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"},
|
||||
]
|
||||
packaging = [
|
||||
{file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"},
|
||||
{file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"},
|
||||
@ -436,7 +422,3 @@ typing-extensions = [
|
||||
{file = "typing_extensions-3.10.0.0-py3-none-any.whl", hash = "sha256:779383f6086d90c99ae41cf0ff39aac8a7937a9283ce0a414e5dd782f4c94a84"},
|
||||
{file = "typing_extensions-3.10.0.0.tar.gz", hash = "sha256:50b6f157849174217d0656f99dc82fe932884fb250826c18350e159ec6cdf342"},
|
||||
]
|
||||
zipp = [
|
||||
{file = "zipp-3.5.0-py3-none-any.whl", hash = "sha256:957cfda87797e389580cb8b9e3870841ca991e2125350677b2ca83a0e99390a3"},
|
||||
{file = "zipp-3.5.0.tar.gz", hash = "sha256:f5812b1e007e48cff63449a5e9f4e7ebea716b4111f9c4f9a645f91d579bf0c4"},
|
||||
]
|
||||
|
@ -5,7 +5,8 @@ description = ""
|
||||
authors = ["Deepak <dmallubhotla+github@gmail.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.7"
|
||||
python = "^3.8"
|
||||
numpy = "^1.21.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = ">=6"
|
||||
@ -16,3 +17,6 @@ mypy = "^0.790"
|
||||
[build-system]
|
||||
requires = ["poetry>=0.12"]
|
||||
build-backend = "poetry.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
29
scripts/patch.sh
Normal file
29
scripts/patch.sh
Normal file
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
set -Eeuo pipefail
|
||||
|
||||
if [ -z "$(git status --porcelain)" ]; then
|
||||
# Working directory clean
|
||||
branch_name=$(git symbolic-ref -q HEAD)
|
||||
branch_name=${branch_name##refs/heads/}
|
||||
branch_name=${branch_name:-HEAD}
|
||||
|
||||
poetry version patch
|
||||
version=`sed 's/version = "\([0-9]*.[0-9]*.[0-9]*\)"/\1/p' -n <pyproject.toml`
|
||||
read -p "Create commit for version $version? " -n 1 -r
|
||||
echo # (optional) move to a new line
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]
|
||||
then
|
||||
# do dangerous stuff
|
||||
echo "Creating a new patch"
|
||||
git add pyproject.toml
|
||||
git commit -m "Created version $version"
|
||||
git tag -a "$version" -m "patch.sh created version $version"
|
||||
git push --tags
|
||||
else
|
||||
echo "Surrendering, clean up by reverting pyproject.toml..."
|
||||
exit 2
|
||||
fi
|
||||
else
|
||||
echo "Can't create patch version, working tree unclean..."
|
||||
exit 1
|
||||
fi
|
110
tests/test_gradient_descent.py
Normal file
110
tests/test_gradient_descent.py
Normal file
@ -0,0 +1,110 @@
|
||||
import numpy
|
||||
import pathfinder.gradient_descent
|
||||
|
||||
|
||||
def circ_cost(radius, center=(0, 0)):
|
||||
|
||||
def cf(pt):
|
||||
pt2 = numpy.array(pt) - numpy.array(center)
|
||||
return (radius**2 - pt2.dot(pt2))
|
||||
|
||||
return cf
|
||||
|
||||
|
||||
def test_circ_cost():
|
||||
cost = circ_cost(5)
|
||||
actual = cost([3, 4])
|
||||
expected = 0
|
||||
assert actual == expected
|
||||
|
||||
cost = circ_cost(13, [12, 5])
|
||||
actual = cost([0, 0])
|
||||
expected = 0
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_find_sols():
|
||||
c1 = circ_cost(5)
|
||||
c2 = circ_cost(13, [8, -8])
|
||||
|
||||
def costs(pt):
|
||||
return numpy.array(
|
||||
[c1(pt), c2(pt)]
|
||||
)
|
||||
|
||||
def jac(pt):
|
||||
x, y = pt
|
||||
return numpy.array([[-2 * x, -2 * y], [-2 * (x - 8), -2 * (y + 8)]])
|
||||
|
||||
print(costs([3, 4]))
|
||||
result = pathfinder.gradient_descent.find_sols(costs, jac, step_size=0.01, max_iterations=5000, initial=(2, 10), desired_cost=1e-6)
|
||||
print(result)
|
||||
|
||||
|
||||
def dipole_cost(vn, xn_raw):
|
||||
xn = numpy.array(xn_raw)
|
||||
|
||||
def dc(pt):
|
||||
p = pt[0:3]
|
||||
s = pt[3:6]
|
||||
|
||||
diff = xn - s
|
||||
return (vn * (numpy.linalg.norm(diff)**3)) - p.dot(diff)
|
||||
|
||||
return dc
|
||||
|
||||
|
||||
def test_actual_dipole_finding():
|
||||
def c0(pt):
|
||||
p = pt[0:3]
|
||||
return (p.dot(p) - 35)
|
||||
|
||||
v1 = -0.0554777
|
||||
v2 = -0.0601857
|
||||
v3 = -0.0636403
|
||||
v4 = -0.0648838
|
||||
v5 = -0.0629715
|
||||
|
||||
# the 0 here is a red herring for index purposes later
|
||||
vns = [0, v1, v2, v3, v4, v5]
|
||||
# the 0 here is a red herring
|
||||
xns = [numpy.array([0, 0, n]) for n in range(0, 6)]
|
||||
|
||||
c1 = dipole_cost(v1, [0, 0, 1])
|
||||
c2 = dipole_cost(v2, [0, 0, 2])
|
||||
c3 = dipole_cost(v3, [0, 0, 3])
|
||||
c4 = dipole_cost(v4, [0, 0, 4])
|
||||
c5 = dipole_cost(v5, [0, 0, 5])
|
||||
|
||||
def costs(pt):
|
||||
return numpy.array(
|
||||
[c0(pt), c1(pt), c2(pt), c3(pt), c4(pt), c5(pt)]
|
||||
)
|
||||
|
||||
def jac_row(n):
|
||||
def jr(pt):
|
||||
p = pt[0:3]
|
||||
s = pt[3:6]
|
||||
vn = vns[n]
|
||||
xn = xns[n]
|
||||
diff = xn - s
|
||||
return [
|
||||
-diff[0], -diff[1], -diff[2],
|
||||
p[0] - vn * 3 * numpy.linalg.norm(diff) * (diff)[0],
|
||||
p[1] - vn * 3 * numpy.linalg.norm(diff) * (diff)[1],
|
||||
p[2] - vn * 3 * numpy.linalg.norm(diff) * (diff)[2]
|
||||
]
|
||||
return jr
|
||||
|
||||
def jac(pt):
|
||||
return numpy.array([
|
||||
[2 * pt[0], 2 * pt[1], 2 * pt[2], 0, 0, 0],
|
||||
jac_row(1)(pt),
|
||||
jac_row(2)(pt),
|
||||
jac_row(3)(pt),
|
||||
jac_row(4)(pt),
|
||||
jac_row(5)(pt),
|
||||
])
|
||||
|
||||
result = pathfinder.gradient_descent.find_sols(costs, jac, step_size=0.01, max_iterations=10000, initial=(1, 2, 3, 4, 5, 6), desired_cost=1e-6, step_size_tries=25)
|
||||
print(result)
|
Reference in New Issue
Block a user