diff --git a/tantri/cli/file_importer.py b/tantri/cli/file_importer.py index ea6b4b7..9f12c6c 100755 --- a/tantri/cli/file_importer.py +++ b/tantri/cli/file_importer.py @@ -22,3 +22,8 @@ def read_data_from_filename(filename: str): def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: data = read_data_from_filename(filename) return tantri.cli.input_files.rows_to_dots(data) + + +def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: + data = read_data_from_filename(filename) + return tantri.cli.input_files.rows_to_dots(data) diff --git a/tantri/cli/input_files/__init__.py b/tantri/cli/input_files/__init__.py index 96e998a..8c38a99 100755 --- a/tantri/cli/input_files/__init__.py +++ b/tantri/cli/input_files/__init__.py @@ -1,5 +1,8 @@ from tantri.cli.input_files.read_dots import rows_to_dots +from tantri.cli.input_files.read_dipoles import rows_to_dipoles, DipoleTO __all__ = [ - "rows_to_dots" + "rows_to_dots", + "rows_to_dipoles", + "DipoleTO", ] diff --git a/tantri/cli/input_files/read_dipoles.py b/tantri/cli/input_files/read_dipoles.py new file mode 100755 index 0000000..aa24c89 --- /dev/null +++ b/tantri/cli/input_files/read_dipoles.py @@ -0,0 +1,31 @@ +import numpy +from typing import Sequence +from dataclasses import dataclass + + +# Lazily just separating this from tantri.dipoles.Dipole where there's additional cached stuff, this is just a thing +# we can use as a DTO For dipole info. +@dataclass +class DipoleTO: + # assumed len 3 + p: numpy.ndarray + s: numpy.ndarray + + # should be 1/tau up to some pis + w: float + + +def row_to_dipole(input_dict: dict) -> DipoleTO: + p = input_dict["p"] + if len(p) != 3: + raise ValueError(f"p parameter in input_dict [{input_dict}] does not have length 3") + s = input_dict["s"] + if len(s) != 3: + raise ValueError(f"s parameter in input_dict [{input_dict}] does not have length 3") + w = input_dict["w"] + + return DipoleTO(p, s, w) + + +def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]: + return [row_to_dipole(input_dict) for input_dict in dot_dict_array] diff --git a/tests/cli/__snapshots__/test_read_dipoles_json.ambr b/tests/cli/__snapshots__/test_read_dipoles_json.ambr new file mode 100755 index 0000000..05c4816 --- /dev/null +++ b/tests/cli/__snapshots__/test_read_dipoles_json.ambr @@ -0,0 +1,10 @@ +# serializer version: 1 +# name: test_dipoles_from_json_string + list([ + DipoleTO(p=[0, 0, 1000], s=[-1.5, 0.0, 3.5], w=0.11), + DipoleTO(p=[1000, 5050, 1234], s=[-66, 323, 431], w=220), + ]) +# --- +# name: test_parse_one_dipole_happy_path + DipoleTO(p=[150, 15.0, -21], s=[1.5, 0, 1], w=1513) +# --- diff --git a/tests/cli/test_read_dipoles_json.py b/tests/cli/test_read_dipoles_json.py new file mode 100755 index 0000000..553e3a2 --- /dev/null +++ b/tests/cli/test_read_dipoles_json.py @@ -0,0 +1,87 @@ +import json +import tantri.cli.input_files.read_dipoles as read_dipoles +import pytest + + +def test_parse_one_dipole_failures(): + + dict = { + "s": [1.5, 0, 1], + # missing p + "w": 1.5, + } + + with pytest.raises(KeyError, match=r".*p.*"): + read_dipoles.row_to_dipole(dict) + + dict = { + # missing s + "p": [150, 15.0, -21], + "w": 1.5, + } + + with pytest.raises(KeyError, match=r".*s.*"): + read_dipoles.row_to_dipole(dict) + + dict = { + "s": [1.5, 0, 1], + "p": [150, 15.0, -21], + # missing w + } + + with pytest.raises(KeyError, match=r".*w.*"): + read_dipoles.row_to_dipole(dict) + + +def test_parse_one_dipole_happy_path(snapshot): + + dict = { + "s": [1.5, 0, 1], + "p": [150, 15.0, -21], + "w": 1513, + } + + transformed = read_dipoles.row_to_dipole(dict) + assert transformed == snapshot + + +def test_dot_r_wrong_length(): + + dict = { + "s": [1.5, 0, 1, 5], + "p": [150, 15.0, -21], + "w": 1513, + } + + with pytest.raises(ValueError, match=r".*does not have length 3.*"): + read_dipoles.row_to_dipole(dict) + + dict = { + "s": [1.5, 0, 1], + "p": [150, 15.0, -21, 235, 23], + "w": 1513, + } + + with pytest.raises(ValueError, match=r".*does not have length 3.*"): + read_dipoles.row_to_dipole(dict) + + +def test_dipoles_from_json_string(snapshot): + read_dipoles_json = """[ + { + "s": [-1.5, 0.0, 3.5], + "p": [0, 0, 1000], + "w": 0.11 + }, + { + "s": [-66, 323, 431], + "p": [1000, 5050, 1234], + "w": 220 + } + ] + """ + + parsed = json.loads(read_dipoles_json) + transformed = read_dipoles.rows_to_dipoles(parsed) + + assert transformed == snapshot