CUQIpy 1.2.0.post0.dev42__py3-none-any.whl → 1.2.0.post0.dev109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {CUQIpy-1.2.0.post0.dev42.dist-info → CUQIpy-1.2.0.post0.dev109.dist-info}/METADATA +1 -1
- {CUQIpy-1.2.0.post0.dev42.dist-info → CUQIpy-1.2.0.post0.dev109.dist-info}/RECORD +13 -12
- {CUQIpy-1.2.0.post0.dev42.dist-info → CUQIpy-1.2.0.post0.dev109.dist-info}/WHEEL +1 -1
- cuqi/_version.py +3 -3
- cuqi/distribution/__init__.py +1 -0
- cuqi/distribution/_normal.py +34 -0
- cuqi/distribution/_truncated_normal.py +129 -0
- cuqi/experimental/mcmc/_rto.py +2 -2
- cuqi/sampler/_rto.py +2 -2
- cuqi/solver/__init__.py +1 -0
- cuqi/solver/_solver.py +169 -4
- {CUQIpy-1.2.0.post0.dev42.dist-info → CUQIpy-1.2.0.post0.dev109.dist-info}/LICENSE +0 -0
- {CUQIpy-1.2.0.post0.dev42.dist-info → CUQIpy-1.2.0.post0.dev109.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: CUQIpy
|
|
3
|
-
Version: 1.2.0.post0.
|
|
3
|
+
Version: 1.2.0.post0.dev109
|
|
4
4
|
Summary: Computational Uncertainty Quantification for Inverse problems in Python
|
|
5
5
|
Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
|
|
6
6
|
License: Apache License
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
|
|
2
2
|
cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
|
|
3
|
-
cuqi/_version.py,sha256=
|
|
3
|
+
cuqi/_version.py,sha256=W8GFj1jnTSt-WhxIkFksEkvysldGlIeV0DKfFYcd6TU,510
|
|
4
4
|
cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
|
|
5
5
|
cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
|
|
6
6
|
cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
|
|
@@ -14,7 +14,7 @@ cuqi/data/cookie.png,sha256=mr6wUeoIUc5VC2qYj8vafOmTbcRwz0fHz4IIPK9_PnE,984680
|
|
|
14
14
|
cuqi/data/satellite.mat,sha256=a0Nz_Ak-Y0m360dH74pa_rpk-MhaQ91ftGTKhQX7I8g,16373
|
|
15
15
|
cuqi/density/__init__.py,sha256=0zfVcPgqdqiPkss5n_WP_PUt-G3ovHXjokhqEKIlLwA,48
|
|
16
16
|
cuqi/density/_density.py,sha256=BG7gtP0cbFYLVgjYQGkNAhM95PR5ocBVLKRlOVX2PyM,7253
|
|
17
|
-
cuqi/distribution/__init__.py,sha256=
|
|
17
|
+
cuqi/distribution/__init__.py,sha256=Vvw-ge5HAF1now9n4rcwDicCsEUN9_jbbxlKxyzeUuY,761
|
|
18
18
|
cuqi/distribution/_beta.py,sha256=lgN6PGoF9RXQtrMGqSaSBV0hw-LEsOfRTD2Q2L3-Ok4,2903
|
|
19
19
|
cuqi/distribution/_cauchy.py,sha256=UsVXYz8HhagXN5fIWSAIyELqhsJAX_-wk9kkRGgRmA8,3296
|
|
20
20
|
cuqi/distribution/_cmrf.py,sha256=tCbEulM_O7FB3C_W-3IqZp9zGHkTofCdFF0ybHc9UZI,3745
|
|
@@ -29,9 +29,10 @@ cuqi/distribution/_laplace.py,sha256=5exLvlzJm2AgfvZ3KUSkjfwlGwwbsktBxP8z0iLMik8
|
|
|
29
29
|
cuqi/distribution/_lmrf.py,sha256=rdGoQ-fPe1oW6Z29P-l3woq0NX3_RxUQ2rzm1VzemNM,3290
|
|
30
30
|
cuqi/distribution/_lognormal.py,sha256=8_hOFQ3iu88ujX8vxmfVEZ0fdmlhTY98PlG5PasPjEg,2612
|
|
31
31
|
cuqi/distribution/_modifiedhalfnormal.py,sha256=eCg9YhH-zyX25V5WqdBwQykwG_90lm5Qc2901z7jFUE,7390
|
|
32
|
-
cuqi/distribution/_normal.py,sha256=
|
|
32
|
+
cuqi/distribution/_normal.py,sha256=vhIiAseW09IKh1uy0KUq7RP1IuY7hH5aNM1W_R8Gd_Q,2912
|
|
33
33
|
cuqi/distribution/_posterior.py,sha256=zAfL0GECxekZ2lBt1W6_LN0U_xskMwK4VNce5xAF7ig,5018
|
|
34
34
|
cuqi/distribution/_smoothed_laplace.py,sha256=p-1Y23mYA9omwiHGkEuv3T2mwcPAAoNlCr7T8osNkjE,2925
|
|
35
|
+
cuqi/distribution/_truncated_normal.py,sha256=sZkLYgnkGOyS_3ZxY7iw6L62t-Jh6shzsweRsRepN2k,4240
|
|
35
36
|
cuqi/distribution/_uniform.py,sha256=KA8yQ6ZS3nQGS4PYJ4hpDg6Eq8EQKQvPsIpYfR8fj2w,1967
|
|
36
37
|
cuqi/experimental/__init__.py,sha256=vhZvyMX6rl8Y0haqCzGLPz6PSUKyu75XMQbeDHqTTrw,83
|
|
37
38
|
cuqi/experimental/mcmc/__init__.py,sha256=1sn0U6Ep0x5zv2602og2DkV3Bs8hNFOiq7C3VcMimVw,4472
|
|
@@ -45,7 +46,7 @@ cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=yNO7ABxmkixzcLG-lv57GOTyeTr
|
|
|
45
46
|
cuqi/experimental/mcmc/_laplace_approximation.py,sha256=rdiE3cMQFq6FLQcOQwPpuGIxrTAp3aoGPxMDSdeopV0,5688
|
|
46
47
|
cuqi/experimental/mcmc/_mh.py,sha256=MXo0ahXP4KGFkaY4HtvcBE-TMQzsMlTmLKzSvpz7drU,2941
|
|
47
48
|
cuqi/experimental/mcmc/_pcn.py,sha256=wqJBZLuRFSwxihaI53tumAg6AWVuceLMOmXssTetd1A,3374
|
|
48
|
-
cuqi/experimental/mcmc/_rto.py,sha256=
|
|
49
|
+
cuqi/experimental/mcmc/_rto.py,sha256=Ub5rDe_yfkzxqcnimEArXWVb3twuGUJmvxEQNPKQWfU,10061
|
|
49
50
|
cuqi/experimental/mcmc/_sampler.py,sha256=xtoT70T8xe3Ye7yYdIFQD_kivjXlqUImyV3bMt406nk,20106
|
|
50
51
|
cuqi/experimental/mcmc/_utilities.py,sha256=kUzHbhIS3HYZRbneNBK41IogUYX5dS_bJxqEGm7TQBI,525
|
|
51
52
|
cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
|
|
@@ -74,19 +75,19 @@ cuqi/sampler/_langevin_algorithm.py,sha256=o5EyvaR6QGAD7LKwXVRC3WwAP5IYJf5GoMVWl
|
|
|
74
75
|
cuqi/sampler/_laplace_approximation.py,sha256=u018Z5eqlcq_cIwD9yNOaA15dLQE_vUWaee5Xp8bcjg,6454
|
|
75
76
|
cuqi/sampler/_mh.py,sha256=V5tIdn-KdfWo4J_Nbf-AH6XwKWblWUyc4BeuSikUHsE,7062
|
|
76
77
|
cuqi/sampler/_pcn.py,sha256=F0h9-nUFtkqn-o-1s8BCsmr8V7u6R7ycoCOeeV1uhj0,8601
|
|
77
|
-
cuqi/sampler/_rto.py,sha256
|
|
78
|
+
cuqi/sampler/_rto.py,sha256=eJe7_gN_1NpHHc_okKmFtLcOrvoe6cBoVLdf9ULuB_w,11518
|
|
78
79
|
cuqi/sampler/_sampler.py,sha256=TkZ_WAS-5Q43oICa-Elc2gftsRTBd7PEDUMDZ9tTGmU,5712
|
|
79
80
|
cuqi/samples/__init__.py,sha256=vCs6lVk-pi8RBqa6cIN5wyn6u-K9oEf1Na4k1ZMrYv8,44
|
|
80
81
|
cuqi/samples/_samples.py,sha256=hUc8OnCF9CTCuDTrGHwwzv3wp8mG_6vsJAFvuQ-x0uA,35832
|
|
81
|
-
cuqi/solver/__init__.py,sha256=
|
|
82
|
-
cuqi/solver/_solver.py,sha256=
|
|
82
|
+
cuqi/solver/__init__.py,sha256=3eoTTgBHe3M6ygrbgUVG3GlqaZVe5lGajNV9rolXZJ8,179
|
|
83
|
+
cuqi/solver/_solver.py,sha256=4LdfxLaU-fUHltZw7Sq-Xohyxd_6RvKy03xxtIMW6Zs,29488
|
|
83
84
|
cuqi/testproblem/__init__.py,sha256=DWTOcyuNHMbhEuuWlY5CkYkNDSAqhvsKmJXBLivyblU,202
|
|
84
85
|
cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7sS0Q,52540
|
|
85
86
|
cuqi/utilities/__init__.py,sha256=H7xpJe2UinjZftKvE2JuXtTi4DqtkR6uIezStAXwfGg,428
|
|
86
87
|
cuqi/utilities/_get_python_variable_name.py,sha256=QwlBVj2koJRA8s8pWd554p7-ElcI7HUwY32HknaR92E,1827
|
|
87
88
|
cuqi/utilities/_utilities.py,sha256=Jc4knn80vLoA7kgw9FzXwKVFGaNBOXiA9kgvltZU3Ao,11777
|
|
88
|
-
CUQIpy-1.2.0.post0.
|
|
89
|
-
CUQIpy-1.2.0.post0.
|
|
90
|
-
CUQIpy-1.2.0.post0.
|
|
91
|
-
CUQIpy-1.2.0.post0.
|
|
92
|
-
CUQIpy-1.2.0.post0.
|
|
89
|
+
CUQIpy-1.2.0.post0.dev109.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
|
|
90
|
+
CUQIpy-1.2.0.post0.dev109.dist-info/METADATA,sha256=8LneS_GWSYI--t-LZsilX6fC2N8CB4yfhlirg9lXpVE,18496
|
|
91
|
+
CUQIpy-1.2.0.post0.dev109.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
92
|
+
CUQIpy-1.2.0.post0.dev109.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
|
|
93
|
+
CUQIpy-1.2.0.post0.dev109.dist-info/RECORD,,
|
cuqi/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-11-08T11:08:22+0100",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.2.0.post0.
|
|
14
|
+
"full-revisionid": "17570092caf729244ade9c6d647cfa1d2b9ef5f0",
|
|
15
|
+
"version": "1.2.0.post0.dev109"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
cuqi/distribution/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ from ._laplace import Laplace
|
|
|
12
12
|
from ._smoothed_laplace import SmoothedLaplace
|
|
13
13
|
from ._lognormal import Lognormal
|
|
14
14
|
from ._normal import Normal
|
|
15
|
+
from ._truncated_normal import TruncatedNormal
|
|
15
16
|
from ._posterior import Posterior
|
|
16
17
|
from ._uniform import Uniform
|
|
17
18
|
from ._custom import UserDefinedDistribution, DistributionGallery
|
cuqi/distribution/_normal.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
+
import numbers
|
|
2
3
|
from scipy.special import erf
|
|
4
|
+
from cuqi.geometry import _get_identity_geometries
|
|
5
|
+
from cuqi.utilities import force_ndarray
|
|
3
6
|
from cuqi.distribution import Distribution
|
|
4
7
|
|
|
5
8
|
class Normal(Distribution):
|
|
@@ -27,6 +30,24 @@ class Normal(Distribution):
|
|
|
27
30
|
self.mean = mean
|
|
28
31
|
self.std = std
|
|
29
32
|
|
|
33
|
+
@property
|
|
34
|
+
def mean(self):
|
|
35
|
+
""" Mean of the distribution """
|
|
36
|
+
return self._mean
|
|
37
|
+
|
|
38
|
+
@mean.setter
|
|
39
|
+
def mean(self, value):
|
|
40
|
+
self._mean = force_ndarray(value, flatten=True)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def std(self):
|
|
44
|
+
""" Std of the distribution """
|
|
45
|
+
return self._std
|
|
46
|
+
|
|
47
|
+
@std.setter
|
|
48
|
+
def std(self, value):
|
|
49
|
+
self._std = force_ndarray(value, flatten=True)
|
|
50
|
+
|
|
30
51
|
def pdf(self, x):
|
|
31
52
|
return np.prod(1/(self.std*np.sqrt(2*np.pi))*np.exp(-0.5*((x-self.mean)/self.std)**2))
|
|
32
53
|
|
|
@@ -36,6 +57,19 @@ class Normal(Distribution):
|
|
|
36
57
|
def cdf(self, x):
|
|
37
58
|
return np.prod(0.5*(1 + erf((x-self.mean)/(self.std*np.sqrt(2)))))
|
|
38
59
|
|
|
60
|
+
def _gradient(self, val, *args, **kwargs):
|
|
61
|
+
if not type(self.geometry) in _get_identity_geometries():
|
|
62
|
+
raise NotImplementedError("Gradient not implemented for distribution {} with geometry {}".format(self,self.geometry))
|
|
63
|
+
if not callable(self.mean):
|
|
64
|
+
return -(val-self.mean)/(self.std**2)
|
|
65
|
+
elif hasattr(self.mean, "gradient"): # for likelihood
|
|
66
|
+
model = self.mean
|
|
67
|
+
dev = val - model.forward(*args, **kwargs)
|
|
68
|
+
print(dev)
|
|
69
|
+
return model.gradient(1.0/(np.array(self.std)) @ dev, *args, **kwargs)
|
|
70
|
+
else:
|
|
71
|
+
raise NotImplementedError("Gradient not implemented for distribution {} with location {}".format(self,self.mean))
|
|
72
|
+
|
|
39
73
|
def _sample(self,N=1, rng=None):
|
|
40
74
|
|
|
41
75
|
"""
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.special import erf
|
|
3
|
+
from cuqi.utilities import force_ndarray
|
|
4
|
+
from cuqi.distribution import Distribution
|
|
5
|
+
from cuqi.distribution import Normal
|
|
6
|
+
|
|
7
|
+
class TruncatedNormal(Distribution):
|
|
8
|
+
"""
|
|
9
|
+
Truncated Normal probability distribution.
|
|
10
|
+
|
|
11
|
+
Generates instance of cuqi.distribution.TruncatedNormal.
|
|
12
|
+
It allows the user to specify upper and lower bounds on random variables
|
|
13
|
+
represented by a Normal distribution. This distribution is suitable for a
|
|
14
|
+
small dimension setup (e.g. `dim`=3 or 4). Using TruncatedNormal
|
|
15
|
+
Distribution with a larger dimension can lead to a high rejection rate when
|
|
16
|
+
used within MCMC samplers.
|
|
17
|
+
|
|
18
|
+
The variables of this distribution are iid.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
------------
|
|
23
|
+
mean : float or array_like of floats
|
|
24
|
+
mean of distribution
|
|
25
|
+
std : float or array_like of floats
|
|
26
|
+
standard deviation
|
|
27
|
+
low : float or array_like of floats
|
|
28
|
+
lower bound of the distribution
|
|
29
|
+
high : float or array_like of floats
|
|
30
|
+
upper bound of the distribution
|
|
31
|
+
|
|
32
|
+
Example
|
|
33
|
+
-----------
|
|
34
|
+
.. code-block:: python
|
|
35
|
+
|
|
36
|
+
#Generate Normal with mean 0, standard deviation 1 and bounds [-2,2]
|
|
37
|
+
p = cuqi.distribution.TruncatedNormal(mean=0, std=1, low=-2, high=2)
|
|
38
|
+
samples = p.sample(5000)
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, mean=None, std=None, low=-np.Inf, high=np.Inf, is_symmetric=False, **kwargs):
|
|
41
|
+
# Init from abstract distribution class
|
|
42
|
+
super().__init__(is_symmetric=is_symmetric, **kwargs)
|
|
43
|
+
|
|
44
|
+
# Init specific to this distribution
|
|
45
|
+
self.mean = mean
|
|
46
|
+
self.std = std
|
|
47
|
+
self.low = low
|
|
48
|
+
self.high = high
|
|
49
|
+
|
|
50
|
+
# Init underlying normal distribution
|
|
51
|
+
self._normal = Normal(self.mean, self.std, is_symmetric=True, **kwargs)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def mean(self):
|
|
55
|
+
""" Mean of the distribution """
|
|
56
|
+
return self._mean
|
|
57
|
+
|
|
58
|
+
@mean.setter
|
|
59
|
+
def mean(self, value):
|
|
60
|
+
self._mean = force_ndarray(value, flatten=True)
|
|
61
|
+
if hasattr(self, '_normal'):
|
|
62
|
+
self._normal.mean = self._mean
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def std(self):
|
|
66
|
+
""" Std of the distribution """
|
|
67
|
+
return self._std
|
|
68
|
+
|
|
69
|
+
@std.setter
|
|
70
|
+
def std(self, value):
|
|
71
|
+
self._std = force_ndarray(value, flatten=True)
|
|
72
|
+
if hasattr(self, '_normal'):
|
|
73
|
+
self._normal.std = self._std
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def low(self):
|
|
77
|
+
""" Lower bound of the distribution """
|
|
78
|
+
return self._low
|
|
79
|
+
|
|
80
|
+
@low.setter
|
|
81
|
+
def low(self, value):
|
|
82
|
+
self._low = force_ndarray(value, flatten=True)
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def high(self):
|
|
86
|
+
""" Higher bound of the distribution """
|
|
87
|
+
return self._high
|
|
88
|
+
|
|
89
|
+
@high.setter
|
|
90
|
+
def high(self, value):
|
|
91
|
+
self._high = force_ndarray(value, flatten=True)
|
|
92
|
+
|
|
93
|
+
def logpdf(self, x):
|
|
94
|
+
"""
|
|
95
|
+
Computes the unnormalized logpdf at the given values of x.
|
|
96
|
+
"""
|
|
97
|
+
# the unnormalized logpdf
|
|
98
|
+
# check if x falls in the range between np.array a and b
|
|
99
|
+
if np.any(x < self.low) or np.any(x > self.high):
|
|
100
|
+
return -np.Inf
|
|
101
|
+
else:
|
|
102
|
+
return self._normal.logpdf(x)
|
|
103
|
+
|
|
104
|
+
def _gradient(self, x, *args, **kwargs):
|
|
105
|
+
"""
|
|
106
|
+
Computes the gradient of the unnormalized logpdf at the given values of x.
|
|
107
|
+
"""
|
|
108
|
+
# check if x falls in the range between np.array a and b
|
|
109
|
+
if np.any(x < self.low) or np.any(x > self.high):
|
|
110
|
+
return np.NaN*np.ones_like(x)
|
|
111
|
+
else:
|
|
112
|
+
return self._normal.gradient(x, *args, **kwargs)
|
|
113
|
+
|
|
114
|
+
def _sample(self, N=1, rng=None):
|
|
115
|
+
"""
|
|
116
|
+
Generates random samples from the distribution.
|
|
117
|
+
"""
|
|
118
|
+
max_iter = 1e9 # maximum number of trials to avoid infinite loop
|
|
119
|
+
samples = []
|
|
120
|
+
for i in range(int(max_iter)):
|
|
121
|
+
if len(samples) == N:
|
|
122
|
+
break
|
|
123
|
+
sample = self._normal.sample(1,rng)
|
|
124
|
+
if np.all(sample >= self.low) and np.all(sample <= self.high):
|
|
125
|
+
samples.append(sample)
|
|
126
|
+
# raise a error if the number of iterations exceeds max_iter
|
|
127
|
+
if i == max_iter-1:
|
|
128
|
+
raise RuntimeError("Failed to generate {} samples within {} iterations".format(N, max_iter))
|
|
129
|
+
return np.array(samples).T.reshape(-1,N)
|
cuqi/experimental/mcmc/_rto.py
CHANGED
|
@@ -235,8 +235,8 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
235
235
|
|
|
236
236
|
def step(self):
|
|
237
237
|
y = self.b_tild + np.random.randn(len(self.b_tild))
|
|
238
|
-
sim = FISTA(self.M, y, self.
|
|
239
|
-
maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
|
|
238
|
+
sim = FISTA(self.M, y, self.proximal,
|
|
239
|
+
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
|
|
240
240
|
self.current_point, _ = sim.solve()
|
|
241
241
|
acc = 1
|
|
242
242
|
return acc
|
cuqi/sampler/_rto.py
CHANGED
|
@@ -267,8 +267,8 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
267
267
|
samples[:, 0] = self.x0
|
|
268
268
|
for s in range(Ns-1):
|
|
269
269
|
y = self.b_tild + np.random.randn(len(self.b_tild))
|
|
270
|
-
sim = FISTA(self.M, y,
|
|
271
|
-
maxit = self.maxit, stepsize = _stepsize, abstol = self.abstol, adaptive = self.adaptive)
|
|
270
|
+
sim = FISTA(self.M, y, self.proximal,
|
|
271
|
+
samples[:, s], maxit = self.maxit, stepsize = _stepsize, abstol = self.abstol, adaptive = self.adaptive)
|
|
272
272
|
samples[:, s+1], _ = sim.solve()
|
|
273
273
|
|
|
274
274
|
self._print_progress(s+2,Ns) #s+2 is the sample number, s+1 is index assuming x0 is the first sample
|
cuqi/solver/__init__.py
CHANGED
cuqi/solver/_solver.py
CHANGED
|
@@ -584,8 +584,8 @@ class FISTA(object):
|
|
|
584
584
|
----------
|
|
585
585
|
A : ndarray or callable f(x,*args).
|
|
586
586
|
b : ndarray.
|
|
587
|
-
x0 : ndarray. Initial guess.
|
|
588
587
|
proximal : callable f(x, gamma) for proximal mapping.
|
|
588
|
+
x0 : ndarray. Initial guess.
|
|
589
589
|
maxit : The maximum number of iterations.
|
|
590
590
|
stepsize : The stepsize of the gradient step.
|
|
591
591
|
abstol : The numerical tolerance for convergence checks.
|
|
@@ -606,11 +606,11 @@ class FISTA(object):
|
|
|
606
606
|
b = rng.standard_normal(m)
|
|
607
607
|
stepsize = 0.99/(sp.linalg.interpolative.estimate_spectral_norm(A)**2)
|
|
608
608
|
x0 = np.zeros(n)
|
|
609
|
-
fista = FISTA(A, b,
|
|
609
|
+
fista = FISTA(A, b, proximal = ProximalL1, x0, stepsize = stepsize, maxit = 100, abstol=1e-12, adaptive = True)
|
|
610
610
|
sol, _ = fista.solve()
|
|
611
611
|
|
|
612
612
|
"""
|
|
613
|
-
def __init__(self, A, b,
|
|
613
|
+
def __init__(self, A, b, proximal, x0, maxit=100, stepsize=1e0, abstol=1e-14, adaptive = True):
|
|
614
614
|
|
|
615
615
|
self.A = A
|
|
616
616
|
self.b = b
|
|
@@ -650,8 +650,157 @@ class FISTA(object):
|
|
|
650
650
|
x_new = x_new + ((k-1)/(k+2))*(x_new - x_old)
|
|
651
651
|
|
|
652
652
|
x = x_new.copy()
|
|
653
|
+
|
|
654
|
+
class ADMM(object):
|
|
655
|
+
"""Alternating Direction Method of Multipliers for solving regularized linear least squares problems of the form:
|
|
656
|
+
Minimize ||Ax-b||^2 + sum_i f_i(L_i x),
|
|
657
|
+
where the sum ranges from 1 to an arbitrary n. See definition of the parameter `penalty_terms` below for more details about f_i and L_i
|
|
658
|
+
|
|
659
|
+
Reference:
|
|
660
|
+
[1] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
Parameters
|
|
664
|
+
----------
|
|
665
|
+
A : ndarray or callable
|
|
666
|
+
Represents a matrix or a function that performs matrix-vector multiplications.
|
|
667
|
+
When A is a callable, it accepts arguments (x, flag) where:
|
|
668
|
+
- flag=1 indicates multiplication of A with vector x, that is A @ x.
|
|
669
|
+
- flag=2 indicates multiplication of the transpose of A with vector x, that is A.T @ x.
|
|
670
|
+
b : ndarray.
|
|
671
|
+
penalty_terms : List of tuples (callable proximal operator of f_i, linear operator L_i)
|
|
672
|
+
Each callable proximal operator f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
|
|
673
|
+
x0 : ndarray. Initial guess.
|
|
674
|
+
penalty_parameter : Trade-off between linear least squares and regularization term in the solver iterates. Denoted as "rho" in [1].
|
|
675
|
+
maxit : The maximum number of iterations.
|
|
676
|
+
adaptive : Whether to adaptively update the penalty_parameter each iteration such that the primal and dual residual norms are of the same order of magnitude. Based on [1], Subsection 3.4.1
|
|
677
|
+
|
|
678
|
+
Example
|
|
679
|
+
-----------
|
|
680
|
+
.. code-block:: python
|
|
653
681
|
|
|
682
|
+
from cuqi.solver import ADMM, ProximalL1, ProjectNonnegative
|
|
683
|
+
import numpy as np
|
|
684
|
+
|
|
685
|
+
rng = np.random.default_rng()
|
|
686
|
+
|
|
687
|
+
m, n, k = 10, 5, 4
|
|
688
|
+
A = rng.standard_normal((m, n))
|
|
689
|
+
b = rng.standard_normal(m)
|
|
690
|
+
L = rng.standard_normal((k, n))
|
|
691
|
+
|
|
692
|
+
x0 = np.zeros(n)
|
|
693
|
+
admm = ADMM(A, b, x0, penalty_terms = [(ProximalL1, L), (lambda z, _ : ProjectNonnegative(z), np.eye(n))], tradeoff = 10)
|
|
694
|
+
sol, _ = admm.solve()
|
|
695
|
+
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
def __init__(self, A, b, penalty_terms, x0, penalty_parameter = 10, maxit = 100, inner_max_it = 10, adaptive = True):
|
|
699
|
+
|
|
700
|
+
self.A = A
|
|
701
|
+
self.b = b
|
|
702
|
+
self.x_cur = x0
|
|
703
|
+
|
|
704
|
+
dual_len = [penalty[1].shape[0] for penalty in penalty_terms]
|
|
705
|
+
self.z_cur = [np.zeros(l) for l in dual_len]
|
|
706
|
+
self.u_cur = [np.zeros(l) for l in dual_len]
|
|
707
|
+
self.n = penalty_terms[0][1].shape[1]
|
|
708
|
+
|
|
709
|
+
self.rho = penalty_parameter
|
|
710
|
+
self.maxit = maxit
|
|
711
|
+
self.inner_max_it = inner_max_it
|
|
712
|
+
self.adaptive = adaptive
|
|
713
|
+
|
|
714
|
+
self.penalty_terms = penalty_terms
|
|
715
|
+
|
|
716
|
+
self.p = len(self.penalty_terms)
|
|
717
|
+
self._big_matrix = None
|
|
718
|
+
self._big_vector = None
|
|
719
|
+
|
|
720
|
+
def solve(self):
|
|
721
|
+
"""
|
|
722
|
+
Solves the regularized linear least squares problem using ADMM in scaled form. Based on [1], Subsection 3.1.1
|
|
723
|
+
"""
|
|
724
|
+
z_new = self.p*[0]
|
|
725
|
+
u_new = self.p*[0]
|
|
726
|
+
|
|
727
|
+
# Iterating
|
|
728
|
+
for i in range(self.maxit):
|
|
729
|
+
self._iteration_pre_processing()
|
|
730
|
+
|
|
731
|
+
# Main update (Least Squares)
|
|
732
|
+
solver = CGLS(self._big_matrix, self._big_vector, self.x_cur, self.inner_max_it)
|
|
733
|
+
x_new, _ = solver.solve()
|
|
734
|
+
|
|
735
|
+
# Regularization update
|
|
736
|
+
for j, penalty in enumerate(self.penalty_terms):
|
|
737
|
+
z_new[j] = penalty[0](penalty[1]@x_new + self.u_cur[j], 1.0/self.rho)
|
|
738
|
+
|
|
739
|
+
res_primal = 0.0
|
|
740
|
+
# Dual update
|
|
741
|
+
for j, penalty in enumerate(self.penalty_terms):
|
|
742
|
+
r_partial = penalty[1]@x_new - z_new[j]
|
|
743
|
+
res_primal += LA.norm(r_partial)**2
|
|
744
|
+
|
|
745
|
+
u_new[j] = self.u_cur[j] + r_partial
|
|
746
|
+
|
|
747
|
+
res_dual = 0.0
|
|
748
|
+
for j, penalty in enumerate(self.penalty_terms):
|
|
749
|
+
res_dual += LA.norm(penalty[1].T@(z_new[j] - self.z_cur[j]))**2
|
|
750
|
+
|
|
751
|
+
# Adaptive approach based on [1], Subsection 3.4.1
|
|
752
|
+
if self.adaptive:
|
|
753
|
+
if res_dual > 1e2*res_primal:
|
|
754
|
+
self.rho *= 0.5 # More regularization
|
|
755
|
+
elif res_primal > 1e2*res_dual:
|
|
756
|
+
self.rho *= 2.0 # More data fidelity
|
|
757
|
+
|
|
758
|
+
self.x_cur, self.z_cur, self.u_cur = x_new, z_new.copy(), u_new
|
|
759
|
+
|
|
760
|
+
return self.x_cur, i
|
|
654
761
|
|
|
762
|
+
def _iteration_pre_processing(self):
|
|
763
|
+
""" Preprocessing
|
|
764
|
+
Every iteration of ADMM requires solving a linear least squares system of the form
|
|
765
|
+
minimize 1/(rho) \|Ax-b\|_2^2 + sum_{i=1}^{p} \|penalty[1]x - (y - u)\|_2^2
|
|
766
|
+
To solve this, all linear least squares terms are combined into a single big term
|
|
767
|
+
with matrix big_matrix and data big_vector.
|
|
768
|
+
|
|
769
|
+
The matrix only needs to be updated when rho changes, i.e., when the adaptive option is used.
|
|
770
|
+
The data vector needs to be updated every iteration.
|
|
771
|
+
"""
|
|
772
|
+
|
|
773
|
+
self._big_vector = np.hstack([np.sqrt(1/self.rho)*self.b] + [self.z_cur[i] - self.u_cur[i] for i in range(self.p)])
|
|
774
|
+
|
|
775
|
+
# Check whether matrix needs to be updated
|
|
776
|
+
if self._big_matrix is not None and not self.adaptive:
|
|
777
|
+
return
|
|
778
|
+
|
|
779
|
+
# Update big_matrix
|
|
780
|
+
if callable(self.A):
|
|
781
|
+
def matrix_eval(x, flag):
|
|
782
|
+
if flag == 1:
|
|
783
|
+
out1 = np.sqrt(1/self.rho)*self.A(x, 1)
|
|
784
|
+
out2 = [penalty[1]@x for penalty in self.penalty_terms]
|
|
785
|
+
out = np.hstack([out1] + out2)
|
|
786
|
+
elif flag == 2:
|
|
787
|
+
idx_start = len(x)
|
|
788
|
+
idx_end = len(x)
|
|
789
|
+
out1 = np.zeros(self.n)
|
|
790
|
+
for _, t in reversed(self.penalty_terms):
|
|
791
|
+
idx_start -= t.shape[0]
|
|
792
|
+
out1 += t.T@x[idx_start:idx_end]
|
|
793
|
+
idx_end = idx_start
|
|
794
|
+
out2 = np.sqrt(1/self.rho)*self.A(x[:idx_end], 2)
|
|
795
|
+
out = out1 + out2
|
|
796
|
+
return out
|
|
797
|
+
self._big_matrix = matrix_eval
|
|
798
|
+
else:
|
|
799
|
+
self._big_matrix = np.vstack([np.sqrt(1/self.rho)*self.A] + [penalty[1] for penalty in self.penalty_terms])
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
|
|
655
804
|
def ProjectNonnegative(x):
|
|
656
805
|
"""(Euclidean) projection onto the nonnegative orthant.
|
|
657
806
|
|
|
@@ -678,6 +827,22 @@ def ProjectBox(x, lower = None, upper = None):
|
|
|
678
827
|
|
|
679
828
|
return np.minimum(np.maximum(x, lower), upper)
|
|
680
829
|
|
|
830
|
+
def ProjectHalfspace(x, a, b):
|
|
831
|
+
"""(Euclidean) projection onto the halfspace defined {z|<a,z> <= b}.
|
|
832
|
+
|
|
833
|
+
Parameters
|
|
834
|
+
----------
|
|
835
|
+
x : array_like.
|
|
836
|
+
a : array_like.
|
|
837
|
+
b : array_like.
|
|
838
|
+
"""
|
|
839
|
+
|
|
840
|
+
ax_b = np.inner(a,x) - b
|
|
841
|
+
if ax_b <= 0:
|
|
842
|
+
return x
|
|
843
|
+
else:
|
|
844
|
+
return x - (ax_b/np.inner(a,a))*a
|
|
845
|
+
|
|
681
846
|
def ProximalL1(x, gamma):
|
|
682
847
|
"""(Euclidean) proximal operator of the \|x\|_1 norm.
|
|
683
848
|
Also known as the shrinkage or soft thresholding operator.
|
|
@@ -687,4 +852,4 @@ def ProximalL1(x, gamma):
|
|
|
687
852
|
x : array_like.
|
|
688
853
|
gamma : scale parameter.
|
|
689
854
|
"""
|
|
690
|
-
return np.multiply(np.sign(x), np.maximum(np.abs(x)-gamma, 0))
|
|
855
|
+
return np.multiply(np.sign(x), np.maximum(np.abs(x)-gamma, 0))
|
|
File without changes
|
|
File without changes
|