deepinv 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. deepinv/__about__.py +17 -0
  2. deepinv/__init__.py +71 -0
  3. deepinv/datasets/__init__.py +1 -0
  4. deepinv/datasets/datagenerator.py +238 -0
  5. deepinv/loss/__init__.py +10 -0
  6. deepinv/loss/ei.py +76 -0
  7. deepinv/loss/mc.py +39 -0
  8. deepinv/loss/measplit.py +219 -0
  9. deepinv/loss/metric.py +125 -0
  10. deepinv/loss/moi.py +64 -0
  11. deepinv/loss/regularisers.py +155 -0
  12. deepinv/loss/score.py +41 -0
  13. deepinv/loss/sup.py +37 -0
  14. deepinv/loss/sure.py +338 -0
  15. deepinv/loss/tv.py +39 -0
  16. deepinv/models/GSPnP.py +129 -0
  17. deepinv/models/PDNet.py +109 -0
  18. deepinv/models/__init__.py +17 -0
  19. deepinv/models/ae.py +43 -0
  20. deepinv/models/artifactremoval.py +56 -0
  21. deepinv/models/bm3d.py +57 -0
  22. deepinv/models/diffunet.py +997 -0
  23. deepinv/models/dip.py +214 -0
  24. deepinv/models/dncnn.py +131 -0
  25. deepinv/models/drunet.py +689 -0
  26. deepinv/models/equivariant.py +135 -0
  27. deepinv/models/median.py +51 -0
  28. deepinv/models/scunet.py +490 -0
  29. deepinv/models/swinir.py +1140 -0
  30. deepinv/models/tgv.py +232 -0
  31. deepinv/models/tv.py +146 -0
  32. deepinv/models/unet.py +337 -0
  33. deepinv/models/utils.py +22 -0
  34. deepinv/models/wavdict.py +231 -0
  35. deepinv/optim/__init__.py +5 -0
  36. deepinv/optim/data_fidelity.py +607 -0
  37. deepinv/optim/fixed_point.py +289 -0
  38. deepinv/optim/optim_iterators/__init__.py +9 -0
  39. deepinv/optim/optim_iterators/admm.py +117 -0
  40. deepinv/optim/optim_iterators/drs.py +115 -0
  41. deepinv/optim/optim_iterators/gradient_descent.py +90 -0
  42. deepinv/optim/optim_iterators/hqs.py +74 -0
  43. deepinv/optim/optim_iterators/optim_iterator.py +141 -0
  44. deepinv/optim/optim_iterators/pgd.py +91 -0
  45. deepinv/optim/optim_iterators/primal_dual_CP.py +145 -0
  46. deepinv/optim/optim_iterators/utils.py +17 -0
  47. deepinv/optim/optimizers.py +563 -0
  48. deepinv/optim/prior.py +288 -0
  49. deepinv/optim/utils.py +80 -0
  50. deepinv/physics/__init__.py +18 -0
  51. deepinv/physics/blur.py +544 -0
  52. deepinv/physics/compressed_sensing.py +197 -0
  53. deepinv/physics/forward.py +547 -0
  54. deepinv/physics/haze.py +65 -0
  55. deepinv/physics/inpainting.py +48 -0
  56. deepinv/physics/lidar.py +123 -0
  57. deepinv/physics/mri.py +329 -0
  58. deepinv/physics/noise.py +180 -0
  59. deepinv/physics/range.py +53 -0
  60. deepinv/physics/remote_sensing.py +123 -0
  61. deepinv/physics/singlepixel.py +218 -0
  62. deepinv/physics/tomography.py +321 -0
  63. deepinv/sampling/__init__.py +2 -0
  64. deepinv/sampling/diffusion.py +676 -0
  65. deepinv/sampling/langevin.py +512 -0
  66. deepinv/sampling/utils.py +35 -0
  67. deepinv/tests/conftest.py +39 -0
  68. deepinv/tests/dummy_datasets/datasets.py +57 -0
  69. deepinv/tests/test_loss.py +269 -0
  70. deepinv/tests/test_loss_train.py +179 -0
  71. deepinv/tests/test_models.py +377 -0
  72. deepinv/tests/test_optim.py +647 -0
  73. deepinv/tests/test_physics.py +316 -0
  74. deepinv/tests/test_sampling.py +158 -0
  75. deepinv/tests/test_unfolded.py +158 -0
  76. deepinv/tests/test_utils.py +68 -0
  77. deepinv/training_utils.py +529 -0
  78. deepinv/transform/__init__.py +2 -0
  79. deepinv/transform/rotate.py +41 -0
  80. deepinv/transform/shift.py +26 -0
  81. deepinv/unfolded/__init__.py +2 -0
  82. deepinv/unfolded/deep_equilibrium.py +163 -0
  83. deepinv/unfolded/unfolded.py +87 -0
  84. deepinv/utils/__init__.py +17 -0
  85. deepinv/utils/demo.py +171 -0
  86. deepinv/utils/logger.py +93 -0
  87. deepinv/utils/metric.py +87 -0
  88. deepinv/utils/nn.py +213 -0
  89. deepinv/utils/optimization.py +108 -0
  90. deepinv/utils/parameters.py +43 -0
  91. deepinv/utils/phantoms.py +115 -0
  92. deepinv/utils/plotting.py +312 -0
  93. deepinv-0.1.0.dev0.dist-info/LICENSE +28 -0
  94. deepinv-0.1.0.dev0.dist-info/METADATA +159 -0
  95. deepinv-0.1.0.dev0.dist-info/RECORD +97 -0
  96. deepinv-0.1.0.dev0.dist-info/WHEEL +5 -0
  97. deepinv-0.1.0.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,135 @@
