feat: allows for deepdog bayesrun with ss to not print csv to make snapshot testing possible

This commit is contained in:
Deepak Mallubhotla 2023-07-27 17:09:36 -05:00
parent e6defc7948
commit 8e6ead416c
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
5 changed files with 379 additions and 10 deletions

View File

@ -71,6 +71,7 @@ class BayesRunWithSubspaceSimulation:
ss_default_w_log_step=0.01, ss_default_w_log_step=0.01,
ss_default_upper_w_log_step=4, ss_default_upper_w_log_step=4,
ss_dump_last_generation=False, ss_dump_last_generation=False,
write_output_to_bayesruncsv=True,
) -> None: ) -> None:
self.dot_inputs = pdme.inputs.inputs_with_frequency_range( self.dot_inputs = pdme.inputs.inputs_with_frequency_range(
dot_positions, frequency_range dot_positions, frequency_range
@ -138,10 +139,18 @@ class BayesRunWithSubspaceSimulation:
self.run_count = run_count self.run_count = run_count
def go(self) -> None: self.write_output_to_csv = write_output_to_bayesruncsv
with open(self.filename, "a", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") def go(self) -> Sequence:
writer.writeheader()
if self.write_output_to_csv:
with open(self.filename, "a", newline="") as outfile:
writer = csv.DictWriter(
outfile, fieldnames=self.csv_fields, dialect="unix"
)
writer.writeheader()
return_result = []
for run in range(1, self.run_count + 1): for run in range(1, self.run_count + 1):
@ -222,12 +231,14 @@ class BayesRunWithSubspaceSimulation:
for name, probability in zip(self.model_names, self.probabilities): for name, probability in zip(self.model_names, self.probabilities):
row[f"{name}_prob"] = probability row[f"{name}_prob"] = probability
_logger.info(row) _logger.info(row)
return_result.append(row)
with open(self.filename, "a", newline="") as outfile: if self.write_output_to_csv:
writer = csv.DictWriter( with open(self.filename, "a", newline="") as outfile:
outfile, fieldnames=self.csv_fields, dialect="unix" writer = csv.DictWriter(
) outfile, fieldnames=self.csv_fields, dialect="unix"
writer.writerow(row) )
writer.writerow(row)
if self.use_end_threshold: if self.use_end_threshold:
max_prob = max(self.probabilities) max_prob = max(self.probabilities)
@ -236,3 +247,5 @@ class BayesRunWithSubspaceSimulation:
f"Aborting early, because {max_prob} is greater than {self.end_threshold}" f"Aborting early, because {max_prob} is greater than {self.end_threshold}"
) )
break break
return return_result

24
poetry.lock generated
View File

@ -92,6 +92,14 @@ category = "dev"
optional = false optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
[[package]]
name = "colored"
version = "1.4.4"
description = "Simple library for color and formatting to terminal"
category = "dev"
optional = false
python-versions = "*"
[[package]] [[package]]
name = "coverage" name = "coverage"
version = "7.2.7" version = "7.2.7"
@ -633,6 +641,18 @@ category = "dev"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[[package]]
name = "syrupy"
version = "4.0.8"
description = "Pytest Snapshot Test Utility"
category = "dev"
optional = false
python-versions = ">=3.8.1,<4"
[package.dependencies]
colored = ">=1.3.92,<2.0.0"
pytest = ">=7.0.0,<8.0.0"
[[package]] [[package]]
name = "tomli" name = "tomli"
version = "2.0.1" version = "2.0.1"
@ -730,7 +750,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8.1,<3.10" python-versions = ">=3.8.1,<3.10"
content-hash = "111972d04616ce3ddfc9039a0b38c7eb7c4a41f10390139b27e958aedac7e979" content-hash = "e1531b1493bac50ffe5e8f9a46a64d9b66198f7021f6d643c72f21cb53dc77ec"
[metadata.files] [metadata.files]
black = [] black = []
@ -741,6 +761,7 @@ charset-normalizer = []
click = [] click = []
click-log = [] click-log = []
colorama = [] colorama = []
colored = []
coverage = [] coverage = []
cryptography = [] cryptography = []
docutils = [] docutils = []
@ -786,6 +807,7 @@ secretstorage = []
semver = [] semver = []
six = [] six = []
smmap = [] smmap = []
syrupy = []
tomli = [] tomli = []
tomlkit = [] tomlkit = []
tqdm = [] tqdm = []

View File

@ -17,6 +17,7 @@ pytest-cov = "^4.1.0"
mypy = "^0.971" mypy = "^0.971"
python-semantic-release = "^7.24.0" python-semantic-release = "^7.24.0"
black = "^22.3.0" black = "^22.3.0"
syrupy = "^4.0.8"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]

View File

