feat!: adds writing of time series and removes some temporary commands from the cli
This commit is contained in:
parent
450f0d5730
commit
4c00fe845d
20
poetry.lock
generated
20
poetry.lock
generated
@ -461,6 +461,23 @@ category = "dev"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tqdm"
|
||||||
|
version = "4.66.2"
|
||||||
|
description = "Fast, Extensible Progress Meter"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
|
||||||
|
notebook = ["ipywidgets (>=6)"]
|
||||||
|
slack = ["slack-sdk"]
|
||||||
|
telegram = ["requests"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "traitlets"
|
name = "traitlets"
|
||||||
version = "5.13.0"
|
version = "5.13.0"
|
||||||
@ -492,7 +509,7 @@ python-versions = "*"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = ">=3.8.1,<3.10"
|
python-versions = ">=3.8.1,<3.10"
|
||||||
content-hash = "66c63e7a7a61bc34880f47e2568e2407f330c2691dabc0dace9aa15f03a97531"
|
content-hash = "ba0e7fa989baea42343639637136dd80ce087eb7bfcfb32da50c9ddffc46ad71"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
appnope = []
|
appnope = []
|
||||||
@ -534,6 +551,7 @@ six = []
|
|||||||
stack-data = []
|
stack-data = []
|
||||||
syrupy = []
|
syrupy = []
|
||||||
tomli = []
|
tomli = []
|
||||||
|
tqdm = []
|
||||||
traitlets = []
|
traitlets = []
|
||||||
typing-extensions = []
|
typing-extensions = []
|
||||||
wcwidth = []
|
wcwidth = []
|
||||||
|
@ -11,6 +11,7 @@ python = ">=3.8.1,<3.10"
|
|||||||
numpy = "^1.22.3"
|
numpy = "^1.22.3"
|
||||||
scipy = "~1.10"
|
scipy = "~1.10"
|
||||||
click = "^8.1.7"
|
click = "^8.1.7"
|
||||||
|
tqdm = "^4.66.2"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = ">=6"
|
pytest = ">=6"
|
||||||
|
@ -4,15 +4,21 @@ import tantri
|
|||||||
import numpy
|
import numpy
|
||||||
import tantri.cli.input_files.write_dipoles
|
import tantri.cli.input_files.write_dipoles
|
||||||
import tantri.cli.file_importer
|
import tantri.cli.file_importer
|
||||||
|
import tantri.dipoles
|
||||||
import json
|
import json
|
||||||
import tantri.dipoles.generation
|
import tantri.dipoles.generation
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import tqdm
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
|
LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
|
||||||
|
|
||||||
|
|
||||||
|
POTENTIAL = "electric-potential"
|
||||||
|
X_ELECTRIC_FIELD = "x-electric-field"
|
||||||
|
|
||||||
|
|
||||||
def _set_up_logging(filename):
|
def _set_up_logging(filename):
|
||||||
handlers = [logging.StreamHandler()]
|
handlers = [logging.StreamHandler()]
|
||||||
if filename is not None:
|
if filename is not None:
|
||||||
@ -53,18 +59,84 @@ def cli(log, log_file):
|
|||||||
type=click.Path(exists=True, path_type=pathlib.Path),
|
type=click.Path(exists=True, path_type=pathlib.Path),
|
||||||
help="File with json array of dots",
|
help="File with json array of dots",
|
||||||
)
|
)
|
||||||
def read_files(dipoles_file, dots_file):
|
@click.option(
|
||||||
click.echo("in hello")
|
"--measurement-type",
|
||||||
_logger.info("in hello")
|
type=click.Choice([POTENTIAL, X_ELECTRIC_FIELD]),
|
||||||
_logger.info(
|
default=POTENTIAL,
|
||||||
|
help="The type of measurement to simulate",
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--delta-t",
|
||||||
|
"-t",
|
||||||
|
type=float,
|
||||||
|
default=1,
|
||||||
|
help="The delta t between time series iterations.",
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-iterations",
|
||||||
|
"-n",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="The number of iterations.",
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--time-series-rng-seed",
|
||||||
|
"-s",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="A seed to use to create an override default rng. You should set this.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"output_file",
|
||||||
|
"-o",
|
||||||
|
type=click.File("w"),
|
||||||
|
help="The output file to write, in csv format",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--header-row/--no-header-row",
|
||||||
|
default=False,
|
||||||
|
help="Write a header row"
|
||||||
|
)
|
||||||
|
def write_time_series(dipoles_file, dots_file, measurement_type, delta_t, num_iterations, time_series_rng_seed, output_file, header_row):
|
||||||
|
"""
|
||||||
|
Generate a time series for the passed in parameters.
|
||||||
|
"""
|
||||||
|
_logger.debug(
|
||||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||||
)
|
)
|
||||||
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
|
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
|
||||||
dots = tantri.cli.file_importer.read_dots_json_file(dots_file)
|
dots = tantri.cli.file_importer.read_dots_json_file(dots_file)
|
||||||
for dipole in dipoles:
|
|
||||||
_logger.info(dipole)
|
if measurement_type == POTENTIAL:
|
||||||
for dot in dots:
|
measurement_enum = tantri.dipoles.DipoleMeasurementType.ELECTRIC_POTENTIAL
|
||||||
_logger.info(dot)
|
value_name = "V"
|
||||||
|
elif measurement_type == X_ELECTRIC_FIELD:
|
||||||
|
measurement_enum = tantri.dipoles.DipoleMeasurementType.X_ELECTRIC_FIELD
|
||||||
|
value_name = "Ex"
|
||||||
|
|
||||||
|
_logger.debug(f"Using measurement {measurement_enum.name}")
|
||||||
|
labels = [dot.label for dot in dots]
|
||||||
|
if header_row:
|
||||||
|
value_labels = ", ".join([f"{value_name}_{label}" for label in labels])
|
||||||
|
click.echo(f"t (s), {value_labels}")
|
||||||
|
|
||||||
|
_logger.debug(f"Going to simulate {num_iterations} iterations with a delta t of {delta_t}")
|
||||||
|
|
||||||
|
_logger.debug(f"Got seed {time_series_rng_seed}")
|
||||||
|
if time_series_rng_seed is None:
|
||||||
|
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t)
|
||||||
|
else:
|
||||||
|
rng = numpy.random.default_rng(time_series_rng_seed)
|
||||||
|
time_series = tantri.dipoles.DipoleTimeSeries(dipoles, dots, measurement_enum, delta_t, rng)
|
||||||
|
|
||||||
|
for i in tqdm.trange(num_iterations):
|
||||||
|
transition = time_series.transition()
|
||||||
|
transition_values = ", ".join(str(transition[label]) for label in labels)
|
||||||
|
output_file.write(f"{i * delta_t}, {transition_values}")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user