From 6193ecb9c9f7a21d24e860987a7107549a4b2fa7 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Thu, 16 May 2024 23:12:15 -0500 Subject: [PATCH] feat: adds mcmc chain that returns number of repeats --- pdme/model/model.py | 62 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/pdme/model/model.py b/pdme/model/model.py index f22e594..ca4c124 100644 --- a/pdme/model/model.py +++ b/pdme/model/model.py @@ -99,3 +99,65 @@ class DipoleModel: else: chain.append((numpy.squeeze(current_cost).item(), current)) return chain + + def get_repeat_counting_mcmc_chain( + self, + seed, + cost_function, + chain_length, + threshold_cost: float, + stdevs: pdme.subspace_simulation.MCMCStandardDeviation, + initial_cost: Optional[float] = None, + rng_arg: Optional[numpy.random.Generator] = None, + ) -> Tuple[int, List[Tuple[float, numpy.ndarray]]]: + """ + performs constrained markov chain monte carlo starting on seed parameter. + The cost function given is used as a constrained to condition the chain; + a new state is only accepted if cost_function(state) < cost_function(previous_state). + The stdevs passed in are the stdevs we're expected to use. + + Because we're using this for subspace simulation where our proposal function is not too important, we're in good shape. + Note that for our adaptive stdevs to work, there's an unwritten contract that we sort each dipole in the state by frequency (increasing). + + The seed is a list of dipoles, and each chain state is a list of dipoles as well. + + initial_cost is a performance guy that lets you pre-populate the initial cost used to define the condition. + Probably premature optimisation. + + Chain has type of [ (cost: float, state: dipole_ndarray ) ] format, + returning (repeat_count, chain) to keep track of number of repeats + """ + _logger.debug( + f"Starting Markov Chain Monte Carlo with seed: {seed} for chain length {chain_length} and provided stdevs {stdevs}" + ) + chain: List[Tuple[float, numpy.ndarray]] = [] + if initial_cost is None: + current_cost = cost_function(numpy.array([seed])) + else: + current_cost = initial_cost + current = seed + repeat_event_count = 0 + for _ in range(chain_length): + dips = [] + for dipole_index, dipole in enumerate(current): + _logger.debug(dipole_index) + _logger.debug(dipole) + stdev = stdevs[dipole_index] + tentative_dip = self.markov_chain_monte_carlo_proposal( + dipole, stdev, rng_arg + ) + + dips.append(tentative_dip) + dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency( + dips + ) + tentative_cost = cost_function(numpy.array([dips_array]))[0] + if tentative_cost < threshold_cost: + chain.append((numpy.squeeze(tentative_cost).item(), dips_array)) + current = dips_array + current_cost = tentative_cost + else: + # repeating a sample, increase count + repeat_event_count += 1 + chain.append((numpy.squeeze(current_cost).item(), current)) + return (repeat_event_count, chain)