feat: adds a filter that works with cost functions
This commit is contained in:
parent
72791f2d0f
commit
8845b2875f
19
deepdog/direct_monte_carlo/cost_function_filter.py
Normal file
19
deepdog/direct_monte_carlo/cost_function_filter.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from deepdog.direct_monte_carlo.direct_mc import DirectMonteCarloFilter
|
||||||
|
from typing import Sequence, Callable
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
class CostFunctionTargetFilter(DirectMonteCarloFilter):
|
||||||
|
def __init__(self, cost_function: Callable[[numpy.ndarray], numpy.ndarray], target_cost: float):
|
||||||
|
"""
|
||||||
|
Filters dipoles by cost, only leaving dipoles with cost below target_cost
|
||||||
|
"""
|
||||||
|
self.cost_function = cost_function
|
||||||
|
self.target_cost = target_cost
|
||||||
|
|
||||||
|
def filter_samples(self, samples: numpy.ndarray) -> numpy.ndarray:
|
||||||
|
current_sample = samples
|
||||||
|
|
||||||
|
costs = self.cost_function(current_sample)
|
||||||
|
|
||||||
|
current_sample = current_sample[costs < self.target_cost]
|
||||||
|
return current_sample
|
38
tests/direct_monte_carlo/test_cost_function_filter.py
Normal file
38
tests/direct_monte_carlo/test_cost_function_filter.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import deepdog.direct_monte_carlo.cost_function_filter
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
def test_px_cost_function_filter_example():
|
||||||
|
|
||||||
|
dipoles_1 = [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7],
|
||||||
|
[2, 3, 2, 5, 4, 7, 6],
|
||||||
|
]
|
||||||
|
|
||||||
|
dipoles_2 = [
|
||||||
|
[15, 9, 8, 7, 6, 5, 3],
|
||||||
|
[30, 4, 4, 7, 3, 1, 4],
|
||||||
|
]
|
||||||
|
|
||||||
|
dipoleses = numpy.array([dipoles_1, dipoles_2])
|
||||||
|
|
||||||
|
def cost_function(dipoleses: numpy.ndarray) -> numpy.ndarray:
|
||||||
|
return dipoleses[:, :, 0].max(axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
expected_costs = numpy.array([2, 30])
|
||||||
|
|
||||||
|
numpy.testing.assert_array_equal(cost_function(dipoleses), expected_costs)
|
||||||
|
|
||||||
|
filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 5)
|
||||||
|
|
||||||
|
actual_filtered = filter.filter_samples(dipoleses)
|
||||||
|
expected_filtered = numpy.array([dipoles_1])
|
||||||
|
assert actual_filtered.size != 0
|
||||||
|
numpy.testing.assert_array_equal(actual_filtered, expected_filtered)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
filter_stricter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 0.5)
|
||||||
|
|
||||||
|
actual_filtered_stricter = filter_stricter.filter_samples(dipoleses)
|
||||||
|
assert actual_filtered_stricter.size == 0
|
Loading…
x
Reference in New Issue
Block a user