deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,544 @@
1
+ from torchvision.transforms.functional import rotate
2
+ import torchvision
3
+ import torch.nn.functional as F
4
+ import torch
5
+ import numpy as np
6
+ import torch.fft as fft
7
+ from deepinv.physics.forward import Physics, LinearPhysics, DecomposablePhysics
8
+ from deepinv.utils import TensorList
9
+
10
+
11
+ def filter_fft(filter, img_size, real_fft=True):
12
+ ph = int((filter.shape[2] - 1) / 2)
13
+ pw = int((filter.shape[3] - 1) / 2)
14
+
15
+ filt2 = torch.zeros(filter.shape[:2] + img_size[-2:], device=filter.device)
16
+
17
+ filt2[:, : filter.shape[1], : filter.shape[2], : filter.shape[3]] = filter
18
+ filt2 = torch.roll(filt2, shifts=(-ph, -pw), dims=(2, 3))
19
+
20
+ return fft.rfft2(filt2) if real_fft else fft.fft2(filt2)
21
+
22
+
23
+ def gaussian_blur(sigma=(1, 1), angle=0):
24
+ r"""
25
+ Gaussian blur filter.
26
+
27
+ :param float, tuple[float] sigma: standard deviation of the gaussian filter. If sigma is a float the filter is isotropic, whereas
28
+ if sigma is a tuple of floats (sigma_x, sigma_y) the filter is anisotropic.
29
+ :param float angle: rotation angle of the filter in degrees (only useful for anisotropic filters)
30
+ """
31
+ if isinstance(sigma, (int, float)):
32
+ sigma = (sigma, sigma)
33
+
34
+ s = max(sigma)
35
+ c = int(s / 0.3 + 1)
36
+ k_size = 2 * c + 1
37
+
38
+ delta = torch.arange(k_size)
39
+
40
+ x, y = torch.meshgrid(delta, delta, indexing="ij")
41
+ x = x - c
42
+ y = y - c
43
+ filt = (x / sigma[0]).pow(2)
44
+ filt += (y / sigma[1]).pow(2)
45
+ filt = torch.exp(-filt / 2.0)
46
+
47
+ filt = (
48
+ rotate(
49
+ filt.unsqueeze(0).unsqueeze(0),
50
+ angle,
51
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
52
+ )
53
+ .squeeze(0)
54
+ .squeeze(0)
55
+ )
56
+
57
+ filt = filt / filt.flatten().sum()
58
+
59
+ return filt.unsqueeze(0).unsqueeze(0)
60
+
61
+
62
+ def bilinear_filter(factor=2):
63
+ x = np.arange(start=-factor + 0.5, stop=factor, step=1) / factor
64
+ w = 1 - np.abs(x)
65
+ w = np.outer(w, w)
66
+ w = w / np.sum(w)
67
+ return torch.Tensor(w).unsqueeze(0).unsqueeze(0)
68
+
69
+
70
+ def bicubic_filter(factor=2):
71
+ x = np.arange(start=-2 * factor + 0.5, stop=2 * factor, step=1) / factor
72
+ a = -0.5
73
+ x = np.abs(x)
74
+ w = ((a + 2) * np.power(x, 3) - (a + 3) * np.power(x, 2) + 1) * (x <= 1)
75
+ w += (
76
+ (a * np.power(x, 3) - 5 * a * np.power(x, 2) + 8 * a * x - 4 * a)
77
+ * (x > 1)
78
+ * (x < 2)
79
+ )
80
+ w = np.outer(w, w)
81
+ w = w / np.sum(w)
82
+ return torch.Tensor(w).unsqueeze(0).unsqueeze(0)
83
+
84
+
85
+ class Downsampling(LinearPhysics):
86
+ r"""
87
+ Downsampling operator for super-resolution problems.
88
+
89
+ It is defined as
90
+
91
+ .. math::
92
+
93
+ y = S (h*x)
94
+
95
+ where :math:`h` is a low-pass filter and :math:`S` is a subsampling operator.
96
+
97
+ :param tuple[int] img_size: size of the input image
98
+ :param int factor: downsampling factor
99
+ :param torch.Tensor, str, NoneType filter: Downsampling filter. It can be 'gaussian', 'bilinear' or 'bicubic' or a
100
+ custom ``torch.Tensor`` filter. If ``None``, no filtering is applied.
101
+ :param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
102
+ If ``padding='valid'`` the blurred output is smaller than the image (no padding)
103
+ otherwise the blurred output has the same size as the image.
104
+
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ img_size,
110
+ factor=2,
111
+ filter="gaussian",
112
+ device="cpu",
113
+ padding="circular",
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+ self.factor = factor
118
+ assert isinstance(factor, int), "downsampling factor should be an integer"
119
+ self.imsize = img_size
120
+ self.padding = padding
121
+ if isinstance(filter, torch.Tensor):
122
+ self.filter = filter.to(device)
123
+ elif filter is None:
124
+ self.filter = filter
125
+ elif filter == "gaussian":
126
+ self.filter = (
127
+ gaussian_blur(sigma=(factor, factor)).requires_grad_(False).to(device)
128
+ )
129
+ elif filter == "bilinear":
130
+ self.filter = bilinear_filter(self.factor).requires_grad_(False).to(device)
131
+ elif filter == "bicubic":
132
+ self.filter = bicubic_filter(self.factor).requires_grad_(False).to(device)
133
+ else:
134
+ raise Exception("The chosen downsampling filter doesn't exist")
135
+
136
+ if self.filter is not None:
137
+ self.Fh = filter_fft(self.filter, img_size, real_fft=False).to(device)
138
+ self.Fhc = torch.conj(self.Fh)
139
+ self.Fh2 = self.Fhc * self.Fh
140
+ self.filter = torch.nn.Parameter(self.filter, requires_grad=False)
141
+ self.Fhc = torch.nn.Parameter(self.Fhc, requires_grad=False)
142
+ self.Fh2 = torch.nn.Parameter(self.Fh2, requires_grad=False)
143
+
144
+ def A(self, x):
145
+ if self.filter is not None:
146
+ x = conv(x, self.filter, padding=self.padding)
147
+ x = x[:, :, :: self.factor, :: self.factor] # downsample
148
+ return x
149
+
150
+ def A_adjoint(self, y):
151
+ x = torch.zeros((y.shape[0],) + self.imsize, device=y.device)
152
+ x[:, :, :: self.factor, :: self.factor] = y # upsample
153
+ if self.filter is not None:
154
+ x = conv_transpose(x, self.filter, padding=self.padding)
155
+ return x
156
+
157
+ def prox_l2(self, z, y, gamma, use_fft=True):
158
+ r"""
159
+ If the padding is circular, it computes the proximal operator with the closed-formula of
160
+ https://arxiv.org/abs/1510.00143.
161
+
162
+ Otherwise, it computes it using the conjugate gradient algorithm which can be slow if applied many times.
163
+ """
164
+
165
+ if use_fft and self.padding == "circular": # Formula from (Zhao, 2016)
166
+ z_hat = self.A_adjoint(y) + 1 / gamma * z
167
+ Fz_hat = fft.fft2(z_hat)
168
+
169
+ def splits(a, sf):
170
+ """split a into sfxsf distinct blocks
171
+ Args:
172
+ a: NxCxWxH
173
+ sf: split factor
174
+ Returns:
175
+ b: NxCx(W/sf)x(H/sf)x(sf^2)
176
+ """
177
+ b = torch.stack(torch.chunk(a, sf, dim=2), dim=4)
178
+ b = torch.cat(torch.chunk(b, sf, dim=3), dim=4)
179
+ return b
180
+
181
+ top = torch.mean(splits(self.Fh * Fz_hat, self.factor), dim=-1)
182
+ below = torch.mean(splits(self.Fh2, self.factor), dim=-1) + 1 / gamma
183
+ rc = self.Fhc * (top / below).repeat(1, 1, self.factor, self.factor)
184
+ r = torch.real(fft.ifft2(rc))
185
+ return (z_hat - r) * gamma
186
+ else:
187
+ return LinearPhysics.prox_l2(self, z, y, gamma)
188
+
189
+
190
+ def extend_filter(filter):
191
+ b, c, h, w = filter.shape
192
+ w_new = w
193
+ h_new = h
194
+
195
+ offset_w = 0
196
+ offset_h = 0
197
+
198
+ if w == 1:
199
+ w_new = 3
200
+ offset_w = 1
201
+ elif w % 2 == 0:
202
+ w_new += 1
203
+
204
+ if h == 1:
205
+ h_new = 3
206
+ offset_h = 1
207
+ elif h % 2 == 0:
208
+ h_new += 1
209
+
210
+ out = torch.zeros((b, c, h_new, w_new), device=filter.device)
211
+ out[:, :, offset_h : h + offset_h, offset_w : w + offset_w] = filter
212
+ return out
213
+
214
+
215
+ def conv(x, filter, padding):
216
+ r"""
217
+ Convolution of x and filter. The transposed of this operation is conv_transpose(x, filter, padding)
218
+
219
+ :param x: (torch.Tensor) Image of size (B,C,W,H).
220
+ :param filter: (torch.Tensor) Filter of size (1,C,W,H) for colour filtering or (1,1,W,H) for filtering each channel with the same filter.
221
+ :param padding: (string) options = 'valid','circular','replicate','reflect'. If padding='valid' the blurred output is smaller than the image (no padding), otherwise the blurred output has the same size as the image.
222
+
223
+ """
224
+ b, c, h, w = x.shape
225
+
226
+ filter = filter.flip(-1).flip(
227
+ -2
228
+ ) # In order to perform convolution and not correlation like Pytorch native conv
229
+
230
+ filter = extend_filter(filter)
231
+
232
+ ph = (filter.shape[2] - 1) / 2
233
+ pw = (filter.shape[3] - 1) / 2
234
+
235
+ if padding == "valid":
236
+ h_out = int(h - 2 * ph)
237
+ w_out = int(w - 2 * pw)
238
+ else:
239
+ h_out = h
240
+ w_out = w
241
+ pw = int(pw)
242
+ ph = int(ph)
243
+ x = F.pad(x, (pw, pw, ph, ph), mode=padding, value=0)
244
+
245
+ if filter.shape[1] == 1:
246
+ y = torch.zeros((b, c, h_out, w_out), device=x.device)
247
+ for i in range(b):
248
+ for j in range(c):
249
+ y[i, j, :, :] = F.conv2d(
250
+ x[i, j, :, :].unsqueeze(0).unsqueeze(1), filter, padding="valid"
251
+ ).unsqueeze(1)
252
+ else:
253
+ y = F.conv2d(x, filter, padding="valid")
254
+
255
+ return y
256
+
257
+
258
+ def conv_transpose(y, filter, padding):
259
+ r"""
260
+ Tranposed convolution of x and filter. The transposed of this operation is conv(x, filter, padding)
261
+
262
+ :param torch.tensor x: Image of size (B,C,W,H).
263
+ :param torch.tensor filter: Filter of size (1,C,W,H) for colour filtering or (1,C,W,H) for filtering each channel with the same filter.
264
+ :param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
265
+ If ``padding='valid'`` the blurred output is smaller than the image (no padding)
266
+ otherwise the blurred output has the same size as the image.
267
+ """
268
+
269
+ b, c, h, w = y.shape
270
+
271
+ filter = filter.flip(-1).flip(
272
+ -2
273
+ ) # In order to perform convolution and not correlation like Pytorch native conv
274
+
275
+ filter = extend_filter(filter)
276
+
277
+ ph = (filter.shape[2] - 1) / 2
278
+ pw = (filter.shape[3] - 1) / 2
279
+
280
+ h_out = int(h + 2 * ph)
281
+ w_out = int(w + 2 * pw)
282
+ pw = int(pw)
283
+ ph = int(ph)
284
+
285
+ x = torch.zeros((b, c, h_out, w_out), device=y.device)
286
+ if filter.shape[1] == 1:
287
+ for i in range(b):
288
+ if filter.shape[0] > 1:
289
+ f = filter[i, :, :, :].unsqueeze(0)
290
+ else:
291
+ f = filter
292
+
293
+ for j in range(c):
294
+ x[i, j, :, :] = F.conv_transpose2d(
295
+ y[i, j, :, :].unsqueeze(0).unsqueeze(1), f
296
+ )
297
+ else:
298
+ x = F.conv_transpose2d(y, filter)
299
+
300
+ if padding == "valid":
301
+ out = x
302
+ elif padding == "zero":
303
+ out = x[:, :, ph:-ph, pw:-pw]
304
+ elif padding == "circular":
305
+ out = x[:, :, ph:-ph, pw:-pw]
306
+ # sides
307
+ out[:, :, :ph, :] += x[:, :, -ph:, pw:-pw]
308
+ out[:, :, -ph:, :] += x[:, :, :ph, pw:-pw]
309
+ out[:, :, :, :pw] += x[:, :, ph:-ph, -pw:]
310
+ out[:, :, :, -pw:] += x[:, :, ph:-ph, :pw]
311
+ # corners
312
+ out[:, :, :ph, :pw] += x[:, :, -ph:, -pw:]
313
+ out[:, :, -ph:, -pw:] += x[:, :, :ph, :pw]
314
+ out[:, :, :ph, -pw:] += x[:, :, -ph:, :pw]
315
+ out[:, :, -ph:, :pw] += x[:, :, :ph, -pw:]
316
+
317
+ elif padding == "reflect":
318
+ out = x[:, :, ph:-ph, pw:-pw]
319
+ # sides
320
+ out[:, :, 1 : 1 + ph, :] += x[:, :, :ph, pw:-pw].flip(dims=(2,))
321
+ out[:, :, -ph - 1 : -1, :] += x[:, :, -ph:, pw:-pw].flip(dims=(2,))
322
+ out[:, :, :, 1 : 1 + pw] += x[:, :, ph:-ph, :pw].flip(dims=(3,))
323
+ out[:, :, :, -pw - 1 : -1] += x[:, :, ph:-ph, -pw:].flip(dims=(3,))
324
+ # corners
325
+ out[:, :, 1 : 1 + ph, 1 : 1 + pw] += x[:, :, :ph, :pw].flip(dims=(2, 3))
326
+ out[:, :, -ph - 1 : -1, -pw - 1 : -1] += x[:, :, -ph:, -pw:].flip(dims=(2, 3))
327
+ out[:, :, -ph - 1 : -1, 1 : 1 + pw] += x[:, :, -ph:, :pw].flip(dims=(2, 3))
328
+ out[:, :, 1 : 1 + ph, -pw - 1 : -1] += x[:, :, :ph, -pw:].flip(dims=(2, 3))
329
+
330
+ elif padding == "replicate":
331
+ out = x[:, :, ph:-ph, pw:-pw]
332
+ # sides
333
+ out[:, :, 0, :] += x[:, :, :ph, pw:-pw].sum(2)
334
+ out[:, :, -1, :] += x[:, :, -ph:, pw:-pw].sum(2)
335
+ out[:, :, :, 0] += x[:, :, ph:-ph, :pw].sum(3)
336
+ out[:, :, :, -1] += x[:, :, ph:-ph, -pw:].sum(3)
337
+ # corners
338
+ out[:, :, 0, 0] += x[:, :, :ph, :pw].sum(3).sum(2)
339
+ out[:, :, -1, -1] += x[:, :, -ph:, -pw:].sum(3).sum(2)
340
+ out[:, :, -1, 0] += x[:, :, -ph:, :pw].sum(3).sum(2)
341
+ out[:, :, 0, -1] += x[:, :, :ph, -pw:].sum(3).sum(2)
342
+ return out
343
+
344
+
345
+ class BlindBlur(Physics):
346
+ r"""
347
+ Blind blur operator.
348
+
349
+ If performs
350
+
351
+ .. math::
352
+
353
+ y = w*x
354
+
355
+ where :math:`*` denotes convolution and :math:`w` is an unknown filter.
356
+ This class uses ``torch.conv2d`` for performing the convolutions.
357
+
358
+ The signal is described by a tuple (x,w) where the first element is the clean image, and the second element
359
+ is the blurring kernel. The measurements y are a tensor representing the convolution of x and w.
360
+
361
+ :param int kernel_size: maximum support size of the (unknown) blurring kernels.
362
+ :param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``.
363
+ If ``padding='valid'`` the blurred output is smaller than the image (no padding)
364
+ otherwise the blurred output has the same size as the image.
365
+
366
+ """
367
+
368
+ def __init__(self, kernel_size=3, padding="circular", **kwargs):
369
+ super().__init__(**kwargs)
370
+ self.padding = padding
371
+
372
+ if type(kernel_size) is not list or type(kernel_size) is not tuple:
373
+ self.kernel_size = [kernel_size, kernel_size]
374
+
375
+ def A(self, s):
376
+ r"""
377
+
378
+ :param tuple, list, deepinv.utils.ListTensor x: List containing two torch.tensor, x[0] with the image and x[1] with the filter.
379
+ :return: (torch.tensor) blurred measurement.
380
+ """
381
+ x = s[0]
382
+ w = s[1]
383
+ return conv(x, w, self.padding)
384
+
385
+ def A_dagger(self, y):
386
+ r"""
387
+ Returns the trivial inverse where x[0] = blurry input and x[1] with a delta filter, such that
388
+ the convolution of x[0] and x[1] is y.
389
+
390
+ .. note:
391
+
392
+ This trivial inverse can be useful for some reconstruction networks, such as ``deepinv.models.ArtifactRemoval``.
393
+
394
+ :param torch.tensor y: blurred measurement.
395
+ :return: Tuple containing the trivial inverse.
396
+ """
397
+ x = y.clone()
398
+ mid_h = int(self.kernel_size[0] / 2)
399
+ mid_w = int(self.kernel_size[1] / 2)
400
+ w = torch.zeros((y.shape[0], 1, self.kernel_size[0], self.kernel_size[1]))
401
+ w[:, :, mid_h, mid_w] = 1.0
402
+
403
+ return TensorList([x, w])
404
+
405
+
406
+ class Blur(LinearPhysics):
407
+ r"""
408
+
409
+ Blur operator.
410
+
411
+ This forward operator performs
412
+
413
+ .. math:: y = w*x
414
+
415
+ where :math:`*` denotes convolution and :math:`w` is a filter.
416
+
417
+ This class uses :meth:`torch.nn.functional.conv2d` for performing the convolutions.
418
+
419
+ :param torch.Tensor filter: Tensor of size (1, 1, H, W) or (1, C, H, W) containing the blur filter, e.g., :meth:`deepinv.physics.blur.gaussian_blur`.
420
+ :param str padding: options are ``'valid'``, ``'circular'``, ``'replicate'`` and ``'reflect'``. If ``padding='valid'`` the blurred output is smaller than the image (no padding)
421
+ otherwise the blurred output has the same size as the image.
422
+ :param str device: cpu or cuda.
423
+
424
+ """
425
+
426
+ def __init__(self, filter, padding="circular", device="cpu", **kwargs):
427
+ super().__init__(**kwargs)
428
+ self.padding = padding
429
+ self.device = device
430
+ self.filter = torch.nn.Parameter(filter, requires_grad=False).to(device)
431
+
432
+ def A(self, x):
433
+ return conv(x, self.filter, self.padding)
434
+
435
+ def A_adjoint(self, y):
436
+ return conv_transpose(y, self.filter, self.padding)
437
+
438
+
439
+ class BlurFFT(DecomposablePhysics):
440
+ """
441
+
442
+ FFT-based blur operator.
443
+
444
+ It performs the operation
445
+
446
+ .. math:: y = w*x
447
+
448
+ where :math:`*` denotes convolution and :math:`w` is a filter.
449
+
450
+ Blur operator based on ``torch.fft`` operations, which assumes a circular padding of the input, and allows for
451
+ the singular value decomposition via ``deepinv.Physics.DecomposablePhysics`` and has fast pseudo-inverse and prox operators.
452
+
453
+
454
+
455
+ :param tuple img_size: Input image size in the form (C, H, W).
456
+ :param torch.tensor filter: torch.Tensor of size (1, 1, H, W) or (1, C, H, W) containing the blur filter, e.g.,
457
+ :meth:`deepinv.physics.blur.gaussian_blur`.
458
+ :param str device: cpu or cuda
459
+
460
+ """
461
+
462
+ def __init__(self, img_size, filter, device="cpu", **kwargs):
463
+ super().__init__(**kwargs)
464
+ self.img_size = img_size
465
+
466
+ if img_size[0] > filter.shape[1]:
467
+ filter = filter.repeat(1, img_size[0], 1, 1)
468
+
469
+ self.mask = filter_fft(filter, img_size).to("cpu")
470
+ self.angle = torch.angle(self.mask)
471
+ self.angle = torch.exp(-1j * self.angle).to(device)
472
+ self.mask = torch.abs(self.mask).unsqueeze(-1)
473
+ self.mask = torch.cat([self.mask, self.mask], dim=-1)
474
+
475
+ self.mask = torch.nn.Parameter(self.mask, requires_grad=False).to(device)
476
+
477
+ def V_adjoint(self, x):
478
+ return torch.view_as_real(
479
+ fft.rfft2(x, norm="ortho")
480
+ ) # make it a true SVD (see J. Romberg notes)
481
+
482
+ def U(self, x):
483
+ return fft.irfft2(
484
+ torch.view_as_complex(x) * self.angle, norm="ortho", s=self.img_size[-2:]
485
+ )
486
+
487
+ def U_adjoint(self, x):
488
+ return torch.view_as_real(
489
+ fft.rfft2(x, norm="ortho") * torch.conj(self.angle)
490
+ ) # make it a true SVD (see J.
491
+ # Romberg notes)
492
+
493
+ def V(self, x):
494
+ return fft.irfft2(torch.view_as_complex(x), norm="ortho", s=self.img_size[-2:])
495
+
496
+
497
+ # # test code
498
+ # if __name__ == "__main__":
499
+ # device = "cuda:0"
500
+ #
501
+ # import matplotlib.pyplot as plt
502
+ #
503
+ # device = "cuda:0"
504
+ # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg")
505
+ # x = x.unsqueeze(0).float().to(device) / 255
506
+ # x = torchvision.transforms.Resize((160, 180))(x)
507
+ #
508
+ # sigma_noise = 0.0
509
+ # kernel = torch.zeros((1, 1, 15, 15), device=device)
510
+ # kernel[:, :, 7, :] = 1 / 15
511
+ # physics = Downsampling(img_size=x.shape[1:], filter="bilinear", device=device)
512
+ # physics2 = Blur(img_size=x.shape[1:], filter=kernel, device=device)
513
+ #
514
+ # y = physics(x)
515
+ # y2 = physics2(x)
516
+ #
517
+ # xhat = physics.V(physics.U_adjoint(y) / physics.mask)
518
+ # xhat2 = physics2.A_dagger(y2)
519
+ #
520
+ # print(xhat.shape)
521
+ # # print(physics.adjointness_test(x))
522
+ # print(torch.sum((y - y2).pow(2)))
523
+ # print(torch.sum((xhat - xhat2).pow(2)))
524
+ #
525
+ # print(torch.sum((x - xhat).pow(2)))
526
+ # print(torch.sum((x - xhat2).pow(2)))
527
+ #
528
+ # print(physics.compute_norm(x))
529
+ # print(physics.adjointness_test(x))
530
+ # xhat = physics.prox_l2(y, y, gamma=1.0)
531
+ #
532
+ # xhat = physics.A_dagger(y)
533
+ #
534
+ # plt.imshow(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
535
+ # plt.show()
536
+ # plt.imshow(y.squeeze(0).permute(1, 2, 0).cpu().numpy())
537
+ # plt.show()
538
+ # plt.imshow(xhat.squeeze(0).permute(1, 2, 0).cpu().numpy())
539
+ # plt.show()
540
+ # plt.imshow(xhat2.squeeze(0).permute(1, 2, 0).cpu().numpy())
541
+ # plt.show()
542
+ #
543
+ # plt.imshow(physics.A(xhat).squeeze(0).permute(1, 2, 0).cpu().numpy())
544
+ # plt.show()