From 5d0a7a4be09c58f8f8f859384f01d7912a98b8b9 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sat, 7 May 2022 18:45:58 -0500 Subject: [PATCH] feat!: bayes run now handles multidipoles with changes to output file format etc. --- deepdog/bayes_run.py | 150 ++++++++++++++++++------------------------- 1 file changed, 64 insertions(+), 86 deletions(-) diff --git a/deepdog/bayes_run.py b/deepdog/bayes_run.py index 83d2350..2348acb 100644 --- a/deepdog/bayes_run.py +++ b/deepdog/bayes_run.py @@ -23,9 +23,11 @@ _logger = logging.getLogger(__name__) def get_a_result(input) -> int: - model, dot_inputs, lows, highs, monte_carlo_count, max_frequency = input - sample_dipoles = model.get_model().get_n_single_dipoles( - monte_carlo_count, max_frequency + model, dot_inputs, lows, highs, monte_carlo_count, max_frequency, seed = input + + rng = numpy.random.default_rng(seed) + sample_dipoles = model.get_monte_carlo_dipole_inputs( + monte_carlo_count, max_frequency, rng_to_use=rng ) vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles) return numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs)) @@ -88,8 +90,6 @@ class BayesRun: run_count: int = 100, low_error: float = 0.9, high_error: float = 1.1, - pairs_high_error=None, - pairs_low_error=None, monte_carlo_count: int = 10000, monte_carlo_cycles: int = 10, target_success: int = 100, @@ -97,7 +97,6 @@ class BayesRun: max_frequency: float = 20, end_threshold: float = None, chunksize: int = CHUNKSIZE, - use_pairs: bool = False, ) -> None: self.dot_inputs = pdme.inputs.inputs_with_frequency_range( dot_positions, frequency_range @@ -106,18 +105,16 @@ class BayesRun: self.dot_inputs ) - self.use_pairs = use_pairs - - self.dot_pair_inputs = pdme.inputs.input_pairs_with_frequency_range( - dot_positions, frequency_range - ) - self.dot_pair_inputs_array = ( - pdme.measurement.input_types.dot_pair_inputs_to_array(self.dot_pair_inputs) - ) - self.models = [model for (_, model) in models_with_names] self.model_names = [name for (name, _) in models_with_names] self.actual_model = actual_model + + self.n: int + try: + self.n = self.actual_model.n # type: ignore + except AttributeError: + self.n = 1 + self.model_count = len(self.models) self.monte_carlo_count = monte_carlo_count self.monte_carlo_cycles = monte_carlo_cycles @@ -126,15 +123,16 @@ class BayesRun: self.run_count = run_count self.low_error = low_error self.high_error = high_error - if pairs_low_error is None: - self.pairs_low_error = self.low_error - else: - self.pairs_low_error = pairs_low_error - if pairs_high_error is None: - self.pairs_high_error = self.high_error - else: - self.pairs_high_error = pairs_high_error - self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"] + + self.csv_fields = [] + for i in range(self.n): + self.csv_fields.extend( + [ + f"dipole_moment_{i+1}", + f"dipole_location_{i+1}", + f"dipole_frequency_{i+1}", + ] + ) self.compensate_zeros = True self.chunksize = chunksize for name in self.model_names: @@ -143,10 +141,7 @@ class BayesRun: self.probabilities = [1 / self.model_count] * self.model_count timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - if self.use_pairs: - self.filename = f"{timestamp}-{filename_slug}.altbayes.pairs.csv" - else: - self.filename = f"{timestamp}-{filename_slug}.altbayes.csv" + self.filename = f"{timestamp}-{filename_slug}.bayesrun.csv" self.max_frequency = max_frequency if end_threshold is not None: @@ -179,29 +174,17 @@ class BayesRun: dots ) - pair_lows, pair_highs = (None, None) - if self.use_pairs: - pair_measurements = ( - actual_dipoles.get_percent_range_dot_pair_measurements( - self.dot_pair_inputs, - self.pairs_low_error, - self.pairs_high_error, - ) - ) - ( - pair_lows, - pair_highs, - ) = pdme.measurement.input_types.dot_range_measurements_low_high_arrays( - pair_measurements - ) - _logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}") + # define a new seed sequence for each run + seed_sequence = numpy.random.SeedSequence(run) + results = [] _logger.debug("Going to iterate over models now") for model_count, model in enumerate(self.models): _logger.debug(f"Doing model #{model_count}") - with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool: + core_count = multiprocessing.cpu_count() - 1 or 1 + with multiprocessing.Pool(core_count) as pool: cycle_count = 0 cycle_success = 0 cycles = 0 @@ -212,55 +195,50 @@ class BayesRun: cycles += 1 current_success = 0 cycle_count += self.monte_carlo_count * self.monte_carlo_cycles - if self.use_pairs: - current_success = sum( - pool.imap_unordered( - get_a_result_using_pairs, - [ - ( - model, - self.dot_inputs_array, - self.dot_pair_inputs_array, - lows, - highs, - pair_lows, - pair_highs, - self.monte_carlo_count, - self.max_frequency, - ) - ] - * self.monte_carlo_cycles, - self.chunksize, - ) - ) - else: - current_success = sum( - pool.imap_unordered( - get_a_result, - [ - ( - model, - self.dot_inputs_array, - lows, - highs, - self.monte_carlo_count, - self.max_frequency, - ) - ] - * self.monte_carlo_cycles, - self.chunksize, - ) + + # generate a seed from the sequence for each core. + # note this needs to be inside the loop for monte carlo cycle steps! + # that way we get more stuff. + seeds = seed_sequence.spawn(self.monte_carlo_cycles) + + current_success = sum( + pool.imap_unordered( + get_a_result, + [ + ( + model, + self.dot_inputs_array, + lows, + highs, + self.monte_carlo_count, + self.max_frequency, + seed, + ) + for seed in seeds + ], + self.chunksize, ) + ) cycle_success += current_success results.append((cycle_count, cycle_success)) _logger.debug("Done, constructing output now") row = { - "dipole_moment": actual_dipoles.dipoles[0].p, - "dipole_location": actual_dipoles.dipoles[0].s, - "dipole_frequency": actual_dipoles.dipoles[0].w, + "dipole_moment_1": actual_dipoles.dipoles[0].p, + "dipole_location_1": actual_dipoles.dipoles[0].s, + "dipole_frequency_1": actual_dipoles.dipoles[0].w, } + for i in range(1, self.n): + try: + current_dipoles = actual_dipoles.dipoles[i] + row[f"dipole_moment_{i+1}"] = current_dipoles.p + row[f"dipole_location_{i+1}"] = current_dipoles.s + row[f"dipole_frequency_{i+1}"] = current_dipoles.w + except IndexError: + _logger.info(f"Not writing anymore, saw end after {i}") + break + successes: List[float] = [] counts: List[int] = [] for model_index, (name, (count, result)) in enumerate(