feat: allows generation of random dipoles from a standard config

This commit is contained in:
Deepak Mallubhotla 2024-04-20 22:04:03 -05:00
parent 0a51d12cc0
commit b4b25974c9
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
7 changed files with 99 additions and 14 deletions

View File

@ -1,6 +1,12 @@
import click import click
import logging import logging
import tantri import tantri
import numpy
import tantri.cli.input_files.write_dipoles
import tantri.cli.file_importer
import json
import tantri.dipoles.generation
import pathlib
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -30,21 +36,72 @@ def cli(log, log_file):
if log or (log_file is not None): if log or (log_file is not None):
# log file has been provided, let's log # log file has been provided, let's log
_set_up_logging(log_file) _set_up_logging(log_file)
_logger.info("cli")
@cli.command() @cli.command()
@click.option( @click.option(
"--dipoles-file", "--dipoles-file",
default="dipoles.json", default="dipoles.json",
help="Filename containing json array of dipoles", show_default=True,
type=click.Path(exists=True, path_type=pathlib.Path),
help="File with json array of dipoles",
) )
@click.option( @click.option(
"--dots-file", default="dots.json", help="Filename containing json array of dots" "--dots-file",
default="dots.json",
show_default=True,
type=click.Path(exists=True, path_type=pathlib.Path),
help="File with json array of dots",
) )
def hello(dipoles_file, dots_file): def read_files(dipoles_file, dots_file):
click.echo("in hello")
_logger.info("in hello") _logger.info("in hello")
_logger.info( _logger.info(
f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]" f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]"
) )
click.echo("in hello") dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file)
_logger.info(dipoles)
@cli.command()
@click.argument(
"generation_config",
type=click.Path(exists=True, path_type=pathlib.Path),
)
@click.argument(
"output_file",
type=click.File("w"),
)
@click.option(
"--override-rng-seed", type=int, help="Seed to override the generation config spec."
)
def generate_dipoles(generation_config, output_file, override_rng_seed):
"""Generate random dipoles as described by GENERATION_CONFIG and output to OUTPUT_FILE.
GENERATION_CONFIG should be a JSON file that matches the appropriate spec, and OUTPUT_FILE will contain JSON formatted contents.
OUTPUT_FILE will be overwritten, if it exists.
If the --override-rng-seed is set, it's better to keep logs of the generation!
"""
_logger.debug(
f"generate_dipoles was called, with config file {click.format_filename(generation_config)}"
)
_logger.debug(f"override_rng_seed: [{override_rng_seed}]")
with open(generation_config, "r") as config_file:
data = json.load(config_file)
config = tantri.dipoles.generation.DipoleGenerationConfig(**data)
override_rng = None
if override_rng_seed is not None:
_logger.info(f"Overriding the rng with a new one with seed {override_rng_seed}")
override_rng = numpy.random.default_rng(override_rng_seed)
_logger.debug(f"generating dipoles with config {config}...")
generated = tantri.dipoles.generation.make_dipoles(config, override_rng)
output_file.write(
json.dumps(
[g.as_dict() for g in generated],
cls=tantri.cli.input_files.write_dipoles.NumpyEncoder,
)
)

View File

@ -1,3 +1,4 @@
import pathlib
import json import json
import logging import logging
import tantri.cli.input_files import tantri.cli.input_files
@ -11,7 +12,7 @@ _logger = logging.getLogger(__name__)
# TODO: if this ever matters, can improve file handling. # TODO: if this ever matters, can improve file handling.
def read_data_from_filename(filename: str): def read_data_from_filename(filename: pathlib.Path):
try: try:
with open(filename, "r") as file: with open(filename, "r") as file:
return json.load(file) return json.load(file)
@ -21,11 +22,11 @@ def read_data_from_filename(filename: str):
) )
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: def read_dots_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DotPosition]:
data = read_data_from_filename(filename) data = read_data_from_filename(filename)
return tantri.cli.input_files.rows_to_dots(data) return tantri.cli.input_files.rows_to_dots(data)
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]: def read_dipoles_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DipoleTO]:
data = read_data_from_filename(filename) data = read_data_from_filename(filename)
return tantri.cli.input_files.rows_to_dipoles(data) return tantri.cli.input_files.rows_to_dipoles(data)

View File

@ -0,0 +1,13 @@
import numpy
import json
class NumpyEncoder(json.JSONEncoder):
"""
Stolen from https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
"""
def default(self, obj):
if isinstance(obj, numpy.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

View File

@ -1,6 +1,6 @@
import numpy import numpy
from typing import Sequence, Optional from typing import Sequence, Optional
from dataclasses import dataclass from dataclasses import dataclass, asdict
from tantri.dipoles.types import DipoleTO from tantri.dipoles.types import DipoleTO
from enum import Enum from enum import Enum
import logging import logging
@ -43,6 +43,17 @@ class DipoleGenerationConfig:
dipole_count: int dipole_count: int
generation_seed: int generation_seed: int
def __post_init__(self):
# This allows us to transparently set this with a string, while providing early warning of a type error
self.orientation = Orientation(self.orientation)
def as_dict(self) -> dict:
return_dict = asdict(self)
return_dict["orientation"] = return_dict["orientation"].value
return return_dict
def make_dipoles( def make_dipoles(
config: DipoleGenerationConfig, config: DipoleGenerationConfig,
@ -55,6 +66,7 @@ def make_dipoles(
) )
rng = numpy.random.default_rng(config.generation_seed) rng = numpy.random.default_rng(config.generation_seed)
else: else:
_logger.info("Using overridden rng, of unknown seed")
rng = rng_override rng = rng_override
dipoles = [] dipoles = []

View File

@ -1,5 +1,5 @@
import numpy import numpy
from dataclasses import dataclass from dataclasses import dataclass, asdict
# Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing # Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing
@ -12,3 +12,6 @@ class DipoleTO:
# should be 1/tau up to some pis # should be 1/tau up to some pis
w: float w: float
def as_dict(self) -> dict:
return asdict(self)

View File

@ -1,6 +1,6 @@
# serializer version: 1 # serializer version: 1
# name: test_deserialise_json_to_generation_config # name: test_deserialise_json_to_generation_config
DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation='RANDOM', dipole_count=15, generation_seed=1234) DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation=<Orientation.RANDOM: 'RANDOM'>, dipole_count=15, generation_seed=1234)
# --- # ---
# name: test_serialise_generation_config_to_json # name: test_serialise_generation_config_to_json
''' '''

View File

@ -1,6 +1,5 @@
import json import json
from tantri.dipoles.generation import DipoleGenerationConfig, Orientation from tantri.dipoles.generation import DipoleGenerationConfig, Orientation
import dataclasses
def test_serialise_generation_config_to_json(snapshot): def test_serialise_generation_config_to_json(snapshot):
@ -20,7 +19,7 @@ def test_serialise_generation_config_to_json(snapshot):
generation_seed=1234, generation_seed=1234,
) )
config_json = json.dumps(dataclasses.asdict(config), indent="\t") config_json = json.dumps(config.as_dict(), indent="\t")
assert config_json == snapshot assert config_json == snapshot
@ -64,6 +63,6 @@ def test_serialise_deserialise_dipole_generation_config_back_and_forth():
generation_seed=1234, generation_seed=1234,
) )
config_json = json.dumps(dataclasses.asdict(config), indent="\t") config_json = json.dumps(config.as_dict(), indent="\t")
data = json.loads(config_json) data = json.loads(config_json)
assert config == DipoleGenerationConfig(**data) assert config == DipoleGenerationConfig(**data)