CUQIpy 1.3.0.post0.dev371__tar.gz → 1.3.0.post0.dev395__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (126) hide show
  1. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/CUQIpy.egg-info/PKG-INFO +1 -1
  2. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/PKG-INFO +1 -1
  3. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/_version.py +3 -3
  4. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_posterior.py +9 -0
  5. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_gibbs.py +11 -19
  6. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_hmc.py +3 -1
  7. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/implicitprior/__init__.py +1 -1
  8. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/implicitprior/_restorator.py +35 -1
  9. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/CUQIpy.egg-info/SOURCES.txt +0 -0
  10. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/CUQIpy.egg-info/dependency_links.txt +0 -0
  11. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/CUQIpy.egg-info/requires.txt +0 -0
  12. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/CUQIpy.egg-info/top_level.txt +0 -0
  13. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/LICENSE +0 -0
  14. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/README.md +0 -0
  15. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/__init__.py +0 -0
  16. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/_messages.py +0 -0
  17. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/array/__init__.py +0 -0
  18. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/array/_array.py +0 -0
  19. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/config.py +0 -0
  20. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/data/__init__.py +0 -0
  21. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/data/_data.py +0 -0
  22. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/data/astronaut.npz +0 -0
  23. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/data/camera.npz +0 -0
  24. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/data/cat.npz +0 -0
  25. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/data/cookie.png +0 -0
  26. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/data/satellite.mat +0 -0
  27. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/density/__init__.py +0 -0
  28. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/density/_density.py +0 -0
  29. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/diagnostics.py +0 -0
  30. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/__init__.py +0 -0
  31. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_beta.py +0 -0
  32. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_cauchy.py +0 -0
  33. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_cmrf.py +0 -0
  34. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_custom.py +0 -0
  35. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_distribution.py +0 -0
  36. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_gamma.py +0 -0
  37. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_gaussian.py +0 -0
  38. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_gmrf.py +0 -0
  39. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_inverse_gamma.py +0 -0
  40. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_joint_distribution.py +0 -0
  41. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_laplace.py +0 -0
  42. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_lmrf.py +0 -0
  43. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_lognormal.py +0 -0
  44. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
  45. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_normal.py +0 -0
  46. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_smoothed_laplace.py +0 -0
  47. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_truncated_normal.py +0 -0
  48. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/distribution/_uniform.py +0 -0
  49. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/__init__.py +0 -0
  50. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/_recommender.py +0 -0
  51. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/algebra/__init__.py +0 -0
  52. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/algebra/_ast.py +0 -0
  53. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/algebra/_orderedset.py +0 -0
  54. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/algebra/_randomvariable.py +0 -0
  55. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/geometry/__init__.py +0 -0
  56. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/geometry/_productgeometry.py +0 -0
  57. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/__init__.py +0 -0
  58. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_conjugate.py +0 -0
  59. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
  60. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_cwmh.py +0 -0
  61. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_direct.py +0 -0
  62. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
  63. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
  64. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_mh.py +0 -0
  65. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_pcn.py +0 -0
  66. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_rto.py +0 -0
  67. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/experimental/mcmc/_sampler.py +0 -0
  68. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/geometry/__init__.py +0 -0
  69. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/geometry/_geometry.py +0 -0
  70. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  71. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  72. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
  73. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/likelihood/__init__.py +0 -0
  74. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/likelihood/_likelihood.py +0 -0
  75. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/model/__init__.py +0 -0
  76. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/model/_model.py +0 -0
  77. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/operator/__init__.py +0 -0
  78. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/operator/_operator.py +0 -0
  79. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/pde/__init__.py +0 -0
  80. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/pde/_pde.py +0 -0
  81. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/problem/__init__.py +0 -0
  82. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/problem/_problem.py +0 -0
  83. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/__init__.py +0 -0
  84. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_conjugate.py +0 -0
  85. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_conjugate_approx.py +0 -0
  86. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_cwmh.py +0 -0
  87. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_gibbs.py +0 -0
  88. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_hmc.py +0 -0
  89. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_langevin_algorithm.py +0 -0
  90. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_laplace_approximation.py +0 -0
  91. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_mh.py +0 -0
  92. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_pcn.py +0 -0
  93. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_rto.py +0 -0
  94. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/sampler/_sampler.py +0 -0
  95. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/samples/__init__.py +0 -0
  96. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/samples/_samples.py +0 -0
  97. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/solver/__init__.py +0 -0
  98. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/solver/_solver.py +0 -0
  99. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/testproblem/__init__.py +0 -0
  100. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/testproblem/_testproblem.py +0 -0
  101. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/utilities/__init__.py +0 -0
  102. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/utilities/_get_python_variable_name.py +0 -0
  103. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/cuqi/utilities/_utilities.py +0 -0
  104. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/pyproject.toml +0 -0
  105. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/requirements.txt +0 -0
  106. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/setup.cfg +0 -0
  107. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/setup.py +0 -0
  108. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_MRFs.py +0 -0
  109. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_abstract_distribution_density.py +0 -0
  110. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_bayesian_inversion.py +0 -0
  111. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_density.py +0 -0
  112. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_distribution.py +0 -0
  113. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_distributions_shape.py +0 -0
  114. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_geometry.py +0 -0
  115. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_implicit_priors.py +0 -0
  116. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_joint_distribution.py +0 -0
  117. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_likelihood.py +0 -0
  118. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_model.py +0 -0
  119. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_pde.py +0 -0
  120. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_posterior.py +0 -0
  121. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_problem.py +0 -0
  122. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_sampler.py +0 -0
  123. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_samples.py +0 -0
  124. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_solver.py +0 -0
  125. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_testproblem.py +0 -0
  126. {cuqipy-1.3.0.post0.dev371 → cuqipy-1.3.0.post0.dev395}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: CUQIpy
