CUQIpy 1.2.0.post0.dev314__py3-none-any.whl → 1.2.0.post0.dev342__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.2.0.post0.dev314
3
+ Version: 1.2.0.post0.dev342
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
- License: Apache License
6
+ License: Apache License
7
7
  Version 2.0, January 2004
8
8
  http://www.apache.org/licenses/
9
9
 
@@ -1,6 +1,6 @@
1
1
  cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
2
2
  cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
3
- cuqi/_version.py,sha256=3is9GWwImF4ErxQUumozNosursu2a0z-JOwSEfvceWY,510
3
+ cuqi/_version.py,sha256=3ljcYFr3ZS-sk3bvctKBFOKTMiiZYLbPdoHP03jbCfk,510
4
4
  cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
5
5
  cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
6
6
  cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
@@ -34,7 +34,9 @@ cuqi/distribution/_posterior.py,sha256=zAfL0GECxekZ2lBt1W6_LN0U_xskMwK4VNce5xAF7
34
34
  cuqi/distribution/_smoothed_laplace.py,sha256=p-1Y23mYA9omwiHGkEuv3T2mwcPAAoNlCr7T8osNkjE,2925
35
35
  cuqi/distribution/_truncated_normal.py,sha256=sZkLYgnkGOyS_3ZxY7iw6L62t-Jh6shzsweRsRepN2k,4240
36
36
  cuqi/distribution/_uniform.py,sha256=KA8yQ6ZS3nQGS4PYJ4hpDg6Eq8EQKQvPsIpYfR8fj2w,1967
37
- cuqi/experimental/__init__.py,sha256=vhZvyMX6rl8Y0haqCzGLPz6PSUKyu75XMQbeDHqTTrw,83
37
+ cuqi/experimental/__init__.py,sha256=iStrmEy4ZMnGpEyd6QNlC6RK83lrS9iRkxQS0u-s8cU,105
38
+ cuqi/experimental/algebra/__init__.py,sha256=3d4Bfx1upcHhEubNn6-Sa3WFpuksPQJif4OptcNDe_s,31
39
+ cuqi/experimental/algebra/_ast.py,sha256=SAlqqQkW_559fyiM66S3dWgfACR7jdSZkomJm1mKix0,8698
38
40
  cuqi/experimental/mcmc/__init__.py,sha256=zSqLZmxOqQ-F94C9-gPv7g89TX1XxlrlNm071Eb167I,4487
39
41
  cuqi/experimental/mcmc/_conjugate.py,sha256=VNPQkGity0mposcqxrx4UIeXm35EvJvZED4p2stffvA,9924
40
42
  cuqi/experimental/mcmc/_conjugate_approx.py,sha256=uEnY2ea9su5ivcNagyRAwpQP2gBY98sXU7N0y5hTADo,3653
@@ -46,14 +48,14 @@ cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=NIoCLKL5x89Bxm-JLDLR_NTunRE
46
48
  cuqi/experimental/mcmc/_laplace_approximation.py,sha256=XcGIa2wl9nCSTtAFurejYYOKkDVAJ22q75xQKsyu2nI,5803
47
49
  cuqi/experimental/mcmc/_mh.py,sha256=MXo0ahXP4KGFkaY4HtvcBE-TMQzsMlTmLKzSvpz7drU,2941
48
50
  cuqi/experimental/mcmc/_pcn.py,sha256=wqJBZLuRFSwxihaI53tumAg6AWVuceLMOmXssTetd1A,3374
49
- cuqi/experimental/mcmc/_rto.py,sha256=j3PD3ZfOuGifbBu51Z7GdMCM47HjH0luhpWFDPXNtxc,10477
51
+ cuqi/experimental/mcmc/_rto.py,sha256=lzfeUuV8jUiWG-80KQ4if6toVcX7bMv-a0chBZq0vZ4,12021
50
52
  cuqi/experimental/mcmc/_sampler.py,sha256=BZHnpB6s-YSddd46wQSds0vNF61RA58Nc9ZU05WngdU,20184
51
53
  cuqi/experimental/mcmc/_utilities.py,sha256=kUzHbhIS3HYZRbneNBK41IogUYX5dS_bJxqEGm7TQBI,525
52
54
  cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
53
55
  cuqi/geometry/_geometry.py,sha256=5ZNrw6LivxEEw0vrk1eCxKIw-8mkAh7930voRVywDbY,47089
