Compare commits
14 Commits
129de4e025
...
5a5f5ede3a
Author | SHA1 | Date | |
---|---|---|---|
5a5f5ede3a | |||
f7559b2c4f | |||
9a7a3ff2c7 | |||
c4805806be | |||
161bcf42ad | |||
8e6ead416c | |||
e6defc7948 | |||
33d5da6a4f | |||
1110372a55 | |||
e6a00d6b8f | |||
57cd746e5c | |||
878e16286b | |||
4726ccfb8c | |||
598dad1e6d |
2
.flake8
2
.flake8
@ -1,3 +1,3 @@
|
||||
[flake8]
|
||||
ignore = W191, E501, W503
|
||||
ignore = W191, E501, W503, E203
|
||||
max-line-length = 120
|
||||
|
21
CHANGELOG.md
21
CHANGELOG.md
@ -2,6 +2,27 @@
|
||||
|
||||
All notable changes to this project will be documented in this file. See [standard-version](https://github.com/conventional-changelog/standard-version) for commit guidelines.
|
||||
|
||||
### [0.7.4](https://gitea.deepak.science:2222/physics/deepdog/compare/0.7.3...0.7.4) (2023-07-27)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* adds configurable chunk size for the initial mc level 0 SS stage cost calculation to reduce memory usage ([9a7a3ff](https://gitea.deepak.science:2222/physics/deepdog/commit/9a7a3ff2c7ebe81d5e10647ce39844c372ff7b07))
|
||||
* allows for deepdog bayesrun with ss to not print csv to make snapshot testing possible ([8e6ead4](https://gitea.deepak.science:2222/physics/deepdog/commit/8e6ead416c9eba56f568f648d0df44caaa510cfe))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* fixes bug if case of clamping necessary ([161bcf4](https://gitea.deepak.science:2222/physics/deepdog/commit/161bcf42addf331661c3929073688b9f2c13502c))
|
||||
* fixes bug with clamped probabilities being underestimated ([e6defc7](https://gitea.deepak.science:2222/physics/deepdog/commit/e6defc794871a48ac331023eb477bd235b78d6d0))
|
||||
|
||||
### [0.7.3](https://gitea.deepak.science:2222/physics/deepdog/compare/0.7.2...0.7.3) (2023-07-27)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* adds utility options and avoids memory leak ([598dad1](https://gitea.deepak.science:2222/physics/deepdog/commit/598dad1e6dc8fc0b7a5b4a90c8e17bf744e8d98c))
|
||||
|
||||
### [0.7.2](https://gitea.deepak.science:2222/physics/deepdog/compare/0.7.1...0.7.2) (2023-07-24)
|
||||
|
||||
|
||||
|
@ -70,6 +70,9 @@ class BayesRunWithSubspaceSimulation:
|
||||
ss_default_r_step=0.01,
|
||||
ss_default_w_log_step=0.01,
|
||||
ss_default_upper_w_log_step=4,
|
||||
ss_dump_last_generation=False,
|
||||
ss_initial_costs_chunk_size=100,
|
||||
write_output_to_bayesruncsv=True,
|
||||
) -> None:
|
||||
self.dot_inputs = pdme.inputs.inputs_with_frequency_range(
|
||||
dot_positions, frequency_range
|
||||
@ -133,13 +136,22 @@ class BayesRunWithSubspaceSimulation:
|
||||
self.ss_default_r_step = ss_default_r_step
|
||||
self.ss_default_w_log_step = ss_default_w_log_step
|
||||
self.ss_default_upper_w_log_step = ss_default_upper_w_log_step
|
||||
|
||||
self.ss_dump_last_generation = ss_dump_last_generation
|
||||
self.ss_initial_costs_chunk_size = ss_initial_costs_chunk_size
|
||||
self.run_count = run_count
|
||||
|
||||
def go(self) -> None:
|
||||
with open(self.filename, "a", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||
writer.writeheader()
|
||||
self.write_output_to_csv = write_output_to_bayesruncsv
|
||||
|
||||
def go(self) -> Sequence:
|
||||
|
||||
if self.write_output_to_csv:
|
||||
with open(self.filename, "a", newline="") as outfile:
|
||||
writer = csv.DictWriter(
|
||||
outfile, fieldnames=self.csv_fields, dialect="unix"
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
return_result = []
|
||||
|
||||
for run in range(1, self.run_count + 1):
|
||||
|
||||
@ -172,6 +184,9 @@ class BayesRunWithSubspaceSimulation:
|
||||
self.ss_default_r_step,
|
||||
self.ss_default_w_log_step,
|
||||
self.ss_default_upper_w_log_step,
|
||||
initial_cost_chunk_size=self.ss_initial_costs_chunk_size,
|
||||
keep_probs_list=False,
|
||||
dump_last_generation_to_file=self.ss_dump_last_generation,
|
||||
)
|
||||
results.append(subset_run.execute())
|
||||
|
||||
@ -195,8 +210,14 @@ class BayesRunWithSubspaceSimulation:
|
||||
|
||||
for (name, result) in zip(self.model_names, results):
|
||||
if result.over_target_likelihood is None:
|
||||
clamped_likelihood = result.probs_list[-1][0] / CLAMPING_FACTOR
|
||||
_logger.warning(f"got a none result, clamping to {clamped_likelihood}")
|
||||
if result.lowest_likelihood is None:
|
||||
_logger.error(f"result {result} looks bad")
|
||||
clamped_likelihood = 10**-15
|
||||
else:
|
||||
clamped_likelihood = result.lowest_likelihood / CLAMPING_FACTOR
|
||||
_logger.warning(
|
||||
f"got a none result, clamping to {clamped_likelihood}"
|
||||
)
|
||||
else:
|
||||
clamped_likelihood = result.over_target_likelihood
|
||||
likelihoods.append(clamped_likelihood)
|
||||
@ -216,12 +237,14 @@ class BayesRunWithSubspaceSimulation:
|
||||
for name, probability in zip(self.model_names, self.probabilities):
|
||||
row[f"{name}_prob"] = probability
|
||||
_logger.info(row)
|
||||
return_result.append(row)
|
||||
|
||||
with open(self.filename, "a", newline="") as outfile:
|
||||
writer = csv.DictWriter(
|
||||
outfile, fieldnames=self.csv_fields, dialect="unix"
|
||||
)
|
||||
writer.writerow(row)
|
||||
if self.write_output_to_csv:
|
||||
with open(self.filename, "a", newline="") as outfile:
|
||||
writer = csv.DictWriter(
|
||||
outfile, fieldnames=self.csv_fields, dialect="unix"
|
||||
)
|
||||
writer.writerow(row)
|
||||
|
||||
if self.use_end_threshold:
|
||||
max_prob = max(self.probabilities)
|
||||
@ -230,3 +253,5 @@ class BayesRunWithSubspaceSimulation:
|
||||
f"Aborting early, because {max_prob} is greater than {self.end_threshold}"
|
||||
)
|
||||
break
|
||||
|
||||
return return_result
|
||||
|
@ -17,6 +17,7 @@ class SubsetSimulationResult:
|
||||
over_target_likelihood: Optional[float]
|
||||
under_target_cost: Optional[float]
|
||||
under_target_likelihood: Optional[float]
|
||||
lowest_likelihood: Optional[float]
|
||||
|
||||
|
||||
class SubsetSimulation:
|
||||
@ -37,6 +38,9 @@ class SubsetSimulation:
|
||||
default_r_step=0.01,
|
||||
default_w_log_step=0.01,
|
||||
default_upper_w_log_step=4,
|
||||
keep_probs_list=True,
|
||||
dump_last_generation_to_file=False,
|
||||
initial_cost_chunk_size=100,
|
||||
):
|
||||
name, model = model_name_pair
|
||||
self.model_name = name
|
||||
@ -79,6 +83,11 @@ class SubsetSimulation:
|
||||
self.target_cost = target_cost
|
||||
_logger.info(f"will stop at target cost {target_cost}")
|
||||
|
||||
self.keep_probs_list = keep_probs_list
|
||||
self.dump_last_generations = dump_last_generation_to_file
|
||||
|
||||
self.initial_cost_chunk_size = initial_cost_chunk_size
|
||||
|
||||
def execute(self) -> SubsetSimulationResult:
|
||||
|
||||
probs_list = []
|
||||
@ -90,7 +99,14 @@ class SubsetSimulation:
|
||||
)
|
||||
# _logger.debug(sample_dipoles)
|
||||
# _logger.debug(sample_dipoles.shape)
|
||||
costs = self.cost_function_to_use(sample_dipoles)
|
||||
|
||||
raw_costs = []
|
||||
_logger.debug(f"Using iterated cost function thing with chunk size {self.initial_cost_chunk_size}")
|
||||
|
||||
for x in range(0, len(sample_dipoles), self.initial_cost_chunk_size):
|
||||
_logger.debug(f"doing chunk {x}")
|
||||
raw_costs.extend(self.cost_function_to_use(sample_dipoles[x: x + self.initial_cost_chunk_size]))
|
||||
costs = numpy.array(raw_costs)
|
||||
|
||||
_logger.debug(f"costs: {costs}")
|
||||
sorted_indexes = costs.argsort()[::-1]
|
||||
@ -114,27 +130,42 @@ class SubsetSimulation:
|
||||
mcmc_rng = numpy.random.default_rng(self.mcmc_seed)
|
||||
|
||||
for i in range(self.m_max):
|
||||
next_seeds = all_chains[-self.n_c:]
|
||||
next_seeds = all_chains[-self.n_c :]
|
||||
|
||||
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
||||
probs_list.append(
|
||||
(
|
||||
((self.n_c * self.n_s - cost_index) / (self.n_c * self.n_s))
|
||||
/ (self.n_s ** (i)),
|
||||
cost_chain[0],
|
||||
i + 1,
|
||||
if self.dump_last_generations:
|
||||
_logger.info("writing out csv file")
|
||||
next_dipoles_seed_dipoles = numpy.array([n[1] for n in next_seeds])
|
||||
for n in range(self.model.n):
|
||||
_logger.info(f"{next_dipoles_seed_dipoles[:, n].shape}")
|
||||
numpy.savetxt(
|
||||
f"generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv",
|
||||
next_dipoles_seed_dipoles[:, n],
|
||||
delimiter=",",
|
||||
)
|
||||
|
||||
if self.keep_probs_list:
|
||||
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
||||
probs_list.append(
|
||||
(
|
||||
((self.n_c * self.n_s - cost_index) / (self.n_c * self.n_s))
|
||||
/ (self.n_s ** (i)),
|
||||
cost_chain[0],
|
||||
i + 1,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
next_seeds_as_array = numpy.array([s for _, s in next_seeds])
|
||||
|
||||
stdevs = self.get_stdevs_from_arrays(next_seeds_as_array)
|
||||
_logger.info(f"got stdevs: {stdevs.stdevs}")
|
||||
|
||||
_logger.debug("Starting the MCMC")
|
||||
all_chains = []
|
||||
for c, s in next_seeds:
|
||||
for seed_index, (c, s) in enumerate(next_seeds):
|
||||
# chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
|
||||
# until new version gotta do
|
||||
_logger.debug(
|
||||
f"\t{seed_index}: getting another chain from the next seed"
|
||||
)
|
||||
chain = self.model.get_mcmc_chain(
|
||||
s,
|
||||
self.cost_function_to_use,
|
||||
@ -150,10 +181,11 @@ class SubsetSimulation:
|
||||
except IndexError:
|
||||
filtered_cost = cost
|
||||
all_chains.append((filtered_cost, chained))
|
||||
|
||||
_logger.debug("finished mcmc")
|
||||
# _logger.debug(all_chains)
|
||||
|
||||
all_chains.sort(key=lambda c: c[0], reverse=True)
|
||||
_logger.debug("finished sorting all_chains")
|
||||
|
||||
threshold_cost = all_chains[-self.n_c][0]
|
||||
_logger.info(
|
||||
@ -169,14 +201,18 @@ class SubsetSimulation:
|
||||
|
||||
shorter_probs_list = []
|
||||
for cost_index, cost_chain in enumerate(all_chains):
|
||||
probs_list.append(
|
||||
(
|
||||
((self.n_c * self.n_s - cost_index) / (self.n_c * self.n_s))
|
||||
/ (self.n_s ** (i)),
|
||||
cost_chain[0],
|
||||
i + 1,
|
||||
if self.keep_probs_list:
|
||||
probs_list.append(
|
||||
(
|
||||
(
|
||||
(self.n_c * self.n_s - cost_index)
|
||||
/ (self.n_c * self.n_s)
|
||||
)
|
||||
/ (self.n_s ** (i)),
|
||||
cost_chain[0],
|
||||
i + 1,
|
||||
)
|
||||
)
|
||||
)
|
||||
shorter_probs_list.append(
|
||||
(
|
||||
cost_chain[0],
|
||||
@ -191,21 +227,23 @@ class SubsetSimulation:
|
||||
over_target_likelihood=shorter_probs_list[over_index - 1][1],
|
||||
under_target_cost=shorter_probs_list[over_index][0],
|
||||
under_target_likelihood=shorter_probs_list[over_index][1],
|
||||
lowest_likelihood=shorter_probs_list[-1][1],
|
||||
)
|
||||
return result
|
||||
|
||||
# _logger.debug([c[0] for c in all_chains[-n_c:]])
|
||||
_logger.info(f"doing level {i + 1}")
|
||||
|
||||
for cost_index, cost_chain in enumerate(all_chains):
|
||||
probs_list.append(
|
||||
(
|
||||
((self.n_c * self.n_s - cost_index) / (self.n_c * self.n_s))
|
||||
/ (self.n_s ** (self.m_max)),
|
||||
cost_chain[0],
|
||||
self.m_max + 1,
|
||||
if self.keep_probs_list:
|
||||
for cost_index, cost_chain in enumerate(all_chains):
|
||||
probs_list.append(
|
||||
(
|
||||
((self.n_c * self.n_s - cost_index) / (self.n_c * self.n_s))
|
||||
/ (self.n_s ** (self.m_max)),
|
||||
cost_chain[0],
|
||||
self.m_max + 1,
|
||||
)
|
||||
)
|
||||
)
|
||||
threshold_cost = all_chains[-self.n_c][0]
|
||||
_logger.info(
|
||||
f"final threshold cost: {threshold_cost}, at P = (1 / {self.n_s})^{self.m_max + 1}"
|
||||
@ -215,12 +253,16 @@ class SubsetSimulation:
|
||||
# for prob, prob_cost in probs_list:
|
||||
# _logger.info(f"\t{prob}: {prob_cost}")
|
||||
probs_list.sort(key=lambda c: c[0], reverse=True)
|
||||
|
||||
min_likelihood = ((1) / (self.n_c * self.n_s)) / (self.n_s ** (self.m_max))
|
||||
|
||||
result = SubsetSimulationResult(
|
||||
probs_list=probs_list,
|
||||
over_target_cost=None,
|
||||
over_target_likelihood=None,
|
||||
under_target_cost=None,
|
||||
under_target_likelihood=None,
|
||||
lowest_likelihood=min_likelihood,
|
||||
)
|
||||
return result
|
||||
|
||||
|
2
do.sh
2
do.sh
@ -18,7 +18,7 @@ test() {
|
||||
|
||||
fmt() {
|
||||
poetry run black .
|
||||
find . -type f -name "*.py" -exec sed -i -e 's/ /\t/g' {} \;
|
||||
find . -not \( -path "./.*" -type d -prune \) -type f -name "*.py" -exec sed -i -e 's/ /\t/g' {} \;
|
||||
}
|
||||
|
||||
release() {
|
||||
|
28
poetry.lock
generated
28
poetry.lock
generated
@ -92,6 +92,14 @@ category = "dev"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "1.4.4"
|
||||
description = "Simple library for color and formatting to terminal"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.2.7"
|
||||
@ -461,11 +469,11 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "3.0.0"
|
||||
version = "4.1.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=5.2.1", extras = ["toml"]}
|
||||
@ -633,6 +641,18 @@ category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
|
||||
[[package]]
|
||||
name = "syrupy"
|
||||
version = "4.0.8"
|
||||
description = "Pytest Snapshot Test Utility"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4"
|
||||
|
||||
[package.dependencies]
|
||||
colored = ">=1.3.92,<2.0.0"
|
||||
pytest = ">=7.0.0,<8.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
@ -730,7 +750,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8.1,<3.10"
|
||||
content-hash = "0161af7edf18c16819f1ce083ab491c17c9809f2770219725131451b1a16a970"
|
||||
content-hash = "e1531b1493bac50ffe5e8f9a46a64d9b66198f7021f6d643c72f21cb53dc77ec"
|
||||
|
||||
[metadata.files]
|
||||
black = []
|
||||
@ -741,6 +761,7 @@ charset-normalizer = []
|
||||
click = []
|
||||
click-log = []
|
||||
colorama = []
|
||||
colored = []
|
||||
coverage = []
|
||||
cryptography = []
|
||||
docutils = []
|
||||
@ -786,6 +807,7 @@ secretstorage = []
|
||||
semver = []
|
||||
six = []
|
||||
smmap = []
|
||||
syrupy = []
|
||||
tomli = []
|
||||
tomlkit = []
|
||||
tqdm = []
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "deepdog"
|
||||
version = "0.7.2"
|
||||
version = "0.7.4"
|
||||
description = ""
|
||||
authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
|
||||
|
||||
@ -13,10 +13,11 @@ scipy = "1.10"
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = ">=6"
|
||||
flake8 = "^4.0.1"
|
||||
pytest-cov = "^3.0.0"
|
||||
pytest-cov = "^4.1.0"
|
||||
mypy = "^0.971"
|
||||
python-semantic-release = "^7.24.0"
|
||||
black = "^23.0.0"
|
||||
syrupy = "^4.0.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
177
tests/__snapshots__/test_bayes_run_with_ss.ambr
Normal file
177
tests/__snapshots__/test_bayes_run_with_ss.ambr
Normal file
@ -0,0 +1,177 @@
|
||||
# serializer version: 1
|
||||
# name: test_basic_analysis
|
||||
list([
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.3333333333333333,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.3333333333333333,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.3333333333333333,
|
||||
'dipole_frequency_1': 0.006029931414230269,
|
||||
'dipole_frequency_2': 85436.78758379082,
|
||||
'dipole_location_1': array([-4.76615152, -6.33160296, 5.29522808]),
|
||||
'dipole_location_2': array([-4.72700391, -2.06478573, 6.52467702]),
|
||||
'dipole_moment_1': array([ 860.14181416, -450.27082062, -239.60852996]),
|
||||
'dipole_moment_2': array([ 908.18325588, -208.52681777, -362.93214244]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.45,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.3103448275862069,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.9,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.6206896551724138,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.06896551724137932,
|
||||
'dipole_frequency_1': 102275.63477261562,
|
||||
'dipole_frequency_2': 1755280.9783485082,
|
||||
'dipole_location_1': array([ 4.71515397, -9.70362197, 5.43016546]),
|
||||
'dipole_location_2': array([3.42476038, 3.88562934, 5.15034328]),
|
||||
'dipole_moment_1': array([-502.60742674, -790.60222587, 349.7626267 ]),
|
||||
'dipole_moment_2': array([-192.42708465, -434.81009148, -879.7226844 ]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.7,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.6631578947368421,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.18947368421052635,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.7,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.1473684210526316,
|
||||
'dipole_frequency_1': 2896.799464036654,
|
||||
'dipole_frequency_2': 9.980565189326681e-05,
|
||||
'dipole_location_1': array([-4.97465789, 12.54716531, 6.06324588]),
|
||||
'dipole_location_2': array([ 9.84518459, -11.1183876 , 7.35028226]),
|
||||
'dipole_moment_1': array([997.67961917, 19.6376112 , 65.19004305]),
|
||||
'dipole_moment_2': array([305.63093655, 440.57669389, 844.08643362]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.663157894736842,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.18947368421052635,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.1473684210526316,
|
||||
'dipole_frequency_1': 1.4522667818288244,
|
||||
'dipole_frequency_2': 2704.9795645301197,
|
||||
'dipole_location_1': array([ 7.38183022, 16.6745801 , 7.10428414]),
|
||||
'dipole_location_2': array([-8.15636906, -9.56609132, 6.34141559]),
|
||||
'dipole_moment_1': array([-145.9924693 , 738.74936496, 657.97839986]),
|
||||
'dipole_moment_2': array([-960.16113239, 104.96824669, -258.98314046]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.9,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.9465776293823038,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.030050083472454105,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.1,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.02337228714524208,
|
||||
'dipole_frequency_1': 3827.2315421318913,
|
||||
'dipole_frequency_2': 1.9301094166184413e-05,
|
||||
'dipole_location_1': array([ 5.02067673, -0.9783039 , 6.1431897 ]),
|
||||
'dipole_location_2': array([ 4.66628999, 10.80907459, 7.21771744]),
|
||||
'dipole_moment_1': array([ 871.30659253, -299.17389491, -388.99846068]),
|
||||
'dipole_moment_2': array([-189.87268624, 677.28285845, 710.79975568]),
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_bayesss_with_tighter_cost
|
||||
list([
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.33333333333333337,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.33333333333333337,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.33333333333333337,
|
||||
'dipole_frequency_1': 0.006029931414230269,
|
||||
'dipole_frequency_2': 85436.78758379082,
|
||||
'dipole_location_1': array([-4.76615152, -6.33160296, 5.29522808]),
|
||||
'dipole_location_2': array([-4.72700391, -2.06478573, 6.52467702]),
|
||||
'dipole_moment_1': array([ 860.14181416, -450.27082062, -239.60852996]),
|
||||
'dipole_moment_2': array([ 908.18325588, -208.52681777, -362.93214244]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.0109375,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.1044776119402985,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.03125,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.2985074626865672,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.0625,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5970149253731344,
|
||||
'dipole_frequency_1': 102275.63477261562,
|
||||
'dipole_frequency_2': 1755280.9783485082,
|
||||
'dipole_location_1': array([ 4.71515397, -9.70362197, 5.43016546]),
|
||||
'dipole_location_2': array([3.42476038, 3.88562934, 5.15034328]),
|
||||
'dipole_moment_1': array([-502.60742674, -790.60222587, 349.7626267 ]),
|
||||
'dipole_moment_2': array([-192.42708465, -434.81009148, -879.7226844 ]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 7.291135021404688e-05,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.021875,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.4666326413699001,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.0125,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5332944472798858,
|
||||
'dipole_frequency_1': 2896.799464036654,
|
||||
'dipole_frequency_2': 9.980565189326681e-05,
|
||||
'dipole_location_1': array([-4.97465789, 12.54716531, 6.06324588]),
|
||||
'dipole_location_2': array([ 9.84518459, -11.1183876 , 7.35028226]),
|
||||
'dipole_moment_1': array([997.67961917, 19.6376112 , 65.19004305]),
|
||||
'dipole_moment_2': array([305.63093655, 440.57669389, 844.08643362]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 7.291135021404688e-05,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.4666326413699001,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.5332944472798858,
|
||||
'dipole_frequency_1': 1.4522667818288244,
|
||||
'dipole_frequency_2': 2704.9795645301197,
|
||||
'dipole_location_1': array([ 7.38183022, 16.6745801 , 7.10428414]),
|
||||
'dipole_location_2': array([-8.15636906, -9.56609132, 6.34141559]),
|
||||
'dipole_moment_1': array([-145.9924693 , 738.74936496, 657.97839986]),
|
||||
'dipole_moment_2': array([-960.16113239, 104.96824669, -258.98314046]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 0.175,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 0.00012008361740869356,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.05625,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.24702915581216964,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.15,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.7528507605704217,
|
||||
'dipole_frequency_1': 3827.2315421318913,
|
||||
'dipole_frequency_2': 1.9301094166184413e-05,
|
||||
'dipole_location_1': array([ 5.02067673, -0.9783039 , 6.1431897 ]),
|
||||
'dipole_location_2': array([ 4.66628999, 10.80907459, 7.21771744]),
|
||||
'dipole_moment_1': array([ 871.30659253, -299.17389491, -388.99846068]),
|
||||
'dipole_moment_2': array([-189.87268624, 677.28285845, 710.79975568]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 4.9116305003549454e-08,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 0.0109375,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.11316396672817797,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.028125,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.886835984155517,
|
||||
'dipole_frequency_1': 1.1715179359592061e-05,
|
||||
'dipole_frequency_2': 0.0019103783276337497,
|
||||
'dipole_location_1': array([-0.95736547, 1.09273812, 7.47158641]),
|
||||
'dipole_location_2': array([ -3.18510322, -15.64493131, 5.81623624]),
|
||||
'dipole_moment_1': array([-184.64961369, 956.56786553, 225.57136075]),
|
||||
'dipole_moment_2': array([ -34.63395137, 801.17771816, -597.42342885]),
|
||||
}),
|
||||
dict({
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedxy-pfixexp_3-dipole_count_2_prob': 1.977090156727901e-10,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_likelihood': 9.765625e-06,
|
||||
'connors_geom-5height-orientation_fixedz-pfixexp_3-dipole_count_2_prob': 0.00045552157211010855,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_likelihood': 0.002734375,
|
||||
'connors_geom-5height-orientation_free-pfixexp_3-dipole_count_2_prob': 0.9995444782301809,
|
||||
'dipole_frequency_1': 999786.9069039805,
|
||||
'dipole_frequency_2': 186034.67996840767,
|
||||
'dipole_location_1': array([-5.59679125, 6.3411602 , 5.33602522]),
|
||||
'dipole_location_2': array([-0.03412955, -6.83522954, 5.58551513]),
|
||||
'dipole_moment_1': array([826.38270589, 491.81526944, 274.24325726]),
|
||||
'dipole_moment_2': array([ 202.74745884, -656.07483714, -726.95204519]),
|
||||
}),
|
||||
])
|
||||
# ---
|
158
tests/test_bayes_run_with_ss.py
Normal file
158
tests/test_bayes_run_with_ss.py
Normal file
@ -0,0 +1,158 @@
|
||||
import deepdog
|
||||
import logging
|
||||
import logging.config
|
||||
|
||||
import numpy.random
|
||||
|
||||
from pdme.model import (
|
||||
LogSpacedRandomCountMultipleDipoleFixedMagnitudeModel,
|
||||
LogSpacedRandomCountMultipleDipoleFixedMagnitudeXYModel,
|
||||
LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel,
|
||||
)
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fixed_z_model_func(
|
||||
xmin,
|
||||
xmax,
|
||||
ymin,
|
||||
ymax,
|
||||
zmin,
|
||||
zmax,
|
||||
wexp_min,
|
||||
wexp_max,
|
||||
pfixed,
|
||||
n_max,
|
||||
prob_occupancy,
|
||||
):
|
||||
return LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel(
|
||||
xmin,
|
||||
xmax,
|
||||
ymin,
|
||||
ymax,
|
||||
zmin,
|
||||
zmax,
|
||||
wexp_min,
|
||||
wexp_max,
|
||||
pfixed,
|
||||
0,
|
||||
0,
|
||||
n_max,
|
||||
prob_occupancy,
|
||||
)
|
||||
|
||||
|
||||
def get_model(orientation):
|
||||
model_funcs = {
|
||||
"fixedz": fixed_z_model_func,
|
||||
"free": LogSpacedRandomCountMultipleDipoleFixedMagnitudeModel,
|
||||
"fixedxy": LogSpacedRandomCountMultipleDipoleFixedMagnitudeXYModel,
|
||||
}
|
||||
model = model_funcs[orientation](
|
||||
-10,
|
||||
10,
|
||||
-17.5,
|
||||
17.5,
|
||||
5,
|
||||
7.5,
|
||||
-5,
|
||||
6.5,
|
||||
10**3,
|
||||
2,
|
||||
0.99999999,
|
||||
)
|
||||
model.n = 2
|
||||
model.rng = numpy.random.default_rng(1234)
|
||||
|
||||
return (
|
||||
f"connors_geom-5height-orientation_{orientation}-pfixexp_{3}-dipole_count_{2}",
|
||||
model,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_analysis(snapshot):
|
||||
|
||||
dot_positions = [[0, 0, 0], [0, 1, 0]]
|
||||
|
||||
freqs = [1, 10, 100]
|
||||
models = []
|
||||
|
||||
orientations = ["free", "fixedxy", "fixedz"]
|
||||
for orientation in orientations:
|
||||
models.append(get_model(orientation))
|
||||
|
||||
_logger.info(f"have {len(models)} models to look at")
|
||||
if len(models) == 1:
|
||||
_logger.info(f"only one model, name: {models[0][0]}")
|
||||
|
||||
square_run = deepdog.BayesRunWithSubspaceSimulation(
|
||||
dot_positions,
|
||||
freqs,
|
||||
models,
|
||||
models[0][1],
|
||||
filename_slug="test",
|
||||
end_threshold=0.9,
|
||||
ss_n_c=5,
|
||||
ss_n_s=2,
|
||||
ss_m_max=10,
|
||||
ss_target_cost=150,
|
||||
ss_level_0_seed=200,
|
||||
ss_mcmc_seed=20,
|
||||
ss_use_adaptive_steps=True,
|
||||
ss_default_phi_step=0.01,
|
||||
ss_default_theta_step=0.01,
|
||||
ss_default_r_step=0.01,
|
||||
ss_default_w_log_step=0.01,
|
||||
ss_default_upper_w_log_step=4,
|
||||
ss_dump_last_generation=False,
|
||||
write_output_to_bayesruncsv=False,
|
||||
ss_initial_costs_chunk_size=1000,
|
||||
)
|
||||
result = square_run.go()
|
||||
|
||||
assert result == snapshot
|
||||
|
||||
|
||||
def test_bayesss_with_tighter_cost(snapshot):
|
||||
|
||||
dot_positions = [[0, 0, 0], [0, 1, 0]]
|
||||
|
||||
freqs = [1, 10, 100]
|
||||
models = []
|
||||
|
||||
orientations = ["free", "fixedxy", "fixedz"]
|
||||
for orientation in orientations:
|
||||
models.append(get_model(orientation))
|
||||
|
||||
_logger.info(f"have {len(models)} models to look at")
|
||||
if len(models) == 1:
|
||||
_logger.info(f"only one model, name: {models[0][0]}")
|
||||
|
||||
square_run = deepdog.BayesRunWithSubspaceSimulation(
|
||||
dot_positions,
|
||||
freqs,
|
||||
models,
|
||||
models[0][1],
|
||||
filename_slug="test",
|
||||
end_threshold=0.9,
|
||||
ss_n_c=5,
|
||||
ss_n_s=2,
|
||||
ss_m_max=10,
|
||||
ss_target_cost=1.5,
|
||||
ss_level_0_seed=200,
|
||||
ss_mcmc_seed=20,
|
||||
ss_use_adaptive_steps=True,
|
||||
ss_default_phi_step=0.01,
|
||||
ss_default_theta_step=0.01,
|
||||
ss_default_r_step=0.01,
|
||||
ss_default_w_log_step=0.01,
|
||||
ss_default_upper_w_log_step=4,
|
||||
ss_dump_last_generation=False,
|
||||
write_output_to_bayesruncsv=False,
|
||||
ss_initial_costs_chunk_size=1
|
||||
)
|
||||
result = square_run.go()
|
||||
|
||||
assert result == snapshot
|
Loading…
x
Reference in New Issue
Block a user