feat: Adds alt bayes solver with monte carlo sampler
This commit is contained in:
parent
d078004773
commit
7284dbeb34
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from deepdog.meta import __version__
|
from deepdog.meta import __version__
|
||||||
from deepdog.bayes_run import BayesRun
|
from deepdog.bayes_run import BayesRun
|
||||||
|
from deepdog.alt_bayes_run import AltBayesRun
|
||||||
from deepdog.diagnostic import Diagnostic
|
from deepdog.diagnostic import Diagnostic
|
||||||
|
|
||||||
|
|
||||||
@ -8,7 +9,7 @@ def get_version():
|
|||||||
return __version__
|
return __version__
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_version", "BayesRun", "Diagnostic"]
|
__all__ = ["get_version", "BayesRun", "AltBayesRun", "Diagnostic"]
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
124
deepdog/alt_bayes_run.py
Normal file
124
deepdog/alt_bayes_run.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import pdme.model
|
||||||
|
import pdme.measurement.oscillating_dipole
|
||||||
|
import pdme.util.fast_v_calc
|
||||||
|
from typing import Sequence, Tuple, List
|
||||||
|
import datetime
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: remove hardcode
|
||||||
|
COST_THRESHOLD = 1e-10
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: It's garbage to have this here duplicated from pdme.
|
||||||
|
DotInput = Tuple[numpy.typing.ArrayLike, float]
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AltBayesRun():
|
||||||
|
'''
|
||||||
|
A single Bayes run for a given set of dots.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dot_inputs : Sequence[DotInput]
|
||||||
|
The dot inputs for this bayes run.
|
||||||
|
discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)]
|
||||||
|
The models to evaluate.
|
||||||
|
actual_model_discretisation : pdme.model.Discretisation
|
||||||
|
The discretisation for the model which is actually correct.
|
||||||
|
filename_slug : str
|
||||||
|
The filename slug to include.
|
||||||
|
run_count: int
|
||||||
|
The number of runs to do.
|
||||||
|
'''
|
||||||
|
def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, low_error: float = 0.9, high_error: float = 1.1, monte_carlo_count: int = 10000, max_frequency: float = 20, end_threshold: float = None) -> None:
|
||||||
|
self.dot_inputs = dot_inputs
|
||||||
|
self.dot_inputs_array = pdme.measurement.oscillating_dipole.dot_inputs_to_array(dot_inputs)
|
||||||
|
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
||||||
|
self.model_names = [name for (name, _) in discretisations_with_names]
|
||||||
|
self.actual_model = actual_model
|
||||||
|
self.model_count = len(self.discretisations)
|
||||||
|
self.monte_carlo_count = monte_carlo_count
|
||||||
|
self.run_count = run_count
|
||||||
|
self.low_error = low_error
|
||||||
|
self.high_error = high_error
|
||||||
|
self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"]
|
||||||
|
self.compensate_zeros = True
|
||||||
|
for name in self.model_names:
|
||||||
|
self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"])
|
||||||
|
|
||||||
|
self.probabilities = [1 / self.model_count] * self.model_count
|
||||||
|
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
self.filename = f"{timestamp}-{filename_slug}.altbayes.csv"
|
||||||
|
self.max_frequency = max_frequency
|
||||||
|
|
||||||
|
if end_threshold is not None:
|
||||||
|
if 0 < end_threshold < 1:
|
||||||
|
self.end_threshold: float = end_threshold
|
||||||
|
self.use_end_threshold = True
|
||||||
|
_logger.info(f"Will abort early, at {self.end_threshold}.")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"end_threshold should be between 0 and 1, but is actually {end_threshold}")
|
||||||
|
|
||||||
|
def go(self) -> None:
|
||||||
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
|
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for run in range(1, self.run_count + 1):
|
||||||
|
|
||||||
|
rng = numpy.random.default_rng()
|
||||||
|
frequency = rng.uniform(1, self.max_frequency)
|
||||||
|
|
||||||
|
# Generate the actual dipoles
|
||||||
|
actual_dipoles = self.actual_model.get_dipoles(frequency)
|
||||||
|
|
||||||
|
dots = actual_dipoles.get_percent_range_dot_measurements(self.dot_inputs, self.low_error, self.high_error)
|
||||||
|
lows, highs = pdme.measurement.oscillating_dipole.dot_range_measurements_low_high_arrays(dots)
|
||||||
|
_logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
_logger.debug("Going to iterate over discretisations now")
|
||||||
|
for disc_count, discretisation in enumerate(self.discretisations):
|
||||||
|
_logger.debug(f"Doing discretisation #{disc_count}")
|
||||||
|
sample_dipoles = discretisation.get_model().get_n_single_dipoles(self.monte_carlo_count, self.max_frequency)
|
||||||
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(self.dot_inputs_array, sample_dipoles)
|
||||||
|
results.append(numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs)))
|
||||||
|
|
||||||
|
_logger.debug("Done, constructing output now")
|
||||||
|
row = {
|
||||||
|
"dipole_moment": actual_dipoles.dipoles[0].p,
|
||||||
|
"dipole_location": actual_dipoles.dipoles[0].s,
|
||||||
|
"dipole_frequency": actual_dipoles.dipoles[0].w
|
||||||
|
}
|
||||||
|
successes: List[float] = []
|
||||||
|
counts: List[int] = []
|
||||||
|
for model_index, (name, result) in enumerate(zip(self.model_names, results)):
|
||||||
|
|
||||||
|
row[f"{name}_success"] = result
|
||||||
|
row[f"{name}_count"] = self.monte_carlo_count
|
||||||
|
successes.append(max(result, 0.5))
|
||||||
|
counts.append(self.monte_carlo_count)
|
||||||
|
|
||||||
|
success_weight = sum([(succ / count) * prob for succ, count, prob in zip(successes, counts, self.probabilities)])
|
||||||
|
new_probabilities = [(succ / count) * old_prob / success_weight for succ, count, old_prob in zip(successes, counts, self.probabilities)]
|
||||||
|
self.probabilities = new_probabilities
|
||||||
|
for name, probability in zip(self.model_names, self.probabilities):
|
||||||
|
row[f"{name}_prob"] = probability
|
||||||
|
_logger.info(row)
|
||||||
|
|
||||||
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
|
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
if self.use_end_threshold:
|
||||||
|
max_prob = max(self.probabilities)
|
||||||
|
if max_prob > self.end_threshold:
|
||||||
|
_logger.info(f"Aborting early, because {max_prob} is greater than {self.end_threshold}")
|
||||||
|
break
|
8
poetry.lock
generated
8
poetry.lock
generated
@ -304,7 +304,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pdme"
|
name = "pdme"
|
||||||
version = "0.5.3"
|
version = "0.5.4"
|
||||||
description = "Python dipole model evaluator"
|
description = "Python dipole model evaluator"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = false
|
optional = false
|
||||||
@ -697,7 +697,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.8,<3.10"
|
python-versions = "^3.8,<3.10"
|
||||||
content-hash = "9ab90751e7341b5d3bf97445385aee4201807ffddc9ead94a9e2244ac797f85d"
|
content-hash = "ba794e6f69d42e44e2b1abc40731fd78d2f417cdc9be12d6e752dbcfa95adaad"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
atomicwrites = [
|
atomicwrites = [
|
||||||
@ -952,8 +952,8 @@ packaging = [
|
|||||||
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
|
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
|
||||||
]
|
]
|
||||||
pdme = [
|
pdme = [
|
||||||
{file = "pdme-0.5.3-py3-none-any.whl", hash = "sha256:91f2340e392159b688f3dd56a43be2093e9acdd573d1614aa8937316155552fd"},
|
{file = "pdme-0.5.4-py3-none-any.whl", hash = "sha256:90ba75efbec04f5a505c9c5228824538510149e218372e758c15150d5967d92b"},
|
||||||
{file = "pdme-0.5.3.tar.gz", hash = "sha256:ccb7e173849a6650f79767f4fc1b74ff820d22f5d8c0937c9ffb0d818fc4f147"},
|
{file = "pdme-0.5.4.tar.gz", hash = "sha256:82bff2ccc8f38996c23b43ab7d7dc80d87a6b340492f368861e7748105b50174"},
|
||||||
]
|
]
|
||||||
pkginfo = [
|
pkginfo = [
|
||||||
{file = "pkginfo-1.8.2-py2.py3-none-any.whl", hash = "sha256:c24c487c6a7f72c66e816ab1796b96ac6c3d14d49338293d2141664330b55ffc"},
|
{file = "pkginfo-1.8.2-py2.py3-none-any.whl", hash = "sha256:c24c487c6a7f72c66e816ab1796b96ac6c3d14d49338293d2141664330b55ffc"},
|
||||||
|
@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.8,<3.10"
|
python = "^3.8,<3.10"
|
||||||
pdme = "^0.5.3"
|
pdme = "^0.5.4"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = ">=6"
|
pytest = ">=6"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user