monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/apps/vista3d/inferer.py +177 -0
  7. monai/apps/vista3d/sampler.py +179 -0
  8. monai/apps/vista3d/transforms.py +224 -0
  9. monai/bundle/scripts.py +29 -17
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/inferers/utils.py +1 -0
  13. monai/losses/__init__.py +1 -0
  14. monai/losses/dice.py +10 -1
  15. monai/losses/nacl_loss.py +139 -0
  16. monai/networks/blocks/crossattention.py +48 -26
  17. monai/networks/blocks/mlp.py +16 -4
  18. monai/networks/blocks/selfattention.py +75 -23
  19. monai/networks/blocks/spatialattention.py +16 -1
  20. monai/networks/blocks/transformerblock.py +17 -2
  21. monai/networks/layers/filtering.py +6 -2
  22. monai/networks/nets/__init__.py +2 -1
  23. monai/networks/nets/autoencoderkl.py +55 -22
  24. monai/networks/nets/cell_sam_wrapper.py +92 -0
  25. monai/networks/nets/controlnet.py +24 -22
  26. monai/networks/nets/diffusion_model_unet.py +159 -19
  27. monai/networks/nets/segresnet_ds.py +127 -1
  28. monai/networks/nets/spade_autoencoderkl.py +22 -0
  29. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  30. monai/networks/nets/transformer.py +17 -17
  31. monai/networks/nets/vista3d.py +946 -0
  32. monai/networks/utils.py +4 -4
  33. monai/transforms/__init__.py +13 -2
  34. monai/transforms/io/array.py +59 -3
  35. monai/transforms/io/dictionary.py +29 -2
  36. monai/transforms/spatial/functional.py +1 -1
  37. monai/transforms/transform.py +2 -2
  38. monai/transforms/utility/dictionary.py +4 -0
  39. monai/transforms/utils.py +230 -1
  40. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  41. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  42. monai/utils/enums.py +1 -0
  43. monai/utils/module.py +7 -6
  44. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
  45. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
  46. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
  47. /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
  48. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
  49. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
@@ -66,6 +66,10 @@ class DiffusionUNetTransformerBlock(nn.Module):
66
66
  dropout: dropout probability to use.
67
67
  cross_attention_dim: size of the context vector for cross attention.
68
68
  upcast_attention: if True, upcast attention operations to full precision.
69
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
70
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
71
+ include_fc: whether to include the final linear layer. Default to True.
72
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
69
73
 
70
74
  """
71
75
 
@@ -77,6 +81,9 @@ class DiffusionUNetTransformerBlock(nn.Module):
77
81
  dropout: float = 0.0,
78
82
  cross_attention_dim: int | None = None,
79
83
  upcast_attention: bool = False,
84
+ use_flash_attention: bool = False,
85
+ include_fc: bool = True,
86
+ use_combined_linear: bool = False,
80
87
  ) -> None:
81
88
  super().__init__()
82
89
  self.attn1 = SABlock(
@@ -86,6 +93,9 @@ class DiffusionUNetTransformerBlock(nn.Module):
86
93
  dim_head=num_head_channels,
87
94
  dropout_rate=dropout,
88
95
  attention_dtype=torch.float if upcast_attention else None,
96
+ include_fc=include_fc,
97
+ use_combined_linear=use_combined_linear,
98
+ use_flash_attention=use_flash_attention,
89
99
  )
90
100
  self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
91
101
  self.attn2 = CrossAttentionBlock(
@@ -96,6 +106,7 @@ class DiffusionUNetTransformerBlock(nn.Module):
96
106
  dim_head=num_head_channels,
97
107
  dropout_rate=dropout,
98
108
  attention_dtype=torch.float if upcast_attention else None,
109
+ use_flash_attention=use_flash_attention,
99
110
  )
100
111
  self.norm1 = nn.LayerNorm(num_channels)
101
112
  self.norm2 = nn.LayerNorm(num_channels)
@@ -129,6 +140,11 @@ class SpatialTransformer(nn.Module):
129
140
  norm_eps: epsilon for the normalization.
130
141
  cross_attention_dim: number of context dimensions to use.
131
142
  upcast_attention: if True, upcast attention operations to full precision.
143
+ include_fc: whether to include the final linear layer. Default to True.
144
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
145
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
146
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
147
+
132
148
  """
