CUQIpy 1.3.0.post0.dev362__tar.gz → 1.3.0.post0.dev383__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (126) hide show
  1. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/CUQIpy.egg-info/PKG-INFO +1 -1
  2. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/PKG-INFO +1 -1
  3. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/_version.py +3 -3
  4. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_rto.py +53 -9
  5. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_sampler.py +10 -0
  6. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/implicitprior/__init__.py +1 -1
  7. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/implicitprior/_restorator.py +35 -1
  8. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/CUQIpy.egg-info/SOURCES.txt +0 -0
  9. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/CUQIpy.egg-info/dependency_links.txt +0 -0
  10. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/CUQIpy.egg-info/requires.txt +0 -0
  11. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/CUQIpy.egg-info/top_level.txt +0 -0
  12. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/LICENSE +0 -0
  13. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/README.md +0 -0
  14. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/__init__.py +0 -0
  15. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/_messages.py +0 -0
  16. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/array/__init__.py +0 -0
  17. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/array/_array.py +0 -0
  18. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/config.py +0 -0
  19. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/data/__init__.py +0 -0
  20. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/data/_data.py +0 -0
  21. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/data/astronaut.npz +0 -0
  22. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/data/camera.npz +0 -0
  23. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/data/cat.npz +0 -0
  24. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/data/cookie.png +0 -0
  25. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/data/satellite.mat +0 -0
  26. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/density/__init__.py +0 -0
  27. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/density/_density.py +0 -0
  28. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/diagnostics.py +0 -0
  29. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/__init__.py +0 -0
  30. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_beta.py +0 -0
  31. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_cauchy.py +0 -0
  32. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_cmrf.py +0 -0
  33. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_custom.py +0 -0
  34. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_distribution.py +0 -0
  35. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_gamma.py +0 -0
  36. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_gaussian.py +0 -0
  37. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_gmrf.py +0 -0
  38. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_inverse_gamma.py +0 -0
  39. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_joint_distribution.py +0 -0
  40. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_laplace.py +0 -0
  41. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_lmrf.py +0 -0
  42. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_lognormal.py +0 -0
  43. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
  44. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_normal.py +0 -0
  45. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_posterior.py +0 -0
  46. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_smoothed_laplace.py +0 -0
  47. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_truncated_normal.py +0 -0
  48. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/distribution/_uniform.py +0 -0
  49. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/__init__.py +0 -0
  50. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/_recommender.py +0 -0
  51. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/algebra/__init__.py +0 -0
  52. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/algebra/_ast.py +0 -0
  53. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/algebra/_orderedset.py +0 -0
  54. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/algebra/_randomvariable.py +0 -0
  55. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/geometry/__init__.py +0 -0
  56. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/geometry/_productgeometry.py +0 -0
  57. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/__init__.py +0 -0
  58. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_conjugate.py +0 -0
  59. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
  60. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_cwmh.py +0 -0
  61. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_direct.py +0 -0
  62. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_gibbs.py +0 -0
  63. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_hmc.py +0 -0
  64. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
  65. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
  66. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_mh.py +0 -0
  67. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/experimental/mcmc/_pcn.py +0 -0
  68. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/geometry/__init__.py +0 -0
  69. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/geometry/_geometry.py +0 -0
  70. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  71. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  72. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
  73. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/likelihood/__init__.py +0 -0
  74. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/likelihood/_likelihood.py +0 -0
  75. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/model/__init__.py +0 -0
  76. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/model/_model.py +0 -0
  77. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/operator/__init__.py +0 -0
  78. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/operator/_operator.py +0 -0
  79. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/pde/__init__.py +0 -0
  80. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/pde/_pde.py +0 -0
  81. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/problem/__init__.py +0 -0
  82. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/problem/_problem.py +0 -0
  83. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/__init__.py +0 -0
  84. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_conjugate.py +0 -0
  85. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_conjugate_approx.py +0 -0
  86. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_cwmh.py +0 -0
  87. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_gibbs.py +0 -0
  88. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_hmc.py +0 -0
  89. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_langevin_algorithm.py +0 -0
  90. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_laplace_approximation.py +0 -0
  91. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_mh.py +0 -0
  92. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_pcn.py +0 -0
  93. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_rto.py +0 -0
  94. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/sampler/_sampler.py +0 -0
  95. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/samples/__init__.py +0 -0
  96. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/samples/_samples.py +0 -0
  97. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/solver/__init__.py +0 -0
  98. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/solver/_solver.py +0 -0
  99. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/testproblem/__init__.py +0 -0
  100. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/testproblem/_testproblem.py +0 -0
  101. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/utilities/__init__.py +0 -0
  102. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/utilities/_get_python_variable_name.py +0 -0
  103. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/cuqi/utilities/_utilities.py +0 -0
  104. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/pyproject.toml +0 -0
  105. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/requirements.txt +0 -0
  106. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/setup.cfg +0 -0
  107. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/setup.py +0 -0
  108. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_MRFs.py +0 -0
  109. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_abstract_distribution_density.py +0 -0
  110. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_bayesian_inversion.py +0 -0
  111. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_density.py +0 -0
  112. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_distribution.py +0 -0
  113. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_distributions_shape.py +0 -0
  114. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_geometry.py +0 -0
  115. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_implicit_priors.py +0 -0
  116. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_joint_distribution.py +0 -0
  117. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_likelihood.py +0 -0
  118. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_model.py +0 -0
  119. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_pde.py +0 -0
  120. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_posterior.py +0 -0
  121. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_problem.py +0 -0
  122. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_sampler.py +0 -0
  123. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_samples.py +0 -0
  124. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_solver.py +0 -0
  125. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_testproblem.py +0 -0
  126. {cuqipy-1.3.0.post0.dev362 → cuqipy-1.3.0.post0.dev383}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: CUQIpy
