CUQIpy 1.3.0__py3-none-any.whl → 1.4.0.post0.dev61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/__init__.py +1 -1
  5. cuqi/distribution/_beta.py +1 -1
  6. cuqi/distribution/_cauchy.py +2 -2
  7. cuqi/distribution/_distribution.py +24 -15
  8. cuqi/distribution/_joint_distribution.py +97 -12
  9. cuqi/distribution/_posterior.py +9 -0
  10. cuqi/distribution/_truncated_normal.py +3 -3
  11. cuqi/distribution/_uniform.py +36 -2
  12. cuqi/experimental/__init__.py +1 -1
  13. cuqi/experimental/_recommender.py +216 -0
  14. cuqi/experimental/geometry/_productgeometry.py +3 -3
  15. cuqi/geometry/_geometry.py +12 -1
  16. cuqi/implicitprior/__init__.py +1 -1
  17. cuqi/implicitprior/_regularizedGaussian.py +40 -4
  18. cuqi/implicitprior/_restorator.py +35 -1
  19. cuqi/legacy/__init__.py +2 -0
  20. cuqi/legacy/sampler/__init__.py +11 -0
  21. cuqi/legacy/sampler/_conjugate.py +55 -0
  22. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  23. cuqi/legacy/sampler/_cwmh.py +196 -0
  24. cuqi/legacy/sampler/_gibbs.py +231 -0
  25. cuqi/legacy/sampler/_hmc.py +335 -0
  26. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  27. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  28. cuqi/legacy/sampler/_mh.py +190 -0
  29. cuqi/legacy/sampler/_pcn.py +244 -0
  30. cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +134 -152
  31. cuqi/legacy/sampler/_sampler.py +182 -0
  32. cuqi/likelihood/_likelihood.py +1 -1
  33. cuqi/model/_model.py +1248 -357
  34. cuqi/pde/__init__.py +4 -0
  35. cuqi/pde/_observation_map.py +36 -0
  36. cuqi/pde/_pde.py +133 -32
  37. cuqi/problem/_problem.py +88 -82
  38. cuqi/sampler/__init__.py +120 -8
  39. cuqi/sampler/_conjugate.py +376 -35
  40. cuqi/sampler/_conjugate_approx.py +40 -16
  41. cuqi/sampler/_cwmh.py +132 -138
  42. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  43. cuqi/sampler/_gibbs.py +269 -130
  44. cuqi/sampler/_hmc.py +328 -201
  45. cuqi/sampler/_langevin_algorithm.py +282 -98
  46. cuqi/sampler/_laplace_approximation.py +87 -117
  47. cuqi/sampler/_mh.py +47 -157
  48. cuqi/sampler/_pcn.py +56 -211
  49. cuqi/sampler/_rto.py +206 -140
  50. cuqi/sampler/_sampler.py +540 -135
  51. cuqi/solver/_solver.py +6 -2
  52. cuqi/testproblem/_testproblem.py +2 -3
  53. cuqi/utilities/__init__.py +3 -1
  54. cuqi/utilities/_utilities.py +94 -12
  55. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/METADATA +6 -4
  56. cuqipy-1.4.0.post0.dev61.dist-info/RECORD +102 -0
  57. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/WHEEL +1 -1
  58. CUQIpy-1.3.0.dist-info/RECORD +0 -100
  59. cuqi/experimental/mcmc/__init__.py +0 -123
  60. cuqi/experimental/mcmc/_conjugate.py +0 -345
  61. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  62. cuqi/experimental/mcmc/_cwmh.py +0 -193
  63. cuqi/experimental/mcmc/_gibbs.py +0 -318
  64. cuqi/experimental/mcmc/_hmc.py +0 -464
  65. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -392
  66. cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
  67. cuqi/experimental/mcmc/_mh.py +0 -80
  68. cuqi/experimental/mcmc/_pcn.py +0 -89
  69. cuqi/experimental/mcmc/_sampler.py +0 -566
  70. cuqi/experimental/mcmc/_utilities.py +0 -17
  71. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info/licenses}/LICENSE +0 -0
  72. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,216 @@
