diff --git a/deepdog/__init__.py b/deepdog/__init__.py index 52019a1..e590cf3 100644 --- a/deepdog/__init__.py +++ b/deepdog/__init__.py @@ -1,9 +1,13 @@ import logging from deepdog.meta import __version__ +from deepdog.bayes_run import BayesRun def get_version(): return __version__ +__all__ = ["get_version", "BayesRun"] + + logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/deepdog/bayes_run.py b/deepdog/bayes_run.py new file mode 100644 index 0000000..c734df6 --- /dev/null +++ b/deepdog/bayes_run.py @@ -0,0 +1,108 @@ +import pdme.model +from typing import Sequence, Tuple, List +import datetime +import itertools +import csv +import logging +import numpy +import scipy.optimize +import multiprocessing + + +# TODO: remove hardcode +COST_THRESHOLD = 1e-10 + + +# TODO: It's garbage to have this here duplicated from pdme. +DotInput = Tuple[numpy.typing.ArrayLike, float] + + +_logger = logging.getLogger(__name__) + + +def get_a_result(discretisation, dots, index) -> Tuple[Tuple[int, ...], scipy.optimize.OptimizeResult]: + return (index, discretisation.solve_for_index(dots, index)) + + +class BayesRun(): + ''' + A single Bayes run for a given set of dots. + + Parameters + ---------- + dot_inputs : Sequence[DotInput] + The dot inputs for this bayes run. + discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)] + The models to evaluate. + actual_model_discretisation : pdme.model.Discretisation + The discretisation for the model which is actually correct. + filename_slug : str + The filename slug to include. + run_count: int + The number of runs to do. + ''' + def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int) -> None: + self.dot_inputs = dot_inputs + self.discretisations = [disc for (_, disc) in discretisations_with_names] + self.model_names = [name for (name, _) in discretisations_with_names] + self.actual_model = actual_model + self.model_count = len(self.discretisations) + self.run_count = run_count + self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"] + self.compensate_zeros = True + + for name in self.model_names: + self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"]) + + self.probabilities = [1 / self.model_count] * self.model_count + + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + self.filename = f"{timestamp}-{filename_slug}.csv" + + def go(self) -> None: + with open(self.filename, "a", newline="") as outfile: + writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") + writer.writeheader() + + for run in range(1, self.run_count + 1): + dipoles = self.actual_model.get_dipoles(run) + + dots = dipoles.get_dot_measurements(self.dot_inputs) + _logger.info(f"Going to work on dipole at {dipoles.dipoles}") + + results = [] + _logger.debug("Going to iterate over discretisations now") + for disc_count, discretisation in enumerate(self.discretisations): + _logger.debug(f"Doing discretisation #{disc_count}") + with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool: + results.append(pool.starmap(get_a_result, zip(itertools.repeat(discretisation), itertools.repeat(dots), discretisation.all_indices()))) + + _logger.debug("Done, constructing output now") + row = { + "dipole_moment": dipoles.dipoles[0].p, + "dipole_location": dipoles.dipoles[0].s, + "dipole_frequency": dipoles.dipoles[0].w + } + successes: List[int] = [] + for model_index, (name, result) in enumerate(zip(self.model_names, results)): + count = 0 + success = 0 + for idx, val in result: + count += 1 + if val.success and val.cost <= COST_THRESHOLD: + success += 1 + + row[f"{name}_success"] = success + row[f"{name}_count"] = count + successes.append(max(success, 1)) + + success_weight = sum([succ * prob for succ, prob in zip(successes, self.probabilities)]) + new_probabilities = [succ * old_prob / success_weight for succ, old_prob in zip(successes, self.probabilities)] + self.probabilities = new_probabilities + for name, probability in zip(self.model_names, self.probabilities): + row[f"{name}_prob"] = probability + _logger.info(row) + + with open(self.filename, "a", newline="") as outfile: + writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") + writer.writerow(row) diff --git a/poetry.lock b/poetry.lock index 50f1c54..2a90829 100644 --- a/poetry.lock +++ b/poetry.lock @@ -117,7 +117,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" [[package]] name = "pdme" -version = "0.0.1" +version = "0.4.1" description = "Python dipole model evaluator" category = "main" optional = false @@ -125,6 +125,7 @@ python-versions = ">=3.8,<3.10" [package.dependencies] numpy = ">=1.21.1,<2.0.0" +scipy = ">=1.5,<1.6" [[package]] name = "pluggy" @@ -209,6 +210,17 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "six", "pytest-xdist", "virtualenv"] +[[package]] +name = "scipy" +version = "1.5.4" +description = "SciPy: Scientific Library for Python" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +numpy = ">=1.14.5" + [[package]] name = "toml" version = "0.10.2" @@ -236,7 +248,7 @@ python-versions = ">=3.6" [metadata] lock-version = "1.1" python-versions = "^3.8,<3.10" -content-hash = "380c692a00857c85c6767daa93b458f8d768a6b05dbd94a2e14fe84cc9d2f3ea" +content-hash = "d0cc6fcaa5f489054917fd097c7f86bb314b4f7b76c14cddecb9cd9ef1a24d72" [metadata.files] atomicwrites = [ @@ -367,8 +379,8 @@ packaging = [ {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] pdme = [ - {file = "pdme-0.0.1-py3-none-any.whl", hash = "sha256:f2b8fc969de63e66681f012a1f04eebb2b65e08ba5315dfbb8a19acbadc57858"}, - {file = "pdme-0.0.1.tar.gz", hash = "sha256:7ae4bbe49de630b1b4c2f30a8d343d532fa88ff7486c47f1d8d287a78d017f8b"}, + {file = "pdme-0.4.1-py3-none-any.whl", hash = "sha256:c8a6b6a9755094a89f55a46b42ba2531424805747b8e23e078928c880721d84f"}, + {file = "pdme-0.4.1.tar.gz", hash = "sha256:9c892b23462d87f745735a6cd73346831a1fcb765a1c542dda5909f47b23fc7b"}, ] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, @@ -398,6 +410,33 @@ pytest-cov = [ {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, ] +scipy = [ + {file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"}, + {file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"}, + {file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"}, + {file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"}, + {file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"}, + {file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"}, + {file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"}, + {file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"}, + {file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"}, + {file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"}, + {file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"}, + {file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"}, + {file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"}, + {file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"}, + {file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"}, + {file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"}, + {file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"}, + {file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"}, + {file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"}, + {file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"}, + {file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"}, + {file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"}, + {file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"}, + {file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"}, + {file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"}, +] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, diff --git a/pyproject.toml b/pyproject.toml index 35c16ad..1769722 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla "] [tool.poetry.dependencies] python = "^3.8,<3.10" -pdme = "^0.0.1" +pdme = "^0.4.1" [tool.poetry.dev-dependencies] pytest = ">=6" @@ -20,5 +20,15 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] testpaths = ["tests"] -addopts = "--junitxml pytest.xml --cov deepdog --cov-report=xml:coverage.xml --cov-fail-under=90 --cov-report=html" +addopts = "--junitxml pytest.xml --cov deepdog --cov-report=xml:coverage.xml --cov-report=html" junit_family = "xunit1" + +[tool.mypy] +plugins = "numpy.typing.mypy_plugin" + +[[tool.mypy.overrides]] +module = [ + "scipy", + "scipy.optimize" +] +ignore_missing_imports = true