deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,547 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from deepinv.optim.utils import conjugate_gradient
|
|
3
|
+
from .noise import GaussianNoise
|
|
4
|
+
from deepinv.utils import randn_like, TensorList
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Physics(torch.nn.Module): # parent class for forward models
|
|
8
|
+
r"""
|
|
9
|
+
Parent class for forward operators
|
|
10
|
+
|
|
11
|
+
It describes the general forward measurement process
|
|
12
|
+
|
|
13
|
+
.. math::
|
|
14
|
+
|
|
15
|
+
y = N(A(x))
|
|
16
|
+
|
|
17
|
+
where :math:`x` is an image of :math:`n` pixels, :math:`y` is the measurements of size :math:`m`,
|
|
18
|
+
:math:`A:\xset\mapsto \yset` is a deterministic mapping capturing the physics of the acquisition
|
|
19
|
+
and :math:`N:\yset\mapsto \yset` is a stochastic mapping which characterizes the noise affecting
|
|
20
|
+
the measurements.
|
|
21
|
+
|
|
22
|
+
:param callable A: forward operator function which maps an image to the observed measurements :math:`x\mapsto y`.
|
|
23
|
+
:param callable noise_model: function that adds noise to the measurements :math:`N(z)`.
|
|
24
|
+
See the noise module for some predefined functions.
|
|
25
|
+
:param callable sensor_model: function that incorporates any sensor non-linearities to the sensing process,
|
|
26
|
+
such as quantization or saturation, defined as a function :math:`\eta(z)`, such that
|
|
27
|
+
:math:`y=\eta\left(N(A(x))\right)`. By default, the sensor_model is set to the identity :math:`\eta(z)=z`.
|
|
28
|
+
:param int max_iter: If the operator does not have a closed form pseudoinverse, the gradient descent algorithm
|
|
29
|
+
is used for computing it, and this parameter fixes the maximum number of gradient descent iterations.
|
|
30
|
+
:param float tol: If the operator does not have a closed form pseudoinverse, the gradient descent algorithm
|
|
31
|
+
is used for computing it, and this parameter fixes the absolute tolerance of the gradient descent algorithm.
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
A=lambda x: x,
|
|
38
|
+
noise_model=lambda x: x,
|
|
39
|
+
sensor_model=lambda x: x,
|
|
40
|
+
max_iter=50,
|
|
41
|
+
tol=1e-3,
|
|
42
|
+
):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.noise_model = noise_model
|
|
45
|
+
self.sensor_model = sensor_model
|
|
46
|
+
self.forw = A
|
|
47
|
+
self.SVD = False # flag indicating SVD available
|
|
48
|
+
self.max_iter = max_iter
|
|
49
|
+
self.tol = tol
|
|
50
|
+
|
|
51
|
+
def __mul__(self, other): # physics3 = physics1 \circ physics2
|
|
52
|
+
r"""
|
|
53
|
+
Concatenates two forward operators :math:`A = A_1\circ A_2` via the mul operation
|
|
54
|
+
|
|
55
|
+
The resulting operator keeps the noise and sensor models of :math:`A_1`.
|
|
56
|
+
|
|
57
|
+
:param deepinv.physics.Physics other: Physics operator :math:`A_2`
|
|
58
|
+
:return: (deepinv.physics.Physics) concantenated operator
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
A = lambda x: self.A(other.A(x)) # (A' = A_1 A_2)
|
|
62
|
+
noise = self.noise_model
|
|
63
|
+
sensor = self.sensor_model
|
|
64
|
+
return Physics(
|
|
65
|
+
A=A,
|
|
66
|
+
noise_model=noise,
|
|
67
|
+
sensor_model=sensor,
|
|
68
|
+
max_iter=self.max_iter,
|
|
69
|
+
tol=self.tol,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def __add__(self, other):
|
|
73
|
+
r"""
|
|
74
|
+
Stacks two linear forward operators :math:`A(x) = \begin{bmatrix} A_1(x) \\ A_2(x) \end{bmatrix}`
|
|
75
|
+
via the add operation.
|
|
76
|
+
|
|
77
|
+
The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where
|
|
78
|
+
each entry corresponds to the measurements of the corresponding operator.
|
|
79
|
+
|
|
80
|
+
:param deepinv.physics.Physics other: Physics operator :math:`A_2`
|
|
81
|
+
:return: (deepinv.physics.Physics) stacked operator
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
A = lambda x: TensorList(self.A(x)).append(TensorList(other.A(x)))
|
|
85
|
+
|
|
86
|
+
class noise(torch.nn.Module):
|
|
87
|
+
def __init__(self, noise1, noise2):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.noise1 = noise1
|
|
90
|
+
self.noise2 = noise2
|
|
91
|
+
|
|
92
|
+
def forward(self, x):
|
|
93
|
+
return TensorList(self.noise1(x[:-1])).append(self.noise2(x[-1]))
|
|
94
|
+
|
|
95
|
+
class sensor(torch.nn.Module):
|
|
96
|
+
def __init__(self, sensor1, sensor2):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.sensor1 = sensor1
|
|
99
|
+
self.sensor2 = sensor2
|
|
100
|
+
|
|
101
|
+
def forward(self, x):
|
|
102
|
+
return TensorList(self.sensor1(x[:-1])).append(self.sensor2(x[-1]))
|
|
103
|
+
|
|
104
|
+
return Physics(
|
|
105
|
+
A=A,
|
|
106
|
+
noise_model=noise(self.noise_model, other.noise_model),
|
|
107
|
+
sensor_model=sensor(self.sensor_model, other.sensor_model),
|
|
108
|
+
max_iter=self.max_iter,
|
|
109
|
+
tol=self.tol,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
def reset(self, **kwargs):
|
|
113
|
+
if isinstance(self.noise_model, torch.nn.Module):
|
|
114
|
+
self.noise_model.__init__(**kwargs)
|
|
115
|
+
|
|
116
|
+
def forward(self, x):
|
|
117
|
+
r"""
|
|
118
|
+
Computes forward operator :math:`y = N(A(x))` (with noise and/or sensor non-linearities)
|
|
119
|
+
|
|
120
|
+
:param torch.Tensor,list[torch.Tensor] x: signal/image
|
|
121
|
+
:return: (torch.Tensor) noisy measurements
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
return self.sensor(self.noise(self.A(x)))
|
|
125
|
+
|
|
126
|
+
def A(self, x):
|
|
127
|
+
r"""
|
|
128
|
+
Computes forward operator :math:`y = A(x)` (without noise and/or sensor non-linearities)
|
|
129
|
+
|
|
130
|
+
:param torch.Tensor,list[torch.Tensor] x: signal/image
|
|
131
|
+
:return: (torch.Tensor) clean measurements
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
return self.forw(x)
|
|
135
|
+
|
|
136
|
+
def sensor(self, x):
|
|
137
|
+
r"""
|
|
138
|
+
Computes sensor non-linearities :math:`y = \eta(y)`
|
|
139
|
+
|
|
140
|
+
:param torch.Tensor,list[torch.Tensor] x: signal/image
|
|
141
|
+
:return: (torch.Tensor) clean measurements
|
|
142
|
+
"""
|
|
143
|
+
return self.sensor_model(x)
|
|
144
|
+
|
|
145
|
+
def noise(self, x):
|
|
146
|
+
r"""
|
|
147
|
+
Incorporates noise into the measurements :math:`\tilde{y} = N(y)`
|
|
148
|
+
|
|
149
|
+
:param torch.Tensor x: clean measurements
|
|
150
|
+
:return torch.Tensor: noisy measurements
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
return self.noise_model(x)
|
|
154
|
+
|
|
155
|
+
def A_dagger(self, y, x_init=None):
|
|
156
|
+
r"""
|
|
157
|
+
Computes an inverse of :math:`y = Ax` via gradient descent.
|
|
158
|
+
|
|
159
|
+
This function can be overwritten by a more efficient pseudoinverse in cases where closed form formulas exist.
|
|
160
|
+
|
|
161
|
+
:param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse.
|
|
162
|
+
:param torch.Tensor x_init: initial guess for the reconstruction.
|
|
163
|
+
:return: (torch.Tensor) The reconstructed image :math:`x`.
|
|
164
|
+
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
if x_init is None:
|
|
168
|
+
x_init = self.A_adjoint(y)
|
|
169
|
+
|
|
170
|
+
x = torch.nn.Parameter(x_init, requires_grad=True)
|
|
171
|
+
|
|
172
|
+
optimizer = torch.optim.SGD([x], lr=1e-1)
|
|
173
|
+
loss = torch.nn.MSELoss()
|
|
174
|
+
for i in range(self.max_iter):
|
|
175
|
+
err = loss(self.A(x), y)
|
|
176
|
+
optimizer.zero_grad()
|
|
177
|
+
err.backward(retain_graph=True)
|
|
178
|
+
optimizer.step()
|
|
179
|
+
if err < self.tol:
|
|
180
|
+
break
|
|
181
|
+
|
|
182
|
+
return x.clone()
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class LinearPhysics(Physics):
|
|
186
|
+
r"""
|
|
187
|
+
Parent class for linear operators.
|
|
188
|
+
|
|
189
|
+
It describes the linear forward measurement process of the form
|
|
190
|
+
|
|
191
|
+
.. math::
|
|
192
|
+
|
|
193
|
+
y = N(A(x))
|
|
194
|
+
|
|
195
|
+
where :math:`x` is an image of :math:`n` pixels, :math:`y` is the measurements of size :math:`m`,
|
|
196
|
+
:math:`A:\xset\mapsto \yset` is a deterministic linear mapping capturing the physics of the acquisition
|
|
197
|
+
and :math:`N:\yset\mapsto \yset` is a stochastic mapping which characterizes the noise affecting
|
|
198
|
+
the measurements.
|
|
199
|
+
|
|
200
|
+
:param callable A: forward operator function which maps an image to the observed measurements :math:`x\mapsto y`.
|
|
201
|
+
It is recommended to normalize it to have unit norm.
|
|
202
|
+
:param callable A_adjoint: transpose of the forward operator, which should verify the adjointness test.
|
|
203
|
+
:param callable noise_model: function that adds noise to the measurements :math:`N(z)`.
|
|
204
|
+
See the noise module for some predefined functions.
|
|
205
|
+
:param callable sensor_model: function that incorporates any sensor non-linearities to the sensing process,
|
|
206
|
+
such as quantization or saturation, defined as a function :math:`\eta(z)`, such that
|
|
207
|
+
:math:`y=\eta\left(N(A(x))\right)`. By default, the sensor_model is set to the identity :math:`\eta(z)=z`.
|
|
208
|
+
:param int max_iter: If the operator does not have a closed form pseudoinverse, the conjugate gradient algorithm
|
|
209
|
+
is used for computing it, and this parameter fixes the maximum number of conjugate gradient iterations.
|
|
210
|
+
:param float tol: If the operator does not have a closed form pseudoinverse, the conjugate gradient algorithm
|
|
211
|
+
is used for computing it, and this parameter fixes the absolute tolerance of the conjugate gradient algorithm.
|
|
212
|
+
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
A=lambda x: x,
|
|
218
|
+
A_adjoint=lambda x: x,
|
|
219
|
+
noise_model=lambda x: x,
|
|
220
|
+
sensor_model=lambda x: x,
|
|
221
|
+
max_iter=50,
|
|
222
|
+
tol=1e-3,
|
|
223
|
+
**kwargs,
|
|
224
|
+
):
|
|
225
|
+
super().__init__(
|
|
226
|
+
A=A,
|
|
227
|
+
noise_model=noise_model,
|
|
228
|
+
sensor_model=sensor_model,
|
|
229
|
+
max_iter=max_iter,
|
|
230
|
+
tol=tol,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
self.adjoint = A_adjoint
|
|
234
|
+
|
|
235
|
+
def A_adjoint(self, y):
|
|
236
|
+
r"""
|
|
237
|
+
Computes transpose of the forward operator :math:`\tilde{x} = A^{\top}y`.
|
|
238
|
+
If :math:`A` is linear, it should be the exact transpose of the forward matrix.
|
|
239
|
+
|
|
240
|
+
.. note:
|
|
241
|
+
|
|
242
|
+
If problem is non-linear, there is not a well-defined transpose operation,
|
|
243
|
+
but defining one can be useful for some reconstruction networks, such as ``deepinv.models.ArtifactRemoval``.
|
|
244
|
+
|
|
245
|
+
:param torch.Tensor y: measurements.
|
|
246
|
+
:return: (torch.Tensor) linear reconstruction :math:`\tilde{x} = A^{\top}y`.
|
|
247
|
+
|
|
248
|
+
"""
|
|
249
|
+
return self.adjoint(y)
|
|
250
|
+
|
|
251
|
+
def __mul__(self, other):
|
|
252
|
+
r"""
|
|
253
|
+
Concatenates two linear forward operators :math:`A = A_1\circ A_2` via the * operation
|
|
254
|
+
|
|
255
|
+
The resulting linear operator keeps the noise and sensor models of :math:`A_1`.
|
|
256
|
+
|
|
257
|
+
:param deepinv.physics.LinearPhysics other: Physics operator :math:`A_2`
|
|
258
|
+
:return: (deepinv.physics.LinearPhysics) concantenated operator
|
|
259
|
+
|
|
260
|
+
"""
|
|
261
|
+
A = lambda x: self.A(other.A(x)) # (A' = A_1 A_2)
|
|
262
|
+
A_adjoint = lambda x: other.A_adjoint(self.A_adjoint(x))
|
|
263
|
+
noise = self.noise_model
|
|
264
|
+
sensor = self.sensor_model
|
|
265
|
+
return LinearPhysics(
|
|
266
|
+
A=A,
|
|
267
|
+
A_adjoint=A_adjoint,
|
|
268
|
+
noise_model=noise,
|
|
269
|
+
sensor_model=sensor,
|
|
270
|
+
max_iter=self.max_iter,
|
|
271
|
+
tol=self.tol,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
def __add__(self, other):
|
|
275
|
+
r"""
|
|
276
|
+
Stacks two linear forward operators :math:`A = \begin{bmatrix} A_1 \\ A_2 \end{bmatrix}` via the add operation.
|
|
277
|
+
|
|
278
|
+
The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where
|
|
279
|
+
each entry corresponds to the measurements of the corresponding operator.
|
|
280
|
+
|
|
281
|
+
:param deepinv.physics.LinearPhysics other: Physics operator :math:`A_2`
|
|
282
|
+
:return: (deepinv.physics.LinearPhysics) stacked operator
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
A = lambda x: TensorList(self.A(x)).append(TensorList(other.A(x)))
|
|
286
|
+
|
|
287
|
+
def A_adjoint(y):
|
|
288
|
+
at1 = self.A_adjoint(y[:-1]) if len(y) > 2 else self.A_adjoint(y[0])
|
|
289
|
+
return at1 + other.A_adjoint(y[-1])
|
|
290
|
+
|
|
291
|
+
class noise(torch.nn.Module):
|
|
292
|
+
def __init__(self, noise1, noise2):
|
|
293
|
+
super().__init__()
|
|
294
|
+
self.noise1 = noise1
|
|
295
|
+
self.noise2 = noise2
|
|
296
|
+
|
|
297
|
+
def forward(self, x):
|
|
298
|
+
return TensorList(self.noise1(x[:-1])).append(self.noise2(x[-1]))
|
|
299
|
+
|
|
300
|
+
class sensor(torch.nn.Module):
|
|
301
|
+
def __init__(self, sensor1, sensor2):
|
|
302
|
+
super().__init__()
|
|
303
|
+
self.sensor1 = sensor1
|
|
304
|
+
self.sensor2 = sensor2
|
|
305
|
+
|
|
306
|
+
def forward(self, x):
|
|
307
|
+
return TensorList(self.sensor1(x[:-1])).append(self.sensor2(x[-1]))
|
|
308
|
+
|
|
309
|
+
return LinearPhysics(
|
|
310
|
+
A=A,
|
|
311
|
+
A_adjoint=A_adjoint,
|
|
312
|
+
noise_model=noise(self.noise_model, other.noise_model),
|
|
313
|
+
sensor_model=sensor(self.sensor_model, other.sensor_model),
|
|
314
|
+
max_iter=self.max_iter,
|
|
315
|
+
tol=self.tol,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def compute_norm(self, x0, max_iter=100, tol=1e-3, verbose=True):
|
|
319
|
+
r"""
|
|
320
|
+
Computes the spectral :math:`\ell_2` norm (Lipschitz constant) of the operator
|
|
321
|
+
|
|
322
|
+
:math:`A^{\top}A`, i.e., :math:`\|A^{\top}A\|`.
|
|
323
|
+
|
|
324
|
+
using the `power method <https://en.wikipedia.org/wiki/Power_iteration>`_.
|
|
325
|
+
|
|
326
|
+
:param torch.Tensor x0: initialisation point of the algorithm
|
|
327
|
+
:param int max_iter: maximum number of iterations
|
|
328
|
+
:param float tol: relative variation criterion for convergence
|
|
329
|
+
:param bool verbose: print information
|
|
330
|
+
|
|
331
|
+
:returns z: (float) spectral norm of :math:`A^{\top}A`, i.e., :math:`\|A^{\top}A\|`.
|
|
332
|
+
"""
|
|
333
|
+
x = torch.randn_like(x0)
|
|
334
|
+
x /= torch.norm(x)
|
|
335
|
+
zold = torch.zeros_like(x)
|
|
336
|
+
for it in range(max_iter):
|
|
337
|
+
y = self.A(x)
|
|
338
|
+
y = self.A_adjoint(y)
|
|
339
|
+
z = torch.matmul(x.reshape(-1), y.reshape(-1)) / torch.norm(x) ** 2
|
|
340
|
+
|
|
341
|
+
rel_var = torch.norm(z - zold)
|
|
342
|
+
if rel_var < tol and verbose:
|
|
343
|
+
print(
|
|
344
|
+
f"Power iteration converged at iteration {it}, value={z.item():.2f}"
|
|
345
|
+
)
|
|
346
|
+
break
|
|
347
|
+
zold = z
|
|
348
|
+
x = y / torch.norm(y)
|
|
349
|
+
|
|
350
|
+
return z
|
|
351
|
+
|
|
352
|
+
def adjointness_test(self, u):
|
|
353
|
+
r"""
|
|
354
|
+
Numerically check that :math:`A^{\top}` is indeed the adjoint of :math:`A`.
|
|
355
|
+
|
|
356
|
+
:param torch.Tensor u: initialisation point of the adjointness test method
|
|
357
|
+
|
|
358
|
+
:return: (float) a quantity that should be theoretically 0. In practice, it should be of the order of the chosen dtype precision (i.e. single or double).
|
|
359
|
+
|
|
360
|
+
"""
|
|
361
|
+
u_in = u # .type(self.dtype)
|
|
362
|
+
Au = self.A(u_in)
|
|
363
|
+
|
|
364
|
+
if isinstance(Au, tuple) or isinstance(Au, list):
|
|
365
|
+
V = [randn_like(au) for au in Au]
|
|
366
|
+
Atv = self.A_adjoint(V)
|
|
367
|
+
s1 = 0
|
|
368
|
+
for au, v in zip(Au, V):
|
|
369
|
+
s1 += (v * au).flatten().sum()
|
|
370
|
+
|
|
371
|
+
else:
|
|
372
|
+
v = randn_like(Au)
|
|
373
|
+
Atv = self.A_adjoint(v)
|
|
374
|
+
|
|
375
|
+
s1 = (v * Au).flatten().sum()
|
|
376
|
+
|
|
377
|
+
s2 = (Atv * u_in).flatten().sum()
|
|
378
|
+
|
|
379
|
+
return s1 - s2
|
|
380
|
+
|
|
381
|
+
def prox_l2(self, z, y, gamma):
|
|
382
|
+
r"""
|
|
383
|
+
Computes proximal operator of :math:`f(x) = \frac{1}{2}\|Ax-y\|^2`, i.e.,
|
|
384
|
+
|
|
385
|
+
.. math::
|
|
386
|
+
|
|
387
|
+
\underset{x}{\arg\min} \; \frac{\gamma}{2}\|Ax-y\|^2 + \frac{1}{2}\|x-z\|^2
|
|
388
|
+
|
|
389
|
+
:param torch.Tensor y: measurements tensor
|
|
390
|
+
:param torch.Tensor z: signal tensor
|
|
391
|
+
:param float gamma: hyperparameter of the proximal operator
|
|
392
|
+
:return: (torch.Tensor) estimated signal tensor
|
|
393
|
+
|
|
394
|
+
"""
|
|
395
|
+
b = self.A_adjoint(y) + 1 / gamma * z
|
|
396
|
+
H = lambda x: self.A_adjoint(self.A(x)) + 1 / gamma * x
|
|
397
|
+
x = conjugate_gradient(H, b, self.max_iter, self.tol)
|
|
398
|
+
return x
|
|
399
|
+
|
|
400
|
+
def A_dagger(self, y):
|
|
401
|
+
r"""
|
|
402
|
+
Computes the solution in :math:`x` to :math:`y = Ax` using the
|
|
403
|
+
` conjugate gradient method <https://en.wikipedia.org/wiki/Conjugate_gradient_method>`_.
|
|
404
|
+
|
|
405
|
+
If the size of :math:`y` is larger than :math:`x` (overcomplete problem), it computes :math:`(A^{\top} A)^{-1} A^{\top} y`,
|
|
406
|
+
otherwise (incomplete problem) it computes :math:`A^{\top} (A A^{\top})^{-1} y`.
|
|
407
|
+
|
|
408
|
+
This function can be overwritten by a more efficient pseudoinverse in cases where closed form formulas exist.
|
|
409
|
+
|
|
410
|
+
:param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse.
|
|
411
|
+
:return: (torch.Tensor) The reconstructed image :math:`x`.
|
|
412
|
+
|
|
413
|
+
"""
|
|
414
|
+
Aty = self.A_adjoint(y)
|
|
415
|
+
|
|
416
|
+
overcomplete = Aty.flatten().shape[0] < y.flatten().shape[0]
|
|
417
|
+
|
|
418
|
+
if not overcomplete:
|
|
419
|
+
A = lambda x: self.A(self.A_adjoint(x))
|
|
420
|
+
b = y
|
|
421
|
+
else:
|
|
422
|
+
A = lambda x: self.A_adjoint(self.A(x))
|
|
423
|
+
b = Aty
|
|
424
|
+
|
|
425
|
+
x = conjugate_gradient(A=A, b=b, max_iter=self.max_iter, tol=self.tol)
|
|
426
|
+
|
|
427
|
+
if not overcomplete:
|
|
428
|
+
x = self.A_adjoint(x)
|
|
429
|
+
|
|
430
|
+
return x
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class DecomposablePhysics(LinearPhysics):
|
|
434
|
+
r"""
|
|
435
|
+
Parent class for linear operators with SVD decomposition.
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
The singular value decomposition is expressed as
|
|
439
|
+
|
|
440
|
+
.. math::
|
|
441
|
+
|
|
442
|
+
A = U\text{diag}(s)V^{\top} \in \mathbb{R}^{m\times n}
|
|
443
|
+
|
|
444
|
+
where :math:`U\in\mathbb{C}^{n\times n}` and :math:`V\in\mathbb{C}^{m\times m}`
|
|
445
|
+
are orthonormal linear transformations and :math:`s\in\mathbb{R}_{+}^{n}` are the singular values.
|
|
446
|
+
|
|
447
|
+
:param callable U: orthonormal transformation
|
|
448
|
+
:param callable U_adjoint: transpose of U
|
|
449
|
+
:param callable V: orthonormal transformation
|
|
450
|
+
:param callable V_adjoint: transpose of V
|
|
451
|
+
:param torch.Tensor, float mask: Singular values of the transform
|
|
452
|
+
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
def __init__(
|
|
456
|
+
self,
|
|
457
|
+
U=lambda x: x,
|
|
458
|
+
U_adjoint=lambda x: x,
|
|
459
|
+
V=lambda x: x,
|
|
460
|
+
V_adjoint=lambda x: x,
|
|
461
|
+
mask=1.0,
|
|
462
|
+
**kwargs,
|
|
463
|
+
):
|
|
464
|
+
super().__init__(**kwargs)
|
|
465
|
+
self._V = V
|
|
466
|
+
self._U = U
|
|
467
|
+
self._U_adjoint = U_adjoint
|
|
468
|
+
self._V_adjoint = V_adjoint
|
|
469
|
+
self.mask = mask
|
|
470
|
+
|
|
471
|
+
def A(self, x):
|
|
472
|
+
return self.U(self.mask * self.V_adjoint(x))
|
|
473
|
+
|
|
474
|
+
def U(self, x):
|
|
475
|
+
return self._U(x)
|
|
476
|
+
|
|
477
|
+
def V(self, x):
|
|
478
|
+
return self._U(x)
|
|
479
|
+
|
|
480
|
+
def U_adjoint(self, x):
|
|
481
|
+
return self._U_adjoint(x)
|
|
482
|
+
|
|
483
|
+
def V_adjoint(self, x):
|
|
484
|
+
return self._V_adjoint(x)
|
|
485
|
+
|
|
486
|
+
def A_adjoint(self, y):
|
|
487
|
+
if isinstance(self.mask, float):
|
|
488
|
+
mask = self.mask
|
|
489
|
+
else:
|
|
490
|
+
mask = torch.conj(self.mask)
|
|
491
|
+
|
|
492
|
+
return self.V(mask * self.U_adjoint(y))
|
|
493
|
+
|
|
494
|
+
def prox_l2(self, z, y, gamma):
|
|
495
|
+
r"""
|
|
496
|
+
Computes proximal operator of :math:`f(x)=\frac{\gamma}{2}\|Ax-y\|^2`
|
|
497
|
+
in an efficient manner leveraging the singular vector decomposition.
|
|
498
|
+
|
|
499
|
+
:param torch.Tensor y: measurements tensor
|
|
500
|
+
:param torch.Tensor, float z: signal tensor
|
|
501
|
+
:param float gamma: hyperparameter :math:`\gamma` of the proximal operator
|
|
502
|
+
:return: (torch.Tensor) estimated signal tensor
|
|
503
|
+
|
|
504
|
+
"""
|
|
505
|
+
b = self.A_adjoint(y) + 1 / gamma * z
|
|
506
|
+
if isinstance(self.mask, float):
|
|
507
|
+
scaling = self.mask**2 + 1 / gamma
|
|
508
|
+
else:
|
|
509
|
+
scaling = torch.conj(self.mask) * self.mask + 1 / gamma
|
|
510
|
+
x = self.V(self.V_adjoint(b) / scaling)
|
|
511
|
+
return x
|
|
512
|
+
|
|
513
|
+
def A_dagger(self, y):
|
|
514
|
+
r"""
|
|
515
|
+
Computes :math:`A^{\dagger}y = x` in an efficient manner leveraging the singular vector decomposition.
|
|
516
|
+
|
|
517
|
+
:param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse.
|
|
518
|
+
:return: (torch.Tensor) The reconstructed image :math:`x`.
|
|
519
|
+
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
# avoid division by singular value = 0
|
|
523
|
+
|
|
524
|
+
if not isinstance(self.mask, float):
|
|
525
|
+
mask = torch.zeros_like(self.mask)
|
|
526
|
+
mask[self.mask > 1e-5] = 1 / self.mask[self.mask > 1e-5]
|
|
527
|
+
else:
|
|
528
|
+
mask = 1 / self.mask
|
|
529
|
+
|
|
530
|
+
return self.V(self.U_adjoint(y) * mask)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class Denoising(DecomposablePhysics):
|
|
534
|
+
r"""
|
|
535
|
+
|
|
536
|
+
Forward operator for denoising problems.
|
|
537
|
+
|
|
538
|
+
The linear operator is just the identity mapping :math:`A(x)=x`
|
|
539
|
+
|
|
540
|
+
:param torch.nn.Module noise: noise distribution, e.g., ``deepinv.physics.GaussianNoise``, or a user-defined torch.nn.Module.
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
def __init__(self, noise=GaussianNoise(sigma=0.1), **kwargs):
|
|
544
|
+
super().__init__(**kwargs)
|
|
545
|
+
if noise is None:
|
|
546
|
+
noise = GaussianNoise(sigma=0.0)
|
|
547
|
+
self.noise_model = noise
|
deepinv/physics/haze.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from deepinv.physics.forward import Physics
|
|
3
|
+
from deepinv.utils import TensorList
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Haze(Physics):
|
|
7
|
+
r"""
|
|
8
|
+
Standard haze model
|
|
9
|
+
|
|
10
|
+
The operator is defined as https://ieeexplore.ieee.org/abstract/document/5567108/
|
|
11
|
+
|
|
12
|
+
.. math::
|
|
13
|
+
|
|
14
|
+
y = t \odot I + a (1-t)
|
|
15
|
+
|
|
16
|
+
where :math:`t = \exp(-\beta d - o)` is the medium transmission, :math:`I` is the intensity (possibly RGB) image,
|
|
17
|
+
:math:`\odot` denotes element-wise multiplication, :math:`a>0` is the atmospheric light,
|
|
18
|
+
:math:`d` is the scene depth, and :math:`\beta>0` and :math:`o` are constants.
|
|
19
|
+
|
|
20
|
+
This is a non-linear inverse problems, whose unknown parameters are :math:`I`, :math:`d`, :math:`a`.
|
|
21
|
+
|
|
22
|
+
:param float beta: constant :math:`\beta>0`
|
|
23
|
+
:param float offset: constant :math:`o`
|
|
24
|
+
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, beta=0.1, offset=0, **kwargs):
|
|
28
|
+
super().__init__(**kwargs)
|
|
29
|
+
self.beta = beta
|
|
30
|
+
self.offset = offset
|
|
31
|
+
|
|
32
|
+
def A(self, x):
|
|
33
|
+
r"""
|
|
34
|
+
:param list, tuple x: The input x should be a tuple/list such that x[0] = image torch.tensor :math:`I`,
|
|
35
|
+
x[1] = depth torch.tensor :math:`d`, x[2] = scalar or torch.tensor of one element :math:`a`.
|
|
36
|
+
:return: (torch.tensor) hazy image.
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
im = x[0]
|
|
40
|
+
d = x[1]
|
|
41
|
+
A = x[2]
|
|
42
|
+
|
|
43
|
+
t = torch.exp(-self.beta * (d + self.offset))
|
|
44
|
+
y = t * im + (1 - t) * A
|
|
45
|
+
return y
|
|
46
|
+
|
|
47
|
+
def A_dagger(self, y):
|
|
48
|
+
r"""
|
|
49
|
+
|
|
50
|
+
Returns the trivial inverse where x[0] = y (trivial estimate of the image :math:`I`),
|
|
51
|
+
x[1] = tensor of depth :math:`d` equal to one, x[2] = 1 for :math:`a`.
|
|
52
|
+
|
|
53
|
+
.. note:
|
|
54
|
+
|
|
55
|
+
This trivial inverse can be useful for some reconstruction networks, such as ``deepinv.models.ArtifactRemoval``.
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
:param torch.tensor y: Hazy image.
|
|
59
|
+
:return: (deepinv.utils.ListTensor) trivial inverse.
|
|
60
|
+
|
|
61
|
+
"""
|
|
62
|
+
b, c, h, w = y.shape
|
|
63
|
+
d = torch.ones((b, 1, h, w), device=y.device)
|
|
64
|
+
A = torch.ones(1, device=y.device)
|
|
65
|
+
return TensorList([y, d, A])
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from deepinv.physics.forward import DecomposablePhysics
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Inpainting(DecomposablePhysics):
|
|
6
|
+
r"""
|
|
7
|
+
|
|
8
|
+
Inpainting forward operator, keeps a subset of entries.
|
|
9
|
+
|
|
10
|
+
The operator is described by the diagonal matrix
|
|
11
|
+
|
|
12
|
+
.. math::
|
|
13
|
+
|
|
14
|
+
A = \text{diag}(m) \in \mathbb{R}^{n\times n}
|
|
15
|
+
|
|
16
|
+
where :math:`m` is a binary mask with n entries.
|
|
17
|
+
|
|
18
|
+
This operator is linear and has a trivial SVD decomposition, which allows for fast computation
|
|
19
|
+
of the pseudo-inverse and proximal operator.
|
|
20
|
+
|
|
21
|
+
An existing operator can be loaded from a saved ``.pth`` file via ``self.load_state_dict(save_path)``,
|
|
22
|
+
in a similar fashion to ``torch.nn.Module``.
|
|
23
|
+
|
|
24
|
+
:param tuple tensor_size: size of the input images, e.g., (C, H, W).
|
|
25
|
+
:param torch.tensor, float mask: If the input is a float, the entries of the mask will be sampled from a bernoulli
|
|
26
|
+
distribution with probability equal to ``mask``. If the input is a ``torch.tensor`` matching tensor_size,
|
|
27
|
+
the mask will be set to this tensor.
|
|
28
|
+
:param torch.device device: gpu or cpu
|
|
29
|
+
:param bool pixelwise: Apply the mask in a pixelwise fashion, i.e., zero all channels in a given pixel simultaneously.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, tensor_size, mask=0.3, pixelwise=True, device="cpu", **kwargs):
|
|
34
|
+
super().__init__(**kwargs)
|
|
35
|
+
self.tensor_size = tensor_size
|
|
36
|
+
|
|
37
|
+
if isinstance(mask, torch.Tensor): # check if the user created mask
|
|
38
|
+
self.mask = mask
|
|
39
|
+
else: # otherwise create new random mask
|
|
40
|
+
mask_rate = mask
|
|
41
|
+
self.mask = torch.ones(tensor_size, device=device)
|
|
42
|
+
aux = torch.rand_like(self.mask)
|
|
43
|
+
if not pixelwise:
|
|
44
|
+
self.mask[aux > mask_rate] = 0
|
|
45
|
+
else:
|
|
46
|
+
self.mask[:, aux[0, :, :] > mask_rate] = 0
|
|
47
|
+
|
|
48
|
+
self.mask = torch.nn.Parameter(self.mask.unsqueeze(0), requires_grad=False)
|