CUQIpy 1.0.0.post0.dev384__py3-none-any.whl → 1.0.0.post0.dev420__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.0.0.post0.dev384
3
+ Version: 1.0.0.post0.dev420
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
2
2
  cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
3
- cuqi/_version.py,sha256=IKVgkfhXqLynqXrlVfTXOFBq5r4w6rJcCBTk1PNkuKA,510
3
+ cuqi/_version.py,sha256=w5RJCnySMypI_9tEG8VNibbbBzdnFB4_1O9-0uJHbdE,510
4
4
  cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
5
5
  cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
6
6
  cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
@@ -18,7 +18,7 @@ cuqi/distribution/__init__.py,sha256=4vVLArg6NVzBj67vVioK8BY6wISJKb5cOxdoHMuUb_s
18
18
  cuqi/distribution/_beta.py,sha256=hdAc6Tbuz9Yqf76NSHxpaUgN7s6Z2lNV7YSRD3JhyCU,2997
19
19
  cuqi/distribution/_cauchy.py,sha256=UsVXYz8HhagXN5fIWSAIyELqhsJAX_-wk9kkRGgRmA8,3296
20
20
  cuqi/distribution/_cmrf.py,sha256=tCbEulM_O7FB3C_W-3IqZp9zGHkTofCdFF0ybHc9UZI,3745
21
- cuqi/distribution/_custom.py,sha256=uUJwGlGjcMY89mIyu9nFI3OafLOMgn8uAEMfCbTDzi0,10661
21
+ cuqi/distribution/_custom.py,sha256=3toVV_ntnWAGlI7pNQLS7-7m5i-34sTWQrxL-SMweq0,10681
22
22
  cuqi/distribution/_distribution.py,sha256=G7BCpVueK4QLoLa_hu9h-Euh58Yp9SrgUKuudUlg-pw,18351
23
23
  cuqi/distribution/_gamma.py,sha256=9vljt5iaBDCHRhrVCMLc2RWDuBchZRQcv9buJMDYPlM,3434
24
24
  cuqi/distribution/_gaussian.py,sha256=DmmgVxKp4iEiEYWDdDcRoh35y14Oepn-zDHex0WVaYo,33316
@@ -35,10 +35,10 @@ cuqi/distribution/_smoothed_laplace.py,sha256=bfvOI4YOYC6QPds1UHNHar_q-vOZEjBcPx
35
35
  cuqi/distribution/_uniform.py,sha256=7xJmCZH_LPhuGkwEDGh-_CTtzcWKrXMOxtTJUFb7Ydo,1607
36
36
  cuqi/experimental/__init__.py,sha256=vhZvyMX6rl8Y0haqCzGLPz6PSUKyu75XMQbeDHqTTrw,83
37
37
  cuqi/experimental/mcmc/__init__.py,sha256=0Vk_MzfE_9tvqQRgR6_3nkjSe_D3vgFqVM9pFrXN2iQ,581
38
- cuqi/experimental/mcmc/_conjugate.py,sha256=qYrBvZ9wNK4oBz0c0RRUtQkbpPIHI3BvBYSLRw8ok5k,3757
39
- cuqi/experimental/mcmc/_conjugate_approx.py,sha256=JQe9gmnNespCxSP6vaZWfizFvUWUh8Jn-jRqsJYyNeM,2839
38
+ cuqi/experimental/mcmc/_conjugate.py,sha256=r3cEXFXjNucDUAzj9mq4nZtgc6B3lJM21c77d1tB8gw,9936
39
+ cuqi/experimental/mcmc/_conjugate_approx.py,sha256=FLhNN0O6DvohLjmekMA6iVn8yXMEvUGcx1s8w3Wu8cA,3665
40
40
  cuqi/experimental/mcmc/_cwmh.py,sha256=-TM_S_UtD5ljEfXGEUpYImxNx3JXppIKTSpoWen7kP8,7142
41
- cuqi/experimental/mcmc/_direct.py,sha256=E3UevdJ_DLk2wL0lid1TTKkdmgnIMJ5Ihr7iM1jU8KI,738
41
+ cuqi/experimental/mcmc/_direct.py,sha256=pAtxqoSQGhLdukLi8gcoY2qXJmKluttX5RWOk5fAbOY,786
42
42
  cuqi/experimental/mcmc/_gibbs.py,sha256=z6YOCiBM1YuZbQHfdmsArR-pT61dsS14F_O4kUxsNYM,10638
