chore: refactoring to help with getting dipole random generation going
This commit is contained in:
parent
b5a3f76745
commit
2c9fa5ee87
@ -10,7 +10,7 @@ LOG_PATTERN = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
|
|||||||
def _set_up_logging(filename):
|
def _set_up_logging(filename):
|
||||||
handlers = [logging.StreamHandler()]
|
handlers = [logging.StreamHandler()]
|
||||||
if filename is not None:
|
if filename is not None:
|
||||||
handlers.append(logging.FileHandler(filename)) # type: ignore idk what the typing issue is
|
handlers.append(logging.FileHandler(filename))
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
format=LOG_PATTERN,
|
format=LOG_PATTERN,
|
||||||
@ -21,7 +21,9 @@ def _set_up_logging(filename):
|
|||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.option("--log", help="Enable logging to stream only", is_flag=True, default=False)
|
@click.option(
|
||||||
|
"--log", help="Enable logging to stream only", is_flag=True, default=False
|
||||||
|
)
|
||||||
@click.option("--log-file", help="A filename to use for logging (implies --log)")
|
@click.option("--log-file", help="A filename to use for logging (implies --log)")
|
||||||
@click.version_option(tantri.get_version())
|
@click.version_option(tantri.get_version())
|
||||||
def cli(log, log_file):
|
def cli(log, log_file):
|
||||||
@ -32,9 +34,17 @@ def cli(log, log_file):
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("--dipoles-file", default="dipoles.json", help="Filename containing json array of dipoles")
|
@click.option(
|
||||||
@click.option("--dots-file", default="dots.json", help="Filename containing json array of dots")
|
"--dipoles-file",
|
||||||
|
default="dipoles.json",
|
||||||
|
help="Filename containing json array of dipoles",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dots-file", default="dots.json", help="Filename containing json array of dots"
|
||||||
|
)
|
||||||
def hello(dipoles_file, dots_file):
|
def hello(dipoles_file, dots_file):
|
||||||
_logger.info("in hello")
|
_logger.info("in hello")
|
||||||
_logger.info(f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]")
|
_logger.info(
|
||||||
|
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||||
|
)
|
||||||
click.echo("in hello")
|
click.echo("in hello")
|
||||||
|
@ -16,7 +16,9 @@ def read_data_from_filename(filename: str):
|
|||||||
with open(filename, "r") as file:
|
with open(filename, "r") as file:
|
||||||
return json.load(file)
|
return json.load(file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.error(f"failed to read the file {filename}, raising and aborting", exc_info=e)
|
_logger.error(
|
||||||
|
f"failed to read the file {filename}, raising and aborting", exc_info=e
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
||||||
@ -24,6 +26,6 @@ def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
|||||||
return tantri.cli.input_files.rows_to_dots(data)
|
return tantri.cli.input_files.rows_to_dots(data)
|
||||||
|
|
||||||
|
|
||||||
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]:
|
||||||
data = read_data_from_filename(filename)
|
data = read_data_from_filename(filename)
|
||||||
return tantri.cli.input_files.rows_to_dots(data)
|
return tantri.cli.input_files.rows_to_dipoles(data)
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from tantri.cli.input_files.read_dots import rows_to_dots
|
from tantri.cli.input_files.read_dots import rows_to_dots
|
||||||
from tantri.cli.input_files.read_dipoles import rows_to_dipoles, DipoleTO
|
from tantri.cli.input_files.read_dipoles import rows_to_dipoles
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"rows_to_dots",
|
"rows_to_dots",
|
||||||
"rows_to_dipoles",
|
"rows_to_dipoles",
|
||||||
"DipoleTO",
|
|
||||||
]
|
]
|
||||||
|
@ -1,27 +1,18 @@
|
|||||||
import numpy
|
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from dataclasses import dataclass
|
from tantri.dipoles import DipoleTO
|
||||||
|
|
||||||
|
|
||||||
# Lazily just separating this from tantri.dipoles.Dipole where there's additional cached stuff, this is just a thing
|
|
||||||
# we can use as a DTO For dipole info.
|
|
||||||
@dataclass
|
|
||||||
class DipoleTO:
|
|
||||||
# assumed len 3
|
|
||||||
p: numpy.ndarray
|
|
||||||
s: numpy.ndarray
|
|
||||||
|
|
||||||
# should be 1/tau up to some pis
|
|
||||||
w: float
|
|
||||||
|
|
||||||
|
|
||||||
def row_to_dipole(input_dict: dict) -> DipoleTO:
|
def row_to_dipole(input_dict: dict) -> DipoleTO:
|
||||||
p = input_dict["p"]
|
p = input_dict["p"]
|
||||||
if len(p) != 3:
|
if len(p) != 3:
|
||||||
raise ValueError(f"p parameter in input_dict [{input_dict}] does not have length 3")
|
raise ValueError(
|
||||||
|
f"p parameter in input_dict [{input_dict}] does not have length 3"
|
||||||
|
)
|
||||||
s = input_dict["s"]
|
s = input_dict["s"]
|
||||||
if len(s) != 3:
|
if len(s) != 3:
|
||||||
raise ValueError(f"s parameter in input_dict [{input_dict}] does not have length 3")
|
raise ValueError(
|
||||||
|
f"s parameter in input_dict [{input_dict}] does not have length 3"
|
||||||
|
)
|
||||||
w = input_dict["w"]
|
w = input_dict["w"]
|
||||||
|
|
||||||
return DipoleTO(p, s, w)
|
return DipoleTO(p, s, w)
|
||||||
|
@ -6,11 +6,15 @@ from typing import Sequence
|
|||||||
def row_to_dot(input_dict: dict) -> tantri.dipoles.DotPosition:
|
def row_to_dot(input_dict: dict) -> tantri.dipoles.DotPosition:
|
||||||
r = input_dict["r"]
|
r = input_dict["r"]
|
||||||
if len(r) != 3:
|
if len(r) != 3:
|
||||||
raise ValueError(f"r parameter in input_dict [{input_dict}] does not have length 3")
|
raise ValueError(
|
||||||
|
f"r parameter in input_dict [{input_dict}] does not have length 3"
|
||||||
|
)
|
||||||
label = input_dict["label"]
|
label = input_dict["label"]
|
||||||
|
|
||||||
return tantri.dipoles.DotPosition(numpy.array(r), label)
|
return tantri.dipoles.DotPosition(numpy.array(r), label)
|
||||||
|
|
||||||
|
|
||||||
def rows_to_dots(dot_dict_array: Sequence[dict]) -> Sequence[tantri.dipoles.DotPosition]:
|
def rows_to_dots(
|
||||||
|
dot_dict_array: Sequence[dict],
|
||||||
|
) -> Sequence[tantri.dipoles.DotPosition]:
|
||||||
return [row_to_dot(input_dict) for input_dict in dot_dict_array]
|
return [row_to_dot(input_dict) for input_dict in dot_dict_array]
|
||||||
|
@ -3,6 +3,7 @@ import numpy
|
|||||||
import numpy.random
|
import numpy.random
|
||||||
import typing
|
import typing
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from tantri.dipoles.types import DipoleTO
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -126,4 +127,5 @@ class DipoleTimeSeries:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Dipole",
|
"Dipole",
|
||||||
"DipoleTimeSeries",
|
"DipoleTimeSeries",
|
||||||
|
"DipoleTO",
|
||||||
]
|
]
|
||||||
|
38
tantri/dipoles/generation/__init__.py
Executable file
38
tantri/dipoles/generation/__init__.py
Executable file
@ -0,0 +1,38 @@
|
|||||||
|
import numpy
|
||||||
|
from typing import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from tantri.dipoles.types import DipoleTO
|
||||||
|
|
||||||
|
|
||||||
|
# stuff for generating random dipoles from parameters
|
||||||
|
|
||||||
|
|
||||||
|
# A description of the parameters needed to generate random dipoles
|
||||||
|
@dataclass
|
||||||
|
class DipoleGenerationConfig:
|
||||||
|
# assumed len 3
|
||||||
|
p: numpy.ndarray
|
||||||
|
s: numpy.ndarray
|
||||||
|
|
||||||
|
# should be 1/tau up to some pis
|
||||||
|
w: float
|
||||||
|
|
||||||
|
|
||||||
|
def row_to_dipole(input_dict: dict) -> DipoleTO:
|
||||||
|
p = input_dict["p"]
|
||||||
|
if len(p) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"p parameter in input_dict [{input_dict}] does not have length 3"
|
||||||
|
)
|
||||||
|
s = input_dict["s"]
|
||||||
|
if len(s) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"s parameter in input_dict [{input_dict}] does not have length 3"
|
||||||
|
)
|
||||||
|
w = input_dict["w"]
|
||||||
|
|
||||||
|
return DipoleTO(p, s, w)
|
||||||
|
|
||||||
|
|
||||||
|
def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]:
|
||||||
|
return [row_to_dipole(input_dict) for input_dict in dot_dict_array]
|
14
tantri/dipoles/types.py
Executable file
14
tantri/dipoles/types.py
Executable file
@ -0,0 +1,14 @@
|
|||||||
|
import numpy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
|
||||||
|
# we can use as a DTO For dipole info.
|
||||||
|
@dataclass
|
||||||
|
class DipoleTO:
|
||||||
|
# assumed len 3
|
||||||
|
p: numpy.ndarray
|
||||||
|
s: numpy.ndarray
|
||||||
|
|
||||||
|
# should be 1/tau up to some pis
|
||||||
|
w: float
|
Loading…
x
Reference in New Issue
Block a user