54
56
  cuqi/implicitprior/__init__.py,sha256=6z3lvw-tWDyjZSpB3pYzvijSMK9Zlf1IYqOVTtMD2h4,309
55
- cuqi/implicitprior/_regularizedGMRF.py,sha256=IR9tKzNMoz-b0RKu6ahVgMx_lDNB3jZHVWFMQm6QqZk,6259
56
- cuqi/implicitprior/_regularizedGaussian.py,sha256=cQtrgzyJU2pwoK4ORGl1erKLE9VY5NqwZTiqiViDswA,12371
57
+ cuqi/implicitprior/_regularizedGMRF.py,sha256=rr3R2C1aheuu_KD35MureZKfOwY8O1pkVDHvuaFnFFU,6300
58
+ cuqi/implicitprior/_regularizedGaussian.py,sha256=mzaAHq0yz73FZo-OB2iqFMd2i2NNzVv4mjd9-ger8a0,15435
57
59
  cuqi/implicitprior/_regularizedUnboundedUniform.py,sha256=H2fTOSqYTlDiLxQ7Ya6wnpCUIkpO4qKrkTOsOPnBBeU,3483
58
60
  cuqi/implicitprior/_restorator.py,sha256=ixnH8RGcLpqlaIUdR5Dwjx72sO9f3BeotNFRC7Z7qZo,9198
59
61
  cuqi/likelihood/__init__.py,sha256=QXif382iwZ5bT3ZUqmMs_n70JVbbjxbqMrlQYbMn4Zo,1776
@@ -81,14 +83,14 @@ cuqi/sampler/_sampler.py,sha256=TkZ_WAS-5Q43oICa-Elc2gftsRTBd7PEDUMDZ9tTGmU,5712
81
83
  cuqi/samples/__init__.py,sha256=vCs6lVk-pi8RBqa6cIN5wyn6u-K9oEf1Na4k1ZMrYv8,44
82
84
  cuqi/samples/_samples.py,sha256=hUc8OnCF9CTCuDTrGHwwzv3wp8mG_6vsJAFvuQ-x0uA,35832
83
85
  cuqi/solver/__init__.py,sha256=3eoTTgBHe3M6ygrbgUVG3GlqaZVe5lGajNV9rolXZJ8,179
84
- cuqi/solver/_solver.py,sha256=4LdfxLaU-fUHltZw7Sq-Xohyxd_6RvKy03xxtIMW6Zs,29488
86
+ cuqi/solver/_solver.py,sha256=GquU_rj-9yfPQnBVE_gXo4wdF84xw_pLks3bJarzR58,29491
85
87
  cuqi/testproblem/__init__.py,sha256=DWTOcyuNHMbhEuuWlY5CkYkNDSAqhvsKmJXBLivyblU,202
86
88
  cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7sS0Q,52540
87
89
  cuqi/utilities/__init__.py,sha256=H7xpJe2UinjZftKvE2JuXtTi4DqtkR6uIezStAXwfGg,428
88
90
  cuqi/utilities/_get_python_variable_name.py,sha256=QwlBVj2koJRA8s8pWd554p7-ElcI7HUwY32HknaR92E,1827
89
91
  cuqi/utilities/_utilities.py,sha256=Jc4knn80vLoA7kgw9FzXwKVFGaNBOXiA9kgvltZU3Ao,11777
90
- CUQIpy-1.2.0.post0.dev314.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
91
- CUQIpy-1.2.0.post0.dev314.dist-info/METADATA,sha256=HyWTczhv-qzQEzxuVcHGaTAJx2XRXxVjmM5tihB8g9Q,18496
92
- CUQIpy-1.2.0.post0.dev314.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
93
- CUQIpy-1.2.0.post0.dev314.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
94
- CUQIpy-1.2.0.post0.dev314.dist-info/RECORD,,
92
+ CUQIpy-1.2.0.post0.dev342.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
93
+ CUQIpy-1.2.0.post0.dev342.dist-info/METADATA,sha256=gzSVkU5NkQHKRqldZB3uQ9m1ef9_2ZFmAIt1Q0wDeYo,18529
94
+ CUQIpy-1.2.0.post0.dev342.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
95
+ CUQIpy-1.2.0.post0.dev342.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
96
+ CUQIpy-1.2.0.post0.dev342.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.5.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
cuqi/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-20T10:51:18+0100",
11
+ "date": "2024-11-25T09:18:51+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "42e096c7066472ac4d5701fc5f2de53636246f5d",
15
- "version": "1.2.0.post0.dev314"
14
+ "full-revisionid": "63395d14f6c3f964633b20200ada13b8a213da20",
15
+ "version": "1.2.0.post0.dev342"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -1,2 +1,3 @@
1
1
  """ Experimental module for testing new features and ideas. """
