diff --git a/pdme/util/fast_nonlocal_spectrum.py b/pdme/util/fast_nonlocal_spectrum.py index 79998ff..a9f838d 100644 --- a/pdme/util/fast_nonlocal_spectrum.py +++ b/pdme/util/fast_nonlocal_spectrum.py @@ -98,3 +98,11 @@ def fast_s_nonlocal_dipoleses( _logger.debug(f"Raw pair calc: [{alphses1 * alphses2 * bses}]") return numpy.einsum("...j->...", alphses1 * alphses2 * bses) + + +def signarg(x, **kwargs): + """ + uses numpy.sign to implement Arg for real numbers only. Should return pi for negative inputs, 0 for positive. + Passes through args to numpy.sign + """ + return numpy.pi * (numpy.sign(x, **kwargs) - 1) / (-2) diff --git a/tests/util/__snapshots__/test_fast_nonlocal_spectrum.ambr b/tests/util/__snapshots__/test_fast_nonlocal_spectrum.ambr new file mode 100644 index 0000000..82a208f --- /dev/null +++ b/tests/util/__snapshots__/test_fast_nonlocal_spectrum.ambr @@ -0,0 +1,20 @@ +# serializer version: 1 +# name: test_arg + list([ + list([ + -0.0, + -0.0, + -0.0, + ]), + list([ + 3.141592653589793, + -0.0, + -0.0, + ]), + list([ + -0.0, + -0.0, + 3.141592653589793, + ]), + ]) +# --- diff --git a/tests/util/test_fast_nonlocal_spectrum.py b/tests/util/test_fast_nonlocal_spectrum.py index 49d6ae8..2f1a6c1 100644 --- a/tests/util/test_fast_nonlocal_spectrum.py +++ b/tests/util/test_fast_nonlocal_spectrum.py @@ -53,3 +53,11 @@ def test_fast_nonlocal_frequency_check(): with pytest.raises(ValueError): pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal(dot_pairs, dipoles) + + +def test_arg(snapshot): + + test_input = numpy.array([[1, 2, 3], [-1, 1, 3], [3, 5, -1]]) + + actual_result = pdme.util.fast_nonlocal_spectrum.signarg(test_input) + assert actual_result.tolist() == snapshot