fix: fixes some of the shape mangling of our mcmc code
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
All checks were successful
gitea-physics/pdme/pipeline/head This commit looks good
This commit is contained in:
parent
dfaf8abed5
commit
e01d0e14a9
@ -129,12 +129,14 @@ class LogSpacedRandomCountMultipleDipoleFixedMagnitudeFixedOrientationModel(
|
|||||||
|
|
||||||
p_mask = rng.binomial(1, self.prob_occupancy, shape)
|
p_mask = rng.binomial(1, self.prob_occupancy, shape)
|
||||||
|
|
||||||
dipoles = numpy.einsum("ij,k->ijk", p_mask, self.moment_fixed)
|
# dipoles = numpy.einsum("ij,k->ijk", p_mask, self.moment_fixed)
|
||||||
# Is there a better way to create the final array? probably! can create a flatter guy then reshape.
|
# Is there a better way to create the final array? probably! can create a flatter guy then reshape.
|
||||||
# this is easier to reason about.
|
# this is easier to reason about.
|
||||||
px = dipoles[:, :, 0]
|
p_magnitude = self.pfixed * p_mask
|
||||||
py = dipoles[:, :, 1]
|
|
||||||
pz = dipoles[:, :, 2]
|
px = p_magnitude * numpy.sin(self.thetafixed) * numpy.cos(self.phifixed)
|
||||||
|
py = p_magnitude * numpy.sin(self.thetafixed) * numpy.sin(self.phifixed)
|
||||||
|
pz = p_magnitude * numpy.cos(self.thetafixed)
|
||||||
|
|
||||||
sx = rng.uniform(self.xmin, self.xmax, shape)
|
sx = rng.uniform(self.xmin, self.xmax, shape)
|
||||||
sy = rng.uniform(self.ymin, self.ymax, shape)
|
sy = rng.uniform(self.ymin, self.ymax, shape)
|
||||||
|
@ -72,24 +72,24 @@ class DipoleModel:
|
|||||||
f"Starting Markov Chain Monte Carlo with seed: {seed} for chain length {chain_length} and provided stdevs {stdevs}"
|
f"Starting Markov Chain Monte Carlo with seed: {seed} for chain length {chain_length} and provided stdevs {stdevs}"
|
||||||
)
|
)
|
||||||
chain: List[Tuple[float, numpy.ndarray]] = []
|
chain: List[Tuple[float, numpy.ndarray]] = []
|
||||||
current = seed
|
|
||||||
if initial_cost is None:
|
if initial_cost is None:
|
||||||
current_cost = cost_function(current)
|
current_cost = cost_function(numpy.array([seed]))
|
||||||
else:
|
else:
|
||||||
current_cost = initial_cost
|
current_cost = initial_cost
|
||||||
|
current = seed
|
||||||
for i in range(chain_length):
|
for i in range(chain_length):
|
||||||
dips = []
|
dips = []
|
||||||
for dipole_index, dipole in enumerate(current):
|
for dipole_index, dipole in enumerate(current):
|
||||||
|
_logger.debug(dipole_index)
|
||||||
|
_logger.debug(dipole)
|
||||||
stdev = stdevs[dipole_index]
|
stdev = stdevs[dipole_index]
|
||||||
tentative_dip = self.markov_chain_monte_carlo_proposal(
|
tentative_dip = self.markov_chain_monte_carlo_proposal(
|
||||||
dipole, stdev, rng_arg
|
dipole, stdev, rng_arg
|
||||||
)
|
)
|
||||||
|
|
||||||
dips.append(tentative_dip)
|
dips.append(tentative_dip)
|
||||||
dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(
|
dips_array = pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(dips)
|
||||||
dips
|
tentative_cost = cost_function(numpy.array([dips_array]))
|
||||||
)
|
|
||||||
tentative_cost = cost_function(dips_array)
|
|
||||||
if tentative_cost < threshold_cost:
|
if tentative_cost < threshold_cost:
|
||||||
chain.append((tentative_cost, dips_array))
|
chain.append((tentative_cost, dips_array))
|
||||||
current = dips_array
|
current = dips_array
|
||||||
|
@ -15,6 +15,6 @@ def proportional_costs_vs_actual_measurement(
|
|||||||
dipoles_to_test: numpy.ndarray,
|
dipoles_to_test: numpy.ndarray,
|
||||||
) -> numpy.ndarray:
|
) -> numpy.ndarray:
|
||||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoleses(
|
||||||
dot_inputs_array, numpy.array([dipoles_to_test])
|
dot_inputs_array, dipoles_to_test
|
||||||
)
|
)
|
||||||
return proportional_cost(actual_measurement_array, vals)
|
return proportional_cost(actual_measurement_array, vals)
|
||||||
|
@ -28,6 +28,8 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
addopts = "--junitxml pytest.xml --cov pdme --cov-report=xml:coverage.xml --cov-fail-under=50 --cov-report=html"
|
addopts = "--junitxml pytest.xml --cov pdme --cov-report=xml:coverage.xml --cov-fail-under=50 --cov-report=html"
|
||||||
junit_family = "xunit1"
|
junit_family = "xunit1"
|
||||||
|
log_format = "%(asctime)s | %(levelname)s | %(pathname)s:%(lineno)d | %(message)s"
|
||||||
|
log_level = "DEBUG"
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
plugins = "numpy.typing.mypy_plugin"
|
plugins = "numpy.typing.mypy_plugin"
|
||||||
|
@ -5,6 +5,9 @@ import pdme.inputs
|
|||||||
import pdme.measurement.input_types
|
import pdme.measurement.input_types
|
||||||
import pdme.subspace_simulation
|
import pdme.subspace_simulation
|
||||||
import numpy
|
import numpy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SEED_TO_USE = 42
|
SEED_TO_USE = 42
|
||||||
|
|
||||||
@ -43,9 +46,9 @@ def get_cost_function():
|
|||||||
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
||||||
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
||||||
|
|
||||||
def cost_to_use(sample_dipoles: numpy.ndarray) -> numpy.ndarray:
|
def cost_to_use(sample_dipoleses: numpy.ndarray) -> numpy.ndarray:
|
||||||
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
||||||
dot_input_array, actual_measurements_array, sample_dipoles
|
dot_input_array, actual_measurements_array, sample_dipoleses
|
||||||
)
|
)
|
||||||
|
|
||||||
return cost_to_use
|
return cost_to_use
|
||||||
@ -77,14 +80,19 @@ def test_log_spaced_fixedxy_orientation_mcmc_basic(snapshot):
|
|||||||
)
|
)
|
||||||
model.rng = numpy.random.default_rng(1234)
|
model.rng = numpy.random.default_rng(1234)
|
||||||
|
|
||||||
seed = model.get_monte_carlo_dipole_inputs(1, -1)[0]
|
seed = model.get_monte_carlo_dipole_inputs(1, -1)
|
||||||
|
|
||||||
cost_function = get_cost_function()
|
cost_function = get_cost_function()
|
||||||
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
||||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||||
|
|
||||||
chain = model.get_mcmc_chain(
|
chain = model.get_mcmc_chain(
|
||||||
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
|
seed[0],
|
||||||
|
cost_function,
|
||||||
|
10,
|
||||||
|
cost_function(seed)[0],
|
||||||
|
stdevs,
|
||||||
|
rng_arg=numpy.random.default_rng(1515),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chain == snapshot
|
assert chain == snapshot
|
||||||
|
@ -5,6 +5,9 @@ import pdme.inputs
|
|||||||
import pdme.measurement.input_types
|
import pdme.measurement.input_types
|
||||||
import pdme.subspace_simulation
|
import pdme.subspace_simulation
|
||||||
import numpy
|
import numpy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SEED_TO_USE = 42
|
SEED_TO_USE = 42
|
||||||
|
|
||||||
@ -43,9 +46,9 @@ def get_cost_function():
|
|||||||
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
||||||
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
||||||
|
|
||||||
def cost_to_use(sample_dipoles: numpy.ndarray) -> numpy.ndarray:
|
def cost_to_use(sample_dipoleses: numpy.ndarray) -> numpy.ndarray:
|
||||||
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
||||||
dot_input_array, actual_measurements_array, sample_dipoles
|
dot_input_array, actual_measurements_array, sample_dipoleses
|
||||||
)
|
)
|
||||||
|
|
||||||
return cost_to_use
|
return cost_to_use
|
||||||
@ -77,14 +80,19 @@ def test_log_spaced_free_orientation_mcmc_basic(snapshot):
|
|||||||
)
|
)
|
||||||
model.rng = numpy.random.default_rng(1234)
|
model.rng = numpy.random.default_rng(1234)
|
||||||
|
|
||||||
seed = model.get_monte_carlo_dipole_inputs(1, -1)[0]
|
seed = model.get_monte_carlo_dipole_inputs(1, -1)
|
||||||
|
|
||||||
cost_function = get_cost_function()
|
cost_function = get_cost_function()
|
||||||
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
||||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||||
|
|
||||||
chain = model.get_mcmc_chain(
|
chain = model.get_mcmc_chain(
|
||||||
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
|
seed[0],
|
||||||
|
cost_function,
|
||||||
|
10,
|
||||||
|
cost_function(seed)[0],
|
||||||
|
stdevs,
|
||||||
|
rng_arg=numpy.random.default_rng(1515),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chain == snapshot
|
assert chain == snapshot
|
||||||
|
@ -5,6 +5,9 @@ import pdme.inputs
|
|||||||
import pdme.measurement.input_types
|
import pdme.measurement.input_types
|
||||||
import pdme.subspace_simulation
|
import pdme.subspace_simulation
|
||||||
import numpy
|
import numpy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SEED_TO_USE = 42
|
SEED_TO_USE = 42
|
||||||
|
|
||||||
@ -47,9 +50,9 @@ def get_cost_function():
|
|||||||
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
actual_measurements = actual_dipoles.get_dot_measurements(dot_inputs)
|
||||||
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
actual_measurements_array = numpy.array([m.v for m in actual_measurements])
|
||||||
|
|
||||||
def cost_to_use(sample_dipoles: numpy.ndarray) -> numpy.ndarray:
|
def cost_to_use(sample_dipoleses: numpy.ndarray) -> numpy.ndarray:
|
||||||
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
return pdme.subspace_simulation.proportional_costs_vs_actual_measurement(
|
||||||
dot_input_array, actual_measurements_array, sample_dipoles
|
dot_input_array, actual_measurements_array, sample_dipoleses
|
||||||
)
|
)
|
||||||
|
|
||||||
return cost_to_use
|
return cost_to_use
|
||||||
@ -85,14 +88,19 @@ def test_log_spaced_fixed_orientation_mcmc_basic(snapshot):
|
|||||||
)
|
)
|
||||||
model.rng = numpy.random.default_rng(1234)
|
model.rng = numpy.random.default_rng(1234)
|
||||||
|
|
||||||
seed = model.get_monte_carlo_dipole_inputs(1, -1)[0]
|
seed = model.get_monte_carlo_dipole_inputs(1, -1)
|
||||||
|
|
||||||
cost_function = get_cost_function()
|
cost_function = get_cost_function()
|
||||||
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
stdev = pdme.subspace_simulation.DipoleStandardDeviation(2, 2, 1, 0.25, 0.5, 1)
|
||||||
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
stdevs = pdme.subspace_simulation.MCMCStandardDeviation([stdev])
|
||||||
|
|
||||||
chain = model.get_mcmc_chain(
|
chain = model.get_mcmc_chain(
|
||||||
seed, cost_function, 10, cost_function(seed)[0], stdevs, rng_arg=numpy.random.default_rng(1515)
|
seed[0],
|
||||||
|
cost_function,
|
||||||
|
10,
|
||||||
|
cost_function(seed)[0],
|
||||||
|
stdevs,
|
||||||
|
rng_arg=numpy.random.default_rng(1515),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert chain == snapshot
|
assert chain == snapshot
|
||||||
|
@ -116,7 +116,6 @@ def test_random_count_multiple_dipole_fixed_mag_model_get_dipoles_invariant():
|
|||||||
|
|
||||||
|
|
||||||
def test_random_count_multiple_dipole_fixed_or_fixed_mag_model_get_n_dipoles(snapshot):
|
def test_random_count_multiple_dipole_fixed_or_fixed_mag_model_get_n_dipoles(snapshot):
|
||||||
# TODO: this test is a bit garbage just calls things without testing.
|
|
||||||
x_min = -10
|
x_min = -10
|
||||||
x_max = 10
|
x_max = 10
|
||||||
y_min = -5
|
y_min = -5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user