133
149
 
134
150
  def __init__(
@@ -143,6 +159,9 @@ class SpatialTransformer(nn.Module):
143
159
  norm_eps: float = 1e-6,
144
160
  cross_attention_dim: int | None = None,
145
161
  upcast_attention: bool = False,
162
+ include_fc: bool = True,
163
+ use_combined_linear: bool = False,
164
+ use_flash_attention: bool = False,
146
165
  ) -> None:
147
166
  super().__init__()
148
167
  self.spatial_dims = spatial_dims
@@ -170,6 +189,9 @@ class SpatialTransformer(nn.Module):
170
189
  dropout=dropout,
171
190
  cross_attention_dim=cross_attention_dim,
172
191
  upcast_attention=upcast_attention,
192
+ include_fc=include_fc,
193
+ use_combined_linear=use_combined_linear,
194
+ use_flash_attention=use_flash_attention,
173
195
  )
174
196
  for _ in range(num_layers)
175
197
  ]
@@ -524,6 +546,10 @@ class AttnDownBlock(nn.Module):
524
546
  resblock_updown: if True use residual blocks for downsampling.
525
547
  downsample_padding: padding used in the downsampling block.
526
548
  num_head_channels: number of channels in each attention head.
549
+ include_fc: whether to include the final linear layer. Default to True.
550
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
551
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
552
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
527
553
  """
528
554
 
529
555
  def __init__(
@@ -539,6 +565,9 @@ class AttnDownBlock(nn.Module):
539
565
  resblock_updown: bool = False,
540
566
  downsample_padding: int = 1,
541
567
  num_head_channels: int = 1,
568
+ include_fc: bool = True,
569
+ use_combined_linear: bool = False,
570
+ use_flash_attention: bool = False,
542
571
  ) -> None:
543
572
  super().__init__()
544
573
  self.resblock_updown = resblock_updown
@@ -565,6 +594,9 @@ class AttnDownBlock(nn.Module):
565
594
  num_head_channels=num_head_channels,
566
595
  norm_num_groups=norm_num_groups,
567
596
  norm_eps=norm_eps,
597
+ include_fc=include_fc,
598
+ use_combined_linear=use_combined_linear,
599
+ use_flash_attention=use_flash_attention,
568
600
  )
569
601
  )
570
602
 
@@ -631,7 +663,11 @@ class CrossAttnDownBlock(nn.Module):
631
663
  transformer_num_layers: number of layers of Transformer blocks to use.
632
664
  cross_attention_dim: number of context dimensions to use.
633
665
  upcast_attention: if True, upcast attention operations to full precision.
634
- dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
666
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
667
+ include_fc: whether to include the final linear layer. Default to True.
668
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
669
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
670
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
635
671
  """
636
672
 
637
673
  def __init__(
@@ -651,6 +687,9 @@ class CrossAttnDownBlock(nn.Module):
651
687
  cross_attention_dim: int | None = None,
652
688
  upcast_attention: bool = False,
653
689
  dropout_cattn: float = 0.0,
690
+ include_fc: bool = True,
691
+ use_combined_linear: bool = False,
692
+ use_flash_attention: bool = False,
654
693
  ) -> None:
655
694
  super().__init__()
656
695
  self.resblock_updown = resblock_updown
@@ -683,6 +722,9 @@ class CrossAttnDownBlock(nn.Module):
683
722
  cross_attention_dim=cross_attention_dim,
684
723
  upcast_attention=upcast_attention,
685
724
  dropout=dropout_cattn,
725
+ include_fc=include_fc,
726
+ use_combined_linear=use_combined_linear,
727
+ use_flash_attention=use_flash_attention,
686
728
  )
687
729
  )
688
730
 
@@ -740,6 +782,10 @@ class AttnMidBlock(nn.Module):
740
782
  norm_num_groups: number of groups for the group normalization.
