deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,689 @@
1
+ # Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from .utils import get_weights_url
7
+
8
+ cuda = True if torch.cuda.is_available() else False
9
+ Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
10
+
11
+
12
+ class DRUNet(nn.Module):
13
+ r"""
14
+ DRUNet denoiser network.
15
+
16
+ The network architecture is based on the paper
17
+ `Learning deep CNN denoiser prior for image restoration <https://arxiv.org/abs/1704.03264>`_,
18
+ and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts.
19
+
20
+ The network takes into account the noise level of the input image, which is encoded as an additional input channel.
21
+
22
+ A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3)
23
+ can be downloaded via setting ``pretrained='download'``.
24
+
25
+ :param int in_channels: number of channels of the input.
26
+ :param int out_channels: number of channels of the output.
27
+ :param list nc: number of convolutional layers.
28
+ :param int nb: number of convolutional blocks per layer.
29
+ :param int nf: number of channels per convolutional layer.
30
+ :param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus.
31
+ :param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and
32
+ "strideconv" for convolution with stride 2.
33
+ :param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel
34
+ shuffling, and "upconv" for nearest neighbour upsampling with additional convolution.
35
+ :param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
36
+ using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
37
+ online repository (only available for the default architecture with 3 or 1 input/output channels).
38
+ Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
39
+ See :ref:`pretrained-weights <pretrained-weights>` for more details.
40
+ :param bool train: training or testing mode.
41
+ :param str device: gpu or cpu.
42
+
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ in_channels=3,
48
+ out_channels=3,
49
+ nc=[64, 128, 256, 512],
50
+ nb=4,
51
+ act_mode="R",
52
+ downsample_mode="strideconv",
53
+ upsample_mode="convtranspose",
54
+ pretrained="download",
55
+ train=False,
56
+ device=None,
57
+ ):
58
+ super(DRUNet, self).__init__()
59
+ in_channels = in_channels + 1 # accounts for the input noise channel
60
+ self.m_head = conv(in_channels, nc[0], bias=False, mode="C")
61
+
62
+ # downsample
63
+ if downsample_mode == "avgpool":
64
+ downsample_block = downsample_avgpool
65
+ elif downsample_mode == "maxpool":
66
+ downsample_block = downsample_maxpool
67
+ elif downsample_mode == "strideconv":
68
+ downsample_block = downsample_strideconv
69
+ else:
70
+ raise NotImplementedError(
71
+ "downsample mode [{:s}] is not found".format(downsample_mode)
72
+ )
73
+
74
+ self.m_down1 = sequential(
75
+ *[
76
+ ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C")
77
+ for _ in range(nb)
78
+ ],
79
+ downsample_block(nc[0], nc[1], bias=False, mode="2"),
80
+ )
81
+ self.m_down2 = sequential(
82
+ *[
83
+ ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C")
84
+ for _ in range(nb)
85
+ ],
86
+ downsample_block(nc[1], nc[2], bias=False, mode="2"),
87
+ )
88
+ self.m_down3 = sequential(
89
+ *[
90
+ ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C")
91
+ for _ in range(nb)
92
+ ],
93
+ downsample_block(nc[2], nc[3], bias=False, mode="2"),
94
+ )
95
+
96
+ self.m_body = sequential(
97
+ *[
98
+ ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C")
99
+ for _ in range(nb)
100
+ ]
101
+ )
102
+
103
+ # upsample
104
+ if upsample_mode == "upconv":
105
+ upsample_block = upsample_upconv
106
+ elif upsample_mode == "pixelshuffle":
107
+ upsample_block = upsample_pixelshuffle
108
+ elif upsample_mode == "convtranspose":
109
+ upsample_block = upsample_convtranspose
110
+ else:
111
+ raise NotImplementedError(
112
+ "upsample mode [{:s}] is not found".format(upsample_mode)
113
+ )
114
+
115
+ self.m_up3 = sequential(
116
+ upsample_block(nc[3], nc[2], bias=False, mode="2"),
117
+ *[
118
+ ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C")
119
+ for _ in range(nb)
120
+ ],
121
+ )
122
+ self.m_up2 = sequential(
123
+ upsample_block(nc[2], nc[1], bias=False, mode="2"),
124
+ *[
125
+ ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C")
126
+ for _ in range(nb)
127
+ ],
128
+ )
129
+ self.m_up1 = sequential(
130
+ upsample_block(nc[1], nc[0], bias=False, mode="2"),
131
+ *[
132
+ ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C")
133
+ for _ in range(nb)
134
+ ],
135
+ )
136
+
137
+ self.m_tail = conv(nc[0], out_channels, bias=False, mode="C")
138
+ if pretrained is not None:
139
+ if pretrained == "download":
140
+ if in_channels == 4:
141
+ name = "drunet_deepinv_color.pth"
142
+ elif in_channels == 2:
143
+ name = "drunet_deepinv_gray.pth"
144
+ url = get_weights_url(model_name="drunet", file_name=name)
145
+ ckpt_drunet = torch.hub.load_state_dict_from_url(
146
+ url, map_location=lambda storage, loc: storage, file_name=name
147
+ )
148
+ else:
149
+ ckpt_drunet = torch.load(
150
+ pretrained, map_location=lambda storage, loc: storage
151
+ )
152
+
153
+ self.load_state_dict(ckpt_drunet, strict=True)
154
+
155
+ if not train:
156
+ self.eval()
157
+ for _, v in self.named_parameters():
158
+ v.requires_grad = False
159
+ else:
160
+ self.apply(weights_init_drunet)
161
+
162
+ if device is not None:
163
+ self.to(device)
164
+
165
+ def forward_unet(self, x0):
166
+ x1 = self.m_head(x0)
167
+ x2 = self.m_down1(x1)
168
+ x3 = self.m_down2(x2)
169
+ x4 = self.m_down3(x3)
170
+ x = self.m_body(x4)
171
+ x = self.m_up3(x + x4)
172
+ x = self.m_up2(x + x3)
173
+ x = self.m_up1(x + x2)
174
+ x = self.m_tail(x + x1)
175
+ return x
176
+
177
+ def forward(self, x, sigma):
178
+ r"""
179
+ Run the denoiser on image with noise level :math:`\sigma`.
180
+
181
+ :param torch.Tensor x: noisy image
182
+ :param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch.
183
+ If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``.
184
+ """
185
+ if isinstance(sigma, torch.Tensor):
186
+ if len(sigma.size()) > 0:
187
+ if x.get_device() > -1:
188
+ sigma = sigma[
189
+ int(x.get_device() * x.shape[0]) : int(
190
+ (x.get_device() + 1) * x.shape[0]
191
+ )
192
+ ]
193
+ noise_level_map = sigma.to(x.device)
194
+ else:
195
+ noise_level_map = sigma.view(x.size(0), 1, 1, 1).to(x.device)
196
+ noise_level_map = noise_level_map.expand(-1, 1, x.size(2), x.size(3))
197
+ else:
198
+ sigma = sigma.item()
199
+ noise_level_map = (
200
+ torch.FloatTensor(x.size(0), 1, x.size(2), x.size(3))
201
+ .fill_(sigma)
202
+ .to(x.device)
203
+ )
204
+ else:
205
+ noise_level_map = (
206
+ torch.FloatTensor(x.size(0), 1, x.size(2), x.size(3))
207
+ .fill_(sigma)
208
+ .to(x.device)
209
+ )
210
+ x = torch.cat((x, noise_level_map), 1)
211
+ if self.training or (
212
+ x.size(2) % 8 == 0
213
+ and x.size(3) % 8 == 0
214
+ and x.size(2) > 31
215
+ and x.size(3) > 31
216
+ ):
217
+ x = self.forward_unet(x)
218
+ elif x.size(2) < 32 or x.size(3) < 32:
219
+ x = test_pad(self.forward_unet, x, modulo=16)
220
+ else:
221
+ x = test_onesplit(self.forward_unet, x, refield=64)
222
+ return x
223
+
224
+
225
+ """
226
+ Functional blocks below
227
+ """
228
+ from collections import OrderedDict
229
+ import torch
230
+ import torch.nn as nn
231
+
232
+
233
+ """
234
+ # --------------------------------------------
235
+ # Advanced nn.Sequential
236
+ # https://github.com/xinntao/BasicSR
237
+ # --------------------------------------------
238
+ """
239
+
240
+
241
+ def sequential(*args):
242
+ """Advanced nn.Sequential.
243
+ Args:
244
+ nn.Sequential, nn.Module
245
+ Returns:
246
+ nn.Sequential
247
+ """
248
+ if len(args) == 1:
249
+ if isinstance(args[0], OrderedDict):
250
+ raise NotImplementedError("sequential does not support OrderedDict input.")
251
+ return args[0] # No sequential is needed.
252
+ modules = []
253
+ for module in args:
254
+ if isinstance(module, nn.Sequential):
255
+ for submodule in module.children():
256
+ modules.append(submodule)
257
+ elif isinstance(module, nn.Module):
258
+ modules.append(module)
259
+ return nn.Sequential(*modules)
260
+
261
+
262
+ """
263
+ # --------------------------------------------
264
+ # Useful blocks
265
+ # https://github.com/xinntao/BasicSR
266
+ # --------------------------------
267
+ # conv + normaliation + relu (conv)
268
+ # (PixelUnShuffle)
269
+ # (ConditionalBatchNorm2d)
270
+ # concat (ConcatBlock)
271
+ # sum (ShortcutBlock)
272
+ # resblock (ResBlock)
273
+ # Channel Attention (CA) Layer (CALayer)
274
+ # Residual Channel Attention Block (RCABlock)
275
+ # Residual Channel Attention Group (RCAGroup)
276
+ # Residual Dense Block (ResidualDenseBlock_5C)
277
+ # Residual in Residual Dense Block (RRDB)
278
+ # --------------------------------------------
279
+ """
280
+
281
+
282
+ # --------------------------------------------
283
+ # return nn.Sequantial of (Conv + BN + ReLU)
284
+ # --------------------------------------------
285
+ def conv(
286
+ in_channels=64,
287
+ out_channels=64,
288
+ kernel_size=3,
289
+ stride=1,
290
+ padding=1,
291
+ bias=True,
292
+ mode="CBR",
293
+ negative_slope=0.2,
294
+ ):
295
+ L = []
296
+ for t in mode:
297
+ if t == "C":
298
+ L.append(
299
+ nn.Conv2d(
300
+ in_channels=in_channels,
301
+ out_channels=out_channels,
302
+ kernel_size=kernel_size,
303
+ stride=stride,
304
+ padding=padding,
305
+ bias=bias,
306
+ )
307
+ )
308
+ elif t == "T":
309
+ L.append(
310
+ nn.ConvTranspose2d(
311
+ in_channels=in_channels,
312
+ out_channels=out_channels,
313
+ kernel_size=kernel_size,
314
+ stride=stride,
315
+ padding=padding,
316
+ bias=bias,
317
+ )
318
+ )
319
+ elif t == "B":
320
+ L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
321
+ elif t == "I":
322
+ L.append(nn.InstanceNorm2d(out_channels, affine=True))
323
+ elif t == "R":
324
+ L.append(nn.ReLU(inplace=True))
325
+ elif t == "r":
326
+ L.append(nn.ReLU(inplace=False))
327
+ elif t == "L":
328
+ L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
329
+ elif t == "l":
330
+ L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
331
+ elif t == "E":
332
+ L.append(nn.ELU(inplace=False))
333
+ elif t == "s":
334
+ L.append(nn.Softplus())
335
+ elif t == "2":
336
+ L.append(nn.PixelShuffle(upscale_factor=2))
337
+ elif t == "3":
338
+ L.append(nn.PixelShuffle(upscale_factor=3))
339
+ elif t == "4":
340
+ L.append(nn.PixelShuffle(upscale_factor=4))
341
+ elif t == "U":
342
+ L.append(nn.Upsample(scale_factor=2, mode="nearest"))
343
+ elif t == "u":
344
+ L.append(nn.Upsample(scale_factor=3, mode="nearest"))
345
+ elif t == "v":
346
+ L.append(nn.Upsample(scale_factor=4, mode="nearest"))
347
+ elif t == "M":
348
+ L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
349
+ elif t == "A":
350
+ L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
351
+ else:
352
+ raise NotImplementedError("Undefined type: ".format(t))
353
+ return sequential(*L)
354
+
355
+
356
+ # --------------------------------------------
357
+ # Res Block: x + conv(relu(conv(x)))
358
+ # --------------------------------------------
359
+ class ResBlock(nn.Module):
360
+ def __init__(
361
+ self,
362
+ in_channels=64,
363
+ out_channels=64,
364
+ kernel_size=3,
365
+ stride=1,
366
+ padding=1,
367
+ bias=True,
368
+ mode="CRC",
369
+ negative_slope=0.2,
370
+ ):
371
+ super(ResBlock, self).__init__()
372
+
373
+ assert in_channels == out_channels, "Only support in_channels==out_channels."
374
+ if mode[0] in ["R", "L"]:
375
+ mode = mode[0].lower() + mode[1:]
376
+
377
+ self.res = conv(
378
+ in_channels,
379
+ out_channels,
380
+ kernel_size,
381
+ stride,
382
+ padding,
383
+ bias,
384
+ mode,
385
+ negative_slope,
386
+ )
387
+
388
+ def forward(self, x):
389
+ res = self.res(x)
390
+ return x + res
391
+
392
+
393
+ """
394
+ # --------------------------------------------
395
+ # Upsampler
396
+ # Kai Zhang, https://github.com/cszn/KAIR
397
+ # --------------------------------------------
398
+ # upsample_pixelshuffle
399
+ # upsample_upconv
400
+ # upsample_convtranspose
401
+ # --------------------------------------------
402
+ """
403
+
404
+
405
+ # --------------------------------------------
406
+ # conv + subp (+ relu)
407
+ # --------------------------------------------
408
+ def upsample_pixelshuffle(
409
+ in_channels=64,
410
+ out_channels=3,
411
+ kernel_size=3,
412
+ stride=1,
413
+ padding=1,
414
+ bias=True,
415
+ mode="2R",
416
+ negative_slope=0.2,
417
+ ):
418
+ assert len(mode) < 4 and mode[0] in [
419
+ "2",
420
+ "3",
421
+ "4",
422
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
423
+ up1 = conv(
424
+ in_channels,
425
+ out_channels * (int(mode[0]) ** 2),
426
+ kernel_size,
427
+ stride,
428
+ padding,
429
+ bias,
430
+ mode="C" + mode,
431
+ negative_slope=negative_slope,
432
+ )
433
+ return up1
434
+
435
+
436
+ # --------------------------------------------
437
+ # nearest_upsample + conv (+ R)
438
+ # --------------------------------------------
439
+ def upsample_upconv(
440
+ in_channels=64,
441
+ out_channels=3,
442
+ kernel_size=3,
443
+ stride=1,
444
+ padding=1,
445
+ bias=True,
446
+ mode="2R",
447
+ negative_slope=0.2,
448
+ ):
449
+ assert len(mode) < 4 and mode[0] in [
450
+ "2",
451
+ "3",
452
+ "4",
453
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR"
454
+ if mode[0] == "2":
455
+ uc = "UC"
456
+ elif mode[0] == "3":
457
+ uc = "uC"
458
+ elif mode[0] == "4":
459
+ uc = "vC"
460
+ mode = mode.replace(mode[0], uc)
461
+ up1 = conv(
462
+ in_channels,
463
+ out_channels,
464
+ kernel_size,
465
+ stride,
466
+ padding,
467
+ bias,
468
+ mode=mode,
469
+ negative_slope=negative_slope,
470
+ )
471
+ return up1
472
+
473
+
474
+ # --------------------------------------------
475
+ # convTranspose (+ relu)
476
+ # --------------------------------------------
477
+ def upsample_convtranspose(
478
+ in_channels=64,
479
+ out_channels=3,
480
+ kernel_size=2,
481
+ stride=2,
482
+ padding=0,
483
+ bias=True,
484
+ mode="2R",
485
+ negative_slope=0.2,
486
+ ):
487
+ assert len(mode) < 4 and mode[0] in [
488
+ "2",
489
+ "3",
490
+ "4",
491
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
492
+ kernel_size = int(mode[0])
493
+ stride = int(mode[0])
494
+ mode = mode.replace(mode[0], "T")
495
+ up1 = conv(
496
+ in_channels,
497
+ out_channels,
498
+ kernel_size,
499
+ stride,
500
+ padding,
501
+ bias,
502
+ mode,
503
+ negative_slope,
504
+ )
505
+ return up1
506
+
507
+
508
+ """
509
+ # --------------------------------------------
510
+ # Downsampler
511
+ # Kai Zhang, https://github.com/cszn/KAIR
512
+ # --------------------------------------------
513
+ # downsample_strideconv
514
+ # downsample_maxpool
515
+ # downsample_avgpool
516
+ # --------------------------------------------
517
+ """
518
+
519
+
520
+ # --------------------------------------------
521
+ # strideconv (+ relu)
522
+ # --------------------------------------------
523
+ def downsample_strideconv(
524
+ in_channels=64,
525
+ out_channels=64,
526
+ kernel_size=2,
527
+ stride=2,
528
+ padding=0,
529
+ bias=True,
530
+ mode="2R",
531
+ negative_slope=0.2,
532
+ ):
533
+ assert len(mode) < 4 and mode[0] in [
534
+ "2",
535
+ "3",
536
+ "4",
537
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
538
+ kernel_size = int(mode[0])
539
+ stride = int(mode[0])
540
+ mode = mode.replace(mode[0], "C")
541
+ down1 = conv(
542
+ in_channels,
543
+ out_channels,
544
+ kernel_size,
545
+ stride,
546
+ padding,
547
+ bias,
548
+ mode,
549
+ negative_slope,
550
+ )
551
+ return down1
552
+
553
+
554
+ # --------------------------------------------
555
+ # maxpooling + conv (+ relu)
556
+ # --------------------------------------------
557
+ def downsample_maxpool(
558
+ in_channels=64,
559
+ out_channels=64,
560
+ kernel_size=3,
561
+ stride=1,
562
+ padding=0,
563
+ bias=True,
564
+ mode="2R",
565
+ negative_slope=0.2,
566
+ ):
567
+ assert len(mode) < 4 and mode[0] in [
568
+ "2",
569
+ "3",
570
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
571
+ kernel_size_pool = int(mode[0])
572
+ stride_pool = int(mode[0])
573
+ mode = mode.replace(mode[0], "MC")
574
+ pool = conv(
575
+ kernel_size=kernel_size_pool,
576
+ stride=stride_pool,
577
+ mode=mode[0],
578
+ negative_slope=negative_slope,
579
+ )
580
+ pool_tail = conv(
581
+ in_channels,
582
+ out_channels,
583
+ kernel_size,
584
+ stride,
585
+ padding,
586
+ bias,
587
+ mode=mode[1:],
588
+ negative_slope=negative_slope,
589
+ )
590
+ return sequential(pool, pool_tail)
591
+
592
+
593
+ # --------------------------------------------
594
+ # averagepooling + conv (+ relu)
595
+ # --------------------------------------------
596
+ def downsample_avgpool(
597
+ in_channels=64,
598
+ out_channels=64,
599
+ kernel_size=3,
600
+ stride=1,
601
+ padding=1,
602
+ bias=True,
603
+ mode="2R",
604
+ negative_slope=0.2,
605
+ ):
606
+ assert len(mode) < 4 and mode[0] in [
607
+ "2",
608
+ "3",
609
+ ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
610
+ kernel_size_pool = int(mode[0])
611
+ stride_pool = int(mode[0])
612
+ mode = mode.replace(mode[0], "AC")
613
+ pool = conv(
614
+ kernel_size=kernel_size_pool,
615
+ stride=stride_pool,
616
+ mode=mode[0],
617
+ negative_slope=negative_slope,
618
+ )
619
+ pool_tail = conv(
620
+ in_channels,
621
+ out_channels,
622
+ kernel_size,
623
+ stride,
624
+ padding,
625
+ bias,
626
+ mode=mode[1:],
627
+ negative_slope=negative_slope,
628
+ )
629
+ return sequential(pool, pool_tail)
630
+
631
+
632
+ """
633
+ Helpers for test time
634
+ """
635
+
636
+
637
+ def test_onesplit(model, L, refield=32, sf=1):
638
+ """
639
+ Changes the size of the image to fit the model's expected image size.
640
+
641
+ :param model: model.
642
+ :param L: input Low-quality image.
643
+ :param refield: effective receptive field of the network, 32 is enough.
644
+ :param sf: scale factor for super-resolution, otherwise 1.
645
+ """
646
+ h, w = L.size()[-2:]
647
+ top = slice(0, (h // 2 // refield + 1) * refield)
648
+ bottom = slice(h - (h // 2 // refield + 1) * refield, h)
649
+ left = slice(0, (w // 2 // refield + 1) * refield)
650
+ right = slice(w - (w // 2 // refield + 1) * refield, w)
651
+ Ls = [
652
+ L[..., top, left],
653
+ L[..., top, right],
654
+ L[..., bottom, left],
655
+ L[..., bottom, right],
656
+ ]
657
+ Es = [model(Ls[i]) for i in range(4)]
658
+ b, c = Es[0].size()[:2]
659
+ E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
660
+ E[..., : h // 2 * sf, : w // 2 * sf] = Es[0][..., : h // 2 * sf, : w // 2 * sf]
661
+ E[..., : h // 2 * sf, w // 2 * sf : w * sf] = Es[1][
662
+ ..., : h // 2 * sf, (-w + w // 2) * sf :
663
+ ]
664
+ E[..., h // 2 * sf : h * sf, : w // 2 * sf] = Es[2][
665
+ ..., (-h + h // 2) * sf :, : w // 2 * sf
666
+ ]
667
+ E[..., h // 2 * sf : h * sf, w // 2 * sf : w * sf] = Es[3][
668
+ ..., (-h + h // 2) * sf :, (-w + w // 2) * sf :
669
+ ]
670
+ return E
671
+
672
+
673
+ def test_pad(model, L, modulo=16):
674
+ """
675
+ Pads the image to fit the model's expected image size.
676
+ """
677
+ h, w = L.size()[-2:]
678
+ padding_bottom = int(np.ceil(h / modulo) * modulo - h)
679
+ padding_right = int(np.ceil(w / modulo) * modulo - w)
680
+ L = torch.nn.ReplicationPad2d((0, padding_right, 0, padding_bottom))(L)
681
+ E = model(L)
682
+ E = E[..., :h, :w]
683
+ return E
684
+
685
+
686
+ def weights_init_drunet(m):
687
+ classname = m.__class__.__name__
688
+ if classname.find("Conv") != -1:
689
+ nn.init.orthogonal_(m.weight.data, gain=0.2)