diff --git a/deepdog/direct_monte_carlo/direct_mc.py b/deepdog/direct_monte_carlo/direct_mc.py index 50ada69..f77a0f4 100644 --- a/deepdog/direct_monte_carlo/direct_mc.py +++ b/deepdog/direct_monte_carlo/direct_mc.py @@ -14,6 +14,8 @@ import multiprocessing _logger = logging.getLogger(__name__) +ANTI_ZERO_SUCCESS_THRES = 0.1 + @dataclass class DirectMonteCarloResult: @@ -292,7 +294,11 @@ class DirectMonteCarloRun: num_models = len(self.model_name_pairs) success_weight = sum( [ - (res.successes / res.monte_carlo_count) / num_models + ( + max(ANTI_ZERO_SUCCESS_THRES, res.successes) + / res.monte_carlo_count + ) + / num_models for res in results ] ) @@ -303,7 +309,8 @@ class DirectMonteCarloRun: f"{res.model_name}_success": res.successes, f"{res.model_name}_count": res.monte_carlo_count, f"{res.model_name}_prob": ( - res.successes / res.monte_carlo_count + max(ANTI_ZERO_SUCCESS_THRES, res.successes) + / res.monte_carlo_count ) / (num_models * success_weight), }