fmt: formatting, nicer log, removing comment
This commit is contained in:
parent
92b49fce7c
commit
8d04803eb3
@ -1,9 +1,14 @@
|
||||
from deepdog.direct_monte_carlo.direct_mc import DirectMonteCarloFilter
|
||||
from typing import Sequence, Callable
|
||||
from typing import Callable
|
||||
import numpy
|
||||
|
||||
|
||||
class CostFunctionTargetFilter(DirectMonteCarloFilter):
|
||||
def __init__(self, cost_function: Callable[[numpy.ndarray], numpy.ndarray], target_cost: float):
|
||||
def __init__(
|
||||
self,
|
||||
cost_function: Callable[[numpy.ndarray], numpy.ndarray],
|
||||
target_cost: float,
|
||||
):
|
||||
"""
|
||||
Filters dipoles by cost, only leaving dipoles with cost below target_cost
|
||||
"""
|
||||
@ -12,7 +17,7 @@ class CostFunctionTargetFilter(DirectMonteCarloFilter):
|
||||
|
||||
def filter_samples(self, samples: numpy.ndarray) -> numpy.ndarray:
|
||||
current_sample = samples
|
||||
|
||||
|
||||
costs = self.cost_function(current_sample)
|
||||
|
||||
current_sample = current_sample[costs < self.target_cost]
|
||||
|
@ -94,9 +94,8 @@ class SubsetSimulation:
|
||||
_logger.info(f"\tn_c: {self.n_c}")
|
||||
_logger.info(f"\tn_s: {self.n_s}")
|
||||
_logger.info(f"\tm: {self.m_max}")
|
||||
_logger.info(f"\tseeds:")
|
||||
_logger.info(f"\t\t{mcmc_seed=}")
|
||||
_logger.info(f"\t\t{level_0_seed=}")
|
||||
_logger.info(f"\t{mcmc_seed=}")
|
||||
_logger.info(f"\t{level_0_seed=}")
|
||||
_logger.info("let's do level 0...")
|
||||
|
||||
self.target_cost = target_cost
|
||||
@ -272,22 +271,6 @@ class SubsetSimulation:
|
||||
samples_generated += self.n_s
|
||||
samples_rejected += rejected_count
|
||||
|
||||
# for seed_index, (c, s) in enumerate(next_seeds):
|
||||
# # chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
|
||||
# # until new version gotta do
|
||||
# _logger.debug(
|
||||
# f"\t{seed_index}: getting another chain from the next seed"
|
||||
# )
|
||||
# rejected_count, chain = self.model.get_repeat_counting_mcmc_chain(
|
||||
# s,
|
||||
# self.cost_function_to_use,
|
||||
# self.n_s,
|
||||
# threshold_cost,
|
||||
# stdevs,
|
||||
# initial_cost=c,
|
||||
# rng_arg=mcmc_rng,
|
||||
# )
|
||||
|
||||
_logger.debug("finished mcmc")
|
||||
_logger.debug(f"{samples_rejected=} out of {samples_generated=}")
|
||||
if samples_rejected * 2 > samples_generated:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import deepdog.direct_monte_carlo.cost_function_filter
|
||||
import numpy
|
||||
|
||||
|
||||
def test_px_cost_function_filter_example():
|
||||
|
||||
dipoles_1 = [
|
||||
@ -18,21 +19,24 @@ def test_px_cost_function_filter_example():
|
||||
def cost_function(dipoleses: numpy.ndarray) -> numpy.ndarray:
|
||||
return dipoleses[:, :, 0].max(axis=-1)
|
||||
|
||||
|
||||
expected_costs = numpy.array([2, 30])
|
||||
|
||||
numpy.testing.assert_array_equal(cost_function(dipoleses), expected_costs)
|
||||
|
||||
filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 5)
|
||||
filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, 5
|
||||
)
|
||||
|
||||
actual_filtered = filter.filter_samples(dipoleses)
|
||||
expected_filtered = numpy.array([dipoles_1])
|
||||
assert actual_filtered.size != 0
|
||||
numpy.testing.assert_array_equal(actual_filtered, expected_filtered)
|
||||
|
||||
|
||||
|
||||
filter_stricter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 0.5)
|
||||
filter_stricter = (
|
||||
deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, 0.5
|
||||
)
|
||||
)
|
||||
|
||||
actual_filtered_stricter = filter_stricter.filter_samples(dipoleses)
|
||||
assert actual_filtered_stricter.size == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user