fmt: auto format changes
All checks were successful
gitea-physics/deepdog/pipeline/head This commit looks good
All checks were successful
gitea-physics/deepdog/pipeline/head This commit looks good
This commit is contained in:
parent
1741807be4
commit
b10586bf55
@ -1,3 +1,6 @@
|
|||||||
from deepdog.direct_monte_carlo.direct_mc import DirectMonteCarloRun, DirectMonteCarloConfig
|
from deepdog.direct_monte_carlo.direct_mc import (
|
||||||
|
DirectMonteCarloRun,
|
||||||
|
DirectMonteCarloConfig,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["DirectMonteCarloRun", "DirectMonteCarloConfig"]
|
__all__ = ["DirectMonteCarloRun", "DirectMonteCarloConfig"]
|
||||||
|
@ -18,6 +18,7 @@ class DirectMonteCarloResult:
|
|||||||
monte_carlo_count: int
|
monte_carlo_count: int
|
||||||
likelihood: float
|
likelihood: float
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DirectMonteCarloConfig:
|
class DirectMonteCarloConfig:
|
||||||
monte_carlo_count_per_cycle: int = 10000
|
monte_carlo_count_per_cycle: int = 10000
|
||||||
@ -28,6 +29,7 @@ class DirectMonteCarloConfig:
|
|||||||
write_successes_to_file: bool = False
|
write_successes_to_file: bool = False
|
||||||
tag: str = ""
|
tag: str = ""
|
||||||
|
|
||||||
|
|
||||||
class DirectMonteCarloRun:
|
class DirectMonteCarloRun:
|
||||||
"""
|
"""
|
||||||
A single model Direct Monte Carlo run, currently implemented only using single threading.
|
A single model Direct Monte Carlo run, currently implemented only using single threading.
|
||||||
@ -82,6 +84,7 @@ class DirectMonteCarloRun:
|
|||||||
) = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(
|
) = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(
|
||||||
self.measurements
|
self.measurements
|
||||||
)
|
)
|
||||||
|
|
||||||
def _single_run(self, seed) -> numpy.ndarray:
|
def _single_run(self, seed) -> numpy.ndarray:
|
||||||
rng = numpy.random.default_rng(seed)
|
rng = numpy.random.default_rng(seed)
|
||||||
|
|
||||||
@ -98,7 +101,9 @@ class DirectMonteCarloRun:
|
|||||||
numpy.array([di]), current_sample
|
numpy.array([di]), current_sample
|
||||||
)
|
)
|
||||||
|
|
||||||
current_sample = current_sample[numpy.all((vals > low) & (vals < high), axis=1)]
|
current_sample = current_sample[
|
||||||
|
numpy.all((vals > low) & (vals < high), axis=1)
|
||||||
|
]
|
||||||
return current_sample
|
return current_sample
|
||||||
|
|
||||||
def execute(self) -> DirectMonteCarloResult:
|
def execute(self) -> DirectMonteCarloResult:
|
||||||
@ -106,23 +111,30 @@ class DirectMonteCarloRun:
|
|||||||
total_success = 0
|
total_success = 0
|
||||||
total_count = 0
|
total_count = 0
|
||||||
|
|
||||||
count_per_step = self.config.monte_carlo_count_per_cycle * self.config.monte_carlo_cycles
|
count_per_step = (
|
||||||
|
self.config.monte_carlo_count_per_cycle * self.config.monte_carlo_cycles
|
||||||
|
)
|
||||||
seed_sequence = numpy.random.SeedSequence(self.config.monte_carlo_seed)
|
seed_sequence = numpy.random.SeedSequence(self.config.monte_carlo_seed)
|
||||||
while (
|
while (step_count < self.config.max_monte_carlo_cycles_steps) and (
|
||||||
(step_count < self.config.max_monte_carlo_cycles_steps) and
|
total_success < self.config.target_success
|
||||||
(total_success < self.config.target_success)
|
|
||||||
):
|
):
|
||||||
_logger.debug(f"Executing step {step_count}")
|
_logger.debug(f"Executing step {step_count}")
|
||||||
for cycle_i, seed in enumerate(seed_sequence.spawn(self.config.monte_carlo_cycles)):
|
for cycle_i, seed in enumerate(
|
||||||
|
seed_sequence.spawn(self.config.monte_carlo_cycles)
|
||||||
|
):
|
||||||
cycle_success_configs = self._single_run(seed)
|
cycle_success_configs = self._single_run(seed)
|
||||||
cycle_success_count = len(cycle_success_configs)
|
cycle_success_count = len(cycle_success_configs)
|
||||||
if cycle_success_count > 0:
|
if cycle_success_count > 0:
|
||||||
_logger.debug(f"For cycle {cycle_i} received {cycle_success_count} successes")
|
_logger.debug(
|
||||||
|
f"For cycle {cycle_i} received {cycle_success_count} successes"
|
||||||
|
)
|
||||||
_logger.debug(cycle_success_configs)
|
_logger.debug(cycle_success_configs)
|
||||||
if self.config.write_successes_to_file:
|
if self.config.write_successes_to_file:
|
||||||
sorted_by_freq = numpy.array(
|
sorted_by_freq = numpy.array(
|
||||||
[
|
[
|
||||||
pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(dipole_config)
|
pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(
|
||||||
|
dipole_config
|
||||||
|
)
|
||||||
for dipole_config in cycle_success_configs
|
for dipole_config in cycle_success_configs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -138,9 +150,8 @@ class DirectMonteCarloRun:
|
|||||||
step_count += 1
|
step_count += 1
|
||||||
total_count += count_per_step
|
total_count += count_per_step
|
||||||
|
|
||||||
|
|
||||||
return DirectMonteCarloResult(
|
return DirectMonteCarloResult(
|
||||||
successes=total_success,
|
successes=total_success,
|
||||||
monte_carlo_count=total_count,
|
monte_carlo_count=total_count,
|
||||||
likelihood=total_success/total_count
|
likelihood=total_success / total_count,
|
||||||
)
|
)
|
||||||
|
@ -101,11 +101,17 @@ class SubsetSimulation:
|
|||||||
# _logger.debug(sample_dipoles.shape)
|
# _logger.debug(sample_dipoles.shape)
|
||||||
|
|
||||||
raw_costs = []
|
raw_costs = []
|
||||||
_logger.debug(f"Using iterated cost function thing with chunk size {self.initial_cost_chunk_size}")
|
_logger.debug(
|
||||||
|
f"Using iterated cost function thing with chunk size {self.initial_cost_chunk_size}"
|
||||||
|
)
|
||||||
|
|
||||||
for x in range(0, len(sample_dipoles), self.initial_cost_chunk_size):
|
for x in range(0, len(sample_dipoles), self.initial_cost_chunk_size):
|
||||||
_logger.debug(f"doing chunk {x}")
|
_logger.debug(f"doing chunk {x}")
|
||||||
raw_costs.extend(self.cost_function_to_use(sample_dipoles[x: x + self.initial_cost_chunk_size]))
|
raw_costs.extend(
|
||||||
|
self.cost_function_to_use(
|
||||||
|
sample_dipoles[x : x + self.initial_cost_chunk_size]
|
||||||
|
)
|
||||||
|
)
|
||||||
costs = numpy.array(raw_costs)
|
costs = numpy.array(raw_costs)
|
||||||
|
|
||||||
_logger.debug(f"costs: {costs}")
|
_logger.debug(f"costs: {costs}")
|
||||||
@ -147,13 +153,12 @@ class SubsetSimulation:
|
|||||||
stdevs = self.get_stdevs_from_arrays(next_seeds_as_array)
|
stdevs = self.get_stdevs_from_arrays(next_seeds_as_array)
|
||||||
_logger.info(f"got stdevs: {stdevs.stdevs}")
|
_logger.info(f"got stdevs: {stdevs.stdevs}")
|
||||||
all_long_chains = []
|
all_long_chains = []
|
||||||
for seed_index, (c, s) in enumerate(next_seeds[::len(next_seeds) // 20]):
|
for seed_index, (c, s) in enumerate(
|
||||||
|
next_seeds[:: len(next_seeds) // 20]
|
||||||
|
):
|
||||||
# chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
|
# chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
|
||||||
# until new version gotta do
|
# until new version gotta do
|
||||||
_logger.debug(
|
_logger.debug(f"\t{seed_index}: doing long chain on the next seed")
|
||||||
f"\t{seed_index}: doing long chain on the next seed"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
long_chain = self.model.get_mcmc_chain(
|
long_chain = self.model.get_mcmc_chain(
|
||||||
s,
|
s,
|
||||||
@ -175,7 +180,6 @@ class SubsetSimulation:
|
|||||||
delimiter=",",
|
delimiter=",",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.keep_probs_list:
|
if self.keep_probs_list:
|
||||||
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
||||||
probs_list.append(
|
probs_list.append(
|
||||||
|
@ -151,7 +151,7 @@ def test_bayesss_with_tighter_cost(snapshot):
|
|||||||
ss_default_upper_w_log_step=4,
|
ss_default_upper_w_log_step=4,
|
||||||
ss_dump_last_generation=False,
|
ss_dump_last_generation=False,
|
||||||
write_output_to_bayesruncsv=False,
|
write_output_to_bayesruncsv=False,
|
||||||
ss_initial_costs_chunk_size=1
|
ss_initial_costs_chunk_size=1,
|
||||||
)
|
)
|
||||||
result = square_run.go()
|
result = square_run.go()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user