diff --git a/deepdog/real_spectrum_run.py b/deepdog/real_spectrum_run.py index 5180b6d..03bf147 100644 --- a/deepdog/real_spectrum_run.py +++ b/deepdog/real_spectrum_run.py @@ -32,6 +32,28 @@ def get_a_result(input) -> int: return numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs)) +def get_a_result_fast_filter(input) -> int: + model, dot_inputs, lows, highs, monte_carlo_count, seed = input + + rng = numpy.random.default_rng(seed) + # TODO: A long term refactor is to pull the frequency stuff out from here. The None stands for max_frequency, which is unneeded in the actually useful models. + sample_dipoles = model.get_monte_carlo_dipole_inputs( + monte_carlo_count, None, rng_to_use=rng + ) + + current_sample = sample_dipoles + for di, low, high in zip(dot_inputs, lows, highs): + + if len(current_sample) < 1: + break + vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses( + numpy.array([di]), current_sample + ) + + current_sample = current_sample[numpy.all((vals > low) & (vals < high), axis=1)] + return len(current_sample) + + class RealSpectrumRun: """ A bayes run given some real data. @@ -65,6 +87,7 @@ class RealSpectrumRun: max_monte_carlo_cycles_steps: int = 10, chunksize: int = CHUNKSIZE, initial_seed: int = 12345, + use_fast_filter: bool = True, ) -> None: self.measurements = measurements self.dot_inputs = [(measure.r, measure.f) for measure in self.measurements] @@ -93,7 +116,11 @@ class RealSpectrumRun: self.probabilities = [1 / self.model_count] * self.model_count timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - self.filename = f"{timestamp}-{filename_slug}.realdata.bayesrun.csv" + self.use_fast_filter = use_fast_filter + ff_string = "no_fast_filter" + if self.use_fast_filter: + ff_string = "fast_filter" + self.filename = f"{timestamp}-{filename_slug}.realdata.{ff_string}.bayesrun.csv" self.initial_seed = initial_seed def go(self) -> None: @@ -133,9 +160,13 @@ class RealSpectrumRun: # that way we get more stuff. seeds = seed_sequence.spawn(self.monte_carlo_cycles) + if self.use_fast_filter: + result_func = get_a_result_fast_filter + else: + result_func = get_a_result current_success = sum( pool.imap_unordered( - get_a_result, + result_func, [ ( model,