diff --git a/deepdog/subset_simulation/subset_simulation_impl.py b/deepdog/subset_simulation/subset_simulation_impl.py index c630fbf..11d6d22 100644 --- a/deepdog/subset_simulation/subset_simulation_impl.py +++ b/deepdog/subset_simulation/subset_simulation_impl.py @@ -1,9 +1,11 @@ import logging +import multiprocessing import numpy import pdme.measurement import pdme.measurement.input_types +import pdme.model import pdme.subspace_simulation -from typing import Sequence, Tuple, Optional +from typing import Sequence, Tuple, Optional, Callable, Union, List from dataclasses import dataclass @@ -18,20 +20,32 @@ class SubsetSimulationResult: under_target_cost: Optional[float] under_target_likelihood: Optional[float] lowest_likelihood: Optional[float] + messages: Sequence[str] + + +@dataclass +class MultiSubsetSimulationResult: + child_results: Sequence[SubsetSimulationResult] + model_name: str + estimated_likelihood: float + arithmetic_mean_estimated_likelihood: float + num_children: int + num_finished_children: int + clean_estimate: bool class SubsetSimulation: def __init__( self, model_name_pair, - dot_inputs, - actual_measurements: Sequence[pdme.measurement.DotMeasurement], + # actual_measurements: Sequence[pdme.measurement.DotMeasurement], + cost_function: Callable[[numpy.ndarray], numpy.ndarray], n_c: int, n_s: int, m_max: int, target_cost: Optional[float] = None, - level_0_seed: int = 200, - mcmc_seed: int = 20, + level_0_seed: Union[int, Sequence[int]] = 200, + mcmc_seed: Union[int, Sequence[int]] = 20, use_adaptive_steps=True, default_phi_step=0.01, default_theta_step=0.01, @@ -41,24 +55,26 @@ class SubsetSimulation: keep_probs_list=True, dump_last_generation_to_file=False, initial_cost_chunk_size=100, + cap_core_count: int = 0, # 0 means cap at num cores - 1 ): name, model = model_name_pair self.model_name = name self.model = model _logger.info(f"got model {self.model_name}") - self.dot_inputs_array = pdme.measurement.input_types.dot_inputs_to_array( - dot_inputs - ) + # dot_inputs = [(meas.r, meas.f) for meas in actual_measurements] + # self.dot_inputs_array = pdme.measurement.input_types.dot_inputs_to_array( + # dot_inputs + # ) # _logger.debug(f"actual measurements: {actual_measurements}") - self.actual_measurement_array = numpy.array([m.v for m in actual_measurements]) + # self.actual_measurement_array = numpy.array([m.v for m in actual_measurements]) - def cost_function_to_use(dipoles_to_test): - return pdme.subspace_simulation.proportional_costs_vs_actual_measurement( - self.dot_inputs_array, self.actual_measurement_array, dipoles_to_test - ) + # def cost_function_to_use(dipoles_to_test): + # return pdme.subspace_simulation.proportional_costs_vs_actual_measurement( + # self.dot_inputs_array, self.actual_measurement_array, dipoles_to_test + # ) - self.cost_function_to_use = cost_function_to_use + self.cost_function_to_use = cost_function self.n_c = n_c self.n_s = n_s @@ -78,6 +94,9 @@ class SubsetSimulation: _logger.info(f"\tn_c: {self.n_c}") _logger.info(f"\tn_s: {self.n_s}") _logger.info(f"\tm: {self.m_max}") + _logger.info(f"\tseeds:") + _logger.info(f"\t\t{mcmc_seed=}") + _logger.info(f"\t\t{level_0_seed=}") _logger.info("let's do level 0...") self.target_cost = target_cost @@ -88,10 +107,27 @@ class SubsetSimulation: self.initial_cost_chunk_size = initial_cost_chunk_size + self.cap_core_count = cap_core_count + + def _single_chain_gen(self, args: Tuple): + threshold_cost, stdevs, rng_seed, (c, s) = args + rng = numpy.random.default_rng(rng_seed) + return self.model.get_repeat_counting_mcmc_chain( + s, + self.cost_function_to_use, + self.n_s, + threshold_cost, + stdevs, + initial_cost=c, + rng_arg=rng, + ) + def execute(self) -> SubsetSimulationResult: probs_list = [] + output_messages = [] + sample_dipoles = self.model.get_monte_carlo_dipole_inputs( self.n_c * self.n_s, -1, @@ -106,19 +142,19 @@ class SubsetSimulation: ) for x in range(0, len(sample_dipoles), self.initial_cost_chunk_size): - _logger.debug(f"doing chunk {x}") - raw_costs.extend( + # _logger.debug(f"doing chunk {x}") + raw_costs.append( self.cost_function_to_use( sample_dipoles[x : x + self.initial_cost_chunk_size] ) ) - costs = numpy.array(raw_costs) + costs = numpy.concatenate(raw_costs) - _logger.debug(f"costs: {costs}") + # _logger.debug(f"costs: {costs}") sorted_indexes = costs.argsort()[::-1] - _logger.debug(costs[sorted_indexes]) - _logger.debug(sample_dipoles[sorted_indexes]) + # _logger.debug(costs[sorted_indexes]) + # _logger.debug(sample_dipoles[sorted_indexes]) sorted_costs = costs[sorted_indexes] sorted_dipoles = sample_dipoles[sorted_indexes] @@ -133,112 +169,65 @@ class SubsetSimulation: ) all_chains = list(zip(sorted_costs, all_dipoles)) - mcmc_rng = numpy.random.default_rng(self.mcmc_seed) + long_mcmc_rng = numpy.random.default_rng(self.mcmc_seed) + mcmc_rng_seed_sequence = numpy.random.SeedSequence(self.mcmc_seed) - for i in range(self.m_max): - next_seeds = all_chains[-self.n_c :] + # core count etc. logic here + core_count = multiprocessing.cpu_count() - 1 or 1 + if (self.cap_core_count >= 1) and (self.cap_core_count < core_count): + core_count = self.cap_core_count + _logger.info(f"Using {core_count} cores") - if self.dump_last_generations: - _logger.info("writing out csv file") - next_dipoles_seed_dipoles = numpy.array([n[1] for n in next_seeds]) - for n in range(self.model.n): - _logger.info(f"{next_dipoles_seed_dipoles[:, n].shape}") - numpy.savetxt( - f"generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv", - next_dipoles_seed_dipoles[:, n], - delimiter=",", - ) + with multiprocessing.Pool(core_count) as pool: + for i in range(self.m_max): + next_seeds = all_chains[-self.n_c :] - next_seeds_as_array = numpy.array([s for _, s in next_seeds]) - stdevs = self.get_stdevs_from_arrays(next_seeds_as_array) - _logger.info(f"got stdevs: {stdevs.stdevs}") - all_long_chains = [] - for seed_index, (c, s) in enumerate( - next_seeds[:: len(next_seeds) // 20] - ): - # chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs) - # until new version gotta do - _logger.debug(f"\t{seed_index}: doing long chain on the next seed") - - long_chain = self.model.get_mcmc_chain( - s, - self.cost_function_to_use, - 1000, - threshold_cost, - stdevs, - initial_cost=c, - rng_arg=mcmc_rng, - ) - for _, chained in long_chain: - all_long_chains.append(chained) - all_long_chains_array = numpy.array(all_long_chains) - for n in range(self.model.n): - _logger.info(f"{all_long_chains_array[:, n].shape}") - numpy.savetxt( - f"long_chain_generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv", - all_long_chains_array[:, n], - delimiter=",", - ) - - if self.keep_probs_list: - for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]): - probs_list.append( - ( - ((self.n_c * self.n_s - cost_index) / (self.n_c * self.n_s)) - / (self.n_s ** (i)), - cost_chain[0], - i + 1, + if self.dump_last_generations: + _logger.info("writing out csv file") + next_dipoles_seed_dipoles = numpy.array([n[1] for n in next_seeds]) + for n in range(self.model.n): + _logger.info(f"{next_dipoles_seed_dipoles[:, n].shape}") + numpy.savetxt( + f"generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv", + next_dipoles_seed_dipoles[:, n], + delimiter=",", ) - ) - next_seeds_as_array = numpy.array([s for _, s in next_seeds]) + next_seeds_as_array = numpy.array([s for _, s in next_seeds]) + stdevs = self.get_stdevs_from_arrays(next_seeds_as_array) + _logger.info(f"got stdevs: {stdevs.stdevs}") + all_long_chains = [] + for seed_index, (c, s) in enumerate( + next_seeds[:: len(next_seeds) // 20] + ): + # chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs) + # until new version gotta do + _logger.debug( + f"\t{seed_index}: doing long chain on the next seed" + ) - stdevs = self.get_stdevs_from_arrays(next_seeds_as_array) - _logger.info(f"got stdevs: {stdevs.stdevs}") - _logger.debug("Starting the MCMC") - all_chains = [] - for seed_index, (c, s) in enumerate(next_seeds): - # chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs) - # until new version gotta do - _logger.debug( - f"\t{seed_index}: getting another chain from the next seed" - ) - chain = self.model.get_mcmc_chain( - s, - self.cost_function_to_use, - self.n_s, - threshold_cost, - stdevs, - initial_cost=c, - rng_arg=mcmc_rng, - ) - for cost, chained in chain: - try: - filtered_cost = cost[0] - except (IndexError, TypeError): - filtered_cost = cost - all_chains.append((filtered_cost, chained)) - _logger.debug("finished mcmc") - # _logger.debug(all_chains) + long_chain = self.model.get_mcmc_chain( + s, + self.cost_function_to_use, + 1000, + threshold_cost, + stdevs, + initial_cost=c, + rng_arg=long_mcmc_rng, + ) + for _, chained in long_chain: + all_long_chains.append(chained) + all_long_chains_array = numpy.array(all_long_chains) + for n in range(self.model.n): + _logger.info(f"{all_long_chains_array[:, n].shape}") + numpy.savetxt( + f"long_chain_generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv", + all_long_chains_array[:, n], + delimiter=",", + ) - all_chains.sort(key=lambda c: c[0], reverse=True) - _logger.debug("finished sorting all_chains") - - threshold_cost = all_chains[-self.n_c][0] - _logger.info( - f"current threshold cost: {threshold_cost}, at P = (1 / {self.n_s})^{i + 1}" - ) - if (self.target_cost is not None) and (threshold_cost < self.target_cost): - _logger.info( - f"got a threshold cost {threshold_cost}, less than {self.target_cost}. will leave early" - ) - - cost_list = [c[0] for c in all_chains] - over_index = reverse_bisect_right(cost_list, self.target_cost) - - shorter_probs_list = [] - for cost_index, cost_chain in enumerate(all_chains): - if self.keep_probs_list: + if self.keep_probs_list: + for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]): probs_list.append( ( ( @@ -250,26 +239,121 @@ class SubsetSimulation: i + 1, ) ) - shorter_probs_list.append( - ( - cost_chain[0], - ((self.n_c * self.n_s - cost_index) / (self.n_c * self.n_s)) - / (self.n_s ** (i)), - ) - ) - # _logger.info(shorter_probs_list) - result = SubsetSimulationResult( - probs_list=probs_list, - over_target_cost=shorter_probs_list[over_index - 1][0], - over_target_likelihood=shorter_probs_list[over_index - 1][1], - under_target_cost=shorter_probs_list[over_index][0], - under_target_likelihood=shorter_probs_list[over_index][1], - lowest_likelihood=shorter_probs_list[-1][1], - ) - return result - # _logger.debug([c[0] for c in all_chains[-n_c:]]) - _logger.info(f"doing level {i + 1}") + next_seeds_as_array = numpy.array([s for _, s in next_seeds]) + + stdevs = self.get_stdevs_from_arrays(next_seeds_as_array) + _logger.debug(f"got stdevs, begin: {stdevs.stdevs[:10]}") + _logger.debug("Starting the MCMC") + all_chains = [] + + seeds = mcmc_rng_seed_sequence.spawn(len(next_seeds)) + pool_results = pool.imap_unordered( + self._single_chain_gen, + [ + (threshold_cost, stdevs, rng_seed, test_seed) + for rng_seed, test_seed in zip(seeds, next_seeds) + ], + chunksize=50, + ) + + # count for ergodicity analysis + samples_generated = 0 + samples_rejected = 0 + + for rejected_count, chain in pool_results: + for cost, chained in chain: + try: + filtered_cost = cost[0] + except (IndexError, TypeError): + filtered_cost = cost + all_chains.append((filtered_cost, chained)) + + samples_generated += self.n_s + samples_rejected += rejected_count + + # for seed_index, (c, s) in enumerate(next_seeds): + # # chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs) + # # until new version gotta do + # _logger.debug( + # f"\t{seed_index}: getting another chain from the next seed" + # ) + # rejected_count, chain = self.model.get_repeat_counting_mcmc_chain( + # s, + # self.cost_function_to_use, + # self.n_s, + # threshold_cost, + # stdevs, + # initial_cost=c, + # rng_arg=mcmc_rng, + # ) + + _logger.debug("finished mcmc") + _logger.debug(f"{samples_rejected=} out of {samples_generated=}") + if samples_rejected * 2 > samples_generated: + reject_ratio = samples_rejected / samples_generated + rejectionmessage = f"On level {i}, rejected {samples_rejected} out of {samples_generated}, {reject_ratio=} is too high and may indicate ergodicity problems" + output_messages.append(rejectionmessage) + _logger.warning(rejectionmessage) + # _logger.debug(all_chains) + + all_chains.sort(key=lambda c: c[0], reverse=True) + _logger.debug("finished sorting all_chains") + + threshold_cost = all_chains[-self.n_c][0] + _logger.info( + f"current threshold cost: {threshold_cost}, at P = (1 / {self.n_s})^{i + 1}" + ) + if (self.target_cost is not None) and ( + threshold_cost < self.target_cost + ): + _logger.info( + f"got a threshold cost {threshold_cost}, less than {self.target_cost}. will leave early" + ) + + cost_list = [c[0] for c in all_chains] + over_index = reverse_bisect_right(cost_list, self.target_cost) + + winner = all_chains[over_index][1] + _logger.info(f"Winner obtained: {winner}") + shorter_probs_list = [] + for cost_index, cost_chain in enumerate(all_chains): + if self.keep_probs_list: + probs_list.append( + ( + ( + (self.n_c * self.n_s - cost_index) + / (self.n_c * self.n_s) + ) + / (self.n_s ** (i)), + cost_chain[0], + i + 1, + ) + ) + shorter_probs_list.append( + ( + cost_chain[0], + ( + (self.n_c * self.n_s - cost_index) + / (self.n_c * self.n_s) + ) + / (self.n_s ** (i)), + ) + ) + # _logger.info(shorter_probs_list) + result = SubsetSimulationResult( + probs_list=probs_list, + over_target_cost=shorter_probs_list[over_index - 1][0], + over_target_likelihood=shorter_probs_list[over_index - 1][1], + under_target_cost=shorter_probs_list[over_index][0], + under_target_likelihood=shorter_probs_list[over_index][1], + lowest_likelihood=shorter_probs_list[-1][1], + messages=output_messages, + ) + return result + + # _logger.debug([c[0] for c in all_chains[-n_c:]]) + _logger.info(f"doing level {i + 1}") if self.keep_probs_list: for cost_index, cost_chain in enumerate(all_chains): @@ -300,6 +384,7 @@ class SubsetSimulation: under_target_cost=None, under_target_likelihood=None, lowest_likelihood=min_likelihood, + messages=output_messages, ) return result @@ -358,6 +443,112 @@ class SubsetSimulation: return stdevs +class MultiSubsetSimulations: + def __init__( + self, + model_name_pairs: Sequence[Tuple[str, pdme.model.DipoleModel]], + # actual_measurements: Sequence[pdme.measurement.DotMeasurement], + cost_function: Callable[[numpy.ndarray], numpy.ndarray], + num_runs: int, + n_c: int, + n_s: int, + m_max: int, + target_cost: float, + level_0_seed_seed: int = 200, + mcmc_seed_seed: int = 20, + use_adaptive_steps=True, + default_phi_step=0.01, + default_theta_step=0.01, + default_r_step=0.01, + default_w_log_step=0.01, + default_upper_w_log_step=4, + initial_cost_chunk_size=100, + cap_core_count: int = 0, # 0 means cap at num cores - 1 + ): + self.model_name_pairs = model_name_pairs + self.cost_function = cost_function + self.num_runs = num_runs + self.n_c = n_c + self.n_s = n_s + self.m_max = m_max + self.target_cost = target_cost # This is not optional here! + + self.level_0_seed_seed = level_0_seed_seed + self.mcmc_seed_seed = mcmc_seed_seed + + self.use_adaptive_steps = use_adaptive_steps + self.default_phi_step = default_phi_step + self.default_theta_step = default_theta_step + self.default_r_step = default_r_step + self.default_w_log_step = default_w_log_step + self.default_upper_w_log_step = default_upper_w_log_step + self.initial_cost_chunk_size = initial_cost_chunk_size + self.cap_core_count = cap_core_count + + def execute(self) -> Sequence[MultiSubsetSimulationResult]: + output: List[MultiSubsetSimulationResult] = [] + for model_name_pair in self.model_name_pairs: + ss_results = [ + SubsetSimulation( + model_name_pair, + self.cost_function, + self.n_c, + self.n_s, + self.m_max, + self.target_cost, + level_0_seed=[run_index, self.level_0_seed_seed], + mcmc_seed=[run_index, self.mcmc_seed_seed], + use_adaptive_steps=self.use_adaptive_steps, + default_phi_step=self.default_phi_step, + default_theta_step=self.default_theta_step, + default_r_step=self.default_r_step, + default_w_log_step=self.default_w_log_step, + default_upper_w_log_step=self.default_upper_w_log_step, + keep_probs_list=False, + dump_last_generation_to_file=False, + initial_cost_chunk_size=self.initial_cost_chunk_size, + cap_core_count=self.cap_core_count, + ).execute() + for run_index in range(self.num_runs) + ] + output.append(coalesce_ss_results(model_name_pair[0], ss_results)) + return output + + +def coalesce_ss_results( + model_name: str, results: Sequence[SubsetSimulationResult] +) -> MultiSubsetSimulationResult: + + num_finished = sum(1 for res in results if res.under_target_likelihood is not None) + + estimated_likelihoods = numpy.array( + [ + res.under_target_likelihood + if res.under_target_likelihood is not None + else res.lowest_likelihood + for res in results + ] + ) + + _logger.warning(estimated_likelihoods) + geometric_mean_estimated_likelihoods = numpy.exp( + numpy.log(estimated_likelihoods).mean() + ) + _logger.warning(geometric_mean_estimated_likelihoods) + arithmetic_mean_estimated_likelihoods = estimated_likelihoods.mean() + + result = MultiSubsetSimulationResult( + child_results=results, + model_name=model_name, + estimated_likelihood=geometric_mean_estimated_likelihoods, + arithmetic_mean_estimated_likelihood=arithmetic_mean_estimated_likelihoods, + num_children=len(results), + num_finished_children=num_finished, + clean_estimate=num_finished == len(results), + ) + return result + + def reverse_bisect_right(a, x, lo=0, hi=None): """Return the index where to insert item x in list a, assuming a is sorted in descending order. diff --git a/tests/subset_simulation/__snapshots__/test_subset_simulation_coalescing.ambr b/tests/subset_simulation/__snapshots__/test_subset_simulation_coalescing.ambr new file mode 100644 index 0000000..6700216 --- /dev/null +++ b/tests/subset_simulation/__snapshots__/test_subset_simulation_coalescing.ambr @@ -0,0 +1,10 @@ +# serializer version: 1 +# name: test_subset_simulation_multi_result_coalescing_easy_arithmetic + MultiSubsetSimulationResult(child_results=[SubsetSimulationResult(probs_list=(), over_target_cost=1, over_target_likelihood=1, under_target_cost=0.99, under_target_likelihood=0.8, lowest_likelihood=0.5, messages=[]), SubsetSimulationResult(probs_list=(), over_target_cost=1, over_target_likelihood=1, under_target_cost=0.99, under_target_likelihood=0.6, lowest_likelihood=0.01, messages=[])], model_name='test', estimated_likelihood=0.6928203230275509, arithmetic_mean_estimated_likelihood=0.7, num_children=2, num_finished_children=2, clean_estimate=True) +# --- +# name: test_subset_simulation_multi_result_coalescing_easy_geometric + MultiSubsetSimulationResult(child_results=[SubsetSimulationResult(probs_list=(), over_target_cost=1, over_target_likelihood=1, under_target_cost=0.99, under_target_likelihood=0.1, lowest_likelihood=0.5, messages=[]), SubsetSimulationResult(probs_list=(), over_target_cost=1, over_target_likelihood=1, under_target_cost=0.99, under_target_likelihood=0.001, lowest_likelihood=0.01, messages=[])], model_name='test', estimated_likelihood=0.010000000000000004, arithmetic_mean_estimated_likelihood=0.0505, num_children=2, num_finished_children=2, clean_estimate=True) +# --- +# name: test_subset_simulation_multi_result_coalescing_include_dirty + MultiSubsetSimulationResult(child_results=[SubsetSimulationResult(probs_list=(), over_target_cost=1, over_target_likelihood=1, under_target_cost=0.99, under_target_likelihood=0.8, lowest_likelihood=0.5, messages=[]), SubsetSimulationResult(probs_list=(), over_target_cost=1, over_target_likelihood=1, under_target_cost=0.99, under_target_likelihood=0.08, lowest_likelihood=0.01, messages=[]), SubsetSimulationResult(probs_list=(), over_target_cost=None, over_target_likelihood=None, under_target_cost=None, under_target_likelihood=None, lowest_likelihood=0.0001, messages=[])], model_name='test', estimated_likelihood=0.01856635533445112, arithmetic_mean_estimated_likelihood=0.29336666666666666, num_children=3, num_finished_children=2, clean_estimate=False) +# --- diff --git a/tests/subset_simulation/test_subset_simulation_coalescing.py b/tests/subset_simulation/test_subset_simulation_coalescing.py new file mode 100644 index 0000000..b79158d --- /dev/null +++ b/tests/subset_simulation/test_subset_simulation_coalescing.py @@ -0,0 +1,92 @@ +import deepdog.subset_simulation.subset_simulation_impl as impl +import numpy + + +def test_subset_simulation_multi_result_coalescing_include_dirty(snapshot): + res1 = impl.SubsetSimulationResult( + probs_list=(), + over_target_cost=1, + over_target_likelihood=1, + under_target_cost=0.99, + under_target_likelihood=0.8, + lowest_likelihood=0.5, + messages=[], + ) + + res2 = impl.SubsetSimulationResult( + probs_list=(), + over_target_cost=1, + over_target_likelihood=1, + under_target_cost=0.99, + under_target_likelihood=0.08, + lowest_likelihood=0.01, + messages=[], + ) + + res3 = impl.SubsetSimulationResult( + probs_list=(), + over_target_cost=None, + over_target_likelihood=None, + under_target_cost=None, + under_target_likelihood=None, + lowest_likelihood=0.0001, + messages=[], + ) + + combined = impl.coalesce_ss_results("test", [res1, res2, res3]) + + assert combined == snapshot + + +def test_subset_simulation_multi_result_coalescing_easy_arithmetic(snapshot): + res1 = impl.SubsetSimulationResult( + probs_list=(), + over_target_cost=1, + over_target_likelihood=1, + under_target_cost=0.99, + under_target_likelihood=0.8, + lowest_likelihood=0.5, + messages=[], + ) + + res2 = impl.SubsetSimulationResult( + probs_list=(), + over_target_cost=1, + over_target_likelihood=1, + under_target_cost=0.99, + under_target_likelihood=0.6, + lowest_likelihood=0.01, + messages=[], + ) + + combined = impl.coalesce_ss_results("test", [res1, res2]) + + assert combined.arithmetic_mean_estimated_likelihood == 0.7 + assert combined == snapshot + + +def test_subset_simulation_multi_result_coalescing_easy_geometric(snapshot): + res1 = impl.SubsetSimulationResult( + probs_list=(), + over_target_cost=1, + over_target_likelihood=1, + under_target_cost=0.99, + under_target_likelihood=0.1, + lowest_likelihood=0.5, + messages=[], + ) + + res2 = impl.SubsetSimulationResult( + probs_list=(), + over_target_cost=1, + over_target_likelihood=1, + under_target_cost=0.99, + under_target_likelihood=0.001, + lowest_likelihood=0.01, + messages=[], + ) + + combined = impl.coalesce_ss_results("test", [res1, res2]) + + numpy.testing.assert_allclose(combined.estimated_likelihood, 0.01) + assert combined == snapshot