Adds bayes run

This commit is contained in:
Deepak Mallubhotla 2022-01-24 10:18:12 -06:00
parent 027797448e
commit 3397b3e2f9
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
4 changed files with 167 additions and 6 deletions

View File

@ -1,9 +1,13 @@
import logging import logging
from deepdog.meta import __version__ from deepdog.meta import __version__
from deepdog.bayes_run import BayesRun
def get_version(): def get_version():
return __version__ return __version__
__all__ = ["get_version", "BayesRun"]
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())

108
deepdog/bayes_run.py Normal file
View File

@ -0,0 +1,108 @@
import pdme.model
from typing import Sequence, Tuple, List
import datetime
import itertools
import csv
import logging
import numpy
import scipy.optimize
import multiprocessing
# TODO: remove hardcode
COST_THRESHOLD = 1e-10
# TODO: It's garbage to have this here duplicated from pdme.
DotInput = Tuple[numpy.typing.ArrayLike, float]
_logger = logging.getLogger(__name__)
def get_a_result(discretisation, dots, index) -> Tuple[Tuple[int, ...], scipy.optimize.OptimizeResult]:
return (index, discretisation.solve_for_index(dots, index))
class BayesRun():
'''
A single Bayes run for a given set of dots.
Parameters
----------
dot_inputs : Sequence[DotInput]
The dot inputs for this bayes run.
discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)]
The models to evaluate.
actual_model_discretisation : pdme.model.Discretisation
The discretisation for the model which is actually correct.
filename_slug : str
The filename slug to include.
run_count: int
The number of runs to do.
'''
def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int) -> None:
self.dot_inputs = dot_inputs
self.discretisations = [disc for (_, disc) in discretisations_with_names]
self.model_names = [name for (name, _) in discretisations_with_names]
self.actual_model = actual_model
self.model_count = len(self.discretisations)
self.run_count = run_count
self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"]
self.compensate_zeros = True
for name in self.model_names:
self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"])
self.probabilities = [1 / self.model_count] * self.model_count
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
self.filename = f"{timestamp}-{filename_slug}.csv"
def go(self) -> None:
with open(self.filename, "a", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
writer.writeheader()
for run in range(1, self.run_count + 1):
dipoles = self.actual_model.get_dipoles(run)
dots = dipoles.get_dot_measurements(self.dot_inputs)
_logger.info(f"Going to work on dipole at {dipoles.dipoles}")
results = []
_logger.debug("Going to iterate over discretisations now")
for disc_count, discretisation in enumerate(self.discretisations):
_logger.debug(f"Doing discretisation #{disc_count}")
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
results.append(pool.starmap(get_a_result, zip(itertools.repeat(discretisation), itertools.repeat(dots), discretisation.all_indices())))
_logger.debug("Done, constructing output now")
row = {
"dipole_moment": dipoles.dipoles[0].p,
"dipole_location": dipoles.dipoles[0].s,
"dipole_frequency": dipoles.dipoles[0].w
}
successes: List[int] = []
for model_index, (name, result) in enumerate(zip(self.model_names, results)):
count = 0
success = 0
for idx, val in result:
count += 1
if val.success and val.cost <= COST_THRESHOLD:
success += 1
row[f"{name}_success"] = success
row[f"{name}_count"] = count
successes.append(max(success, 1))
success_weight = sum([succ * prob for succ, prob in zip(successes, self.probabilities)])
new_probabilities = [succ * old_prob / success_weight for succ, old_prob in zip(successes, self.probabilities)]
self.probabilities = new_probabilities
for name, probability in zip(self.model_names, self.probabilities):
row[f"{name}_prob"] = probability
_logger.info(row)
with open(self.filename, "a", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
writer.writerow(row)

47
poetry.lock generated
View File

@ -117,7 +117,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
[[package]] [[package]]
name = "pdme" name = "pdme"
version = "0.0.1" version = "0.4.1"
description = "Python dipole model evaluator" description = "Python dipole model evaluator"
category = "main" category = "main"
optional = false optional = false
@ -125,6 +125,7 @@ python-versions = ">=3.8,<3.10"
[package.dependencies] [package.dependencies]
numpy = ">=1.21.1,<2.0.0" numpy = ">=1.21.1,<2.0.0"
scipy = ">=1.5,<1.6"
[[package]] [[package]]
name = "pluggy" name = "pluggy"
@ -209,6 +210,17 @@ pytest = ">=4.6"
[package.extras] [package.extras]
testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"] testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
[[package]]
name = "scipy"
version = "1.5.4"
description = "SciPy: Scientific Library for Python"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
numpy = ">=1.14.5"
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.10.2" version = "0.10.2"
@ -236,7 +248,7 @@ python-versions = ">=3.6"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.8,<3.10" python-versions = "^3.8,<3.10"
content-hash = "380c692a00857c85c6767daa93b458f8d768a6b05dbd94a2e14fe84cc9d2f3ea" content-hash = "d0cc6fcaa5f489054917fd097c7f86bb314b4f7b76c14cddecb9cd9ef1a24d72"
[metadata.files] [metadata.files]
atomicwrites = [ atomicwrites = [
@ -367,8 +379,8 @@ packaging = [
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
] ]
pdme = [ pdme = [
{file = "pdme-0.0.1-py3-none-any.whl", hash = "sha256:f2b8fc969de63e66681f012a1f04eebb2b65e08ba5315dfbb8a19acbadc57858"}, {file = "pdme-0.4.1-py3-none-any.whl", hash = "sha256:c8a6b6a9755094a89f55a46b42ba2531424805747b8e23e078928c880721d84f"},
{file = "pdme-0.0.1.tar.gz", hash = "sha256:7ae4bbe49de630b1b4c2f30a8d343d532fa88ff7486c47f1d8d287a78d017f8b"}, {file = "pdme-0.4.1.tar.gz", hash = "sha256:9c892b23462d87f745735a6cd73346831a1fcb765a1c542dda5909f47b23fc7b"},
] ]
pluggy = [ pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
@ -398,6 +410,33 @@ pytest-cov = [
{file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
{file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
] ]
scipy = [
{file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"},
{file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"},
{file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"},
{file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"},
{file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"},
{file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"},
{file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"},
{file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"},
{file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"},
{file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"},
{file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"},
{file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"},
{file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"},
{file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"},
{file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"},
{file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"},
{file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"},
{file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"},
{file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"},
{file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"},
{file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"},
{file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"},
{file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"},
{file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"},
{file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"},
]
toml = [ toml = [
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},

View File

@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8,<3.10" python = "^3.8,<3.10"
pdme = "^0.0.1" pdme = "^0.4.1"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = ">=6" pytest = ">=6"
@ -20,5 +20,15 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options] [tool.pytest.ini_options]
testpaths = ["tests"] testpaths = ["tests"]
addopts = "--junitxml pytest.xml --cov deepdog --cov-report=xml:coverage.xml --cov-fail-under=90 --cov-report=html" addopts = "--junitxml pytest.xml --cov deepdog --cov-report=xml:coverage.xml --cov-report=html"
junit_family = "xunit1" junit_family = "xunit1"
[tool.mypy]
plugins = "numpy.typing.mypy_plugin"
[[tool.mypy.overrides]]
module = [
"scipy",
"scipy.optimize"
]
ignore_missing_imports = true