diff --git a/pynam/dielectric/nam_dielectric_coefficient_approximator.py b/pynam/dielectric/nam_dielectric_coefficient_approximator.py index deedb80..f3a26db 100644 --- a/pynam/dielectric/nam_dielectric_coefficient_approximator.py +++ b/pynam/dielectric/nam_dielectric_coefficient_approximator.py @@ -36,9 +36,9 @@ class NamDielectricCoefficients(object): self.d = d self.u_l = np.real((-self.c + 1j * self.d) / (-self.a + 1j * self.b)) - def eps(self): + def eps(self, u_c: float): - def piecewise_eps(u: float, u_c: float): + def piecewise_eps(u: float): # todo add check for u_c vs u_l if u < self.u_l: return -self.a + 1j * self.b diff --git a/tests/dielectric/test_nam_dielectric_coefficient_approximator.py b/tests/dielectric/test_nam_dielectric_coefficient_approximator.py index 48e6f74..2a604a7 100644 --- a/tests/dielectric/test_nam_dielectric_coefficient_approximator.py +++ b/tests/dielectric/test_nam_dielectric_coefficient_approximator.py @@ -15,7 +15,8 @@ import pynam.dielectric.nam_dielectric_coefficient_approximator ]) def test_dedimensionalise_parameters(test_input, expected): - actual_parameters = pynam.dielectric.nam_dielectric_coefficient_approximator.get_dedimensionalised_parameters(*test_input) + actual_parameters = pynam.dielectric.nam_dielectric_coefficient_approximator.get_dedimensionalised_parameters( + *test_input) np.testing.assert_almost_equal( actual_parameters.xi, expected[0], @@ -52,7 +53,8 @@ def test_dedimensionalise_parameters(test_input, expected): ]) def test_nam_coefficients(test_input, expected): - actual_coefficients = pynam.dielectric.nam_dielectric_coefficient_approximator.get_nam_dielectric_coefficients(*test_input) + actual_coefficients = pynam.dielectric.nam_dielectric_coefficient_approximator.get_nam_dielectric_coefficients( + *test_input) np.testing.assert_allclose( actual_coefficients.a, expected[0], @@ -86,19 +88,19 @@ def test_nam_eps(): 2e6, 0.8e11, 1e11, - 3e8).eps() + 3e8).eps(u_c) np.testing.assert_allclose( - eps_to_test(10, u_c), -3.789672906817707e10 + 3.257134605133221e8j, + eps_to_test(10), -3.789672906817707e10 + 3.257134605133221e8j, rtol=1e-3, err_msg='below u_l bad' ) np.testing.assert_allclose( - eps_to_test(1e10, u_c), -2.655709887616547e8 + 2.302290450767144e6j, + eps_to_test(1e10), -2.655709887616547e8 + 2.302290450767144e6j, rtol=1e-3, err_msg='linear region bad' ) np.testing.assert_allclose( - eps_to_test(1e17, u_c), 1, + eps_to_test(1e17), 1, rtol=1e-6, err_msg='above cutoff bad' )