feat!: switches over to pdme new stuff, uses models and scraps discretisations entirely
This commit is contained in:
parent
31070b5342
commit
6e29f7a702
@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from deepdog.meta import __version__
|
||||
from deepdog.bayes_run import BayesRun
|
||||
from deepdog.alt_bayes_run import AltBayesRun
|
||||
from deepdog.alt_bayes_run_simulpairs import AltBayesRunSimulPairs
|
||||
from deepdog.diagnostic import Diagnostic
|
||||
from deepdog.bayes_run_simulpairs import BayesRunSimulPairs
|
||||
|
||||
|
||||
def get_version():
|
||||
@ -13,9 +11,7 @@ def get_version():
|
||||
__all__ = [
|
||||
"get_version",
|
||||
"BayesRun",
|
||||
"AltBayesRun",
|
||||
"AltBayesRunSimulPairs",
|
||||
"Diagnostic",
|
||||
"BayesRunSimulPairs",
|
||||
]
|
||||
|
||||
|
||||
|
@ -23,8 +23,8 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_a_result(input) -> int:
|
||||
discretisation, dot_inputs, lows, highs, monte_carlo_count, max_frequency = input
|
||||
sample_dipoles = discretisation.get_model().get_n_single_dipoles(
|
||||
model, dot_inputs, lows, highs, monte_carlo_count, max_frequency = input
|
||||
sample_dipoles = model.get_model().get_n_single_dipoles(
|
||||
monte_carlo_count, max_frequency
|
||||
)
|
||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
||||
@ -33,7 +33,7 @@ def get_a_result(input) -> int:
|
||||
|
||||
def get_a_result_using_pairs(input) -> int:
|
||||
(
|
||||
discretisation,
|
||||
model,
|
||||
dot_inputs,
|
||||
pair_inputs,
|
||||
local_lows,
|
||||
@ -43,9 +43,7 @@ def get_a_result_using_pairs(input) -> int:
|
||||
monte_carlo_count,
|
||||
max_frequency,
|
||||
) = input
|
||||
sample_dipoles = discretisation.get_model().get_n_single_dipoles(
|
||||
monte_carlo_count, max_frequency
|
||||
)
|
||||
sample_dipoles = model.get_n_single_dipoles(monte_carlo_count, max_frequency)
|
||||
local_vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
||||
local_matches = pdme.util.fast_v_calc.between(local_vals, local_lows, local_highs)
|
||||
nonlocal_vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal(
|
||||
@ -58,7 +56,7 @@ def get_a_result_using_pairs(input) -> int:
|
||||
return numpy.count_nonzero(combined_matches)
|
||||
|
||||
|
||||
class AltBayesRun:
|
||||
class BayesRun:
|
||||
"""
|
||||
A single Bayes run for a given set of dots.
|
||||
|
||||
@ -67,11 +65,11 @@ class AltBayesRun:
|
||||
dot_inputs : Sequence[DotInput]
|
||||
The dot inputs for this bayes run.
|
||||
|
||||
discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)]
|
||||
models_with_names : Sequence[Tuple(str, pdme.model.DipoleModel)]
|
||||
The models to evaluate.
|
||||
|
||||
actual_model_discretisation : pdme.model.Discretisation
|
||||
The discretisation for the model which is actually correct.
|
||||
actual_model : pdme.model.DipoleModel
|
||||
The model which is actually correct.
|
||||
|
||||
filename_slug : str
|
||||
The filename slug to include.
|
||||
@ -84,8 +82,8 @@ class AltBayesRun:
|
||||
self,
|
||||
dot_positions: Sequence[numpy.typing.ArrayLike],
|
||||
frequency_range: Sequence[float],
|
||||
discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]],
|
||||
actual_model: pdme.model.Model,
|
||||
models_with_names: Sequence[Tuple[str, pdme.model.DipoleModel]],
|
||||
actual_model: pdme.model.DipoleModel,
|
||||
filename_slug: str,
|
||||
run_count: int = 100,
|
||||
low_error: float = 0.9,
|
||||
@ -117,10 +115,10 @@ class AltBayesRun:
|
||||
pdme.measurement.input_types.dot_pair_inputs_to_array(self.dot_pair_inputs)
|
||||
)
|
||||
|
||||
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
||||
self.model_names = [name for (name, _) in discretisations_with_names]
|
||||
self.models = [model for (_, model) in models_with_names]
|
||||
self.model_names = [name for (name, _) in models_with_names]
|
||||
self.actual_model = actual_model
|
||||
self.model_count = len(self.discretisations)
|
||||
self.model_count = len(self.models)
|
||||
self.monte_carlo_count = monte_carlo_count
|
||||
self.monte_carlo_cycles = monte_carlo_cycles
|
||||
self.target_success = target_success
|
||||
@ -203,9 +201,9 @@ class AltBayesRun:
|
||||
_logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}")
|
||||
|
||||
results = []
|
||||
_logger.debug("Going to iterate over discretisations now")
|
||||
for disc_count, discretisation in enumerate(self.discretisations):
|
||||
_logger.debug(f"Doing discretisation #{disc_count}")
|
||||
_logger.debug("Going to iterate over models now")
|
||||
for model_count, model in enumerate(self.models):
|
||||
_logger.debug(f"Doing model #{model_count}")
|
||||
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
|
||||
cycle_count = 0
|
||||
cycle_success = 0
|
||||
@ -223,7 +221,7 @@ class AltBayesRun:
|
||||
get_a_result_using_pairs,
|
||||
[
|
||||
(
|
||||
discretisation,
|
||||
model,
|
||||
self.dot_inputs_array,
|
||||
self.dot_pair_inputs_array,
|
||||
lows,
|
||||
@ -244,7 +242,7 @@ class AltBayesRun:
|
||||
get_a_result,
|
||||
[
|
||||
(
|
||||
discretisation,
|
||||
model,
|
||||
self.dot_inputs_array,
|
||||
lows,
|
||||
highs,
|
||||
|
@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
def get_a_simul_result_using_pairs(input) -> numpy.ndarray:
|
||||
(
|
||||
discretisation,
|
||||
model,
|
||||
dot_inputs,
|
||||
pair_inputs,
|
||||
local_lows,
|
||||
@ -42,16 +42,12 @@ def get_a_simul_result_using_pairs(input) -> numpy.ndarray:
|
||||
local_total = 0
|
||||
combined_total = 0
|
||||
|
||||
sample_dipoles = discretisation.get_model().get_n_single_dipoles(
|
||||
sample_dipoles = model.get_monte_carlo_dipole_inputs(
|
||||
monte_carlo_count, max_frequency, rng_to_use=rng
|
||||
)
|
||||
local_vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(
|
||||
dot_inputs, sample_dipoles
|
||||
)
|
||||
local_matches = pdme.util.fast_v_calc.between(
|
||||
local_vals, local_lows, local_highs
|
||||
)
|
||||
nonlocal_vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal(
|
||||
local_vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(dot_inputs, sample_dipoles)
|
||||
local_matches = pdme.util.fast_v_calc.between(local_vals, local_lows, local_highs)
|
||||
nonlocal_vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal_dipoleses(
|
||||
pair_inputs, sample_dipoles
|
||||
)
|
||||
nonlocal_matches = pdme.util.fast_v_calc.between(
|
||||
@ -64,7 +60,7 @@ def get_a_simul_result_using_pairs(input) -> numpy.ndarray:
|
||||
return numpy.array([local_total, combined_total])
|
||||
|
||||
|
||||
class AltBayesRunSimulPairs:
|
||||
class BayesRunSimulPairs:
|
||||
"""
|
||||
A dual pairs-nonpairs Bayes run for a given set of dots.
|
||||
|
||||
@ -73,11 +69,11 @@ class AltBayesRunSimulPairs:
|
||||
dot_inputs : Sequence[DotInput]
|
||||
The dot inputs for this bayes run.
|
||||
|
||||
discretisations_with_names : Sequence[Tuple(str, pdme.model.Model)]
|
||||
models_with_names : Sequence[Tuple(str, pdme.model.DipoleModel)]
|
||||
The models to evaluate.
|
||||
|
||||
actual_model_discretisation : pdme.model.Discretisation
|
||||
The discretisation for the model which is actually correct.
|
||||
actual_model : pdme.model.DipoleModel
|
||||
The modoel for the model which is actually correct.
|
||||
|
||||
filename_slug : str
|
||||
The filename slug to include.
|
||||
@ -90,8 +86,8 @@ class AltBayesRunSimulPairs:
|
||||
self,
|
||||
dot_positions: Sequence[numpy.typing.ArrayLike],
|
||||
frequency_range: Sequence[float],
|
||||
discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]],
|
||||
actual_model: pdme.model.Model,
|
||||
models_with_names: Sequence[Tuple[str, pdme.model.DipoleModel]],
|
||||
actual_model: pdme.model.DipoleModel,
|
||||
filename_slug: str,
|
||||
run_count: int = 100,
|
||||
low_error: float = 0.9,
|
||||
@ -120,10 +116,10 @@ class AltBayesRunSimulPairs:
|
||||
pdme.measurement.input_types.dot_pair_inputs_to_array(self.dot_pair_inputs)
|
||||
)
|
||||
|
||||
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
||||
self.model_names = [name for (name, _) in discretisations_with_names]
|
||||
self.models = [mod for (_, mod) in models_with_names]
|
||||
self.model_names = [name for (name, _) in models_with_names]
|
||||
self.actual_model = actual_model
|
||||
self.model_count = len(self.discretisations)
|
||||
self.model_count = len(self.models)
|
||||
self.monte_carlo_count = monte_carlo_count
|
||||
self.monte_carlo_cycles = monte_carlo_cycles
|
||||
self.target_success = target_success
|
||||
@ -208,9 +204,9 @@ class AltBayesRunSimulPairs:
|
||||
|
||||
results_pairs = []
|
||||
results_no_pairs = []
|
||||
_logger.debug("Going to iterate over discretisations now")
|
||||
for disc_count, discretisation in enumerate(self.discretisations):
|
||||
_logger.debug(f"Doing discretisation #{disc_count}")
|
||||
_logger.debug("Going to iterate over models now")
|
||||
for model_count, model in enumerate(self.models):
|
||||
_logger.debug(f"Doing model #{model_count}")
|
||||
|
||||
core_count = multiprocessing.cpu_count() - 1 or 1
|
||||
with multiprocessing.Pool(core_count) as pool:
|
||||
@ -223,7 +219,9 @@ class AltBayesRunSimulPairs:
|
||||
<= self.target_success
|
||||
):
|
||||
_logger.debug(f"Starting cycle {cycles}")
|
||||
_logger.debug(f"(pair, no_pair) successes are {(cycle_success_pairs, cycle_success_no_pairs)}")
|
||||
_logger.debug(
|
||||
f"(pair, no_pair) successes are {(cycle_success_pairs, cycle_success_no_pairs)}"
|
||||
)
|
||||
cycles += 1
|
||||
current_success_pairs = 0
|
||||
current_success_no_pairs = 0
|
||||
@ -241,7 +239,7 @@ class AltBayesRunSimulPairs:
|
||||
get_a_simul_result_using_pairs,
|
||||
[
|
||||
(
|
||||
discretisation,
|
||||
model,
|
||||
self.dot_inputs_array,
|
||||
self.dot_pair_inputs_array,
|
||||
lows,
|
||||
|
8
poetry.lock
generated
8
poetry.lock
generated
@ -335,7 +335,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
|
||||
|
||||
[[package]]
|
||||
name = "pdme"
|
||||
version = "0.7.0"
|
||||
version = "0.8.1"
|
||||
description = "Python dipole model evaluator"
|
||||
category = "main"
|
||||
optional = false
|
||||
@ -740,7 +740,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.8,<3.10"
|
||||
content-hash = "f7a79a9d8f1a94f36207c3cee41551a6c8af10976562678c5f3d1462586a7703"
|
||||
content-hash = "891c2891a49d446e8a381212ba40c61941b9a23d7e85304b18cb024b4a341edb"
|
||||
|
||||
[metadata.files]
|
||||
atomicwrites = [
|
||||
@ -1025,8 +1025,8 @@ pathspec = [
|
||||
{file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"},
|
||||
]
|
||||
pdme = [
|
||||
{file = "pdme-0.7.0-py3-none-any.whl", hash = "sha256:28270dba845ce771e50df598c2d2ceee6b746fef94ba648b89c58784655d9631"},
|
||||
{file = "pdme-0.7.0.tar.gz", hash = "sha256:a6aa25a2d306d9222bddf8e0932c5b3c050e164c7609cc2b43108cffe62b5b23"},
|
||||
{file = "pdme-0.8.1-py3-none-any.whl", hash = "sha256:e20fbd0b1381957a2d80a48ed3b7a400dac6782972e1feb022f1924aabca4136"},
|
||||
{file = "pdme-0.8.1.tar.gz", hash = "sha256:69a69690dfede43a8b40bad8470477ff129cc07987e899abcbb16f01b0066c89"},
|
||||
]
|
||||
pkginfo = [
|
||||
{file = "pkginfo-1.8.2-py2.py3-none-any.whl", hash = "sha256:c24c487c6a7f72c66e816ab1796b96ac6c3d14d49338293d2141664330b55ffc"},
|
||||
|
@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla <dmallubhotla+github@gmail.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8,<3.10"
|
||||
pdme = "^0.7.0"
|
||||
pdme = "^0.8.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = ">=6"
|
||||
|
Loading…
x
Reference in New Issue
Block a user