feat: adds a filter that works with cost functions

This commit is contained in:
Deepak Mallubhotla 2024-05-19 02:26:00 -05:00
parent 72791f2d0f
commit 8845b2875f
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
2 changed files with 57 additions and 0 deletions

View File

@ -0,0 +1,19 @@
from deepdog.direct_monte_carlo.direct_mc import DirectMonteCarloFilter
from typing import Sequence, Callable
import numpy
class CostFunctionTargetFilter(DirectMonteCarloFilter):
def __init__(self, cost_function: Callable[[numpy.ndarray], numpy.ndarray], target_cost: float):
"""
Filters dipoles by cost, only leaving dipoles with cost below target_cost
"""
self.cost_function = cost_function
self.target_cost = target_cost
def filter_samples(self, samples: numpy.ndarray) -> numpy.ndarray:
current_sample = samples
costs = self.cost_function(current_sample)
current_sample = current_sample[costs < self.target_cost]
return current_sample

View File

@ -0,0 +1,38 @@
import deepdog.direct_monte_carlo.cost_function_filter
import numpy
def test_px_cost_function_filter_example():
dipoles_1 = [
[1, 2, 3, 4, 5, 6, 7],
[2, 3, 2, 5, 4, 7, 6],
]
dipoles_2 = [
[15, 9, 8, 7, 6, 5, 3],
[30, 4, 4, 7, 3, 1, 4],
]
dipoleses = numpy.array([dipoles_1, dipoles_2])
def cost_function(dipoleses: numpy.ndarray) -> numpy.ndarray:
return dipoleses[:, :, 0].max(axis=-1)
expected_costs = numpy.array([2, 30])
numpy.testing.assert_array_equal(cost_function(dipoleses), expected_costs)
filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 5)
actual_filtered = filter.filter_samples(dipoleses)
expected_filtered = numpy.array([dipoles_1])
assert actual_filtered.size != 0
numpy.testing.assert_array_equal(actual_filtered, expected_filtered)
filter_stricter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 0.5)
actual_filtered_stricter = filter_stricter.filter_samples(dipoleses)
assert actual_filtered_stricter.size == 0