2
2
  from . import mcmc
3
+ from . import algebra
@@ -0,0 +1 @@
1
+ from ._ast import VariableNode
@@ -0,0 +1,325 @@
1
+ """
2
+ CUQIpy specific implementation of an abstract syntax tree (AST) for algebra on variables.
3
+
4
+ The AST is used to record the operations applied to variables allowing a delayed evaluation
5
+ of said operations when needed by traversing the tree with the __call__ method.
6
+
7
+ For example, the following code
8
+
9
+ x = VariableNode('x')
10
+ y = VariableNode('y')
11
+ z = 2*x + 3*y
12
+
13
+ will create the following AST:
14
+
15
+ z = AddNode(
16
+ MultiplyNode(
17
+ ValueNode(2),
18
+ VariableNode('x')
19
+ ),
20
+ MultiplyNode(
21
+ ValueNode(3),
22
+ VariableNode('y')
23
+ )
24
+ )
25
+
26
+ which can be evaluated by calling the __call__ method:
27
+
28
+ z(x=1, y=2) # returns 8
29
+
30
+ """
31
+
32
+ from abc import ABC, abstractmethod
33
+
34
+ convert_to_node = lambda x: x if isinstance(x, Node) else ValueNode(x)
35
+ """ Converts any non-Node object to a ValueNode object. """
36
+
37
+ # ====== Base classes for the nodes ======
38
+
39
+
40
+ class Node(ABC):
41
+ """Base class for all nodes in the abstract syntax tree.
42
+
43
+ Responsible for building the AST by creating nodes that represent the operations applied to variables.
44
+
45
+ Each subclass must implement the __call__ method that will evaluate the node given the input parameters.
46
+
47
+ """
48
+
49
+ @abstractmethod
50
+ def __call__(self, **kwargs):
51
+ """Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations."""
52
+ pass
53
+
54
+ @abstractmethod
55
+ def __repr__(self):
56
+ """String representation of the node. Used for printing the AST."""
57
+ pass
58
+
59
+ def __add__(self, other):
60
+ return AddNode(self, convert_to_node(other))
61
+
62
+ def __radd__(self, other):
63
+ return AddNode(convert_to_node(other), self)
64
+
65
+ def __sub__(self, other):
66
+ return SubtractNode(self, convert_to_node(other))
67
+
68
+ def __rsub__(self, other):
69
+ return SubtractNode(convert_to_node(other), self)
70
+
71
+ def __mul__(self, other):
72
+ return MultiplyNode(self, convert_to_node(other))
73
+
74
+ def __rmul__(self, other):
75
+ return MultiplyNode(convert_to_node(other), self)
76
+
77
+ def __truediv__(self, other):
78
+ return DivideNode(self, convert_to_node(other))
79
+
80
+ def __rtruediv__(self, other):
81
+ return DivideNode(convert_to_node(other), self)
82
+
83
+ def __pow__(self, other):
84
+ return PowerNode(self, convert_to_node(other))
85
+
86
+ def __rpow__(self, other):
87
+ return PowerNode(convert_to_node(other), self)
88
+
89
+ def __neg__(self):
90
+ return NegateNode(self)
91
+
92
+ def __abs__(self):
93
+ return AbsNode(self)
94
+
95
+ def __getitem__(self, i):
96
+ return GetItemNode(self, convert_to_node(i))
97
+
98
+ def __matmul__(self, other):
99
+ return MatMulNode(self, convert_to_node(other))
100
+
101
+ def __rmatmul__(self, other):
102
+ return MatMulNode(convert_to_node(other), self)
103
+
104
+
105
+ class UnaryNode(Node, ABC):
106
+ """Base class for all unary nodes in the abstract syntax tree.
107
+
108
+ Parameters
109
+ ----------
110
+ child : Node
111
+ The direct child node on which the unary operation is performed.
112
+
113
+ """
114
+
115
+ def __init__(self, child: Node):
116
+ self.child = child
117
+
118
+
119
+ class BinaryNode(Node, ABC):
120
+ """Base class for all binary nodes in the abstract syntax tree.
121
+
122
+ The op_symbol attribute is used for printing the operation in the __repr__ method.
123
+
124
+ Parameters
125
+ ----------
126
+ left : Node
127
+ Left child node to the binary operation.
128
+
129
+ right : Node
130
+ Right child node to the binary operation.
131
+
132
+ """
133
+
134
+ @property
135
+ @abstractmethod
136
+ def op_symbol(self):
137
+ """Symbol used to represent the operation in the __repr__ method."""
138
+ pass
139
+
140
+ def __init__(self, left: Node, right: Node):
141
+ self.left = left
142
+ self.right = right
143
+
144
+ def __repr__(self):
145
+ return f"{self.left} {self.op_symbol} {self.right}"
146
+
147
+
148
+ class BinaryNodeWithParenthesis(BinaryNode, ABC):
149
+ """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis."""
150
+
151
+ def __repr__(self):
152
+ left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
153
+ right = (
154
+ f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
155
+ )
156
+ return f"{left} {self.op_symbol} {right}"
157
+
158
+ class BinaryNodeWithParenthesisNoSpace(BinaryNode, ABC):
159
+ """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis but no space."""
160
+
161
+ def __repr__(self):
162
+ left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
163
+ right = (
164
+ f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
165
+ )
166
+ return f"{left}{self.op_symbol}{right}"
167
+
168
+
169
+ # ====== Specific implementations of the "leaf" nodes ======
170
+
171
+
172
+ class VariableNode(Node):
173
+ """Node that represents a generic variable, e.g. "x" or "y".
174
+
175
+ Parameters
176
+ ----------
177
+ name : str
178
+ Name of the variable. Used for printing and to retrieve the given input value
179
+ of the variable in the kwargs dictionary when evaluating the tree.
180
+
181
+ """
182
+
183
+ def __init__(self, name):
184
+ self.name = name
185
+
186
+ def __call__(self, **kwargs):
187
+ """Retrieves the value of the variable from the passed kwargs. If no value is found, it raises a KeyError."""
188
+ if not self.name in kwargs:
189
+ raise KeyError(
190
+ f"Variable '{self.name}' not found in the given input parameters. Unable to evaluate the expression."
191
+ )
192
+ return kwargs[self.name]
193
+
194
+ def __repr__(self):
195
+ return self.name
196
+
197
+
198
+ class ValueNode(Node):
199
+ """Node that represents a constant value. The value can be any python object that is not a Node.
200
+
201
+ Parameters
202
+ ----------
203
+ value : object
204
+ The python object that represents the value of the node.
205
+
206
+ """
207
+
208
+ def __init__(self, value):
209
+ self.value = value
210
+
211
+ def __call__(self, **kwargs):
212
+ """Returns the value of the node."""
213
+ return self.value
214
+
215
+ def __repr__(self):
216
+ return str(self.value)
217
+
218
+
219
+ # ====== Specific implementations of the "internal" nodes ======
220
+
221
+
222
+ class AddNode(BinaryNode):
223
+ """Node that represents the addition operation."""
224
+
225
+ @property
226
+ def op_symbol(self):
227
+ return "+"
228
+
229
+ def __call__(self, **kwargs):
230
+ return self.left(**kwargs) + self.right(**kwargs)
231
+
232
+
233
+ class SubtractNode(BinaryNode):
234
+ """Node that represents the subtraction operation."""
235
+
236
+ @property
237
+ def op_symbol(self):
238
+ return "-"
239
+
240
+ def __call__(self, **kwargs):
241
+ return self.left(**kwargs) - self.right(**kwargs)
242
+
243
+
244
+ class MultiplyNode(BinaryNodeWithParenthesis):
245
+ """Node that represents the multiplication operation."""
246
+
247
+ @property
248
+ def op_symbol(self):
249
+ return "*"
250
+
251
+ def __call__(self, **kwargs):
252
+ return self.left(**kwargs) * self.right(**kwargs)
253
+
254
+
255
+ class DivideNode(BinaryNodeWithParenthesis):
256
+ """Node that represents the division operation."""
257
+
258
+ @property
259
+ def op_symbol(self):
260
+ return "/"
261
+
262
+ def __call__(self, **kwargs):
263
+ return self.left(**kwargs) / self.right(**kwargs)
264
+
265
+
266
+ class PowerNode(BinaryNodeWithParenthesisNoSpace):
267
+ """Node that represents the power operation."""
268
+
269
+ @property
270
+ def op_symbol(self):
271
+ return "^"
272
+
273
+ def __call__(self, **kwargs):
274
+ return self.left(**kwargs) ** self.right(**kwargs)
275
+
276
+
277
+ class GetItemNode(BinaryNode):
278
+ """Node that represents the get item operation. Here the left node is the object and the right node is the index."""
279
+
280
+ def __call__(self, **kwargs):
281
+ return self.left(**kwargs)[self.right(**kwargs)]
282
+
283
+ def __repr__(self):
284
+ left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
285
+ return f"{left}[{self.right}]"
286
+
287
+ @property
288
+ def op_symbol(self):
289
+ pass
290
+
291
+
292
+ class NegateNode(UnaryNode):
293
+ """Node that represents the arithmetic negation operation."""
294
+
295
+ def __call__(self, **kwargs):
296
+ return -self.child(**kwargs)
297
+
298
+ def __repr__(self):
299
+ child = (
300
+ f"({self.child})"
301
+ if isinstance(self.child, (BinaryNode, UnaryNode))
302
+ else str(self.child)
303
+ )
304
+ return f"-{child}"
305
+
306
+
307
+ class AbsNode(UnaryNode):
308
+ """Node that represents the absolute value operation."""
309
+
310
+ def __call__(self, **kwargs):
311
+ return abs(self.child(**kwargs))
312
+
313
+ def __repr__(self):
314
+ return f"abs({self.child})"
315
+
316
+
317
+ class MatMulNode(BinaryNodeWithParenthesis):
318
+ """Node that represents the matrix multiplication operation."""
319
+
320
+ @property
321
+ def op_symbol(self):
322
+ return "@"
323
+
324
+ def __call__(self, **kwargs):
325
+ return self.left(**kwargs) @ self.right(**kwargs)
@@ -3,7 +3,7 @@ from scipy.linalg.interpolative import estimate_spectral_norm
3
3
  from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
