monai-weekly 1.5.dev2511__py3-none-any.whl → 1.5.dev2513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/data/utils.py +1 -1
- monai/metrics/meandice.py +132 -76
- monai/networks/blocks/__init__.py +2 -1
- monai/networks/blocks/cablock.py +182 -0
- monai/networks/blocks/downsample.py +241 -2
- monai/networks/nets/restormer.py +337 -0
- monai/networks/utils.py +44 -1
- monai/utils/__init__.py +1 -0
- monai/utils/enums.py +13 -0
- monai/utils/misc.py +1 -1
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/METADATA +3 -2
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/RECORD +22 -17
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/WHEEL +1 -1
- tests/metrics/test_compute_meandice.py +3 -3
- tests/networks/blocks/test_CABlock.py +150 -0
- tests/networks/blocks/test_downsample_block.py +184 -0
- tests/networks/nets/test_restormer.py +147 -0
- tests/networks/utils/test_pixelunshuffle.py +51 -0
- tests/integration/test_downsample_block.py +0 -50
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info/licenses}/LICENSE +0 -0
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,11 @@ from collections.abc import Sequence
|
|
16
16
|
import torch
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
|
-
from monai.networks.layers.factories import Pool
|
20
|
-
from monai.utils import
|
19
|
+
from monai.networks.layers.factories import Conv, Pool
|
20
|
+
from monai.networks.utils import pixelunshuffle
|
21
|
+
from monai.utils import DownsampleMode, ensure_tuple_rep, look_up_option
|
22
|
+
|
23
|
+
__all__ = ["MaxAvgPool", "DownSample", "Downsample", "SubpixelDownsample", "SubpixelDownSample", "Subpixeldownsample"]
|
21
24
|
|
22
25
|
|
23
26
|
class MaxAvgPool(nn.Module):
|
@@ -61,3 +64,239 @@ class MaxAvgPool(nn.Module):
|
|
61
64
|
Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]).
|
62
65
|
"""
|
63
66
|
return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1)
|
67
|
+
|
68
|
+
|
69
|
+
class DownSample(nn.Sequential):
|
70
|
+
"""
|
71
|
+
Downsamples data by `scale_factor`.
|
72
|
+
|
73
|
+
Supported modes are:
|
74
|
+
|
75
|
+
- "conv": uses a strided convolution for learnable downsampling.
|
76
|
+
- "convgroup": uses a grouped strided convolution for efficient feature reduction.
|
77
|
+
- "nontrainable": uses :py:class:`torch.nn.Upsample` with inverse scale factor.
|
78
|
+
- "pixelunshuffle": uses :py:class:`monai.networks.blocks.PixelUnshuffle` for channel-space rearrangement.
|
79
|
+
|
80
|
+
This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.
|
81
|
+
Please check the link below for more details:
|
82
|
+
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
|
83
|
+
|
84
|
+
This module can optionally take a pre-convolution
|
85
|
+
(often used to map the number of features from `in_channels` to `out_channels`).
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
spatial_dims: int,
|
91
|
+
in_channels: int | None = None,
|
92
|
+
out_channels: int | None = None,
|
93
|
+
scale_factor: Sequence[float] | float = 2,
|
94
|
+
kernel_size: Sequence[float] | float | None = None,
|
95
|
+
mode: DownsampleMode | str = DownsampleMode.CONV,
|
96
|
+
pre_conv: nn.Module | str | None = "default",
|
97
|
+
post_conv: nn.Module | None = None,
|
98
|
+
bias: bool = True,
|
99
|
+
) -> None:
|
100
|
+
"""
|
101
|
+
Downsamples data by `scale_factor`.
|
102
|
+
Supported modes are:
|
103
|
+
|
104
|
+
- DownsampleMode.CONV: uses a strided convolution for learnable downsampling.
|
105
|
+
- DownsampleMode.CONVGROUP: uses a grouped strided convolution for efficient feature reduction.
|
106
|
+
- DownsampleMode.MAXPOOL: uses maxpooling for non-learnable downsampling.
|
107
|
+
- DownsampleMode.AVGPOOL: uses average pooling for non-learnable downsampling.
|
108
|
+
- DownsampleMode.PIXELUNSHUFFLE: uses :py:class:`monai.networks.blocks.SubpixelDownsample`.
|
109
|
+
|
110
|
+
This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.
|
111
|
+
Please check the link below for more details:
|
112
|
+
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
|
113
|
+
|
114
|
+
This module can optionally take a pre-convolution and post-convolution
|
115
|
+
(often used to map the number of features from `in_channels` to `out_channels`).
|
116
|
+
|
117
|
+
Args:
|
118
|
+
spatial_dims: number of spatial dimensions of the input image.
|
119
|
+
in_channels: number of channels of the input image.
|
120
|
+
out_channels: number of channels of the output image. Defaults to `in_channels`.
|
121
|
+
scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2.
|
122
|
+
kernel_size: kernel size used during convolutions. Defaults to `scale_factor`.
|
123
|
+
mode: {``DownsampleMode.CONV``, ``DownsampleMode.CONVGROUP``, ``DownsampleMode.MAXPOOL``, ``DownsampleMode.AVGPOOL``,
|
124
|
+
``DownsampleMode.PIXELUNSHUFFLE``}. Defaults to ``DownsampleMode.CONV``.
|
125
|
+
pre_conv: a conv block applied before downsampling. Defaults to "default".
|
126
|
+
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
|
127
|
+
Only used in the "maxpool", "avgpool" or "pixelunshuffle" modes.
|
128
|
+
post_conv: a conv block applied after downsampling. Defaults to None. Only used in the "maxpool" and "avgpool" modes.
|
129
|
+
bias: whether to have a bias term in the default preconv and conv layers. Defaults to True.
|
130
|
+
"""
|
131
|
+
super().__init__()
|
132
|
+
|
133
|
+
scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
|
134
|
+
down_mode = look_up_option(mode, DownsampleMode)
|
135
|
+
|
136
|
+
if not kernel_size:
|
137
|
+
kernel_size_ = scale_factor_
|
138
|
+
padding = ensure_tuple_rep(0, spatial_dims)
|
139
|
+
else:
|
140
|
+
kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
|
141
|
+
padding = tuple((k - 1) // 2 for k in kernel_size_)
|
142
|
+
|
143
|
+
if down_mode == DownsampleMode.CONV:
|
144
|
+
if not in_channels:
|
145
|
+
raise ValueError("in_channels needs to be specified in conv mode")
|
146
|
+
self.add_module(
|
147
|
+
"conv",
|
148
|
+
Conv[Conv.CONV, spatial_dims](
|
149
|
+
in_channels=in_channels,
|
150
|
+
out_channels=out_channels or in_channels,
|
151
|
+
kernel_size=kernel_size_,
|
152
|
+
stride=scale_factor_,
|
153
|
+
padding=padding,
|
154
|
+
bias=bias,
|
155
|
+
),
|
156
|
+
)
|
157
|
+
elif down_mode == DownsampleMode.CONVGROUP:
|
158
|
+
if not in_channels:
|
159
|
+
raise ValueError("in_channels needs to be specified")
|
160
|
+
if out_channels is None:
|
161
|
+
out_channels = in_channels
|
162
|
+
groups = in_channels if out_channels % in_channels == 0 else 1
|
163
|
+
self.add_module(
|
164
|
+
"convgroup",
|
165
|
+
Conv[Conv.CONV, spatial_dims](
|
166
|
+
in_channels=in_channels,
|
167
|
+
out_channels=out_channels,
|
168
|
+
kernel_size=kernel_size_,
|
169
|
+
stride=scale_factor_,
|
170
|
+
padding=padding,
|
171
|
+
groups=groups,
|
172
|
+
bias=bias,
|
173
|
+
),
|
174
|
+
)
|
175
|
+
elif down_mode == DownsampleMode.MAXPOOL:
|
176
|
+
if pre_conv == "default" and (out_channels != in_channels):
|
177
|
+
if not in_channels:
|
178
|
+
raise ValueError("in_channels needs to be specified")
|
179
|
+
self.add_module(
|
180
|
+
"preconv",
|
181
|
+
Conv[Conv.CONV, spatial_dims](
|
182
|
+
in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias
|
183
|
+
),
|
184
|
+
)
|
185
|
+
self.add_module(
|
186
|
+
"maxpool", Pool[Pool.MAX, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding)
|
187
|
+
)
|
188
|
+
if post_conv:
|
189
|
+
self.add_module("postconv", post_conv)
|
190
|
+
|
191
|
+
elif down_mode == DownsampleMode.AVGPOOL:
|
192
|
+
if pre_conv == "default" and (out_channels != in_channels):
|
193
|
+
if not in_channels:
|
194
|
+
raise ValueError("in_channels needs to be specified")
|
195
|
+
self.add_module(
|
196
|
+
"preconv",
|
197
|
+
Conv[Conv.CONV, spatial_dims](
|
198
|
+
in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias
|
199
|
+
),
|
200
|
+
)
|
201
|
+
self.add_module(
|
202
|
+
"avgpool", Pool[Pool.AVG, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding)
|
203
|
+
)
|
204
|
+
if post_conv:
|
205
|
+
self.add_module("postconv", post_conv)
|
206
|
+
|
207
|
+
elif down_mode == DownsampleMode.PIXELUNSHUFFLE:
|
208
|
+
self.add_module(
|
209
|
+
"pixelunshuffle",
|
210
|
+
SubpixelDownsample(
|
211
|
+
spatial_dims=spatial_dims,
|
212
|
+
in_channels=in_channels,
|
213
|
+
out_channels=out_channels,
|
214
|
+
scale_factor=scale_factor_[0],
|
215
|
+
conv_block=pre_conv,
|
216
|
+
bias=bias,
|
217
|
+
),
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
class SubpixelDownsample(nn.Module):
|
222
|
+
"""
|
223
|
+
Downsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images.
|
224
|
+
The module consists of two parts. First, a convolutional layer is employed
|
225
|
+
to adjust the number of channels. Secondly, a pixel unshuffle manipulation
|
226
|
+
rearranges the spatial information into channel space, effectively reducing
|
227
|
+
spatial dimensions while increasing channel depth.
|
228
|
+
|
229
|
+
The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions
|
230
|
+
from (B, C, H*r, W*r) to (B, C*r², H, W) for 2D images or from (B, C, H*r, W*r, D*r) to (B, C*r³, H, W, D) in 3D case.
|
231
|
+
|
232
|
+
Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2).
|
233
|
+
|
234
|
+
See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
|
235
|
+
Using a nEfficient Sub-Pixel Convolutional Neural Network."
|
236
|
+
|
237
|
+
The pixel unshuffle mechanism is the inverse operation of:
|
238
|
+
https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py
|
239
|
+
"""
|
240
|
+
|
241
|
+
def __init__(
|
242
|
+
self,
|
243
|
+
spatial_dims: int,
|
244
|
+
in_channels: int | None,
|
245
|
+
out_channels: int | None = None,
|
246
|
+
scale_factor: int = 2,
|
247
|
+
conv_block: nn.Module | str | None = "default",
|
248
|
+
bias: bool = True,
|
249
|
+
) -> None:
|
250
|
+
"""
|
251
|
+
Downsamples data by rearranging spatial information into channel space.
|
252
|
+
This reduces spatial dimensions while increasing channel depth.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
spatial_dims: number of spatial dimensions of the input image.
|
256
|
+
in_channels: number of channels of the input image.
|
257
|
+
out_channels: optional number of channels of the output image.
|
258
|
+
scale_factor: factor to reduce the spatial dimensions by. Defaults to 2.
|
259
|
+
conv_block: a conv block to adjust channels before downsampling. Defaults to None.
|
260
|
+
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
|
261
|
+
When ``conv_block`` is an ``nn.module``,
|
262
|
+
please ensure the input number of channels matches requirements.
|
263
|
+
bias: whether to have a bias term in the default conv_block. Defaults to True.
|
264
|
+
"""
|
265
|
+
super().__init__()
|
266
|
+
|
267
|
+
if scale_factor <= 0:
|
268
|
+
raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.")
|
269
|
+
|
270
|
+
self.dimensions = spatial_dims
|
271
|
+
self.scale_factor = scale_factor
|
272
|
+
|
273
|
+
if conv_block == "default":
|
274
|
+
if not in_channels:
|
275
|
+
raise ValueError("in_channels need to be specified.")
|
276
|
+
out_channels = out_channels or in_channels
|
277
|
+
self.conv_block = Conv[Conv.CONV, self.dimensions](
|
278
|
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
279
|
+
)
|
280
|
+
elif conv_block is None:
|
281
|
+
self.conv_block = nn.Identity()
|
282
|
+
else:
|
283
|
+
self.conv_block = conv_block
|
284
|
+
|
285
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
286
|
+
"""
|
287
|
+
Args:
|
288
|
+
x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).
|
289
|
+
Returns:
|
290
|
+
Tensor with reduced spatial dimensions and increased channel depth.
|
291
|
+
"""
|
292
|
+
x = self.conv_block(x)
|
293
|
+
if not all(d % self.scale_factor == 0 for d in x.shape[2:]):
|
294
|
+
raise ValueError(
|
295
|
+
f"All spatial dimensions {x.shape[2:]} must be evenly " f"divisible by scale_factor {self.scale_factor}"
|
296
|
+
)
|
297
|
+
x = pixelunshuffle(x, self.dimensions, self.scale_factor)
|
298
|
+
return x
|
299
|
+
|
300
|
+
|
301
|
+
Downsample = DownSample
|
302
|
+
SubpixelDownSample = Subpixeldownsample = SubpixelDownsample
|
@@ -0,0 +1,337 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import torch.nn as nn
|
15
|
+
|
16
|
+
from monai.networks.blocks.cablock import CABlock, FeedForward
|
17
|
+
from monai.networks.blocks.convolutions import Convolution
|
18
|
+
from monai.networks.blocks.downsample import DownSample
|
19
|
+
from monai.networks.blocks.upsample import UpSample
|
20
|
+
from monai.networks.layers.factories import Norm
|
21
|
+
from monai.utils.enums import DownsampleMode, UpsampleMode
|
22
|
+
|
23
|
+
|
24
|
+
class MDTATransformerBlock(nn.Module):
|
25
|
+
"""Basic transformer unit combining MDTA and GDFN with skip connections.
|
26
|
+
Unlike standard transformers that use LayerNorm, this block uses Instance Norm
|
27
|
+
for better adaptation to image restoration tasks.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
spatial_dims: Number of spatial dimensions (2D or 3D)
|
31
|
+
dim: Number of input channels
|
32
|
+
num_heads: Number of attention heads
|
33
|
+
ffn_expansion_factor: Expansion factor for feed-forward network
|
34
|
+
bias: Whether to use bias in attention layers
|
35
|
+
layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False.
|
36
|
+
flash_attention: Whether to use flash attention optimization. Defaults to False.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
spatial_dims: int,
|
42
|
+
dim: int,
|
43
|
+
num_heads: int,
|
44
|
+
ffn_expansion_factor: float,
|
45
|
+
bias: bool,
|
46
|
+
layer_norm_use_bias: bool = False,
|
47
|
+
flash_attention: bool = False,
|
48
|
+
):
|
49
|
+
super().__init__()
|
50
|
+
self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)
|
51
|
+
self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention)
|
52
|
+
self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)
|
53
|
+
self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias)
|
54
|
+
|
55
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56
|
+
x = x + self.attn(self.norm1(x))
|
57
|
+
x = x + self.ffn(self.norm2(x))
|
58
|
+
return x
|
59
|
+
|
60
|
+
|
61
|
+
class OverlapPatchEmbed(Convolution):
|
62
|
+
"""Initial feature extraction using overlapped convolutions.
|
63
|
+
Unlike standard patch embeddings that use non-overlapping patches,
|
64
|
+
this approach maintains spatial continuity through 3x3 convolutions.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
spatial_dims: Number of spatial dimensions (2D or 3D)
|
68
|
+
in_channels: Number of input channels
|
69
|
+
embed_dim: Dimension of embedded features. Defaults to 48.
|
70
|
+
bias: Whether to use bias in convolution layer. Defaults to False.
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False):
|
74
|
+
super().__init__(
|
75
|
+
spatial_dims=spatial_dims,
|
76
|
+
in_channels=in_channels,
|
77
|
+
out_channels=embed_dim,
|
78
|
+
kernel_size=3,
|
79
|
+
strides=1,
|
80
|
+
padding=1,
|
81
|
+
bias=bias,
|
82
|
+
conv_only=True,
|
83
|
+
)
|
84
|
+
|
85
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86
|
+
x = super().forward(x)
|
87
|
+
return x
|
88
|
+
|
89
|
+
|
90
|
+
class Restormer(nn.Module):
|
91
|
+
"""Restormer: Efficient Transformer for High-Resolution Image Restoration.
|
92
|
+
|
93
|
+
Implements a U-Net style architecture with transformer blocks, combining:
|
94
|
+
- Multi-scale feature processing through progressive down/upsampling
|
95
|
+
- Efficient attention via MDTA blocks
|
96
|
+
- Local feature mixing through GDFN
|
97
|
+
- Skip connections for preserving spatial details
|
98
|
+
|
99
|
+
Architecture:
|
100
|
+
- Encoder: Progressive feature downsampling with increasing channels
|
101
|
+
- Latent: Deep feature processing at lowest resolution
|
102
|
+
- Decoder: Progressive upsampling with skip connections
|
103
|
+
- Refinement: Final feature enhancement
|
104
|
+
"""
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
spatial_dims: int = 2,
|
109
|
+
in_channels: int = 3,
|
110
|
+
out_channels: int = 3,
|
111
|
+
dim: int = 48,
|
112
|
+
num_blocks: tuple[int, ...] = (1, 1, 1, 1),
|
113
|
+
heads: tuple[int, ...] = (1, 1, 1, 1),
|
114
|
+
num_refinement_blocks: int = 4,
|
115
|
+
ffn_expansion_factor: float = 2.66,
|
116
|
+
bias: bool = False,
|
117
|
+
layer_norm_use_bias: bool = True,
|
118
|
+
dual_pixel_task: bool = False,
|
119
|
+
flash_attention: bool = False,
|
120
|
+
) -> None:
|
121
|
+
super().__init__()
|
122
|
+
"""Initialize Restormer model.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
spatial_dims: Number of spatial dimensions (2D or 3D)
|
126
|
+
in_channels: Number of input image channels
|
127
|
+
out_channels: Number of output image channels
|
128
|
+
dim: Base feature dimension. Defaults to 48.
|
129
|
+
num_blocks: Number of transformer blocks at each scale. Defaults to (1,1,1,1).
|
130
|
+
heads: Number of attention heads at each scale. Defaults to (1,1,1,1).
|
131
|
+
num_refinement_blocks: Number of final refinement blocks. Defaults to 4.
|
132
|
+
ffn_expansion_factor: Expansion factor for feed-forward network. Defaults to 2.66.
|
133
|
+
bias: Whether to use bias in convolutions. Defaults to False.
|
134
|
+
layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to True.
|
135
|
+
dual_pixel_task: Enable dual-pixel specific processing. Defaults to False.
|
136
|
+
flash_attention: Use flash attention if available. Defaults to False.
|
137
|
+
|
138
|
+
Note:
|
139
|
+
The number of blocks must be greater than 1
|
140
|
+
The length of num_blocks and heads must be equal
|
141
|
+
All values in num_blocks must be greater than 0
|
142
|
+
"""
|
143
|
+
# Check input parameters
|
144
|
+
assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
|
145
|
+
assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
|
146
|
+
assert all(n > 0 for n in num_blocks), "Number of blocks must be greater than 0"
|
147
|
+
|
148
|
+
# Initial feature extraction
|
149
|
+
self.patch_embed = OverlapPatchEmbed(spatial_dims, in_channels, dim)
|
150
|
+
self.encoder_levels = nn.ModuleList()
|
151
|
+
self.downsamples = nn.ModuleList()
|
152
|
+
self.decoder_levels = nn.ModuleList()
|
153
|
+
self.upsamples = nn.ModuleList()
|
154
|
+
self.reduce_channels = nn.ModuleList()
|
155
|
+
num_steps = len(num_blocks) - 1
|
156
|
+
self.num_steps = num_steps
|
157
|
+
self.spatial_dims = spatial_dims
|
158
|
+
spatial_multiplier = 2 ** (spatial_dims - 1)
|
159
|
+
|
160
|
+
# Define encoder levels
|
161
|
+
for n in range(num_steps):
|
162
|
+
current_dim = dim * (2) ** (n)
|
163
|
+
next_dim = current_dim // spatial_multiplier
|
164
|
+
self.encoder_levels.append(
|
165
|
+
nn.Sequential(
|
166
|
+
*[
|
167
|
+
MDTATransformerBlock(
|
168
|
+
spatial_dims=spatial_dims,
|
169
|
+
dim=current_dim,
|
170
|
+
num_heads=heads[n],
|
171
|
+
ffn_expansion_factor=ffn_expansion_factor,
|
172
|
+
bias=bias,
|
173
|
+
layer_norm_use_bias=layer_norm_use_bias,
|
174
|
+
flash_attention=flash_attention,
|
175
|
+
)
|
176
|
+
for _ in range(num_blocks[n])
|
177
|
+
]
|
178
|
+
)
|
179
|
+
)
|
180
|
+
|
181
|
+
self.downsamples.append(
|
182
|
+
DownSample(
|
183
|
+
spatial_dims=self.spatial_dims,
|
184
|
+
in_channels=current_dim,
|
185
|
+
out_channels=next_dim,
|
186
|
+
mode=DownsampleMode.PIXELUNSHUFFLE,
|
187
|
+
scale_factor=2,
|
188
|
+
bias=bias,
|
189
|
+
)
|
190
|
+
)
|
191
|
+
|
192
|
+
# Define latent space
|
193
|
+
latent_dim = dim * (2) ** (num_steps)
|
194
|
+
self.latent = nn.Sequential(
|
195
|
+
*[
|
196
|
+
MDTATransformerBlock(
|
197
|
+
spatial_dims=spatial_dims,
|
198
|
+
dim=latent_dim,
|
199
|
+
num_heads=heads[num_steps],
|
200
|
+
ffn_expansion_factor=ffn_expansion_factor,
|
201
|
+
bias=bias,
|
202
|
+
layer_norm_use_bias=layer_norm_use_bias,
|
203
|
+
flash_attention=flash_attention,
|
204
|
+
)
|
205
|
+
for _ in range(num_blocks[num_steps])
|
206
|
+
]
|
207
|
+
)
|
208
|
+
|
209
|
+
# Define decoder levels
|
210
|
+
for n in reversed(range(num_steps)):
|
211
|
+
current_dim = dim * (2) ** (n)
|
212
|
+
next_dim = dim * (2) ** (n + 1)
|
213
|
+
self.upsamples.append(
|
214
|
+
UpSample(
|
215
|
+
spatial_dims=self.spatial_dims,
|
216
|
+
in_channels=next_dim,
|
217
|
+
out_channels=(current_dim),
|
218
|
+
mode=UpsampleMode.PIXELSHUFFLE,
|
219
|
+
scale_factor=2,
|
220
|
+
bias=bias,
|
221
|
+
apply_pad_pool=False,
|
222
|
+
)
|
223
|
+
)
|
224
|
+
|
225
|
+
# Reduce channel layers to deal with skip connections
|
226
|
+
if n != 0:
|
227
|
+
self.reduce_channels.append(
|
228
|
+
Convolution(
|
229
|
+
spatial_dims=self.spatial_dims,
|
230
|
+
in_channels=next_dim,
|
231
|
+
out_channels=current_dim,
|
232
|
+
kernel_size=1,
|
233
|
+
bias=bias,
|
234
|
+
conv_only=True,
|
235
|
+
)
|
236
|
+
)
|
237
|
+
decoder_dim = current_dim
|
238
|
+
else:
|
239
|
+
decoder_dim = next_dim
|
240
|
+
|
241
|
+
self.decoder_levels.append(
|
242
|
+
nn.Sequential(
|
243
|
+
*[
|
244
|
+
MDTATransformerBlock(
|
245
|
+
spatial_dims=spatial_dims,
|
246
|
+
dim=decoder_dim,
|
247
|
+
num_heads=heads[n],
|
248
|
+
ffn_expansion_factor=ffn_expansion_factor,
|
249
|
+
bias=bias,
|
250
|
+
layer_norm_use_bias=layer_norm_use_bias,
|
251
|
+
flash_attention=flash_attention,
|
252
|
+
)
|
253
|
+
for _ in range(num_blocks[n])
|
254
|
+
]
|
255
|
+
)
|
256
|
+
)
|
257
|
+
|
258
|
+
# Final refinement and output
|
259
|
+
self.refinement = nn.Sequential(
|
260
|
+
*[
|
261
|
+
MDTATransformerBlock(
|
262
|
+
spatial_dims=spatial_dims,
|
263
|
+
dim=decoder_dim,
|
264
|
+
num_heads=heads[0],
|
265
|
+
ffn_expansion_factor=ffn_expansion_factor,
|
266
|
+
bias=bias,
|
267
|
+
layer_norm_use_bias=layer_norm_use_bias,
|
268
|
+
flash_attention=flash_attention,
|
269
|
+
)
|
270
|
+
for _ in range(num_refinement_blocks)
|
271
|
+
]
|
272
|
+
)
|
273
|
+
self.dual_pixel_task = dual_pixel_task
|
274
|
+
if self.dual_pixel_task:
|
275
|
+
self.skip_conv = Convolution(
|
276
|
+
spatial_dims=self.spatial_dims,
|
277
|
+
in_channels=dim,
|
278
|
+
out_channels=dim * 2,
|
279
|
+
kernel_size=1,
|
280
|
+
bias=bias,
|
281
|
+
conv_only=True,
|
282
|
+
)
|
283
|
+
self.output = Convolution(
|
284
|
+
spatial_dims=self.spatial_dims,
|
285
|
+
in_channels=dim * 2,
|
286
|
+
out_channels=out_channels,
|
287
|
+
kernel_size=3,
|
288
|
+
strides=1,
|
289
|
+
padding=1,
|
290
|
+
bias=bias,
|
291
|
+
conv_only=True,
|
292
|
+
)
|
293
|
+
|
294
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
295
|
+
"""Forward pass of Restormer.
|
296
|
+
Processes input through encoder-decoder architecture with skip connections.
|
297
|
+
Args:
|
298
|
+
inp_img: Input image tensor of shape (B, C, H, W, [D])
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
Restored image tensor of shape (B, C, H, W, [D])
|
302
|
+
"""
|
303
|
+
assert all(
|
304
|
+
x.shape[-i] > 2**self.num_steps for i in range(1, self.spatial_dims + 1)
|
305
|
+
), "All spatial dimensions should be larger than 2^number_of_step"
|
306
|
+
|
307
|
+
# Patch embedding
|
308
|
+
x = self.patch_embed(x)
|
309
|
+
skip_connections = []
|
310
|
+
|
311
|
+
# Encoding path
|
312
|
+
for _idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
|
313
|
+
x = encoder(x)
|
314
|
+
skip_connections.append(x)
|
315
|
+
x = downsample(x)
|
316
|
+
|
317
|
+
# Latent space
|
318
|
+
x = self.latent(x)
|
319
|
+
|
320
|
+
# Decoding path
|
321
|
+
for idx in range(len(self.decoder_levels)):
|
322
|
+
x = self.upsamples[idx](x)
|
323
|
+
x = torch.concat([x, skip_connections[-(idx + 1)]], 1)
|
324
|
+
if idx < len(self.decoder_levels) - 1:
|
325
|
+
x = self.reduce_channels[idx](x)
|
326
|
+
x = self.decoder_levels[idx](x)
|
327
|
+
|
328
|
+
# Final refinement
|
329
|
+
x = self.refinement(x)
|
330
|
+
|
331
|
+
if self.dual_pixel_task:
|
332
|
+
x = x + self.skip_conv(skip_connections[0])
|
333
|
+
x = self.output(x)
|
334
|
+
else:
|
335
|
+
x = self.output(x)
|
336
|
+
|
337
|
+
return x
|
monai/networks/utils.py
CHANGED
@@ -49,6 +49,7 @@ __all__ = [
|
|
49
49
|
"normal_init",
|
50
50
|
"icnr_init",
|
51
51
|
"pixelshuffle",
|
52
|
+
"pixelunshuffle",
|
52
53
|
"eval_mode",
|
53
54
|
"train_mode",
|
54
55
|
"get_state_dict",
|
@@ -376,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
|
|
376
377
|
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
|
377
378
|
|
378
379
|
Args:
|
379
|
-
x: Input tensor
|
380
|
+
x: Input tensor with shape BCHW[D]
|
380
381
|
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
|
381
382
|
scale_factor: factor to rescale the spatial dimensions by, must be >=1
|
382
383
|
|
@@ -411,6 +412,48 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
|
|
411
412
|
return x
|
412
413
|
|
413
414
|
|
415
|
+
def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor:
|
416
|
+
"""
|
417
|
+
Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.
|
418
|
+
Inverse operation of pixelshuffle.
|
419
|
+
|
420
|
+
See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
|
421
|
+
Using an Efficient Sub-Pixel Convolutional Neural Network."
|
422
|
+
|
423
|
+
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
|
424
|
+
|
425
|
+
Args:
|
426
|
+
x: Input tensor with shape BCHW[D]
|
427
|
+
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
|
428
|
+
scale_factor: factor to reduce the spatial dimensions by, must be >=1
|
429
|
+
|
430
|
+
Returns:
|
431
|
+
Unshuffled version of `x` with shape (B, C*(r**d), H/r, W/r) for 2D
|
432
|
+
or (B, C*(r**d), D/r, H/r, W/r) for 3D, where r is the scale_factor
|
433
|
+
and d is spatial_dims.
|
434
|
+
|
435
|
+
Raises:
|
436
|
+
ValueError: When spatial dimensions are not divisible by scale_factor
|
437
|
+
"""
|
438
|
+
dim, factor = spatial_dims, scale_factor
|
439
|
+
input_size = list(x.size())
|
440
|
+
batch_size, channels = input_size[:2]
|
441
|
+
scale_factor_mult = factor**dim
|
442
|
+
new_channels = channels * scale_factor_mult
|
443
|
+
|
444
|
+
if any(d % factor != 0 for d in input_size[2:]):
|
445
|
+
raise ValueError(
|
446
|
+
f"All spatial dimensions must be divisible by factor {factor}. " f", spatial shape is: {input_size[2:]}"
|
447
|
+
)
|
448
|
+
output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]]
|
449
|
+
reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], [])
|
450
|
+
|
451
|
+
permute_indices = [0, 1] + [(2 * i + 3) for i in range(spatial_dims)] + [(2 * i + 2) for i in range(spatial_dims)]
|
452
|
+
x = x.reshape(reshaped_size).permute(permute_indices)
|
453
|
+
x = x.reshape(output_size)
|
454
|
+
return x
|
455
|
+
|
456
|
+
|
414
457
|
@contextmanager
|
415
458
|
def eval_mode(*nets: nn.Module):
|
416
459
|
"""
|
monai/utils/__init__.py
CHANGED
monai/utils/enums.py
CHANGED
@@ -24,6 +24,7 @@ __all__ = [
|
|
24
24
|
"SplineMode",
|
25
25
|
"InterpolateMode",
|
26
26
|
"UpsampleMode",
|
27
|
+
"DownsampleMode",
|
27
28
|
"BlendMode",
|
28
29
|
"PytorchPadMode",
|
29
30
|
"NdimageMode",
|
@@ -181,6 +182,18 @@ class UpsampleMode(StrEnum):
|
|
181
182
|
PIXELSHUFFLE = "pixelshuffle"
|
182
183
|
|
183
184
|
|
185
|
+
class DownsampleMode(StrEnum):
|
186
|
+
"""
|
187
|
+
See also: :py:class:`monai.networks.blocks.UpSample`
|
188
|
+
"""
|
189
|
+
|
190
|
+
CONV = "conv" # e.g. using strided convolution
|
191
|
+
CONVGROUP = "convgroup" # e.g. using grouped strided convolution
|
192
|
+
PIXELUNSHUFFLE = "pixelunshuffle"
|
193
|
+
MAXPOOL = "maxpool"
|
194
|
+
AVGPOOL = "avgpool"
|
195
|
+
|
196
|
+
|
184
197
|
class BlendMode(StrEnum):
|
185
198
|
"""
|
186
199
|
See also: :py:class:`monai.data.utils.compute_importance_map`
|