43
43
  cuqi/experimental/mcmc/_hmc.py,sha256=0sZMHtnNFGGtQdzpx-cgqA0xyfvGy7r4K62RH3AQNa4,19285
44
44
  cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=n6WRQooKuUDjmqF-CtpcSNFDvaHCgLKhWxX-hi7h_ZA,8224
@@ -63,7 +63,7 @@ cuqi/operator/_operator.py,sha256=yNwPTh7jR07AiKMbMQQ5_54EgirlKFsbq9JN1EODaQI,88
63
63
  cuqi/pde/__init__.py,sha256=NyS_ZYruCvy-Yg24qKlwm3ZIX058kLNQX9bqs-xg4ZM,99
64
64
  cuqi/pde/_pde.py,sha256=WRkOYyIdT_T3aZepRh0aS9C5nBbUZUcHaA80iSRvgoo,12572
65
65
  cuqi/problem/__init__.py,sha256=JxJty4JqHTOqSG6NeTGiXRQ7OLxiRK9jvVq3lXLeIRw,38
66
- cuqi/problem/_problem.py,sha256=Irk4OlTAhZAX81excesi8ANokz2GSAS3z_mcvx8Wqdc,32018
66
+ cuqi/problem/_problem.py,sha256=XvNbo7BXcnDZvj3n36f879QknTYg3_-jnKhkVvqUQto,31944
67
67
  cuqi/sampler/__init__.py,sha256=D-dYa0gFgIwQukP8_VKhPGmlGKXbvVo7YqaET4SdAeQ,382
68
68
  cuqi/sampler/_conjugate.py,sha256=ztmUR3V3qZk9zelKx48ULnmMs_zKTDUfohc256VOIe8,2753
69
69
  cuqi/sampler/_conjugate_approx.py,sha256=xX-X71EgxGnZooOY6CIBhuJTs3dhcKfoLnoFxX3CO2g,1938
@@ -85,8 +85,8 @@ cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7s
85
85
  cuqi/utilities/__init__.py,sha256=T4tLsC215MknBCsw_C0Qeeg_ox26aDUrCA5hbWvNQkU,387
86
86
  cuqi/utilities/_get_python_variable_name.py,sha256=QwlBVj2koJRA8s8pWd554p7-ElcI7HUwY32HknaR92E,1827
87
87
  cuqi/utilities/_utilities.py,sha256=MWAqV6L5btMpWwlUzrZYuV2VeSpfTbOaLRMRkuw2WIA,8509
88
- CUQIpy-1.0.0.post0.dev384.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
89
- CUQIpy-1.0.0.post0.dev384.dist-info/METADATA,sha256=5xT68suvERdJf0DPgCPA23chSdG2yRnQQJhW_zSnjps,18393
90
- CUQIpy-1.0.0.post0.dev384.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
91
- CUQIpy-1.0.0.post0.dev384.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
92
- CUQIpy-1.0.0.post0.dev384.dist-info/RECORD,,
88
+ CUQIpy-1.0.0.post0.dev420.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
89
+ CUQIpy-1.0.0.post0.dev420.dist-info/METADATA,sha256=3KFA90vbSnjyOmJtY3j8wtqt7wrzvmqDZ8NV9pWoI3Q,18393
90
+ CUQIpy-1.0.0.post0.dev420.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
91
+ CUQIpy-1.0.0.post0.dev420.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
92
+ CUQIpy-1.0.0.post0.dev420.dist-info/RECORD,,
cuqi/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-05T09:29:07+0200",
11
+ "date": "2024-07-05T09:58:51+0200",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "3e3d762fab019f39257708bb9cdf5abc3681f196",
15
- "version": "1.0.0.post0.dev384"
14
+ "full-revisionid": "c77d3deee70affd3a987b7a639bd576c4fe3896a",
15
+ "version": "1.0.0.post0.dev420"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -50,13 +50,13 @@ class UserDefinedDistribution(Distribution):
50
50
  if self.logpdf_func is not None:
51
51
  return self.logpdf_func(x)
52
52
  else:
53
- raise Exception("logpdf_func is not defined.")
53
+ raise NotImplementedError("logpdf_func is not defined.")
54
54
 
55
55
  def _gradient(self, x):
56
56
  if self.gradient_func is not None:
57
57
  return self.gradient_func(x)
58
58
  else:
59
- raise Exception("gradient_func is not defined.")
59
+ raise NotImplementedError("gradient_func is not defined.")
60
60
 
61
61
  def _sample(self, N=1, rng=None):
