From 7594f9b9b1691a67bdb26835b7bd637413f7a7cd Mon Sep 17 00:00:00 2001 From: Deepak Mallubhotla Date: Sun, 23 Feb 2025 19:02:13 -0600 Subject: [PATCH] feat: adds ability to parse pair measurements --- kalpaa/read_bin_csv.py | 65 +++++++++++--- .../__snapshots__/test_read_bin_csv.ambr | 90 +++++++++++++++++++ .../test_files/test_simple_pair_V.csv | 8 ++ tests/read_bin_csv/test_read_bin_csv.py | 12 +++ 4 files changed, 162 insertions(+), 13 deletions(-) create mode 100644 tests/read_bin_csv/test_files/test_simple_pair_V.csv diff --git a/kalpaa/read_bin_csv.py b/kalpaa/read_bin_csv.py index 4d2a8e5..791214e 100644 --- a/kalpaa/read_bin_csv.py +++ b/kalpaa/read_bin_csv.py @@ -334,6 +334,8 @@ def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]: class CSV_BinnedData: measurement_type: MeasurementTypeEnum single_dot_dict: typing.Dict[str, typing.Any] + pair_dot_dict: typing.Dict[typing.Tuple[str, str], typing.Any] + freqs: typing.Sequence[float] def read_bin_csv( @@ -363,9 +365,11 @@ def read_bin_csv( _logger.debug(f"Going to read frequencies from {freq_field=}") parsed_headers = {} + freq_list = [] aggregated_dict: typing.Dict[str, typing.Any] = { RETURNED_FREQUENCIES_KEY: [] } + pair_aggregated_dict: typing.Dict[typing.Tuple[str, str], typing.Any] = {} for field in remaining_fields: parsed_header = _parse_bin_header(field) @@ -374,17 +378,36 @@ def read_bin_csv( continue parsed_headers[field] = parsed_header - if parsed_header.dot_name not in aggregated_dict: - aggregated_dict[parsed_header.dot_name] = {} + # Get our dictionary structures set up by initialising empty dictionaries for each new field as we go + if parsed_header.pair: + if parsed_header.dot_name2 is None: + raise ValueError( + f"Pair measurement {field=} has no dot_name2, but it should" + ) + dot_names = (parsed_header.dot_name, parsed_header.dot_name2) + if dot_names not in pair_aggregated_dict: + pair_aggregated_dict[dot_names] = {} - if ( - parsed_header.summary_stat - not in aggregated_dict[parsed_header.dot_name] - ): - aggregated_dict[parsed_header.dot_name][ + if ( parsed_header.summary_stat - ] = [] + not in pair_aggregated_dict[dot_names] + ): + pair_aggregated_dict[dot_names][parsed_header.summary_stat] = [] + else: + if parsed_header.dot_name not in aggregated_dict: + aggregated_dict[parsed_header.dot_name] = {} + + if ( + parsed_header.summary_stat + not in aggregated_dict[parsed_header.dot_name] + ): + aggregated_dict[parsed_header.dot_name][ + parsed_header.summary_stat + ] = [] + + # Realistically we'll always have the same measurement type, but this warning may help us catch out cases where this didn't happen correctly + # We should only need to set it once, so the fact we keep checking is more about catching errors than anything else if measurement_type is not None: if measurement_type != parsed_header.measurement_type: _logger.warning( @@ -397,14 +420,27 @@ def read_bin_csv( for row in reader: # _logger.debug(f"Got {row=}") + freq_list.append(float(row[freq_field].strip())) + # don't need to set, but keep for legacy aggregated_dict[RETURNED_FREQUENCIES_KEY].append( float(row[freq_field].strip()) ) for field, parsed_header in parsed_headers.items(): - value = float(row[field].strip()) - aggregated_dict[parsed_header.dot_name][ - parsed_header.summary_stat - ].append(value) + if parsed_header.pair: + if parsed_header.dot_name2 is None: + raise ValueError( + f"Pair measurement {field=} has no dot_name2, but it should" + ) + value = float(row[field].strip()) + dot_names = (parsed_header.dot_name, parsed_header.dot_name2) + pair_aggregated_dict[dot_names][ + parsed_header.summary_stat + ].append(value) + else: + value = float(row[field].strip()) + aggregated_dict[parsed_header.dot_name][ + parsed_header.summary_stat + ].append(value) if measurement_type is None: raise ValueError( @@ -412,7 +448,10 @@ def read_bin_csv( ) return CSV_BinnedData( - measurement_type=measurement_type, single_dot_dict=aggregated_dict + measurement_type=measurement_type, + single_dot_dict=aggregated_dict, + freqs=freq_list, + pair_dot_dict=pair_aggregated_dict, ) except Exception as e: _logger.error( diff --git a/tests/read_bin_csv/__snapshots__/test_read_bin_csv.ambr b/tests/read_bin_csv/__snapshots__/test_read_bin_csv.ambr index 0add46d..e2829bc 100644 --- a/tests/read_bin_csv/__snapshots__/test_read_bin_csv.ambr +++ b/tests/read_bin_csv/__snapshots__/test_read_bin_csv.ambr @@ -341,3 +341,93 @@ }), ]) # --- +# name: test_read_csv_with_pairs + dict({ + 'freqs': list([ + 0.0125, + 0.024999999999999998, + 0.045, + 0.0775, + 0.1375, + 0.24749999999999997, + 0.41, + ]), + 'measurement_type': , + 'pair_dot_dict': dict({ + tuple( + 'dot1', + 'dot2', + ): dict({ + 'mean': list([ + 3.15, + 3.13, + 3.0, + 2.7, + 0.1, + 0.25, + 0.002, + ]), + 'stdev': list([ + 0.02, + 0.015, + 0.8, + 1.5, + 0.3, + 0.01, + 0.1, + ]), + }), + }), + 'single_dot_dict': dict({ + 'dot1': dict({ + 'mean': list([ + 10.638916947949246, + 4.808960230987057, + 1.8458074293863327, + 1.0990901962765007, + 0.6425140116757488, + 0.4844873135633905, + 0.448232552, + ]), + 'stdev': list([ + 5.688165841523548, + 1.5555855859097745, + 0.5112103163244077, + 0.37605535, + 0.1411676088216461, + 0.11795510686231957, + 0.081977941, + ]), + }), + 'dot2': dict({ + 'mean': list([ + 14.780311491085596, + 7.413101036489984, + 3.081527317039941, + 1.198719434472466, + 0.44608783800009594, + 0.16750150967807267, + 0.095604286, + ]), + 'stdev': list([ + 5.085761250807487, + 2.7753690312876014, + 1.3009911753215875, + 0.3361763625979774, + 0.18042157503806078, + 0.05820931, + 0.022567042968929727, + ]), + }), + 'frequencies': list([ + 0.0125, + 0.024999999999999998, + 0.045, + 0.0775, + 0.1375, + 0.24749999999999997, + 0.41, + ]), + }), + }) +# --- diff --git a/tests/read_bin_csv/test_files/test_simple_pair_V.csv b/tests/read_bin_csv/test_files/test_simple_pair_V.csv new file mode 100644 index 0000000..cf362f3 --- /dev/null +++ b/tests/read_bin_csv/test_files/test_simple_pair_V.csv @@ -0,0 +1,8 @@ +mean bin f (Hz), APSD_V_dot1_mean, APSD_V_dot1_stdev, APSD_V_dot2_mean, APSD_V_dot2_stdev,CPSD_phase_V_dot1_dot2_mean,CPSD_phase_V_dot1_dot2_stdev +0.0125, 10.638916947949246, 5.688165841523548, 14.780311491085596, 5.085761250807487,3.15,0.02 +0.024999999999999998, 4.808960230987057, 1.5555855859097745, 7.413101036489984, 2.7753690312876014,3.13,0.015 +0.045, 1.8458074293863327, 0.5112103163244077, 3.081527317039941, 1.3009911753215875,3,0.8 +0.0775, 1.0990901962765007,0.37605535, 1.198719434472466, 0.3361763625979774,2.7,1.5 +0.1375, 0.6425140116757488, 0.1411676088216461, 0.44608783800009594, 0.18042157503806078,0.1,0.3 +0.24749999999999997, 0.4844873135633905, 0.11795510686231957, 0.16750150967807267,0.05820931,0.25,0.01 +0.41,0.448232552,0.081977941,0.095604286, 0.022567042968929727,0.002,0.1 diff --git a/tests/read_bin_csv/test_read_bin_csv.py b/tests/read_bin_csv/test_read_bin_csv.py index 84dcd8e..aa4a526 100644 --- a/tests/read_bin_csv/test_read_bin_csv.py +++ b/tests/read_bin_csv/test_read_bin_csv.py @@ -131,3 +131,15 @@ def test_binned_data_dot_measurement_costs(snapshot): } assert result_dict == snapshot + + +def test_read_csv_with_pairs(snapshot): + + # dots_json = TEST_DATA_DIR / "dots.json" + # v_csv_file = TEST_DATA_DIR / "test_binned_apsd_V.csv" + # ex_csv_file = TEST_DATA_DIR / "test_binned_apsd_Ex.csv" + pair_data_csv = TEST_DATA_DIR / "test_simple_pair_V.csv" + + actual_read = kalpaa.read_bin_csv.read_bin_csv(pair_data_csv) + + assert dataclasses.asdict(actual_read) == snapshot