4
4
  import numpy as np
5
5
  import cuqi
6
- from cuqi.solver import CGLS, FISTA
6
+ from cuqi.solver import CGLS, FISTA, ADMM
7
7
  from cuqi.experimental.mcmc import Sampler
8
8
 
9
9
 
@@ -161,6 +161,13 @@ class RegularizedLinearRTO(LinearRTO):
161
161
  Regularized Linear RTO (Randomize-Then-Optimize) sampler.
162
162
 
163
163
  Samples posterior related to the inverse problem with Gaussian likelihood and implicit Gaussian prior, and where the forward model is Linear.
164
+ The sampler works by repeatedly solving regularized linear least squares problems for perturbed data.
165
+ The solver for these optimization problems is chosen based on how the regularized is provided in the implicit Gaussian prior.
166
+ Currently we use the following solvers:
167
+ FISTA: [1] Beck, Amir, and Marc Teboulle. "A fast iterative shrinkage-thresholding algorithm for linear inverse problems." SIAM journal on imaging sciences 2.1 (2009): 183-202.
168
+ Used when prior.proximal is callable.
169
+ ADMM: [2] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
170
+ Used when prior.proximal is a list of penalty terms.
164
171
 
165
172
  Parameters
166
173
  ------------
@@ -171,12 +178,19 @@ class RegularizedLinearRTO(LinearRTO):
171
178
  Initial point for the sampler. *Optional*.
