feat: adds utility functions for dealing with markov chain monte carlo
This commit is contained in:
parent
e9bb62c0a0
commit
feb0a5f645
42
pdme/subspace_simulation/__init__.py
Normal file
42
pdme/subspace_simulation/__init__.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Sequence
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DipoleStandardDeviation:
|
||||||
|
"""
|
||||||
|
contains the dipole standard deviation to be used in porposals for markov chain monte carlo
|
||||||
|
"""
|
||||||
|
|
||||||
|
p_phi_step: float
|
||||||
|
p_theta_step: float
|
||||||
|
rx_step: float
|
||||||
|
ry_step: float
|
||||||
|
rz_step: float
|
||||||
|
w_log_step: float
|
||||||
|
|
||||||
|
|
||||||
|
class MCMCStandardDeviation:
|
||||||
|
"""
|
||||||
|
wrapper for multiple standard deviations, allows for flexible length stuff
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stdevs: Sequence[DipoleStandardDeviation]):
|
||||||
|
self.stdevs = stdevs
|
||||||
|
if len(stdevs) < 1:
|
||||||
|
raise ValueError(f"Got stdevs: {stdevs}, must have length > 1")
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
newkey = key % len(self.stdevs)
|
||||||
|
return self.stdevs[newkey]
|
||||||
|
|
||||||
|
|
||||||
|
def sort_array_of_dipoles_by_frequency(configuration) -> numpy.ndarray:
|
||||||
|
"""
|
||||||
|
Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array.
|
||||||
|
For each of the 8 samples, we want the 2 dipoles to be in order of frequency.
|
||||||
|
|
||||||
|
Utility function.
|
||||||
|
"""
|
||||||
|
return numpy.array(sorted(configuration, key=lambda l: l[6]))
|
@ -0,0 +1,32 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_sort_dipoles_by_freq
|
||||||
|
list([
|
||||||
|
list([
|
||||||
|
100.0,
|
||||||
|
200.0,
|
||||||
|
300.0,
|
||||||
|
400.0,
|
||||||
|
500.0,
|
||||||
|
600.0,
|
||||||
|
0.07,
|
||||||
|
]),
|
||||||
|
list([
|
||||||
|
1.0,
|
||||||
|
2.0,
|
||||||
|
3.0,
|
||||||
|
4.0,
|
||||||
|
5.0,
|
||||||
|
6.0,
|
||||||
|
7.0,
|
||||||
|
]),
|
||||||
|
list([
|
||||||
|
10.0,
|
||||||
|
200.0,
|
||||||
|
30.0,
|
||||||
|
41.0,
|
||||||
|
315.0,
|
||||||
|
0.31,
|
||||||
|
100.0,
|
||||||
|
]),
|
||||||
|
])
|
||||||
|
# ---
|
22
tests/subspace_simulation/__snapshots__/test_stdevs.ambr
Normal file
22
tests/subspace_simulation/__snapshots__/test_stdevs.ambr
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_return_four
|
||||||
|
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
|
||||||
|
# ---
|
||||||
|
# name: test_return_four.1
|
||||||
|
DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60)
|
||||||
|
# ---
|
||||||
|
# name: test_return_four.2
|
||||||
|
DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6)
|
||||||
|
# ---
|
||||||
|
# name: test_return_four.3
|
||||||
|
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
|
||||||
|
# ---
|
||||||
|
# name: test_return_four.4
|
||||||
|
DipoleStandardDeviation(p_phi_step=10, p_theta_step=20, rx_step=30, ry_step=40, rz_step=50, w_log_step=60)
|
||||||
|
# ---
|
||||||
|
# name: test_return_four.5
|
||||||
|
DipoleStandardDeviation(p_phi_step=0.1, p_theta_step=0.2, rx_step=0.3, ry_step=0.4, rz_step=0.5, w_log_step=0.6)
|
||||||
|
# ---
|
||||||
|
# name: test_return_one
|
||||||
|
DipoleStandardDeviation(p_phi_step=1, p_theta_step=2, rx_step=3, ry_step=4, rz_step=5, w_log_step=6)
|
||||||
|
# ---
|
17
tests/subspace_simulation/test_sort_dipoles.py
Normal file
17
tests/subspace_simulation/test_sort_dipoles.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import numpy
|
||||||
|
import pdme.subspace_simulation
|
||||||
|
|
||||||
|
|
||||||
|
def test_sort_dipoles_by_freq(snapshot):
|
||||||
|
orig = numpy.array(
|
||||||
|
[
|
||||||
|
[1, 2, 3, 4, 5, 6, 7],
|
||||||
|
[100, 200, 300, 400, 500, 600, 0.07],
|
||||||
|
[10, 200, 30, 41, 315, 0.31, 100],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
actual_sorted = (
|
||||||
|
pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(orig)
|
||||||
|
)
|
||||||
|
assert actual_sorted.tolist() == snapshot
|
38
tests/subspace_simulation/test_stdevs.py
Normal file
38
tests/subspace_simulation/test_stdevs.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import pytest
|
||||||
|
import pdme.subspace_simulation
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pdme.subspace_simulation.MCMCStandardDeviation([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_one(snapshot):
|
||||||
|
stdev = pdme.subspace_simulation.DipoleStandardDeviation(
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
)
|
||||||
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||||
|
|
||||||
|
assert stdevs[3] == snapshot
|
||||||
|
assert stdevs[3] == stdev
|
||||||
|
|
||||||
|
|
||||||
|
def test_return_four(snapshot):
|
||||||
|
stdev_list = [
|
||||||
|
pdme.subspace_simulation.DipoleStandardDeviation(1, 2, 3, 4, 5, 6),
|
||||||
|
pdme.subspace_simulation.DipoleStandardDeviation(10, 20, 30, 40, 50, 60),
|
||||||
|
pdme.subspace_simulation.DipoleStandardDeviation(0.1, 0.2, 0.3, 0.4, 0.5, 0.6),
|
||||||
|
]
|
||||||
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation(stdev_list)
|
||||||
|
|
||||||
|
assert stdevs[0] == snapshot
|
||||||
|
assert stdevs[1] == snapshot
|
||||||
|
assert stdevs[2] == snapshot
|
||||||
|
assert stdevs[3] == snapshot
|
||||||
|
assert stdevs[4] == snapshot
|
||||||
|
assert stdevs[5] == snapshot
|
Loading…
x
Reference in New Issue
Block a user