Compare commits
1 Commits
910f19d42f
...
82aa9d9ddf
Author | SHA1 | Date | |
---|---|---|---|
82aa9d9ddf |
25
CHANGELOG.md
25
CHANGELOG.md
@ -2,31 +2,6 @@
|
|||||||
|
|
||||||
All notable changes to this project will be documented in this file. See [standard-version](https://github.com/conventional-changelog/standard-version) for commit guidelines.
|
All notable changes to this project will be documented in this file. See [standard-version](https://github.com/conventional-changelog/standard-version) for commit guidelines.
|
||||||
|
|
||||||
### [0.9.2](https://gitea.deepak.science:2222/physics/pdme/compare/0.9.1...0.9.2) (2023-07-24)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* update tests but for git also don't wrap costs ([50f98ed](https://gitea.deepak.science:2222/physics/pdme/commit/50f98ed89b2a05cd47c41958036dd50bc872e07c))
|
|
||||||
|
|
||||||
### [0.9.1](https://gitea.deepak.science:2222/physics/pdme/compare/0.9.0...0.9.1) (2023-07-24)
|
|
||||||
|
|
||||||
|
|
||||||
### Bug Fixes
|
|
||||||
|
|
||||||
* fixes some of the shape mangling of our mcmc code ([e01d0e1](https://gitea.deepak.science:2222/physics/pdme/commit/e01d0e14a9bcd6d7e8fe9449ce562dbf1b8fd25c))
|
|
||||||
|
|
||||||
## [0.9.0](https://gitea.deepak.science:2222/physics/pdme/compare/0.8.9...0.9.0) (2023-07-24)
|
|
||||||
|
|
||||||
|
|
||||||
### ⚠ BREAKING CHANGES
|
|
||||||
|
|
||||||
* separates threshold cost and the seed_cost in mcmc
|
|
||||||
|
|
||||||
### Features
|
|
||||||
|
|
||||||
* separates threshold cost and the seed_cost in mcmc ([ca710e3](https://gitea.deepak.science:2222/physics/pdme/commit/ca710e359fd0cfbb620a3574a2fa4fab1be2b52a))
|
|
||||||
|
|
||||||
### [0.8.9](https://gitea.deepak.science:2222/physics/pdme/compare/0.8.8...0.8.9) (2023-07-23)
|
### [0.8.9](https://gitea.deepak.science:2222/physics/pdme/compare/0.8.8...0.8.9) (2023-07-23)
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,14 +129,12 @@ class LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel(
|
|||||||
|
|
||||||
p_mask = rng.binomial(1, self.prob_occupancy, shape)
|
p_mask = rng.binomial(1, self.prob_occupancy, shape)
|
||||||
|
|
||||||
# dipoles = numpy.einsum("ij,k->ijk", p_mask, self.moment_fixed)
|
dipoles = numpy.einsum("ij,k->ijk", p_mask, self.moment_fixed)
|
||||||
# Is there a better way to create the final array? probably! can create a flatter guy then reshape.
|
# Is there a better way to create the final array? probably! can create a flatter guy then reshape.
|
||||||
# this is easier to reason about.
|
# this is easier to reason about.
|
||||||
p_magnitude = self.pfixed * p_mask
|
px = dipoles[:, :, 0]
|
||||||
|
py = dipoles[:, :, 1]
|
||||||
px = p_magnitude * numpy.sin(self.thetafixed) * numpy.cos(self.phifixed)
|
pz = dipoles[:, :, 2]
|
||||||
py = p_magnitude * numpy.sin(self.thetafixed) * numpy.sin(self.phifixed)
|
|
||||||
pz = p_magnitude * numpy.cos(self.thetafixed)
|
|
||||||
|
|
||||||
sx = rng.uniform(self.xmin, self.xmax, shape)
|
sx = rng.uniform(self.xmin, self.xmax, shape)
|
||||||
sy = rng.uniform(self.ymin, self.ymax, shape)
|
sy = rng.uniform(self.ymin, self.ymax, shape)
|
||||||
|
@ -47,7 +47,6 @@ class DipoleModel:
|
|||||||
seed,
|
seed,
|
||||||
cost_function,
|
cost_function,
|
||||||
chain_length,
|
chain_length,
|
||||||
threshold_cost: float,
|
|
||||||
stdevs: pdme.subspace_simulation.MCMCStandardDeviation,
|
stdevs: pdme.subspace_simulation.MCMCStandardDeviation,
|
||||||
initial_cost: Optional[float] = None,
|
initial_cost: Optional[float] = None,
|
||||||
rng_arg: Optional[numpy.random.Generator] = None,
|
rng_arg: Optional[numpy.random.Generator] = None,
|
||||||
@ -72,16 +71,15 @@ class DipoleModel:
|
|||||||
f"Starting Markov Chain Monte Carlo with seed: {seed} for chain length {chain_length} and provided stdevs {stdevs}"
|
f"Starting Markov Chain Monte Carlo with seed: {seed} for chain length {chain_length} and provided stdevs {stdevs}"
|
||||||
)
|
)
|
||||||
chain: List[Tuple[float, numpy.ndarray]] = []
|
chain: List[Tuple[float, numpy.ndarray]] = []
|
||||||
if initial_cost is None:
|
|
||||||
current_cost = cost_function(numpy.array([seed]))
|
|
||||||
else:
|
|
||||||
current_cost = initial_cost
|
|
||||||
current = seed
|
current = seed
|
||||||
|
if initial_cost is None:
|
||||||
|
cost_to_compare = cost_function(current)
|
||||||
|
else:
|
||||||
|
cost_to_compare = initial_cost
|
||||||
|
current_cost = cost_to_compare
|
||||||
for i in range(chain_length):
|
for i in range(chain_length):
|
||||||
dips = []
|
dips = []
|
||||||
for dipole_index, dipole in enumerate(current):
|
for dipole_index, dipole in enumerate(current):
|
||||||
_logger.debug(dipole_index)
|
|
||||||
_logger.debug(dipole)
|
|
||||||
stdev = stdevs[dipole_index]
|
stdev = stdevs[dipole_index]
|
||||||
tentative_dip = self.markov_chain_monte_carlo_proposal(
|
tentative_dip = self.markov_chain_monte_carlo_proposal(
|
||||||
dipole, stdev, rng_arg
|
dipole, stdev, rng_arg
|
||||||
@ -91,11 +89,11 @@ class DipoleModel:
|
|||||||
dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(
|
dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(
|
||||||
dips
|
dips
|
||||||
)
|
)
|
||||||
tentative_cost = cost_function(numpy.array([dips_array]))[0]
|
tentative_cost = cost_function(dips_array)
|
||||||
if tentative_cost < threshold_cost:
|
if tentative_cost < cost_to_compare:
|
||||||
chain.append((numpy.squeeze(tentative_cost).item(), dips_array))
|
chain.append((tentative_cost, dips_array))
|
||||||
current = dips_array
|
current = dips_array
|
||||||
current_cost = tentative_cost
|
current_cost = tentative_cost
|
||||||
else:
|
else:
|
||||||
chain.append((numpy.squeeze(current_cost).item(), current))
|
chain.append((current_cost, current))
|
||||||
return chain
|
return chain
|
||||||
|
@ -15,6 +15,6 @@ def proportional_costs_vs_actual_measurement(
|
|||||||
dipoles_to_test: numpy.ndarray,
|
dipoles_to_test: numpy.ndarray,
|
||||||
) -> numpy.ndarray:
|
) -> numpy.ndarray:
|
||||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||||
dot_inputs_array, dipoles_to_test
|
dot_inputs_array, numpy.array([dipoles_to_test])
|
||||||
)
|
)
|
||||||
return proportional_cost(actual_measurement_array, vals)
|
return proportional_cost(actual_measurement_array, vals)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "pdme"
|
name = "pdme"
|
||||||
version = "0.9.2"
|
version = "0.8.9"
|
||||||
description = "Python dipole model evaluator"
|
description = "Python dipole model evaluator"
|
||||||
authors = ["Deepak <dmallubhotla+github@gmail.com>"]
|
authors = ["Deepak <dmallubhotla+github@gmail.com>"]
|
||||||
license = "GPL-3.0-only"
|
license = "GPL-3.0-only"
|
||||||
@ -15,7 +15,7 @@ scipy = "~1.10"
|
|||||||
pytest = ">=6"
|
pytest = ">=6"
|
||||||
flake8 = "^4.0.0"
|
flake8 = "^4.0.0"
|
||||||
pytest-cov = "^4.1.0"
|
pytest-cov = "^4.1.0"
|
||||||
mypy = "^1.5"
|
mypy = "^1.4"
|
||||||
ipython = "^8.2.0"
|
ipython = "^8.2.0"
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
syrupy = "^4.0.8"
|
syrupy = "^4.0.8"
|
||||||
@ -28,8 +28,6 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
addopts = "--junitxml pytest.xml --cov pdme --cov-report=xml:coverage.xml --cov-fail-under=50 --cov-report=html"
|
addopts = "--junitxml pytest.xml --cov pdme --cov-report=xml:coverage.xml --cov-fail-under=50 --cov-report=html"
|
||||||
junit_family = "xunit1"
|
junit_family = "xunit1"
|
||||||
log_format = "%(asctime)s | %(levelname)s | %(pathname)s:%(lineno)d | %(message)s"
|
|
||||||
log_level = "WARNING"
|
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
plugins = "numpy.typing.mypy_plugin"
|
plugins = "numpy.typing.mypy_plugin"
|
||||||
|
@ -2,52 +2,52 @@
|
|||||||
# name: test_log_spaced_fixedxy_orientation_mcmc_basic
|
# name: test_log_spaced_fixedxy_orientation_mcmc_basic
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
3984.461796565,
|
array([3984.46179656]),
|
||||||
array([[ 9.55610128, 2.94634152, 0. , 9.21529051, -2.46576127,
|
array([[ 9.55610128, 2.94634152, 0. , 9.21529051, -2.46576127,
|
||||||
2.42481096, 9.19034554]]),
|
2.42481096, 9.19034554]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
8583.9908787152,
|
array([8583.99087872]),
|
||||||
array([[ 9.99991539, 0.04113671, 0. , 8.71258954, -2.26599865,
|
array([[ 9.99991539, 0.04113671, 0. , 8.71258954, -2.26599865,
|
||||||
2.60452102, 6.37042214]]),
|
2.60452102, 6.37042214]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
6215.6376616016,
|
array([6215.6376616]),
|
||||||
array([[ 9.81950685, -1.89137124, 0. , 8.90637055, -2.48043039,
|
array([[ 9.81950685, -1.89137124, 0. , 8.90637055, -2.48043039,
|
||||||
2.28444435, 8.84239221]]),
|
2.28444435, 8.84239221]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
424.7332846598,
|
array([424.73328466]),
|
||||||
array([[ 1.00028483, 9.94984574, 0. , 8.53064898, -2.59230757,
|
array([[ 1.00028483, 9.94984574, 0. , 8.53064898, -2.59230757,
|
||||||
2.33774773, 8.6714416 ]]),
|
2.33774773, 8.6714416 ]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
300.9220380849,
|
array([300.92203808]),
|
||||||
array([[ 1.4003442 , 9.90146636, 0. , 8.05557992, -2.6753126 ,
|
array([[ 1.4003442 , 9.90146636, 0. , 8.05557992, -2.6753126 ,
|
||||||
2.65915755, 13.02021385]]),
|
2.65915755, 13.02021385]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
2400.0107277085,
|
array([2400.01072771]),
|
||||||
array([[ 9.97761813, 0.66868263, 0. , 8.69171028, -2.73145011,
|
array([[ 9.97761813, 0.66868263, 0. , 8.69171028, -2.73145011,
|
||||||
2.90140456, 19.94999593]]),
|
2.90140456, 19.94999593]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
5001.4620511303,
|
array([5001.46205113]),
|
||||||
array([[ 9.93976109, -1.09596962, 0. , 8.95245025, -2.59409162,
|
array([[ 9.93976109, -1.09596962, 0. , 8.95245025, -2.59409162,
|
||||||
2.90140456, 9.75535945]]),
|
2.90140456, 9.75535945]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
195.2198074488,
|
array([195.21980745]),
|
||||||
array([[ 0.20690762, 9.99785923, 0. , 9.59636585, -2.83240984,
|
array([[ 0.20690762, 9.99785923, 0. , 9.59636585, -2.83240984,
|
||||||
2.90140456, 16.14771567]]),
|
2.90140456, 16.14771567]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
2698.258844498,
|
array([2698.2588445]),
|
||||||
array([[-9.68130127, -2.50447712, 0. , 8.94823619, -2.92889659,
|
array([[-9.68130127, -2.50447712, 0. , 8.94823619, -2.92889659,
|
||||||
2.77065328, 13.63173263]]),
|
2.77065328, 13.63173263]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
1193.6985473944,
|
array([1193.69854739]),
|
||||||
array([[-6.16597091, -7.87278875, 0. , 9.62210721, -2.75993924,
|
array([[-6.16597091, -7.87278875, 0. , 9.62210721, -2.75993924,
|
||||||
2.77065328, 5.64553534]]),
|
2.77065328, 5.64553534]]),
|
||||||
),
|
),
|
||||||
|
@ -2,52 +2,52 @@
|
|||||||
# name: test_log_spaced_free_orientation_mcmc_basic
|
# name: test_log_spaced_free_orientation_mcmc_basic
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
3167.6711268743,
|
array([3167.67112687]),
|
||||||
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
|
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
|
||||||
2.11809123, 16.17452242]]),
|
2.11809123, 16.17452242]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
3167.6711268743,
|
array([3167.67112687]),
|
||||||
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
|
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
|
||||||
2.11809123, 16.17452242]]),
|
2.11809123, 16.17452242]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
3167.6711268743,
|
array([3167.67112687]),
|
||||||
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
|
array([[ 9.60483896, -1.41627817, -2.3960853 , -4.76615152, -1.80902942,
|
||||||
2.11809123, 16.17452242]]),
|
2.11809123, 16.17452242]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
736.0306527136,
|
array([736.03065271]),
|
||||||
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
|
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
|
||||||
2.43615641, 12.92143144]]),
|
2.43615641, 12.92143144]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
736.0306527136,
|
array([736.03065271]),
|
||||||
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
|
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
|
||||||
2.43615641, 12.92143144]]),
|
2.43615641, 12.92143144]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
736.0306527136,
|
array([736.03065271]),
|
||||||
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
|
array([[ 4.1660069 , -8.11557337, 4.0965663 , -4.35968351, -1.97945216,
|
||||||
2.43615641, 12.92143144]]),
|
2.43615641, 12.92143144]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
2248.0779986277,
|
array([2248.07799863]),
|
||||||
array([[-1.71755535, -5.59925137, 8.10545419, -4.03306318, -1.81098441,
|
array([[-1.71755535, -5.59925137, 8.10545419, -4.03306318, -1.81098441,
|
||||||
2.77407111, 32.28020575]]),
|
2.77407111, 32.28020575]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
1663.310672736,
|
array([1663.31067274]),
|
||||||
array([[-5.16785855, 2.7558756 , 8.10545419, -3.34620897, -1.74763642,
|
array([[-5.16785855, 2.7558756 , 8.10545419, -3.34620897, -1.74763642,
|
||||||
2.42770463, 52.98214008]]),
|
2.42770463, 52.98214008]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
1329.2704143918,
|
array([1329.27041439]),
|
||||||
array([[ -1.39600464, 9.69718343, -2.00394725, -2.59147366,
|
array([[ -1.39600464, 9.69718343, -2.00394725, -2.59147366,
|
||||||
-1.91246681, 2.07361175, 123.01833742]]),
|
-1.91246681, 2.07361175, 123.01833742]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
355.7695591897,
|
array([355.76955919]),
|
||||||
array([[ 9.76047401, 0.84696075, -2.00394725, -3.04310053,
|
array([[ 9.76047401, 0.84696075, -2.00394725, -3.04310053,
|
||||||
-1.99338573, 2.1185589 , 271.35743739]]),
|
-1.99338573, 2.1185589 , 271.35743739]]),
|
||||||
),
|
),
|
||||||
|
@ -2,52 +2,52 @@
|
|||||||
# name: test_log_spaced_fixed_orientation_mcmc_basic
|
# name: test_log_spaced_fixed_orientation_mcmc_basic
|
||||||
list([
|
list([
|
||||||
tuple(
|
tuple(
|
||||||
50.5683119299,
|
array([50.56831193]),
|
||||||
array([[ 0. , 0. , 10. , -2.3960853 , 4.23246234,
|
array([[ 0. , 0. , 10. , -2.3960853 , 4.23246234,
|
||||||
2.26169242, 39.39900844]]),
|
2.26169242, 39.39900844]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
50.5683119299,
|
array([50.56831193]),
|
||||||
array([[ 0. , 0. , 10. , -2.3960853 , 4.23246234,
|
array([[ 0. , 0. , 10. , -2.3960853 , 4.23246234,
|
||||||
2.26169242, 39.39900844]]),
|
2.26169242, 39.39900844]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
47.408654554,
|
array([47.40865455]),
|
||||||
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
||||||
2.21309317, 47.82371559]]),
|
2.21309317, 47.82371559]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
47.408654554,
|
array([47.40865455]),
|
||||||
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
||||||
2.21309317, 47.82371559]]),
|
2.21309317, 47.82371559]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
47.408654554,
|
array([47.40865455]),
|
||||||
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
||||||
2.21309317, 47.82371559]]),
|
2.21309317, 47.82371559]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
47.408654554,
|
array([47.40865455]),
|
||||||
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
array([[ 0. , 0. , 10. , -2.03666518, 4.14084039,
|
||||||
2.21309317, 47.82371559]]),
|
2.21309317, 47.82371559]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
22.9327902847,
|
array([22.93279028]),
|
||||||
array([[ 0. , 0. , 10. , -1.63019717, 3.97041764,
|
array([[ 0. , 0. , 10. , -1.63019717, 3.97041764,
|
||||||
2.53115835, 38.2051999 ]]),
|
2.53115835, 38.2051999 ]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
28.8119773322,
|
array([28.81197733]),
|
||||||
array([[ 0. , 0. , 10. , -1.14570315, 4.07709911,
|
array([[ 0. , 0. , 10. , -1.14570315, 4.07709911,
|
||||||
2.48697441, 49.58615195]]),
|
2.48697441, 49.58615195]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
28.8119773322,
|
array([28.81197733]),
|
||||||
array([[ 0. , 0. , 10. , -1.14570315, 4.07709911,
|
array([[ 0. , 0. , 10. , -1.14570315, 4.07709911,
|
||||||
2.48697441, 49.58615195]]),
|
2.48697441, 49.58615195]]),
|
||||||
),
|
),
|
||||||
tuple(
|
tuple(
|
||||||
40.9740600543,
|
array([40.97406005]),
|
||||||
array([[ 0. , 0. , 10. , -0.50178755, 3.83878089,
|
array([[ 0. , 0. , 10. , -0.50178755, 3.83878089,
|
||||||
2.93560796, 82.07827571]]),
|
2.93560796, 82.07827571]]),
|
||||||
),
|
),
|
||||||
|
@ -5,9 +5,6 @@ import pdme.inputs
|
|||||||
import pdme.measurement.input_types
|
import pdme.measurement.input_types
|
||||||
import pdme.subspace_simulation
|
import pdme.subspace_simulation
|
||||||
import numpy
|
import numpy
|
||||||
import logging
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SEED_TO_USE = 42
|
SEED_TO_USE = 42
|
||||||
|
|
||||||
@ -46,9 +43,9 @@ def get_cost_function():
|
|||||||
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
||||||
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
||||||
|
|
||||||
def cost_to_use(sample_dipoleses: numpy.ndarray) -> numpy.ndarray:
|
def cost_to_use(sample_dipoles: numpy.ndarray) -> numpy.ndarray:
|
||||||
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
||||||
dot_input_array, actual_measurements_array, sample_dipoleses
|
dot_input_array, actual_measurements_array, sample_dipoles
|
||||||
)
|
)
|
||||||
|
|
||||||
return cost_to_use
|
return cost_to_use
|
||||||
@ -80,21 +77,14 @@ def test_log_spaced_fixedxy_orientation_mcmc_basic(snapshot):
|
|||||||
)
|
)
|
||||||
model.rng = numpy.random.default_rng(1234)
|
model.rng = numpy.random.default_rng(1234)
|
||||||
|
|
||||||
seed = model.get_monte_carlo_dipole_inputs(1, -1)
|
seed = model.get_monte_carlo_dipole_inputs(1, -1)[0]
|
||||||
|
|
||||||
cost_function = get_cost_function()
|
cost_function = get_cost_function()
|
||||||
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
||||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||||
|
|
||||||
chain = model.get_mcmc_chain(
|
chain = model.get_mcmc_chain(
|
||||||
seed[0],
|
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||||
cost_function,
|
|
||||||
10,
|
|
||||||
cost_function(seed)[0],
|
|
||||||
stdevs,
|
|
||||||
rng_arg=numpy.random.default_rng(1515),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_rounded = [(round(cost, 10), dipoles) for (cost, dipoles) in chain]
|
assert chain == snapshot
|
||||||
|
|
||||||
assert chain_rounded == snapshot
|
|
||||||
|
@ -5,9 +5,6 @@ import pdme.inputs
|
|||||||
import pdme.measurement.input_types
|
import pdme.measurement.input_types
|
||||||
import pdme.subspace_simulation
|
import pdme.subspace_simulation
|
||||||
import numpy
|
import numpy
|
||||||
import logging
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SEED_TO_USE = 42
|
SEED_TO_USE = 42
|
||||||
|
|
||||||
@ -46,9 +43,9 @@ def get_cost_function():
|
|||||||
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
||||||
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
||||||
|
|
||||||
def cost_to_use(sample_dipoleses: numpy.ndarray) -> numpy.ndarray:
|
def cost_to_use(sample_dipoles: numpy.ndarray) -> numpy.ndarray:
|
||||||
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
||||||
dot_input_array, actual_measurements_array, sample_dipoleses
|
dot_input_array, actual_measurements_array, sample_dipoles
|
||||||
)
|
)
|
||||||
|
|
||||||
return cost_to_use
|
return cost_to_use
|
||||||
@ -80,21 +77,14 @@ def test_log_spaced_free_orientation_mcmc_basic(snapshot):
|
|||||||
)
|
)
|
||||||
model.rng = numpy.random.default_rng(1234)
|
model.rng = numpy.random.default_rng(1234)
|
||||||
|
|
||||||
seed = model.get_monte_carlo_dipole_inputs(1, -1)
|
seed = model.get_monte_carlo_dipole_inputs(1, -1)[0]
|
||||||
|
|
||||||
cost_function = get_cost_function()
|
cost_function = get_cost_function()
|
||||||
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
||||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||||
|
|
||||||
chain = model.get_mcmc_chain(
|
chain = model.get_mcmc_chain(
|
||||||
seed[0],
|
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||||
cost_function,
|
|
||||||
10,
|
|
||||||
cost_function(seed)[0],
|
|
||||||
stdevs,
|
|
||||||
rng_arg=numpy.random.default_rng(1515),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_rounded = [(round(cost, 10), dipoles) for (cost, dipoles) in chain]
|
assert chain == snapshot
|
||||||
|
|
||||||
assert chain_rounded == snapshot
|
|
||||||
|
@ -5,9 +5,6 @@ import pdme.inputs
|
|||||||
import pdme.measurement.input_types
|
import pdme.measurement.input_types
|
||||||
import pdme.subspace_simulation
|
import pdme.subspace_simulation
|
||||||
import numpy
|
import numpy
|
||||||
import logging
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
SEED_TO_USE = 42
|
SEED_TO_USE = 42
|
||||||
|
|
||||||
@ -50,9 +47,9 @@ def get_cost_function():
|
|||||||
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
||||||
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
||||||
|
|
||||||
def cost_to_use(sample_dipoleses: numpy.ndarray) -> numpy.ndarray:
|
def cost_to_use(sample_dipoles: numpy.ndarray) -> numpy.ndarray:
|
||||||
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
||||||
dot_input_array, actual_measurements_array, sample_dipoleses
|
dot_input_array, actual_measurements_array, sample_dipoles
|
||||||
)
|
)
|
||||||
|
|
||||||
return cost_to_use
|
return cost_to_use
|
||||||
@ -88,21 +85,14 @@ def test_log_spaced_fixed_orientation_mcmc_basic(snapshot):
|
|||||||
)
|
)
|
||||||
model.rng = numpy.random.default_rng(1234)
|
model.rng = numpy.random.default_rng(1234)
|
||||||
|
|
||||||
seed = model.get_monte_carlo_dipole_inputs(1, -1)
|
seed = model.get_monte_carlo_dipole_inputs(1, -1)[0]
|
||||||
|
|
||||||
cost_function = get_cost_function()
|
cost_function = get_cost_function()
|
||||||
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
||||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||||
|
|
||||||
chain = model.get_mcmc_chain(
|
chain = model.get_mcmc_chain(
|
||||||
seed[0],
|
seed, cost_function, 10, stdevs, rng_arg=numpy.random.default_rng(1515)
|
||||||
cost_function,
|
|
||||||
10,
|
|
||||||
cost_function(seed)[0],
|
|
||||||
stdevs,
|
|
||||||
rng_arg=numpy.random.default_rng(1515),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_rounded = [(round(cost, 10), dipoles) for (cost, dipoles) in chain]
|
assert chain == snapshot
|
||||||
|
|
||||||
assert chain_rounded == snapshot
|
|
||||||
|
@ -116,6 +116,7 @@ def test_random_count_multiple_dipole_fixed_mag_model_get_dipoles_invariant():
|
|||||||
|
|
||||||
|
|
||||||
def test_random_count_multiple_dipole_fixed_or_fixed_mag_model_get_n_dipoles(snapshot):
|
def test_random_count_multiple_dipole_fixed_or_fixed_mag_model_get_n_dipoles(snapshot):
|
||||||
|
# TODO: this test is a bit garbage just calls things without testing.
|
||||||
x_min = -10
|
x_min = -10
|
||||||
x_max = 10
|
x_max = 10
|
||||||
y_min = -5
|
y_min = -5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user