Compare commits
2 Commits
c36a836e74
...
c05953bb20
Author | SHA1 | Date | |
---|---|---|---|
c05953bb20 | |||
b4e5f53726 |
@ -143,6 +143,39 @@ class SubsetSimulation:
|
|||||||
delimiter=",",
|
delimiter=",",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
next_seeds_as_array = numpy.array([s for _, s in next_seeds])
|
||||||
|
stdevs = self.get_stdevs_from_arrays(next_seeds_as_array)
|
||||||
|
_logger.info(f"got stdevs: {stdevs.stdevs}")
|
||||||
|
all_long_chains = []
|
||||||
|
for seed_index, (c, s) in enumerate(next_seeds[::len(next_seeds) // 20]):
|
||||||
|
# chain = mcmc(s, threshold_cost, n_s, model, dot_inputs_array, actual_measurement_array, mcmc_rng, curr_cost=c, stdevs=stdevs)
|
||||||
|
# until new version gotta do
|
||||||
|
_logger.debug(
|
||||||
|
f"\t{seed_index}: doing long chain on the next seed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
long_chain = self.model.get_mcmc_chain(
|
||||||
|
s,
|
||||||
|
self.cost_function_to_use,
|
||||||
|
1000,
|
||||||
|
threshold_cost,
|
||||||
|
stdevs,
|
||||||
|
initial_cost=c,
|
||||||
|
rng_arg=mcmc_rng,
|
||||||
|
)
|
||||||
|
for _, chained in long_chain:
|
||||||
|
all_long_chains.append(chained)
|
||||||
|
all_long_chains_array = numpy.array(all_long_chains)
|
||||||
|
for n in range(self.model.n):
|
||||||
|
_logger.info(f"{all_long_chains_array[:, n].shape}")
|
||||||
|
numpy.savetxt(
|
||||||
|
f"long_chain_generation_{self.n_c}_{self.n_s}_{i}_dipole_{n}.csv",
|
||||||
|
all_long_chains_array[:, n],
|
||||||
|
delimiter=",",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.keep_probs_list:
|
if self.keep_probs_list:
|
||||||
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
for cost_index, cost_chain in enumerate(all_chains[: -self.n_c]):
|
||||||
probs_list.append(
|
probs_list.append(
|
||||||
|
@ -14,7 +14,7 @@ scipy = "1.10"
|
|||||||
pytest = ">=6"
|
pytest = ">=6"
|
||||||
flake8 = "^4.0.1"
|
flake8 = "^4.0.1"
|
||||||
pytest-cov = "^4.1.0"
|
pytest-cov = "^4.1.0"
|
||||||
mypy = "^0.971"
|
mypy = "^1.6"
|
||||||
python-semantic-release = "^7.24.0"
|
python-semantic-release = "^7.24.0"
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
syrupy = "^4.0.8"
|
syrupy = "^4.0.8"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user