monai-weekly 1.4.dev2425__py3-none-any.whl → 1.4.dev2427__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/deepedit/transforms.py +1 -1
  4. monai/apps/deepgrow/transforms.py +1 -1
  5. monai/apps/generation/__init__.py +10 -0
  6. monai/apps/generation/maisi/__init__.py +10 -0
  7. monai/apps/generation/maisi/networks/__init__.py +10 -0
  8. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +975 -0
  9. monai/apps/generation/maisi/networks/controlnet_maisi.py +178 -0
  10. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +410 -0
  11. monai/apps/generation/maisi/utils/__init__.py +10 -0
  12. monai/apps/generation/maisi/utils/morphological_ops.py +170 -0
  13. monai/apps/nuclick/transforms.py +1 -1
  14. monai/apps/pathology/transforms/post/array.py +1 -1
  15. monai/apps/pathology/utils.py +2 -2
  16. monai/data/torchscript_utils.py +1 -1
  17. monai/data/ultrasound_confidence_map.py +41 -10
  18. monai/losses/dice.py +10 -3
  19. monai/metrics/utils.py +3 -3
  20. monai/optimizers/lr_finder.py +1 -1
  21. monai/transforms/intensity/array.py +25 -2
  22. monai/transforms/signal/array.py +1 -1
  23. monai/utils/misc.py +20 -2
  24. monai/utils/module.py +6 -3
  25. {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/METADATA +6 -3
  26. {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/RECORD +29 -21
  27. {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/WHEEL +1 -1
  28. {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/LICENSE +0 -0
  29. {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,975 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import gc
15
+ import logging
16
+ from typing import TYPE_CHECKING, Sequence, cast
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from monai.networks.blocks import Convolution
23
+ from monai.utils import optional_import
24
+ from monai.utils.type_conversion import convert_to_tensor
25
+
26
+ AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
27
+ AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
28
+ ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
33
+ else:
34
+ AutoencoderKLType = cast(type, AutoencoderKL)
35
+
36
+
37
+ # Set up logging configuration
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ def _empty_cuda_cache(save_mem: bool) -> None:
42
+ if torch.cuda.is_available() and save_mem:
43
+ torch.cuda.empty_cache()
44
+ return
45
+
46
+
47
+ class MaisiGroupNorm3D(nn.GroupNorm):
48
+ """
49
+ Custom 3D Group Normalization with optional print_info output.
50
+
51
+ Args:
52
+ num_groups: Number of groups for the group norm.
53
+ num_channels: Number of channels for the group norm.
54
+ eps: Epsilon value for numerical stability.
55
+ affine: Whether to use learnable affine parameters, default to `True`.
56
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
57
+ print_info: Whether to print information, default to `False`.
58
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ num_groups: int,
64
+ num_channels: int,
65
+ eps: float = 1e-5,
66
+ affine: bool = True,
67
+ norm_float16: bool = False,
68
+ print_info: bool = False,
69
+ save_mem: bool = True,
70
+ ):
71
+ super().__init__(num_groups, num_channels, eps, affine)
72
+ self.norm_float16 = norm_float16
73
+ self.print_info = print_info
74
+ self.save_mem = save_mem
75
+
76
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
77
+ if self.print_info:
78
+ logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")
79
+
80
+ if len(input.shape) != 5:
81
+ raise ValueError("Expected a 5D tensor")
82
+
83
+ param_n, param_c, param_d, param_h, param_w = input.shape
84
+ input = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w)
85
+
86
+ inputs = []
87
+ for i in range(input.size(1)):
88
+ array = input[:, i : i + 1, ...].to(dtype=torch.float32)
89
+ mean = array.mean([2, 3, 4, 5], keepdim=True)
90
+ std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
91
+ if self.norm_float16:
92
+ inputs.append(((array - mean) / std).to(dtype=torch.float16))
93
+ else:
94
+ inputs.append((array - mean) / std)
95
+
96
+ del input
97
+ _empty_cuda_cache(self.save_mem)
98
+
99
+ input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs)
100
+
101
+ input = input.view(param_n, param_c, param_d, param_h, param_w)
102
+ if self.affine:
103
+ input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1))
104
+
105
+ if self.print_info:
106
+ logger.info(f"MaisiGroupNorm3D with output size: {input.size()}")
107
+
108
+ return input
109
+
110
+ def _cat_inputs(self, inputs):
111
+ input_type = inputs[0].device.type
112
+ input = inputs[0].clone().to("cpu", non_blocking=True) if input_type == "cuda" else inputs[0].clone()
113
+ inputs[0] = 0
114
+ _empty_cuda_cache(self.save_mem)
115
+
116
+ for k in range(len(inputs) - 1):
117
+ input = torch.cat((input, inputs[k + 1].cpu()), dim=1)
118
+ inputs[k + 1] = 0
119
+ _empty_cuda_cache(self.save_mem)
120
+ gc.collect()
121
+
122
+ if self.print_info:
123
+ logger.info(f"MaisiGroupNorm3D concat progress: {k + 1}/{len(inputs) - 1}.")
124
+
125
+ return input.to("cuda", non_blocking=True) if input_type == "cuda" else input
126
+
127
+
128
+ class MaisiConvolution(nn.Module):
129
+ """
130
+ Convolutional layer with optional print_info output and custom splitting mechanism.
131
+
132
+ Args:
133
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
134
+ in_channels: Number of input channels.
135
+ out_channels: Number of output channels.
136
+ num_splits: Number of splits for the input tensor.
137
+ dim_split: Dimension of splitting for the input tensor.
138
+ print_info: Whether to print information.
139
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
140
+ Additional arguments for the convolution operation.
141
+ https://docs.monai.io/en/stable/networks.html#convolution
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ spatial_dims: int,
147
+ in_channels: int,
148
+ out_channels: int,
149
+ num_splits: int,
150
+ dim_split: int,
151
+ print_info: bool,
152
+ save_mem: bool = True,
153
+ strides: Sequence[int] | int = 1,
154
+ kernel_size: Sequence[int] | int = 3,
155
+ adn_ordering: str = "NDA",
156
+ act: tuple | str | None = "PRELU",
157
+ norm: tuple | str | None = "INSTANCE",
158
+ dropout: tuple | str | float | None = None,
159
+ dropout_dim: int = 1,
160
+ dilation: Sequence[int] | int = 1,
161
+ groups: int = 1,
162
+ bias: bool = True,
163
+ conv_only: bool = False,
164
+ is_transposed: bool = False,
165
+ padding: Sequence[int] | int | None = None,
166
+ output_padding: Sequence[int] | int | None = None,
167
+ ) -> None:
168
+ super().__init__()
169
+ self.conv = Convolution(
170
+ spatial_dims=spatial_dims,
171
+ in_channels=in_channels,
172
+ out_channels=out_channels,
173
+ strides=strides,
174
+ kernel_size=kernel_size,
175
+ adn_ordering=adn_ordering,
176
+ act=act,
177
+ norm=norm,
178
+ dropout=dropout,
179
+ dropout_dim=dropout_dim,
180
+ dilation=dilation,
181
+ groups=groups,
182
+ bias=bias,
183
+ conv_only=conv_only,
184
+ is_transposed=is_transposed,
185
+ padding=padding,
186
+ output_padding=output_padding,
187
+ )
188
+
189
+ self.dim_split = dim_split
190
+ self.stride = strides[self.dim_split] if isinstance(strides, list) else strides
191
+ self.num_splits = num_splits
192
+ self.print_info = print_info
193
+ self.save_mem = save_mem
194
+
195
+ def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]:
196
+ overlaps = [0] + [padding] * (self.num_splits - 1)
197
+ last_padding = x.size(self.dim_split + 2) % split_size
198
+
199
+ slices = [slice(None)] * 5
200
+ splits: list[torch.Tensor] = []
201
+ for i in range(self.num_splits):
202
+ slices[self.dim_split + 2] = slice(
203
+ i * split_size - overlaps[i],
204
+ (i + 1) * split_size + (padding if i != self.num_splits - 1 else last_padding),
205
+ )
206
+ splits.append(x[tuple(slices)])
207
+
208
+ if self.print_info:
209
+ for j in range(len(splits)):
210
+ logger.info(f"Split {j + 1}/{len(splits)} size: {splits[j].size()}")
211
+
212
+ return splits
213
+
214
+ def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, padding: int) -> torch.Tensor:
215
+ slices = [slice(None)] * 5
216
+ for i in range(self.num_splits):
217
+ slices[self.dim_split + 2] = slice(None, split_size) if i == 0 else slice(padding, padding + split_size)
218
+ outputs[i] = outputs[i][tuple(slices)]
219
+
220
+ if self.print_info:
221
+ for i in range(self.num_splits):
222
+ logger.info(f"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}")
223
+
224
+ if max(outputs[0].size()) < 500:
225
+ x = torch.cat(outputs, dim=self.dim_split + 2)
226
+ else:
227
+ x = outputs[0].clone().to("cpu", non_blocking=True)
228
+ outputs[0] = torch.Tensor(0)
229
+ _empty_cuda_cache(self.save_mem)
230
+ for k in range(len(outputs) - 1):
231
+ x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2)
232
+ outputs[k + 1] = torch.Tensor(0)
233
+ _empty_cuda_cache(self.save_mem)
234
+ gc.collect()
235
+ if self.print_info:
236
+ logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.")
237
+
238
+ x = x.to("cuda", non_blocking=True)
239
+ return x
240
+
241
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
242
+ if self.print_info:
243
+ logger.info(f"Number of splits: {self.num_splits}")
244
+
245
+ # compute size of splits
246
+ l = x.size(self.dim_split + 2)
247
+ split_size = l // self.num_splits
248
+
249
+ # update padding length if necessary
250
+ padding = 3
251
+ if padding % self.stride > 0:
252
+ padding = (padding // self.stride + 1) * self.stride
253
+ if self.print_info:
254
+ logger.info(f"Padding size: {padding}")
255
+
256
+ # split tensor into a list of tensors
257
+ splits = self._split_tensor(x, split_size, padding)
258
+
259
+ del x
260
+ _empty_cuda_cache(self.save_mem)
261
+
262
+ # convolution
263
+ outputs = [self.conv(split) for split in splits]
264
+ if self.print_info:
265
+ for j in range(len(outputs)):
266
+ logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}")
267
+
268
+ # update size of splits and padding length for output
269
+ split_size_out = split_size
270
+ padding_s = padding
271
+ non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0
272
+ if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2:
273
+ split_size_out *= 2
274
+ padding_s *= 2
275
+ elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2:
276
+ split_size_out //= 2
277
+ padding_s //= 2
278
+
279
+ # concatenate list of tensors
280
+ x = self._concatenate_tensors(outputs, split_size_out, padding_s)
281
+
282
+ del outputs
283
+ _empty_cuda_cache(self.save_mem)
284
+
285
+ return x
286
+
287
+
288
+ class MaisiUpsample(nn.Module):
289
+ """
290
+ Convolution-based upsampling layer.
291
+
292
+ Args:
293
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
294
+ in_channels: Number of input channels to the layer.
295
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
296
+ num_splits: Number of splits for the input tensor.
297
+ dim_split: Dimension of splitting for the input tensor.
298
+ print_info: Whether to print information.
299
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ spatial_dims: int,
305
+ in_channels: int,
306
+ use_convtranspose: bool,
307
+ num_splits: int,
308
+ dim_split: int,
309
+ print_info: bool,
310
+ save_mem: bool = True,
311
+ ) -> None:
312
+ super().__init__()
313
+ self.conv = MaisiConvolution(
314
+ spatial_dims=spatial_dims,
315
+ in_channels=in_channels,
316
+ out_channels=in_channels,
317
+ strides=2 if use_convtranspose else 1,
318
+ kernel_size=3,
319
+ padding=1,
320
+ conv_only=True,
321
+ is_transposed=use_convtranspose,
322
+ num_splits=num_splits,
323
+ dim_split=dim_split,
324
+ print_info=print_info,
325
+ save_mem=save_mem,
326
+ )
327
+ self.use_convtranspose = use_convtranspose
328
+ self.save_mem = save_mem
329
+
330
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
331
+ if self.use_convtranspose:
332
+ x = self.conv(x)
333
+ x_tensor: torch.Tensor = convert_to_tensor(x)
334
+ return x_tensor
335
+
336
+ x = F.interpolate(x, scale_factor=2.0, mode="trilinear")
337
+ _empty_cuda_cache(self.save_mem)
338
+ x = self.conv(x)
339
+ _empty_cuda_cache(self.save_mem)
340
+
341
+ out_tensor: torch.Tensor = convert_to_tensor(x)
342
+ return out_tensor
343
+
344
+
345
+ class MaisiDownsample(nn.Module):
346
+ """
347
+ Convolution-based downsampling layer.
348
+
349
+ Args:
350
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
351
+ in_channels: Number of input channels.
352
+ num_splits: Number of splits for the input tensor.
353
+ dim_split: Dimension of splitting for the input tensor.
354
+ print_info: Whether to print information.
355
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ spatial_dims: int,
361
+ in_channels: int,
362
+ num_splits: int,
363
+ dim_split: int,
364
+ print_info: bool,
365
+ save_mem: bool = True,
366
+ ) -> None:
367
+ super().__init__()
368
+ self.pad = (0, 1) * spatial_dims
369
+ self.conv = MaisiConvolution(
370
+ spatial_dims=spatial_dims,
371
+ in_channels=in_channels,
372
+ out_channels=in_channels,
373
+ strides=2,
374
+ kernel_size=3,
375
+ padding=0,
376
+ conv_only=True,
377
+ num_splits=num_splits,
378
+ dim_split=dim_split,
379
+ print_info=print_info,
380
+ save_mem=save_mem,
381
+ )
382
+
383
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
384
+ x = F.pad(x, self.pad, mode="constant", value=0.0)
385
+ x = self.conv(x)
386
+ return x
387
+
388
+
389
+ class MaisiResBlock(nn.Module):
390
+ """
391
+ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
392
+ residual connection between input and output.
393
+
394
+ Args:
395
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
396
+ in_channels: Input channels to the layer.
397
+ norm_num_groups: Number of groups for the group norm layer.
398
+ norm_eps: Epsilon for the normalization.
399
+ out_channels: Number of output channels.
400
+ num_splits: Number of splits for the input tensor.
401
+ dim_split: Dimension of splitting for the input tensor.
402
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
403
+ print_info: Whether to print information, default to `False`.
404
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
405
+ """
406
+
407
+ def __init__(
408
+ self,
409
+ spatial_dims: int,
410
+ in_channels: int,
411
+ norm_num_groups: int,
412
+ norm_eps: float,
413
+ out_channels: int,
414
+ num_splits: int,
415
+ dim_split: int,
416
+ norm_float16: bool = False,
417
+ print_info: bool = False,
418
+ save_mem: bool = True,
419
+ ) -> None:
420
+ super().__init__()
421
+ self.in_channels = in_channels
422
+ self.out_channels = in_channels if out_channels is None else out_channels
423
+ self.save_mem = save_mem
424
+
425
+ self.norm1 = MaisiGroupNorm3D(
426
+ num_groups=norm_num_groups,
427
+ num_channels=in_channels,
428
+ eps=norm_eps,
429
+ affine=True,
430
+ norm_float16=norm_float16,
431
+ print_info=print_info,
432
+ save_mem=save_mem,
433
+ )
434
+ self.conv1 = MaisiConvolution(
435
+ spatial_dims=spatial_dims,
436
+ in_channels=self.in_channels,
437
+ out_channels=self.out_channels,
438
+ strides=1,
439
+ kernel_size=3,
440
+ padding=1,
441
+ conv_only=True,
442
+ num_splits=num_splits,
443
+ dim_split=dim_split,
444
+ print_info=print_info,
445
+ save_mem=save_mem,
446
+ )
447
+ self.norm2 = MaisiGroupNorm3D(
448
+ num_groups=norm_num_groups,
449
+ num_channels=out_channels,
450
+ eps=norm_eps,
451
+ affine=True,
452
+ norm_float16=norm_float16,
453
+ print_info=print_info,
454
+ save_mem=save_mem,
455
+ )
456
+ self.conv2 = MaisiConvolution(
457
+ spatial_dims=spatial_dims,
458
+ in_channels=self.out_channels,
459
+ out_channels=self.out_channels,
460
+ strides=1,
461
+ kernel_size=3,
462
+ padding=1,
463
+ conv_only=True,
464
+ num_splits=num_splits,
465
+ dim_split=dim_split,
466
+ print_info=print_info,
467
+ save_mem=save_mem,
468
+ )
469
+
470
+ self.nin_shortcut = (
471
+ MaisiConvolution(
472
+ spatial_dims=spatial_dims,
473
+ in_channels=self.in_channels,
474
+ out_channels=self.out_channels,
475
+ strides=1,
476
+ kernel_size=1,
477
+ padding=0,
478
+ conv_only=True,
479
+ num_splits=num_splits,
480
+ dim_split=dim_split,
481
+ print_info=print_info,
482
+ save_mem=save_mem,
483
+ )
484
+ if self.in_channels != self.out_channels
485
+ else nn.Identity()
486
+ )
487
+
488
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
489
+ h = self.norm1(x)
490
+ _empty_cuda_cache(self.save_mem)
491
+
492
+ h = F.silu(h)
493
+ _empty_cuda_cache(self.save_mem)
494
+ h = self.conv1(h)
495
+ _empty_cuda_cache(self.save_mem)
496
+
497
+ h = self.norm2(h)
498
+ _empty_cuda_cache(self.save_mem)
499
+
500
+ h = F.silu(h)
501
+ _empty_cuda_cache(self.save_mem)
502
+ h = self.conv2(h)
503
+ _empty_cuda_cache(self.save_mem)
504
+
505
+ if self.in_channels != self.out_channels:
506
+ x = self.nin_shortcut(x)
507
+ _empty_cuda_cache(self.save_mem)
508
+
509
+ out = x + h
510
+ out_tensor: torch.Tensor = convert_to_tensor(out)
511
+ return out_tensor
512
+
513
+
514
+ class MaisiEncoder(nn.Module):
515
+ """
516
+ Convolutional cascade that downsamples the image into a spatial latent space.
517
+
518
+ Args:
519
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
520
+ in_channels: Number of input channels.
521
+ num_channels: Sequence of block output channels.
522
+ out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
523
+ num_res_blocks: Number of residual blocks (see ResBlock) per level.
524
+ norm_num_groups: Number of groups for the group norm layers.
525
+ norm_eps: Epsilon for the normalization.
526
+ attention_levels: Indicate which level from num_channels contain an attention block.
527
+ with_nonlocal_attn: If True, use non-local attention block.
528
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
529
+ num_splits: Number of splits for the input tensor.
530
+ dim_split: Dimension of splitting for the input tensor.
531
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
532
+ print_info: Whether to print information, default to `False`.
533
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
534
+ """
535
+
536
+ def __init__(
537
+ self,
538
+ spatial_dims: int,
539
+ in_channels: int,
540
+ num_channels: Sequence[int],
541
+ out_channels: int,
542
+ num_res_blocks: Sequence[int],
543
+ norm_num_groups: int,
544
+ norm_eps: float,
545
+ attention_levels: Sequence[bool],
546
+ num_splits: int,
547
+ dim_split: int,
548
+ norm_float16: bool = False,
549
+ print_info: bool = False,
550
+ save_mem: bool = True,
551
+ with_nonlocal_attn: bool = True,
552
+ use_flash_attention: bool = False,
553
+ ) -> None:
554
+ super().__init__()
555
+
556
+ # Check if attention_levels and num_channels have the same size
557
+ if len(attention_levels) != len(num_channels):
558
+ raise ValueError("attention_levels and num_channels must have the same size")
559
+
560
+ # Check if num_res_blocks and num_channels have the same size
561
+ if len(num_res_blocks) != len(num_channels):
562
+ raise ValueError("num_res_blocks and num_channels must have the same size")
563
+
564
+ self.save_mem = save_mem
565
+
566
+ blocks: list[nn.Module] = []
567
+
568
+ blocks.append(
569
+ MaisiConvolution(
570
+ spatial_dims=spatial_dims,
571
+ in_channels=in_channels,
572
+ out_channels=num_channels[0],
573
+ strides=1,
574
+ kernel_size=3,
575
+ padding=1,
576
+ conv_only=True,
577
+ num_splits=num_splits,
578
+ dim_split=dim_split,
579
+ print_info=print_info,
580
+ save_mem=save_mem,
581
+ )
582
+ )
583
+
584
+ output_channel = num_channels[0]
585
+ for i in range(len(num_channels)):
586
+ input_channel = output_channel
587
+ output_channel = num_channels[i]
588
+ is_final_block = i == len(num_channels) - 1
589
+
590
+ for _ in range(num_res_blocks[i]):
591
+ blocks.append(
592
+ MaisiResBlock(
593
+ spatial_dims=spatial_dims,
594
+ in_channels=input_channel,
595
+ norm_num_groups=norm_num_groups,
596
+ norm_eps=norm_eps,
597
+ out_channels=output_channel,
598
+ num_splits=num_splits,
599
+ dim_split=dim_split,
600
+ norm_float16=norm_float16,
601
+ print_info=print_info,
602
+ save_mem=save_mem,
603
+ )
604
+ )
605
+ input_channel = output_channel
606
+ if attention_levels[i]:
607
+ blocks.append(
608
+ AttentionBlock(
609
+ spatial_dims=spatial_dims,
610
+ num_channels=input_channel,
611
+ norm_num_groups=norm_num_groups,
612
+ norm_eps=norm_eps,
613
+ use_flash_attention=use_flash_attention,
614
+ )
615
+ )
616
+
617
+ if not is_final_block:
618
+ blocks.append(
619
+ MaisiDownsample(
620
+ spatial_dims=spatial_dims,
621
+ in_channels=input_channel,
622
+ num_splits=num_splits,
623
+ dim_split=dim_split,
624
+ print_info=print_info,
625
+ save_mem=save_mem,
626
+ )
627
+ )
628
+
629
+ if with_nonlocal_attn:
630
+ blocks.append(
631
+ ResBlock(
632
+ spatial_dims=spatial_dims,
633
+ in_channels=num_channels[-1],
634
+ norm_num_groups=norm_num_groups,
635
+ norm_eps=norm_eps,
636
+ out_channels=num_channels[-1],
637
+ )
638
+ )
639
+
640
+ blocks.append(
641
+ AttentionBlock(
642
+ spatial_dims=spatial_dims,
643
+ num_channels=num_channels[-1],
644
+ norm_num_groups=norm_num_groups,
645
+ norm_eps=norm_eps,
646
+ use_flash_attention=use_flash_attention,
647
+ )
648
+ )
649
+ blocks.append(
650
+ ResBlock(
651
+ spatial_dims=spatial_dims,
652
+ in_channels=num_channels[-1],
653
+ norm_num_groups=norm_num_groups,
654
+ norm_eps=norm_eps,
655
+ out_channels=num_channels[-1],
656
+ )
657
+ )
658
+
659
+ blocks.append(
660
+ MaisiGroupNorm3D(
661
+ num_groups=norm_num_groups,
662
+ num_channels=num_channels[-1],
663
+ eps=norm_eps,
664
+ affine=True,
665
+ norm_float16=norm_float16,
666
+ print_info=print_info,
667
+ save_mem=save_mem,
668
+ )
669
+ )
670
+ blocks.append(
671
+ MaisiConvolution(
672
+ spatial_dims=spatial_dims,
673
+ in_channels=num_channels[-1],
674
+ out_channels=out_channels,
675
+ strides=1,
676
+ kernel_size=3,
677
+ padding=1,
678
+ conv_only=True,
679
+ num_splits=num_splits,
680
+ dim_split=dim_split,
681
+ print_info=print_info,
682
+ save_mem=save_mem,
683
+ )
684
+ )
685
+
686
+ self.blocks = nn.ModuleList(blocks)
687
+
688
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
689
+ for block in self.blocks:
690
+ x = block(x)
691
+ _empty_cuda_cache(self.save_mem)
692
+ return x
693
+
694
+
695
+ class MaisiDecoder(nn.Module):
696
+ """
697
+ Convolutional cascade upsampling from a spatial latent space into an image space.
698
+
699
+ Args:
700
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
701
+ num_channels: Sequence of block output channels.
702
+ in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
703
+ out_channels: Number of output channels.
704
+ num_res_blocks: Number of residual blocks (see ResBlock) per level.
705
+ norm_num_groups: Number of groups for the group norm layers.
706
+ norm_eps: Epsilon for the normalization.
707
+ attention_levels: Indicate which level from num_channels contain an attention block.
708
+ with_nonlocal_attn: If True, use non-local attention block.
709
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
710
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
711
+ num_splits: Number of splits for the input tensor.
712
+ dim_split: Dimension of splitting for the input tensor.
713
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
714
+ print_info: Whether to print information, default to `False`.
715
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
716
+ """
717
+
718
+ def __init__(
719
+ self,
720
+ spatial_dims: int,
721
+ num_channels: Sequence[int],
722
+ in_channels: int,
723
+ out_channels: int,
724
+ num_res_blocks: Sequence[int],
725
+ norm_num_groups: int,
726
+ norm_eps: float,
727
+ attention_levels: Sequence[bool],
728
+ num_splits: int,
729
+ dim_split: int,
730
+ norm_float16: bool = False,
731
+ print_info: bool = False,
732
+ save_mem: bool = True,
733
+ with_nonlocal_attn: bool = True,
734
+ use_flash_attention: bool = False,
735
+ use_convtranspose: bool = False,
736
+ ) -> None:
737
+ super().__init__()
738
+ self.print_info = print_info
739
+ self.save_mem = save_mem
740
+
741
+ reversed_block_out_channels = list(reversed(num_channels))
742
+
743
+ blocks: list[nn.Module] = []
744
+
745
+ blocks.append(
746
+ MaisiConvolution(
747
+ spatial_dims=spatial_dims,
748
+ in_channels=in_channels,
749
+ out_channels=reversed_block_out_channels[0],
750
+ strides=1,
751
+ kernel_size=3,
752
+ padding=1,
753
+ conv_only=True,
754
+ num_splits=num_splits,
755
+ dim_split=dim_split,
756
+ print_info=print_info,
757
+ save_mem=save_mem,
758
+ )
759
+ )
760
+
761
+ if with_nonlocal_attn:
762
+ blocks.append(
763
+ ResBlock(
764
+ spatial_dims=spatial_dims,
765
+ in_channels=reversed_block_out_channels[0],
766
+ norm_num_groups=norm_num_groups,
767
+ norm_eps=norm_eps,
768
+ out_channels=reversed_block_out_channels[0],
769
+ )
770
+ )
771
+ blocks.append(
772
+ AttentionBlock(
773
+ spatial_dims=spatial_dims,
774
+ num_channels=reversed_block_out_channels[0],
775
+ norm_num_groups=norm_num_groups,
776
+ norm_eps=norm_eps,
777
+ use_flash_attention=use_flash_attention,
778
+ )
779
+ )
780
+ blocks.append(
781
+ ResBlock(
782
+ spatial_dims=spatial_dims,
783
+ in_channels=reversed_block_out_channels[0],
784
+ norm_num_groups=norm_num_groups,
785
+ norm_eps=norm_eps,
786
+ out_channels=reversed_block_out_channels[0],
787
+ )
788
+ )
789
+
790
+ reversed_attention_levels = list(reversed(attention_levels))
791
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
792
+ block_out_ch = reversed_block_out_channels[0]
793
+ for i in range(len(reversed_block_out_channels)):
794
+ block_in_ch = block_out_ch
795
+ block_out_ch = reversed_block_out_channels[i]
796
+ is_final_block = i == len(num_channels) - 1
797
+
798
+ for _ in range(reversed_num_res_blocks[i]):
799
+ blocks.append(
800
+ MaisiResBlock(
801
+ spatial_dims=spatial_dims,
802
+ in_channels=block_in_ch,
803
+ norm_num_groups=norm_num_groups,
804
+ norm_eps=norm_eps,
805
+ out_channels=block_out_ch,
806
+ num_splits=num_splits,
807
+ dim_split=dim_split,
808
+ norm_float16=norm_float16,
809
+ print_info=print_info,
810
+ save_mem=save_mem,
811
+ )
812
+ )
813
+ block_in_ch = block_out_ch
814
+
815
+ if reversed_attention_levels[i]:
816
+ blocks.append(
817
+ AttentionBlock(
818
+ spatial_dims=spatial_dims,
819
+ num_channels=block_in_ch,
820
+ norm_num_groups=norm_num_groups,
821
+ norm_eps=norm_eps,
822
+ use_flash_attention=use_flash_attention,
823
+ )
824
+ )
825
+
826
+ if not is_final_block:
827
+ blocks.append(
828
+ MaisiUpsample(
829
+ spatial_dims=spatial_dims,
830
+ in_channels=block_in_ch,
831
+ use_convtranspose=use_convtranspose,
832
+ num_splits=num_splits,
833
+ dim_split=dim_split,
834
+ print_info=print_info,
835
+ save_mem=save_mem,
836
+ )
837
+ )
838
+
839
+ blocks.append(
840
+ MaisiGroupNorm3D(
841
+ num_groups=norm_num_groups,
842
+ num_channels=block_in_ch,
843
+ eps=norm_eps,
844
+ affine=True,
845
+ norm_float16=norm_float16,
846
+ print_info=print_info,
847
+ save_mem=save_mem,
848
+ )
849
+ )
850
+ blocks.append(
851
+ MaisiConvolution(
852
+ spatial_dims=spatial_dims,
853
+ in_channels=block_in_ch,
854
+ out_channels=out_channels,
855
+ strides=1,
856
+ kernel_size=3,
857
+ padding=1,
858
+ conv_only=True,
859
+ num_splits=num_splits,
860
+ dim_split=dim_split,
861
+ print_info=print_info,
862
+ save_mem=save_mem,
863
+ )
864
+ )
865
+
866
+ self.blocks = nn.ModuleList(blocks)
867
+
868
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
869
+ for block in self.blocks:
870
+ x = block(x)
871
+ _empty_cuda_cache(self.save_mem)
872
+ return x
873
+
874
+
875
+ class AutoencoderKlMaisi(AutoencoderKLType):
876
+ """
877
+ AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
878
+
879
+ Args:
880
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
881
+ in_channels: Number of input channels.
882
+ out_channels: Number of output channels.
883
+ num_res_blocks: Number of residual blocks per level.
884
+ num_channels: Sequence of block output channels.
885
+ attention_levels: Indicate which level from num_channels contain an attention block.
886
+ latent_channels: Number of channels in the latent space.
887
+ norm_num_groups: Number of groups for the group norm layers.
888
+ norm_eps: Epsilon for the normalization.
889
+ with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
890
+ with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
891
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
892
+ use_checkpointing: If True, use activation checkpointing.
893
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
894
+ num_splits: Number of splits for the input tensor.
895
+ dim_split: Dimension of splitting for the input tensor.
896
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
897
+ print_info: Whether to print information, default to `False`.
898
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
899
+ """
900
+
901
+ def __init__(
902
+ self,
903
+ spatial_dims: int,
904
+ in_channels: int,
905
+ out_channels: int,
906
+ num_res_blocks: Sequence[int],
907
+ num_channels: Sequence[int],
908
+ attention_levels: Sequence[bool],
909
+ latent_channels: int = 3,
910
+ norm_num_groups: int = 32,
911
+ norm_eps: float = 1e-6,
912
+ with_encoder_nonlocal_attn: bool = False,
913
+ with_decoder_nonlocal_attn: bool = False,
914
+ use_flash_attention: bool = False,
915
+ use_checkpointing: bool = False,
916
+ use_convtranspose: bool = False,
917
+ num_splits: int = 16,
918
+ dim_split: int = 0,
919
+ norm_float16: bool = False,
920
+ print_info: bool = False,
921
+ save_mem: bool = True,
922
+ ) -> None:
923
+ super().__init__(
924
+ spatial_dims,
925
+ in_channels,
926
+ out_channels,
927
+ num_res_blocks,
928
+ num_channels,
929
+ attention_levels,
930
+ latent_channels,
931
+ norm_num_groups,
932
+ norm_eps,
933
+ with_encoder_nonlocal_attn,
934
+ with_decoder_nonlocal_attn,
935
+ use_flash_attention,
936
+ use_checkpointing,
937
+ use_convtranspose,
938
+ )
939
+
940
+ self.encoder = MaisiEncoder(
941
+ spatial_dims=spatial_dims,
942
+ in_channels=in_channels,
943
+ num_channels=num_channels,
944
+ out_channels=latent_channels,
945
+ num_res_blocks=num_res_blocks,
946
+ norm_num_groups=norm_num_groups,
947
+ norm_eps=norm_eps,
948
+ attention_levels=attention_levels,
949
+ with_nonlocal_attn=with_encoder_nonlocal_attn,
950
+ use_flash_attention=use_flash_attention,
951
+ num_splits=num_splits,
952
+ dim_split=dim_split,
953
+ norm_float16=norm_float16,
954
+ print_info=print_info,
955
+ save_mem=save_mem,
956
+ )
957
+
958
+ self.decoder = MaisiDecoder(
959
+ spatial_dims=spatial_dims,
960
+ num_channels=num_channels,
961
+ in_channels=latent_channels,
962
+ out_channels=out_channels,
963
+ num_res_blocks=num_res_blocks,
964
+ norm_num_groups=norm_num_groups,
965
+ norm_eps=norm_eps,
966
+ attention_levels=attention_levels,
967
+ with_nonlocal_attn=with_decoder_nonlocal_attn,
968
+ use_flash_attention=use_flash_attention,
969
+ use_convtranspose=use_convtranspose,
970
+ num_splits=num_splits,
971
+ dim_split=dim_split,
972
+ norm_float16=norm_float16,
973
+ print_info=print_info,
974
+ save_mem=save_mem,
975
+ )