style: run doo fmt
All checks were successful
gitea-physics/deepdog/pipeline/head This commit looks good
All checks were successful
gitea-physics/deepdog/pipeline/head This commit looks good
This commit is contained in:
parent
492a5e6681
commit
f00b29391c
@ -10,7 +10,13 @@ def get_version():
|
|||||||
return __version__
|
return __version__
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["get_version", "BayesRun", "AltBayesRun", "AltBayesRunSimulPairs", "Diagnostic"]
|
__all__ = [
|
||||||
|
"get_version",
|
||||||
|
"BayesRun",
|
||||||
|
"AltBayesRun",
|
||||||
|
"AltBayesRunSimulPairs",
|
||||||
|
"Diagnostic",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
||||||
|
@ -24,24 +24,42 @@ _logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def get_a_result(input) -> int:
|
def get_a_result(input) -> int:
|
||||||
discretisation, dot_inputs, lows, highs, monte_carlo_count, max_frequency = input
|
discretisation, dot_inputs, lows, highs, monte_carlo_count, max_frequency = input
|
||||||
sample_dipoles = discretisation.get_model().get_n_single_dipoles(monte_carlo_count, max_frequency)
|
sample_dipoles = discretisation.get_model().get_n_single_dipoles(
|
||||||
|
monte_carlo_count, max_frequency
|
||||||
|
)
|
||||||
vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
||||||
return numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs))
|
return numpy.count_nonzero(pdme.util.fast_v_calc.between(vals, lows, highs))
|
||||||
|
|
||||||
|
|
||||||
def get_a_result_using_pairs(input) -> int:
|
def get_a_result_using_pairs(input) -> int:
|
||||||
discretisation, dot_inputs, pair_inputs, local_lows, local_highs, nonlocal_lows, nonlocal_highs, monte_carlo_count, max_frequency = input
|
(
|
||||||
sample_dipoles = discretisation.get_model().get_n_single_dipoles(monte_carlo_count, max_frequency)
|
discretisation,
|
||||||
|
dot_inputs,
|
||||||
|
pair_inputs,
|
||||||
|
local_lows,
|
||||||
|
local_highs,
|
||||||
|
nonlocal_lows,
|
||||||
|
nonlocal_highs,
|
||||||
|
monte_carlo_count,
|
||||||
|
max_frequency,
|
||||||
|
) = input
|
||||||
|
sample_dipoles = discretisation.get_model().get_n_single_dipoles(
|
||||||
|
monte_carlo_count, max_frequency
|
||||||
|
)
|
||||||
local_vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
local_vals = pdme.util.fast_v_calc.fast_vs_for_dipoles(dot_inputs, sample_dipoles)
|
||||||
local_matches = pdme.util.fast_v_calc.between(local_vals, local_lows, local_highs)
|
local_matches = pdme.util.fast_v_calc.between(local_vals, local_lows, local_highs)
|
||||||
nonlocal_vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal(pair_inputs, sample_dipoles)
|
nonlocal_vals = pdme.util.fast_nonlocal_spectrum.fast_s_nonlocal(
|
||||||
nonlocal_matches = pdme.util.fast_v_calc.between(nonlocal_vals, nonlocal_lows, nonlocal_highs)
|
pair_inputs, sample_dipoles
|
||||||
|
)
|
||||||
|
nonlocal_matches = pdme.util.fast_v_calc.between(
|
||||||
|
nonlocal_vals, nonlocal_lows, nonlocal_highs
|
||||||
|
)
|
||||||
combined_matches = numpy.logical_and(local_matches, nonlocal_matches)
|
combined_matches = numpy.logical_and(local_matches, nonlocal_matches)
|
||||||
return numpy.count_nonzero(combined_matches)
|
return numpy.count_nonzero(combined_matches)
|
||||||
|
|
||||||
|
|
||||||
class AltBayesRun():
|
class AltBayesRun:
|
||||||
'''
|
"""
|
||||||
A single Bayes run for a given set of dots.
|
A single Bayes run for a given set of dots.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -56,15 +74,44 @@ class AltBayesRun():
|
|||||||
The filename slug to include.
|
The filename slug to include.
|
||||||
run_count: int
|
run_count: int
|
||||||
The number of runs to do.
|
The number of runs to do.
|
||||||
'''
|
"""
|
||||||
def __init__(self, dot_positions: Sequence[numpy.typing.ArrayLike], frequency_range: Sequence[float], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int = 100, low_error: float = 0.9, high_error: float = 1.1, pairs_high_error=None, pairs_low_error=None, monte_carlo_count: int = 10000, monte_carlo_cycles: int = 10, target_success: int = 100, max_monte_carlo_cycles_steps: int = 10, max_frequency: float = 20, end_threshold: float = None, chunksize: int = CHUNKSIZE, use_pairs: bool = False) -> None:
|
|
||||||
self.dot_inputs = pdme.inputs.inputs_with_frequency_range(dot_positions, frequency_range)
|
def __init__(
|
||||||
self.dot_inputs_array = pdme.measurement.input_types.dot_inputs_to_array(self.dot_inputs)
|
self,
|
||||||
|
dot_positions: Sequence[numpy.typing.ArrayLike],
|
||||||
|
frequency_range: Sequence[float],
|
||||||
|
discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]],
|
||||||
|
actual_model: pdme.model.Model,
|
||||||
|
filename_slug: str,
|
||||||
|
run_count: int = 100,
|
||||||
|
low_error: float = 0.9,
|
||||||
|
high_error: float = 1.1,
|
||||||
|
pairs_high_error=None,
|
||||||
|
pairs_low_error=None,
|
||||||
|
monte_carlo_count: int = 10000,
|
||||||
|
monte_carlo_cycles: int = 10,
|
||||||
|
target_success: int = 100,
|
||||||
|
max_monte_carlo_cycles_steps: int = 10,
|
||||||
|
max_frequency: float = 20,
|
||||||
|
end_threshold: float = None,
|
||||||
|
chunksize: int = CHUNKSIZE,
|
||||||
|
use_pairs: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.dot_inputs = pdme.inputs.inputs_with_frequency_range(
|
||||||
|
dot_positions, frequency_range
|
||||||
|
)
|
||||||
|
self.dot_inputs_array = pdme.measurement.input_types.dot_inputs_to_array(
|
||||||
|
self.dot_inputs
|
||||||
|
)
|
||||||
|
|
||||||
self.use_pairs = use_pairs
|
self.use_pairs = use_pairs
|
||||||
|
|
||||||
self.dot_pair_inputs = pdme.inputs.input_pairs_with_frequency_range(dot_positions, frequency_range)
|
self.dot_pair_inputs = pdme.inputs.input_pairs_with_frequency_range(
|
||||||
self.dot_pair_inputs_array = pdme.measurement.input_types.dot_pair_inputs_to_array(self.dot_pair_inputs)
|
dot_positions, frequency_range
|
||||||
|
)
|
||||||
|
self.dot_pair_inputs_array = (
|
||||||
|
pdme.measurement.input_types.dot_pair_inputs_to_array(self.dot_pair_inputs)
|
||||||
|
)
|
||||||
|
|
||||||
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
||||||
self.model_names = [name for (name, _) in discretisations_with_names]
|
self.model_names = [name for (name, _) in discretisations_with_names]
|
||||||
@ -106,7 +153,9 @@ class AltBayesRun():
|
|||||||
self.use_end_threshold = True
|
self.use_end_threshold = True
|
||||||
_logger.info(f"Will abort early, at {self.end_threshold}.")
|
_logger.info(f"Will abort early, at {self.end_threshold}.")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"end_threshold should be between 0 and 1, but is actually {end_threshold}")
|
raise ValueError(
|
||||||
|
f"end_threshold should be between 0 and 1, but is actually {end_threshold}"
|
||||||
|
)
|
||||||
|
|
||||||
def go(self) -> None:
|
def go(self) -> None:
|
||||||
with open(self.filename, "a", newline="") as outfile:
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
@ -121,13 +170,31 @@ class AltBayesRun():
|
|||||||
# Generate the actual dipoles
|
# Generate the actual dipoles
|
||||||
actual_dipoles = self.actual_model.get_dipoles(frequency)
|
actual_dipoles = self.actual_model.get_dipoles(frequency)
|
||||||
|
|
||||||
dots = actual_dipoles.get_percent_range_dot_measurements(self.dot_inputs, self.low_error, self.high_error)
|
dots = actual_dipoles.get_percent_range_dot_measurements(
|
||||||
lows, highs = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(dots)
|
self.dot_inputs, self.low_error, self.high_error
|
||||||
|
)
|
||||||
|
(
|
||||||
|
lows,
|
||||||
|
highs,
|
||||||
|
) = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(
|
||||||
|
dots
|
||||||
|
)
|
||||||
|
|
||||||
pair_lows, pair_highs = (None, None)
|
pair_lows, pair_highs = (None, None)
|
||||||
if self.use_pairs:
|
if self.use_pairs:
|
||||||
pair_measurements = actual_dipoles.get_percent_range_dot_pair_measurements(self.dot_pair_inputs, self.pairs_low_error, self.pairs_high_error)
|
pair_measurements = (
|
||||||
pair_lows, pair_highs = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(pair_measurements)
|
actual_dipoles.get_percent_range_dot_pair_measurements(
|
||||||
|
self.dot_pair_inputs,
|
||||||
|
self.pairs_low_error,
|
||||||
|
self.pairs_high_error,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(
|
||||||
|
pair_lows,
|
||||||
|
pair_highs,
|
||||||
|
) = pdme.measurement.input_types.dot_range_measurements_low_high_arrays(
|
||||||
|
pair_measurements
|
||||||
|
)
|
||||||
|
|
||||||
_logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}")
|
_logger.info(f"Going to work on dipole at {actual_dipoles.dipoles}")
|
||||||
|
|
||||||
@ -139,18 +206,51 @@ class AltBayesRun():
|
|||||||
cycle_count = 0
|
cycle_count = 0
|
||||||
cycle_success = 0
|
cycle_success = 0
|
||||||
cycles = 0
|
cycles = 0
|
||||||
while (cycles < self.max_monte_carlo_cycles_steps) and (cycle_success <= self.target_success):
|
while (cycles < self.max_monte_carlo_cycles_steps) and (
|
||||||
|
cycle_success <= self.target_success
|
||||||
|
):
|
||||||
_logger.debug(f"Starting cycle {cycles}")
|
_logger.debug(f"Starting cycle {cycles}")
|
||||||
cycles += 1
|
cycles += 1
|
||||||
current_success = 0
|
current_success = 0
|
||||||
cycle_count += self.monte_carlo_count * self.monte_carlo_cycles
|
cycle_count += self.monte_carlo_count * self.monte_carlo_cycles
|
||||||
if self.use_pairs:
|
if self.use_pairs:
|
||||||
current_success = sum(
|
current_success = sum(
|
||||||
pool.imap_unordered(get_a_result_using_pairs, [(discretisation, self.dot_inputs_array, self.dot_pair_inputs_array, lows, highs, pair_lows, pair_highs, self.monte_carlo_count, self.max_frequency)] * self.monte_carlo_cycles, self.chunksize)
|
pool.imap_unordered(
|
||||||
|
get_a_result_using_pairs,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
discretisation,
|
||||||
|
self.dot_inputs_array,
|
||||||
|
self.dot_pair_inputs_array,
|
||||||
|
lows,
|
||||||
|
highs,
|
||||||
|
pair_lows,
|
||||||
|
pair_highs,
|
||||||
|
self.monte_carlo_count,
|
||||||
|
self.max_frequency,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
* self.monte_carlo_cycles,
|
||||||
|
self.chunksize,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
current_success = sum(
|
current_success = sum(
|
||||||
pool.imap_unordered(get_a_result, [(discretisation, self.dot_inputs_array, lows, highs, self.monte_carlo_count, self.max_frequency)] * self.monte_carlo_cycles, self.chunksize)
|
pool.imap_unordered(
|
||||||
|
get_a_result,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
discretisation,
|
||||||
|
self.dot_inputs_array,
|
||||||
|
lows,
|
||||||
|
highs,
|
||||||
|
self.monte_carlo_count,
|
||||||
|
self.max_frequency,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
* self.monte_carlo_cycles,
|
||||||
|
self.chunksize,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
cycle_success += current_success
|
cycle_success += current_success
|
||||||
@ -160,30 +260,44 @@ class AltBayesRun():
|
|||||||
row = {
|
row = {
|
||||||
"dipole_moment": actual_dipoles.dipoles[0].p,
|
"dipole_moment": actual_dipoles.dipoles[0].p,
|
||||||
"dipole_location": actual_dipoles.dipoles[0].s,
|
"dipole_location": actual_dipoles.dipoles[0].s,
|
||||||
"dipole_frequency": actual_dipoles.dipoles[0].w
|
"dipole_frequency": actual_dipoles.dipoles[0].w,
|
||||||
}
|
}
|
||||||
successes: List[float] = []
|
successes: List[float] = []
|
||||||
counts: List[int] = []
|
counts: List[int] = []
|
||||||
for model_index, (name, (count, result)) in enumerate(zip(self.model_names, results)):
|
for model_index, (name, (count, result)) in enumerate(
|
||||||
|
zip(self.model_names, results)
|
||||||
|
):
|
||||||
|
|
||||||
row[f"{name}_success"] = result
|
row[f"{name}_success"] = result
|
||||||
row[f"{name}_count"] = count
|
row[f"{name}_count"] = count
|
||||||
successes.append(max(result, 0.5))
|
successes.append(max(result, 0.5))
|
||||||
counts.append(count)
|
counts.append(count)
|
||||||
|
|
||||||
success_weight = sum([(succ / count) * prob for succ, count, prob in zip(successes, counts, self.probabilities)])
|
success_weight = sum(
|
||||||
new_probabilities = [(succ / count) * old_prob / success_weight for succ, count, old_prob in zip(successes, counts, self.probabilities)]
|
[
|
||||||
|
(succ / count) * prob
|
||||||
|
for succ, count, prob in zip(successes, counts, self.probabilities)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
new_probabilities = [
|
||||||
|
(succ / count) * old_prob / success_weight
|
||||||
|
for succ, count, old_prob in zip(successes, counts, self.probabilities)
|
||||||
|
]
|
||||||
self.probabilities = new_probabilities
|
self.probabilities = new_probabilities
|
||||||
for name, probability in zip(self.model_names, self.probabilities):
|
for name, probability in zip(self.model_names, self.probabilities):
|
||||||
row[f"{name}_prob"] = probability
|
row[f"{name}_prob"] = probability
|
||||||
_logger.info(row)
|
_logger.info(row)
|
||||||
|
|
||||||
with open(self.filename, "a", newline="") as outfile:
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
writer = csv.DictWriter(
|
||||||
|
outfile, fieldnames=self.csv_fields, dialect="unix"
|
||||||
|
)
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
if self.use_end_threshold:
|
if self.use_end_threshold:
|
||||||
max_prob = max(self.probabilities)
|
max_prob = max(self.probabilities)
|
||||||
if max_prob > self.end_threshold:
|
if max_prob > self.end_threshold:
|
||||||
_logger.info(f"Aborting early, because {max_prob} is greater than {self.end_threshold}")
|
_logger.info(
|
||||||
|
f"Aborting early, because {max_prob} is greater than {self.end_threshold}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
@ -206,7 +206,8 @@ class AltBayesRunSimulPairs:
|
|||||||
current_success_no_pairs = 0
|
current_success_no_pairs = 0
|
||||||
cycle_count += self.monte_carlo_count * self.monte_carlo_cycles
|
cycle_count += self.monte_carlo_count * self.monte_carlo_cycles
|
||||||
|
|
||||||
current_success_both = numpy.array(sum(
|
current_success_both = numpy.array(
|
||||||
|
sum(
|
||||||
pool.imap_unordered(
|
pool.imap_unordered(
|
||||||
get_a_simul_result_using_pairs,
|
get_a_simul_result_using_pairs,
|
||||||
[
|
[
|
||||||
@ -225,7 +226,8 @@ class AltBayesRunSimulPairs:
|
|||||||
* self.monte_carlo_cycles,
|
* self.monte_carlo_cycles,
|
||||||
self.chunksize,
|
self.chunksize,
|
||||||
)
|
)
|
||||||
))
|
)
|
||||||
|
)
|
||||||
current_success_no_pairs = current_success_both[0]
|
current_success_no_pairs = current_success_both[0]
|
||||||
current_success_pairs = current_success_both[1]
|
current_success_pairs = current_success_both[1]
|
||||||
|
|
||||||
|
@ -20,12 +20,14 @@ DotInput = Tuple[numpy.typing.ArrayLike, float]
|
|||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_a_result(discretisation, dots, index) -> Tuple[Tuple[int, ...], scipy.optimize.OptimizeResult]:
|
def get_a_result(
|
||||||
|
discretisation, dots, index
|
||||||
|
) -> Tuple[Tuple[int, ...], scipy.optimize.OptimizeResult]:
|
||||||
return (index, discretisation.solve_for_index(dots, index))
|
return (index, discretisation.solve_for_index(dots, index))
|
||||||
|
|
||||||
|
|
||||||
class BayesRun():
|
class BayesRun:
|
||||||
'''
|
"""
|
||||||
A single Bayes run for a given set of dots.
|
A single Bayes run for a given set of dots.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -40,8 +42,18 @@ class BayesRun():
|
|||||||
The filename slug to include.
|
The filename slug to include.
|
||||||
run_count: int
|
run_count: int
|
||||||
The number of runs to do.
|
The number of runs to do.
|
||||||
'''
|
"""
|
||||||
def __init__(self, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], actual_model: pdme.model.Model, filename_slug: str, run_count: int, max_frequency: float = None, end_threshold: float = None) -> None:
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dot_inputs: Sequence[DotInput],
|
||||||
|
discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]],
|
||||||
|
actual_model: pdme.model.Model,
|
||||||
|
filename_slug: str,
|
||||||
|
run_count: int,
|
||||||
|
max_frequency: float = None,
|
||||||
|
end_threshold: float = None,
|
||||||
|
) -> None:
|
||||||
self.dot_inputs = dot_inputs
|
self.dot_inputs = dot_inputs
|
||||||
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
self.discretisations = [disc for (_, disc) in discretisations_with_names]
|
||||||
self.model_names = [name for (name, _) in discretisations_with_names]
|
self.model_names = [name for (name, _) in discretisations_with_names]
|
||||||
@ -65,7 +77,9 @@ class BayesRun():
|
|||||||
self.use_end_threshold = True
|
self.use_end_threshold = True
|
||||||
_logger.info(f"Will abort early, at {self.end_threshold}.")
|
_logger.info(f"Will abort early, at {self.end_threshold}.")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"end_threshold should be between 0 and 1, but is actually {end_threshold}")
|
raise ValueError(
|
||||||
|
f"end_threshold should be between 0 and 1, but is actually {end_threshold}"
|
||||||
|
)
|
||||||
|
|
||||||
def go(self) -> None:
|
def go(self) -> None:
|
||||||
with open(self.filename, "a", newline="") as outfile:
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
@ -87,17 +101,28 @@ class BayesRun():
|
|||||||
for disc_count, discretisation in enumerate(self.discretisations):
|
for disc_count, discretisation in enumerate(self.discretisations):
|
||||||
_logger.debug(f"Doing discretisation #{disc_count}")
|
_logger.debug(f"Doing discretisation #{disc_count}")
|
||||||
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
|
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
|
||||||
results.append(pool.starmap(get_a_result, zip(itertools.repeat(discretisation), itertools.repeat(dots), discretisation.all_indices())))
|
results.append(
|
||||||
|
pool.starmap(
|
||||||
|
get_a_result,
|
||||||
|
zip(
|
||||||
|
itertools.repeat(discretisation),
|
||||||
|
itertools.repeat(dots),
|
||||||
|
discretisation.all_indices(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
_logger.debug("Done, constructing output now")
|
_logger.debug("Done, constructing output now")
|
||||||
row = {
|
row = {
|
||||||
"dipole_moment": dipoles.dipoles[0].p,
|
"dipole_moment": dipoles.dipoles[0].p,
|
||||||
"dipole_location": dipoles.dipoles[0].s,
|
"dipole_location": dipoles.dipoles[0].s,
|
||||||
"dipole_frequency": dipoles.dipoles[0].w
|
"dipole_frequency": dipoles.dipoles[0].w,
|
||||||
}
|
}
|
||||||
successes: List[float] = []
|
successes: List[float] = []
|
||||||
counts: List[int] = []
|
counts: List[int] = []
|
||||||
for model_index, (name, result) in enumerate(zip(self.model_names, results)):
|
for model_index, (name, result) in enumerate(
|
||||||
|
zip(self.model_names, results)
|
||||||
|
):
|
||||||
count = 0
|
count = 0
|
||||||
success = 0
|
success = 0
|
||||||
for idx, val in result:
|
for idx, val in result:
|
||||||
@ -110,19 +135,31 @@ class BayesRun():
|
|||||||
successes.append(max(success, 0.5))
|
successes.append(max(success, 0.5))
|
||||||
counts.append(count)
|
counts.append(count)
|
||||||
|
|
||||||
success_weight = sum([(succ / count) * prob for succ, count, prob in zip(successes, counts, self.probabilities)])
|
success_weight = sum(
|
||||||
new_probabilities = [(succ / count) * old_prob / success_weight for succ, count, old_prob in zip(successes, counts, self.probabilities)]
|
[
|
||||||
|
(succ / count) * prob
|
||||||
|
for succ, count, prob in zip(successes, counts, self.probabilities)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
new_probabilities = [
|
||||||
|
(succ / count) * old_prob / success_weight
|
||||||
|
for succ, count, old_prob in zip(successes, counts, self.probabilities)
|
||||||
|
]
|
||||||
self.probabilities = new_probabilities
|
self.probabilities = new_probabilities
|
||||||
for name, probability in zip(self.model_names, self.probabilities):
|
for name, probability in zip(self.model_names, self.probabilities):
|
||||||
row[f"{name}_prob"] = probability
|
row[f"{name}_prob"] = probability
|
||||||
_logger.info(row)
|
_logger.info(row)
|
||||||
|
|
||||||
with open(self.filename, "a", newline="") as outfile:
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
writer = csv.DictWriter(
|
||||||
|
outfile, fieldnames=self.csv_fields, dialect="unix"
|
||||||
|
)
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
if self.use_end_threshold:
|
if self.use_end_threshold:
|
||||||
max_prob = max(self.probabilities)
|
max_prob = max(self.probabilities)
|
||||||
if max_prob > self.end_threshold:
|
if max_prob > self.end_threshold:
|
||||||
_logger.info(f"Aborting early, because {max_prob} is greater than {self.end_threshold}")
|
_logger.info(
|
||||||
|
f"Aborting early, because {max_prob} is greater than {self.end_threshold}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
@ -18,7 +18,7 @@ def get_a_result(discretisation, dots, index):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SingleDipoleDiagnostic():
|
class SingleDipoleDiagnostic:
|
||||||
model: str
|
model: str
|
||||||
index: Tuple
|
index: Tuple
|
||||||
bounds: Tuple
|
bounds: Tuple
|
||||||
@ -43,8 +43,8 @@ class SingleDipoleDiagnostic():
|
|||||||
self.w_result = self.result_dipole.w
|
self.w_result = self.result_dipole.w
|
||||||
|
|
||||||
|
|
||||||
class Diagnostic():
|
class Diagnostic:
|
||||||
'''
|
"""
|
||||||
Represents a diagnostic for a single dipole moment given a set of discretisations.
|
Represents a diagnostic for a single dipole moment given a set of discretisations.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -59,15 +59,51 @@ class Diagnostic():
|
|||||||
The filename slug to include.
|
The filename slug to include.
|
||||||
run_count: int
|
run_count: int
|
||||||
The number of runs to do.
|
The number of runs to do.
|
||||||
'''
|
"""
|
||||||
def __init__(self, actual_dipole_moment: numpy.ndarray, actual_dipole_position: numpy.ndarray, actual_dipole_frequency: float, dot_inputs: Sequence[DotInput], discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]], filename_slug: str) -> None:
|
|
||||||
self.dipoles = OscillatingDipoleArrangement([OscillatingDipole(actual_dipole_moment, actual_dipole_position, actual_dipole_frequency)])
|
def __init__(
|
||||||
|
self,
|
||||||
|
actual_dipole_moment: numpy.ndarray,
|
||||||
|
actual_dipole_position: numpy.ndarray,
|
||||||
|
actual_dipole_frequency: float,
|
||||||
|
dot_inputs: Sequence[DotInput],
|
||||||
|
discretisations_with_names: Sequence[Tuple[str, pdme.model.Discretisation]],
|
||||||
|
filename_slug: str,
|
||||||
|
) -> None:
|
||||||
|
self.dipoles = OscillatingDipoleArrangement(
|
||||||
|
[
|
||||||
|
OscillatingDipole(
|
||||||
|
actual_dipole_moment,
|
||||||
|
actual_dipole_position,
|
||||||
|
actual_dipole_frequency,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
self.dots = self.dipoles.get_dot_measurements(dot_inputs)
|
self.dots = self.dipoles.get_dot_measurements(dot_inputs)
|
||||||
|
|
||||||
self.discretisations_with_names = discretisations_with_names
|
self.discretisations_with_names = discretisations_with_names
|
||||||
self.model_count = len(self.discretisations_with_names)
|
self.model_count = len(self.discretisations_with_names)
|
||||||
|
|
||||||
self.csv_fields = ["model", "index", "bounds", "p_actual_x", "p_actual_y", "p_actual_z", "s_actual_x", "s_actual_y", "s_actual_z", "w_actual", "success", "p_result_x", "p_result_y", "p_result_z", "s_result_x", "s_result_y", "s_result_z", "w_result"]
|
self.csv_fields = [
|
||||||
|
"model",
|
||||||
|
"index",
|
||||||
|
"bounds",
|
||||||
|
"p_actual_x",
|
||||||
|
"p_actual_y",
|
||||||
|
"p_actual_z",
|
||||||
|
"s_actual_x",
|
||||||
|
"s_actual_y",
|
||||||
|
"s_actual_z",
|
||||||
|
"w_actual",
|
||||||
|
"success",
|
||||||
|
"p_result_x",
|
||||||
|
"p_result_y",
|
||||||
|
"p_result_z",
|
||||||
|
"s_result_x",
|
||||||
|
"s_result_y",
|
||||||
|
"s_result_z",
|
||||||
|
"w_result",
|
||||||
|
]
|
||||||
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
self.filename = f"{timestamp}-{filename_slug}.diag.csv"
|
self.filename = f"{timestamp}-{filename_slug}.diag.csv"
|
||||||
@ -75,7 +111,7 @@ class Diagnostic():
|
|||||||
def go(self):
|
def go(self):
|
||||||
with open(self.filename, "a", newline="") as outfile:
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
# csv fields
|
# csv fields
|
||||||
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect='unix')
|
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
||||||
for (name, discretisation) in self.discretisations_with_names:
|
for (name, discretisation) in self.discretisations_with_names:
|
||||||
@ -83,17 +119,38 @@ class Diagnostic():
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
|
with multiprocessing.Pool(multiprocessing.cpu_count() - 1 or 1) as pool:
|
||||||
results = pool.starmap(get_a_result, zip(itertools.repeat(discretisation), itertools.repeat(self.dots), discretisation.all_indices()))
|
results = pool.starmap(
|
||||||
|
get_a_result,
|
||||||
|
zip(
|
||||||
|
itertools.repeat(discretisation),
|
||||||
|
itertools.repeat(self.dots),
|
||||||
|
discretisation.all_indices(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
with open(self.filename, "a", newline='') as outfile:
|
with open(self.filename, "a", newline="") as outfile:
|
||||||
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect='unix', extrasaction="ignore")
|
writer = csv.DictWriter(
|
||||||
|
outfile,
|
||||||
|
fieldnames=self.csv_fields,
|
||||||
|
dialect="unix",
|
||||||
|
extrasaction="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
for idx, result in results:
|
for idx, result in results:
|
||||||
|
|
||||||
bounds = discretisation.bounds(idx)
|
bounds = discretisation.bounds(idx)
|
||||||
|
|
||||||
actual_success = result.success and result.cost <= 1e-10
|
actual_success = result.success and result.cost <= 1e-10
|
||||||
diag_row = SingleDipoleDiagnostic(name, idx, bounds, self.dipoles.dipoles[0], discretisation.model.solution_as_dipoles(result.normalised_x)[0], actual_success)
|
diag_row = SingleDipoleDiagnostic(
|
||||||
|
name,
|
||||||
|
idx,
|
||||||
|
bounds,
|
||||||
|
self.dipoles.dipoles[0],
|
||||||
|
discretisation.model.solution_as_dipoles(result.normalised_x)[
|
||||||
|
0
|
||||||
|
],
|
||||||
|
actual_success,
|
||||||
|
)
|
||||||
row = vars(diag_row)
|
row = vars(diag_row)
|
||||||
_logger.debug(f"Writing result {row}")
|
_logger.debug(f"Writing result {row}")
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
__version__ = version('deepdog')
|
__version__ = version("deepdog")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user