types: fixes tqdm type hints, man tqdm isn't worth it is it
All checks were successful
gitea-physics/tantri/pipeline/head This commit looks good

This commit is contained in:
Deepak Mallubhotla 2024-04-21 15:58:24 -05:00
parent 6334aae223
commit 4449b71be9
Signed by: deepak
GPG Key ID: BEBAEBF28083E022

View File

@ -8,7 +8,7 @@ import tantri.dipoles
import json
import tantri.dipoles.generation
import pathlib
import tqdm
import tqdm # type: ignore
_logger = logging.getLogger(__name__)
@ -39,8 +39,7 @@ def _set_up_logging(filename):
@click.option("--log-file", help="A filename to use for logging (implies --log)")
@click.version_option(tantri.get_version())
def cli(log, log_file):
"""Utilities for generating simulated TLS time series data.
"""
"""Utilities for generating simulated TLS time series data."""
if log or (log_file is not None):
# log file has been provided, let's log
_set_up_logging(log_file)
@ -98,12 +97,17 @@ def cli(log, log_file):
help="The output file to write, in csv format",
required=True,
)
@click.option(
"--header-row/--no-header-row",
default=False,
help="Write a header row"
)
def write_time_series(dipoles_file, dots_file, measurement_type, delta_t, num_iterations, time_series_rng_seed, output_file, header_row):
@click.option("--header-row/--no-header-row", default=False, help="Write a header row")
def write_time_series(
dipoles_file,
dots_file,
measurement_type,
delta_t,
num_iterations,
time_series_rng_seed,
output_file,
header_row,
):
"""
Generate a time series for the passed in parameters.
"""
@ -126,15 +130,21 @@ def write_time_series(dipoles_file, dots_file, measurement_type, delta_t, num_it
value_labels = ", ".join([f"{value_name}_{label}" for label in labels])
output_file.write(f"t (s), {value_labels}\n")
_logger.debug(f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}")
_logger.debug(
f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}"
)
_logger.debug(f"Got seed {time_series_rng_seed}")
if time_series_rng_seed is None:
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t)
time_series = tantri.dipoles.DipoleTimeSeries(
dipoles, dots, measurement_enum, delta_t
)
else:
rng = numpy.random.default_rng(time_series_rng_seed)
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t, rng)
time_series = tantri.dipoles.DipoleTimeSeries(
dipoles, dots, measurement_enum, delta_t, rng
)
for i in tqdm.trange(num_iterations):
transition = time_series.transition()
transition_values = ", ".join(str(transition[label]) for label in labels)