172
179
 
173
180
  maxit : int
174
- Maximum number of iterations of the inner FISTA solver. *Optional*.
181
+ Maximum number of iterations of the inner FISTA/ADMM solver. *Optional*.
182
+
183
+ inner_max_it : int
184
+ Maximum number of iterations of the CGLS solver used within the ADMM solver. *Optional*.
175
185
 
176
186
  stepsize : string or float
177
187
  If stepsize is a string and equals either "automatic", then the stepsize is automatically estimated based on the spectral norm.
178
188
  If stepsize is a float, then this stepsize is used.
179
189
 
190
+ penalty_parameter : int
191
+ Penalty parameter of the inner ADMM solver. *Optional*.
192
+ See [2] or `cuqi.solver.ADMM`
193
+
180
194
  abstol : float
181
195
  Absolute tolerance of the inner FISTA solver. *Optional*.
182
196
 
@@ -190,7 +204,7 @@ class RegularizedLinearRTO(LinearRTO):
190
204
  An example is shown in demos/demo31_callback.py.
191
205
 
192
206
  """
193
- def __init__(self, target=None, initial_point=None, maxit=100, stepsize="automatic", abstol=1e-10, adaptive=True, **kwargs):
207
+ def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, **kwargs):
194
208
 
195
209
  super().__init__(target=target, initial_point=initial_point, **kwargs)
196
210
 
@@ -199,10 +213,13 @@ class RegularizedLinearRTO(LinearRTO):
199
213
  self.abstol = abstol
200
214
  self.adaptive = adaptive
201
215
  self.maxit = maxit
216
+ self.inner_max_it = inner_max_it
217
+ self.penalty_parameter = penalty_parameter
202
218
 
203
219
  def _initialize(self):
204
220
  super()._initialize()
205
- self._stepsize = self._choose_stepsize()
221
+ if self._inner_solver == "FISTA":
222
+ self._stepsize = self._choose_stepsize()
206
223
 
207
224
  @property
208
225
  def proximal(self):
@@ -212,8 +229,7 @@ class RegularizedLinearRTO(LinearRTO):
212
229
  super().validate_target()
213
230
  if not isinstance(self.target.prior, (cuqi.implicitprior.RegularizedGaussian, cuqi.implicitprior.RegularizedGMRF)):
214
231
  raise TypeError("Prior needs to be RegularizedGaussian or RegularizedGMRF")
215
- if not callable(self.proximal):
216
- raise TypeError("Proximal needs to be callable")
232
+ self._inner_solver = "FISTA" if callable(self.proximal) else "ADMM"
217
233
 
218
234
  def _choose_stepsize(self):
219
235
  if isinstance(self.stepsize, str):
@@ -237,8 +253,16 @@ class RegularizedLinearRTO(LinearRTO):
237
253
 
238
254
  def step(self):
239
255
  y = self.b_tild + np.random.randn(len(self.b_tild))
240
- sim = FISTA(self.M, y, self.proximal,
241
- self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
256
+
257
+ if self._inner_solver == "FISTA":
258
+ sim = FISTA(self.M, y, self.proximal,
259
+ self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
260
+ elif self._inner_solver == "ADMM":
261
+ sim = ADMM(self.M, y, self.proximal,
262
+ self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
263
+ else:
264
+ raise ValueError("Choice of solver not supported.")
265
+
242
266
  self.current_point, _ = sim.solve()
243
267
  acc = 1
244
268
  return acc
@@ -63,6 +63,7 @@ class RegularizedGMRF(RegularizedGaussian):
63
63
 
64
64
  # Underlying explicit GMRF
65
65
  self._gaussian = GMRF(mean, prec, bc_type=bc_type, order=order, **kwargs)
66
+ kwargs.pop("geometry", None)
66
67
 
67
68
  # Init from abstract distribution class
68
69
  super(Distribution, self).__init__(**kwargs)
@@ -1,6 +1,8 @@
1
1
  from cuqi.utilities import get_non_default_args
2
2
  from cuqi.distribution import Distribution, Gaussian
3
3
  from cuqi.solver import ProjectNonnegative, ProjectBox, ProximalL1
4
+ from cuqi.geometry import Continuous1D, Continuous2D, Image2D
5
+ from cuqi.operator import FirstOrderFiniteDifference
4
6
 
5
7
  import numpy as np
6
8
 
@@ -39,17 +41,22 @@ class RegularizedGaussian(Distribution):
39
41
  sqrtprec
40
42
  See :class:`~cuqi.distribution.Gaussian` for details.
41
43
 
42
- proximal : callable f(x, scale) or None
43
- Euclidean proximal operator f of the regularization function g, that is, a solver for the optimization problem
44
- min_z 0.5||x-z||_2^2+scale*g(x).
45
-
44
+ proximal : callable f(x, scale), list of tuples (callable proximal operator of f_i, linear operator L_i) or None
45
+ If callable:
46
+ Euclidean proximal operator f of the regularization function g, that is, a solver for the optimization problem
47
+ min_z 0.5||x-z||_2^2+scale*g(x).
48
+ If list of tuples (callable proximal operator of f_i, linear operator L_i):
49
+ Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
50
+ The corresponding regularization takes the form
51
+ sum_i f_i(L_i x),
52
+ where the sum ranges from 1 to an arbitrary n.
46
53
 
47
54
  projector : callable f(x) or None
48
55
  Euclidean projection onto the constraint C, that is, a solver for the optimization problem
49
56
  min_(z in C) 0.5||x-z||_2^2.
50
57
 
51
58
  constraint : string or None
52
- Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
59
+ Preset constraints that generate the corresponding proximal parameter. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
53
60
  For "box", the following additional parameters can be passed:
54
61
  lower_bound : array_like or None
55
62
  Lower bound of box, defaults to zero
@@ -57,10 +64,10 @@ class RegularizedGaussian(Distribution):
57
64
  Upper bound of box, defaults to one
58
65
 
59
66
  regularization : string or None
60
- Preset regularization. Can be set to "l1". Required for use in Gibbs in future update.
61
- For "l1", the following additional parameters can be passed:
67
+ Preset regularization that generate the corresponding proximal parameter. Can be set to "l1" or 'tv'. Required for use in Gibbs in future update.
68
+ For "l1" or "tv", the following additional parameters can be passed:
62
69
  strength : scalar
63
- Regularization parameter, i.e., strength*||x||_1 , defaults to one
70
+ Regularization parameter, i.e., strength*||Lx||_1, defaults to one
64
71
 
65
72
  """
