From 8d04803eb341c63d25b347b011e6a0b709c9438f Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 19 May 2024 02:29:59 -0500 Subject: [PATCH] fmt: formatting, nicer log, removing comment --- .../cost_function_filter.py | 11 +++++++--- .../subset_simulation_impl.py | 21 ++----------------- .../test_cost_function_filter.py | 14 ++++++++----- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/deepdog/direct_monte_carlo/cost_function_filter.py b/deepdog/direct_monte_carlo/cost_function_filter.py index dec937e..c9e4d0f 100644 --- a/deepdog/direct_monte_carlo/cost_function_filter.py +++ b/deepdog/direct_monte_carlo/cost_function_filter.py @@ -1,9 +1,14 @@ from deepdog.direct_monte_carlo.direct_mc import DirectMonteCarloFilter -from typing import Sequence, Callable +from typing import Callable import numpy + class CostFunctionTargetFilter(DirectMonteCarloFilter): - def __init__(self, cost_function: Callable[[numpy.ndarray], numpy.ndarray], target_cost: float): + def __init__( + self, + cost_function: Callable[[numpy.ndarray], numpy.ndarray], + target_cost: float, + ): """ Filters dipoles by cost, only leaving dipoles with cost below target_cost """ @@ -12,7 +17,7 @@ class CostFunctionTargetFilter(DirectMonteCarloFilter): def filter_samples(self, samples: numpy.ndarray) -> numpy.ndarray: current_sample = samples - + costs = self.cost_function(current_sample) current_sample = current_sample[costs < self.target_cost] diff --git a/deepdog/subset_simulation/subset_simulation_impl.py b/deepdog/subset_simulation/subset_simulation_impl.py index 11d6d22..9344645 100644 --- a/deepdog/subset_simulation/subset_simulation_impl.py +++ b/deepdog/subset_simulation/subset_simulation_impl.py @@ -94,9 +94,8 @@ class SubsetSimulation: _logger.info(f"\tn_c: {self.n_c}") _logger.info(f"\tn_s: {self.n_s}") _logger.info(f"\tm: {self.m_max}") - _logger.info(f"\tseeds:") - _logger.info(f"\t\t{mcmc_seed=}") - _logger.info(f"\t\t{level_0_seed=}") + _logger.info(f"\t{mcmc_seed=}") + _logger.info(f"\t{level_0_seed=}") _logger.info("let's do level 0...") self.target_cost = target_cost @@ -272,22 +271,6 @@ class SubsetSimulation: samples_generated += self.n_s samples_rejected += rejected_count - # for seed_index, (c, s) in enumerate(next_seeds): - # # chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs) - # # until new version gotta do - # _logger.debug( - # f"\t{seed_index}: getting another chain from the next seed" - # ) - # rejected_count, chain = self.model.get_repeat_counting_mcmc_chain( - # s, - # self.cost_function_to_use, - # self.n_s, - # threshold_cost, - # stdevs, - # initial_cost=c, - # rng_arg=mcmc_rng, - # ) - _logger.debug("finished mcmc") _logger.debug(f"{samples_rejected=} out of {samples_generated=}") if samples_rejected * 2 > samples_generated: diff --git a/tests/direct_monte_carlo/test_cost_function_filter.py b/tests/direct_monte_carlo/test_cost_function_filter.py index 82e3864..2e6482c 100644 --- a/tests/direct_monte_carlo/test_cost_function_filter.py +++ b/tests/direct_monte_carlo/test_cost_function_filter.py @@ -1,6 +1,7 @@ import deepdog.direct_monte_carlo.cost_function_filter import numpy + def test_px_cost_function_filter_example(): dipoles_1 = [ @@ -18,21 +19,24 @@ def test_px_cost_function_filter_example(): def cost_function(dipoleses: numpy.ndarray) -> numpy.ndarray: return dipoleses[:, :, 0].max(axis=-1) - expected_costs = numpy.array([2, 30]) numpy.testing.assert_array_equal(cost_function(dipoleses), expected_costs) - filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 5) + filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter( + cost_function, 5 + ) actual_filtered = filter.filter_samples(dipoleses) expected_filtered = numpy.array([dipoles_1]) assert actual_filtered.size != 0 numpy.testing.assert_array_equal(actual_filtered, expected_filtered) - - - filter_stricter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 0.5) + filter_stricter = ( + deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter( + cost_function, 0.5 + ) + ) actual_filtered_stricter = filter_stricter.filter_samples(dipoleses) assert actual_filtered_stricter.size == 0