feat: adds longchain logging if logging last generation
Some checks failed
gitea-physics/deepdog/pipeline/head There was a failure building this commit

This commit is contained in:
Deepak Mallubhotla 2023-08-12 19:48:30 -05:00
parent f7559b2c4f
commit b4e5f53726
Signed by: deepak
GPG Key ID: BEBAEBF28083E022

View File

@ -143,6 +143,39 @@ class SubsetSimulation:
delimiter=",", delimiter=",",
) )
next_seeds_as_array = numpy.array([s for _, s in next_seeds])
stdevs = self.get_stdevs_from_arrays(next_seeds_as_array)
_logger.info(f"got stdevs: {stdevs.stdevs}")
all_long_chains = []
for seed_index, (c, s) in enumerate(next_seeds[::len(next_seeds) // 20]):
# chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
# until new version gotta do
_logger.debug(
f"\t{seed_index}: doing long chain on the next seed"
)
long_chain = self.model.get_mcmc_chain(
s,
self.cost_function_to_use,
1000,
threshold_cost,
stdevs,
initial_cost=c,
rng_arg=mcmc_rng,
)
for _, chained in long_chain:
all_long_chains.append(chained)
all_long_chains_array = numpy.array(all_long_chains)
for n in range(self.model.n):
_logger.info(f"{all_long_chains_array[:, n].shape}")
numpy.savetxt(
f"long_chain_generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv",
all_long_chains_array[:, n],
delimiter=",",
)
if self.keep_probs_list: if self.keep_probs_list:
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]): for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
probs_list.append( probs_list.append(