66
73
 
@@ -75,6 +82,7 @@ class RegularizedGaussian(Distribution):
75
82
 
76
83
  # We init the underlying Gaussian first for geometry and dimensionality handling
77
84
  self._gaussian = Gaussian(mean=mean, cov=cov, prec=prec, sqrtcov=sqrtcov, sqrtprec=sqrtprec, **kwargs)
85
+ kwargs.pop("geometry", None)
78
86
 
79
87
  # Init from abstract distribution class
80
88
  super().__init__(**kwargs)
@@ -88,12 +96,6 @@ class RegularizedGaussian(Distribution):
88
96
  if (proximal is not None) + (projector is not None) + (constraint is not None) + (regularization is not None) != 1:
89
97
  raise ValueError("Precisely one of proximal, projector, constraint or regularization needs to be provided.")
90
98
 
91
- if proximal is not None:
92
- if not callable(proximal):
93
- raise ValueError("Proximal needs to be callable.")
94
- if len(get_non_default_args(proximal)) != 2:
95
- raise ValueError("Proximal should take 2 arguments.")
96
-
97
99
  if projector is not None:
98
100
  if not callable(projector):
99
101
  raise ValueError("Projector needs to be callable.")
@@ -104,7 +106,8 @@ class RegularizedGaussian(Distribution):
104
106
  self._preset = None
