feat!: adds writing of time series and removes some temporary commands from the cli
This commit is contained in:
parent
450f0d5730
commit
4c00fe845d
20
poetry.lock
generated
20
poetry.lock
generated
@ -461,6 +461,23 @@ category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.66.2"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
|
||||
notebook = ["ipywidgets (>=6)"]
|
||||
slack = ["slack-sdk"]
|
||||
telegram = ["requests"]
|
||||
|
||||
[[package]]
|
||||
name = "traitlets"
|
||||
version = "5.13.0"
|
||||
@ -492,7 +509,7 @@ python-versions = "*"
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8.1,<3.10"
|
||||
content-hash = "66c63e7a7a61bc34880f47e2568e2407f330c2691dabc0dace9aa15f03a97531"
|
||||
content-hash = "ba0e7fa989baea42343639637136dd80ce087eb7bfcfb32da50c9ddffc46ad71"
|
||||
|
||||
[metadata.files]
|
||||
appnope = []
|
||||
@ -534,6 +551,7 @@ six = []
|
||||
stack-data = []
|
||||
syrupy = []
|
||||
tomli = []
|
||||
tqdm = []
|
||||
traitlets = []
|
||||
typing-extensions = []
|
||||
wcwidth = []
|
||||
|
@ -11,6 +11,7 @@ python = ">=3.8.1,<3.10"
|
||||
numpy = "^1.22.3"
|
||||
scipy = "~1.10"
|
||||
click = "^8.1.7"
|
||||
tqdm = "^4.66.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = ">=6"
|
||||
|
@ -4,15 +4,21 @@ import tantri
|
||||
import numpy
|
||||
import tantri.cli.input_files.write_dipoles
|
||||
import tantri.cli.file_importer
|
||||
import tantri.dipoles
|
||||
import json
|
||||
import tantri.dipoles.generation
|
||||
import pathlib
|
||||
import tqdm
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
|
||||
|
||||
|
||||
POTENTIAL = "electric-potential"
|
||||
X_ELECTRIC_FIELD = "x-electric-field"
|
||||
|
||||
|
||||
def _set_up_logging(filename):
|
||||
handlers = [logging.StreamHandler()]
|
||||
if filename is not None:
|
||||
@ -53,18 +59,84 @@ def cli(log, log_file):
|
||||
type=click.Path(exists=True, path_type=pathlib.Path),
|
||||
help="File with json array of dots",
|
||||
)
|
||||
def read_files(dipoles_file, dots_file):
|
||||
click.echo("in hello")
|
||||
_logger.info("in hello")
|
||||
_logger.info(
|
||||
@click.option(
|
||||
"--measurement-type",
|
||||
type=click.Choice([POTENTIAL, X_ELECTRIC_FIELD]),
|
||||
default=POTENTIAL,
|
||||
help="The type of measurement to simulate",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--delta-t",
|
||||
"-t",
|
||||
type=float,
|
||||
default=1,
|
||||
help="The delta t between time series iterations.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--num-iterations",
|
||||
"-n",
|
||||
type=int,
|
||||
default=10,
|
||||
help="The number of iterations.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--time-series-rng-seed",
|
||||
"-s",
|
||||
type=int,
|
||||
default=None,
|
||||
help="A seed to use to create an override default rng. You should set this.",
|
||||
)
|
||||
@click.option(
|
||||
"output_file",
|
||||
"-o",
|
||||
type=click.File("w"),
|
||||
help="The output file to write, in csv format",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--header-row/--no-header-row",
|
||||
default=False,
|
||||
help="Write a header row"
|
||||
)
|
||||
def write_time_series(dipoles_file, dots_file, measurement_type, delta_t, num_iterations, time_series_rng_seed, output_file, header_row):
|
||||
"""
|
||||
Generate a time series for the passed in parameters.
|
||||
"""
|
||||
_logger.debug(
|
||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||
)
|
||||
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
|
||||
dots = tantri.cli.file_importer.read_dots_json_file(dots_file)
|
||||
for dipole in dipoles:
|
||||
_logger.info(dipole)
|
||||
for dot in dots:
|
||||
_logger.info(dot)
|
||||
|
||||
if measurement_type == POTENTIAL:
|
||||
measurement_enum = tantri.dipoles.DipoleMeasurementType.ELECTRIC_POTENTIAL
|
||||
value_name = "V"
|
||||
elif measurement_type == X_ELECTRIC_FIELD:
|
||||
measurement_enum = tantri.dipoles.DipoleMeasurementType.X_ELECTRIC_FIELD
|
||||
value_name = "Ex"
|
||||
|
||||
_logger.debug(f"Using measurement {measurement_enum.name}")
|
||||
labels = [dot.label for dot in dots]
|
||||
if header_row:
|
||||
value_labels = ", ".join([f"{value_name}_{label}" for label in labels])
|
||||
click.echo(f"t (s), {value_labels}")
|
||||
|
||||
_logger.debug(f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}")
|
||||
|
||||
_logger.debug(f"Got seed {time_series_rng_seed}")
|
||||
if time_series_rng_seed is None:
|
||||
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t)
|
||||
else:
|
||||
rng = numpy.random.default_rng(time_series_rng_seed)
|
||||
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t, rng)
|
||||
|
||||
for i in tqdm.trange(num_iterations):
|
||||
transition = time_series.transition()
|
||||
transition_values = ", ".join(str(transition[label]) for label in labels)
|
||||
output_file.write(f"{i * delta_t}, {transition_values}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
Loading…
x
Reference in New Issue
Block a user