741
783
  norm_eps: epsilon for the group normalization.
742
784
  num_head_channels: number of channels in each attention head.
785
+ include_fc: whether to include the final linear layer. Default to True.
786
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
787
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
788
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
743
789
  """
744
790
 
745
791
  def __init__(
@@ -750,6 +796,9 @@ class AttnMidBlock(nn.Module):
750
796
  norm_num_groups: int = 32,
751
797
  norm_eps: float = 1e-6,
752
798
  num_head_channels: int = 1,
799
+ include_fc: bool = True,
800
+ use_combined_linear: bool = False,
801
+ use_flash_attention: bool = False,
753
802
  ) -> None:
754
803
  super().__init__()
755
804
 
@@ -767,6 +816,9 @@ class AttnMidBlock(nn.Module):
767
816
  num_head_channels=num_head_channels,
768
817
  norm_num_groups=norm_num_groups,
769
818
  norm_eps=norm_eps,
819
+ include_fc=include_fc,
820
+ use_combined_linear=use_combined_linear,
821
+ use_flash_attention=use_flash_attention,
770
822
  )
771
823
 
772
824
  self.resnet_2 = DiffusionUNetResnetBlock(
@@ -803,6 +855,10 @@ class CrossAttnMidBlock(nn.Module):
803
855
  transformer_num_layers: number of layers of Transformer blocks to use.
804
856
  cross_attention_dim: number of context dimensions to use.
805
857
  upcast_attention: if True, upcast attention operations to full precision.
858
+ include_fc: whether to include the final linear layer. Default to True.
859
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
860
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
861
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
806
862
  """
807
863
 
808
864
  def __init__(
@@ -817,6 +873,9 @@ class CrossAttnMidBlock(nn.Module):
817
873
  cross_attention_dim: int | None = None,
818
874
  upcast_attention: bool = False,
819
875
  dropout_cattn: float = 0.0,
876
+ include_fc: bool = True,
877
+ use_combined_linear: bool = False,
878
+ use_flash_attention: bool = False,
820
879
  ) -> None:
821
880
  super().__init__()
822
881
 
@@ -839,6 +898,9 @@ class CrossAttnMidBlock(nn.Module):
839
898
  cross_attention_dim=cross_attention_dim,
840
899
  upcast_attention=upcast_attention,
841
900
  dropout=dropout_cattn,
901
+ include_fc=include_fc,
902
+ use_combined_linear=use_combined_linear,
903
+ use_flash_attention=use_flash_attention,
842
904
  )
