chore: refactoring to help with getting dipole random generation going

This commit is contained in:
Deepak Mallubhotla 2024-04-20 19:29:12 -05:00
parent b5a3f76745
commit 2c9fa5ee87
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
8 changed files with 88 additions and 28 deletions

View File

@ -10,7 +10,7 @@ LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
def _set_up_logging(filename):
handlers = [logging.StreamHandler()]
if filename is not None:
handlers.append(logging.FileHandler(filename)) # type: ignore idk what the typing issue is
handlers.append(logging.FileHandler(filename))
logging.basicConfig(
level=logging.DEBUG,
format=LOG_PATTERN,
@ -21,7 +21,9 @@ def _set_up_logging(filename):
@click.group()
@click.option("--log", help="Enable logging to stream only", is_flag=True, default=False)
@click.option(
"--log", help="Enable logging to stream only", is_flag=True, default=False
)
@click.option("--log-file", help="A filename to use for logging (implies --log)")
@click.version_option(tantri.get_version())
def cli(log, log_file):
@ -32,9 +34,17 @@ def cli(log, log_file):
@cli.command()
@click.option("--dipoles-file", default="dipoles.json", help="Filename containing json array of dipoles")
@click.option("--dots-file", default="dots.json", help="Filename containing json array of dots")
@click.option(
"--dipoles-file",
default="dipoles.json",
help="Filename containing json array of dipoles",
)
@click.option(
"--dots-file", default="dots.json", help="Filename containing json array of dots"
)
def hello(dipoles_file, dots_file):
_logger.info("in hello")
_logger.info(f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]")
_logger.info(
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
)
click.echo("in hello")

View File

@ -16,7 +16,9 @@ def read_data_from_filename(filename: str):
with open(filename, "r") as file:
return json.load(file)
except Exception as e:
_logger.error(f"failed to read the file {filename}, raising and aborting", exc_info=e)
_logger.error(
f"failed to read the file {filename}, raising and aborting", exc_info=e
)
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
@ -24,6 +26,6 @@ def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
return tantri.cli.input_files.rows_to_dots(data)
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]:
data = read_data_from_filename(filename)
return tantri.cli.input_files.rows_to_dots(data)
return tantri.cli.input_files.rows_to_dipoles(data)

View File

@ -1,8 +1,7 @@
from tantri.cli.input_files.read_dots import rows_to_dots
from tantri.cli.input_files.read_dipoles import rows_to_dipoles, DipoleTO
from tantri.cli.input_files.read_dipoles import rows_to_dipoles
__all__ = [
"rows_to_dots",
"rows_to_dipoles",
"DipoleTO",
]

View File

@ -1,27 +1,18 @@
import numpy
from typing import Sequence
from dataclasses import dataclass
# Lazily just separating this from tantri.dipoles.Dipole where there's additional cached stuff, this is just a thing
# we can use as a DTO For dipole info.
@dataclass
class DipoleTO:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# should be 1/tau up to some pis
w: float
from tantri.dipoles import DipoleTO
def row_to_dipole(input_dict: dict) -> DipoleTO:
p = input_dict["p"]
if len(p) != 3:
raise ValueError(f"p parameter in input_dict [{input_dict}] does not have length 3")
raise ValueError(
f"p parameter in input_dict [{input_dict}] does not have length 3"
)
s = input_dict["s"]
if len(s) != 3:
raise ValueError(f"s parameter in input_dict [{input_dict}] does not have length 3")
raise ValueError(
f"s parameter in input_dict [{input_dict}] does not have length 3"
)
w = input_dict["w"]
return DipoleTO(p, s, w)

View File

@ -6,11 +6,15 @@ from typing import Sequence
def row_to_dot(input_dict: dict) -> tantri.dipoles.DotPosition:
r = input_dict["r"]
if len(r) != 3:
raise ValueError(f"r parameter in input_dict [{input_dict}] does not have length 3")
raise ValueError(
f"r parameter in input_dict [{input_dict}] does not have length 3"
)
label = input_dict["label"]
return tantri.dipoles.DotPosition(numpy.array(r), label)
def rows_to_dots(dot_dict_array: Sequence[dict]) -> Sequence[tantri.dipoles.DotPosition]:
def rows_to_dots(
dot_dict_array: Sequence[dict],
) -> Sequence[tantri.dipoles.DotPosition]:
return [row_to_dot(input_dict) for input_dict in dot_dict_array]

View File

@ -3,6 +3,7 @@ import numpy
import numpy.random
import typing
from enum import Enum
from tantri.dipoles.types import DipoleTO
import logging
@ -126,4 +127,5 @@ class DipoleTimeSeries:
__all__ = [
"Dipole",
"DipoleTimeSeries",
"DipoleTO",
]

View File

@ -0,0 +1,38 @@
import numpy
from typing import Sequence
from dataclasses import dataclass
from tantri.dipoles.types import DipoleTO
# stuff for generating random dipoles from parameters
# A description of the parameters needed to generate random dipoles
@dataclass
class DipoleGenerationConfig:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# should be 1/tau up to some pis
w: float
def row_to_dipole(input_dict: dict) -> DipoleTO:
p = input_dict["p"]
if len(p) != 3:
raise ValueError(
f"p parameter in input_dict [{input_dict}] does not have length 3"
)
s = input_dict["s"]
if len(s) != 3:
raise ValueError(
f"s parameter in input_dict [{input_dict}] does not have length 3"
)
w = input_dict["w"]
return DipoleTO(p, s, w)
def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]:
return [row_to_dipole(input_dict) for input_dict in dot_dict_array]

14
tantri/dipoles/types.py Executable file
View File

@ -0,0 +1,14 @@
import numpy
from dataclasses import dataclass
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
# we can use as a DTO For dipole info.
@dataclass
class DipoleTO:
# assumed len 3
p: numpy.ndarray
s: numpy.ndarray
# should be 1/tau up to some pis
w: float