fmt: formatting, nicer log, removing comment

This commit is contained in:
Deepak Mallubhotla 2024-05-19 02:29:59 -05:00
parent 92b49fce7c
commit 8d04803eb3
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
3 changed files with 19 additions and 27 deletions

View File

@ -1,9 +1,14 @@
from deepdog.direct_monte_carlo.direct_mc import DirectMonteCarloFilter
from typing import Sequence, Callable
from typing import Callable
import numpy
class CostFunctionTargetFilter(DirectMonteCarloFilter):
def __init__(self, cost_function: Callable[[numpy.ndarray], numpy.ndarray], target_cost: float):
def __init__(
self,
cost_function: Callable[[numpy.ndarray], numpy.ndarray],
target_cost: float,
):
"""
Filters dipoles by cost, only leaving dipoles with cost below target_cost
"""

View File

@ -94,9 +94,8 @@ class SubsetSimulation:
_logger.info(f"\tn_c: {self.n_c}")
_logger.info(f"\tn_s: {self.n_s}")
_logger.info(f"\tm: {self.m_max}")
_logger.info(f"\tseeds:")
_logger.info(f"\t\t{mcmc_seed=}")
_logger.info(f"\t\t{level_0_seed=}")
_logger.info(f"\t{mcmc_seed=}")
_logger.info(f"\t{level_0_seed=}")
_logger.info("let's do level 0...")
self.target_cost = target_cost
@ -272,22 +271,6 @@ class SubsetSimulation:
samples_generated += self.n_s
samples_rejected += rejected_count
# for seed_index, (c, s) in enumerate(next_seeds):
# # chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
# # until new version gotta do
# _logger.debug(
# f"\t{seed_index}: getting another chain from the next seed"
# )
# rejected_count, chain = self.model.get_repeat_counting_mcmc_chain(
# s,
# self.cost_function_to_use,
# self.n_s,
# threshold_cost,
# stdevs,
# initial_cost=c,
# rng_arg=mcmc_rng,
# )
_logger.debug("finished mcmc")
_logger.debug(f"{samples_rejected=} out of {samples_generated=}")
if samples_rejected * 2 > samples_generated:

View File

@ -1,6 +1,7 @@
import deepdog.direct_monte_carlo.cost_function_filter
import numpy
def test_px_cost_function_filter_example():
dipoles_1 = [
@ -18,21 +19,24 @@ def test_px_cost_function_filter_example():
def cost_function(dipoleses: numpy.ndarray) -> numpy.ndarray:
return dipoleses[:, :, 0].max(axis=-1)
expected_costs = numpy.array([2, 30])
numpy.testing.assert_array_equal(cost_function(dipoleses), expected_costs)
filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 5)
filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
cost_function, 5
)
actual_filtered = filter.filter_samples(dipoleses)
expected_filtered = numpy.array([dipoles_1])
assert actual_filtered.size != 0
numpy.testing.assert_array_equal(actual_filtered, expected_filtered)
filter_stricter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 0.5)
filter_stricter = (
deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
cost_function, 0.5
)
)
actual_filtered_stricter = filter_stricter.filter_samples(dipoleses)
assert actual_filtered_stricter.size == 0