62
62
  #TODO(nabr) allow sampling more than 1 sample and potentially rng?
@@ -1,12 +1,15 @@
1
1
  import numpy as np
2
+ from abc import ABC, abstractmethod
3
+ import math
2
4
  from cuqi.experimental.mcmc import SamplerNew
3
5
  from cuqi.distribution import Posterior, Gaussian, Gamma, GMRF
4
6
  from cuqi.implicitprior import RegularizedGaussian, RegularizedGMRF
7
+ from cuqi.utilities import get_non_default_args
5
8
 
6
9
  class ConjugateNew(SamplerNew):
7
10
  """ Conjugate sampler
8
11
 
9
- Sampler for sampling a posterior distribution where the likelihood and prior are conjugate.
12
+ Sampler for sampling a posterior distribution which is a so-called "conjugate" distribution, i.e., where the likelihood and prior are conjugate to each other - denoted as a conjugate pair.
10
13
 
11
14
  Currently supported conjugate pairs are:
12
15
  - (Gaussian, Gamma) where Gamma is defined on the precision parameter of the Gaussian
@@ -14,64 +17,181 @@ class ConjugateNew(SamplerNew):
14
17
  - (RegularizedGaussian, Gamma) with nonnegativity constraints only and Gamma is defined on the precision parameter of the RegularizedGaussian
15
18
  - (RegularizedGMRF, Gamma) with nonnegativity constraints only and Gamma is defined on the precision parameter of the RegularizedGMRF
16
19
 
17
- Gamma distribution must be univariate.
20
+ Currently the Gamma distribution must be univariate.
18
21
 
19
- Currently, the sampler does NOT automatically check that the conjugate distributions are defined on the correct parameters.
22
+ A conjugate pair defines implicitly a so-called conjugate distribution which can be sampled from directly.
20
23
 
21
- For more information on conjugate pairs, see https://en.wikipedia.org/wiki/Conjugate_prior.
24
+ The conjugate parameter is the parameter that both the likelihood and prior PDF depend on.
25
+
26
+ For more information on conjugacy and conjugate distributions see https://en.wikipedia.org/wiki/Conjugate_prior.
22
27
 
23
28
  For implicit regularized Gaussians see:
24
29
 
25
30
  [1] Everink, Jasper M., Yiqiu Dong, and Martin S. Andersen. "Bayesian inference with projected densities." SIAM/ASA Journal on Uncertainty Quantification 11.3 (2023): 1025-1043.
26
31
 
27
32
  """
33
+
28
34
  def _initialize(self):
29
35
  pass
30
36
 
37
+ @SamplerNew.target.setter # Overwrite the target setter to set the conjugate pair
38
+ def target(self, value):
39
+ """ Set the target density. Runs validation of the target. """
40
+ self._target = value
41
+ if self._target is not None:
42
+ self._set_conjugatepair()
43
+ self.validate_target()
44
+
31
45
  def validate_target(self):
46
+ self._ensure_target_is_posterior()
47
+ self._conjugatepair.validate_target()
48
+
49
+ def step(self):
50
+ self.current_point = self._conjugatepair.sample()
51
+ return 1 # Returns acceptance rate of 1
32
52
 
53
+ def tune(self, skip_len, update_count):
54
+ pass # No tuning required for conjugate sampler
55
+
56
+ def _ensure_target_is_posterior(self):
57
+ """ Ensure that the target is a Posterior distribution. """
33
58
  if not isinstance(self.target, Posterior):
34
59
  raise TypeError("Conjugate sampler requires a target of type Posterior")
35
60
 
