feat: allows generation of random dipoles from a standard config
This commit is contained in:
parent
0a51d12cc0
commit
b4b25974c9
@ -1,6 +1,12 @@
|
|||||||
import click
|
import click
|
||||||
import logging
|
import logging
|
||||||
import tantri
|
import tantri
|
||||||
|
import numpy
|
||||||
|
import tantri.cli.input_files.write_dipoles
|
||||||
|
import tantri.cli.file_importer
|
||||||
|
import json
|
||||||
|
import tantri.dipoles.generation
|
||||||
|
import pathlib
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -30,21 +36,72 @@ def cli(log, log_file):
|
|||||||
if log or (log_file is not None):
|
if log or (log_file is not None):
|
||||||
# log file has been provided, let's log
|
# log file has been provided, let's log
|
||||||
_set_up_logging(log_file)
|
_set_up_logging(log_file)
|
||||||
_logger.info("cli")
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
"--dipoles-file",
|
"--dipoles-file",
|
||||||
default="dipoles.json",
|
default="dipoles.json",
|
||||||
help="Filename containing json array of dipoles",
|
show_default=True,
|
||||||
|
type=click.Path(exists=True, path_type=pathlib.Path),
|
||||||
|
help="File with json array of dipoles",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--dots-file", default="dots.json", help="Filename containing json array of dots"
|
"--dots-file",
|
||||||
|
default="dots.json",
|
||||||
|
show_default=True,
|
||||||
|
type=click.Path(exists=True, path_type=pathlib.Path),
|
||||||
|
help="File with json array of dots",
|
||||||
)
|
)
|
||||||
def hello(dipoles_file, dots_file):
|
def read_files(dipoles_file, dots_file):
|
||||||
|
click.echo("in hello")
|
||||||
_logger.info("in hello")
|
_logger.info("in hello")
|
||||||
_logger.info(
|
_logger.info(
|
||||||
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
|
||||||
)
|
)
|
||||||
click.echo("in hello")
|
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
|
||||||
|
_logger.info(dipoles)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument(
|
||||||
|
"generation_config",
|
||||||
|
type=click.Path(exists=True, path_type=pathlib.Path),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"output_file",
|
||||||
|
type=click.File("w"),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--override-rng-seed", type=int, help="Seed to override the generation config spec."
|
||||||
|
)
|
||||||
|
def generate_dipoles(generation_config, output_file, override_rng_seed):
|
||||||
|
"""Generate random dipoles as described by GENERATION_CONFIG and output to OUTPUT_FILE.
|
||||||
|
|
||||||
|
GENERATION_CONFIG should be a JSON file that matches the appropriate spec, and OUTPUT_FILE will contain JSON formatted contents.
|
||||||
|
OUTPUT_FILE will be overwritten, if it exists.
|
||||||
|
|
||||||
|
If the --override-rng-seed is set, it's better to keep logs of the generation!
|
||||||
|
"""
|
||||||
|
_logger.debug(
|
||||||
|
f"generate_dipoles was called, with config file {click.format_filename(generation_config)}"
|
||||||
|
)
|
||||||
|
_logger.debug(f"override_rng_seed: [{override_rng_seed}]")
|
||||||
|
|
||||||
|
with open(generation_config, "r") as config_file:
|
||||||
|
data = json.load(config_file)
|
||||||
|
config = tantri.dipoles.generation.DipoleGenerationConfig(**data)
|
||||||
|
|
||||||
|
override_rng = None
|
||||||
|
if override_rng_seed is not None:
|
||||||
|
_logger.info(f"Overriding the rng with a new one with seed {override_rng_seed}")
|
||||||
|
override_rng = numpy.random.default_rng(override_rng_seed)
|
||||||
|
_logger.debug(f"generating dipoles with config {config}...")
|
||||||
|
generated = tantri.dipoles.generation.make_dipoles(config, override_rng)
|
||||||
|
|
||||||
|
output_file.write(
|
||||||
|
json.dumps(
|
||||||
|
[g.as_dict() for g in generated],
|
||||||
|
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pathlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import tantri.cli.input_files
|
import tantri.cli.input_files
|
||||||
@ -11,7 +12,7 @@ _logger = logging.getLogger(__name__)
|
|||||||
# TODO: if this ever matters, can improve file handling.
|
# TODO: if this ever matters, can improve file handling.
|
||||||
|
|
||||||
|
|
||||||
def read_data_from_filename(filename: str):
|
def read_data_from_filename(filename: pathlib.Path):
|
||||||
try:
|
try:
|
||||||
with open(filename, "r") as file:
|
with open(filename, "r") as file:
|
||||||
return json.load(file)
|
return json.load(file)
|
||||||
@ -21,11 +22,11 @@ def read_data_from_filename(filename: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
def read_dots_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DotPosition]:
|
||||||
data = read_data_from_filename(filename)
|
data = read_data_from_filename(filename)
|
||||||
return tantri.cli.input_files.rows_to_dots(data)
|
return tantri.cli.input_files.rows_to_dots(data)
|
||||||
|
|
||||||
|
|
||||||
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]:
|
def read_dipoles_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DipoleTO]:
|
||||||
data = read_data_from_filename(filename)
|
data = read_data_from_filename(filename)
|
||||||
return tantri.cli.input_files.rows_to_dipoles(data)
|
return tantri.cli.input_files.rows_to_dipoles(data)
|
||||||
|
13
tantri/cli/input_files/write_dipoles.py
Executable file
13
tantri/cli/input_files/write_dipoles.py
Executable file
@ -0,0 +1,13 @@
|
|||||||
|
import numpy
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyEncoder(json.JSONEncoder):
|
||||||
|
"""
|
||||||
|
Stolen from https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
|
||||||
|
"""
|
||||||
|
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, numpy.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
return json.JSONEncoder.default(self, obj)
|
@ -1,6 +1,6 @@
|
|||||||
import numpy
|
import numpy
|
||||||
from typing import Sequence, Optional
|
from typing import Sequence, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, asdict
|
||||||
from tantri.dipoles.types import DipoleTO
|
from tantri.dipoles.types import DipoleTO
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
@ -43,6 +43,17 @@ class DipoleGenerationConfig:
|
|||||||
dipole_count: int
|
dipole_count: int
|
||||||
generation_seed: int
|
generation_seed: int
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# This allows us to transparently set this with a string, while providing early warning of a type error
|
||||||
|
self.orientation = Orientation(self.orientation)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return_dict = asdict(self)
|
||||||
|
|
||||||
|
return_dict["orientation"] = return_dict["orientation"].value
|
||||||
|
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
def make_dipoles(
|
def make_dipoles(
|
||||||
config: DipoleGenerationConfig,
|
config: DipoleGenerationConfig,
|
||||||
@ -55,6 +66,7 @@ def make_dipoles(
|
|||||||
)
|
)
|
||||||
rng = numpy.random.default_rng(config.generation_seed)
|
rng = numpy.random.default_rng(config.generation_seed)
|
||||||
else:
|
else:
|
||||||
|
_logger.info("Using overridden rng, of unknown seed")
|
||||||
rng = rng_override
|
rng = rng_override
|
||||||
|
|
||||||
dipoles = []
|
dipoles = []
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import numpy
|
import numpy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
|
||||||
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
|
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
|
||||||
@ -12,3 +12,6 @@ class DipoleTO:
|
|||||||
|
|
||||||
# should be 1/tau up to some pis
|
# should be 1/tau up to some pis
|
||||||
w: float
|
w: float
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return asdict(self)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_deserialise_json_to_generation_config
|
# name: test_deserialise_json_to_generation_config
|
||||||
DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation='RANDOM', dipole_count=15, generation_seed=1234)
|
DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation=<Orientation.RANDOM: 'RANDOM'>, dipole_count=15, generation_seed=1234)
|
||||||
# ---
|
# ---
|
||||||
# name: test_serialise_generation_config_to_json
|
# name: test_serialise_generation_config_to_json
|
||||||
'''
|
'''
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from tantri.dipoles.generation import DipoleGenerationConfig, Orientation
|
from tantri.dipoles.generation import DipoleGenerationConfig, Orientation
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
|
|
||||||
def test_serialise_generation_config_to_json(snapshot):
|
def test_serialise_generation_config_to_json(snapshot):
|
||||||
@ -20,7 +19,7 @@ def test_serialise_generation_config_to_json(snapshot):
|
|||||||
generation_seed=1234,
|
generation_seed=1234,
|
||||||
)
|
)
|
||||||
|
|
||||||
config_json = json.dumps(dataclasses.asdict(config), indent="\t")
|
config_json = json.dumps(config.as_dict(), indent="\t")
|
||||||
assert config_json == snapshot
|
assert config_json == snapshot
|
||||||
|
|
||||||
|
|
||||||
@ -64,6 +63,6 @@ def test_serialise_deserialise_dipole_generation_config_back_and_forth():
|
|||||||
generation_seed=1234,
|
generation_seed=1234,
|
||||||
)
|
)
|
||||||
|
|
||||||
config_json = json.dumps(dataclasses.asdict(config), indent="\t")
|
config_json = json.dumps(config.as_dict(), indent="\t")
|
||||||
data = json.loads(config_json)
|
data = json.loads(config_json)
|
||||||
assert config == DipoleGenerationConfig(**data)
|
assert config == DipoleGenerationConfig(**data)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user