CUQIpy 1.1.1.post0.dev36__tar.gz → 1.1.1.post0.dev57__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (118) hide show
  1. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/CUQIpy.egg-info/PKG-INFO +2 -1
  2. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/CUQIpy.egg-info/requires.txt +1 -0
  3. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/PKG-INFO +2 -1
  4. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/_version.py +3 -3
  5. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_gibbs.py +78 -18
  6. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/requirements.txt +1 -0
  7. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/CUQIpy.egg-info/SOURCES.txt +0 -0
  8. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/CUQIpy.egg-info/dependency_links.txt +0 -0
  9. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/CUQIpy.egg-info/top_level.txt +0 -0
  10. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/LICENSE +0 -0
  11. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/README.md +0 -0
  12. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/__init__.py +0 -0
  13. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/_messages.py +0 -0
  14. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/array/__init__.py +0 -0
  15. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/array/_array.py +0 -0
  16. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/config.py +0 -0
  17. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/data/__init__.py +0 -0
  18. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/data/_data.py +0 -0
  19. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/data/astronaut.npz +0 -0
  20. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/data/camera.npz +0 -0
  21. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/data/cat.npz +0 -0
  22. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/data/cookie.png +0 -0
  23. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/data/satellite.mat +0 -0
  24. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/density/__init__.py +0 -0
  25. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/density/_density.py +0 -0
  26. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/diagnostics.py +0 -0
  27. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/__init__.py +0 -0
  28. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_beta.py +0 -0
  29. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_cauchy.py +0 -0
  30. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_cmrf.py +0 -0
  31. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_custom.py +0 -0
  32. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_distribution.py +0 -0
  33. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_gamma.py +0 -0
  34. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_gaussian.py +0 -0
  35. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_gmrf.py +0 -0
  36. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_inverse_gamma.py +0 -0
  37. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_joint_distribution.py +0 -0
  38. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_laplace.py +0 -0
  39. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_lmrf.py +0 -0
  40. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_lognormal.py +0 -0
  41. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
  42. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_normal.py +0 -0
  43. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_posterior.py +0 -0
  44. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_smoothed_laplace.py +0 -0
  45. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/distribution/_uniform.py +0 -0
  46. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/__init__.py +0 -0
  47. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/__init__.py +0 -0
  48. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_conjugate.py +0 -0
  49. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
  50. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_cwmh.py +0 -0
  51. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_direct.py +0 -0
  52. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_hmc.py +0 -0
  53. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
  54. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
  55. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_mh.py +0 -0
  56. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_pcn.py +0 -0
  57. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_rto.py +0 -0
  58. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_sampler.py +0 -0
  59. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/experimental/mcmc/_utilities.py +0 -0
  60. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/geometry/__init__.py +0 -0
  61. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/geometry/_geometry.py +0 -0
  62. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/implicitprior/__init__.py +0 -0
  63. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  64. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  65. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
  66. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/likelihood/__init__.py +0 -0
  67. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/likelihood/_likelihood.py +0 -0
  68. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/model/__init__.py +0 -0
  69. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/model/_model.py +0 -0
  70. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/operator/__init__.py +0 -0
  71. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/operator/_operator.py +0 -0
  72. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/pde/__init__.py +0 -0
  73. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/pde/_pde.py +0 -0
  74. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/problem/__init__.py +0 -0
  75. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/problem/_problem.py +0 -0
  76. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/__init__.py +0 -0
  77. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_conjugate.py +0 -0
  78. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_conjugate_approx.py +0 -0
  79. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_cwmh.py +0 -0
  80. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_gibbs.py +0 -0
  81. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_hmc.py +0 -0
  82. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_langevin_algorithm.py +0 -0
  83. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_laplace_approximation.py +0 -0
  84. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_mh.py +0 -0
  85. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_pcn.py +0 -0
  86. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_rto.py +0 -0
  87. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/sampler/_sampler.py +0 -0
  88. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/samples/__init__.py +0 -0
  89. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/samples/_samples.py +0 -0
  90. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/solver/__init__.py +0 -0
  91. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/solver/_solver.py +0 -0
  92. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/testproblem/__init__.py +0 -0
  93. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/testproblem/_testproblem.py +0 -0
  94. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/utilities/__init__.py +0 -0
  95. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/utilities/_get_python_variable_name.py +0 -0
  96. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/cuqi/utilities/_utilities.py +0 -0
  97. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/pyproject.toml +0 -0
  98. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/setup.cfg +0 -0
  99. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/setup.py +0 -0
  100. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_MRFs.py +0 -0
  101. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_abstract_distribution_density.py +0 -0
  102. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_bayesian_inversion.py +0 -0
  103. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_density.py +0 -0
  104. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_distribution.py +0 -0
  105. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_distributions_shape.py +0 -0
  106. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_geometry.py +0 -0
  107. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_implicit_priors.py +0 -0
  108. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_joint_distribution.py +0 -0
  109. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_likelihood.py +0 -0
  110. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_model.py +0 -0
  111. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_pde.py +0 -0
  112. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_posterior.py +0 -0
  113. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_problem.py +0 -0
  114. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_sampler.py +0 -0
  115. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_samples.py +0 -0
  116. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_solver.py +0 -0
  117. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_testproblem.py +0 -0
  118. {cuqipy-1.1.1.post0.dev36 → cuqipy-1.1.1.post0.dev57}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.1.1.post0.dev36
