deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,547 @@
1
+ import torch
2
+ from deepinv.optim.utils import conjugate_gradient
3
+ from .noise import GaussianNoise
4
+ from deepinv.utils import randn_like, TensorList
5
+
6
+
7
+ class Physics(torch.nn.Module): # parent class for forward models
8
+ r"""
9
+ Parent class for forward operators
10
+
11
+ It describes the general forward measurement process
12
+
13
+ .. math::
14
+
15
+ y = N(A(x))
16
+
17
+ where :math:`x` is an image of :math:`n` pixels, :math:`y` is the measurements of size :math:`m`,
18
+ :math:`A:\xset\mapsto \yset` is a deterministic mapping capturing the physics of the acquisition
19
+ and :math:`N:\yset\mapsto \yset` is a stochastic mapping which characterizes the noise affecting
20
+ the measurements.
21
+
22
+ :param callable A: forward operator function which maps an image to the observed measurements :math:`x\mapsto y`.
23
+ :param callable noise_model: function that adds noise to the measurements :math:`N(z)`.
24
+ See the noise module for some predefined functions.
25
+ :param callable sensor_model: function that incorporates any sensor non-linearities to the sensing process,
26
+ such as quantization or saturation, defined as a function :math:`\eta(z)`, such that
27
+ :math:`y=\eta\left(N(A(x))\right)`. By default, the sensor_model is set to the identity :math:`\eta(z)=z`.
28
+ :param int max_iter: If the operator does not have a closed form pseudoinverse, the gradient descent algorithm
29
+ is used for computing it, and this parameter fixes the maximum number of gradient descent iterations.
30
+ :param float tol: If the operator does not have a closed form pseudoinverse, the gradient descent algorithm
31
+ is used for computing it, and this parameter fixes the absolute tolerance of the gradient descent algorithm.
32
+
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ A=lambda x: x,
38
+ noise_model=lambda x: x,
39
+ sensor_model=lambda x: x,
40
+ max_iter=50,
41
+ tol=1e-3,
42
+ ):
43
+ super().__init__()
44
+ self.noise_model = noise_model
45
+ self.sensor_model = sensor_model
46
+ self.forw = A
47
+ self.SVD = False # flag indicating SVD available
48
+ self.max_iter = max_iter
49
+ self.tol = tol
50
+
51
+ def __mul__(self, other): # physics3 = physics1 \circ physics2
52
+ r"""
53
+ Concatenates two forward operators :math:`A = A_1\circ A_2` via the mul operation
54
+
55
+ The resulting operator keeps the noise and sensor models of :math:`A_1`.
56
+
57
+ :param deepinv.physics.Physics other: Physics operator :math:`A_2`
58
+ :return: (deepinv.physics.Physics) concantenated operator
59
+
60
+ """
61
+ A = lambda x: self.A(other.A(x)) # (A' = A_1 A_2)
62
+ noise = self.noise_model
63
+ sensor = self.sensor_model
64
+ return Physics(
65
+ A=A,
66
+ noise_model=noise,
67
+ sensor_model=sensor,
68
+ max_iter=self.max_iter,
69
+ tol=self.tol,
70
+ )
71
+
72
+ def __add__(self, other):
73
+ r"""
74
+ Stacks two linear forward operators :math:`A(x) = \begin{bmatrix} A_1(x) \\ A_2(x) \end{bmatrix}`
75
+ via the add operation.
76
+
77
+ The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where
78
+ each entry corresponds to the measurements of the corresponding operator.
79
+
80
+ :param deepinv.physics.Physics other: Physics operator :math:`A_2`
81
+ :return: (deepinv.physics.Physics) stacked operator
82
+
83
+ """
84
+ A = lambda x: TensorList(self.A(x)).append(TensorList(other.A(x)))
85
+
86
+ class noise(torch.nn.Module):
87
+ def __init__(self, noise1, noise2):
88
+ super().__init__()
89
+ self.noise1 = noise1
90
+ self.noise2 = noise2
91
+
92
+ def forward(self, x):
93
+ return TensorList(self.noise1(x[:-1])).append(self.noise2(x[-1]))
94
+
95
+ class sensor(torch.nn.Module):
96
+ def __init__(self, sensor1, sensor2):
97
+ super().__init__()
98
+ self.sensor1 = sensor1
99
+ self.sensor2 = sensor2
100
+
101
+ def forward(self, x):
102
+ return TensorList(self.sensor1(x[:-1])).append(self.sensor2(x[-1]))
103
+
104
+ return Physics(
105
+ A=A,
106
+ noise_model=noise(self.noise_model, other.noise_model),
107
+ sensor_model=sensor(self.sensor_model, other.sensor_model),
108
+ max_iter=self.max_iter,
109
+ tol=self.tol,
110
+ )
111
+
112
+ def reset(self, **kwargs):
113
+ if isinstance(self.noise_model, torch.nn.Module):
114
+ self.noise_model.__init__(**kwargs)
115
+
116
+ def forward(self, x):
117
+ r"""
118
+ Computes forward operator :math:`y = N(A(x))` (with noise and/or sensor non-linearities)
119
+
120
+ :param torch.Tensor,list[torch.Tensor] x: signal/image
121
+ :return: (torch.Tensor) noisy measurements
122
+
123
+ """
124
+ return self.sensor(self.noise(self.A(x)))
125
+
126
+ def A(self, x):
127
+ r"""
128
+ Computes forward operator :math:`y = A(x)` (without noise and/or sensor non-linearities)
129
+
130
+ :param torch.Tensor,list[torch.Tensor] x: signal/image
131
+ :return: (torch.Tensor) clean measurements
132
+
133
+ """
134
+ return self.forw(x)
135
+
136
+ def sensor(self, x):
137
+ r"""
138
+ Computes sensor non-linearities :math:`y = \eta(y)`
139
+
140
+ :param torch.Tensor,list[torch.Tensor] x: signal/image
141
+ :return: (torch.Tensor) clean measurements
142
+ """
143
+ return self.sensor_model(x)
144
+
145
+ def noise(self, x):
146
+ r"""
147
+ Incorporates noise into the measurements :math:`\tilde{y} = N(y)`
148
+
149
+ :param torch.Tensor x: clean measurements
150
+ :return torch.Tensor: noisy measurements
151
+
152
+ """
153
+ return self.noise_model(x)
154
+
155
+ def A_dagger(self, y, x_init=None):
156
+ r"""
157
+ Computes an inverse of :math:`y = Ax` via gradient descent.
158
+
159
+ This function can be overwritten by a more efficient pseudoinverse in cases where closed form formulas exist.
160
+
161
+ :param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse.
162
+ :param torch.Tensor x_init: initial guess for the reconstruction.
163
+ :return: (torch.Tensor) The reconstructed image :math:`x`.
164
+
165
+ """
166
+
167
+ if x_init is None:
168
+ x_init = self.A_adjoint(y)
169
+
170
+ x = torch.nn.Parameter(x_init, requires_grad=True)
171
+
172
+ optimizer = torch.optim.SGD([x], lr=1e-1)
173
+ loss = torch.nn.MSELoss()
174
+ for i in range(self.max_iter):
175
+ err = loss(self.A(x), y)
176
+ optimizer.zero_grad()
177
+ err.backward(retain_graph=True)
178
+ optimizer.step()
179
+ if err < self.tol:
180
+ break
181
+
182
+ return x.clone()
183
+
184
+
185
+ class LinearPhysics(Physics):
186
+ r"""
187
+ Parent class for linear operators.
188
+
189
+ It describes the linear forward measurement process of the form
190
+
191
+ .. math::
192
+
193
+ y = N(A(x))
194
+
195
+ where :math:`x` is an image of :math:`n` pixels, :math:`y` is the measurements of size :math:`m`,
196
+ :math:`A:\xset\mapsto \yset` is a deterministic linear mapping capturing the physics of the acquisition
197
+ and :math:`N:\yset\mapsto \yset` is a stochastic mapping which characterizes the noise affecting
198
+ the measurements.
199
+
200
+ :param callable A: forward operator function which maps an image to the observed measurements :math:`x\mapsto y`.
201
+ It is recommended to normalize it to have unit norm.
202
+ :param callable A_adjoint: transpose of the forward operator, which should verify the adjointness test.
203
+ :param callable noise_model: function that adds noise to the measurements :math:`N(z)`.
204
+ See the noise module for some predefined functions.
205
+ :param callable sensor_model: function that incorporates any sensor non-linearities to the sensing process,
206
+ such as quantization or saturation, defined as a function :math:`\eta(z)`, such that
207
+ :math:`y=\eta\left(N(A(x))\right)`. By default, the sensor_model is set to the identity :math:`\eta(z)=z`.
208
+ :param int max_iter: If the operator does not have a closed form pseudoinverse, the conjugate gradient algorithm
209
+ is used for computing it, and this parameter fixes the maximum number of conjugate gradient iterations.
210
+ :param float tol: If the operator does not have a closed form pseudoinverse, the conjugate gradient algorithm
211
+ is used for computing it, and this parameter fixes the absolute tolerance of the conjugate gradient algorithm.
212
+
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ A=lambda x: x,
218
+ A_adjoint=lambda x: x,
219
+ noise_model=lambda x: x,
220
+ sensor_model=lambda x: x,
221
+ max_iter=50,
222
+ tol=1e-3,
223
+ **kwargs,
224
+ ):
225
+ super().__init__(
226
+ A=A,
227
+ noise_model=noise_model,
228
+ sensor_model=sensor_model,
229
+ max_iter=max_iter,
230
+ tol=tol,
231
+ )
232
+
233
+ self.adjoint = A_adjoint
234
+
235
+ def A_adjoint(self, y):
236
+ r"""
237
+ Computes transpose of the forward operator :math:`\tilde{x} = A^{\top}y`.
238
+ If :math:`A` is linear, it should be the exact transpose of the forward matrix.
239
+
240
+ .. note:
241
+
242
+ If problem is non-linear, there is not a well-defined transpose operation,
243
+ but defining one can be useful for some reconstruction networks, such as ``deepinv.models.ArtifactRemoval``.
244
+
245
+ :param torch.Tensor y: measurements.
246
+ :return: (torch.Tensor) linear reconstruction :math:`\tilde{x} = A^{\top}y`.
247
+
248
+ """
249
+ return self.adjoint(y)
250
+
251
+ def __mul__(self, other):
252
+ r"""
253
+ Concatenates two linear forward operators :math:`A = A_1\circ A_2` via the * operation
254
+
255
+ The resulting linear operator keeps the noise and sensor models of :math:`A_1`.
256
+
257
+ :param deepinv.physics.LinearPhysics other: Physics operator :math:`A_2`
258
+ :return: (deepinv.physics.LinearPhysics) concantenated operator
259
+
260
+ """
261
+ A = lambda x: self.A(other.A(x)) # (A' = A_1 A_2)
262
+ A_adjoint = lambda x: other.A_adjoint(self.A_adjoint(x))
263
+ noise = self.noise_model
264
+ sensor = self.sensor_model
265
+ return LinearPhysics(
266
+ A=A,
267
+ A_adjoint=A_adjoint,
268
+ noise_model=noise,
269
+ sensor_model=sensor,
270
+ max_iter=self.max_iter,
271
+ tol=self.tol,
272
+ )
273
+
274
+ def __add__(self, other):
275
+ r"""
276
+ Stacks two linear forward operators :math:`A = \begin{bmatrix} A_1 \\ A_2 \end{bmatrix}` via the add operation.
277
+
278
+ The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where
279
+ each entry corresponds to the measurements of the corresponding operator.
280
+
281
+ :param deepinv.physics.LinearPhysics other: Physics operator :math:`A_2`
282
+ :return: (deepinv.physics.LinearPhysics) stacked operator
283
+
284
+ """
285
+ A = lambda x: TensorList(self.A(x)).append(TensorList(other.A(x)))
286
+
287
+ def A_adjoint(y):
288
+ at1 = self.A_adjoint(y[:-1]) if len(y) > 2 else self.A_adjoint(y[0])
289
+ return at1 + other.A_adjoint(y[-1])
290
+
291
+ class noise(torch.nn.Module):
292
+ def __init__(self, noise1, noise2):
293
+ super().__init__()
294
+ self.noise1 = noise1
295
+ self.noise2 = noise2
296
+
297
+ def forward(self, x):
298
+ return TensorList(self.noise1(x[:-1])).append(self.noise2(x[-1]))
299
+
300
+ class sensor(torch.nn.Module):
301
+ def __init__(self, sensor1, sensor2):
302
+ super().__init__()
303
+ self.sensor1 = sensor1
304
+ self.sensor2 = sensor2
305
+
306
+ def forward(self, x):
307
+ return TensorList(self.sensor1(x[:-1])).append(self.sensor2(x[-1]))
308
+
309
+ return LinearPhysics(
310
+ A=A,
311
+ A_adjoint=A_adjoint,
312
+ noise_model=noise(self.noise_model, other.noise_model),
313
+ sensor_model=sensor(self.sensor_model, other.sensor_model),
314
+ max_iter=self.max_iter,
315
+ tol=self.tol,
316
+ )
317
+
318
+ def compute_norm(self, x0, max_iter=100, tol=1e-3, verbose=True):
319
+ r"""
320
+ Computes the spectral :math:`\ell_2` norm (Lipschitz constant) of the operator
321
+
322
+ :math:`A^{\top}A`, i.e., :math:`\|A^{\top}A\|`.
323
+
324
+ using the `power method <https://en.wikipedia.org/wiki/Power_iteration>`_.
325
+
326
+ :param torch.Tensor x0: initialisation point of the algorithm
327
+ :param int max_iter: maximum number of iterations
328
+ :param float tol: relative variation criterion for convergence
329
+ :param bool verbose: print information
330
+
331
+ :returns z: (float) spectral norm of :math:`A^{\top}A`, i.e., :math:`\|A^{\top}A\|`.
332
+ """
333
+ x = torch.randn_like(x0)
334
+ x /= torch.norm(x)
335
+ zold = torch.zeros_like(x)
336
+ for it in range(max_iter):
337
+ y = self.A(x)
338
+ y = self.A_adjoint(y)
339
+ z = torch.matmul(x.reshape(-1), y.reshape(-1)) / torch.norm(x) ** 2
340
+
341
+ rel_var = torch.norm(z - zold)
342
+ if rel_var < tol and verbose:
343
+ print(
344
+ f"Power iteration converged at iteration {it}, value={z.item():.2f}"
345
+ )
346
+ break
347
+ zold = z
348
+ x = y / torch.norm(y)
349
+
350
+ return z
351
+
352
+ def adjointness_test(self, u):
353
+ r"""
354
+ Numerically check that :math:`A^{\top}` is indeed the adjoint of :math:`A`.
355
+
356
+ :param torch.Tensor u: initialisation point of the adjointness test method
357
+
358
+ :return: (float) a quantity that should be theoretically 0. In practice, it should be of the order of the chosen dtype precision (i.e. single or double).
359
+
360
+ """
361
+ u_in = u # .type(self.dtype)
362
+ Au = self.A(u_in)
363
+
364
+ if isinstance(Au, tuple) or isinstance(Au, list):
365
+ V = [randn_like(au) for au in Au]
366
+ Atv = self.A_adjoint(V)
367
+ s1 = 0
368
+ for au, v in zip(Au, V):
369
+ s1 += (v * au).flatten().sum()
370
+
371
+ else:
372
+ v = randn_like(Au)
373
+ Atv = self.A_adjoint(v)
374
+
375
+ s1 = (v * Au).flatten().sum()
376
+
377
+ s2 = (Atv * u_in).flatten().sum()
378
+
379
+ return s1 - s2
380
+
381
+ def prox_l2(self, z, y, gamma):
382
+ r"""
383
+ Computes proximal operator of :math:`f(x) = \frac{1}{2}\|Ax-y\|^2`, i.e.,
384
+
385
+ .. math::
386
+
387
+ \underset{x}{\arg\min} \; \frac{\gamma}{2}\|Ax-y\|^2 + \frac{1}{2}\|x-z\|^2
388
+
389
+ :param torch.Tensor y: measurements tensor
390
+ :param torch.Tensor z: signal tensor
391
+ :param float gamma: hyperparameter of the proximal operator
392
+ :return: (torch.Tensor) estimated signal tensor
393
+
394
+ """
395
+ b = self.A_adjoint(y) + 1 / gamma * z
396
+ H = lambda x: self.A_adjoint(self.A(x)) + 1 / gamma * x
397
+ x = conjugate_gradient(H, b, self.max_iter, self.tol)
398
+ return x
399
+
400
+ def A_dagger(self, y):
401
+ r"""
402
+ Computes the solution in :math:`x` to :math:`y = Ax` using the
403
+ ` conjugate gradient method <https://en.wikipedia.org/wiki/Conjugate_gradient_method>`_.
404
+
405
+ If the size of :math:`y` is larger than :math:`x` (overcomplete problem), it computes :math:`(A^{\top} A)^{-1} A^{\top} y`,
406
+ otherwise (incomplete problem) it computes :math:`A^{\top} (A A^{\top})^{-1} y`.
407
+
408
+ This function can be overwritten by a more efficient pseudoinverse in cases where closed form formulas exist.
409
+
410
+ :param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse.
411
+ :return: (torch.Tensor) The reconstructed image :math:`x`.
412
+
413
+ """
414
+ Aty = self.A_adjoint(y)
415
+
416
+ overcomplete = Aty.flatten().shape[0] < y.flatten().shape[0]
417
+
418
+ if not overcomplete:
419
+ A = lambda x: self.A(self.A_adjoint(x))
420
+ b = y
421
+ else:
422
+ A = lambda x: self.A_adjoint(self.A(x))
423
+ b = Aty
424
+
425
+ x = conjugate_gradient(A=A, b=b, max_iter=self.max_iter, tol=self.tol)
426
+
427
+ if not overcomplete:
428
+ x = self.A_adjoint(x)
429
+
430
+ return x
431
+
432
+
433
+ class DecomposablePhysics(LinearPhysics):
434
+ r"""
435
+ Parent class for linear operators with SVD decomposition.
436
+
437
+
438
+ The singular value decomposition is expressed as
439
+
440
+ .. math::
441
+
442
+ A = U\text{diag}(s)V^{\top} \in \mathbb{R}^{m\times n}
443
+
444
+ where :math:`U\in\mathbb{C}^{n\times n}` and :math:`V\in\mathbb{C}^{m\times m}`
445
+ are orthonormal linear transformations and :math:`s\in\mathbb{R}_{+}^{n}` are the singular values.
446
+
447
+ :param callable U: orthonormal transformation
448
+ :param callable U_adjoint: transpose of U
449
+ :param callable V: orthonormal transformation
450
+ :param callable V_adjoint: transpose of V
451
+ :param torch.Tensor, float mask: Singular values of the transform
452
+
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ U=lambda x: x,
458
+ U_adjoint=lambda x: x,
459
+ V=lambda x: x,
460
+ V_adjoint=lambda x: x,
461
+ mask=1.0,
462
+ **kwargs,
463
+ ):
464
+ super().__init__(**kwargs)
465
+ self._V = V
466
+ self._U = U
467
+ self._U_adjoint = U_adjoint
468
+ self._V_adjoint = V_adjoint
469
+ self.mask = mask
470
+
471
+ def A(self, x):
472
+ return self.U(self.mask * self.V_adjoint(x))
473
+
474
+ def U(self, x):
475
+ return self._U(x)
476
+
477
+ def V(self, x):
478
+ return self._U(x)
479
+
480
+ def U_adjoint(self, x):
481
+ return self._U_adjoint(x)
482
+
483
+ def V_adjoint(self, x):
484
+ return self._V_adjoint(x)
485
+
486
+ def A_adjoint(self, y):
487
+ if isinstance(self.mask, float):
488
+ mask = self.mask
489
+ else:
490
+ mask = torch.conj(self.mask)
491
+
492
+ return self.V(mask * self.U_adjoint(y))
493
+
494
+ def prox_l2(self, z, y, gamma):
495
+ r"""
496
+ Computes proximal operator of :math:`f(x)=\frac{\gamma}{2}\|Ax-y\|^2`
497
+ in an efficient manner leveraging the singular vector decomposition.
498
+
499
+ :param torch.Tensor y: measurements tensor
500
+ :param torch.Tensor, float z: signal tensor
501
+ :param float gamma: hyperparameter :math:`\gamma` of the proximal operator
502
+ :return: (torch.Tensor) estimated signal tensor
503
+
504
+ """
505
+ b = self.A_adjoint(y) + 1 / gamma * z
506
+ if isinstance(self.mask, float):
507
+ scaling = self.mask**2 + 1 / gamma
508
+ else:
509
+ scaling = torch.conj(self.mask) * self.mask + 1 / gamma
510
+ x = self.V(self.V_adjoint(b) / scaling)
511
+ return x
512
+
513
+ def A_dagger(self, y):
514
+ r"""
515
+ Computes :math:`A^{\dagger}y = x` in an efficient manner leveraging the singular vector decomposition.
516
+
517
+ :param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse.
518
+ :return: (torch.Tensor) The reconstructed image :math:`x`.
519
+
520
+ """
521
+
522
+ # avoid division by singular value = 0
523
+
524
+ if not isinstance(self.mask, float):
525
+ mask = torch.zeros_like(self.mask)
526
+ mask[self.mask > 1e-5] = 1 / self.mask[self.mask > 1e-5]
527
+ else:
528
+ mask = 1 / self.mask
529
+
530
+ return self.V(self.U_adjoint(y) * mask)
531
+
532
+
533
+ class Denoising(DecomposablePhysics):
534
+ r"""
535
+
536
+ Forward operator for denoising problems.
537
+
538
+ The linear operator is just the identity mapping :math:`A(x)=x`
539
+
540
+ :param torch.nn.Module noise: noise distribution, e.g., ``deepinv.physics.GaussianNoise``, or a user-defined torch.nn.Module.
541
+ """
542
+
543
+ def __init__(self, noise=GaussianNoise(sigma=0.1), **kwargs):
544
+ super().__init__(**kwargs)
545
+ if noise is None:
546
+ noise = GaussianNoise(sigma=0.0)
547
+ self.noise_model = noise
@@ -0,0 +1,65 @@
1
+ import torch
2
+ from deepinv.physics.forward import Physics
3
+ from deepinv.utils import TensorList
4
+
5
+
6
+ class Haze(Physics):
7
+ r"""
8
+ Standard haze model
9
+
10
+ The operator is defined as https://ieeexplore.ieee.org/abstract/document/5567108/
11
+
12
+ .. math::
13
+
14
+ y = t \odot I + a (1-t)
15
+
16
+ where :math:`t = \exp(-\beta d - o)` is the medium transmission, :math:`I` is the intensity (possibly RGB) image,
17
+ :math:`\odot` denotes element-wise multiplication, :math:`a>0` is the atmospheric light,
18
+ :math:`d` is the scene depth, and :math:`\beta>0` and :math:`o` are constants.
19
+
20
+ This is a non-linear inverse problems, whose unknown parameters are :math:`I`, :math:`d`, :math:`a`.
21
+
22
+ :param float beta: constant :math:`\beta>0`
23
+ :param float offset: constant :math:`o`
24
+
25
+ """
26
+
27
+ def __init__(self, beta=0.1, offset=0, **kwargs):
28
+ super().__init__(**kwargs)
29
+ self.beta = beta
30
+ self.offset = offset
31
+
32
+ def A(self, x):
33
+ r"""
34
+ :param list, tuple x: The input x should be a tuple/list such that x[0] = image torch.tensor :math:`I`,
35
+ x[1] = depth torch.tensor :math:`d`, x[2] = scalar or torch.tensor of one element :math:`a`.
36
+ :return: (torch.tensor) hazy image.
37
+
38
+ """
39
+ im = x[0]
40
+ d = x[1]
41
+ A = x[2]
42
+
43
+ t = torch.exp(-self.beta * (d + self.offset))
44
+ y = t * im + (1 - t) * A
45
+ return y
46
+
47
+ def A_dagger(self, y):
48
+ r"""
49
+
50
+ Returns the trivial inverse where x[0] = y (trivial estimate of the image :math:`I`),
51
+ x[1] = tensor of depth :math:`d` equal to one, x[2] = 1 for :math:`a`.
52
+
53
+ .. note:
54
+
55
+ This trivial inverse can be useful for some reconstruction networks, such as ``deepinv.models.ArtifactRemoval``.
56
+
57
+
58
+ :param torch.tensor y: Hazy image.
59
+ :return: (deepinv.utils.ListTensor) trivial inverse.
60
+
61
+ """
62
+ b, c, h, w = y.shape
63
+ d = torch.ones((b, 1, h, w), device=y.device)
64
+ A = torch.ones(1, device=y.device)
65
+ return TensorList([y, d, A])
@@ -0,0 +1,48 @@
1
+ from deepinv.physics.forward import DecomposablePhysics
2
+ import torch
3
+
4
+
5
+ class Inpainting(DecomposablePhysics):
6
+ r"""
7
+
8
+ Inpainting forward operator, keeps a subset of entries.
9
+
10
+ The operator is described by the diagonal matrix
11
+
12
+ .. math::
13
+
14
+ A = \text{diag}(m) \in \mathbb{R}^{n\times n}
15
+
16
+ where :math:`m` is a binary mask with n entries.
17
+
18
+ This operator is linear and has a trivial SVD decomposition, which allows for fast computation
19
+ of the pseudo-inverse and proximal operator.
20
+
21
+ An existing operator can be loaded from a saved ``.pth`` file via ``self.load_state_dict(save_path)``,
22
+ in a similar fashion to ``torch.nn.Module``.
23
+
24
+ :param tuple tensor_size: size of the input images, e.g., (C, H, W).
25
+ :param torch.tensor, float mask: If the input is a float, the entries of the mask will be sampled from a bernoulli
26
+ distribution with probability equal to ``mask``. If the input is a ``torch.tensor`` matching tensor_size,
27
+ the mask will be set to this tensor.
28
+ :param torch.device device: gpu or cpu
29
+ :param bool pixelwise: Apply the mask in a pixelwise fashion, i.e., zero all channels in a given pixel simultaneously.
30
+
31
+ """
32
+
33
+ def __init__(self, tensor_size, mask=0.3, pixelwise=True, device="cpu", **kwargs):
34
+ super().__init__(**kwargs)
35
+ self.tensor_size = tensor_size
36
+
37
+ if isinstance(mask, torch.Tensor): # check if the user created mask
38
+ self.mask = mask
39
+ else: # otherwise create new random mask
40
+ mask_rate = mask
41
+ self.mask = torch.ones(tensor_size, device=device)
42
+ aux = torch.rand_like(self.mask)
43
+ if not pixelwise:
44
+ self.mask[aux > mask_rate] = 0
45
+ else:
46
+ self.mask[:, aux[0, :, :] > mask_rate] = 0
47
+
48
+ self.mask = torch.nn.Parameter(self.mask.unsqueeze(0), requires_grad=False)