1
+ import cuqi
2
+ import inspect
3
+ import numpy as np
4
+
5
+ # This import makes suggest_sampler easier to read
6
+ import cuqi.sampler as samplers
7
+
8
+
9
+ class SamplerRecommender(object):
10
+ """
11
+ This class can be used to automatically choose a sampler.
12
+
13
+ Parameters
14
+ ----------
15
+ target: Density or JointDistribution
16
+ Distribution to get sampler recommendations for.
17
+
18
+ exceptions: list[cuqi.sampler.Sampler], *optional*
19
+ Samplers not to be recommended.
20
+
21
+ Example
22
+ -------
23
+ .. code-block:: python
24
+ import numpy as np
25
+ from cuqi.distribution import Gamma, Gaussian, JointDistribution
26
+ from cuqi.experimental import SamplerRecommender
27
+
28
+ x = Gamma(1, 1)
29
+ y = Gaussian(np.zeros(2), cov=lambda x: 1 / x)
30
+ target = JointDistribution(y, x)(y=1)
31
+
32
+ recommender = SamplerRecommender(target)
33
+ valid_samplers = recommender.valid_samplers()
34
+ recommended_sampler = recommender.recommend()
35
+ print("Valid samplers:", valid_samplers)
36
+ print("Recommended sampler:\n", recommended_sampler)
37
+
38
+ """
39
+
40
+ def __init__(self, target:cuqi.density.Density, exceptions = []):
41
+ self._target = target
42
+ self._exceptions = exceptions
43
+ self._create_ordering()
44
+
45
+ @property
46
+ def target(self) -> cuqi.density.Density:
47
+ """ Return the target Distribution. """
48
+ return self._target
49
+
50
+ @target.setter
51
+ def target(self, value:cuqi.density.Density):
52
+ """ Set the target Distribution. Runs validation of the target. """
53
+ if value is None:
54
+ raise ValueError("Target needs to be of type cuqi.density.Density.")
55
+ self._target = value
56
+
57
+ def _create_ordering(self):
58
+ """
59
+ Every element in the ordering consists of a tuple:
60
+ (
61
+ Sampler: Class
62
+ boolean: additional conditions on the target
63
+ parameters: additional parameters to be passed to the sampler once initialized
64
+ )
65
+ """
66
+ number_of_components = np.sum(self._target.dim)
67
+
68
+ self._ordering = [
69
+ # Direct and Conjugate samplers
70
+ (samplers.Direct, True, {}),
71
+ (samplers.Conjugate, True, {}),
72
+ (samplers.ConjugateApprox, True, {}),
73
+ # Specialized samplers
74
+ (samplers.LinearRTO, True, {}),
75
+ (samplers.RegularizedLinearRTO, True, {}),
76
+ (samplers.UGLA, True, {}),
77
+ # Gradient.based samplers (Hamiltonian and Langevin)
78
+ (samplers.NUTS, True, {}),
79
+ (samplers.MALA, True, {}),
80
+ (samplers.ULA, True, {}),
81
+ # Gibbs and Componentwise samplers
82
+ (samplers.HybridGibbs, True, {"sampling_strategy" : self.recommend_HybridGibbs_sampling_strategy(as_string = False)}),
83
+ (samplers.CWMH, number_of_components <= 100, {"scale" : 0.05*np.ones(number_of_components),
84
+ "initial_point" : 0.5*np.ones(number_of_components)}),
85
+ # Proposal based samplers
86
+ (samplers.PCN, True, {"scale" : 0.02}),
87
+ (samplers.MH, number_of_components <= 1000, {}),
88
+ ]
89
+
90
+ @property
91
+ def ordering(self):
92
+ """ Returns the ordered list of recommendation rules used by the recommender. """
93
+ return self._ordering
94
+
95
+ def valid_samplers(self, as_string = True):
96
+ """
97
+ Finds all possible samplers that can be used for sampling from the target distribution.
98
+
99
+ Parameters
100
+ ----------
101
+
102
+ as_string : boolean
103
+ Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
104
+
105
+ """
106
+
107
+ all_samplers = [(name, cls) for name, cls in inspect.getmembers(cuqi.sampler, inspect.isclass) if issubclass(cls, cuqi.sampler.Sampler)]
108
+ valid_samplers = []
109
+
110
+ for name, sampler in all_samplers:
111
+ try:
112
+ sampler(self.target)
113
+ valid_samplers += [name if as_string else sampler]
114
+ except:
115
+ pass
116
+
117
+ # Need a separate case for HybridGibbs
118
+ if self.valid_HybridGibbs_sampling_strategy() is not None:
119
+ valid_samplers += [cuqi.sampler.HybridGibbs.__name__ if as_string else cuqi.sampler.HybridGibbs]
120
+
121
+ return valid_samplers
122
+
123
+ def valid_HybridGibbs_sampling_strategy(self, as_string = True):
124
+ """
125
+ Find all possible sampling strategies to be used with the HybridGibbs sampler.
126
+ Returns None if no sampler could be suggested for at least one conditional distribution.
127
+
128
+ Parameters
129
+ ----------
130
+
131
+ as_string : boolean
132
+ Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
133
+
134
+
135
+ """
136
+
137
+ if not isinstance(self.target, cuqi.distribution.JointDistribution):
138
+ return None
139
+
140
+ par_names = self.target.get_parameter_names()
141
+
142
+ valid_samplers = dict()
143
+ for par_name in par_names:
144
+ conditional_params = {par_name_: np.ones(self.target.dim[i]) for i, par_name_ in enumerate(par_names) if par_name_ != par_name}
145
+ conditional = self.target(**conditional_params)
146
+
147
+ recommender = SamplerRecommender(conditional)
148
+ samplers = recommender.valid_samplers(as_string)
149
+ if len(samplers) == 0:
150
+ return None
151
+
152
+ valid_samplers[par_name] = samplers
153
+
154
+ return valid_samplers
155
+
156
+ def recommend(self, as_string = False):
157
+ """
158
+ Suggests a possible sampler that can be used for sampling from the target distribution.
159
+ Return None if no sampler could be suggested.
160
+
161
+ Parameters
162
+ ----------
163
+
164
+ as_string : boolean
165
+ Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
166
+
167
+ """
168
+
169
+ valid_samplers = self.valid_samplers(as_string = False)
170
+
171
+ for suggestion, flag, values in self._ordering:
172
+ if flag and (suggestion in valid_samplers) and (suggestion not in self._exceptions):
173
+ # Sampler found
174
+ if as_string:
175
+ return suggestion.__name__
176
+ else:
177
+ return suggestion(self.target, **values)
178
+
179
+ # No sampler can be suggested
180
+ raise ValueError("Cannot suggest any sampler. Either the provided distribution is incorrectly defined or there are too many exceptions provided.")
181
+
182
+ def recommend_HybridGibbs_sampling_strategy(self, as_string = False):
183
+ """
184
+ Suggests a possible sampling strategy to be used with the HybridGibbs sampler.
185
+ Returns None if no sampler could be suggested for at least one conditional distribution.
186
+
187
+ Parameters
188
+ ----------
189
+
190
+ target : `cuqi.distribution.JointDistribution`
191
+ The target distribution get a sampling strategy for.
192
+
193
+ as_string : boolean
194
+ Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
195
+
196
+ """
197
+
198
+ if not isinstance(self.target, cuqi.distribution.JointDistribution):
199
+ return None
200
+
201
+ par_names = self.target.get_parameter_names()
202
+
203
+ suggested_samplers = dict()
204
+ for par_name in par_names:
205
+ conditional_params = {par_name_: np.ones(self.target.dim[i]) for i, par_name_ in enumerate(par_names) if par_name_ != par_name}
206
+ conditional = self.target(**conditional_params)
207
+
208
+ recommender = SamplerRecommender(conditional, exceptions = self._exceptions.copy())
209
+ sampler = recommender.recommend(as_string = as_string)
210
+
211
+ if sampler is None:
212
+ return None
213
+
214
+ suggested_samplers[par_name] = sampler
215
+
216
+ return suggested_samplers
@@ -172,10 +172,10 @@ class _ProductGeometry(Geometry):
172
172
  return tuple(funvecs)