3
- Version: 1.3.0.post0.dev371
3
+ Version: 1.3.0.post0.dev395
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: CUQIpy
3
- Version: 1.3.0.post0.dev371
3
+ Version: 1.3.0.post0.dev395
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-09-12T10:46:36+0200",
11
+ "date": "2025-09-19T16:37:46+0300",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "1bd165981a9e7200ec642e2845ffde05a7e8a168",
15
- "version": "1.3.0.post0.dev371"
14
+ "full-revisionid": "2cf72ec9af9af17dad4bb3870ee20d303376de24",
15
+ "version": "1.3.0.post0.dev395"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -1,5 +1,6 @@
1
1
  from cuqi.geometry import _DefaultGeometry, _get_identity_geometries
2
2
  from cuqi.distribution import Distribution
3
+ from cuqi.density import Density
3
4
 
4
5
  # ========================================================================
5
6
  class Posterior(Distribution):
@@ -25,6 +26,14 @@ class Posterior(Distribution):
25
26
  self.prior = prior
26
27
  super().__init__(**kwargs)
27
28
 
29
+ def get_density(self, name) -> Density:
30
+ """ Return a density with the given name. """
31
+ if name == self.likelihood.name:
32
+ return self.likelihood
33
+ if name == self.prior.name:
34
+ return self.prior
35
+ raise ValueError(f"No density with name {name}.")
36
+
28
37
  @property
29
38
  def data(self):
30
39
  return self.likelihood.data
@@ -1,7 +1,6 @@
1
- from cuqi.distribution import JointDistribution
1
+ from cuqi.distribution import JointDistribution, Posterior
2
2
  from cuqi.experimental.mcmc import Sampler
3
3
  from cuqi.samples import Samples, JointSamples
4
- from cuqi.experimental.mcmc import NUTS
5
4
  from typing import Dict
6
5
  import numpy as np
7
6
  import warnings
@@ -36,11 +35,10 @@ class HybridGibbs:
36
35
  Gelman et al. "Bayesian Data Analysis" (2014), Third Edition
37
36
  for more details.
38
37
 
39
- In each Gibbs step, the corresponding sampler has the initial_point
40
- and initial_scale (if applicable) set to the value of the previous step
41
- and the sampler is reinitialized. This means that the sampling is not
42
- fully stateful at this point. This means samplers like NUTS will lose
43
- their internal state between Gibbs steps.
38
+ In each Gibbs step, the corresponding sampler state and history are stored,
39
+ then the sampler is reinitialized. After reinitialization, the sampler state
40
+ and history are set back to the stored values. This ensures preserving the
41
+ statefulness of the samplers.
44
42
 
45
43
  The order in which the conditionals are sampled is the order of the
46
44
  variables in the sampling strategy, unless a different sampling order
@@ -177,8 +175,8 @@ class HybridGibbs:
177
175
  # ------------ Public methods ------------
178
176
  def validate_targets(self):
179
177
  """ Validate each of the conditional targets used in the Gibbs steps """
180
- if not isinstance(self.target, JointDistribution):
181
- raise ValueError('Target distribution must be a JointDistribution.')
178
+ if not isinstance(self.target, (JointDistribution, Posterior)):
179
+ raise ValueError('Target distribution must be a JointDistribution or Posterior.')
182
180
  for sampler in self.samplers.values():
183
181
  sampler.validate_target()
184
182
 
@@ -257,19 +255,15 @@ class HybridGibbs:
257
255
  # before reinitializing the sampler and then set the state and history back to the sampler
258
256
 
259
257
  # Extract state and history from sampler