3
+ Version: 1.1.1.post0.dev57
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -204,6 +204,7 @@ Requires-Dist: matplotlib
204
204
  Requires-Dist: numpy>=1.17.0
205
205
  Requires-Dist: scipy<1.13
206
206
  Requires-Dist: arviz
207
+ Requires-Dist: tqdm
207
208
 
208
209
  <div align="center">
209
210
  <img src="https://cuqi-dtu.github.io/CUQIpy/_static/logo.png" alt="CUQIpy logo" width="250"/>
@@ -2,3 +2,4 @@ matplotlib
2
2
  numpy>=1.17.0
3
3
  scipy<1.13
4
4
  arviz
5
+ tqdm
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.1.1.post0.dev36
3
+ Version: 1.1.1.post0.dev57
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -204,6 +204,7 @@ Requires-Dist: matplotlib
204
204
  Requires-Dist: numpy>=1.17.0
205
205
  Requires-Dist: scipy<1.13
206
206
  Requires-Dist: arviz
207
+ Requires-Dist: tqdm
207
208
 
208
209
  <div align="center">
209
210
  <img src="https://cuqi-dtu.github.io/CUQIpy/_static/logo.png" alt="CUQIpy logo" width="250"/>
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-09-04T13:31:10+0200",
11
+ "date": "2024-09-10T22:51:38+0200",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "7852699550f6a8cb8bd6ad61a579e2eecb7cc964",
15
- "version": "1.1.1.post0.dev36"
14
+ "full-revisionid": "528481d0e961831e7e64f9dd6e48ddced1e10ae8",
15
+ "version": "1.1.1.post0.dev57"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -1,6 +1,7 @@
1
1
  from cuqi.distribution import JointDistribution
2
2
  from cuqi.experimental.mcmc import Sampler
3
3
  from cuqi.samples import Samples
4
+ from cuqi.experimental.mcmc import NUTS
4
5
  from typing import Dict
5
6
  import numpy as np
6
7
  import warnings
@@ -151,18 +152,50 @@ class HybridGibbs:
151
152
  sampler.validate_target()
152
153
 
153
154
  def sample(self, Ns) -> 'HybridGibbs':
154
- """ Sample from the joint distribution using Gibbs sampling """
155
+ """ Sample from the joint distribution using Gibbs sampling
156
+
157
+ Parameters
158
+ ----------
159
+ Ns : int
160
+ The number of samples to draw.
161
+
162
+ """
163
+
155
164
  for _ in tqdm(range(Ns)):
165
+
156
166
  self.step()
167
+
157
168
  self._store_samples()
158
169
 
159
- def warmup(self, Nb) -> 'HybridGibbs':
160
- """ Warmup (tune) the Gibbs sampler """
170
+ return self
171
+
172
+ def warmup(self, Nb, tune_freq=0.1) -> 'HybridGibbs':
173
+ """ Warmup (tune) the samplers in the Gibbs sampling scheme
174
+
175
+ Parameters
176
+ ----------
177
+ Nb : int
178
+ The number of samples to draw during warmup.
179
+
180
+ tune_freq : float, optional
181
+ Frequency of tuning the samplers. Tuning is performed every tune_freq*Nb steps.
182
+
183
+ """
184
+
185
+ tune_interval = max(int(tune_freq * Nb), 1)
186
+
161
187
  for idx in tqdm(range(Nb)):
188
+
162
189
  self.step()
