monai-weekly 1.5.dev2444__py3-none-any.whl → 1.5.dev2446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,354 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Portions of this code are derived from the original repository at:
13
+ # https://github.com/MIC-DKFZ/MedNeXt
14
+ # and are used under the terms of the Apache License, Version 2.0.
15
+
16
+ from __future__ import annotations
17
+
18
+ from collections.abc import Sequence
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
24
+
25
+ __all__ = [
26
+ "MedNeXt",
27
+ "MedNeXtSmall",
28
+ "MedNeXtBase",
29
+ "MedNeXtMedium",
30
+ "MedNeXtLarge",
31
+ "MedNext",
32
+ "MedNextS",
33
+ "MedNeXtS",
34
+ "MedNextSmall",
35
+ "MedNextB",
36
+ "MedNeXtB",
37
+ "MedNextBase",
38
+ "MedNextM",
39
+ "MedNeXtM",
40
+ "MedNextMedium",
41
+ "MedNextL",
42
+ "MedNeXtL",
43
+ "MedNextLarge",
44
+ ]
45
+
46
+
47
+ class MedNeXt(nn.Module):
48
+ """
49
+ MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975
50
+
51
+ Args:
52
+ spatial_dims: spatial dimension of the input data. Defaults to 3.
53
+ init_filters: number of output channels for initial convolution layer. Defaults to 32.
54
+ in_channels: number of input channels for the network. Defaults to 1.
55
+ out_channels: number of output channels for the network. Defaults to 2.
56
+ encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2.
57
+ decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2.
58
+ bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2.
59
+ kernel_size: kernel size for convolutions. Defaults to 7.
60
+ deep_supervision: whether to use deep supervision. Defaults to False.
61
+ use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False.
62
+ blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2].
63
+ blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.
64
+ blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2].
65
+ norm_type: type of normalization layer. Defaults to 'group'.
66
+ global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ spatial_dims: int = 3,
72
+ init_filters: int = 32,
73
+ in_channels: int = 1,
74
+ out_channels: int = 2,
75
+ encoder_expansion_ratio: Sequence[int] | int = 2,
76
+ decoder_expansion_ratio: Sequence[int] | int = 2,
77
+ bottleneck_expansion_ratio: int = 2,
78
+ kernel_size: int = 7,
79
+ deep_supervision: bool = False,
80
+ use_residual_connection: bool = False,
81
+ blocks_down: Sequence[int] = (2, 2, 2, 2),
82
+ blocks_bottleneck: int = 2,
83
+ blocks_up: Sequence[int] = (2, 2, 2, 2),
84
+ norm_type: str = "group",
85
+ global_resp_norm: bool = False,
86
+ ):
87
+ """
88
+ Initialize the MedNeXt model.
89
+
90
+ This method sets up the architecture of the model, including:
91
+ - Stem convolution
92
+ - Encoder stages and downsampling blocks
93
+ - Bottleneck blocks
94
+ - Decoder stages and upsampling blocks
95
+ - Output blocks for deep supervision (if enabled)
96
+ """
97
+ super().__init__()
98
+
99
+ self.do_ds = deep_supervision
100
+ assert spatial_dims in [2, 3], "`spatial_dims` can only be 2 or 3."
101
+ spatial_dims_str = f"{spatial_dims}d"
102
+ enc_kernel_size = dec_kernel_size = kernel_size
103
+
104
+ if isinstance(encoder_expansion_ratio, int):
105
+ encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down)
106
+
107
+ if isinstance(decoder_expansion_ratio, int):
108
+ decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up)
109
+
110
+ conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d
111
+
112
+ self.stem = conv(in_channels, init_filters, kernel_size=1)
113
+
114
+ enc_stages = []
115
+ down_blocks = []
116
+
117
+ for i, num_blocks in enumerate(blocks_down):
118
+ enc_stages.append(
119
+ nn.Sequential(
120
+ *[
121
+ MedNeXtBlock(
122
+ in_channels=init_filters * (2**i),
123
+ out_channels=init_filters * (2**i),
124
+ expansion_ratio=encoder_expansion_ratio[i],
125
+ kernel_size=enc_kernel_size,
126
+ use_residual_connection=use_residual_connection,
127
+ norm_type=norm_type,
128
+ dim=spatial_dims_str,
129
+ global_resp_norm=global_resp_norm,
130
+ )
131
+ for _ in range(num_blocks)
132
+ ]
133
+ )
134
+ )
135
+
136
+ down_blocks.append(
137
+ MedNeXtDownBlock(
138
+ in_channels=init_filters * (2**i),
139
+ out_channels=init_filters * (2 ** (i + 1)),
140
+ expansion_ratio=encoder_expansion_ratio[i],
141
+ kernel_size=enc_kernel_size,
142
+ use_residual_connection=use_residual_connection,
143
+ norm_type=norm_type,
144
+ dim=spatial_dims_str,
145
+ )
146
+ )
147
+
148
+ self.enc_stages = nn.ModuleList(enc_stages)
149
+ self.down_blocks = nn.ModuleList(down_blocks)
150
+
151
+ self.bottleneck = nn.Sequential(
152
+ *[
153
+ MedNeXtBlock(
154
+ in_channels=init_filters * (2 ** len(blocks_down)),
155
+ out_channels=init_filters * (2 ** len(blocks_down)),
156
+ expansion_ratio=bottleneck_expansion_ratio,
157
+ kernel_size=dec_kernel_size,
158
+ use_residual_connection=use_residual_connection,
159
+ norm_type=norm_type,
160
+ dim=spatial_dims_str,
161
+ global_resp_norm=global_resp_norm,
162
+ )
163
+ for _ in range(blocks_bottleneck)
164
+ ]
165
+ )
166
+
167
+ up_blocks = []
168
+ dec_stages = []
169
+ for i, num_blocks in enumerate(blocks_up):
170
+ up_blocks.append(
171
+ MedNeXtUpBlock(
172
+ in_channels=init_filters * (2 ** (len(blocks_up) - i)),
173
+ out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
174
+ expansion_ratio=decoder_expansion_ratio[i],
175
+ kernel_size=dec_kernel_size,
176
+ use_residual_connection=use_residual_connection,
177
+ norm_type=norm_type,
178
+ dim=spatial_dims_str,
179
+ global_resp_norm=global_resp_norm,
180
+ )
181
+ )
182
+
183
+ dec_stages.append(
184
+ nn.Sequential(
185
+ *[
186
+ MedNeXtBlock(
187
+ in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
188
+ out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
189
+ expansion_ratio=decoder_expansion_ratio[i],
190
+ kernel_size=dec_kernel_size,
191
+ use_residual_connection=use_residual_connection,
192
+ norm_type=norm_type,
193
+ dim=spatial_dims_str,
194
+ global_resp_norm=global_resp_norm,
195
+ )
196
+ for _ in range(num_blocks)
197
+ ]
198
+ )
199
+ )
200
+
201
+ self.up_blocks = nn.ModuleList(up_blocks)
202
+ self.dec_stages = nn.ModuleList(dec_stages)
203
+
204
+ self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str)
205
+
206
+ if deep_supervision:
207
+ out_blocks = [
208
+ MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str)
209
+ for i in range(1, len(blocks_up) + 1)
210
+ ]
211
+
212
+ out_blocks.reverse()
213
+ self.out_blocks = nn.ModuleList(out_blocks)
214
+
215
+ def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:
216
+ """
217
+ Forward pass of the MedNeXt model.
218
+
219
+ This method performs the forward pass through the model, including:
220
+ - Stem convolution
221
+ - Encoder stages and downsampling
222
+ - Bottleneck blocks
223
+ - Decoder stages and upsampling with skip connections
224
+ - Output blocks for deep supervision (if enabled)
225
+
226
+ Args:
227
+ x (torch.Tensor): Input tensor.
228
+
229
+ Returns:
230
+ torch.Tensor or Sequence[torch.Tensor]: Output tensor(s).
231
+ """
232
+ # Apply stem convolution
233
+ x = self.stem(x)
234
+
235
+ # Encoder forward pass
236
+ enc_outputs = []
237
+ for enc_stage, down_block in zip(self.enc_stages, self.down_blocks):
238
+ x = enc_stage(x)
239
+ enc_outputs.append(x)
240
+ x = down_block(x)
241
+
242
+ # Bottleneck forward pass
243
+ x = self.bottleneck(x)
244
+
245
+ # Initialize deep supervision outputs if enabled
246
+ if self.do_ds:
247
+ ds_outputs = []
248
+
249
+ # Decoder forward pass with skip connections
250
+ for i, (up_block, dec_stage) in enumerate(zip(self.up_blocks, self.dec_stages)):
251
+ if self.do_ds and i < len(self.out_blocks):
252
+ ds_outputs.append(self.out_blocks[i](x))
253
+
254
+ x = up_block(x)
255
+ x = x + enc_outputs[-(i + 1)]
256
+ x = dec_stage(x)
257
+
258
+ # Final output block
259
+ x = self.out_0(x)
260
+
261
+ # Return output(s)
262
+ if self.do_ds and self.training:
263
+ return (x, *ds_outputs[::-1])
264
+ else:
265
+ return x
266
+
267
+
268
+ # Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975
269
+ def create_mednext(
270
+ variant: str,
271
+ spatial_dims: int = 3,
272
+ in_channels: int = 1,
273
+ out_channels: int = 2,
274
+ kernel_size: int = 3,
275
+ deep_supervision: bool = False,
276
+ ) -> MedNeXt:
277
+ """
278
+ Factory method to create MedNeXt variants.
279
+
280
+ Args:
281
+ variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').
282
+ spatial_dims (int): Number of spatial dimensions. Defaults to 3.
283
+ in_channels (int): Number of input channels. Defaults to 1.
284
+ out_channels (int): Number of output channels. Defaults to 2.
285
+ kernel_size (int): Kernel size for convolutions. Defaults to 3.
286
+ deep_supervision (bool): Whether to use deep supervision. Defaults to False.
287
+
288
+ Returns:
289
+ MedNeXt: The specified MedNeXt variant.
290
+
291
+ Raises:
292
+ ValueError: If an invalid variant is specified.
293
+ """
294
+ common_args = {
295
+ "spatial_dims": spatial_dims,
296
+ "in_channels": in_channels,
297
+ "out_channels": out_channels,
298
+ "kernel_size": kernel_size,
299
+ "deep_supervision": deep_supervision,
300
+ "use_residual_connection": True,
301
+ "norm_type": "group",
302
+ "global_resp_norm": False,
303
+ "init_filters": 32,
304
+ }
305
+
306
+ if variant.upper() == "S":
307
+ return MedNeXt(
308
+ encoder_expansion_ratio=2,
309
+ decoder_expansion_ratio=2,
310
+ bottleneck_expansion_ratio=2,
311
+ blocks_down=(2, 2, 2, 2),
312
+ blocks_bottleneck=2,
313
+ blocks_up=(2, 2, 2, 2),
314
+ **common_args, # type: ignore
315
+ )
316
+ elif variant.upper() == "B":
317
+ return MedNeXt(
318
+ encoder_expansion_ratio=(2, 3, 4, 4),
319
+ decoder_expansion_ratio=(4, 4, 3, 2),
320
+ bottleneck_expansion_ratio=4,
321
+ blocks_down=(2, 2, 2, 2),
322
+ blocks_bottleneck=2,
323
+ blocks_up=(2, 2, 2, 2),
324
+ **common_args, # type: ignore
325
+ )
326
+ elif variant.upper() == "M":
327
+ return MedNeXt(
328
+ encoder_expansion_ratio=(2, 3, 4, 4),
329
+ decoder_expansion_ratio=(4, 4, 3, 2),
330
+ bottleneck_expansion_ratio=4,
331
+ blocks_down=(3, 4, 4, 4),
332
+ blocks_bottleneck=4,
333
+ blocks_up=(4, 4, 4, 3),
334
+ **common_args, # type: ignore
335
+ )
336
+ elif variant.upper() == "L":
337
+ return MedNeXt(
338
+ encoder_expansion_ratio=(3, 4, 8, 8),
339
+ decoder_expansion_ratio=(8, 8, 4, 3),
340
+ bottleneck_expansion_ratio=8,
341
+ blocks_down=(3, 4, 8, 8),
342
+ blocks_bottleneck=8,
343
+ blocks_up=(8, 8, 4, 3),
344
+ **common_args, # type: ignore
345
+ )
346
+ else:
347
+ raise ValueError(f"Invalid MedNeXt variant: {variant}")
348
+
349
+
350
+ MedNext = MedNeXt
351
+ MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs)
352
+ MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs)
353
+ MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs)
354
+ MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs)
@@ -641,7 +641,6 @@ class ClassMappingClassify(nn.Module):
641
641
  # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
642
642
  masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
643
643
  masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)
644
-
645
644
  return masks_embedding, class_embedding
646
645
 
647
646