Compare commits

...

2 Commits

Author SHA1 Message Date
c05953bb20 chore(deps): update dependency mypy to v1
Some checks failed
renovate/artifacts Artifact file update failure
gitea-physics/deepdog/pipeline/pr-master There was a failure building this commit
2023-10-11 01:30:56 +00:00
b4e5f53726
feat: adds longchain logging if logging last generation
Some checks failed
gitea-physics/deepdog/pipeline/head There was a failure building this commit
2023-08-12 19:48:30 -05:00
2 changed files with 34 additions and 1 deletions

View File

@ -143,6 +143,39 @@ class SubsetSimulation:
delimiter=",",
)
next_seeds_as_array = numpy.array([s for _, s in next_seeds])
stdevs = self.get_stdevs_from_arrays(next_seeds_as_array)
_logger.info(f"got stdevs: {stdevs.stdevs}")
all_long_chains = []
for seed_index, (c, s) in enumerate(next_seeds[::len(next_seeds) // 20]):
# chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
# until new version gotta do
_logger.debug(
f"\t{seed_index}: doing long chain on the next seed"
)
long_chain = self.model.get_mcmc_chain(
s,
self.cost_function_to_use,
1000,
threshold_cost,
stdevs,
initial_cost=c,
rng_arg=mcmc_rng,
)
for _, chained in long_chain:
all_long_chains.append(chained)
all_long_chains_array = numpy.array(all_long_chains)
for n in range(self.model.n):
_logger.info(f"{all_long_chains_array[:, n].shape}")
numpy.savetxt(
f"long_chain_generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv",
all_long_chains_array[:, n],
delimiter=",",
)
if self.keep_probs_list:
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
probs_list.append(

View File

@ -14,7 +14,7 @@ scipy = "1.10"
pytest = ">=6"
flake8 = "^4.0.1"
pytest-cov = "^4.1.0"
mypy = "^0.971"
mypy = "^1.6"
python-semantic-release = "^7.24.0"
black = "^22.3.0"
syrupy = "^4.0.8"