deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/models/unet.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from .drunet import test_pad
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BFBatchNorm2d(nn.BatchNorm2d):
|
|
7
|
+
r"""
|
|
8
|
+
From Mohan et al.
|
|
9
|
+
|
|
10
|
+
"Robust And Interpretable Blind Image Denoising Via Bias-Free Convolutional Neural Networks"
|
|
11
|
+
S. Mohan, Z. Kadkhodaie, E. P. Simoncelli, C. Fernandez-Granda
|
|
12
|
+
Int'l. Conf. on Learning Representations (ICLR), Apr 2020.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, num_features, eps=1e-5, momentum=0.1, use_bias=False, affine=True
|
|
17
|
+
):
|
|
18
|
+
super(BFBatchNorm2d, self).__init__(num_features, eps, momentum)
|
|
19
|
+
self.use_bias = use_bias
|
|
20
|
+
self.affine = affine
|
|
21
|
+
|
|
22
|
+
def forward(self, x):
|
|
23
|
+
self._check_input_dim(x)
|
|
24
|
+
y = x.transpose(0, 1)
|
|
25
|
+
return_shape = y.shape
|
|
26
|
+
y = y.contiguous().view(x.size(1), -1)
|
|
27
|
+
if self.use_bias:
|
|
28
|
+
mu = y.mean(dim=1)
|
|
29
|
+
sigma2 = y.var(dim=1)
|
|
30
|
+
if self.training is not True:
|
|
31
|
+
if self.use_bias:
|
|
32
|
+
y = y - self.running_mean.view(-1, 1)
|
|
33
|
+
y = y / (self.running_var.view(-1, 1) ** 0.5 + self.eps)
|
|
34
|
+
else:
|
|
35
|
+
if self.track_running_stats is True:
|
|
36
|
+
with torch.no_grad():
|
|
37
|
+
if self.use_bias:
|
|
38
|
+
self.running_mean = (
|
|
39
|
+
1 - self.momentum
|
|
40
|
+
) * self.running_mean + self.momentum * mu
|
|
41
|
+
self.running_var = (
|
|
42
|
+
1 - self.momentum
|
|
43
|
+
) * self.running_var + self.momentum * sigma2
|
|
44
|
+
if self.use_bias:
|
|
45
|
+
y = y - mu.view(-1, 1)
|
|
46
|
+
y = y / (sigma2.view(-1, 1) ** 0.5 + self.eps)
|
|
47
|
+
if self.affine:
|
|
48
|
+
y = self.weight.view(-1, 1) * y
|
|
49
|
+
if self.use_bias:
|
|
50
|
+
y += self.bias.view(-1, 1)
|
|
51
|
+
|
|
52
|
+
return y.view(return_shape).transpose(0, 1)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class UNet(nn.Module):
|
|
56
|
+
r"""
|
|
57
|
+
U-Net convolutional denoiser.
|
|
58
|
+
|
|
59
|
+
This network is a fully convolutional denoiser based on the U-Net architecture. The number of downsample steps
|
|
60
|
+
can be controlled with the `scales` parameter. The number of trainable parameters increases with the number of
|
|
61
|
+
scales.
|
|
62
|
+
|
|
63
|
+
:param int in_channels: input image channels
|
|
64
|
+
:param int out_channels: output image channels
|
|
65
|
+
:param bool residual: use a skip-connection between output and output.
|
|
66
|
+
:param bool circular_padding: circular padding for the convolutional layers.
|
|
67
|
+
:param bool cat: use skip-connections between intermediate levels.
|
|
68
|
+
:param bool bias: use learnable biases.
|
|
69
|
+
:param int scales: Number of downsampling steps used in the U-Net. The options are 2,3,4 and 5.
|
|
70
|
+
The number of trainable parameters increases with the scale.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
in_channels=1,
|
|
76
|
+
out_channels=1,
|
|
77
|
+
residual=True,
|
|
78
|
+
circular_padding=False,
|
|
79
|
+
cat=True,
|
|
80
|
+
bias=True,
|
|
81
|
+
batch_norm=True,
|
|
82
|
+
scales=4,
|
|
83
|
+
):
|
|
84
|
+
super(UNet, self).__init__()
|
|
85
|
+
self.name = "unet"
|
|
86
|
+
|
|
87
|
+
self.in_channels = in_channels
|
|
88
|
+
self.out_channels = out_channels
|
|
89
|
+
|
|
90
|
+
self.residual = residual
|
|
91
|
+
self.cat = cat
|
|
92
|
+
self.compact = scales
|
|
93
|
+
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
94
|
+
|
|
95
|
+
def conv_block(ch_in, ch_out):
|
|
96
|
+
if batch_norm:
|
|
97
|
+
return nn.Sequential(
|
|
98
|
+
nn.Conv2d(
|
|
99
|
+
ch_in,
|
|
100
|
+
ch_out,
|
|
101
|
+
kernel_size=3,
|
|
102
|
+
stride=1,
|
|
103
|
+
padding=1,
|
|
104
|
+
bias=bias,
|
|
105
|
+
padding_mode="circular" if circular_padding else "zeros",
|
|
106
|
+
),
|
|
107
|
+
BFBatchNorm2d(ch_out, use_bias=bias),
|
|
108
|
+
nn.ReLU(inplace=True),
|
|
109
|
+
nn.Conv2d(
|
|
110
|
+
ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
|
|
111
|
+
),
|
|
112
|
+
BFBatchNorm2d(ch_out, use_bias=bias),
|
|
113
|
+
nn.ReLU(inplace=True),
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
return nn.Sequential(
|
|
117
|
+
nn.Conv2d(
|
|
118
|
+
ch_in,
|
|
119
|
+
ch_out,
|
|
120
|
+
kernel_size=3,
|
|
121
|
+
stride=1,
|
|
122
|
+
padding=1,
|
|
123
|
+
bias=bias,
|
|
124
|
+
padding_mode="circular" if circular_padding else "zeros",
|
|
125
|
+
),
|
|
126
|
+
nn.ReLU(inplace=True),
|
|
127
|
+
nn.Conv2d(
|
|
128
|
+
ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
|
|
129
|
+
),
|
|
130
|
+
nn.ReLU(inplace=True),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def up_conv(ch_in, ch_out):
|
|
134
|
+
if batch_norm:
|
|
135
|
+
return nn.Sequential(
|
|
136
|
+
nn.Upsample(scale_factor=2),
|
|
137
|
+
nn.Conv2d(
|
|
138
|
+
ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
|
|
139
|
+
),
|
|
140
|
+
BFBatchNorm2d(ch_out, use_bias=bias),
|
|
141
|
+
nn.ReLU(inplace=True),
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
return nn.Sequential(
|
|
145
|
+
nn.Upsample(scale_factor=2),
|
|
146
|
+
nn.Conv2d(
|
|
147
|
+
ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=bias
|
|
148
|
+
),
|
|
149
|
+
nn.ReLU(inplace=True),
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.Conv1 = conv_block(ch_in=in_channels, ch_out=64)
|
|
153
|
+
self.Conv2 = conv_block(ch_in=64, ch_out=128)
|
|
154
|
+
self.Conv3 = (
|
|
155
|
+
conv_block(ch_in=128, ch_out=256) if self.compact in [3, 4, 5] else None
|
|
156
|
+
)
|
|
157
|
+
self.Conv4 = (
|
|
158
|
+
conv_block(ch_in=256, ch_out=512) if self.compact in [4, 5] else None
|
|
159
|
+
)
|
|
160
|
+
self.Conv5 = conv_block(ch_in=512, ch_out=1024) if self.compact in [5] else None
|
|
161
|
+
|
|
162
|
+
self.Up5 = up_conv(ch_in=1024, ch_out=512) if self.compact in [5] else None
|
|
163
|
+
self.Up_conv5 = (
|
|
164
|
+
conv_block(ch_in=1024, ch_out=512) if self.compact in [5] else None
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self.Up4 = up_conv(ch_in=512, ch_out=256) if self.compact in [4, 5] else None
|
|
168
|
+
self.Up_conv4 = (
|
|
169
|
+
conv_block(ch_in=512, ch_out=256) if self.compact in [4, 5] else None
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
self.Up3 = up_conv(ch_in=256, ch_out=128) if self.compact in [3, 4, 5] else None
|
|
173
|
+
self.Up_conv3 = (
|
|
174
|
+
conv_block(ch_in=256, ch_out=128) if self.compact in [3, 4, 5] else None
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
self.Up2 = up_conv(ch_in=128, ch_out=64)
|
|
178
|
+
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
|
|
179
|
+
|
|
180
|
+
self.Conv_1x1 = nn.Conv2d(
|
|
181
|
+
in_channels=64,
|
|
182
|
+
out_channels=out_channels,
|
|
183
|
+
bias=bias,
|
|
184
|
+
kernel_size=1,
|
|
185
|
+
stride=1,
|
|
186
|
+
padding=0,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if self.compact == 5:
|
|
190
|
+
self._forward = self.forward_standard
|
|
191
|
+
if self.compact == 4:
|
|
192
|
+
self._forward = self.forward_compact4
|
|
193
|
+
if self.compact == 3:
|
|
194
|
+
self._forward = self.forward_compact3
|
|
195
|
+
if self.compact == 2:
|
|
196
|
+
self._forward = self.forward_compact2
|
|
197
|
+
|
|
198
|
+
def forward(self, x, sigma=None):
|
|
199
|
+
r"""
|
|
200
|
+
Run the denoiser on noisy image. The noise level is not used in this denoiser.
|
|
201
|
+
|
|
202
|
+
:param torch.Tensor x: noisy image.
|
|
203
|
+
:param float sigma: noise level (not used).
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
factor = 2 ** (self.compact - 1)
|
|
207
|
+
if x.size(2) % factor == 0 and x.size(3) % factor == 0:
|
|
208
|
+
return self._forward(x)
|
|
209
|
+
else:
|
|
210
|
+
return test_pad(self._forward, x, modulo=factor)
|
|
211
|
+
|
|
212
|
+
def forward_standard(self, x):
|
|
213
|
+
# encoding path
|
|
214
|
+
cat_dim = 1
|
|
215
|
+
input = x
|
|
216
|
+
x1 = self.Conv1(input)
|
|
217
|
+
|
|
218
|
+
x2 = self.Maxpool(x1)
|
|
219
|
+
x2 = self.Conv2(x2)
|
|
220
|
+
|
|
221
|
+
x3 = self.Maxpool(x2)
|
|
222
|
+
x3 = self.Conv3(x3)
|
|
223
|
+
|
|
224
|
+
x4 = self.Maxpool(x3)
|
|
225
|
+
x4 = self.Conv4(x4)
|
|
226
|
+
|
|
227
|
+
x5 = self.Maxpool(x4)
|
|
228
|
+
x5 = self.Conv5(x5)
|
|
229
|
+
|
|
230
|
+
# decoding + concat path
|
|
231
|
+
d5 = self.Up5(x5)
|
|
232
|
+
if self.cat:
|
|
233
|
+
d5 = torch.cat((x4, d5), dim=cat_dim)
|
|
234
|
+
d5 = self.Up_conv5(d5)
|
|
235
|
+
|
|
236
|
+
d4 = self.Up4(d5)
|
|
237
|
+
if self.cat:
|
|
238
|
+
d4 = torch.cat((x3, d4), dim=cat_dim)
|
|
239
|
+
d4 = self.Up_conv4(d4)
|
|
240
|
+
|
|
241
|
+
d3 = self.Up3(d4)
|
|
242
|
+
if self.cat:
|
|
243
|
+
d3 = torch.cat((x2, d3), dim=cat_dim)
|
|
244
|
+
d3 = self.Up_conv3(d3)
|
|
245
|
+
|
|
246
|
+
d2 = self.Up2(d3)
|
|
247
|
+
if self.cat:
|
|
248
|
+
d2 = torch.cat((x1, d2), dim=cat_dim)
|
|
249
|
+
d2 = self.Up_conv2(d2)
|
|
250
|
+
|
|
251
|
+
d1 = self.Conv_1x1(d2)
|
|
252
|
+
|
|
253
|
+
out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
|
|
254
|
+
return out
|
|
255
|
+
|
|
256
|
+
def forward_compact4(self, x):
|
|
257
|
+
# def forward_compact4(self, x):
|
|
258
|
+
# encoding path
|
|
259
|
+
cat_dim = 1
|
|
260
|
+
input = x
|
|
261
|
+
|
|
262
|
+
x1 = self.Conv1(input) # 1->64
|
|
263
|
+
|
|
264
|
+
x2 = self.Maxpool(x1)
|
|
265
|
+
x2 = self.Conv2(x2) # 64->128
|
|
266
|
+
|
|
267
|
+
x3 = self.Maxpool(x2)
|
|
268
|
+
x3 = self.Conv3(x3) # 128->256
|
|
269
|
+
|
|
270
|
+
x4 = self.Maxpool(x3)
|
|
271
|
+
x4 = self.Conv4(x4) # 256->512
|
|
272
|
+
|
|
273
|
+
d4 = self.Up4(x4) # 512->256
|
|
274
|
+
if self.cat:
|
|
275
|
+
d4 = torch.cat((x3, d4), dim=cat_dim)
|
|
276
|
+
d4 = self.Up_conv4(d4)
|
|
277
|
+
|
|
278
|
+
d3 = self.Up3(d4) # 256->128
|
|
279
|
+
if self.cat:
|
|
280
|
+
d3 = torch.cat((x2, d3), dim=cat_dim)
|
|
281
|
+
d3 = self.Up_conv3(d3)
|
|
282
|
+
|
|
283
|
+
d2 = self.Up2(d3) # 128->64
|
|
284
|
+
if self.cat:
|
|
285
|
+
d2 = torch.cat((x1, d2), dim=cat_dim)
|
|
286
|
+
d2 = self.Up_conv2(d2)
|
|
287
|
+
|
|
288
|
+
d1 = self.Conv_1x1(d2)
|
|
289
|
+
|
|
290
|
+
out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
|
|
291
|
+
return out
|
|
292
|
+
|
|
293
|
+
def forward_compact3(self, x):
|
|
294
|
+
# encoding path
|
|
295
|
+
cat_dim = 1
|
|
296
|
+
input = x
|
|
297
|
+
x1 = self.Conv1(input)
|
|
298
|
+
|
|
299
|
+
x2 = self.Maxpool(x1)
|
|
300
|
+
x2 = self.Conv2(x2)
|
|
301
|
+
|
|
302
|
+
x3 = self.Maxpool(x2)
|
|
303
|
+
x3 = self.Conv3(x3)
|
|
304
|
+
|
|
305
|
+
d3 = self.Up3(x3)
|
|
306
|
+
if self.cat:
|
|
307
|
+
d3 = torch.cat((x2, d3), dim=cat_dim)
|
|
308
|
+
d3 = self.Up_conv3(d3)
|
|
309
|
+
|
|
310
|
+
d2 = self.Up2(d3)
|
|
311
|
+
if self.cat:
|
|
312
|
+
d2 = torch.cat((x1, d2), dim=cat_dim)
|
|
313
|
+
d2 = self.Up_conv2(d2)
|
|
314
|
+
|
|
315
|
+
d1 = self.Conv_1x1(d2)
|
|
316
|
+
|
|
317
|
+
out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
|
|
318
|
+
return out
|
|
319
|
+
|
|
320
|
+
def forward_compact2(self, x):
|
|
321
|
+
# encoding path
|
|
322
|
+
cat_dim = 1
|
|
323
|
+
input = x
|
|
324
|
+
x1 = self.Conv1(input)
|
|
325
|
+
|
|
326
|
+
x2 = self.Maxpool(x1)
|
|
327
|
+
x2 = self.Conv2(x2)
|
|
328
|
+
|
|
329
|
+
d2 = self.Up2(x2)
|
|
330
|
+
if self.cat:
|
|
331
|
+
d2 = torch.cat((x1, d2), dim=cat_dim)
|
|
332
|
+
d2 = self.Up_conv2(d2)
|
|
333
|
+
|
|
334
|
+
d1 = self.Conv_1x1(d2)
|
|
335
|
+
|
|
336
|
+
out = d1 + x if self.residual and self.in_channels == self.out_channels else d1
|
|
337
|
+
return out
|
deepinv/models/utils.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def tensor2array(img):
|
|
6
|
+
img = img.cpu().detach().numpy()
|
|
7
|
+
img = np.transpose(img, (1, 2, 0))
|
|
8
|
+
return img
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def array2tensor(img):
|
|
12
|
+
return torch.from_numpy(img).permute(2, 0, 1)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_weights_url(model_name, file_name):
|
|
16
|
+
return (
|
|
17
|
+
"https://huggingface.co/deepinv/"
|
|
18
|
+
+ model_name
|
|
19
|
+
+ "/resolve/main/"
|
|
20
|
+
+ file_name
|
|
21
|
+
+ "?download=true"
|
|
22
|
+
)
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import pytorch_wavelets
|
|
6
|
+
except:
|
|
7
|
+
pytorch_wavelets = ImportError("The pytorch_wavelets package is not installed.")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class WaveletPrior(nn.Module):
|
|
11
|
+
r"""
|
|
12
|
+
Wavelet denoising with the :math:`\ell_1` norm.
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
This denoiser is defined as the solution to the optimization problem:
|
|
16
|
+
|
|
17
|
+
.. math::
|
|
18
|
+
|
|
19
|
+
\underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n
|
|
20
|
+
|
|
21
|
+
where :math:`\Psi` is an orthonormal wavelet transform, :math:`\lambda>0` is a hyperparameter, and where
|
|
22
|
+
:math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``) or
|
|
23
|
+
the :math:`\ell_0` norm (``non_linearity="hard"``). A variant of the :math:`\ell_0` norm is also available
|
|
24
|
+
(``non_linearity="topk"``), where the thresholding is done by keeping the :math:`k` largest coefficients
|
|
25
|
+
in each wavelet subband and setting the others to zero.
|
|
26
|
+
|
|
27
|
+
The solution is available in closed-form, thus the denoiser is cheap to compute.
|
|
28
|
+
|
|
29
|
+
:param int level: decomposition level of the wavelet transform
|
|
30
|
+
:param str wv: mother wavelet (follows the `PyWavelets convention
|
|
31
|
+
<https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html>`_) (default: "db8")
|
|
32
|
+
:param str device: cpu or gpu
|
|
33
|
+
:param str non_linearity: ``"soft"``, ``"hard"`` or ``"topk"`` thresholding (default: ``"soft"``).
|
|
34
|
+
If ``"topk"``, only the top-k wavelet coefficients are kept.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, level=3, wv="db8", device="cpu", non_linearity="soft"):
|
|
38
|
+
if isinstance(pytorch_wavelets, ImportError):
|
|
39
|
+
raise ImportError(
|
|
40
|
+
"pytorch_wavelets is needed to use the WaveletPrior class. "
|
|
41
|
+
"It should be installed with `pip install "
|
|
42
|
+
"git+https://github.com/fbcotter/pytorch_wavelets.git`"
|
|
43
|
+
) from pytorch_wavelets
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.level = level
|
|
46
|
+
self.dwt = pytorch_wavelets.DWTForward(J=self.level, wave=wv).to(device)
|
|
47
|
+
self.iwt = pytorch_wavelets.DWTInverse(wave=wv).to(device)
|
|
48
|
+
self.device = device
|
|
49
|
+
self.non_linearity = non_linearity
|
|
50
|
+
|
|
51
|
+
def get_ths_map(self, ths):
|
|
52
|
+
if isinstance(ths, float) or isinstance(ths, int):
|
|
53
|
+
ths_map = ths
|
|
54
|
+
elif len(ths.shape) == 0 or ths.shape[0] == 1:
|
|
55
|
+
ths_map = ths.to(self.device)
|
|
56
|
+
else:
|
|
57
|
+
ths_map = (
|
|
58
|
+
ths.unsqueeze(0)
|
|
59
|
+
.unsqueeze(0)
|
|
60
|
+
.unsqueeze(-1)
|
|
61
|
+
.unsqueeze(-1)
|
|
62
|
+
.to(self.device)
|
|
63
|
+
)
|
|
64
|
+
return ths_map
|
|
65
|
+
|
|
66
|
+
def prox_l1(self, x, ths=0.1):
|
|
67
|
+
r"""
|
|
68
|
+
Soft thresholding of the wavelet coefficients.
|
|
69
|
+
|
|
70
|
+
:param torch.Tensor x: wavelet coefficients.
|
|
71
|
+
:param float, torch.Tensor ths: threshold.
|
|
72
|
+
"""
|
|
73
|
+
ths_map = self.get_ths_map(ths)
|
|
74
|
+
return torch.maximum(
|
|
75
|
+
torch.tensor([0], device=x.device).type(x.dtype), x - ths_map
|
|
76
|
+
) + torch.minimum(torch.tensor([0], device=x.device).type(x.dtype), x + ths_map)
|
|
77
|
+
|
|
78
|
+
def prox_l0(self, x, ths=0.1):
|
|
79
|
+
r"""
|
|
80
|
+
Hard thresholding of the wavelet coefficients.
|
|
81
|
+
|
|
82
|
+
:param torch.Tensor x: wavelet coefficients.
|
|
83
|
+
:param float, torch.Tensor ths: threshold.
|
|
84
|
+
"""
|
|
85
|
+
if isinstance(ths, float):
|
|
86
|
+
ths_map = ths
|
|
87
|
+
else:
|
|
88
|
+
ths_map = self.get_ths_map(ths)
|
|
89
|
+
ths_map = ths_map.repeat(
|
|
90
|
+
1, 1, 1, x.shape[-2], x.shape[-1]
|
|
91
|
+
) # Reshaping to image wavelet shape
|
|
92
|
+
out = x.clone()
|
|
93
|
+
out[abs(out) < ths_map] = 0
|
|
94
|
+
return out
|
|
95
|
+
|
|
96
|
+
def hard_threshold_topk(self, x, ths=0.1):
|
|
97
|
+
r"""
|
|
98
|
+
Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to
|
|
99
|
+
0.
|
|
100
|
+
|
|
101
|
+
:param torch.Tensor x: wavelet coefficients.
|
|
102
|
+
:param float, int ths: top k coefficients to keep. If ``float``, it is interpreted as a proportion of the total
|
|
103
|
+
number of coefficients. If ``int``, it is interpreted as the number of coefficients to keep.
|
|
104
|
+
"""
|
|
105
|
+
if isinstance(ths, float):
|
|
106
|
+
k = int(ths * x.shape[-2] * x.shape[-1])
|
|
107
|
+
else:
|
|
108
|
+
k = int(ths)
|
|
109
|
+
|
|
110
|
+
# Reshape arrays to 2D and initialize output to 0
|
|
111
|
+
x_flat = x.view(x.shape[0], -1)
|
|
112
|
+
out = torch.zeros_like(x_flat)
|
|
113
|
+
|
|
114
|
+
topk_indices_flat = torch.topk(abs(x_flat), k, dim=-1)[1]
|
|
115
|
+
|
|
116
|
+
# Convert the flattened indices to the original indices of x
|
|
117
|
+
batch_indices = (
|
|
118
|
+
torch.arange(x.shape[0], device=x.device).unsqueeze(1).repeat(1, k)
|
|
119
|
+
)
|
|
120
|
+
topk_indices = torch.stack([batch_indices, topk_indices_flat], dim=-1)
|
|
121
|
+
|
|
122
|
+
# Set output's top-k elements to values from original x
|
|
123
|
+
out[tuple(topk_indices.view(-1, 2).t())] = x_flat[
|
|
124
|
+
tuple(topk_indices.view(-1, 2).t())
|
|
125
|
+
]
|
|
126
|
+
return torch.reshape(out, x.shape)
|
|
127
|
+
|
|
128
|
+
def forward(self, x, ths=0.1):
|
|
129
|
+
r"""
|
|
130
|
+
Run the model on a noisy image.
|
|
131
|
+
|
|
132
|
+
:param torch.Tensor x: noisy image.
|
|
133
|
+
:param int, float, torch.Tensor ths: thresholding parameter.
|
|
134
|
+
If ``non_linearity`` equals ``"soft"`` or ``"hard"``, ``ths`` serves as a (soft or hard)
|
|
135
|
+
thresholding parameter for the wavelet coefficients. If ``non_linearity`` equals ``"topk"``,
|
|
136
|
+
``ths`` can indicate the number of wavelet coefficients
|
|
137
|
+
that are kept (if ``int``) or the proportion of coefficients that are kept (if ``float``).
|
|
138
|
+
|
|
139
|
+
"""
|
|
140
|
+
h, w = x.size()[-2:]
|
|
141
|
+
padding_bottom = h % 2
|
|
142
|
+
padding_right = w % 2
|
|
143
|
+
x = torch.nn.ReplicationPad2d((0, padding_right, 0, padding_bottom))(x)
|
|
144
|
+
|
|
145
|
+
coeffs = self.dwt(x)
|
|
146
|
+
for l in range(self.level):
|
|
147
|
+
ths_cur = (
|
|
148
|
+
ths
|
|
149
|
+
if (
|
|
150
|
+
isinstance(ths, int)
|
|
151
|
+
or isinstance(ths, float)
|
|
152
|
+
or len(ths.shape) == 0
|
|
153
|
+
or ths.shape[0] == 1
|
|
154
|
+
)
|
|
155
|
+
else ths[l]
|
|
156
|
+
)
|
|
157
|
+
if self.non_linearity == "soft":
|
|
158
|
+
coeffs[1][l] = self.prox_l1(coeffs[1][l], ths_cur)
|
|
159
|
+
elif self.non_linearity == "hard":
|
|
160
|
+
coeffs[1][l] = self.prox_l0(coeffs[1][l], ths_cur)
|
|
161
|
+
elif self.non_linearity == "topk":
|
|
162
|
+
coeffs[1][l] = self.hard_threshold_topk(coeffs[1][l], ths_cur)
|
|
163
|
+
y = self.iwt(coeffs)
|
|
164
|
+
|
|
165
|
+
y = y[..., :h, :w]
|
|
166
|
+
return y
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class WaveletDict(nn.Module):
|
|
170
|
+
r"""
|
|
171
|
+
Overcomplete Wavelet denoising with the :math:`\ell_1` norm.
|
|
172
|
+
|
|
173
|
+
This denoiser is defined as the solution to the optimization problem:
|
|
174
|
+
|
|
175
|
+
.. math::
|
|
176
|
+
|
|
177
|
+
\underset{x}{\arg\min} \; \|x-y\|^2 + \lambda \|\Psi x\|_n
|
|
178
|
+
|
|
179
|
+
where :math:`\Psi` is an overcomplete wavelet transform, composed of 2 or more wavelets, i.e.,
|
|
180
|
+
:math:`\Psi=[\Psi_1,\Psi_2,\dots,\Psi_L]`, :math:`\lambda>0` is a hyperparameter, and where
|
|
181
|
+
:math:`\|\cdot\|_n` is either the :math:`\ell_1` norm (``non_linearity="soft"``),
|
|
182
|
+
the :math:`\ell_0` norm (``non_linearity="hard"``) or a variant of the :math:`\ell_0` norm
|
|
183
|
+
(``non_linearity="topk"``) where only the top-k coefficients are kept; see :meth:`deepinv.models.WaveletPrior` for
|
|
184
|
+
more details.
|
|
185
|
+
|
|
186
|
+
The solution is not available in closed-form, thus the denoiser runs an optimization algorithm for each test image.
|
|
187
|
+
|
|
188
|
+
:param int level: decomposition level of the wavelet transform.
|
|
189
|
+
:param list[str] wv: list of mother wavelets. The names of the wavelets can be found in `here
|
|
190
|
+
<https://wavelets.pybytes.com/>`_. (default: ["db8", "db4"]).
|
|
191
|
+
:param str device: cpu or gpu.
|
|
192
|
+
:param int max_iter: number of iterations of the optimization algorithm (default: 10).
|
|
193
|
+
:param str non_linearity: "soft", "hard" or "topk" thresholding (default: "soft")
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self, level=3, list_wv=["db8", "db4"], max_iter=10, non_linearity="soft"
|
|
198
|
+
):
|
|
199
|
+
super().__init__()
|
|
200
|
+
self.level = level
|
|
201
|
+
self.list_prox = nn.ModuleList(
|
|
202
|
+
[
|
|
203
|
+
WaveletPrior(level=level, wv=wv, non_linearity=non_linearity)
|
|
204
|
+
for wv in list_wv
|
|
205
|
+
]
|
|
206
|
+
)
|
|
207
|
+
self.max_iter = max_iter
|
|
208
|
+
|
|
209
|
+
def forward(self, y, ths=0.1):
|
|
210
|
+
r"""
|
|
211
|
+
Run the model on a noisy image.
|
|
212
|
+
|
|
213
|
+
:param torch.Tensor y: noisy image.
|
|
214
|
+
:param float, torch.Tensor ths: noise level.
|
|
215
|
+
"""
|
|
216
|
+
z_p = y.repeat(len(self.list_prox), 1, 1, 1, 1)
|
|
217
|
+
p_p = torch.zeros_like(z_p)
|
|
218
|
+
x = p_p.clone()
|
|
219
|
+
for it in range(self.max_iter):
|
|
220
|
+
x_prev = x.clone()
|
|
221
|
+
for p in range(len(self.list_prox)):
|
|
222
|
+
p_p[p, ...] = self.list_prox[p](z_p[p, ...], ths)
|
|
223
|
+
x = torch.mean(p_p.clone(), axis=0)
|
|
224
|
+
for p in range(len(self.list_prox)):
|
|
225
|
+
z_p[p, ...] = x + z_p[p, ...].clone() - p_p[p, ...]
|
|
226
|
+
rel_crit = torch.linalg.norm((x - x_prev).flatten()) / torch.linalg.norm(
|
|
227
|
+
x.flatten() + 1e-6
|
|
228
|
+
)
|
|
229
|
+
if rel_crit < 1e-3:
|
|
230
|
+
break
|
|
231
|
+
return x
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from .data_fidelity import DataFidelity, L2, L1, IndicatorL2, PoissonLikelihood
|
|
2
|
+
from .optimizers import BaseOptim, optim_builder
|
|
3
|
+
from .fixed_point import FixedPoint
|
|
4
|
+
from .prior import Prior, ScorePrior, Tikhonov, PnP, RED, L1Prior
|
|
5
|
+
from .optim_iterators.optim_iterator import OptimIterator
|