deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1140 @@
1
+ # This file is taken (with only mild modifications) from the SwinIR repository:
2
+ # https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py
3
+ # -----------------------------------------------------------------------------------
4
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
5
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
6
+ # -----------------------------------------------------------------------------------
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint as checkpoint
13
+
14
+ # Compatibility with optional dependency on timm
15
+ try:
16
+ import timm
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+ except ImportError as e:
19
+ timm = e
20
+
21
+
22
+ class Mlp(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_features,
26
+ hidden_features=None,
27
+ out_features=None,
28
+ act_layer=nn.GELU,
29
+ drop=0.0,
30
+ ):
31
+ super().__init__()
32
+ out_features = out_features or in_features
33
+ hidden_features = hidden_features or in_features
34
+ self.fc1 = nn.Linear(in_features, hidden_features)
35
+ self.act = act_layer()
36
+ self.fc2 = nn.Linear(hidden_features, out_features)
37
+ self.drop = nn.Dropout(drop)
38
+
39
+ def forward(self, x):
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop(x)
43
+ x = self.fc2(x)
44
+ x = self.drop(x)
45
+ return x
46
+
47
+
48
+ def window_partition(x, window_size):
49
+ """
50
+ Args:
51
+ x: (B, H, W, C)
52
+ window_size (int): window size
53
+
54
+ Returns:
55
+ windows: (num_windows*B, window_size, window_size, C)
56
+ """
57
+ B, H, W, C = x.shape
58
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
59
+ windows = (
60
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
61
+ )
62
+ return windows
63
+
64
+
65
+ def window_reverse(windows, window_size, H, W):
66
+ """
67
+ Args:
68
+ windows: (num_windows*B, window_size, window_size, C)
69
+ window_size (int): Window size
70
+ H (int): Height of image
71
+ W (int): Width of image
72
+
73
+ Returns:
74
+ x: (B, H, W, C)
75
+ """
76
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
77
+ x = windows.view(
78
+ B, H // window_size, W // window_size, window_size, window_size, -1
79
+ )
80
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
81
+ return x
82
+
83
+
84
+ class WindowAttention(nn.Module):
85
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
86
+ It supports both of shifted and non-shifted window.
87
+
88
+ Args:
89
+ dim (int): Number of input channels.
90
+ window_size (tuple[int]): The height and width of the window.
91
+ num_heads (int): Number of attention heads.
92
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
93
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
94
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
95
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ dim,
101
+ window_size,
102
+ num_heads,
103
+ qkv_bias=True,
104
+ qk_scale=None,
105
+ attn_drop=0.0,
106
+ proj_drop=0.0,
107
+ ):
108
+ super().__init__()
109
+ self.dim = dim
110
+ self.window_size = window_size # Wh, Ww
111
+ self.num_heads = num_heads
112
+ head_dim = dim // num_heads
113
+ self.scale = qk_scale or head_dim**-0.5
114
+
115
+ # define a parameter table of relative position bias
116
+ self.relative_position_bias_table = nn.Parameter(
117
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
118
+ ) # 2*Wh-1 * 2*Ww-1, nH
119
+
120
+ # get pair-wise relative position index for each token inside the window
121
+ coords_h = torch.arange(self.window_size[0])
122
+ coords_w = torch.arange(self.window_size[1])
123
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
124
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
125
+ relative_coords = (
126
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
127
+ ) # 2, Wh*Ww, Wh*Ww
128
+ relative_coords = relative_coords.permute(
129
+ 1, 2, 0
130
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
131
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
132
+ relative_coords[:, :, 1] += self.window_size[1] - 1
133
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
134
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
135
+ self.register_buffer("relative_position_index", relative_position_index)
136
+
137
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
138
+ self.attn_drop = nn.Dropout(attn_drop)
139
+ self.proj = nn.Linear(dim, dim)
140
+
141
+ self.proj_drop = nn.Dropout(proj_drop)
142
+
143
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
144
+ self.softmax = nn.Softmax(dim=-1)
145
+
146
+ def forward(self, x, mask=None):
147
+ """
148
+ Args:
149
+ x: input features with shape of (num_windows*B, N, C)
150
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
151
+ """
152
+ B_, N, C = x.shape
153
+ qkv = (
154
+ self.qkv(x)
155
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
156
+ .permute(2, 0, 3, 1, 4)
157
+ )
158
+ q, k, v = (
159
+ qkv[0],
160
+ qkv[1],
161
+ qkv[2],
162
+ ) # make torchscript happy (cannot use tensor as tuple)
163
+
164
+ q = q * self.scale
165
+ attn = q @ k.transpose(-2, -1)
166
+
167
+ relative_position_bias = self.relative_position_bias_table[
168
+ self.relative_position_index.view(-1)
169
+ ].view(
170
+ self.window_size[0] * self.window_size[1],
171
+ self.window_size[0] * self.window_size[1],
172
+ -1,
173
+ ) # Wh*Ww,Wh*Ww,nH
174
+ relative_position_bias = relative_position_bias.permute(
175
+ 2, 0, 1
176
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
177
+ attn = attn + relative_position_bias.unsqueeze(0)
178
+
179
+ if mask is not None:
180
+ nW = mask.shape[0]
181
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
182
+ 1
183
+ ).unsqueeze(0)
184
+ attn = attn.view(-1, self.num_heads, N, N)
185
+ attn = self.softmax(attn)
186
+ else:
187
+ attn = self.softmax(attn)
188
+
189
+ attn = self.attn_drop(attn)
190
+
191
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
192
+ x = self.proj(x)
193
+ x = self.proj_drop(x)
194
+ return x
195
+
196
+ def extra_repr(self) -> str:
197
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
198
+
199
+ def flops(self, N):
200
+ # calculate flops for 1 window with token length of N
201
+ flops = 0
202
+ # qkv = self.qkv(x)
203
+ flops += N * self.dim * 3 * self.dim
204
+ # attn = (q @ k.transpose(-2, -1))
205
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
206
+ # x = (attn @ v)
207
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
208
+ # x = self.proj(x)
209
+ flops += N * self.dim * self.dim
210
+ return flops
211
+
212
+
213
+ class SwinTransformerBlock(nn.Module):
214
+ r"""Swin Transformer Block.
215
+
216
+ Args:
217
+ dim (int): Number of input channels.
218
+ input_resolution (tuple[int]): Input resulotion.
219
+ num_heads (int): Number of attention heads.
220
+ window_size (int): Window size.
221
+ shift_size (int): Shift size for SW-MSA.
222
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
223
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
224
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
225
+ drop (float, optional): Dropout rate. Default: 0.0
226
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
227
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
228
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
229
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ input_resolution,
236
+ num_heads,
237
+ window_size=7,
238
+ shift_size=0,
239
+ mlp_ratio=4.0,
240
+ qkv_bias=True,
241
+ qk_scale=None,
242
+ drop=0.0,
243
+ attn_drop=0.0,
244
+ drop_path=0.0,
245
+ act_layer=nn.GELU,
246
+ norm_layer=nn.LayerNorm,
247
+ ):
248
+ super().__init__()
249
+ self.dim = dim
250
+ self.input_resolution = input_resolution
251
+ self.num_heads = num_heads
252
+ self.window_size = window_size
253
+ self.shift_size = shift_size
254
+ self.mlp_ratio = mlp_ratio
255
+ if min(self.input_resolution) <= self.window_size:
256
+ # if window size is larger than input resolution, we don't partition windows
257
+ self.shift_size = 0
258
+ self.window_size = min(self.input_resolution)
259
+ assert (
260
+ 0 <= self.shift_size < self.window_size
261
+ ), "shift_size must in 0-window_size"
262
+
263
+ self.norm1 = norm_layer(dim)
264
+ self.attn = WindowAttention(
265
+ dim,
266
+ window_size=to_2tuple(self.window_size),
267
+ num_heads=num_heads,
268
+ qkv_bias=qkv_bias,
269
+ qk_scale=qk_scale,
270
+ attn_drop=attn_drop,
271
+ proj_drop=drop,
272
+ )
273
+
274
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
275
+ self.norm2 = norm_layer(dim)
276
+ mlp_hidden_dim = int(dim * mlp_ratio)
277
+ self.mlp = Mlp(
278
+ in_features=dim,
279
+ hidden_features=mlp_hidden_dim,
280
+ act_layer=act_layer,
281
+ drop=drop,
282
+ )
283
+
284
+ if self.shift_size > 0:
285
+ attn_mask = self.calculate_mask(self.input_resolution)
286
+ else:
287
+ attn_mask = None
288
+
289
+ self.register_buffer("attn_mask", attn_mask)
290
+
291
+ def calculate_mask(self, x_size):
292
+ # calculate attention mask for SW-MSA
293
+ H, W = x_size
294
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
295
+ h_slices = (
296
+ slice(0, -self.window_size),
297
+ slice(-self.window_size, -self.shift_size),
298
+ slice(-self.shift_size, None),
299
+ )
300
+ w_slices = (
301
+ slice(0, -self.window_size),
302
+ slice(-self.window_size, -self.shift_size),
303
+ slice(-self.shift_size, None),
304
+ )
305
+ cnt = 0
306
+ for h in h_slices:
307
+ for w in w_slices:
308
+ img_mask[:, h, w, :] = cnt
309
+ cnt += 1
310
+
311
+ mask_windows = window_partition(
312
+ img_mask, self.window_size
313
+ ) # nW, window_size, window_size, 1
314
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
315
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
316
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
317
+ attn_mask == 0, float(0.0)
318
+ )
319
+
320
+ return attn_mask
321
+
322
+ def forward(self, x, x_size):
323
+ H, W = x_size
324
+ B, L, C = x.shape
325
+ # assert L == H * W, "input feature has wrong size"
326
+
327
+ shortcut = x
328
+ x = self.norm1(x)
329
+ x = x.view(B, H, W, C)
330
+
331
+ # cyclic shift
332
+ if self.shift_size > 0:
333
+ shifted_x = torch.roll(
334
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
335
+ )
336
+ else:
337
+ shifted_x = x
338
+
339
+ # partition windows
340
+ x_windows = window_partition(
341
+ shifted_x, self.window_size
342
+ ) # nW*B, window_size, window_size, C
343
+ x_windows = x_windows.view(
344
+ -1, self.window_size * self.window_size, C
345
+ ) # nW*B, window_size*window_size, C
346
+
347
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
348
+ if self.input_resolution == x_size:
349
+ attn_windows = self.attn(
350
+ x_windows, mask=self.attn_mask
351
+ ) # nW*B, window_size*window_size, C
352
+ else:
353
+ attn_windows = self.attn(
354
+ x_windows, mask=self.calculate_mask(x_size).to(x.device)
355
+ )
356
+
357
+ # merge windows
358
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
359
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
360
+
361
+ # reverse cyclic shift
362
+ if self.shift_size > 0:
363
+ x = torch.roll(
364
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
365
+ )
366
+ else:
367
+ x = shifted_x
368
+ x = x.view(B, H * W, C)
369
+
370
+ # FFN
371
+ x = shortcut + self.drop_path(x)
372
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
373
+
374
+ return x
375
+
376
+ def extra_repr(self) -> str:
377
+ return (
378
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
379
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
380
+ )
381
+
382
+ def flops(self):
383
+ flops = 0
384
+ H, W = self.input_resolution
385
+ # norm1
386
+ flops += self.dim * H * W
387
+ # W-MSA/SW-MSA
388
+ nW = H * W / self.window_size / self.window_size
389
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
390
+ # mlp
391
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
392
+ # norm2
393
+ flops += self.dim * H * W
394
+ return flops
395
+
396
+
397
+ class PatchMerging(nn.Module):
398
+ r"""Patch Merging Layer.
399
+
400
+ Args:
401
+ input_resolution (tuple[int]): Resolution of input feature.
402
+ dim (int): Number of input channels.
403
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
404
+ """
405
+
406
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
407
+ super().__init__()
408
+ self.input_resolution = input_resolution
409
+ self.dim = dim
410
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
411
+ self.norm = norm_layer(4 * dim)
412
+
413
+ def forward(self, x):
414
+ """
415
+ x: B, H*W, C
416
+ """
417
+ H, W = self.input_resolution
418
+ B, L, C = x.shape
419
+ assert L == H * W, "input feature has wrong size"
420
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
421
+
422
+ x = x.view(B, H, W, C)
423
+
424
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
425
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
426
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
427
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
428
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
429
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
430
+
431
+ x = self.norm(x)
432
+ x = self.reduction(x)
433
+
434
+ return x
435
+
436
+ def extra_repr(self) -> str:
437
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
438
+
439
+ def flops(self):
440
+ H, W = self.input_resolution
441
+ flops = H * W * self.dim
442
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
443
+ return flops
444
+
445
+
446
+ class BasicLayer(nn.Module):
447
+ """A basic Swin Transformer layer for one stage.
448
+
449
+ Args:
450
+ dim (int): Number of input channels.
451
+ input_resolution (tuple[int]): Input resolution.
452
+ depth (int): Number of blocks.
453
+ num_heads (int): Number of attention heads.
454
+ window_size (int): Local window size.
455
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
456
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
457
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
458
+ drop (float, optional): Dropout rate. Default: 0.0
459
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
460
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
461
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
462
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
463
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ dim,
469
+ input_resolution,
470
+ depth,
471
+ num_heads,
472
+ window_size,
473
+ mlp_ratio=4.0,
474
+ qkv_bias=True,
475
+ qk_scale=None,
476
+ drop=0.0,
477
+ attn_drop=0.0,
478
+ drop_path=0.0,
479
+ norm_layer=nn.LayerNorm,
480
+ downsample=None,
481
+ use_checkpoint=False,
482
+ ):
483
+ super().__init__()
484
+ self.dim = dim
485
+ self.input_resolution = input_resolution
486
+ self.depth = depth
487
+ self.use_checkpoint = use_checkpoint
488
+
489
+ # build blocks
490
+ self.blocks = nn.ModuleList(
491
+ [
492
+ SwinTransformerBlock(
493
+ dim=dim,
494
+ input_resolution=input_resolution,
495
+ num_heads=num_heads,
496
+ window_size=window_size,
497
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
498
+ mlp_ratio=mlp_ratio,
499
+ qkv_bias=qkv_bias,
500
+ qk_scale=qk_scale,
501
+ drop=drop,
502
+ attn_drop=attn_drop,
503
+ drop_path=drop_path[i]
504
+ if isinstance(drop_path, list)
505
+ else drop_path,
506
+ norm_layer=norm_layer,
507
+ )
508
+ for i in range(depth)
509
+ ]
510
+ )
511
+
512
+ # patch merging layer
513
+ if downsample is not None:
514
+ self.downsample = downsample(
515
+ input_resolution, dim=dim, norm_layer=norm_layer
516
+ )
517
+ else:
518
+ self.downsample = None
519
+
520
+ def forward(self, x, x_size):
521
+ for blk in self.blocks:
522
+ if self.use_checkpoint:
523
+ x = checkpoint.checkpoint(blk, x, x_size)
524
+ else:
525
+ x = blk(x, x_size)
526
+ if self.downsample is not None:
527
+ x = self.downsample(x)
528
+ return x
529
+
530
+ def extra_repr(self) -> str:
531
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
532
+
533
+ def flops(self):
534
+ flops = 0
535
+ for blk in self.blocks:
536
+ flops += blk.flops()
537
+ if self.downsample is not None:
538
+ flops += self.downsample.flops()
539
+ return flops
540
+
541
+
542
+ class RSTB(nn.Module):
543
+ """Residual Swin Transformer Block (RSTB).
544
+
545
+ Args:
546
+ dim (int): Number of input channels.
547
+ input_resolution (tuple[int]): Input resolution.
548
+ depth (int): Number of blocks.
549
+ num_heads (int): Number of attention heads.
550
+ window_size (int): Local window size.
551
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
552
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
553
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
554
+ drop (float, optional): Dropout rate. Default: 0.0
555
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
556
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
557
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
558
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
559
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
560
+ img_size: Input image size.
561
+ patch_size: Patch size.
562
+ resi_connection: The convolutional block before residual connection.
563
+ """
564
+
565
+ def __init__(
566
+ self,
567
+ dim,
568
+ input_resolution,
569
+ depth,
570
+ num_heads,
571
+ window_size,
572
+ mlp_ratio=4.0,
573
+ qkv_bias=True,
574
+ qk_scale=None,
575
+ drop=0.0,
576
+ attn_drop=0.0,
577
+ drop_path=0.0,
578
+ norm_layer=nn.LayerNorm,
579
+ downsample=None,
580
+ use_checkpoint=False,
581
+ img_size=224,
582
+ patch_size=4,
583
+ resi_connection="1conv",
584
+ ):
585
+ super(RSTB, self).__init__()
586
+
587
+ self.dim = dim
588
+ self.input_resolution = input_resolution
589
+
590
+ self.residual_group = BasicLayer(
591
+ dim=dim,
592
+ input_resolution=input_resolution,
593
+ depth=depth,
594
+ num_heads=num_heads,
595
+ window_size=window_size,
596
+ mlp_ratio=mlp_ratio,
597
+ qkv_bias=qkv_bias,
598
+ qk_scale=qk_scale,
599
+ drop=drop,
600
+ attn_drop=attn_drop,
601
+ drop_path=drop_path,
602
+ norm_layer=norm_layer,
603
+ downsample=downsample,
604
+ use_checkpoint=use_checkpoint,
605
+ )
606
+
607
+ if resi_connection == "1conv":
608
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
609
+ elif resi_connection == "3conv":
610
+ # to save parameters and memory
611
+ self.conv = nn.Sequential(
612
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
613
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
614
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
615
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
616
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
617
+ )
618
+
619
+ self.patch_embed = PatchEmbed(
620
+ img_size=img_size,
621
+ patch_size=patch_size,
622
+ in_chans=0,
623
+ embed_dim=dim,
624
+ norm_layer=None,
625
+ )
626
+
627
+ self.patch_unembed = PatchUnEmbed(
628
+ img_size=img_size,
629
+ patch_size=patch_size,
630
+ in_chans=0,
631
+ embed_dim=dim,
632
+ norm_layer=None,
633
+ )
634
+
635
+ def forward(self, x, x_size):
636
+ return (
637
+ self.patch_embed(
638
+ self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
639
+ )
640
+ + x
641
+ )
642
+
643
+ def flops(self):
644
+ flops = 0
645
+ flops += self.residual_group.flops()
646
+ H, W = self.input_resolution
647
+ flops += H * W * self.dim * self.dim * 9
648
+ flops += self.patch_embed.flops()
649
+ flops += self.patch_unembed.flops()
650
+
651
+ return flops
652
+
653
+
654
+ class PatchEmbed(nn.Module):
655
+ r"""Image to Patch Embedding
656
+
657
+ Args:
658
+ img_size (int): Image size. Default: 224.
659
+ patch_size (int): Patch token size. Default: 4.
660
+ in_chans (int): Number of input image channels. Default: 3.
661
+ embed_dim (int): Number of linear projection output channels. Default: 96.
662
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
663
+ """
664
+
665
+ def __init__(
666
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
667
+ ):
668
+ super().__init__()
669
+ img_size = to_2tuple(img_size)
670
+ patch_size = to_2tuple(patch_size)
671
+ patches_resolution = [
672
+ img_size[0] // patch_size[0],
673
+ img_size[1] // patch_size[1],
674
+ ]
675
+ self.img_size = img_size
676
+ self.patch_size = patch_size
677
+ self.patches_resolution = patches_resolution
678
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
679
+
680
+ self.in_chans = in_chans
681
+ self.embed_dim = embed_dim
682
+
683
+ if norm_layer is not None:
684
+ self.norm = norm_layer(embed_dim)
685
+ else:
686
+ self.norm = None
687
+
688
+ def forward(self, x):
689
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
690
+ if self.norm is not None:
691
+ x = self.norm(x)
692
+ return x
693
+
694
+ def flops(self):
695
+ flops = 0
696
+ H, W = self.img_size
697
+ if self.norm is not None:
698
+ flops += H * W * self.embed_dim
699
+ return flops
700
+
701
+
702
+ class PatchUnEmbed(nn.Module):
703
+ r"""Image to Patch Unembedding
704
+
705
+ Args:
706
+ img_size (int): Image size. Default: 224.
707
+ patch_size (int): Patch token size. Default: 4.
708
+ in_chans (int): Number of input image channels. Default: 3.
709
+ embed_dim (int): Number of linear projection output channels. Default: 96.
710
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
711
+ """
712
+
713
+ def __init__(
714
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
715
+ ):
716
+ super().__init__()
717
+ img_size = to_2tuple(img_size)
718
+ patch_size = to_2tuple(patch_size)
719
+ patches_resolution = [
720
+ img_size[0] // patch_size[0],
721
+ img_size[1] // patch_size[1],
722
+ ]
723
+ self.img_size = img_size
724
+ self.patch_size = patch_size
725
+ self.patches_resolution = patches_resolution
726
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
727
+
728
+ self.in_chans = in_chans
729
+ self.embed_dim = embed_dim
730
+
731
+ def forward(self, x, x_size):
732
+ B, HW, C = x.shape
733
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
734
+ return x
735
+
736
+ def flops(self):
737
+ flops = 0
738
+ return flops
739
+
740
+
741
+ class Upsample(nn.Sequential):
742
+ """Upsample module.
743
+
744
+ Args:
745
+ scale (int): Scale factor. Supported scales: 2^n and 3.
746
+ num_feat (int): Channel number of intermediate features.
747
+ """
748
+
749
+ def __init__(self, scale, num_feat):
750
+ m = []
751
+ if (scale & (scale - 1)) == 0: # scale = 2^n
752
+ for _ in range(int(math.log(scale, 2))):
753
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
754
+ m.append(nn.PixelShuffle(2))
755
+ elif scale == 3:
756
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
757
+ m.append(nn.PixelShuffle(3))
758
+ else:
759
+ raise ValueError(
760
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
761
+ )
762
+ super(Upsample, self).__init__(*m)
763
+
764
+
765
+ class UpsampleOneStep(nn.Sequential):
766
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
767
+ Used in lightweight SR to save parameters.
768
+
769
+ Args:
770
+ scale (int): Scale factor. Supported scales: 2^n and 3.
771
+ num_feat (int): Channel number of intermediate features.
772
+
773
+ """
774
+
775
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
776
+ self.num_feat = num_feat
777
+ self.input_resolution = input_resolution
778
+ m = []
779
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
780
+ m.append(nn.PixelShuffle(scale))
781
+ super(UpsampleOneStep, self).__init__(*m)
782
+
783
+ def flops(self):
784
+ H, W = self.input_resolution
785
+ flops = H * W * self.num_feat * 3 * 9
786
+ return flops
787
+
788
+
789
+ class SwinIR(nn.Module):
790
+ r"""SwinIR denoising network.
791
+
792
+ The Swin Image Restoration (SwinIR) denoising network was introduced in `SwinIR: Image Restoration Using Swin
793
+ Transformer <https://arxiv.org/abs/2108.10257>`_. This code is adapted from the official implementation by the
794
+ authors.
795
+
796
+ :param int|tuple img_size: Input image size. Default 128.
797
+ :param int|tuple patch_size: Patch size. Default: 1.
798
+ :param int in_chans: Number of input image channels. Default: 3.
799
+ :param int embed_dim: Patch embedding dimension. Default: 180.
800
+ :param tuple depths: Depth of each Swin Transformer layer.
801
+ :param tuple num_heads: Number of attention heads in different layers.
802
+ :param int window_size: Window size. Default: 8.
803
+ :param float mlp_ratio: Ratio of mlp hidden dim to embedding dim. Default: 2.
804
+ :param bool qkv_bias: If True, add a learnable bias to query, key, value. Default: True.
805
+ :param float qk_scale: Override default qk scale of head_dim ** -0.5 if set. Default: None.
806
+ :param float drop_rate: Dropout rate. Default: 0.
807
+ :param float attn_drop_rate: Attention dropout rate. Default: 0.
808
+ :param float drop_path_rate: Stochastic depth rate. Default: 0.1.
809
+ :param nn.Module norm_layer: Normalization layer. Default: nn.LayerNorm.
810
+ :param bool ape: If True, add absolute position embedding to the patch embedding. Default: False.
811
+ :param bool patch_norm: If True, add normalization after patch embedding. Default: True.
812
+ :param bool use_checkpoint: Whether to use checkpointing to save memory. Default: False.
813
+ :param int upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
814
+ :param float img_range: Image range. 1. or 255. Default: 1.
815
+ :param str|None upsampler: The reconstruction module. ''/'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None.
816
+ Default: ''.
817
+ :param str resi_connection: The convolutional block before residual connection. Should be either '1conv' or '3conv'.
818
+ Default: '1conv'.
819
+ :param str|None pretrained: Use a pretrained network. If ``pretrained=None``, the weights will be initialized at
820
+ random using PyTorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from
821
+ the authors' online repository https://github.com/JingyunLiang/SwinIR/releases/tag/v0.0 (only available for the
822
+ default architecture). Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
823
+ Default: 'download'.
824
+ See :ref:`pretrained-weights <pretrained-weights>` for more details.
825
+ :param int pretrained_noise_level: The noise level of the pretrained model to be downloaded (in 0-255 scale). This
826
+ value is directly concatenated to the download url; should be chosen in the set {15, 25, 50}. Default: 15.
827
+ """
828
+
829
+ def __init__(
830
+ self,
831
+ img_size=128,
832
+ patch_size=1,
833
+ in_chans=3,
834
+ embed_dim=180,
835
+ depths=[6, 6, 6, 6, 6, 6],
836
+ num_heads=[6, 6, 6, 6, 6, 6],
837
+ window_size=8,
838
+ mlp_ratio=2,
839
+ qkv_bias=True,
840
+ qk_scale=None,
841
+ drop_rate=0.0,
842
+ attn_drop_rate=0.0,
843
+ drop_path_rate=0.1,
844
+ norm_layer=nn.LayerNorm,
845
+ ape=False,
846
+ patch_norm=True,
847
+ use_checkpoint=False,
848
+ upscale=1,
849
+ img_range=1.0,
850
+ upsampler="",
851
+ resi_connection="1conv",
852
+ pretrained="download",
853
+ pretrained_noise_level=15,
854
+ **kwargs,
855
+ ):
856
+ if isinstance(timm, ImportError):
857
+ raise ImportError(
858
+ "timm is needed to use the SCUNet class. Please install it with `pip install timm`"
859
+ ) from timm
860
+
861
+ super(SwinIR, self).__init__()
862
+ num_in_ch = in_chans
863
+ num_out_ch = in_chans
864
+ num_feat = 64
865
+ self.img_range = img_range
866
+ if in_chans == 3:
867
+ rgb_mean = (0.4488, 0.4371, 0.4040)
868
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
869
+ else:
870
+ self.mean = torch.zeros(1, 1, 1, 1)
871
+ self.upscale = upscale
872
+ self.upsampler = upsampler
873
+ self.window_size = window_size
874
+
875
+ #####################################################################################################
876
+ ################################### 1, shallow feature extraction ###################################
877
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
878
+
879
+ #####################################################################################################
880
+ ################################### 2, deep feature extraction ######################################
881
+ self.num_layers = len(depths)
882
+ self.embed_dim = embed_dim
883
+ self.ape = ape
884
+ self.patch_norm = patch_norm
885
+ self.num_features = embed_dim
886
+ self.mlp_ratio = mlp_ratio
887
+
888
+ # split image into non-overlapping patches
889
+ self.patch_embed = PatchEmbed(
890
+ img_size=img_size,
891
+ patch_size=patch_size,
892
+ in_chans=embed_dim,
893
+ embed_dim=embed_dim,
894
+ norm_layer=norm_layer if self.patch_norm else None,
895
+ )
896
+ num_patches = self.patch_embed.num_patches
897
+ patches_resolution = self.patch_embed.patches_resolution
898
+ self.patches_resolution = patches_resolution
899
+
900
+ # merge non-overlapping patches into image
901
+ self.patch_unembed = PatchUnEmbed(
902
+ img_size=img_size,
903
+ patch_size=patch_size,
904
+ in_chans=embed_dim,
905
+ embed_dim=embed_dim,
906
+ norm_layer=norm_layer if self.patch_norm else None,
907
+ )
908
+
909
+ # absolute position embedding
910
+ if self.ape:
911
+ self.absolute_pos_embed = nn.Parameter(
912
+ torch.zeros(1, num_patches, embed_dim)
913
+ )
914
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
915
+
916
+ self.pos_drop = nn.Dropout(p=drop_rate)
917
+
918
+ # stochastic depth
919
+ dpr = [
920
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
921
+ ] # stochastic depth decay rule
922
+
923
+ # build Residual Swin Transformer blocks (RSTB)
924
+ self.layers = nn.ModuleList()
925
+ for i_layer in range(self.num_layers):
926
+ layer = RSTB(
927
+ dim=embed_dim,
928
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
929
+ depth=depths[i_layer],
930
+ num_heads=num_heads[i_layer],
931
+ window_size=window_size,
932
+ mlp_ratio=self.mlp_ratio,
933
+ qkv_bias=qkv_bias,
934
+ qk_scale=qk_scale,
935
+ drop=drop_rate,
936
+ attn_drop=attn_drop_rate,
937
+ drop_path=dpr[
938
+ sum(depths[:i_layer]) : sum(depths[: i_layer + 1])
939
+ ], # no impact on SR results
940
+ norm_layer=norm_layer,
941
+ downsample=None,
942
+ use_checkpoint=use_checkpoint,
943
+ img_size=img_size,
944
+ patch_size=patch_size,
945
+ resi_connection=resi_connection,
946
+ )
947
+ self.layers.append(layer)
948
+ self.norm = norm_layer(self.num_features)
949
+
950
+ # build the last conv layer in deep feature extraction
951
+ if resi_connection == "1conv":
952
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
953
+ elif resi_connection == "3conv":
954
+ # to save parameters and memory
955
+ self.conv_after_body = nn.Sequential(
956
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
957
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
958
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
959
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
960
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
961
+ )
962
+
963
+ #####################################################################################################
964
+ ################################ 3, high quality image reconstruction ################################
965
+ if self.upsampler == "pixelshuffle":
966
+ # for classical SR
967
+ self.conv_before_upsample = nn.Sequential(
968
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
969
+ )
970
+ self.upsample = Upsample(upscale, num_feat)
971
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
972
+ elif self.upsampler == "pixelshuffledirect":
973
+ # for lightweight SR (to save parameters)
974
+ self.upsample = UpsampleOneStep(
975
+ upscale,
976
+ embed_dim,
977
+ num_out_ch,
978
+ (patches_resolution[0], patches_resolution[1]),
979
+ )
980
+ elif self.upsampler == "nearest+conv":
981
+ # for real-world SR (less artifacts)
982
+ self.conv_before_upsample = nn.Sequential(
983
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
984
+ )
985
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
986
+ if self.upscale == 4:
987
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
988
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
989
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
990
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
991
+ else:
992
+ # for image denoising and JPEG compression artifact reduction
993
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
994
+
995
+ self.apply(self._init_weights)
996
+
997
+ if pretrained is not None:
998
+ if pretrained == "download":
999
+ assert img_size == 128
1000
+ assert in_chans in [1, 3]
1001
+ assert upscale == 1
1002
+ assert window_size == 8
1003
+ assert img_range == 1.0
1004
+ assert embed_dim == 180
1005
+ assert mlp_ratio == 2
1006
+ assert upsampler == ""
1007
+ assert resi_connection == "1conv"
1008
+ assert depths == [6, 6, 6, 6, 6, 6]
1009
+ assert num_heads == [6, 6, 6, 6, 6, 6]
1010
+
1011
+ if in_chans == 1:
1012
+ weights_url = (
1013
+ "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/004_grayDN_DFWB_s128w8_SwinIR-M_noise"
1014
+ + str(pretrained_noise_level)
1015
+ + ".pth"
1016
+ )
1017
+ elif in_chans == 3:
1018
+ weights_url = (
1019
+ "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/005_colorDN_DFWB_s128w8_SwinIR-M_noise"
1020
+ + str(pretrained_noise_level)
1021
+ + ".pth"
1022
+ )
1023
+
1024
+ pretrained_weights = torch.hub.load_state_dict_from_url(
1025
+ weights_url, map_location=lambda storage, loc: storage
1026
+ )
1027
+ else:
1028
+ pretrained_weights = torch.load(
1029
+ pretrained, map_location=lambda storage, loc: storage
1030
+ )
1031
+ param_key_g = "params"
1032
+ pretrained_weights = (
1033
+ pretrained_weights[param_key_g]
1034
+ if param_key_g in pretrained_weights.keys()
1035
+ else pretrained_weights
1036
+ )
1037
+ self.load_state_dict(pretrained_weights, strict=True)
1038
+
1039
+ def _init_weights(self, m):
1040
+ if isinstance(m, nn.Linear):
1041
+ trunc_normal_(m.weight, std=0.02)
1042
+ if isinstance(m, nn.Linear) and m.bias is not None:
1043
+ nn.init.constant_(m.bias, 0)
1044
+ elif isinstance(m, nn.LayerNorm):
1045
+ nn.init.constant_(m.bias, 0)
1046
+ nn.init.constant_(m.weight, 1.0)
1047
+
1048
+ @torch.jit.ignore
1049
+ def no_weight_decay(self):
1050
+ return {"absolute_pos_embed"}
1051
+
1052
+ @torch.jit.ignore
1053
+ def no_weight_decay_keywords(self):
1054
+ return {"relative_position_bias_table"}
1055
+
1056
+ def check_image_size(self, x):
1057
+ _, _, h, w = x.size()
1058
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1059
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1060
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1061
+ return x
1062
+
1063
+ def forward_features(self, x):
1064
+ x_size = (x.shape[2], x.shape[3])
1065
+ x = self.patch_embed(x)
1066
+ if self.ape:
1067
+ x = x + self.absolute_pos_embed
1068
+ x = self.pos_drop(x)
1069
+
1070
+ for layer in self.layers:
1071
+ x = layer(x, x_size)
1072
+
1073
+ x = self.norm(x) # B L C
1074
+ x = self.patch_unembed(x, x_size)
1075
+
1076
+ return x
1077
+
1078
+ def forward(self, x, sigma=None):
1079
+ r"""
1080
+ Run the denoiser on noisy image. The noise level is not used in this denoiser.
1081
+
1082
+ :param torch.Tensor x: noisy image, of shape B, C, W, H.
1083
+ :param float sigma: noise level (not used).
1084
+ """
1085
+ H, W = x.shape[2:]
1086
+ x = self.check_image_size(x)
1087
+
1088
+ self.mean = self.mean.type_as(x)
1089
+ x = (x - self.mean) * self.img_range
1090
+
1091
+ if self.upsampler == "pixelshuffle":
1092
+ # for classical SR
1093
+ x = self.conv_first(x)
1094
+ x = self.conv_after_body(self.forward_features(x)) + x
1095
+ x = self.conv_before_upsample(x)
1096
+ x = self.conv_last(self.upsample(x))
1097
+ elif self.upsampler == "pixelshuffledirect":
1098
+ # for lightweight SR
1099
+ x = self.conv_first(x)
1100
+ x = self.conv_after_body(self.forward_features(x)) + x
1101
+ x = self.upsample(x)
1102
+ elif self.upsampler == "nearest+conv":
1103
+ # for real-world SR
1104
+ x = self.conv_first(x)
1105
+ x = self.conv_after_body(self.forward_features(x)) + x
1106
+ x = self.conv_before_upsample(x)
1107
+ x = self.lrelu(
1108
+ self.conv_up1(
1109
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
1110
+ )
1111
+ )
1112
+ if self.upscale == 4:
1113
+ x = self.lrelu(
1114
+ self.conv_up2(
1115
+ torch.nn.functional.interpolate(
1116
+ x, scale_factor=2, mode="nearest"
1117
+ )
1118
+ )
1119
+ )
1120
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1121
+ else:
1122
+ # for image denoising and JPEG compression artifact reduction
1123
+ x_first = self.conv_first(x)
1124
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1125
+ x = x + self.conv_last(res)
1126
+
1127
+ x = x / self.img_range + self.mean
1128
+
1129
+ return x[:, :, : H * self.upscale, : W * self.upscale]
1130
+
1131
+ def flops(self):
1132
+ flops = 0
1133
+ H, W = self.patches_resolution
1134
+ flops += H * W * 3 * self.embed_dim * 9
1135
+ flops += self.patch_embed.flops()
1136
+ for i, layer in enumerate(self.layers):
1137
+ flops += layer.flops()
1138
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1139
+ flops += self.upsample.flops()
1140
+ return flops