feat!: separates threshold cost and the seed_cost in mcmc
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
This commit is contained in:
parent
dc43e4bfbc
commit
ca710e359f
@ -47,6 +47,7 @@ class DipoleModel:
|
||||
seed,
|
||||
cost_function,
|
||||
chain_length,
|
||||
threshold_cost: float,
|
||||
stdevs: pdme.subspace_simulation.MCMCStandardDeviation,
|
||||
initial_cost: Optional[float] = None,
|
||||
rng_arg: Optional[numpy.random.Generator] = None,
|
||||
@ -73,10 +74,9 @@ class DipoleModel:
|
||||
chain: List[Tuple[float, numpy.ndarray]] = []
|
||||
current = seed
|
||||
if initial_cost is None:
|
||||
cost_to_compare = cost_function(current)
|
||||
current_cost = cost_function(current)
|
||||
else:
|
||||
cost_to_compare = initial_cost
|
||||
current_cost = cost_to_compare
|
||||
current_cost = initial_cost
|
||||
for i in range(chain_length):
|
||||
dips = []
|
||||
for dipole_index, dipole in enumerate(current):
|
||||
@ -90,7 +90,7 @@ class DipoleModel:
|
||||
dips
|
||||
)
|
||||
tentative_cost = cost_function(dips_array)
|
||||
if tentative_cost < cost_to_compare:
|
||||
if tentative_cost < threshold_cost:
|
||||
chain.append((tentative_cost, dips_array))
|
||||
current = dips_array
|
||||
current_cost = tentative_cost
|
||||
|
@ -84,7 +84,7 @@ def test_log_spaced_fixedxy_orientation_mcmc_basic(snapshot):
|
||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||
|
||||
chain = model.get_mcmc_chain(
|
||||
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||
)
|
||||
|
||||
assert chain == snapshot
|
||||
|
@ -84,7 +84,7 @@ def test_log_spaced_free_orientation_mcmc_basic(snapshot):
|
||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||
|
||||
chain = model.get_mcmc_chain(
|
||||
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||
)
|
||||
|
||||
assert chain == snapshot
|
||||
|
@ -92,7 +92,7 @@ def test_log_spaced_fixed_orientation_mcmc_basic(snapshot):
|
||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||
|
||||
chain = model.get_mcmc_chain(
|
||||
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||
)
|
||||
|
||||
assert chain == snapshot
|
||||
|
Loading…
x
Reference in New Issue
Block a user