From 3d3b1a83f630bfa50da802905b5b6f6d6a78a7fc Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Mon, 14 Feb 2022 09:27:40 -0600 Subject: [PATCH] feat: Adds end threshold for early abort --- deepdog/bayes_run.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/deepdog/bayes_run.py b/deepdog/bayes_run.py index 88c4656..ae816de 100644 --- a/deepdog/bayes_run.py +++ b/deepdog/bayes_run.py @@ -41,7 +41,7 @@ class BayesRun(): run_count: int The number of runs to do. ''' - def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, max_frequency: float = None) -> None: + def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, max_frequency: float = None, end_threshold: float = None) -> None: self.dot_inputs = dot_inputs self.discretisations = [disc for (_, disc) in discretisations_with_names] self.model_names = [name for (name, _) in discretisations_with_names] @@ -50,7 +50,6 @@ class BayesRun(): self.run_count = run_count self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"] self.compensate_zeros = True - for name in self.model_names: self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"]) @@ -60,6 +59,13 @@ class BayesRun(): self.filename = f"{timestamp}-{filename_slug}.csv" self.max_frequency = max_frequency + if end_threshold is not None: + if 0 < self.end_threshold < 1: + self.end_threshold: float = end_threshold + self.use_end_threshold = True + else: + raise ValueError(f"{end_threshold} should be between 0 and 1") + def go(self) -> None: with open(self.filename, "a", newline="") as outfile: writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") @@ -111,3 +117,9 @@ class BayesRun(): with open(self.filename, "a", newline="") as outfile: writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") writer.writerow(row) + + if self.use_end_threshold: + max_prob = max(self.probabilities) + if max_prob > self.end_threshold: + _logger.info(f"Aborting early, because {max_prob} is greater than {self.end_threshold}") + break