173
173
 
174
174
 
175
- def __repr__(self) -> str:
175
+ def __repr__(self, pad="") -> str:
176
176
  """Representation of the product geometry."""
177
177
  string = "{}(".format(self.__class__.__name__) + "\n"
178
178
  for g in self.geometries:
179
- string += "\t{}\n".format(g.__repr__())
180
- string += ")"
179
+ string += pad + " {}\n".format(g.__repr__())
180
+ string += pad + ")"
181
181
  return string
@@ -225,7 +225,18 @@ class Geometry(ABC):
225
225
  return self._all_values_equal(obj)
226
226
 
227
227
  def __repr__(self) -> str:
228
- return "{}{}".format(self.__class__.__name__,self.par_shape)
228
+ if self.par_shape == self.fun_shape:
229
+ return "{}[{}]".format(self.__class__.__name__,
230
+ self.par_shape if len(self.par_shape) != 1 else self.par_shape[0])
231
+ return "{}[{}: {}]".format(
232
+ self.__class__.__name__,
233
+ self.par_shape if len(self.par_shape) != 1 else self.par_shape[0],
234
+ (
235
+ self.fun_shape
236
+ if (self.fun_shape is None or len(self.fun_shape) != 1)
237
+ else self.fun_shape[0]
238
+ ),
239
+ )
229
240
 
