Adds bayes run
This commit is contained in:
parent
027797448e
commit
3397b3e2f9
@ -1,9 +1,13 @@
|
|||||||
import logging
|
import logging
|
||||||
from deepdog.meta import __version__
|
from deepdog.meta import __version__
|
||||||
|
from deepdog.bayes_run import BayesRun
|
||||||
|
|
||||||
|
|
||||||
def get_version():
|
def get_version():
|
||||||
return __version__
|
return __version__
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_version", "BayesRun"]
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
108
deepdog/bayes_run.py
Normal file
108
deepdog/bayes_run.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import pdme.model
|
||||||
|
from typing import Sequence, Tuple, List
|
||||||
|
import datetime
|
||||||
|
import itertools
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
import numpy
|
||||||
|
import scipy.optimize
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: remove hardcode
|
||||||
|
COST_THRESHOLD = 1e-10
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: It's garbage to have this here duplicated from pdme.
|
||||||
|
DotInput = Tuple[numpy.typing.ArrayLike, float]
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_a_result(discretisation, dots, index) -> Tuple[Tuple[int, ...], scipy.optimize.OptimizeResult]:
|
||||||
|
return (index, discretisation.solve_for_index(dots, index))
|
||||||
|
|
||||||
|
|
||||||
|
class BayesRun():
|
||||||
|
'''
|
||||||
|
A single Bayes run for a given set of dots.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dot_inputs : Sequence[DotInput]
|
||||||
|
The dot inputs for this bayes run.
|
||||||
|
discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)]
|
||||||
|
The models to evaluate.
|
||||||
|
actual_model_discretisation : pdme.model.Discretisation
|
||||||
|
The discretisation for the model which is actually correct.
|
||||||
|
filename_slug : str
|
||||||
|
The filename slug to include.
|
||||||
|
run_count: int
|
||||||
|
The number of runs to do.
|
||||||
|
'''
|
||||||
|
def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int) -> None:
|
||||||
|
self.dot_inputs = dot_inputs
|
||||||
|
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
||||||
|
self.model_names = [name for (name, _) in discretisations_with_names]
|
||||||
|
self.actual_model = actual_model
|
||||||
|
self.model_count = len(self.discretisations)
|
||||||
|
self.run_count = run_count
|
||||||
|
self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"]
|
||||||
|
self.compensate_zeros = True
|
||||||
|
|
||||||
|
for name in self.model_names:
|
||||||
|
self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"])
|
||||||
|
|
||||||
|
self.probabilities = [1 / self.model_count] * self.model_count
|
||||||
|
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
self.filename = f"{timestamp}-{filename_slug}.csv"
|
||||||
|
|
||||||
|
def go(self) -> None:
|
||||||
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
|
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for run in range(1, self.run_count + 1):
|
||||||
|
dipoles = self.actual_model.get_dipoles(run)
|
||||||
|
|
||||||
|
dots = dipoles.get_dot_measurements(self.dot_inputs)
|
||||||
|
_logger.info(f"Going to work on dipole at {dipoles.dipoles}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
_logger.debug("Going to iterate over discretisations now")
|
||||||
|
for disc_count, discretisation in enumerate(self.discretisations):
|
||||||
|
_logger.debug(f"Doing discretisation #{disc_count}")
|
||||||
|
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
|
||||||
|
results.append(pool.starmap(get_a_result, zip(itertools.repeat(discretisation), itertools.repeat(dots), discretisation.all_indices())))
|
||||||
|
|
||||||
|
_logger.debug("Done, constructing output now")
|
||||||
|
row = {
|
||||||
|
"dipole_moment": dipoles.dipoles[0].p,
|
||||||
|
"dipole_location": dipoles.dipoles[0].s,
|
||||||
|
"dipole_frequency": dipoles.dipoles[0].w
|
||||||
|
}
|
||||||
|
successes: List[int] = []
|
||||||
|
for model_index, (name, result) in enumerate(zip(self.model_names, results)):
|
||||||
|
count = 0
|
||||||
|
success = 0
|
||||||
|
for idx, val in result:
|
||||||
|
count += 1
|
||||||
|
if val.success and val.cost <= COST_THRESHOLD:
|
||||||
|
success += 1
|
||||||
|
|
||||||
|
row[f"{name}_success"] = success
|
||||||
|
row[f"{name}_count"] = count
|
||||||
|
successes.append(max(success, 1))
|
||||||
|
|
||||||
|
success_weight = sum([succ * prob for succ, prob in zip(successes, self.probabilities)])
|
||||||
|
new_probabilities = [succ * old_prob / success_weight for succ, old_prob in zip(successes, self.probabilities)]
|
||||||
|
self.probabilities = new_probabilities
|
||||||
|
for name, probability in zip(self.model_names, self.probabilities):
|
||||||
|
row[f"{name}_prob"] = probability
|
||||||
|
_logger.info(row)
|
||||||
|
|
||||||
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
|
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||||
|
writer.writerow(row)
|
47
poetry.lock
generated
47
poetry.lock
generated
@ -117,7 +117,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pdme"
|
name = "pdme"
|
||||||
version = "0.0.1"
|
version = "0.4.1"
|
||||||
description = "Python dipole model evaluator"
|
description = "Python dipole model evaluator"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
@ -125,6 +125,7 @@ python-versions = ">=3.8,<3.10"
|
|||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
numpy = ">=1.21.1,<2.0.0"
|
numpy = ">=1.21.1,<2.0.0"
|
||||||
|
scipy = ">=1.5,<1.6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pluggy"
|
name = "pluggy"
|
||||||
@ -209,6 +210,17 @@ pytest = ">=4.6"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
|
testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scipy"
|
||||||
|
version = "1.5.4"
|
||||||
|
description = "SciPy: Scientific Library for Python"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
numpy = ">=1.14.5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml"
|
name = "toml"
|
||||||
version = "0.10.2"
|
version = "0.10.2"
|
||||||
@ -236,7 +248,7 @@ python-versions = ">=3.6"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.8,<3.10"
|
python-versions = "^3.8,<3.10"
|
||||||
content-hash = "380c692a00857c85c6767daa93b458f8d768a6b05dbd94a2e14fe84cc9d2f3ea"
|
content-hash = "d0cc6fcaa5f489054917fd097c7f86bb314b4f7b76c14cddecb9cd9ef1a24d72"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
atomicwrites = [
|
atomicwrites = [
|
||||||
@ -367,8 +379,8 @@ packaging = [
|
|||||||
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
|
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
|
||||||
]
|
]
|
||||||
pdme = [
|
pdme = [
|
||||||
{file = "pdme-0.0.1-py3-none-any.whl", hash = "sha256:f2b8fc969de63e66681f012a1f04eebb2b65e08ba5315dfbb8a19acbadc57858"},
|
{file = "pdme-0.4.1-py3-none-any.whl", hash = "sha256:c8a6b6a9755094a89f55a46b42ba2531424805747b8e23e078928c880721d84f"},
|
||||||
{file = "pdme-0.0.1.tar.gz", hash = "sha256:7ae4bbe49de630b1b4c2f30a8d343d532fa88ff7486c47f1d8d287a78d017f8b"},
|
{file = "pdme-0.4.1.tar.gz", hash = "sha256:9c892b23462d87f745735a6cd73346831a1fcb765a1c542dda5909f47b23fc7b"},
|
||||||
]
|
]
|
||||||
pluggy = [
|
pluggy = [
|
||||||
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
|
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
|
||||||
@ -398,6 +410,33 @@ pytest-cov = [
|
|||||||
{file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
|
{file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
|
||||||
{file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
|
{file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
|
||||||
]
|
]
|
||||||
|
scipy = [
|
||||||
|
{file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"},
|
||||||
|
{file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"},
|
||||||
|
{file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"},
|
||||||
|
{file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"},
|
||||||
|
{file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"},
|
||||||
|
{file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"},
|
||||||
|
{file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"},
|
||||||
|
{file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"},
|
||||||
|
{file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"},
|
||||||
|
{file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"},
|
||||||
|
{file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"},
|
||||||
|
{file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"},
|
||||||
|
{file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"},
|
||||||
|
{file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"},
|
||||||
|
{file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"},
|
||||||
|
{file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"},
|
||||||
|
{file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"},
|
||||||
|
{file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"},
|
||||||
|
{file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"},
|
||||||
|
{file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"},
|
||||||
|
{file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"},
|
||||||
|
{file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"},
|
||||||
|
{file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"},
|
||||||
|
{file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"},
|
||||||
|
{file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"},
|
||||||
|
]
|
||||||
toml = [
|
toml = [
|
||||||
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
||||||
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
||||||
|
@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.8,<3.10"
|
python = "^3.8,<3.10"
|
||||||
pdme = "^0.0.1"
|
pdme = "^0.4.1"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = ">=6"
|
pytest = ">=6"
|
||||||
@ -20,5 +20,15 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
addopts = "--junitxml pytest.xml --cov deepdog --cov-report=xml:coverage.xml --cov-fail-under=90 --cov-report=html"
|
addopts = "--junitxml pytest.xml --cov deepdog --cov-report=xml:coverage.xml --cov-report=html"
|
||||||
junit_family = "xunit1"
|
junit_family = "xunit1"
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
plugins = "numpy.typing.mypy_plugin"
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = [
|
||||||
|
"scipy",
|
||||||
|
"scipy.optimize"
|
||||||
|
]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
Loading…
x
Reference in New Issue
Block a user