Compare commits

..

7 Commits

Author SHA1 Message Date
semantic-release
6dfc26104a 0.3.0
Some checks reported errors
gitea-physics/deepdog/pipeline/head This commit looks good
gitea-physics/deepdog/pipeline/tag Something is wrong with the build of this commit
Automatically generated by python-semantic-release
2022-02-14 09:57:42 -06:00
3a6be738b1 feat: Actually uses probabilities to update bayes
All checks were successful
gitea-physics/deepdog/pipeline/head This commit looks good
2022-02-14 09:51:02 -06:00
bd240900b4 fix: Actually logs end threshold 2022-02-14 09:50:08 -06:00
0e1fbec043 fix: Fixes bug with end_threshold and better error logging 2022-02-14 09:44:04 -06:00
3d3b1a83f6 feat: Adds end threshold for early abort 2022-02-14 09:27:40 -06:00
63cecba824 Created version 0.2.4
All checks were successful
gitea-physics/deepdog/pipeline/head This commit looks good
gitea-physics/deepdog/pipeline/tag This commit looks good
2022-02-06 19:50:07 -06:00
344998835d fix: Fixes linting 2022-02-06 19:49:02 -06:00
3 changed files with 25 additions and 8 deletions

View File

@@ -41,7 +41,7 @@ class BayesRun():
run_count: int
The number of runs to do.
'''
def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, max_frequency: float = None) -> None:
def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, max_frequency: float = None, end_threshold: float = None) -> None:
self.dot_inputs = dot_inputs
self.discretisations = [disc for (_, disc) in discretisations_with_names]
self.model_names = [name for (name, _) in discretisations_with_names]
@@ -50,7 +50,6 @@ class BayesRun():
self.run_count = run_count
self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"]
self.compensate_zeros = True
for name in self.model_names:
self.csv_fields.extend([f"{name}_success", f"{name}_count", f"{name}_prob"])
@@ -60,6 +59,14 @@ class BayesRun():
self.filename = f"{timestamp}-{filename_slug}.csv"
self.max_frequency = max_frequency
if end_threshold is not None:
if 0 < end_threshold < 1:
self.end_threshold: float = end_threshold
self.use_end_threshold = True
_logger.info(f"Will abort early, at {self.end_threshold}.")
else:
raise ValueError(f"end_threshold should be between 0 and 1, but is actually {end_threshold}")
def go(self) -> None:
with open(self.filename, "a", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
@@ -88,7 +95,8 @@ class BayesRun():
"dipole_location": dipoles.dipoles[0].s,
"dipole_frequency": dipoles.dipoles[0].w
}
successes: List[int] = []
successes: List[float] = []
counts: List[int] = []
for model_index, (name, result) in enumerate(zip(self.model_names, results)):
count = 0
success = 0
@@ -99,10 +107,11 @@ class BayesRun():
row[f"{name}_success"] = success
row[f"{name}_count"] = count
successes.append(max(success, 1))
successes.append(max(success, 0.5))
counts.append(count)
success_weight = sum([succ * prob for succ, prob in zip(successes, self.probabilities)])
new_probabilities = [succ * old_prob / success_weight for succ, old_prob in zip(successes, self.probabilities)]
success_weight = sum([(succ / count) * prob for succ, count, prob in zip(successes, counts, self.probabilities)])
new_probabilities = [(succ / count) * old_prob / success_weight for succ, count, old_prob in zip(successes, counts, self.probabilities)]
self.probabilities = new_probabilities
for name, probability in zip(self.model_names, self.probabilities):
row[f"{name}_prob"] = probability
@@ -111,3 +120,9 @@ class BayesRun():
with open(self.filename, "a", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
writer.writerow(row)
if self.use_end_threshold:
max_prob = max(self.probabilities)
if max_prob > self.end_threshold:
_logger.info(f"Aborting early, because {max_prob} is greater than {self.end_threshold}")
break

View File

@@ -41,7 +41,6 @@ class SingleDipoleDiagnostic():
self.s_result_z = self.result_dipole.s[2]
self.w_actual = self.actual_dipole.w
self.w_result = self.result_dipole.w
class Diagnostic():

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "deepdog"
version = "0.2.3"
version = "0.3.0"
description = ""
authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
@@ -32,3 +32,6 @@ module = [
"scipy.optimize"
]
ignore_missing_imports = true
[tool.semantic_release]
version_toml = "pyproject.toml:tool.poetry.version"