deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,607 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from deepinv.optim.utils import gradient_descent
5
+
6
+
7
+ class DataFidelity(nn.Module):
8
+ r"""
9
+ Data fidelity term :math:`\datafid{x}{y}=\distance{Ax}{y}`.
10
+
11
+ This is the base class for the data fidelity term :math:`\datafid{x}{y} = \distance{A(x)}{y}` where :math:`A` is a
12
+ linear or nonlinear operator, :math:`x\in\xset` is a variable , :math:`y\in\yset` is the observation and
13
+ :math:`\distancename` is a distance function.
14
+
15
+ ::
16
+
17
+ # define a loss function
18
+ data_fidelity = L2()
19
+
20
+ # Create a measurement operator
21
+ A = torch.Tensor([[2, 0], [0, 0.5]])
22
+ A_forward = lambda v: A @ v
23
+ A_adjoint = lambda v: A.transpose(0, 1) @ v
24
+
25
+ # Define the physics model associated to this operator
26
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
27
+
28
+ # Define two points
29
+ x = torch.Tensor([[1], [4]]).unsqueeze(0)
30
+ y = torch.Tensor([[1], [1]]).unsqueeze(0)
31
+
32
+ # Compute the loss :math:`f(x) = \datafid{A(x)}{y}`
33
+ f_x = data_fidelity(x, y, physics) # print(f_x) gives tensor([1.0000])
34
+
35
+ # Compute the gradient of :math:`f`
36
+ grad = data_fidelity.grad(x, y, physics) # print(grad) gives tensor([[[2.0000], [0.5000]]])
37
+
38
+ # Compute the proximity operator of :math:`f`
39
+ prox = data_fidelity.prox(x, y, physics, gamma=1.0) # print(prox) gives tensor([[[0.6000], [3.6000]]])
40
+
41
+
42
+ .. warning::
43
+ All variables have a batch dimension as first dimension.
44
+
45
+ :param callable d: data fidelity distance function :math:`\distance{u}{y}`. Outputs a tensor of size `B`, the size of the batch. Default: None.
46
+ """
47
+
48
+ def __init__(self, d=None):
49
+ super().__init__()
50
+ self._d = d
51
+
52
+ def d(self, u, y, *args, **kwargs):
53
+ r"""
54
+ Computes the data fidelity distance :math:`\distance{u}{y}`.
55
+
56
+ :param torch.tensor u: Variable :math:`u` at which the distance function is computed.
57
+ :param torch.tensor y: Data :math:`y`.
58
+ :return: (torch.tensor) data fidelity :math:`\distance{u}{y}`.
59
+ """
60
+ return self._d(u, y, *args, **kwargs)
61
+
62
+ def grad_d(self, u, y, *args, **kwargs):
63
+ r"""
64
+ Computes the gradient :math:`\nabla_u\distance{u}{y}`, computed in :math:`u`. Note that this is the gradient of
65
+ :math:`\distancename` and not :math:`\datafidname`. By default, the gradient is computed using automatic differentiation.
66
+
67
+ :param torch.tensor u: Variable :math:`u` at which the gradient is computed.
68
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`u`.
69
+ :return: (torch.tensor) gradient of :math:`d` in :math:`u`, i.e. :math:`\nabla_u\distance{u}{y}`.
70
+ """
71
+ with torch.enable_grad():
72
+ u = u.requires_grad_()
73
+ grad = torch.autograd.grad(
74
+ self.d(u, y, *args, **kwargs), u, create_graph=True, only_inputs=True
75
+ )[0]
76
+ return grad
77
+
78
+ def prox_d(
79
+ self,
80
+ u,
81
+ y,
82
+ *args,
83
+ gamma=1.0,
84
+ stepsize_inter=1.0,
85
+ max_iter_inter=50,
86
+ tol_inter=1e-3,
87
+ **kwargs,
88
+ ):
89
+ r"""
90
+ Computes the proximity operator :math:`\operatorname{prox}_{\gamma\distance{\cdot}{y}}(u)`, computed in :math:`u`. Note
91
+ that this is the proximity operator of :math:`\distancename` and not :math:`\datafidname`. By default, the proximity operator is computed using internal gradient descent.
92
+
93
+ :param torch.tensor u: Variable :math:`u` at which the proximity operator is computed.
94
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`u`.
95
+ :param float gamma: stepsize of the proximity operator.
96
+ :param float stepsize_inter: stepsize used for internal gradient descent
97
+ :param int max_iter_inter: maximal number of iterations for internal gradient descent.
98
+ :param float tol_inter: internal gradient descent has converged when the L2 distance between two consecutive iterates is smaller than tol_inter.
99
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma\distance{\cdot}{y}}(u)`.
100
+ """
101
+ grad = lambda z: gamma * self.grad_d(z, y, *args, **kwargs) + (z - u)
102
+ return gradient_descent(
103
+ grad, u, step_size=stepsize_inter, max_iter=max_iter_inter, tol=tol_inter
104
+ )
105
+
106
+ def forward(self, x, y, physics, *args, **kwargs):
107
+ r"""
108
+ Computes the data fidelity term :math:`\datafid{x}{y} = \distance{Ax}{y}`.
109
+
110
+ :param torch.tensor x: Variable :math:`x` at which the data fidelity is computed.
111
+ :param torch.tensor y: Data :math:`y`.
112
+ :param deepinv.physics.Physics physics: physics model.
113
+ :return: (torch.tensor) data fidelity :math:`\datafid{x}{y}`.
114
+ """
115
+ return self.d(physics.A(x), y, *args, **kwargs)
116
+
117
+ def grad(self, x, y, physics, *args, **kwargs):
118
+ r"""
119
+ Calculates the gradient of the data fidelity term :math:`\datafidname` at :math:`x`.
120
+
121
+ :param torch.tensor x: Variable :math:`x` at which the gradient is computed.
122
+ :param torch.tensor y: Data :math:`y`.
123
+ :param deepinv.physics.Physics physics: physics model.
124
+ :return: (torch.tensor) gradient :math:`\nabla_x\datafid{x}{y}`, computed in :math:`x`.
125
+ """
126
+ return physics.A_adjoint(self.grad_d(physics.A(x), y, *args, **kwargs))
127
+
128
+ def prox(
129
+ self,
130
+ x,
131
+ y,
132
+ physics,
133
+ *args,
134
+ gamma=1.0,
135
+ stepsize_inter=1.0,
136
+ max_iter_inter=50,
137
+ tol_inter=1e-3,
138
+ **kwargs,
139
+ ):
140
+ r"""
141
+ Calculates the proximity operator of :math:`\datafidname` at :math:`x`.
142
+
143
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
144
+ :param torch.tensor y: Data :math:`y`.
145
+ :param deepinv.physics.Physics physics: physics model.
146
+ :param float gamma: stepsize of the proximity operator.
147
+ :param float stepsize_inter: stepsize used for internal gradient descent
148
+ :param int max_iter_inter: maximal number of iterations for internal gradient descent.
149
+ :param float tol_inter: internal gradient descent has converged when the L2 distance between two consecutive iterates is smaller than tol_inter.
150
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`, computed in :math:`x`.
151
+ """
152
+ grad = lambda z: gamma * self.grad(z, y, physics, *args, **kwargs) + (z - x)
153
+ return gradient_descent(
154
+ grad, x, step_size=stepsize_inter, max_iter=max_iter_inter, tol=tol_inter
155
+ )
156
+
157
+ def prox_conjugate(self, x, y, physics, *args, gamma=1.0, lamb=1.0, **kwargs):
158
+ r"""
159
+ Calculates the proximity operator of the convex conjugate :math:`(\lambda \datafidname)^*` at :math:`x`,
160
+ using the Moreau formula.
161
+
162
+ .. warning::
163
+
164
+ This function is only valid for convex :math:`\datafidname`.
165
+
166
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
167
+ :param torch.tensor y: Data :math:`y`.
168
+ :param deepinv.physics.Physics physics: physics model.
169
+ :param float gamma: stepsize of the proximity operator.
170
+ :param float lamb: math:`\lambda` parameter in front of :math:`f`
171
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma (\lambda \datafidname)^*}(x)`,
172
+ computed in :math:`x`.
173
+ """
174
+ return x - gamma * self.prox(
175
+ x / gamma, y, physics, *args, gamma=lamb / gamma, **kwargs
176
+ )
177
+
178
+ def prox_d_conjugate(self, u, y, *args, gamma=1.0, lamb=1.0, **kwargs):
179
+ r"""
180
+ Calculates the proximity operator of the convex conjugate :math:`(\lambda \distancename)^*` at :math:`u`,
181
+ using the Moreau formula.
182
+
183
+ .. warning::
184
+
185
+ This function is only valid for convex :math:`\distancename`.
186
+
187
+ :param torch.tensor u: Variable :math:`u` at which the proximity operator is computed.
188
+ :param torch.tensor y: Data :math:`y`.
189
+ :param float gamma: stepsize of the proximity operator.
190
+ :param float lamb: math:`\lambda` parameter in front of :math:`\distancename`
191
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma (\lambda \distancename)^*}(x)`,
192
+ computed in :math:`x`.
193
+ """
194
+ return u - gamma * self.prox_d(
195
+ u / gamma, y, *args, gamma=lamb / gamma, **kwargs
196
+ )
197
+
198
+
199
+ class L2(DataFidelity):
200
+ r"""
201
+ Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm
202
+
203
+ .. math::
204
+
205
+ f(x) = \frac{1}{2\sigma^2}\|Ax-y\|^2
206
+
207
+ It can be used to define a log-likelihood function associated with additive Gaussian noise
208
+ by setting an appropriate noise level :math:`\sigma`.
209
+
210
+ :param float sigma: Standard deviation of the noise to be used as a normalisation factor.
211
+
212
+
213
+ ::
214
+
215
+ # define a loss function
216
+ loss = L2()
217
+
218
+ # create a measurement operator
219
+ A = torch.Tensor([[2, 0], [0, 0.5]])
220
+ A_forward = lambda v:A@v
221
+ A_adjoint = lambda v: A.transpose(0,1)@v
222
+
223
+ # Define the physics model associated to this operator
224
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
225
+
226
+ # Define two points
227
+ x = torch.Tensor([1, 4])
228
+ y = torch.Tensor([1, 1])
229
+
230
+ # Compute the loss f(Ax, y)
231
+ f = loss(x, y, physics) # print f gives 1.0
232
+
233
+ # Compute the gradient of f
234
+ grad_dA = data_fidelity.grad(x, y, physics) # print grad_d gives [2.0000, 0.5000]
235
+
236
+ # Compute the proximity operator of f
237
+ prox_dA = data_fidelity.prox(x, y, physics, gamma=1.0) # print prox_dA gives [0.6000, 3.6000]
238
+ """
239
+
240
+ def __init__(self, sigma=1.0):
241
+ super().__init__()
242
+
243
+ self.norm = 1 / (sigma**2)
244
+
245
+ def d(self, u, y):
246
+ r"""
247
+ Computes the data fidelity distance :math:`\datafid{u}{y}`, i.e.
248
+
249
+ .. math::
250
+
251
+ \datafid{u}{y} = \frac{1}{2\sigma^2}\|u-y\|^2
252
+
253
+
254
+ :param torch.tensor u: Variable :math:`u` at which the data fidelity is computed.
255
+ :param torch.tensor y: Data :math:`y`.
256
+ :return: (torch.tensor) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch.
257
+ """
258
+ x = u - y
259
+ d = 0.5 * torch.norm(x.view(x.shape[0], -1), p=2, dim=-1) ** 2
260
+ return d
261
+
262
+ def grad_d(self, u, y):
263
+ r"""
264
+ Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{u}\distance{u}{y}`, i.e.
265
+
266
+ .. math::
267
+
268
+ \nabla_{u}\distance{u}{y} = \frac{1}{\sigma^2}(u-y)
269
+
270
+
271
+ :param torch.tensor u: Variable :math:`u` at which the gradient is computed.
272
+ :param torch.tensor y: Data :math:`y`.
273
+ :return: (torch.tensor) gradient of the distance function :math:`\nabla_{u}\distance{u}{y}`.
274
+ """
275
+ return self.norm * (u - y)
276
+
277
+ def prox_d(self, x, y, gamma=1.0):
278
+ r"""
279
+ Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2\sigma^2}\|x-y\|^2`.
280
+
281
+ Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e.
282
+
283
+ .. math::
284
+
285
+ \operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
286
+
287
+
288
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
289
+ :param torch.tensor y: Data :math:`y`.
290
+ :param float gamma: thresholding parameter.
291
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`.
292
+ """
293
+ gamma_ = self.norm * gamma
294
+ return (x + gamma_ * y) / (1 + gamma_)
295
+
296
+ def prox(self, x, y, physics, gamma=1.0):
297
+ r"""
298
+ Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`.
299
+
300
+ Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e.
301
+
302
+ .. math::
303
+
304
+ \operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
305
+
306
+
307
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
308
+ :param torch.tensor y: Data :math:`y`.
309
+ :param deepinv.physics.Physics physics: physics model.
310
+ :param float gamma: stepsize of the proximity operator.
311
+ :return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`.
312
+ """
313
+ return physics.prox_l2(x, y, self.norm * gamma)
314
+
315
+
316
+ class IndicatorL2(DataFidelity):
317
+ r"""
318
+ Indicator of :math:`\ell_2` ball with radius :math:`r`.
319
+
320
+ The indicator function of the $\ell_2$ ball with radius :math:`r`, denoted as \iota_{\mathcal{B}_2(y,r)(u)},
321
+ is defined as
322
+
323
+ .. math::
324
+
325
+ \iota_{\mathcal{B}_2(y,r)}(u)= \left.
326
+ \begin{cases}
327
+ 0, & \text{if } \|u-y\|_2\leq r \\
328
+ +\infty & \text{else.}
329
+ \end{cases}
330
+ \right.
331
+
332
+
333
+ :param float radius: radius of the ball. Default: None.
334
+
335
+ """
336
+
337
+ def __init__(self, radius=None):
338
+ super().__init__()
339
+ self.radius = radius
340
+
341
+ def d(self, u, y, radius=None):
342
+ r"""
343
+ Computes the batched indicator of :math:`\ell_2` ball with radius `radius`, i.e. :math:`\iota_{\mathcal{B}(y,r)}(u)`.
344
+
345
+ :param torch.tensor u: Variable :math:`u` at which the indicator is computed. :math:`u` is assumed to be of shape (B, ...) where B is the batch size.
346
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`u`.
347
+ :param float radius: radius of the :math:`\ell_2` ball. If `radius` is None, the radius of the ball is set to `self.radius`. Default: None.
348
+ :return: (torch.tensor) indicator of :math:`\ell_2` ball with radius `radius`. If the point is inside the ball, the output is 0, else it is 1e16.
349
+ """
350
+ diff = u - y
351
+ dist = torch.norm(diff.view(diff.shape[0], -1), p=2, dim=-1)
352
+ radius = self.radius if radius is None else radius
353
+ loss = (dist > radius) * 1e16
354
+ return loss
355
+
356
+ def prox_d(self, x, y, radius=None, gamma=None):
357
+ r"""
358
+ Proximal operator of the indicator of :math:`\ell_2` ball with radius `radius`, i.e.
359
+
360
+ .. math::
361
+
362
+ \operatorname{prox}_{\iota_{\mathcal{B}_2(y,r)}}(x) = \operatorname{proj}_{\mathcal{B}_2(y, r)}(x)
363
+
364
+
365
+ where :math:`\operatorname{proj}_{C}(x)` denotes the projection on the closed convex set :math:`C`.
366
+
367
+
368
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
369
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`x`.
370
+ :param float gamma: step-size. Note that this parameter is not used in this function.
371
+ :param float radius: radius of the :math:`\ell_2` ball.
372
+ :return: (torch.tensor) projection on the :math:`\ell_2` ball of radius `radius` and centered in `y`.
373
+ """
374
+ radius = self.radius if radius is None else radius
375
+ diff = x - y
376
+ dist = torch.norm(diff.view(diff.shape[0], -1), p=2, dim=-1)
377
+ return y + diff * (
378
+ torch.min(torch.tensor([radius]).to(x.device), dist) / (dist + 1e-12)
379
+ ).view(-1, 1, 1, 1)
380
+
381
+ def prox(
382
+ self,
383
+ x,
384
+ y,
385
+ physics,
386
+ radius=None,
387
+ stepsize=None,
388
+ crit_conv=1e-5,
389
+ max_iter=100,
390
+ ):
391
+ r"""
392
+ Proximal operator of the indicator of :math:`\ell_2` ball with radius `radius`, i.e.
393
+
394
+ .. math::
395
+
396
+ \operatorname{prox}_{\gamma \iota_{\mathcal{B}_2(y, r)}(A\cdot)}(x) = \underset{u}{\text{argmin}} \,\, \iota_{\mathcal{B}_2(y, r)}(Au)+\frac{1}{2}\|u-x\|_2^2
397
+
398
+ Since no closed form is available for general measurement operators, we use a dual forward-backward algorithm,
399
+ as suggested in `Proximal Splitting Methods in Signal Processing <https://arxiv.org/pdf/0912.3522.pdf>`_.
400
+
401
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
402
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`A(x)`.
403
+ :param torch.tensor radius: radius of the :math:`\ell_2` ball.
404
+ :param float stepsize: step-size of the dual-forward-backward algorithm.
405
+ :param float crit_conv: convergence criterion of the dual-forward-backward algorithm.
406
+ :param int max_iter: maximum number of iterations of the dual-forward-backward algorithm.
407
+ :param float gamma: factor in front of the indicator function. Notice that this does not affect the proximity
408
+ operator since the indicator is scale invariant. Default: None.
409
+ :return: (torch.tensor) projection on the :math:`\ell_2` ball of radius `radius` and centered in `y`.
410
+ """
411
+ radius = self.radius if radius is None else radius
412
+
413
+ if physics.A(x).shape == x.shape and (physics.A(x) == x).all(): # Identity case
414
+ return self.prox_d(x, y, gamma=None, radius=radius)
415
+ else:
416
+ norm_AtA = physics.compute_norm(x, verbose=False)
417
+ stepsize = 1.0 / norm_AtA if stepsize is None else stepsize
418
+ u = physics.A(x)
419
+ for it in range(max_iter):
420
+ u_prev = u.clone()
421
+
422
+ t = x - physics.A_adjoint(u)
423
+ u_ = u + stepsize * physics.A(t)
424
+ u = u_ - stepsize * self.prox_d(
425
+ u_ / stepsize, y, radius=radius, gamma=None
426
+ )
427
+ rel_crit = ((u - u_prev).norm()) / (u.norm() + 1e-12)
428
+ if rel_crit < crit_conv:
429
+ break
430
+ return t
431
+
432
+
433
+ class PoissonLikelihood(DataFidelity):
434
+ r"""
435
+
436
+ Poisson negative log-likelihood.
437
+
438
+ .. math::
439
+
440
+ \datafid{z}{y} = -y^{\top} \log(z+\beta)+1^{\top}z
441
+
442
+ where :math:`y` are the measurements, :math:`z` is the estimated (positive) density and :math:`\beta\geq 0` is
443
+ an optional background level.
444
+
445
+ .. note::
446
+
447
+ The function is not Lipschitz smooth w.r.t. :math:`z` in the absence of background (:math:`\beta=0`).
448
+
449
+ :param float bkg: background level :math:`\beta`.
450
+ """
451
+
452
+ def __init__(self, gain=1.0, bkg=0, normalize=True):
453
+ super().__init__()
454
+ self.bkg = bkg
455
+ self.gain = gain
456
+ self.normalize = normalize
457
+
458
+ def d(self, x, y):
459
+ if self.normalize:
460
+ y = y * self.gain
461
+ return (-y * torch.log(self.gain * x + self.bkg)).flatten().sum() + (
462
+ self.gain * x
463
+ ).flatten().sum()
464
+
465
+ def grad_d(self, x, y):
466
+ if self.normalize:
467
+ y = y * self.gain
468
+ return (1 / self.gain) * (torch.ones_like(x) - y / (self.gain * x + self.bkg))
469
+
470
+ def prox_d(self, x, y, gamma=1.0):
471
+ if self.normalize:
472
+ y = y * self.gain
473
+ out = (
474
+ x
475
+ - (self.gain / gamma)
476
+ * ((x - self.gain / gamma).pow(2) + 4 * y / gamma).sqrt()
477
+ )
478
+ return out / 2
479
+
480
+
481
+ class L1(DataFidelity):
482
+ r"""
483
+ :math:`\ell_1` data fidelity term.
484
+
485
+ In this case, the data fidelity term is defined as
486
+
487
+ .. math::
488
+
489
+ f(x) = \|Ax-y\|_1.
490
+
491
+ """
492
+
493
+ def __init__(self):
494
+ super().__init__()
495
+
496
+ def d(self, x, y):
497
+ diff = x - y
498
+ return torch.norm(diff.view(diff.shape[0], -1), p=1, dim=-1)
499
+
500
+ def grad_d(self, x, y):
501
+ r"""
502
+ Gradient of the gradient of the :math:`\ell_1` norm, i.e.
503
+
504
+ .. math::
505
+
506
+ \partial \datafid(x) = \operatorname{sign}(x-y)
507
+
508
+
509
+ .. note::
510
+
511
+ The gradient is not defined at :math:`x=y`.
512
+
513
+
514
+ :param torch.tensor x: Variable :math:`x` at which the gradient is computed.
515
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`x`.
516
+ :return: (torch.tensor) gradient of the :math:`\ell_1` norm at `x`.
517
+ """
518
+ return torch.sign(x - y)
519
+
520
+ def prox_d(self, u, y, gamma=1.0):
521
+ r"""
522
+ Proximal operator of the :math:`\ell_1` norm, i.e.
523
+
524
+ .. math::
525
+
526
+ \operatorname{prox}_{\gamma \ell_1}(x) = \underset{z}{\text{argmin}} \,\, \gamma \|z-y\|_1+\frac{1}{2}\|z-x\|_2^2
527
+
528
+
529
+ also known as the soft-thresholding operator.
530
+
531
+ :param torch.tensor u: Variable :math:`u` at which the proximity operator is computed.
532
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`x`.
533
+ :param float gamma: stepsize (or soft-thresholding parameter).
534
+ :return: (torch.tensor) soft-thresholding of `u` with parameter `gamma`.
535
+ """
536
+ d = u - y
537
+ aux = torch.sign(d) * torch.maximum(
538
+ d.abs() - gamma, torch.tensor([0]).to(d.device)
539
+ )
540
+ return aux + y
541
+
542
+ def prox(
543
+ self, x, y, physics, gamma=1.0, stepsize=None, crit_conv=1e-5, max_iter=100
544
+ ):
545
+ r"""
546
+ Proximal operator of the :math:`\ell_1` norm composed with A, i.e.
547
+
548
+ .. math::
549
+
550
+ \operatorname{prox}_{\gamma \ell_1}(x) = \underset{u}{\text{argmin}} \,\, \gamma \|Au-y\|_1+\frac{1}{2}\|u-x\|_2^2.
551
+
552
+
553
+
554
+ Since no closed form is available for general measurement operators, we use a dual forward-backward algorithm.
555
+
556
+
557
+ :param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
558
+ :param torch.tensor y: Data :math:`y` of the same dimension as :math:`A(x)`.
559
+ :param deepinv.physics.Physics physics: physics model.
560
+ :param float stepsize: step-size of the dual-forward-backward algorithm.
561
+ :param float crit_conv: convergence criterion of the dual-forward-backward algorithm.
562
+ :param int max_iter: maximum number of iterations of the dual-forward-backward algorithm.
563
+ :return: (torch.tensor) projection on the :math:`\ell_2` ball of radius `radius` and centered in `y`.
564
+ """
565
+ norm_AtA = physics.compute_norm(x)
566
+ stepsize = 1.0 / norm_AtA if stepsize is None else stepsize
567
+ u = x.clone()
568
+ for it in range(max_iter):
569
+ u_prev = u.clone()
570
+
571
+ t = x - physics.A_adjoint(u)
572
+ u_ = u + stepsize * physics.A(t)
573
+ u = u_ - stepsize * self.prox_d(u_ / stepsize, y, gamma / stepsize)
574
+ rel_crit = ((u - u_prev).norm()) / (u.norm() + 1e-12)
575
+ print(rel_crit)
576
+ if rel_crit < crit_conv and it > 2:
577
+ break
578
+ return t
579
+
580
+
581
+ if __name__ == "__main__":
582
+ import deepinv as dinv
583
+
584
+ # define a loss function
585
+ data_fidelity = L2()
586
+
587
+ # create a measurement operator dxd
588
+ A = torch.Tensor([[2, 0], [0, 0.5]])
589
+ A_forward = lambda v: torch.matmul(A, v)
590
+ A_adjoint = lambda v: torch.matmul(A.transpose(0, 1), v)
591
+
592
+ # Define the physics model associated to this operator
593
+ physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
594
+
595
+ # Define two points of size Bxd
596
+ x = torch.Tensor([1, 4]).unsqueeze(0).repeat(4, 1).unsqueeze(-1)
597
+ y = torch.Tensor([1, 1]).unsqueeze(0).repeat(4, 1).unsqueeze(-1)
598
+
599
+ # Compute the loss :math:`f(x) = \datafid{A(x)}{y}`
600
+ f = data_fidelity(x, y, physics) # print f gives 1.0
601
+ # Compute the gradient of :math:`f`
602
+ grad = data_fidelity.grad(x, y, physics) # print grad_f gives [2.0000, 0.5000]
603
+
604
+ # Compute the proximity operator of :math:`f`
605
+ prox = data_fidelity.prox(
606
+ x, y, physics, gamma=1.0
607
+ ) # print prox_fA gives [0.6000, 3.6000]