feat!: separates threshold cost and the seed_cost in mcmc
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good

This commit is contained in:
Deepak Mallubhotla 2023-07-23 20:31:53 -05:00
parent dc43e4bfbc
commit ca710e359f
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
4 changed files with 7 additions and 7 deletions

View File

@ -47,6 +47,7 @@ class DipoleModel:
seed,
cost_function,
chain_length,
threshold_cost: float,
stdevs: pdme.subspace_simulation.MCMCStandardDeviation,
initial_cost: Optional[float] = None,
rng_arg: Optional[numpy.random.Generator] = None,
@ -73,10 +74,9 @@ class DipoleModel:
chain: List[Tuple[float, numpy.ndarray]] = []
current = seed
if initial_cost is None:
cost_to_compare = cost_function(current)
current_cost = cost_function(current)
else:
cost_to_compare = initial_cost
current_cost = cost_to_compare
current_cost = initial_cost
for i in range(chain_length):
dips = []
for dipole_index, dipole in enumerate(current):
@ -90,7 +90,7 @@ class DipoleModel:
dips
)
tentative_cost = cost_function(dips_array)
if tentative_cost < cost_to_compare:
if tentative_cost < threshold_cost:
chain.append((tentative_cost, dips_array))
current = dips_array
current_cost = tentative_cost

View File

@ -84,7 +84,7 @@ def test_log_spaced_fixedxy_orientation_mcmc_basic(snapshot):
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
chain = model.get_mcmc_chain(
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
)
assert chain == snapshot

View File

@ -84,7 +84,7 @@ def test_log_spaced_free_orientation_mcmc_basic(snapshot):
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
chain = model.get_mcmc_chain(
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
)
assert chain == snapshot

View File

@ -92,7 +92,7 @@ def test_log_spaced_fixed_orientation_mcmc_basic(snapshot):
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
chain = model.get_mcmc_chain(
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
)
assert chain == snapshot