deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/models/unet.py ADDED
@@ -0,0 +1,337 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .drunet import test_pad
4
+
5
+
6
+ class BFBatchNorm2d(nn.BatchNorm2d):
7
+ r"""
8
+ From Mohan et al.
9
+
10
+ "Robust And Interpretable Blind Image Denoising Via Bias-Free Convolutional Neural Networks"
11
+ S. Mohan, Z. Kadkhodaie, E. P. Simoncelli, C. Fernandez-Granda
12
+ Int'l. Conf. on Learning Representations (ICLR), Apr 2020.
13
+ """
14
+
15
+ def __init__(
16
+ self, num_features, eps=1e-5, momentum=0.1, use_bias=False, affine=True
17
+ ):
18
+ super(BFBatchNorm2d, self).__init__(num_features, eps, momentum)
19
+ self.use_bias = use_bias
20
+ self.affine = affine
21
+
22
+ def forward(self, x):
23
+ self._check_input_dim(x)
24
+ y = x.transpose(0, 1)
25
+ return_shape = y.shape
26
+ y = y.contiguous().view(x.size(1), -1)
27
+ if self.use_bias:
28
+ mu = y.mean(dim=1)
29
+ sigma2 = y.var(dim=1)
30
+ if self.training is not True:
31
+ if self.use_bias:
32
+ y = y - self.running_mean.view(-1, 1)
33
+ y = y / (self.running_var.view(-1, 1) ** 0.5 + self.eps)
34
+ else:
35
+ if self.track_running_stats is True:
36
+ with torch.no_grad():
37
+ if self.use_bias:
38
+ self.running_mean = (
39
+ 1 - self.momentum
40
+ ) * self.running_mean + self.momentum * mu
41
+ self.running_var = (
42
+ 1 - self.momentum
43
+ ) * self.running_var + self.momentum * sigma2
44
+ if self.use_bias:
45
+ y = y - mu.view(-1, 1)
46
+ y = y / (sigma2.view(-1, 1) ** 0.5 + self.eps)
47
+ if self.affine:
48
+ y = self.weight.view(-1, 1) * y
49
+ if self.use_bias:
50
+ y += self.bias.view(-1, 1)
51
+
52
+ return y.view(return_shape).transpose(0, 1)
53
+
54
+
55
+ class UNet(nn.Module):
56
+ r"""
57
+ U-Net convolutional denoiser.
58
+
59
+ This network is a fully convolutional denoiser based on the U-Net architecture. The number of downsample steps
60
+ can be controlled with the `scales` parameter. The number of trainable parameters increases with the number of
61
+ scales.
62
+
63
+ :param int in_channels: input image channels
64
+ :param int out_channels: output image channels
65
+ :param bool residual: use a skip-connection between output and output.
66
+ :param bool circular_padding: circular padding for the convolutional layers.
67
+ :param bool cat: use skip-connections between intermediate levels.
68
+ :param bool bias: use learnable biases.
69
+ :param int scales: Number of downsampling steps used in the U-Net. The options are 2,3,4 and 5.
70
+ The number of trainable parameters increases with the scale.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ in_channels=1,
76
+ out_channels=1,
77
+ residual=True,
78
+ circular_padding=False,
79
+ cat=True,
80
+ bias=True,
81
+ batch_norm=True,
82
+ scales=4,
83
+ ):
84
+ super(UNet, self).__init__()
85
+ self.name = "unet"
86
+
87
+ self.in_channels = in_channels
88
+ self.out_channels = out_channels
89
+
90
+ self.residual = residual
91
+ self.cat = cat
92
+ self.compact = scales
93
+ self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
94
+
95
+ def conv_block(ch_in, ch_out):
96
+ if batch_norm:
97
+ return nn.Sequential(
98
+ nn.Conv2d(
99
+ ch_in,
100
+ ch_out,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=bias,
105
+ padding_mode="circular" if circular_padding else "zeros",
106
+ ),
107
+ BFBatchNorm2d(ch_out, use_bias=bias),
108
+ nn.ReLU(inplace=True),
109
+ nn.Conv2d(
110
+ ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
111
+ ),
112
+ BFBatchNorm2d(ch_out, use_bias=bias),
113
+ nn.ReLU(inplace=True),
114
+ )
115
+ else:
116
+ return nn.Sequential(
117
+ nn.Conv2d(
118
+ ch_in,
119
+ ch_out,
120
+ kernel_size=3,
121
+ stride=1,
122
+ padding=1,
123
+ bias=bias,
124
+ padding_mode="circular" if circular_padding else "zeros",
125
+ ),
126
+ nn.ReLU(inplace=True),
127
+ nn.Conv2d(
128
+ ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
129
+ ),
130
+ nn.ReLU(inplace=True),
131
+ )
132
+
133
+ def up_conv(ch_in, ch_out):
134
+ if batch_norm:
135
+ return nn.Sequential(
136
+ nn.Upsample(scale_factor=2),
137
+ nn.Conv2d(
138
+ ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
139
+ ),
140
+ BFBatchNorm2d(ch_out, use_bias=bias),
141
+ nn.ReLU(inplace=True),
142
+ )
143
+ else:
144
+ return nn.Sequential(
145
+ nn.Upsample(scale_factor=2),
146
+ nn.Conv2d(
147
+ ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
148
+ ),
149
+ nn.ReLU(inplace=True),
150
+ )
151
+
152
+ self.Conv1 = conv_block(ch_in=in_channels, ch_out=64)
153
+ self.Conv2 = conv_block(ch_in=64, ch_out=128)
154
+ self.Conv3 = (
155
+ conv_block(ch_in=128, ch_out=256) if self.compact in [3, 4, 5] else None
156
+ )
157
+ self.Conv4 = (
158
+ conv_block(ch_in=256, ch_out=512) if self.compact in [4, 5] else None
159
+ )
160
+ self.Conv5 = conv_block(ch_in=512, ch_out=1024) if self.compact in [5] else None
161
+
162
+ self.Up5 = up_conv(ch_in=1024, ch_out=512) if self.compact in [5] else None
163
+ self.Up_conv5 = (
164
+ conv_block(ch_in=1024, ch_out=512) if self.compact in [5] else None
165
+ )
166
+
167
+ self.Up4 = up_conv(ch_in=512, ch_out=256) if self.compact in [4, 5] else None
168
+ self.Up_conv4 = (
169
+ conv_block(ch_in=512, ch_out=256) if self.compact in [4, 5] else None
170
+ )
171
+
172
+ self.Up3 = up_conv(ch_in=256, ch_out=128) if self.compact in [3, 4, 5] else None
173
+ self.Up_conv3 = (
174
+ conv_block(ch_in=256, ch_out=128) if self.compact in [3, 4, 5] else None
175
+ )
176
+
177
+ self.Up2 = up_conv(ch_in=128, ch_out=64)
178
+ self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
179
+
180
+ self.Conv_1x1 = nn.Conv2d(
181
+ in_channels=64,
182
+ out_channels=out_channels,
183
+ bias=bias,
184
+ kernel_size=1,
185
+ stride=1,
186
+ padding=0,
187
+ )
188
+
189
+ if self.compact == 5:
190
+ self._forward = self.forward_standard
191
+ if self.compact == 4:
192
+ self._forward = self.forward_compact4
193
+ if self.compact == 3:
194
+ self._forward = self.forward_compact3
195
+ if self.compact == 2:
196
+ self._forward = self.forward_compact2
197
+
198
+ def forward(self, x, sigma=None):
199
+ r"""
200
+ Run the denoiser on noisy image. The noise level is not used in this denoiser.
201
+
202
+ :param torch.Tensor x: noisy image.
203
+ :param float sigma: noise level (not used).
204
+ """
205
+
206
+ factor = 2 ** (self.compact - 1)
207
+ if x.size(2) % factor == 0 and x.size(3) % factor == 0:
208
+ return self._forward(x)
209
+ else:
210
+ return test_pad(self._forward, x, modulo=factor)
211
+
212
+ def forward_standard(self, x):
213
+ # encoding path
214
+ cat_dim = 1
215
+ input = x
216
+ x1 = self.Conv1(input)
217
+
218
+ x2 = self.Maxpool(x1)
219
+ x2 = self.Conv2(x2)
220
+
221
+ x3 = self.Maxpool(x2)
222
+ x3 = self.Conv3(x3)
223
+
224
+ x4 = self.Maxpool(x3)
225
+ x4 = self.Conv4(x4)
226
+
227
+ x5 = self.Maxpool(x4)
228
+ x5 = self.Conv5(x5)
229
+
230
+ # decoding + concat path
231
+ d5 = self.Up5(x5)
232
+ if self.cat:
233
+ d5 = torch.cat((x4, d5), dim=cat_dim)
234
+ d5 = self.Up_conv5(d5)
235
+
236
+ d4 = self.Up4(d5)
237
+ if self.cat:
238
+ d4 = torch.cat((x3, d4), dim=cat_dim)
239
+ d4 = self.Up_conv4(d4)
240
+
241
+ d3 = self.Up3(d4)
242
+ if self.cat:
243
+ d3 = torch.cat((x2, d3), dim=cat_dim)
244
+ d3 = self.Up_conv3(d3)
245
+
246
+ d2 = self.Up2(d3)
247
+ if self.cat:
248
+ d2 = torch.cat((x1, d2), dim=cat_dim)
249
+ d2 = self.Up_conv2(d2)
250
+
251
+ d1 = self.Conv_1x1(d2)
252
+
253
+ out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
254
+ return out
255
+
256
+ def forward_compact4(self, x):
257
+ # def forward_compact4(self, x):
258
+ # encoding path
259
+ cat_dim = 1
260
+ input = x
261
+
262
+ x1 = self.Conv1(input) # 1->64
263
+
264
+ x2 = self.Maxpool(x1)
265
+ x2 = self.Conv2(x2) # 64->128
266
+
267
+ x3 = self.Maxpool(x2)
268
+ x3 = self.Conv3(x3) # 128->256
269
+
270
+ x4 = self.Maxpool(x3)
271
+ x4 = self.Conv4(x4) # 256->512
272
+
273
+ d4 = self.Up4(x4) # 512->256
274
+ if self.cat:
275
+ d4 = torch.cat((x3, d4), dim=cat_dim)
276
+ d4 = self.Up_conv4(d4)
277
+
278
+ d3 = self.Up3(d4) # 256->128
279
+ if self.cat:
280
+ d3 = torch.cat((x2, d3), dim=cat_dim)
281
+ d3 = self.Up_conv3(d3)
282
+
283
+ d2 = self.Up2(d3) # 128->64
284
+ if self.cat:
285
+ d2 = torch.cat((x1, d2), dim=cat_dim)
286
+ d2 = self.Up_conv2(d2)
287
+
288
+ d1 = self.Conv_1x1(d2)
289
+
290
+ out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
291
+ return out
292
+
293
+ def forward_compact3(self, x):
294
+ # encoding path
295
+ cat_dim = 1
296
+ input = x
297
+ x1 = self.Conv1(input)
298
+
299
+ x2 = self.Maxpool(x1)
300
+ x2 = self.Conv2(x2)
301
+
302
+ x3 = self.Maxpool(x2)
303
+ x3 = self.Conv3(x3)
304
+
305
+ d3 = self.Up3(x3)
306
+ if self.cat:
307
+ d3 = torch.cat((x2, d3), dim=cat_dim)
308
+ d3 = self.Up_conv3(d3)
309
+
310
+ d2 = self.Up2(d3)
311
+ if self.cat:
312
+ d2 = torch.cat((x1, d2), dim=cat_dim)
313
+ d2 = self.Up_conv2(d2)
314
+
315
+ d1 = self.Conv_1x1(d2)
316
+
317
+ out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
318
+ return out
319
+
320
+ def forward_compact2(self, x):
321
+ # encoding path
322
+ cat_dim = 1
323
+ input = x
324
+ x1 = self.Conv1(input)
325
+
326
+ x2 = self.Maxpool(x1)
327
+ x2 = self.Conv2(x2)
328
+
329
+ d2 = self.Up2(x2)
330
+ if self.cat:
331
+ d2 = torch.cat((x1, d2), dim=cat_dim)
332
+ d2 = self.Up_conv2(d2)
333
+
334
+ d1 = self.Conv_1x1(d2)
335
+
336
+ out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
337
+ return out
@@ -0,0 +1,22 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def tensor2array(img):
6
+ img = img.cpu().detach().numpy()
7
+ img = np.transpose(img, (1, 2, 0))
8
+ return img
9
+
10
+
11
+ def array2tensor(img):
12
+ return torch.from_numpy(img).permute(2, 0, 1)
13
+
14
+
15
+ def get_weights_url(model_name, file_name):
16
+ return (
17
+ "https://huggingface.co/deepinv/"
18
+ + model_name
19
+ + "/resolve/main/"
20
+ + file_name
21
+ + "?download=true"
22
+ )
@@ -0,0 +1,231 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ try:
5
+ import pytorch_wavelets
6
+ except:
7
+ pytorch_wavelets = ImportError("The pytorch_wavelets package is not installed.")
8
+
9
+
10
+ class WaveletPrior(nn.Module):
11
+ r"""
12
+ Wavelet denoising with the :math:`\ell_1` norm.
13
+
14
+
15
+ This denoiser is defined as the solution to the optimization problem:
16
+
17
+ .. math::
18
+
19
+ \underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n
20
+
21
+ where :math:`\Psi` is an orthonormal wavelet transform, :math:`\lambda>0` is a hyperparameter, and where
22
+ :math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``) or
23
+ the :math:`\ell_0` norm (``non_linearity="hard"``). A variant of the :math:`\ell_0` norm is also available
24
+ (``non_linearity="topk"``), where the thresholding is done by keeping the :math:`k` largest coefficients
25
+ in each wavelet subband and setting the others to zero.
26
+
27
+ The solution is available in closed-form, thus the denoiser is cheap to compute.
28
+
29
+ :param int level: decomposition level of the wavelet transform
30
+ :param str wv: mother wavelet (follows the `PyWavelets convention
31
+ <https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html>`_) (default: "db8")
32
+ :param str device: cpu or gpu
33
+ :param str non_linearity: ``"soft"``, ``"hard"`` or ``"topk"`` thresholding (default: ``"soft"``).
34
+ If ``"topk"``, only the top-k wavelet coefficients are kept.
35
+ """
36
+
37
+ def __init__(self, level=3, wv="db8", device="cpu", non_linearity="soft"):
38
+ if isinstance(pytorch_wavelets, ImportError):
39
+ raise ImportError(
40
+ "pytorch_wavelets is needed to use the WaveletPrior class. "
41
+ "It should be installed with `pip install "
42
+ "git+https://github.com/fbcotter/pytorch_wavelets.git`"
43
+ ) from pytorch_wavelets
44
+ super().__init__()
45
+ self.level = level
46
+ self.dwt = pytorch_wavelets.DWTForward(J=self.level, wave=wv).to(device)
47
+ self.iwt = pytorch_wavelets.DWTInverse(wave=wv).to(device)
48
+ self.device = device
49
+ self.non_linearity = non_linearity
50
+
51
+ def get_ths_map(self, ths):
52
+ if isinstance(ths, float) or isinstance(ths, int):
53
+ ths_map = ths
54
+ elif len(ths.shape) == 0 or ths.shape[0] == 1:
55
+ ths_map = ths.to(self.device)
56
+ else:
57
+ ths_map = (
58
+ ths.unsqueeze(0)
59
+ .unsqueeze(0)
60
+ .unsqueeze(-1)
61
+ .unsqueeze(-1)
62
+ .to(self.device)
63
+ )
64
+ return ths_map
65
+
66
+ def prox_l1(self, x, ths=0.1):
67
+ r"""
68
+ Soft thresholding of the wavelet coefficients.
69
+
70
+ :param torch.Tensor x: wavelet coefficients.
71
+ :param float, torch.Tensor ths: threshold.
72
+ """
73
+ ths_map = self.get_ths_map(ths)
74
+ return torch.maximum(
75
+ torch.tensor([0], device=x.device).type(x.dtype), x - ths_map
76
+ ) + torch.minimum(torch.tensor([0], device=x.device).type(x.dtype), x + ths_map)
77
+
78
+ def prox_l0(self, x, ths=0.1):
79
+ r"""
80
+ Hard thresholding of the wavelet coefficients.
81
+
82
+ :param torch.Tensor x: wavelet coefficients.
83
+ :param float, torch.Tensor ths: threshold.
84
+ """
85
+ if isinstance(ths, float):
86
+ ths_map = ths
87
+ else:
88
+ ths_map = self.get_ths_map(ths)
89
+ ths_map = ths_map.repeat(
90
+ 1, 1, 1, x.shape[-2], x.shape[-1]
91
+ ) # Reshaping to image wavelet shape
92
+ out = x.clone()
93
+ out[abs(out) < ths_map] = 0
94
+ return out
95
+
96
+ def hard_threshold_topk(self, x, ths=0.1):
97
+ r"""
98
+ Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to
99
+ 0.
100
+
101
+ :param torch.Tensor x: wavelet coefficients.
102
+ :param float, int ths: top k coefficients to keep. If ``float``, it is interpreted as a proportion of the total
103
+ number of coefficients. If ``int``, it is interpreted as the number of coefficients to keep.
104
+ """
105
+ if isinstance(ths, float):
106
+ k = int(ths * x.shape[-2] * x.shape[-1])
107
+ else:
108
+ k = int(ths)
109
+
110
+ # Reshape arrays to 2D and initialize output to 0
111
+ x_flat = x.view(x.shape[0], -1)
112
+ out = torch.zeros_like(x_flat)
113
+
114
+ topk_indices_flat = torch.topk(abs(x_flat), k, dim=-1)[1]
115
+
116
+ # Convert the flattened indices to the original indices of x
117
+ batch_indices = (
118
+ torch.arange(x.shape[0], device=x.device).unsqueeze(1).repeat(1, k)
119
+ )
120
+ topk_indices = torch.stack([batch_indices, topk_indices_flat], dim=-1)
121
+
122
+ # Set output's top-k elements to values from original x
123
+ out[tuple(topk_indices.view(-1, 2).t())] = x_flat[
124
+ tuple(topk_indices.view(-1, 2).t())
125
+ ]
126
+ return torch.reshape(out, x.shape)
127
+
128
+ def forward(self, x, ths=0.1):
129
+ r"""
130
+ Run the model on a noisy image.
131
+
132
+ :param torch.Tensor x: noisy image.
133
+ :param int, float, torch.Tensor ths: thresholding parameter.
134
+ If ``non_linearity`` equals ``"soft"`` or ``"hard"``, ``ths`` serves as a (soft or hard)
135
+ thresholding parameter for the wavelet coefficients. If ``non_linearity`` equals ``"topk"``,
136
+ ``ths`` can indicate the number of wavelet coefficients
137
+ that are kept (if ``int``) or the proportion of coefficients that are kept (if ``float``).
138
+
139
+ """
140
+ h, w = x.size()[-2:]
141
+ padding_bottom = h % 2
142
+ padding_right = w % 2
143
+ x = torch.nn.ReplicationPad2d((0, padding_right, 0, padding_bottom))(x)
144
+
145
+ coeffs = self.dwt(x)
146
+ for l in range(self.level):
147
+ ths_cur = (
148
+ ths
149
+ if (
150
+ isinstance(ths, int)
151
+ or isinstance(ths, float)
152
+ or len(ths.shape) == 0
153
+ or ths.shape[0] == 1
154
+ )
155
+ else ths[l]
156
+ )
157
+ if self.non_linearity == "soft":
158
+ coeffs[1][l] = self.prox_l1(coeffs[1][l], ths_cur)
159
+ elif self.non_linearity == "hard":
160
+ coeffs[1][l] = self.prox_l0(coeffs[1][l], ths_cur)
161
+ elif self.non_linearity == "topk":
162
+ coeffs[1][l] = self.hard_threshold_topk(coeffs[1][l], ths_cur)
163
+ y = self.iwt(coeffs)
164
+
165
+ y = y[..., :h, :w]
166
+ return y
167
+
168
+
169
+ class WaveletDict(nn.Module):
170
+ r"""
171
+ Overcomplete Wavelet denoising with the :math:`\ell_1` norm.
172
+
173
+ This denoiser is defined as the solution to the optimization problem:
174
+
175
+ .. math::
176
+
177
+ \underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n
178
+
179
+ where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e.,
180
+ :math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where
181
+ :math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``),
182
+ the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm
183
+ (``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletPrior` for
184
+ more details.
185
+
186
+ The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image.
187
+
188
+ :param int level: decomposition level of the wavelet transform.
189
+ :param list[str] wv: list of mother wavelets. The names of the wavelets can be found in `here
190
+ <https://wavelets.pybytes.com/>`_. (default: ["db8", "db4"]).
191
+ :param str device: cpu or gpu.
192
+ :param int max_iter: number of iterations of the optimization algorithm (default: 10).
193
+ :param str non_linearity: "soft", "hard" or "topk" thresholding (default: "soft")
194
+ """
195
+
196
+ def __init__(
197
+ self, level=3, list_wv=["db8", "db4"], max_iter=10, non_linearity="soft"
198
+ ):
199
+ super().__init__()
200
+ self.level = level
201
+ self.list_prox = nn.ModuleList(
202
+ [
203
+ WaveletPrior(level=level, wv=wv, non_linearity=non_linearity)
204
+ for wv in list_wv
205
+ ]
206
+ )
207
+ self.max_iter = max_iter
208
+
209
+ def forward(self, y, ths=0.1):
210
+ r"""
211
+ Run the model on a noisy image.
212
+
213
+ :param torch.Tensor y: noisy image.
214
+ :param float, torch.Tensor ths: noise level.
215
+ """
216
+ z_p = y.repeat(len(self.list_prox), 1, 1, 1, 1)
217
+ p_p = torch.zeros_like(z_p)
218
+ x = p_p.clone()
219
+ for it in range(self.max_iter):
220
+ x_prev = x.clone()
221
+ for p in range(len(self.list_prox)):
222
+ p_p[p, ...] = self.list_prox[p](z_p[p, ...], ths)
223
+ x = torch.mean(p_p.clone(), axis=0)
224
+ for p in range(len(self.list_prox)):
225
+ z_p[p, ...] = x + z_p[p, ...].clone() - p_p[p, ...]
226
+ rel_crit = torch.linalg.norm((x - x_prev).flatten()) / torch.linalg.norm(
227
+ x.flatten() + 1e-6
228
+ )
229
+ if rel_crit < 1e-3:
230
+ break
231
+ return x
@@ -0,0 +1,5 @@
1
+ from .data_fidelity import DataFidelity, L2, L1, IndicatorL2, PoissonLikelihood
2
+ from .optimizers import BaseOptim, optim_builder
3
+ from .fixed_point import FixedPoint
4
+ from .prior import Prior, ScorePrior, Tikhonov, PnP, RED, L1Prior
5
+ from .optim_iterators.optim_iterator import OptimIterator