monai-weekly 1.5.dev2511__py3-none-any.whl → 1.5.dev2513__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,8 +16,11 @@ from collections.abc import Sequence
16
16
  import torch
17
17
  import torch.nn as nn
18
18
 
19
- from monai.networks.layers.factories import Pool
20
- from monai.utils import ensure_tuple_rep
19
+ from monai.networks.layers.factories import Conv, Pool
20
+ from monai.networks.utils import pixelunshuffle
21
+ from monai.utils import DownsampleMode, ensure_tuple_rep, look_up_option
22
+
23
+ __all__ = ["MaxAvgPool", "DownSample", "Downsample", "SubpixelDownsample", "SubpixelDownSample", "Subpixeldownsample"]
21
24
 
22
25
 
23
26
  class MaxAvgPool(nn.Module):
@@ -61,3 +64,239 @@ class MaxAvgPool(nn.Module):
61
64
  Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]).
62
65
  """
63
66
  return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1)
67
+
68
+
69
+ class DownSample(nn.Sequential):
70
+ """
71
+ Downsamples data by `scale_factor`.
72
+
73
+ Supported modes are:
74
+
75
+ - "conv": uses a strided convolution for learnable downsampling.
76
+ - "convgroup": uses a grouped strided convolution for efficient feature reduction.
77
+ - "nontrainable": uses :py:class:`torch.nn.Upsample` with inverse scale factor.
78
+ - "pixelunshuffle": uses :py:class:`monai.networks.blocks.PixelUnshuffle` for channel-space rearrangement.
79
+
80
+ This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.
81
+ Please check the link below for more details:
82
+ https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
83
+
84
+ This module can optionally take a pre-convolution
85
+ (often used to map the number of features from `in_channels` to `out_channels`).
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ spatial_dims: int,
91
+ in_channels: int | None = None,
92
+ out_channels: int | None = None,
93
+ scale_factor: Sequence[float] | float = 2,
94
+ kernel_size: Sequence[float] | float | None = None,
95
+ mode: DownsampleMode | str = DownsampleMode.CONV,
96
+ pre_conv: nn.Module | str | None = "default",
97
+ post_conv: nn.Module | None = None,
98
+ bias: bool = True,
99
+ ) -> None:
100
+ """
101
+ Downsamples data by `scale_factor`.
102
+ Supported modes are:
103
+
104
+ - DownsampleMode.CONV: uses a strided convolution for learnable downsampling.
105
+ - DownsampleMode.CONVGROUP: uses a grouped strided convolution for efficient feature reduction.
106
+ - DownsampleMode.MAXPOOL: uses maxpooling for non-learnable downsampling.
107
+ - DownsampleMode.AVGPOOL: uses average pooling for non-learnable downsampling.
108
+ - DownsampleMode.PIXELUNSHUFFLE: uses :py:class:`monai.networks.blocks.SubpixelDownsample`.
109
+
110
+ This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.
111
+ Please check the link below for more details:
112
+ https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
113
+
114
+ This module can optionally take a pre-convolution and post-convolution
115
+ (often used to map the number of features from `in_channels` to `out_channels`).
116
+
117
+ Args:
118
+ spatial_dims: number of spatial dimensions of the input image.
119
+ in_channels: number of channels of the input image.
120
+ out_channels: number of channels of the output image. Defaults to `in_channels`.
121
+ scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2.
122
+ kernel_size: kernel size used during convolutions. Defaults to `scale_factor`.
123
+ mode: {``DownsampleMode.CONV``, ``DownsampleMode.CONVGROUP``, ``DownsampleMode.MAXPOOL``, ``DownsampleMode.AVGPOOL``,
124
+ ``DownsampleMode.PIXELUNSHUFFLE``}. Defaults to ``DownsampleMode.CONV``.
125
+ pre_conv: a conv block applied before downsampling. Defaults to "default".
126
+ When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
127
+ Only used in the "maxpool", "avgpool" or "pixelunshuffle" modes.
128
+ post_conv: a conv block applied after downsampling. Defaults to None. Only used in the "maxpool" and "avgpool" modes.
129
+ bias: whether to have a bias term in the default preconv and conv layers. Defaults to True.
130
+ """
131
+ super().__init__()
132
+
133
+ scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
134
+ down_mode = look_up_option(mode, DownsampleMode)
135
+
136
+ if not kernel_size:
137
+ kernel_size_ = scale_factor_
138
+ padding = ensure_tuple_rep(0, spatial_dims)
139
+ else:
140
+ kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
141
+ padding = tuple((k - 1) // 2 for k in kernel_size_)
142
+
143
+ if down_mode == DownsampleMode.CONV:
144
+ if not in_channels:
145
+ raise ValueError("in_channels needs to be specified in conv mode")
146
+ self.add_module(
147
+ "conv",
148
+ Conv[Conv.CONV, spatial_dims](
149
+ in_channels=in_channels,
150
+ out_channels=out_channels or in_channels,
151
+ kernel_size=kernel_size_,
152
+ stride=scale_factor_,
153
+ padding=padding,
154
+ bias=bias,
155
+ ),
156
+ )
157
+ elif down_mode == DownsampleMode.CONVGROUP:
158
+ if not in_channels:
159
+ raise ValueError("in_channels needs to be specified")
160
+ if out_channels is None:
161
+ out_channels = in_channels
162
+ groups = in_channels if out_channels % in_channels == 0 else 1
163
+ self.add_module(
164
+ "convgroup",
165
+ Conv[Conv.CONV, spatial_dims](
166
+ in_channels=in_channels,
167
+ out_channels=out_channels,
168
+ kernel_size=kernel_size_,
169
+ stride=scale_factor_,
170
+ padding=padding,
171
+ groups=groups,
172
+ bias=bias,
173
+ ),
174
+ )
175
+ elif down_mode == DownsampleMode.MAXPOOL:
176
+ if pre_conv == "default" and (out_channels != in_channels):
177
+ if not in_channels:
178
+ raise ValueError("in_channels needs to be specified")
179
+ self.add_module(
180
+ "preconv",
181
+ Conv[Conv.CONV, spatial_dims](
182
+ in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias
183
+ ),
184
+ )
185
+ self.add_module(
186
+ "maxpool", Pool[Pool.MAX, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding)
187
+ )
188
+ if post_conv:
189
+ self.add_module("postconv", post_conv)
190
+
191
+ elif down_mode == DownsampleMode.AVGPOOL:
192
+ if pre_conv == "default" and (out_channels != in_channels):
193
+ if not in_channels:
194
+ raise ValueError("in_channels needs to be specified")
195
+ self.add_module(
196
+ "preconv",
197
+ Conv[Conv.CONV, spatial_dims](
198
+ in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias
199
+ ),
200
+ )
201
+ self.add_module(
202
+ "avgpool", Pool[Pool.AVG, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding)
203
+ )
204
+ if post_conv:
205
+ self.add_module("postconv", post_conv)
206
+
207
+ elif down_mode == DownsampleMode.PIXELUNSHUFFLE:
208
+ self.add_module(
209
+ "pixelunshuffle",
210
+ SubpixelDownsample(
211
+ spatial_dims=spatial_dims,
212
+ in_channels=in_channels,
213
+ out_channels=out_channels,
214
+ scale_factor=scale_factor_[0],
215
+ conv_block=pre_conv,
216
+ bias=bias,
217
+ ),
218
+ )
219
+
220
+
221
+ class SubpixelDownsample(nn.Module):
222
+ """
223
+ Downsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images.
224
+ The module consists of two parts. First, a convolutional layer is employed
225
+ to adjust the number of channels. Secondly, a pixel unshuffle manipulation
226
+ rearranges the spatial information into channel space, effectively reducing
227
+ spatial dimensions while increasing channel depth.
228
+
229
+ The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions
230
+ from (B, C, H*r, W*r) to (B, C*r², H, W) for 2D images or from (B, C, H*r, W*r, D*r) to (B, C*r³, H, W, D) in 3D case.
231
+
232
+ Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2).
233
+
234
+ See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
235
+ Using a nEfficient Sub-Pixel Convolutional Neural Network."
236
+
237
+ The pixel unshuffle mechanism is the inverse operation of:
238
+ https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ spatial_dims: int,
244
+ in_channels: int | None,
245
+ out_channels: int | None = None,
246
+ scale_factor: int = 2,
247
+ conv_block: nn.Module | str | None = "default",
248
+ bias: bool = True,
249
+ ) -> None:
250
+ """
251
+ Downsamples data by rearranging spatial information into channel space.
252
+ This reduces spatial dimensions while increasing channel depth.
253
+
254
+ Args:
255
+ spatial_dims: number of spatial dimensions of the input image.
256
+ in_channels: number of channels of the input image.
257
+ out_channels: optional number of channels of the output image.
258
+ scale_factor: factor to reduce the spatial dimensions by. Defaults to 2.
259
+ conv_block: a conv block to adjust channels before downsampling. Defaults to None.
260
+ When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
261
+ When ``conv_block`` is an ``nn.module``,
262
+ please ensure the input number of channels matches requirements.
263
+ bias: whether to have a bias term in the default conv_block. Defaults to True.
264
+ """
265
+ super().__init__()
266
+
267
+ if scale_factor <= 0:
268
+ raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.")
269
+
270
+ self.dimensions = spatial_dims
271
+ self.scale_factor = scale_factor
272
+
273
+ if conv_block == "default":
274
+ if not in_channels:
275
+ raise ValueError("in_channels need to be specified.")
276
+ out_channels = out_channels or in_channels
277
+ self.conv_block = Conv[Conv.CONV, self.dimensions](
278
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=bias
279
+ )
280
+ elif conv_block is None:
281
+ self.conv_block = nn.Identity()
282
+ else:
283
+ self.conv_block = conv_block
284
+
285
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
286
+ """
287
+ Args:
288
+ x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).
289
+ Returns:
290
+ Tensor with reduced spatial dimensions and increased channel depth.
291
+ """
292
+ x = self.conv_block(x)
293
+ if not all(d % self.scale_factor == 0 for d in x.shape[2:]):
294
+ raise ValueError(
295
+ f"All spatial dimensions {x.shape[2:]} must be evenly " f"divisible by scale_factor {self.scale_factor}"
296
+ )
297
+ x = pixelunshuffle(x, self.dimensions, self.scale_factor)
298
+ return x
299
+
300
+
301
+ Downsample = DownSample
302
+ SubpixelDownSample = Subpixeldownsample = SubpixelDownsample
@@ -0,0 +1,337 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ from __future__ import annotations
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from monai.networks.blocks.cablock import CABlock, FeedForward
17
+ from monai.networks.blocks.convolutions import Convolution
18
+ from monai.networks.blocks.downsample import DownSample
19
+ from monai.networks.blocks.upsample import UpSample
20
+ from monai.networks.layers.factories import Norm
21
+ from monai.utils.enums import DownsampleMode, UpsampleMode
22
+
23
+
24
+ class MDTATransformerBlock(nn.Module):
25
+ """Basic transformer unit combining MDTA and GDFN with skip connections.
26
+ Unlike standard transformers that use LayerNorm, this block uses Instance Norm
27
+ for better adaptation to image restoration tasks.
28
+
29
+ Args:
30
+ spatial_dims: Number of spatial dimensions (2D or 3D)
31
+ dim: Number of input channels
32
+ num_heads: Number of attention heads
33
+ ffn_expansion_factor: Expansion factor for feed-forward network
34
+ bias: Whether to use bias in attention layers
35
+ layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False.
36
+ flash_attention: Whether to use flash attention optimization. Defaults to False.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ spatial_dims: int,
42
+ dim: int,
43
+ num_heads: int,
44
+ ffn_expansion_factor: float,
45
+ bias: bool,
46
+ layer_norm_use_bias: bool = False,
47
+ flash_attention: bool = False,
48
+ ):
49
+ super().__init__()
50
+ self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)
51
+ self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention)
52
+ self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)
53
+ self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias)
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ x = x + self.attn(self.norm1(x))
57
+ x = x + self.ffn(self.norm2(x))
58
+ return x
59
+
60
+
61
+ class OverlapPatchEmbed(Convolution):
62
+ """Initial feature extraction using overlapped convolutions.
63
+ Unlike standard patch embeddings that use non-overlapping patches,
64
+ this approach maintains spatial continuity through 3x3 convolutions.
65
+
66
+ Args:
67
+ spatial_dims: Number of spatial dimensions (2D or 3D)
68
+ in_channels: Number of input channels
69
+ embed_dim: Dimension of embedded features. Defaults to 48.
70
+ bias: Whether to use bias in convolution layer. Defaults to False.
71
+ """
72
+
73
+ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False):
74
+ super().__init__(
75
+ spatial_dims=spatial_dims,
76
+ in_channels=in_channels,
77
+ out_channels=embed_dim,
78
+ kernel_size=3,
79
+ strides=1,
80
+ padding=1,
81
+ bias=bias,
82
+ conv_only=True,
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x = super().forward(x)
87
+ return x
88
+
89
+
90
+ class Restormer(nn.Module):
91
+ """Restormer: Efficient Transformer for High-Resolution Image Restoration.
92
+
93
+ Implements a U-Net style architecture with transformer blocks, combining:
94
+ - Multi-scale feature processing through progressive down/upsampling
95
+ - Efficient attention via MDTA blocks
96
+ - Local feature mixing through GDFN
97
+ - Skip connections for preserving spatial details
98
+
99
+ Architecture:
100
+ - Encoder: Progressive feature downsampling with increasing channels
101
+ - Latent: Deep feature processing at lowest resolution
102
+ - Decoder: Progressive upsampling with skip connections
103
+ - Refinement: Final feature enhancement
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ spatial_dims: int = 2,
109
+ in_channels: int = 3,
110
+ out_channels: int = 3,
111
+ dim: int = 48,
112
+ num_blocks: tuple[int, ...] = (1, 1, 1, 1),
113
+ heads: tuple[int, ...] = (1, 1, 1, 1),
114
+ num_refinement_blocks: int = 4,
115
+ ffn_expansion_factor: float = 2.66,
116
+ bias: bool = False,
117
+ layer_norm_use_bias: bool = True,
118
+ dual_pixel_task: bool = False,
119
+ flash_attention: bool = False,
120
+ ) -> None:
121
+ super().__init__()
122
+ """Initialize Restormer model.
123
+
124
+ Args:
125
+ spatial_dims: Number of spatial dimensions (2D or 3D)
126
+ in_channels: Number of input image channels
127
+ out_channels: Number of output image channels
128
+ dim: Base feature dimension. Defaults to 48.
129
+ num_blocks: Number of transformer blocks at each scale. Defaults to (1,1,1,1).
130
+ heads: Number of attention heads at each scale. Defaults to (1,1,1,1).
131
+ num_refinement_blocks: Number of final refinement blocks. Defaults to 4.
132
+ ffn_expansion_factor: Expansion factor for feed-forward network. Defaults to 2.66.
133
+ bias: Whether to use bias in convolutions. Defaults to False.
134
+ layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to True.
135
+ dual_pixel_task: Enable dual-pixel specific processing. Defaults to False.
136
+ flash_attention: Use flash attention if available. Defaults to False.
137
+
138
+ Note:
139
+ The number of blocks must be greater than 1
140
+ The length of num_blocks and heads must be equal
141
+ All values in num_blocks must be greater than 0
142
+ """
143
+ # Check input parameters
144
+ assert len(num_blocks) > 1, "Number of blocks must be greater than 1"
145
+ assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal"
146
+ assert all(n > 0 for n in num_blocks), "Number of blocks must be greater than 0"
147
+
148
+ # Initial feature extraction
149
+ self.patch_embed = OverlapPatchEmbed(spatial_dims, in_channels, dim)
150
+ self.encoder_levels = nn.ModuleList()
151
+ self.downsamples = nn.ModuleList()
152
+ self.decoder_levels = nn.ModuleList()
153
+ self.upsamples = nn.ModuleList()
154
+ self.reduce_channels = nn.ModuleList()
155
+ num_steps = len(num_blocks) - 1
156
+ self.num_steps = num_steps
157
+ self.spatial_dims = spatial_dims
158
+ spatial_multiplier = 2 ** (spatial_dims - 1)
159
+
160
+ # Define encoder levels
161
+ for n in range(num_steps):
162
+ current_dim = dim * (2) ** (n)
163
+ next_dim = current_dim // spatial_multiplier
164
+ self.encoder_levels.append(
165
+ nn.Sequential(
166
+ *[
167
+ MDTATransformerBlock(
168
+ spatial_dims=spatial_dims,
169
+ dim=current_dim,
170
+ num_heads=heads[n],
171
+ ffn_expansion_factor=ffn_expansion_factor,
172
+ bias=bias,
173
+ layer_norm_use_bias=layer_norm_use_bias,
174
+ flash_attention=flash_attention,
175
+ )
176
+ for _ in range(num_blocks[n])
177
+ ]
178
+ )
179
+ )
180
+
181
+ self.downsamples.append(
182
+ DownSample(
183
+ spatial_dims=self.spatial_dims,
184
+ in_channels=current_dim,
185
+ out_channels=next_dim,
186
+ mode=DownsampleMode.PIXELUNSHUFFLE,
187
+ scale_factor=2,
188
+ bias=bias,
189
+ )
190
+ )
191
+
192
+ # Define latent space
193
+ latent_dim = dim * (2) ** (num_steps)
194
+ self.latent = nn.Sequential(
195
+ *[
196
+ MDTATransformerBlock(
197
+ spatial_dims=spatial_dims,
198
+ dim=latent_dim,
199
+ num_heads=heads[num_steps],
200
+ ffn_expansion_factor=ffn_expansion_factor,
201
+ bias=bias,
202
+ layer_norm_use_bias=layer_norm_use_bias,
203
+ flash_attention=flash_attention,
204
+ )
205
+ for _ in range(num_blocks[num_steps])
206
+ ]
207
+ )
208
+
209
+ # Define decoder levels
210
+ for n in reversed(range(num_steps)):
211
+ current_dim = dim * (2) ** (n)
212
+ next_dim = dim * (2) ** (n + 1)
213
+ self.upsamples.append(
214
+ UpSample(
215
+ spatial_dims=self.spatial_dims,
216
+ in_channels=next_dim,
217
+ out_channels=(current_dim),
218
+ mode=UpsampleMode.PIXELSHUFFLE,
219
+ scale_factor=2,
220
+ bias=bias,
221
+ apply_pad_pool=False,
222
+ )
223
+ )
224
+
225
+ # Reduce channel layers to deal with skip connections
226
+ if n != 0:
227
+ self.reduce_channels.append(
228
+ Convolution(
229
+ spatial_dims=self.spatial_dims,
230
+ in_channels=next_dim,
231
+ out_channels=current_dim,
232
+ kernel_size=1,
233
+ bias=bias,
234
+ conv_only=True,
235
+ )
236
+ )
237
+ decoder_dim = current_dim
238
+ else:
239
+ decoder_dim = next_dim
240
+
241
+ self.decoder_levels.append(
242
+ nn.Sequential(
243
+ *[
244
+ MDTATransformerBlock(
245
+ spatial_dims=spatial_dims,
246
+ dim=decoder_dim,
247
+ num_heads=heads[n],
248
+ ffn_expansion_factor=ffn_expansion_factor,
249
+ bias=bias,
250
+ layer_norm_use_bias=layer_norm_use_bias,
251
+ flash_attention=flash_attention,
252
+ )
253
+ for _ in range(num_blocks[n])
254
+ ]
255
+ )
256
+ )
257
+
258
+ # Final refinement and output
259
+ self.refinement = nn.Sequential(
260
+ *[
261
+ MDTATransformerBlock(
262
+ spatial_dims=spatial_dims,
263
+ dim=decoder_dim,
264
+ num_heads=heads[0],
265
+ ffn_expansion_factor=ffn_expansion_factor,
266
+ bias=bias,
267
+ layer_norm_use_bias=layer_norm_use_bias,
268
+ flash_attention=flash_attention,
269
+ )
270
+ for _ in range(num_refinement_blocks)
271
+ ]
272
+ )
273
+ self.dual_pixel_task = dual_pixel_task
274
+ if self.dual_pixel_task:
275
+ self.skip_conv = Convolution(
276
+ spatial_dims=self.spatial_dims,
277
+ in_channels=dim,
278
+ out_channels=dim * 2,
279
+ kernel_size=1,
280
+ bias=bias,
281
+ conv_only=True,
282
+ )
283
+ self.output = Convolution(
284
+ spatial_dims=self.spatial_dims,
285
+ in_channels=dim * 2,
286
+ out_channels=out_channels,
287
+ kernel_size=3,
288
+ strides=1,
289
+ padding=1,
290
+ bias=bias,
291
+ conv_only=True,
292
+ )
293
+
294
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
295
+ """Forward pass of Restormer.
296
+ Processes input through encoder-decoder architecture with skip connections.
297
+ Args:
298
+ inp_img: Input image tensor of shape (B, C, H, W, [D])
299
+
300
+ Returns:
301
+ Restored image tensor of shape (B, C, H, W, [D])
302
+ """
303
+ assert all(
304
+ x.shape[-i] > 2**self.num_steps for i in range(1, self.spatial_dims + 1)
305
+ ), "All spatial dimensions should be larger than 2^number_of_step"
306
+
307
+ # Patch embedding
308
+ x = self.patch_embed(x)
309
+ skip_connections = []
310
+
311
+ # Encoding path
312
+ for _idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):
313
+ x = encoder(x)
314
+ skip_connections.append(x)
315
+ x = downsample(x)
316
+
317
+ # Latent space
318
+ x = self.latent(x)
319
+
320
+ # Decoding path
321
+ for idx in range(len(self.decoder_levels)):
322
+ x = self.upsamples[idx](x)
323
+ x = torch.concat([x, skip_connections[-(idx + 1)]], 1)
324
+ if idx < len(self.decoder_levels) - 1:
325
+ x = self.reduce_channels[idx](x)
326
+ x = self.decoder_levels[idx](x)
327
+
328
+ # Final refinement
329
+ x = self.refinement(x)
330
+
331
+ if self.dual_pixel_task:
332
+ x = x + self.skip_conv(skip_connections[0])
333
+ x = self.output(x)
334
+ else:
335
+ x = self.output(x)
336
+
337
+ return x
monai/networks/utils.py CHANGED
@@ -49,6 +49,7 @@ __all__ = [
49
49
  "normal_init",
50
50
  "icnr_init",
51
51
  "pixelshuffle",
52
+ "pixelunshuffle",
52
53
  "eval_mode",
53
54
  "train_mode",
54
55
  "get_state_dict",
@@ -376,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
376
377
  See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
377
378
 
378
379
  Args:
379
- x: Input tensor
380
+ x: Input tensor with shape BCHW[D]
380
381
  spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
381
382
  scale_factor: factor to rescale the spatial dimensions by, must be >=1
382
383
 
@@ -411,6 +412,48 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
411
412
  return x
412
413
 
413
414
 
415
+ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor:
416
+ """
417
+ Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.
418
+ Inverse operation of pixelshuffle.
419
+
420
+ See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
421
+ Using an Efficient Sub-Pixel Convolutional Neural Network."
422
+
423
+ See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
424
+
425
+ Args:
426
+ x: Input tensor with shape BCHW[D]
427
+ spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
428
+ scale_factor: factor to reduce the spatial dimensions by, must be >=1
429
+
430
+ Returns:
431
+ Unshuffled version of `x` with shape (B, C*(r**d), H/r, W/r) for 2D
432
+ or (B, C*(r**d), D/r, H/r, W/r) for 3D, where r is the scale_factor
433
+ and d is spatial_dims.
434
+
435
+ Raises:
436
+ ValueError: When spatial dimensions are not divisible by scale_factor
437
+ """
438
+ dim, factor = spatial_dims, scale_factor
439
+ input_size = list(x.size())
440
+ batch_size, channels = input_size[:2]
441
+ scale_factor_mult = factor**dim
442
+ new_channels = channels * scale_factor_mult
443
+
444
+ if any(d % factor != 0 for d in input_size[2:]):
445
+ raise ValueError(
446
+ f"All spatial dimensions must be divisible by factor {factor}. " f", spatial shape is: {input_size[2:]}"
447
+ )
448
+ output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]]
449
+ reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], [])
450
+
451
+ permute_indices = [0, 1] + [(2 * i + 3) for i in range(spatial_dims)] + [(2 * i + 2) for i in range(spatial_dims)]
452
+ x = x.reshape(reshaped_size).permute(permute_indices)
453
+ x = x.reshape(output_size)
454
+ return x
455
+
456
+
414
457
  @contextmanager
415
458
  def eval_mode(*nets: nn.Module):
416
459
  """
monai/utils/__init__.py CHANGED
@@ -29,6 +29,7 @@ from .enums import (
29
29
  CommonKeys,
30
30
  CompInitMode,
31
31
  DiceCEReduction,
32
+ DownsampleMode,
32
33
  EngineStatsKeys,
33
34
  FastMRIKeys,
34
35
  ForwardMode,
monai/utils/enums.py CHANGED
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "SplineMode",
25
25
  "InterpolateMode",
26
26
  "UpsampleMode",
27
+ "DownsampleMode",
27
28
  "BlendMode",
28
29
  "PytorchPadMode",
29
30
  "NdimageMode",
@@ -181,6 +182,18 @@ class UpsampleMode(StrEnum):
181
182
  PIXELSHUFFLE = "pixelshuffle"
182
183
 
183
184
 
185
+ class DownsampleMode(StrEnum):
186
+ """
187
+ See also: :py:class:`monai.networks.blocks.UpSample`
188
+ """
189
+
190
+ CONV = "conv" # e.g. using strided convolution
191
+ CONVGROUP = "convgroup" # e.g. using grouped strided convolution
192
+ PIXELUNSHUFFLE = "pixelunshuffle"
193
+ MAXPOOL = "maxpool"
194
+ AVGPOOL = "avgpool"
195
+
196
+
184
197
  class BlendMode(StrEnum):
185
198
  """
186
199
  See also: :py:class:`monai.data.utils.compute_importance_map`