From 4623af69e7971eb72fdd6509d28ca41d8125a906 Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Mon, 1 Jul 2024 03:11:26 +0000 Subject: [PATCH] feat: allows use of new event-based method for time series, despite it being slow --- tantri/cli/__init__.py | 56 +++++++++++++++++++++-------- tantri/dipoles/event_time_series.py | 3 -- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/tantri/cli/__init__.py b/tantri/cli/__init__.py index 19ca99a..3e454a1 100755 --- a/tantri/cli/__init__.py +++ b/tantri/cli/__init__.py @@ -6,6 +6,7 @@ import tantri.cli.input_files.write_dipoles import tantri.cli.file_importer import json import tantri.dipoles +import tantri.dipoles.event_time_series import pathlib @@ -97,6 +98,9 @@ def cli(log, log_file): required=True, ) @click.option("--header-row/--no-header-row", default=False, help="Write a header row") +@click.option( + "--event-based/--no-event-based", default=False, help="Use new event-based method" +) def write_time_series( dipoles_file, dots_file, @@ -106,6 +110,7 @@ def write_time_series( time_series_rng_seed, output_file, header_row, + event_based, ): """ Generate a time series for the passed in parameters. @@ -119,6 +124,7 @@ def write_time_series( time_series_rng_seed, output_file, header_row, + event_based, ) @@ -131,6 +137,7 @@ def _write_time_series( time_series_rng_seed, output_file, header_row, + new_method, ): _logger.debug( f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]" @@ -156,21 +163,42 @@ def _write_time_series( f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}" ) - _logger.debug(f"Got seed {time_series_rng_seed}") - if time_series_rng_seed is None: - time_series = tantri.dipoles.DipoleTimeSeries( - dipoles, dots, measurement_enum, delta_t - ) - else: - rng = numpy.random.default_rng(time_series_rng_seed) - time_series = tantri.dipoles.DipoleTimeSeries( - dipoles, dots, measurement_enum, delta_t, rng - ) + if new_method: + _logger.info("Using new method") + _logger.debug(f"Got seed {time_series_rng_seed}") + if time_series_rng_seed is None: + time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries( + dipoles, dots, measurement_enum, delta_t, num_iterations + ) + else: + rng = numpy.random.default_rng(time_series_rng_seed) + time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries( + dipoles, dots, measurement_enum, delta_t, num_iterations, rng + ) + output_series = time_series.create_time_series() + for time, time_series_dict in output_series: + values = ", ".join(str(time_series_dict[label]) for label in labels) + out.write(f"{time}, {values}\n") - for i in range(num_iterations): - transition = time_series.transition() - transition_values = ", ".join(str(transition[label]) for label in labels) - out.write(f"{i * delta_t}, {transition_values}\n") + else: + # in the old method + _logger.debug(f"Got seed {time_series_rng_seed}") + if time_series_rng_seed is None: + time_series = tantri.dipoles.DipoleTimeSeries( + dipoles, dots, measurement_enum, delta_t + ) + else: + rng = numpy.random.default_rng(time_series_rng_seed) + time_series = tantri.dipoles.DipoleTimeSeries( + dipoles, dots, measurement_enum, delta_t, rng + ) + + for i in range(num_iterations): + transition = time_series.transition() + transition_values = ", ".join( + str(transition[label]) for label in labels + ) + out.write(f"{i * delta_t}, {transition_values}\n") @cli.command() diff --git a/tantri/dipoles/event_time_series.py b/tantri/dipoles/event_time_series.py index 673d263..000ccda 100644 --- a/tantri/dipoles/event_time_series.py +++ b/tantri/dipoles/event_time_series.py @@ -112,7 +112,6 @@ class EventDipoleTimeSeries: _logger.debug(f"Doing dipole {dipole}") series = dipole.get_time_series(self.dt, self.num_samples, self.rng) for time, meases in series: - _logger.debug(f"Working on time {time}, got measures {meases}") if time in collected_dictionary: for k, v in meases.items(): @@ -133,8 +132,6 @@ def get_num_events_before( event_times: List = [] random_size = max(1, int(total_time // scale)) while sum(event_times) < total_time: - _logger.debug(sum(event_times)) - _logger.debug(event_times) event_times.extend(rng.exponential(scale=scale, size=random_size)) accumulator = 0