diff --git a/tantri/cli/__init__.py b/tantri/cli/__init__.py index 58bbbe6..92ed83c 100755 --- a/tantri/cli/__init__.py +++ b/tantri/cli/__init__.py @@ -10,7 +10,7 @@ LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s" def _set_up_logging(filename): handlers = [logging.StreamHandler()] if filename is not None: - handlers.append(logging.FileHandler(filename)) # type: ignore idk what the typing issue is + handlers.append(logging.FileHandler(filename)) logging.basicConfig( level=logging.DEBUG, format=LOG_PATTERN, @@ -21,7 +21,9 @@ def _set_up_logging(filename): @click.group() -@click.option("--log", help="Enable logging to stream only", is_flag=True, default=False) +@click.option( + "--log", help="Enable logging to stream only", is_flag=True, default=False +) @click.option("--log-file", help="A filename to use for logging (implies --log)") @click.version_option(tantri.get_version()) def cli(log, log_file): @@ -32,9 +34,17 @@ def cli(log, log_file): @cli.command() -@click.option("--dipoles-file", default="dipoles.json", help="Filename containing json array of dipoles") -@click.option("--dots-file", default="dots.json", help="Filename containing json array of dots") +@click.option( + "--dipoles-file", + default="dipoles.json", + help="Filename containing json array of dipoles", +) +@click.option( + "--dots-file", default="dots.json", help="Filename containing json array of dots" +) def hello(dipoles_file, dots_file): _logger.info("in hello") - _logger.info(f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]") + _logger.info( + f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]" + ) click.echo("in hello") diff --git a/tantri/cli/file_importer.py b/tantri/cli/file_importer.py index 9f12c6c..22de59d 100755 --- a/tantri/cli/file_importer.py +++ b/tantri/cli/file_importer.py @@ -16,7 +16,9 @@ def read_data_from_filename(filename: str): with open(filename, "r") as file: return json.load(file) except Exception as e: - _logger.error(f"failed to read the file {filename}, raising and aborting", exc_info=e) + _logger.error( + f"failed to read the file {filename}, raising and aborting", exc_info=e + ) def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: @@ -24,6 +26,6 @@ def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: return tantri.cli.input_files.rows_to_dots(data) -def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: +def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]: data = read_data_from_filename(filename) - return tantri.cli.input_files.rows_to_dots(data) + return tantri.cli.input_files.rows_to_dipoles(data) diff --git a/tantri/cli/input_files/__init__.py b/tantri/cli/input_files/__init__.py index 8c38a99..208b97c 100755 --- a/tantri/cli/input_files/__init__.py +++ b/tantri/cli/input_files/__init__.py @@ -1,8 +1,7 @@ from tantri.cli.input_files.read_dots import rows_to_dots -from tantri.cli.input_files.read_dipoles import rows_to_dipoles, DipoleTO +from tantri.cli.input_files.read_dipoles import rows_to_dipoles __all__ = [ "rows_to_dots", "rows_to_dipoles", - "DipoleTO", ] diff --git a/tantri/cli/input_files/read_dipoles.py b/tantri/cli/input_files/read_dipoles.py index aa24c89..73bf534 100755 --- a/tantri/cli/input_files/read_dipoles.py +++ b/tantri/cli/input_files/read_dipoles.py @@ -1,27 +1,18 @@ -import numpy from typing import Sequence -from dataclasses import dataclass - - -# Lazily just separating this from tantri.dipoles.Dipole where there's additional cached stuff, this is just a thing -# we can use as a DTO For dipole info. -@dataclass -class DipoleTO: - # assumed len 3 - p: numpy.ndarray - s: numpy.ndarray - - # should be 1/tau up to some pis - w: float +from tantri.dipoles import DipoleTO def row_to_dipole(input_dict: dict) -> DipoleTO: p = input_dict["p"] if len(p) != 3: - raise ValueError(f"p parameter in input_dict [{input_dict}] does not have length 3") + raise ValueError( + f"p parameter in input_dict [{input_dict}] does not have length 3" + ) s = input_dict["s"] if len(s) != 3: - raise ValueError(f"s parameter in input_dict [{input_dict}] does not have length 3") + raise ValueError( + f"s parameter in input_dict [{input_dict}] does not have length 3" + ) w = input_dict["w"] return DipoleTO(p, s, w) diff --git a/tantri/cli/input_files/read_dots.py b/tantri/cli/input_files/read_dots.py index 83a1159..61da95a 100755 --- a/tantri/cli/input_files/read_dots.py +++ b/tantri/cli/input_files/read_dots.py @@ -6,11 +6,15 @@ from typing import Sequence def row_to_dot(input_dict: dict) -> tantri.dipoles.DotPosition: r = input_dict["r"] if len(r) != 3: - raise ValueError(f"r parameter in input_dict [{input_dict}] does not have length 3") + raise ValueError( + f"r parameter in input_dict [{input_dict}] does not have length 3" + ) label = input_dict["label"] return tantri.dipoles.DotPosition(numpy.array(r), label) -def rows_to_dots(dot_dict_array: Sequence[dict]) -> Sequence[tantri.dipoles.DotPosition]: +def rows_to_dots( + dot_dict_array: Sequence[dict], +) -> Sequence[tantri.dipoles.DotPosition]: return [row_to_dot(input_dict) for input_dict in dot_dict_array] diff --git a/tantri/dipoles/__init__.py b/tantri/dipoles/__init__.py index 78b458f..a22c074 100755 --- a/tantri/dipoles/__init__.py +++ b/tantri/dipoles/__init__.py @@ -3,6 +3,7 @@ import numpy import numpy.random import typing from enum import Enum +from tantri.dipoles.types import DipoleTO import logging @@ -126,4 +127,5 @@ class DipoleTimeSeries: __all__ = [ "Dipole", "DipoleTimeSeries", + "DipoleTO", ] diff --git a/tantri/dipoles/generation/__init__.py b/tantri/dipoles/generation/__init__.py new file mode 100755 index 0000000..00fe610 --- /dev/null +++ b/tantri/dipoles/generation/__init__.py @@ -0,0 +1,38 @@ +import numpy +from typing import Sequence +from dataclasses import dataclass +from tantri.dipoles.types import DipoleTO + + +# stuff for generating random dipoles from parameters + + +# A description of the parameters needed to generate random dipoles +@dataclass +class DipoleGenerationConfig: + # assumed len 3 + p: numpy.ndarray + s: numpy.ndarray + + # should be 1/tau up to some pis + w: float + + +def row_to_dipole(input_dict: dict) -> DipoleTO: + p = input_dict["p"] + if len(p) != 3: + raise ValueError( + f"p parameter in input_dict [{input_dict}] does not have length 3" + ) + s = input_dict["s"] + if len(s) != 3: + raise ValueError( + f"s parameter in input_dict [{input_dict}] does not have length 3" + ) + w = input_dict["w"] + + return DipoleTO(p, s, w) + + +def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]: + return [row_to_dipole(input_dict) for input_dict in dot_dict_array] diff --git a/tantri/dipoles/types.py b/tantri/dipoles/types.py new file mode 100755 index 0000000..030a788 --- /dev/null +++ b/tantri/dipoles/types.py @@ -0,0 +1,14 @@ +import numpy +from dataclasses import dataclass + + +# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing +# we can use as a DTO For dipole info. +@dataclass +class DipoleTO: + # assumed len 3 + p: numpy.ndarray + s: numpy.ndarray + + # should be 1/tau up to some pis + w: float