deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/models/swinir.py
ADDED
|
@@ -0,0 +1,1140 @@
|
|
|
1
|
+
# This file is taken (with only mild modifications) from the SwinIR repository:
|
|
2
|
+
# https://github.com/JingyunLiang/SwinIR/blob/main/models/network_swinir.py
|
|
3
|
+
# -----------------------------------------------------------------------------------
|
|
4
|
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
|
5
|
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
|
6
|
+
# -----------------------------------------------------------------------------------
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
import torch.utils.checkpoint as checkpoint
|
|
13
|
+
|
|
14
|
+
# Compatibility with optional dependency on timm
|
|
15
|
+
try:
|
|
16
|
+
import timm
|
|
17
|
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
|
18
|
+
except ImportError as e:
|
|
19
|
+
timm = e
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Mlp(nn.Module):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
in_features,
|
|
26
|
+
hidden_features=None,
|
|
27
|
+
out_features=None,
|
|
28
|
+
act_layer=nn.GELU,
|
|
29
|
+
drop=0.0,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
out_features = out_features or in_features
|
|
33
|
+
hidden_features = hidden_features or in_features
|
|
34
|
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
35
|
+
self.act = act_layer()
|
|
36
|
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
37
|
+
self.drop = nn.Dropout(drop)
|
|
38
|
+
|
|
39
|
+
def forward(self, x):
|
|
40
|
+
x = self.fc1(x)
|
|
41
|
+
x = self.act(x)
|
|
42
|
+
x = self.drop(x)
|
|
43
|
+
x = self.fc2(x)
|
|
44
|
+
x = self.drop(x)
|
|
45
|
+
return x
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def window_partition(x, window_size):
|
|
49
|
+
"""
|
|
50
|
+
Args:
|
|
51
|
+
x: (B, H, W, C)
|
|
52
|
+
window_size (int): window size
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
windows: (num_windows*B, window_size, window_size, C)
|
|
56
|
+
"""
|
|
57
|
+
B, H, W, C = x.shape
|
|
58
|
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
|
59
|
+
windows = (
|
|
60
|
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
61
|
+
)
|
|
62
|
+
return windows
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def window_reverse(windows, window_size, H, W):
|
|
66
|
+
"""
|
|
67
|
+
Args:
|
|
68
|
+
windows: (num_windows*B, window_size, window_size, C)
|
|
69
|
+
window_size (int): Window size
|
|
70
|
+
H (int): Height of image
|
|
71
|
+
W (int): Width of image
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
x: (B, H, W, C)
|
|
75
|
+
"""
|
|
76
|
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
77
|
+
x = windows.view(
|
|
78
|
+
B, H // window_size, W // window_size, window_size, window_size, -1
|
|
79
|
+
)
|
|
80
|
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
81
|
+
return x
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class WindowAttention(nn.Module):
|
|
85
|
+
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
|
|
86
|
+
It supports both of shifted and non-shifted window.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
dim (int): Number of input channels.
|
|
90
|
+
window_size (tuple[int]): The height and width of the window.
|
|
91
|
+
num_heads (int): Number of attention heads.
|
|
92
|
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
93
|
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
|
94
|
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
|
95
|
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
dim,
|
|
101
|
+
window_size,
|
|
102
|
+
num_heads,
|
|
103
|
+
qkv_bias=True,
|
|
104
|
+
qk_scale=None,
|
|
105
|
+
attn_drop=0.0,
|
|
106
|
+
proj_drop=0.0,
|
|
107
|
+
):
|
|
108
|
+
super().__init__()
|
|
109
|
+
self.dim = dim
|
|
110
|
+
self.window_size = window_size # Wh, Ww
|
|
111
|
+
self.num_heads = num_heads
|
|
112
|
+
head_dim = dim // num_heads
|
|
113
|
+
self.scale = qk_scale or head_dim**-0.5
|
|
114
|
+
|
|
115
|
+
# define a parameter table of relative position bias
|
|
116
|
+
self.relative_position_bias_table = nn.Parameter(
|
|
117
|
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
|
118
|
+
) # 2*Wh-1 * 2*Ww-1, nH
|
|
119
|
+
|
|
120
|
+
# get pair-wise relative position index for each token inside the window
|
|
121
|
+
coords_h = torch.arange(self.window_size[0])
|
|
122
|
+
coords_w = torch.arange(self.window_size[1])
|
|
123
|
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
124
|
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
125
|
+
relative_coords = (
|
|
126
|
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
127
|
+
) # 2, Wh*Ww, Wh*Ww
|
|
128
|
+
relative_coords = relative_coords.permute(
|
|
129
|
+
1, 2, 0
|
|
130
|
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
131
|
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
|
132
|
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
133
|
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
134
|
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
135
|
+
self.register_buffer("relative_position_index", relative_position_index)
|
|
136
|
+
|
|
137
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
138
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
139
|
+
self.proj = nn.Linear(dim, dim)
|
|
140
|
+
|
|
141
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
142
|
+
|
|
143
|
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
|
144
|
+
self.softmax = nn.Softmax(dim=-1)
|
|
145
|
+
|
|
146
|
+
def forward(self, x, mask=None):
|
|
147
|
+
"""
|
|
148
|
+
Args:
|
|
149
|
+
x: input features with shape of (num_windows*B, N, C)
|
|
150
|
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
|
151
|
+
"""
|
|
152
|
+
B_, N, C = x.shape
|
|
153
|
+
qkv = (
|
|
154
|
+
self.qkv(x)
|
|
155
|
+
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
|
|
156
|
+
.permute(2, 0, 3, 1, 4)
|
|
157
|
+
)
|
|
158
|
+
q, k, v = (
|
|
159
|
+
qkv[0],
|
|
160
|
+
qkv[1],
|
|
161
|
+
qkv[2],
|
|
162
|
+
) # make torchscript happy (cannot use tensor as tuple)
|
|
163
|
+
|
|
164
|
+
q = q * self.scale
|
|
165
|
+
attn = q @ k.transpose(-2, -1)
|
|
166
|
+
|
|
167
|
+
relative_position_bias = self.relative_position_bias_table[
|
|
168
|
+
self.relative_position_index.view(-1)
|
|
169
|
+
].view(
|
|
170
|
+
self.window_size[0] * self.window_size[1],
|
|
171
|
+
self.window_size[0] * self.window_size[1],
|
|
172
|
+
-1,
|
|
173
|
+
) # Wh*Ww,Wh*Ww,nH
|
|
174
|
+
relative_position_bias = relative_position_bias.permute(
|
|
175
|
+
2, 0, 1
|
|
176
|
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
177
|
+
attn = attn + relative_position_bias.unsqueeze(0)
|
|
178
|
+
|
|
179
|
+
if mask is not None:
|
|
180
|
+
nW = mask.shape[0]
|
|
181
|
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
|
|
182
|
+
1
|
|
183
|
+
).unsqueeze(0)
|
|
184
|
+
attn = attn.view(-1, self.num_heads, N, N)
|
|
185
|
+
attn = self.softmax(attn)
|
|
186
|
+
else:
|
|
187
|
+
attn = self.softmax(attn)
|
|
188
|
+
|
|
189
|
+
attn = self.attn_drop(attn)
|
|
190
|
+
|
|
191
|
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
|
192
|
+
x = self.proj(x)
|
|
193
|
+
x = self.proj_drop(x)
|
|
194
|
+
return x
|
|
195
|
+
|
|
196
|
+
def extra_repr(self) -> str:
|
|
197
|
+
return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
|
|
198
|
+
|
|
199
|
+
def flops(self, N):
|
|
200
|
+
# calculate flops for 1 window with token length of N
|
|
201
|
+
flops = 0
|
|
202
|
+
# qkv = self.qkv(x)
|
|
203
|
+
flops += N * self.dim * 3 * self.dim
|
|
204
|
+
# attn = (q @ k.transpose(-2, -1))
|
|
205
|
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
|
206
|
+
# x = (attn @ v)
|
|
207
|
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
|
208
|
+
# x = self.proj(x)
|
|
209
|
+
flops += N * self.dim * self.dim
|
|
210
|
+
return flops
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class SwinTransformerBlock(nn.Module):
|
|
214
|
+
r"""Swin Transformer Block.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
dim (int): Number of input channels.
|
|
218
|
+
input_resolution (tuple[int]): Input resulotion.
|
|
219
|
+
num_heads (int): Number of attention heads.
|
|
220
|
+
window_size (int): Window size.
|
|
221
|
+
shift_size (int): Shift size for SW-MSA.
|
|
222
|
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
223
|
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
224
|
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
225
|
+
drop (float, optional): Dropout rate. Default: 0.0
|
|
226
|
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
227
|
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
|
228
|
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
|
229
|
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
dim,
|
|
235
|
+
input_resolution,
|
|
236
|
+
num_heads,
|
|
237
|
+
window_size=7,
|
|
238
|
+
shift_size=0,
|
|
239
|
+
mlp_ratio=4.0,
|
|
240
|
+
qkv_bias=True,
|
|
241
|
+
qk_scale=None,
|
|
242
|
+
drop=0.0,
|
|
243
|
+
attn_drop=0.0,
|
|
244
|
+
drop_path=0.0,
|
|
245
|
+
act_layer=nn.GELU,
|
|
246
|
+
norm_layer=nn.LayerNorm,
|
|
247
|
+
):
|
|
248
|
+
super().__init__()
|
|
249
|
+
self.dim = dim
|
|
250
|
+
self.input_resolution = input_resolution
|
|
251
|
+
self.num_heads = num_heads
|
|
252
|
+
self.window_size = window_size
|
|
253
|
+
self.shift_size = shift_size
|
|
254
|
+
self.mlp_ratio = mlp_ratio
|
|
255
|
+
if min(self.input_resolution) <= self.window_size:
|
|
256
|
+
# if window size is larger than input resolution, we don't partition windows
|
|
257
|
+
self.shift_size = 0
|
|
258
|
+
self.window_size = min(self.input_resolution)
|
|
259
|
+
assert (
|
|
260
|
+
0 <= self.shift_size < self.window_size
|
|
261
|
+
), "shift_size must in 0-window_size"
|
|
262
|
+
|
|
263
|
+
self.norm1 = norm_layer(dim)
|
|
264
|
+
self.attn = WindowAttention(
|
|
265
|
+
dim,
|
|
266
|
+
window_size=to_2tuple(self.window_size),
|
|
267
|
+
num_heads=num_heads,
|
|
268
|
+
qkv_bias=qkv_bias,
|
|
269
|
+
qk_scale=qk_scale,
|
|
270
|
+
attn_drop=attn_drop,
|
|
271
|
+
proj_drop=drop,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
275
|
+
self.norm2 = norm_layer(dim)
|
|
276
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
277
|
+
self.mlp = Mlp(
|
|
278
|
+
in_features=dim,
|
|
279
|
+
hidden_features=mlp_hidden_dim,
|
|
280
|
+
act_layer=act_layer,
|
|
281
|
+
drop=drop,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
if self.shift_size > 0:
|
|
285
|
+
attn_mask = self.calculate_mask(self.input_resolution)
|
|
286
|
+
else:
|
|
287
|
+
attn_mask = None
|
|
288
|
+
|
|
289
|
+
self.register_buffer("attn_mask", attn_mask)
|
|
290
|
+
|
|
291
|
+
def calculate_mask(self, x_size):
|
|
292
|
+
# calculate attention mask for SW-MSA
|
|
293
|
+
H, W = x_size
|
|
294
|
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
|
295
|
+
h_slices = (
|
|
296
|
+
slice(0, -self.window_size),
|
|
297
|
+
slice(-self.window_size, -self.shift_size),
|
|
298
|
+
slice(-self.shift_size, None),
|
|
299
|
+
)
|
|
300
|
+
w_slices = (
|
|
301
|
+
slice(0, -self.window_size),
|
|
302
|
+
slice(-self.window_size, -self.shift_size),
|
|
303
|
+
slice(-self.shift_size, None),
|
|
304
|
+
)
|
|
305
|
+
cnt = 0
|
|
306
|
+
for h in h_slices:
|
|
307
|
+
for w in w_slices:
|
|
308
|
+
img_mask[:, h, w, :] = cnt
|
|
309
|
+
cnt += 1
|
|
310
|
+
|
|
311
|
+
mask_windows = window_partition(
|
|
312
|
+
img_mask, self.window_size
|
|
313
|
+
) # nW, window_size, window_size, 1
|
|
314
|
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
|
315
|
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
316
|
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
|
317
|
+
attn_mask == 0, float(0.0)
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
return attn_mask
|
|
321
|
+
|
|
322
|
+
def forward(self, x, x_size):
|
|
323
|
+
H, W = x_size
|
|
324
|
+
B, L, C = x.shape
|
|
325
|
+
# assert L == H * W, "input feature has wrong size"
|
|
326
|
+
|
|
327
|
+
shortcut = x
|
|
328
|
+
x = self.norm1(x)
|
|
329
|
+
x = x.view(B, H, W, C)
|
|
330
|
+
|
|
331
|
+
# cyclic shift
|
|
332
|
+
if self.shift_size > 0:
|
|
333
|
+
shifted_x = torch.roll(
|
|
334
|
+
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
|
|
335
|
+
)
|
|
336
|
+
else:
|
|
337
|
+
shifted_x = x
|
|
338
|
+
|
|
339
|
+
# partition windows
|
|
340
|
+
x_windows = window_partition(
|
|
341
|
+
shifted_x, self.window_size
|
|
342
|
+
) # nW*B, window_size, window_size, C
|
|
343
|
+
x_windows = x_windows.view(
|
|
344
|
+
-1, self.window_size * self.window_size, C
|
|
345
|
+
) # nW*B, window_size*window_size, C
|
|
346
|
+
|
|
347
|
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
|
348
|
+
if self.input_resolution == x_size:
|
|
349
|
+
attn_windows = self.attn(
|
|
350
|
+
x_windows, mask=self.attn_mask
|
|
351
|
+
) # nW*B, window_size*window_size, C
|
|
352
|
+
else:
|
|
353
|
+
attn_windows = self.attn(
|
|
354
|
+
x_windows, mask=self.calculate_mask(x_size).to(x.device)
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# merge windows
|
|
358
|
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
|
359
|
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
|
360
|
+
|
|
361
|
+
# reverse cyclic shift
|
|
362
|
+
if self.shift_size > 0:
|
|
363
|
+
x = torch.roll(
|
|
364
|
+
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
|
|
365
|
+
)
|
|
366
|
+
else:
|
|
367
|
+
x = shifted_x
|
|
368
|
+
x = x.view(B, H * W, C)
|
|
369
|
+
|
|
370
|
+
# FFN
|
|
371
|
+
x = shortcut + self.drop_path(x)
|
|
372
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
373
|
+
|
|
374
|
+
return x
|
|
375
|
+
|
|
376
|
+
def extra_repr(self) -> str:
|
|
377
|
+
return (
|
|
378
|
+
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
|
379
|
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
def flops(self):
|
|
383
|
+
flops = 0
|
|
384
|
+
H, W = self.input_resolution
|
|
385
|
+
# norm1
|
|
386
|
+
flops += self.dim * H * W
|
|
387
|
+
# W-MSA/SW-MSA
|
|
388
|
+
nW = H * W / self.window_size / self.window_size
|
|
389
|
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
|
390
|
+
# mlp
|
|
391
|
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
|
392
|
+
# norm2
|
|
393
|
+
flops += self.dim * H * W
|
|
394
|
+
return flops
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class PatchMerging(nn.Module):
|
|
398
|
+
r"""Patch Merging Layer.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
input_resolution (tuple[int]): Resolution of input feature.
|
|
402
|
+
dim (int): Number of input channels.
|
|
403
|
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
|
407
|
+
super().__init__()
|
|
408
|
+
self.input_resolution = input_resolution
|
|
409
|
+
self.dim = dim
|
|
410
|
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
|
411
|
+
self.norm = norm_layer(4 * dim)
|
|
412
|
+
|
|
413
|
+
def forward(self, x):
|
|
414
|
+
"""
|
|
415
|
+
x: B, H*W, C
|
|
416
|
+
"""
|
|
417
|
+
H, W = self.input_resolution
|
|
418
|
+
B, L, C = x.shape
|
|
419
|
+
assert L == H * W, "input feature has wrong size"
|
|
420
|
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
|
421
|
+
|
|
422
|
+
x = x.view(B, H, W, C)
|
|
423
|
+
|
|
424
|
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
|
425
|
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
|
426
|
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
|
427
|
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
|
428
|
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
|
429
|
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
|
430
|
+
|
|
431
|
+
x = self.norm(x)
|
|
432
|
+
x = self.reduction(x)
|
|
433
|
+
|
|
434
|
+
return x
|
|
435
|
+
|
|
436
|
+
def extra_repr(self) -> str:
|
|
437
|
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
|
438
|
+
|
|
439
|
+
def flops(self):
|
|
440
|
+
H, W = self.input_resolution
|
|
441
|
+
flops = H * W * self.dim
|
|
442
|
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
|
443
|
+
return flops
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class BasicLayer(nn.Module):
|
|
447
|
+
"""A basic Swin Transformer layer for one stage.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
dim (int): Number of input channels.
|
|
451
|
+
input_resolution (tuple[int]): Input resolution.
|
|
452
|
+
depth (int): Number of blocks.
|
|
453
|
+
num_heads (int): Number of attention heads.
|
|
454
|
+
window_size (int): Local window size.
|
|
455
|
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
456
|
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
457
|
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
458
|
+
drop (float, optional): Dropout rate. Default: 0.0
|
|
459
|
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
460
|
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
|
461
|
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
462
|
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
|
463
|
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
|
464
|
+
"""
|
|
465
|
+
|
|
466
|
+
def __init__(
|
|
467
|
+
self,
|
|
468
|
+
dim,
|
|
469
|
+
input_resolution,
|
|
470
|
+
depth,
|
|
471
|
+
num_heads,
|
|
472
|
+
window_size,
|
|
473
|
+
mlp_ratio=4.0,
|
|
474
|
+
qkv_bias=True,
|
|
475
|
+
qk_scale=None,
|
|
476
|
+
drop=0.0,
|
|
477
|
+
attn_drop=0.0,
|
|
478
|
+
drop_path=0.0,
|
|
479
|
+
norm_layer=nn.LayerNorm,
|
|
480
|
+
downsample=None,
|
|
481
|
+
use_checkpoint=False,
|
|
482
|
+
):
|
|
483
|
+
super().__init__()
|
|
484
|
+
self.dim = dim
|
|
485
|
+
self.input_resolution = input_resolution
|
|
486
|
+
self.depth = depth
|
|
487
|
+
self.use_checkpoint = use_checkpoint
|
|
488
|
+
|
|
489
|
+
# build blocks
|
|
490
|
+
self.blocks = nn.ModuleList(
|
|
491
|
+
[
|
|
492
|
+
SwinTransformerBlock(
|
|
493
|
+
dim=dim,
|
|
494
|
+
input_resolution=input_resolution,
|
|
495
|
+
num_heads=num_heads,
|
|
496
|
+
window_size=window_size,
|
|
497
|
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
|
498
|
+
mlp_ratio=mlp_ratio,
|
|
499
|
+
qkv_bias=qkv_bias,
|
|
500
|
+
qk_scale=qk_scale,
|
|
501
|
+
drop=drop,
|
|
502
|
+
attn_drop=attn_drop,
|
|
503
|
+
drop_path=drop_path[i]
|
|
504
|
+
if isinstance(drop_path, list)
|
|
505
|
+
else drop_path,
|
|
506
|
+
norm_layer=norm_layer,
|
|
507
|
+
)
|
|
508
|
+
for i in range(depth)
|
|
509
|
+
]
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
# patch merging layer
|
|
513
|
+
if downsample is not None:
|
|
514
|
+
self.downsample = downsample(
|
|
515
|
+
input_resolution, dim=dim, norm_layer=norm_layer
|
|
516
|
+
)
|
|
517
|
+
else:
|
|
518
|
+
self.downsample = None
|
|
519
|
+
|
|
520
|
+
def forward(self, x, x_size):
|
|
521
|
+
for blk in self.blocks:
|
|
522
|
+
if self.use_checkpoint:
|
|
523
|
+
x = checkpoint.checkpoint(blk, x, x_size)
|
|
524
|
+
else:
|
|
525
|
+
x = blk(x, x_size)
|
|
526
|
+
if self.downsample is not None:
|
|
527
|
+
x = self.downsample(x)
|
|
528
|
+
return x
|
|
529
|
+
|
|
530
|
+
def extra_repr(self) -> str:
|
|
531
|
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
532
|
+
|
|
533
|
+
def flops(self):
|
|
534
|
+
flops = 0
|
|
535
|
+
for blk in self.blocks:
|
|
536
|
+
flops += blk.flops()
|
|
537
|
+
if self.downsample is not None:
|
|
538
|
+
flops += self.downsample.flops()
|
|
539
|
+
return flops
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class RSTB(nn.Module):
|
|
543
|
+
"""Residual Swin Transformer Block (RSTB).
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
dim (int): Number of input channels.
|
|
547
|
+
input_resolution (tuple[int]): Input resolution.
|
|
548
|
+
depth (int): Number of blocks.
|
|
549
|
+
num_heads (int): Number of attention heads.
|
|
550
|
+
window_size (int): Local window size.
|
|
551
|
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
552
|
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
553
|
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
554
|
+
drop (float, optional): Dropout rate. Default: 0.0
|
|
555
|
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
556
|
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
|
557
|
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
558
|
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
|
559
|
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
|
560
|
+
img_size: Input image size.
|
|
561
|
+
patch_size: Patch size.
|
|
562
|
+
resi_connection: The convolutional block before residual connection.
|
|
563
|
+
"""
|
|
564
|
+
|
|
565
|
+
def __init__(
|
|
566
|
+
self,
|
|
567
|
+
dim,
|
|
568
|
+
input_resolution,
|
|
569
|
+
depth,
|
|
570
|
+
num_heads,
|
|
571
|
+
window_size,
|
|
572
|
+
mlp_ratio=4.0,
|
|
573
|
+
qkv_bias=True,
|
|
574
|
+
qk_scale=None,
|
|
575
|
+
drop=0.0,
|
|
576
|
+
attn_drop=0.0,
|
|
577
|
+
drop_path=0.0,
|
|
578
|
+
norm_layer=nn.LayerNorm,
|
|
579
|
+
downsample=None,
|
|
580
|
+
use_checkpoint=False,
|
|
581
|
+
img_size=224,
|
|
582
|
+
patch_size=4,
|
|
583
|
+
resi_connection="1conv",
|
|
584
|
+
):
|
|
585
|
+
super(RSTB, self).__init__()
|
|
586
|
+
|
|
587
|
+
self.dim = dim
|
|
588
|
+
self.input_resolution = input_resolution
|
|
589
|
+
|
|
590
|
+
self.residual_group = BasicLayer(
|
|
591
|
+
dim=dim,
|
|
592
|
+
input_resolution=input_resolution,
|
|
593
|
+
depth=depth,
|
|
594
|
+
num_heads=num_heads,
|
|
595
|
+
window_size=window_size,
|
|
596
|
+
mlp_ratio=mlp_ratio,
|
|
597
|
+
qkv_bias=qkv_bias,
|
|
598
|
+
qk_scale=qk_scale,
|
|
599
|
+
drop=drop,
|
|
600
|
+
attn_drop=attn_drop,
|
|
601
|
+
drop_path=drop_path,
|
|
602
|
+
norm_layer=norm_layer,
|
|
603
|
+
downsample=downsample,
|
|
604
|
+
use_checkpoint=use_checkpoint,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
if resi_connection == "1conv":
|
|
608
|
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
|
609
|
+
elif resi_connection == "3conv":
|
|
610
|
+
# to save parameters and memory
|
|
611
|
+
self.conv = nn.Sequential(
|
|
612
|
+
nn.Conv2d(dim, dim // 4, 3, 1, 1),
|
|
613
|
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
614
|
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
|
615
|
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
616
|
+
nn.Conv2d(dim // 4, dim, 3, 1, 1),
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
self.patch_embed = PatchEmbed(
|
|
620
|
+
img_size=img_size,
|
|
621
|
+
patch_size=patch_size,
|
|
622
|
+
in_chans=0,
|
|
623
|
+
embed_dim=dim,
|
|
624
|
+
norm_layer=None,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
self.patch_unembed = PatchUnEmbed(
|
|
628
|
+
img_size=img_size,
|
|
629
|
+
patch_size=patch_size,
|
|
630
|
+
in_chans=0,
|
|
631
|
+
embed_dim=dim,
|
|
632
|
+
norm_layer=None,
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
def forward(self, x, x_size):
|
|
636
|
+
return (
|
|
637
|
+
self.patch_embed(
|
|
638
|
+
self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
|
|
639
|
+
)
|
|
640
|
+
+ x
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
def flops(self):
|
|
644
|
+
flops = 0
|
|
645
|
+
flops += self.residual_group.flops()
|
|
646
|
+
H, W = self.input_resolution
|
|
647
|
+
flops += H * W * self.dim * self.dim * 9
|
|
648
|
+
flops += self.patch_embed.flops()
|
|
649
|
+
flops += self.patch_unembed.flops()
|
|
650
|
+
|
|
651
|
+
return flops
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
class PatchEmbed(nn.Module):
|
|
655
|
+
r"""Image to Patch Embedding
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
img_size (int): Image size. Default: 224.
|
|
659
|
+
patch_size (int): Patch token size. Default: 4.
|
|
660
|
+
in_chans (int): Number of input image channels. Default: 3.
|
|
661
|
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
|
662
|
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
|
663
|
+
"""
|
|
664
|
+
|
|
665
|
+
def __init__(
|
|
666
|
+
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
|
|
667
|
+
):
|
|
668
|
+
super().__init__()
|
|
669
|
+
img_size = to_2tuple(img_size)
|
|
670
|
+
patch_size = to_2tuple(patch_size)
|
|
671
|
+
patches_resolution = [
|
|
672
|
+
img_size[0] // patch_size[0],
|
|
673
|
+
img_size[1] // patch_size[1],
|
|
674
|
+
]
|
|
675
|
+
self.img_size = img_size
|
|
676
|
+
self.patch_size = patch_size
|
|
677
|
+
self.patches_resolution = patches_resolution
|
|
678
|
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
|
679
|
+
|
|
680
|
+
self.in_chans = in_chans
|
|
681
|
+
self.embed_dim = embed_dim
|
|
682
|
+
|
|
683
|
+
if norm_layer is not None:
|
|
684
|
+
self.norm = norm_layer(embed_dim)
|
|
685
|
+
else:
|
|
686
|
+
self.norm = None
|
|
687
|
+
|
|
688
|
+
def forward(self, x):
|
|
689
|
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
|
690
|
+
if self.norm is not None:
|
|
691
|
+
x = self.norm(x)
|
|
692
|
+
return x
|
|
693
|
+
|
|
694
|
+
def flops(self):
|
|
695
|
+
flops = 0
|
|
696
|
+
H, W = self.img_size
|
|
697
|
+
if self.norm is not None:
|
|
698
|
+
flops += H * W * self.embed_dim
|
|
699
|
+
return flops
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class PatchUnEmbed(nn.Module):
|
|
703
|
+
r"""Image to Patch Unembedding
|
|
704
|
+
|
|
705
|
+
Args:
|
|
706
|
+
img_size (int): Image size. Default: 224.
|
|
707
|
+
patch_size (int): Patch token size. Default: 4.
|
|
708
|
+
in_chans (int): Number of input image channels. Default: 3.
|
|
709
|
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
|
710
|
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
|
711
|
+
"""
|
|
712
|
+
|
|
713
|
+
def __init__(
|
|
714
|
+
self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
|
|
715
|
+
):
|
|
716
|
+
super().__init__()
|
|
717
|
+
img_size = to_2tuple(img_size)
|
|
718
|
+
patch_size = to_2tuple(patch_size)
|
|
719
|
+
patches_resolution = [
|
|
720
|
+
img_size[0] // patch_size[0],
|
|
721
|
+
img_size[1] // patch_size[1],
|
|
722
|
+
]
|
|
723
|
+
self.img_size = img_size
|
|
724
|
+
self.patch_size = patch_size
|
|
725
|
+
self.patches_resolution = patches_resolution
|
|
726
|
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
|
727
|
+
|
|
728
|
+
self.in_chans = in_chans
|
|
729
|
+
self.embed_dim = embed_dim
|
|
730
|
+
|
|
731
|
+
def forward(self, x, x_size):
|
|
732
|
+
B, HW, C = x.shape
|
|
733
|
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
|
734
|
+
return x
|
|
735
|
+
|
|
736
|
+
def flops(self):
|
|
737
|
+
flops = 0
|
|
738
|
+
return flops
|
|
739
|
+
|
|
740
|
+
|
|
741
|
+
class Upsample(nn.Sequential):
|
|
742
|
+
"""Upsample module.
|
|
743
|
+
|
|
744
|
+
Args:
|
|
745
|
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
|
746
|
+
num_feat (int): Channel number of intermediate features.
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
def __init__(self, scale, num_feat):
|
|
750
|
+
m = []
|
|
751
|
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
|
752
|
+
for _ in range(int(math.log(scale, 2))):
|
|
753
|
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
|
754
|
+
m.append(nn.PixelShuffle(2))
|
|
755
|
+
elif scale == 3:
|
|
756
|
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
|
757
|
+
m.append(nn.PixelShuffle(3))
|
|
758
|
+
else:
|
|
759
|
+
raise ValueError(
|
|
760
|
+
f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
|
|
761
|
+
)
|
|
762
|
+
super(Upsample, self).__init__(*m)
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
class UpsampleOneStep(nn.Sequential):
|
|
766
|
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
|
767
|
+
Used in lightweight SR to save parameters.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
|
771
|
+
num_feat (int): Channel number of intermediate features.
|
|
772
|
+
|
|
773
|
+
"""
|
|
774
|
+
|
|
775
|
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
|
776
|
+
self.num_feat = num_feat
|
|
777
|
+
self.input_resolution = input_resolution
|
|
778
|
+
m = []
|
|
779
|
+
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
|
|
780
|
+
m.append(nn.PixelShuffle(scale))
|
|
781
|
+
super(UpsampleOneStep, self).__init__(*m)
|
|
782
|
+
|
|
783
|
+
def flops(self):
|
|
784
|
+
H, W = self.input_resolution
|
|
785
|
+
flops = H * W * self.num_feat * 3 * 9
|
|
786
|
+
return flops
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
class SwinIR(nn.Module):
|
|
790
|
+
r"""SwinIR denoising network.
|
|
791
|
+
|
|
792
|
+
The Swin Image Restoration (SwinIR) denoising network was introduced in `SwinIR: Image Restoration Using Swin
|
|
793
|
+
Transformer <https://arxiv.org/abs/2108.10257>`_. This code is adapted from the official implementation by the
|
|
794
|
+
authors.
|
|
795
|
+
|
|
796
|
+
:param int|tuple img_size: Input image size. Default 128.
|
|
797
|
+
:param int|tuple patch_size: Patch size. Default: 1.
|
|
798
|
+
:param int in_chans: Number of input image channels. Default: 3.
|
|
799
|
+
:param int embed_dim: Patch embedding dimension. Default: 180.
|
|
800
|
+
:param tuple depths: Depth of each Swin Transformer layer.
|
|
801
|
+
:param tuple num_heads: Number of attention heads in different layers.
|
|
802
|
+
:param int window_size: Window size. Default: 8.
|
|
803
|
+
:param float mlp_ratio: Ratio of mlp hidden dim to embedding dim. Default: 2.
|
|
804
|
+
:param bool qkv_bias: If True, add a learnable bias to query, key, value. Default: True.
|
|
805
|
+
:param float qk_scale: Override default qk scale of head_dim ** -0.5 if set. Default: None.
|
|
806
|
+
:param float drop_rate: Dropout rate. Default: 0.
|
|
807
|
+
:param float attn_drop_rate: Attention dropout rate. Default: 0.
|
|
808
|
+
:param float drop_path_rate: Stochastic depth rate. Default: 0.1.
|
|
809
|
+
:param nn.Module norm_layer: Normalization layer. Default: nn.LayerNorm.
|
|
810
|
+
:param bool ape: If True, add absolute position embedding to the patch embedding. Default: False.
|
|
811
|
+
:param bool patch_norm: If True, add normalization after patch embedding. Default: True.
|
|
812
|
+
:param bool use_checkpoint: Whether to use checkpointing to save memory. Default: False.
|
|
813
|
+
:param int upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
|
814
|
+
:param float img_range: Image range. 1. or 255. Default: 1.
|
|
815
|
+
:param str|None upsampler: The reconstruction module. ''/'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None.
|
|
816
|
+
Default: ''.
|
|
817
|
+
:param str resi_connection: The convolutional block before residual connection. Should be either '1conv' or '3conv'.
|
|
818
|
+
Default: '1conv'.
|
|
819
|
+
:param str|None pretrained: Use a pretrained network. If ``pretrained=None``, the weights will be initialized at
|
|
820
|
+
random using PyTorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from
|
|
821
|
+
the authors' online repository https://github.com/JingyunLiang/SwinIR/releases/tag/v0.0 (only available for the
|
|
822
|
+
default architecture). Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
|
|
823
|
+
Default: 'download'.
|
|
824
|
+
See :ref:`pretrained-weights <pretrained-weights>` for more details.
|
|
825
|
+
:param int pretrained_noise_level: The noise level of the pretrained model to be downloaded (in 0-255 scale). This
|
|
826
|
+
value is directly concatenated to the download url; should be chosen in the set {15, 25, 50}. Default: 15.
|
|
827
|
+
"""
|
|
828
|
+
|
|
829
|
+
def __init__(
|
|
830
|
+
self,
|
|
831
|
+
img_size=128,
|
|
832
|
+
patch_size=1,
|
|
833
|
+
in_chans=3,
|
|
834
|
+
embed_dim=180,
|
|
835
|
+
depths=[6, 6, 6, 6, 6, 6],
|
|
836
|
+
num_heads=[6, 6, 6, 6, 6, 6],
|
|
837
|
+
window_size=8,
|
|
838
|
+
mlp_ratio=2,
|
|
839
|
+
qkv_bias=True,
|
|
840
|
+
qk_scale=None,
|
|
841
|
+
drop_rate=0.0,
|
|
842
|
+
attn_drop_rate=0.0,
|
|
843
|
+
drop_path_rate=0.1,
|
|
844
|
+
norm_layer=nn.LayerNorm,
|
|
845
|
+
ape=False,
|
|
846
|
+
patch_norm=True,
|
|
847
|
+
use_checkpoint=False,
|
|
848
|
+
upscale=1,
|
|
849
|
+
img_range=1.0,
|
|
850
|
+
upsampler="",
|
|
851
|
+
resi_connection="1conv",
|
|
852
|
+
pretrained="download",
|
|
853
|
+
pretrained_noise_level=15,
|
|
854
|
+
**kwargs,
|
|
855
|
+
):
|
|
856
|
+
if isinstance(timm, ImportError):
|
|
857
|
+
raise ImportError(
|
|
858
|
+
"timm is needed to use the SCUNet class. Please install it with `pip install timm`"
|
|
859
|
+
) from timm
|
|
860
|
+
|
|
861
|
+
super(SwinIR, self).__init__()
|
|
862
|
+
num_in_ch = in_chans
|
|
863
|
+
num_out_ch = in_chans
|
|
864
|
+
num_feat = 64
|
|
865
|
+
self.img_range = img_range
|
|
866
|
+
if in_chans == 3:
|
|
867
|
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
|
868
|
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
|
869
|
+
else:
|
|
870
|
+
self.mean = torch.zeros(1, 1, 1, 1)
|
|
871
|
+
self.upscale = upscale
|
|
872
|
+
self.upsampler = upsampler
|
|
873
|
+
self.window_size = window_size
|
|
874
|
+
|
|
875
|
+
#####################################################################################################
|
|
876
|
+
################################### 1, shallow feature extraction ###################################
|
|
877
|
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
|
878
|
+
|
|
879
|
+
#####################################################################################################
|
|
880
|
+
################################### 2, deep feature extraction ######################################
|
|
881
|
+
self.num_layers = len(depths)
|
|
882
|
+
self.embed_dim = embed_dim
|
|
883
|
+
self.ape = ape
|
|
884
|
+
self.patch_norm = patch_norm
|
|
885
|
+
self.num_features = embed_dim
|
|
886
|
+
self.mlp_ratio = mlp_ratio
|
|
887
|
+
|
|
888
|
+
# split image into non-overlapping patches
|
|
889
|
+
self.patch_embed = PatchEmbed(
|
|
890
|
+
img_size=img_size,
|
|
891
|
+
patch_size=patch_size,
|
|
892
|
+
in_chans=embed_dim,
|
|
893
|
+
embed_dim=embed_dim,
|
|
894
|
+
norm_layer=norm_layer if self.patch_norm else None,
|
|
895
|
+
)
|
|
896
|
+
num_patches = self.patch_embed.num_patches
|
|
897
|
+
patches_resolution = self.patch_embed.patches_resolution
|
|
898
|
+
self.patches_resolution = patches_resolution
|
|
899
|
+
|
|
900
|
+
# merge non-overlapping patches into image
|
|
901
|
+
self.patch_unembed = PatchUnEmbed(
|
|
902
|
+
img_size=img_size,
|
|
903
|
+
patch_size=patch_size,
|
|
904
|
+
in_chans=embed_dim,
|
|
905
|
+
embed_dim=embed_dim,
|
|
906
|
+
norm_layer=norm_layer if self.patch_norm else None,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
# absolute position embedding
|
|
910
|
+
if self.ape:
|
|
911
|
+
self.absolute_pos_embed = nn.Parameter(
|
|
912
|
+
torch.zeros(1, num_patches, embed_dim)
|
|
913
|
+
)
|
|
914
|
+
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
|
915
|
+
|
|
916
|
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
917
|
+
|
|
918
|
+
# stochastic depth
|
|
919
|
+
dpr = [
|
|
920
|
+
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
|
921
|
+
] # stochastic depth decay rule
|
|
922
|
+
|
|
923
|
+
# build Residual Swin Transformer blocks (RSTB)
|
|
924
|
+
self.layers = nn.ModuleList()
|
|
925
|
+
for i_layer in range(self.num_layers):
|
|
926
|
+
layer = RSTB(
|
|
927
|
+
dim=embed_dim,
|
|
928
|
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
|
929
|
+
depth=depths[i_layer],
|
|
930
|
+
num_heads=num_heads[i_layer],
|
|
931
|
+
window_size=window_size,
|
|
932
|
+
mlp_ratio=self.mlp_ratio,
|
|
933
|
+
qkv_bias=qkv_bias,
|
|
934
|
+
qk_scale=qk_scale,
|
|
935
|
+
drop=drop_rate,
|
|
936
|
+
attn_drop=attn_drop_rate,
|
|
937
|
+
drop_path=dpr[
|
|
938
|
+
sum(depths[:i_layer]) : sum(depths[: i_layer + 1])
|
|
939
|
+
], # no impact on SR results
|
|
940
|
+
norm_layer=norm_layer,
|
|
941
|
+
downsample=None,
|
|
942
|
+
use_checkpoint=use_checkpoint,
|
|
943
|
+
img_size=img_size,
|
|
944
|
+
patch_size=patch_size,
|
|
945
|
+
resi_connection=resi_connection,
|
|
946
|
+
)
|
|
947
|
+
self.layers.append(layer)
|
|
948
|
+
self.norm = norm_layer(self.num_features)
|
|
949
|
+
|
|
950
|
+
# build the last conv layer in deep feature extraction
|
|
951
|
+
if resi_connection == "1conv":
|
|
952
|
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
|
953
|
+
elif resi_connection == "3conv":
|
|
954
|
+
# to save parameters and memory
|
|
955
|
+
self.conv_after_body = nn.Sequential(
|
|
956
|
+
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
|
957
|
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
958
|
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
|
959
|
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
960
|
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
#####################################################################################################
|
|
964
|
+
################################ 3, high quality image reconstruction ################################
|
|
965
|
+
if self.upsampler == "pixelshuffle":
|
|
966
|
+
# for classical SR
|
|
967
|
+
self.conv_before_upsample = nn.Sequential(
|
|
968
|
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
|
|
969
|
+
)
|
|
970
|
+
self.upsample = Upsample(upscale, num_feat)
|
|
971
|
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
972
|
+
elif self.upsampler == "pixelshuffledirect":
|
|
973
|
+
# for lightweight SR (to save parameters)
|
|
974
|
+
self.upsample = UpsampleOneStep(
|
|
975
|
+
upscale,
|
|
976
|
+
embed_dim,
|
|
977
|
+
num_out_ch,
|
|
978
|
+
(patches_resolution[0], patches_resolution[1]),
|
|
979
|
+
)
|
|
980
|
+
elif self.upsampler == "nearest+conv":
|
|
981
|
+
# for real-world SR (less artifacts)
|
|
982
|
+
self.conv_before_upsample = nn.Sequential(
|
|
983
|
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
|
|
984
|
+
)
|
|
985
|
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
986
|
+
if self.upscale == 4:
|
|
987
|
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
988
|
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
989
|
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
990
|
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
991
|
+
else:
|
|
992
|
+
# for image denoising and JPEG compression artifact reduction
|
|
993
|
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
|
994
|
+
|
|
995
|
+
self.apply(self._init_weights)
|
|
996
|
+
|
|
997
|
+
if pretrained is not None:
|
|
998
|
+
if pretrained == "download":
|
|
999
|
+
assert img_size == 128
|
|
1000
|
+
assert in_chans in [1, 3]
|
|
1001
|
+
assert upscale == 1
|
|
1002
|
+
assert window_size == 8
|
|
1003
|
+
assert img_range == 1.0
|
|
1004
|
+
assert embed_dim == 180
|
|
1005
|
+
assert mlp_ratio == 2
|
|
1006
|
+
assert upsampler == ""
|
|
1007
|
+
assert resi_connection == "1conv"
|
|
1008
|
+
assert depths == [6, 6, 6, 6, 6, 6]
|
|
1009
|
+
assert num_heads == [6, 6, 6, 6, 6, 6]
|
|
1010
|
+
|
|
1011
|
+
if in_chans == 1:
|
|
1012
|
+
weights_url = (
|
|
1013
|
+
"https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/004_grayDN_DFWB_s128w8_SwinIR-M_noise"
|
|
1014
|
+
+ str(pretrained_noise_level)
|
|
1015
|
+
+ ".pth"
|
|
1016
|
+
)
|
|
1017
|
+
elif in_chans == 3:
|
|
1018
|
+
weights_url = (
|
|
1019
|
+
"https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/005_colorDN_DFWB_s128w8_SwinIR-M_noise"
|
|
1020
|
+
+ str(pretrained_noise_level)
|
|
1021
|
+
+ ".pth"
|
|
1022
|
+
)
|
|
1023
|
+
|
|
1024
|
+
pretrained_weights = torch.hub.load_state_dict_from_url(
|
|
1025
|
+
weights_url, map_location=lambda storage, loc: storage
|
|
1026
|
+
)
|
|
1027
|
+
else:
|
|
1028
|
+
pretrained_weights = torch.load(
|
|
1029
|
+
pretrained, map_location=lambda storage, loc: storage
|
|
1030
|
+
)
|
|
1031
|
+
param_key_g = "params"
|
|
1032
|
+
pretrained_weights = (
|
|
1033
|
+
pretrained_weights[param_key_g]
|
|
1034
|
+
if param_key_g in pretrained_weights.keys()
|
|
1035
|
+
else pretrained_weights
|
|
1036
|
+
)
|
|
1037
|
+
self.load_state_dict(pretrained_weights, strict=True)
|
|
1038
|
+
|
|
1039
|
+
def _init_weights(self, m):
|
|
1040
|
+
if isinstance(m, nn.Linear):
|
|
1041
|
+
trunc_normal_(m.weight, std=0.02)
|
|
1042
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
1043
|
+
nn.init.constant_(m.bias, 0)
|
|
1044
|
+
elif isinstance(m, nn.LayerNorm):
|
|
1045
|
+
nn.init.constant_(m.bias, 0)
|
|
1046
|
+
nn.init.constant_(m.weight, 1.0)
|
|
1047
|
+
|
|
1048
|
+
@torch.jit.ignore
|
|
1049
|
+
def no_weight_decay(self):
|
|
1050
|
+
return {"absolute_pos_embed"}
|
|
1051
|
+
|
|
1052
|
+
@torch.jit.ignore
|
|
1053
|
+
def no_weight_decay_keywords(self):
|
|
1054
|
+
return {"relative_position_bias_table"}
|
|
1055
|
+
|
|
1056
|
+
def check_image_size(self, x):
|
|
1057
|
+
_, _, h, w = x.size()
|
|
1058
|
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
|
1059
|
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
|
1060
|
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
|
1061
|
+
return x
|
|
1062
|
+
|
|
1063
|
+
def forward_features(self, x):
|
|
1064
|
+
x_size = (x.shape[2], x.shape[3])
|
|
1065
|
+
x = self.patch_embed(x)
|
|
1066
|
+
if self.ape:
|
|
1067
|
+
x = x + self.absolute_pos_embed
|
|
1068
|
+
x = self.pos_drop(x)
|
|
1069
|
+
|
|
1070
|
+
for layer in self.layers:
|
|
1071
|
+
x = layer(x, x_size)
|
|
1072
|
+
|
|
1073
|
+
x = self.norm(x) # B L C
|
|
1074
|
+
x = self.patch_unembed(x, x_size)
|
|
1075
|
+
|
|
1076
|
+
return x
|
|
1077
|
+
|
|
1078
|
+
def forward(self, x, sigma=None):
|
|
1079
|
+
r"""
|
|
1080
|
+
Run the denoiser on noisy image. The noise level is not used in this denoiser.
|
|
1081
|
+
|
|
1082
|
+
:param torch.Tensor x: noisy image, of shape B, C, W, H.
|
|
1083
|
+
:param float sigma: noise level (not used).
|
|
1084
|
+
"""
|
|
1085
|
+
H, W = x.shape[2:]
|
|
1086
|
+
x = self.check_image_size(x)
|
|
1087
|
+
|
|
1088
|
+
self.mean = self.mean.type_as(x)
|
|
1089
|
+
x = (x - self.mean) * self.img_range
|
|
1090
|
+
|
|
1091
|
+
if self.upsampler == "pixelshuffle":
|
|
1092
|
+
# for classical SR
|
|
1093
|
+
x = self.conv_first(x)
|
|
1094
|
+
x = self.conv_after_body(self.forward_features(x)) + x
|
|
1095
|
+
x = self.conv_before_upsample(x)
|
|
1096
|
+
x = self.conv_last(self.upsample(x))
|
|
1097
|
+
elif self.upsampler == "pixelshuffledirect":
|
|
1098
|
+
# for lightweight SR
|
|
1099
|
+
x = self.conv_first(x)
|
|
1100
|
+
x = self.conv_after_body(self.forward_features(x)) + x
|
|
1101
|
+
x = self.upsample(x)
|
|
1102
|
+
elif self.upsampler == "nearest+conv":
|
|
1103
|
+
# for real-world SR
|
|
1104
|
+
x = self.conv_first(x)
|
|
1105
|
+
x = self.conv_after_body(self.forward_features(x)) + x
|
|
1106
|
+
x = self.conv_before_upsample(x)
|
|
1107
|
+
x = self.lrelu(
|
|
1108
|
+
self.conv_up1(
|
|
1109
|
+
torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
|
|
1110
|
+
)
|
|
1111
|
+
)
|
|
1112
|
+
if self.upscale == 4:
|
|
1113
|
+
x = self.lrelu(
|
|
1114
|
+
self.conv_up2(
|
|
1115
|
+
torch.nn.functional.interpolate(
|
|
1116
|
+
x, scale_factor=2, mode="nearest"
|
|
1117
|
+
)
|
|
1118
|
+
)
|
|
1119
|
+
)
|
|
1120
|
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
|
1121
|
+
else:
|
|
1122
|
+
# for image denoising and JPEG compression artifact reduction
|
|
1123
|
+
x_first = self.conv_first(x)
|
|
1124
|
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
|
1125
|
+
x = x + self.conv_last(res)
|
|
1126
|
+
|
|
1127
|
+
x = x / self.img_range + self.mean
|
|
1128
|
+
|
|
1129
|
+
return x[:, :, : H * self.upscale, : W * self.upscale]
|
|
1130
|
+
|
|
1131
|
+
def flops(self):
|
|
1132
|
+
flops = 0
|
|
1133
|
+
H, W = self.patches_resolution
|
|
1134
|
+
flops += H * W * 3 * self.embed_dim * 9
|
|
1135
|
+
flops += self.patch_embed.flops()
|
|
1136
|
+
for i, layer in enumerate(self.layers):
|
|
1137
|
+
flops += layer.flops()
|
|
1138
|
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
|
1139
|
+
flops += self.upsample.flops()
|
|
1140
|
+
return flops
|