230
241
  def _all_values_equal(self, obj):
231
242
  """Returns true of all values of the object and self are equal"""
@@ -1,5 +1,5 @@
1
1
  from ._regularizedGaussian import RegularizedGaussian, ConstrainedGaussian, NonnegativeGaussian
2
2
  from ._regularizedGMRF import RegularizedGMRF, ConstrainedGMRF, NonnegativeGMRF
3
3
  from ._regularizedUnboundedUniform import RegularizedUnboundedUniform
4
- from ._restorator import RestorationPrior, MoreauYoshidaPrior
4
+ from ._restorator import RestorationPrior, MoreauYoshidaPrior, TweediePrior
5
5
 
@@ -2,10 +2,12 @@ from cuqi.utilities import get_non_default_args
2
2
  from cuqi.distribution import Distribution, Gaussian
3
3
  from cuqi.solver import ProjectNonnegative, ProjectBox, ProximalL1
4
4
  from cuqi.geometry import Continuous1D, Continuous2D, Image2D
5
- from cuqi.operator import FirstOrderFiniteDifference, Operator
5
+ from cuqi.operator import FirstOrderFiniteDifference, SecondOrderFiniteDifference, Operator
6
6
 
7
7
  import numpy as np
8
8
  import scipy.sparse as sparse
9
+ import scipy.optimize as spoptimize
10
+
9
11
  from copy import copy
10
12
 
11
13
 
@@ -60,12 +62,18 @@ class RegularizedGaussian(Distribution):
60
62
  min_(z in C) 0.5||x-z||_2^2.
61
63
 
62
64
  constraint : string or None
63
- Preset constraints that generate the corresponding proximal parameter. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
64
- For "box", the following additional parameters can be passed:
65
+ Preset constraints that generate the corresponding proximal parameter. Required for use in Gibbs. For any geometry the following can be chosen:
66
+ - "nonnegativity"
67
+ - "box", the following additional parameters can be passed:
65
68
  lower_bound : array_like or None
66
69
  Lower bound of box, defaults to zero
67
70
  upper_bound : array_like
68
71
  Upper bound of box, defaults to one
72
+ Additionally, for Continuous1D geometry the following options can be chosen:
73
+ - "increasing"
74
+ - "decreasing"
75
+ - "convex"
76
+ - "concave"
69
77
 
70
78
  regularization : string or None
71
79
  Preset regularization that generate the corresponding proximal parameter. Can be set to "l1" or 'tv'. Required for use in Gibbs in future update.
@@ -178,6 +186,34 @@ class RegularizedGaussian(Distribution):
178
186
  self._proximal = lambda z, _: ProjectBox(z, _box_lower, _box_upper)
179
187
  self._box_bounds = (np.ones(self.dim)*_box_lower, np.ones(self.dim)*_box_upper)
180
188
  self._preset["constraint"] = "box"
189
+ elif c_lower == "increasing":
190
+ if not isinstance(self.geometry, Continuous1D):
191
+ raise ValueError("Geometry not supported for " + c_lower)
192
+ if hasattr(spoptimize, 'isotonic_regression'):
193
+ self._constraint_prox = lambda z, _: spoptimize.isotonic_regression(z, increasing=True).x
194
+ else:
195
+ raise AttributeError(f"The function 'isotonic_regression' does not exist in scipy.optimize. Installed scipy version: {spoptimize.__version__}. You need to install a scipy >= 1.12.0")
196
+ self._preset["constraint"] = "increasing"
197
+ elif c_lower == "decreasing":
198
+ if not isinstance(self.geometry, Continuous1D):
199
+ raise ValueError("Geometry not supported for " + c_lower)
200
+ if hasattr(spoptimize, 'isotonic_regression'):
201
+ self._constraint_prox = lambda z, _: spoptimize.isotonic_regression(z, increasing=False).x
202
+ else:
203
+ raise AttributeError(f"The function 'isotonic_regression' does not exist in scipy.optimize. Installed scipy version: {spoptimize.__version__}. You need to install a scipy >= 1.12.0")
204
+ self._preset["constraint"] = "decreasing"
205
+ elif c_lower == "convex":
206
+ if not isinstance(self.geometry, Continuous1D):
207
+ raise ValueError("Geometry not supported for " + c_lower)
208
+ self._constraint_prox = lambda z, _: -ProjectNonnegative(-z)
209
+ self._constraint_oper = SecondOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
210
+ self._preset["constraint"] = "convex"
211
+ elif c_lower == "concave":
212
+ if not isinstance(self.geometry, Continuous1D):
213
+ raise ValueError("Geometry not supported for " + c_lower)
214
+ self._constraint_prox = lambda z, _: ProjectNonnegative(z)
215
+ self._constraint_oper = SecondOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
216
+ self._preset["constraint"] = "concave"
181
217
  else:
