feat: adds ability to parse dot information from json file

This commit is contained in:
Deepak Mallubhotla 2024-04-20 17:20:12 -05:00
parent 3b7dfd2462
commit 16dd254a87
Signed by: deepak
GPG Key ID: BEBAEBF28083E022
6 changed files with 107 additions and 5 deletions

View File

@ -1,5 +1,8 @@
import json import json
import logging import logging
import tantri.cli.input_files
from typing import Sequence
import tantri.dipoles
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -8,8 +11,14 @@ _logger = logging.getLogger(__name__)
# TODO: if this ever matters, can improve file handling. # TODO: if this ever matters, can improve file handling.
def read_json_file(filename): def read_data_from_filename(filename: str):
try: try:
return json.load(filename) with open(filename, "r") as file:
except Exception: return json.load(file)
_logger.exception(f"failed on reading filename {filename}") except Exception as e:
_logger.error(f"failed to read the file {filename}, raising and aborting", exc_info=e)
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
data = read_data_from_filename(filename)
return tantri.cli.input_files.rows_to_dots(data)

View File

@ -0,0 +1,5 @@
from tantri.cli.input_files.read_dots import rows_to_dots
__all__ = [
"rows_to_dots"
]

View File

@ -1 +0,0 @@
import tantri

View File

@ -0,0 +1,16 @@
import tantri.dipoles
import numpy
from typing import Sequence
def row_to_dot(input_dict: dict) -> tantri.dipoles.DotPosition:
r = input_dict["r"]
if len(r) != 3:
raise ValueError(f"r parameter in input_dict [{input_dict}] does not have length 3")
label = input_dict["label"]
return tantri.dipoles.DotPosition(numpy.array(r), label)
def rows_to_dots(dot_dict_array: Sequence[dict]) -> Sequence[tantri.dipoles.DotPosition]:
return [row_to_dot(input_dict) for input_dict in dot_dict_array]

View File

@ -0,0 +1,10 @@
# serializer version: 1
# name: test_from_json_string
list([
DotPosition(r=array([0. , 1.5, 2. ]), label='dot1'),
DotPosition(r=array([-5., 3., -1.]), label='dot2'),
])
# ---
# name: test_parse_one_dot
DotPosition(r=array([1.5, 0. , 1. ]), label='dot1')
# ---

View File

@ -0,0 +1,63 @@
import json
import tantri.cli.input_files.read_dots as dots
import pytest
def test_parse_one_dot_failures():
dict = {
"r": [1.5, 0, 1]
# missing label
}
with pytest.raises(KeyError, match=r".*label.*"):
dots.row_to_dot(dict)
dict = {
# missing r
"label": "label"
}
with pytest.raises(KeyError, match=r".*r.*"):
dots.row_to_dot(dict)
def test_parse_one_dot(snapshot):
dict = {
"r": [1.5, 0, 1],
"label": "dot1",
}
transformed = dots.row_to_dot(dict)
assert transformed == snapshot
def test_dot_r_wrong_length():
dict = {
"r": [1.5, 0],
"label": "dot1",
}
with pytest.raises(ValueError, match=r".*does not have length 3.*"):
dots.row_to_dot(dict)
def test_from_json_string(snapshot):
dots_json = """[
{
"r": [0, 1.5, 2],
"label": "dot1"
},
{
"r": [-5, 3, -1.0],
"label": "dot2"
}
]
"""
parsed = json.loads(dots_json)
transformed = dots.rows_to_dots(parsed)
assert transformed == snapshot