163
- self.tune(idx)
190
+
191
+ # Tune the sampler at tuning intervals (matching behavior of Sampler class)
192
+ if (idx + 1) % tune_interval == 0:
193
+ self.tune(tune_interval, idx // tune_interval)
194
+
164
195
  self._store_samples()
165
196
 
197
+ return self
198
+
166
199
  def get_samples(self) -> Dict[str, Samples]:
167
200
  samples_object = {}
168
201
  for par_name in self.par_names:
@@ -182,38 +215,65 @@ class HybridGibbs:
182
215
  # Get sampler
183
216
  sampler = self.samplers[par_name]
184
217
 
185
- # Set initial parameters using current point and scale (subset of state)
186
- # This does not store the full state from e.g. NUTS sampler
187
- # But works on samplers like MH, PCN, ULA, MALA, LinearRTO, UGLA, CWMH
188
- # that only use initial_point and initial_scale
189
- sampler.initial_point = self.current_samples[par_name]
190
- if hasattr(sampler, 'initial_scale'): sampler.initial_scale = sampler.scale
218
+ # Instead of simply changing the target of the sampler, we reinitialize it.
219
+ # This is to ensure that all internal variables are set to match the new target.
220
+ # To return the sampler to the old state and history, we first extract the state and history
221
+ # before reinitializing the sampler and then set the state and history back to the sampler
222
+
223
+ # Extract state and history from sampler
224
+ if isinstance(sampler, NUTS): # Special case for NUTS as it is not playing nice with get_state and get_history
225
+ sampler.initial_point = sampler.current_point
226
+ else:
227
+ sampler_state = sampler.get_state()
228
+ sampler_history = sampler.get_history()
191
229
 
192
230
  # Reinitialize sampler
193
- # This makes the sampler lose all of its state.
194
- # This is only OK because we set the initial values above from the previous state
195
231
  sampler.reinitialize()
196
232
 
233
+ # Set state and history back to sampler
234
+ if not isinstance(sampler, NUTS): # Again, special case for NUTS.
235
+ sampler.set_state(sampler_state)
236
+ sampler.set_history(sampler_history)
237
+
197
238
  # Run pre_warmup and pre_sample methods for sampler
198
239
  # TODO. Some samplers (NUTS) seem to require to run _pre_warmup before _pre_sample
199
240
  self._pre_warmup_and_pre_sample_sampler(sampler)
200
241
 
201
- # Take MCMC steps
242
+ # Allow for multiple sampling steps in each Gibbs step
202
243
  for _ in range(self.num_sampling_steps[par_name]):
203
- sampler.step()
244
+ # Sampling step
245
+ acc = sampler.step()
246
+
247
+ # Store acceptance rate in sampler (matching behavior of Sampler class Sample method)
248
+ sampler._acc.append(acc)
204
249
 
205
250
  # Extract samples (Ensure even 1-dimensional samples are 1D arrays)
206
- self.current_samples[par_name] = sampler.current_point.reshape(-1)
251
+ if isinstance(sampler.current_point, np.ndarray):
252
+ self.current_samples[par_name] = sampler.current_point.reshape(-1)
253
+ else:
254
+ self.current_samples[par_name] = sampler.current_point
255
+
256
+ def tune(self, skip_len, update_count):
257
+ """ Run a single tuning step on each of the samplers in the Gibbs sampling scheme
258
+
259
+ Parameters
260
+ ----------
261
+ skip_len : int
262
+ Defines the number of steps in between tuning (i.e. the tuning interval).
263
+
264
+ update_count : int
265
+ The number of times tuning has been performed. Can be used for internal bookkeeping.
207
266
 
208
- def tune(self, idx):
209
- """ Tune each of the samplers """
267
+ """
210
268
  for par_name in self.par_names:
211
- self.samplers[par_name].tune(skip_len=1, update_count=idx)
269
+ self.samplers[par_name].tune(skip_len=skip_len, update_count=update_count)
212
270
 
213
271
  # ------------ Private methods ------------
214
272
  def _initialize_samplers(self):
215
273
  """ Initialize samplers """
216
274
  for sampler in self.samplers.values():
275
+ if isinstance(sampler, NUTS):
276
+ print(f'Warning: NUTS sampler is not fully stateful in HybridGibbs. Sampler will be reinitialized in each Gibbs step.')
217
277
  sampler.initialize()
218
278
 
219
279
  def _initialize_num_sampling_steps(self):
@@ -2,3 +2,4 @@ matplotlib
2
2
  numpy>=1.17.0
3
3
  scipy<1.13
4
4
  arviz
5
+ tqdm