1
+ import numpy as np
2
+
3
+ import torch
4
+
5
+
6
+ class EquivariantDenoiser(torch.nn.Module):
7
+ r"""
8
+ Turns the input denoiser into an equivariant denoiser with respect to geometric transforms.
9
+
10
+ Recall that a denoiser is equivariant with respect to a group of transformations if it commutes with the action of
11
+ the group. More precisely, let :math:`\mathcal{G}` be a group of transformations :math:`\{T_g\}_{g\in \mathcal{G}}`
12
+ and :math:`\denoisername` a denoiser. Then, :math:`\denoisername` is equivariant with respect to :math:`\mathcal{G}`
13
+ if :math:`\denoisername(T_g(x)) = T_g(\denoisername(x))` for any image :math:`x` and any :math:`g\in \mathcal{G}`.
14
+
15
+ The denoiser can be turned into an equivariant denoiser by averaging over the group of transforms, i.e.
16
+
17
+ .. math::
18
+ \operatorname{D}^{\text{eq}}_{\sigma}(x) = \frac{1}{|\mathcal{G}|}\sum_{g\in \mathcal{G}} T_g^{-1}(\operatorname{D}_{\sigma}(T_g(x))).
19
+
20
+ Otherwise, as proposed in <https://arxiv.org/abs/2312.01831>`_, a Monte-Carlo approximation can be obtained by
21
+ sampling :math:`g \sim \mathcal{G}` at random and applying
22
+
23
+ .. math::
24
+ \operatorname{D}^{\text{MC}}_{\sigma}(x) = T_g^{-1}(\operatorname{D}_{\sigma}(T_g(x))).
25
+
26
+
27
+ :param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
28
+ :param str transform: type of geometric transformation. Can be either 'rotations', 'flips' or 'rotoflips'.
29
+ If 'rotations', the group of transformations contains the 4 rotations by multiples of 90 degrees; if 'flips',
30
+ the group of transformations contains the 2 horizontal and vertical flips; if 'rotoflips', the group of
31
+ transformations contains the 8 rotations and flips.
32
+ :param bool random: if True, the denoiser is applied to a randomly transformed version of the input image.
33
+ If False, the denoiser is applied to the average of all the transformed images, turning the denoiser into an
34
+ equivariant denoiser with respect to the chosen group of transformations. Otherwise, it is a Monte-Carlo
35
+ approximation of an equivariant denoiser.
36
+ """
37
+
38
+ def __init__(self, denoiser, transform="rotations", random=True):
39
+ super().__init__()
40
+ self.denoiser = denoiser
41
+ self.rotations = True if "rot" in transform else False
42
+ self.flips = True if "flip" in transform else False
43
+ self.random = random
44
+
45
+ def forward(self, x, sigma):
46
+ r"""
47
+ Applies the denoiser to the input image with the appropriate transformation.
48
+
49
+ :param torch.Tensor x: input image.
50
+ :param float sigma: noise level.
51
+ :return: denoised image.
52
+ """
53
+ return denoise_rotate(
54
+ self.denoiser,
55
+ x,
56
+ sigma,
57
+ rotations=self.rotations,
58
+ flips=self.flips,
59
+ random=self.random,
60
+ )
61
+
62
+
63
+ def denoise_rotate(
64
+ denoiser,
65
+ image,
66
+ sigma,
67
+ rotations=True,
68
+ flips=False,
69
+ random=True,
70
+ ):
71
+ r"""
72
+ Applies a geometric transform (rotations and/or flips) to the input image, denoises it with the denoiser and
73
+ transform back the result. The output is either the average of all the transformed images (if random=False) or a
74
+ randomly transformed version of the denoised image (if random=True).
75
+
76
+ :param callable denoiser: Denoiser :math:`\operatorname{D}_{\sigma}`.
77
+ :param torch.Tensor image: input image.
78
+ :param float sigma: noise level.
79
+ :param bool rotations: if True, rotations are applied to the input image.
80
+ :param bool flips: if True, flips are applied to the input image.
81
+ :param bool random: if True, the denoiser is applied to a randomly transformed version of the input image.
82
+ :return: denoised image.
83
+ """
84
+ if random:
85
+ if rotations:
86
+ idx = np.random.randint(8) if flips else np.random.randint(4)
87
+ elif flips:
88
+ idx = np.random.choice([4, 6])
89
+ denoised = denoise_rotate_flip_fn(denoiser, image, sigma, idx)
90
+ else:
91
+ if rotations:
92
+ list_idx = list(range(8)) if flips else list(range(4))
93
+ elif flips:
94
+ list_idx = [4, 6]
95
+ denoised = torch.zeros_like(image)
96
+ for idx in list_idx:
97
+ denoised = denoised + denoise_rotate_flip_fn(denoiser, image, sigma, idx)
98
+ denoised = denoised / len(list_idx)
99
+ return denoised
100
+
101
+
102
+ def denoise_rotate_flip_fn(denoiser, x, sigma_den, idx):
103
+ if idx == 0:
104
+ out = denoiser(x, sigma_den)
105
+ elif idx == 1:
106
+ out = rot3(denoiser(rot1(x), sigma_den))
107
+ elif idx == 2:
108
+ out = rot2(denoiser(rot2(x), sigma_den))
109
+ elif idx == 3:
110
+ out = rot1(denoiser(rot3(x), sigma_den))
111
+ elif idx == 4:
112
+ out = hflip(denoiser(hflip(x), sigma_den))
113
+ elif idx == 5:
114
+ out = hflip(rot3(denoiser(rot1(hflip(x)), sigma_den)))
115
+ elif idx == 6:
116
+ out = hflip(rot2(denoiser(rot2(hflip(x)), sigma_den)))
117
+ elif idx == 7:
118
+ out = hflip(rot1(denoiser(rot3(hflip(x)), sigma_den)))
119
+ return out
120
+
121
+
122
+ def hflip(x):
123
+ return torch.flip(x, dims=[-1])
124
+
125
+
126
+ def rot1(x):
127
+ return torch.rot90(x, k=1, dims=[-2, -1])
128
+
129
+
130
+ def rot2(x):
131
+ return torch.rot90(x, k=2, dims=[-2, -1])
132
+
133
+
134
+ def rot3(x):
135
+ return torch.rot90(x, k=3, dims=[-2, -1])
@@ -0,0 +1,51 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from torch.nn.modules.utils import _pair, _quadruple
4
+
5
+ # code adapted from https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598
6
+
7
+
8
+ class MedianFilter(nn.Module):
9
+ r"""
10
+ Median filter.
11
+
12
+
13
+ :param int kernel_size: size of pooling kernel, int or 2-tuple
14
+ :param padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
15
+ :param same: override padding and enforce same padding, boolean
16
+ """
17
+
18
+ def __init__(self, kernel_size=9, padding=0, same=True):
19
+ super(MedianFilter, self).__init__()
20
+ self.k = _pair(kernel_size)
21
+ self.stride = _pair(1)
22
+ self.padding = _quadruple(padding) # convert to l, r, t, b
23
+ self.same = same
24
+
25
+ def _padding(self, x):
26
+ if self.same:
27
+ ih, iw = x.size()[2:]
28
+ if ih % self.stride[0] == 0:
29
+ ph = max(self.k[0] - self.stride[0], 0)
30
+ else:
31
+ ph = max(self.k[0] - (ih % self.stride[0]), 0)
32
+ if iw % self.stride[1] == 0:
33
+ pw = max(self.k[1] - self.stride[1], 0)
34
+ else:
35
+ pw = max(self.k[1] - (iw % self.stride[1]), 0)
36
+ pl = pw // 2
37
+ pr = pw - pl
38
+ pt = ph // 2
39
+ pb = ph - pt
40
+ padding = (pl, pr, pt, pb)
41
+ else:
42
+ padding = self.padding
43
+ return padding
44
+
45
+ def forward(self, x, sigma=None):
46
+ # using existing pytorch functions and tensor ops so that we get autograd,
47
+ # would likely be more efficient to implement from scratch at C/Cuda level
48
+ x = F.pad(x, self._padding(x), mode="reflect")
49
+ x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
50
+ x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
51
+ return x
@@ -0,0 +1,490 @@
1
+ # Code taken from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from .utils import get_weights_url
8
+
9
+ # Compatibility with optional dependency on timm
10
+ try:
11
+ import timm
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ except ImportError as e:
14
+ timm = e
15
+
16
+
17
+ class WMSA(nn.Module):
18
+ """Self-attention module in Swin Transformer"""
19
+
20
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
21
+ if isinstance(timm, ImportError):
22
+ raise ImportError(
23
+ "timm is needed to use the SCUNet class. Please install it with `pip install timm`"
24
+ ) from timm
25
+ super(WMSA, self).__init__()
26
+ self.input_dim = input_dim
27
+ self.output_dim = output_dim
28
+ self.head_dim = head_dim
29
+ self.scale = self.head_dim**-0.5
30
+ self.n_heads = input_dim // head_dim
31
+ self.window_size = window_size
32
+ self.type = type
33
+ self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
34
+
35
+ # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
36
+ self.relative_position_params = nn.Parameter(
37
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
38
+ )
39
+
40
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
41
+
42
+ trunc_normal_(self.relative_position_params, std=0.02)
43
+ self.relative_position_params = torch.nn.Parameter(
44
+ self.relative_position_params.view(
45
+ 2 * window_size - 1, 2 * window_size - 1, self.n_heads
46
+ )
47
+ .transpose(1, 2)
48
+ .transpose(0, 1)
49
+ )
50
+
51
+ def generate_mask(self, h, w, p, shift):
52
+ """generating the mask of SW-MSA
53
+ Args:
54
+ shift: shift parameters in CyclicShift.
55
+ Returns:
56
+ attn_mask: should be (1 1 w p p),
57
+ """
58
+ # supporting sqaure.
59
+ attn_mask = torch.zeros(
60
+ h,
61
+ w,
62
+ p,
63
+ p,
64
+ p,
65
+ p,
66
+ dtype=torch.bool,
67
+ device=self.relative_position_params.device,
68
+ )
69
+ if self.type == "W":
70
+ return attn_mask
71
+
72
+ s = p - shift
73
+ attn_mask[-1, :, :s, :, s:, :] = True
74
+ attn_mask[-1, :, s:, :, :s, :] = True
75
+ attn_mask[:, -1, :, :s, :, s:] = True
76
+ attn_mask[:, -1, :, s:, :, :s] = True
77
+ attn_mask = rearrange(
78
+ attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
79
+ )
80
+ return attn_mask
81
+
82
+ def forward(self, x):
83
+ """Forward pass of Window Multi-head Self-attention module.
84
+ Args:
85
+ x: input tensor with shape of [b h w c];
86
+ attn_mask: attention mask, fill -inf where the value is True;
87
+ Returns:
88
+ output: tensor shape [b h w c]
89
+ """
90
+ if self.type != "W":
91
+ x = torch.roll(
92
+ x,
93
+ shifts=(-(self.window_size // 2), -(self.window_size // 2)),
94
+ dims=(1, 2),
95
+ )
96
+ x = rearrange(
97
+ x,
98
+ "b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
99
+ p1=self.window_size,
100
+ p2=self.window_size,
101
+ )
102
+ h_windows = x.size(1)
103
+ w_windows = x.size(2)
104
+ # sqaure validation
105
+ # assert h_windows == w_windows
106
+
107
+ x = rearrange(
108
+ x,
109
+ "b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
110
+ p1=self.window_size,
111
+ p2=self.window_size,
112
+ )
113
+ qkv = self.embedding_layer(x)
114
+ q, k, v = rearrange(
115
+ qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
116
+ ).chunk(3, dim=0)
117
+ sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
118
+ # Adding learnable relative embedding
119
+ sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
120
+ # Using Attn Mask to distinguish different subwindows.
121
+ if self.type != "W":
122
+ attn_mask = self.generate_mask(
123
+ h_windows, w_windows, self.window_size, shift=self.window_size // 2
124
+ )
125
+ sim = sim.masked_fill_(attn_mask, float("-inf"))
126
+
127
+ probs = nn.functional.softmax(sim, dim=-1)
128
+ output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
129
+ output = rearrange(output, "h b w p c -> b w p (h c)")
130
+ output = self.linear(output)
131
+ output = rearrange(
132
+ output,
133
+ "b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
134
+ w1=h_windows,
135
+ p1=self.window_size,
136
+ )
137
+
138
+ if self.type != "W":
139
+ output = torch.roll(
140
+ output,
141
+ shifts=(self.window_size // 2, self.window_size // 2),
142
+ dims=(1, 2),
143
+ )
144
+ return output
145
+
146
+ def relative_embedding(self):
147
+ cord = torch.tensor(
148
+ np.array(
149
+ [
150
+ [i, j]
151
+ for i in range(self.window_size)
152
+ for j in range(self.window_size)
153
+ ]
154
+ )
155
+ )
156
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
157
+ # negative is allowed
158
+ return self.relative_position_params[
159
+ :, relation[:, :, 0].long(), relation[:, :, 1].long()
160
+ ]
161
+
162
+
163
+ class Block(nn.Module):
164
+ def __init__(
165
+ self,
166
+ input_dim,
167
+ output_dim,
168
+ head_dim,
169
+ window_size,
170
+ drop_path,
171
+ type="W",
172
+ input_resolution=None,
173
+ ):
174
+ """SwinTransformer Block"""
175
+ super(Block, self).__init__()
176
+ self.input_dim = input_dim
177
+ self.output_dim = output_dim
178
+ assert type in ["W", "SW"]
179
+ self.type = type
180
+ if input_resolution <= window_size:
181
+ self.type = "W"
182
+
183
+ # print(
184
+ # "Block Initial Type: {}, drop_path_rate:{:.6f}".format(self.type, drop_path)
185
+ # )
186
+ self.ln1 = nn.LayerNorm(input_dim)
187
+ self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
188
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
189
+ self.ln2 = nn.LayerNorm(input_dim)
190
+ self.mlp = nn.Sequential(
191
+ nn.Linear(input_dim, 4 * input_dim),
192
+ nn.GELU(),
193
+ nn.Linear(4 * input_dim, output_dim),
194
+ )
195
+
196
+ def forward(self, x):
197
+ x = x + self.drop_path(self.msa(self.ln1(x)))
198
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
199
+ return x
200
+
201
+
202
+ class ConvTransBlock(nn.Module):
203
+ def __init__(
204
+ self,
205
+ conv_dim,
206
+ trans_dim,
207
+ head_dim,
208
+ window_size,
209
+ drop_path,
210
+ type="W",
211
+ input_resolution=None,
212
+ ):
213
+ """SwinTransformer and Conv Block"""
214
+ super(ConvTransBlock, self).__init__()
215
+ self.conv_dim = conv_dim
216
+ self.trans_dim = trans_dim
217
+ self.head_dim = head_dim
218
+ self.window_size = window_size
219
+ self.drop_path = drop_path
220
+ self.type = type
221
+ self.input_resolution = input_resolution
222
+
223
+ assert self.type in ["W", "SW"]
224
+ if self.input_resolution <= self.window_size:
225
+ self.type = "W"
226
+
227
+ self.trans_block = Block(
228
+ self.trans_dim,
229
+ self.trans_dim,
230
+ self.head_dim,
231
+ self.window_size,
232
+ self.drop_path,
233
+ self.type,
234
+ self.input_resolution,
235
+ )
236
+ self.conv1_1 = nn.Conv2d(
237
+ self.conv_dim + self.trans_dim,
238
+ self.conv_dim + self.trans_dim,
239
+ 1,
240
+ 1,
241
+ 0,
242
+ bias=True,
243
+ )
244
+ self.conv1_2 = nn.Conv2d(
245
+ self.conv_dim + self.trans_dim,
246
+ self.conv_dim + self.trans_dim,
247
+ 1,
248
+ 1,
249
+ 0,
250
+ bias=True,
251
+ )
252
+
253
+ self.conv_block = nn.Sequential(
254
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
255
+ nn.ReLU(True),
256
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
257
+ )
258
+
259
+ def forward(self, x):
260
+ conv_x, trans_x = torch.split(
261
+ self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
262
+ )
263
+ conv_x = self.conv_block(conv_x) + conv_x
264
+ trans_x = Rearrange("b c h w -> b h w c")(trans_x)
265
+ trans_x = self.trans_block(trans_x)
266
+ trans_x = Rearrange("b h w c -> b c h w")(trans_x)
267
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
268
+ x = x + res
269
+
270
+ return x
271
+
272
+
273
+ class SCUNet(nn.Module):
274
+ r"""
275
+ SCUNet denoising network.
276
+
277
+ The Swin-Conv-UNet (SCUNet) denoising was introduced in `Practical Blind Denoising via Swin-Conv-UNet and
278
+ Data Synthesis <https://arxiv.org/abs/2203.13278>`_.
279
+
280
+ :param int in_nc: number of input channels. Default: 3.
281
+ :param list config: number of layers in each stage. Default: [4, 4, 4, 4, 4, 4, 4].
282
+ :param int dim: number of channels in each layer. Default: 64.
283
+ :param float drop_path_rate: drop path per sample rate (stochastic depth) for each layer. Default: 0.0.
284
+ :param int input_resolution: input resolution. Default: 256.
285
+ :param bool pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
286
+ using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
287
+ online repository (only available for the default architecture).
288
+ Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights. Default: 'download'.
289
+ See :ref:`pretrained-weights <pretrained-weights>` for more details.
290
+ :param bool train: training or testing mode. Default: False.
291
+ :param str device: gpu or cpu. Default: 'cpu'.
292
+ ....
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ in_nc=3,
298
+ config=[4, 4, 4, 4, 4, 4, 4],
299
+ dim=64,
300
+ drop_path_rate=0.0,
301
+ input_resolution=256,
302
+ pretrained="download",
303
+ train=False,
304
+ device="cpu",
305
+ ):
306
+ super(SCUNet, self).__init__()
307
+ self.config = config
308
+ self.dim = dim
309
+ self.head_dim = 32
310
+ self.window_size = 8
311
+
312
+ # drop path rate for each layer
313
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
314
+
315
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
316
+
317
+ begin = 0
318
+ self.m_down1 = [
319
+ ConvTransBlock(
320
+ dim // 2,
321
+ dim // 2,
322
+ self.head_dim,
323
+ self.window_size,
324
+ dpr[i + begin],
325
+ "W" if not i % 2 else "SW",
326
+ input_resolution,
327
+ )
328
+ for i in range(config[0])
329
+ ] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
330
+
331
+ begin += config[0]
332
+ self.m_down2 = [
333
+ ConvTransBlock(
334
+ dim,
335
+ dim,
336
+ self.head_dim,
337
+ self.window_size,
338
+ dpr[i + begin],
339
+ "W" if not i % 2 else "SW",
340
+ input_resolution // 2,
341
+ )
342
+ for i in range(config[1])
343
+ ] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
344
+
345
+ begin += config[1]
346
+ self.m_down3 = [
347
+ ConvTransBlock(
348
+ 2 * dim,
349
+ 2 * dim,
350
+ self.head_dim,
351
+ self.window_size,
352
+ dpr[i + begin],
353
+ "W" if not i % 2 else "SW",
354
+ input_resolution // 4,
355
+ )
356
+ for i in range(config[2])
357
+ ] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
358
+
359
+ begin += config[2]
360
+ self.m_body = [
361
+ ConvTransBlock(
362
+ 4 * dim,
363
+ 4 * dim,
364
+ self.head_dim,
365
+ self.window_size,
366
+ dpr[i + begin],
367
+ "W" if not i % 2 else "SW",
368
+ input_resolution // 8,
369
+ )
370
+ for i in range(config[3])
371
+ ]
372
+
373
+ begin += config[3]
374
+ self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False)] + [
375
+ ConvTransBlock(
376
+ 2 * dim,
377
+ 2 * dim,
378
+ self.head_dim,
379
+ self.window_size,
380
+ dpr[i + begin],
381
+ "W" if not i % 2 else "SW",
382
+ input_resolution // 4,
383
+ )
384
+ for i in range(config[4])
385
+ ]
386
+
387
+ begin += config[4]
388
+ self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False)] + [
389
+ ConvTransBlock(
390
+ dim,
391
+ dim,
392
+ self.head_dim,
393
+ self.window_size,
394
+ dpr[i + begin],
395
+ "W" if not i % 2 else "SW",
396
+ input_resolution // 2,
397
+ )
398
+ for i in range(config[5])
399
+ ]
400
+
401
+ begin += config[5]
402
+ self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False)] + [
403
+ ConvTransBlock(
404
+ dim // 2,
405
+ dim // 2,
406
+ self.head_dim,
407
+ self.window_size,
408
+ dpr[i + begin],
409
+ "W" if not i % 2 else "SW",
410
+ input_resolution,
411
+ )
412
+ for i in range(config[6])
413
+ ]
414
+
415
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
416
+
417
+ self.m_head = nn.Sequential(*self.m_head)
418
+ self.m_down1 = nn.Sequential(*self.m_down1)
419
+ self.m_down2 = nn.Sequential(*self.m_down2)
420
+ self.m_down3 = nn.Sequential(*self.m_down3)
421
+ self.m_body = nn.Sequential(*self.m_body)
422
+ self.m_up3 = nn.Sequential(*self.m_up3)
423
+ self.m_up2 = nn.Sequential(*self.m_up2)
424
+ self.m_up1 = nn.Sequential(*self.m_up1)
425
+ self.m_tail = nn.Sequential(*self.m_tail)
426
+ # self.apply(self._init_weights)
427
+
428
+ if pretrained is not None:
429
+ if pretrained == "download":
430
+ name = "scunet_color_real_psnr.pth"
431
+ url = get_weights_url(model_name="scunet", file_name=name)
432
+ ckpt_drunet = torch.hub.load_state_dict_from_url(
433
+ url, map_location=lambda storage, loc: storage, file_name=name
434
+ )
435
+ else:
436
+ ckpt_drunet = torch.load(
437
+ pretrained, map_location=lambda storage, loc: storage
438
+ )
439
+
440
+ self.load_state_dict(ckpt_drunet, strict=True)
441
+
442
+ if not train:
443
+ self.eval()
444
+ for _, v in self.named_parameters():
445
+ v.requires_grad = False
446
+
447
+ if device is not None:
448
+ self.to(device)
449
+
450
+ def forward_scunet(self, x0):
451
+ h, w = x0.size()[-2:]
452
+ paddingBottom = int(np.ceil(h / 64) * 64 - h)
453
+ paddingRight = int(np.ceil(w / 64) * 64 - w)
454
+ x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
455
+
456
+ x1 = self.m_head(x0)
457
+ x2 = self.m_down1(x1)
458
+ x3 = self.m_down2(x2)
459
+ x4 = self.m_down3(x3)
460
+ x = self.m_body(x4)
461
+ x = self.m_up3(x + x4)
462
+ x = self.m_up2(x + x3)
463
+ x = self.m_up1(x + x2)
464
+ x = self.m_tail(x + x1)
465
+
466
+ x = x[..., :h, :w]
467
+
468
+ return x
469
+
470
+ def forward(self, x, sigma): # This is a blind model: sigma is not used
471
+ den = self.forward_scunet(x)
472
+ return den
473
+
474
+ def _init_weights(self, m):
475
+ if isinstance(m, nn.Linear):
476
+ trunc_normal_(m.weight, std=0.02)
477
+ if m.bias is not None:
478
+ nn.init.constant_(m.bias, 0)
479
+ elif isinstance(m, nn.LayerNorm):
480
+ nn.init.constant_(m.bias, 0)
481
+ nn.init.constant_(m.weight, 1.0)
482
+
483
+
484
+ # if __name__ == '__main__':
485
+ # # torch.cuda.empty_cache()
486
+ # net = SCUNet(pretrained='download', device='cpu', train=False)
487
+ #
488
+ # x = torch.randn((2, 3, 64, 128))
489
+ # x = net(x)
490
+ # print(x.shape)