diff --git a/tantri/cli/__init__.py b/tantri/cli/__init__.py index 92ed83c..3f7ec7a 100755 --- a/tantri/cli/__init__.py +++ b/tantri/cli/__init__.py @@ -1,6 +1,12 @@ import click import logging import tantri +import numpy +import tantri.cli.input_files.write_dipoles +import tantri.cli.file_importer +import json +import tantri.dipoles.generation +import pathlib _logger = logging.getLogger(__name__) @@ -30,21 +36,72 @@ def cli(log, log_file): if log or (log_file is not None): # log file has been provided, let's log _set_up_logging(log_file) - _logger.info("cli") @cli.command() @click.option( "--dipoles-file", default="dipoles.json", - help="Filename containing json array of dipoles", + show_default=True, + type=click.Path(exists=True, path_type=pathlib.Path), + help="File with json array of dipoles", ) @click.option( - "--dots-file", default="dots.json", help="Filename containing json array of dots" + "--dots-file", + default="dots.json", + show_default=True, + type=click.Path(exists=True, path_type=pathlib.Path), + help="File with json array of dots", ) -def hello(dipoles_file, dots_file): +def read_files(dipoles_file, dots_file): + click.echo("in hello") _logger.info("in hello") _logger.info( f"Received parameters [dipoles_file: {dipoles_file}] and [dots_file: {dots_file}]" ) - click.echo("in hello") + dipoles = tantri.cli.file_importer.read_dipoles_json_file(dipoles_file) + _logger.info(dipoles) + + +@cli.command() +@click.argument( + "generation_config", + type=click.Path(exists=True, path_type=pathlib.Path), +) +@click.argument( + "output_file", + type=click.File("w"), +) +@click.option( + "--override-rng-seed", type=int, help="Seed to override the generation config spec." +) +def generate_dipoles(generation_config, output_file, override_rng_seed): + """Generate random dipoles as described by GENERATION_CONFIG and output to OUTPUT_FILE. + + GENERATION_CONFIG should be a JSON file that matches the appropriate spec, and OUTPUT_FILE will contain JSON formatted contents. + OUTPUT_FILE will be overwritten, if it exists. + + If the --override-rng-seed is set, it's better to keep logs of the generation! + """ + _logger.debug( + f"generate_dipoles was called, with config file {click.format_filename(generation_config)}" + ) + _logger.debug(f"override_rng_seed: [{override_rng_seed}]") + + with open(generation_config, "r") as config_file: + data = json.load(config_file) + config = tantri.dipoles.generation.DipoleGenerationConfig(**data) + + override_rng = None + if override_rng_seed is not None: + _logger.info(f"Overriding the rng with a new one with seed {override_rng_seed}") + override_rng = numpy.random.default_rng(override_rng_seed) + _logger.debug(f"generating dipoles with config {config}...") + generated = tantri.dipoles.generation.make_dipoles(config, override_rng) + + output_file.write( + json.dumps( + [g.as_dict() for g in generated], + cls=tantri.cli.input_files.write_dipoles.NumpyEncoder, + ) + ) diff --git a/tantri/cli/file_importer.py b/tantri/cli/file_importer.py index 22de59d..3617777 100755 --- a/tantri/cli/file_importer.py +++ b/tantri/cli/file_importer.py @@ -1,3 +1,4 @@ +import pathlib import json import logging import tantri.cli.input_files @@ -11,7 +12,7 @@ _logger = logging.getLogger(__name__) # TODO: if this ever matters, can improve file handling. -def read_data_from_filename(filename: str): +def read_data_from_filename(filename: pathlib.Path): try: with open(filename, "r") as file: return json.load(file) @@ -21,11 +22,11 @@ def read_data_from_filename(filename: str): ) -def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: +def read_dots_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DotPosition]: data = read_data_from_filename(filename) return tantri.cli.input_files.rows_to_dots(data) -def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DipoleTO]: +def read_dipoles_json_file(filename: pathlib.Path) -> Sequence[tantri.dipoles.DipoleTO]: data = read_data_from_filename(filename) return tantri.cli.input_files.rows_to_dipoles(data) diff --git a/tantri/cli/input_files/write_dipoles.py b/tantri/cli/input_files/write_dipoles.py new file mode 100755 index 0000000..aec7add --- /dev/null +++ b/tantri/cli/input_files/write_dipoles.py @@ -0,0 +1,13 @@ +import numpy +import json + + +class NumpyEncoder(json.JSONEncoder): + """ + Stolen from https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable + """ + + def default(self, obj): + if isinstance(obj, numpy.ndarray): + return obj.tolist() + return json.JSONEncoder.default(self, obj) diff --git a/tantri/dipoles/generation/__init__.py b/tantri/dipoles/generation/__init__.py index 3079c5a..76ecb4e 100755 --- a/tantri/dipoles/generation/__init__.py +++ b/tantri/dipoles/generation/__init__.py @@ -1,6 +1,6 @@ import numpy from typing import Sequence, Optional -from dataclasses import dataclass +from dataclasses import dataclass, asdict from tantri.dipoles.types import DipoleTO from enum import Enum import logging @@ -43,6 +43,17 @@ class DipoleGenerationConfig: dipole_count: int generation_seed: int + def __post_init__(self): + # This allows us to transparently set this with a string, while providing early warning of a type error + self.orientation = Orientation(self.orientation) + + def as_dict(self) -> dict: + return_dict = asdict(self) + + return_dict["orientation"] = return_dict["orientation"].value + + return return_dict + def make_dipoles( config: DipoleGenerationConfig, @@ -55,6 +66,7 @@ def make_dipoles( ) rng = numpy.random.default_rng(config.generation_seed) else: + _logger.info("Using overridden rng, of unknown seed") rng = rng_override dipoles = [] diff --git a/tantri/dipoles/types.py b/tantri/dipoles/types.py index 030a788..d2d9d0a 100755 --- a/tantri/dipoles/types.py +++ b/tantri/dipoles/types.py @@ -1,5 +1,5 @@ import numpy -from dataclasses import dataclass +from dataclasses import dataclass, asdict # Lazily just separating this from Dipole where there's additional cached stuff, this is just a thing @@ -12,3 +12,6 @@ class DipoleTO: # should be 1/tau up to some pis w: float + + def as_dict(self) -> dict: + return asdict(self) diff --git a/tests/dipoles/generation/__snapshots__/test_serialization_dipole_generation.ambr b/tests/dipoles/generation/__snapshots__/test_serialization_dipole_generation.ambr index 1daed18..842f2ac 100755 --- a/tests/dipoles/generation/__snapshots__/test_serialization_dipole_generation.ambr +++ b/tests/dipoles/generation/__snapshots__/test_serialization_dipole_generation.ambr @@ -1,6 +1,6 @@ # serializer version: 1 # name: test_deserialise_json_to_generation_config - DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation='RANDOM', dipole_count=15, generation_seed=1234) + DipoleGenerationConfig(x_min=-5, x_max=5, y_min=-4.2, y_max=39, z_min=1.2, z_max=80, mag=1000, w_log_min=-5.5, w_log_max=6.6, orientation=, dipole_count=15, generation_seed=1234) # --- # name: test_serialise_generation_config_to_json ''' diff --git a/tests/dipoles/generation/test_serialization_dipole_generation.py b/tests/dipoles/generation/test_serialization_dipole_generation.py index d310fcd..8b1d359 100755 --- a/tests/dipoles/generation/test_serialization_dipole_generation.py +++ b/tests/dipoles/generation/test_serialization_dipole_generation.py @@ -1,6 +1,5 @@ import json from tantri.dipoles.generation import DipoleGenerationConfig, Orientation -import dataclasses def test_serialise_generation_config_to_json(snapshot): @@ -20,7 +19,7 @@ def test_serialise_generation_config_to_json(snapshot): generation_seed=1234, ) - config_json = json.dumps(dataclasses.asdict(config), indent="\t") + config_json = json.dumps(config.as_dict(), indent="\t") assert config_json == snapshot @@ -64,6 +63,6 @@ def test_serialise_deserialise_dipole_generation_config_back_and_forth(): generation_seed=1234, ) - config_json = json.dumps(dataclasses.asdict(config), indent="\t") + config_json = json.dumps(config.as_dict(), indent="\t") data = json.loads(config_json) assert config == DipoleGenerationConfig(**data)