36
- if not isinstance(self.target.likelihood.distribution, (Gaussian, GMRF, RegularizedGaussian, RegularizedGMRF)):
37
- raise ValueError("Conjugate sampler only works with a Gaussian-type likelihood function")
61
+ def _set_conjugatepair(self):
62
+ """ Set the conjugate pair based on the likelihood and prior. This requires target to be set. """
63
+ self._ensure_target_is_posterior()
64
+ if isinstance(self.target.likelihood.distribution, (Gaussian, GMRF)) and isinstance(self.target.prior, Gamma):
65
+ self._conjugatepair = _GaussianGammaPair(self.target)
66
+ elif isinstance(self.target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)) and isinstance(self.target.prior, Gamma):
67
+ self._conjugatepair = _RegularizedGaussianGammaPair(self.target)
68
+ else:
69
+ raise ValueError(f"Conjugacy is not defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
70
+
71
+ def __repr__(self):
72
+ msg = super().__repr__()
73
+ if hasattr(self, "_conjugatepair"):
74
+ msg += f"\n Conjugate pair:\n\t {type(self._conjugatepair).__name__.removeprefix('_')}"
75
+ return msg
38
76
 
77
+ class _ConjugatePair(ABC):
78
+ """ Abstract base class for conjugate pairs (likelihood, prior) used in the Conjugate sampler. """
79
+
80
+ def __init__(self, target):
81
+ self.target = target
82
+
83
+ @abstractmethod
84
+ def validate_target(self):
85
+ """ Validate the target distribution for the conjugate pair. """
86
+ pass
87
+
88
+ @abstractmethod
89
+ def sample(self):
90
+ """ Sample from the conjugate distribution. """
91
+ pass
92
+
93
+
94
+ class _GaussianGammaPair(_ConjugatePair):
95
+ """ Implementation for the Gaussian-Gamma conjugate pair."""
96
+
97
+ def validate_target(self):
98
+ if not isinstance(self.target.likelihood.distribution, (Gaussian, GMRF)):
99
+ raise ValueError("Conjugate sampler only works with a Gaussian likelihood function")
100
+
39
101
  if not isinstance(self.target.prior, Gamma):
40
102
  raise ValueError("Conjugate sampler only works with Gamma prior")
41
-
42
- if not self.target.prior.dim == 1:
103
+
104
+ if self.target.prior.dim != 1:
43
105
  raise ValueError("Conjugate sampler only works with univariate Gamma prior")
44
-
45
- if isinstance(self.target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)) and self.target.likelihood.distribution.preset not in ["nonnegativity"]:
46
- raise ValueError("Conjugate sampler only works with implicit regularized Gaussian likelihood with nonnegativity constraints")
47
106
 
48
- def step(self):
107
+ key, value = _get_conjugate_parameter(self.target)
108
+ if key == "cov":
109
+ if not _check_conjugate_parameter_is_scalar_reciprocal(value):
110
+ raise ValueError("Gaussian-Gamma conjugate pair defined via covariance requires `cov` for the `Gaussian` to be: lambda x : 1.0/x for the conjugate parameter")
111
+ elif key == "prec":
112
+ if not _check_conjugate_parameter_is_scalar_identity(value):
113
+ raise ValueError("Gaussian-Gamma conjugate pair defined via precision requires `prec` for the `Gaussian` to be: lambda x : x for the conjugate parameter")
114
+ else:
115
+ raise ValueError("Conjugate sampler for Gaussian likelihood functions only works when conjugate parameter is defined via covariance or precision")
116
+
117
+ def sample(self):
49
118
  # Extract variables
50
- b = self.target.likelihood.data #mu
51
- m = self._calc_m_for_Gaussians(b) #n
52
- Ax = self.target.likelihood.distribution.mean #x_i
53
- L = self.target.likelihood.distribution(np.array([1])).sqrtprec #L
54
- alpha = self.target.prior.shape #alpha
55
- beta = self.target.prior.rate #beta
119
+ b = self.target.likelihood.data # mu
120
+ m = len(b) # n
121
+ Ax = self.target.likelihood.distribution.mean # x_i
122
+ L = self.target.likelihood.distribution(np.array([1])).sqrtprec # L
123
+ alpha = self.target.prior.shape # alpha
124
+ beta = self.target.prior.rate # beta
56
125
 
57
126
  # Create Gamma distribution and sample
58
- dist = Gamma(shape=m/2+alpha,rate=.5*np.linalg.norm(L@(Ax-b))**2+beta)
127
+ dist = Gamma(shape=m/2 + alpha, rate=.5 * np.linalg.norm(L @ (Ax - b))**2 + beta)
59
128
 
60
- self.current_point = dist.sample()
129
+ return dist.sample()
61
130
 
62
- def tune(self, skip_len, update_count):
63
- pass
64
131
 
65
- def _calc_m_for_Gaussians(self, b):
66
- """ Helper method to calculate m parameter for Gaussian-Gamma conjugate pair.
67
-
68
- Classically m defines the number of observations in the Gaussian likelihood function.
132
+ class _RegularizedGaussianGammaPair(_ConjugatePair):
133
+ """Implementation for the Regularized Gaussian-Gamma conjugate pair."""
134
+
135
+ def validate_target(self):
136
+ if not isinstance(self.target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)):
137
+ raise ValueError("Conjugate sampler only works with a Regularized Gaussian likelihood function")
138
+
139
+ if not isinstance(self.target.prior, Gamma):
140
+ raise ValueError("Conjugate sampler only works with Gamma prior")
69
141
 
70
- However, for implicit regularized Gaussians, m is the number of non-zero elements in the data vector b see [1].
71
-
72
- """
142
+ if self.target.prior.dim != 1:
143
+ raise ValueError("Conjugate sampler only works with univariate Gamma prior")
73
144
 
