feat: allows some betetr matching for single_dipole runs

This commit is contained in:
2024-08-26 03:31:15 -05:00
parent 6a5c5931d4
commit 5425ce1362
2 changed files with 53 additions and 12 deletions

View File

@@ -36,8 +36,8 @@ class DirectMonteCarloConfig:
tag: str = "" tag: str = ""
cap_core_count: int = 0 # 0 means cap at num cores - 1 cap_core_count: int = 0 # 0 means cap at num cores - 1
chunk_size: int = 50 chunk_size: int = 50
write_bayesrun_file = True write_bayesrun_file: bool = True
bayesrun_file_timestamp = True bayesrun_file_timestamp: bool = True
# chunk size of some kind # chunk size of some kind
@@ -145,15 +145,21 @@ class DirectMonteCarloRun:
single run wrapped up for multiprocessing call. single run wrapped up for multiprocessing call.
takes in a tuple of arguments corresponding to takes in a tuple of arguments corresponding to
(model_name_pair, seed) (model_name_pair, seed, return_configs)
return_configs is a boolean, if true then will return tuple of (count, [matching configs])
if false, return (count, [])
""" """
# here's where we do our work # here's where we do our work
model_name_pair, seed = args model_name_pair, seed, return_configs = args
cycle_success_configs = self._single_run(model_name_pair, seed) cycle_success_configs = self._single_run(model_name_pair, seed)
cycle_success_count = len(cycle_success_configs) cycle_success_count = len(cycle_success_configs)
return cycle_success_count if return_configs:
return (cycle_success_count, cycle_success_configs)
else:
return (cycle_success_count, [])
def execute_no_multiprocessing(self) -> Sequence[DirectMonteCarloResult]: def execute_no_multiprocessing(self) -> Sequence[DirectMonteCarloResult]:
@@ -198,9 +204,11 @@ class DirectMonteCarloRun:
) )
dipole_count = numpy.array(cycle_success_configs).shape[1] dipole_count = numpy.array(cycle_success_configs).shape[1]
for n in range(dipole_count): for n in range(dipole_count):
number_dipoles_to_write = self.config.target_success * 5
_logger.info(f"Limiting to {number_dipoles_to_write=}")
numpy.savetxt( numpy.savetxt(
f"{self.config.tag}_{step_count}_{cycle_i}_dipole_{n}.csv", f"{self.config.tag}_{step_count}_{cycle_i}_dipole_{n}.csv",
sorted_by_freq[:, n], sorted_by_freq[:number_dipoles_to_write, n],
delimiter=",", delimiter=",",
) )
total_success += cycle_success_count total_success += cycle_success_count
@@ -259,13 +267,45 @@ class DirectMonteCarloRun:
seeds = seed_sequence.spawn(self.config.monte_carlo_cycles) seeds = seed_sequence.spawn(self.config.monte_carlo_cycles)
pool_results = sum( raw_pool_results = list(pool.imap_unordered(
pool.imap_unordered(
self._wrapped_single_run, self._wrapped_single_run,
[(model_name_pair, seed) for seed in seeds], [
(model_name_pair, seed, self.config.write_successes_to_file)
for seed in seeds
],
self.config.chunk_size, self.config.chunk_size,
))
pool_results = sum(result[0] for result in raw_pool_results)
if self.config.write_successes_to_file:
cycle_success_configs = numpy.concatenate(
[result[1] for result in raw_pool_results]
) )
if len(cycle_success_configs):
sorted_by_freq = numpy.array(
[
pdme.subspace_simulation.sort_array_of_dipoles_by_frequency(
dipole_config
) )
for dipole_config in cycle_success_configs
]
)
dipole_count = numpy.array(cycle_success_configs).shape[1]
number_dipoles_to_write = self.config.target_success * 5
_logger.info(f"Limiting to {number_dipoles_to_write=}")
for n in range(dipole_count):
numpy.savetxt(
f"{self.config.tag}_{step_count}_dipole_{n}.csv",
sorted_by_freq[:: number_dipoles_to_write, n],
delimiter=",",
)
else:
_logger.debug("Instructed to write results, but none obtained")
_logger.debug(f"Pool results: {pool_results}") _logger.debug(f"Pool results: {pool_results}")
total_success += pool_results total_success += pool_results

View File

@@ -8,6 +8,7 @@ FILE_SLUG_REGEXES = [
r"(?P<tag>\w+)-(?P<job_index>\d+)", r"(?P<tag>\w+)-(?P<job_index>\d+)",
r"mock_tarucha-(?P<job_index>\d+)", r"mock_tarucha-(?P<job_index>\d+)",
r"(?:(?P<mock>mock)_)?tarucha(?:_(?P<tarucha_run_id>\d+))?-(?P<job_index>\d+)", r"(?:(?P<mock>mock)_)?tarucha(?:_(?P<tarucha_run_id>\d+))?-(?P<job_index>\d+)",
r"(?P<tag>\w+)-(?P<included_dots>[\w,]+)-(?P<target_cost>\d*\.?\d+)-(?P<job_index>\d+)",
] ]
] ]