diff --git a/tantri/cli/file_importer.py b/tantri/cli/file_importer.py index 1fca1e3..ea6b4b7 100755 --- a/tantri/cli/file_importer.py +++ b/tantri/cli/file_importer.py @@ -1,5 +1,8 @@ import json import logging +import tantri.cli.input_files +from typing import Sequence +import tantri.dipoles _logger = logging.getLogger(__name__) @@ -8,8 +11,14 @@ _logger = logging.getLogger(__name__) # TODO: if this ever matters, can improve file handling. -def read_json_file(filename): +def read_data_from_filename(filename: str): try: - return json.load(filename) - except Exception: - _logger.exception(f"failed on reading filename {filename}") + with open(filename, "r") as file: + return json.load(file) + except Exception as e: + _logger.error(f"failed to read the file {filename}, raising and aborting", exc_info=e) + + +def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]: + data = read_data_from_filename(filename) + return tantri.cli.input_files.rows_to_dots(data) diff --git a/tantri/cli/input_files/__init__.py b/tantri/cli/input_files/__init__.py new file mode 100755 index 0000000..96e998a --- /dev/null +++ b/tantri/cli/input_files/__init__.py @@ -0,0 +1,5 @@ +from tantri.cli.input_files.read_dots import rows_to_dots + +__all__ = [ + "rows_to_dots" +] diff --git a/tantri/cli/input_files/json_dots.py b/tantri/cli/input_files/json_dots.py deleted file mode 100755 index 2ee43a1..0000000 --- a/tantri/cli/input_files/json_dots.py +++ /dev/null @@ -1 +0,0 @@ -import tantri \ No newline at end of file diff --git a/tantri/cli/input_files/read_dots.py b/tantri/cli/input_files/read_dots.py new file mode 100755 index 0000000..83a1159 --- /dev/null +++ b/tantri/cli/input_files/read_dots.py @@ -0,0 +1,16 @@ +import tantri.dipoles +import numpy +from typing import Sequence + + +def row_to_dot(input_dict: dict) -> tantri.dipoles.DotPosition: + r = input_dict["r"] + if len(r) != 3: + raise ValueError(f"r parameter in input_dict [{input_dict}] does not have length 3") + label = input_dict["label"] + + return tantri.dipoles.DotPosition(numpy.array(r), label) + + +def rows_to_dots(dot_dict_array: Sequence[dict]) -> Sequence[tantri.dipoles.DotPosition]: + return [row_to_dot(input_dict) for input_dict in dot_dict_array] diff --git a/tests/cli/__snapshots__/test_read_dots_json.ambr b/tests/cli/__snapshots__/test_read_dots_json.ambr new file mode 100755 index 0000000..9ee40bc --- /dev/null +++ b/tests/cli/__snapshots__/test_read_dots_json.ambr @@ -0,0 +1,10 @@ +# serializer version: 1 +# name: test_from_json_string + list([ + DotPosition(r=array([0. , 1.5, 2. ]), label='dot1'), + DotPosition(r=array([-5., 3., -1.]), label='dot2'), + ]) +# --- +# name: test_parse_one_dot + DotPosition(r=array([1.5, 0. , 1. ]), label='dot1') +# --- diff --git a/tests/cli/test_read_dots_json.py b/tests/cli/test_read_dots_json.py index e69de29..72f7b02 100755 --- a/tests/cli/test_read_dots_json.py +++ b/tests/cli/test_read_dots_json.py @@ -0,0 +1,63 @@ +import json +import tantri.cli.input_files.read_dots as dots +import pytest + + +def test_parse_one_dot_failures(): + + dict = { + "r": [1.5, 0, 1] + # missing label + } + + with pytest.raises(KeyError, match=r".*label.*"): + dots.row_to_dot(dict) + + dict = { + # missing r + "label": "label" + } + + with pytest.raises(KeyError, match=r".*r.*"): + dots.row_to_dot(dict) + + +def test_parse_one_dot(snapshot): + + dict = { + "r": [1.5, 0, 1], + "label": "dot1", + } + + transformed = dots.row_to_dot(dict) + assert transformed == snapshot + + +def test_dot_r_wrong_length(): + + dict = { + "r": [1.5, 0], + "label": "dot1", + } + + with pytest.raises(ValueError, match=r".*does not have length 3.*"): + dots.row_to_dot(dict) + + +def test_from_json_string(snapshot): + dots_json = """[ + { + "r": [0, 1.5, 2], + "label": "dot1" + }, + { + "r": [-5, 3, -1.0], + "label": "dot2" + } + ] + """ + + parsed = json.loads(dots_json) + transformed = dots.rows_to_dots(parsed) + + assert transformed == snapshot