import pytest import numpy import tantri.dipoles.time_series def test_apsd_merge(): freqs = numpy.array([0.01, 0.1, 1, 10, 100]) dict1 = {"t1": numpy.array([1, 2, 3, 4, 5])} a1 = tantri.dipoles.time_series.APSDResult(dict1, freqs) dict2 = {"t1": numpy.array([3, 4, 5, 6, 7])} a2 = tantri.dipoles.time_series.APSDResult(dict2, freqs) merged = tantri.dipoles.time_series.average_apsds([a1, a2]) expected = tantri.dipoles.time_series.APSDResult( psd_dict={ "t1": numpy.array([2, 3, 4, 5, 6]), }, freqs=freqs, ) numpy.testing.assert_equal(merged.freqs, expected.freqs) numpy.testing.assert_equal(merged.psd_dict, expected.psd_dict) def test_apsd_merge_mismatch_freqs(): dict = {"t1": numpy.array([1, 2, 3, 4, 5])} freqs1 = numpy.array([0.01, 0.1, 1, 10, 100]) a1 = tantri.dipoles.time_series.APSDResult(dict, freqs1) freqs2 = numpy.array([1, 3, 5, 7, 9]) a2 = tantri.dipoles.time_series.APSDResult(dict, freqs2) with pytest.raises(ValueError): tantri.dipoles.time_series.average_apsds([a1, a2])