feat: allows generation of random dipoles from a standard config

This commit is contained in:
Deepak Mallubhotla 2024-04-20 22:04:03 -05:00
parent 0a51d12cc0
commit b4b25974c9
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
7 changed files with 99 additions and 14 deletions

View File

@ -1,6 +1,12 @@
import click
import logging
import tantri
import numpy
import tantri.cli.input_files.write_dipoles
import tantri.cli.file_importer
import json
import tantri.dipoles.generation
import pathlib
_logger = logging.getLogger(__name__)
@ -30,21 +36,72 @@ def cli(log, log_file):
if log or (log_file is not None):
# log file has been provided, let's log
_set_up_logging(log_file)
_logger.info("cli")
@cli.command()
@click.option(
"--dipoles-file",
default="dipoles.json",
help="Filename containing json array of dipoles",
show_default=True,
type=click.Path(exists=True, path_type=pathlib.Path),
help="File with json array of dipoles",
)
@click.option(
"--dots-file", default="dots.json", help="Filename containing json array of dots"
"--dots-file",
default="dots.json",
show_default=True,
type=click.Path(exists=True, path_type=pathlib.Path),
help="File with json array of dots",
)
def hello(dipoles_file, dots_file):
def read_files(dipoles_file, dots_file):
click.echo("in hello")
_logger.info("in hello")
_logger.info(
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
)
click.echo("in hello")
dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
_logger.info(dipoles)
@cli.command()
@click.argument(
"generation_config",
type=click.Path(exists=True, path_type=pathlib.Path),
)
@click.argument(
"output_file",
type=click.File("w"),
)
@click.option(
"--override-rng-seed", type=int, help="Seed to override the generation config spec."
)
def generate_dipoles(generation_config, output_file, override_rng_seed):
"""Generate random dipoles as described by GENERATION_CONFIG and output to OUTPUT_FILE.
GENERATION_CONFIG should be a JSON file that matches the appropriate spec, and OUTPUT_FILE will contain JSON formatted contents.
OUTPUT_FILE will be overwritten, if it exists.
If the --override-rng-seed is set, it's better to keep logs of the generation!
"""
_logger.debug(
f"generate_dipoles was called, with config file {click.format_filename(generation_config)}"
)
_logger.debug(f"override_rng_seed: [{override_rng_seed}]")
with open(generation_config, "r") as config_file:
data = json.load(config_file)
config = tantri.dipoles.generation.DipoleGenerationConfig(**data)
override_rng = None
if override_rng_seed is not None:
_logger.info(f"Overriding the rng with a new one with seed {override_rng_seed}")
override_rng = numpy.random.default_rng(override_rng_seed)
_logger.debug(f"generating dipoles with config {config}...")
generated = tantri.dipoles.generation.make_dipoles(config, override_rng)
output_file.write(
json.dumps(
[g.as_dict() for g in generated],
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder,
)
)

View File

@ -1,3 +1,4 @@
import pathlib
import json
import logging
import tantri.cli.input_files
@ -11,7 +12,7 @@ _logger = logging.getLogger(__name__)
# TODO: if this ever matters, can improve file handling.
def read_data_from_filename(filename: str):
def read_data_from_filename(filename: pathlib.Path):
try:
with open(filename, "r") as file:
return json.load(file)
@ -21,11 +22,11 @@ def read_data_from_filename(filename: str):
)
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
def read_dots_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DotPosition]:
data = read_data_from_filename(filename)
return tantri.cli.input_files.rows_to_dots(data)
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]:
def read_dipoles_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DipoleTO]:
data = read_data_from_filename(filename)
return tantri.cli.input_files.rows_to_dipoles(data)

View File

@ -0,0 +1,13 @@
import numpy
import json
class NumpyEncoder(json.JSONEncoder):
"""
Stolen from https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
"""
def default(self, obj):
if isinstance(obj, numpy.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

View File

@ -1,6 +1,6 @@
import numpy
from typing import Sequence, Optional
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from tantri.dipoles.types import DipoleTO
from enum import Enum
import logging
@ -43,6 +43,17 @@ class DipoleGenerationConfig:
dipole_count: int
generation_seed: int
def __post_init__(self):
# This allows us to transparently set this with a string, while providing early warning of a type error
self.orientation = Orientation(self.orientation)
def as_dict(self) -> dict:
return_dict = asdict(self)
return_dict["orientation"] = return_dict["orientation"].value
return return_dict
def make_dipoles(
config: DipoleGenerationConfig,
@ -55,6 +66,7 @@ def make_dipoles(
)
rng = numpy.random.default_rng(config.generation_seed)
else:
_logger.info("Using overridden rng, of unknown seed")
rng = rng_override
dipoles = []

View File

@ -1,5 +1,5 @@
import numpy
from dataclasses import dataclass
from dataclasses import dataclass, asdict
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
@ -12,3 +12,6 @@ class DipoleTO:
# should be 1/tau up to some pis
w: float
def as_dict(self) -> dict:
return asdict(self)

View File

@ -1,6 +1,6 @@
# serializer version: 1
# name: test_deserialise_json_to_generation_config
DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation='RANDOM', dipole_count=15, generation_seed=1234)
DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation=<Orientation.RANDOM: 'RANDOM'>, dipole_count=15, generation_seed=1234)
# ---
# name: test_serialise_generation_config_to_json
'''

View File

@ -1,6 +1,5 @@
import json
from tantri.dipoles.generation import DipoleGenerationConfig, Orientation
import dataclasses
def test_serialise_generation_config_to_json(snapshot):
@ -20,7 +19,7 @@ def test_serialise_generation_config_to_json(snapshot):
generation_seed=1234,
)
config_json = json.dumps(dataclasses.asdict(config), indent="\t")
config_json = json.dumps(config.as_dict(), indent="\t")
assert config_json == snapshot
@ -64,6 +63,6 @@ def test_serialise_deserialise_dipole_generation_config_back_and_forth():
generation_seed=1234,
)
config_json = json.dumps(dataclasses.asdict(config), indent="\t")
config_json = json.dumps(config.as_dict(), indent="\t")
data = json.loads(config_json)
assert config == DipoleGenerationConfig(**data)