feat: adds longchain logging if logging last generation
Some checks failed
gitea-physics/deepdog/pipeline/head There was a failure building this commit
Some checks failed
gitea-physics/deepdog/pipeline/head There was a failure building this commit
This commit is contained in:
parent
f7559b2c4f
commit
b4e5f53726
@ -143,6 +143,39 @@ class SubsetSimulation:
|
|||||||
delimiter=",",
|
delimiter=",",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_seeds_as_array = numpy.array([s for _, s in next_seeds])
|
||||||
|
stdevs = self.get_stdevs_from_arrays(next_seeds_as_array)
|
||||||
|
_logger.info(f"got stdevs: {stdevs.stdevs}")
|
||||||
|
all_long_chains = []
|
||||||
|
for seed_index, (c, s) in enumerate(next_seeds[::len(next_seeds) // 20]):
|
||||||
|
# chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
|
||||||
|
# until new version gotta do
|
||||||
|
_logger.debug(
|
||||||
|
f"\t{seed_index}: doing long chain on the next seed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
long_chain = self.model.get_mcmc_chain(
|
||||||
|
s,
|
||||||
|
self.cost_function_to_use,
|
||||||
|
1000,
|
||||||
|
threshold_cost,
|
||||||
|
stdevs,
|
||||||
|
initial_cost=c,
|
||||||
|
rng_arg=mcmc_rng,
|
||||||
|
)
|
||||||
|
for _, chained in long_chain:
|
||||||
|
all_long_chains.append(chained)
|
||||||
|
all_long_chains_array = numpy.array(all_long_chains)
|
||||||
|
for n in range(self.model.n):
|
||||||
|
_logger.info(f"{all_long_chains_array[:, n].shape}")
|
||||||
|
numpy.savetxt(
|
||||||
|
f"long_chain_generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv",
|
||||||
|
all_long_chains_array[:, n],
|
||||||
|
delimiter=",",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.keep_probs_list:
|
if self.keep_probs_list:
|
||||||
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
||||||
probs_list.append(
|
probs_list.append(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user