deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
deepinv/models/drunet.py
ADDED
|
@@ -0,0 +1,689 @@
|
|
|
1
|
+
# Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from .utils import get_weights_url
|
|
7
|
+
|
|
8
|
+
cuda = True if torch.cuda.is_available() else False
|
|
9
|
+
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DRUNet(nn.Module):
|
|
13
|
+
r"""
|
|
14
|
+
DRUNet denoiser network.
|
|
15
|
+
|
|
16
|
+
The network architecture is based on the paper
|
|
17
|
+
`Learning deep CNN denoiser prior for image restoration <https://arxiv.org/abs/1704.03264>`_,
|
|
18
|
+
and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts.
|
|
19
|
+
|
|
20
|
+
The network takes into account the noise level of the input image, which is encoded as an additional input channel.
|
|
21
|
+
|
|
22
|
+
A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3)
|
|
23
|
+
can be downloaded via setting ``pretrained='download'``.
|
|
24
|
+
|
|
25
|
+
:param int in_channels: number of channels of the input.
|
|
26
|
+
:param int out_channels: number of channels of the output.
|
|
27
|
+
:param list nc: number of convolutional layers.
|
|
28
|
+
:param int nb: number of convolutional blocks per layer.
|
|
29
|
+
:param int nf: number of channels per convolutional layer.
|
|
30
|
+
:param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus.
|
|
31
|
+
:param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and
|
|
32
|
+
"strideconv" for convolution with stride 2.
|
|
33
|
+
:param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel
|
|
34
|
+
shuffling, and "upconv" for nearest neighbour upsampling with additional convolution.
|
|
35
|
+
:param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
|
|
36
|
+
using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
|
|
37
|
+
online repository (only available for the default architecture with 3 or 1 input/output channels).
|
|
38
|
+
Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
|
|
39
|
+
See :ref:`pretrained-weights <pretrained-weights>` for more details.
|
|
40
|
+
:param bool train: training or testing mode.
|
|
41
|
+
:param str device: gpu or cpu.
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
in_channels=3,
|
|
48
|
+
out_channels=3,
|
|
49
|
+
nc=[64, 128, 256, 512],
|
|
50
|
+
nb=4,
|
|
51
|
+
act_mode="R",
|
|
52
|
+
downsample_mode="strideconv",
|
|
53
|
+
upsample_mode="convtranspose",
|
|
54
|
+
pretrained="download",
|
|
55
|
+
train=False,
|
|
56
|
+
device=None,
|
|
57
|
+
):
|
|
58
|
+
super(DRUNet, self).__init__()
|
|
59
|
+
in_channels = in_channels + 1 # accounts for the input noise channel
|
|
60
|
+
self.m_head = conv(in_channels, nc[0], bias=False, mode="C")
|
|
61
|
+
|
|
62
|
+
# downsample
|
|
63
|
+
if downsample_mode == "avgpool":
|
|
64
|
+
downsample_block = downsample_avgpool
|
|
65
|
+
elif downsample_mode == "maxpool":
|
|
66
|
+
downsample_block = downsample_maxpool
|
|
67
|
+
elif downsample_mode == "strideconv":
|
|
68
|
+
downsample_block = downsample_strideconv
|
|
69
|
+
else:
|
|
70
|
+
raise NotImplementedError(
|
|
71
|
+
"downsample mode [{:s}] is not found".format(downsample_mode)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self.m_down1 = sequential(
|
|
75
|
+
*[
|
|
76
|
+
ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C")
|
|
77
|
+
for _ in range(nb)
|
|
78
|
+
],
|
|
79
|
+
downsample_block(nc[0], nc[1], bias=False, mode="2"),
|
|
80
|
+
)
|
|
81
|
+
self.m_down2 = sequential(
|
|
82
|
+
*[
|
|
83
|
+
ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C")
|
|
84
|
+
for _ in range(nb)
|
|
85
|
+
],
|
|
86
|
+
downsample_block(nc[1], nc[2], bias=False, mode="2"),
|
|
87
|
+
)
|
|
88
|
+
self.m_down3 = sequential(
|
|
89
|
+
*[
|
|
90
|
+
ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C")
|
|
91
|
+
for _ in range(nb)
|
|
92
|
+
],
|
|
93
|
+
downsample_block(nc[2], nc[3], bias=False, mode="2"),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self.m_body = sequential(
|
|
97
|
+
*[
|
|
98
|
+
ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C")
|
|
99
|
+
for _ in range(nb)
|
|
100
|
+
]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# upsample
|
|
104
|
+
if upsample_mode == "upconv":
|
|
105
|
+
upsample_block = upsample_upconv
|
|
106
|
+
elif upsample_mode == "pixelshuffle":
|
|
107
|
+
upsample_block = upsample_pixelshuffle
|
|
108
|
+
elif upsample_mode == "convtranspose":
|
|
109
|
+
upsample_block = upsample_convtranspose
|
|
110
|
+
else:
|
|
111
|
+
raise NotImplementedError(
|
|
112
|
+
"upsample mode [{:s}] is not found".format(upsample_mode)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.m_up3 = sequential(
|
|
116
|
+
upsample_block(nc[3], nc[2], bias=False, mode="2"),
|
|
117
|
+
*[
|
|
118
|
+
ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C")
|
|
119
|
+
for _ in range(nb)
|
|
120
|
+
],
|
|
121
|
+
)
|
|
122
|
+
self.m_up2 = sequential(
|
|
123
|
+
upsample_block(nc[2], nc[1], bias=False, mode="2"),
|
|
124
|
+
*[
|
|
125
|
+
ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C")
|
|
126
|
+
for _ in range(nb)
|
|
127
|
+
],
|
|
128
|
+
)
|
|
129
|
+
self.m_up1 = sequential(
|
|
130
|
+
upsample_block(nc[1], nc[0], bias=False, mode="2"),
|
|
131
|
+
*[
|
|
132
|
+
ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C")
|
|
133
|
+
for _ in range(nb)
|
|
134
|
+
],
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
self.m_tail = conv(nc[0], out_channels, bias=False, mode="C")
|
|
138
|
+
if pretrained is not None:
|
|
139
|
+
if pretrained == "download":
|
|
140
|
+
if in_channels == 4:
|
|
141
|
+
name = "drunet_deepinv_color.pth"
|
|
142
|
+
elif in_channels == 2:
|
|
143
|
+
name = "drunet_deepinv_gray.pth"
|
|
144
|
+
url = get_weights_url(model_name="drunet", file_name=name)
|
|
145
|
+
ckpt_drunet = torch.hub.load_state_dict_from_url(
|
|
146
|
+
url, map_location=lambda storage, loc: storage, file_name=name
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
ckpt_drunet = torch.load(
|
|
150
|
+
pretrained, map_location=lambda storage, loc: storage
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self.load_state_dict(ckpt_drunet, strict=True)
|
|
154
|
+
|
|
155
|
+
if not train:
|
|
156
|
+
self.eval()
|
|
157
|
+
for _, v in self.named_parameters():
|
|
158
|
+
v.requires_grad = False
|
|
159
|
+
else:
|
|
160
|
+
self.apply(weights_init_drunet)
|
|
161
|
+
|
|
162
|
+
if device is not None:
|
|
163
|
+
self.to(device)
|
|
164
|
+
|
|
165
|
+
def forward_unet(self, x0):
|
|
166
|
+
x1 = self.m_head(x0)
|
|
167
|
+
x2 = self.m_down1(x1)
|
|
168
|
+
x3 = self.m_down2(x2)
|
|
169
|
+
x4 = self.m_down3(x3)
|
|
170
|
+
x = self.m_body(x4)
|
|
171
|
+
x = self.m_up3(x + x4)
|
|
172
|
+
x = self.m_up2(x + x3)
|
|
173
|
+
x = self.m_up1(x + x2)
|
|
174
|
+
x = self.m_tail(x + x1)
|
|
175
|
+
return x
|
|
176
|
+
|
|
177
|
+
def forward(self, x, sigma):
|
|
178
|
+
r"""
|
|
179
|
+
Run the denoiser on image with noise level :math:`\sigma`.
|
|
180
|
+
|
|
181
|
+
:param torch.Tensor x: noisy image
|
|
182
|
+
:param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch.
|
|
183
|
+
If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``.
|
|
184
|
+
"""
|
|
185
|
+
if isinstance(sigma, torch.Tensor):
|
|
186
|
+
if len(sigma.size()) > 0:
|
|
187
|
+
if x.get_device() > -1:
|
|
188
|
+
sigma = sigma[
|
|
189
|
+
int(x.get_device() * x.shape[0]) : int(
|
|
190
|
+
(x.get_device() + 1) * x.shape[0]
|
|
191
|
+
)
|
|
192
|
+
]
|
|
193
|
+
noise_level_map = sigma.to(x.device)
|
|
194
|
+
else:
|
|
195
|
+
noise_level_map = sigma.view(x.size(0), 1, 1, 1).to(x.device)
|
|
196
|
+
noise_level_map = noise_level_map.expand(-1, 1, x.size(2), x.size(3))
|
|
197
|
+
else:
|
|
198
|
+
sigma = sigma.item()
|
|
199
|
+
noise_level_map = (
|
|
200
|
+
torch.FloatTensor(x.size(0), 1, x.size(2), x.size(3))
|
|
201
|
+
.fill_(sigma)
|
|
202
|
+
.to(x.device)
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
noise_level_map = (
|
|
206
|
+
torch.FloatTensor(x.size(0), 1, x.size(2), x.size(3))
|
|
207
|
+
.fill_(sigma)
|
|
208
|
+
.to(x.device)
|
|
209
|
+
)
|
|
210
|
+
x = torch.cat((x, noise_level_map), 1)
|
|
211
|
+
if self.training or (
|
|
212
|
+
x.size(2) % 8 == 0
|
|
213
|
+
and x.size(3) % 8 == 0
|
|
214
|
+
and x.size(2) > 31
|
|
215
|
+
and x.size(3) > 31
|
|
216
|
+
):
|
|
217
|
+
x = self.forward_unet(x)
|
|
218
|
+
elif x.size(2) < 32 or x.size(3) < 32:
|
|
219
|
+
x = test_pad(self.forward_unet, x, modulo=16)
|
|
220
|
+
else:
|
|
221
|
+
x = test_onesplit(self.forward_unet, x, refield=64)
|
|
222
|
+
return x
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
"""
|
|
226
|
+
Functional blocks below
|
|
227
|
+
"""
|
|
228
|
+
from collections import OrderedDict
|
|
229
|
+
import torch
|
|
230
|
+
import torch.nn as nn
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
"""
|
|
234
|
+
# --------------------------------------------
|
|
235
|
+
# Advanced nn.Sequential
|
|
236
|
+
# https://github.com/xinntao/BasicSR
|
|
237
|
+
# --------------------------------------------
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def sequential(*args):
|
|
242
|
+
"""Advanced nn.Sequential.
|
|
243
|
+
Args:
|
|
244
|
+
nn.Sequential, nn.Module
|
|
245
|
+
Returns:
|
|
246
|
+
nn.Sequential
|
|
247
|
+
"""
|
|
248
|
+
if len(args) == 1:
|
|
249
|
+
if isinstance(args[0], OrderedDict):
|
|
250
|
+
raise NotImplementedError("sequential does not support OrderedDict input.")
|
|
251
|
+
return args[0] # No sequential is needed.
|
|
252
|
+
modules = []
|
|
253
|
+
for module in args:
|
|
254
|
+
if isinstance(module, nn.Sequential):
|
|
255
|
+
for submodule in module.children():
|
|
256
|
+
modules.append(submodule)
|
|
257
|
+
elif isinstance(module, nn.Module):
|
|
258
|
+
modules.append(module)
|
|
259
|
+
return nn.Sequential(*modules)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
"""
|
|
263
|
+
# --------------------------------------------
|
|
264
|
+
# Useful blocks
|
|
265
|
+
# https://github.com/xinntao/BasicSR
|
|
266
|
+
# --------------------------------
|
|
267
|
+
# conv + normaliation + relu (conv)
|
|
268
|
+
# (PixelUnShuffle)
|
|
269
|
+
# (ConditionalBatchNorm2d)
|
|
270
|
+
# concat (ConcatBlock)
|
|
271
|
+
# sum (ShortcutBlock)
|
|
272
|
+
# resblock (ResBlock)
|
|
273
|
+
# Channel Attention (CA) Layer (CALayer)
|
|
274
|
+
# Residual Channel Attention Block (RCABlock)
|
|
275
|
+
# Residual Channel Attention Group (RCAGroup)
|
|
276
|
+
# Residual Dense Block (ResidualDenseBlock_5C)
|
|
277
|
+
# Residual in Residual Dense Block (RRDB)
|
|
278
|
+
# --------------------------------------------
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# --------------------------------------------
|
|
283
|
+
# return nn.Sequantial of (Conv + BN + ReLU)
|
|
284
|
+
# --------------------------------------------
|
|
285
|
+
def conv(
|
|
286
|
+
in_channels=64,
|
|
287
|
+
out_channels=64,
|
|
288
|
+
kernel_size=3,
|
|
289
|
+
stride=1,
|
|
290
|
+
padding=1,
|
|
291
|
+
bias=True,
|
|
292
|
+
mode="CBR",
|
|
293
|
+
negative_slope=0.2,
|
|
294
|
+
):
|
|
295
|
+
L = []
|
|
296
|
+
for t in mode:
|
|
297
|
+
if t == "C":
|
|
298
|
+
L.append(
|
|
299
|
+
nn.Conv2d(
|
|
300
|
+
in_channels=in_channels,
|
|
301
|
+
out_channels=out_channels,
|
|
302
|
+
kernel_size=kernel_size,
|
|
303
|
+
stride=stride,
|
|
304
|
+
padding=padding,
|
|
305
|
+
bias=bias,
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
elif t == "T":
|
|
309
|
+
L.append(
|
|
310
|
+
nn.ConvTranspose2d(
|
|
311
|
+
in_channels=in_channels,
|
|
312
|
+
out_channels=out_channels,
|
|
313
|
+
kernel_size=kernel_size,
|
|
314
|
+
stride=stride,
|
|
315
|
+
padding=padding,
|
|
316
|
+
bias=bias,
|
|
317
|
+
)
|
|
318
|
+
)
|
|
319
|
+
elif t == "B":
|
|
320
|
+
L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
|
|
321
|
+
elif t == "I":
|
|
322
|
+
L.append(nn.InstanceNorm2d(out_channels, affine=True))
|
|
323
|
+
elif t == "R":
|
|
324
|
+
L.append(nn.ReLU(inplace=True))
|
|
325
|
+
elif t == "r":
|
|
326
|
+
L.append(nn.ReLU(inplace=False))
|
|
327
|
+
elif t == "L":
|
|
328
|
+
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
|
|
329
|
+
elif t == "l":
|
|
330
|
+
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
|
|
331
|
+
elif t == "E":
|
|
332
|
+
L.append(nn.ELU(inplace=False))
|
|
333
|
+
elif t == "s":
|
|
334
|
+
L.append(nn.Softplus())
|
|
335
|
+
elif t == "2":
|
|
336
|
+
L.append(nn.PixelShuffle(upscale_factor=2))
|
|
337
|
+
elif t == "3":
|
|
338
|
+
L.append(nn.PixelShuffle(upscale_factor=3))
|
|
339
|
+
elif t == "4":
|
|
340
|
+
L.append(nn.PixelShuffle(upscale_factor=4))
|
|
341
|
+
elif t == "U":
|
|
342
|
+
L.append(nn.Upsample(scale_factor=2, mode="nearest"))
|
|
343
|
+
elif t == "u":
|
|
344
|
+
L.append(nn.Upsample(scale_factor=3, mode="nearest"))
|
|
345
|
+
elif t == "v":
|
|
346
|
+
L.append(nn.Upsample(scale_factor=4, mode="nearest"))
|
|
347
|
+
elif t == "M":
|
|
348
|
+
L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
|
|
349
|
+
elif t == "A":
|
|
350
|
+
L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
|
|
351
|
+
else:
|
|
352
|
+
raise NotImplementedError("Undefined type: ".format(t))
|
|
353
|
+
return sequential(*L)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
# --------------------------------------------
|
|
357
|
+
# Res Block: x + conv(relu(conv(x)))
|
|
358
|
+
# --------------------------------------------
|
|
359
|
+
class ResBlock(nn.Module):
|
|
360
|
+
def __init__(
|
|
361
|
+
self,
|
|
362
|
+
in_channels=64,
|
|
363
|
+
out_channels=64,
|
|
364
|
+
kernel_size=3,
|
|
365
|
+
stride=1,
|
|
366
|
+
padding=1,
|
|
367
|
+
bias=True,
|
|
368
|
+
mode="CRC",
|
|
369
|
+
negative_slope=0.2,
|
|
370
|
+
):
|
|
371
|
+
super(ResBlock, self).__init__()
|
|
372
|
+
|
|
373
|
+
assert in_channels == out_channels, "Only support in_channels==out_channels."
|
|
374
|
+
if mode[0] in ["R", "L"]:
|
|
375
|
+
mode = mode[0].lower() + mode[1:]
|
|
376
|
+
|
|
377
|
+
self.res = conv(
|
|
378
|
+
in_channels,
|
|
379
|
+
out_channels,
|
|
380
|
+
kernel_size,
|
|
381
|
+
stride,
|
|
382
|
+
padding,
|
|
383
|
+
bias,
|
|
384
|
+
mode,
|
|
385
|
+
negative_slope,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def forward(self, x):
|
|
389
|
+
res = self.res(x)
|
|
390
|
+
return x + res
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
"""
|
|
394
|
+
# --------------------------------------------
|
|
395
|
+
# Upsampler
|
|
396
|
+
# Kai Zhang, https://github.com/cszn/KAIR
|
|
397
|
+
# --------------------------------------------
|
|
398
|
+
# upsample_pixelshuffle
|
|
399
|
+
# upsample_upconv
|
|
400
|
+
# upsample_convtranspose
|
|
401
|
+
# --------------------------------------------
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
# --------------------------------------------
|
|
406
|
+
# conv + subp (+ relu)
|
|
407
|
+
# --------------------------------------------
|
|
408
|
+
def upsample_pixelshuffle(
|
|
409
|
+
in_channels=64,
|
|
410
|
+
out_channels=3,
|
|
411
|
+
kernel_size=3,
|
|
412
|
+
stride=1,
|
|
413
|
+
padding=1,
|
|
414
|
+
bias=True,
|
|
415
|
+
mode="2R",
|
|
416
|
+
negative_slope=0.2,
|
|
417
|
+
):
|
|
418
|
+
assert len(mode) < 4 and mode[0] in [
|
|
419
|
+
"2",
|
|
420
|
+
"3",
|
|
421
|
+
"4",
|
|
422
|
+
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
|
|
423
|
+
up1 = conv(
|
|
424
|
+
in_channels,
|
|
425
|
+
out_channels * (int(mode[0]) ** 2),
|
|
426
|
+
kernel_size,
|
|
427
|
+
stride,
|
|
428
|
+
padding,
|
|
429
|
+
bias,
|
|
430
|
+
mode="C" + mode,
|
|
431
|
+
negative_slope=negative_slope,
|
|
432
|
+
)
|
|
433
|
+
return up1
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
# --------------------------------------------
|
|
437
|
+
# nearest_upsample + conv (+ R)
|
|
438
|
+
# --------------------------------------------
|
|
439
|
+
def upsample_upconv(
|
|
440
|
+
in_channels=64,
|
|
441
|
+
out_channels=3,
|
|
442
|
+
kernel_size=3,
|
|
443
|
+
stride=1,
|
|
444
|
+
padding=1,
|
|
445
|
+
bias=True,
|
|
446
|
+
mode="2R",
|
|
447
|
+
negative_slope=0.2,
|
|
448
|
+
):
|
|
449
|
+
assert len(mode) < 4 and mode[0] in [
|
|
450
|
+
"2",
|
|
451
|
+
"3",
|
|
452
|
+
"4",
|
|
453
|
+
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR"
|
|
454
|
+
if mode[0] == "2":
|
|
455
|
+
uc = "UC"
|
|
456
|
+
elif mode[0] == "3":
|
|
457
|
+
uc = "uC"
|
|
458
|
+
elif mode[0] == "4":
|
|
459
|
+
uc = "vC"
|
|
460
|
+
mode = mode.replace(mode[0], uc)
|
|
461
|
+
up1 = conv(
|
|
462
|
+
in_channels,
|
|
463
|
+
out_channels,
|
|
464
|
+
kernel_size,
|
|
465
|
+
stride,
|
|
466
|
+
padding,
|
|
467
|
+
bias,
|
|
468
|
+
mode=mode,
|
|
469
|
+
negative_slope=negative_slope,
|
|
470
|
+
)
|
|
471
|
+
return up1
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
# --------------------------------------------
|
|
475
|
+
# convTranspose (+ relu)
|
|
476
|
+
# --------------------------------------------
|
|
477
|
+
def upsample_convtranspose(
|
|
478
|
+
in_channels=64,
|
|
479
|
+
out_channels=3,
|
|
480
|
+
kernel_size=2,
|
|
481
|
+
stride=2,
|
|
482
|
+
padding=0,
|
|
483
|
+
bias=True,
|
|
484
|
+
mode="2R",
|
|
485
|
+
negative_slope=0.2,
|
|
486
|
+
):
|
|
487
|
+
assert len(mode) < 4 and mode[0] in [
|
|
488
|
+
"2",
|
|
489
|
+
"3",
|
|
490
|
+
"4",
|
|
491
|
+
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
|
|
492
|
+
kernel_size = int(mode[0])
|
|
493
|
+
stride = int(mode[0])
|
|
494
|
+
mode = mode.replace(mode[0], "T")
|
|
495
|
+
up1 = conv(
|
|
496
|
+
in_channels,
|
|
497
|
+
out_channels,
|
|
498
|
+
kernel_size,
|
|
499
|
+
stride,
|
|
500
|
+
padding,
|
|
501
|
+
bias,
|
|
502
|
+
mode,
|
|
503
|
+
negative_slope,
|
|
504
|
+
)
|
|
505
|
+
return up1
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
"""
|
|
509
|
+
# --------------------------------------------
|
|
510
|
+
# Downsampler
|
|
511
|
+
# Kai Zhang, https://github.com/cszn/KAIR
|
|
512
|
+
# --------------------------------------------
|
|
513
|
+
# downsample_strideconv
|
|
514
|
+
# downsample_maxpool
|
|
515
|
+
# downsample_avgpool
|
|
516
|
+
# --------------------------------------------
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
# --------------------------------------------
|
|
521
|
+
# strideconv (+ relu)
|
|
522
|
+
# --------------------------------------------
|
|
523
|
+
def downsample_strideconv(
|
|
524
|
+
in_channels=64,
|
|
525
|
+
out_channels=64,
|
|
526
|
+
kernel_size=2,
|
|
527
|
+
stride=2,
|
|
528
|
+
padding=0,
|
|
529
|
+
bias=True,
|
|
530
|
+
mode="2R",
|
|
531
|
+
negative_slope=0.2,
|
|
532
|
+
):
|
|
533
|
+
assert len(mode) < 4 and mode[0] in [
|
|
534
|
+
"2",
|
|
535
|
+
"3",
|
|
536
|
+
"4",
|
|
537
|
+
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
|
|
538
|
+
kernel_size = int(mode[0])
|
|
539
|
+
stride = int(mode[0])
|
|
540
|
+
mode = mode.replace(mode[0], "C")
|
|
541
|
+
down1 = conv(
|
|
542
|
+
in_channels,
|
|
543
|
+
out_channels,
|
|
544
|
+
kernel_size,
|
|
545
|
+
stride,
|
|
546
|
+
padding,
|
|
547
|
+
bias,
|
|
548
|
+
mode,
|
|
549
|
+
negative_slope,
|
|
550
|
+
)
|
|
551
|
+
return down1
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
# --------------------------------------------
|
|
555
|
+
# maxpooling + conv (+ relu)
|
|
556
|
+
# --------------------------------------------
|
|
557
|
+
def downsample_maxpool(
|
|
558
|
+
in_channels=64,
|
|
559
|
+
out_channels=64,
|
|
560
|
+
kernel_size=3,
|
|
561
|
+
stride=1,
|
|
562
|
+
padding=0,
|
|
563
|
+
bias=True,
|
|
564
|
+
mode="2R",
|
|
565
|
+
negative_slope=0.2,
|
|
566
|
+
):
|
|
567
|
+
assert len(mode) < 4 and mode[0] in [
|
|
568
|
+
"2",
|
|
569
|
+
"3",
|
|
570
|
+
], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
|
|
571
|
+
kernel_size_pool = int(mode[0])
|
|
572
|
+
stride_pool = int(mode[0])
|
|
573
|
+
mode = mode.replace(mode[0], "MC")
|
|
574
|
+
pool = conv(
|
|
575
|
+
kernel_size=kernel_size_pool,
|
|
576
|
+
stride=stride_pool,
|
|
577
|
+
mode=mode[0],
|
|
578
|
+
negative_slope=negative_slope,
|
|
579
|
+
)
|
|
580
|
+
pool_tail = conv(
|
|
581
|
+
in_channels,
|
|
582
|
+
out_channels,
|
|
583
|
+
kernel_size,
|
|
584
|
+
stride,
|
|
585
|
+
padding,
|
|
586
|
+
bias,
|
|
587
|
+
mode=mode[1:],
|
|
588
|
+
negative_slope=negative_slope,
|
|
589
|
+
)
|
|
590
|
+
return sequential(pool, pool_tail)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
# --------------------------------------------
|
|
594
|
+
# averagepooling + conv (+ relu)
|
|
595
|
+
# --------------------------------------------
|
|
596
|
+
def downsample_avgpool(
|
|
597
|
+
in_channels=64,
|
|
598
|
+
out_channels=64,
|
|
599
|
+
kernel_size=3,
|
|
600
|
+
stride=1,
|
|
601
|
+
padding=1,
|
|
602
|
+
bias=True,
|
|
603
|
+
mode="2R",
|
|
604
|
+
negative_slope=0.2,
|
|
605
|
+
):
|
|
606
|
+
assert len(mode) < 4 and mode[0] in [
|
|
607
|
+
"2",
|
|
608
|
+
"3",
|
|
609
|
+
], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
|
|
610
|
+
kernel_size_pool = int(mode[0])
|
|
611
|
+
stride_pool = int(mode[0])
|
|
612
|
+
mode = mode.replace(mode[0], "AC")
|
|
613
|
+
pool = conv(
|
|
614
|
+
kernel_size=kernel_size_pool,
|
|
615
|
+
stride=stride_pool,
|
|
616
|
+
mode=mode[0],
|
|
617
|
+
negative_slope=negative_slope,
|
|
618
|
+
)
|
|
619
|
+
pool_tail = conv(
|
|
620
|
+
in_channels,
|
|
621
|
+
out_channels,
|
|
622
|
+
kernel_size,
|
|
623
|
+
stride,
|
|
624
|
+
padding,
|
|
625
|
+
bias,
|
|
626
|
+
mode=mode[1:],
|
|
627
|
+
negative_slope=negative_slope,
|
|
628
|
+
)
|
|
629
|
+
return sequential(pool, pool_tail)
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
"""
|
|
633
|
+
Helpers for test time
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def test_onesplit(model, L, refield=32, sf=1):
|
|
638
|
+
"""
|
|
639
|
+
Changes the size of the image to fit the model's expected image size.
|
|
640
|
+
|
|
641
|
+
:param model: model.
|
|
642
|
+
:param L: input Low-quality image.
|
|
643
|
+
:param refield: effective receptive field of the network, 32 is enough.
|
|
644
|
+
:param sf: scale factor for super-resolution, otherwise 1.
|
|
645
|
+
"""
|
|
646
|
+
h, w = L.size()[-2:]
|
|
647
|
+
top = slice(0, (h // 2 // refield + 1) * refield)
|
|
648
|
+
bottom = slice(h - (h // 2 // refield + 1) * refield, h)
|
|
649
|
+
left = slice(0, (w // 2 // refield + 1) * refield)
|
|
650
|
+
right = slice(w - (w // 2 // refield + 1) * refield, w)
|
|
651
|
+
Ls = [
|
|
652
|
+
L[..., top, left],
|
|
653
|
+
L[..., top, right],
|
|
654
|
+
L[..., bottom, left],
|
|
655
|
+
L[..., bottom, right],
|
|
656
|
+
]
|
|
657
|
+
Es = [model(Ls[i]) for i in range(4)]
|
|
658
|
+
b, c = Es[0].size()[:2]
|
|
659
|
+
E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
|
|
660
|
+
E[..., : h // 2 * sf, : w // 2 * sf] = Es[0][..., : h // 2 * sf, : w // 2 * sf]
|
|
661
|
+
E[..., : h // 2 * sf, w // 2 * sf : w * sf] = Es[1][
|
|
662
|
+
..., : h // 2 * sf, (-w + w // 2) * sf :
|
|
663
|
+
]
|
|
664
|
+
E[..., h // 2 * sf : h * sf, : w // 2 * sf] = Es[2][
|
|
665
|
+
..., (-h + h // 2) * sf :, : w // 2 * sf
|
|
666
|
+
]
|
|
667
|
+
E[..., h // 2 * sf : h * sf, w // 2 * sf : w * sf] = Es[3][
|
|
668
|
+
..., (-h + h // 2) * sf :, (-w + w // 2) * sf :
|
|
669
|
+
]
|
|
670
|
+
return E
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def test_pad(model, L, modulo=16):
|
|
674
|
+
"""
|
|
675
|
+
Pads the image to fit the model's expected image size.
|
|
676
|
+
"""
|
|
677
|
+
h, w = L.size()[-2:]
|
|
678
|
+
padding_bottom = int(np.ceil(h / modulo) * modulo - h)
|
|
679
|
+
padding_right = int(np.ceil(w / modulo) * modulo - w)
|
|
680
|
+
L = torch.nn.ReplicationPad2d((0, padding_right, 0, padding_bottom))(L)
|
|
681
|
+
E = model(L)
|
|
682
|
+
E = E[..., :h, :w]
|
|
683
|
+
return E
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def weights_init_drunet(m):
|
|
687
|
+
classname = m.__class__.__name__
|
|
688
|
+
if classname.find("Conv") != -1:
|
|
689
|
+
nn.init.orthogonal_(m.weight.data, gain=0.2)
|