monai-weekly 1.4.dev2426__py3-none-any.whl → 1.4.dev2428__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,973 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import gc
15
+ import logging
16
+ from typing import TYPE_CHECKING, Sequence, cast
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from monai.networks.blocks import Convolution
23
+ from monai.utils import optional_import
24
+ from monai.utils.type_conversion import convert_to_tensor
25
+
26
+ AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
27
+ AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
28
+ ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
29
+
30
+ if TYPE_CHECKING:
31
+ from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
32
+ else:
33
+ AutoencoderKLType = cast(type, AutoencoderKL)
34
+
35
+ # Set up logging configuration
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def _empty_cuda_cache(save_mem: bool) -> None:
40
+ if torch.cuda.is_available() and save_mem:
41
+ torch.cuda.empty_cache()
42
+ return
43
+
44
+
45
+ class MaisiGroupNorm3D(nn.GroupNorm):
46
+ """
47
+ Custom 3D Group Normalization with optional print_info output.
48
+
49
+ Args:
50
+ num_groups: Number of groups for the group norm.
51
+ num_channels: Number of channels for the group norm.
52
+ eps: Epsilon value for numerical stability.
53
+ affine: Whether to use learnable affine parameters, default to `True`.
54
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
55
+ print_info: Whether to print information, default to `False`.
56
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ num_groups: int,
62
+ num_channels: int,
63
+ eps: float = 1e-5,
64
+ affine: bool = True,
65
+ norm_float16: bool = False,
66
+ print_info: bool = False,
67
+ save_mem: bool = True,
68
+ ):
69
+ super().__init__(num_groups, num_channels, eps, affine)
70
+ self.norm_float16 = norm_float16
71
+ self.print_info = print_info
72
+ self.save_mem = save_mem
73
+
74
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
75
+ if self.print_info:
76
+ logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")
77
+
78
+ if len(input.shape) != 5:
79
+ raise ValueError("Expected a 5D tensor")
80
+
81
+ param_n, param_c, param_d, param_h, param_w = input.shape
82
+ input = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w)
83
+
84
+ inputs = []
85
+ for i in range(input.size(1)):
86
+ array = input[:, i : i + 1, ...].to(dtype=torch.float32)
87
+ mean = array.mean([2, 3, 4, 5], keepdim=True)
88
+ std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
89
+ if self.norm_float16:
90
+ inputs.append(((array - mean) / std).to(dtype=torch.float16))
91
+ else:
92
+ inputs.append((array - mean) / std)
93
+
94
+ del input
95
+ _empty_cuda_cache(self.save_mem)
96
+
97
+ input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs)
98
+
99
+ input = input.view(param_n, param_c, param_d, param_h, param_w)
100
+ if self.affine:
101
+ input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1))
102
+
103
+ if self.print_info:
104
+ logger.info(f"MaisiGroupNorm3D with output size: {input.size()}")
105
+
106
+ return input
107
+
108
+ def _cat_inputs(self, inputs):
109
+ input_type = inputs[0].device.type
110
+ input = inputs[0].clone().to("cpu", non_blocking=True) if input_type == "cuda" else inputs[0].clone()
111
+ inputs[0] = 0
112
+ _empty_cuda_cache(self.save_mem)
113
+
114
+ for k in range(len(inputs) - 1):
115
+ input = torch.cat((input, inputs[k + 1].cpu()), dim=1)
116
+ inputs[k + 1] = 0
117
+ _empty_cuda_cache(self.save_mem)
118
+ gc.collect()
119
+
120
+ if self.print_info:
121
+ logger.info(f"MaisiGroupNorm3D concat progress: {k + 1}/{len(inputs) - 1}.")
122
+
123
+ return input.to("cuda", non_blocking=True) if input_type == "cuda" else input
124
+
125
+
126
+ class MaisiConvolution(nn.Module):
127
+ """
128
+ Convolutional layer with optional print_info output and custom splitting mechanism.
129
+
130
+ Args:
131
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
132
+ in_channels: Number of input channels.
133
+ out_channels: Number of output channels.
134
+ num_splits: Number of splits for the input tensor.
135
+ dim_split: Dimension of splitting for the input tensor.
136
+ print_info: Whether to print information.
137
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
138
+ Additional arguments for the convolution operation.
139
+ https://docs.monai.io/en/stable/networks.html#convolution
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ spatial_dims: int,
145
+ in_channels: int,
146
+ out_channels: int,
147
+ num_splits: int,
148
+ dim_split: int,
149
+ print_info: bool,
150
+ save_mem: bool = True,
151
+ strides: Sequence[int] | int = 1,
152
+ kernel_size: Sequence[int] | int = 3,
153
+ adn_ordering: str = "NDA",
154
+ act: tuple | str | None = "PRELU",
155
+ norm: tuple | str | None = "INSTANCE",
156
+ dropout: tuple | str | float | None = None,
157
+ dropout_dim: int = 1,
158
+ dilation: Sequence[int] | int = 1,
159
+ groups: int = 1,
160
+ bias: bool = True,
161
+ conv_only: bool = False,
162
+ is_transposed: bool = False,
163
+ padding: Sequence[int] | int | None = None,
164
+ output_padding: Sequence[int] | int | None = None,
165
+ ) -> None:
166
+ super().__init__()
167
+ self.conv = Convolution(
168
+ spatial_dims=spatial_dims,
169
+ in_channels=in_channels,
170
+ out_channels=out_channels,
171
+ strides=strides,
172
+ kernel_size=kernel_size,
173
+ adn_ordering=adn_ordering,
174
+ act=act,
175
+ norm=norm,
176
+ dropout=dropout,
177
+ dropout_dim=dropout_dim,
178
+ dilation=dilation,
179
+ groups=groups,
180
+ bias=bias,
181
+ conv_only=conv_only,
182
+ is_transposed=is_transposed,
183
+ padding=padding,
184
+ output_padding=output_padding,
185
+ )
186
+
187
+ self.dim_split = dim_split
188
+ self.stride = strides[self.dim_split] if isinstance(strides, list) else strides
189
+ self.num_splits = num_splits
190
+ self.print_info = print_info
191
+ self.save_mem = save_mem
192
+
193
+ def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]:
194
+ overlaps = [0] + [padding] * (self.num_splits - 1)
195
+ last_padding = x.size(self.dim_split + 2) % split_size
196
+
197
+ slices = [slice(None)] * 5
198
+ splits: list[torch.Tensor] = []
199
+ for i in range(self.num_splits):
200
+ slices[self.dim_split + 2] = slice(
201
+ i * split_size - overlaps[i],
202
+ (i + 1) * split_size + (padding if i != self.num_splits - 1 else last_padding),
203
+ )
204
+ splits.append(x[tuple(slices)])
205
+
206
+ if self.print_info:
207
+ for j in range(len(splits)):
208
+ logger.info(f"Split {j + 1}/{len(splits)} size: {splits[j].size()}")
209
+
210
+ return splits
211
+
212
+ def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, padding: int) -> torch.Tensor:
213
+ slices = [slice(None)] * 5
214
+ for i in range(self.num_splits):
215
+ slices[self.dim_split + 2] = slice(None, split_size) if i == 0 else slice(padding, padding + split_size)
216
+ outputs[i] = outputs[i][tuple(slices)]
217
+
218
+ if self.print_info:
219
+ for i in range(self.num_splits):
220
+ logger.info(f"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}")
221
+
222
+ if max(outputs[0].size()) < 500:
223
+ x = torch.cat(outputs, dim=self.dim_split + 2)
224
+ else:
225
+ x = outputs[0].clone().to("cpu", non_blocking=True)
226
+ outputs[0] = torch.Tensor(0)
227
+ _empty_cuda_cache(self.save_mem)
228
+ for k in range(len(outputs) - 1):
229
+ x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2)
230
+ outputs[k + 1] = torch.Tensor(0)
231
+ _empty_cuda_cache(self.save_mem)
232
+ gc.collect()
233
+ if self.print_info:
234
+ logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.")
235
+
236
+ x = x.to("cuda", non_blocking=True)
237
+ return x
238
+
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ if self.print_info:
241
+ logger.info(f"Number of splits: {self.num_splits}")
242
+
243
+ # compute size of splits
244
+ l = x.size(self.dim_split + 2)
245
+ split_size = l // self.num_splits
246
+
247
+ # update padding length if necessary
248
+ padding = 3
249
+ if padding % self.stride > 0:
250
+ padding = (padding // self.stride + 1) * self.stride
251
+ if self.print_info:
252
+ logger.info(f"Padding size: {padding}")
253
+
254
+ # split tensor into a list of tensors
255
+ splits = self._split_tensor(x, split_size, padding)
256
+
257
+ del x
258
+ _empty_cuda_cache(self.save_mem)
259
+
260
+ # convolution
261
+ outputs = [self.conv(split) for split in splits]
262
+ if self.print_info:
263
+ for j in range(len(outputs)):
264
+ logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}")
265
+
266
+ # update size of splits and padding length for output
267
+ split_size_out = split_size
268
+ padding_s = padding
269
+ non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0
270
+ if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2:
271
+ split_size_out *= 2
272
+ padding_s *= 2
273
+ elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2:
274
+ split_size_out //= 2
275
+ padding_s //= 2
276
+
277
+ # concatenate list of tensors
278
+ x = self._concatenate_tensors(outputs, split_size_out, padding_s)
279
+
280
+ del outputs
281
+ _empty_cuda_cache(self.save_mem)
282
+
283
+ return x
284
+
285
+
286
+ class MaisiUpsample(nn.Module):
287
+ """
288
+ Convolution-based upsampling layer.
289
+
290
+ Args:
291
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
292
+ in_channels: Number of input channels to the layer.
293
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
294
+ num_splits: Number of splits for the input tensor.
295
+ dim_split: Dimension of splitting for the input tensor.
296
+ print_info: Whether to print information.
297
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ spatial_dims: int,
303
+ in_channels: int,
304
+ use_convtranspose: bool,
305
+ num_splits: int,
306
+ dim_split: int,
307
+ print_info: bool,
308
+ save_mem: bool = True,
309
+ ) -> None:
310
+ super().__init__()
311
+ self.conv = MaisiConvolution(
312
+ spatial_dims=spatial_dims,
313
+ in_channels=in_channels,
314
+ out_channels=in_channels,
315
+ strides=2 if use_convtranspose else 1,
316
+ kernel_size=3,
317
+ padding=1,
318
+ conv_only=True,
319
+ is_transposed=use_convtranspose,
320
+ num_splits=num_splits,
321
+ dim_split=dim_split,
322
+ print_info=print_info,
323
+ save_mem=save_mem,
324
+ )
325
+ self.use_convtranspose = use_convtranspose
326
+ self.save_mem = save_mem
327
+
328
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
329
+ if self.use_convtranspose:
330
+ x = self.conv(x)
331
+ x_tensor: torch.Tensor = convert_to_tensor(x)
332
+ return x_tensor
333
+
334
+ x = F.interpolate(x, scale_factor=2.0, mode="trilinear")
335
+ _empty_cuda_cache(self.save_mem)
336
+ x = self.conv(x)
337
+ _empty_cuda_cache(self.save_mem)
338
+
339
+ out_tensor: torch.Tensor = convert_to_tensor(x)
340
+ return out_tensor
341
+
342
+
343
+ class MaisiDownsample(nn.Module):
344
+ """
345
+ Convolution-based downsampling layer.
346
+
347
+ Args:
348
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
349
+ in_channels: Number of input channels.
350
+ num_splits: Number of splits for the input tensor.
351
+ dim_split: Dimension of splitting for the input tensor.
352
+ print_info: Whether to print information.
353
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
354
+ """
355
+
356
+ def __init__(
357
+ self,
358
+ spatial_dims: int,
359
+ in_channels: int,
360
+ num_splits: int,
361
+ dim_split: int,
362
+ print_info: bool,
363
+ save_mem: bool = True,
364
+ ) -> None:
365
+ super().__init__()
366
+ self.pad = (0, 1) * spatial_dims
367
+ self.conv = MaisiConvolution(
368
+ spatial_dims=spatial_dims,
369
+ in_channels=in_channels,
370
+ out_channels=in_channels,
371
+ strides=2,
372
+ kernel_size=3,
373
+ padding=0,
374
+ conv_only=True,
375
+ num_splits=num_splits,
376
+ dim_split=dim_split,
377
+ print_info=print_info,
378
+ save_mem=save_mem,
379
+ )
380
+
381
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
382
+ x = F.pad(x, self.pad, mode="constant", value=0.0)
383
+ x = self.conv(x)
384
+ return x
385
+
386
+
387
+ class MaisiResBlock(nn.Module):
388
+ """
389
+ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
390
+ residual connection between input and output.
391
+
392
+ Args:
393
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
394
+ in_channels: Input channels to the layer.
395
+ norm_num_groups: Number of groups for the group norm layer.
396
+ norm_eps: Epsilon for the normalization.
397
+ out_channels: Number of output channels.
398
+ num_splits: Number of splits for the input tensor.
399
+ dim_split: Dimension of splitting for the input tensor.
400
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
401
+ print_info: Whether to print information, default to `False`.
402
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ spatial_dims: int,
408
+ in_channels: int,
409
+ norm_num_groups: int,
410
+ norm_eps: float,
411
+ out_channels: int,
412
+ num_splits: int,
413
+ dim_split: int,
414
+ norm_float16: bool = False,
415
+ print_info: bool = False,
416
+ save_mem: bool = True,
417
+ ) -> None:
418
+ super().__init__()
419
+ self.in_channels = in_channels
420
+ self.out_channels = in_channels if out_channels is None else out_channels
421
+ self.save_mem = save_mem
422
+
423
+ self.norm1 = MaisiGroupNorm3D(
424
+ num_groups=norm_num_groups,
425
+ num_channels=in_channels,
426
+ eps=norm_eps,
427
+ affine=True,
428
+ norm_float16=norm_float16,
429
+ print_info=print_info,
430
+ save_mem=save_mem,
431
+ )
432
+ self.conv1 = MaisiConvolution(
433
+ spatial_dims=spatial_dims,
434
+ in_channels=self.in_channels,
435
+ out_channels=self.out_channels,
436
+ strides=1,
437
+ kernel_size=3,
438
+ padding=1,
439
+ conv_only=True,
440
+ num_splits=num_splits,
441
+ dim_split=dim_split,
442
+ print_info=print_info,
443
+ save_mem=save_mem,
444
+ )
445
+ self.norm2 = MaisiGroupNorm3D(
446
+ num_groups=norm_num_groups,
447
+ num_channels=out_channels,
448
+ eps=norm_eps,
449
+ affine=True,
450
+ norm_float16=norm_float16,
451
+ print_info=print_info,
452
+ save_mem=save_mem,
453
+ )
454
+ self.conv2 = MaisiConvolution(
455
+ spatial_dims=spatial_dims,
456
+ in_channels=self.out_channels,
457
+ out_channels=self.out_channels,
458
+ strides=1,
459
+ kernel_size=3,
460
+ padding=1,
461
+ conv_only=True,
462
+ num_splits=num_splits,
463
+ dim_split=dim_split,
464
+ print_info=print_info,
465
+ save_mem=save_mem,
466
+ )
467
+
468
+ self.nin_shortcut = (
469
+ MaisiConvolution(
470
+ spatial_dims=spatial_dims,
471
+ in_channels=self.in_channels,
472
+ out_channels=self.out_channels,
473
+ strides=1,
474
+ kernel_size=1,
475
+ padding=0,
476
+ conv_only=True,
477
+ num_splits=num_splits,
478
+ dim_split=dim_split,
479
+ print_info=print_info,
480
+ save_mem=save_mem,
481
+ )
482
+ if self.in_channels != self.out_channels
483
+ else nn.Identity()
484
+ )
485
+
486
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
487
+ h = self.norm1(x)
488
+ _empty_cuda_cache(self.save_mem)
489
+
490
+ h = F.silu(h)
491
+ _empty_cuda_cache(self.save_mem)
492
+ h = self.conv1(h)
493
+ _empty_cuda_cache(self.save_mem)
494
+
495
+ h = self.norm2(h)
496
+ _empty_cuda_cache(self.save_mem)
497
+
498
+ h = F.silu(h)
499
+ _empty_cuda_cache(self.save_mem)
500
+ h = self.conv2(h)
501
+ _empty_cuda_cache(self.save_mem)
502
+
503
+ if self.in_channels != self.out_channels:
504
+ x = self.nin_shortcut(x)
505
+ _empty_cuda_cache(self.save_mem)
506
+
507
+ out = x + h
508
+ out_tensor: torch.Tensor = convert_to_tensor(out)
509
+ return out_tensor
510
+
511
+
512
+ class MaisiEncoder(nn.Module):
513
+ """
514
+ Convolutional cascade that downsamples the image into a spatial latent space.
515
+
516
+ Args:
517
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
518
+ in_channels: Number of input channels.
519
+ num_channels: Sequence of block output channels.
520
+ out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
521
+ num_res_blocks: Number of residual blocks (see ResBlock) per level.
522
+ norm_num_groups: Number of groups for the group norm layers.
523
+ norm_eps: Epsilon for the normalization.
524
+ attention_levels: Indicate which level from num_channels contain an attention block.
525
+ with_nonlocal_attn: If True, use non-local attention block.
526
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
527
+ num_splits: Number of splits for the input tensor.
528
+ dim_split: Dimension of splitting for the input tensor.
529
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
530
+ print_info: Whether to print information, default to `False`.
531
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
532
+ """
533
+
534
+ def __init__(
535
+ self,
536
+ spatial_dims: int,
537
+ in_channels: int,
538
+ num_channels: Sequence[int],
539
+ out_channels: int,
540
+ num_res_blocks: Sequence[int],
541
+ norm_num_groups: int,
542
+ norm_eps: float,
543
+ attention_levels: Sequence[bool],
544
+ num_splits: int,
545
+ dim_split: int,
546
+ norm_float16: bool = False,
547
+ print_info: bool = False,
548
+ save_mem: bool = True,
549
+ with_nonlocal_attn: bool = True,
550
+ use_flash_attention: bool = False,
551
+ ) -> None:
552
+ super().__init__()
553
+
554
+ # Check if attention_levels and num_channels have the same size
555
+ if len(attention_levels) != len(num_channels):
556
+ raise ValueError("attention_levels and num_channels must have the same size")
557
+
558
+ # Check if num_res_blocks and num_channels have the same size
559
+ if len(num_res_blocks) != len(num_channels):
560
+ raise ValueError("num_res_blocks and num_channels must have the same size")
561
+
562
+ self.save_mem = save_mem
563
+
564
+ blocks: list[nn.Module] = []
565
+
566
+ blocks.append(
567
+ MaisiConvolution(
568
+ spatial_dims=spatial_dims,
569
+ in_channels=in_channels,
570
+ out_channels=num_channels[0],
571
+ strides=1,
572
+ kernel_size=3,
573
+ padding=1,
574
+ conv_only=True,
575
+ num_splits=num_splits,
576
+ dim_split=dim_split,
577
+ print_info=print_info,
578
+ save_mem=save_mem,
579
+ )
580
+ )
581
+
582
+ output_channel = num_channels[0]
583
+ for i in range(len(num_channels)):
584
+ input_channel = output_channel
585
+ output_channel = num_channels[i]
586
+ is_final_block = i == len(num_channels) - 1
587
+
588
+ for _ in range(num_res_blocks[i]):
589
+ blocks.append(
590
+ MaisiResBlock(
591
+ spatial_dims=spatial_dims,
592
+ in_channels=input_channel,
593
+ norm_num_groups=norm_num_groups,
594
+ norm_eps=norm_eps,
595
+ out_channels=output_channel,
596
+ num_splits=num_splits,
597
+ dim_split=dim_split,
598
+ norm_float16=norm_float16,
599
+ print_info=print_info,
600
+ save_mem=save_mem,
601
+ )
602
+ )
603
+ input_channel = output_channel
604
+ if attention_levels[i]:
605
+ blocks.append(
606
+ AttentionBlock(
607
+ spatial_dims=spatial_dims,
608
+ num_channels=input_channel,
609
+ norm_num_groups=norm_num_groups,
610
+ norm_eps=norm_eps,
611
+ use_flash_attention=use_flash_attention,
612
+ )
613
+ )
614
+
615
+ if not is_final_block:
616
+ blocks.append(
617
+ MaisiDownsample(
618
+ spatial_dims=spatial_dims,
619
+ in_channels=input_channel,
620
+ num_splits=num_splits,
621
+ dim_split=dim_split,
622
+ print_info=print_info,
623
+ save_mem=save_mem,
624
+ )
625
+ )
626
+
627
+ if with_nonlocal_attn:
628
+ blocks.append(
629
+ ResBlock(
630
+ spatial_dims=spatial_dims,
631
+ in_channels=num_channels[-1],
632
+ norm_num_groups=norm_num_groups,
633
+ norm_eps=norm_eps,
634
+ out_channels=num_channels[-1],
635
+ )
636
+ )
637
+
638
+ blocks.append(
639
+ AttentionBlock(
640
+ spatial_dims=spatial_dims,
641
+ num_channels=num_channels[-1],
642
+ norm_num_groups=norm_num_groups,
643
+ norm_eps=norm_eps,
644
+ use_flash_attention=use_flash_attention,
645
+ )
646
+ )
647
+ blocks.append(
648
+ ResBlock(
649
+ spatial_dims=spatial_dims,
650
+ in_channels=num_channels[-1],
651
+ norm_num_groups=norm_num_groups,
652
+ norm_eps=norm_eps,
653
+ out_channels=num_channels[-1],
654
+ )
655
+ )
656
+
657
+ blocks.append(
658
+ MaisiGroupNorm3D(
659
+ num_groups=norm_num_groups,
660
+ num_channels=num_channels[-1],
661
+ eps=norm_eps,
662
+ affine=True,
663
+ norm_float16=norm_float16,
664
+ print_info=print_info,
665
+ save_mem=save_mem,
666
+ )
667
+ )
668
+ blocks.append(
669
+ MaisiConvolution(
670
+ spatial_dims=spatial_dims,
671
+ in_channels=num_channels[-1],
672
+ out_channels=out_channels,
673
+ strides=1,
674
+ kernel_size=3,
675
+ padding=1,
676
+ conv_only=True,
677
+ num_splits=num_splits,
678
+ dim_split=dim_split,
679
+ print_info=print_info,
680
+ save_mem=save_mem,
681
+ )
682
+ )
683
+
684
+ self.blocks = nn.ModuleList(blocks)
685
+
686
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
687
+ for block in self.blocks:
688
+ x = block(x)
689
+ _empty_cuda_cache(self.save_mem)
690
+ return x
691
+
692
+
693
+ class MaisiDecoder(nn.Module):
694
+ """
695
+ Convolutional cascade upsampling from a spatial latent space into an image space.
696
+
697
+ Args:
698
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
699
+ num_channels: Sequence of block output channels.
700
+ in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
701
+ out_channels: Number of output channels.
702
+ num_res_blocks: Number of residual blocks (see ResBlock) per level.
703
+ norm_num_groups: Number of groups for the group norm layers.
704
+ norm_eps: Epsilon for the normalization.
705
+ attention_levels: Indicate which level from num_channels contain an attention block.
706
+ with_nonlocal_attn: If True, use non-local attention block.
707
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
708
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
709
+ num_splits: Number of splits for the input tensor.
710
+ dim_split: Dimension of splitting for the input tensor.
711
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
712
+ print_info: Whether to print information, default to `False`.
713
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
714
+ """
715
+
716
+ def __init__(
717
+ self,
718
+ spatial_dims: int,
719
+ num_channels: Sequence[int],
720
+ in_channels: int,
721
+ out_channels: int,
722
+ num_res_blocks: Sequence[int],
723
+ norm_num_groups: int,
724
+ norm_eps: float,
725
+ attention_levels: Sequence[bool],
726
+ num_splits: int,
727
+ dim_split: int,
728
+ norm_float16: bool = False,
729
+ print_info: bool = False,
730
+ save_mem: bool = True,
731
+ with_nonlocal_attn: bool = True,
732
+ use_flash_attention: bool = False,
733
+ use_convtranspose: bool = False,
734
+ ) -> None:
735
+ super().__init__()
736
+ self.print_info = print_info
737
+ self.save_mem = save_mem
738
+
739
+ reversed_block_out_channels = list(reversed(num_channels))
740
+
741
+ blocks: list[nn.Module] = []
742
+
743
+ blocks.append(
744
+ MaisiConvolution(
745
+ spatial_dims=spatial_dims,
746
+ in_channels=in_channels,
747
+ out_channels=reversed_block_out_channels[0],
748
+ strides=1,
749
+ kernel_size=3,
750
+ padding=1,
751
+ conv_only=True,
752
+ num_splits=num_splits,
753
+ dim_split=dim_split,
754
+ print_info=print_info,
755
+ save_mem=save_mem,
756
+ )
757
+ )
758
+
759
+ if with_nonlocal_attn:
760
+ blocks.append(
761
+ ResBlock(
762
+ spatial_dims=spatial_dims,
763
+ in_channels=reversed_block_out_channels[0],
764
+ norm_num_groups=norm_num_groups,
765
+ norm_eps=norm_eps,
766
+ out_channels=reversed_block_out_channels[0],
767
+ )
768
+ )
769
+ blocks.append(
770
+ AttentionBlock(
771
+ spatial_dims=spatial_dims,
772
+ num_channels=reversed_block_out_channels[0],
773
+ norm_num_groups=norm_num_groups,
774
+ norm_eps=norm_eps,
775
+ use_flash_attention=use_flash_attention,
776
+ )
777
+ )
778
+ blocks.append(
779
+ ResBlock(
780
+ spatial_dims=spatial_dims,
781
+ in_channels=reversed_block_out_channels[0],
782
+ norm_num_groups=norm_num_groups,
783
+ norm_eps=norm_eps,
784
+ out_channels=reversed_block_out_channels[0],
785
+ )
786
+ )
787
+
788
+ reversed_attention_levels = list(reversed(attention_levels))
789
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
790
+ block_out_ch = reversed_block_out_channels[0]
791
+ for i in range(len(reversed_block_out_channels)):
792
+ block_in_ch = block_out_ch
793
+ block_out_ch = reversed_block_out_channels[i]
794
+ is_final_block = i == len(num_channels) - 1
795
+
796
+ for _ in range(reversed_num_res_blocks[i]):
797
+ blocks.append(
798
+ MaisiResBlock(
799
+ spatial_dims=spatial_dims,
800
+ in_channels=block_in_ch,
801
+ norm_num_groups=norm_num_groups,
802
+ norm_eps=norm_eps,
803
+ out_channels=block_out_ch,
804
+ num_splits=num_splits,
805
+ dim_split=dim_split,
806
+ norm_float16=norm_float16,
807
+ print_info=print_info,
808
+ save_mem=save_mem,
809
+ )
810
+ )
811
+ block_in_ch = block_out_ch
812
+
813
+ if reversed_attention_levels[i]:
814
+ blocks.append(
815
+ AttentionBlock(
816
+ spatial_dims=spatial_dims,
817
+ num_channels=block_in_ch,
818
+ norm_num_groups=norm_num_groups,
819
+ norm_eps=norm_eps,
820
+ use_flash_attention=use_flash_attention,
821
+ )
822
+ )
823
+
824
+ if not is_final_block:
825
+ blocks.append(
826
+ MaisiUpsample(
827
+ spatial_dims=spatial_dims,
828
+ in_channels=block_in_ch,
829
+ use_convtranspose=use_convtranspose,
830
+ num_splits=num_splits,
831
+ dim_split=dim_split,
832
+ print_info=print_info,
833
+ save_mem=save_mem,
834
+ )
835
+ )
836
+
837
+ blocks.append(
838
+ MaisiGroupNorm3D(
839
+ num_groups=norm_num_groups,
840
+ num_channels=block_in_ch,
841
+ eps=norm_eps,
842
+ affine=True,
843
+ norm_float16=norm_float16,
844
+ print_info=print_info,
845
+ save_mem=save_mem,
846
+ )
847
+ )
848
+ blocks.append(
849
+ MaisiConvolution(
850
+ spatial_dims=spatial_dims,
851
+ in_channels=block_in_ch,
852
+ out_channels=out_channels,
853
+ strides=1,
854
+ kernel_size=3,
855
+ padding=1,
856
+ conv_only=True,
857
+ num_splits=num_splits,
858
+ dim_split=dim_split,
859
+ print_info=print_info,
860
+ save_mem=save_mem,
861
+ )
862
+ )
863
+
864
+ self.blocks = nn.ModuleList(blocks)
865
+
866
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
867
+ for block in self.blocks:
868
+ x = block(x)
869
+ _empty_cuda_cache(self.save_mem)
870
+ return x
871
+
872
+
873
+ class AutoencoderKlMaisi(AutoencoderKLType):
874
+ """
875
+ AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
876
+
877
+ Args:
878
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
879
+ in_channels: Number of input channels.
880
+ out_channels: Number of output channels.
881
+ num_res_blocks: Number of residual blocks per level.
882
+ num_channels: Sequence of block output channels.
883
+ attention_levels: Indicate which level from num_channels contain an attention block.
884
+ latent_channels: Number of channels in the latent space.
885
+ norm_num_groups: Number of groups for the group norm layers.
886
+ norm_eps: Epsilon for the normalization.
887
+ with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
888
+ with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
889
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
890
+ use_checkpointing: If True, use activation checkpointing.
891
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
892
+ num_splits: Number of splits for the input tensor.
893
+ dim_split: Dimension of splitting for the input tensor.
894
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
895
+ print_info: Whether to print information, default to `False`.
896
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
897
+ """
898
+
899
+ def __init__(
900
+ self,
901
+ spatial_dims: int,
902
+ in_channels: int,
903
+ out_channels: int,
904
+ num_res_blocks: Sequence[int],
905
+ num_channels: Sequence[int],
906
+ attention_levels: Sequence[bool],
907
+ latent_channels: int = 3,
908
+ norm_num_groups: int = 32,
909
+ norm_eps: float = 1e-6,
910
+ with_encoder_nonlocal_attn: bool = False,
911
+ with_decoder_nonlocal_attn: bool = False,
912
+ use_flash_attention: bool = False,
913
+ use_checkpointing: bool = False,
914
+ use_convtranspose: bool = False,
915
+ num_splits: int = 16,
916
+ dim_split: int = 0,
917
+ norm_float16: bool = False,
918
+ print_info: bool = False,
919
+ save_mem: bool = True,
920
+ ) -> None:
921
+ super().__init__(
922
+ spatial_dims,
923
+ in_channels,
924
+ out_channels,
925
+ num_res_blocks,
926
+ num_channels,
927
+ attention_levels,
928
+ latent_channels,
929
+ norm_num_groups,
930
+ norm_eps,
931
+ with_encoder_nonlocal_attn,
932
+ with_decoder_nonlocal_attn,
933
+ use_flash_attention,
934
+ use_checkpointing,
935
+ use_convtranspose,
936
+ )
937
+
938
+ self.encoder = MaisiEncoder(
939
+ spatial_dims=spatial_dims,
940
+ in_channels=in_channels,
941
+ num_channels=num_channels,
942
+ out_channels=latent_channels,
943
+ num_res_blocks=num_res_blocks,
944
+ norm_num_groups=norm_num_groups,
945
+ norm_eps=norm_eps,
946
+ attention_levels=attention_levels,
947
+ with_nonlocal_attn=with_encoder_nonlocal_attn,
948
+ use_flash_attention=use_flash_attention,
949
+ num_splits=num_splits,
950
+ dim_split=dim_split,
951
+ norm_float16=norm_float16,
952
+ print_info=print_info,
953
+ save_mem=save_mem,
954
+ )
955
+
956
+ self.decoder = MaisiDecoder(
957
+ spatial_dims=spatial_dims,
958
+ num_channels=num_channels,
959
+ in_channels=latent_channels,
960
+ out_channels=out_channels,
961
+ num_res_blocks=num_res_blocks,
962
+ norm_num_groups=norm_num_groups,
963
+ norm_eps=norm_eps,
964
+ attention_levels=attention_levels,
965
+ with_nonlocal_attn=with_decoder_nonlocal_attn,
966
+ use_flash_attention=use_flash_attention,
967
+ use_convtranspose=use_convtranspose,
968
+ num_splits=num_splits,
969
+ dim_split=dim_split,
970
+ norm_float16=norm_float16,
971
+ print_info=print_info,
972
+ save_mem=save_mem,
973
+ )