diff --git a/pdme/subspace_simulation/__init__.py b/pdme/subspace_simulation/__init__.py index 83593e6..b898fbd 100644 --- a/pdme/subspace_simulation/__init__.py +++ b/pdme/subspace_simulation/__init__.py @@ -41,11 +41,30 @@ def sort_array_of_dipoles_by_frequency(configuration) -> numpy.ndarray: Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array. For each of the 8 samples, we want the 2 dipoles to be in order of frequency. + This just sorts each sample, the 2x7 array. + Utility function. """ return numpy.array(sorted(configuration, key=lambda l: l[6])) +def sort_array_of_dipoleses_by_frequency(configurations) -> numpy.ndarray: + """ + Say we have a situation of 2 dipoles, and we've created 8 samples. Then we'll have an (8, 2, 7) numpy array. + For each of the 8 samples, we want the 2 dipoles to be in order of frequency. + + This is the wrapper that sorts everything. + + Utility function. + """ + return numpy.array( + [ + sort_array_of_dipoles_by_frequency(configuration) + for configuration in configurations + ] + ) + + __all__ = [ "DipoleStandardDeviation", "MCMCStandardDeviation", diff --git a/tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr b/tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr index b7d67f9..fc491fc 100644 --- a/tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr +++ b/tests/subspace_simulation/__snapshots__/test_sort_dipoles.ambr @@ -30,3 +30,65 @@ ]), ]) # --- +# name: test_sort_dipoleses_by_freq + list([ + list([ + list([ + 100.0, + 200.0, + 300.0, + 400.0, + 500.0, + 600.0, + 0.07, + ]), + list([ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + ]), + list([ + 10.0, + 200.0, + 30.0, + 41.0, + 315.0, + 0.31, + 100.0, + ]), + ]), + list([ + list([ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 100.0, + ]), + list([ + 22.0, + 22.2, + 2.2, + 222.0, + 22.0, + 2.0, + 200.0, + ]), + list([ + 33.0, + 33.3, + 33.0, + 3.3, + 0.33, + 0.3, + 300.0, + ]), + ]), + ]) +# --- diff --git a/tests/subspace_simulation/test_sort_dipoles.py b/tests/subspace_simulation/test_sort_dipoles.py index 2d8c28e..450c2b4 100644 --- a/tests/subspace_simulation/test_sort_dipoles.py +++ b/tests/subspace_simulation/test_sort_dipoles.py @@ -13,3 +13,28 @@ def test_sort_dipoles_by_freq(snapshot): actual_sorted = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(orig) assert actual_sorted.tolist() == snapshot + + +def test_sort_dipoleses_by_freq(snapshot): + sample_1 = numpy.array( + [ + [1, 2, 3, 4, 5, 6, 7], + [100, 200, 300, 400, 500, 600, 0.07], + [10, 200, 30, 41, 315, 0.31, 100], + ] + ) + + sample_2 = numpy.array( + [ + [1, 1, 1, 1, 1, 1, 100], + [33, 33.3, 33, 3.3, 0.33, 0.3, 300], + [22, 22.2, 2.2, 222, 22, 2, 200], + ] + ) + + original_samples = numpy.array([sample_1, sample_2]) + + actual_sorted = pdme.subspace_simulation.sort_array_of_dipoleses_by_frequency( + original_samples + ) + assert actual_sorted.tolist() == snapshot