feat: adds option to cap core count for temp aware run

This commit is contained in:
2023-04-13 20:16:33 -05:00
parent 959b9af378
commit 12903b2540

View File

@@ -90,6 +90,7 @@ class TempAwareRealSpectrumRun:
max_monte_carlo_cycles_steps: int = 10, max_monte_carlo_cycles_steps: int = 10,
chunksize: int = CHUNKSIZE, chunksize: int = CHUNKSIZE,
initial_seed: int = 12345, initial_seed: int = 12345,
cap_core_count: int = 0,
) -> None: ) -> None:
self.measurements_dict = measurements_dict self.measurements_dict = measurements_dict
self.dot_inputs_dict = { self.dot_inputs_dict = {
@@ -126,6 +127,8 @@ class TempAwareRealSpectrumRun:
self.filename = f"{timestamp}-{filename_slug}.realdata.{ff_string}.bayesrun.csv" self.filename = f"{timestamp}-{filename_slug}.realdata.{ff_string}.bayesrun.csv"
self.initial_seed = initial_seed self.initial_seed = initial_seed
self.cap_core_count = cap_core_count
def go(self) -> None: def go(self) -> None:
with open(self.filename, "a", newline="") as outfile: with open(self.filename, "a", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix") writer = csv.DictWriter(outfile, fieldnames=self.csv_fields, dialect="unix")
@@ -151,6 +154,8 @@ class TempAwareRealSpectrumRun:
): ):
_logger.debug(f"Doing model #{model_count}: {model_name}") _logger.debug(f"Doing model #{model_count}: {model_name}")
core_count = multiprocessing.cpu_count() - 1 or 1 core_count = multiprocessing.cpu_count() - 1 or 1
if (self.cap_core_count >= 1) and (self.cap_core_count < core_count):
core_count = self.cap_core_count
with multiprocessing.Pool(core_count) as pool: with multiprocessing.Pool(core_count) as pool:
cycle_count = 0 cycle_count = 0
cycle_success = 0 cycle_success = 0