tantri/tests/test_util.py
Deepak Mallubhotla 22bb3d876f
All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good
feat: adds psd calculation code and utilities to handle averaging of periodograms
2024-07-27 11:07:24 -05:00

29 lines
681 B
Python

import typing
import tantri.util
import numpy
def test_mean_dict():
dict1 = {
"squares": numpy.array([1, 4, 9, 16]),
"linear": numpy.array([1, 2, 3, 4, 5]),
}
dict2 = {
"squares": numpy.array([2, 8, 18, 32]),
"linear": numpy.array([2, 4, 6, 8, 10]),
}
def mean(list_of_arrays: typing.Sequence[numpy.ndarray]) -> numpy.ndarray:
return numpy.mean(numpy.array(list_of_arrays), axis=0)
result = tantri.util.dict_reduce([dict1, dict2], mean)
expected = {
"squares": 1.5 * numpy.array([1, 4, 9, 16]),
"linear": 1.5 * numpy.array([1, 2, 3, 4, 5]),
}
numpy.testing.assert_equal(
result, expected, "The reduced dictionary should have matched the expected"
)