feat!: separates threshold cost and the seed_cost in mcmc
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good

This commit is contained in:
Deepak Mallubhotla 2023-07-23 20:31:53 -05:00
parent dc43e4bfbc
commit ca710e359f
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
4 changed files with 7 additions and 7 deletions

View File

@ -47,6 +47,7 @@ class DipoleModel:
seed, seed,
cost_function, cost_function,
chain_length, chain_length,
threshold_cost: float,
stdevs: pdme.subspace_simulation.MCMCStandardDeviation, stdevs: pdme.subspace_simulation.MCMCStandardDeviation,
initial_cost: Optional[float] = None, initial_cost: Optional[float] = None,
rng_arg: Optional[numpy.random.Generator] = None, rng_arg: Optional[numpy.random.Generator] = None,
@ -73,10 +74,9 @@ class DipoleModel:
chain: List[Tuple[float, numpy.ndarray]] = [] chain: List[Tuple[float, numpy.ndarray]] = []
current = seed current = seed
if initial_cost is None: if initial_cost is None:
cost_to_compare = cost_function(current) current_cost = cost_function(current)
else: else:
cost_to_compare = initial_cost current_cost = initial_cost
current_cost = cost_to_compare
for i in range(chain_length): for i in range(chain_length):
dips = [] dips = []
for dipole_index, dipole in enumerate(current): for dipole_index, dipole in enumerate(current):
@ -90,7 +90,7 @@ class DipoleModel:
dips dips
) )
tentative_cost = cost_function(dips_array) tentative_cost = cost_function(dips_array)
if tentative_cost < cost_to_compare: if tentative_cost < threshold_cost:
chain.append((tentative_cost, dips_array)) chain.append((tentative_cost, dips_array))
current = dips_array current = dips_array
current_cost = tentative_cost current_cost = tentative_cost

View File

@ -84,7 +84,7 @@ def test_log_spaced_fixedxy_orientation_mcmc_basic(snapshot):
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev]) stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
chain = model.get_mcmc_chain( chain = model.get_mcmc_chain(
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515) seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
) )
assert chain == snapshot assert chain == snapshot

View File

@ -84,7 +84,7 @@ def test_log_spaced_free_orientation_mcmc_basic(snapshot):
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev]) stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
chain = model.get_mcmc_chain( chain = model.get_mcmc_chain(
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515) seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
) )
assert chain == snapshot assert chain == snapshot

View File

@ -92,7 +92,7 @@ def test_log_spaced_fixed_orientation_mcmc_basic(snapshot):
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev]) stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
chain = model.get_mcmc_chain( chain = model.get_mcmc_chain(
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515) seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
) )
assert chain == snapshot assert chain == snapshot