monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,702 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Sequence
15
+ from typing import List
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample
22
+ from monai.utils import ensure_tuple_rep, optional_import
23
+
24
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
25
+
26
+ __all__ = ["AutoencoderKL"]
27
+
28
+
29
+ class AsymmetricPad(nn.Module):
30
+ """
31
+ Pad the input tensor asymmetrically along every spatial dimension.
32
+
33
+ Args:
34
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
35
+ """
36
+
37
+ def __init__(self, spatial_dims: int) -> None:
38
+ super().__init__()
39
+ self.pad = (0, 1) * spatial_dims
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ x = nn.functional.pad(x, self.pad, mode="constant", value=0.0)
43
+ return x
44
+
45
+
46
+ class AEKLDownsample(nn.Module):
47
+ """
48
+ Convolution-based downsampling layer.
49
+
50
+ Args:
51
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
52
+ in_channels: number of input channels.
53
+ """
54
+
55
+ def __init__(self, spatial_dims: int, in_channels: int) -> None:
56
+ super().__init__()
57
+ self.pad = AsymmetricPad(spatial_dims=spatial_dims)
58
+
59
+ self.conv = Convolution(
60
+ spatial_dims=spatial_dims,
61
+ in_channels=in_channels,
62
+ out_channels=in_channels,
63
+ strides=2,
64
+ kernel_size=3,
65
+ padding=0,
66
+ conv_only=True,
67
+ )
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ x = self.pad(x)
71
+ x = self.conv(x)
72
+ return x
73
+
74
+
75
+ class AEKLResBlock(nn.Module):
76
+ """
77
+ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
78
+ residual connection between input and output.
79
+
80
+ Args:
81
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
82
+ in_channels: input channels to the layer.
83
+ norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
84
+ channels is divisible by this number.
85
+ norm_eps: epsilon for the normalisation.
86
+ out_channels: number of output channels.
87
+ """
88
+
89
+ def __init__(
90
+ self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int
91
+ ) -> None:
92
+ super().__init__()
93
+ self.in_channels = in_channels
94
+ self.out_channels = in_channels if out_channels is None else out_channels
95
+
96
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
97
+ self.conv1 = Convolution(
98
+ spatial_dims=spatial_dims,
99
+ in_channels=self.in_channels,
100
+ out_channels=self.out_channels,
101
+ strides=1,
102
+ kernel_size=3,
103
+ padding=1,
104
+ conv_only=True,
105
+ )
106
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True)
107
+ self.conv2 = Convolution(
108
+ spatial_dims=spatial_dims,
109
+ in_channels=self.out_channels,
110
+ out_channels=self.out_channels,
111
+ strides=1,
112
+ kernel_size=3,
113
+ padding=1,
114
+ conv_only=True,
115
+ )
116
+
117
+ self.nin_shortcut: nn.Module
118
+ if self.in_channels != self.out_channels:
119
+ self.nin_shortcut = Convolution(
120
+ spatial_dims=spatial_dims,
121
+ in_channels=self.in_channels,
122
+ out_channels=self.out_channels,
123
+ strides=1,
124
+ kernel_size=1,
125
+ padding=0,
126
+ conv_only=True,
127
+ )
128
+ else:
129
+ self.nin_shortcut = nn.Identity()
130
+
131
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
132
+ h = x
133
+ h = self.norm1(h)
134
+ h = F.silu(h)
135
+ h = self.conv1(h)
136
+
137
+ h = self.norm2(h)
138
+ h = F.silu(h)
139
+ h = self.conv2(h)
140
+
141
+ x = self.nin_shortcut(x)
142
+
143
+ return x + h
144
+
145
+
146
+ class Encoder(nn.Module):
147
+ """
148
+ Convolutional cascade that downsamples the image into a spatial latent space.
149
+
150
+ Args:
151
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
152
+ in_channels: number of input channels.
153
+ channels: sequence of block output channels.
154
+ out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
155
+ num_res_blocks: number of residual blocks (see _ResBlock) per level.
156
+ norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
157
+ norm_eps: epsilon for the normalization.
158
+ attention_levels: indicate which level from num_channels contain an attention block.
159
+ with_nonlocal_attn: if True use non-local attention block.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ spatial_dims: int,
165
+ in_channels: int,
166
+ channels: Sequence[int],
167
+ out_channels: int,
168
+ num_res_blocks: Sequence[int],
169
+ norm_num_groups: int,
170
+ norm_eps: float,
171
+ attention_levels: Sequence[bool],
172
+ with_nonlocal_attn: bool = True,
173
+ ) -> None:
174
+ super().__init__()
175
+ self.spatial_dims = spatial_dims
176
+ self.in_channels = in_channels
177
+ self.channels = channels
178
+ self.out_channels = out_channels
179
+ self.num_res_blocks = num_res_blocks
180
+ self.norm_num_groups = norm_num_groups
181
+ self.norm_eps = norm_eps
182
+ self.attention_levels = attention_levels
183
+
184
+ blocks: List[nn.Module] = []
185
+ # Initial convolution
186
+ blocks.append(
187
+ Convolution(
188
+ spatial_dims=spatial_dims,
189
+ in_channels=in_channels,
190
+ out_channels=channels[0],
191
+ strides=1,
192
+ kernel_size=3,
193
+ padding=1,
194
+ conv_only=True,
195
+ )
196
+ )
197
+
198
+ # Residual and downsampling blocks
199
+ output_channel = channels[0]
200
+ for i in range(len(channels)):
201
+ input_channel = output_channel
202
+ output_channel = channels[i]
203
+ is_final_block = i == len(channels) - 1
204
+
205
+ for _ in range(self.num_res_blocks[i]):
206
+ blocks.append(
207
+ AEKLResBlock(
208
+ spatial_dims=spatial_dims,
209
+ in_channels=input_channel,
210
+ norm_num_groups=norm_num_groups,
211
+ norm_eps=norm_eps,
212
+ out_channels=output_channel,
213
+ )
214
+ )
215
+ input_channel = output_channel
216
+ if attention_levels[i]:
217
+ blocks.append(
218
+ SpatialAttentionBlock(
219
+ spatial_dims=spatial_dims,
220
+ num_channels=input_channel,
221
+ norm_num_groups=norm_num_groups,
222
+ norm_eps=norm_eps,
223
+ )
224
+ )
225
+
226
+ if not is_final_block:
227
+ blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel))
228
+ # Non-local attention block
229
+ if with_nonlocal_attn is True:
230
+ blocks.append(
231
+ AEKLResBlock(
232
+ spatial_dims=spatial_dims,
233
+ in_channels=channels[-1],
234
+ norm_num_groups=norm_num_groups,
235
+ norm_eps=norm_eps,
236
+ out_channels=channels[-1],
237
+ )
238
+ )
239
+
240
+ blocks.append(
241
+ SpatialAttentionBlock(
242
+ spatial_dims=spatial_dims,
243
+ num_channels=channels[-1],
244
+ norm_num_groups=norm_num_groups,
245
+ norm_eps=norm_eps,
246
+ )
247
+ )
248
+ blocks.append(
249
+ AEKLResBlock(
250
+ spatial_dims=spatial_dims,
251
+ in_channels=channels[-1],
252
+ norm_num_groups=norm_num_groups,
253
+ norm_eps=norm_eps,
254
+ out_channels=channels[-1],
255
+ )
256
+ )
257
+ # Normalise and convert to latent size
258
+ blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True))
259
+ blocks.append(
260
+ Convolution(
261
+ spatial_dims=self.spatial_dims,
262
+ in_channels=channels[-1],
263
+ out_channels=out_channels,
264
+ strides=1,
265
+ kernel_size=3,
266
+ padding=1,
267
+ conv_only=True,
268
+ )
269
+ )
270
+
271
+ self.blocks = nn.ModuleList(blocks)
272
+
273
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
274
+ for block in self.blocks:
275
+ x = block(x)
276
+ return x
277
+
278
+
279
+ class Decoder(nn.Module):
280
+ """
281
+ Convolutional cascade upsampling from a spatial latent space into an image space.
282
+
283
+ Args:
284
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
285
+ channels: sequence of block output channels.
286
+ in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
287
+ out_channels: number of output channels.
288
+ num_res_blocks: number of residual blocks (see _ResBlock) per level.
289
+ norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
290
+ norm_eps: epsilon for the normalization.
291
+ attention_levels: indicate which level from num_channels contain an attention block.
292
+ with_nonlocal_attn: if True use non-local attention block.
293
+ use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ spatial_dims: int,
299
+ channels: Sequence[int],
300
+ in_channels: int,
301
+ out_channels: int,
302
+ num_res_blocks: Sequence[int],
303
+ norm_num_groups: int,
304
+ norm_eps: float,
305
+ attention_levels: Sequence[bool],
306
+ with_nonlocal_attn: bool = True,
307
+ use_convtranspose: bool = False,
308
+ ) -> None:
309
+ super().__init__()
310
+ self.spatial_dims = spatial_dims
311
+ self.channels = channels
312
+ self.in_channels = in_channels
313
+ self.out_channels = out_channels
314
+ self.num_res_blocks = num_res_blocks
315
+ self.norm_num_groups = norm_num_groups
316
+ self.norm_eps = norm_eps
317
+ self.attention_levels = attention_levels
318
+
319
+ reversed_block_out_channels = list(reversed(channels))
320
+
321
+ blocks: List[nn.Module] = []
322
+
323
+ # Initial convolution
324
+ blocks.append(
325
+ Convolution(
326
+ spatial_dims=spatial_dims,
327
+ in_channels=in_channels,
328
+ out_channels=reversed_block_out_channels[0],
329
+ strides=1,
330
+ kernel_size=3,
331
+ padding=1,
332
+ conv_only=True,
333
+ )
334
+ )
335
+
336
+ # Non-local attention block
337
+ if with_nonlocal_attn is True:
338
+ blocks.append(
339
+ AEKLResBlock(
340
+ spatial_dims=spatial_dims,
341
+ in_channels=reversed_block_out_channels[0],
342
+ norm_num_groups=norm_num_groups,
343
+ norm_eps=norm_eps,
344
+ out_channels=reversed_block_out_channels[0],
345
+ )
346
+ )
347
+ blocks.append(
348
+ SpatialAttentionBlock(
349
+ spatial_dims=spatial_dims,
350
+ num_channels=reversed_block_out_channels[0],
351
+ norm_num_groups=norm_num_groups,
352
+ norm_eps=norm_eps,
353
+ )
354
+ )
355
+ blocks.append(
356
+ AEKLResBlock(
357
+ spatial_dims=spatial_dims,
358
+ in_channels=reversed_block_out_channels[0],
359
+ norm_num_groups=norm_num_groups,
360
+ norm_eps=norm_eps,
361
+ out_channels=reversed_block_out_channels[0],
362
+ )
363
+ )
364
+
365
+ reversed_attention_levels = list(reversed(attention_levels))
366
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
367
+ block_out_ch = reversed_block_out_channels[0]
368
+ for i in range(len(reversed_block_out_channels)):
369
+ block_in_ch = block_out_ch
370
+ block_out_ch = reversed_block_out_channels[i]
371
+ is_final_block = i == len(channels) - 1
372
+
373
+ for _ in range(reversed_num_res_blocks[i]):
374
+ blocks.append(
375
+ AEKLResBlock(
376
+ spatial_dims=spatial_dims,
377
+ in_channels=block_in_ch,
378
+ norm_num_groups=norm_num_groups,
379
+ norm_eps=norm_eps,
380
+ out_channels=block_out_ch,
381
+ )
382
+ )
383
+ block_in_ch = block_out_ch
384
+
385
+ if reversed_attention_levels[i]:
386
+ blocks.append(
387
+ SpatialAttentionBlock(
388
+ spatial_dims=spatial_dims,
389
+ num_channels=block_in_ch,
390
+ norm_num_groups=norm_num_groups,
391
+ norm_eps=norm_eps,
392
+ )
393
+ )
394
+
395
+ if not is_final_block:
396
+ if use_convtranspose:
397
+ blocks.append(
398
+ Upsample(
399
+ spatial_dims=spatial_dims, mode="deconv", in_channels=block_in_ch, out_channels=block_in_ch
400
+ )
401
+ )
402
+ else:
403
+ post_conv = Convolution(
404
+ spatial_dims=spatial_dims,
405
+ in_channels=block_in_ch,
406
+ out_channels=block_in_ch,
407
+ strides=1,
408
+ kernel_size=3,
409
+ padding=1,
410
+ conv_only=True,
411
+ )
412
+ blocks.append(
413
+ Upsample(
414
+ spatial_dims=spatial_dims,
415
+ mode="nontrainable",
416
+ in_channels=block_in_ch,
417
+ out_channels=block_in_ch,
418
+ interp_mode="nearest",
419
+ scale_factor=2.0,
420
+ post_conv=post_conv,
421
+ align_corners=None,
422
+ )
423
+ )
424
+
425
+ blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
426
+ blocks.append(
427
+ Convolution(
428
+ spatial_dims=spatial_dims,
429
+ in_channels=block_in_ch,
430
+ out_channels=out_channels,
431
+ strides=1,
432
+ kernel_size=3,
433
+ padding=1,
434
+ conv_only=True,
435
+ )
436
+ )
437
+
438
+ self.blocks = nn.ModuleList(blocks)
439
+
440
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
441
+ for block in self.blocks:
442
+ x = block(x)
443
+ return x
444
+
445
+
446
+ class AutoencoderKL(nn.Module):
447
+ """
448
+ Autoencoder model with KL-regularized latent space based on
449
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
450
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
451
+
452
+ Args:
453
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
454
+ in_channels: number of input channels.
455
+ out_channels: number of output channels.
456
+ num_res_blocks: number of residual blocks (see _ResBlock) per level.
457
+ channels: number of output channels for each block.
458
+ attention_levels: sequence of levels to add attention.
459
+ latent_channels: latent embedding dimension.
460
+ norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
461
+ norm_eps: epsilon for the normalization.
462
+ with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
463
+ with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
464
+ use_checkpoint: if True, use activation checkpoint to save memory.
465
+ use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
466
+ """
467
+
468
+ def __init__(
469
+ self,
470
+ spatial_dims: int,
471
+ in_channels: int = 1,
472
+ out_channels: int = 1,
473
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
474
+ channels: Sequence[int] = (32, 64, 64, 64),
475
+ attention_levels: Sequence[bool] = (False, False, True, True),
476
+ latent_channels: int = 3,
477
+ norm_num_groups: int = 32,
478
+ norm_eps: float = 1e-6,
479
+ with_encoder_nonlocal_attn: bool = True,
480
+ with_decoder_nonlocal_attn: bool = True,
481
+ use_checkpoint: bool = False,
482
+ use_convtranspose: bool = False,
483
+ ) -> None:
484
+ super().__init__()
485
+
486
+ # All number of channels should be multiple of num_groups
487
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
488
+ raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups")
489
+
490
+ if len(channels) != len(attention_levels):
491
+ raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")
492
+
493
+ if isinstance(num_res_blocks, int):
494
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
495
+
496
+ if len(num_res_blocks) != len(channels):
497
+ raise ValueError(
498
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
499
+ "`num_channels`."
500
+ )
501
+
502
+ self.encoder = Encoder(
503
+ spatial_dims=spatial_dims,
504
+ in_channels=in_channels,
505
+ channels=channels,
506
+ out_channels=latent_channels,
507
+ num_res_blocks=num_res_blocks,
508
+ norm_num_groups=norm_num_groups,
509
+ norm_eps=norm_eps,
510
+ attention_levels=attention_levels,
511
+ with_nonlocal_attn=with_encoder_nonlocal_attn,
512
+ )
513
+ self.decoder = Decoder(
514
+ spatial_dims=spatial_dims,
515
+ channels=channels,
516
+ in_channels=latent_channels,
517
+ out_channels=out_channels,
518
+ num_res_blocks=num_res_blocks,
519
+ norm_num_groups=norm_num_groups,
520
+ norm_eps=norm_eps,
521
+ attention_levels=attention_levels,
522
+ with_nonlocal_attn=with_decoder_nonlocal_attn,
523
+ use_convtranspose=use_convtranspose,
524
+ )
525
+ self.quant_conv_mu = Convolution(
526
+ spatial_dims=spatial_dims,
527
+ in_channels=latent_channels,
528
+ out_channels=latent_channels,
529
+ strides=1,
530
+ kernel_size=1,
531
+ padding=0,
532
+ conv_only=True,
533
+ )
534
+ self.quant_conv_log_sigma = Convolution(
535
+ spatial_dims=spatial_dims,
536
+ in_channels=latent_channels,
537
+ out_channels=latent_channels,
538
+ strides=1,
539
+ kernel_size=1,
540
+ padding=0,
541
+ conv_only=True,
542
+ )
543
+ self.post_quant_conv = Convolution(
544
+ spatial_dims=spatial_dims,
545
+ in_channels=latent_channels,
546
+ out_channels=latent_channels,
547
+ strides=1,
548
+ kernel_size=1,
549
+ padding=0,
550
+ conv_only=True,
551
+ )
552
+ self.latent_channels = latent_channels
553
+ self.use_checkpoint = use_checkpoint
554
+
555
+ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
556
+ """
557
+ Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.
558
+
559
+ Args:
560
+ x: BxCx[SPATIAL DIMS] tensor
561
+
562
+ """
563
+ if self.use_checkpoint:
564
+ h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False)
565
+ else:
566
+ h = self.encoder(x)
567
+
568
+ z_mu = self.quant_conv_mu(h)
569
+ z_log_var = self.quant_conv_log_sigma(h)
570
+ z_log_var = torch.clamp(z_log_var, -30.0, 20.0)
571
+ z_sigma = torch.exp(z_log_var / 2)
572
+
573
+ return z_mu, z_sigma
574
+
575
+ def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:
576
+ """
577
+ From the mean and sigma representations resulting of encoding an image through the latent space,
578
+ obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
579
+ adding the mean.
580
+
581
+ Args:
582
+ z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
583
+ z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image
584
+
585
+ Returns:
586
+ sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
587
+ """
588
+ eps = torch.randn_like(z_sigma)
589
+ z_vae = z_mu + eps * z_sigma
590
+ return z_vae
591
+
592
+ def reconstruct(self, x: torch.Tensor) -> torch.Tensor:
593
+ """
594
+ Encodes and decodes an input image.
595
+
596
+ Args:
597
+ x: BxCx[SPATIAL DIMENSIONS] tensor.
598
+
599
+ Returns:
600
+ reconstructed image, of the same shape as input
601
+ """
602
+ z_mu, _ = self.encode(x)
603
+ reconstruction = self.decode(z_mu)
604
+ return reconstruction
605
+
606
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
607
+ """
608
+ Based on a latent space sample, forwards it through the Decoder.
609
+
610
+ Args:
611
+ z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]
612
+
613
+ Returns:
614
+ decoded image tensor
615
+ """
616
+ z = self.post_quant_conv(z)
617
+ dec: torch.Tensor
618
+ if self.use_checkpoint:
619
+ dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
620
+ else:
621
+ dec = self.decoder(z)
622
+ return dec
623
+
624
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
625
+ z_mu, z_sigma = self.encode(x)
626
+ z = self.sampling(z_mu, z_sigma)
627
+ reconstruction = self.decode(z)
628
+ return reconstruction, z_mu, z_sigma
629
+
630
+ def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
631
+ z_mu, z_sigma = self.encode(x)
632
+ z = self.sampling(z_mu, z_sigma)
633
+ return z
634
+
635
+ def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
636
+ image = self.decode(z)
637
+ return image
638
+
639
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
640
+ """
641
+ Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
642
+
643
+ Args:
644
+ old_state_dict: state dict from the old AutoencoderKL model.
645
+ """
646
+
647
+ new_state_dict = self.state_dict()
648
+ # if all keys match, just load the state dict
649
+ if all(k in new_state_dict for k in old_state_dict):
650
+ print("All keys match, loading state dict.")
651
+ self.load_state_dict(old_state_dict)
652
+ return
653
+
654
+ if verbose:
655
+ # print all new_state_dict keys that are not in old_state_dict
656
+ for k in new_state_dict:
657
+ if k not in old_state_dict:
658
+ print(f"key {k} not found in old state dict")
659
+ # and vice versa
660
+ print("----------------------------------------------")
661
+ for k in old_state_dict:
662
+ if k not in new_state_dict:
663
+ print(f"key {k} not found in new state dict")
664
+
665
+ # copy over all matching keys
666
+ for k in new_state_dict:
667
+ if k in old_state_dict:
668
+ new_state_dict[k] = old_state_dict[k]
669
+
670
+ # fix the attention blocks
671
+ attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k]
672
+ for block in attention_blocks:
673
+ new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat(
674
+ [
675
+ old_state_dict[f"{block}.to_q.weight"],
676
+ old_state_dict[f"{block}.to_k.weight"],
677
+ old_state_dict[f"{block}.to_v.weight"],
678
+ ],
679
+ dim=0,
680
+ )
681
+ new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat(
682
+ [
683
+ old_state_dict[f"{block}.to_q.bias"],
684
+ old_state_dict[f"{block}.to_k.bias"],
685
+ old_state_dict[f"{block}.to_v.bias"],
686
+ ],
687
+ dim=0,
688
+ )
689
+ # old version did not have a projection so set these to the identity
690
+ new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
691
+ new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
692
+ )
693
+ new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
694
+ new_state_dict[f"{block}.attn.out_proj.bias"].shape
695
+ )
696
+
697
+ # fix the upsample conv blocks which were renamed postconv
698
+ for k in new_state_dict:
699
+ if "postconv" in k:
700
+ old_name = k.replace("postconv", "conv")
701
+ new_state_dict[k] = old_state_dict[old_name]
702
+ self.load_state_dict(new_state_dict)