From 8e6ead416c9eba56f568f648d0df44caaa510cfe Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Thu, 27 Jul 2023 17:09:36 -0500 Subject: [PATCH] feat: allows for deepdog bayesrun with ss to not print csv to make snapshot testing possible --- deepdog/bayes_run_with_ss.py | 31 ++- poetry.lock | 24 ++- pyproject.toml | 1 + .../__snapshots__/test_bayes_run_with_ss.ambr | 177 ++++++++++++++++++ tests/test_bayes_run_with_ss.py | 156 +++++++++++++++ 5 files changed, 379 insertions(+), 10 deletions(-) create mode 100644 tests/__snapshots__/test_bayes_run_with_ss.ambr create mode 100644 tests/test_bayes_run_with_ss.py diff --git a/deepdog/bayes_run_with_ss.py b/deepdog/bayes_run_with_ss.py index 885885f..f8f0a48 100644 --- a/deepdog/bayes_run_with_ss.py +++ b/deepdog/bayes_run_with_ss.py @@ -71,6 +71,7 @@ class BayesRunWithSubspaceSimulation: ss_default_w_log_step=0.01, ss_default_upper_w_log_step=4, ss_dump_last_generation=False, + write_output_to_bayesruncsv=True, ) -> None: self.dot_inputs = pdme.inputs.inputs_with_frequency_range( dot_positions, frequency_range @@ -138,10 +139,18 @@ class BayesRunWithSubspaceSimulation: self.run_count = run_count - def go(self) -> None: - with open(self.filename, "a", newline="") as outfile: - writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") - writer.writeheader() + self.write_output_to_csv = write_output_to_bayesruncsv + + def go(self) -> Sequence: + + if self.write_output_to_csv: + with open(self.filename, "a", newline="") as outfile: + writer = csv.DictWriter( + outfile, fieldnames=self.csv_fields, dialect="unix" + ) + writer.writeheader() + + return_result = [] for run in range(1, self.run_count + 1): @@ -222,12 +231,14 @@ class BayesRunWithSubspaceSimulation: for name, probability in zip(self.model_names, self.probabilities): row[f"{name}_prob"] = probability _logger.info(row) + return_result.append(row) - with open(self.filename, "a", newline="") as outfile: - writer = csv.DictWriter( - outfile, fieldnames=self.csv_fields, dialect="unix" - ) - writer.writerow(row) + if self.write_output_to_csv: + with open(self.filename, "a", newline="") as outfile: + writer = csv.DictWriter( + outfile, fieldnames=self.csv_fields, dialect="unix" + ) + writer.writerow(row) if self.use_end_threshold: max_prob = max(self.probabilities) @@ -236,3 +247,5 @@ class BayesRunWithSubspaceSimulation: f"Aborting early, because {max_prob} is greater than {self.end_threshold}" ) break + + return return_result diff --git a/poetry.lock b/poetry.lock index 4f73dad..e2f1de2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -92,6 +92,14 @@ category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +[[package]] +name = "colored" +version = "1.4.4" +description = "Simple library for color and formatting to terminal" +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "coverage" version = "7.2.7" @@ -633,6 +641,18 @@ category = "dev" optional = false python-versions = ">=3.6" +[[package]] +name = "syrupy" +version = "4.0.8" +description = "Pytest Snapshot Test Utility" +category = "dev" +optional = false +python-versions = ">=3.8.1,<4" + +[package.dependencies] +colored = ">=1.3.92,<2.0.0" +pytest = ">=7.0.0,<8.0.0" + [[package]] name = "tomli" version = "2.0.1" @@ -730,7 +750,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co [metadata] lock-version = "1.1" python-versions = ">=3.8.1,<3.10" -content-hash = "111972d04616ce3ddfc9039a0b38c7eb7c4a41f10390139b27e958aedac7e979" +content-hash = "e1531b1493bac50ffe5e8f9a46a64d9b66198f7021f6d643c72f21cb53dc77ec" [metadata.files] black = [] @@ -741,6 +761,7 @@ charset-normalizer = [] click = [] click-log = [] colorama = [] +colored = [] coverage = [] cryptography = [] docutils = [] @@ -786,6 +807,7 @@ secretstorage = [] semver = [] six = [] smmap = [] +syrupy = [] tomli = [] tomlkit = [] tqdm = [] diff --git a/pyproject.toml b/pyproject.toml index 89a47ec..26b71f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ pytest-cov = "^4.1.0" mypy = "^0.971" python-semantic-release = "^7.24.0" black = "^22.3.0" +syrupy = "^4.0.8" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/__snapshots__/test_bayes_run_with_ss.ambr b/tests/__snapshots__/test_bayes_run_with_ss.ambr new file mode 100644 index 0000000..5c5e384 --- /dev/null +++ b/tests/__snapshots__/test_bayes_run_with_ss.ambr @@ -0,0 +1,177 @@ +# serializer version: 1 +# name: test_basic_analysis + list([ + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.3333333333333333, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.3333333333333333, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.3333333333333333, + 'dipole_frequency_1': 0.006029931414230269, + 'dipole_frequency_2': 85436.78758379082, + 'dipole_location_1': array([-4.76615152, -6.33160296, 5.29522808]), + 'dipole_location_2': array([-4.72700391, -2.06478573, 6.52467702]), + 'dipole_moment_1': array([ 860.14181416, -450.27082062, -239.60852996]), + 'dipole_moment_2': array([ 908.18325588, -208.52681777, -362.93214244]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.45, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.3103448275862069, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.9, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.6206896551724138, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.06896551724137932, + 'dipole_frequency_1': 102275.63477261562, + 'dipole_frequency_2': 1755280.9783485082, + 'dipole_location_1': array([ 4.71515397, -9.70362197, 5.43016546]), + 'dipole_location_2': array([3.42476038, 3.88562934, 5.15034328]), + 'dipole_moment_1': array([-502.60742674, -790.60222587, 349.7626267 ]), + 'dipole_moment_2': array([-192.42708465, -434.81009148, -879.7226844 ]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.7, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.6631578947368421, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.18947368421052635, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.7, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.1473684210526316, + 'dipole_frequency_1': 2896.799464036654, + 'dipole_frequency_2': 9.980565189326681e-05, + 'dipole_location_1': array([-4.97465789, 12.54716531, 6.06324588]), + 'dipole_location_2': array([ 9.84518459, -11.1183876 , 7.35028226]), + 'dipole_moment_1': array([997.67961917, 19.6376112 , 65.19004305]), + 'dipole_moment_2': array([305.63093655, 440.57669389, 844.08643362]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.663157894736842, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.18947368421052635, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.1473684210526316, + 'dipole_frequency_1': 1.4522667818288244, + 'dipole_frequency_2': 2704.9795645301197, + 'dipole_location_1': array([ 7.38183022, 16.6745801 , 7.10428414]), + 'dipole_location_2': array([-8.15636906, -9.56609132, 6.34141559]), + 'dipole_moment_1': array([-145.9924693 , 738.74936496, 657.97839986]), + 'dipole_moment_2': array([-960.16113239, 104.96824669, -258.98314046]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.9, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.9465776293823038, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.030050083472454105, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.02337228714524208, + 'dipole_frequency_1': 3827.2315421318913, + 'dipole_frequency_2': 1.9301094166184413e-05, + 'dipole_location_1': array([ 5.02067673, -0.9783039 , 6.1431897 ]), + 'dipole_location_2': array([ 4.66628999, 10.80907459, 7.21771744]), + 'dipole_moment_1': array([ 871.30659253, -299.17389491, -388.99846068]), + 'dipole_moment_2': array([-189.87268624, 677.28285845, 710.79975568]), + }), + ]) +# --- +# name: test_bayesss_with_tighter_cost + list([ + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.33333333333333337, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.33333333333333337, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.33333333333333337, + 'dipole_frequency_1': 0.006029931414230269, + 'dipole_frequency_2': 85436.78758379082, + 'dipole_location_1': array([-4.76615152, -6.33160296, 5.29522808]), + 'dipole_location_2': array([-4.72700391, -2.06478573, 6.52467702]), + 'dipole_moment_1': array([ 860.14181416, -450.27082062, -239.60852996]), + 'dipole_moment_2': array([ 908.18325588, -208.52681777, -362.93214244]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.0109375, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.1044776119402985, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.03125, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.2985074626865672, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.0625, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5970149253731344, + 'dipole_frequency_1': 102275.63477261562, + 'dipole_frequency_2': 1755280.9783485082, + 'dipole_location_1': array([ 4.71515397, -9.70362197, 5.43016546]), + 'dipole_location_2': array([3.42476038, 3.88562934, 5.15034328]), + 'dipole_moment_1': array([-502.60742674, -790.60222587, 349.7626267 ]), + 'dipole_moment_2': array([-192.42708465, -434.81009148, -879.7226844 ]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 7.291135021404688e-05, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.021875, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.4666326413699001, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.0125, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5332944472798858, + 'dipole_frequency_1': 2896.799464036654, + 'dipole_frequency_2': 9.980565189326681e-05, + 'dipole_location_1': array([-4.97465789, 12.54716531, 6.06324588]), + 'dipole_location_2': array([ 9.84518459, -11.1183876 , 7.35028226]), + 'dipole_moment_1': array([997.67961917, 19.6376112 , 65.19004305]), + 'dipole_moment_2': array([305.63093655, 440.57669389, 844.08643362]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 7.291135021404688e-05, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.4666326413699001, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5332944472798858, + 'dipole_frequency_1': 1.4522667818288244, + 'dipole_frequency_2': 2704.9795645301197, + 'dipole_location_1': array([ 7.38183022, 16.6745801 , 7.10428414]), + 'dipole_location_2': array([-8.15636906, -9.56609132, 6.34141559]), + 'dipole_moment_1': array([-145.9924693 , 738.74936496, 657.97839986]), + 'dipole_moment_2': array([-960.16113239, 104.96824669, -258.98314046]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.175, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.00012008361740869356, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.05625, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.24702915581216964, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.15, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.7528507605704217, + 'dipole_frequency_1': 3827.2315421318913, + 'dipole_frequency_2': 1.9301094166184413e-05, + 'dipole_location_1': array([ 5.02067673, -0.9783039 , 6.1431897 ]), + 'dipole_location_2': array([ 4.66628999, 10.80907459, 7.21771744]), + 'dipole_moment_1': array([ 871.30659253, -299.17389491, -388.99846068]), + 'dipole_moment_2': array([-189.87268624, 677.28285845, 710.79975568]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 4.9116305003549454e-08, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.0109375, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.11316396672817797, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.028125, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.886835984155517, + 'dipole_frequency_1': 1.1715179359592061e-05, + 'dipole_frequency_2': 0.0019103783276337497, + 'dipole_location_1': array([-0.95736547, 1.09273812, 7.47158641]), + 'dipole_location_2': array([ -3.18510322, -15.64493131, 5.81623624]), + 'dipole_moment_1': array([-184.64961369, 956.56786553, 225.57136075]), + 'dipole_moment_2': array([ -34.63395137, 801.17771816, -597.42342885]), + }), + dict({ + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 1.977090156727901e-10, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06, + 'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.00045552157211010855, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.002734375, + 'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.9995444782301809, + 'dipole_frequency_1': 999786.9069039805, + 'dipole_frequency_2': 186034.67996840767, + 'dipole_location_1': array([-5.59679125, 6.3411602 , 5.33602522]), + 'dipole_location_2': array([-0.03412955, -6.83522954, 5.58551513]), + 'dipole_moment_1': array([826.38270589, 491.81526944, 274.24325726]), + 'dipole_moment_2': array([ 202.74745884, -656.07483714, -726.95204519]), + }), + ]) +# --- diff --git a/tests/test_bayes_run_with_ss.py b/tests/test_bayes_run_with_ss.py new file mode 100644 index 0000000..ca2bcda --- /dev/null +++ b/tests/test_bayes_run_with_ss.py @@ -0,0 +1,156 @@ +import deepdog +import logging +import logging.config + +import numpy.random + +from pdme.model import ( + LogSpacedRandomCountMultipleDipoleFixedMagnitudeModel, + LogSpacedRandomCountMultipleDipoleFixedMagnitudeXYModel, + LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel, +) + + +_logger = logging.getLogger(__name__) + + +def fixed_z_model_func( + xmin, + xmax, + ymin, + ymax, + zmin, + zmax, + wexp_min, + wexp_max, + pfixed, + n_max, + prob_occupancy, +): + return LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel( + xmin, + xmax, + ymin, + ymax, + zmin, + zmax, + wexp_min, + wexp_max, + pfixed, + 0, + 0, + n_max, + prob_occupancy, + ) + + +def get_model(orientation): + model_funcs = { + "fixedz": fixed_z_model_func, + "free": LogSpacedRandomCountMultipleDipoleFixedMagnitudeModel, + "fixedxy": LogSpacedRandomCountMultipleDipoleFixedMagnitudeXYModel, + } + model = model_funcs[orientation]( + -10, + 10, + -17.5, + 17.5, + 5, + 7.5, + -5, + 6.5, + 10**3, + 2, + 0.99999999, + ) + model.n = 2 + model.rng = numpy.random.default_rng(1234) + + return ( + f"connors_geom-5height-orientation_{orientation}-pfixexp_{3}-dipole_count_{2}", + model, + ) + + +def test_basic_analysis(snapshot): + + dot_positions = [[0, 0, 0], [0, 1, 0]] + + freqs = [1, 10, 100] + models = [] + + orientations = ["free", "fixedxy", "fixedz"] + for orientation in orientations: + models.append(get_model(orientation)) + + _logger.info(f"have {len(models)} models to look at") + if len(models) == 1: + _logger.info(f"only one model, name: {models[0][0]}") + + square_run = deepdog.BayesRunWithSubspaceSimulation( + dot_positions, + freqs, + models, + models[0][1], + filename_slug="test", + end_threshold=0.9, + ss_n_c=5, + ss_n_s=2, + ss_m_max=10, + ss_target_cost=150, + ss_level_0_seed=200, + ss_mcmc_seed=20, + ss_use_adaptive_steps=True, + ss_default_phi_step=0.01, + ss_default_theta_step=0.01, + ss_default_r_step=0.01, + ss_default_w_log_step=0.01, + ss_default_upper_w_log_step=4, + ss_dump_last_generation=False, + write_output_to_bayesruncsv=False, + ) + result = square_run.go() + + assert result == snapshot + + +def test_bayesss_with_tighter_cost(snapshot): + + dot_positions = [[0, 0, 0], [0, 1, 0]] + + freqs = [1, 10, 100] + models = [] + + orientations = ["free", "fixedxy", "fixedz"] + for orientation in orientations: + models.append(get_model(orientation)) + + _logger.info(f"have {len(models)} models to look at") + if len(models) == 1: + _logger.info(f"only one model, name: {models[0][0]}") + + square_run = deepdog.BayesRunWithSubspaceSimulation( + dot_positions, + freqs, + models, + models[0][1], + filename_slug="test", + end_threshold=0.9, + ss_n_c=5, + ss_n_s=2, + ss_m_max=10, + ss_target_cost=1.5, + ss_level_0_seed=200, + ss_mcmc_seed=20, + ss_use_adaptive_steps=True, + ss_default_phi_step=0.01, + ss_default_theta_step=0.01, + ss_default_r_step=0.01, + ss_default_w_log_step=0.01, + ss_default_upper_w_log_step=4, + ss_dump_last_generation=False, + write_output_to_bayesruncsv=False, + ) + result = square_run.go() + + assert result == snapshot