105
107
 
106
108
  if proximal is not None:
107
- self._proximal = proximal
109
+ # No need to generate the proximal and associated information
110
+ self.proximal = proximal
108
111
  elif projector is not None:
109
112
  self._proximal = lambda z, gamma: projector(z)
110
113
  elif (isinstance(constraint, str) and constraint.lower() == "nonnegativity"):
@@ -113,15 +116,48 @@ class RegularizedGaussian(Distribution):
113
116
  elif (isinstance(constraint, str) and constraint.lower() == "box"):
114
117
  lower = optional_regularization_parameters["lower_bound"]
115
118
  upper = optional_regularization_parameters["upper_bound"]
116
- self._proximal = lambda z, gamma: ProjectBox(z, lower, upper)
119
+ self._proximal = lambda z, _: ProjectBox(z, lower, upper)
117
120
  self._preset = "box" # Not supported in Gibbs
118
121
  elif (isinstance(regularization, str) and regularization.lower() in ["l1"]):
119
- strength = optional_regularization_parameters["strength"]
120
- self._proximal = lambda z, gamma: ProximalL1(z, gamma*strength)
122
+ self._strength = optional_regularization_parameters["strength"]
123
+ self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
121
124
  self._preset = "l1"
125
+ elif (isinstance(regularization, str) and regularization.lower() in ["tv"]):
126
+ self._strength = optional_regularization_parameters["strength"]
127
+ if isinstance(self.geometry, (Continuous1D, Continuous2D, Image2D)):
128
+ self._transformation = FirstOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
129
+ else:
130
+ raise ValueError("Geometry not supported for total variation")
131
+
132
+ self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
133
+ self._regularization_oper = self._transformation
134
+
135
+ self._proximal = [(self._regularization_prox, self._regularization_oper)]
136
+ self._preset = "tv"
122
137
  else:
123
138
  raise ValueError("Regularization not supported")
124
139
 
