From feb0a5f6453dcb5e71a07c7749cd579dab15171c Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 23 Jul 2023 16:30:19 -0500 Subject: [PATCH] feat: adds utility functions for dealing with markov chain monte carlo --- pdme/subspace_simulation/__init__.py | 42 +++++++++++++++++++ .../__snapshots__/test_sort_dipoles.ambr | 32 ++++++++++++++ .../__snapshots__/test_stdevs.ambr | 22 ++++++++++ .../subspace_simulation/test_sort_dipoles.py | 17 ++++++++ tests/subspace_simulation/test_stdevs.py | 38 +++++++++++++++++ 5 files changed, 151 insertions(+) create mode 100644 pdme/subspace_simulation/__init__.py create mode 100644 tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr create mode 100644 tests/subspace_simulation/__snapshots__/test_stdevs.ambr create mode 100644 tests/subspace_simulation/test_sort_dipoles.py create mode 100644 tests/subspace_simulation/test_stdevs.py diff --git a/pdme/subspace_simulation/__init__.py b/pdme/subspace_simulation/__init__.py new file mode 100644 index 0000000..b7467bc --- /dev/null +++ b/pdme/subspace_simulation/__init__.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from typing import Sequence +import numpy + + +@dataclass +class DipoleStandardDeviation: + """ + contains the dipole standard deviation to be used in porposals for markov chain monte carlo + """ + + p_phi_step: float + p_theta_step: float + rx_step: float + ry_step: float + rz_step: float + w_log_step: float + + +class MCMCStandardDeviation: + """ + wrapper for multiple standard deviations, allows for flexible length stuff + """ + + def __init__(self, stdevs: Sequence[DipoleStandardDeviation]): + self.stdevs = stdevs + if len(stdevs) < 1: + raise ValueError(f"Got stdevs: {stdevs}, must have length > 1") + + def __getitem__(self, key): + newkey = key % len(self.stdevs) + return self.stdevs[newkey] + + +def sort_array_of_dipoles_by_frequency(configuration) -> numpy.ndarray: + """ + Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array. + For each of the 8 samples, we want the 2 dipoles to be in order of frequency. + + Utility function. + """ + return numpy.array(sorted(configuration, key=lambda l: l[6])) diff --git a/tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr b/tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr new file mode 100644 index 0000000..b7d67f9 --- /dev/null +++ b/tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr @@ -0,0 +1,32 @@ +# serializer version: 1 +# name: test_sort_dipoles_by_freq + list([ + list([ + 100.0, + 200.0, + 300.0, + 400.0, + 500.0, + 600.0, + 0.07, + ]), + list([ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + ]), + list([ + 10.0, + 200.0, + 30.0, + 41.0, + 315.0, + 0.31, + 100.0, + ]), + ]) +# --- diff --git a/tests/subspace_simulation/__snapshots__/test_stdevs.ambr b/tests/subspace_simulation/__snapshots__/test_stdevs.ambr new file mode 100644 index 0000000..6bae65b --- /dev/null +++ b/tests/subspace_simulation/__snapshots__/test_stdevs.ambr @@ -0,0 +1,22 @@ +# serializer version: 1 +# name: test_return_four + DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6) +# --- +# name: test_return_four.1 + DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60) +# --- +# name: test_return_four.2 + DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6) +# --- +# name: test_return_four.3 + DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6) +# --- +# name: test_return_four.4 + DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60) +# --- +# name: test_return_four.5 + DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6) +# --- +# name: test_return_one + DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6) +# --- diff --git a/tests/subspace_simulation/test_sort_dipoles.py b/tests/subspace_simulation/test_sort_dipoles.py new file mode 100644 index 0000000..550c358 --- /dev/null +++ b/tests/subspace_simulation/test_sort_dipoles.py @@ -0,0 +1,17 @@ +import numpy +import pdme.subspace_simulation + + +def test_sort_dipoles_by_freq(snapshot): + orig = numpy.array( + [ + [1, 2, 3, 4, 5, 6, 7], + [100, 200, 300, 400, 500, 600, 0.07], + [10, 200, 30, 41, 315, 0.31, 100], + ] + ) + + actual_sorted = ( + pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(orig) + ) + assert actual_sorted.tolist() == snapshot diff --git a/tests/subspace_simulation/test_stdevs.py b/tests/subspace_simulation/test_stdevs.py new file mode 100644 index 0000000..cda5d87 --- /dev/null +++ b/tests/subspace_simulation/test_stdevs.py @@ -0,0 +1,38 @@ +import pytest +import pdme.subspace_simulation + + +def test_empty(): + with pytest.raises(ValueError): + pdme.subspace_simulation.MCMCStandardDeviation([]) + + +def test_return_one(snapshot): + stdev = pdme.subspace_simulation.DipoleStandardDeviation( + 1, + 2, + 3, + 4, + 5, + 6, + ) + stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev]) + + assert stdevs[3] == snapshot + assert stdevs[3] == stdev + + +def test_return_four(snapshot): + stdev_list = [ + pdme.subspace_simulation.DipoleStandardDeviation(1, 2, 3, 4, 5, 6), + pdme.subspace_simulation.DipoleStandardDeviation(10, 20, 30, 40, 50, 60), + pdme.subspace_simulation.DipoleStandardDeviation(0.1, 0.2, 0.3, 0.4, 0.5, 0.6), + ] + stdevs = pdme.subspace_simulation.MCMCStandardDeviation(stdev_list) + + assert stdevs[0] == snapshot + assert stdevs[1] == snapshot + assert stdevs[2] == snapshot + assert stdevs[3] == snapshot + assert stdevs[4] == snapshot + assert stdevs[5] == snapshot