3
- Version: 1.3.0.post0.dev362
3
+ Version: 1.3.0.post0.dev383
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: CUQIpy
3
- Version: 1.3.0.post0.dev362
3
+ Version: 1.3.0.post0.dev383
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-09-12T10:23:05+0300",
11
+ "date": "2025-09-12T11:00:39+0200",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "f6a73b0b32186614fe5451781567c1abcd48452a",
15
- "version": "1.3.0.post0.dev362"
14
+ "full-revisionid": "37e1d4431766233eccce97c09eb773486ed25032",
15
+ "version": "1.3.0.post0.dev383"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -36,21 +36,48 @@ class LinearRTO(Sampler):
36
36
  tol : float
37
37
  Tolerance of the inner CGLS solver. *Optional*.
38
38
 
39
+ inner_initial_point : string or np.ndarray or cuqi.array.CUQIArray
40
+ Initial point for the inner optimization problem. Can be "previous_sample" (default), "MAP", or a specific numpy or cuqi array. *Optional*.
41
+
39
42
  callback : callable, optional
40
43
  A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
41
44
  The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
42
45
 
43
46
  """
44
- def __init__(self, target=None, initial_point=None, maxit=10, tol=1e-6, **kwargs):
47
+ def __init__(self, target=None, initial_point=None, maxit=10, tol=1e-6, inner_initial_point="previous_sample", **kwargs):
45
48
 
46
49
  super().__init__(target=target, initial_point=initial_point, **kwargs)
47
50
 
48
51
  # Other parameters
49
52
  self.maxit = maxit
50
53
  self.tol = tol
54
+ self.inner_initial_point = inner_initial_point
51
55
 
52
56
  def _initialize(self):
53
57
  self._precompute()
58
+ self._compute_map()
59
+
60
+ @property
61
+ def inner_initial_point(self):
62
+ if isinstance(self._inner_initial_point, str):
63
+ if self._inner_initial_point == "previous_sample":
64
+ return self.current_point
65
+ elif self._inner_initial_point == "map":
66
+ return self._map
67
+ else:
68
+ return self._inner_initial_point
69
+
70
+ @inner_initial_point.setter
71
+ def inner_initial_point(self, value):
72
+ is_correct_string = (isinstance(value, str) and
73
+ (value.lower() == "previous_sample" or
74
+ value.lower() == "map"))
75
+ if is_correct_string:
76
+ self._inner_initial_point = value.lower()
77
+ elif isinstance(value, (np.ndarray, cuqi.array.CUQIarray)):
78
+ self._inner_initial_point = value
79
+ else:
80
+ raise ValueError("Invalid value for inner_initial_point. Choose either 'previous_sample', 'MAP', or provide a numpy array/cuqi array.")
54
81
 
55
82
  @property
56
83
  def prior(self):
@@ -78,6 +105,10 @@ class LinearRTO(Sampler):
78
105
  elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
79
106
  return self.target.models
80
107
 
108
+ def _compute_map(self):
109
+ sim = CGLS(self.M, self.b_tild, self.current_point, self.maxit, self.tol)
110
+ self._map, _ = sim.solve()
111
+
81
112
  def _precompute(self):
82
113
  L1 = [likelihood.distribution.sqrtprec for likelihood in self.likelihoods]
83
114
  L2 = self.prior.sqrtprec
@@ -114,7 +145,7 @@ class LinearRTO(Sampler):
114
145
 
115
146
  def step(self):
116
147
  y = self.b_tild + np.random.randn(len(self.b_tild))
117
- sim = CGLS(self.M, y, self.current_point, self.maxit, self.tol)
148
+ sim = CGLS(self.M, y, self.inner_initial_point, self.maxit, self.tol)
118
149
  self.current_point, _ = sim.solve()
119
150
  acc = 1
120
151
  return acc
@@ -203,12 +234,15 @@ class RegularizedLinearRTO(LinearRTO):
203
234
  solver : string
204
235
  Options are "FISTA" (default for a single constraint or regularization), "ADMM" (default and the only option for multiple constraints or regularizations), "ScipyLinearLSQ" and "ScipyMinimizer". Note "ScipyLinearLSQ" and "ScipyMinimizer" can only be used with `RegularizedGaussian` of a single `box` or `nonnegativity` constraint. *Optional*.
205
236
 
237
+ inner_initial_point : string or np.ndarray or cuqi.array.CUQIArray
238
+ Initial point for the inner optimization problem. Can be "previous_sample" (default), "MAP", or a specific numpy or cuqi array. *Optional*.
239
+
206
240
  callback : callable, optional
207
241
  A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
208
242
  The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
209
243
 
210
244
  """
