From 8845b2875f2c91c91dd3988fabda26400c59b2d7 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 19 May 2024 02:26:00 -0500 Subject: [PATCH] feat: adds a filter that works with cost functions --- .../cost_function_filter.py | 19 ++++++++++ .../test_cost_function_filter.py | 38 +++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 deepdog/direct_monte_carlo/cost_function_filter.py create mode 100644 tests/direct_monte_carlo/test_cost_function_filter.py diff --git a/deepdog/direct_monte_carlo/cost_function_filter.py b/deepdog/direct_monte_carlo/cost_function_filter.py new file mode 100644 index 0000000..dec937e --- /dev/null +++ b/deepdog/direct_monte_carlo/cost_function_filter.py @@ -0,0 +1,19 @@ +from deepdog.direct_monte_carlo.direct_mc import DirectMonteCarloFilter +from typing import Sequence, Callable +import numpy + +class CostFunctionTargetFilter(DirectMonteCarloFilter): + def __init__(self, cost_function: Callable[[numpy.ndarray], numpy.ndarray], target_cost: float): + """ + Filters dipoles by cost, only leaving dipoles with cost below target_cost + """ + self.cost_function = cost_function + self.target_cost = target_cost + + def filter_samples(self, samples: numpy.ndarray) -> numpy.ndarray: + current_sample = samples + + costs = self.cost_function(current_sample) + + current_sample = current_sample[costs < self.target_cost] + return current_sample diff --git a/tests/direct_monte_carlo/test_cost_function_filter.py b/tests/direct_monte_carlo/test_cost_function_filter.py new file mode 100644 index 0000000..82e3864 --- /dev/null +++ b/tests/direct_monte_carlo/test_cost_function_filter.py @@ -0,0 +1,38 @@ +import deepdog.direct_monte_carlo.cost_function_filter +import numpy + +def test_px_cost_function_filter_example(): + + dipoles_1 = [ + [1, 2, 3, 4, 5, 6, 7], + [2, 3, 2, 5, 4, 7, 6], + ] + + dipoles_2 = [ + [15, 9, 8, 7, 6, 5, 3], + [30, 4, 4, 7, 3, 1, 4], + ] + + dipoleses = numpy.array([dipoles_1, dipoles_2]) + + def cost_function(dipoleses: numpy.ndarray) -> numpy.ndarray: + return dipoleses[:, :, 0].max(axis=-1) + + + expected_costs = numpy.array([2, 30]) + + numpy.testing.assert_array_equal(cost_function(dipoleses), expected_costs) + + filter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 5) + + actual_filtered = filter.filter_samples(dipoleses) + expected_filtered = numpy.array([dipoles_1]) + assert actual_filtered.size != 0 + numpy.testing.assert_array_equal(actual_filtered, expected_filtered) + + + + filter_stricter = deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(cost_function, 0.5) + + actual_filtered_stricter = filter_stricter.filter_samples(dipoleses) + assert actual_filtered_stricter.size == 0