fix: fixes stupid cost shape issue
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good

This commit is contained in:
Deepak Mallubhotla 2023-07-26 21:12:04 -05:00
parent 74c1b01a6c
commit ed9dd2c94f
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
7 changed files with 44 additions and 36 deletions

View File

@ -88,12 +88,14 @@ class DipoleModel:
)
dips.append(tentative_dip)
dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(dips)
dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(
dips
)
tentative_cost = cost_function(numpy.array([dips_array]))[0]
if tentative_cost < threshold_cost:
chain.append((tentative_cost, dips_array))
chain.append((numpy.squeeze(tentative_cost).item(), dips_array))
current = dips_array
current_cost = tentative_cost
else:
chain.append((current_cost, current))
chain.append((numpy.squeeze(current_cost).item(), current))
return chain

View File

@ -2,52 +2,52 @@
# name: test_log_spaced_fixedxy_orientation_mcmc_basic
list([
tuple(
3984.461796564996,
3984.461796565,
array([[ 9.55610128, 2.94634152, 0. , 9.21529051, -2.46576127,
2.42481096, 9.19034554]]),
),
tuple(
8583.990878715194,
8583.9908787152,
array([[ 9.99991539, 0.04113671, 0. , 8.71258954, -2.26599865,
2.60452102, 6.37042214]]),
),
tuple(
6215.637661601595,
6215.6376616016,
array([[ 9.81950685, -1.89137124, 0. , 8.90637055, -2.48043039,
2.28444435, 8.84239221]]),
),
tuple(
424.73328465980165,
424.7332846598,
array([[ 1.00028483, 9.94984574, 0. , 8.53064898, -2.59230757,
2.33774773, 8.6714416 ]]),
),
tuple(
300.9220380848663,
300.9220380849,
array([[ 1.4003442 , 9.90146636, 0. , 8.05557992, -2.6753126 ,
2.65915755, 13.02021385]]),
),
tuple(
2400.010727708547,
2400.0107277085,
array([[ 9.97761813, 0.66868263, 0. , 8.69171028, -2.73145011,
2.90140456, 19.94999593]]),
),
tuple(
5001.462051130342,
5001.4620511303,
array([[ 9.93976109, -1.09596962, 0. , 8.95245025, -2.59409162,
2.90140456, 9.75535945]]),
),
tuple(
195.21980744877803,
195.2198074488,
array([[ 0.20690762, 9.99785923, 0. , 9.59636585, -2.83240984,
2.90140456, 16.14771567]]),
),
tuple(
2698.258844497963,
2698.258844498,
array([[-9.68130127, -2.50447712, 0. , 8.94823619, -2.92889659,
2.77065328, 13.63173263]]),
),
tuple(
1193.698547394381,
1193.6985473944,
array([[-6.16597091, -7.87278875, 0. , 9.62210721, -2.75993924,
2.77065328, 5.64553534]]),
),

View File

@ -2,52 +2,52 @@
# name: test_log_spaced_free_orientation_mcmc_basic
list([
tuple(
array([3167.67112687]),
3167.6711268743,
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
2.11809123, 16.17452242]]),
),
tuple(
array([3167.67112687]),
3167.6711268743,
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
2.11809123, 16.17452242]]),
),
tuple(
array([3167.67112687]),
3167.6711268743,
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
2.11809123, 16.17452242]]),
),
tuple(
736.0306527136138,
736.0306527136,
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
2.43615641, 12.92143144]]),
),
tuple(
736.0306527136138,
736.0306527136,
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
2.43615641, 12.92143144]]),
),
tuple(
736.0306527136138,
736.0306527136,
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
2.43615641, 12.92143144]]),
),
tuple(
2248.0779986277157,
2248.0779986277,
array([[-1.71755535, -5.59925137, 8.10545419, -4.03306318, -1.81098441,
2.77407111, 32.28020575]]),
),
tuple(
1663.3106727359873,
1663.310672736,
array([[-5.16785855, 2.7558756 , 8.10545419, -3.34620897, -1.74763642,
2.42770463, 52.98214008]]),
),
tuple(
1329.2704143918077,
1329.2704143918,
array([[ -1.39600464, 9.69718343, -2.00394725, -2.59147366,
-1.91246681, 2.07361175, 123.01833742]]),
),
tuple(
355.769559189747,
355.7695591897,
array([[ 9.76047401, 0.84696075, -2.00394725, -3.04310053,
-1.99338573, 2.1185589 , 271.35743739]]),
),

View File

@ -2,52 +2,52 @@
# name: test_log_spaced_fixed_orientation_mcmc_basic
list([
tuple(
array([50.56831193]),
50.5683119299,
array([[ 0. , 0. , 10. , -2.3960853 , 4.23246234,
2.26169242, 39.39900844]]),
),
tuple(
array([50.56831193]),
50.5683119299,
array([[ 0. , 0. , 10. , -2.3960853 , 4.23246234,
2.26169242, 39.39900844]]),
),
tuple(
47.4086545539552,
47.408654554,
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
2.21309317, 47.82371559]]),
),
tuple(
47.4086545539552,
47.408654554,
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
2.21309317, 47.82371559]]),
),
tuple(
47.4086545539552,
47.408654554,
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
2.21309317, 47.82371559]]),
),
tuple(
47.4086545539552,
47.408654554,
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
2.21309317, 47.82371559]]),
),
tuple(
22.9327902846994,
22.9327902847,
array([[ 0. , 0. , 10. , -1.63019717, 3.97041764,
2.53115835, 38.2051999 ]]),
),
tuple(
28.811977332207675,
28.8119773322,
array([[ 0. , 0. , 10. , -1.14570315, 4.07709911,
2.48697441, 49.58615195]]),
),
tuple(
28.811977332207675,
28.8119773322,
array([[ 0. , 0. , 10. , -1.14570315, 4.07709911,
2.48697441, 49.58615195]]),
),
tuple(
40.97406005434711,
40.9740600543,
array([[ 0. , 0. , 10. , -0.50178755, 3.83878089,
2.93560796, 82.07827571]]),
),

View File

@ -95,4 +95,6 @@ def test_log_spaced_fixedxy_orientation_mcmc_basic(snapshot):
rng_arg=numpy.random.default_rng(1515),
)
assert chain == snapshot
chain_rounded = [(round(cost, 10), dipoles) for (cost, dipoles) in chain]
assert chain_rounded == snapshot

View File

@ -95,4 +95,6 @@ def test_log_spaced_free_orientation_mcmc_basic(snapshot):
rng_arg=numpy.random.default_rng(1515),
)
assert chain == snapshot
chain_rounded = [(round(cost, 10), dipoles) for (cost, dipoles) in chain]
assert chain_rounded == snapshot

View File

@ -103,4 +103,6 @@ def test_log_spaced_fixed_orientation_mcmc_basic(snapshot):
rng_arg=numpy.random.default_rng(1515),
)
assert chain == snapshot
chain_rounded = [(round(cost, 10), dipoles) for (cost, dipoles) in chain]
assert chain_rounded == snapshot