843
905
  self.resnet_2 = DiffusionUNetResnetBlock(
844
906
  spatial_dims=spatial_dims,
@@ -984,6 +1046,10 @@ class AttnUpBlock(nn.Module):
984
1046
  add_upsample: if True add downsample block.
985
1047
  resblock_updown: if True use residual blocks for upsampling.
986
1048
  num_head_channels: number of channels in each attention head.
1049
+ include_fc: whether to include the final linear layer. Default to True.
1050
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
1051
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
1052
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
987
1053
  """
988
1054
 
989
1055
  def __init__(
@@ -999,6 +1065,9 @@ class AttnUpBlock(nn.Module):
999
1065
  add_upsample: bool = True,
1000
1066
  resblock_updown: bool = False,
1001
1067
  num_head_channels: int = 1,
1068
+ include_fc: bool = True,
1069
+ use_combined_linear: bool = False,
1070
+ use_flash_attention: bool = False,
1002
1071
  ) -> None:
1003
1072
  super().__init__()
1004
1073
  self.resblock_updown = resblock_updown
@@ -1027,6 +1096,9 @@ class AttnUpBlock(nn.Module):
1027
1096
  num_head_channels=num_head_channels,
1028
1097
  norm_num_groups=norm_num_groups,
1029
1098
  norm_eps=norm_eps,
1099
+ include_fc=include_fc,
1100
+ use_combined_linear=use_combined_linear,
1101
+ use_flash_attention=use_flash_attention,
1030
1102
  )
1031
1103
  )
1032
1104
 
@@ -1111,7 +1183,11 @@ class CrossAttnUpBlock(nn.Module):
1111
1183
  transformer_num_layers: number of layers of Transformer blocks to use.
1112
1184
  cross_attention_dim: number of context dimensions to use.
1113
1185
  upcast_attention: if True, upcast attention operations to full precision.
1114
- dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
1186
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
1187
+ include_fc: whether to include the final linear layer. Default to True.
1188
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
1189
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
1190
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
1115
1191
  """
1116
1192
 
1117
1193
  def __init__(
@@ -1131,6 +1207,9 @@ class CrossAttnUpBlock(nn.Module):
1131
1207
  cross_attention_dim: int | None = None,
1132
1208
  upcast_attention: bool = False,
1133
1209
  dropout_cattn: float = 0.0,
1210
+ include_fc: bool = True,
1211
+ use_combined_linear: bool = False,
1212
+ use_flash_attention: bool = False,
1134
1213
  ) -> None:
1135
1214
  super().__init__()
1136
1215
  self.resblock_updown = resblock_updown
@@ -1164,6 +1243,9 @@ class CrossAttnUpBlock(nn.Module):
1164
1243
  cross_attention_dim=cross_attention_dim,
1165
1244
  upcast_attention=upcast_attention,
1166
1245
  dropout=dropout_cattn,
1246
+ include_fc=include_fc,
1247
+ use_combined_linear=use_combined_linear,
1248
+ use_flash_attention=use_flash_attention,
1167
1249
  )
1168
1250
  )
1169
1251
 
@@ -1245,6 +1327,9 @@ def get_down_block(
1245
1327
  cross_attention_dim: int | None,
1246
1328
  upcast_attention: bool = False,
1247
1329
  dropout_cattn: float = 0.0,
1330
+ include_fc: bool = True,
1331
+ use_combined_linear: bool = False,
1332
+ use_flash_attention: bool = False,
1248
1333
  ) -> nn.Module:
1249
1334
  if with_attn:
