diff --git a/deepdog/__init__.py b/deepdog/__init__.py index 53b58dc..369c2ce 100644 --- a/deepdog/__init__.py +++ b/deepdog/__init__.py @@ -1,6 +1,7 @@ import logging from deepdog.meta import __version__ from deepdog.bayes_run import BayesRun +from deepdog.alt_bayes_run import AltBayesRun from deepdog.diagnostic import Diagnostic @@ -8,7 +9,7 @@ def get_version(): return __version__ -__all__ = ["get_version", "BayesRun", "Diagnostic"] +__all__ = ["get_version", "BayesRun", "AltBayesRun", "Diagnostic"] logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/deepdog/alt_bayes_run.py b/deepdog/alt_bayes_run.py new file mode 100644 index 0000000..154006a --- /dev/null +++ b/deepdog/alt_bayes_run.py @@ -0,0 +1,124 @@ +import pdme.model +import pdme.measurement.oscillating_dipole +import pdme.util.fast_v_calc +from typing import Sequence, Tuple, List +import datetime +import csv +import logging +import numpy + + +# TODO: remove hardcode +COST_THRESHOLD = 1e-10 + + +# TODO: It's garbage to have this here duplicated from pdme. +DotInput = Tuple[numpy.typing.ArrayLike, float] + + +_logger = logging.getLogger(__name__) + + +class AltBayesRun(): + ''' + A single Bayes run for a given set of dots. + + Parameters + ---------- + dot_inputs : Sequence[DotInput] + The dot inputs for this bayes run. + discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)] + The models to evaluate. + actual_model_discretisation : pdme.model.Discretisation + The discretisation for the model which is actually correct. + filename_slug : str + The filename slug to include. + run_count: int + The number of runs to do. + ''' + def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, low_error: float = 0.9, high_error: float = 1.1, monte_carlo_count: int = 10000, max_frequency: float = 20, end_threshold: float = None) -> None: + self.dot_inputs = dot_inputs + self.dot_inputs_array = pdme.measurement.oscillating_dipole.dot_inputs_to_array(dot_inputs) + self.discretisations = [disc for (_, disc) in discretisations_with_names] + self.model_names = [name for (name, _) in discretisations_with_names] + self.actual_model = actual_model + self.model_count = len(self.discretisations) + self.monte_carlo_count = monte_carlo_count + self.run_count = run_count + self.low_error = low_error + self.high_error = high_error + self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"] + self.compensate_zeros = True + for name in self.model_names: + self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"]) + + self.probabilities = [1 / self.model_count] * self.model_count + + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + self.filename = f"{timestamp}-{filename_slug}.altbayes.csv" + self.max_frequency = max_frequency + + if end_threshold is not None: + if 0 < end_threshold < 1: + self.end_threshold: float = end_threshold + self.use_end_threshold = True + _logger.info(f"Will abort early, at {self.end_threshold}.") + else: + raise ValueError(f"end_threshold should be between 0 and 1, but is actually {end_threshold}") + + def go(self) -> None: + with open(self.filename, "a", newline="") as outfile: + writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") + writer.writeheader() + + for run in range(1, self.run_count + 1): + + rng = numpy.random.default_rng() + frequency = rng.uniform(1, self.max_frequency) + + # Generate the actual dipoles + actual_dipoles = self.actual_model.get_dipoles(frequency) + + dots = actual_dipoles.get_percent_range_dot_measurements(self.dot_inputs, self.low_error, self.high_error) + lows, highs = pdme.measurement.oscillating_dipole.dot_range_measurements_low_high_arrays(dots) + _logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}") + + results = [] + _logger.debug("Going to iterate over discretisations now") + for disc_count, discretisation in enumerate(self.discretisations): + _logger.debug(f"Doing discretisation #{disc_count}") + sample_dipoles = discretisation.get_model().get_n_single_dipoles(self.monte_carlo_count, self.max_frequency) + vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(self.dot_inputs_array, sample_dipoles) + results.append(numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs))) + + _logger.debug("Done, constructing output now") + row = { + "dipole_moment": actual_dipoles.dipoles[0].p, + "dipole_location": actual_dipoles.dipoles[0].s, + "dipole_frequency": actual_dipoles.dipoles[0].w + } + successes: List[float] = [] + counts: List[int] = [] + for model_index, (name, result) in enumerate(zip(self.model_names, results)): + + row[f"{name}_success"] = result + row[f"{name}_count"] = self.monte_carlo_count + successes.append(max(result, 0.5)) + counts.append(self.monte_carlo_count) + + success_weight = sum([(succ / count) * prob for succ, count, prob in zip(successes, counts, self.probabilities)]) + new_probabilities = [(succ / count) * old_prob / success_weight for succ, count, old_prob in zip(successes, counts, self.probabilities)] + self.probabilities = new_probabilities + for name, probability in zip(self.model_names, self.probabilities): + row[f"{name}_prob"] = probability + _logger.info(row) + + with open(self.filename, "a", newline="") as outfile: + writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") + writer.writerow(row) + + if self.use_end_threshold: + max_prob = max(self.probabilities) + if max_prob > self.end_threshold: + _logger.info(f"Aborting early, because {max_prob} is greater than {self.end_threshold}") + break diff --git a/poetry.lock b/poetry.lock index f590881..b14e346 100644 --- a/poetry.lock +++ b/poetry.lock @@ -304,7 +304,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" [[package]] name = "pdme" -version = "0.5.3" +version = "0.5.4" description = "Python dipole model evaluator" category = "main" optional = false @@ -697,7 +697,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8,<3.10" -content-hash = "9ab90751e7341b5d3bf97445385aee4201807ffddc9ead94a9e2244ac797f85d" +content-hash = "ba794e6f69d42e44e2b1abc40731fd78d2f417cdc9be12d6e752dbcfa95adaad" [metadata.files] atomicwrites = [ @@ -952,8 +952,8 @@ packaging = [ {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] pdme = [ - {file = "pdme-0.5.3-py3-none-any.whl", hash = "sha256:91f2340e392159b688f3dd56a43be2093e9acdd573d1614aa8937316155552fd"}, - {file = "pdme-0.5.3.tar.gz", hash = "sha256:ccb7e173849a6650f79767f4fc1b74ff820d22f5d8c0937c9ffb0d818fc4f147"}, + {file = "pdme-0.5.4-py3-none-any.whl", hash = "sha256:90ba75efbec04f5a505c9c5228824538510149e218372e758c15150d5967d92b"}, + {file = "pdme-0.5.4.tar.gz", hash = "sha256:82bff2ccc8f38996c23b43ab7d7dc80d87a6b340492f368861e7748105b50174"}, ] pkginfo = [ {file = "pkginfo-1.8.2-py2.py3-none-any.whl", hash = "sha256:c24c487c6a7f72c66e816ab1796b96ac6c3d14d49338293d2141664330b55ffc"}, diff --git a/pyproject.toml b/pyproject.toml index ed74de4..c59704f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla "] [tool.poetry.dependencies] python = "^3.8,<3.10" -pdme = "^0.5.3" +pdme = "^0.5.4" [tool.poetry.dev-dependencies] pytest = ">=6"