diff --git a/deepdog/alt_bayes_run.py b/deepdog/alt_bayes_run.py index e963209..7599b68 100644 --- a/deepdog/alt_bayes_run.py +++ b/deepdog/alt_bayes_run.py @@ -1,6 +1,9 @@ +import pdme.inputs import pdme.model +import pdme.measurement.input_types import pdme.measurement.oscillating_dipole import pdme.util.fast_v_calc +import pdme.util.fast_nonlocal_spectrum from typing import Sequence, Tuple, List import datetime import csv @@ -26,6 +29,17 @@ def get_a_result(input) -> int: return numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs)) +def get_a_result_using_pairs(input) -> int: + discretisation, dot_inputs, pair_inputs, local_lows, local_highs, nonlocal_lows, nonlocal_highs, monte_carlo_count, max_frequency = input + sample_dipoles = discretisation.get_model().get_n_single_dipoles(monte_carlo_count, max_frequency) + local_vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles) + local_matches = pdme.util.fast_v_calc.between(local_vals, local_lows, local_highs) + nonlocal_vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal(pair_inputs, sample_dipoles) + nonlocal_matches = pdme.util.fast_v_calc.between(nonlocal_vals, nonlocal_lows, nonlocal_highs) + combined_matches = numpy.logical_and(local_matches, nonlocal_matches) + return numpy.count_nonzero(combined_matches) + + class AltBayesRun(): ''' A single Bayes run for a given set of dots. @@ -43,9 +57,15 @@ class AltBayesRun(): run_count: int The number of runs to do. ''' - def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, low_error: float = 0.9, high_error: float = 1.1, monte_carlo_count: int = 10000, monte_carlo_cycles: int = 10, max_frequency: float = 20, end_threshold: float = None, chunksize: int = CHUNKSIZE) -> None: - self.dot_inputs = dot_inputs - self.dot_inputs_array = pdme.measurement.oscillating_dipole.dot_inputs_to_array(dot_inputs) + def __init__(self, dot_positions: Sequence[numpy.typing.ArrayLike], frequency_range: Sequence[float], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int = 100, low_error: float = 0.9, high_error: float = 1.1, pairs_high_error = None, pairs_low_error = None, monte_carlo_count: int = 10000, monte_carlo_cycles: int = 10, max_frequency: float = 20, end_threshold: float = None, chunksize: int = CHUNKSIZE, use_pairs: bool = False) -> None: + self.dot_inputs = pdme.inputs.inputs_with_frequency_range(dot_positions, frequency_range) + self.dot_inputs_array = pdme.measurement.input_types.dot_inputs_to_array(self.dot_inputs) + + self.use_pairs = use_pairs + + self.dot_pair_inputs = pdme.inputs.input_pairs_with_frequency_range(dot_positions, frequency_range) + self.dot_pair_inputs_array = pdme.measurement.input_types.dot_pair_inputs_to_array(self.dot_pair_inputs) + self.discretisations = [disc for (_, disc) in discretisations_with_names] self.model_names = [name for (name, _) in discretisations_with_names] self.actual_model = actual_model @@ -55,6 +75,14 @@ class AltBayesRun(): self.run_count = run_count self.low_error = low_error self.high_error = high_error + if pairs_low_error is None: + self.pairs_low_error = self.low_error + else: + self.pairs_low_error = pairs_low_error + if pairs_high_error is None: + self.pairs_high_error = self.high_error + else: + self.pairs_high_error = pairs_high_error self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"] self.compensate_zeros = True self.chunksize = chunksize @@ -64,7 +92,10 @@ class AltBayesRun(): self.probabilities = [1 / self.model_count] * self.model_count timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") - self.filename = f"{timestamp}-{filename_slug}.altbayes.csv" + if self.use_pairs: + self.filename = f"{timestamp}-{filename_slug}.altbayes.pairs.csv" + else: + self.filename = f"{timestamp}-{filename_slug}.altbayes.csv" self.max_frequency = max_frequency if end_threshold is not None: @@ -89,7 +120,13 @@ class AltBayesRun(): actual_dipoles = self.actual_model.get_dipoles(frequency) dots = actual_dipoles.get_percent_range_dot_measurements(self.dot_inputs, self.low_error, self.high_error) - lows, highs = pdme.measurement.oscillating_dipole.dot_range_measurements_low_high_arrays(dots) + lows, highs = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(dots) + + pair_lows, pair_highs = (None, None) + if self.use_pairs: + pair_measurements = actual_dipoles.get_percent_range_dot_pair_measurements(self.dot_pair_inputs, self.pairs_low_error, self.pairs_high_error) + pair_lows, pair_highs = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(pair_measurements) + _logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}") results = [] @@ -97,9 +134,14 @@ class AltBayesRun(): for disc_count, discretisation in enumerate(self.discretisations): _logger.debug(f"Doing discretisation #{disc_count}") with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool: - results.append(sum( - pool.imap_unordered(get_a_result, [(discretisation, self.dot_inputs_array, lows, highs, self.monte_carlo_count, self.max_frequency)] * self.monte_carlo_cycles, self.chunksize) - )) + if self.use_pairs: + results.append(sum( + pool.imap_unordered(get_a_result_using_pairs, [(discretisation, self.dot_inputs_array, self.dot_pair_inputs_array, lows, highs, pair_lows, pair_highs, self.monte_carlo_count, self.max_frequency)] * self.monte_carlo_cycles, self.chunksize) + )) + else: + results.append(sum( + pool.imap_unordered(get_a_result, [(discretisation, self.dot_inputs_array, lows, highs, self.monte_carlo_count, self.max_frequency)] * self.monte_carlo_cycles, self.chunksize) + )) _logger.debug("Done, constructing output now") row = { diff --git a/poetry.lock b/poetry.lock index b14e346..018dcf7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -304,7 +304,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" [[package]] name = "pdme" -version = "0.5.4" +version = "0.6.0" description = "Python dipole model evaluator" category = "main" optional = false @@ -312,7 +312,7 @@ python-versions = ">=3.8,<3.10" [package.dependencies] numpy = ">=1.21.1,<2.0.0" -scipy = ">=1.5,<1.6" +scipy = ">=1.8,<1.9" [[package]] name = "pkginfo" @@ -532,14 +532,14 @@ idna2008 = ["idna"] [[package]] name = "scipy" -version = "1.5.4" +version = "1.8.0" description = "SciPy: Scientific Library for Python" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8,<3.11" [package.dependencies] -numpy = ">=1.14.5" +numpy = ">=1.17.3,<1.25.0" [[package]] name = "secretstorage" @@ -697,7 +697,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.8,<3.10" -content-hash = "ba794e6f69d42e44e2b1abc40731fd78d2f417cdc9be12d6e752dbcfa95adaad" +content-hash = "b70066bcf7d66ae8156df30babbe85c6e60b914da41ed5d7af1054c24897051d" [metadata.files] atomicwrites = [ @@ -952,8 +952,8 @@ packaging = [ {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] pdme = [ - {file = "pdme-0.5.4-py3-none-any.whl", hash = "sha256:90ba75efbec04f5a505c9c5228824538510149e218372e758c15150d5967d92b"}, - {file = "pdme-0.5.4.tar.gz", hash = "sha256:82bff2ccc8f38996c23b43ab7d7dc80d87a6b340492f368861e7748105b50174"}, + {file = "pdme-0.6.0-py3-none-any.whl", hash = "sha256:eb06864332768e96e3729cb9c08231f46994c928246d3148782537ace02a50f1"}, + {file = "pdme-0.6.0.tar.gz", hash = "sha256:b2955c3ef4951e168d5298cee2e04d57b7bd069454c0815d7e8997f63a44c0c4"}, ] pkginfo = [ {file = "pkginfo-1.8.2-py2.py3-none-any.whl", hash = "sha256:c24c487c6a7f72c66e816ab1796b96ac6c3d14d49338293d2141664330b55ffc"}, @@ -1024,31 +1024,29 @@ rfc3986 = [ {file = "rfc3986-2.0.0.tar.gz", hash = "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c"}, ] scipy = [ - {file = "scipy-1.5.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4f12d13ffbc16e988fa40809cbbd7a8b45bc05ff6ea0ba8e3e41f6f4db3a9e47"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:a254b98dbcc744c723a838c03b74a8a34c0558c9ac5c86d5561703362231107d"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:368c0f69f93186309e1b4beb8e26d51dd6f5010b79264c0f1e9ca00cd92ea8c9"}, - {file = "scipy-1.5.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:4598cf03136067000855d6b44d7a1f4f46994164bcd450fb2c3d481afc25dd06"}, - {file = "scipy-1.5.4-cp36-cp36m-win32.whl", hash = "sha256:e98d49a5717369d8241d6cf33ecb0ca72deee392414118198a8e5b4c35c56340"}, - {file = "scipy-1.5.4-cp36-cp36m-win_amd64.whl", hash = "sha256:65923bc3809524e46fb7eb4d6346552cbb6a1ffc41be748535aa502a2e3d3389"}, - {file = "scipy-1.5.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9ad4fcddcbf5dc67619379782e6aeef41218a79e17979aaed01ed099876c0e62"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:25b241034215247481f53355e05f9e25462682b13bd9191359075682adcd9554"}, - {file = "scipy-1.5.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c"}, - {file = "scipy-1.5.4-cp37-cp37m-win32.whl", hash = "sha256:d6d25c41a009e3c6b7e757338948d0076ee1dd1770d1c09ec131f11946883c54"}, - {file = "scipy-1.5.4-cp37-cp37m-win_amd64.whl", hash = "sha256:2c872de0c69ed20fb1a9b9cf6f77298b04a26f0b8720a5457be08be254366c6e"}, - {file = "scipy-1.5.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e360cb2299028d0b0d0f65a5c5e51fc16a335f1603aa2357c25766c8dab56938"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3397c129b479846d7eaa18f999369a24322d008fac0782e7828fa567358c36ce"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:168c45c0c32e23f613db7c9e4e780bc61982d71dcd406ead746c7c7c2f2004ce"}, - {file = "scipy-1.5.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:213bc59191da2f479984ad4ec39406bf949a99aba70e9237b916ce7547b6ef42"}, - {file = "scipy-1.5.4-cp38-cp38-win32.whl", hash = "sha256:634568a3018bc16a83cda28d4f7aed0d803dd5618facb36e977e53b2df868443"}, - {file = "scipy-1.5.4-cp38-cp38-win_amd64.whl", hash = "sha256:b03c4338d6d3d299e8ca494194c0ae4f611548da59e3c038813f1a43976cb437"}, - {file = "scipy-1.5.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3d5db5d815370c28d938cf9b0809dade4acf7aba57eaf7ef733bfedc9b2474c4"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:6b0ceb23560f46dd236a8ad4378fc40bad1783e997604ba845e131d6c680963e"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660"}, - {file = "scipy-1.5.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:8c8d6ca19c8497344b810b0b0344f8375af5f6bb9c98bd42e33f747417ab3f57"}, - {file = "scipy-1.5.4-cp39-cp39-win32.whl", hash = "sha256:d84cadd7d7998433334c99fa55bcba0d8b4aeff0edb123b2a1dfcface538e474"}, - {file = "scipy-1.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:cc1f78ebc982cd0602c9a7615d878396bec94908db67d4ecddca864d049112f2"}, - {file = "scipy-1.5.4.tar.gz", hash = "sha256:4a453d5e5689de62e5d38edf40af3f17560bfd63c9c5bd228c18c1f99afa155b"}, + {file = "scipy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:87b01c7d5761e8a266a0fbdb9d88dcba0910d63c1c671bdb4d99d29f469e9e03"}, + {file = "scipy-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ae3e327da323d82e918e593460e23babdce40d7ab21490ddf9fc06dec6b91a18"}, + {file = "scipy-1.8.0-cp310-cp310-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:16e09ef68b352d73befa8bcaf3ebe25d3941fe1a58c82909d5589856e6bc8174"}, + {file = "scipy-1.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c17a1878d00a5dd2797ccd73623ceca9d02375328f6218ee6d921e1325e61aff"}, + {file = "scipy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:937d28722f13302febde29847bbe554b89073fbb924a30475e5ed7b028898b5f"}, + {file = "scipy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:8f4d059a97b29c91afad46b1737274cb282357a305a80bdd9e8adf3b0ca6a3f0"}, + {file = "scipy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:38aa39b6724cb65271e469013aeb6f2ce66fd44f093e241c28a9c6bc64fd79ed"}, + {file = "scipy-1.8.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:559a8a4c03a5ba9fe3232f39ed24f86457e4f3f6c0abbeae1fb945029f092720"}, + {file = "scipy-1.8.0-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:f4a6d3b9f9797eb2d43938ac2c5d96d02aed17ef170c8b38f11798717523ddba"}, + {file = "scipy-1.8.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:92b2c2af4183ed09afb595709a8ef5783b2baf7f41e26ece24e1329c109691a7"}, + {file = "scipy-1.8.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a279e27c7f4566ef18bab1b1e2c37d168e365080974758d107e7d237d3f0f484"}, + {file = "scipy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad5be4039147c808e64f99c0e8a9641eb5d2fa079ff5894dcd8240e94e347af4"}, + {file = "scipy-1.8.0-cp38-cp38-win32.whl", hash = "sha256:3d9dd6c8b93a22bf9a3a52d1327aca7e092b1299fb3afc4f89e8eba381be7b59"}, + {file = "scipy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:5e73343c5e0d413c1f937302b2e04fb07872f5843041bcfd50699aef6e95e399"}, + {file = "scipy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:de2e80ee1d925984c2504812a310841c241791c5279352be4707cdcd7c255039"}, + {file = "scipy-1.8.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:c2bae431d127bf0b1da81fc24e4bba0a84d058e3a96b9dd6475dfcb3c5e8761e"}, + {file = "scipy-1.8.0-cp39-cp39-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:723b9f878095ed994756fa4ee3060c450e2db0139c5ba248ee3f9628bd64e735"}, + {file = "scipy-1.8.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:011d4386b53b933142f58a652aa0f149c9b9242abd4f900b9f4ea5fbafc86b89"}, + {file = "scipy-1.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6f0cd9c0bd374ef834ee1e0f0999678d49dcc400ea6209113d81528958f97c7"}, + {file = "scipy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3720d0124aced49f6f2198a6900304411dbbeed12f56951d7c66ebef05e3df6"}, + {file = "scipy-1.8.0-cp39-cp39-win32.whl", hash = "sha256:3d573228c10a3a8c32b9037be982e6440e411b443a6267b067cac72f690b8d56"}, + {file = "scipy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:bb7088e89cd751acf66195d2f00cf009a1ea113f3019664032d9075b1e727b6c"}, + {file = "scipy-1.8.0.tar.gz", hash = "sha256:31d4f2d6b724bc9a98e527b5849b8a7e589bf1ea630c33aa563eda912c9ff0bd"}, ] secretstorage = [ {file = "SecretStorage-3.3.1-py3-none-any.whl", hash = "sha256:422d82c36172d88d6a0ed5afdec956514b189ddbfb72fefab0c8a1cee4eaf71f"}, diff --git a/pyproject.toml b/pyproject.toml index a9f4e60..47269d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Deepak Mallubhotla "] [tool.poetry.dependencies] python = "^3.8,<3.10" -pdme = "^0.5.4" +pdme = "^0.6.0" [tool.poetry.dev-dependencies] pytest = ">=6"