211
- def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, **kwargs):
245
+ def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, inner_initial_point="previous_sample", **kwargs):
212
246
 
213
247
  super().__init__(target=target, initial_point=initial_point, **kwargs)
214
248
 
@@ -221,6 +255,7 @@ class RegularizedLinearRTO(LinearRTO):
221
255
  self.inner_max_it = inner_max_it
222
256
  self.penalty_parameter = penalty_parameter
223
257
  self.solver = solver
258
+ self.inner_initial_point = inner_initial_point
224
259
 
225
260
  def _initialize(self):
226
261
  super()._initialize()
@@ -228,6 +263,7 @@ class RegularizedLinearRTO(LinearRTO):
228
263
  self.solver = "FISTA" if callable(self.proximal) else "ADMM"
229
264
  if self.solver == "FISTA":
230
265
  self._stepsize = self._choose_stepsize()
266
+ self._compute_map_regularized()
231
267
 
232
268
  @property
233
269
  def solver(self):
@@ -272,15 +308,16 @@ class RegularizedLinearRTO(LinearRTO):
272
308
  def prior(self):
273
309
  return self.target.prior.gaussian
274
310
 
275
- def step(self):
276
- y = self.b_tild + np.random.randn(len(self.b_tild))
311
+ def _compute_map_regularized(self):
312
+ self._map = self._customized_step(self.b_tild, self.initial_point)
277
313
 
314
+ def _customized_step(self, y, x0):
278
315
  if self.solver == "FISTA":
