feat!: bayes run now handles multidipoles with changes to output file format etc.
All checks were successful
gitea-physics/deepdog/pipeline/pr-master This commit looks good
All checks were successful
gitea-physics/deepdog/pipeline/pr-master This commit looks good
This commit is contained in:
parent
67a9721c31
commit
5d0a7a4be0
@ -23,9 +23,11 @@ _logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_a_result(input) -> int:
|
def get_a_result(input) -> int:
|
||||||
model, dot_inputs, lows, highs, monte_carlo_count, max_frequency = input
|
model, dot_inputs, lows, highs, monte_carlo_count, max_frequency, seed = input
|
||||||
sample_dipoles = model.get_model().get_n_single_dipoles(
|
|
||||||
monte_carlo_count, max_frequency
|
rng = numpy.random.default_rng(seed)
|
||||||
|
sample_dipoles = model.get_monte_carlo_dipole_inputs(
|
||||||
|
monte_carlo_count, max_frequency, rng_to_use=rng
|
||||||
)
|
)
|
||||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
||||||
return numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs))
|
return numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs))
|
||||||
@ -88,8 +90,6 @@ class BayesRun:
|
|||||||
run_count: int = 100,
|
run_count: int = 100,
|
||||||
low_error: float = 0.9,
|
low_error: float = 0.9,
|
||||||
high_error: float = 1.1,
|
high_error: float = 1.1,
|
||||||
pairs_high_error=None,
|
|
||||||
pairs_low_error=None,
|
|
||||||
monte_carlo_count: int = 10000,
|
monte_carlo_count: int = 10000,
|
||||||
monte_carlo_cycles: int = 10,
|
monte_carlo_cycles: int = 10,
|
||||||
target_success: int = 100,
|
target_success: int = 100,
|
||||||
@ -97,7 +97,6 @@ class BayesRun:
|
|||||||
max_frequency: float = 20,
|
max_frequency: float = 20,
|
||||||
end_threshold: float = None,
|
end_threshold: float = None,
|
||||||
chunksize: int = CHUNKSIZE,
|
chunksize: int = CHUNKSIZE,
|
||||||
use_pairs: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dot_inputs = pdme.inputs.inputs_with_frequency_range(
|
self.dot_inputs = pdme.inputs.inputs_with_frequency_range(
|
||||||
dot_positions, frequency_range
|
dot_positions, frequency_range
|
||||||
@ -106,18 +105,16 @@ class BayesRun:
|
|||||||
self.dot_inputs
|
self.dot_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_pairs = use_pairs
|
|
||||||
|
|
||||||
self.dot_pair_inputs = pdme.inputs.input_pairs_with_frequency_range(
|
|
||||||
dot_positions, frequency_range
|
|
||||||
)
|
|
||||||
self.dot_pair_inputs_array = (
|
|
||||||
pdme.measurement.input_types.dot_pair_inputs_to_array(self.dot_pair_inputs)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.models = [model for (_, model) in models_with_names]
|
self.models = [model for (_, model) in models_with_names]
|
||||||
self.model_names = [name for (name, _) in models_with_names]
|
self.model_names = [name for (name, _) in models_with_names]
|
||||||
self.actual_model = actual_model
|
self.actual_model = actual_model
|
||||||
|
|
||||||
|
self.n: int
|
||||||
|
try:
|
||||||
|
self.n = self.actual_model.n # type: ignore
|
||||||
|
except AttributeError:
|
||||||
|
self.n = 1
|
||||||
|
|
||||||
self.model_count = len(self.models)
|
self.model_count = len(self.models)
|
||||||
self.monte_carlo_count = monte_carlo_count
|
self.monte_carlo_count = monte_carlo_count
|
||||||
self.monte_carlo_cycles = monte_carlo_cycles
|
self.monte_carlo_cycles = monte_carlo_cycles
|
||||||
@ -126,15 +123,16 @@ class BayesRun:
|
|||||||
self.run_count = run_count
|
self.run_count = run_count
|
||||||
self.low_error = low_error
|
self.low_error = low_error
|
||||||
self.high_error = high_error
|
self.high_error = high_error
|
||||||
if pairs_low_error is None:
|
|
||||||
self.pairs_low_error = self.low_error
|
self.csv_fields = []
|
||||||
else:
|
for i in range(self.n):
|
||||||
self.pairs_low_error = pairs_low_error
|
self.csv_fields.extend(
|
||||||
if pairs_high_error is None:
|
[
|
||||||
self.pairs_high_error = self.high_error
|
f"dipole_moment_{i+1}",
|
||||||
else:
|
f"dipole_location_{i+1}",
|
||||||
self.pairs_high_error = pairs_high_error
|
f"dipole_frequency_{i+1}",
|
||||||
self.csv_fields = ["dipole_moment", "dipole_location", "dipole_frequency"]
|
]
|
||||||
|
)
|
||||||
self.compensate_zeros = True
|
self.compensate_zeros = True
|
||||||
self.chunksize = chunksize
|
self.chunksize = chunksize
|
||||||
for name in self.model_names:
|
for name in self.model_names:
|
||||||
@ -143,10 +141,7 @@ class BayesRun:
|
|||||||
self.probabilities = [1 / self.model_count] * self.model_count
|
self.probabilities = [1 / self.model_count] * self.model_count
|
||||||
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
if self.use_pairs:
|
self.filename = f"{timestamp}-{filename_slug}.bayesrun.csv"
|
||||||
self.filename = f"{timestamp}-{filename_slug}.altbayes.pairs.csv"
|
|
||||||
else:
|
|
||||||
self.filename = f"{timestamp}-{filename_slug}.altbayes.csv"
|
|
||||||
self.max_frequency = max_frequency
|
self.max_frequency = max_frequency
|
||||||
|
|
||||||
if end_threshold is not None:
|
if end_threshold is not None:
|
||||||
@ -179,29 +174,17 @@ class BayesRun:
|
|||||||
dots
|
dots
|
||||||
)
|
)
|
||||||
|
|
||||||
pair_lows, pair_highs = (None, None)
|
|
||||||
if self.use_pairs:
|
|
||||||
pair_measurements = (
|
|
||||||
actual_dipoles.get_percent_range_dot_pair_measurements(
|
|
||||||
self.dot_pair_inputs,
|
|
||||||
self.pairs_low_error,
|
|
||||||
self.pairs_high_error,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
(
|
|
||||||
pair_lows,
|
|
||||||
pair_highs,
|
|
||||||
) = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(
|
|
||||||
pair_measurements
|
|
||||||
)
|
|
||||||
|
|
||||||
_logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}")
|
_logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}")
|
||||||
|
|
||||||
|
# define a new seed sequence for each run
|
||||||
|
seed_sequence = numpy.random.SeedSequence(run)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
_logger.debug("Going to iterate over models now")
|
_logger.debug("Going to iterate over models now")
|
||||||
for model_count, model in enumerate(self.models):
|
for model_count, model in enumerate(self.models):
|
||||||
_logger.debug(f"Doing model #{model_count}")
|
_logger.debug(f"Doing model #{model_count}")
|
||||||
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
|
core_count = multiprocessing.cpu_count() - 1 or 1
|
||||||
|
with multiprocessing.Pool(core_count) as pool:
|
||||||
cycle_count = 0
|
cycle_count = 0
|
||||||
cycle_success = 0
|
cycle_success = 0
|
||||||
cycles = 0
|
cycles = 0
|
||||||
@ -212,55 +195,50 @@ class BayesRun:
|
|||||||
cycles += 1
|
cycles += 1
|
||||||
current_success = 0
|
current_success = 0
|
||||||
cycle_count += self.monte_carlo_count * self.monte_carlo_cycles
|
cycle_count += self.monte_carlo_count * self.monte_carlo_cycles
|
||||||
if self.use_pairs:
|
|
||||||
current_success = sum(
|
# generate a seed from the sequence for each core.
|
||||||
pool.imap_unordered(
|
# note this needs to be inside the loop for monte carlo cycle steps!
|
||||||
get_a_result_using_pairs,
|
# that way we get more stuff.
|
||||||
[
|
seeds = seed_sequence.spawn(self.monte_carlo_cycles)
|
||||||
(
|
|
||||||
model,
|
current_success = sum(
|
||||||
self.dot_inputs_array,
|
pool.imap_unordered(
|
||||||
self.dot_pair_inputs_array,
|
get_a_result,
|
||||||
lows,
|
[
|
||||||
highs,
|
(
|
||||||
pair_lows,
|
model,
|
||||||
pair_highs,
|
self.dot_inputs_array,
|
||||||
self.monte_carlo_count,
|
lows,
|
||||||
self.max_frequency,
|
highs,
|
||||||
)
|
self.monte_carlo_count,
|
||||||
]
|
self.max_frequency,
|
||||||
* self.monte_carlo_cycles,
|
seed,
|
||||||
self.chunksize,
|
)
|
||||||
)
|
for seed in seeds
|
||||||
)
|
],
|
||||||
else:
|
self.chunksize,
|
||||||
current_success = sum(
|
|
||||||
pool.imap_unordered(
|
|
||||||
get_a_result,
|
|
||||||
[
|
|
||||||
(
|
|
||||||
model,
|
|
||||||
self.dot_inputs_array,
|
|
||||||
lows,
|
|
||||||
highs,
|
|
||||||
self.monte_carlo_count,
|
|
||||||
self.max_frequency,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
* self.monte_carlo_cycles,
|
|
||||||
self.chunksize,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
cycle_success += current_success
|
cycle_success += current_success
|
||||||
results.append((cycle_count, cycle_success))
|
results.append((cycle_count, cycle_success))
|
||||||
|
|
||||||
_logger.debug("Done, constructing output now")
|
_logger.debug("Done, constructing output now")
|
||||||
row = {
|
row = {
|
||||||
"dipole_moment": actual_dipoles.dipoles[0].p,
|
"dipole_moment_1": actual_dipoles.dipoles[0].p,
|
||||||
"dipole_location": actual_dipoles.dipoles[0].s,
|
"dipole_location_1": actual_dipoles.dipoles[0].s,
|
||||||
"dipole_frequency": actual_dipoles.dipoles[0].w,
|
"dipole_frequency_1": actual_dipoles.dipoles[0].w,
|
||||||
}
|
}
|
||||||
|
for i in range(1, self.n):
|
||||||
|
try:
|
||||||
|
current_dipoles = actual_dipoles.dipoles[i]
|
||||||
|
row[f"dipole_moment_{i+1}"] = current_dipoles.p
|
||||||
|
row[f"dipole_location_{i+1}"] = current_dipoles.s
|
||||||
|
row[f"dipole_frequency_{i+1}"] = current_dipoles.w
|
||||||
|
except IndexError:
|
||||||
|
_logger.info(f"Not writing anymore, saw end after {i}")
|
||||||
|
break
|
||||||
|
|
||||||
successes: List[float] = []
|
successes: List[float] = []
|
||||||
counts: List[int] = []
|
counts: List[int] = []
|
||||||
for model_index, (name, (count, result)) in enumerate(
|
for model_index, (name, (count, result)) in enumerate(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user