feat!: adds writing of time series and removes some temporary commands from the cli

This commit is contained in:
Deepak Mallubhotla 2024-04-20 22:59:52 -05:00
parent 450f0d5730
commit 4c00fe845d
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
3 changed files with 100 additions and 9 deletions

20
poetry.lock generated
View File

@ -461,6 +461,23 @@ category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
[[package]]
name = "tqdm"
version = "4.66.2"
description = "Fast, Extensible Progress Meter"
category = "main"
optional = false
python-versions = ">=3.7"
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]] [[package]]
name = "traitlets" name = "traitlets"
version = "5.13.0" version = "5.13.0"
@ -492,7 +509,7 @@ python-versions = "*"
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8.1,<3.10" python-versions = ">=3.8.1,<3.10"
content-hash = "66c63e7a7a61bc34880f47e2568e2407f330c2691dabc0dace9aa15f03a97531" content-hash = "ba0e7fa989baea42343639637136dd80ce087eb7bfcfb32da50c9ddffc46ad71"
[metadata.files] [metadata.files]
appnope = [] appnope = []
@ -534,6 +551,7 @@ six = []
stack-data = [] stack-data = []
syrupy = [] syrupy = []
tomli = [] tomli = []
tqdm = []
traitlets = [] traitlets = []
typing-extensions = [] typing-extensions = []
wcwidth = [] wcwidth = []

View File

@ -11,6 +11,7 @@ python = ">=3.8.1,<3.10"
numpy = "^1.22.3" numpy = "^1.22.3"
scipy = "~1.10" scipy = "~1.10"
click = "^8.1.7" click = "^8.1.7"
tqdm = "^4.66.2"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = ">=6" pytest = ">=6"

View File

@ -4,15 +4,21 @@ import tantri
import numpy import numpy
import tantri.cli.input_files.write_dipoles import tantri.cli.input_files.write_dipoles
import tantri.cli.file_importer import tantri.cli.file_importer
import tantri.dipoles
import json import json
import tantri.dipoles.generation import tantri.dipoles.generation
import pathlib import pathlib
import tqdm
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s" LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
POTENTIAL = "electric-potential"
X_ELECTRIC_FIELD = "x-electric-field"
def _set_up_logging(filename): def _set_up_logging(filename):
handlers = [logging.StreamHandler()] handlers = [logging.StreamHandler()]
if filename is not None: if filename is not None:
@ -53,18 +59,84 @@ def cli(log, log_file):
type=click.Path(exists=True, path_type=pathlib.Path), type=click.Path(exists=True, path_type=pathlib.Path),
help="File with json array of dots", help="File with json array of dots",
) )
def read_files(dipoles_file, dots_file): @click.option(
click.echo("in hello") "--measurement-type",
_logger.info("in hello") type=click.Choice([POTENTIAL, X_ELECTRIC_FIELD]),
_logger.info( default=POTENTIAL,
help="The type of measurement to simulate",
show_default=True,
)
@click.option(
"--delta-t",
"-t",
type=float,
default=1,
help="The delta t between time series iterations.",
show_default=True,
)
@click.option(
"--num-iterations",
"-n",
type=int,
default=10,
help="The number of iterations.",
show_default=True,
)
@click.option(
"--time-series-rng-seed",
"-s",
type=int,
default=None,
help="A seed to use to create an override default rng. You should set this.",
)
@click.option(
"output_file",
"-o",
type=click.File("w"),
help="The output file to write, in csv format",
required=True,
)
@click.option(
"--header-row/--no-header-row",
default=False,
help="Write a header row"
)
def write_time_series(dipoles_file, dots_file, measurement_type, delta_t, num_iterations, time_series_rng_seed, output_file, header_row):
"""
Generate a time series for the passed in parameters.
"""
_logger.debug(
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]" f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
) )
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file) dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
dots = tantri.cli.file_importer.read_dots_json_file(dots_file) dots = tantri.cli.file_importer.read_dots_json_file(dots_file)
for dipole in dipoles:
_logger.info(dipole) if measurement_type == POTENTIAL:
for dot in dots: measurement_enum = tantri.dipoles.DipoleMeasurementType.ELECTRIC_POTENTIAL
_logger.info(dot) value_name = "V"
elif measurement_type == X_ELECTRIC_FIELD:
measurement_enum = tantri.dipoles.DipoleMeasurementType.X_ELECTRIC_FIELD
value_name = "Ex"
_logger.debug(f"Using measurement {measurement_enum.name}")
labels = [dot.label for dot in dots]
if header_row:
value_labels = ", ".join([f"{value_name}_{label}" for label in labels])
click.echo(f"t (s), {value_labels}")
_logger.debug(f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}")
_logger.debug(f"Got seed {time_series_rng_seed}")
if time_series_rng_seed is None:
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t)
else:
rng = numpy.random.default_rng(time_series_rng_seed)
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t, rng)
for i in tqdm.trange(num_iterations):
transition = time_series.transition()
transition_values = ", ".join(str(transition[label]) for label in labels)
output_file.write(f"{i * delta_t}, {transition_values}")
@cli.command() @cli.command()