74
- if isinstance(self.target.likelihood.distribution, (Gaussian, GMRF)):
75
- return len(b)
76
- elif isinstance(self.target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)):
77
- return np.count_nonzero(b)
145
+ if self.target.likelihood.distribution.preset not in ["nonnegativity"]:
146
+ raise ValueError("Conjugate sampler only works with implicit regularized Gaussian likelihood with nonnegativity constraints")
147
+
148
+ key, value = _get_conjugate_parameter(self.target)
149
+ if key == "cov":
150
+ if not _check_conjugate_parameter_is_scalar_reciprocal(value):
151
+ raise ValueError("Regularized Gaussian-Gamma conjugate pair defined via covariance requires cov: lambda x : 1.0/x for the conjugate parameter")
152
+ elif key == "prec":
153
+ if not _check_conjugate_parameter_is_scalar_identity(value):
154
+ raise ValueError("Regularized Gaussian-Gamma conjugate pair defined via precision requires prec: lambda x : x for the conjugate parameter")
155
+ else:
156
+ raise ValueError("Conjugate sampler for a Regularized Gaussian likelihood functions only works when conjugate parameter is defined via covariance or precision")
157
+
158
+ def sample(self):
159
+ # Extract variables
160
+ b = self.target.likelihood.data # mu
161
+ m = np.count_nonzero(b) # n
162
+ Ax = self.target.likelihood.distribution.mean # x_i
163
+ L = self.target.likelihood.distribution(np.array([1])).sqrtprec # L
164
+ alpha = self.target.prior.shape # alpha
165
+ beta = self.target.prior.rate # beta
166
+
167
+ # Create Gamma distribution and sample
168
+ dist = Gamma(shape=m/2 + alpha, rate=.5 * np.linalg.norm(L @ (Ax - b))**2 + beta)
169
+
170
+ return dist.sample()
171
+
172
+ def _get_conjugate_parameter(target):
173
+ """Extract the conjugate parameter name (e.g. d), and returns the mutable variable that is defined by the conjugate parameter, e.g. cov and its value e.g. lambda d:1/d"""
174
+ par_name = target.prior.name
175
+ mutable_likelihood_vars = target.likelihood.distribution.get_mutable_variables()
176
+
177
+ found_parameter_pairs = []
178
+
179
+ for var_key in mutable_likelihood_vars:
180
+ attr = getattr(target.likelihood.distribution, var_key)
181
+ if callable(attr) and par_name in get_non_default_args(attr):
182
+ found_parameter_pairs.append((var_key, attr))
183
+ if len(found_parameter_pairs) == 1:
184
+ return found_parameter_pairs[0]
185
+ elif len(found_parameter_pairs) > 1:
186
+ raise ValueError(f"Multiple references of parameter {par_name} found in likelihood function for conjugate sampler with target {target}. This is not supported.")
187
+ else:
188
+ raise ValueError(f"Unable to find conjugate parameter {par_name} in likelihood function for conjugate sampler with target {target}")
189
+
190
+ def _check_conjugate_parameter_is_scalar_identity(f):
191
+ """Tests whether a function (scalar to scalar) is the identity (lambda x: x)."""
192
+ test_values = [1.0, 10.0, 100.0]
193
+ return all(np.allclose(f(x), x) for x in test_values)
194
+
195
+ def _check_conjugate_parameter_is_scalar_reciprocal(f):
196
+ """Tests whether a function (scalar to scalar) is the reciprocal (lambda x : 1.0/x)."""
197
+ return all(math.isclose(f(x), 1.0 / x) for x in [1.0, 10.0, 100.0])
@@ -1,9 +1,10 @@
1
1
  import numpy as np
2
- from cuqi.experimental.mcmc import SamplerNew
3
- from cuqi.distribution import Posterior, LMRF, Gamma
2
+ from cuqi.experimental.mcmc import ConjugateNew
3
+ from cuqi.experimental.mcmc._conjugate import _ConjugatePair, _get_conjugate_parameter, _check_conjugate_parameter_is_scalar_reciprocal
4
+ from cuqi.distribution import LMRF, Gamma
4
5
  import scipy as sp