260
- if isinstance(sampler, NUTS): # Special case for NUTS as it is not playing nice with get_state and get_history
261
- sampler.initial_point = sampler.current_point
262
- else:
263
- sampler_state = sampler.get_state()
264
- sampler_history = sampler.get_history()
258
+ sampler_state = sampler.get_state()
259
+ sampler_history = sampler.get_history()
265
260
 
266
261
  # Reinitialize sampler
267
262
  sampler.reinitialize()
268
263
 
269
264
  # Set state and history back to sampler
270
- if not isinstance(sampler, NUTS): # Again, special case for NUTS.
271
- sampler.set_state(sampler_state)
272
- sampler.set_history(sampler_history)
265
+ sampler.set_state(sampler_state)
266
+ sampler.set_history(sampler_history)
273
267
 
274
268
  # Allow for multiple sampling steps in each Gibbs step
275
269
  for _ in range(self.num_sampling_steps[par_name]):
@@ -309,8 +303,6 @@ class HybridGibbs:
309
303
  def _initialize_samplers(self):
310
304
  """ Initialize samplers """
311
305
  for sampler in self.samplers.values():
312
- if isinstance(sampler, NUTS):
313
- print(f'Warning: NUTS sampler is not fully stateful in HybridGibbs. Sampler will be reinitialized in each Gibbs step.')
314
306
  sampler.initialize()
315
307
 
316
308
  def _initialize_num_sampling_steps(self):
@@ -118,8 +118,10 @@ class NUTS(Sampler):
118
118
  # to epsilon_bar for the remaining sampling steps.
119
119
  if self.step_size is None:
120
120
  self._epsilon = self._FindGoodEpsilon()
121
+ self.step_size = self._epsilon
121
122
  else:
122
123
  self._epsilon = self.step_size
124
+
123
125
  self._epsilon_bar = "unset"
124
126
 
125
127
  # Parameter mu, does not change during the run
@@ -127,7 +129,7 @@ class NUTS(Sampler):
127
129
 
128
130
  self._H_bar = 0
129
131
 
130
- # NUTS run diagnostic:
132
+ # NUTS run diagnostics
131
133
  # number of tree nodes created each NUTS iteration
132
134
  self._num_tree_node = 0
133
135
 
@@ -1,5 +1,5 @@
1
1
  from ._regularizedGaussian import RegularizedGaussian, ConstrainedGaussian, NonnegativeGaussian
2
2
  from ._regularizedGMRF import RegularizedGMRF, ConstrainedGMRF, NonnegativeGMRF
3
3
  from ._regularizedUnboundedUniform import RegularizedUnboundedUniform
4
- from ._restorator import RestorationPrior, MoreauYoshidaPrior
4
+ from ._restorator import RestorationPrior, MoreauYoshidaPrior, TweediePrior
5
5
 
@@ -232,4 +232,38 @@ class MoreauYoshidaPrior(Distribution):
232
232
  """ Returns the conditioning variables of the distribution. """
233
233
  # Currently conditioning variables are not supported for user-defined
234
234
  # distributions.
235
- return []
235
+ return []
236
+
237
+ class TweediePrior(MoreauYoshidaPrior):
238
+ """
239
+ Alias for MoreauYoshidaPrior following Tweedie's formula framework. TweediePrior
240
+ defines priors where gradients are computed based on Tweedie's identity that links
241
+ MMSE (Minimum Mean Square Error) denoisers with the underlying smoothed prior, see:
242
+ - Laumont et al. https://arxiv.org/abs/2103.04715 or https://doi.org/10.1137/21M1406349
243
+
244
+ Tweedie's Formula
245
+ -------------------------
246
+ In the context of denoising, Tweedie's identity states that for a signal x
247
+ corrupted by Gaussian noise:
248
+
249
+ ∇_x log p_e(x) = (D_e(x) - x) / e
250
+
251
+ where D_e(x) is the MMSE denoiser output and e is the noise variance.
252
+ This enables us to perform gradient-based sampling with algorithms like ULA.
253
+
254
+ At implementation level, TweediePrior shares identical functionality with MoreauYoshidaPrior.
255
+ Thus, it is implemented as an alias of MoreauYoshidaPrior, meaning all methods,
256
+ properties, and behavior are identical. The separate name provides clarity when
257
+ working specifically with Tweedie's formula-based approaches.
258
+
259
+ Parameters
260
+ ----------
261
+ prior : RestorationPrior
262
+ Prior of the RestorationPrior type containing a denoiser/restorator.
263
+
264
+ smoothing_strength : float, default=0.1
265
+ Corresponds to the noise variance e in Tweedie's formula context.
266
+
267
+ See MoreauYoshidaPrior for the underlying implementation with complete documentation.
268
+ """
269
+ pass