140
+
141
+ @property
142
+ def transformation(self):
143
+ return self._transformation
144
+
145
+ @property
146
+ def strength(self):
147
+ return self._strength
148
+
149
+ @strength.setter
150
+ def strength(self, value):
151
+ if self._preset not in self.regularization_options():
152
+ raise TypeError("Strength is only used when the regularization is set to l1 or TV.")
153
+
154
+ self._strength = value
155
+ if self._preset == "tv":
156
+ self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
157
+ self._proximal = [(self._regularization_prox, self._regularization_oper)]
158
+ elif self._preset == "l1":
159
+ self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
160
+
125
161
  # This is a getter only attribute for the underlying Gaussian
126
162
  # It also ensures that the name of the underlying Gaussian
127
163
  # matches the name of the implicit regularized Gaussian
@@ -135,6 +171,25 @@ class RegularizedGaussian(Distribution):
135
171
  def proximal(self):
136
172
  return self._proximal
137
173
 
174
+ @proximal.setter
175
+ def proximal(self, value):
176
+ if callable(value):
177
+ if len(get_non_default_args(value)) != 2:
178
+ raise ValueError("Proximal should take 2 arguments.")
179
+ elif isinstance(value, list):
180
+ for (prox, op) in value:
181
+ if len(get_non_default_args(prox)) != 2:
182
+ raise ValueError("Proximal should take 2 arguments.")
183
+ if op.shape[1] != self.geometry.par_dim:
184
+ raise ValueError("Incorrect shape of linear operator in proximal list.")
185
+ else:
186
+ raise ValueError("Proximal needs to be callable or a list. See documentation.")
187
+
188
+ self._proximal = value
189
+
190
+ # For all the presets, self._proximal is set directly,
191
+ self._preset = None
192
+
138
193
  @property
139
194
  def preset(self):
140
195
  return self._preset
@@ -154,7 +209,7 @@ class RegularizedGaussian(Distribution):
154
209
 
155
210
  @staticmethod
156
211
  def regularization_options():
157
- return ["l1"]
212
+ return ["l1", "tv"]
158
213
 
159
214
 
160
215
  # --- Defer behavior of the underlying Gaussian --- #
@@ -206,16 +261,18 @@ class RegularizedGaussian(Distribution):
206
261
  def sqrtcov(self, value):
207
262
  self.gaussian.sqrtcov = value
208
263
 
209
- def get_conditioning_variables(self):
210
- return self.gaussian.get_conditioning_variables()
211
-
212
264
  def get_mutable_variables(self):
213
- return self.gaussian.get_mutable_variables()
265
+ mutable_vars = self.gaussian.get_mutable_variables().copy()
266
+ if self.preset in self.regularization_options():
267
+ mutable_vars += ["strength"]
268
+ return mutable_vars
214
269
 
215
270
  # Overwrite the condition method such that the underlying Gaussian is conditioned in general, except when conditioning on self.name
216
271
  # which means we convert Distribution to Likelihood or EvaluatedDensity.
217
272
  def _condition(self, *args, **kwargs):
218
-
273
+ if self.preset in self.regularization_options():
274
+ return super()._condition(*args, **kwargs)
275
+
219
276
  # Handle positional arguments (similar code as in Distribution._condition)
220
277
  cond_vars = self.get_conditioning_variables()
221
278
  kwargs = self._parse_args_add_to_kwargs(cond_vars, *args, **kwargs)
@@ -275,7 +332,7 @@ class ConstrainedGaussian(RegularizedGaussian):
275
332
  min_(z in C) 0.5||x-z||_2^2.
276
333
 
277
334
  constraint : string or None
278
- Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
335
+ Preset constraints that generate the corresponding proximal parameter. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
279
336
  For "box", the following additional parameters can be passed:
280
337
  lower_bound : array_like or None
281
338
  Lower bound of box, defaults to zero
cuqi/solver/_solver.py CHANGED
@@ -669,7 +669,7 @@ class ADMM(object):
669
669
  - flag=2 indicates multiplication of the transpose of A with vector x, that is A.T @ x.
670
670
  b : ndarray.
671
671
  penalty_terms : List of tuples (callable proximal operator of f_i, linear operator L_i)
672
- Each callable proximal operator f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
672
+ Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
673
673
  x0 : ndarray. Initial guess.
674
674
  penalty_parameter : Trade-off between linear least squares and regularization term in the solver iterates. Denoted as "rho" in [1].
675
675
  maxit : The maximum number of iterations.