CUQIpy 1.3.0.post0.dev401__py3-none-any.whl → 1.4.0.post0.dev41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- cuqi/__init__.py +1 -0
- cuqi/_version.py +3 -3
- cuqi/density/_density.py +9 -1
- cuqi/distribution/_joint_distribution.py +96 -11
- cuqi/experimental/__init__.py +1 -2
- cuqi/experimental/_recommender.py +4 -4
- cuqi/legacy/__init__.py +2 -0
- cuqi/legacy/sampler/__init__.py +11 -0
- cuqi/legacy/sampler/_conjugate.py +55 -0
- cuqi/legacy/sampler/_conjugate_approx.py +52 -0
- cuqi/legacy/sampler/_cwmh.py +196 -0
- cuqi/legacy/sampler/_gibbs.py +231 -0
- cuqi/legacy/sampler/_hmc.py +335 -0
- cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
- cuqi/legacy/sampler/_laplace_approximation.py +184 -0
- cuqi/legacy/sampler/_mh.py +190 -0
- cuqi/legacy/sampler/_pcn.py +244 -0
- cuqi/legacy/sampler/_rto.py +284 -0
- cuqi/legacy/sampler/_sampler.py +182 -0
- cuqi/problem/_problem.py +87 -80
- cuqi/sampler/__init__.py +120 -8
- cuqi/sampler/_conjugate.py +376 -35
- cuqi/sampler/_conjugate_approx.py +40 -16
- cuqi/sampler/_cwmh.py +132 -138
- cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
- cuqi/sampler/_gibbs.py +269 -130
- cuqi/sampler/_hmc.py +328 -201
- cuqi/sampler/_langevin_algorithm.py +282 -98
- cuqi/sampler/_laplace_approximation.py +87 -117
- cuqi/sampler/_mh.py +47 -157
- cuqi/sampler/_pcn.py +56 -211
- cuqi/sampler/_rto.py +206 -140
- cuqi/sampler/_sampler.py +540 -135
- {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/METADATA +1 -1
- {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/RECORD +38 -37
- cuqi/experimental/mcmc/__init__.py +0 -122
- cuqi/experimental/mcmc/_conjugate.py +0 -396
- cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
- cuqi/experimental/mcmc/_cwmh.py +0 -190
- cuqi/experimental/mcmc/_gibbs.py +0 -366
- cuqi/experimental/mcmc/_hmc.py +0 -462
- cuqi/experimental/mcmc/_langevin_algorithm.py +0 -382
- cuqi/experimental/mcmc/_laplace_approximation.py +0 -154
- cuqi/experimental/mcmc/_mh.py +0 -80
- cuqi/experimental/mcmc/_pcn.py +0 -89
- cuqi/experimental/mcmc/_rto.py +0 -350
- cuqi/experimental/mcmc/_sampler.py +0 -582
- {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/WHEEL +0 -0
- {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/licenses/LICENSE +0 -0
- {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/top_level.txt +0 -0
cuqi/sampler/__init__.py
CHANGED
|
@@ -1,11 +1,123 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The sampler module of CUQIpy. It has been re-implemented to improve design, flexibility,
|
|
3
|
+
and extensibility. The old sampler module can be found in :py:mod:`cuqi.legacy.sampler`.
|
|
4
|
+
|
|
5
|
+
Main changes for users in this implementation
|
|
6
|
+
---------------------------------------------
|
|
7
|
+
|
|
8
|
+
1. Sampling API
|
|
9
|
+
^^^^^^^^^^^^
|
|
10
|
+
|
|
11
|
+
Previously one would call the `.sample` or `sample_adapt` methods of a sampler instance at :py:mod:`cuqi.legacy.sampler` to sample from a target distribution and store the samples as the output as follows:
|
|
12
|
+
|
|
13
|
+
.. code-block:: python
|
|
14
|
+
|
|
15
|
+
from cuqi.legacy.sampler import MH
|
|
16
|
+
from cuqi.distribution import DistributionGallery
|
|
17
|
+
|
|
18
|
+
# Target distribution
|
|
19
|
+
target = DistributionGallery("donut")
|
|
20
|
+
|
|
21
|
+
# Set up sampler
|
|
22
|
+
sampler = MH(target)
|
|
23
|
+
|
|
24
|
+
# Sample from the target distribution (Alternatively calling sample with explicit scale parameter set in sampler)
|
|
25
|
+
samples = sampler.sample_adapt(Ns=100, Nb=100) # Burn-in (Nb) removed by default
|
|
26
|
+
|
|
27
|
+
This has now changed to to a more object-oriented API which provides more flexibility and control over the sampling process.
|
|
28
|
+
|
|
29
|
+
For example one can now more explicitly control when the sampler is tuned (warmup) and when it is sampling with fixed parameters.
|
|
30
|
+
|
|
31
|
+
.. code-block:: python
|
|
32
|
+
|
|
33
|
+
from cuqi.sampler import MH
|
|
34
|
+
from cuqi.distribution import DistributionGallery
|
|
35
|
+
|
|
36
|
+
# Target distribution
|
|
37
|
+
target = DistributionGallery("donut")
|
|
38
|
+
|
|
39
|
+
# Set up sampler
|
|
40
|
+
sampler = MH(target)
|
|
41
|
+
|
|
42
|
+
# Sample from the target distribution
|
|
43
|
+
sampler.warmup(Nb=100) # Explicit warmup (tuning) of sampler
|
|
44
|
+
sampler.sample(Ns=100) # Sampling with fixed parameters
|
|
45
|
+
samples = sampler.get_samples().burnthin(Nb=100) # Getting samples and removing burn-in from warmup
|
|
46
|
+
|
|
47
|
+
Importantly, the removal of burn-in from e.g. warmup is now a separate step that is done after the sampling process is complete.
|
|
48
|
+
|
|
49
|
+
2. Sampling API for BayesianProblem
|
|
50
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
51
|
+
|
|
52
|
+
:py:class:`cuqi.problem.BayesianProblem` continues to have the same API for `sample_posterior` and the `UQ` method.
|
|
53
|
+
|
|
54
|
+
There is a flag `legacy` that can be set to `True` to use the legacy MCMC samplers.
|
|
55
|
+
|
|
56
|
+
By default, the flag is set to `False` and the samplers in `cuqi.sampler` are used.
|
|
57
|
+
|
|
58
|
+
For this more high-level interface, burn-in is automatically removed from the samples as was the case before.
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
3. More options for Gibbs sampling
|
|
62
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
63
|
+
|
|
64
|
+
There are now more options for Gibbs sampling. Previously it was only possible to sample with Gibbs for samplers :py:class:`cuqi.legacy.sampler.LinearRTO`, :py:class:`cuqi.legacy.sampler.RegularizedLinearRTO`, :py:class:`cuqi.legacy.sampler.Conjugate`, and :py:class:`cuqi.legacy.sampler.ConjugateApprox`.
|
|
65
|
+
|
|
66
|
+
Now, it is possible to define a Gibbs sampling scheme using any sampler from the :py:mod:`cuqi.sampler` module.
|
|
67
|
+
|
|
68
|
+
**Example using a NUTS-within-Gibbs scheme for a 1D deconvolution problem:**
|
|
69
|
+
|
|
70
|
+
.. code-block:: python
|
|
71
|
+
|
|
72
|
+
import cuqi
|
|
73
|
+
import numpy as np
|
|
74
|
+
from cuqi.distribution import Gamma, Gaussian, GMRF, JointDistribution
|
|
75
|
+
from cuqi.sampler import NUTS, HybridGibbs, Conjugate
|
|
76
|
+
from cuqi.testproblem import Deconvolution1D
|
|
77
|
+
|
|
78
|
+
# Forward problem
|
|
79
|
+
A, y_data, info = Deconvolution1D(dim=128, phantom='sinc', noise_std=0.001).get_components()
|
|
80
|
+
|
|
81
|
+
# Bayesian Inverse Problem
|
|
82
|
+
s = Gamma(1, 1e-4)
|
|
83
|
+
x = GMRF(np.zeros(A.domain_dim), 50)
|
|
84
|
+
y = Gaussian(A @ x, lambda s: 1 / s)
|
|
85
|
+
|
|
86
|
+
# Posterior
|
|
87
|
+
target = JointDistribution(y, x, s)(y=y_data)
|
|
88
|
+
|
|
89
|
+
# Gibbs sampling strategy. Note we can define initial_points and various parameters for each sampler
|
|
90
|
+
sampling_strategy = {
|
|
91
|
+
"x": NUTS(max_depth=10, initial_point=np.zeros(A.domain_dim)),
|
|
92
|
+
"s": Conjugate()
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Here we do 10 internal steps with NUTS for each Gibbs step
|
|
96
|
+
num_sampling_steps = {
|
|
97
|
+
"x": 10,
|
|
98
|
+
"s": 1
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
sampler = HybridGibbs(target, sampling_strategy, num_sampling_steps)
|
|
102
|
+
|
|
103
|
+
sampler.warmup(50)
|
|
104
|
+
sampler.sample(200)
|
|
105
|
+
samples = sampler.get_samples().burnthin(Nb=50)
|
|
106
|
+
|
|
107
|
+
samples["x"].plot_ci(exact=info.exactSolution)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
1
112
|
from ._sampler import Sampler, ProposalBasedSampler
|
|
2
|
-
from .
|
|
3
|
-
from ._conjugate_approx import ConjugateApprox
|
|
4
|
-
from ._cwmh import CWMH
|
|
5
|
-
from ._gibbs import Gibbs
|
|
6
|
-
from ._hmc import NUTS
|
|
7
|
-
from ._langevin_algorithm import ULA, MALA
|
|
8
|
-
from ._laplace_approximation import UGLA
|
|
113
|
+
from ._langevin_algorithm import ULA, MALA, MYULA, PnPULA
|
|
9
114
|
from ._mh import MH
|
|
10
|
-
from ._pcn import
|
|
115
|
+
from ._pcn import PCN
|
|
11
116
|
from ._rto import LinearRTO, RegularizedLinearRTO
|
|
117
|
+
from ._cwmh import CWMH
|
|
118
|
+
from ._laplace_approximation import UGLA
|
|
119
|
+
from ._hmc import NUTS
|
|
120
|
+
from ._gibbs import HybridGibbs
|
|
121
|
+
from ._conjugate import Conjugate
|
|
122
|
+
from ._conjugate_approx import ConjugateApprox
|
|
123
|
+
from ._direct import Direct
|
cuqi/sampler/_conjugate.py
CHANGED
|
@@ -1,55 +1,396 @@
|
|
|
1
|
-
from cuqi.distribution import Posterior, Gaussian, Gamma, GMRF
|
|
2
|
-
from cuqi.implicitprior import RegularizedGaussian, RegularizedGMRF
|
|
3
1
|
import numpy as np
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
import math
|
|
4
|
+
from cuqi.sampler import Sampler
|
|
5
|
+
from cuqi.distribution import Posterior, Gaussian, Gamma, GMRF, ModifiedHalfNormal
|
|
6
|
+
from cuqi.implicitprior import RegularizedGaussian, RegularizedGMRF, RegularizedUnboundedUniform
|
|
7
|
+
from cuqi.utilities import get_non_default_args, count_nonzero, count_within_bounds, count_constant_components_1D, count_constant_components_2D, piecewise_linear_1D_DoF
|
|
8
|
+
from cuqi.geometry import Continuous1D, Continuous2D, Image2D
|
|
4
9
|
|
|
5
|
-
class Conjugate:
|
|
10
|
+
class Conjugate(Sampler):
|
|
6
11
|
""" Conjugate sampler
|
|
7
12
|
|
|
8
|
-
Sampler for sampling a posterior distribution where the likelihood and prior are conjugate.
|
|
13
|
+
Sampler for sampling a posterior distribution which is a so-called "conjugate" distribution, i.e., where the likelihood and prior are conjugate to each other - denoted as a conjugate pair.
|
|
9
14
|
|
|
10
15
|
Currently supported conjugate pairs are:
|
|
11
|
-
- (Gaussian, Gamma)
|
|
12
|
-
- (GMRF, Gamma)
|
|
13
|
-
- (RegularizedGaussian, Gamma) with
|
|
16
|
+
- (Gaussian, Gamma) where Gamma is defined on the precision parameter of the Gaussian
|
|
17
|
+
- (GMRF, Gamma) where Gamma is defined on the precision parameter of the GMRF
|
|
18
|
+
- (RegularizedGaussian, Gamma) with preset constraints only and Gamma is defined on the precision parameter of the RegularizedGaussian
|
|
19
|
+
- (RegularizedGMRF, Gamma) with preset constraints only and Gamma is defined on the precision parameter of the RegularizedGMRF
|
|
20
|
+
- (RegularizedGaussian, ModifiedHalfNormal) with most of the preset constraints and regularization
|
|
21
|
+
- (RegularizedGMRF, ModifiedHalfNormal) with most of the preset constraints and regularization
|
|
14
22
|
|
|
15
|
-
|
|
23
|
+
Currently the Gamma and ModifiedHalfNormal distribution must be univariate.
|
|
16
24
|
|
|
17
|
-
|
|
25
|
+
A conjugate pair defines implicitly a so-called conjugate distribution which can be sampled from directly.
|
|
26
|
+
|
|
27
|
+
The conjugate parameter is the parameter that both the likelihood and prior PDF depend on.
|
|
28
|
+
|
|
29
|
+
For more information on conjugacy and conjugate distributions see https://en.wikipedia.org/wiki/Conjugate_prior.
|
|
30
|
+
|
|
31
|
+
For implicit regularized Gaussians and the corresponding conjugacy relations, see:
|
|
18
32
|
|
|
19
|
-
[1] Everink, Jasper M., Yiqiu Dong, and Martin S. Andersen. "Bayesian inference with projected densities." SIAM/ASA Journal on Uncertainty Quantification 11.3 (2023): 1025-1043.
|
|
33
|
+
Section 3.3 from [1] Everink, Jasper M., Yiqiu Dong, and Martin S. Andersen. "Bayesian inference with projected densities." SIAM/ASA Journal on Uncertainty Quantification 11.3 (2023): 1025-1043.
|
|
34
|
+
Section 4 from [2] Everink, Jasper M., Yiqiu Dong, and Martin S. Andersen. "Sparse Bayesian inference with regularized Gaussian distributions." Inverse Problems 39.11 (2023): 115004.
|
|
20
35
|
|
|
21
36
|
"""
|
|
22
37
|
|
|
23
|
-
def
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
38
|
+
def _initialize(self):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@Sampler.target.setter # Overwrite the target setter to set the conjugate pair
|
|
42
|
+
def target(self, value):
|
|
43
|
+
""" Set the target density. Runs validation of the target. """
|
|
44
|
+
self._target = value
|
|
45
|
+
if self._target is not None:
|
|
46
|
+
self._set_conjugatepair()
|
|
47
|
+
self.validate_target()
|
|
48
|
+
|
|
49
|
+
def validate_target(self):
|
|
50
|
+
self._ensure_target_is_posterior()
|
|
51
|
+
self._conjugatepair.validate_target()
|
|
52
|
+
|
|
53
|
+
def step(self):
|
|
54
|
+
self.current_point = self._conjugatepair.sample()
|
|
55
|
+
return 1 # Returns acceptance rate of 1
|
|
56
|
+
|
|
57
|
+
def tune(self, skip_len, update_count):
|
|
58
|
+
pass # No tuning required for conjugate sampler
|
|
59
|
+
|
|
60
|
+
def _ensure_target_is_posterior(self):
|
|
61
|
+
""" Ensure that the target is a Posterior distribution. """
|
|
62
|
+
if not isinstance(self.target, Posterior):
|
|
63
|
+
raise TypeError("Conjugate sampler requires a target of type Posterior")
|
|
64
|
+
|
|
65
|
+
def _set_conjugatepair(self):
|
|
66
|
+
""" Set the conjugate pair based on the likelihood and prior. This requires target to be set. """
|
|
67
|
+
self._ensure_target_is_posterior()
|
|
68
|
+
if isinstance(self.target.likelihood.distribution, (Gaussian, GMRF)) and isinstance(self.target.prior, Gamma):
|
|
69
|
+
self._conjugatepair = _GaussianGammaPair(self.target)
|
|
70
|
+
elif isinstance(self.target.likelihood.distribution, RegularizedUnboundedUniform) and isinstance(self.target.prior, Gamma):
|
|
71
|
+
# Check RegularizedUnboundedUniform before RegularizedGaussian and RegularizedGMRF due to the first inheriting from the second.
|
|
72
|
+
self._conjugatepair = _RegularizedUnboundedUniformGammaPair(self.target)
|
|
73
|
+
elif isinstance(self.target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)) and isinstance(self.target.prior, Gamma):
|
|
74
|
+
self._conjugatepair = _RegularizedGaussianGammaPair(self.target)
|
|
75
|
+
elif isinstance(self.target.likelihood.distribution, (RegularizedGaussian, RegularizedGMRF)) and isinstance(self.target.prior, ModifiedHalfNormal):
|
|
76
|
+
self._conjugatepair = _RegularizedGaussianModifiedHalfNormalPair(self.target)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Conjugacy is not defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
|
|
79
|
+
|
|
80
|
+
def conjugate_distribution(self):
|
|
81
|
+
return self._conjugatepair.conjugate_distribution()
|
|
82
|
+
|
|
83
|
+
def __repr__(self):
|
|
84
|
+
msg = super().__repr__()
|
|
85
|
+
if hasattr(self, "_conjugatepair"):
|
|
86
|
+
msg += f"\n Conjugate pair:\n\t {type(self._conjugatepair).__name__.removeprefix('_')}"
|
|
87
|
+
return msg
|
|
33
88
|
|
|
89
|
+
class _ConjugatePair(ABC):
|
|
90
|
+
""" Abstract base class for conjugate pairs (likelihood, prior) used in the Conjugate sampler. """
|
|
91
|
+
|
|
92
|
+
def __init__(self, target):
|
|
34
93
|
self.target = target
|
|
35
94
|
|
|
36
|
-
|
|
95
|
+
@abstractmethod
|
|
96
|
+
def validate_target(self):
|
|
97
|
+
""" Validate the target distribution for the conjugate pair. """
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
def conjugate_distribution(self):
|
|
102
|
+
""" Returns the posterior distribution in the form of a CUQIpy distribution """
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
def sample(self):
|
|
106
|
+
""" Sample from the conjugate distribution. """
|
|
107
|
+
return self.conjugate_distribution().sample()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class _GaussianGammaPair(_ConjugatePair):
|
|
111
|
+
""" Implementation for the Gaussian-Gamma conjugate pair."""
|
|
112
|
+
|
|
113
|
+
def validate_target(self):
|
|
114
|
+
if self.target.prior.dim != 1:
|
|
115
|
+
raise ValueError("Gaussian-Gamma conjugacy only works with univariate Gamma prior")
|
|
116
|
+
|
|
117
|
+
key_value_pairs = _get_conjugate_parameter(self.target)
|
|
118
|
+
if len(key_value_pairs) != 1:
|
|
119
|
+
raise ValueError(f"Multiple references to conjugate parameter {self.target.prior.name} found in likelihood. Only one occurance is supported.")
|
|
120
|
+
for key, value in key_value_pairs:
|
|
121
|
+
if key == "cov":
|
|
122
|
+
if not _check_conjugate_parameter_is_scalar_linear_reciprocal(value):
|
|
123
|
+
raise ValueError("Gaussian-Gamma conjugate pair defined via covariance requires cov: lambda x : s/x for the conjugate parameter")
|
|
124
|
+
elif key == "prec":
|
|
125
|
+
if not _check_conjugate_parameter_is_scalar_linear(value):
|
|
126
|
+
raise ValueError("Gaussian-Gamma conjugate pair defined via precision requires prec: lambda x : s*x for the conjugate parameter")
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(f"RegularizedGaussian-ModifiedHalfNormal conjugacy does not support the conjugate parameter {self.target.prior.name} in the {key} attribute. Only cov and prec")
|
|
129
|
+
|
|
130
|
+
def conjugate_distribution(self):
|
|
37
131
|
# Extract variables
|
|
38
|
-
b = self.target.likelihood.data #mu
|
|
39
|
-
m =
|
|
40
|
-
Ax = self.target.likelihood.distribution.mean #x_i
|
|
41
|
-
L = self.target.likelihood.distribution(np.array([1])).sqrtprec #L
|
|
42
|
-
alpha = self.target.prior.shape #alpha
|
|
43
|
-
beta = self.target.prior.rate #beta
|
|
132
|
+
b = self.target.likelihood.data # mu
|
|
133
|
+
m = len(b) # n
|
|
134
|
+
Ax = self.target.likelihood.distribution.mean # x_i
|
|
135
|
+
L = self.target.likelihood.distribution(np.array([1])).sqrtprec # L
|
|
136
|
+
alpha = self.target.prior.shape # alpha
|
|
137
|
+
beta = self.target.prior.rate # beta
|
|
44
138
|
|
|
45
139
|
# Create Gamma distribution and sample
|
|
46
|
-
|
|
140
|
+
return Gamma(shape=m/2 + alpha, rate=.5 * np.linalg.norm(L @ (Ax - b))**2 + beta)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class _RegularizedGaussianGammaPair(_ConjugatePair):
|
|
144
|
+
"""Implementation for the Regularized Gaussian-Gamma conjugate pair using the conjugacy rules from [1], Section 3.3."""
|
|
145
|
+
|
|
146
|
+
def validate_target(self):
|
|
147
|
+
if self.target.prior.dim != 1:
|
|
148
|
+
raise ValueError("RegularizedGaussian-Gamma conjugacy only works with univariate ModifiedHalfNormal prior")
|
|
149
|
+
|
|
150
|
+
# Raises error if preset is not supported
|
|
151
|
+
_compute_sparsity_level(self.target)
|
|
152
|
+
|
|
153
|
+
key_value_pairs = _get_conjugate_parameter(self.target)
|
|
154
|
+
if len(key_value_pairs) != 1:
|
|
155
|
+
raise ValueError(f"Multiple references to conjugate parameter {self.target.prior.name} found in likelihood. Only one occurance is supported.")
|
|
156
|
+
for key, value in key_value_pairs:
|
|
157
|
+
if key == "cov":
|
|
158
|
+
if not _check_conjugate_parameter_is_scalar_linear_reciprocal(value):
|
|
159
|
+
raise ValueError("Regularized Gaussian-Gamma conjugacy defined via covariance requires cov: lambda x : s/x for the conjugate parameter")
|
|
160
|
+
elif key == "prec":
|
|
161
|
+
if not _check_conjugate_parameter_is_scalar_linear(value):
|
|
162
|
+
raise ValueError("Regularized Gaussian-Gamma conjugacy defined via precision requires prec: lambda x : s*x for the conjugate parameter")
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"RegularizedGaussian-ModifiedHalfNormal conjugacy does not support the conjugate parameter {self.target.prior.name} in the {key} attribute. Only cov and prec")
|
|
165
|
+
|
|
166
|
+
def conjugate_distribution(self):
|
|
167
|
+
# Extract variables
|
|
168
|
+
b = self.target.likelihood.data # mu
|
|
169
|
+
m = _compute_sparsity_level(self.target)
|
|
170
|
+
Ax = self.target.likelihood.distribution.mean # x_i
|
|
171
|
+
L = self.target.likelihood.distribution(np.array([1])).sqrtprec # L
|
|
172
|
+
alpha = self.target.prior.shape # alpha
|
|
173
|
+
beta = self.target.prior.rate # beta
|
|
174
|
+
|
|
175
|
+
# Create Gamma distribution and sample
|
|
176
|
+
return Gamma(shape=m/2 + alpha, rate=.5 * np.linalg.norm(L @ (Ax - b))**2 + beta)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class _RegularizedUnboundedUniformGammaPair(_ConjugatePair):
|
|
180
|
+
"""Implementation for the RegularizedUnboundedUniform-ModifiedHalfNormal conjugate pair using the conjugacy rules from [2], Section 4."""
|
|
181
|
+
|
|
182
|
+
def validate_target(self):
|
|
183
|
+
if self.target.prior.dim != 1:
|
|
184
|
+
raise ValueError("RegularizedUnboundedUniform-Gamma conjugacy only works with univariate Gamma prior")
|
|
185
|
+
|
|
186
|
+
# Raises error if preset is not supported
|
|
187
|
+
_compute_sparsity_level(self.target)
|
|
188
|
+
|
|
189
|
+
key_value_pairs = _get_conjugate_parameter(self.target)
|
|
190
|
+
if len(key_value_pairs) != 1:
|
|
191
|
+
raise ValueError(f"Multiple references to conjugate parameter {self.target.prior.name} found in likelihood. Only one occurance is supported.")
|
|
192
|
+
for key, value in key_value_pairs:
|
|
193
|
+
if key == "strength":
|
|
194
|
+
if not _check_conjugate_parameter_is_scalar_linear(value):
|
|
195
|
+
raise ValueError("RegularizedUnboundedUniform-Gamma conjugacy defined via strength requires strength: lambda x : s*x for the conjugate parameter")
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(f"RegularizedUnboundedUniform-Gamma conjugacy does not support the conjugate parameter {self.target.prior.name} in the {key} attribute. Only strength is supported")
|
|
198
|
+
|
|
199
|
+
def conjugate_distribution(self):
|
|
200
|
+
# Extract prior variables
|
|
201
|
+
alpha = self.target.prior.shape
|
|
202
|
+
beta = self.target.prior.rate
|
|
203
|
+
|
|
204
|
+
# Compute likelihood quantities
|
|
205
|
+
x = self.target.likelihood.data
|
|
206
|
+
m = _compute_sparsity_level(self.target)
|
|
47
207
|
|
|
48
|
-
|
|
208
|
+
reg_op = self.target.likelihood.distribution._regularization_oper
|
|
209
|
+
reg_strength = self.target.likelihood.distribution(np.array([1])).strength
|
|
210
|
+
fx = reg_strength*np.linalg.norm(reg_op@x, ord = 1)
|
|
49
211
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
212
|
+
# Create Gamma distribution
|
|
213
|
+
return Gamma(shape=m/2 + alpha, rate=fx + beta)
|
|
214
|
+
|
|
215
|
+
class _RegularizedGaussianModifiedHalfNormalPair(_ConjugatePair):
|
|
216
|
+
"""Implementation for the Regularized Gaussian-ModifiedHalfNormal conjugate pair using the conjugacy rules from [2], Section 4."""
|
|
217
|
+
|
|
218
|
+
def validate_target(self):
|
|
219
|
+
if self.target.prior.dim != 1:
|
|
220
|
+
raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy only works with univariate ModifiedHalfNormal prior")
|
|
221
|
+
|
|
222
|
+
# Raises error if preset is not supported
|
|
223
|
+
_compute_sparsity_level(self.target)
|
|
224
|
+
|
|
225
|
+
key_value_pairs = _get_conjugate_parameter(self.target)
|
|
226
|
+
if len(key_value_pairs) != 2:
|
|
227
|
+
raise ValueError(f"Incorrect number of references to conjugate parameter {self.target.prior.name} found in likelihood. Found {len(key_value_pairs)} times, but needs to occur in prec or cov, and in strength")
|
|
228
|
+
for key, value in key_value_pairs:
|
|
229
|
+
if key == "strength":
|
|
230
|
+
if not _check_conjugate_parameter_is_scalar_linear(value):
|
|
231
|
+
raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy defined via strength requires strength: lambda x : s*x for the conjugate parameter")
|
|
232
|
+
elif key == "prec":
|
|
233
|
+
if not _check_conjugate_parameter_is_scalar_quadratic(value):
|
|
234
|
+
raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy defined via precision requires prec: lambda x : s*x for the conjugate parameter")
|
|
235
|
+
elif key == "cov":
|
|
236
|
+
if not _check_conjugate_parameter_is_scalar_quadratic_reciprocal(value):
|
|
237
|
+
raise ValueError("RegularizedGaussian-ModifiedHalfNormal conjugacy defined via covariance requires cov: lambda x : s/x for the conjugate parameter")
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(f"RegularizedGaussian-ModifiedHalfNormal conjugacy does not support the conjugate parameter {self.target.prior.name} in the {key} attribute. Only cov, prec and strength are supported")
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def conjugate_distribution(self):
|
|
243
|
+
# Extract prior variables
|
|
244
|
+
alpha = self.target.prior.alpha
|
|
245
|
+
beta = self.target.prior.beta
|
|
246
|
+
gamma = self.target.prior.gamma
|
|
247
|
+
|
|
248
|
+
# Compute likelihood variables
|
|
249
|
+
x = self.target.likelihood.data
|
|
250
|
+
mu = self.target.likelihood.distribution.mean
|
|
251
|
+
L = self.target.likelihood.distribution(np.array([1])).sqrtprec
|
|
252
|
+
|
|
253
|
+
m = _compute_sparsity_level(self.target)
|
|
254
|
+
|
|
255
|
+
reg_op = self.target.likelihood.distribution._regularization_oper
|
|
256
|
+
reg_strength = self.target.likelihood.distribution(np.array([1])).strength
|
|
257
|
+
fx = reg_strength*np.linalg.norm(reg_op@x, ord = 1)
|
|
258
|
+
|
|
259
|
+
# Compute parameters of conjugate distribution
|
|
260
|
+
conj_alpha = m + alpha
|
|
261
|
+
conj_beta = 0.5*np.linalg.norm(L @ (mu - x))**2 + beta
|
|
262
|
+
conj_gamma = -fx + gamma
|
|
263
|
+
|
|
264
|
+
# Create conjugate distribution
|
|
265
|
+
return ModifiedHalfNormal(conj_alpha, conj_beta, conj_gamma)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _compute_sparsity_level(target):
|
|
269
|
+
"""Computes the sparsity level in accordance with Section 4 from [2],
|
|
270
|
+
this can be interpreted as the number of degrees of freedom, that is,
|
|
271
|
+
the number of components n minus the dimension the of the subdifferential of the regularized.
|
|
272
|
+
"""
|
|
273
|
+
x = target.likelihood.data
|
|
274
|
+
|
|
275
|
+
constraint = target.likelihood.distribution.preset["constraint"]
|
|
276
|
+
regularization = target.likelihood.distribution.preset["regularization"]
|
|
277
|
+
|
|
278
|
+
# There is no reference for some of these conjugacy rules
|
|
279
|
+
if constraint == "nonnegativity":
|
|
280
|
+
if regularization in [None, "l1"]:
|
|
281
|
+
# Number of non-zero components in x
|
|
282
|
+
return count_nonzero(x)
|
|
283
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
|
|
284
|
+
# Number of non-zero constant components in x
|
|
285
|
+
return count_constant_components_1D(x, lower = 0.0)
|
|
286
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
|
|
287
|
+
# Number of non-zero constant components in x
|
|
288
|
+
return count_constant_components_2D(target.likelihood.distribution.geometry.par2fun(x), lower = 0.0)
|
|
289
|
+
elif constraint == "box":
|
|
290
|
+
bounds = target.likelihood.distribution._box_bounds
|
|
291
|
+
if regularization is None:
|
|
292
|
+
# Number of components in x that are strictly between the lower and upper bound
|
|
293
|
+
return count_within_bounds(x, bounds[0], bounds[1])
|
|
294
|
+
elif regularization == "l1":
|
|
295
|
+
# Number of components in x that are strictly between the lower and upper bound and are not zero
|
|
296
|
+
return count_within_bounds(x, bounds[0], bounds[1], exception = 0.0)
|
|
297
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
|
|
298
|
+
# Number of constant components in x between are strictly between the lower and upper bound
|
|
299
|
+
return count_constant_components_1D(x, lower = bounds[0], upper = bounds[1])
|
|
300
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
|
|
301
|
+
# Number of constant components in x between are strictly between the lower and upper bound
|
|
302
|
+
return count_constant_components_2D(target.likelihood.distribution.geometry.par2fun(x), lower = bounds[0], upper = bounds[1])
|
|
303
|
+
elif constraint in ["increasing", "decreasing"]:
|
|
304
|
+
if regularization is None:
|
|
305
|
+
# Number of constant components in x
|
|
306
|
+
return count_constant_components_1D(x)
|
|
307
|
+
elif regularization == "l1":
|
|
308
|
+
# Number of constant components in x that are not zero
|
|
309
|
+
return count_constant_components_1D(x, exception = 0.0)
|
|
310
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
|
|
311
|
+
# Number of constant components in x
|
|
312
|
+
return count_constant_components_1D(x)
|
|
313
|
+
# Increasing and decreasing cannot be done in 2D
|
|
314
|
+
elif constraint in ["convex", "concave"]:
|
|
315
|
+
if regularization is None:
|
|
316
|
+
# Number of piecewise linear components in x
|
|
317
|
+
return piecewise_linear_1D_DoF(x)
|
|
318
|
+
elif regularization == "l1":
|
|
319
|
+
# Number of piecewise linear components in x that are not zero
|
|
320
|
+
return piecewise_linear_1D_DoF(x, exception_zero = True)
|
|
321
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
|
|
322
|
+
# Number of piecewise linear components in x that are not flat
|
|
323
|
+
return piecewise_linear_1D_DoF(x, exception_flat = True)
|
|
324
|
+
# convex and concave has only been implemented in 1D
|
|
325
|
+
elif constraint == None:
|
|
326
|
+
if regularization == "l1":
|
|
327
|
+
# Number of non-zero components in x
|
|
328
|
+
return count_nonzero(x)
|
|
329
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, Continuous1D):
|
|
330
|
+
# Number of non-zero constant components in x
|
|
331
|
+
return count_constant_components_1D(x)
|
|
332
|
+
elif regularization == "tv" and isinstance(target.likelihood.distribution.geometry, (Continuous2D, Image2D)):
|
|
333
|
+
# Number of non-zero constant components in x
|
|
334
|
+
return count_constant_components_2D(target.likelihood.distribution.geometry.par2fun(x))
|
|
335
|
+
|
|
336
|
+
raise ValueError("RegularizedGaussian preset constraint and regularization choice is currently not supported with conjugacy.")
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _get_conjugate_parameter(target):
|
|
340
|
+
"""Extract the conjugate parameter name (e.g. d), and returns the mutable variable that is defined by the conjugate parameter, e.g. cov and its value e.g. lambda d:1/d"""
|
|
341
|
+
par_name = target.prior.name
|
|
342
|
+
mutable_likelihood_vars = target.likelihood.distribution.get_mutable_variables()
|
|
343
|
+
|
|
344
|
+
found_parameter_pairs = []
|
|
345
|
+
|
|
346
|
+
for var_key in mutable_likelihood_vars:
|
|
347
|
+
attr = getattr(target.likelihood.distribution, var_key)
|
|
348
|
+
if callable(attr) and par_name in get_non_default_args(attr):
|
|
349
|
+
found_parameter_pairs.append((var_key, attr))
|
|
350
|
+
if len(found_parameter_pairs) == 0:
|
|
351
|
+
raise ValueError(f"Unable to find conjugate parameter {par_name} in likelihood function for conjugate sampler with target {target}")
|
|
352
|
+
return found_parameter_pairs
|
|
353
|
+
|
|
354
|
+
def _check_conjugate_parameter_is_scalar_identity(f):
|
|
355
|
+
"""Tests whether a function (scalar to scalar) is the identity (lambda x: x)."""
|
|
356
|
+
test_values = [1.0, 10.0, 100.0]
|
|
357
|
+
return all(np.allclose(f(x), x) for x in test_values)
|
|
358
|
+
|
|
359
|
+
def _check_conjugate_parameter_is_scalar_reciprocal(f):
|
|
360
|
+
"""Tests whether a function (scalar to scalar) is the reciprocal (lambda x : 1.0/x)."""
|
|
361
|
+
return all(math.isclose(f(x), 1.0 / x) for x in [1.0, 10.0, 100.0])
|
|
362
|
+
|
|
363
|
+
def _check_conjugate_parameter_is_scalar_linear(f):
|
|
364
|
+
"""
|
|
365
|
+
Tests whether a function (scalar to scalar) is linear (lambda x: s*x for some s).
|
|
366
|
+
The tests checks whether the function is zero and some finite differences are constant.
|
|
367
|
+
"""
|
|
368
|
+
test_values = [1.0, 10.0, 100.0]
|
|
369
|
+
h = 1e-2
|
|
370
|
+
finite_diffs = [(f(x + h*x)-f(x))/(h*x) for x in test_values]
|
|
371
|
+
return np.isclose(f(0.0), 0.0) and all(np.allclose(c, finite_diffs[0]) for c in finite_diffs[1:])
|
|
372
|
+
|
|
373
|
+
def _check_conjugate_parameter_is_scalar_linear_reciprocal(f):
|
|
374
|
+
"""
|
|
375
|
+
Tests whether a function (scalar to scalar) is a constant times the inverse of the input (lambda x: s/x for some s).
|
|
376
|
+
The tests checks whether the the reciprocal of the function has constant finite differences.
|
|
377
|
+
"""
|
|
378
|
+
g = lambda x : 1.0/f(x)
|
|
379
|
+
test_values = [1.0, 10.0, 100.0]
|
|
380
|
+
h = 1e-2
|
|
381
|
+
finite_diffs = [(g(x + h*x)-g(x))/(h*x) for x in test_values]
|
|
382
|
+
return all(np.allclose(c, finite_diffs[0]) for c in finite_diffs[1:])
|
|
383
|
+
|
|
384
|
+
def _check_conjugate_parameter_is_scalar_quadratic(f):
|
|
385
|
+
"""
|
|
386
|
+
Tests whether a function (scalar to scalar) is linear (lambda x: s*x**2 for some s).
|
|
387
|
+
The tests checks whether the function divided by the parameter is linear
|
|
388
|
+
"""
|
|
389
|
+
return _check_conjugate_parameter_is_scalar_linear(lambda x: f(x)/x if x != 0.0 else f(0.0))
|
|
390
|
+
|
|
391
|
+
def _check_conjugate_parameter_is_scalar_quadratic_reciprocal(f):
|
|
392
|
+
"""
|
|
393
|
+
Tests whether a function (scalar to scalar) is linear (lambda x: s*x**-2 for some s).
|
|
394
|
+
The tests checks whether the function divided by the parameter is the reciprical of a linear function.
|
|
395
|
+
"""
|
|
396
|
+
return _check_conjugate_parameter_is_scalar_linear_reciprocal(lambda x: f(x)/x)
|
|
@@ -1,31 +1,57 @@
|
|
|
1
|
-
from cuqi.distribution import Posterior, LMRF, Gamma
|
|
2
1
|
import numpy as np
|
|
2
|
+
from cuqi.sampler import Conjugate
|
|
3
|
+
from cuqi.sampler._conjugate import _ConjugatePair, _get_conjugate_parameter, _check_conjugate_parameter_is_scalar_reciprocal
|
|
4
|
+
from cuqi.distribution import LMRF, Gamma
|
|
3
5
|
import scipy as sp
|
|
4
6
|
|
|
5
|
-
class ConjugateApprox:
|
|
7
|
+
class ConjugateApprox(Conjugate):
|
|
6
8
|
""" Approximate Conjugate sampler
|
|
7
9
|
|
|
8
10
|
Sampler for sampling a posterior distribution where the likelihood and prior can be approximated
|
|
9
11
|
by a conjugate pair.
|
|
10
12
|
|
|
11
13
|
Currently supported pairs are:
|
|
12
|
-
- (LMRF, Gamma): Approximated by (Gaussian, Gamma)
|
|
14
|
+
- (LMRF, Gamma): Approximated by (Gaussian, Gamma) where Gamma is defined on the inverse of the scale parameter of the LMRF distribution.
|
|
13
15
|
|
|
14
|
-
|
|
16
|
+
Gamma distribution must be univariate.
|
|
17
|
+
|
|
18
|
+
LMRF likelihood must have zero mean.
|
|
19
|
+
|
|
20
|
+
For more details on conjugacy see :class:`Conjugate`.
|
|
15
21
|
|
|
16
22
|
"""
|
|
17
23
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
if
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
raise ValueError("
|
|
24
|
-
|
|
24
|
+
def _set_conjugatepair(self):
|
|
25
|
+
""" Set the conjugate pair based on the likelihood and prior. This requires target to be set. """
|
|
26
|
+
if isinstance(self.target.likelihood.distribution, LMRF) and isinstance(self.target.prior, Gamma):
|
|
27
|
+
self._conjugatepair = _LMRFGammaPair(self.target)
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError(f"Conjugacy is not defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
|
|
30
|
+
|
|
25
31
|
|
|
26
|
-
|
|
32
|
+
class _LMRFGammaPair(_ConjugatePair):
|
|
33
|
+
""" Implementation of the conjugate pair (LMRF, Gamma) """
|
|
34
|
+
|
|
35
|
+
def validate_target(self):
|
|
36
|
+
if not self.target.prior.dim == 1:
|
|
37
|
+
raise ValueError("Approximate conjugate sampler only works with univariate Gamma prior")
|
|
38
|
+
|
|
39
|
+
if np.sum(self.target.likelihood.distribution.location) != 0:
|
|
40
|
+
raise ValueError("Approximate conjugate sampler only works with zero mean LMRF likelihood")
|
|
41
|
+
|
|
42
|
+
key_value_pairs = _get_conjugate_parameter(self.target)
|
|
43
|
+
if len(key_value_pairs) != 1:
|
|
44
|
+
raise ValueError(f"Multiple references to conjugate parameter {self.target.prior.name} found in likelihood. Only one occurance is supported.")
|
|
45
|
+
for key, value in key_value_pairs:
|
|
46
|
+
if key == "scale":
|
|
47
|
+
if not _check_conjugate_parameter_is_scalar_reciprocal(value):
|
|
48
|
+
raise ValueError("Approximate conjugate sampler only works with Gamma prior on the inverse of the scale parameter of the LMRF likelihood")
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(f"No approximate conjugacy defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
|
|
51
|
+
|
|
52
|
+
def conjugate_distribution(self):
|
|
27
53
|
# Extract variables
|
|
28
|
-
# Here we approximate the
|
|
54
|
+
# Here we approximate the LMRF with a Gaussian
|
|
29
55
|
|
|
30
56
|
# Extract diff_op from target likelihood
|
|
31
57
|
D = self.target.likelihood.distribution._diff_op
|
|
@@ -47,6 +73,4 @@ class ConjugateApprox: # TODO: Subclass from Sampler once updated
|
|
|
47
73
|
beta = self.target.prior.rate #beta
|
|
48
74
|
|
|
49
75
|
# Create Gamma distribution and sample
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
return dist.sample()
|
|
76
|
+
return Gamma(shape=d+alpha, rate=np.linalg.norm(Lx)**2+beta)
|