deepinv 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepinv/__about__.py +17 -0
- deepinv/__init__.py +71 -0
- deepinv/datasets/__init__.py +1 -0
- deepinv/datasets/datagenerator.py +238 -0
- deepinv/loss/__init__.py +10 -0
- deepinv/loss/ei.py +76 -0
- deepinv/loss/mc.py +39 -0
- deepinv/loss/measplit.py +219 -0
- deepinv/loss/metric.py +125 -0
- deepinv/loss/moi.py +64 -0
- deepinv/loss/regularisers.py +155 -0
- deepinv/loss/score.py +41 -0
- deepinv/loss/sup.py +37 -0
- deepinv/loss/sure.py +338 -0
- deepinv/loss/tv.py +39 -0
- deepinv/models/GSPnP.py +129 -0
- deepinv/models/PDNet.py +109 -0
- deepinv/models/__init__.py +17 -0
- deepinv/models/ae.py +43 -0
- deepinv/models/artifactremoval.py +56 -0
- deepinv/models/bm3d.py +57 -0
- deepinv/models/diffunet.py +997 -0
- deepinv/models/dip.py +214 -0
- deepinv/models/dncnn.py +131 -0
- deepinv/models/drunet.py +689 -0
- deepinv/models/equivariant.py +135 -0
- deepinv/models/median.py +51 -0
- deepinv/models/scunet.py +490 -0
- deepinv/models/swinir.py +1140 -0
- deepinv/models/tgv.py +232 -0
- deepinv/models/tv.py +146 -0
- deepinv/models/unet.py +337 -0
- deepinv/models/utils.py +22 -0
- deepinv/models/wavdict.py +231 -0
- deepinv/optim/__init__.py +5 -0
- deepinv/optim/data_fidelity.py +607 -0
- deepinv/optim/fixed_point.py +289 -0
- deepinv/optim/optim_iterators/__init__.py +9 -0
- deepinv/optim/optim_iterators/admm.py +117 -0
- deepinv/optim/optim_iterators/drs.py +115 -0
- deepinv/optim/optim_iterators/gradient_descent.py +90 -0
- deepinv/optim/optim_iterators/hqs.py +74 -0
- deepinv/optim/optim_iterators/optim_iterator.py +141 -0
- deepinv/optim/optim_iterators/pgd.py +91 -0
- deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
- deepinv/optim/optim_iterators/utils.py +17 -0
- deepinv/optim/optimizers.py +563 -0
- deepinv/optim/prior.py +288 -0
- deepinv/optim/utils.py +80 -0
- deepinv/physics/__init__.py +18 -0
- deepinv/physics/blur.py +544 -0
- deepinv/physics/compressed_sensing.py +197 -0
- deepinv/physics/forward.py +547 -0
- deepinv/physics/haze.py +65 -0
- deepinv/physics/inpainting.py +48 -0
- deepinv/physics/lidar.py +123 -0
- deepinv/physics/mri.py +329 -0
- deepinv/physics/noise.py +180 -0
- deepinv/physics/range.py +53 -0
- deepinv/physics/remote_sensing.py +123 -0
- deepinv/physics/singlepixel.py +218 -0
- deepinv/physics/tomography.py +321 -0
- deepinv/sampling/__init__.py +2 -0
- deepinv/sampling/diffusion.py +676 -0
- deepinv/sampling/langevin.py +512 -0
- deepinv/sampling/utils.py +35 -0
- deepinv/tests/conftest.py +39 -0
- deepinv/tests/dummy_datasets/datasets.py +57 -0
- deepinv/tests/test_loss.py +269 -0
- deepinv/tests/test_loss_train.py +179 -0
- deepinv/tests/test_models.py +377 -0
- deepinv/tests/test_optim.py +647 -0
- deepinv/tests/test_physics.py +316 -0
- deepinv/tests/test_sampling.py +158 -0
- deepinv/tests/test_unfolded.py +158 -0
- deepinv/tests/test_utils.py +68 -0
- deepinv/training_utils.py +529 -0
- deepinv/transform/__init__.py +2 -0
- deepinv/transform/rotate.py +41 -0
- deepinv/transform/shift.py +26 -0
- deepinv/unfolded/__init__.py +2 -0
- deepinv/unfolded/deep_equilibrium.py +163 -0
- deepinv/unfolded/unfolded.py +87 -0
- deepinv/utils/__init__.py +17 -0
- deepinv/utils/demo.py +171 -0
- deepinv/utils/logger.py +93 -0
- deepinv/utils/metric.py +87 -0
- deepinv/utils/nn.py +213 -0
- deepinv/utils/optimization.py +108 -0
- deepinv/utils/parameters.py +43 -0
- deepinv/utils/phantoms.py +115 -0
- deepinv/utils/plotting.py +312 -0
- deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
- deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
- deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
- deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
- deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EquivariantDenoiser(torch.nn.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Turns the input denoiser into an equivariant denoiser with respect to geometric transforms.
|
|
9
|
+
|
|
10
|
+
Recall that a denoiser is equivariant with respect to a group of transformations if it commutes with the action of
|
|
11
|
+
the group. More precisely, let :math:`\mathcal{G}` be a group of transformations :math:`\{T_g\}_{g\in \mathcal{G}}`
|
|
12
|
+
and :math:`\denoisername` a denoiser. Then, :math:`\denoisername` is equivariant with respect to :math:`\mathcal{G}`
|
|
13
|
+
if :math:`\denoisername(T_g(x)) = T_g(\denoisername(x))` for any image :math:`x` and any :math:`g\in \mathcal{G}`.
|
|
14
|
+
|
|
15
|
+
The denoiser can be turned into an equivariant denoiser by averaging over the group of transforms, i.e.
|
|
16
|
+
|
|
17
|
+
.. math::
|
|
18
|
+
\operatorname{D}^{\text{eq}}_{\sigma}(x) = \frac{1}{|\mathcal{G}|}\sum_{g\in \mathcal{G}} T_g^{-1}(\operatorname{D}_{\sigma}(T_g(x))).
|
|
19
|
+
|
|
20
|
+
Otherwise, as proposed in <https://arxiv.org/abs/2312.01831>`_, a Monte-Carlo approximation can be obtained by
|
|
21
|
+
sampling :math:`g \sim \mathcal{G}` at random and applying
|
|
22
|
+
|
|
23
|
+
.. math::
|
|
24
|
+
\operatorname{D}^{\text{MC}}_{\sigma}(x) = T_g^{-1}(\operatorname{D}_{\sigma}(T_g(x))).
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
:param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
|
|
28
|
+
:param str transform: type of geometric transformation. Can be either 'rotations', 'flips' or 'rotoflips'.
|
|
29
|
+
If 'rotations', the group of transformations contains the 4 rotations by multiples of 90 degrees; if 'flips',
|
|
30
|
+
the group of transformations contains the 2 horizontal and vertical flips; if 'rotoflips', the group of
|
|
31
|
+
transformations contains the 8 rotations and flips.
|
|
32
|
+
:param bool random: if True, the denoiser is applied to a randomly transformed version of the input image.
|
|
33
|
+
If False, the denoiser is applied to the average of all the transformed images, turning the denoiser into an
|
|
34
|
+
equivariant denoiser with respect to the chosen group of transformations. Otherwise, it is a Monte-Carlo
|
|
35
|
+
approximation of an equivariant denoiser.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, denoiser, transform="rotations", random=True):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.denoiser = denoiser
|
|
41
|
+
self.rotations = True if "rot" in transform else False
|
|
42
|
+
self.flips = True if "flip" in transform else False
|
|
43
|
+
self.random = random
|
|
44
|
+
|
|
45
|
+
def forward(self, x, sigma):
|
|
46
|
+
r"""
|
|
47
|
+
Applies the denoiser to the input image with the appropriate transformation.
|
|
48
|
+
|
|
49
|
+
:param torch.Tensor x: input image.
|
|
50
|
+
:param float sigma: noise level.
|
|
51
|
+
:return: denoised image.
|
|
52
|
+
"""
|
|
53
|
+
return denoise_rotate(
|
|
54
|
+
self.denoiser,
|
|
55
|
+
x,
|
|
56
|
+
sigma,
|
|
57
|
+
rotations=self.rotations,
|
|
58
|
+
flips=self.flips,
|
|
59
|
+
random=self.random,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def denoise_rotate(
|
|
64
|
+
denoiser,
|
|
65
|
+
image,
|
|
66
|
+
sigma,
|
|
67
|
+
rotations=True,
|
|
68
|
+
flips=False,
|
|
69
|
+
random=True,
|
|
70
|
+
):
|
|
71
|
+
r"""
|
|
72
|
+
Applies a geometric transform (rotations and/or flips) to the input image, denoises it with the denoiser and
|
|
73
|
+
transform back the result. The output is either the average of all the transformed images (if random=False) or a
|
|
74
|
+
randomly transformed version of the denoised image (if random=True).
|
|
75
|
+
|
|
76
|
+
:param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
|
|
77
|
+
:param torch.Tensor image: input image.
|
|
78
|
+
:param float sigma: noise level.
|
|
79
|
+
:param bool rotations: if True, rotations are applied to the input image.
|
|
80
|
+
:param bool flips: if True, flips are applied to the input image.
|
|
81
|
+
:param bool random: if True, the denoiser is applied to a randomly transformed version of the input image.
|
|
82
|
+
:return: denoised image.
|
|
83
|
+
"""
|
|
84
|
+
if random:
|
|
85
|
+
if rotations:
|
|
86
|
+
idx = np.random.randint(8) if flips else np.random.randint(4)
|
|
87
|
+
elif flips:
|
|
88
|
+
idx = np.random.choice([4, 6])
|
|
89
|
+
denoised = denoise_rotate_flip_fn(denoiser, image, sigma, idx)
|
|
90
|
+
else:
|
|
91
|
+
if rotations:
|
|
92
|
+
list_idx = list(range(8)) if flips else list(range(4))
|
|
93
|
+
elif flips:
|
|
94
|
+
list_idx = [4, 6]
|
|
95
|
+
denoised = torch.zeros_like(image)
|
|
96
|
+
for idx in list_idx:
|
|
97
|
+
denoised = denoised + denoise_rotate_flip_fn(denoiser, image, sigma, idx)
|
|
98
|
+
denoised = denoised / len(list_idx)
|
|
99
|
+
return denoised
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def denoise_rotate_flip_fn(denoiser, x, sigma_den, idx):
|
|
103
|
+
if idx == 0:
|
|
104
|
+
out = denoiser(x, sigma_den)
|
|
105
|
+
elif idx == 1:
|
|
106
|
+
out = rot3(denoiser(rot1(x), sigma_den))
|
|
107
|
+
elif idx == 2:
|
|
108
|
+
out = rot2(denoiser(rot2(x), sigma_den))
|
|
109
|
+
elif idx == 3:
|
|
110
|
+
out = rot1(denoiser(rot3(x), sigma_den))
|
|
111
|
+
elif idx == 4:
|
|
112
|
+
out = hflip(denoiser(hflip(x), sigma_den))
|
|
113
|
+
elif idx == 5:
|
|
114
|
+
out = hflip(rot3(denoiser(rot1(hflip(x)), sigma_den)))
|
|
115
|
+
elif idx == 6:
|
|
116
|
+
out = hflip(rot2(denoiser(rot2(hflip(x)), sigma_den)))
|
|
117
|
+
elif idx == 7:
|
|
118
|
+
out = hflip(rot1(denoiser(rot3(hflip(x)), sigma_den)))
|
|
119
|
+
return out
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def hflip(x):
|
|
123
|
+
return torch.flip(x, dims=[-1])
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def rot1(x):
|
|
127
|
+
return torch.rot90(x, k=1, dims=[-2, -1])
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def rot2(x):
|
|
131
|
+
return torch.rot90(x, k=2, dims=[-2, -1])
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def rot3(x):
|
|
135
|
+
return torch.rot90(x, k=3, dims=[-2, -1])
|
deepinv/models/median.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from torch.nn.modules.utils import _pair, _quadruple
|
|
4
|
+
|
|
5
|
+
# code adapted from https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MedianFilter(nn.Module):
|
|
9
|
+
r"""
|
|
10
|
+
Median filter.
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
:param int kernel_size: size of pooling kernel, int or 2-tuple
|
|
14
|
+
:param padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
|
|
15
|
+
:param same: override padding and enforce same padding, boolean
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, kernel_size=9, padding=0, same=True):
|
|
19
|
+
super(MedianFilter, self).__init__()
|
|
20
|
+
self.k = _pair(kernel_size)
|
|
21
|
+
self.stride = _pair(1)
|
|
22
|
+
self.padding = _quadruple(padding) # convert to l, r, t, b
|
|
23
|
+
self.same = same
|
|
24
|
+
|
|
25
|
+
def _padding(self, x):
|
|
26
|
+
if self.same:
|
|
27
|
+
ih, iw = x.size()[2:]
|
|
28
|
+
if ih % self.stride[0] == 0:
|
|
29
|
+
ph = max(self.k[0] - self.stride[0], 0)
|
|
30
|
+
else:
|
|
31
|
+
ph = max(self.k[0] - (ih % self.stride[0]), 0)
|
|
32
|
+
if iw % self.stride[1] == 0:
|
|
33
|
+
pw = max(self.k[1] - self.stride[1], 0)
|
|
34
|
+
else:
|
|
35
|
+
pw = max(self.k[1] - (iw % self.stride[1]), 0)
|
|
36
|
+
pl = pw // 2
|
|
37
|
+
pr = pw - pl
|
|
38
|
+
pt = ph // 2
|
|
39
|
+
pb = ph - pt
|
|
40
|
+
padding = (pl, pr, pt, pb)
|
|
41
|
+
else:
|
|
42
|
+
padding = self.padding
|
|
43
|
+
return padding
|
|
44
|
+
|
|
45
|
+
def forward(self, x, sigma=None):
|
|
46
|
+
# using existing pytorch functions and tensor ops so that we get autograd,
|
|
47
|
+
# would likely be more efficient to implement from scratch at C/Cuda level
|
|
48
|
+
x = F.pad(x, self._padding(x), mode="reflect")
|
|
49
|
+
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
|
|
50
|
+
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
|
|
51
|
+
return x
|
deepinv/models/scunet.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
1
|
+
# Code taken from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import numpy as np
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
from .utils import get_weights_url
|
|
8
|
+
|
|
9
|
+
# Compatibility with optional dependency on timm
|
|
10
|
+
try:
|
|
11
|
+
import timm
|
|
12
|
+
from timm.models.layers import trunc_normal_, DropPath
|
|
13
|
+
except ImportError as e:
|
|
14
|
+
timm = e
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WMSA(nn.Module):
|
|
18
|
+
"""Self-attention module in Swin Transformer"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
|
21
|
+
if isinstance(timm, ImportError):
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"timm is needed to use the SCUNet class. Please install it with `pip install timm`"
|
|
24
|
+
) from timm
|
|
25
|
+
super(WMSA, self).__init__()
|
|
26
|
+
self.input_dim = input_dim
|
|
27
|
+
self.output_dim = output_dim
|
|
28
|
+
self.head_dim = head_dim
|
|
29
|
+
self.scale = self.head_dim**-0.5
|
|
30
|
+
self.n_heads = input_dim // head_dim
|
|
31
|
+
self.window_size = window_size
|
|
32
|
+
self.type = type
|
|
33
|
+
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
|
34
|
+
|
|
35
|
+
# self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
|
|
36
|
+
self.relative_position_params = nn.Parameter(
|
|
37
|
+
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
|
41
|
+
|
|
42
|
+
trunc_normal_(self.relative_position_params, std=0.02)
|
|
43
|
+
self.relative_position_params = torch.nn.Parameter(
|
|
44
|
+
self.relative_position_params.view(
|
|
45
|
+
2 * window_size - 1, 2 * window_size - 1, self.n_heads
|
|
46
|
+
)
|
|
47
|
+
.transpose(1, 2)
|
|
48
|
+
.transpose(0, 1)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def generate_mask(self, h, w, p, shift):
|
|
52
|
+
"""generating the mask of SW-MSA
|
|
53
|
+
Args:
|
|
54
|
+
shift: shift parameters in CyclicShift.
|
|
55
|
+
Returns:
|
|
56
|
+
attn_mask: should be (1 1 w p p),
|
|
57
|
+
"""
|
|
58
|
+
# supporting sqaure.
|
|
59
|
+
attn_mask = torch.zeros(
|
|
60
|
+
h,
|
|
61
|
+
w,
|
|
62
|
+
p,
|
|
63
|
+
p,
|
|
64
|
+
p,
|
|
65
|
+
p,
|
|
66
|
+
dtype=torch.bool,
|
|
67
|
+
device=self.relative_position_params.device,
|
|
68
|
+
)
|
|
69
|
+
if self.type == "W":
|
|
70
|
+
return attn_mask
|
|
71
|
+
|
|
72
|
+
s = p - shift
|
|
73
|
+
attn_mask[-1, :, :s, :, s:, :] = True
|
|
74
|
+
attn_mask[-1, :, s:, :, :s, :] = True
|
|
75
|
+
attn_mask[:, -1, :, :s, :, s:] = True
|
|
76
|
+
attn_mask[:, -1, :, s:, :, :s] = True
|
|
77
|
+
attn_mask = rearrange(
|
|
78
|
+
attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
|
|
79
|
+
)
|
|
80
|
+
return attn_mask
|
|
81
|
+
|
|
82
|
+
def forward(self, x):
|
|
83
|
+
"""Forward pass of Window Multi-head Self-attention module.
|
|
84
|
+
Args:
|
|
85
|
+
x: input tensor with shape of [b h w c];
|
|
86
|
+
attn_mask: attention mask, fill -inf where the value is True;
|
|
87
|
+
Returns:
|
|
88
|
+
output: tensor shape [b h w c]
|
|
89
|
+
"""
|
|
90
|
+
if self.type != "W":
|
|
91
|
+
x = torch.roll(
|
|
92
|
+
x,
|
|
93
|
+
shifts=(-(self.window_size // 2), -(self.window_size // 2)),
|
|
94
|
+
dims=(1, 2),
|
|
95
|
+
)
|
|
96
|
+
x = rearrange(
|
|
97
|
+
x,
|
|
98
|
+
"b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
|
|
99
|
+
p1=self.window_size,
|
|
100
|
+
p2=self.window_size,
|
|
101
|
+
)
|
|
102
|
+
h_windows = x.size(1)
|
|
103
|
+
w_windows = x.size(2)
|
|
104
|
+
# sqaure validation
|
|
105
|
+
# assert h_windows == w_windows
|
|
106
|
+
|
|
107
|
+
x = rearrange(
|
|
108
|
+
x,
|
|
109
|
+
"b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
|
|
110
|
+
p1=self.window_size,
|
|
111
|
+
p2=self.window_size,
|
|
112
|
+
)
|
|
113
|
+
qkv = self.embedding_layer(x)
|
|
114
|
+
q, k, v = rearrange(
|
|
115
|
+
qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
|
|
116
|
+
).chunk(3, dim=0)
|
|
117
|
+
sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
|
|
118
|
+
# Adding learnable relative embedding
|
|
119
|
+
sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
|
|
120
|
+
# Using Attn Mask to distinguish different subwindows.
|
|
121
|
+
if self.type != "W":
|
|
122
|
+
attn_mask = self.generate_mask(
|
|
123
|
+
h_windows, w_windows, self.window_size, shift=self.window_size // 2
|
|
124
|
+
)
|
|
125
|
+
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
|
126
|
+
|
|
127
|
+
probs = nn.functional.softmax(sim, dim=-1)
|
|
128
|
+
output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
|
|
129
|
+
output = rearrange(output, "h b w p c -> b w p (h c)")
|
|
130
|
+
output = self.linear(output)
|
|
131
|
+
output = rearrange(
|
|
132
|
+
output,
|
|
133
|
+
"b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
|
|
134
|
+
w1=h_windows,
|
|
135
|
+
p1=self.window_size,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if self.type != "W":
|
|
139
|
+
output = torch.roll(
|
|
140
|
+
output,
|
|
141
|
+
shifts=(self.window_size // 2, self.window_size // 2),
|
|
142
|
+
dims=(1, 2),
|
|
143
|
+
)
|
|
144
|
+
return output
|
|
145
|
+
|
|
146
|
+
def relative_embedding(self):
|
|
147
|
+
cord = torch.tensor(
|
|
148
|
+
np.array(
|
|
149
|
+
[
|
|
150
|
+
[i, j]
|
|
151
|
+
for i in range(self.window_size)
|
|
152
|
+
for j in range(self.window_size)
|
|
153
|
+
]
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
|
157
|
+
# negative is allowed
|
|
158
|
+
return self.relative_position_params[
|
|
159
|
+
:, relation[:, :, 0].long(), relation[:, :, 1].long()
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class Block(nn.Module):
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
input_dim,
|
|
167
|
+
output_dim,
|
|
168
|
+
head_dim,
|
|
169
|
+
window_size,
|
|
170
|
+
drop_path,
|
|
171
|
+
type="W",
|
|
172
|
+
input_resolution=None,
|
|
173
|
+
):
|
|
174
|
+
"""SwinTransformer Block"""
|
|
175
|
+
super(Block, self).__init__()
|
|
176
|
+
self.input_dim = input_dim
|
|
177
|
+
self.output_dim = output_dim
|
|
178
|
+
assert type in ["W", "SW"]
|
|
179
|
+
self.type = type
|
|
180
|
+
if input_resolution <= window_size:
|
|
181
|
+
self.type = "W"
|
|
182
|
+
|
|
183
|
+
# print(
|
|
184
|
+
# "Block Initial Type: {}, drop_path_rate:{:.6f}".format(self.type, drop_path)
|
|
185
|
+
# )
|
|
186
|
+
self.ln1 = nn.LayerNorm(input_dim)
|
|
187
|
+
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
|
188
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
189
|
+
self.ln2 = nn.LayerNorm(input_dim)
|
|
190
|
+
self.mlp = nn.Sequential(
|
|
191
|
+
nn.Linear(input_dim, 4 * input_dim),
|
|
192
|
+
nn.GELU(),
|
|
193
|
+
nn.Linear(4 * input_dim, output_dim),
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def forward(self, x):
|
|
197
|
+
x = x + self.drop_path(self.msa(self.ln1(x)))
|
|
198
|
+
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
|
199
|
+
return x
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class ConvTransBlock(nn.Module):
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
conv_dim,
|
|
206
|
+
trans_dim,
|
|
207
|
+
head_dim,
|
|
208
|
+
window_size,
|
|
209
|
+
drop_path,
|
|
210
|
+
type="W",
|
|
211
|
+
input_resolution=None,
|
|
212
|
+
):
|
|
213
|
+
"""SwinTransformer and Conv Block"""
|
|
214
|
+
super(ConvTransBlock, self).__init__()
|
|
215
|
+
self.conv_dim = conv_dim
|
|
216
|
+
self.trans_dim = trans_dim
|
|
217
|
+
self.head_dim = head_dim
|
|
218
|
+
self.window_size = window_size
|
|
219
|
+
self.drop_path = drop_path
|
|
220
|
+
self.type = type
|
|
221
|
+
self.input_resolution = input_resolution
|
|
222
|
+
|
|
223
|
+
assert self.type in ["W", "SW"]
|
|
224
|
+
if self.input_resolution <= self.window_size:
|
|
225
|
+
self.type = "W"
|
|
226
|
+
|
|
227
|
+
self.trans_block = Block(
|
|
228
|
+
self.trans_dim,
|
|
229
|
+
self.trans_dim,
|
|
230
|
+
self.head_dim,
|
|
231
|
+
self.window_size,
|
|
232
|
+
self.drop_path,
|
|
233
|
+
self.type,
|
|
234
|
+
self.input_resolution,
|
|
235
|
+
)
|
|
236
|
+
self.conv1_1 = nn.Conv2d(
|
|
237
|
+
self.conv_dim + self.trans_dim,
|
|
238
|
+
self.conv_dim + self.trans_dim,
|
|
239
|
+
1,
|
|
240
|
+
1,
|
|
241
|
+
0,
|
|
242
|
+
bias=True,
|
|
243
|
+
)
|
|
244
|
+
self.conv1_2 = nn.Conv2d(
|
|
245
|
+
self.conv_dim + self.trans_dim,
|
|
246
|
+
self.conv_dim + self.trans_dim,
|
|
247
|
+
1,
|
|
248
|
+
1,
|
|
249
|
+
0,
|
|
250
|
+
bias=True,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
self.conv_block = nn.Sequential(
|
|
254
|
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
|
255
|
+
nn.ReLU(True),
|
|
256
|
+
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def forward(self, x):
|
|
260
|
+
conv_x, trans_x = torch.split(
|
|
261
|
+
self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
|
|
262
|
+
)
|
|
263
|
+
conv_x = self.conv_block(conv_x) + conv_x
|
|
264
|
+
trans_x = Rearrange("b c h w -> b h w c")(trans_x)
|
|
265
|
+
trans_x = self.trans_block(trans_x)
|
|
266
|
+
trans_x = Rearrange("b h w c -> b c h w")(trans_x)
|
|
267
|
+
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
|
268
|
+
x = x + res
|
|
269
|
+
|
|
270
|
+
return x
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class SCUNet(nn.Module):
|
|
274
|
+
r"""
|
|
275
|
+
SCUNet denoising network.
|
|
276
|
+
|
|
277
|
+
The Swin-Conv-UNet (SCUNet) denoising was introduced in `Practical Blind Denoising via Swin-Conv-UNet and
|
|
278
|
+
Data Synthesis <https://arxiv.org/abs/2203.13278>`_.
|
|
279
|
+
|
|
280
|
+
:param int in_nc: number of input channels. Default: 3.
|
|
281
|
+
:param list config: number of layers in each stage. Default: [4, 4, 4, 4, 4, 4, 4].
|
|
282
|
+
:param int dim: number of channels in each layer. Default: 64.
|
|
283
|
+
:param float drop_path_rate: drop path per sample rate (stochastic depth) for each layer. Default: 0.0.
|
|
284
|
+
:param int input_resolution: input resolution. Default: 256.
|
|
285
|
+
:param bool pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
|
|
286
|
+
using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
|
|
287
|
+
online repository (only available for the default architecture).
|
|
288
|
+
Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights. Default: 'download'.
|
|
289
|
+
See :ref:`pretrained-weights <pretrained-weights>` for more details.
|
|
290
|
+
:param bool train: training or testing mode. Default: False.
|
|
291
|
+
:param str device: gpu or cpu. Default: 'cpu'.
|
|
292
|
+
....
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def __init__(
|
|
296
|
+
self,
|
|
297
|
+
in_nc=3,
|
|
298
|
+
config=[4, 4, 4, 4, 4, 4, 4],
|
|
299
|
+
dim=64,
|
|
300
|
+
drop_path_rate=0.0,
|
|
301
|
+
input_resolution=256,
|
|
302
|
+
pretrained="download",
|
|
303
|
+
train=False,
|
|
304
|
+
device="cpu",
|
|
305
|
+
):
|
|
306
|
+
super(SCUNet, self).__init__()
|
|
307
|
+
self.config = config
|
|
308
|
+
self.dim = dim
|
|
309
|
+
self.head_dim = 32
|
|
310
|
+
self.window_size = 8
|
|
311
|
+
|
|
312
|
+
# drop path rate for each layer
|
|
313
|
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
|
314
|
+
|
|
315
|
+
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
|
316
|
+
|
|
317
|
+
begin = 0
|
|
318
|
+
self.m_down1 = [
|
|
319
|
+
ConvTransBlock(
|
|
320
|
+
dim // 2,
|
|
321
|
+
dim // 2,
|
|
322
|
+
self.head_dim,
|
|
323
|
+
self.window_size,
|
|
324
|
+
dpr[i + begin],
|
|
325
|
+
"W" if not i % 2 else "SW",
|
|
326
|
+
input_resolution,
|
|
327
|
+
)
|
|
328
|
+
for i in range(config[0])
|
|
329
|
+
] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
|
330
|
+
|
|
331
|
+
begin += config[0]
|
|
332
|
+
self.m_down2 = [
|
|
333
|
+
ConvTransBlock(
|
|
334
|
+
dim,
|
|
335
|
+
dim,
|
|
336
|
+
self.head_dim,
|
|
337
|
+
self.window_size,
|
|
338
|
+
dpr[i + begin],
|
|
339
|
+
"W" if not i % 2 else "SW",
|
|
340
|
+
input_resolution // 2,
|
|
341
|
+
)
|
|
342
|
+
for i in range(config[1])
|
|
343
|
+
] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
|
344
|
+
|
|
345
|
+
begin += config[1]
|
|
346
|
+
self.m_down3 = [
|
|
347
|
+
ConvTransBlock(
|
|
348
|
+
2 * dim,
|
|
349
|
+
2 * dim,
|
|
350
|
+
self.head_dim,
|
|
351
|
+
self.window_size,
|
|
352
|
+
dpr[i + begin],
|
|
353
|
+
"W" if not i % 2 else "SW",
|
|
354
|
+
input_resolution // 4,
|
|
355
|
+
)
|
|
356
|
+
for i in range(config[2])
|
|
357
|
+
] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
|
358
|
+
|
|
359
|
+
begin += config[2]
|
|
360
|
+
self.m_body = [
|
|
361
|
+
ConvTransBlock(
|
|
362
|
+
4 * dim,
|
|
363
|
+
4 * dim,
|
|
364
|
+
self.head_dim,
|
|
365
|
+
self.window_size,
|
|
366
|
+
dpr[i + begin],
|
|
367
|
+
"W" if not i % 2 else "SW",
|
|
368
|
+
input_resolution // 8,
|
|
369
|
+
)
|
|
370
|
+
for i in range(config[3])
|
|
371
|
+
]
|
|
372
|
+
|
|
373
|
+
begin += config[3]
|
|
374
|
+
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False)] + [
|
|
375
|
+
ConvTransBlock(
|
|
376
|
+
2 * dim,
|
|
377
|
+
2 * dim,
|
|
378
|
+
self.head_dim,
|
|
379
|
+
self.window_size,
|
|
380
|
+
dpr[i + begin],
|
|
381
|
+
"W" if not i % 2 else "SW",
|
|
382
|
+
input_resolution // 4,
|
|
383
|
+
)
|
|
384
|
+
for i in range(config[4])
|
|
385
|
+
]
|
|
386
|
+
|
|
387
|
+
begin += config[4]
|
|
388
|
+
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False)] + [
|
|
389
|
+
ConvTransBlock(
|
|
390
|
+
dim,
|
|
391
|
+
dim,
|
|
392
|
+
self.head_dim,
|
|
393
|
+
self.window_size,
|
|
394
|
+
dpr[i + begin],
|
|
395
|
+
"W" if not i % 2 else "SW",
|
|
396
|
+
input_resolution // 2,
|
|
397
|
+
)
|
|
398
|
+
for i in range(config[5])
|
|
399
|
+
]
|
|
400
|
+
|
|
401
|
+
begin += config[5]
|
|
402
|
+
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False)] + [
|
|
403
|
+
ConvTransBlock(
|
|
404
|
+
dim // 2,
|
|
405
|
+
dim // 2,
|
|
406
|
+
self.head_dim,
|
|
407
|
+
self.window_size,
|
|
408
|
+
dpr[i + begin],
|
|
409
|
+
"W" if not i % 2 else "SW",
|
|
410
|
+
input_resolution,
|
|
411
|
+
)
|
|
412
|
+
for i in range(config[6])
|
|
413
|
+
]
|
|
414
|
+
|
|
415
|
+
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
|
416
|
+
|
|
417
|
+
self.m_head = nn.Sequential(*self.m_head)
|
|
418
|
+
self.m_down1 = nn.Sequential(*self.m_down1)
|
|
419
|
+
self.m_down2 = nn.Sequential(*self.m_down2)
|
|
420
|
+
self.m_down3 = nn.Sequential(*self.m_down3)
|
|
421
|
+
self.m_body = nn.Sequential(*self.m_body)
|
|
422
|
+
self.m_up3 = nn.Sequential(*self.m_up3)
|
|
423
|
+
self.m_up2 = nn.Sequential(*self.m_up2)
|
|
424
|
+
self.m_up1 = nn.Sequential(*self.m_up1)
|
|
425
|
+
self.m_tail = nn.Sequential(*self.m_tail)
|
|
426
|
+
# self.apply(self._init_weights)
|
|
427
|
+
|
|
428
|
+
if pretrained is not None:
|
|
429
|
+
if pretrained == "download":
|
|
430
|
+
name = "scunet_color_real_psnr.pth"
|
|
431
|
+
url = get_weights_url(model_name="scunet", file_name=name)
|
|
432
|
+
ckpt_drunet = torch.hub.load_state_dict_from_url(
|
|
433
|
+
url, map_location=lambda storage, loc: storage, file_name=name
|
|
434
|
+
)
|
|
435
|
+
else:
|
|
436
|
+
ckpt_drunet = torch.load(
|
|
437
|
+
pretrained, map_location=lambda storage, loc: storage
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
self.load_state_dict(ckpt_drunet, strict=True)
|
|
441
|
+
|
|
442
|
+
if not train:
|
|
443
|
+
self.eval()
|
|
444
|
+
for _, v in self.named_parameters():
|
|
445
|
+
v.requires_grad = False
|
|
446
|
+
|
|
447
|
+
if device is not None:
|
|
448
|
+
self.to(device)
|
|
449
|
+
|
|
450
|
+
def forward_scunet(self, x0):
|
|
451
|
+
h, w = x0.size()[-2:]
|
|
452
|
+
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
|
453
|
+
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
|
454
|
+
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
|
455
|
+
|
|
456
|
+
x1 = self.m_head(x0)
|
|
457
|
+
x2 = self.m_down1(x1)
|
|
458
|
+
x3 = self.m_down2(x2)
|
|
459
|
+
x4 = self.m_down3(x3)
|
|
460
|
+
x = self.m_body(x4)
|
|
461
|
+
x = self.m_up3(x + x4)
|
|
462
|
+
x = self.m_up2(x + x3)
|
|
463
|
+
x = self.m_up1(x + x2)
|
|
464
|
+
x = self.m_tail(x + x1)
|
|
465
|
+
|
|
466
|
+
x = x[..., :h, :w]
|
|
467
|
+
|
|
468
|
+
return x
|
|
469
|
+
|
|
470
|
+
def forward(self, x, sigma): # This is a blind model: sigma is not used
|
|
471
|
+
den = self.forward_scunet(x)
|
|
472
|
+
return den
|
|
473
|
+
|
|
474
|
+
def _init_weights(self, m):
|
|
475
|
+
if isinstance(m, nn.Linear):
|
|
476
|
+
trunc_normal_(m.weight, std=0.02)
|
|
477
|
+
if m.bias is not None:
|
|
478
|
+
nn.init.constant_(m.bias, 0)
|
|
479
|
+
elif isinstance(m, nn.LayerNorm):
|
|
480
|
+
nn.init.constant_(m.bias, 0)
|
|
481
|
+
nn.init.constant_(m.weight, 1.0)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
# if __name__ == '__main__':
|
|
485
|
+
# # torch.cuda.empty_cache()
|
|
486
|
+
# net = SCUNet(pretrained='download', device='cpu', train=False)
|
|
487
|
+
#
|
|
488
|
+
# x = torch.randn((2, 3, 64, 128))
|
|
489
|
+
# x = net(x)
|
|
490
|
+
# print(x.shape)
|