@ -0,0 +1,177 @@
# serializer version: 1
# name: test_basic_analysis
list([
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.3333333333333333,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.3333333333333333,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.3333333333333333,
'dipole_frequency_1': 0.006029931414230269,
'dipole_frequency_2': 85436.78758379082,
'dipole_location_1': array([-4.76615152, -6.33160296, 5.29522808]),
'dipole_location_2': array([-4.72700391, -2.06478573, 6.52467702]),
'dipole_moment_1': array([ 860.14181416, -450.27082062, -239.60852996]),
'dipole_moment_2': array([ 908.18325588, -208.52681777, -362.93214244]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.45,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.3103448275862069,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.9,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.6206896551724138,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.06896551724137932,
'dipole_frequency_1': 102275.63477261562,
'dipole_frequency_2': 1755280.9783485082,
'dipole_location_1': array([ 4.71515397, -9.70362197, 5.43016546]),
'dipole_location_2': array([3.42476038, 3.88562934, 5.15034328]),
'dipole_moment_1': array([-502.60742674, -790.60222587, 349.7626267 ]),
'dipole_moment_2': array([-192.42708465, -434.81009148, -879.7226844 ]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.7,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.6631578947368421,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.18947368421052635,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.7,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.1473684210526316,
'dipole_frequency_1': 2896.799464036654,
'dipole_frequency_2': 9.980565189326681e-05,
'dipole_location_1': array([-4.97465789, 12.54716531, 6.06324588]),
'dipole_location_2': array([ 9.84518459, -11.1183876 , 7.35028226]),
'dipole_moment_1': array([997.67961917, 19.6376112 , 65.19004305]),
'dipole_moment_2': array([305.63093655, 440.57669389, 844.08643362]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.663157894736842,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.18947368421052635,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.1473684210526316,
'dipole_frequency_1': 1.4522667818288244,
'dipole_frequency_2': 2704.9795645301197,
'dipole_location_1': array([ 7.38183022, 16.6745801 , 7.10428414]),
'dipole_location_2': array([-8.15636906, -9.56609132, 6.34141559]),
'dipole_moment_1': array([-145.9924693 , 738.74936496, 657.97839986]),
'dipole_moment_2': array([-960.16113239, 104.96824669, -258.98314046]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.9,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.9465776293823038,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.030050083472454105,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.02337228714524208,
'dipole_frequency_1': 3827.2315421318913,
'dipole_frequency_2': 1.9301094166184413e-05,
'dipole_location_1': array([ 5.02067673, -0.9783039 , 6.1431897 ]),
'dipole_location_2': array([ 4.66628999, 10.80907459, 7.21771744]),
'dipole_moment_1': array([ 871.30659253, -299.17389491, -388.99846068]),
'dipole_moment_2': array([-189.87268624, 677.28285845, 710.79975568]),
}),
])
# ---
# name: test_bayesss_with_tighter_cost
list([
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.33333333333333337,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.33333333333333337,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.33333333333333337,
'dipole_frequency_1': 0.006029931414230269,
'dipole_frequency_2': 85436.78758379082,
'dipole_location_1': array([-4.76615152, -6.33160296, 5.29522808]),
'dipole_location_2': array([-4.72700391, -2.06478573, 6.52467702]),
'dipole_moment_1': array([ 860.14181416, -450.27082062, -239.60852996]),
'dipole_moment_2': array([ 908.18325588, -208.52681777, -362.93214244]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.0109375,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.1044776119402985,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.03125,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.2985074626865672,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.0625,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5970149253731344,
'dipole_frequency_1': 102275.63477261562,
'dipole_frequency_2': 1755280.9783485082,
'dipole_location_1': array([ 4.71515397, -9.70362197, 5.43016546]),
'dipole_location_2': array([3.42476038, 3.88562934, 5.15034328]),
'dipole_moment_1': array([-502.60742674, -790.60222587, 349.7626267 ]),
'dipole_moment_2': array([-192.42708465, -434.81009148, -879.7226844 ]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 7.291135021404688e-05,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.021875,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.4666326413699001,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.0125,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5332944472798858,
'dipole_frequency_1': 2896.799464036654,
'dipole_frequency_2': 9.980565189326681e-05,
'dipole_location_1': array([-4.97465789, 12.54716531, 6.06324588]),
'dipole_location_2': array([ 9.84518459, -11.1183876 , 7.35028226]),
'dipole_moment_1': array([997.67961917, 19.6376112 , 65.19004305]),
'dipole_moment_2': array([305.63093655, 440.57669389, 844.08643362]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 7.291135021404688e-05,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.4666326413699001,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5332944472798858,
'dipole_frequency_1': 1.4522667818288244,
'dipole_frequency_2': 2704.9795645301197,
'dipole_location_1': array([ 7.38183022, 16.6745801 , 7.10428414]),
'dipole_location_2': array([-8.15636906, -9.56609132, 6.34141559]),
'dipole_moment_1': array([-145.9924693 , 738.74936496, 657.97839986]),
'dipole_moment_2': array([-960.16113239, 104.96824669, -258.98314046]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.175,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.00012008361740869356,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.05625,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.24702915581216964,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.15,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.7528507605704217,
'dipole_frequency_1': 3827.2315421318913,
'dipole_frequency_2': 1.9301094166184413e-05,
'dipole_location_1': array([ 5.02067673, -0.9783039 , 6.1431897 ]),
'dipole_location_2': array([ 4.66628999, 10.80907459, 7.21771744]),
'dipole_moment_1': array([ 871.30659253, -299.17389491, -388.99846068]),
'dipole_moment_2': array([-189.87268624, 677.28285845, 710.79975568]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 4.9116305003549454e-08,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.0109375,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.11316396672817797,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.028125,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.886835984155517,
'dipole_frequency_1': 1.1715179359592061e-05,
'dipole_frequency_2': 0.0019103783276337497,
'dipole_location_1': array([-0.95736547, 1.09273812, 7.47158641]),
'dipole_location_2': array([ -3.18510322, -15.64493131, 5.81623624]),
'dipole_moment_1': array([-184.64961369, 956.56786553, 225.57136075]),
'dipole_moment_2': array([ -34.63395137, 801.17771816, -597.42342885]),
}),
dict({
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 1.977090156727901e-10,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.00045552157211010855,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.002734375,
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.9995444782301809,
'dipole_frequency_1': 999786.9069039805,
'dipole_frequency_2': 186034.67996840767,
'dipole_location_1': array([-5.59679125, 6.3411602 , 5.33602522]),
'dipole_location_2': array([-0.03412955, -6.83522954, 5.58551513]),
'dipole_moment_1': array([826.38270589, 491.81526944, 274.24325726]),
'dipole_moment_2': array([ 202.74745884, -656.07483714, -726.95204519]),
}),
])
# ---

View File

@ -0,0 +1,156 @@
import deepdog
import logging
import logging.config
import numpy.random
from pdme.model import (
LogSpacedRandomCountMultipleDipoleFixedMagnitudeModel,
LogSpacedRandomCountMultipleDipoleFixedMagnitudeXYModel,
LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel,
)
_logger = logging.getLogger(__name__)
def fixed_z_model_func(
xmin,
xmax,
ymin,
ymax,
zmin,
zmax,
wexp_min,
wexp_max,
pfixed,
n_max,
prob_occupancy,
):
return LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel(
xmin,
xmax,
ymin,
ymax,
zmin,
zmax,
wexp_min,
wexp_max,
pfixed,
0,
0,
n_max,
prob_occupancy,
)
def get_model(orientation):
model_funcs = {
"fixedz": fixed_z_model_func,
"free": LogSpacedRandomCountMultipleDipoleFixedMagnitudeModel,
"fixedxy": LogSpacedRandomCountMultipleDipoleFixedMagnitudeXYModel,
}
model = model_funcs[orientation](
-10,
10,
-17.5,
17.5,
5,
7.5,
-5,
6.5,
10**3,
2,
0.99999999,
)
model.n = 2
model.rng = numpy.random.default_rng(1234)
return (
f"connors_geom-5height-orientation_{orientation}-pfixexp_{3}-dipole_count_{2}",
model,
)
def test_basic_analysis(snapshot):
dot_positions = [[0, 0, 0], [0, 1, 0]]
freqs = [1, 10, 100]
models = []
orientations = ["free", "fixedxy", "fixedz"]
for orientation in orientations:
models.append(get_model(orientation))
_logger.info(f"have {len(models)} models to look at")
if len(models) == 1:
_logger.info(f"only one model, name: {models[0][0]}")
square_run = deepdog.BayesRunWithSubspaceSimulation(
dot_positions,
freqs,
models,
models[0][1],
filename_slug="test",
end_threshold=0.9,
ss_n_c=5,
ss_n_s=2,
ss_m_max=10,
ss_target_cost=150,
ss_level_0_seed=200,
ss_mcmc_seed=20,
ss_use_adaptive_steps=True,
ss_default_phi_step=0.01,
ss_default_theta_step=0.01,
ss_default_r_step=0.01,
ss_default_w_log_step=0.01,
ss_default_upper_w_log_step=4,
ss_dump_last_generation=False,
write_output_to_bayesruncsv=False,
)
result = square_run.go()
assert result == snapshot
def test_bayesss_with_tighter_cost(snapshot):
dot_positions = [[0, 0, 0], [0, 1, 0]]
freqs = [1, 10, 100]
models = []
orientations = ["free", "fixedxy", "fixedz"]
for orientation in orientations:
models.append(get_model(orientation))
_logger.info(f"have {len(models)} models to look at")
if len(models) == 1:
_logger.info(f"only one model, name: {models[0][0]}")
square_run = deepdog.BayesRunWithSubspaceSimulation(
dot_positions,
freqs,
models,
models[0][1],
filename_slug="test",
end_threshold=0.9,
ss_n_c=5,
ss_n_s=2,
ss_m_max=10,
ss_target_cost=1.5,
ss_level_0_seed=200,
ss_mcmc_seed=20,
ss_use_adaptive_steps=True,
ss_default_phi_step=0.01,
ss_default_theta_step=0.01,
ss_default_r_step=0.01,
ss_default_w_log_step=0.01,
ss_default_upper_w_log_step=4,
ss_dump_last_generation=False,
write_output_to_bayesruncsv=False,
)
result = square_run.go()
assert result == snapshot