feat: adds mcmc chain that returns number of repeats
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
This commit is contained in:
@@ -99,3 +99,65 @@ class DipoleModel:
|
|||||||
else:
|
else:
|
||||||
chain.append((numpy.squeeze(current_cost).item(), current))
|
chain.append((numpy.squeeze(current_cost).item(), current))
|
||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
def get_repeat_counting_mcmc_chain(
|
||||||
|
self,
|
||||||
|
seed,
|
||||||
|
cost_function,
|
||||||
|
chain_length,
|
||||||
|
threshold_cost: float,
|
||||||
|
stdevs: pdme.subspace_simulation.MCMCStandardDeviation,
|
||||||
|
initial_cost: Optional[float] = None,
|
||||||
|
rng_arg: Optional[numpy.random.Generator] = None,
|
||||||
|
) -> Tuple[int, List[Tuple[float, numpy.ndarray]]]:
|
||||||
|
"""
|
||||||
|
performs constrained markov chain monte carlo starting on seed parameter.
|
||||||
|
The cost function given is used as a constrained to condition the chain;
|
||||||
|
a new state is only accepted if cost_function(state) < cost_function(previous_state).
|
||||||
|
The stdevs passed in are the stdevs we're expected to use.
|
||||||
|
|
||||||
|
Because we're using this for subspace simulation where our proposal function is not too important, we're in good shape.
|
||||||
|
Note that for our adaptive stdevs to work, there's an unwritten contract that we sort each dipole in the state by frequency (increasing).
|
||||||
|
|
||||||
|
The seed is a list of dipoles, and each chain state is a list of dipoles as well.
|
||||||
|
|
||||||
|
initial_cost is a performance guy that lets you pre-populate the initial cost used to define the condition.
|
||||||
|
Probably premature optimisation.
|
||||||
|
|
||||||
|
Chain has type of [ (cost: float, state: dipole_ndarray ) ] format,
|
||||||
|
returning (repeat_count, chain) to keep track of number of repeats
|
||||||
|
"""
|
||||||
|
_logger.debug(
|
||||||
|
f"Starting Markov Chain Monte Carlo with seed: {seed} for chain length {chain_length} and provided stdevs {stdevs}"
|
||||||
|
)
|
||||||
|
chain: List[Tuple[float, numpy.ndarray]] = []
|
||||||
|
if initial_cost is None:
|
||||||
|
current_cost = cost_function(numpy.array([seed]))
|
||||||
|
else:
|
||||||
|
current_cost = initial_cost
|
||||||
|
current = seed
|
||||||
|
repeat_event_count = 0
|
||||||
|
for _ in range(chain_length):
|
||||||
|
dips = []
|
||||||
|
for dipole_index, dipole in enumerate(current):
|
||||||
|
_logger.debug(dipole_index)
|
||||||
|
_logger.debug(dipole)
|
||||||
|
stdev = stdevs[dipole_index]
|
||||||
|
tentative_dip = self.markov_chain_monte_carlo_proposal(
|
||||||
|
dipole, stdev, rng_arg
|
||||||
|
)
|
||||||
|
|
||||||
|
dips.append(tentative_dip)
|
||||||
|
dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(
|
||||||
|
dips
|
||||||
|
)
|
||||||
|
tentative_cost = cost_function(numpy.array([dips_array]))[0]
|
||||||
|
if tentative_cost < threshold_cost:
|
||||||
|
chain.append((numpy.squeeze(tentative_cost).item(), dips_array))
|
||||||
|
current = dips_array
|
||||||
|
current_cost = tentative_cost
|
||||||
|
else:
|
||||||
|
# repeating a sample, increase count
|
||||||
|
repeat_event_count += 1
|
||||||
|
chain.append((numpy.squeeze(current_cost).item(), current))
|
||||||
|
return (repeat_event_count, chain)
|
||||||
|
|||||||
Reference in New Issue
Block a user