diff --git a/poetry.lock b/poetry.lock index 2d92d40..e445af7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -461,6 +461,23 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "tqdm" +version = "4.66.2" +description = "Fast, Extensible Progress Meter" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "traitlets" version = "5.13.0" @@ -492,7 +509,7 @@ python-versions = "*" [metadata] lock-version = "1.1" python-versions = ">=3.8.1,<3.10" -content-hash = "66c63e7a7a61bc34880f47e2568e2407f330c2691dabc0dace9aa15f03a97531" +content-hash = "ba0e7fa989baea42343639637136dd80ce087eb7bfcfb32da50c9ddffc46ad71" [metadata.files] appnope = [] @@ -534,6 +551,7 @@ six = [] stack-data = [] syrupy = [] tomli = [] +tqdm = [] traitlets = [] typing-extensions = [] wcwidth = [] diff --git a/pyproject.toml b/pyproject.toml index 1f1d0ff..b025eef 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ python = ">=3.8.1,<3.10" numpy = "^1.22.3" scipy = "~1.10" click = "^8.1.7" +tqdm = "^4.66.2" [tool.poetry.dev-dependencies] pytest = ">=6" diff --git a/tantri/cli/__init__.py b/tantri/cli/__init__.py index 658e50c..3b0d73b 100755 --- a/tantri/cli/__init__.py +++ b/tantri/cli/__init__.py @@ -4,15 +4,21 @@ import tantri import numpy import tantri.cli.input_files.write_dipoles import tantri.cli.file_importer +import tantri.dipoles import json import tantri.dipoles.generation import pathlib +import tqdm _logger = logging.getLogger(__name__) LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s" +POTENTIAL = "electric-potential" +X_ELECTRIC_FIELD = "x-electric-field" + + def _set_up_logging(filename): handlers = [logging.StreamHandler()] if filename is not None: @@ -53,18 +59,84 @@ def cli(log, log_file): type=click.Path(exists=True, path_type=pathlib.Path), help="File with json array of dots", ) -def read_files(dipoles_file, dots_file): - click.echo("in hello") - _logger.info("in hello") - _logger.info( +@click.option( + "--measurement-type", + type=click.Choice([POTENTIAL, X_ELECTRIC_FIELD]), + default=POTENTIAL, + help="The type of measurement to simulate", + show_default=True, +) +@click.option( + "--delta-t", + "-t", + type=float, + default=1, + help="The delta t between time series iterations.", + show_default=True, +) +@click.option( + "--num-iterations", + "-n", + type=int, + default=10, + help="The number of iterations.", + show_default=True, +) +@click.option( + "--time-series-rng-seed", + "-s", + type=int, + default=None, + help="A seed to use to create an override default rng. You should set this.", +) +@click.option( + "output_file", + "-o", + type=click.File("w"), + help="The output file to write, in csv format", + required=True, +) +@click.option( + "--header-row/--no-header-row", + default=False, + help="Write a header row" +) +def write_time_series(dipoles_file, dots_file, measurement_type, delta_t, num_iterations, time_series_rng_seed, output_file, header_row): + """ + Generate a time series for the passed in parameters. + """ + _logger.debug( f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]" ) dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file) dots = tantri.cli.file_importer.read_dots_json_file(dots_file) - for dipole in dipoles: - _logger.info(dipole) - for dot in dots: - _logger.info(dot) + + if measurement_type == POTENTIAL: + measurement_enum = tantri.dipoles.DipoleMeasurementType.ELECTRIC_POTENTIAL + value_name = "V" + elif measurement_type == X_ELECTRIC_FIELD: + measurement_enum = tantri.dipoles.DipoleMeasurementType.X_ELECTRIC_FIELD + value_name = "Ex" + + _logger.debug(f"Using measurement {measurement_enum.name}") + labels = [dot.label for dot in dots] + if header_row: + value_labels = ", ".join([f"{value_name}_{label}" for label in labels]) + click.echo(f"t (s), {value_labels}") + + _logger.debug(f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}") + + _logger.debug(f"Got seed {time_series_rng_seed}") + if time_series_rng_seed is None: + time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t) + else: + rng = numpy.random.default_rng(time_series_rng_seed) + time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t, rng) + + for i in tqdm.trange(num_iterations): + transition = time_series.transition() + transition_values = ", ".join(str(transition[label]) for label in labels) + output_file.write(f"{i * delta_t}, {transition_values}") @cli.command()