1250
1335
  return AttnDownBlock(
@@ -1258,6 +1343,9 @@ def get_down_block(
1258
1343
  add_downsample=add_downsample,
1259
1344
  resblock_updown=resblock_updown,
1260
1345
  num_head_channels=num_head_channels,
1346
+ include_fc=include_fc,
1347
+ use_combined_linear=use_combined_linear,
1348
+ use_flash_attention=use_flash_attention,
1261
1349
  )
1262
1350
  elif with_cross_attn:
1263
1351
  return CrossAttnDownBlock(
@@ -1275,6 +1363,9 @@ def get_down_block(
1275
1363
  cross_attention_dim=cross_attention_dim,
1276
1364
  upcast_attention=upcast_attention,
1277
1365
  dropout_cattn=dropout_cattn,
1366
+ include_fc=include_fc,
1367
+ use_combined_linear=use_combined_linear,
1368
+ use_flash_attention=use_flash_attention,
1278
1369
  )
1279
1370
  else:
1280
1371
  return DownBlock(
@@ -1302,6 +1393,9 @@ def get_mid_block(
1302
1393
  cross_attention_dim: int | None,
1303
1394
  upcast_attention: bool = False,
1304
1395
  dropout_cattn: float = 0.0,
1396
+ include_fc: bool = True,
1397
+ use_combined_linear: bool = False,
1398
+ use_flash_attention: bool = False,
1305
1399
  ) -> nn.Module:
1306
1400
  if with_conditioning:
1307
1401
  return CrossAttnMidBlock(
@@ -1315,6 +1409,9 @@ def get_mid_block(
1315
1409
  cross_attention_dim=cross_attention_dim,
1316
1410
  upcast_attention=upcast_attention,
1317
1411
  dropout_cattn=dropout_cattn,
1412
+ include_fc=include_fc,
1413
+ use_combined_linear=use_combined_linear,
1414
+ use_flash_attention=use_flash_attention,
1318
1415
  )
1319
1416
  else:
1320
1417
  return AttnMidBlock(
@@ -1324,6 +1421,9 @@ def get_mid_block(
1324
1421
  norm_num_groups=norm_num_groups,
1325
1422
  norm_eps=norm_eps,
1326
1423
  num_head_channels=num_head_channels,
1424
+ include_fc=include_fc,
1425
+ use_combined_linear=use_combined_linear,
1426
+ use_flash_attention=use_flash_attention,
1327
1427
  )
1328
1428
 
1329
1429
 
@@ -1345,6 +1445,9 @@ def get_up_block(
1345
1445
  cross_attention_dim: int | None,
1346
1446
  upcast_attention: bool = False,
1347
1447
  dropout_cattn: float = 0.0,
1448
+ include_fc: bool = True,
1449
+ use_combined_linear: bool = False,
1450
+ use_flash_attention: bool = False,
1348
1451
  ) -> nn.Module:
1349
1452
  if with_attn:
1350
1453
  return AttnUpBlock(
@@ -1359,6 +1462,9 @@ def get_up_block(
1359
1462
  add_upsample=add_upsample,
1360
1463
  resblock_updown=resblock_updown,
1361
1464
  num_head_channels=num_head_channels,
1465
+ include_fc=include_fc,
1466
+ use_combined_linear=use_combined_linear,
1467
+ use_flash_attention=use_flash_attention,
1362
1468
  )
1363
1469
  elif with_cross_attn:
1364
1470
  return CrossAttnUpBlock(
@@ -1377,6 +1483,9 @@ def get_up_block(
1377
1483
  cross_attention_dim=cross_attention_dim,
1378
1484
  upcast_attention=upcast_attention,
1379
1485
  dropout_cattn=dropout_cattn,
1486
+ include_fc=include_fc,
1487
+ use_combined_linear=use_combined_linear,
1488
+ use_flash_attention=use_flash_attention,
1380
1489
  )
1381
1490
  else:
1382
1491
  return UpBlock(
@@ -1414,9 +1523,13 @@ class DiffusionModelUNet(nn.Module):
1414
1523
  transformer_num_layers: number of layers of Transformer blocks to use.
1415
1524
  cross_attention_dim: number of context dimensions to use.
1416
1525
  num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
1417
- classes.
1526
+ classes.
1418
1527
  upcast_attention: if True, upcast attention operations to full precision.
1419
- dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
1528
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
1529
+ include_fc: whether to include the final linear layer. Default to True.
1530
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
1531
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
1532
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
1420
1533
  """
1421
1534
 
1422
1535
  def __init__(
@@ -1437,6 +1550,9 @@ class DiffusionModelUNet(nn.Module):
1437
1550
  num_class_embeds: int | None = None,
1438
1551
  upcast_attention: bool = False,
1439
1552
  dropout_cattn: float = 0.0,
1553
+ include_fc: bool = True,
1554
+ use_combined_linear: bool = False,
1555
+ use_flash_attention: bool = False,
1440
1556
  ) -> None:
1441
1557
  super().__init__()
1442
1558
  if with_conditioning is True and cross_attention_dim is None:
@@ -1531,6 +1647,9 @@ class DiffusionModelUNet(nn.Module):
1531
1647
  cross_attention_dim=cross_attention_dim,
1532
1648
  upcast_attention=upcast_attention,
1533
1649
  dropout_cattn=dropout_cattn,
1650
+ include_fc=include_fc,
1651
+ use_combined_linear=use_combined_linear,
1652
+ use_flash_attention=use_flash_attention,
1534
1653
  )
1535
1654
 
1536
1655
  self.down_blocks.append(down_block)
@@ -1548,6 +1667,9 @@ class DiffusionModelUNet(nn.Module):
1548
1667
  cross_attention_dim=cross_attention_dim,
1549
1668
  upcast_attention=upcast_attention,
1550
1669
  dropout_cattn=dropout_cattn,
1670
+ include_fc=include_fc,
1671
+ use_combined_linear=use_combined_linear,
1672
+ use_flash_attention=use_flash_attention,
1551
1673
  )
1552
1674
 
1553
1675
  # up
@@ -1582,6 +1704,9 @@ class DiffusionModelUNet(nn.Module):
1582
1704
  cross_attention_dim=cross_attention_dim,
1583
1705
  upcast_attention=upcast_attention,
1584
1706
  dropout_cattn=dropout_cattn,
1707
+ include_fc=include_fc,
1708
+ use_combined_linear=use_combined_linear,
1709
+ use_flash_attention=use_flash_attention,
1585
1710
  )
1586
1711
 
1587
1712
  self.up_blocks.append(up_block)
@@ -1709,31 +1834,40 @@ class DiffusionModelUNet(nn.Module):
1709
1834
  # copy over all matching keys
1710
1835
  for k in new_state_dict:
1711
1836
  if k in old_state_dict:
1712
- new_state_dict[k] = old_state_dict[k]
1837
+ new_state_dict[k] = old_state_dict.pop(k)
1713
1838
 
1714
1839
  # fix the attention blocks
1715
- attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
1840
+ attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k]
1716
1841
  for block in attention_blocks:
1717
- new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat(
1718
- [
1719
- old_state_dict[f"{block}.attn1.to_q.weight"],
1720
- old_state_dict[f"{block}.attn1.to_k.weight"],
1721
- old_state_dict[f"{block}.attn1.to_v.weight"],
1722
- ],
1723
- dim=0,
1724
- )
1842
+ new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
1843
+ new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
1844
+ new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
1845
+ new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
1846
+ new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
1847
+ new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
1725
1848
 
1726
1849
  # projection
1727
- new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
1728
- new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]
1850
+ new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1851
+ new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
1852
+
1853
+ # fix the cross attention blocks
1854
+ cross_attention_blocks = [
1855
+ k.replace(".out_proj.weight", "")
1856
+ for k in new_state_dict
1857
+ if "out_proj.weight" in k and "transformer_blocks" in k
1858
+ ]
1859
+ for block in cross_attention_blocks:
1860
+ new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight")
1861
+ new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")
1729
1862
 
1730
- new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
1731
- new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]
1732
1863
  # fix the upsample conv blocks which were renamed postconv
1733
1864
  for k in new_state_dict:
1734
1865
  if "postconv" in k:
1735
1866
  old_name = k.replace("postconv", "conv")
1736
- new_state_dict[k] = old_state_dict[old_name]
1867
+ new_state_dict[k] = old_state_dict.pop(old_name)
1868
+ if verbose:
1869
+ # print all remaining keys in old_state_dict
1870
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
1737
1871
  self.load_state_dict(new_state_dict)
1738
1872
 
1739
1873
 
@@ -1777,6 +1911,9 @@ class DiffusionModelEncoder(nn.Module):
1777
1911
  cross_attention_dim: int | None = None,
1778
1912
  num_class_embeds: int | None = None,
1779
1913
  upcast_attention: bool = False,
1914
+ include_fc: bool = True,
1915
+ use_combined_linear: bool = False,
1916
+ use_flash_attention: bool = False,
1780
1917
  ) -> None:
1781
1918
  super().__init__()
1782
1919
  if with_conditioning is True and cross_attention_dim is None:
@@ -1861,6 +1998,9 @@ class DiffusionModelEncoder(nn.Module):
1861
1998
  transformer_num_layers=transformer_num_layers,
1862
1999
  cross_attention_dim=cross_attention_dim,
1863
2000
  upcast_attention=upcast_attention,
2001
+ include_fc=include_fc,
2002
+ use_combined_linear=use_combined_linear,
2003
+ use_flash_attention=use_flash_attention,
1864
2004
  )
1865
2005
 
1866
2006
  self.down_blocks.append(down_block)
@@ -11,6 +11,7 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ import copy
14
15
  from collections.abc import Callable
15
16
  from typing import Union
16
17
 
@@ -23,7 +24,7 @@ from monai.networks.layers.factories import Act, Conv, Norm, split_args
23
24
  from monai.networks.layers.utils import get_act_layer, get_norm_layer
24
25
  from monai.utils import UpsampleMode, has_option
25
26
 
26
- __all__ = ["SegResNetDS"]
27
+ __all__ = ["SegResNetDS", "SegResNetDS2"]
27
28
 
28
29
 
29
30
  def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
@@ -425,3 +426,128 @@ class SegResNetDS(nn.Module):
425
426
 
426
427
  def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
427
428
  return self._forward(x)
429
+
430
+
431
+ class SegResNetDS2(SegResNetDS):
432
+ """
433
+ SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D
434
+ <https://arxiv.org/abs/2406.05285>`_.
435
+
436
+ Args:
437
+ spatial_dims: spatial dimension of the input data. Defaults to 3.
438
+ init_filters: number of output channels for initial convolution layer. Defaults to 32.
439
+ in_channels: number of input channels for the network. Defaults to 1.
440
+ out_channels: number of output channels for the network. Defaults to 2.
441
+ act: activation type and arguments. Defaults to ``RELU``.
442
+ norm: feature normalization type and arguments. Defaults to ``BATCH``.
443
+ blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
444
+ blocks_up: number of upsample blocks (optional).
445
+ dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
446
+ At dsdepth==1,only a single output is returned.
447
+ preprocess: optional callable function to apply before the model's forward pass
448
+ resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
449
+ image spacing into an approximately isotropic space.
450
+ Otherwise, by default, the kernel size and downsampling is always isotropic.
451
+
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ spatial_dims: int = 3,
457
+ init_filters: int = 32,
458
+ in_channels: int = 1,
459
+ out_channels: int = 2,
460
+ act: tuple | str = "relu",
461
+ norm: tuple | str = "batch",
462
+ blocks_down: tuple = (1, 2, 2, 4),
463
+ blocks_up: tuple | None = None,
464
+ dsdepth: int = 1,
465
+ preprocess: nn.Module | Callable | None = None,
466
+ upsample_mode: UpsampleMode | str = "deconv",
467
+ resolution: tuple | None = None,
468
+ ):
469
+ super().__init__(
470
+ spatial_dims=spatial_dims,
471
+ init_filters=init_filters,
472
+ in_channels=in_channels,
473
+ out_channels=out_channels,
474
+ act=act,
475
+ norm=norm,
476
+ blocks_down=blocks_down,
477
+ blocks_up=blocks_up,
478
+ dsdepth=dsdepth,
479
+ preprocess=preprocess,
480
+ upsample_mode=upsample_mode,
481
+ resolution=resolution,
482
+ )
483
+
484
+ self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers])
485
+
486
+ def forward( # type: ignore
487
+ self, x: torch.Tensor, with_point: bool = True, with_label: bool = True
488
+ ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]:
489
+ """
490
+ Args:
491
+ x: input tensor.
492
+ with_point: if true, return the point branch output.
493
+ with_label: if true, return the label branch output.
494
+ """
495
+ if self.preprocess is not None:
496
+ x = self.preprocess(x)
497
+
498
+ if not self.is_valid_shape(x):
499
+ raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")
500
+
501
+ x_down = self.encoder(x)
502
+
503
+ x_down.reverse()
504
+ x = x_down.pop(0)
505
+
506
+ if len(x_down) == 0:
507
+ x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]
508
+
509
+ outputs: list[torch.Tensor] = []
510
+ outputs_auto: list[torch.Tensor] = []
511
+ x_ = x.clone()
512
+ if with_point:
513
+ i = 0
514
+ for level in self.up_layers:
515
+ x = level["upsample"](x)
516
+ x = x + x_down[i]
517
+ x = level["blocks"](x)
518
+
519
+ if len(self.up_layers) - i <= self.dsdepth:
520
+ outputs.append(level["head"](x))
521
+ i = i + 1
522
+
523
+ outputs.reverse()
524
+ x = x_
525
+ if with_label:
526
+ i = 0
527
+ for level in self.up_layers_auto:
528
+ x = level["upsample"](x)
529
+ x = x + x_down[i]
530
+ x = level["blocks"](x)
531
+
532
+ if len(self.up_layers) - i <= self.dsdepth:
533
+ outputs_auto.append(level["head"](x))
534
+ i = i + 1
535
+
536
+ outputs_auto.reverse()
537
+
538
+ return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto
539
+
540
+ def set_auto_grad(self, auto_freeze=False, point_freeze=False):
541
+ """
542
+ Args:
543
+ auto_freeze: if true, freeze the image encoder and the auto-branch.
544
+ point_freeze: if true, freeze the image encoder and the point-branch.
545
+ """
546
+ for param in self.encoder.parameters():
547
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
548
+
549
+ for param in self.up_layers_auto.parameters():
550
+ param.requires_grad = not auto_freeze
551
+
552
+ for param in self.up_layers.parameters():
553
+ param.requires_grad = not point_freeze
@@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module):
137
137
  label_nc: number of semantic channels for SPADE normalisation.
138
138
  with_nonlocal_attn: if True use non-local attention block.
139
139
  spade_intermediate_channels: number of intermediate channels for SPADE block layer.
140
+ include_fc: whether to include the final linear layer. Default to True.
141
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
142
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
143
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
140
144
  """
141
145
 
142
146
  def __init__(
@@ -152,6 +156,9 @@ class SPADEDecoder(nn.Module):
152
156
  label_nc: int,
153
157
  with_nonlocal_attn: bool = True,
154
158
  spade_intermediate_channels: int = 128,
159
+ include_fc: bool = True,
160
+ use_combined_linear: bool = False,
161
+ use_flash_attention: bool = False,
155
162
  ) -> None:
156
163
  super().__init__()
157
164
  self.spatial_dims = spatial_dims
@@ -200,6 +207,9 @@ class SPADEDecoder(nn.Module):
200
207
  num_channels=reversed_block_out_channels[0],
201
208
  norm_num_groups=norm_num_groups,
202
209
  norm_eps=norm_eps,
210
+ include_fc=include_fc,
211
+ use_combined_linear=use_combined_linear,
212
+ use_flash_attention=use_flash_attention,
203
213
  )
204
214
  )
205
215
  blocks.append(
@@ -243,6 +253,9 @@ class SPADEDecoder(nn.Module):
243
253
  num_channels=block_in_ch,
244
254
  norm_num_groups=norm_num_groups,
245
255
  norm_eps=norm_eps,
256
+ include_fc=include_fc,
257
+ use_combined_linear=use_combined_linear,
258
+ use_flash_attention=use_flash_attention,
246
259
  )
247
260
  )
248
261
 
@@ -331,6 +344,9 @@ class SPADEAutoencoderKL(nn.Module):
331
344
  with_encoder_nonlocal_attn: bool = True,
332
345
  with_decoder_nonlocal_attn: bool = True,
333
346
  spade_intermediate_channels: int = 128,
347
+ include_fc: bool = True,
348
+ use_combined_linear: bool = False,
349
+ use_flash_attention: bool = False,
334
350
  ) -> None:
335
351
  super().__init__()
336
352
 
@@ -360,6 +376,9 @@ class SPADEAutoencoderKL(nn.Module):
360
376
  norm_eps=norm_eps,
361
377
  attention_levels=attention_levels,
362
378
  with_nonlocal_attn=with_encoder_nonlocal_attn,
379
+ include_fc=include_fc,
380
+ use_combined_linear=use_combined_linear,
381
+ use_flash_attention=use_flash_attention,
363
382
  )
364
383
  self.decoder = SPADEDecoder(
365
384
  spatial_dims=spatial_dims,
@@ -373,6 +392,9 @@ class SPADEAutoencoderKL(nn.Module):
373
392
  label_nc=label_nc,
374
393
  with_nonlocal_attn=with_decoder_nonlocal_attn,
375
394
  spade_intermediate_channels=spade_intermediate_channels,
395
+ include_fc=include_fc,
396
+ use_combined_linear=use_combined_linear,
397
+ use_flash_attention=use_flash_attention,
376
398
  )
377
399
  self.quant_conv_mu = Convolution(
378
400
  spatial_dims=spatial_dims,