feat: Adds alt bayes solver with monte carlo sampler
This commit is contained in:
parent
d078004773
commit
7284dbeb34
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from deepdog.meta import __version__
|
||||
from deepdog.bayes_run import BayesRun
|
||||
from deepdog.alt_bayes_run import AltBayesRun
|
||||
from deepdog.diagnostic import Diagnostic
|
||||
|
||||
|
||||
@ -8,7 +9,7 @@ def get_version():
|
||||
return __version__
|
||||
|
||||
|
||||
__all__ = ["get_version", "BayesRun", "Diagnostic"]
|
||||
__all__ = ["get_version", "BayesRun", "AltBayesRun", "Diagnostic"]
|
||||
|
||||
|
||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||
|
124
deepdog/alt_bayes_run.py
Normal file
124
deepdog/alt_bayes_run.py
Normal file
@ -0,0 +1,124 @@
|
||||
import pdme.model
|
||||
import pdme.measurement.oscillating_dipole
|
||||
import pdme.util.fast_v_calc
|
||||
from typing import Sequence, Tuple, List
|
||||
import datetime
|
||||
import csv
|
||||
import logging
|
||||
import numpy
|
||||
|
||||
|
||||
# TODO: remove hardcode
|
||||
COST_THRESHOLD = 1e-10
|
||||
|
||||
|
||||
# TODO: It's garbage to have this here duplicated from pdme.
|
||||
DotInput = Tuple[numpy.typing.ArrayLike, float]
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AltBayesRun():
|
||||
'''
|
||||
A single Bayes run for a given set of dots.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dot_inputs : Sequence[DotInput]
|
||||
The dot inputs for this bayes run.
|
||||
discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)]
|
||||
The models to evaluate.
|
||||
actual_model_discretisation : pdme.model.Discretisation
|
||||
The discretisation for the model which is actually correct.
|
||||
filename_slug : str
|
||||
The filename slug to include.
|
||||
run_count: int
|
||||
The number of runs to do.
|
||||
'''
|
||||
def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, low_error: float = 0.9, high_error: float = 1.1, monte_carlo_count: int = 10000, max_frequency: float = 20, end_threshold: float = None) -> None:
|
||||
self.dot_inputs = dot_inputs
|
||||
self.dot_inputs_array = pdme.measurement.oscillating_dipole.dot_inputs_to_array(dot_inputs)
|
||||
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
||||
self.model_names = [name for (name, _) in discretisations_with_names]
|
||||
self.actual_model = actual_model
|
||||
self.model_count = len(self.discretisations)
|
||||
self.monte_carlo_count = monte_carlo_count
|
||||
self.run_count = run_count
|
||||
self.low_error = low_error
|
||||
self.high_error = high_error
|
||||
self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"]
|
||||
self.compensate_zeros = True
|
||||
for name in self.model_names:
|
||||
self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"])
|
||||
|
||||
self.probabilities = [1 / self.model_count] * self.model_count
|
||||
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self.filename = f"{timestamp}-{filename_slug}.altbayes.csv"
|
||||
self.max_frequency = max_frequency
|
||||
|
||||
if end_threshold is not None:
|
||||
if 0 < end_threshold < 1:
|
||||
self.end_threshold: float = end_threshold
|
||||
self.use_end_threshold = True
|
||||
_logger.info(f"Will abort early, at {self.end_threshold}.")
|
||||
else:
|
||||
raise ValueError(f"end_threshold should be between 0 and 1, but is actually {end_threshold}")
|
||||
|
||||
def go(self) -> None:
|
||||
with open(self.filename, "a", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||
writer.writeheader()
|
||||
|
||||
for run in range(1, self.run_count + 1):
|
||||
|
||||
rng = numpy.random.default_rng()
|
||||
frequency = rng.uniform(1, self.max_frequency)
|
||||
|
||||
# Generate the actual dipoles
|
||||
actual_dipoles = self.actual_model.get_dipoles(frequency)
|
||||
|
||||
dots = actual_dipoles.get_percent_range_dot_measurements(self.dot_inputs, self.low_error, self.high_error)
|
||||
lows, highs = pdme.measurement.oscillating_dipole.dot_range_measurements_low_high_arrays(dots)
|
||||
_logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}")
|
||||
|
||||
results = []
|
||||
_logger.debug("Going to iterate over discretisations now")
|
||||
for disc_count, discretisation in enumerate(self.discretisations):
|
||||
_logger.debug(f"Doing discretisation #{disc_count}")
|
||||
sample_dipoles = discretisation.get_model().get_n_single_dipoles(self.monte_carlo_count, self.max_frequency)
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(self.dot_inputs_array, sample_dipoles)
|
||||
results.append(numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs)))
|
||||
|
||||
_logger.debug("Done, constructing output now")
|
||||
row = {
|
||||
"dipole_moment": actual_dipoles.dipoles[0].p,
|
||||
"dipole_location": actual_dipoles.dipoles[0].s,
|
||||
"dipole_frequency": actual_dipoles.dipoles[0].w
|
||||
}
|
||||
successes: List[float] = []
|
||||
counts: List[int] = []
|
||||
for model_index, (name, result) in enumerate(zip(self.model_names, results)):
|
||||
|
||||
row[f"{name}_success"] = result
|
||||
row[f"{name}_count"] = self.monte_carlo_count
|
||||
successes.append(max(result, 0.5))
|
||||
counts.append(self.monte_carlo_count)
|
||||
|
||||
success_weight = sum([(succ / count) * prob for succ, count, prob in zip(successes, counts, self.probabilities)])
|
||||
new_probabilities = [(succ / count) * old_prob / success_weight for succ, count, old_prob in zip(successes, counts, self.probabilities)]
|
||||
self.probabilities = new_probabilities
|
||||
for name, probability in zip(self.model_names, self.probabilities):
|
||||
row[f"{name}_prob"] = probability
|
||||
_logger.info(row)
|
||||
|
||||
with open(self.filename, "a", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||
writer.writerow(row)
|
||||
|
||||
if self.use_end_threshold:
|
||||
max_prob = max(self.probabilities)
|
||||
if max_prob > self.end_threshold:
|
||||
_logger.info(f"Aborting early, because {max_prob} is greater than {self.end_threshold}")
|
||||
break
|
8
poetry.lock
generated
8
poetry.lock
generated
@ -304,7 +304,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
|
||||
|
||||
[[package]]
|
||||
name = "pdme"
|
||||
version = "0.5.3"
|
||||
version = "0.5.4"
|
||||
description = "Python dipole model evaluator"
|
||||
category = "main"
|
||||
optional = false
|
||||
@ -697,7 +697,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8,<3.10"
|
||||
content-hash = "9ab90751e7341b5d3bf97445385aee4201807ffddc9ead94a9e2244ac797f85d"
|
||||
content-hash = "ba794e6f69d42e44e2b1abc40731fd78d2f417cdc9be12d6e752dbcfa95adaad"
|
||||
|
||||
[metadata.files]
|
||||
atomicwrites = [
|
||||
@ -952,8 +952,8 @@ packaging = [
|
||||
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
|
||||
]
|
||||
pdme = [
|
||||
{file = "pdme-0.5.3-py3-none-any.whl", hash = "sha256:91f2340e392159b688f3dd56a43be2093e9acdd573d1614aa8937316155552fd"},
|
||||
{file = "pdme-0.5.3.tar.gz", hash = "sha256:ccb7e173849a6650f79767f4fc1b74ff820d22f5d8c0937c9ffb0d818fc4f147"},
|
||||
{file = "pdme-0.5.4-py3-none-any.whl", hash = "sha256:90ba75efbec04f5a505c9c5228824538510149e218372e758c15150d5967d92b"},
|
||||
{file = "pdme-0.5.4.tar.gz", hash = "sha256:82bff2ccc8f38996c23b43ab7d7dc80d87a6b340492f368861e7748105b50174"},
|
||||
]
|
||||
pkginfo = [
|
||||
{file = "pkginfo-1.8.2-py2.py3-none-any.whl", hash = "sha256:c24c487c6a7f72c66e816ab1796b96ac6c3d14d49338293d2141664330b55ffc"},
|
||||
|
@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8,<3.10"
|
||||
pdme = "^0.5.3"
|
||||
pdme = "^0.5.4"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = ">=6"
|
||||
|
Loading…
x
Reference in New Issue
Block a user