diff --git a/pathfinder/model/oscillating/dot.py b/pathfinder/model/oscillating/dot.py index a09b784..9578b1b 100644 --- a/pathfinder/model/oscillating/dot.py +++ b/pathfinder/model/oscillating/dot.py @@ -76,12 +76,12 @@ class DotMeasurement(): w_div = alpha**2 * (1 / numpy.pi) * ((f2 - w2) / ((f2 + w2)**2)) return numpy.concatenate((p_divs, r_divs, w_div), axis=None) - # - # def jac(self, pts: numpy.ndarray) -> numpy.ndarray: - # # 6 because dipole in 3d has 6 degrees of freedom. - # pt_length = 6 - # # creates numpy.ndarrays in groups of pt_length. - # # Will throw problems for irregular points, but that's okay for now. - # chunked_pts = [pts[i: i + pt_length] for i in range(0, len(pts), pt_length)] - # - # return numpy.append([], [self.jac_pt(pt) for pt in chunked_pts]) + + def jac(self, pts: numpy.ndarray) -> numpy.ndarray: + # 7 because oscillating dipole in 3d has 7 degrees of freedom. + pt_length = 7 + # creates numpy.ndarrays in groups of pt_length. + # Will throw problems for irregular points, but that's okay for now. + chunked_pts = [pts[i: i + pt_length] for i in range(0, len(pts), pt_length)] + + return numpy.append([], [self.jac_pt(pt) for pt in chunked_pts]) diff --git a/tests/model/oscillating/test_oscdot.py b/tests/model/oscillating/test_oscdot.py index ed12434..8317c09 100644 --- a/tests/model/oscillating/test_oscdot.py +++ b/tests/model/oscillating/test_oscdot.py @@ -25,10 +25,19 @@ def test_jac(): 8.603475350051953e-7 ] + expected_jac2 = [ + -4.428720903021825e-6, 3.5429767224174605e-6, 4.428720903021825e-6, + -6.804125751006259e-6, -4.0261099118380183e-7, 2.3754048479844336e-6, + 5.181603456535536e-6 + ] + dot = model.DotMeasurement(50, (-1, -1, -1), 11) # dipole located at (4, 5, 6) with p=(1, 2, 3) and w = 7 pt = numpy.array((1, 2, 3, 4, 5, 6, 7)) + pt2 = numpy.array((2, 5, 3, 4, -5, -6, 2)) + pts = numpy.append(pt, pt2) + expected_jac_all = expected_jac + expected_jac2 assert len(dot.jac_pt(pt)) == 7 numpy.testing.assert_allclose(dot.jac_pt(pt), expected_jac, err_msg="Jac pt doesn't match Mathematica result.") - # numpy.testing.assert_allclose(dot.jac(pts), jac_row_target, err_msg="whole row should match") + numpy.testing.assert_allclose(dot.jac(pts), expected_jac_all, err_msg="whole row should match")