279
316
  sim = FISTA(self.M, y, self.proximal,
280
- self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
317
+ x0, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
281
318
  elif self.solver == "ADMM":
282
319
  sim = ADMM(self.M, y, self.proximal,
283
- self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
320
+ x0, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
284
321
  elif self.solver == "ScipyLinearLSQ":
285
322
  A_op = sp.sparse.linalg.LinearOperator((sum([llh.distribution.dim for llh in self.likelihoods])+self.target.prior.dim, self.target.prior.dim),
286
323
  matvec=lambda x: self.M(x, 1),
@@ -297,10 +334,17 @@ class RegularizedLinearRTO(LinearRTO):
297
334
  bounds = [(self.target.prior._box_bounds[0][i], self.target.prior._box_bounds[1][i]) for i in range(self.target.prior.dim)]
298
335
  # Note that the objective function is defined as 0.5*||Mx-y||^2,
299
336
  # and the corresponding gradient (gradfunc) is given by M^T(Mx-y).
300
- sim = ScipyMinimizer(lambda x: 0.5*np.sum((self.M(x, 1)-y)**2), self.current_point, gradfunc=lambda x: self.M(self.M(x, 1) - y, 2), bounds=bounds, tol=self.abstol, options={"maxiter": self.maxit})
337
+ sim = ScipyMinimizer(lambda x: 0.5*np.sum((self.M(x, 1)-y)**2), x0, gradfunc=lambda x: self.M(self.M(x, 1) - y, 2), bounds=bounds, tol=self.abstol, options={"maxiter": self.maxit})
301
338
  else:
302
339
  raise ValueError("Choice of solver not supported.")
340
+
341
+ sol, _ = sim.solve()
342
+ return sol
343
+
344
+ def step(self):
345
+ y = self.b_tild + np.random.randn(len(self.b_tild))
346
+
347
+ self.current_point = self._customized_step(y, self.inner_initial_point)
303
348
 
304
- self.current_point, _ = sim.solve()
305
349
  acc = 1
306
350
  return acc
@@ -148,6 +148,16 @@ class Sampler(ABC):
148
148
  if self._target is not None:
149
149
  self.validate_target()
150
150
 
151
+ @property
152
+ def current_point(self):
153
+ """ The current point of the sampler. """
154
+ return self._current_point
155
+
156
+ @current_point.setter
157
+ def current_point(self, value):
158
+ """ Set the current point of the sampler. """
159
+ self._current_point = value
160
+
151
161
  # ------------ Public methods ------------
152
162
  def get_samples(self) -> Samples:
153
163
  """ Return the samples. The internal data-structure for the samples is a dynamic list so this creates a copy. """
@@ -1,5 +1,5 @@
1
1
  from ._regularizedGaussian import RegularizedGaussian, ConstrainedGaussian, NonnegativeGaussian
2
2
  from ._regularizedGMRF import RegularizedGMRF, ConstrainedGMRF, NonnegativeGMRF
3
3
  from ._regularizedUnboundedUniform import RegularizedUnboundedUniform
4
- from ._restorator import RestorationPrior, MoreauYoshidaPrior
4
+ from ._restorator import RestorationPrior, MoreauYoshidaPrior, TweediePrior
5
5
 
@@ -232,4 +232,38 @@ class MoreauYoshidaPrior(Distribution):
232
232
  """ Returns the conditioning variables of the distribution. """
233
233
  # Currently conditioning variables are not supported for user-defined
234
234
  # distributions.
235
- return []
235
+ return []
236
+
237
+ class TweediePrior(MoreauYoshidaPrior):
238
+ """
239
+ Alias for MoreauYoshidaPrior following Tweedie's formula framework. TweediePrior
240
+ defines priors where gradients are computed based on Tweedie's identity that links
241
+ MMSE (Minimum Mean Square Error) denoisers with the underlying smoothed prior, see:
242
+ - Laumont et al. https://arxiv.org/abs/2103.04715 or https://doi.org/10.1137/21M1406349
243
+
244
+ Tweedie's Formula
245
+ -------------------------
246
+ In the context of denoising, Tweedie's identity states that for a signal x
247
+ corrupted by Gaussian noise:
248
+
249
+ ∇_x log p_e(x) = (D_e(x) - x) / e
250
+
251
+ where D_e(x) is the MMSE denoiser output and e is the noise variance.
252
+ This enables us to perform gradient-based sampling with algorithms like ULA.
253
+
254
+ At implementation level, TweediePrior shares identical functionality with MoreauYoshidaPrior.
255
+ Thus, it is implemented as an alias of MoreauYoshidaPrior, meaning all methods,
256
+ properties, and behavior are identical. The separate name provides clarity when
257
+ working specifically with Tweedie's formula-based approaches.
258
+
259
+ Parameters
260
+ ----------
261
+ prior : RestorationPrior
262
+ Prior of the RestorationPrior type containing a denoiser/restorator.
263
+
264
+ smoothing_strength : float, default=0.1
265
+ Corresponds to the noise variance e in Tweedie's formula context.
266
+
267
+ See MoreauYoshidaPrior for the underlying implementation with complete documentation.
268
+ """
269
+ pass