monai-weekly 1.5.dev2444__py3-none-any.whl → 1.5.dev2446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/bundle/scripts.py +2 -0
- monai/networks/blocks/__init__.py +1 -0
- monai/networks/blocks/mednext_block.py +309 -0
- monai/networks/nets/__init__.py +19 -0
- monai/networks/nets/mednext.py +354 -0
- monai/networks/nets/vista3d.py +0 -1
- monai/networks/trt_compiler.py +161 -55
- monai/networks/utils.py +11 -5
- monai/transforms/utility/array.py +2 -2
- monai/utils/__init__.py +1 -0
- monai/utils/module.py +41 -0
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/METADATA +1 -1
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/RECORD +18 -16
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/WHEEL +1 -1
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,354 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
# Portions of this code are derived from the original repository at:
|
13
|
+
# https://github.com/MIC-DKFZ/MedNeXt
|
14
|
+
# and are used under the terms of the Apache License, Version 2.0.
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from collections.abc import Sequence
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import torch.nn as nn
|
22
|
+
|
23
|
+
from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"MedNeXt",
|
27
|
+
"MedNeXtSmall",
|
28
|
+
"MedNeXtBase",
|
29
|
+
"MedNeXtMedium",
|
30
|
+
"MedNeXtLarge",
|
31
|
+
"MedNext",
|
32
|
+
"MedNextS",
|
33
|
+
"MedNeXtS",
|
34
|
+
"MedNextSmall",
|
35
|
+
"MedNextB",
|
36
|
+
"MedNeXtB",
|
37
|
+
"MedNextBase",
|
38
|
+
"MedNextM",
|
39
|
+
"MedNeXtM",
|
40
|
+
"MedNextMedium",
|
41
|
+
"MedNextL",
|
42
|
+
"MedNeXtL",
|
43
|
+
"MedNextLarge",
|
44
|
+
]
|
45
|
+
|
46
|
+
|
47
|
+
class MedNeXt(nn.Module):
|
48
|
+
"""
|
49
|
+
MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975
|
50
|
+
|
51
|
+
Args:
|
52
|
+
spatial_dims: spatial dimension of the input data. Defaults to 3.
|
53
|
+
init_filters: number of output channels for initial convolution layer. Defaults to 32.
|
54
|
+
in_channels: number of input channels for the network. Defaults to 1.
|
55
|
+
out_channels: number of output channels for the network. Defaults to 2.
|
56
|
+
encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2.
|
57
|
+
decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2.
|
58
|
+
bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2.
|
59
|
+
kernel_size: kernel size for convolutions. Defaults to 7.
|
60
|
+
deep_supervision: whether to use deep supervision. Defaults to False.
|
61
|
+
use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False.
|
62
|
+
blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2].
|
63
|
+
blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.
|
64
|
+
blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2].
|
65
|
+
norm_type: type of normalization layer. Defaults to 'group'.
|
66
|
+
global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808
|
67
|
+
"""
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
spatial_dims: int = 3,
|
72
|
+
init_filters: int = 32,
|
73
|
+
in_channels: int = 1,
|
74
|
+
out_channels: int = 2,
|
75
|
+
encoder_expansion_ratio: Sequence[int] | int = 2,
|
76
|
+
decoder_expansion_ratio: Sequence[int] | int = 2,
|
77
|
+
bottleneck_expansion_ratio: int = 2,
|
78
|
+
kernel_size: int = 7,
|
79
|
+
deep_supervision: bool = False,
|
80
|
+
use_residual_connection: bool = False,
|
81
|
+
blocks_down: Sequence[int] = (2, 2, 2, 2),
|
82
|
+
blocks_bottleneck: int = 2,
|
83
|
+
blocks_up: Sequence[int] = (2, 2, 2, 2),
|
84
|
+
norm_type: str = "group",
|
85
|
+
global_resp_norm: bool = False,
|
86
|
+
):
|
87
|
+
"""
|
88
|
+
Initialize the MedNeXt model.
|
89
|
+
|
90
|
+
This method sets up the architecture of the model, including:
|
91
|
+
- Stem convolution
|
92
|
+
- Encoder stages and downsampling blocks
|
93
|
+
- Bottleneck blocks
|
94
|
+
- Decoder stages and upsampling blocks
|
95
|
+
- Output blocks for deep supervision (if enabled)
|
96
|
+
"""
|
97
|
+
super().__init__()
|
98
|
+
|
99
|
+
self.do_ds = deep_supervision
|
100
|
+
assert spatial_dims in [2, 3], "`spatial_dims` can only be 2 or 3."
|
101
|
+
spatial_dims_str = f"{spatial_dims}d"
|
102
|
+
enc_kernel_size = dec_kernel_size = kernel_size
|
103
|
+
|
104
|
+
if isinstance(encoder_expansion_ratio, int):
|
105
|
+
encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down)
|
106
|
+
|
107
|
+
if isinstance(decoder_expansion_ratio, int):
|
108
|
+
decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up)
|
109
|
+
|
110
|
+
conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d
|
111
|
+
|
112
|
+
self.stem = conv(in_channels, init_filters, kernel_size=1)
|
113
|
+
|
114
|
+
enc_stages = []
|
115
|
+
down_blocks = []
|
116
|
+
|
117
|
+
for i, num_blocks in enumerate(blocks_down):
|
118
|
+
enc_stages.append(
|
119
|
+
nn.Sequential(
|
120
|
+
*[
|
121
|
+
MedNeXtBlock(
|
122
|
+
in_channels=init_filters * (2**i),
|
123
|
+
out_channels=init_filters * (2**i),
|
124
|
+
expansion_ratio=encoder_expansion_ratio[i],
|
125
|
+
kernel_size=enc_kernel_size,
|
126
|
+
use_residual_connection=use_residual_connection,
|
127
|
+
norm_type=norm_type,
|
128
|
+
dim=spatial_dims_str,
|
129
|
+
global_resp_norm=global_resp_norm,
|
130
|
+
)
|
131
|
+
for _ in range(num_blocks)
|
132
|
+
]
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
down_blocks.append(
|
137
|
+
MedNeXtDownBlock(
|
138
|
+
in_channels=init_filters * (2**i),
|
139
|
+
out_channels=init_filters * (2 ** (i + 1)),
|
140
|
+
expansion_ratio=encoder_expansion_ratio[i],
|
141
|
+
kernel_size=enc_kernel_size,
|
142
|
+
use_residual_connection=use_residual_connection,
|
143
|
+
norm_type=norm_type,
|
144
|
+
dim=spatial_dims_str,
|
145
|
+
)
|
146
|
+
)
|
147
|
+
|
148
|
+
self.enc_stages = nn.ModuleList(enc_stages)
|
149
|
+
self.down_blocks = nn.ModuleList(down_blocks)
|
150
|
+
|
151
|
+
self.bottleneck = nn.Sequential(
|
152
|
+
*[
|
153
|
+
MedNeXtBlock(
|
154
|
+
in_channels=init_filters * (2 ** len(blocks_down)),
|
155
|
+
out_channels=init_filters * (2 ** len(blocks_down)),
|
156
|
+
expansion_ratio=bottleneck_expansion_ratio,
|
157
|
+
kernel_size=dec_kernel_size,
|
158
|
+
use_residual_connection=use_residual_connection,
|
159
|
+
norm_type=norm_type,
|
160
|
+
dim=spatial_dims_str,
|
161
|
+
global_resp_norm=global_resp_norm,
|
162
|
+
)
|
163
|
+
for _ in range(blocks_bottleneck)
|
164
|
+
]
|
165
|
+
)
|
166
|
+
|
167
|
+
up_blocks = []
|
168
|
+
dec_stages = []
|
169
|
+
for i, num_blocks in enumerate(blocks_up):
|
170
|
+
up_blocks.append(
|
171
|
+
MedNeXtUpBlock(
|
172
|
+
in_channels=init_filters * (2 ** (len(blocks_up) - i)),
|
173
|
+
out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
|
174
|
+
expansion_ratio=decoder_expansion_ratio[i],
|
175
|
+
kernel_size=dec_kernel_size,
|
176
|
+
use_residual_connection=use_residual_connection,
|
177
|
+
norm_type=norm_type,
|
178
|
+
dim=spatial_dims_str,
|
179
|
+
global_resp_norm=global_resp_norm,
|
180
|
+
)
|
181
|
+
)
|
182
|
+
|
183
|
+
dec_stages.append(
|
184
|
+
nn.Sequential(
|
185
|
+
*[
|
186
|
+
MedNeXtBlock(
|
187
|
+
in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
|
188
|
+
out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
|
189
|
+
expansion_ratio=decoder_expansion_ratio[i],
|
190
|
+
kernel_size=dec_kernel_size,
|
191
|
+
use_residual_connection=use_residual_connection,
|
192
|
+
norm_type=norm_type,
|
193
|
+
dim=spatial_dims_str,
|
194
|
+
global_resp_norm=global_resp_norm,
|
195
|
+
)
|
196
|
+
for _ in range(num_blocks)
|
197
|
+
]
|
198
|
+
)
|
199
|
+
)
|
200
|
+
|
201
|
+
self.up_blocks = nn.ModuleList(up_blocks)
|
202
|
+
self.dec_stages = nn.ModuleList(dec_stages)
|
203
|
+
|
204
|
+
self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str)
|
205
|
+
|
206
|
+
if deep_supervision:
|
207
|
+
out_blocks = [
|
208
|
+
MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str)
|
209
|
+
for i in range(1, len(blocks_up) + 1)
|
210
|
+
]
|
211
|
+
|
212
|
+
out_blocks.reverse()
|
213
|
+
self.out_blocks = nn.ModuleList(out_blocks)
|
214
|
+
|
215
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:
|
216
|
+
"""
|
217
|
+
Forward pass of the MedNeXt model.
|
218
|
+
|
219
|
+
This method performs the forward pass through the model, including:
|
220
|
+
- Stem convolution
|
221
|
+
- Encoder stages and downsampling
|
222
|
+
- Bottleneck blocks
|
223
|
+
- Decoder stages and upsampling with skip connections
|
224
|
+
- Output blocks for deep supervision (if enabled)
|
225
|
+
|
226
|
+
Args:
|
227
|
+
x (torch.Tensor): Input tensor.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
torch.Tensor or Sequence[torch.Tensor]: Output tensor(s).
|
231
|
+
"""
|
232
|
+
# Apply stem convolution
|
233
|
+
x = self.stem(x)
|
234
|
+
|
235
|
+
# Encoder forward pass
|
236
|
+
enc_outputs = []
|
237
|
+
for enc_stage, down_block in zip(self.enc_stages, self.down_blocks):
|
238
|
+
x = enc_stage(x)
|
239
|
+
enc_outputs.append(x)
|
240
|
+
x = down_block(x)
|
241
|
+
|
242
|
+
# Bottleneck forward pass
|
243
|
+
x = self.bottleneck(x)
|
244
|
+
|
245
|
+
# Initialize deep supervision outputs if enabled
|
246
|
+
if self.do_ds:
|
247
|
+
ds_outputs = []
|
248
|
+
|
249
|
+
# Decoder forward pass with skip connections
|
250
|
+
for i, (up_block, dec_stage) in enumerate(zip(self.up_blocks, self.dec_stages)):
|
251
|
+
if self.do_ds and i < len(self.out_blocks):
|
252
|
+
ds_outputs.append(self.out_blocks[i](x))
|
253
|
+
|
254
|
+
x = up_block(x)
|
255
|
+
x = x + enc_outputs[-(i + 1)]
|
256
|
+
x = dec_stage(x)
|
257
|
+
|
258
|
+
# Final output block
|
259
|
+
x = self.out_0(x)
|
260
|
+
|
261
|
+
# Return output(s)
|
262
|
+
if self.do_ds and self.training:
|
263
|
+
return (x, *ds_outputs[::-1])
|
264
|
+
else:
|
265
|
+
return x
|
266
|
+
|
267
|
+
|
268
|
+
# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975
|
269
|
+
def create_mednext(
|
270
|
+
variant: str,
|
271
|
+
spatial_dims: int = 3,
|
272
|
+
in_channels: int = 1,
|
273
|
+
out_channels: int = 2,
|
274
|
+
kernel_size: int = 3,
|
275
|
+
deep_supervision: bool = False,
|
276
|
+
) -> MedNeXt:
|
277
|
+
"""
|
278
|
+
Factory method to create MedNeXt variants.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').
|
282
|
+
spatial_dims (int): Number of spatial dimensions. Defaults to 3.
|
283
|
+
in_channels (int): Number of input channels. Defaults to 1.
|
284
|
+
out_channels (int): Number of output channels. Defaults to 2.
|
285
|
+
kernel_size (int): Kernel size for convolutions. Defaults to 3.
|
286
|
+
deep_supervision (bool): Whether to use deep supervision. Defaults to False.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
MedNeXt: The specified MedNeXt variant.
|
290
|
+
|
291
|
+
Raises:
|
292
|
+
ValueError: If an invalid variant is specified.
|
293
|
+
"""
|
294
|
+
common_args = {
|
295
|
+
"spatial_dims": spatial_dims,
|
296
|
+
"in_channels": in_channels,
|
297
|
+
"out_channels": out_channels,
|
298
|
+
"kernel_size": kernel_size,
|
299
|
+
"deep_supervision": deep_supervision,
|
300
|
+
"use_residual_connection": True,
|
301
|
+
"norm_type": "group",
|
302
|
+
"global_resp_norm": False,
|
303
|
+
"init_filters": 32,
|
304
|
+
}
|
305
|
+
|
306
|
+
if variant.upper() == "S":
|
307
|
+
return MedNeXt(
|
308
|
+
encoder_expansion_ratio=2,
|
309
|
+
decoder_expansion_ratio=2,
|
310
|
+
bottleneck_expansion_ratio=2,
|
311
|
+
blocks_down=(2, 2, 2, 2),
|
312
|
+
blocks_bottleneck=2,
|
313
|
+
blocks_up=(2, 2, 2, 2),
|
314
|
+
**common_args, # type: ignore
|
315
|
+
)
|
316
|
+
elif variant.upper() == "B":
|
317
|
+
return MedNeXt(
|
318
|
+
encoder_expansion_ratio=(2, 3, 4, 4),
|
319
|
+
decoder_expansion_ratio=(4, 4, 3, 2),
|
320
|
+
bottleneck_expansion_ratio=4,
|
321
|
+
blocks_down=(2, 2, 2, 2),
|
322
|
+
blocks_bottleneck=2,
|
323
|
+
blocks_up=(2, 2, 2, 2),
|
324
|
+
**common_args, # type: ignore
|
325
|
+
)
|
326
|
+
elif variant.upper() == "M":
|
327
|
+
return MedNeXt(
|
328
|
+
encoder_expansion_ratio=(2, 3, 4, 4),
|
329
|
+
decoder_expansion_ratio=(4, 4, 3, 2),
|
330
|
+
bottleneck_expansion_ratio=4,
|
331
|
+
blocks_down=(3, 4, 4, 4),
|
332
|
+
blocks_bottleneck=4,
|
333
|
+
blocks_up=(4, 4, 4, 3),
|
334
|
+
**common_args, # type: ignore
|
335
|
+
)
|
336
|
+
elif variant.upper() == "L":
|
337
|
+
return MedNeXt(
|
338
|
+
encoder_expansion_ratio=(3, 4, 8, 8),
|
339
|
+
decoder_expansion_ratio=(8, 8, 4, 3),
|
340
|
+
bottleneck_expansion_ratio=8,
|
341
|
+
blocks_down=(3, 4, 8, 8),
|
342
|
+
blocks_bottleneck=8,
|
343
|
+
blocks_up=(8, 8, 4, 3),
|
344
|
+
**common_args, # type: ignore
|
345
|
+
)
|
346
|
+
else:
|
347
|
+
raise ValueError(f"Invalid MedNeXt variant: {variant}")
|
348
|
+
|
349
|
+
|
350
|
+
MedNext = MedNeXt
|
351
|
+
MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs)
|
352
|
+
MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs)
|
353
|
+
MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs)
|
354
|
+
MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs)
|
monai/networks/nets/vista3d.py
CHANGED
@@ -641,7 +641,6 @@ class ClassMappingClassify(nn.Module):
|
|
641
641
|
# [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
|
642
642
|
masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
|
643
643
|
masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)
|
644
|
-
|
645
644
|
return masks_embedding, class_embedding
|
646
645
|
|
647
646
|
|