feat: adds utility functions for dealing with markov chain monte carlo
This commit is contained in:
parent
e9bb62c0a0
commit
feb0a5f645
42
pdme/subspace_simulation/__init__.py
Normal file
42
pdme/subspace_simulation/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
import numpy
|
||||
|
||||
|
||||
@dataclass
|
||||
class DipoleStandardDeviation:
|
||||
"""
|
||||
contains the dipole standard deviation to be used in porposals for markov chain monte carlo
|
||||
"""
|
||||
|
||||
p_phi_step: float
|
||||
p_theta_step: float
|
||||
rx_step: float
|
||||
ry_step: float
|
||||
rz_step: float
|
||||
w_log_step: float
|
||||
|
||||
|
||||
class MCMCStandardDeviation:
|
||||
"""
|
||||
wrapper for multiple standard deviations, allows for flexible length stuff
|
||||
"""
|
||||
|
||||
def __init__(self, stdevs: Sequence[DipoleStandardDeviation]):
|
||||
self.stdevs = stdevs
|
||||
if len(stdevs) < 1:
|
||||
raise ValueError(f"Got stdevs: {stdevs}, must have length > 1")
|
||||
|
||||
def __getitem__(self, key):
|
||||
newkey = key % len(self.stdevs)
|
||||
return self.stdevs[newkey]
|
||||
|
||||
|
||||
def sort_array_of_dipoles_by_frequency(configuration) -> numpy.ndarray:
|
||||
"""
|
||||
Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array.
|
||||
For each of the 8 samples, we want the 2 dipoles to be in order of frequency.
|
||||
|
||||
Utility function.
|
||||
"""
|
||||
return numpy.array(sorted(configuration, key=lambda l: l[6]))
|
@ -0,0 +1,32 @@
|
||||
# serializer version: 1
|
||||
# name: test_sort_dipoles_by_freq
|
||||
list([
|
||||
list([
|
||||
100.0,
|
||||
200.0,
|
||||
300.0,
|
||||
400.0,
|
||||
500.0,
|
||||
600.0,
|
||||
0.07,
|
||||
]),
|
||||
list([
|
||||
1.0,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
5.0,
|
||||
6.0,
|
||||
7.0,
|
||||
]),
|
||||
list([
|
||||
10.0,
|
||||
200.0,
|
||||
30.0,
|
||||
41.0,
|
||||
315.0,
|
||||
0.31,
|
||||
100.0,
|
||||
]),
|
||||
])
|
||||
# ---
|
22
tests/subspace_simulation/__snapshots__/test_stdevs.ambr
Normal file
22
tests/subspace_simulation/__snapshots__/test_stdevs.ambr
Normal file
@ -0,0 +1,22 @@
|
||||
# serializer version: 1
|
||||
# name: test_return_four
|
||||
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
|
||||
# ---
|
||||
# name: test_return_four.1
|
||||
DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60)
|
||||
# ---
|
||||
# name: test_return_four.2
|
||||
DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6)
|
||||
# ---
|
||||
# name: test_return_four.3
|
||||
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
|
||||
# ---
|
||||
# name: test_return_four.4
|
||||
DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60)
|
||||
# ---
|
||||
# name: test_return_four.5
|
||||
DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6)
|
||||
# ---
|
||||
# name: test_return_one
|
||||
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
|
||||
# ---
|
17
tests/subspace_simulation/test_sort_dipoles.py
Normal file
17
tests/subspace_simulation/test_sort_dipoles.py
Normal file
@ -0,0 +1,17 @@
|
||||
import numpy
|
||||
import pdme.subspace_simulation
|
||||
|
||||
|
||||
def test_sort_dipoles_by_freq(snapshot):
|
||||
orig = numpy.array(
|
||||
[
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
[100, 200, 300, 400, 500, 600, 0.07],
|
||||
[10, 200, 30, 41, 315, 0.31, 100],
|
||||
]
|
||||
)
|
||||
|
||||
actual_sorted = (
|
||||
pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(orig)
|
||||
)
|
||||
assert actual_sorted.tolist() == snapshot
|
38
tests/subspace_simulation/test_stdevs.py
Normal file
38
tests/subspace_simulation/test_stdevs.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
import pdme.subspace_simulation
|
||||
|
||||
|
||||
def test_empty():
|
||||
with pytest.raises(ValueError):
|
||||
pdme.subspace_simulation.MCMCStandardDeviation([])
|
||||
|
||||
|
||||
def test_return_one(snapshot):
|
||||
stdev = pdme.subspace_simulation.DipoleStandardDeviation(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
)
|
||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||
|
||||
assert stdevs[3] == snapshot
|
||||
assert stdevs[3] == stdev
|
||||
|
||||
|
||||
def test_return_four(snapshot):
|
||||
stdev_list = [
|
||||
pdme.subspace_simulation.DipoleStandardDeviation(1, 2, 3, 4, 5, 6),
|
||||
pdme.subspace_simulation.DipoleStandardDeviation(10, 20, 30, 40, 50, 60),
|
||||
pdme.subspace_simulation.DipoleStandardDeviation(0.1, 0.2, 0.3, 0.4, 0.5, 0.6),
|
||||
]
|
||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation(stdev_list)
|
||||
|
||||
assert stdevs[0] == snapshot
|
||||
assert stdevs[1] == snapshot
|
||||
assert stdevs[2] == snapshot
|
||||
assert stdevs[3] == snapshot
|
||||
assert stdevs[4] == snapshot
|
||||
assert stdevs[5] == snapshot
|
Loading…
x
Reference in New Issue
Block a user