feat: adds dipole reading from json
This commit is contained in:
parent
16dd254a87
commit
c63d553c7b
@ -22,3 +22,8 @@ def read_data_from_filename(filename: str):
|
|||||||
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
def read_dots_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
||||||
data = read_data_from_filename(filename)
|
data = read_data_from_filename(filename)
|
||||||
return tantri.cli.input_files.rows_to_dots(data)
|
return tantri.cli.input_files.rows_to_dots(data)
|
||||||
|
|
||||||
|
|
||||||
|
def read_dipoles_json_file(filename: str) -> Sequence[tantri.dipoles.DotPosition]:
|
||||||
|
data = read_data_from_filename(filename)
|
||||||
|
return tantri.cli.input_files.rows_to_dots(data)
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
from tantri.cli.input_files.read_dots import rows_to_dots
|
from tantri.cli.input_files.read_dots import rows_to_dots
|
||||||
|
from tantri.cli.input_files.read_dipoles import rows_to_dipoles, DipoleTO
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"rows_to_dots"
|
"rows_to_dots",
|
||||||
|
"rows_to_dipoles",
|
||||||
|
"DipoleTO",
|
||||||
]
|
]
|
||||||
|
31
tantri/cli/input_files/read_dipoles.py
Executable file
31
tantri/cli/input_files/read_dipoles.py
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
import numpy
|
||||||
|
from typing import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# Lazily just separating this from tantri.dipoles.Dipole where there's additional cached stuff, this is just a thing
|
||||||
|
# we can use as a DTO For dipole info.
|
||||||
|
@dataclass
|
||||||
|
class DipoleTO:
|
||||||
|
# assumed len 3
|
||||||
|
p: numpy.ndarray
|
||||||
|
s: numpy.ndarray
|
||||||
|
|
||||||
|
# should be 1/tau up to some pis
|
||||||
|
w: float
|
||||||
|
|
||||||
|
|
||||||
|
def row_to_dipole(input_dict: dict) -> DipoleTO:
|
||||||
|
p = input_dict["p"]
|
||||||
|
if len(p) != 3:
|
||||||
|
raise ValueError(f"p parameter in input_dict [{input_dict}] does not have length 3")
|
||||||
|
s = input_dict["s"]
|
||||||
|
if len(s) != 3:
|
||||||
|
raise ValueError(f"s parameter in input_dict [{input_dict}] does not have length 3")
|
||||||
|
w = input_dict["w"]
|
||||||
|
|
||||||
|
return DipoleTO(p, s, w)
|
||||||
|
|
||||||
|
|
||||||
|
def rows_to_dipoles(dot_dict_array: Sequence[dict]) -> Sequence[DipoleTO]:
|
||||||
|
return [row_to_dipole(input_dict) for input_dict in dot_dict_array]
|
10
tests/cli/__snapshots__/test_read_dipoles_json.ambr
Executable file
10
tests/cli/__snapshots__/test_read_dipoles_json.ambr
Executable file
@ -0,0 +1,10 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_dipoles_from_json_string
|
||||||
|
list([
|
||||||
|
DipoleTO(p=[0, 0, 1000], s=[-1.5, 0.0, 3.5], w=0.11),
|
||||||
|
DipoleTO(p=[1000, 5050, 1234], s=[-66, 323, 431], w=220),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_parse_one_dipole_happy_path
|
||||||
|
DipoleTO(p=[150, 15.0, -21], s=[1.5, 0, 1], w=1513)
|
||||||
|
# ---
|
87
tests/cli/test_read_dipoles_json.py
Executable file
87
tests/cli/test_read_dipoles_json.py
Executable file
@ -0,0 +1,87 @@
|
|||||||
|
import json
|
||||||
|
import tantri.cli.input_files.read_dipoles as read_dipoles
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_one_dipole_failures():
|
||||||
|
|
||||||
|
dict = {
|
||||||
|
"s": [1.5, 0, 1],
|
||||||
|
# missing p
|
||||||
|
"w": 1.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match=r".*p.*"):
|
||||||
|
read_dipoles.row_to_dipole(dict)
|
||||||
|
|
||||||
|
dict = {
|
||||||
|
# missing s
|
||||||
|
"p": [150, 15.0, -21],
|
||||||
|
"w": 1.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match=r".*s.*"):
|
||||||
|
read_dipoles.row_to_dipole(dict)
|
||||||
|
|
||||||
|
dict = {
|
||||||
|
"s": [1.5, 0, 1],
|
||||||
|
"p": [150, 15.0, -21],
|
||||||
|
# missing w
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match=r".*w.*"):
|
||||||
|
read_dipoles.row_to_dipole(dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_one_dipole_happy_path(snapshot):
|
||||||
|
|
||||||
|
dict = {
|
||||||
|
"s": [1.5, 0, 1],
|
||||||
|
"p": [150, 15.0, -21],
|
||||||
|
"w": 1513,
|
||||||
|
}
|
||||||
|
|
||||||
|
transformed = read_dipoles.row_to_dipole(dict)
|
||||||
|
assert transformed == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_dot_r_wrong_length():
|
||||||
|
|
||||||
|
dict = {
|
||||||
|
"s": [1.5, 0, 1, 5],
|
||||||
|
"p": [150, 15.0, -21],
|
||||||
|
"w": 1513,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*does not have length 3.*"):
|
||||||
|
read_dipoles.row_to_dipole(dict)
|
||||||
|
|
||||||
|
dict = {
|
||||||
|
"s": [1.5, 0, 1],
|
||||||
|
"p": [150, 15.0, -21, 235, 23],
|
||||||
|
"w": 1513,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r".*does not have length 3.*"):
|
||||||
|
read_dipoles.row_to_dipole(dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dipoles_from_json_string(snapshot):
|
||||||
|
read_dipoles_json = """[
|
||||||
|
{
|
||||||
|
"s": [-1.5, 0.0, 3.5],
|
||||||
|
"p": [0, 0, 1000],
|
||||||
|
"w": 0.11
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"s": [-66, 323, 431],
|
||||||
|
"p": [1000, 5050, 1234],
|
||||||
|
"w": 220
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
parsed = json.loads(read_dipoles_json)
|
||||||
|
transformed = read_dipoles.rows_to_dipoles(parsed)
|
||||||
|
|
||||||
|
assert transformed == snapshot
|
Loading…
x
Reference in New Issue
Block a user