feat: allows use of new event-based method for time series, despite it being slow
All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good
All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good
This commit is contained in:
parent
062dbf6216
commit
4623af69e7
@ -6,6 +6,7 @@ import tantri.cli.input_files.write_dipoles
|
|||||||
import tantri.cli.file_importer
|
import tantri.cli.file_importer
|
||||||
import json
|
import json
|
||||||
import tantri.dipoles
|
import tantri.dipoles
|
||||||
|
import tantri.dipoles.event_time_series
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
|
|
||||||
@ -97,6 +98,9 @@ def cli(log, log_file):
|
|||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
@click.option("--header-row/--no-header-row", default=False, help="Write a header row")
|
@click.option("--header-row/--no-header-row", default=False, help="Write a header row")
|
||||||
|
@click.option(
|
||||||
|
"--event-based/--no-event-based", default=False, help="Use new event-based method"
|
||||||
|
)
|
||||||
def write_time_series(
|
def write_time_series(
|
||||||
dipoles_file,
|
dipoles_file,
|
||||||
dots_file,
|
dots_file,
|
||||||
@ -106,6 +110,7 @@ def write_time_series(
|
|||||||
time_series_rng_seed,
|
time_series_rng_seed,
|
||||||
output_file,
|
output_file,
|
||||||
header_row,
|
header_row,
|
||||||
|
event_based,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate a time series for the passed in parameters.
|
Generate a time series for the passed in parameters.
|
||||||
@ -119,6 +124,7 @@ def write_time_series(
|
|||||||
time_series_rng_seed,
|
time_series_rng_seed,
|
||||||
output_file,
|
output_file,
|
||||||
header_row,
|
header_row,
|
||||||
|
event_based,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -131,6 +137,7 @@ def _write_time_series(
|
|||||||
time_series_rng_seed,
|
time_series_rng_seed,
|
||||||
output_file,
|
output_file,
|
||||||
header_row,
|
header_row,
|
||||||
|
new_method,
|
||||||
):
|
):
|
||||||
_logger.debug(
|
_logger.debug(
|
||||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||||
@ -156,21 +163,42 @@ def _write_time_series(
|
|||||||
f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}"
|
f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}"
|
||||||
)
|
)
|
||||||
|
|
||||||
_logger.debug(f"Got seed {time_series_rng_seed}")
|
if new_method:
|
||||||
if time_series_rng_seed is None:
|
_logger.info("Using new method")
|
||||||
time_series = tantri.dipoles.DipoleTimeSeries(
|
_logger.debug(f"Got seed {time_series_rng_seed}")
|
||||||
dipoles, dots, measurement_enum, delta_t
|
if time_series_rng_seed is None:
|
||||||
)
|
time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries(
|
||||||
else:
|
dipoles, dots, measurement_enum, delta_t, num_iterations
|
||||||
rng = numpy.random.default_rng(time_series_rng_seed)
|
)
|
||||||
time_series = tantri.dipoles.DipoleTimeSeries(
|
else:
|
||||||
dipoles, dots, measurement_enum, delta_t, rng
|
rng = numpy.random.default_rng(time_series_rng_seed)
|
||||||
)
|
time_series = tantri.dipoles.event_time_series.EventDipoleTimeSeries(
|
||||||
|
dipoles, dots, measurement_enum, delta_t, num_iterations, rng
|
||||||
|
)
|
||||||
|
output_series = time_series.create_time_series()
|
||||||
|
for time, time_series_dict in output_series:
|
||||||
|
values = ", ".join(str(time_series_dict[label]) for label in labels)
|
||||||
|
out.write(f"{time}, {values}\n")
|
||||||
|
|
||||||
for i in range(num_iterations):
|
else:
|
||||||
transition = time_series.transition()
|
# in the old method
|
||||||
transition_values = ", ".join(str(transition[label]) for label in labels)
|
_logger.debug(f"Got seed {time_series_rng_seed}")
|
||||||
out.write(f"{i * delta_t}, {transition_values}\n")
|
if time_series_rng_seed is None:
|
||||||
|
time_series = tantri.dipoles.DipoleTimeSeries(
|
||||||
|
dipoles, dots, measurement_enum, delta_t
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rng = numpy.random.default_rng(time_series_rng_seed)
|
||||||
|
time_series = tantri.dipoles.DipoleTimeSeries(
|
||||||
|
dipoles, dots, measurement_enum, delta_t, rng
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_iterations):
|
||||||
|
transition = time_series.transition()
|
||||||
|
transition_values = ", ".join(
|
||||||
|
str(transition[label]) for label in labels
|
||||||
|
)
|
||||||
|
out.write(f"{i * delta_t}, {transition_values}\n")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
@ -112,7 +112,6 @@ class EventDipoleTimeSeries:
|
|||||||
_logger.debug(f"Doing dipole {dipole}")
|
_logger.debug(f"Doing dipole {dipole}")
|
||||||
series = dipole.get_time_series(self.dt, self.num_samples, self.rng)
|
series = dipole.get_time_series(self.dt, self.num_samples, self.rng)
|
||||||
for time, meases in series:
|
for time, meases in series:
|
||||||
_logger.debug(f"Working on time {time}, got measures {meases}")
|
|
||||||
if time in collected_dictionary:
|
if time in collected_dictionary:
|
||||||
for k, v in meases.items():
|
for k, v in meases.items():
|
||||||
|
|
||||||
@ -133,8 +132,6 @@ def get_num_events_before(
|
|||||||
event_times: List = []
|
event_times: List = []
|
||||||
random_size = max(1, int(total_time // scale))
|
random_size = max(1, int(total_time // scale))
|
||||||
while sum(event_times) < total_time:
|
while sum(event_times) < total_time:
|
||||||
_logger.debug(sum(event_times))
|
|
||||||
_logger.debug(event_times)
|
|
||||||
event_times.extend(rng.exponential(scale=scale, size=random_size))
|
event_times.extend(rng.exponential(scale=scale, size=random_size))
|
||||||
|
|
||||||
accumulator = 0
|
accumulator = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user