feat: allows use of new event-based method for time series, despite it being slow
All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good

This commit is contained in:
Deepak Mallubhotla 2024-07-01 03:11:26 +00:00
parent 062dbf6216
commit 4623af69e7
Signed by: deepak
GPG Key ID: 8F904A3FC7021497
2 changed files with 42 additions and 17 deletions

View File

@ -6,6 +6,7 @@ import tantri.cli.input_files.write_dipoles
import tantri.cli.file_importer
import json
import tantri.dipoles
import tantri.dipoles.event_time_series
import pathlib
@ -97,6 +98,9 @@ def cli(log, log_file):
required=True,
)
@click.option("--header-row/--no-header-row", default=False, help="Write a header row")
@click.option(
"--event-based/--no-event-based", default=False, help="Use new event-based method"
)
def write_time_series(
dipoles_file,
dots_file,
@ -106,6 +110,7 @@ def write_time_series(
time_series_rng_seed,
output_file,
header_row,
event_based,
):
"""
Generate a time series for the passed in parameters.
@ -119,6 +124,7 @@ def write_time_series(
time_series_rng_seed,
output_file,
header_row,
event_based,
)
@ -131,6 +137,7 @@ def _write_time_series(
time_series_rng_seed,
output_file,
header_row,
new_method,
):
_logger.debug(
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
@ -156,6 +163,25 @@ def _write_time_series(
f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}"
)
if new_method:
_logger.info("Using new method")
_logger.debug(f"Got seed {time_series_rng_seed}")
if time_series_rng_seed is None:
time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries(
dipoles, dots, measurement_enum, delta_t, num_iterations
)
else:
rng = numpy.random.default_rng(time_series_rng_seed)
time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries(
dipoles, dots, measurement_enum, delta_t, num_iterations, rng
)
output_series = time_series.create_time_series()
for time, time_series_dict in output_series:
values = ", ".join(str(time_series_dict[label]) for label in labels)
out.write(f"{time}, {values}\n")
else:
# in the old method
_logger.debug(f"Got seed {time_series_rng_seed}")
if time_series_rng_seed is None:
time_series = tantri.dipoles.DipoleTimeSeries(
@ -169,7 +195,9 @@ def _write_time_series(
for i in range(num_iterations):
transition = time_series.transition()
transition_values = ", ".join(str(transition[label]) for label in labels)
transition_values = ", ".join(
str(transition[label]) for label in labels
)
out.write(f"{i * delta_t}, {transition_values}\n")

View File

@ -112,7 +112,6 @@ class EventDipoleTimeSeries:
_logger.debug(f"Doing dipole {dipole}")
series = dipole.get_time_series(self.dt, self.num_samples, self.rng)
for time, meases in series:
_logger.debug(f"Working on time {time}, got measures {meases}")
if time in collected_dictionary:
for k, v in meases.items():
@ -133,8 +132,6 @@ def get_num_events_before(
event_times: List = []
random_size = max(1, int(total_time // scale))
while sum(event_times) < total_time:
_logger.debug(sum(event_times))
_logger.debug(event_times)
event_times.extend(rng.exponential(scale=scale, size=random_size))
accumulator = 0