chore: refactoring to help with getting dipole random generation going
This commit is contained in:
parent
b5a3f76745
commit
2c9fa5ee87
@ -10,7 +10,7 @@ LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
|
||||
def _set_up_logging(filename):
|
||||
handlers = [logging.StreamHandler()]
|
||||
if filename is not None:
|
||||
handlers.append(logging.FileHandler(filename)) # type: ignore idk what the typing issue is
|
||||
handlers.append(logging.FileHandler(filename))
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format=LOG_PATTERN,
|
||||
@ -21,7 +21,9 @@ def _set_up_logging(filename):
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option("--log", help="Enable logging to stream only", is_flag=True, default=False)
|
||||
@click.option(
|
||||
"--log", help="Enable logging to stream only", is_flag=True, default=False
|
||||
)
|
||||
@click.option("--log-file", help="A filename to use for logging (implies --log)")
|
||||
@click.version_option(tantri.get_version())
|
||||
def cli(log, log_file):
|
||||
@ -32,9 +34,17 @@ def cli(log, log_file):
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--dipoles-file", default="dipoles.json", help="Filename containing json array of dipoles")
|
||||
@click.option("--dots-file", default="dots.json", help="Filename containing json array of dots")
|
||||
@click.option(
|
||||
"--dipoles-file",
|
||||
default="dipoles.json",
|
||||
help="Filename containing json array of dipoles",
|
||||
)
|
||||
@click.option(
|
||||
"--dots-file", default="dots.json", help="Filename containing json array of dots"
|
||||
)
|
||||
def hello(dipoles_file, dots_file):
|
||||
_logger.info("in hello")
|
||||
_logger.info(f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]")
|
||||
_logger.info(
|
||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||
)
|
||||
click.echo("in hello")
|
||||
|
@ -16,7 +16,9 @@ def read_data_from_filename(filename: str):
|
||||
with open(filename, "r") as file:
|
||||
return json.load(file)
|
||||
except Exception as e:
|
||||
_logger.error(f"failed to read the file {filename}, raising and aborting", exc_info=e)
|
||||
_logger.error(
|
||||
f"failed to read the file {filename}, raising and aborting", exc_info=e
|
||||
)
|
||||
|
||||
|
||||
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
||||
@ -24,6 +26,6 @@ def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
||||
return tantri.cli.input_files.rows_to_dots(data)
|
||||
|
||||
|
||||
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
||||
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]:
|
||||
data = read_data_from_filename(filename)
|
||||
return tantri.cli.input_files.rows_to_dots(data)
|
||||
return tantri.cli.input_files.rows_to_dipoles(data)
|
||||
|
@ -1,8 +1,7 @@
|
||||
from tantri.cli.input_files.read_dots import rows_to_dots
|
||||
from tantri.cli.input_files.read_dipoles import rows_to_dipoles, DipoleTO
|
||||
from tantri.cli.input_files.read_dipoles import rows_to_dipoles
|
||||
|
||||
__all__ = [
|
||||
"rows_to_dots",
|
||||
"rows_to_dipoles",
|
||||
"DipoleTO",
|
||||
]
|
||||
|
@ -1,27 +1,18 @@
|
||||
import numpy
|
||||
from typing import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# Lazily just separating this from tantri.dipoles.Dipole where there's additional cached stuff, this is just a thing
|
||||
# we can use as a DTO For dipole info.
|
||||
@dataclass
|
||||
class DipoleTO:
|
||||
# assumed len 3
|
||||
p: numpy.ndarray
|
||||
s: numpy.ndarray
|
||||
|
||||
# should be 1/tau up to some pis
|
||||
w: float
|
||||
from tantri.dipoles import DipoleTO
|
||||
|
||||
|
||||
def row_to_dipole(input_dict: dict) -> DipoleTO:
|
||||
p = input_dict["p"]
|
||||
if len(p) != 3:
|
||||
raise ValueError(f"p parameter in input_dict [{input_dict}] does not have length 3")
|
||||
raise ValueError(
|
||||
f"p parameter in input_dict [{input_dict}] does not have length 3"
|
||||
)
|
||||
s = input_dict["s"]
|
||||
if len(s) != 3:
|
||||
raise ValueError(f"s parameter in input_dict [{input_dict}] does not have length 3")
|
||||
raise ValueError(
|
||||
f"s parameter in input_dict [{input_dict}] does not have length 3"
|
||||
)
|
||||
w = input_dict["w"]
|
||||
|
||||
return DipoleTO(p, s, w)
|
||||
|
@ -6,11 +6,15 @@ from typing import Sequence
|
||||
def row_to_dot(input_dict: dict) -> tantri.dipoles.DotPosition:
|
||||
r = input_dict["r"]
|
||||
if len(r) != 3:
|
||||
raise ValueError(f"r parameter in input_dict [{input_dict}] does not have length 3")
|
||||
raise ValueError(
|
||||
f"r parameter in input_dict [{input_dict}] does not have length 3"
|
||||
)
|
||||
label = input_dict["label"]
|
||||
|
||||
return tantri.dipoles.DotPosition(numpy.array(r), label)
|
||||
|
||||
|
||||
def rows_to_dots(dot_dict_array: Sequence[dict]) -> Sequence[tantri.dipoles.DotPosition]:
|
||||
def rows_to_dots(
|
||||
dot_dict_array: Sequence[dict],
|
||||
) -> Sequence[tantri.dipoles.DotPosition]:
|
||||
return [row_to_dot(input_dict) for input_dict in dot_dict_array]
|
||||
|
@ -3,6 +3,7 @@ import numpy
|
||||
import numpy.random
|
||||
import typing
|
||||
from enum import Enum
|
||||
from tantri.dipoles.types import DipoleTO
|
||||
|
||||
import logging
|
||||
|
||||
@ -126,4 +127,5 @@ class DipoleTimeSeries:
|
||||
__all__ = [
|
||||
"Dipole",
|
||||
"DipoleTimeSeries",
|
||||
"DipoleTO",
|
||||
]
|
||||
|
38
tantri/dipoles/generation/__init__.py
Executable file
38
tantri/dipoles/generation/__init__.py
Executable file
@ -0,0 +1,38 @@
|
||||
import numpy
|
||||
from typing import Sequence
|
||||
from dataclasses import dataclass
|
||||
from tantri.dipoles.types import DipoleTO
|
||||
|
||||
|
||||
# stuff for generating random dipoles from parameters
|
||||
|
||||
|
||||
# A description of the parameters needed to generate random dipoles
|
||||
@dataclass
|
||||
class DipoleGenerationConfig:
|
||||
# assumed len 3
|
||||
p: numpy.ndarray
|
||||
s: numpy.ndarray
|
||||
|
||||
# should be 1/tau up to some pis
|
||||
w: float
|
||||
|
||||
|
||||
def row_to_dipole(input_dict: dict) -> DipoleTO:
|
||||
p = input_dict["p"]
|
||||
if len(p) != 3:
|
||||
raise ValueError(
|
||||
f"p parameter in input_dict [{input_dict}] does not have length 3"
|
||||
)
|
||||
s = input_dict["s"]
|
||||
if len(s) != 3:
|
||||
raise ValueError(
|
||||
f"s parameter in input_dict [{input_dict}] does not have length 3"
|
||||
)
|
||||
w = input_dict["w"]
|
||||
|
||||
return DipoleTO(p, s, w)
|
||||
|
||||
|
||||
def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]:
|
||||
return [row_to_dipole(input_dict) for input_dict in dot_dict_array]
|
14
tantri/dipoles/types.py
Executable file
14
tantri/dipoles/types.py
Executable file
@ -0,0 +1,14 @@
|
||||
import numpy
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
|
||||
# we can use as a DTO For dipole info.
|
||||
@dataclass
|
||||
class DipoleTO:
|
||||
# assumed len 3
|
||||
p: numpy.ndarray
|
||||
s: numpy.ndarray
|
||||
|
||||
# should be 1/tau up to some pis
|
||||
w: float
|
Loading…
x
Reference in New Issue
Block a user