5
6
 
6
- class ConjugateApproxNew(SamplerNew):
7
+ class ConjugateApproxNew(ConjugateNew):
7
8
  """ Approximate Conjugate sampler
8
9
 
9
10
  Sampler for sampling a posterior distribution where the likelihood and prior can be approximated
@@ -16,26 +17,27 @@ class ConjugateApproxNew(SamplerNew):
16
17
 
17
18
  LMRF likelihood must have zero mean.
18
19
 
19
- Currently, the sampler does NOT automatically check that the conjugate distributions are defined on the correct parameters.
20
-
21
-
22
- For more information on conjugate pairs, see https://en.wikipedia.org/wiki/Conjugate_prior.
20
+ For more details on conjugacy see :class:`ConjugateNew`.
23
21
 
24
22
  """
25
23
 
26
- def _initialize(self):
27
- pass
24
+ def _set_conjugatepair(self):
25
+ """ Set the conjugate pair based on the likelihood and prior. This requires target to be set. """
26
+ if isinstance(self.target.likelihood.distribution, LMRF) and isinstance(self.target.prior, Gamma):
27
+ self._conjugatepair = _LMRFGammaPair(self.target)
28
+ else:
29
+ raise ValueError(f"Conjugacy is not defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
28
30
 
29
- def validate_target(self):
30
31
 
31
- if not isinstance(self.target, Posterior):
32
- raise TypeError("Approximate conjugate sampler requires a target of type Posterior")
32
+ class _LMRFGammaPair(_ConjugatePair):
33
+ """ Implementation of the conjugate pair (LMRF, Gamma) """
33
34
 
35
+ def validate_target(self):
34
36
  if not isinstance(self.target.likelihood.distribution, LMRF):
35
37
  raise ValueError("Approximate conjugate sampler only works with LMRF likelihood function")
36
38
 
37
39
  if not isinstance(self.target.prior, Gamma):
38
- raise ValueError("Approximate conjugate sampler only works with Gamma prior")
40
+ raise ValueError("Approximate conjugate sampler with LMRF likelihood only works with Gamma prior")
39
41
 
40
42
  if not self.target.prior.dim == 1:
41
43
  raise ValueError("Approximate conjugate sampler only works with univariate Gamma prior")
@@ -43,7 +45,14 @@ class ConjugateApproxNew(SamplerNew):
43
45
  if np.sum(self.target.likelihood.distribution.location) != 0:
44
46
  raise ValueError("Approximate conjugate sampler only works with zero mean LMRF likelihood")
45
47
 
46
- def step(self):
48
+ key, value = _get_conjugate_parameter(self.target)
49
+ if key == "scale":
50
+ if not _check_conjugate_parameter_is_scalar_reciprocal(value):
51
+ raise ValueError("Approximate conjugate sampler only works with Gamma prior on the inverse of the scale parameter of the LMRF likelihood")
52
+ else:
53
+ raise ValueError(f"No approximate conjugacy defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
54
+
55
+ def sample(self):
47
56
  # Extract variables
48
57
  # Here we approximate the LMRF with a Gaussian
49
58
 
@@ -69,7 +78,4 @@ class ConjugateApproxNew(SamplerNew):
69
78
  # Create Gamma distribution and sample
70
79
  dist = Gamma(shape=d+alpha, rate=np.linalg.norm(Lx)**2+beta)
71
80
 
72
- self.current_point = dist.sample()
73
-
74
- def tune(self, skip_len, update_count):
75
- pass
81
+ return dist.sample()
@@ -23,6 +23,7 @@ class DirectNew(SamplerNew):
23
23
 
24
24
  def step(self):
25
25
  self.current_point = self.target.sample()
26
+ return 1 # Returns acceptance rate of 1
26
27
 
27
28
  def tune(self, skip_len, update_count):
28
29
  pass
cuqi/problem/_problem.py CHANGED
@@ -710,8 +710,7 @@ class BayesianProblem(object):
710
710
  # Require gradient?
711
711
  if must_have_gradient:
712
712
  try:
713
- posterior.prior.gradient(np.zeros(posterior.prior.dim))
714
- posterior.likelihood.gradient(np.zeros(posterior.likelihood.dim))
713
+ posterior.posterior.gradient(np.zeros(posterior.posterior.dim))
715
714
  G = True
716
715
  except (NotImplementedError, AttributeError):
717
716
  G = False