All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good
29 lines
681 B
Python
29 lines
681 B
Python
import typing
|
|
import tantri.util
|
|
import numpy
|
|
|
|
|
|
def test_mean_dict():
|
|
dict1 = {
|
|
"squares": numpy.array([1, 4, 9, 16]),
|
|
"linear": numpy.array([1, 2, 3, 4, 5]),
|
|
}
|
|
dict2 = {
|
|
"squares": numpy.array([2, 8, 18, 32]),
|
|
"linear": numpy.array([2, 4, 6, 8, 10]),
|
|
}
|
|
|
|
def mean(list_of_arrays: typing.Sequence[numpy.ndarray]) -> numpy.ndarray:
|
|
return numpy.mean(numpy.array(list_of_arrays), axis=0)
|
|
|
|
result = tantri.util.dict_reduce([dict1, dict2], mean)
|
|
|
|
expected = {
|
|
"squares": 1.5 * numpy.array([1, 4, 9, 16]),
|
|
"linear": 1.5 * numpy.array([1, 2, 3, 4, 5]),
|
|
}
|
|
|
|
numpy.testing.assert_equal(
|
|
result, expected, "The reduced dictionary should have matched the expected"
|
|
)
|