182
218
  raise ValueError("Constraint not supported.")
183
219
 
@@ -289,7 +325,7 @@ class RegularizedGaussian(Distribution):
289
325
 
290
326
  @staticmethod
291
327
  def constraint_options():
292
- return ["nonnegativity", "box"]
328
+ return ["nonnegativity", "box", "increasing", "decreasing", "convex", "concave"]
293
329
 
294
330
  @staticmethod
295
331
  def regularization_options():
@@ -232,4 +232,38 @@ class MoreauYoshidaPrior(Distribution):
232
232
  """ Returns the conditioning variables of the distribution. """
233
233
  # Currently conditioning variables are not supported for user-defined
234
234
  # distributions.
235
- return []
235
+ return []
236
+
237
+ class TweediePrior(MoreauYoshidaPrior):
238
+ """
239
+ Alias for MoreauYoshidaPrior following Tweedie's formula framework. TweediePrior
240
+ defines priors where gradients are computed based on Tweedie's identity that links
241
+ MMSE (Minimum Mean Square Error) denoisers with the underlying smoothed prior, see:
242
+ - Laumont et al. https://arxiv.org/abs/2103.04715 or https://doi.org/10.1137/21M1406349
243
+
244
+ Tweedie's Formula
245
+ -------------------------
246
+ In the context of denoising, Tweedie's identity states that for a signal x
247
+ corrupted by Gaussian noise:
248
+
249
+ ∇_x log p_e(x) = (D_e(x) - x) / e
250
+
251
+ where D_e(x) is the MMSE denoiser output and e is the noise variance.
252
+ This enables us to perform gradient-based sampling with algorithms like ULA.
253
+
254
+ At implementation level, TweediePrior shares identical functionality with MoreauYoshidaPrior.
255
+ Thus, it is implemented as an alias of MoreauYoshidaPrior, meaning all methods,
256
+ properties, and behavior are identical. The separate name provides clarity when
257
+ working specifically with Tweedie's formula-based approaches.
258
+
259
+ Parameters
260
+ ----------
261
+ prior : RestorationPrior
262
+ Prior of the RestorationPrior type containing a denoiser/restorator.
263
+
264
+ smoothing_strength : float, default=0.1
265
+ Corresponds to the noise variance e in Tweedie's formula context.
266
+
267
+ See MoreauYoshidaPrior for the underlying implementation with complete documentation.
268
+ """
269
+ pass
@@ -0,0 +1,2 @@
1
+ """ Legacy module for functionalities that are no longer supported or developed. """
2
+ from . import sampler
@@ -0,0 +1,11 @@
1
+ from ._sampler import Sampler, ProposalBasedSampler
2
+ from ._conjugate import Conjugate
3
+ from ._conjugate_approx import ConjugateApprox
4
+ from ._cwmh import CWMH
5
+ from ._gibbs import Gibbs
6
+ from ._hmc import NUTS
7
+ from ._langevin_algorithm import ULA, MALA
8
+ from ._laplace_approximation import UGLA
9
+ from ._mh import MH
10
+ from ._pcn import pCN
11
+ from ._rto import LinearRTO, RegularizedLinearRTO
@@ -0,0 +1,55 @@
1
+ from cuqi.distribution import Posterior, Gaussian, Gamma, GMRF
2
+ from cuqi.implicitprior import RegularizedGaussian, RegularizedGMRF
3
+ import numpy as np
4
+
5
+ class Conjugate: # TODO: Subclass from Sampler once updated
6
+ """ Conjugate sampler
7
+
8
+ Sampler for sampling a posterior distribution where the likelihood and prior are conjugate.
9
+
10
+ Currently supported conjugate pairs are:
11
+ - (Gaussian, Gamma)
12
+ - (GMRF, Gamma)
13
+ - (RegularizedGaussian, Gamma) with nonnegativity constraints only
14
+
15
+ For more information on conjugate pairs, see https://en.wikipedia.org/wiki/Conjugate_prior.
16
+
17
+ For implicit regularized Gaussians see:
18
+
19
+ [1] Everink, Jasper M., Yiqiu Dong, and Martin S. Andersen. "Bayesian inference with projected densities." SIAM/ASA Journal on Uncertainty Quantification 11.3 (2023): 1025-1043.
20
+
21
+ """
22
+
23
+ def __init__(self, target: Posterior):
24
+ if not isinstance(target.likelihood.distribution, (Gaussian, GMRF, RegularizedGaussian, RegularizedGMRF)):
25
+ raise ValueError("Conjugate sampler only works with a Gaussian-type likelihood function")
26
+ if not isinstance(target.prior, Gamma):
27
+ raise ValueError("Conjugate sampler only works with Gamma prior")
28
+ if not target.prior.dim == 1:
29
+ raise ValueError("Conjugate sampler only works with univariate Gamma prior")
30
+
31
+ if isinstance(target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)) and (target.likelihood.distribution.preset["constraint"] not in ["nonnegativity"] or target.likelihood.distribution.preset["regularization"] is not None) :
32
+ raise ValueError("Conjugate sampler only works implicit regularized Gaussian likelihood with nonnegativity constraints")
33
+
34
+ self.target = target
35
+
36
+ def step(self, x=None):
37
+ # Extract variables
38
+ b = self.target.likelihood.data #mu
39
+ m = self._calc_m_for_Gaussians(b) #n
40
+ Ax = self.target.likelihood.distribution.mean #x_i
41
+ L = self.target.likelihood.distribution(np.array([1])).sqrtprec #L
42
+ alpha = self.target.prior.shape #alpha
43
+ beta = self.target.prior.rate #beta
44
+
45
+ # Create Gamma distribution and sample
46
+ dist = Gamma(shape=m/2+alpha,rate=.5*np.linalg.norm(L@(Ax-b))**2+beta)
47
+
48
+ return dist.sample()
49
+
50
+ def _calc_m_for_Gaussians(self, b):
51
+ """ Helper method to calculate m parameter for Gaussian-Gamma conjugate pair. """
52
+ if isinstance(self.target.likelihood.distribution, (Gaussian, GMRF)):
53
+ return len(b)
54
+ elif isinstance(self.target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)):
55
+ return np.count_nonzero(b) # See
@@ -0,0 +1,52 @@
1
+ from cuqi.distribution import Posterior, LMRF, Gamma
2
+ import numpy as np
3
+ import scipy as sp
4
+
5
+ class ConjugateApprox: # TODO: Subclass from Sampler once updated
6
+ """ Approximate Conjugate sampler
7
+
8
+ Sampler for sampling a posterior distribution where the likelihood and prior can be approximated
9
+ by a conjugate pair.
10
+
11
+ Currently supported pairs are:
12
+ - (LMRF, Gamma): Approximated by (Gaussian, Gamma)
13
+
14
+ For more information on conjugate pairs, see https://en.wikipedia.org/wiki/Conjugate_prior.
15
+
16
+ """
17
+
18
+
19
+ def __init__(self, target: Posterior):
20
+ if not isinstance(target.likelihood.distribution, LMRF):
21
+ raise ValueError("Conjugate sampler only works with Laplace diff likelihood function")
22
+ if not isinstance(target.prior, Gamma):
23
+ raise ValueError("Conjugate sampler only works with Gamma prior")
24
+ self.target = target
25
+
26
+ def step(self, x=None):
27
+ # Extract variables
28
+ # Here we approximate the Laplace diff with a Gaussian
29
+
30
+ # Extract diff_op from target likelihood
31
+ D = self.target.likelihood.distribution._diff_op
32
+ n = D.shape[0]
33
+
34
+ # Gaussian approximation of LMRF prior as function of x_k
35
+ # See Uribe et al. (2022) for details
36
+ # Current has a zero mean assumption on likelihood! TODO
37
+ beta=1e-5
38
+ def Lk_fun(x_k):
39
+ dd = 1/np.sqrt((D @ x_k)**2 + beta*np.ones(n))
40
+ W = sp.sparse.diags(dd)
41
+ return W.sqrt() @ D
42
+
43
+ x = self.target.likelihood.data #x
44
+ d = len(x) #d
45
+ Lx = Lk_fun(x)@x #Lx
46
+ alpha = self.target.prior.shape #alpha
47
+ beta = self.target.prior.rate #beta
48
+
49
+ # Create Gamma distribution and sample
50
+ dist = Gamma(shape=d+alpha, rate=np.linalg.norm(Lx)**2+beta)
51
+
52
+ return dist.sample()