feat: allows use of new event-based method for time series, despite it being slow
All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good
All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good
This commit is contained in:
parent
062dbf6216
commit
4623af69e7
@ -6,6 +6,7 @@ import tantri.cli.input_files.write_dipoles
|
||||
import tantri.cli.file_importer
|
||||
import json
|
||||
import tantri.dipoles
|
||||
import tantri.dipoles.event_time_series
|
||||
import pathlib
|
||||
|
||||
|
||||
@ -97,6 +98,9 @@ def cli(log, log_file):
|
||||
required=True,
|
||||
)
|
||||
@click.option("--header-row/--no-header-row", default=False, help="Write a header row")
|
||||
@click.option(
|
||||
"--event-based/--no-event-based", default=False, help="Use new event-based method"
|
||||
)
|
||||
def write_time_series(
|
||||
dipoles_file,
|
||||
dots_file,
|
||||
@ -106,6 +110,7 @@ def write_time_series(
|
||||
time_series_rng_seed,
|
||||
output_file,
|
||||
header_row,
|
||||
event_based,
|
||||
):
|
||||
"""
|
||||
Generate a time series for the passed in parameters.
|
||||
@ -119,6 +124,7 @@ def write_time_series(
|
||||
time_series_rng_seed,
|
||||
output_file,
|
||||
header_row,
|
||||
event_based,
|
||||
)
|
||||
|
||||
|
||||
@ -131,6 +137,7 @@ def _write_time_series(
|
||||
time_series_rng_seed,
|
||||
output_file,
|
||||
header_row,
|
||||
new_method,
|
||||
):
|
||||
_logger.debug(
|
||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||
@ -156,6 +163,25 @@ def _write_time_series(
|
||||
f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}"
|
||||
)
|
||||
|
||||
if new_method:
|
||||
_logger.info("Using new method")
|
||||
_logger.debug(f"Got seed {time_series_rng_seed}")
|
||||
if time_series_rng_seed is None:
|
||||
time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries(
|
||||
dipoles, dots, measurement_enum, delta_t, num_iterations
|
||||
)
|
||||
else:
|
||||
rng = numpy.random.default_rng(time_series_rng_seed)
|
||||
time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries(
|
||||
dipoles, dots, measurement_enum, delta_t, num_iterations, rng
|
||||
)
|
||||
output_series = time_series.create_time_series()
|
||||
for time, time_series_dict in output_series:
|
||||
values = ", ".join(str(time_series_dict[label]) for label in labels)
|
||||
out.write(f"{time}, {values}\n")
|
||||
|
||||
else:
|
||||
# in the old method
|
||||
_logger.debug(f"Got seed {time_series_rng_seed}")
|
||||
if time_series_rng_seed is None:
|
||||
time_series = tantri.dipoles.DipoleTimeSeries(
|
||||
@ -169,7 +195,9 @@ def _write_time_series(
|
||||
|
||||
for i in range(num_iterations):
|
||||
transition = time_series.transition()
|
||||
transition_values = ", ".join(str(transition[label]) for label in labels)
|
||||
transition_values = ", ".join(
|
||||
str(transition[label]) for label in labels
|
||||
)
|
||||
out.write(f"{i * delta_t}, {transition_values}\n")
|
||||
|
||||
|
||||
|
@ -112,7 +112,6 @@ class EventDipoleTimeSeries:
|
||||
_logger.debug(f"Doing dipole {dipole}")
|
||||
series = dipole.get_time_series(self.dt, self.num_samples, self.rng)
|
||||
for time, meases in series:
|
||||
_logger.debug(f"Working on time {time}, got measures {meases}")
|
||||
if time in collected_dictionary:
|
||||
for k, v in meases.items():
|
||||
|
||||
@ -133,8 +132,6 @@ def get_num_events_before(
|
||||
event_times: List = []
|
||||
random_size = max(1, int(total_time // scale))
|
||||
while sum(event_times) < total_time:
|
||||
_logger.debug(sum(event_times))
|
||||
_logger.debug(event_times)
|
||||
event_times.extend(rng.exponential(scale=scale, size=random_size))
|
||||
|
||||
accumulator = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user