Compare commits
40 Commits
Author | SHA1 | Date | |
---|---|---|---|
72c76ec95d | |||
4d2d865e9d | |||
c331bc057f | |||
b685c63efe | |||
f2f326dcfe | |||
96742769be | |||
747fbd4bc2 | |||
ae652779f5 | |||
50c7bc5f0c | |||
a9f40c665e | |||
c99d1104e0 | |||
d4631dbf59 | |||
39d11a59b3 | |||
5ad98f01fe | |||
31ba558364 | |||
3fe55dbb67 | |||
6d24f96b0f | |||
fa40ceb00a | |||
7594f9b9b1 | |||
785746049a | |||
e83706b2b1 | |||
42dddcae02 | |||
1e840a8c32 | |||
9bda1a48c6 | |||
8e122bad15 | |||
969001e864 | |||
55c4476d75 | |||
d738451482 | |||
541a3af3cd | |||
8eec78465f | |||
d591f3e367 | |||
35cb4405fb | |||
5a63b2cb13 | |||
6525caa9d9 | |||
2d5a82926e | |||
aea26dfa16 | |||
e3d3625c92 | |||
30c760f7ec | |||
007e6e0f6d | |||
176987f40c |
11
.gitignore
vendored
11
.gitignore
vendored
@ -142,7 +142,7 @@ dmypy.json
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
*.csv
|
||||
# *.csv
|
||||
|
||||
local_scripts/
|
||||
|
||||
@ -152,6 +152,9 @@ logs/
|
||||
out/
|
||||
*.xlsx
|
||||
|
||||
kalpa.toml
|
||||
indexes.json
|
||||
dots.json
|
||||
/kalpa.toml
|
||||
/indexes.json
|
||||
/dots.json
|
||||
|
||||
# nix
|
||||
result
|
||||
|
37
CHANGELOG.md
37
CHANGELOG.md
@ -2,6 +2,43 @@
|
||||
|
||||
All notable changes to this project will be documented in this file. See [standard-version](https://github.com/conventional-changelog/standard-version) for commit guidelines.
|
||||
|
||||
## [1.2.0](https://gitea.deepak.science:2222/physics/kalpa/compare/1.1.0...1.2.0) (2025-03-03)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* adds ability to specify which dots are used manually, and better handles collating and coalescing results ([9674276](https://gitea.deepak.science:2222/physics/kalpa/commit/96742769bedad928890a27153e7c0952a6fc1cdb))
|
||||
|
||||
## [1.1.0](https://gitea.deepak.science:2222/physics/kalpa/compare/1.0.1...1.1.0) (2025-03-02)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add flag to disable logging to stderr ([39d11a5](https://gitea.deepak.science:2222/physics/kalpa/commit/39d11a59b3bec87a52950b8c525d3e141981613e))
|
||||
* adds ability to do phase cost functions ([fa40ceb](https://gitea.deepak.science:2222/physics/kalpa/commit/fa40ceb00a2734d5116a3475df5ec3d0f57ba08d))
|
||||
* adds ability to manually specify measurement files ([3fe55db](https://gitea.deepak.science:2222/physics/kalpa/commit/3fe55dbb671d24391197a2f6db8bbbfcab174577))
|
||||
* adds ability to override geometry info from config file ([6d24f96](https://gitea.deepak.science:2222/physics/kalpa/commit/6d24f96b0f71842f918e350e0232c5b6ae945d33))
|
||||
* adds ability to parse pair measurement headers, though without processing yet ([8e122ba](https://gitea.deepak.science:2222/physics/kalpa/commit/8e122bad153d7bff78d2e7b39f52aa1455ecb356))
|
||||
* adds ability to parse pair measurements ([7594f9b](https://gitea.deepak.science:2222/physics/kalpa/commit/7594f9b9b1691a67bdb26835b7bd637413f7a7cd))
|
||||
* adds new type for measurement groups to facilitate future refactors, technically breaking but not for what public interface should be ([42dddca](https://gitea.deepak.science:2222/physics/kalpa/commit/42dddcae0214fc576120531e3a15a7f9995dd126))
|
||||
* adds resumptions and ability to check completed work ([ae65277](https://gitea.deepak.science:2222/physics/kalpa/commit/ae652779f56ccdec9c81d637dccf92b26e60d44f))
|
||||
* parse cpsd type from header ([9bda1a4](https://gitea.deepak.science:2222/physics/kalpa/commit/9bda1a48c6b3acce29983a37c0099ead94550021))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* better handling of files in stages and remove crufty code ([d4631db](https://gitea.deepak.science:2222/physics/kalpa/commit/d4631dbf595d7d06855cdb3123dc302342153b25))
|
||||
* better inference with overrides ([31ba558](https://gitea.deepak.science:2222/physics/kalpa/commit/31ba558364a9be25a52cf216204f99c23240bc14))
|
||||
* fixes for override measurement file finding ([5ad98f0](https://gitea.deepak.science:2222/physics/kalpa/commit/5ad98f01feefe20af7674c2be4e93762f035f444))
|
||||
* fixes unused import error ([c99d110](https://gitea.deepak.science:2222/physics/kalpa/commit/c99d1104e0e5c9bdb22d93398eba96fc6d1a193f))
|
||||
|
||||
### [1.0.1](https://gitea.deepak.science:2222/physics/kalpa/compare/1.0.0...1.0.1) (2025-02-22)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* add label to logs, and some minor formatting ([007e6e0](https://gitea.deepak.science:2222/physics/kalpa/commit/007e6e0f6dc1660c9088927f4c750877194027e1))
|
||||
|
||||
## 1.0.0 (2025-02-22)
|
||||
|
||||
|
||||
|
71
flake.nix
71
flake.nix
@ -13,28 +13,66 @@
|
||||
let
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
poetry2nix = poetry2nixSrc.lib.mkPoetry2Nix { inherit pkgs; };
|
||||
kalpaaApp = poetry2nix.mkPoetryApplication {
|
||||
projectDir = self;
|
||||
python = pkgs.python39;
|
||||
preferWheels = true;
|
||||
};
|
||||
kalpaaEnv = poetry2nix.mkPoetryEnv {
|
||||
projectDir = self;
|
||||
python = pkgs.python39;
|
||||
preferWheels = true;
|
||||
overrides = poetry2nix.overrides.withDefaults (self: super: {
|
||||
});
|
||||
};
|
||||
kalpaa-docker-image = pkgs.dockerTools.buildLayeredImage {
|
||||
name = "kalpaa";
|
||||
tag = "latest";
|
||||
|
||||
|
||||
contents = [
|
||||
|
||||
# some stuff that dockertools provides?
|
||||
# pkgs.dockerTools.usrBinEnv
|
||||
# pkgs.dockerTools.binSh
|
||||
# pkgs.dockerTools.caCertificates
|
||||
# pkgs.dockerTools.fakeNss
|
||||
|
||||
pkgs.bash
|
||||
pkgs.coreutils
|
||||
# pkgs.cacert
|
||||
# pkgs.gnutar
|
||||
# pkgs.gzip
|
||||
# pkgs.gnused
|
||||
# pkgs.gnugrep
|
||||
pkgs.uv
|
||||
kalpaaApp
|
||||
|
||||
];
|
||||
|
||||
config = {
|
||||
Cmd = [ "/bin/bash" ];
|
||||
Env = [
|
||||
"PATH=/bin"
|
||||
];
|
||||
WorkingDir = "/workspace";
|
||||
};
|
||||
};
|
||||
in {
|
||||
packages = {
|
||||
kalpaApp = poetry2nix.mkPoetryApplication {
|
||||
projectDir = self;
|
||||
python = pkgs.python39;
|
||||
preferWheels = true;
|
||||
};
|
||||
kalpaEnv = poetry2nix.mkPoetryEnv {
|
||||
projectDir = self;
|
||||
python = pkgs.python39;
|
||||
preferWheels = true;
|
||||
overrides = poetry2nix.overrides.withDefaults (self: super: {
|
||||
});
|
||||
};
|
||||
default = self.packages.${system}.kalpaEnv;
|
||||
inherit kalpaaEnv;
|
||||
inherit kalpaaApp;
|
||||
inherit kalpaa-docker-image;
|
||||
default = self.packages.${system}.kalpaaEnv;
|
||||
};
|
||||
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
inputsFrom = [ self.packages.${system}.kalpaEnv ];
|
||||
inputsFrom = [ self.packages.${system}.kalpaaEnv ];
|
||||
buildInputs = [
|
||||
pkgs.poetry
|
||||
self.packages.${system}.kalpaEnv
|
||||
self.packages.${system}.kalpaApp
|
||||
self.packages.${system}.kalpaaEnv
|
||||
self.packages.${system}.kalpaaApp
|
||||
pkgs.just
|
||||
pkgs.nodejs
|
||||
];
|
||||
@ -42,6 +80,7 @@
|
||||
export DO_NIX_CUSTOM=1
|
||||
'';
|
||||
};
|
||||
|
||||
}
|
||||
);
|
||||
}
|
||||
|
21
justfile
21
justfile
@ -70,3 +70,24 @@ release version="":
|
||||
|
||||
# htmlcov:
|
||||
# poetry run pytest --cov-report=html
|
||||
|
||||
# build docker image
|
||||
build-container:
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
nix build .#kalpaa-docker-image
|
||||
|
||||
# load the image into docker
|
||||
load-container:
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
docker load < result
|
||||
|
||||
# build and load in one step
|
||||
build-load-container: build-container load-container
|
||||
echo "Image loaded successfully!"
|
||||
|
||||
exec-container:
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
docker run -it -v $(pwd)/kalpaa.toml:/workspace/kalpaa.toml -v $(pwd)/dots.json:/workspace/dots.json -v $(pwd)/indexes.json:/workspace/indexes.json kalpaa /bin/bash
|
||||
|
25
kalpaa/common/angles.py
Normal file
25
kalpaa/common/angles.py
Normal file
@ -0,0 +1,25 @@
|
||||
import numpy
|
||||
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def shortest_angular_distance(
|
||||
angles_1: numpy.ndarray, angles_2: numpy.ndarray
|
||||
) -> numpy.ndarray:
|
||||
"""
|
||||
Compute the shortest angular distance, pairwise, between two sets of angles.
|
||||
Assuming that angles in radians, and that the shape of our arrays is what we expect.
|
||||
|
||||
:param angles_1: one array of angles
|
||||
:param angles_2: the other array of angles
|
||||
:return: array of differences numpy.ndarray
|
||||
"""
|
||||
|
||||
result = (angles_1 - angles_2 + numpy.pi) % (2 * numpy.pi) - numpy.pi
|
||||
# _logger.debug(f"{angles_1=}")
|
||||
# _logger.debug(f"{angles_2=}")
|
||||
# _logger.debug(f"{result=}")
|
||||
|
||||
return result
|
@ -6,19 +6,18 @@ import typing
|
||||
|
||||
def set_up_logging(
|
||||
config: kalpaa.config.Config,
|
||||
log_stream: bool,
|
||||
log_file: typing.Optional[str],
|
||||
create_logfile_parents: bool = True,
|
||||
):
|
||||
handlers: typing.List[logging.Handler]
|
||||
if log_file is None:
|
||||
handlers = [
|
||||
logging.StreamHandler(),
|
||||
]
|
||||
else:
|
||||
handlers: typing.List[logging.Handler] = []
|
||||
if log_stream:
|
||||
handlers.append(logging.StreamHandler())
|
||||
if log_file is not None:
|
||||
if create_logfile_parents:
|
||||
# create any parent directories for the log file if needed.
|
||||
pathlib.Path(log_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
handlers = [logging.StreamHandler(), logging.FileHandler(log_file)]
|
||||
handlers.append(logging.FileHandler(log_file))
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format=config.general_config.log_pattern,
|
||||
|
142
kalpaa/completions/__init__.py
Normal file
142
kalpaa/completions/__init__.py
Normal file
@ -0,0 +1,142 @@
|
||||
import pathlib
|
||||
import kalpaa.config
|
||||
import logging
|
||||
from enum import Enum
|
||||
import filecmp
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
KALPAA_COMPLETE = "kalpaa.complete"
|
||||
COMPLETIONS_DIR = "completions"
|
||||
|
||||
# let us implement our own stuff later, this just handles checking if a thing exists or not.
|
||||
|
||||
|
||||
class CompletionsStatus(Enum):
|
||||
NOT_COMPLETE = "not_complete"
|
||||
INVALID = "invalid"
|
||||
COMPLETE = "complete"
|
||||
|
||||
|
||||
def _cwd_file_matches_previous(root_dir: pathlib.Path, file_name: str) -> bool:
|
||||
"""
|
||||
Compare the file in the current working directory with the file in the target root.
|
||||
|
||||
Returns true if they match (meaning continuation is possible), false otherwise.
|
||||
|
||||
Should do byte-by-byte comparison
|
||||
|
||||
:param cwd_file_name: the file name in the current working directory
|
||||
:param root_file_name: the file name in the target root
|
||||
:return: True if the files match, False otherwise
|
||||
"""
|
||||
current_file = pathlib.Path.cwd() / file_name
|
||||
copied_file = root_dir / file_name
|
||||
|
||||
result = filecmp.cmp(current_file, copied_file, shallow=False)
|
||||
_logger.debug(f"Compared {current_file} with {copied_file}, got {result}")
|
||||
return result
|
||||
|
||||
|
||||
def check_completion_file(config: kalpaa.Config, filename: str) -> CompletionsStatus:
|
||||
"""
|
||||
Check if the completion file exists for a given filename.
|
||||
|
||||
:param config: the config object
|
||||
:param filename: the filename to check
|
||||
:return: the completion status
|
||||
"""
|
||||
if not config.general_config.check_completions:
|
||||
_logger.debug("Not checking completions")
|
||||
return CompletionsStatus.NOT_COMPLETE
|
||||
|
||||
root_dir = config.general_config.root_directory
|
||||
completions_dir = root_dir / COMPLETIONS_DIR
|
||||
|
||||
# completions_dir.mkdir(exist_ok=True, parents=True)
|
||||
if not completions_dir.is_dir():
|
||||
_logger.debug(
|
||||
f"Completions dir {completions_dir=} does not exist and it should, invalid!"
|
||||
)
|
||||
return CompletionsStatus.INVALID
|
||||
|
||||
complete_file = completions_dir / filename
|
||||
if complete_file.exists():
|
||||
_logger.info(f"Found {complete_file}, exiting")
|
||||
return CompletionsStatus.COMPLETE
|
||||
else:
|
||||
_logger.info(f"Did not find {complete_file}, continuing")
|
||||
return CompletionsStatus.NOT_COMPLETE
|
||||
|
||||
|
||||
def set_completion_file(config: kalpaa.Config, filename: str):
|
||||
"""
|
||||
Set the completion file for a given filename.
|
||||
|
||||
:param config: the config object
|
||||
:param filename: the filename to set
|
||||
"""
|
||||
if not config.general_config.check_completions:
|
||||
_logger.debug("Not checking completions or setting them")
|
||||
return
|
||||
root_dir = config.general_config.root_directory
|
||||
completions_dir = root_dir / COMPLETIONS_DIR
|
||||
completions_dir.mkdir(exist_ok=True, parents=True)
|
||||
complete_file = completions_dir / filename
|
||||
complete_file.touch()
|
||||
_logger.info(f"Set {complete_file}")
|
||||
|
||||
|
||||
def check_initial_completions(
|
||||
config_file: str, config: kalpaa.Config
|
||||
) -> CompletionsStatus:
|
||||
"""
|
||||
Check if the completion files exist.
|
||||
|
||||
First check if the out dir has been created.
|
||||
If not, then we can run as normal.
|
||||
|
||||
If the out dir exists, check whether the config file matches the one we are using.
|
||||
If not, we have an invalid case and should error (don't want to change settings when resuming!).
|
||||
|
||||
Finally, check whether a kalpaa.complete file exists, and if so then exit.
|
||||
"""
|
||||
|
||||
root_dir = config.general_config.root_directory
|
||||
_logger.debug(f"Checking completions for {root_dir=}")
|
||||
|
||||
if not config.general_config.check_completions:
|
||||
_logger.debug("Not checking completions")
|
||||
return CompletionsStatus.NOT_COMPLETE
|
||||
if not root_dir.is_dir():
|
||||
_logger.debug(f"Root dir {root_dir} does not exist, continuing")
|
||||
return CompletionsStatus.NOT_COMPLETE
|
||||
|
||||
# check if the config file matches
|
||||
|
||||
files_to_check = [
|
||||
config.general_config.indexes_json_name,
|
||||
config.general_config.dots_json_name,
|
||||
config_file,
|
||||
]
|
||||
|
||||
for file in files_to_check:
|
||||
if (root_dir / file).exists():
|
||||
_logger.info(f"Checking {file}, which exists")
|
||||
if not _cwd_file_matches_previous(root_dir, file):
|
||||
_logger.error(f"Config file {file} does not match copied config")
|
||||
return CompletionsStatus.INVALID
|
||||
else:
|
||||
_logger.debug(
|
||||
f"Config file {file} does not exist, expect it will be created this run"
|
||||
)
|
||||
|
||||
completions_dir = root_dir / COMPLETIONS_DIR
|
||||
completions_dir.mkdir(exist_ok=True, parents=True)
|
||||
complete_file = completions_dir / KALPAA_COMPLETE
|
||||
if complete_file.exists():
|
||||
_logger.info(f"Found {complete_file}, exiting")
|
||||
return CompletionsStatus.COMPLETE
|
||||
else:
|
||||
_logger.info(f"Did not find {complete_file}, continuing")
|
||||
return CompletionsStatus.NOT_COMPLETE
|
@ -4,9 +4,11 @@ from kalpaa.config.config import (
|
||||
GeneralConfig,
|
||||
TantriConfig,
|
||||
GenerationConfig,
|
||||
DefaultModelParamConfig,
|
||||
DeepdogConfig,
|
||||
Config,
|
||||
ReducedModelParams,
|
||||
OVERRIDE_MEASUREMENT_DIR_NAME,
|
||||
)
|
||||
from kalpaa.config.config_reader import (
|
||||
read_config_dict,
|
||||
@ -21,6 +23,7 @@ __all__ = [
|
||||
"GeneralConfig",
|
||||
"TantriConfig",
|
||||
"GenerationConfig",
|
||||
"DefaultModelParamConfig",
|
||||
"DeepdogConfig",
|
||||
"Config",
|
||||
"ReducedModelParams",
|
||||
@ -28,4 +31,5 @@ __all__ = [
|
||||
"serialize_config",
|
||||
"read_config",
|
||||
"read_general_config_dict",
|
||||
"OVERRIDE_MEASUREMENT_DIR_NAME",
|
||||
]
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import deepdog.indexify
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, asdict
|
||||
import typing
|
||||
import tantri.dipoles.types
|
||||
import pathlib
|
||||
@ -10,121 +10,6 @@ import logging
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MeasurementTypeEnum(Enum):
|
||||
POTENTIAL = "electric-potential"
|
||||
X_ELECTRIC_FIELD = "x-electric-field"
|
||||
|
||||
|
||||
class SkipToStage(IntEnum):
|
||||
# shouldn't need this lol
|
||||
STAGE_01 = 0
|
||||
STAGE_02 = 1
|
||||
STAGE_03 = 2
|
||||
STAGE_04 = 3
|
||||
|
||||
|
||||
# Copy over some random constants to see if they're ever reused
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneralConfig:
|
||||
dots_json_name: str = "dots.json"
|
||||
indexes_json_name: str = "indexes.json"
|
||||
out_dir_name: str = "out"
|
||||
log_pattern: str = (
|
||||
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
)
|
||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||
|
||||
mega_merged_name: str = "mega_merged_coalesced.csv"
|
||||
mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv"
|
||||
|
||||
skip_to_stage: typing.Optional[int] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TantriConfig:
|
||||
index_seed_starter: int = 31415
|
||||
num_seeds: int = 100
|
||||
delta_t: float = 0.05
|
||||
num_iterations: int = 100000
|
||||
# sample_rate = 10
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationConfig:
|
||||
# Interact with indexes.json, probably should be a subset
|
||||
counts: typing.Sequence[int] = field(default_factory=lambda: [1, 10])
|
||||
orientations: typing.Sequence[tantri.dipoles.types.Orientation] = field(
|
||||
default_factory=lambda: [
|
||||
tantri.dipoles.types.Orientation.RANDOM,
|
||||
tantri.dipoles.types.Orientation.Z,
|
||||
tantri.dipoles.types.Orientation.XY,
|
||||
]
|
||||
)
|
||||
# TODO: what's replica here?
|
||||
num_replicas: int = 3
|
||||
|
||||
# the above three can be overrided with manually specified configurations
|
||||
override_dipole_configs: typing.Optional[
|
||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||
] = None
|
||||
|
||||
tantri_configs: typing.List[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
)
|
||||
|
||||
num_bin_time_series: int = 25
|
||||
bin_log_width: float = 0.25
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeepdogConfig:
|
||||
"""
|
||||
Class that holds all of the computational parameters
|
||||
"""
|
||||
|
||||
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
|
||||
target_success: int = 1000
|
||||
max_monte_carlo_cycles_steps: int = 20
|
||||
# Whether to use a log log cost function
|
||||
use_log_noise: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
generation_config: GenerationConfig = GenerationConfig()
|
||||
general_config: GeneralConfig = GeneralConfig()
|
||||
deepdog_config: DeepdogConfig = DeepdogConfig()
|
||||
|
||||
def absify(self, filename: str) -> pathlib.Path:
|
||||
ret = (self.general_config.root_directory / filename).resolve()
|
||||
_logger.debug(f"Absifying {filename=}, for root directory {self.general_config.root_directory}, geting {ret}")
|
||||
return ret
|
||||
|
||||
def get_out_dir_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.out_dir_name)
|
||||
|
||||
def get_dots_json_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.dots_json_name)
|
||||
|
||||
def indexifier(self) -> deepdog.indexify.Indexifier:
|
||||
with self.absify(self.general_config.indexes_json_name).open(
|
||||
"r"
|
||||
) as indexify_json_file:
|
||||
indexify_spec = json.load(indexify_json_file)
|
||||
indexify_data = indexify_spec["indexes"]
|
||||
if "seed_spec" in indexify_spec:
|
||||
seed_spec = indexify_spec["seed_spec"]
|
||||
indexify_data[seed_spec["field_name"]] = list(
|
||||
range(seed_spec["num_seeds"])
|
||||
)
|
||||
|
||||
_logger.info(f"loading indexifier with data {indexify_data=}")
|
||||
return deepdog.indexify.Indexifier(indexify_data)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReducedModelParams:
|
||||
"""
|
||||
@ -161,3 +46,161 @@ class ReducedModelParams:
|
||||
"generation_seed": seed,
|
||||
}
|
||||
return output_dict
|
||||
|
||||
|
||||
class MeasurementTypeEnum(Enum):
|
||||
POTENTIAL = "electric-potential"
|
||||
X_ELECTRIC_FIELD = "x-electric-field"
|
||||
|
||||
|
||||
class SkipToStage(IntEnum):
|
||||
# shouldn't need this lol
|
||||
STAGE_01 = 0
|
||||
STAGE_02 = 1
|
||||
STAGE_03 = 2
|
||||
STAGE_04 = 3
|
||||
|
||||
|
||||
OVERRIDE_MEASUREMENT_DIR_NAME = "override_measurements"
|
||||
# Copy over some random constants to see if they're ever reused
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneralConfig:
|
||||
dots_json_name: str = "dots.json"
|
||||
indexes_json_name: str = "indexes.json"
|
||||
out_dir_name: str = "out"
|
||||
log_pattern: str = (
|
||||
"%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s"
|
||||
)
|
||||
measurement_type: MeasurementTypeEnum = MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
root_directory: pathlib.Path = pathlib.Path.cwd()
|
||||
|
||||
mega_merged_name: str = "mega_merged_coalesced.csv"
|
||||
mega_merged_inferenced_name: str = "mega_merged_coalesced_inferenced.csv"
|
||||
|
||||
skip_to_stage: typing.Optional[int] = None
|
||||
|
||||
# if true check for existence of completion sentinel files before running
|
||||
check_completions: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DefaultModelParamConfig:
|
||||
x_min: float = -20
|
||||
x_max: float = 20
|
||||
y_min: float = -10
|
||||
y_max: float = 10
|
||||
z_min: float = 5
|
||||
z_max: float = 6.5
|
||||
w_log_min: float = -5
|
||||
w_log_max: float = 1
|
||||
|
||||
def reduced_model_params(self, **kwargs) -> ReducedModelParams:
|
||||
self_params = asdict(self)
|
||||
merged = {**self_params, **kwargs}
|
||||
return ReducedModelParams(**merged)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TantriConfig:
|
||||
index_seed_starter: int = 31415
|
||||
num_seeds: int = 100
|
||||
delta_t: float = 0.05
|
||||
num_iterations: int = 100000
|
||||
# sample_rate = 10
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationConfig:
|
||||
# Interact with indexes.json, probably should be a subset
|
||||
counts: typing.Sequence[int] = field(default_factory=lambda: [1, 10])
|
||||
orientations: typing.Sequence[tantri.dipoles.types.Orientation] = field(
|
||||
default_factory=lambda: [
|
||||
tantri.dipoles.types.Orientation.RANDOM,
|
||||
tantri.dipoles.types.Orientation.Z,
|
||||
tantri.dipoles.types.Orientation.XY,
|
||||
]
|
||||
)
|
||||
# TODO: what's replica here?
|
||||
num_replicas: int = 3
|
||||
|
||||
# the above three can be overrided with manually specified configurations
|
||||
override_dipole_configs: typing.Optional[
|
||||
typing.Mapping[str, typing.Sequence[tantri.dipoles.types.DipoleTO]]
|
||||
] = None
|
||||
|
||||
override_measurement_filesets: typing.Optional[
|
||||
typing.Mapping[str, typing.Sequence[str]]
|
||||
] = None
|
||||
|
||||
tantri_configs: typing.List[TantriConfig] = field(
|
||||
default_factory=lambda: [TantriConfig()]
|
||||
)
|
||||
|
||||
num_bin_time_series: int = 25
|
||||
bin_log_width: float = 0.25
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeepdogConfig:
|
||||
"""
|
||||
Class that holds all of the computational parameters
|
||||
"""
|
||||
|
||||
costs_to_try: typing.Sequence[float] = field(default_factory=lambda: [10, 1, 0.1])
|
||||
target_success: int = 1000
|
||||
max_monte_carlo_cycles_steps: int = 20
|
||||
# Whether to use a log log cost function
|
||||
use_log_noise: bool = False
|
||||
|
||||
# Manually specifying which dots to use
|
||||
# Outer layer is multiple configurations, within that is which dots to combine, then the inner layer is to distinguish single dots and pairs.
|
||||
# example:
|
||||
# [
|
||||
# [ ["dot1"]], # first one is to use just dot1
|
||||
# [ ["dot1"], ["dot2"] ] # second one is to use dot1 and dot2
|
||||
# [ ["dot1", "dot2"] ] # third one is to use dot1 and dot2 as a pair
|
||||
# ]
|
||||
manual_dot_seeds: typing.Optional[
|
||||
typing.Sequence[typing.Sequence[typing.Sequence[str]]]
|
||||
] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Config:
|
||||
generation_config: GenerationConfig = GenerationConfig()
|
||||
general_config: GeneralConfig = GeneralConfig()
|
||||
deepdog_config: DeepdogConfig = DeepdogConfig()
|
||||
default_model_param_config: DefaultModelParamConfig = DefaultModelParamConfig()
|
||||
|
||||
def absify(self, filename: str) -> pathlib.Path:
|
||||
ret = (self.general_config.root_directory / filename).resolve()
|
||||
_logger.debug(
|
||||
f"Absifying {filename=}, for root directory {self.general_config.root_directory}, geting {ret}"
|
||||
)
|
||||
return ret
|
||||
|
||||
def get_out_dir_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.out_dir_name)
|
||||
|
||||
def get_dots_json_path(self) -> pathlib.Path:
|
||||
return self.absify(self.general_config.dots_json_name)
|
||||
|
||||
def get_override_dir_path(self) -> pathlib.Path:
|
||||
return self.absify(OVERRIDE_MEASUREMENT_DIR_NAME)
|
||||
|
||||
def indexifier(self) -> deepdog.indexify.Indexifier:
|
||||
with self.absify(self.general_config.indexes_json_name).open(
|
||||
"r"
|
||||
) as indexify_json_file:
|
||||
indexify_spec = json.load(indexify_json_file)
|
||||
indexify_data = indexify_spec["indexes"]
|
||||
if "seed_spec" in indexify_spec:
|
||||
seed_spec = indexify_spec["seed_spec"]
|
||||
indexify_data[seed_spec["field_name"]] = list(
|
||||
range(seed_spec["num_seeds"])
|
||||
)
|
||||
|
||||
_logger.info(f"loading indexifier with data {indexify_data=}")
|
||||
return deepdog.indexify.Indexifier(indexify_data)
|
||||
|
0
kalpaa/py.typed
Normal file
0
kalpaa/py.typed
Normal file
@ -1,3 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# useful for measurementgroup which is a type that has itself in method signatures, avoids having to manually specify the typehint as a string
|
||||
|
||||
import re
|
||||
import numpy
|
||||
import dataclasses
|
||||
@ -9,9 +13,13 @@ import csv
|
||||
import deepdog.direct_monte_carlo.dmc_filters
|
||||
import deepdog.direct_monte_carlo.compose_filter
|
||||
import deepdog.direct_monte_carlo.cost_function_filter
|
||||
import pdme.util.fast_nonlocal_spectrum
|
||||
|
||||
# import tantri.cli
|
||||
|
||||
from kalpaa.config import MeasurementTypeEnum
|
||||
import kalpaa.common.angles
|
||||
|
||||
import pdme
|
||||
import pdme.util.fast_v_calc
|
||||
import pdme.measurement
|
||||
@ -23,26 +31,164 @@ X_ELECTRIC_FIELD = "Ex"
|
||||
POTENTIAL = "V"
|
||||
|
||||
|
||||
def short_string_to_measurement_type(short_string: str) -> MeasurementTypeEnum:
|
||||
if short_string == X_ELECTRIC_FIELD:
|
||||
return MeasurementTypeEnum.X_ELECTRIC_FIELD
|
||||
elif short_string == POTENTIAL:
|
||||
return MeasurementTypeEnum.POTENTIAL
|
||||
else:
|
||||
raise ValueError(f"Could not find {short_string=}")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Measurement:
|
||||
dot_measurement: pdme.measurement.DotMeasurement
|
||||
dot_measurement: typing.Optional[pdme.measurement.DotMeasurement]
|
||||
stdev: float
|
||||
dot_pair_measurement: typing.Optional[pdme.measurement.DotPairMeasurement] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MeasurementGroup:
|
||||
_measurements: typing.Sequence[Measurement]
|
||||
_measurement_type: MeasurementTypeEnum
|
||||
_using_pairs: bool = dataclasses.field(init=False, default=False)
|
||||
|
||||
def validate(self):
|
||||
if not self._measurements:
|
||||
raise ValueError("Cannot have an empty measurement group")
|
||||
using_pairs = any(
|
||||
m.dot_pair_measurement is not None for m in self._measurements
|
||||
)
|
||||
using_singles = any(m.dot_measurement is not None for m in self._measurements)
|
||||
if using_pairs and using_singles:
|
||||
raise ValueError(
|
||||
"Cannot mix single and pair measurements in a single measurement group"
|
||||
)
|
||||
if not using_pairs and not using_singles:
|
||||
raise ValueError("Cannot have a measurement group with no measurements")
|
||||
self._using_pairs = using_pairs
|
||||
|
||||
def add(self, other: MeasurementGroup) -> MeasurementGroup:
|
||||
|
||||
if other._measurement_type != self._measurement_type:
|
||||
raise ValueError(
|
||||
f"Cannot add {other._measurement_type=} to {self._measurement_type=}, as they have different measurement types"
|
||||
)
|
||||
|
||||
# this is probably not conformant to the ideal contract for typing.Sequence
|
||||
new_measurements = [*self._measurements, *other._measurements]
|
||||
|
||||
return MeasurementGroup(new_measurements, self._measurement_type)
|
||||
|
||||
def _meas_array(self) -> numpy.ndarray:
|
||||
if self._using_pairs:
|
||||
return numpy.array(
|
||||
[
|
||||
m.dot_pair_measurement.v
|
||||
for m in self._measurements
|
||||
if m.dot_pair_measurement is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
return numpy.array(
|
||||
[
|
||||
m.dot_measurement.v
|
||||
for m in self._measurements
|
||||
if m.dot_measurement is not None
|
||||
]
|
||||
)
|
||||
|
||||
def _input_array(self) -> numpy.ndarray:
|
||||
if self._using_pairs:
|
||||
return pdme.measurement.input_types.dot_pair_inputs_to_array(
|
||||
[
|
||||
(
|
||||
m.dot_pair_measurement.r1,
|
||||
m.dot_pair_measurement.r2,
|
||||
m.dot_pair_measurement.f,
|
||||
)
|
||||
for m in self._measurements
|
||||
if m.dot_pair_measurement is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
return pdme.measurement.input_types.dot_inputs_to_array(
|
||||
[
|
||||
(m.dot_measurement.r, m.dot_measurement.f)
|
||||
for m in self._measurements
|
||||
if m.dot_measurement is not None
|
||||
]
|
||||
)
|
||||
|
||||
def _stdev_array(self) -> numpy.ndarray:
|
||||
return numpy.array([m.stdev for m in self._measurements])
|
||||
|
||||
def cost_function(self):
|
||||
self.validate()
|
||||
meas_array = self._meas_array()
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
input_array = self._input_array()
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return CostFunction(self._measurement_type, input_array, meas_array)
|
||||
|
||||
def stdev_cost_function(
|
||||
self,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
self.validate()
|
||||
stdev_array = self._stdev_array()
|
||||
|
||||
meas_array = self._meas_array()
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
input_array = self._input_array()
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return StDevUsingCostFunction(
|
||||
self._measurement_type,
|
||||
input_array,
|
||||
meas_array,
|
||||
stdev_array,
|
||||
log_noise=use_log_noise and not self._using_pairs,
|
||||
use_pair_measurement=self._using_pairs,
|
||||
)
|
||||
|
||||
|
||||
class CostFunction:
|
||||
def __init__(self, measurement_type, dot_inputs_array, actual_measurement_array):
|
||||
def __init__(
|
||||
self,
|
||||
measurement_type: MeasurementTypeEnum,
|
||||
dot_inputs_array: numpy.ndarray,
|
||||
actual_measurement_array: numpy.ndarray,
|
||||
use_pair_measurement: bool = False,
|
||||
):
|
||||
"""
|
||||
Construct a cost function that uses the measurements.
|
||||
|
||||
:param measurement_type: The type of measurement we're using.
|
||||
:param dot_inputs_array: The array of dot inputs.
|
||||
:param actual_measurement_array: The actual measurements.
|
||||
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
||||
"""
|
||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||
self.measurement_type = measurement_type
|
||||
self.dot_inputs_array = dot_inputs_array
|
||||
self.actual_measurement_array = actual_measurement_array
|
||||
self.actual_measurement_array2 = actual_measurement_array**2
|
||||
self.use_pair_measurement = use_pair_measurement
|
||||
if self.use_pair_measurement:
|
||||
raise NotImplementedError("Pair measurements are not yet supported")
|
||||
|
||||
def __call__(self, dipoles_to_test):
|
||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
||||
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
elif self.measurement_type == POTENTIAL:
|
||||
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
@ -55,12 +201,23 @@ class CostFunction:
|
||||
class StDevUsingCostFunction:
|
||||
def __init__(
|
||||
self,
|
||||
measurement_type,
|
||||
dot_inputs_array,
|
||||
actual_measurement_array,
|
||||
actual_stdev_array,
|
||||
measurement_type: MeasurementTypeEnum,
|
||||
dot_inputs_array: numpy.ndarray,
|
||||
actual_measurement_array: numpy.ndarray,
|
||||
actual_stdev_array: numpy.ndarray,
|
||||
log_noise: bool = False,
|
||||
use_pair_measurement: bool = False,
|
||||
):
|
||||
"""
|
||||
Construct a cost function that uses the standard deviation of the measurements.
|
||||
|
||||
:param measurement_type: The type of measurement we're using.
|
||||
:param dot_inputs_array: The array of dot inputs. (may be actually inputses for pair measurements)
|
||||
:param actual_measurement_array: The actual measurements.
|
||||
:param actual_stdev_array: The actual standard deviations.
|
||||
:param use_pair_measurement: Whether to use pair measurements. (default false)
|
||||
:param log_noise: Whether to use log noise. (default false but we should probably use it)
|
||||
"""
|
||||
_logger.info(f"Cost function with measurement type of {measurement_type}")
|
||||
self.measurement_type = measurement_type
|
||||
self.dot_inputs_array = dot_inputs_array
|
||||
@ -75,31 +232,58 @@ class StDevUsingCostFunction:
|
||||
numpy.log(self.actual_stdev_array + self.actual_measurement_array)
|
||||
- numpy.log(self.actual_measurement_array)
|
||||
) ** 2
|
||||
# if self.use_log_noise:
|
||||
# _logger.debug("remove these debugs later")
|
||||
# _logger.debug(self.actual_measurement_array)
|
||||
# _logger.debug(self.actual_stdev_array)
|
||||
# _logger.debug(self.log_actual)
|
||||
# _logger.debug(self.log_denom2)
|
||||
self.use_pair_measurement = use_pair_measurement
|
||||
|
||||
def __call__(self, dipoles_to_test):
|
||||
if self.measurement_type == X_ELECTRIC_FIELD:
|
||||
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
elif self.measurement_type == POTENTIAL:
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
if self.use_pair_measurement:
|
||||
# We're going to just use phase data, rather than correlation data for now.
|
||||
# We'll probably need to do some re-architecting later to get the phase vs correlation flag to propagate here
|
||||
# if self.use_log_noise:
|
||||
# _logger.info("No log noise for phase data, which is wrapped but linear")
|
||||
|
||||
if self.use_log_noise:
|
||||
diffs = ((numpy.log(vals) - self.log_actual) ** 2) / self.log_denom2
|
||||
else:
|
||||
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
||||
vals = pdme.util.fast_nonlocal_spectrum.fast_s_spin_qubit_tarucha_nonlocal_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
||||
vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
|
||||
# _logger.debug(f"Got {vals=}")
|
||||
|
||||
sign_vals = pdme.util.fast_nonlocal_spectrum.signarg(vals)
|
||||
|
||||
# _logger.debug(f"Got {sign_vals=}")
|
||||
diffs = (
|
||||
(vals - self.actual_measurement_array) ** 2
|
||||
) / self.actual_stdev_array2
|
||||
kalpaa.common.angles.shortest_angular_distance(
|
||||
sign_vals, self.actual_measurement_array
|
||||
)
|
||||
** 2
|
||||
)
|
||||
# _logger.debug(f"Got {diffs=}")
|
||||
scaled_diffs = diffs / self.actual_stdev_array2
|
||||
# _logger.debug(f"Got {scaled_diffs=}")
|
||||
return numpy.sqrt(scaled_diffs.mean(axis=-1))
|
||||
|
||||
return numpy.sqrt(diffs.mean(axis=-1))
|
||||
else:
|
||||
if self.measurement_type == MeasurementTypeEnum.X_ELECTRIC_FIELD:
|
||||
vals = pdme.util.fast_v_calc.fast_efieldxs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
elif self.measurement_type == MeasurementTypeEnum.POTENTIAL:
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||
self.dot_inputs_array, dipoles_to_test
|
||||
)
|
||||
|
||||
if self.use_log_noise:
|
||||
diffs = ((numpy.log(vals) - self.log_actual) ** 2) / self.log_denom2
|
||||
else:
|
||||
diffs = (
|
||||
(vals - self.actual_measurement_array) ** 2
|
||||
) / self.actual_stdev_array2
|
||||
|
||||
return numpy.sqrt(diffs.mean(axis=-1))
|
||||
|
||||
|
||||
# the key for frequencies in what we return
|
||||
@ -125,19 +309,79 @@ def _reshape_dots_dict(dots_dict: typing.Sequence[typing.Dict]) -> typing.Dict:
|
||||
|
||||
|
||||
BINNED_HEADER_REGEX = r"\s*APSD_(?P<measurement_type>\w+)_(?P<dot_name>\w+)_(?P<summary_stat>mean|stdev)\s*"
|
||||
PAIR_MEASUREMENT_BINNED_HEADER_REGEX = r"\s*CPSD_(?P<cpsd_type>correlation|phase)_(?P<measurement_type>\w+)_(?P<dot_name>\w+)_(?P<dot_name2>\w+)_(?P<summary_stat>mean|stdev)\s*"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ParsedBinHeader:
|
||||
original_field: str
|
||||
measurement_type: str
|
||||
dot_name: str
|
||||
measurement_type: MeasurementTypeEnum
|
||||
summary_stat: str
|
||||
dot_name: str
|
||||
# only used for pair measurements
|
||||
dot_name2: typing.Optional[str] = None
|
||||
cpsd_type: typing.Optional[typing.Literal["correlation", "phase"]] = None
|
||||
|
||||
@property
|
||||
def pair(self) -> bool:
|
||||
return self.dot_name2 is not None
|
||||
|
||||
|
||||
def _parse_bin_header(field: str) -> typing.Optional[ParsedBinHeader]:
|
||||
"""
|
||||
Parse a binned header field into a ParsedBinHeader object.
|
||||
|
||||
Return None if the field does not match the expected format (and thus no match).
|
||||
"""
|
||||
if (match := re.match(BINNED_HEADER_REGEX, field)) is not None:
|
||||
match_groups = match.groupdict()
|
||||
return ParsedBinHeader(
|
||||
original_field=field,
|
||||
measurement_type=short_string_to_measurement_type(
|
||||
match_groups["measurement_type"]
|
||||
),
|
||||
dot_name=match_groups["dot_name"],
|
||||
summary_stat=match_groups["summary_stat"],
|
||||
)
|
||||
elif (
|
||||
pair_match := re.match(PAIR_MEASUREMENT_BINNED_HEADER_REGEX, field)
|
||||
) is not None:
|
||||
groups = pair_match.groupdict()
|
||||
cpsd_type = typing.cast(
|
||||
typing.Literal["correlation", "phase"], groups["cpsd_type"]
|
||||
)
|
||||
return ParsedBinHeader(
|
||||
original_field=field,
|
||||
measurement_type=short_string_to_measurement_type(
|
||||
groups["measurement_type"]
|
||||
),
|
||||
dot_name=groups["dot_name"],
|
||||
dot_name2=groups["dot_name2"],
|
||||
cpsd_type=cpsd_type,
|
||||
summary_stat=groups["summary_stat"],
|
||||
)
|
||||
else:
|
||||
_logger.debug(f"Could not parse {field=}")
|
||||
return None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CSV_BinnedData:
|
||||
measurement_type: MeasurementTypeEnum
|
||||
single_dot_dict: typing.Dict[str, typing.Any]
|
||||
pair_dot_dict: typing.Dict[typing.Tuple[str, str], typing.Any]
|
||||
freqs: typing.Sequence[float]
|
||||
|
||||
|
||||
def read_bin_csv(
|
||||
csv_file: pathlib.Path,
|
||||
) -> typing.Tuple[str, typing.Dict[str, typing.Any]]:
|
||||
) -> CSV_BinnedData:
|
||||
"""
|
||||
Read a binned csv file and return the measurement type and the binned data.
|
||||
|
||||
:param csv_file: The csv file to read.
|
||||
:return: A tuple of the measurement type and the binned data.
|
||||
"""
|
||||
|
||||
measurement_type = None
|
||||
_logger.info(f"Assuming measurement type is {measurement_type} for now")
|
||||
@ -148,7 +392,7 @@ def read_bin_csv(
|
||||
|
||||
if fields is None:
|
||||
raise ValueError(
|
||||
f"Really wanted our fields for fiel {file=} to be non-None, but they're None"
|
||||
f"Really wanted our fields for file {file=} to be non-None, but they're None"
|
||||
)
|
||||
freq_field = fields[0]
|
||||
|
||||
@ -156,35 +400,49 @@ def read_bin_csv(
|
||||
_logger.debug(f"Going to read frequencies from {freq_field=}")
|
||||
|
||||
parsed_headers = {}
|
||||
freq_list = []
|
||||
aggregated_dict: typing.Dict[str, typing.Any] = {
|
||||
RETURNED_FREQUENCIES_KEY: []
|
||||
}
|
||||
pair_aggregated_dict: typing.Dict[typing.Tuple[str, str], typing.Any] = {}
|
||||
|
||||
for field in remaining_fields:
|
||||
match = re.match(BINNED_HEADER_REGEX, field)
|
||||
if match is None:
|
||||
parsed_header = _parse_bin_header(field)
|
||||
if parsed_header is None:
|
||||
_logger.warning(f"Could not parse {field=}")
|
||||
continue
|
||||
match_groups = match.groupdict()
|
||||
parsed_header = ParsedBinHeader(
|
||||
field,
|
||||
match_groups["measurement_type"],
|
||||
match_groups["dot_name"],
|
||||
match_groups["summary_stat"],
|
||||
)
|
||||
parsed_headers[field] = parsed_header
|
||||
|
||||
if parsed_header.dot_name not in aggregated_dict:
|
||||
aggregated_dict[parsed_header.dot_name] = {}
|
||||
# Get our dictionary structures set up by initialising empty dictionaries for each new field as we go
|
||||
if parsed_header.pair:
|
||||
if parsed_header.dot_name2 is None:
|
||||
raise ValueError(
|
||||
f"Pair measurement {field=} has no dot_name2, but it should"
|
||||
)
|
||||
dot_names = (parsed_header.dot_name, parsed_header.dot_name2)
|
||||
if dot_names not in pair_aggregated_dict:
|
||||
pair_aggregated_dict[dot_names] = {}
|
||||
|
||||
if (
|
||||
parsed_header.summary_stat
|
||||
not in aggregated_dict[parsed_header.dot_name]
|
||||
):
|
||||
aggregated_dict[parsed_header.dot_name][
|
||||
if (
|
||||
parsed_header.summary_stat
|
||||
] = []
|
||||
not in pair_aggregated_dict[dot_names]
|
||||
):
|
||||
pair_aggregated_dict[dot_names][parsed_header.summary_stat] = []
|
||||
|
||||
else:
|
||||
if parsed_header.dot_name not in aggregated_dict:
|
||||
aggregated_dict[parsed_header.dot_name] = {}
|
||||
|
||||
if (
|
||||
parsed_header.summary_stat
|
||||
not in aggregated_dict[parsed_header.dot_name]
|
||||
):
|
||||
aggregated_dict[parsed_header.dot_name][
|
||||
parsed_header.summary_stat
|
||||
] = []
|
||||
|
||||
# Realistically we'll always have the same measurement type, but this warning may help us catch out cases where this didn't happen correctly
|
||||
# We should only need to set it once, so the fact we keep checking is more about catching errors than anything else
|
||||
if measurement_type is not None:
|
||||
if measurement_type != parsed_header.measurement_type:
|
||||
_logger.warning(
|
||||
@ -197,20 +455,39 @@ def read_bin_csv(
|
||||
|
||||
for row in reader:
|
||||
# _logger.debug(f"Got {row=}")
|
||||
freq_list.append(float(row[freq_field].strip()))
|
||||
# don't need to set, but keep for legacy
|
||||
aggregated_dict[RETURNED_FREQUENCIES_KEY].append(
|
||||
float(row[freq_field].strip())
|
||||
)
|
||||
for field, parsed_header in parsed_headers.items():
|
||||
value = float(row[field].strip())
|
||||
aggregated_dict[parsed_header.dot_name][
|
||||
parsed_header.summary_stat
|
||||
].append(value)
|
||||
if parsed_header.pair:
|
||||
if parsed_header.dot_name2 is None:
|
||||
raise ValueError(
|
||||
f"Pair measurement {field=} has no dot_name2, but it should"
|
||||
)
|
||||
value = float(row[field].strip())
|
||||
dot_names = (parsed_header.dot_name, parsed_header.dot_name2)
|
||||
pair_aggregated_dict[dot_names][
|
||||
parsed_header.summary_stat
|
||||
].append(value)
|
||||
else:
|
||||
value = float(row[field].strip())
|
||||
aggregated_dict[parsed_header.dot_name][
|
||||
parsed_header.summary_stat
|
||||
].append(value)
|
||||
|
||||
if measurement_type is None:
|
||||
raise ValueError(
|
||||
f"For some reason {measurement_type=} is None? We want to know our measurement type."
|
||||
)
|
||||
return measurement_type, aggregated_dict
|
||||
|
||||
return CSV_BinnedData(
|
||||
measurement_type=measurement_type,
|
||||
single_dot_dict=aggregated_dict,
|
||||
freqs=freq_list,
|
||||
pair_dot_dict=pair_aggregated_dict,
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.error(
|
||||
f"Had a bad time reading the binned data {csv_file}, sorry.", exc_info=e
|
||||
@ -222,27 +499,32 @@ def read_bin_csv(
|
||||
class BinnedData:
|
||||
dots_dict: typing.Dict
|
||||
csv_dict: typing.Dict[str, typing.Any]
|
||||
measurement_type: str
|
||||
measurement_type: MeasurementTypeEnum
|
||||
pair_dict: typing.Dict[typing.Tuple[str, str], typing.Any]
|
||||
freq_list: typing.Sequence[float]
|
||||
|
||||
# we're ignoring stdevs for the current moment, as in the calculator single_dipole_matches.py script.
|
||||
def _dot_to_measurement(self, dot_name: str) -> typing.Sequence[Measurement]:
|
||||
def _dot_to_measurements(self, dot_name: str) -> MeasurementGroup:
|
||||
if dot_name not in self.dots_dict:
|
||||
raise KeyError(f"Could not find {dot_name=} in {self.dots_dict=}")
|
||||
if dot_name not in self.csv_dict:
|
||||
raise KeyError(f"Could not find {dot_name=} in {self.csv_dict=}")
|
||||
|
||||
dot_r = self.dots_dict[dot_name]
|
||||
freqs = self.csv_dict[RETURNED_FREQUENCIES_KEY]
|
||||
freqs = self.freq_list
|
||||
vs = self.csv_dict[dot_name]["mean"]
|
||||
stdevs = self.csv_dict[dot_name]["stdev"]
|
||||
|
||||
return [
|
||||
Measurement(
|
||||
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
|
||||
stdev=stdev,
|
||||
)
|
||||
for f, v, stdev in zip(freqs, vs, stdevs)
|
||||
]
|
||||
return MeasurementGroup(
|
||||
[
|
||||
Measurement(
|
||||
dot_measurement=pdme.measurement.DotMeasurement(f=f, v=v, r=dot_r),
|
||||
stdev=stdev,
|
||||
)
|
||||
for f, v, stdev in zip(freqs, vs, stdevs)
|
||||
],
|
||||
_measurement_type=self.measurement_type,
|
||||
)
|
||||
|
||||
def _dot_to_stdev(self, dot_name: str) -> typing.Sequence[float]:
|
||||
if dot_name not in self.dots_dict:
|
||||
@ -254,60 +536,140 @@ class BinnedData:
|
||||
|
||||
return stdevs
|
||||
|
||||
def measurements(
|
||||
self, dot_names: typing.Sequence[str]
|
||||
) -> typing.Sequence[Measurement]:
|
||||
_logger.debug(f"Constructing measurements for dots {dot_names=}")
|
||||
ret: typing.List[Measurement] = []
|
||||
for dot_name in dot_names:
|
||||
ret.extend(self._dot_to_measurement(dot_name))
|
||||
return ret
|
||||
def _pair_to_measurements(
|
||||
self, dot_pair_name: typing.Tuple[str, str]
|
||||
) -> MeasurementGroup:
|
||||
if dot_pair_name not in self.pair_dict:
|
||||
raise KeyError(f"Could not find {dot_pair_name=} in {self.pair_dict=}")
|
||||
|
||||
def _cost_function(self, measurements: typing.Sequence[Measurement]):
|
||||
dot_measurements = [m.dot_measurement for m in measurements]
|
||||
meas_array = numpy.array([m.v for m in dot_measurements])
|
||||
dot_name1, dot_name2 = dot_pair_name
|
||||
if dot_name1 not in self.dots_dict:
|
||||
raise KeyError(f"Could not find {dot_name1=} in {self.dots_dict=}")
|
||||
if dot_name2 not in self.dots_dict:
|
||||
raise KeyError(f"Could not find {dot_name2=} in {self.dots_dict=}")
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
dot_r1 = self.dots_dict[dot_name1]
|
||||
dot_r2 = self.dots_dict[dot_name2]
|
||||
freqs = self.freq_list
|
||||
vs = self.pair_dict[dot_pair_name]["mean"]
|
||||
stdevs = self.pair_dict[dot_pair_name]["stdev"]
|
||||
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements]
|
||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return CostFunction(self.measurement_type, input_array, meas_array)
|
||||
|
||||
def _stdev_cost_function(
|
||||
self,
|
||||
measurements: typing.Sequence[Measurement],
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
meas_array = numpy.array([m.dot_measurement.v for m in measurements])
|
||||
stdev_array = numpy.array([m.stdev for m in measurements])
|
||||
|
||||
_logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
inputs = [(m.dot_measurement.r, m.dot_measurement.f) for m in measurements]
|
||||
input_array = pdme.measurement.input_types.dot_inputs_to_array(inputs)
|
||||
_logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
return StDevUsingCostFunction(
|
||||
self.measurement_type, input_array, meas_array, stdev_array, use_log_noise
|
||||
return MeasurementGroup(
|
||||
[
|
||||
Measurement(
|
||||
dot_measurement=None,
|
||||
dot_pair_measurement=pdme.measurement.DotPairMeasurement(
|
||||
f=f, v=v, r1=dot_r1, r2=dot_r2
|
||||
),
|
||||
stdev=stdev,
|
||||
)
|
||||
for f, v, stdev in zip(freqs, vs, stdevs)
|
||||
],
|
||||
_measurement_type=self.measurement_type,
|
||||
)
|
||||
|
||||
def cost_function_filter(self, dot_names: typing.Sequence[str], target_cost: float):
|
||||
measurements = self.measurements(dot_names)
|
||||
cost_function = self._cost_function(measurements)
|
||||
def measurements(self, dot_names: typing.Sequence[str]) -> MeasurementGroup:
|
||||
_logger.debug(f"Constructing measurements for dots {dot_names=}")
|
||||
ret = MeasurementGroup([], self.measurement_type)
|
||||
_logger.debug
|
||||
for dot_name in dot_names:
|
||||
ret = ret.add(self._dot_to_measurements(dot_name))
|
||||
return ret
|
||||
|
||||
def pair_measurements(
|
||||
self, dot_pair_names: typing.Sequence[typing.Tuple[str, str]]
|
||||
) -> MeasurementGroup:
|
||||
_logger.debug(f"Constructing measurements for dot pairs {dot_pair_names=}")
|
||||
ret = MeasurementGroup([], self.measurement_type)
|
||||
_logger.debug
|
||||
for dot_pair_name in dot_pair_names:
|
||||
ret = ret.add(self._pair_to_measurements(dot_pair_name))
|
||||
return ret
|
||||
|
||||
# def _cost_function(self, mg: MeasurementGroup):
|
||||
# meas_array = mg.meas_array()
|
||||
|
||||
# _logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
# input_array = mg.input_array()
|
||||
# _logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
# return CostFunction(self.measurement_type, input_array, meas_array)
|
||||
|
||||
# def _stdev_cost_function(
|
||||
# self,
|
||||
# mg: MeasurementGroup,
|
||||
# use_log_noise: bool = False,
|
||||
# ):
|
||||
# stdev_array = mg.stdev_array()
|
||||
|
||||
# meas_array = mg.meas_array()
|
||||
|
||||
# _logger.debug(f"Obtained {meas_array=}")
|
||||
|
||||
# input_array = mg.input_array()
|
||||
# _logger.debug(f"Obtained {input_array=}")
|
||||
|
||||
# return StDevUsingCostFunction(
|
||||
# self.measurement_type,
|
||||
# input_array,
|
||||
# meas_array,
|
||||
# stdev_array,
|
||||
# log_noise=use_log_noise,
|
||||
# )
|
||||
|
||||
def _get_measurement_from_dot_name_or_pair(
|
||||
self,
|
||||
dot_names_or_pairs: typing.Union[
|
||||
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
||||
],
|
||||
) -> MeasurementGroup:
|
||||
"""
|
||||
check if dot_names_or_pairs is a list of strings or a list of tuples of strings, then return the appropriate measurement group
|
||||
"""
|
||||
if isinstance(dot_names_or_pairs[0], str):
|
||||
_logger.debug("first item was a string, assuming we're specifying strings")
|
||||
# we expect all strings, fail if otherwise
|
||||
names = []
|
||||
for dn in dot_names_or_pairs:
|
||||
if not isinstance(dn, str):
|
||||
raise ValueError(f"Expected all strings in {dot_names_or_pairs=}")
|
||||
names.append(dn)
|
||||
_logger.debug(f"Constructing measurements for dots {names=}")
|
||||
return self.measurements(names)
|
||||
else:
|
||||
_logger.debug("trying out pairs")
|
||||
pairs = []
|
||||
for dn in dot_names_or_pairs:
|
||||
if not isinstance(dn, tuple):
|
||||
raise ValueError(f"Expected all tuples in {dot_names_or_pairs=}")
|
||||
pairs.append(dn)
|
||||
_logger.debug(f"Constructing measurements for dot pairs {pairs=}")
|
||||
return self.pair_measurements(pairs)
|
||||
|
||||
def cost_function_filter(
|
||||
self,
|
||||
dot_names_or_pairs: typing.Union[
|
||||
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
||||
],
|
||||
target_cost: float,
|
||||
):
|
||||
measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs)
|
||||
cost_function = measurements.cost_function()
|
||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, target_cost
|
||||
)
|
||||
|
||||
def stdev_cost_function_filter(
|
||||
self,
|
||||
dot_names: typing.Sequence[str],
|
||||
dot_names_or_pairs: typing.Union[
|
||||
typing.Sequence[str], typing.Sequence[typing.Tuple[str, str]]
|
||||
],
|
||||
target_cost: float,
|
||||
use_log_noise: bool = False,
|
||||
):
|
||||
measurements = self.measurements(dot_names)
|
||||
cost_function = self._stdev_cost_function(measurements, use_log_noise)
|
||||
measurements = self._get_measurement_from_dot_name_or_pair(dot_names_or_pairs)
|
||||
cost_function = measurements.stdev_cost_function(use_log_noise=use_log_noise)
|
||||
return deepdog.direct_monte_carlo.cost_function_filter.CostFunctionTargetFilter(
|
||||
cost_function, target_cost
|
||||
)
|
||||
@ -315,22 +677,11 @@ class BinnedData:
|
||||
|
||||
def read_dots_and_binned(json_file: pathlib.Path, csv_file: pathlib.Path) -> BinnedData:
|
||||
dots = read_dots_json(json_file)
|
||||
measurement_type, binned = read_bin_csv(csv_file)
|
||||
csv_data = read_bin_csv(csv_file)
|
||||
return BinnedData(
|
||||
measurement_type=measurement_type, dots_dict=dots, csv_dict=binned
|
||||
measurement_type=csv_data.measurement_type,
|
||||
dots_dict=dots,
|
||||
csv_dict=csv_data.single_dot_dict,
|
||||
freq_list=csv_data.freqs,
|
||||
pair_dict=csv_data.pair_dot_dict,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
print(read_dots_json(pathlib.Path("dots.json")))
|
||||
# print(read_bin_csv(pathlib.Path("binned-0.01-10000-50-12345.csv")))
|
||||
binned_data = read_dots_and_binned(
|
||||
pathlib.Path("dots.json"), pathlib.Path("binned-0.01-10000-50-12345.csv")
|
||||
)
|
||||
_logger.info(binned_data)
|
||||
for entry in binned_data.measurements(["uprise1", "dot1"]):
|
||||
_logger.info(entry)
|
||||
filter = binned_data.cost_function_filter(["uprise1", "dot1"], 0.5)
|
||||
_logger.info(filter)
|
||||
|
@ -10,27 +10,74 @@ import kalpaa.stages.stage04
|
||||
import kalpaa.common
|
||||
import kalpaa.config
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
import kalpaa.completions
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
class Runnable(Protocol):
|
||||
config: kalpaa.Config
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class Completable:
|
||||
def __init__(self, runnable: Runnable, completion_name: str):
|
||||
self.runnable = runnable
|
||||
self.completion_name = completion_name
|
||||
|
||||
def run(self):
|
||||
_logger.info(
|
||||
f"Running {self.runnable} with completion name {self.completion_name}"
|
||||
)
|
||||
completions = kalpaa.completions.check_completion_file(
|
||||
self.runnable.config, self.completion_name
|
||||
)
|
||||
if completions == kalpaa.completions.CompletionsStatus.COMPLETE:
|
||||
_logger.info(f"Skipping {self.completion_name}")
|
||||
return
|
||||
elif completions == kalpaa.completions.CompletionsStatus.INVALID:
|
||||
_logger.error(f"Invalid completion status for {self.completion_name}")
|
||||
raise ValueError(f"Invalid completion status for {self.completion_name}")
|
||||
else:
|
||||
_logger.debug(f"Not completed for {self.completion_name}, running")
|
||||
self.runnable.run()
|
||||
_logger.info(f"Setting completion for {self.completion_name}")
|
||||
kalpaa.completions.set_completion_file(
|
||||
self.runnable.config, self.completion_name
|
||||
)
|
||||
|
||||
|
||||
# try not to use this out side of main or when defining config stuff pls
|
||||
# import numpy
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Runner:
|
||||
class Runner(Runnable):
|
||||
def __init__(self, config: kalpaa.Config):
|
||||
self.config = config
|
||||
_logger.info(f"Initialising runner with {config=}")
|
||||
|
||||
def run(self):
|
||||
|
||||
if self.config.general_config.skip_to_stage is not None:
|
||||
stage01 = Completable(
|
||||
kalpaa.stages.stage01.Stage01Runner(self.config), "stage01.complete"
|
||||
)
|
||||
stage02 = Completable(
|
||||
kalpaa.stages.stage02.Stage02Runner(self.config), "stage02.complete"
|
||||
)
|
||||
stage03 = Completable(
|
||||
kalpaa.stages.stage03.Stage03Runner(self.config), "stage03.complete"
|
||||
)
|
||||
stage04 = Completable(
|
||||
kalpaa.stages.stage04.Stage04Runner(self.config), "stage04.complete"
|
||||
)
|
||||
|
||||
stage01 = kalpaa.stages.stage01.Stage01Runner(self.config)
|
||||
stage02 = kalpaa.stages.stage02.Stage02Runner(self.config)
|
||||
stage03 = kalpaa.stages.stage03.Stage03Runner(self.config)
|
||||
stage04 = kalpaa.stages.stage04.Stage04Runner(self.config)
|
||||
if self.config.general_config.skip_to_stage is not None:
|
||||
|
||||
stages = [stage01, stage02, stage03, stage04]
|
||||
|
||||
@ -44,20 +91,19 @@ class Runner:
|
||||
# standard run, can keep old
|
||||
|
||||
_logger.info("*** Beginning Stage 01 ***")
|
||||
stage01 = kalpaa.stages.stage01.Stage01Runner(self.config)
|
||||
stage01.run()
|
||||
|
||||
_logger.info("*** Beginning Stage 02 ***")
|
||||
stage02 = kalpaa.stages.stage02.Stage02Runner(self.config)
|
||||
stage02.run()
|
||||
|
||||
_logger.info("*** Beginning Stage 03 ***")
|
||||
stage03 = kalpaa.stages.stage03.Stage03Runner(self.config)
|
||||
stage03.run()
|
||||
|
||||
_logger.info("*** Beginning Stage 04 ***")
|
||||
stage04 = kalpaa.stages.stage04.Stage04Runner(self.config)
|
||||
stage04.run()
|
||||
kalpaa.completions.set_completion_file(
|
||||
self.config, kalpaa.completions.KALPAA_COMPLETE
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -73,6 +119,12 @@ def parse_args():
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-stream",
|
||||
action="store_true",
|
||||
help="Log to stream",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
@ -123,15 +175,58 @@ def main():
|
||||
|
||||
_logger.info(skip)
|
||||
|
||||
kalpaa.common.set_up_logging(config, str(root / f"logs/kalpaa.log"))
|
||||
overridden_config = dataclasses.replace(
|
||||
config,
|
||||
general_config=dataclasses.replace(
|
||||
config.general_config, root_directory=root.resolve(), skip_to_stage=skip
|
||||
),
|
||||
)
|
||||
|
||||
_logger.info(f"Root dir is {root}, copying over {config.general_config.indexes_json_name}, {config.general_config.dots_json_name} and {args.config_file}")
|
||||
for file in [config.general_config.indexes_json_name, config.general_config.dots_json_name, args.config_file]:
|
||||
kalpaa.common.set_up_logging(
|
||||
config,
|
||||
log_stream=args.log_stream,
|
||||
log_file=str(root / f"logs/kalpaa_{label}.log"),
|
||||
)
|
||||
|
||||
completions_status = kalpaa.completions.check_initial_completions(
|
||||
args.config_file, overridden_config
|
||||
)
|
||||
if completions_status == kalpaa.completions.CompletionsStatus.COMPLETE:
|
||||
_logger.info("All stages complete, exiting")
|
||||
return
|
||||
elif completions_status == kalpaa.completions.CompletionsStatus.INVALID:
|
||||
_logger.error("Invalid completion status, exiting")
|
||||
raise ValueError("Invalid completion status")
|
||||
|
||||
# otherwise good to go
|
||||
|
||||
_logger.info(
|
||||
f"Root dir is {root}, copying over {overridden_config.general_config.indexes_json_name}, {overridden_config.general_config.dots_json_name} and {args.config_file}"
|
||||
)
|
||||
for file in [
|
||||
overridden_config.general_config.indexes_json_name,
|
||||
overridden_config.general_config.dots_json_name,
|
||||
args.config_file,
|
||||
]:
|
||||
_logger.info(f"Copying {file} to {root}")
|
||||
(root / file).write_text((pathlib.Path.cwd() / file).read_text())
|
||||
|
||||
|
||||
overridden_config = dataclasses.replace(config, general_config=dataclasses.replace(config.general_config, root_directory=root.resolve(), skip_to_stage=skip))
|
||||
if overridden_config.generation_config.override_measurement_filesets is not None:
|
||||
_logger.info(
|
||||
f"Overriding measurements with {overridden_config.generation_config.override_measurement_filesets}"
|
||||
)
|
||||
override_directory = root / kalpaa.config.OVERRIDE_MEASUREMENT_DIR_NAME
|
||||
override_directory.mkdir(exist_ok=True, parents=True)
|
||||
for (
|
||||
key,
|
||||
files,
|
||||
) in overridden_config.generation_config.override_measurement_filesets.items():
|
||||
_logger.info(f"Copying for {key=}, {files} to {override_directory}")
|
||||
for file in files:
|
||||
fileset_dir = override_directory / key
|
||||
fileset_dir.mkdir(exist_ok=True, parents=True)
|
||||
_logger.info(f"Copying {file} to {override_directory}")
|
||||
(fileset_dir / file).write_text((pathlib.Path.cwd() / file).read_text())
|
||||
|
||||
_logger.info(f"Got {config=}")
|
||||
runner = Runner(overridden_config)
|
||||
|
@ -102,7 +102,7 @@ class Stage01Runner:
|
||||
dipoles_json = directory / "dipoles.json"
|
||||
|
||||
with open(config_json, "w") as conf_file:
|
||||
params = kalpaa.ReducedModelParams(
|
||||
params = self.config.default_model_param_config.reduced_model_params(
|
||||
count=count, orientation=tantri.dipoles.types.Orientation(orientation)
|
||||
)
|
||||
_logger.debug(f"Got params {params=}")
|
||||
@ -163,14 +163,6 @@ class Stage01Runner:
|
||||
# config_json = directory / "generation_config.json"
|
||||
dipoles_json = directory / "dipoles.json"
|
||||
|
||||
# with open(config_json, "w") as conf_file:
|
||||
# params = kalpa.ReducedModelParams(
|
||||
# count=count, orientation=tantri.dipoles.types.Orientation(orientation)
|
||||
# )
|
||||
# _logger.debug(f"Got params {params=}")
|
||||
# json.dump(params.config_dict(seed), conf_file)
|
||||
# # json.dump(kalpa.common.model_config_dict(count, orientation, seed), conf_file)
|
||||
|
||||
# the original logic looked like this:
|
||||
# tantri.cli._generate_dipoles(config_json, dipoles_json, (seed, replica, 1))
|
||||
# We're replicating the bit that wrote the dipoles here, but that's a refactor opportunity
|
||||
@ -208,7 +200,27 @@ class Stage01Runner:
|
||||
|
||||
def run(self):
|
||||
seed_index = 0
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
if self.config.generation_config.override_measurement_filesets is not None:
|
||||
for (
|
||||
override_name
|
||||
) in self.config.generation_config.override_measurement_filesets.keys():
|
||||
# don't need to do anything with the files, just create the out dir
|
||||
out = self.config.get_out_dir_path()
|
||||
directory = out / f"{override_name}"
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
elif self.config.generation_config.override_dipole_configs is not None:
|
||||
_logger.debug(
|
||||
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
||||
)
|
||||
for (
|
||||
override_name,
|
||||
override_dipoles,
|
||||
) in self.config.generation_config.override_dipole_configs.items():
|
||||
self.generate_override_dipole(
|
||||
seed_index, override_name, override_dipoles
|
||||
)
|
||||
else:
|
||||
# should be by default
|
||||
_logger.debug("no override needed!")
|
||||
for count in self.config.generation_config.counts:
|
||||
@ -221,17 +233,6 @@ class Stage01Runner:
|
||||
seed_index, count, orientation, replica
|
||||
)
|
||||
seed_index += 1
|
||||
else:
|
||||
_logger.debug(
|
||||
f"Dipole generation override received: {self.config.generation_config.override_dipole_configs}"
|
||||
)
|
||||
for (
|
||||
override_name,
|
||||
override_dipoles,
|
||||
) in self.config.generation_config.override_dipole_configs.items():
|
||||
self.generate_override_dipole(
|
||||
seed_index, override_name, override_dipoles
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -274,50 +275,3 @@ def parse_args():
|
||||
# tantri.cli._generate_dipoles(config_json, dipoles_json, (index, replica, 1))
|
||||
|
||||
# tantri.cli._write_apsd(dipoles_json, DOTS, X_ELECTRIC_FIELD, DELTA_T, NUM_ITERATIONS, NUM_BIN_TS, (index, replica, 2), output_csv, binned_csv, BIN_WIDTH_LOG, True)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
tantri_configs = [
|
||||
kalpaa.TantriConfig(31415, 100, 5, 100000),
|
||||
kalpaa.TantriConfig(314, 100, 0.00005, 100000),
|
||||
]
|
||||
generation_config = kalpaa.GenerationConfig(
|
||||
tantri_configs=tantri_configs,
|
||||
counts=[1],
|
||||
num_replicas=3,
|
||||
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||
)
|
||||
|
||||
config = kalpaa.Config(generation_config=generation_config)
|
||||
|
||||
kalpaa.common.set_up_logging(config, args.log_file)
|
||||
|
||||
_logger.info("Generating our data, for the following iterations")
|
||||
|
||||
_logger.info(config)
|
||||
# _logger.info(f"{COUNTS=}")
|
||||
# _logger.info(f"{ORIENTATIONS=}")
|
||||
# _logger.info(f"{NUM_REPLICAS=}")
|
||||
|
||||
# _logger.info("Our parameters used: ")
|
||||
|
||||
# _logger.info(f"\t{INDEX_STARTER=}")
|
||||
|
||||
# _logger.info(f"\t{NUM_SEEDS=}")
|
||||
# # these are obviously not independent but it's just easier than thinking about floats to define them both here
|
||||
# _logger.info(f"\t{DELTA_T=}")
|
||||
# _logger.info(f"\t{SAMPLE_RATE=}")
|
||||
# _logger.info(f"\t{NUM_ITERATIONS=}")
|
||||
|
||||
# # for binnng
|
||||
# _logger.info(f"\t{NUM_BIN_TS=}")
|
||||
# _logger.info(f"\t{BIN_WIDTH_LOG=}")
|
||||
|
||||
runner = Stage01Runner(config)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -17,6 +17,7 @@ import json
|
||||
|
||||
import kalpaa
|
||||
import kalpaa.common
|
||||
import kalpaa.completions
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -82,7 +83,7 @@ class Stage02Runner:
|
||||
)
|
||||
_logger.info(f"Got dots {self.dots=}")
|
||||
|
||||
def _dots_to_include(self, current_dot: str) -> typing.Sequence[str]:
|
||||
def _dots_to_include(self, current_dot: str) -> typing.List[str]:
|
||||
if current_dot == "dot1":
|
||||
return ["dot1"]
|
||||
if current_dot == "dot2":
|
||||
@ -90,7 +91,9 @@ class Stage02Runner:
|
||||
else:
|
||||
return ["dot1", "dot2", current_dot]
|
||||
|
||||
def run_in_subdir(self, subdir: pathlib.Path):
|
||||
def run_in_subdir(
|
||||
self, subdir: pathlib.Path, override_key: typing.Optional[str] = None
|
||||
):
|
||||
with kalpaa.common.new_cd(subdir):
|
||||
_logger.debug(f"Running inside {subdir=}")
|
||||
|
||||
@ -99,22 +102,75 @@ class Stage02Runner:
|
||||
_logger.debug(f"Have {num_jobs=}")
|
||||
seed_index = 0
|
||||
for job_index in range(num_jobs):
|
||||
|
||||
_logger.debug(f"Working on {job_index=}")
|
||||
completion_name = f"stage02.job_{job_index}.complete"
|
||||
completion = kalpaa.completions.check_completion_file(
|
||||
self.config, completion_name
|
||||
)
|
||||
if completion == kalpaa.completions.CompletionsStatus.COMPLETE:
|
||||
_logger.info(f"Skipping {completion_name}")
|
||||
continue
|
||||
elif completion == kalpaa.completions.CompletionsStatus.INVALID:
|
||||
_logger.error(f"Invalid completion status for {completion_name}")
|
||||
raise ValueError(f"Invalid completion status for {completion_name}")
|
||||
|
||||
for cost in self.config.deepdog_config.costs_to_try:
|
||||
for dot in self.dots:
|
||||
if self.config.deepdog_config.manual_dot_seeds is not None:
|
||||
for config_i, manual_config in enumerate(
|
||||
self.config.deepdog_config.manual_dot_seeds
|
||||
):
|
||||
|
||||
seed_index += 1
|
||||
seed_index += 1
|
||||
# validate config
|
||||
|
||||
combined_dot_name = ",".join(
|
||||
[d for d in self._dots_to_include(dot.label)]
|
||||
)
|
||||
trial_name = (
|
||||
f"{dot.label}-{combined_dot_name}-{cost}-{job_index}"
|
||||
)
|
||||
_logger.info(f"Working on {trial_name=}")
|
||||
_logger.debug(f"Have {seed_index=}")
|
||||
self.single_run_in_subdir(
|
||||
job_index, cost, dot.label, trial_name, seed_index
|
||||
)
|
||||
dot_label = str(config_i) + str(manual_config).translate(
|
||||
str.maketrans("", "", "[]\",' ")
|
||||
)
|
||||
dot_set = set()
|
||||
for dot_entry in manual_config:
|
||||
for dot_name in dot_entry:
|
||||
dot_set.add(dot_name)
|
||||
_logger.info(f"Dot set {dot_set=}")
|
||||
dot_included = ",".join([d for d in sorted(dot_set)])
|
||||
trial_name = (
|
||||
f"{dot_label}-{dot_included}-{cost}-{job_index}"
|
||||
)
|
||||
|
||||
_logger.info(f"Working on {trial_name=}")
|
||||
_logger.debug(f"Have {seed_index=}")
|
||||
self.single_run_in_subdir(
|
||||
job_index,
|
||||
cost,
|
||||
dot_label,
|
||||
trial_name,
|
||||
seed_index,
|
||||
override_name=override_key,
|
||||
dot_spec=manual_config,
|
||||
)
|
||||
else:
|
||||
for dot in self.dots:
|
||||
|
||||
seed_index += 1
|
||||
|
||||
combined_dot_name = ",".join(
|
||||
[d for d in self._dots_to_include(dot.label)]
|
||||
)
|
||||
trial_name = (
|
||||
f"{dot.label}-{combined_dot_name}-{cost}-{job_index}"
|
||||
)
|
||||
|
||||
_logger.info(f"Working on {trial_name=}")
|
||||
_logger.debug(f"Have {seed_index=}")
|
||||
self.single_run_in_subdir(
|
||||
job_index,
|
||||
cost,
|
||||
dot.label,
|
||||
trial_name,
|
||||
seed_index,
|
||||
override_name=override_key,
|
||||
)
|
||||
kalpaa.completions.set_completion_file(self.config, completion_name)
|
||||
|
||||
def single_run_in_subdir(
|
||||
self,
|
||||
@ -123,6 +179,8 @@ class Stage02Runner:
|
||||
dot_name: str,
|
||||
trial_name: str,
|
||||
seed_index: int,
|
||||
override_name: typing.Optional[str] = None,
|
||||
dot_spec: typing.Optional[typing.Sequence[typing.Sequence[str]]] = None,
|
||||
):
|
||||
# _logger.info(f"Got job index {job_index}")
|
||||
# NOTE This guy runs inside subdirs, obviously. In something like <kalpa>/out/z-10-2/dipoles
|
||||
@ -137,17 +195,58 @@ class Stage02Runner:
|
||||
f"Have {self.config.generation_config.tantri_configs} as our tantri_configs"
|
||||
)
|
||||
num_tantri_configs = len(self.config.generation_config.tantri_configs)
|
||||
binned_datas = [
|
||||
kalpaa.read_dots_and_binned(
|
||||
self.config.get_dots_json_path(),
|
||||
pathlib.Path("..")
|
||||
/ kalpaa.common.tantri_binned_output_name(tantri_index),
|
||||
)
|
||||
for tantri_index in range(num_tantri_configs)
|
||||
]
|
||||
|
||||
dot_names = self._dots_to_include(dot_name)
|
||||
_logger.debug(f"Got dot names {dot_names}")
|
||||
if override_name is not None:
|
||||
if self.config.generation_config.override_measurement_filesets is None:
|
||||
raise ValueError(
|
||||
"override_name provided but no override_measurement_filesets, shouldn't be possible to get here"
|
||||
)
|
||||
_logger.info(f"Time to read override measurement fileset {override_name}")
|
||||
override_dir = self.config.get_override_dir_path() / override_name
|
||||
override_measurements = (
|
||||
self.config.generation_config.override_measurement_filesets[
|
||||
override_name
|
||||
]
|
||||
)
|
||||
_logger.info(f"Finding files {override_measurements} in {override_dir}")
|
||||
binned_datas = [
|
||||
kalpaa.read_dots_and_binned(
|
||||
self.config.get_dots_json_path(),
|
||||
override_dir / measurement,
|
||||
)
|
||||
for measurement in override_measurements
|
||||
]
|
||||
else:
|
||||
|
||||
binned_datas = [
|
||||
kalpaa.read_dots_and_binned(
|
||||
self.config.get_dots_json_path(),
|
||||
pathlib.Path("..")
|
||||
/ kalpaa.common.tantri_binned_output_name(tantri_index),
|
||||
)
|
||||
for tantri_index in range(num_tantri_configs)
|
||||
]
|
||||
|
||||
single_dot_names: typing.List[str] = []
|
||||
pair_dot_names: typing.List[typing.Tuple[str, str]] = []
|
||||
if dot_spec is not None:
|
||||
_logger.info(f"Received dot_spec {dot_spec}, validating")
|
||||
for dot_entry in dot_spec:
|
||||
_logger.debug(f"Working on {dot_entry=}")
|
||||
if len(dot_entry) not in (1, 2):
|
||||
raise ValueError(
|
||||
f"Invalid dot spec {dot_spec}, {dot_entry} has wrong length"
|
||||
)
|
||||
|
||||
if len(dot_entry) == 1:
|
||||
_logger.debug(f"Adding {dot_entry[0]} to single_dot_names")
|
||||
single_dot_names.append(dot_entry[0])
|
||||
else:
|
||||
pair_dot_names.append((dot_entry[0], dot_entry[1]))
|
||||
else:
|
||||
single_dot_names = self._dots_to_include(dot_name)
|
||||
pair_dot_names = []
|
||||
_logger.debug(f"Got dot names {single_dot_names=}, {pair_dot_names=}")
|
||||
|
||||
models = []
|
||||
|
||||
@ -162,6 +261,7 @@ class Stage02Runner:
|
||||
seed = seed_index
|
||||
|
||||
# TODO find way to store this as a global config file
|
||||
# TODO refactor to account for missing entries, (ex. if occupancy of 150 should use next highest value)
|
||||
occupancies_dict = {
|
||||
1: (500, 1000),
|
||||
2: (250, 2000),
|
||||
@ -172,12 +272,14 @@ class Stage02Runner:
|
||||
17: (50, 10000),
|
||||
31: (50, 10000),
|
||||
56: (25, 20000),
|
||||
100: (5, 100000),
|
||||
100: (2, 250000),
|
||||
161: (1, 500000),
|
||||
200: (1, 500000),
|
||||
}
|
||||
|
||||
mccount, mccountcycles = occupancies_dict[avg_filled]
|
||||
|
||||
model_params = kalpaa.ReducedModelParams(
|
||||
model_params = self.config.default_model_param_config.reduced_model_params(
|
||||
count=avg_filled, log_magnitude=log_magnitude, orientation=orientation
|
||||
)
|
||||
|
||||
@ -196,16 +298,31 @@ class Stage02Runner:
|
||||
write_successes_to_file=True,
|
||||
tag=trial_name,
|
||||
write_bayesrun_file=True,
|
||||
bayesrun_file_timestamp=False,
|
||||
skip_if_exists=True, # Can't see why we wouldn't want this, maybe hook to check_completions later
|
||||
)
|
||||
|
||||
_logger.info(f"{deepdog_config=}")
|
||||
|
||||
stdev_cost_function_filters = [
|
||||
b.stdev_cost_function_filter(
|
||||
dot_names, cost, self.config.deepdog_config.use_log_noise
|
||||
)
|
||||
for b in binned_datas
|
||||
]
|
||||
stdev_cost_function_filters = []
|
||||
|
||||
if len(pair_dot_names):
|
||||
pair_stdev_cost_function_filters = [
|
||||
b.stdev_cost_function_filter(
|
||||
pair_dot_names, cost, self.config.deepdog_config.use_log_noise
|
||||
)
|
||||
for b in binned_datas
|
||||
]
|
||||
stdev_cost_function_filters.extend(pair_stdev_cost_function_filters)
|
||||
|
||||
if len(single_dot_names):
|
||||
single_stdev_cost_function_filters = [
|
||||
b.stdev_cost_function_filter(
|
||||
single_dot_names, cost, self.config.deepdog_config.use_log_noise
|
||||
)
|
||||
for b in binned_datas
|
||||
]
|
||||
stdev_cost_function_filters.extend(single_stdev_cost_function_filters)
|
||||
|
||||
_logger.debug(f"{stdev_cost_function_filters=}")
|
||||
combining_filter = deepdog.direct_monte_carlo.compose_filter.ComposedDMCFilter(
|
||||
@ -221,16 +338,27 @@ class Stage02Runner:
|
||||
_logger.info(results)
|
||||
|
||||
def run(self):
|
||||
"""Going to iterate over every folder in out_dir, and execute the subdir stuff inside dirs like <kalpa>/out/z-10-2/dipoles"""
|
||||
out_dir_path = self.config.get_out_dir_path()
|
||||
subdirs = [child for child in out_dir_path.iterdir() if child.is_dir]
|
||||
# _logger.info(f"Going to execute within each of the directories in {subdirs=}")
|
||||
for subdir in subdirs:
|
||||
# skip try finally for now just blow up if problem
|
||||
_logger.debug(f"Running for {subdir=}")
|
||||
dipoles_dir = subdir / "dipoles"
|
||||
dipoles_dir.mkdir(exist_ok=True, parents=False)
|
||||
self.run_in_subdir(subdir / "dipoles")
|
||||
if self.config.generation_config.override_measurement_filesets is not None:
|
||||
_logger.info("Using override configuration.")
|
||||
for (
|
||||
override_name
|
||||
) in self.config.generation_config.override_measurement_filesets.keys():
|
||||
subdir = self.config.get_out_dir_path() / override_name
|
||||
dipoles_dir = subdir / "dipoles"
|
||||
dipoles_dir.mkdir(exist_ok=True, parents=False)
|
||||
self.run_in_subdir(dipoles_dir, override_key=override_name)
|
||||
|
||||
else:
|
||||
"""Going to iterate over every folder in out_dir, and execute the subdir stuff inside dirs like <kalpa>/out/z-10-2/dipoles"""
|
||||
out_dir_path = self.config.get_out_dir_path()
|
||||
subdirs = [child for child in out_dir_path.iterdir() if child.is_dir]
|
||||
# _logger.info(f"Going to execute within each of the directories in {subdirs=}")
|
||||
for subdir in subdirs:
|
||||
# skip try finally for now just blow up if problem
|
||||
_logger.debug(f"Running for {subdir=}")
|
||||
dipoles_dir = subdir / "dipoles"
|
||||
dipoles_dir.mkdir(exist_ok=True, parents=False)
|
||||
self.run_in_subdir(subdir / "dipoles")
|
||||
|
||||
|
||||
def parse_args():
|
||||
@ -247,28 +375,3 @@ def parse_args():
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
tantri_configs = [
|
||||
kalpaa.TantriConfig(31415, 100, 5, 100000),
|
||||
kalpaa.TantriConfig(314, 100, 0.00005, 100000),
|
||||
]
|
||||
generation_config = kalpaa.GenerationConfig(tantri_configs=tantri_configs)
|
||||
|
||||
config = kalpaa.Config(generation_config=generation_config)
|
||||
|
||||
kalpaa.common.set_up_logging(config, args.log_file)
|
||||
|
||||
_logger.info("Generating our data, for the following iterations")
|
||||
|
||||
_logger.info(config)
|
||||
|
||||
runner = Stage02Runner(config)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -24,6 +24,7 @@ import tantri.dipoles.types
|
||||
# folder in curr dir
|
||||
import kalpaa
|
||||
import kalpaa.common
|
||||
import kalpaa.completions
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -48,17 +49,24 @@ OUT_FIELDNAMES = [
|
||||
]
|
||||
|
||||
|
||||
def coalesced_filename(dot_name, target_cost) -> str:
|
||||
return f"coalesced-{dot_name}-{target_cost}.csv"
|
||||
def coalesced_filename(subdir_name: str) -> str:
|
||||
return f"coalesced-{subdir_name}.csv"
|
||||
|
||||
|
||||
def read_coalesced_csv(parent_path: pathlib.Path, dot_name: str, target_cost):
|
||||
def read_coalesced_csv(parent_path: pathlib.Path, subdir_name: str):
|
||||
|
||||
# csv_name = f"coalesced-{dot_name}-{target_cost}.csv"
|
||||
csv_path = parent_path / coalesced_filename(dot_name, target_cost)
|
||||
csv_path = parent_path / coalesced_filename(subdir_name)
|
||||
_logger.debug(f"{csv_path=}")
|
||||
with csv_path.open("r", newline="") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
out_list = []
|
||||
|
||||
subdir_split = subdir_name.rsplit("-", 1)
|
||||
|
||||
dot_name = subdir_split[0]
|
||||
target_cost = subdir_split[1]
|
||||
_logger.debug(f"{dot_name=}, {target_cost=} for subdir_name {subdir_name=}")
|
||||
for row in reader:
|
||||
row["dot_name"] = dot_name
|
||||
row["target_cost"] = target_cost
|
||||
@ -86,12 +94,15 @@ class Stage03Runner:
|
||||
with out_path.open("w", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, OUT_FIELDNAMES)
|
||||
writer.writeheader()
|
||||
for dot in self.dots:
|
||||
for cost in self.config.deepdog_config.costs_to_try:
|
||||
_logger.info(f"Reading {dot=} {cost=}")
|
||||
rows = read_coalesced_csv(sorted_dir, dot, cost)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
for subdir in sorted_dir.iterdir():
|
||||
if not subdir.is_dir():
|
||||
_logger.info(f"That's not a dir {subdir=}")
|
||||
continue
|
||||
subdir_name = subdir.name
|
||||
_logger.info(f"Reading for {subdir_name=}")
|
||||
rows = read_coalesced_csv(sorted_dir, subdir_name)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
def run_in_subdir(self, subdir: pathlib.Path):
|
||||
"""
|
||||
@ -101,33 +112,49 @@ class Stage03Runner:
|
||||
|
||||
_logger.debug(f"Running inside {subdir=}")
|
||||
|
||||
kalpaa.stages.stage03_1.move_all_in_dipoles(subdir / "dipoles")
|
||||
subdir_name = subdir.name
|
||||
completion_name = f"stage03_1.job_{subdir_name}.complete"
|
||||
completion = kalpaa.completions.check_completion_file(
|
||||
self.config, completion_name
|
||||
)
|
||||
if completion == kalpaa.completions.CompletionsStatus.COMPLETE:
|
||||
_logger.info(f"Skipping {completion_name}")
|
||||
# continue
|
||||
elif completion == kalpaa.completions.CompletionsStatus.INVALID:
|
||||
_logger.error(f"Invalid completion status for {completion_name}")
|
||||
raise ValueError(f"Invalid completion status for {completion_name}")
|
||||
else:
|
||||
_logger.info(f"Moving dipoles for {subdir=}")
|
||||
kalpaa.stages.stage03_1.move_all_in_dipoles(subdir / "dipoles")
|
||||
kalpaa.completions.set_completion_file(self.config, completion_name)
|
||||
|
||||
seed_index = 0
|
||||
|
||||
sorted_dir = pathlib.Path(kalpaa.common.sorted_bayesruns_name())
|
||||
_logger.info(f"{sorted_dir.resolve()}")
|
||||
|
||||
for cost in self.config.deepdog_config.costs_to_try:
|
||||
for dot in self.dots:
|
||||
for sorted_subdir in sorted_dir.iterdir():
|
||||
if not subdir.is_dir():
|
||||
_logger.info(f"That's not a dir {subdir=}")
|
||||
continue
|
||||
|
||||
seed_index += 1
|
||||
# TODO pull out
|
||||
sorted_subdir = sorted_dir / f"{dot}-{cost}"
|
||||
seed_index += 1
|
||||
# TODO pull out
|
||||
# sorted_subdir = sorted_dir / f"{dot}-{cost}"
|
||||
|
||||
# TODO need to refactor deepdog probs method so I don't have to dump into args like this
|
||||
probs_args = argparse.Namespace()
|
||||
probs_args.bayesrun_directory = sorted_subdir
|
||||
probs_args.indexify_json = self.config.absify(
|
||||
self.config.general_config.indexes_json_name
|
||||
)
|
||||
probs_args.coalesced_keys = ""
|
||||
probs_args.uncoalesced_outfile = None
|
||||
probs_args.coalesced_outfile = sorted_dir / coalesced_filename(
|
||||
dot, cost
|
||||
)
|
||||
# TODO need to refactor deepdog probs method so I don't have to dump into args like this
|
||||
probs_args = argparse.Namespace()
|
||||
probs_args.bayesrun_directory = sorted_subdir
|
||||
probs_args.indexify_json = self.config.absify(
|
||||
self.config.general_config.indexes_json_name
|
||||
)
|
||||
probs_args.coalesced_keys = ""
|
||||
probs_args.uncoalesced_outfile = None
|
||||
probs_args.coalesced_outfile = sorted_dir / coalesced_filename(
|
||||
sorted_subdir.name
|
||||
)
|
||||
|
||||
deepdog.cli.probs.main.main(probs_args)
|
||||
deepdog.cli.probs.main.main(probs_args)
|
||||
|
||||
self.merge_coalesceds(sorted_dir)
|
||||
|
||||
@ -161,33 +188,3 @@ def parse_args():
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
tantri_configs = [
|
||||
kalpaa.TantriConfig(31415, 100, 5, 100000),
|
||||
kalpaa.TantriConfig(314, 100, 0.00005, 100000),
|
||||
]
|
||||
generation_config = kalpaa.GenerationConfig(
|
||||
tantri_configs=tantri_configs,
|
||||
counts=[1],
|
||||
num_replicas=3,
|
||||
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||
)
|
||||
|
||||
config = kalpaa.Config(generation_config=generation_config)
|
||||
|
||||
kalpaa.common.set_up_logging(config, args.log_file)
|
||||
|
||||
_logger.info("Generating our data, for the following iterations")
|
||||
|
||||
_logger.info(config)
|
||||
|
||||
runner = Stage03Runner(config)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -5,7 +5,6 @@ import csv
|
||||
import kalpaa
|
||||
import kalpaa.common
|
||||
import kalpaa.inference_coalesce
|
||||
import tantri.dipoles.types
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -155,8 +154,29 @@ class Stage04Runner:
|
||||
writer = csv.DictWriter(outfile, MERGED_OUT_FIELDNAMES)
|
||||
writer.writeheader()
|
||||
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
if self.config.generation_config.override_dipole_configs is not None:
|
||||
override_names = (
|
||||
self.config.generation_config.override_dipole_configs.keys()
|
||||
)
|
||||
elif (
|
||||
self.config.generation_config.override_measurement_filesets is not None
|
||||
):
|
||||
override_names = (
|
||||
self.config.generation_config.override_measurement_filesets.keys()
|
||||
)
|
||||
else:
|
||||
override_names = None
|
||||
|
||||
if override_names is not None:
|
||||
_logger.debug(
|
||||
f"We had overridden dipole config, using override {override_names}"
|
||||
)
|
||||
for override_name in override_names:
|
||||
_logger.info(f"Working for subdir {override_name}")
|
||||
rows = self.read_merged_coalesced_csv_override(override_name)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
else:
|
||||
for count in self.config.generation_config.counts:
|
||||
for orientation in self.config.generation_config.orientations:
|
||||
for replica in range(
|
||||
@ -169,21 +189,9 @@ class Stage04Runner:
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
else:
|
||||
_logger.debug(
|
||||
f"We had overridden dipole config, using override {self.config.generation_config.override_dipole_configs}"
|
||||
)
|
||||
for (
|
||||
override_name
|
||||
) in self.config.generation_config.override_dipole_configs.keys():
|
||||
_logger.info(f"Working for subdir {override_name}")
|
||||
rows = self.read_merged_coalesced_csv_override(override_name)
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
|
||||
# merge with inference
|
||||
|
||||
if self.config.generation_config.override_dipole_configs is None:
|
||||
if override_names is None:
|
||||
|
||||
with megamerged_path.open(mode="r", newline="") as infile:
|
||||
# Note that if you pass in fieldnames to a DictReader it doesn't skip. So this is bad:
|
||||
@ -234,33 +242,3 @@ def parse_args():
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
tantri_configs = [
|
||||
kalpaa.TantriConfig(31415, 100, 5, 100000),
|
||||
kalpaa.TantriConfig(314, 100, 0.00005, 100000),
|
||||
]
|
||||
generation_config = kalpaa.GenerationConfig(
|
||||
tantri_configs=tantri_configs,
|
||||
counts=[1],
|
||||
num_replicas=3,
|
||||
orientations=[tantri.dipoles.types.Orientation.Z],
|
||||
)
|
||||
|
||||
config = kalpaa.Config(generation_config=generation_config)
|
||||
|
||||
kalpaa.common.set_up_logging(config, args.log_file)
|
||||
|
||||
_logger.info("Generating our data, for the following iterations")
|
||||
|
||||
_logger.info(config)
|
||||
|
||||
runner = Stage04Runner(config)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
6
poetry.lock
generated
6
poetry.lock
generated
@ -174,13 +174,13 @@ dev = ["black", "coveralls", "mypy", "pre-commit", "pylint", "pytest (>=5)", "py
|
||||
|
||||
[[package]]
|
||||
name = "deepdog"
|
||||
version = "1.5.0"
|
||||
version = "1.7.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<3.10,>=3.8.1"
|
||||
files = [
|
||||
{file = "deepdog-1.5.0-py3-none-any.whl", hash = "sha256:b645fdc32a1933e17b4a76f97b5399d77e698fb10c6386f3fbcdb1fe9c5caf08"},
|
||||
{file = "deepdog-1.5.0.tar.gz", hash = "sha256:9012a9d375fce178fd222dd818a21c49ef4ce4127a65f5a3ad6ae16f5e96d1c5"},
|
||||
{file = "deepdog-1.7.0-py3-none-any.whl", hash = "sha256:53944ec281abf0118ff94033e7b7d73e13805cf6ee15489859a43a250968d45e"},
|
||||
{file = "deepdog-1.7.0.tar.gz", hash = "sha256:cb859f00da24117f49ddf544784dba4ff0df7a25fed83e2d9479fb55110a21d0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "kalpaa"
|
||||
version = "1.0.0"
|
||||
version = "1.2.0"
|
||||
description = "Groups up and runs full run."
|
||||
authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
|
||||
readme = "README.md"
|
||||
|
@ -14,7 +14,7 @@ if [ -z "$(git status --porcelain)" ]; then
|
||||
release_needed=false
|
||||
if \
|
||||
{ git log "$( git describe --tags --abbrev=0 )..HEAD" --format='%s' | cut -d: -f1 | sort -u | sed -e 's/([^)]*)//' | grep -q -i -E '^feat|fix|perf|refactor|revert$' ; } || \
|
||||
{ git log "$( git describe --tags --abbrev=0 )..HEAD" --format='%s' | cut -d: -f1 | sort -u | sed -e 's/([^)]*)//' | grep -q -E '\!$' ; } || \
|
||||
{ git log "$( git describe --tags --abbrev=0 )..HEAD" --format='%s' | cut -d: -f1 | sort -u | sed -e 's/([^)]*)//' | grep -q -E '!$' ; } || \
|
||||
{ git log "$( git describe --tags --abbrev=0 )..HEAD" --format='%b' | grep -q -E '^BREAKING CHANGE:' ; }
|
||||
then
|
||||
release_needed=true
|
||||
|
0
tests/common/__init__.py
Normal file
0
tests/common/__init__.py
Normal file
15
tests/common/__snapshots__/test_angles.ambr
Normal file
15
tests/common/__snapshots__/test_angles.ambr
Normal file
@ -0,0 +1,15 @@
|
||||
# serializer version: 1
|
||||
# name: test_angles_in_range
|
||||
list([
|
||||
0.0,
|
||||
-3.0,
|
||||
0.28318530717958623,
|
||||
-2.7168146928204138,
|
||||
0.5663706143591725,
|
||||
-2.4336293856408275,
|
||||
0.8495559215387587,
|
||||
-2.1504440784612413,
|
||||
1.132741228718345,
|
||||
-1.867258771281655,
|
||||
])
|
||||
# ---
|
13
tests/common/test_angles.py
Normal file
13
tests/common/test_angles.py
Normal file
@ -0,0 +1,13 @@
|
||||
import numpy
|
||||
import kalpaa.common.angles
|
||||
|
||||
|
||||
def test_angles_in_range(snapshot):
|
||||
angles_1 = numpy.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
angles_2 = numpy.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) * 4
|
||||
|
||||
result = kalpaa.common.angles.shortest_angular_distance(angles_1, angles_2)
|
||||
|
||||
assert (result >= -numpy.pi).all()
|
||||
assert (result < numpy.pi).all()
|
||||
assert result.tolist() == snapshot
|
309
tests/config/__snapshots__/test_toml_as_dict_snaps.ambr
Normal file
309
tests/config/__snapshots__/test_toml_as_dict_snaps.ambr
Normal file
@ -0,0 +1,309 @@
|
||||
# serializer version: 1
|
||||
# name: test_parse_config_all_fields_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
20,
|
||||
2,
|
||||
0.2,
|
||||
]),
|
||||
'manual_dot_seeds': None,
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 2000,
|
||||
'use_log_noise': True,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1,
|
||||
'w_log_min': -5,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 6.5,
|
||||
'z_min': 5,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'check_completions': False,
|
||||
'dots_json_name': 'test_dots.json',
|
||||
'indexes_json_name': 'test_indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
|
||||
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
|
||||
'mega_merged_name': 'test_mega_merged.csv',
|
||||
'out_dir_name': 'out',
|
||||
'root_directory': PosixPath('test_root'),
|
||||
'skip_to_stage': 1,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 3,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': dict({
|
||||
'scenario1': list([
|
||||
dict({
|
||||
'p': array([3, 5, 7]),
|
||||
's': array([2, 4, 6]),
|
||||
'w': 10,
|
||||
}),
|
||||
dict({
|
||||
'p': array([30, 50, 70]),
|
||||
's': array([20, 40, 60]),
|
||||
'w': 10.55,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.01,
|
||||
'index_seed_starter': 15151,
|
||||
'num_iterations': 100,
|
||||
'num_seeds': 5,
|
||||
}),
|
||||
dict({
|
||||
'delta_t': 1,
|
||||
'index_seed_starter': 1234,
|
||||
'num_iterations': 200,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_parse_config_few_fields_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
5,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
0.2,
|
||||
]),
|
||||
'manual_dot_seeds': None,
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 2000,
|
||||
'use_log_noise': True,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1,
|
||||
'w_log_min': -5,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 6.5,
|
||||
'z_min': 5,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'check_completions': False,
|
||||
'dots_json_name': 'dots.json',
|
||||
'indexes_json_name': 'indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
|
||||
'mega_merged_name': 'mega_merged_coalesced.csv',
|
||||
'out_dir_name': 'out',
|
||||
'root_directory': PosixPath('test_root1'),
|
||||
'skip_to_stage': None,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 2,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': dict({
|
||||
'scenario1': list([
|
||||
dict({
|
||||
'p': array([3, 5, 7]),
|
||||
's': array([2, 4, 6]),
|
||||
'w': 10,
|
||||
}),
|
||||
dict({
|
||||
'p': array([30, 50, 70]),
|
||||
's': array([20, 40, 60]),
|
||||
'w': 10.55,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.01,
|
||||
'index_seed_starter': 15151,
|
||||
'num_iterations': 100,
|
||||
'num_seeds': 5,
|
||||
}),
|
||||
dict({
|
||||
'delta_t': 1,
|
||||
'index_seed_starter': 1234,
|
||||
'num_iterations': 200,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_parse_config_geom_params_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
5,
|
||||
2,
|
||||
1,
|
||||
0.5,
|
||||
0.2,
|
||||
]),
|
||||
'manual_dot_seeds': None,
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 2000,
|
||||
'use_log_noise': True,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1.5,
|
||||
'w_log_min': -3,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 2,
|
||||
'z_min': 0,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'check_completions': False,
|
||||
'dots_json_name': 'dots.json',
|
||||
'indexes_json_name': 'indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'mega_merged_inferenced_name': 'mega_merged_coalesced_inferenced.csv',
|
||||
'mega_merged_name': 'mega_merged_coalesced.csv',
|
||||
'out_dir_name': 'out',
|
||||
'root_directory': PosixPath('test_root1'),
|
||||
'skip_to_stage': None,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 2,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': dict({
|
||||
'scenario1': list([
|
||||
dict({
|
||||
'p': array([3, 5, 7]),
|
||||
's': array([2, 4, 6]),
|
||||
'w': 10,
|
||||
}),
|
||||
dict({
|
||||
'p': array([30, 50, 70]),
|
||||
's': array([20, 40, 60]),
|
||||
'w': 10.55,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.01,
|
||||
'index_seed_starter': 15151,
|
||||
'num_iterations': 100,
|
||||
'num_seeds': 5,
|
||||
}),
|
||||
dict({
|
||||
'delta_t': 1,
|
||||
'index_seed_starter': 1234,
|
||||
'num_iterations': 200,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_parse_config_toml_as_dict
|
||||
dict({
|
||||
'deepdog_config': dict({
|
||||
'costs_to_try': list([
|
||||
10,
|
||||
1,
|
||||
0.1,
|
||||
]),
|
||||
'manual_dot_seeds': None,
|
||||
'max_monte_carlo_cycles_steps': 20,
|
||||
'target_success': 1000,
|
||||
'use_log_noise': False,
|
||||
}),
|
||||
'default_model_param_config': dict({
|
||||
'w_log_max': 1,
|
||||
'w_log_min': -5,
|
||||
'x_max': 20,
|
||||
'x_min': -20,
|
||||
'y_max': 10,
|
||||
'y_min': -10,
|
||||
'z_max': 6.5,
|
||||
'z_min': 5,
|
||||
}),
|
||||
'general_config': dict({
|
||||
'check_completions': False,
|
||||
'dots_json_name': 'test_dots.json',
|
||||
'indexes_json_name': 'test_indexes.json',
|
||||
'log_pattern': '%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s',
|
||||
'measurement_type': <MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>,
|
||||
'mega_merged_inferenced_name': 'test_mega_merged_inferenced.csv',
|
||||
'mega_merged_name': 'test_mega_merged.csv',
|
||||
'out_dir_name': 'test_out',
|
||||
'root_directory': PosixPath('test_root'),
|
||||
'skip_to_stage': 1,
|
||||
}),
|
||||
'generation_config': dict({
|
||||
'bin_log_width': 0.25,
|
||||
'counts': list([
|
||||
1,
|
||||
10,
|
||||
]),
|
||||
'num_bin_time_series': 25,
|
||||
'num_replicas': 3,
|
||||
'orientations': list([
|
||||
<Orientation.RANDOM: 'RANDOM'>,
|
||||
<Orientation.Z: 'Z'>,
|
||||
<Orientation.XY: 'XY'>,
|
||||
]),
|
||||
'override_dipole_configs': None,
|
||||
'override_measurement_filesets': None,
|
||||
'tantri_configs': list([
|
||||
dict({
|
||||
'delta_t': 0.05,
|
||||
'index_seed_starter': 31415,
|
||||
'num_iterations': 100000,
|
||||
'num_seeds': 100,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
@ -1,10 +1,13 @@
|
||||
# serializer version: 1
|
||||
# name: test_parse_config_all_fields_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True))
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1, check_completions=False), deepdog_config=DeepdogConfig(costs_to_try=[20, 2, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True, manual_dot_seeds=None), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
# ---
|
||||
# name: test_parse_config_few_fields_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True))
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None, check_completions=False), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True, manual_dot_seeds=None), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
# ---
|
||||
# name: test_parse_config_geom_params_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 5, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=2, override_dipole_configs={'scenario1': [DipoleTO(p=array([3, 5, 7]), s=array([2, 4, 6]), w=10), DipoleTO(p=array([30, 50, 70]), s=array([20, 40, 60]), w=10.55)]}, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=15151, num_seeds=5, delta_t=0.01, num_iterations=100), TantriConfig(index_seed_starter=1234, num_seeds=100, delta_t=1, num_iterations=200)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='dots.json', indexes_json_name='indexes.json', out_dir_name='out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.POTENTIAL: 'electric-potential'>, root_directory=PosixPath('test_root1'), mega_merged_name='mega_merged_coalesced.csv', mega_merged_inferenced_name='mega_merged_coalesced_inferenced.csv', skip_to_stage=None, check_completions=False), deepdog_config=DeepdogConfig(costs_to_try=[5, 2, 1, 0.5, 0.2], target_success=2000, max_monte_carlo_cycles_steps=20, use_log_noise=True, manual_dot_seeds=None), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=0, z_max=2, w_log_min=-3, w_log_max=1.5))
|
||||
# ---
|
||||
# name: test_parse_config_toml
|
||||
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False))
|
||||
Config(generation_config=GenerationConfig(counts=[1, 10], orientations=[<Orientation.RANDOM: 'RANDOM'>, <Orientation.Z: 'Z'>, <Orientation.XY: 'XY'>], num_replicas=3, override_dipole_configs=None, override_measurement_filesets=None, tantri_configs=[TantriConfig(index_seed_starter=31415, num_seeds=100, delta_t=0.05, num_iterations=100000)], num_bin_time_series=25, bin_log_width=0.25), general_config=GeneralConfig(dots_json_name='test_dots.json', indexes_json_name='test_indexes.json', out_dir_name='test_out', log_pattern='%(asctime)s | %(process)d | %(levelname)-7s | %(name)s:%(lineno)d | %(message)s', measurement_type=<MeasurementTypeEnum.X_ELECTRIC_FIELD: 'x-electric-field'>, root_directory=PosixPath('test_root'), mega_merged_name='test_mega_merged.csv', mega_merged_inferenced_name='test_mega_merged_inferenced.csv', skip_to_stage=1, check_completions=False), deepdog_config=DeepdogConfig(costs_to_try=[10, 1, 0.1], target_success=1000, max_monte_carlo_cycles_steps=20, use_log_noise=False, manual_dot_seeds=None), default_model_param_config=DefaultModelParamConfig(x_min=-20, x_max=20, y_min=-10, y_max=10, z_min=5, z_max=6.5, w_log_min=-5, w_log_max=1))
|
||||
# ---
|
||||
|
28
tests/config/test_files/test_config_geom_params.toml
Normal file
28
tests/config/test_files/test_config_geom_params.toml
Normal file
@ -0,0 +1,28 @@
|
||||
[general_config]
|
||||
root_directory = "test_root1"
|
||||
measurement_type = "electric-potential"
|
||||
|
||||
[generation_config]
|
||||
counts = [1, 5, 10]
|
||||
num_replicas = 2
|
||||
tantri_configs = [
|
||||
{index_seed_starter = 15151, num_seeds = 5, delta_t = 0.01, num_iterations = 100},
|
||||
{index_seed_starter = 1234, num_seeds = 100, delta_t = 1, num_iterations = 200}
|
||||
]
|
||||
|
||||
[generation_config.override_dipole_configs]
|
||||
scenario1 = [
|
||||
{p = [3, 5, 7], s = [2, 4, 6], w = 10},
|
||||
{p = [30, 50, 70], s = [20, 40, 60], w = 10.55},
|
||||
]
|
||||
|
||||
[deepdog_config]
|
||||
costs_to_try = [5, 2, 1, 0.5, 0.2]
|
||||
target_success = 2000
|
||||
use_log_noise = true
|
||||
|
||||
[default_model_param_config]
|
||||
z_min = 0
|
||||
z_max = 2
|
||||
w_log_min = -3
|
||||
w_log_max = 1.5
|
22
tests/config/test_model_params.py
Normal file
22
tests/config/test_model_params.py
Normal file
@ -0,0 +1,22 @@
|
||||
import kalpaa.config
|
||||
|
||||
|
||||
def test_model_param_default_works():
|
||||
TEST_ZMIN = -99
|
||||
TEST_ZMAX = 99
|
||||
TEST_REAL_ZMAX = 55
|
||||
TEST_COUNT = 58
|
||||
model_param_config = kalpaa.config.DefaultModelParamConfig(
|
||||
z_min=TEST_ZMIN,
|
||||
z_max=TEST_ZMAX,
|
||||
)
|
||||
|
||||
actual_params = model_param_config.reduced_model_params(
|
||||
count=TEST_COUNT, z_max=TEST_REAL_ZMAX
|
||||
)
|
||||
|
||||
assert actual_params.z_min == TEST_ZMIN
|
||||
assert actual_params.z_max == TEST_REAL_ZMAX # want later value to override
|
||||
assert actual_params.x_min == -20
|
||||
assert actual_params.x_max == 20
|
||||
assert actual_params.count == TEST_COUNT
|
35
tests/config/test_toml_as_dict_snaps.py
Normal file
35
tests/config/test_toml_as_dict_snaps.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pathlib
|
||||
from dataclasses import asdict
|
||||
|
||||
import kalpaa.config.config_reader
|
||||
|
||||
|
||||
TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files"
|
||||
|
||||
|
||||
def test_parse_config_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
||||
|
||||
|
||||
def test_parse_config_all_fields_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_all_fields.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
||||
|
||||
|
||||
def test_parse_config_few_fields_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_few_fields.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
||||
|
||||
|
||||
def test_parse_config_geom_params_toml_as_dict(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_geom_params.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert asdict(actual_config) == snapshot
|
@ -80,3 +80,10 @@ def test_parse_config_few_fields_toml(snapshot):
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert actual_config == snapshot
|
||||
|
||||
|
||||
def test_parse_config_geom_params_toml(snapshot):
|
||||
test_config_file = TEST_DATA_DIR / "test_config_geom_params.toml"
|
||||
actual_config = kalpaa.config.config_reader.read_config(test_config_file)
|
||||
|
||||
assert actual_config == snapshot
|
||||
|
460
tests/read_bin_csv/__snapshots__/test_read_bin_csv.ambr
Normal file
460
tests/read_bin_csv/__snapshots__/test_read_bin_csv.ambr
Normal file
@ -0,0 +1,460 @@
|
||||
# serializer version: 1
|
||||
# name: test_binned_data_dot_measurement
|
||||
dict({
|
||||
'csv_dict': dict({
|
||||
'dot1': dict({
|
||||
'mean': list([
|
||||
10.638916947949246,
|
||||
4.808960230987057,
|
||||
1.8458074293863327,
|
||||
1.0990901962765007,
|
||||
0.6425140116757488,
|
||||
0.4844873135633905,
|
||||
0.44823255155896,
|
||||
]),
|
||||
'stdev': list([
|
||||
5.688165841523548,
|
||||
1.5555855859097745,
|
||||
0.5112103163244077,
|
||||
0.376055350460546,
|
||||
0.1411676088216461,
|
||||
0.11795510686231957,
|
||||
0.081977940913005,
|
||||
]),
|
||||
}),
|
||||
'dot2': dict({
|
||||
'mean': list([
|
||||
14.780311491085596,
|
||||
7.413101036489984,
|
||||
3.081527317039941,
|
||||
1.198719434472466,
|
||||
0.44608783800009594,
|
||||
0.16750150967807267,
|
||||
0.095604286225167,
|
||||
]),
|
||||
'stdev': list([
|
||||
5.085761250807487,
|
||||
2.7753690312876014,
|
||||
1.3009911753215875,
|
||||
0.3361763625979774,
|
||||
0.18042157503806078,
|
||||
0.05820931036468994,
|
||||
0.022567042968929727,
|
||||
]),
|
||||
}),
|
||||
'frequencies': list([
|
||||
0.0125,
|
||||
0.024999999999999998,
|
||||
0.045,
|
||||
0.0775,
|
||||
0.1375,
|
||||
0.24749999999999997,
|
||||
0.41,
|
||||
]),
|
||||
'line': dict({
|
||||
'mean': list([
|
||||
38.270876407794276,
|
||||
18.464083270261195,
|
||||
7.958206578940247,
|
||||
3.1067340783807977,
|
||||
1.1918940813568575,
|
||||
0.4599545316713671,
|
||||
0.2695902593679458,
|
||||
]),
|
||||
'stdev': list([
|
||||
11.39990688883059,
|
||||
6.7851330102447776,
|
||||
3.2965800683487467,
|
||||
0.9134607210867323,
|
||||
0.478463968609111,
|
||||
0.15476866608286458,
|
||||
0.06281745354709284,
|
||||
]),
|
||||
}),
|
||||
'triangle1': dict({
|
||||
'mean': list([
|
||||
11.899895021145635,
|
||||
5.913586933800311,
|
||||
2.395916783991011,
|
||||
0.9150694736063263,
|
||||
0.3292608404218125,
|
||||
0.11634897814927361,
|
||||
0.06146961050341923,
|
||||
]),
|
||||
'stdev': list([
|
||||
4.53573184571792,
|
||||
2.2751779798574585,
|
||||
1.0246980114887267,
|
||||
0.24363144876739817,
|
||||
0.1333924934596153,
|
||||
0.04295850059608126,
|
||||
0.015083740543610176,
|
||||
]),
|
||||
}),
|
||||
'triangle2': dict({
|
||||
'mean': list([
|
||||
33.66541768670497,
|
||||
16.223538221497495,
|
||||
6.890979907540165,
|
||||
2.633138291665575,
|
||||
0.9770973161110396,
|
||||
0.3543639501963306,
|
||||
0.1926321299071961,
|
||||
]),
|
||||
'stdev': list([
|
||||
11.231365853486137,
|
||||
6.031211083456,
|
||||
2.8522163160388114,
|
||||
0.7603021217632877,
|
||||
0.39370370747073685,
|
||||
0.12577555122177078,
|
||||
0.04748629565145606,
|
||||
]),
|
||||
}),
|
||||
'uprise1': dict({
|
||||
'mean': list([
|
||||
13.361715690908607,
|
||||
6.661543725531067,
|
||||
2.703848365637358,
|
||||
1.0340536040727817,
|
||||
0.37076116047359375,
|
||||
0.13049802854728143,
|
||||
0.06873250192025933,
|
||||
]),
|
||||
'stdev': list([
|
||||
5.0675284566338625,
|
||||
2.5702840629870978,
|
||||
1.157751077467679,
|
||||
0.27799313793346814,
|
||||
0.14947611970183075,
|
||||
0.04802345307476847,
|
||||
0.016919878842437803,
|
||||
]),
|
||||
}),
|
||||
'uprise2': dict({
|
||||
'mean': list([
|
||||
45.90844164247454,
|
||||
21.788765787102644,
|
||||
9.361785651884627,
|
||||
3.5839262514780574,
|
||||
1.3225984280531975,
|
||||
0.47113614904712353,
|
||||
0.2523607031606806,
|
||||
]),
|
||||
'stdev': list([
|
||||
14.970571928118593,
|
||||
7.992391910434401,
|
||||
3.830813979817984,
|
||||
1.0495809786215216,
|
||||
0.5297312429960187,
|
||||
0.16713468399902373,
|
||||
0.0630760820406672,
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
'dots_dict': dict({
|
||||
'dot1': list([
|
||||
5,
|
||||
0,
|
||||
0,
|
||||
]),
|
||||
'dot2': list([
|
||||
-5,
|
||||
0,
|
||||
0,
|
||||
]),
|
||||
'line': list([
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
]),
|
||||
'triangle1': list([
|
||||
-5,
|
||||
5,
|
||||
0,
|
||||
]),
|
||||
'triangle2': list([
|
||||
0,
|
||||
3.5,
|
||||
0,
|
||||
]),
|
||||
'uprise1': list([
|
||||
-5,
|
||||
5,
|
||||
0.5,
|
||||
]),
|
||||
'uprise2': list([
|
||||
0,
|
||||
3.5,
|
||||
0.5,
|
||||
]),
|
||||
}),
|
||||
'freq_list': list([
|
||||
0.0125,
|
||||
0.024999999999999998,
|
||||
0.045,
|
||||
0.0775,
|
||||
0.1375,
|
||||
0.24749999999999997,
|
||||
0.41,
|
||||
]),
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'pair_dict': dict({
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_binned_data_dot_measurement_costs
|
||||
dict({
|
||||
'linear': dict({
|
||||
'ex': dict({
|
||||
'one': array([ 5.11479963, 152.2645574 ]),
|
||||
'two': array([155.75850994]),
|
||||
}),
|
||||
'v': dict({
|
||||
'one': array([3.58557609, 3.61940193]),
|
||||
'two': array([3.38552928]),
|
||||
}),
|
||||
}),
|
||||
'log': dict({
|
||||
'ex': dict({
|
||||
'one': array([ 8.69396606, 24.39534105]),
|
||||
'two': array([24.51951851]),
|
||||
}),
|
||||
'v': dict({
|
||||
'one': array([12.90271504, 14.61900084]),
|
||||
'two': array([9.85889607]),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_binned_data_dot_pair_measurement_costs
|
||||
dict({
|
||||
'linear': dict({
|
||||
'v': dict({
|
||||
'one': array([ 1.12885689, 16.07423316]),
|
||||
'two': array([16.07423316]),
|
||||
}),
|
||||
}),
|
||||
'log': dict({
|
||||
'v': dict({
|
||||
'one': array([ 1.12885689, 16.07423316]),
|
||||
'two': array([16.07423316]),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_parse_headers
|
||||
list([
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot1_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot1_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'dot2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_dot2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'line',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_line_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'line',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_line_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle1_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle1_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'triangle2',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_triangle2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': None,
|
||||
'dot_name': 'uprise1',
|
||||
'dot_name2': None,
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'APSD_V_uprise1_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
None,
|
||||
dict({
|
||||
'cpsd_type': 'correlation',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_correlation_V_dot1_dot2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': 'correlation',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_correlation_V_dot1_dot2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': 'phase',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_phase_V_dot1_dot2_mean',
|
||||
'summary_stat': 'mean',
|
||||
}),
|
||||
dict({
|
||||
'cpsd_type': 'phase',
|
||||
'dot_name': 'dot1',
|
||||
'dot_name2': 'dot2',
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'original_field': 'CPSD_phase_V_dot1_dot2_stdev',
|
||||
'summary_stat': 'stdev',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_read_csv_with_pairs
|
||||
dict({
|
||||
'freqs': list([
|
||||
0.0125,
|
||||
0.024999999999999998,
|
||||
0.045,
|
||||
0.0775,
|
||||
0.1375,
|
||||
0.24749999999999997,
|
||||
0.41,
|
||||
]),
|
||||
'measurement_type': <MeasurementTypeEnum.POTENTIAL: 'electric-potential'>,
|
||||
'pair_dot_dict': dict({
|
||||
tuple(
|
||||
'dot1',
|
||||
'dot2',
|
||||
): dict({
|
||||
'mean': list([
|
||||
3.15,
|
||||
3.13,
|
||||
3.0,
|
||||
2.7,
|
||||
0.1,
|
||||
0.25,
|
||||
0.002,
|
||||
]),
|
||||
'stdev': list([
|
||||
0.1,
|
||||
0.11,
|
||||
0.8,
|
||||
1.5,
|
||||
2.0,
|
||||
2.0,
|
||||
1.5,
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
'single_dot_dict': dict({
|
||||
'dot1': dict({
|
||||
'mean': list([
|
||||
10.638916947949246,
|
||||
4.808960230987057,
|
||||
1.8458074293863327,
|
||||
1.0990901962765007,
|
||||
0.6425140116757488,
|
||||
0.4844873135633905,
|
||||
0.448232552,
|
||||
]),
|
||||
'stdev': list([
|
||||
5.688165841523548,
|
||||
1.5555855859097745,
|
||||
0.5112103163244077,
|
||||
0.37605535,
|
||||
0.1411676088216461,
|
||||
0.11795510686231957,
|
||||
0.081977941,
|
||||
]),
|
||||
}),
|
||||
'dot2': dict({
|
||||
'mean': list([
|
||||
14.780311491085596,
|
||||
7.413101036489984,
|
||||
3.081527317039941,
|
||||
1.198719434472466,
|
||||
0.44608783800009594,
|
||||
0.16750150967807267,
|
||||
0.095604286,
|
||||
]),
|
||||
'stdev': list([
|
||||
5.085761250807487,
|
||||
2.7753690312876014,
|
||||
1.3009911753215875,
|
||||
0.3361763625979774,
|
||||
0.18042157503806078,
|
||||
0.05820931,
|
||||
0.022567042968929727,
|
||||
]),
|
||||
}),
|
||||
'frequencies': list([
|
||||
0.0125,
|
||||
0.024999999999999998,
|
||||
0.045,
|
||||
0.0775,
|
||||
0.1375,
|
||||
0.24749999999999997,
|
||||
0.41,
|
||||
]),
|
||||
}),
|
||||
})
|
||||
# ---
|
30
tests/read_bin_csv/test_files/dots.json
Normal file
30
tests/read_bin_csv/test_files/dots.json
Normal file
@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"r": [5, 0, 0],
|
||||
"label": "dot1"
|
||||
},
|
||||
{
|
||||
"r": [-5, 0, 0],
|
||||
"label": "dot2"
|
||||
},
|
||||
{
|
||||
"r": [0, 0, 0],
|
||||
"label": "line"
|
||||
},
|
||||
{
|
||||
"r": [-5, 5, 0],
|
||||
"label": "triangle1"
|
||||
},
|
||||
{
|
||||
"r": [0, 3.5, 0],
|
||||
"label": "triangle2"
|
||||
},
|
||||
{
|
||||
"r": [-5, 5, 0.5],
|
||||
"label": "uprise1"
|
||||
},
|
||||
{
|
||||
"r": [0, 3.5, 0.5],
|
||||
"label": "uprise2"
|
||||
}
|
||||
]
|
8
tests/read_bin_csv/test_files/test_binned_apsd_Ex.csv
Normal file
8
tests/read_bin_csv/test_files/test_binned_apsd_Ex.csv
Normal file
@ -0,0 +1,8 @@
|
||||
mean bin f (Hz), APSD_Ex_dot1_mean, APSD_Ex_dot1_stdev, APSD_Ex_dot2_mean, APSD_Ex_dot2_stdev, APSD_Ex_line_mean, APSD_Ex_line_stdev, APSD_Ex_triangle1_mean, APSD_Ex_triangle1_stdev, APSD_Ex_triangle2_mean, APSD_Ex_triangle2_stdev, APSD_Ex_uprise1_mean, APSD_Ex_uprise1_stdev, APSD_Ex_uprise2_mean, APSD_Ex_uprise2_stdev
|
||||
0.0125, 0.0011193742876607947, 7.308063002906329e-05, 0.00025308842133615314, 7.358112457611491e-05, 0.0002730912968430376, 1.5169478503125244e-05, 5.9269065861281914e-05, 1.5702545414499775e-05, 0.0001141958169440423, 7.65214751430245e-06, 5.2821694680807203e-05, 1.4064198180241187e-05, 0.00010211448239039132, 6.71970072461593e-06
|
||||
0.024999999999999998, 0.001412604489329749, 0.00010570879878990876, 0.00013944508257776318, 9.200512894443294e-05, 0.0003291803444964181, 4.8109121440041884e-05, 4.264173761272405e-05, 2.1587183300665237e-05, 0.00013608574913899783, 2.1517500402085424e-05, 3.767673858289332e-05, 1.9266237269735494e-05, 0.00012184117806180041, 1.9109518704617513e-05
|
||||
0.045, 0.0012373217155713524, 0.0002899722649129945, 6.922355178721988e-05, 1.7755523180885715e-05, 0.000272700633559799, 6.414101070632063e-05, 2.6370731010957502e-05, 6.399841060844057e-06, 0.0001119715478076307, 2.6347911201381e-05, 2.3161333453135425e-05, 5.626881722609435e-06, 0.00010032448012855114, 2.3606143309976283e-05
|
||||
0.0775, 0.001112910278683047, 0.00023923894218936458, 5.8856159419426865e-05, 1.0212484008992887e-05, 0.00024448074444704794, 5.056251714395904e-05, 2.304063457218835e-05, 4.225116662857298e-06, 0.0001003381671115233, 2.0674173561984387e-05, 2.0223453630183885e-05, 3.7012422664279698e-06, 8.990567067147231e-05, 1.853194873840372e-05
|
||||
0.1375, 0.0010269808087721118, 0.00021311411519638195, 5.378486118291285e-05, 1.1600120081594358e-05, 0.0002266575298509303, 4.7459339601317906e-05, 2.1318465024274433e-05, 4.551555868181482e-06, 9.304972164771508e-05, 1.9498238568083357e-05, 1.870759251880998e-05, 3.995274554481259e-06, 8.337261927681384e-05, 1.746904722945101e-05
|
||||
0.24749999999999997, 0.0009456006414492572, 0.00018207143834945663, 4.841142457620057e-05, 9.624450747042635e-06, 0.00020826892321528504, 4.027313841923547e-05, 1.9384878389355338e-05, 3.8063957960494524e-06, 8.548080636563674e-05, 1.6536471875696628e-05, 1.7006684558357824e-05, 3.3404547377359696e-06, 7.659277051995944e-05, 1.4816395341788071e-05
|
||||
0.41, 0.0008277349012033521, 0.00016315667168352693, 4.2106112574100974e-05, 8.190855171238974e-06, 0.00018218546262776102, 3.583769342624307e-05, 1.690642674427495e-05, 3.3043159836773044e-06, 7.476985259654567e-05, 1.4705141013296952e-05, 1.4831311923299011e-05, 2.898385689290891e-06, 6.699602547507511e-05, 1.3176515249800673e-05
|
|
8
tests/read_bin_csv/test_files/test_binned_apsd_V.csv
Normal file
8
tests/read_bin_csv/test_files/test_binned_apsd_V.csv
Normal file
@ -0,0 +1,8 @@
|
||||
mean bin f (Hz), APSD_V_dot1_mean, APSD_V_dot1_stdev, APSD_V_dot2_mean, APSD_V_dot2_stdev, APSD_V_line_mean, APSD_V_line_stdev, APSD_V_triangle1_mean, APSD_V_triangle1_stdev, APSD_V_triangle2_mean, APSD_V_triangle2_stdev, APSD_V_uprise1_mean, APSD_V_uprise1_stdev, APSD_V_uprise2_mean, APSD_V_uprise2_stdev
|
||||
0.0125, 10.638916947949246, 5.688165841523548, 14.780311491085596, 5.085761250807487, 38.270876407794276, 11.39990688883059, 11.899895021145635, 4.53573184571792, 33.66541768670497, 11.231365853486137, 13.361715690908607, 5.0675284566338625, 45.90844164247454, 14.970571928118593
|
||||
0.024999999999999998, 4.808960230987057, 1.5555855859097745, 7.413101036489984, 2.7753690312876014, 18.464083270261195, 6.7851330102447776, 5.913586933800311, 2.2751779798574585, 16.223538221497495, 6.031211083456, 6.661543725531067, 2.5702840629870978, 21.788765787102644, 7.992391910434401
|
||||
0.045, 1.8458074293863327, 0.5112103163244077, 3.081527317039941, 1.3009911753215875, 7.958206578940247, 3.2965800683487467, 2.395916783991011, 1.0246980114887267, 6.890979907540165, 2.8522163160388114, 2.703848365637358, 1.157751077467679, 9.361785651884627, 3.830813979817984
|
||||
0.0775, 1.0990901962765007, 0.376055350460546, 1.198719434472466, 0.3361763625979774, 3.1067340783807977, 0.9134607210867323, 0.9150694736063263, 0.24363144876739817, 2.633138291665575, 0.7603021217632877, 1.0340536040727817, 0.27799313793346814, 3.5839262514780574, 1.0495809786215216
|
||||
0.1375, 0.6425140116757488, 0.1411676088216461, 0.44608783800009594, 0.18042157503806078, 1.1918940813568575, 0.478463968609111, 0.3292608404218125, 0.1333924934596153, 0.9770973161110396, 0.39370370747073685, 0.37076116047359375, 0.14947611970183075, 1.3225984280531975, 0.5297312429960187
|
||||
0.24749999999999997, 0.4844873135633905, 0.11795510686231957, 0.16750150967807267, 0.05820931036468994, 0.4599545316713671, 0.15476866608286458, 0.11634897814927361, 0.04295850059608126, 0.3543639501963306, 0.12577555122177078, 0.13049802854728143, 0.04802345307476847, 0.47113614904712353, 0.16713468399902373
|
||||
0.41, 0.44823255155896, 0.081977940913005, 0.095604286225167, 0.022567042968929727, 0.2695902593679458, 0.06281745354709284, 0.06146961050341923, 0.015083740543610176, 0.1926321299071961, 0.04748629565145606, 0.06873250192025933, 0.016919878842437803, 0.2523607031606806, 0.0630760820406672
|
|
8
tests/read_bin_csv/test_files/test_simple_pair_V.csv
Normal file
8
tests/read_bin_csv/test_files/test_simple_pair_V.csv
Normal file
@ -0,0 +1,8 @@
|
||||
mean bin f (Hz), APSD_V_dot1_mean, APSD_V_dot1_stdev, APSD_V_dot2_mean, APSD_V_dot2_stdev,CPSD_phase_V_dot1_dot2_mean,CPSD_phase_V_dot1_dot2_stdev
|
||||
0.0125, 10.638916947949246, 5.688165841523548, 14.780311491085596, 5.085761250807487,3.15,0.1
|
||||
0.024999999999999998, 4.808960230987057, 1.5555855859097745, 7.413101036489984, 2.7753690312876014,3.13,0.11
|
||||
0.045, 1.8458074293863327, 0.5112103163244077, 3.081527317039941, 1.3009911753215875,3,0.8
|
||||
0.0775, 1.0990901962765007,0.37605535, 1.198719434472466, 0.3361763625979774,2.7,1.5
|
||||
0.1375, 0.6425140116757488, 0.1411676088216461, 0.44608783800009594, 0.18042157503806078,0.1,2
|
||||
0.24749999999999997, 0.4844873135633905, 0.11795510686231957, 0.16750150967807267,0.05820931,0.25,2
|
||||
0.41,0.448232552,0.081977941,0.095604286, 0.022567042968929727,0.002,1.5
|
|
191
tests/read_bin_csv/test_read_bin_csv.py
Normal file
191
tests/read_bin_csv/test_read_bin_csv.py
Normal file
@ -0,0 +1,191 @@
|
||||
import re
|
||||
import kalpaa.read_bin_csv
|
||||
import pathlib
|
||||
import dataclasses
|
||||
import logging
|
||||
import typing
|
||||
import numpy
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TEST_DATA_DIR = pathlib.Path(__file__).resolve().parent / "test_files"
|
||||
|
||||
|
||||
def test_regex_matches():
|
||||
apsd_v_1 = "APSD_V_dot1_mean"
|
||||
|
||||
actual_match1 = re.match(kalpaa.read_bin_csv.BINNED_HEADER_REGEX, apsd_v_1)
|
||||
|
||||
# For reference, REGEX is currently: APSD_(?P<measurement_type>\w+)_(?P<dot_name>\w+)_(?P<summary_stat>mean|stdev)\s*
|
||||
assert actual_match1 is not None
|
||||
groups = actual_match1.groupdict()
|
||||
assert groups["measurement_type"] == "V"
|
||||
assert groups["dot_name"] == "dot1"
|
||||
assert groups["summary_stat"] == "mean"
|
||||
|
||||
|
||||
def test_parse_headers(snapshot):
|
||||
example_headers = [
|
||||
# using these headers from recent run: APSD_V_dot1_mean, APSD_V_dot1_stdev, APSD_V_dot2_mean, APSD_V_dot2_stdev, APSD_V_line_mean, APSD_V_line_stdev, APSD_V_triangle1_mean, APSD_V_triangle1_stdev, APSD_V_triangle2_mean, APSD_V_triangle2_stdev, APSD_V_uprise1_mean, APSD_V_uprise1_stdev, APSD_V_uprise2_mean, APSD_V_uprise2_stdev
|
||||
"APSD_V_dot1_mean",
|
||||
"APSD_V_dot1_stdev",
|
||||
"APSD_V_dot2_mean",
|
||||
"APSD_V_dot2_stdev",
|
||||
"APSD_V_line_mean",
|
||||
"APSD_V_line_stdev",
|
||||
"APSD_V_triangle1_mean",
|
||||
"APSD_V_triangle1_stdev",
|
||||
"APSD_V_triangle2_mean",
|
||||
"APSD_V_triangle2_stdev",
|
||||
"APSD_V_uprise1_mean",
|
||||
"This is not a valid header",
|
||||
"CPSD_correlation_V_dot1_dot2_mean",
|
||||
"CPSD_correlation_V_dot1_dot2_stdev",
|
||||
"CPSD_phase_V_dot1_dot2_mean",
|
||||
"CPSD_phase_V_dot1_dot2_stdev",
|
||||
]
|
||||
|
||||
# force logger to be used for now
|
||||
_logger.debug("parsing headers for test")
|
||||
|
||||
def null_asdict(dataclz) -> typing.Optional[dict]:
|
||||
if dataclz is None:
|
||||
return None
|
||||
return dataclasses.asdict(dataclz)
|
||||
|
||||
actual_parsed = [
|
||||
null_asdict(kalpaa.read_bin_csv._parse_bin_header(h)) for h in example_headers
|
||||
]
|
||||
assert actual_parsed == snapshot
|
||||
|
||||
|
||||
def test_binned_data_dot_measurement(snapshot):
|
||||
|
||||
dots_json = TEST_DATA_DIR / "dots.json"
|
||||
csv_file = TEST_DATA_DIR / "test_binned_apsd_V.csv"
|
||||
|
||||
actual_read = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, csv_file)
|
||||
|
||||
assert dataclasses.asdict(actual_read) == snapshot
|
||||
|
||||
|
||||
def test_binned_data_dot_measurement_costs(snapshot):
|
||||
|
||||
dots_json = TEST_DATA_DIR / "dots.json"
|
||||
v_csv_file = TEST_DATA_DIR / "test_binned_apsd_V.csv"
|
||||
ex_csv_file = TEST_DATA_DIR / "test_binned_apsd_Ex.csv"
|
||||
|
||||
# it's overkill but while we have the mental model of what the input form is of these numpy arrays we should record it!
|
||||
test_dipole_1 = [10, 20, 30, 0.4, 0.5, 0.6, 0.7]
|
||||
test_dipole_2 = [15, 25, 35, -4, 0, 0, 0.11]
|
||||
test_dipoles_configuration_1 = [test_dipole_1]
|
||||
test_one_dipole_config2 = [test_dipole_2]
|
||||
all_one_dipole_configs = numpy.array(
|
||||
[test_dipoles_configuration_1, test_one_dipole_config2]
|
||||
)
|
||||
|
||||
test_two_dipole_config1 = [test_dipole_1, test_dipole_2]
|
||||
all_two_dipole_configs = numpy.array([test_two_dipole_config1])
|
||||
|
||||
binned_v = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, v_csv_file)
|
||||
measurements_v = binned_v.measurements(["dot1"])
|
||||
|
||||
binned_ex = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, ex_csv_file)
|
||||
measurements_ex = binned_ex.measurements(["dot1"])
|
||||
_logger.warning(measurements_ex)
|
||||
|
||||
v_log_noise_stdev_cost_func = measurements_v.stdev_cost_function(use_log_noise=True)
|
||||
ex_log_noise_stdev_cost_func = measurements_ex.stdev_cost_function(
|
||||
use_log_noise=True
|
||||
)
|
||||
|
||||
v_linear_noise_stdev_cost_func = measurements_v.stdev_cost_function(
|
||||
use_log_noise=False
|
||||
)
|
||||
ex_linear_noise_stdev_cost_func = measurements_ex.stdev_cost_function(
|
||||
use_log_noise=False
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
"log": {
|
||||
"v": {
|
||||
"one": v_log_noise_stdev_cost_func(all_one_dipole_configs),
|
||||
"two": v_log_noise_stdev_cost_func(all_two_dipole_configs),
|
||||
},
|
||||
"ex": {
|
||||
"one": ex_log_noise_stdev_cost_func(all_one_dipole_configs),
|
||||
"two": ex_log_noise_stdev_cost_func(all_two_dipole_configs),
|
||||
},
|
||||
},
|
||||
"linear": {
|
||||
"v": {
|
||||
"one": v_linear_noise_stdev_cost_func(all_one_dipole_configs),
|
||||
"two": v_linear_noise_stdev_cost_func(all_two_dipole_configs),
|
||||
},
|
||||
"ex": {
|
||||
"one": ex_linear_noise_stdev_cost_func(all_one_dipole_configs),
|
||||
"two": ex_linear_noise_stdev_cost_func(all_two_dipole_configs),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert result_dict == snapshot
|
||||
|
||||
|
||||
def test_read_csv_with_pairs(snapshot):
|
||||
|
||||
# dots_json = TEST_DATA_DIR / "dots.json"
|
||||
# v_csv_file = TEST_DATA_DIR / "test_binned_apsd_V.csv"
|
||||
# ex_csv_file = TEST_DATA_DIR / "test_binned_apsd_Ex.csv"
|
||||
pair_data_csv = TEST_DATA_DIR / "test_simple_pair_V.csv"
|
||||
|
||||
actual_read = kalpaa.read_bin_csv.read_bin_csv(pair_data_csv)
|
||||
|
||||
assert dataclasses.asdict(actual_read) == snapshot
|
||||
|
||||
|
||||
def test_binned_data_dot_pair_measurement_costs(snapshot):
|
||||
|
||||
dots_json = TEST_DATA_DIR / "dots.json"
|
||||
v_csv_file = TEST_DATA_DIR / "test_simple_pair_V.csv"
|
||||
|
||||
# it's overkill but while we have the mental model of what the input form is of these numpy arrays we should record it!
|
||||
test_dipole_1 = [10, 20, 30, 0.4, 0.5, 0.6, 0.7]
|
||||
test_dipole_2 = [15, 25, 35, -6, 0, 0, 0.11]
|
||||
test_dipoles_configuration_1 = [test_dipole_1]
|
||||
test_one_dipole_config2 = [test_dipole_2]
|
||||
all_one_dipole_configs = numpy.array(
|
||||
[test_dipoles_configuration_1, test_one_dipole_config2]
|
||||
)
|
||||
|
||||
test_two_dipole_config1 = [test_dipole_1, test_dipole_2]
|
||||
all_two_dipole_configs = numpy.array([test_two_dipole_config1])
|
||||
|
||||
binned_v = kalpaa.read_bin_csv.read_dots_and_binned(dots_json, v_csv_file)
|
||||
measurements_v = binned_v._get_measurement_from_dot_name_or_pair([("dot1", "dot2")])
|
||||
|
||||
_logger.info(f"measurements_v: {measurements_v}")
|
||||
|
||||
v_log_noise_stdev_cost_func = measurements_v.stdev_cost_function(use_log_noise=True)
|
||||
|
||||
v_linear_noise_stdev_cost_func = measurements_v.stdev_cost_function(
|
||||
use_log_noise=False
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
"log": {
|
||||
"v": {
|
||||
"one": v_log_noise_stdev_cost_func(all_one_dipole_configs),
|
||||
"two": v_log_noise_stdev_cost_func(all_two_dipole_configs),
|
||||
},
|
||||
},
|
||||
"linear": {
|
||||
"v": {
|
||||
"one": v_linear_noise_stdev_cost_func(all_one_dipole_configs),
|
||||
"two": v_linear_noise_stdev_cost_func(all_two_dipole_configs),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert result_dict == snapshot
|
Loading…
x
Reference in New Issue
Block a user