deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,607 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from deepinv.optim.utils import gradient_descent
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DataFidelity(nn.Module):
|
|
8
|
+
r"""
|
|
9
|
+
Data fidelity term :math:`\datafid{x}{y}=\distance{Ax}{y}`.
|
|
10
|
+
|
|
11
|
+
This is the base class for the data fidelity term :math:`\datafid{x}{y} = \distance{A(x)}{y}` where :math:`A` is a
|
|
12
|
+
linear or nonlinear operator, :math:`x\in\xset` is a variable , :math:`y\in\yset` is the observation and
|
|
13
|
+
:math:`\distancename` is a distance function.
|
|
14
|
+
|
|
15
|
+
::
|
|
16
|
+
|
|
17
|
+
# define a loss function
|
|
18
|
+
data_fidelity = L2()
|
|
19
|
+
|
|
20
|
+
# Create a measurement operator
|
|
21
|
+
A = torch.Tensor([[2, 0], [0, 0.5]])
|
|
22
|
+
A_forward = lambda v: A @ v
|
|
23
|
+
A_adjoint = lambda v: A.transpose(0, 1) @ v
|
|
24
|
+
|
|
25
|
+
# Define the physics model associated to this operator
|
|
26
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
27
|
+
|
|
28
|
+
# Define two points
|
|
29
|
+
x = torch.Tensor([[1], [4]]).unsqueeze(0)
|
|
30
|
+
y = torch.Tensor([[1], [1]]).unsqueeze(0)
|
|
31
|
+
|
|
32
|
+
# Compute the loss :math:`f(x) = \datafid{A(x)}{y}`
|
|
33
|
+
f_x = data_fidelity(x, y, physics) # print(f_x) gives tensor([1.0000])
|
|
34
|
+
|
|
35
|
+
# Compute the gradient of :math:`f`
|
|
36
|
+
grad = data_fidelity.grad(x, y, physics) # print(grad) gives tensor([[[2.0000], [0.5000]]])
|
|
37
|
+
|
|
38
|
+
# Compute the proximity operator of :math:`f`
|
|
39
|
+
prox = data_fidelity.prox(x, y, physics, gamma=1.0) # print(prox) gives tensor([[[0.6000], [3.6000]]])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
.. warning::
|
|
43
|
+
All variables have a batch dimension as first dimension.
|
|
44
|
+
|
|
45
|
+
:param callable d: data fidelity distance function :math:`\distance{u}{y}`. Outputs a tensor of size `B`, the size of the batch. Default: None.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, d=None):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self._d = d
|
|
51
|
+
|
|
52
|
+
def d(self, u, y, *args, **kwargs):
|
|
53
|
+
r"""
|
|
54
|
+
Computes the data fidelity distance :math:`\distance{u}{y}`.
|
|
55
|
+
|
|
56
|
+
:param torch.tensor u: Variable :math:`u` at which the distance function is computed.
|
|
57
|
+
:param torch.tensor y: Data :math:`y`.
|
|
58
|
+
:return: (torch.tensor) data fidelity :math:`\distance{u}{y}`.
|
|
59
|
+
"""
|
|
60
|
+
return self._d(u, y, *args, **kwargs)
|
|
61
|
+
|
|
62
|
+
def grad_d(self, u, y, *args, **kwargs):
|
|
63
|
+
r"""
|
|
64
|
+
Computes the gradient :math:`\nabla_u\distance{u}{y}`, computed in :math:`u`. Note that this is the gradient of
|
|
65
|
+
:math:`\distancename` and not :math:`\datafidname`. By default, the gradient is computed using automatic differentiation.
|
|
66
|
+
|
|
67
|
+
:param torch.tensor u: Variable :math:`u` at which the gradient is computed.
|
|
68
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`u`.
|
|
69
|
+
:return: (torch.tensor) gradient of :math:`d` in :math:`u`, i.e. :math:`\nabla_u\distance{u}{y}`.
|
|
70
|
+
"""
|
|
71
|
+
with torch.enable_grad():
|
|
72
|
+
u = u.requires_grad_()
|
|
73
|
+
grad = torch.autograd.grad(
|
|
74
|
+
self.d(u, y, *args, **kwargs), u, create_graph=True, only_inputs=True
|
|
75
|
+
)[0]
|
|
76
|
+
return grad
|
|
77
|
+
|
|
78
|
+
def prox_d(
|
|
79
|
+
self,
|
|
80
|
+
u,
|
|
81
|
+
y,
|
|
82
|
+
*args,
|
|
83
|
+
gamma=1.0,
|
|
84
|
+
stepsize_inter=1.0,
|
|
85
|
+
max_iter_inter=50,
|
|
86
|
+
tol_inter=1e-3,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
r"""
|
|
90
|
+
Computes the proximity operator :math:`\operatorname{prox}_{\gamma\distance{\cdot}{y}}(u)`, computed in :math:`u`. Note
|
|
91
|
+
that this is the proximity operator of :math:`\distancename` and not :math:`\datafidname`. By default, the proximity operator is computed using internal gradient descent.
|
|
92
|
+
|
|
93
|
+
:param torch.tensor u: Variable :math:`u` at which the proximity operator is computed.
|
|
94
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`u`.
|
|
95
|
+
:param float gamma: stepsize of the proximity operator.
|
|
96
|
+
:param float stepsize_inter: stepsize used for internal gradient descent
|
|
97
|
+
:param int max_iter_inter: maximal number of iterations for internal gradient descent.
|
|
98
|
+
:param float tol_inter: internal gradient descent has converged when the L2 distance between two consecutive iterates is smaller than tol_inter.
|
|
99
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma\distance{\cdot}{y}}(u)`.
|
|
100
|
+
"""
|
|
101
|
+
grad = lambda z: gamma * self.grad_d(z, y, *args, **kwargs) + (z - u)
|
|
102
|
+
return gradient_descent(
|
|
103
|
+
grad, u, step_size=stepsize_inter, max_iter=max_iter_inter, tol=tol_inter
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def forward(self, x, y, physics, *args, **kwargs):
|
|
107
|
+
r"""
|
|
108
|
+
Computes the data fidelity term :math:`\datafid{x}{y} = \distance{Ax}{y}`.
|
|
109
|
+
|
|
110
|
+
:param torch.tensor x: Variable :math:`x` at which the data fidelity is computed.
|
|
111
|
+
:param torch.tensor y: Data :math:`y`.
|
|
112
|
+
:param deepinv.physics.Physics physics: physics model.
|
|
113
|
+
:return: (torch.tensor) data fidelity :math:`\datafid{x}{y}`.
|
|
114
|
+
"""
|
|
115
|
+
return self.d(physics.A(x), y, *args, **kwargs)
|
|
116
|
+
|
|
117
|
+
def grad(self, x, y, physics, *args, **kwargs):
|
|
118
|
+
r"""
|
|
119
|
+
Calculates the gradient of the data fidelity term :math:`\datafidname` at :math:`x`.
|
|
120
|
+
|
|
121
|
+
:param torch.tensor x: Variable :math:`x` at which the gradient is computed.
|
|
122
|
+
:param torch.tensor y: Data :math:`y`.
|
|
123
|
+
:param deepinv.physics.Physics physics: physics model.
|
|
124
|
+
:return: (torch.tensor) gradient :math:`\nabla_x\datafid{x}{y}`, computed in :math:`x`.
|
|
125
|
+
"""
|
|
126
|
+
return physics.A_adjoint(self.grad_d(physics.A(x), y, *args, **kwargs))
|
|
127
|
+
|
|
128
|
+
def prox(
|
|
129
|
+
self,
|
|
130
|
+
x,
|
|
131
|
+
y,
|
|
132
|
+
physics,
|
|
133
|
+
*args,
|
|
134
|
+
gamma=1.0,
|
|
135
|
+
stepsize_inter=1.0,
|
|
136
|
+
max_iter_inter=50,
|
|
137
|
+
tol_inter=1e-3,
|
|
138
|
+
**kwargs,
|
|
139
|
+
):
|
|
140
|
+
r"""
|
|
141
|
+
Calculates the proximity operator of :math:`\datafidname` at :math:`x`.
|
|
142
|
+
|
|
143
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
144
|
+
:param torch.tensor y: Data :math:`y`.
|
|
145
|
+
:param deepinv.physics.Physics physics: physics model.
|
|
146
|
+
:param float gamma: stepsize of the proximity operator.
|
|
147
|
+
:param float stepsize_inter: stepsize used for internal gradient descent
|
|
148
|
+
:param int max_iter_inter: maximal number of iterations for internal gradient descent.
|
|
149
|
+
:param float tol_inter: internal gradient descent has converged when the L2 distance between two consecutive iterates is smaller than tol_inter.
|
|
150
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`, computed in :math:`x`.
|
|
151
|
+
"""
|
|
152
|
+
grad = lambda z: gamma * self.grad(z, y, physics, *args, **kwargs) + (z - x)
|
|
153
|
+
return gradient_descent(
|
|
154
|
+
grad, x, step_size=stepsize_inter, max_iter=max_iter_inter, tol=tol_inter
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def prox_conjugate(self, x, y, physics, *args, gamma=1.0, lamb=1.0, **kwargs):
|
|
158
|
+
r"""
|
|
159
|
+
Calculates the proximity operator of the convex conjugate :math:`(\lambda \datafidname)^*` at :math:`x`,
|
|
160
|
+
using the Moreau formula.
|
|
161
|
+
|
|
162
|
+
.. warning::
|
|
163
|
+
|
|
164
|
+
This function is only valid for convex :math:`\datafidname`.
|
|
165
|
+
|
|
166
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
167
|
+
:param torch.tensor y: Data :math:`y`.
|
|
168
|
+
:param deepinv.physics.Physics physics: physics model.
|
|
169
|
+
:param float gamma: stepsize of the proximity operator.
|
|
170
|
+
:param float lamb: math:`\lambda` parameter in front of :math:`f`
|
|
171
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma (\lambda \datafidname)^*}(x)`,
|
|
172
|
+
computed in :math:`x`.
|
|
173
|
+
"""
|
|
174
|
+
return x - gamma * self.prox(
|
|
175
|
+
x / gamma, y, physics, *args, gamma=lamb / gamma, **kwargs
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def prox_d_conjugate(self, u, y, *args, gamma=1.0, lamb=1.0, **kwargs):
|
|
179
|
+
r"""
|
|
180
|
+
Calculates the proximity operator of the convex conjugate :math:`(\lambda \distancename)^*` at :math:`u`,
|
|
181
|
+
using the Moreau formula.
|
|
182
|
+
|
|
183
|
+
.. warning::
|
|
184
|
+
|
|
185
|
+
This function is only valid for convex :math:`\distancename`.
|
|
186
|
+
|
|
187
|
+
:param torch.tensor u: Variable :math:`u` at which the proximity operator is computed.
|
|
188
|
+
:param torch.tensor y: Data :math:`y`.
|
|
189
|
+
:param float gamma: stepsize of the proximity operator.
|
|
190
|
+
:param float lamb: math:`\lambda` parameter in front of :math:`\distancename`
|
|
191
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma (\lambda \distancename)^*}(x)`,
|
|
192
|
+
computed in :math:`x`.
|
|
193
|
+
"""
|
|
194
|
+
return u - gamma * self.prox_d(
|
|
195
|
+
u / gamma, y, *args, gamma=lamb / gamma, **kwargs
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class L2(DataFidelity):
|
|
200
|
+
r"""
|
|
201
|
+
Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm
|
|
202
|
+
|
|
203
|
+
.. math::
|
|
204
|
+
|
|
205
|
+
f(x) = \frac{1}{2\sigma^2}\|Ax-y\|^2
|
|
206
|
+
|
|
207
|
+
It can be used to define a log-likelihood function associated with additive Gaussian noise
|
|
208
|
+
by setting an appropriate noise level :math:`\sigma`.
|
|
209
|
+
|
|
210
|
+
:param float sigma: Standard deviation of the noise to be used as a normalisation factor.
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
::
|
|
214
|
+
|
|
215
|
+
# define a loss function
|
|
216
|
+
loss = L2()
|
|
217
|
+
|
|
218
|
+
# create a measurement operator
|
|
219
|
+
A = torch.Tensor([[2, 0], [0, 0.5]])
|
|
220
|
+
A_forward = lambda v:A@v
|
|
221
|
+
A_adjoint = lambda v: A.transpose(0,1)@v
|
|
222
|
+
|
|
223
|
+
# Define the physics model associated to this operator
|
|
224
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
225
|
+
|
|
226
|
+
# Define two points
|
|
227
|
+
x = torch.Tensor([1, 4])
|
|
228
|
+
y = torch.Tensor([1, 1])
|
|
229
|
+
|
|
230
|
+
# Compute the loss f(Ax, y)
|
|
231
|
+
f = loss(x, y, physics) # print f gives 1.0
|
|
232
|
+
|
|
233
|
+
# Compute the gradient of f
|
|
234
|
+
grad_dA = data_fidelity.grad(x, y, physics) # print grad_d gives [2.0000, 0.5000]
|
|
235
|
+
|
|
236
|
+
# Compute the proximity operator of f
|
|
237
|
+
prox_dA = data_fidelity.prox(x, y, physics, gamma=1.0) # print prox_dA gives [0.6000, 3.6000]
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(self, sigma=1.0):
|
|
241
|
+
super().__init__()
|
|
242
|
+
|
|
243
|
+
self.norm = 1 / (sigma**2)
|
|
244
|
+
|
|
245
|
+
def d(self, u, y):
|
|
246
|
+
r"""
|
|
247
|
+
Computes the data fidelity distance :math:`\datafid{u}{y}`, i.e.
|
|
248
|
+
|
|
249
|
+
.. math::
|
|
250
|
+
|
|
251
|
+
\datafid{u}{y} = \frac{1}{2\sigma^2}\|u-y\|^2
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
:param torch.tensor u: Variable :math:`u` at which the data fidelity is computed.
|
|
255
|
+
:param torch.tensor y: Data :math:`y`.
|
|
256
|
+
:return: (torch.tensor) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch.
|
|
257
|
+
"""
|
|
258
|
+
x = u - y
|
|
259
|
+
d = 0.5 * torch.norm(x.view(x.shape[0], -1), p=2, dim=-1) ** 2
|
|
260
|
+
return d
|
|
261
|
+
|
|
262
|
+
def grad_d(self, u, y):
|
|
263
|
+
r"""
|
|
264
|
+
Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{u}\distance{u}{y}`, i.e.
|
|
265
|
+
|
|
266
|
+
.. math::
|
|
267
|
+
|
|
268
|
+
\nabla_{u}\distance{u}{y} = \frac{1}{\sigma^2}(u-y)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
:param torch.tensor u: Variable :math:`u` at which the gradient is computed.
|
|
272
|
+
:param torch.tensor y: Data :math:`y`.
|
|
273
|
+
:return: (torch.tensor) gradient of the distance function :math:`\nabla_{u}\distance{u}{y}`.
|
|
274
|
+
"""
|
|
275
|
+
return self.norm * (u - y)
|
|
276
|
+
|
|
277
|
+
def prox_d(self, x, y, gamma=1.0):
|
|
278
|
+
r"""
|
|
279
|
+
Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2\sigma^2}\|x-y\|^2`.
|
|
280
|
+
|
|
281
|
+
Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e.
|
|
282
|
+
|
|
283
|
+
.. math::
|
|
284
|
+
|
|
285
|
+
\operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
289
|
+
:param torch.tensor y: Data :math:`y`.
|
|
290
|
+
:param float gamma: thresholding parameter.
|
|
291
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`.
|
|
292
|
+
"""
|
|
293
|
+
gamma_ = self.norm * gamma
|
|
294
|
+
return (x + gamma_ * y) / (1 + gamma_)
|
|
295
|
+
|
|
296
|
+
def prox(self, x, y, physics, gamma=1.0):
|
|
297
|
+
r"""
|
|
298
|
+
Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`.
|
|
299
|
+
|
|
300
|
+
Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e.
|
|
301
|
+
|
|
302
|
+
.. math::
|
|
303
|
+
|
|
304
|
+
\operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
308
|
+
:param torch.tensor y: Data :math:`y`.
|
|
309
|
+
:param deepinv.physics.Physics physics: physics model.
|
|
310
|
+
:param float gamma: stepsize of the proximity operator.
|
|
311
|
+
:return: (torch.tensor) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`.
|
|
312
|
+
"""
|
|
313
|
+
return physics.prox_l2(x, y, self.norm * gamma)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class IndicatorL2(DataFidelity):
|
|
317
|
+
r"""
|
|
318
|
+
Indicator of :math:`\ell_2` ball with radius :math:`r`.
|
|
319
|
+
|
|
320
|
+
The indicator function of the $\ell_2$ ball with radius :math:`r`, denoted as \iota_{\mathcal{B}_2(y,r)(u)},
|
|
321
|
+
is defined as
|
|
322
|
+
|
|
323
|
+
.. math::
|
|
324
|
+
|
|
325
|
+
\iota_{\mathcal{B}_2(y,r)}(u)= \left.
|
|
326
|
+
\begin{cases}
|
|
327
|
+
0, & \text{if } \|u-y\|_2\leq r \\
|
|
328
|
+
+\infty & \text{else.}
|
|
329
|
+
\end{cases}
|
|
330
|
+
\right.
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
:param float radius: radius of the ball. Default: None.
|
|
334
|
+
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
def __init__(self, radius=None):
|
|
338
|
+
super().__init__()
|
|
339
|
+
self.radius = radius
|
|
340
|
+
|
|
341
|
+
def d(self, u, y, radius=None):
|
|
342
|
+
r"""
|
|
343
|
+
Computes the batched indicator of :math:`\ell_2` ball with radius `radius`, i.e. :math:`\iota_{\mathcal{B}(y,r)}(u)`.
|
|
344
|
+
|
|
345
|
+
:param torch.tensor u: Variable :math:`u` at which the indicator is computed. :math:`u` is assumed to be of shape (B, ...) where B is the batch size.
|
|
346
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`u`.
|
|
347
|
+
:param float radius: radius of the :math:`\ell_2` ball. If `radius` is None, the radius of the ball is set to `self.radius`. Default: None.
|
|
348
|
+
:return: (torch.tensor) indicator of :math:`\ell_2` ball with radius `radius`. If the point is inside the ball, the output is 0, else it is 1e16.
|
|
349
|
+
"""
|
|
350
|
+
diff = u - y
|
|
351
|
+
dist = torch.norm(diff.view(diff.shape[0], -1), p=2, dim=-1)
|
|
352
|
+
radius = self.radius if radius is None else radius
|
|
353
|
+
loss = (dist > radius) * 1e16
|
|
354
|
+
return loss
|
|
355
|
+
|
|
356
|
+
def prox_d(self, x, y, radius=None, gamma=None):
|
|
357
|
+
r"""
|
|
358
|
+
Proximal operator of the indicator of :math:`\ell_2` ball with radius `radius`, i.e.
|
|
359
|
+
|
|
360
|
+
.. math::
|
|
361
|
+
|
|
362
|
+
\operatorname{prox}_{\iota_{\mathcal{B}_2(y,r)}}(x) = \operatorname{proj}_{\mathcal{B}_2(y, r)}(x)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
where :math:`\operatorname{proj}_{C}(x)` denotes the projection on the closed convex set :math:`C`.
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
369
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`x`.
|
|
370
|
+
:param float gamma: step-size. Note that this parameter is not used in this function.
|
|
371
|
+
:param float radius: radius of the :math:`\ell_2` ball.
|
|
372
|
+
:return: (torch.tensor) projection on the :math:`\ell_2` ball of radius `radius` and centered in `y`.
|
|
373
|
+
"""
|
|
374
|
+
radius = self.radius if radius is None else radius
|
|
375
|
+
diff = x - y
|
|
376
|
+
dist = torch.norm(diff.view(diff.shape[0], -1), p=2, dim=-1)
|
|
377
|
+
return y + diff * (
|
|
378
|
+
torch.min(torch.tensor([radius]).to(x.device), dist) / (dist + 1e-12)
|
|
379
|
+
).view(-1, 1, 1, 1)
|
|
380
|
+
|
|
381
|
+
def prox(
|
|
382
|
+
self,
|
|
383
|
+
x,
|
|
384
|
+
y,
|
|
385
|
+
physics,
|
|
386
|
+
radius=None,
|
|
387
|
+
stepsize=None,
|
|
388
|
+
crit_conv=1e-5,
|
|
389
|
+
max_iter=100,
|
|
390
|
+
):
|
|
391
|
+
r"""
|
|
392
|
+
Proximal operator of the indicator of :math:`\ell_2` ball with radius `radius`, i.e.
|
|
393
|
+
|
|
394
|
+
.. math::
|
|
395
|
+
|
|
396
|
+
\operatorname{prox}_{\gamma \iota_{\mathcal{B}_2(y, r)}(A\cdot)}(x) = \underset{u}{\text{argmin}} \,\, \iota_{\mathcal{B}_2(y, r)}(Au)+\frac{1}{2}\|u-x\|_2^2
|
|
397
|
+
|
|
398
|
+
Since no closed form is available for general measurement operators, we use a dual forward-backward algorithm,
|
|
399
|
+
as suggested in `Proximal Splitting Methods in Signal Processing <https://arxiv.org/pdf/0912.3522.pdf>`_.
|
|
400
|
+
|
|
401
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
402
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`A(x)`.
|
|
403
|
+
:param torch.tensor radius: radius of the :math:`\ell_2` ball.
|
|
404
|
+
:param float stepsize: step-size of the dual-forward-backward algorithm.
|
|
405
|
+
:param float crit_conv: convergence criterion of the dual-forward-backward algorithm.
|
|
406
|
+
:param int max_iter: maximum number of iterations of the dual-forward-backward algorithm.
|
|
407
|
+
:param float gamma: factor in front of the indicator function. Notice that this does not affect the proximity
|
|
408
|
+
operator since the indicator is scale invariant. Default: None.
|
|
409
|
+
:return: (torch.tensor) projection on the :math:`\ell_2` ball of radius `radius` and centered in `y`.
|
|
410
|
+
"""
|
|
411
|
+
radius = self.radius if radius is None else radius
|
|
412
|
+
|
|
413
|
+
if physics.A(x).shape == x.shape and (physics.A(x) == x).all(): # Identity case
|
|
414
|
+
return self.prox_d(x, y, gamma=None, radius=radius)
|
|
415
|
+
else:
|
|
416
|
+
norm_AtA = physics.compute_norm(x, verbose=False)
|
|
417
|
+
stepsize = 1.0 / norm_AtA if stepsize is None else stepsize
|
|
418
|
+
u = physics.A(x)
|
|
419
|
+
for it in range(max_iter):
|
|
420
|
+
u_prev = u.clone()
|
|
421
|
+
|
|
422
|
+
t = x - physics.A_adjoint(u)
|
|
423
|
+
u_ = u + stepsize * physics.A(t)
|
|
424
|
+
u = u_ - stepsize * self.prox_d(
|
|
425
|
+
u_ / stepsize, y, radius=radius, gamma=None
|
|
426
|
+
)
|
|
427
|
+
rel_crit = ((u - u_prev).norm()) / (u.norm() + 1e-12)
|
|
428
|
+
if rel_crit < crit_conv:
|
|
429
|
+
break
|
|
430
|
+
return t
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class PoissonLikelihood(DataFidelity):
|
|
434
|
+
r"""
|
|
435
|
+
|
|
436
|
+
Poisson negative log-likelihood.
|
|
437
|
+
|
|
438
|
+
.. math::
|
|
439
|
+
|
|
440
|
+
\datafid{z}{y} = -y^{\top} \log(z+\beta)+1^{\top}z
|
|
441
|
+
|
|
442
|
+
where :math:`y` are the measurements, :math:`z` is the estimated (positive) density and :math:`\beta\geq 0` is
|
|
443
|
+
an optional background level.
|
|
444
|
+
|
|
445
|
+
.. note::
|
|
446
|
+
|
|
447
|
+
The function is not Lipschitz smooth w.r.t. :math:`z` in the absence of background (:math:`\beta=0`).
|
|
448
|
+
|
|
449
|
+
:param float bkg: background level :math:`\beta`.
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
def __init__(self, gain=1.0, bkg=0, normalize=True):
|
|
453
|
+
super().__init__()
|
|
454
|
+
self.bkg = bkg
|
|
455
|
+
self.gain = gain
|
|
456
|
+
self.normalize = normalize
|
|
457
|
+
|
|
458
|
+
def d(self, x, y):
|
|
459
|
+
if self.normalize:
|
|
460
|
+
y = y * self.gain
|
|
461
|
+
return (-y * torch.log(self.gain * x + self.bkg)).flatten().sum() + (
|
|
462
|
+
self.gain * x
|
|
463
|
+
).flatten().sum()
|
|
464
|
+
|
|
465
|
+
def grad_d(self, x, y):
|
|
466
|
+
if self.normalize:
|
|
467
|
+
y = y * self.gain
|
|
468
|
+
return (1 / self.gain) * (torch.ones_like(x) - y / (self.gain * x + self.bkg))
|
|
469
|
+
|
|
470
|
+
def prox_d(self, x, y, gamma=1.0):
|
|
471
|
+
if self.normalize:
|
|
472
|
+
y = y * self.gain
|
|
473
|
+
out = (
|
|
474
|
+
x
|
|
475
|
+
- (self.gain / gamma)
|
|
476
|
+
* ((x - self.gain / gamma).pow(2) + 4 * y / gamma).sqrt()
|
|
477
|
+
)
|
|
478
|
+
return out / 2
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class L1(DataFidelity):
|
|
482
|
+
r"""
|
|
483
|
+
:math:`\ell_1` data fidelity term.
|
|
484
|
+
|
|
485
|
+
In this case, the data fidelity term is defined as
|
|
486
|
+
|
|
487
|
+
.. math::
|
|
488
|
+
|
|
489
|
+
f(x) = \|Ax-y\|_1.
|
|
490
|
+
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
def __init__(self):
|
|
494
|
+
super().__init__()
|
|
495
|
+
|
|
496
|
+
def d(self, x, y):
|
|
497
|
+
diff = x - y
|
|
498
|
+
return torch.norm(diff.view(diff.shape[0], -1), p=1, dim=-1)
|
|
499
|
+
|
|
500
|
+
def grad_d(self, x, y):
|
|
501
|
+
r"""
|
|
502
|
+
Gradient of the gradient of the :math:`\ell_1` norm, i.e.
|
|
503
|
+
|
|
504
|
+
.. math::
|
|
505
|
+
|
|
506
|
+
\partial \datafid(x) = \operatorname{sign}(x-y)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
.. note::
|
|
510
|
+
|
|
511
|
+
The gradient is not defined at :math:`x=y`.
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
:param torch.tensor x: Variable :math:`x` at which the gradient is computed.
|
|
515
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`x`.
|
|
516
|
+
:return: (torch.tensor) gradient of the :math:`\ell_1` norm at `x`.
|
|
517
|
+
"""
|
|
518
|
+
return torch.sign(x - y)
|
|
519
|
+
|
|
520
|
+
def prox_d(self, u, y, gamma=1.0):
|
|
521
|
+
r"""
|
|
522
|
+
Proximal operator of the :math:`\ell_1` norm, i.e.
|
|
523
|
+
|
|
524
|
+
.. math::
|
|
525
|
+
|
|
526
|
+
\operatorname{prox}_{\gamma \ell_1}(x) = \underset{z}{\text{argmin}} \,\, \gamma \|z-y\|_1+\frac{1}{2}\|z-x\|_2^2
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
also known as the soft-thresholding operator.
|
|
530
|
+
|
|
531
|
+
:param torch.tensor u: Variable :math:`u` at which the proximity operator is computed.
|
|
532
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`x`.
|
|
533
|
+
:param float gamma: stepsize (or soft-thresholding parameter).
|
|
534
|
+
:return: (torch.tensor) soft-thresholding of `u` with parameter `gamma`.
|
|
535
|
+
"""
|
|
536
|
+
d = u - y
|
|
537
|
+
aux = torch.sign(d) * torch.maximum(
|
|
538
|
+
d.abs() - gamma, torch.tensor([0]).to(d.device)
|
|
539
|
+
)
|
|
540
|
+
return aux + y
|
|
541
|
+
|
|
542
|
+
def prox(
|
|
543
|
+
self, x, y, physics, gamma=1.0, stepsize=None, crit_conv=1e-5, max_iter=100
|
|
544
|
+
):
|
|
545
|
+
r"""
|
|
546
|
+
Proximal operator of the :math:`\ell_1` norm composed with A, i.e.
|
|
547
|
+
|
|
548
|
+
.. math::
|
|
549
|
+
|
|
550
|
+
\operatorname{prox}_{\gamma \ell_1}(x) = \underset{u}{\text{argmin}} \,\, \gamma \|Au-y\|_1+\frac{1}{2}\|u-x\|_2^2.
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
Since no closed form is available for general measurement operators, we use a dual forward-backward algorithm.
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
:param torch.tensor x: Variable :math:`x` at which the proximity operator is computed.
|
|
558
|
+
:param torch.tensor y: Data :math:`y` of the same dimension as :math:`A(x)`.
|
|
559
|
+
:param deepinv.physics.Physics physics: physics model.
|
|
560
|
+
:param float stepsize: step-size of the dual-forward-backward algorithm.
|
|
561
|
+
:param float crit_conv: convergence criterion of the dual-forward-backward algorithm.
|
|
562
|
+
:param int max_iter: maximum number of iterations of the dual-forward-backward algorithm.
|
|
563
|
+
:return: (torch.tensor) projection on the :math:`\ell_2` ball of radius `radius` and centered in `y`.
|
|
564
|
+
"""
|
|
565
|
+
norm_AtA = physics.compute_norm(x)
|
|
566
|
+
stepsize = 1.0 / norm_AtA if stepsize is None else stepsize
|
|
567
|
+
u = x.clone()
|
|
568
|
+
for it in range(max_iter):
|
|
569
|
+
u_prev = u.clone()
|
|
570
|
+
|
|
571
|
+
t = x - physics.A_adjoint(u)
|
|
572
|
+
u_ = u + stepsize * physics.A(t)
|
|
573
|
+
u = u_ - stepsize * self.prox_d(u_ / stepsize, y, gamma / stepsize)
|
|
574
|
+
rel_crit = ((u - u_prev).norm()) / (u.norm() + 1e-12)
|
|
575
|
+
print(rel_crit)
|
|
576
|
+
if rel_crit < crit_conv and it > 2:
|
|
577
|
+
break
|
|
578
|
+
return t
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
if __name__ == "__main__":
|
|
582
|
+
import deepinv as dinv
|
|
583
|
+
|
|
584
|
+
# define a loss function
|
|
585
|
+
data_fidelity = L2()
|
|
586
|
+
|
|
587
|
+
# create a measurement operator dxd
|
|
588
|
+
A = torch.Tensor([[2, 0], [0, 0.5]])
|
|
589
|
+
A_forward = lambda v: torch.matmul(A, v)
|
|
590
|
+
A_adjoint = lambda v: torch.matmul(A.transpose(0, 1), v)
|
|
591
|
+
|
|
592
|
+
# Define the physics model associated to this operator
|
|
593
|
+
physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
|
|
594
|
+
|
|
595
|
+
# Define two points of size Bxd
|
|
596
|
+
x = torch.Tensor([1, 4]).unsqueeze(0).repeat(4, 1).unsqueeze(-1)
|
|
597
|
+
y = torch.Tensor([1, 1]).unsqueeze(0).repeat(4, 1).unsqueeze(-1)
|
|
598
|
+
|
|
599
|
+
# Compute the loss :math:`f(x) = \datafid{A(x)}{y}`
|
|
600
|
+
f = data_fidelity(x, y, physics) # print f gives 1.0
|
|
601
|
+
# Compute the gradient of :math:`f`
|
|
602
|
+
grad = data_fidelity.grad(x, y, physics) # print grad_f gives [2.0000, 0.5000]
|
|
603
|
+
|
|
604
|
+
# Compute the proximity operator of :math:`f`
|
|
605
|
+
prox = data_fidelity.prox(
|
|
606
|
+
x, y, physics, gamma=1.0
|
|
607
|
+
) # print prox_fA gives [0.6000, 3.6000]
|