feat: adds utility functions for dealing with markov chain monte carlo

This commit is contained in:
Deepak Mallubhotla 2023-07-23 16:30:19 -05:00
parent e9bb62c0a0
commit feb0a5f645
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
5 changed files with 151 additions and 0 deletions

View File

@ -0,0 +1,42 @@
from dataclasses import dataclass
from typing import Sequence
import numpy
@dataclass
class DipoleStandardDeviation:
"""
contains the dipole standard deviation to be used in porposals for markov chain monte carlo
"""
p_phi_step: float
p_theta_step: float
rx_step: float
ry_step: float
rz_step: float
w_log_step: float
class MCMCStandardDeviation:
"""
wrapper for multiple standard deviations, allows for flexible length stuff
"""
def __init__(self, stdevs: Sequence[DipoleStandardDeviation]):
self.stdevs = stdevs
if len(stdevs) < 1:
raise ValueError(f"Got stdevs: {stdevs}, must have length > 1")
def __getitem__(self, key):
newkey = key % len(self.stdevs)
return self.stdevs[newkey]
def sort_array_of_dipoles_by_frequency(configuration) -> numpy.ndarray:
"""
Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array.
For each of the 8 samples, we want the 2 dipoles to be in order of frequency.
Utility function.
"""
return numpy.array(sorted(configuration, key=lambda l: l[6]))

View File

@ -0,0 +1,32 @@
# serializer version: 1
# name: test_sort_dipoles_by_freq
list([
list([
100.0,
200.0,
300.0,
400.0,
500.0,
600.0,
0.07,
]),
list([
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
]),
list([
10.0,
200.0,
30.0,
41.0,
315.0,
0.31,
100.0,
]),
])
# ---

View File

@ -0,0 +1,22 @@
# serializer version: 1
# name: test_return_four
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
# ---
# name: test_return_four.1
DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60)
# ---
# name: test_return_four.2
DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6)
# ---
# name: test_return_four.3
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
# ---
# name: test_return_four.4
DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60)
# ---
# name: test_return_four.5
DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6)
# ---
# name: test_return_one
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
# ---

View File

@ -0,0 +1,17 @@
import numpy
import pdme.subspace_simulation
def test_sort_dipoles_by_freq(snapshot):
orig = numpy.array(
[
[1, 2, 3, 4, 5, 6, 7],
[100, 200, 300, 400, 500, 600, 0.07],
[10, 200, 30, 41, 315, 0.31, 100],
]
)
actual_sorted = (
pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(orig)
)
assert actual_sorted.tolist() == snapshot

View File

@ -0,0 +1,38 @@
import pytest
import pdme.subspace_simulation
def test_empty():
with pytest.raises(ValueError):
pdme.subspace_simulation.MCMCStandardDeviation([])
def test_return_one(snapshot):
stdev = pdme.subspace_simulation.DipoleStandardDeviation(
1,
2,
3,
4,
5,
6,
)
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
assert stdevs[3] == snapshot
assert stdevs[3] == stdev
def test_return_four(snapshot):
stdev_list = [
pdme.subspace_simulation.DipoleStandardDeviation(1, 2, 3, 4, 5, 6),
pdme.subspace_simulation.DipoleStandardDeviation(10, 20, 30, 40, 50, 60),
pdme.subspace_simulation.DipoleStandardDeviation(0.1, 0.2, 0.3, 0.4, 0.5, 0.6),
]
stdevs = pdme.subspace_simulation.MCMCStandardDeviation(stdev_list)
assert stdevs[0] == snapshot
assert stdevs[1] == snapshot
assert stdevs[2] == snapshot
assert stdevs[3] == snapshot
assert stdevs[4] == snapshot
assert stdevs[5] == snapshot