feat: adds utility function for sorting samples by frequency for subspace simulation

This commit is contained in:
Deepak Mallubhotla 2024-05-16 19:20:55 -05:00
parent 387a607e09
commit e5fc1207a8
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
3 changed files with 106 additions and 0 deletions

View File

@ -41,11 +41,30 @@ def sort_array_of_dipoles_by_frequency(configuration) -> numpy.ndarray:
Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array.
For each of the 8 samples, we want the 2 dipoles to be in order of frequency.
This just sorts each sample, the 2x7 array.
Utility function.
"""
return numpy.array(sorted(configuration, key=lambda l: l[6]))
def sort_array_of_dipoleses_by_frequency(configurations) -> numpy.ndarray:
"""
Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array.
For each of the 8 samples, we want the 2 dipoles to be in order of frequency.
This is the wrapper that sorts everything.
Utility function.
"""
return numpy.array(
[
sort_array_of_dipoles_by_frequency(configuration)
for configuration in configurations
]
)
__all__ = [
"DipoleStandardDeviation",
"MCMCStandardDeviation",

View File

@ -30,3 +30,65 @@
]),
])
# ---
# name: test_sort_dipoleses_by_freq
list([
list([
list([
100.0,
200.0,
300.0,
400.0,
500.0,
600.0,
0.07,
]),
list([
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
]),
list([
10.0,
200.0,
30.0,
41.0,
315.0,
0.31,
100.0,
]),
]),
list([
list([
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
100.0,
]),
list([
22.0,
22.2,
2.2,
222.0,
22.0,
2.0,
200.0,
]),
list([
33.0,
33.3,
33.0,
3.3,
0.33,
0.3,
300.0,
]),
]),
])
# ---

View File

@ -13,3 +13,28 @@ def test_sort_dipoles_by_freq(snapshot):
actual_sorted = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(orig)
assert actual_sorted.tolist() == snapshot
def test_sort_dipoleses_by_freq(snapshot):
sample_1 = numpy.array(
[
[1, 2, 3, 4, 5, 6, 7],
[100, 200, 300, 400, 500, 600, 0.07],
[10, 200, 30, 41, 315, 0.31, 100],
]
)
sample_2 = numpy.array(
[
[1, 1, 1, 1, 1, 1, 100],
[33, 33.3, 33, 3.3, 0.33, 0.3, 300],
[22, 22.2, 2.2, 222, 22, 2, 200],
]
)
original_samples = numpy.array([sample_1, sample_2])
actual_sorted = pdme.subspace_simulation.sort_array_of_dipoleses_by_frequency(
original_samples
)
assert actual_sorted.tolist() == snapshot