monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/scripts.py +29 -17
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/inferers/utils.py +1 -0
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +22 -0
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +946 -0
- monai/networks/utils.py +4 -4
- monai/transforms/__init__.py +13 -2
- monai/transforms/io/array.py +59 -3
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utility/dictionary.py +4 -0
- monai/transforms/utils.py +230 -1
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/enums.py +1 -0
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
@@ -66,6 +66,10 @@ class DiffusionUNetTransformerBlock(nn.Module):
|
|
66
66
|
dropout: dropout probability to use.
|
67
67
|
cross_attention_dim: size of the context vector for cross attention.
|
68
68
|
upcast_attention: if True, upcast attention operations to full precision.
|
69
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
70
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
71
|
+
include_fc: whether to include the final linear layer. Default to True.
|
72
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
69
73
|
|
70
74
|
"""
|
71
75
|
|
@@ -77,6 +81,9 @@ class DiffusionUNetTransformerBlock(nn.Module):
|
|
77
81
|
dropout: float = 0.0,
|
78
82
|
cross_attention_dim: int | None = None,
|
79
83
|
upcast_attention: bool = False,
|
84
|
+
use_flash_attention: bool = False,
|
85
|
+
include_fc: bool = True,
|
86
|
+
use_combined_linear: bool = False,
|
80
87
|
) -> None:
|
81
88
|
super().__init__()
|
82
89
|
self.attn1 = SABlock(
|
@@ -86,6 +93,9 @@ class DiffusionUNetTransformerBlock(nn.Module):
|
|
86
93
|
dim_head=num_head_channels,
|
87
94
|
dropout_rate=dropout,
|
88
95
|
attention_dtype=torch.float if upcast_attention else None,
|
96
|
+
include_fc=include_fc,
|
97
|
+
use_combined_linear=use_combined_linear,
|
98
|
+
use_flash_attention=use_flash_attention,
|
89
99
|
)
|
90
100
|
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
|
91
101
|
self.attn2 = CrossAttentionBlock(
|
@@ -96,6 +106,7 @@ class DiffusionUNetTransformerBlock(nn.Module):
|
|
96
106
|
dim_head=num_head_channels,
|
97
107
|
dropout_rate=dropout,
|
98
108
|
attention_dtype=torch.float if upcast_attention else None,
|
109
|
+
use_flash_attention=use_flash_attention,
|
99
110
|
)
|
100
111
|
self.norm1 = nn.LayerNorm(num_channels)
|
101
112
|
self.norm2 = nn.LayerNorm(num_channels)
|
@@ -129,6 +140,11 @@ class SpatialTransformer(nn.Module):
|
|
129
140
|
norm_eps: epsilon for the normalization.
|
130
141
|
cross_attention_dim: number of context dimensions to use.
|
131
142
|
upcast_attention: if True, upcast attention operations to full precision.
|
143
|
+
include_fc: whether to include the final linear layer. Default to True.
|
144
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
145
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
146
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
147
|
+
|
132
148
|
"""
|
133
149
|
|
134
150
|
def __init__(
|
@@ -143,6 +159,9 @@ class SpatialTransformer(nn.Module):
|
|
143
159
|
norm_eps: float = 1e-6,
|
144
160
|
cross_attention_dim: int | None = None,
|
145
161
|
upcast_attention: bool = False,
|
162
|
+
include_fc: bool = True,
|
163
|
+
use_combined_linear: bool = False,
|
164
|
+
use_flash_attention: bool = False,
|
146
165
|
) -> None:
|
147
166
|
super().__init__()
|
148
167
|
self.spatial_dims = spatial_dims
|
@@ -170,6 +189,9 @@ class SpatialTransformer(nn.Module):
|
|
170
189
|
dropout=dropout,
|
171
190
|
cross_attention_dim=cross_attention_dim,
|
172
191
|
upcast_attention=upcast_attention,
|
192
|
+
include_fc=include_fc,
|
193
|
+
use_combined_linear=use_combined_linear,
|
194
|
+
use_flash_attention=use_flash_attention,
|
173
195
|
)
|
174
196
|
for _ in range(num_layers)
|
175
197
|
]
|
@@ -524,6 +546,10 @@ class AttnDownBlock(nn.Module):
|
|
524
546
|
resblock_updown: if True use residual blocks for downsampling.
|
525
547
|
downsample_padding: padding used in the downsampling block.
|
526
548
|
num_head_channels: number of channels in each attention head.
|
549
|
+
include_fc: whether to include the final linear layer. Default to True.
|
550
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
551
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
552
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
527
553
|
"""
|
528
554
|
|
529
555
|
def __init__(
|
@@ -539,6 +565,9 @@ class AttnDownBlock(nn.Module):
|
|
539
565
|
resblock_updown: bool = False,
|
540
566
|
downsample_padding: int = 1,
|
541
567
|
num_head_channels: int = 1,
|
568
|
+
include_fc: bool = True,
|
569
|
+
use_combined_linear: bool = False,
|
570
|
+
use_flash_attention: bool = False,
|
542
571
|
) -> None:
|
543
572
|
super().__init__()
|
544
573
|
self.resblock_updown = resblock_updown
|
@@ -565,6 +594,9 @@ class AttnDownBlock(nn.Module):
|
|
565
594
|
num_head_channels=num_head_channels,
|
566
595
|
norm_num_groups=norm_num_groups,
|
567
596
|
norm_eps=norm_eps,
|
597
|
+
include_fc=include_fc,
|
598
|
+
use_combined_linear=use_combined_linear,
|
599
|
+
use_flash_attention=use_flash_attention,
|
568
600
|
)
|
569
601
|
)
|
570
602
|
|
@@ -631,7 +663,11 @@ class CrossAttnDownBlock(nn.Module):
|
|
631
663
|
transformer_num_layers: number of layers of Transformer blocks to use.
|
632
664
|
cross_attention_dim: number of context dimensions to use.
|
633
665
|
upcast_attention: if True, upcast attention operations to full precision.
|
634
|
-
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
|
666
|
+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
|
667
|
+
include_fc: whether to include the final linear layer. Default to True.
|
668
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
669
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
670
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
635
671
|
"""
|
636
672
|
|
637
673
|
def __init__(
|
@@ -651,6 +687,9 @@ class CrossAttnDownBlock(nn.Module):
|
|
651
687
|
cross_attention_dim: int | None = None,
|
652
688
|
upcast_attention: bool = False,
|
653
689
|
dropout_cattn: float = 0.0,
|
690
|
+
include_fc: bool = True,
|
691
|
+
use_combined_linear: bool = False,
|
692
|
+
use_flash_attention: bool = False,
|
654
693
|
) -> None:
|
655
694
|
super().__init__()
|
656
695
|
self.resblock_updown = resblock_updown
|
@@ -683,6 +722,9 @@ class CrossAttnDownBlock(nn.Module):
|
|
683
722
|
cross_attention_dim=cross_attention_dim,
|
684
723
|
upcast_attention=upcast_attention,
|
685
724
|
dropout=dropout_cattn,
|
725
|
+
include_fc=include_fc,
|
726
|
+
use_combined_linear=use_combined_linear,
|
727
|
+
use_flash_attention=use_flash_attention,
|
686
728
|
)
|
687
729
|
)
|
688
730
|
|
@@ -740,6 +782,10 @@ class AttnMidBlock(nn.Module):
|
|
740
782
|
norm_num_groups: number of groups for the group normalization.
|
741
783
|
norm_eps: epsilon for the group normalization.
|
742
784
|
num_head_channels: number of channels in each attention head.
|
785
|
+
include_fc: whether to include the final linear layer. Default to True.
|
786
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
787
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
788
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
743
789
|
"""
|
744
790
|
|
745
791
|
def __init__(
|
@@ -750,6 +796,9 @@ class AttnMidBlock(nn.Module):
|
|
750
796
|
norm_num_groups: int = 32,
|
751
797
|
norm_eps: float = 1e-6,
|
752
798
|
num_head_channels: int = 1,
|
799
|
+
include_fc: bool = True,
|
800
|
+
use_combined_linear: bool = False,
|
801
|
+
use_flash_attention: bool = False,
|
753
802
|
) -> None:
|
754
803
|
super().__init__()
|
755
804
|
|
@@ -767,6 +816,9 @@ class AttnMidBlock(nn.Module):
|
|
767
816
|
num_head_channels=num_head_channels,
|
768
817
|
norm_num_groups=norm_num_groups,
|
769
818
|
norm_eps=norm_eps,
|
819
|
+
include_fc=include_fc,
|
820
|
+
use_combined_linear=use_combined_linear,
|
821
|
+
use_flash_attention=use_flash_attention,
|
770
822
|
)
|
771
823
|
|
772
824
|
self.resnet_2 = DiffusionUNetResnetBlock(
|
@@ -803,6 +855,10 @@ class CrossAttnMidBlock(nn.Module):
|
|
803
855
|
transformer_num_layers: number of layers of Transformer blocks to use.
|
804
856
|
cross_attention_dim: number of context dimensions to use.
|
805
857
|
upcast_attention: if True, upcast attention operations to full precision.
|
858
|
+
include_fc: whether to include the final linear layer. Default to True.
|
859
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
860
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
861
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
806
862
|
"""
|
807
863
|
|
808
864
|
def __init__(
|
@@ -817,6 +873,9 @@ class CrossAttnMidBlock(nn.Module):
|
|
817
873
|
cross_attention_dim: int | None = None,
|
818
874
|
upcast_attention: bool = False,
|
819
875
|
dropout_cattn: float = 0.0,
|
876
|
+
include_fc: bool = True,
|
877
|
+
use_combined_linear: bool = False,
|
878
|
+
use_flash_attention: bool = False,
|
820
879
|
) -> None:
|
821
880
|
super().__init__()
|
822
881
|
|
@@ -839,6 +898,9 @@ class CrossAttnMidBlock(nn.Module):
|
|
839
898
|
cross_attention_dim=cross_attention_dim,
|
840
899
|
upcast_attention=upcast_attention,
|
841
900
|
dropout=dropout_cattn,
|
901
|
+
include_fc=include_fc,
|
902
|
+
use_combined_linear=use_combined_linear,
|
903
|
+
use_flash_attention=use_flash_attention,
|
842
904
|
)
|
843
905
|
self.resnet_2 = DiffusionUNetResnetBlock(
|
844
906
|
spatial_dims=spatial_dims,
|
@@ -984,6 +1046,10 @@ class AttnUpBlock(nn.Module):
|
|
984
1046
|
add_upsample: if True add downsample block.
|
985
1047
|
resblock_updown: if True use residual blocks for upsampling.
|
986
1048
|
num_head_channels: number of channels in each attention head.
|
1049
|
+
include_fc: whether to include the final linear layer. Default to True.
|
1050
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
1051
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
1052
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
987
1053
|
"""
|
988
1054
|
|
989
1055
|
def __init__(
|
@@ -999,6 +1065,9 @@ class AttnUpBlock(nn.Module):
|
|
999
1065
|
add_upsample: bool = True,
|
1000
1066
|
resblock_updown: bool = False,
|
1001
1067
|
num_head_channels: int = 1,
|
1068
|
+
include_fc: bool = True,
|
1069
|
+
use_combined_linear: bool = False,
|
1070
|
+
use_flash_attention: bool = False,
|
1002
1071
|
) -> None:
|
1003
1072
|
super().__init__()
|
1004
1073
|
self.resblock_updown = resblock_updown
|
@@ -1027,6 +1096,9 @@ class AttnUpBlock(nn.Module):
|
|
1027
1096
|
num_head_channels=num_head_channels,
|
1028
1097
|
norm_num_groups=norm_num_groups,
|
1029
1098
|
norm_eps=norm_eps,
|
1099
|
+
include_fc=include_fc,
|
1100
|
+
use_combined_linear=use_combined_linear,
|
1101
|
+
use_flash_attention=use_flash_attention,
|
1030
1102
|
)
|
1031
1103
|
)
|
1032
1104
|
|
@@ -1111,7 +1183,11 @@ class CrossAttnUpBlock(nn.Module):
|
|
1111
1183
|
transformer_num_layers: number of layers of Transformer blocks to use.
|
1112
1184
|
cross_attention_dim: number of context dimensions to use.
|
1113
1185
|
upcast_attention: if True, upcast attention operations to full precision.
|
1114
|
-
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
|
1186
|
+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
|
1187
|
+
include_fc: whether to include the final linear layer. Default to True.
|
1188
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
1189
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
1190
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
1115
1191
|
"""
|
1116
1192
|
|
1117
1193
|
def __init__(
|
@@ -1131,6 +1207,9 @@ class CrossAttnUpBlock(nn.Module):
|
|
1131
1207
|
cross_attention_dim: int | None = None,
|
1132
1208
|
upcast_attention: bool = False,
|
1133
1209
|
dropout_cattn: float = 0.0,
|
1210
|
+
include_fc: bool = True,
|
1211
|
+
use_combined_linear: bool = False,
|
1212
|
+
use_flash_attention: bool = False,
|
1134
1213
|
) -> None:
|
1135
1214
|
super().__init__()
|
1136
1215
|
self.resblock_updown = resblock_updown
|
@@ -1164,6 +1243,9 @@ class CrossAttnUpBlock(nn.Module):
|
|
1164
1243
|
cross_attention_dim=cross_attention_dim,
|
1165
1244
|
upcast_attention=upcast_attention,
|
1166
1245
|
dropout=dropout_cattn,
|
1246
|
+
include_fc=include_fc,
|
1247
|
+
use_combined_linear=use_combined_linear,
|
1248
|
+
use_flash_attention=use_flash_attention,
|
1167
1249
|
)
|
1168
1250
|
)
|
1169
1251
|
|
@@ -1245,6 +1327,9 @@ def get_down_block(
|
|
1245
1327
|
cross_attention_dim: int | None,
|
1246
1328
|
upcast_attention: bool = False,
|
1247
1329
|
dropout_cattn: float = 0.0,
|
1330
|
+
include_fc: bool = True,
|
1331
|
+
use_combined_linear: bool = False,
|
1332
|
+
use_flash_attention: bool = False,
|
1248
1333
|
) -> nn.Module:
|
1249
1334
|
if with_attn:
|
1250
1335
|
return AttnDownBlock(
|
@@ -1258,6 +1343,9 @@ def get_down_block(
|
|
1258
1343
|
add_downsample=add_downsample,
|
1259
1344
|
resblock_updown=resblock_updown,
|
1260
1345
|
num_head_channels=num_head_channels,
|
1346
|
+
include_fc=include_fc,
|
1347
|
+
use_combined_linear=use_combined_linear,
|
1348
|
+
use_flash_attention=use_flash_attention,
|
1261
1349
|
)
|
1262
1350
|
elif with_cross_attn:
|
1263
1351
|
return CrossAttnDownBlock(
|
@@ -1275,6 +1363,9 @@ def get_down_block(
|
|
1275
1363
|
cross_attention_dim=cross_attention_dim,
|
1276
1364
|
upcast_attention=upcast_attention,
|
1277
1365
|
dropout_cattn=dropout_cattn,
|
1366
|
+
include_fc=include_fc,
|
1367
|
+
use_combined_linear=use_combined_linear,
|
1368
|
+
use_flash_attention=use_flash_attention,
|
1278
1369
|
)
|
1279
1370
|
else:
|
1280
1371
|
return DownBlock(
|
@@ -1302,6 +1393,9 @@ def get_mid_block(
|
|
1302
1393
|
cross_attention_dim: int | None,
|
1303
1394
|
upcast_attention: bool = False,
|
1304
1395
|
dropout_cattn: float = 0.0,
|
1396
|
+
include_fc: bool = True,
|
1397
|
+
use_combined_linear: bool = False,
|
1398
|
+
use_flash_attention: bool = False,
|
1305
1399
|
) -> nn.Module:
|
1306
1400
|
if with_conditioning:
|
1307
1401
|
return CrossAttnMidBlock(
|
@@ -1315,6 +1409,9 @@ def get_mid_block(
|
|
1315
1409
|
cross_attention_dim=cross_attention_dim,
|
1316
1410
|
upcast_attention=upcast_attention,
|
1317
1411
|
dropout_cattn=dropout_cattn,
|
1412
|
+
include_fc=include_fc,
|
1413
|
+
use_combined_linear=use_combined_linear,
|
1414
|
+
use_flash_attention=use_flash_attention,
|
1318
1415
|
)
|
1319
1416
|
else:
|
1320
1417
|
return AttnMidBlock(
|
@@ -1324,6 +1421,9 @@ def get_mid_block(
|
|
1324
1421
|
norm_num_groups=norm_num_groups,
|
1325
1422
|
norm_eps=norm_eps,
|
1326
1423
|
num_head_channels=num_head_channels,
|
1424
|
+
include_fc=include_fc,
|
1425
|
+
use_combined_linear=use_combined_linear,
|
1426
|
+
use_flash_attention=use_flash_attention,
|
1327
1427
|
)
|
1328
1428
|
|
1329
1429
|
|
@@ -1345,6 +1445,9 @@ def get_up_block(
|
|
1345
1445
|
cross_attention_dim: int | None,
|
1346
1446
|
upcast_attention: bool = False,
|
1347
1447
|
dropout_cattn: float = 0.0,
|
1448
|
+
include_fc: bool = True,
|
1449
|
+
use_combined_linear: bool = False,
|
1450
|
+
use_flash_attention: bool = False,
|
1348
1451
|
) -> nn.Module:
|
1349
1452
|
if with_attn:
|
1350
1453
|
return AttnUpBlock(
|
@@ -1359,6 +1462,9 @@ def get_up_block(
|
|
1359
1462
|
add_upsample=add_upsample,
|
1360
1463
|
resblock_updown=resblock_updown,
|
1361
1464
|
num_head_channels=num_head_channels,
|
1465
|
+
include_fc=include_fc,
|
1466
|
+
use_combined_linear=use_combined_linear,
|
1467
|
+
use_flash_attention=use_flash_attention,
|
1362
1468
|
)
|
1363
1469
|
elif with_cross_attn:
|
1364
1470
|
return CrossAttnUpBlock(
|
@@ -1377,6 +1483,9 @@ def get_up_block(
|
|
1377
1483
|
cross_attention_dim=cross_attention_dim,
|
1378
1484
|
upcast_attention=upcast_attention,
|
1379
1485
|
dropout_cattn=dropout_cattn,
|
1486
|
+
include_fc=include_fc,
|
1487
|
+
use_combined_linear=use_combined_linear,
|
1488
|
+
use_flash_attention=use_flash_attention,
|
1380
1489
|
)
|
1381
1490
|
else:
|
1382
1491
|
return UpBlock(
|
@@ -1414,9 +1523,13 @@ class DiffusionModelUNet(nn.Module):
|
|
1414
1523
|
transformer_num_layers: number of layers of Transformer blocks to use.
|
1415
1524
|
cross_attention_dim: number of context dimensions to use.
|
1416
1525
|
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
1417
|
-
|
1526
|
+
classes.
|
1418
1527
|
upcast_attention: if True, upcast attention operations to full precision.
|
1419
|
-
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
|
1528
|
+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
|
1529
|
+
include_fc: whether to include the final linear layer. Default to True.
|
1530
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
|
1531
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
1532
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
1420
1533
|
"""
|
1421
1534
|
|
1422
1535
|
def __init__(
|
@@ -1437,6 +1550,9 @@ class DiffusionModelUNet(nn.Module):
|
|
1437
1550
|
num_class_embeds: int | None = None,
|
1438
1551
|
upcast_attention: bool = False,
|
1439
1552
|
dropout_cattn: float = 0.0,
|
1553
|
+
include_fc: bool = True,
|
1554
|
+
use_combined_linear: bool = False,
|
1555
|
+
use_flash_attention: bool = False,
|
1440
1556
|
) -> None:
|
1441
1557
|
super().__init__()
|
1442
1558
|
if with_conditioning is True and cross_attention_dim is None:
|
@@ -1531,6 +1647,9 @@ class DiffusionModelUNet(nn.Module):
|
|
1531
1647
|
cross_attention_dim=cross_attention_dim,
|
1532
1648
|
upcast_attention=upcast_attention,
|
1533
1649
|
dropout_cattn=dropout_cattn,
|
1650
|
+
include_fc=include_fc,
|
1651
|
+
use_combined_linear=use_combined_linear,
|
1652
|
+
use_flash_attention=use_flash_attention,
|
1534
1653
|
)
|
1535
1654
|
|
1536
1655
|
self.down_blocks.append(down_block)
|
@@ -1548,6 +1667,9 @@ class DiffusionModelUNet(nn.Module):
|
|
1548
1667
|
cross_attention_dim=cross_attention_dim,
|
1549
1668
|
upcast_attention=upcast_attention,
|
1550
1669
|
dropout_cattn=dropout_cattn,
|
1670
|
+
include_fc=include_fc,
|
1671
|
+
use_combined_linear=use_combined_linear,
|
1672
|
+
use_flash_attention=use_flash_attention,
|
1551
1673
|
)
|
1552
1674
|
|
1553
1675
|
# up
|
@@ -1582,6 +1704,9 @@ class DiffusionModelUNet(nn.Module):
|
|
1582
1704
|
cross_attention_dim=cross_attention_dim,
|
1583
1705
|
upcast_attention=upcast_attention,
|
1584
1706
|
dropout_cattn=dropout_cattn,
|
1707
|
+
include_fc=include_fc,
|
1708
|
+
use_combined_linear=use_combined_linear,
|
1709
|
+
use_flash_attention=use_flash_attention,
|
1585
1710
|
)
|
1586
1711
|
|
1587
1712
|
self.up_blocks.append(up_block)
|
@@ -1709,31 +1834,40 @@ class DiffusionModelUNet(nn.Module):
|
|
1709
1834
|
# copy over all matching keys
|
1710
1835
|
for k in new_state_dict:
|
1711
1836
|
if k in old_state_dict:
|
1712
|
-
new_state_dict[k] = old_state_dict
|
1837
|
+
new_state_dict[k] = old_state_dict.pop(k)
|
1713
1838
|
|
1714
1839
|
# fix the attention blocks
|
1715
|
-
attention_blocks = [k.replace(".
|
1840
|
+
attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k]
|
1716
1841
|
for block in attention_blocks:
|
1717
|
-
new_state_dict[f"{block}.
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
dim=0,
|
1724
|
-
)
|
1842
|
+
new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
|
1843
|
+
new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
|
1844
|
+
new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
|
1845
|
+
new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
|
1846
|
+
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
|
1847
|
+
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
|
1725
1848
|
|
1726
1849
|
# projection
|
1727
|
-
new_state_dict[f"{block}.
|
1728
|
-
new_state_dict[f"{block}.
|
1850
|
+
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
|
1851
|
+
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
|
1852
|
+
|
1853
|
+
# fix the cross attention blocks
|
1854
|
+
cross_attention_blocks = [
|
1855
|
+
k.replace(".out_proj.weight", "")
|
1856
|
+
for k in new_state_dict
|
1857
|
+
if "out_proj.weight" in k and "transformer_blocks" in k
|
1858
|
+
]
|
1859
|
+
for block in cross_attention_blocks:
|
1860
|
+
new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight")
|
1861
|
+
new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")
|
1729
1862
|
|
1730
|
-
new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
|
1731
|
-
new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]
|
1732
1863
|
# fix the upsample conv blocks which were renamed postconv
|
1733
1864
|
for k in new_state_dict:
|
1734
1865
|
if "postconv" in k:
|
1735
1866
|
old_name = k.replace("postconv", "conv")
|
1736
|
-
new_state_dict[k] = old_state_dict
|
1867
|
+
new_state_dict[k] = old_state_dict.pop(old_name)
|
1868
|
+
if verbose:
|
1869
|
+
# print all remaining keys in old_state_dict
|
1870
|
+
print("remaining keys in old_state_dict:", old_state_dict.keys())
|
1737
1871
|
self.load_state_dict(new_state_dict)
|
1738
1872
|
|
1739
1873
|
|
@@ -1777,6 +1911,9 @@ class DiffusionModelEncoder(nn.Module):
|
|
1777
1911
|
cross_attention_dim: int | None = None,
|
1778
1912
|
num_class_embeds: int | None = None,
|
1779
1913
|
upcast_attention: bool = False,
|
1914
|
+
include_fc: bool = True,
|
1915
|
+
use_combined_linear: bool = False,
|
1916
|
+
use_flash_attention: bool = False,
|
1780
1917
|
) -> None:
|
1781
1918
|
super().__init__()
|
1782
1919
|
if with_conditioning is True and cross_attention_dim is None:
|
@@ -1861,6 +1998,9 @@ class DiffusionModelEncoder(nn.Module):
|
|
1861
1998
|
transformer_num_layers=transformer_num_layers,
|
1862
1999
|
cross_attention_dim=cross_attention_dim,
|
1863
2000
|
upcast_attention=upcast_attention,
|
2001
|
+
include_fc=include_fc,
|
2002
|
+
use_combined_linear=use_combined_linear,
|
2003
|
+
use_flash_attention=use_flash_attention,
|
1864
2004
|
)
|
1865
2005
|
|
1866
2006
|
self.down_blocks.append(down_block)
|
@@ -11,6 +11,7 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
+
import copy
|
14
15
|
from collections.abc import Callable
|
15
16
|
from typing import Union
|
16
17
|
|
@@ -23,7 +24,7 @@ from monai.networks.layers.factories import Act, Conv, Norm, split_args
|
|
23
24
|
from monai.networks.layers.utils import get_act_layer, get_norm_layer
|
24
25
|
from monai.utils import UpsampleMode, has_option
|
25
26
|
|
26
|
-
__all__ = ["SegResNetDS"]
|
27
|
+
__all__ = ["SegResNetDS", "SegResNetDS2"]
|
27
28
|
|
28
29
|
|
29
30
|
def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
|
@@ -425,3 +426,128 @@ class SegResNetDS(nn.Module):
|
|
425
426
|
|
426
427
|
def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
|
427
428
|
return self._forward(x)
|
429
|
+
|
430
|
+
|
431
|
+
class SegResNetDS2(SegResNetDS):
|
432
|
+
"""
|
433
|
+
SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D
|
434
|
+
<https://arxiv.org/abs/2406.05285>`_.
|
435
|
+
|
436
|
+
Args:
|
437
|
+
spatial_dims: spatial dimension of the input data. Defaults to 3.
|
438
|
+
init_filters: number of output channels for initial convolution layer. Defaults to 32.
|
439
|
+
in_channels: number of input channels for the network. Defaults to 1.
|
440
|
+
out_channels: number of output channels for the network. Defaults to 2.
|
441
|
+
act: activation type and arguments. Defaults to ``RELU``.
|
442
|
+
norm: feature normalization type and arguments. Defaults to ``BATCH``.
|
443
|
+
blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
|
444
|
+
blocks_up: number of upsample blocks (optional).
|
445
|
+
dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
|
446
|
+
At dsdepth==1,only a single output is returned.
|
447
|
+
preprocess: optional callable function to apply before the model's forward pass
|
448
|
+
resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
|
449
|
+
image spacing into an approximately isotropic space.
|
450
|
+
Otherwise, by default, the kernel size and downsampling is always isotropic.
|
451
|
+
|
452
|
+
"""
|
453
|
+
|
454
|
+
def __init__(
|
455
|
+
self,
|
456
|
+
spatial_dims: int = 3,
|
457
|
+
init_filters: int = 32,
|
458
|
+
in_channels: int = 1,
|
459
|
+
out_channels: int = 2,
|
460
|
+
act: tuple | str = "relu",
|
461
|
+
norm: tuple | str = "batch",
|
462
|
+
blocks_down: tuple = (1, 2, 2, 4),
|
463
|
+
blocks_up: tuple | None = None,
|
464
|
+
dsdepth: int = 1,
|
465
|
+
preprocess: nn.Module | Callable | None = None,
|
466
|
+
upsample_mode: UpsampleMode | str = "deconv",
|
467
|
+
resolution: tuple | None = None,
|
468
|
+
):
|
469
|
+
super().__init__(
|
470
|
+
spatial_dims=spatial_dims,
|
471
|
+
init_filters=init_filters,
|
472
|
+
in_channels=in_channels,
|
473
|
+
out_channels=out_channels,
|
474
|
+
act=act,
|
475
|
+
norm=norm,
|
476
|
+
blocks_down=blocks_down,
|
477
|
+
blocks_up=blocks_up,
|
478
|
+
dsdepth=dsdepth,
|
479
|
+
preprocess=preprocess,
|
480
|
+
upsample_mode=upsample_mode,
|
481
|
+
resolution=resolution,
|
482
|
+
)
|
483
|
+
|
484
|
+
self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers])
|
485
|
+
|
486
|
+
def forward( # type: ignore
|
487
|
+
self, x: torch.Tensor, with_point: bool = True, with_label: bool = True
|
488
|
+
) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]:
|
489
|
+
"""
|
490
|
+
Args:
|
491
|
+
x: input tensor.
|
492
|
+
with_point: if true, return the point branch output.
|
493
|
+
with_label: if true, return the label branch output.
|
494
|
+
"""
|
495
|
+
if self.preprocess is not None:
|
496
|
+
x = self.preprocess(x)
|
497
|
+
|
498
|
+
if not self.is_valid_shape(x):
|
499
|
+
raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")
|
500
|
+
|
501
|
+
x_down = self.encoder(x)
|
502
|
+
|
503
|
+
x_down.reverse()
|
504
|
+
x = x_down.pop(0)
|
505
|
+
|
506
|
+
if len(x_down) == 0:
|
507
|
+
x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]
|
508
|
+
|
509
|
+
outputs: list[torch.Tensor] = []
|
510
|
+
outputs_auto: list[torch.Tensor] = []
|
511
|
+
x_ = x.clone()
|
512
|
+
if with_point:
|
513
|
+
i = 0
|
514
|
+
for level in self.up_layers:
|
515
|
+
x = level["upsample"](x)
|
516
|
+
x = x + x_down[i]
|
517
|
+
x = level["blocks"](x)
|
518
|
+
|
519
|
+
if len(self.up_layers) - i <= self.dsdepth:
|
520
|
+
outputs.append(level["head"](x))
|
521
|
+
i = i + 1
|
522
|
+
|
523
|
+
outputs.reverse()
|
524
|
+
x = x_
|
525
|
+
if with_label:
|
526
|
+
i = 0
|
527
|
+
for level in self.up_layers_auto:
|
528
|
+
x = level["upsample"](x)
|
529
|
+
x = x + x_down[i]
|
530
|
+
x = level["blocks"](x)
|
531
|
+
|
532
|
+
if len(self.up_layers) - i <= self.dsdepth:
|
533
|
+
outputs_auto.append(level["head"](x))
|
534
|
+
i = i + 1
|
535
|
+
|
536
|
+
outputs_auto.reverse()
|
537
|
+
|
538
|
+
return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto
|
539
|
+
|
540
|
+
def set_auto_grad(self, auto_freeze=False, point_freeze=False):
|
541
|
+
"""
|
542
|
+
Args:
|
543
|
+
auto_freeze: if true, freeze the image encoder and the auto-branch.
|
544
|
+
point_freeze: if true, freeze the image encoder and the point-branch.
|
545
|
+
"""
|
546
|
+
for param in self.encoder.parameters():
|
547
|
+
param.requires_grad = (not auto_freeze) and (not point_freeze)
|
548
|
+
|
549
|
+
for param in self.up_layers_auto.parameters():
|
550
|
+
param.requires_grad = not auto_freeze
|
551
|
+
|
552
|
+
for param in self.up_layers.parameters():
|
553
|
+
param.requires_grad = not point_freeze
|
@@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module):
|
|
137
137
|
label_nc: number of semantic channels for SPADE normalisation.
|
138
138
|
with_nonlocal_attn: if True use non-local attention block.
|
139
139
|
spade_intermediate_channels: number of intermediate channels for SPADE block layer.
|
140
|
+
include_fc: whether to include the final linear layer. Default to True.
|
141
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
142
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
143
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
140
144
|
"""
|
141
145
|
|
142
146
|
def __init__(
|
@@ -152,6 +156,9 @@ class SPADEDecoder(nn.Module):
|
|
152
156
|
label_nc: int,
|
153
157
|
with_nonlocal_attn: bool = True,
|
154
158
|
spade_intermediate_channels: int = 128,
|
159
|
+
include_fc: bool = True,
|
160
|
+
use_combined_linear: bool = False,
|
161
|
+
use_flash_attention: bool = False,
|
155
162
|
) -> None:
|
156
163
|
super().__init__()
|
157
164
|
self.spatial_dims = spatial_dims
|
@@ -200,6 +207,9 @@ class SPADEDecoder(nn.Module):
|
|
200
207
|
num_channels=reversed_block_out_channels[0],
|
201
208
|
norm_num_groups=norm_num_groups,
|
202
209
|
norm_eps=norm_eps,
|
210
|
+
include_fc=include_fc,
|
211
|
+
use_combined_linear=use_combined_linear,
|
212
|
+
use_flash_attention=use_flash_attention,
|
203
213
|
)
|
204
214
|
)
|
205
215
|
blocks.append(
|
@@ -243,6 +253,9 @@ class SPADEDecoder(nn.Module):
|
|
243
253
|
num_channels=block_in_ch,
|
244
254
|
norm_num_groups=norm_num_groups,
|
245
255
|
norm_eps=norm_eps,
|
256
|
+
include_fc=include_fc,
|
257
|
+
use_combined_linear=use_combined_linear,
|
258
|
+
use_flash_attention=use_flash_attention,
|
246
259
|
)
|
247
260
|
)
|
248
261
|
|
@@ -331,6 +344,9 @@ class SPADEAutoencoderKL(nn.Module):
|
|
331
344
|
with_encoder_nonlocal_attn: bool = True,
|
332
345
|
with_decoder_nonlocal_attn: bool = True,
|
333
346
|
spade_intermediate_channels: int = 128,
|
347
|
+
include_fc: bool = True,
|
348
|
+
use_combined_linear: bool = False,
|
349
|
+
use_flash_attention: bool = False,
|
334
350
|
) -> None:
|
335
351
|
super().__init__()
|
336
352
|
|
@@ -360,6 +376,9 @@ class SPADEAutoencoderKL(nn.Module):
|
|
360
376
|
norm_eps=norm_eps,
|
361
377
|
attention_levels=attention_levels,
|
362
378
|
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
379
|
+
include_fc=include_fc,
|
380
|
+
use_combined_linear=use_combined_linear,
|
381
|
+
use_flash_attention=use_flash_attention,
|
363
382
|
)
|
364
383
|
self.decoder = SPADEDecoder(
|
365
384
|
spatial_dims=spatial_dims,
|
@@ -373,6 +392,9 @@ class SPADEAutoencoderKL(nn.Module):
|
|
373
392
|
label_nc=label_nc,
|
374
393
|
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
375
394
|
spade_intermediate_channels=spade_intermediate_channels,
|
395
|
+
include_fc=include_fc,
|
396
|
+
use_combined_linear=use_combined_linear,
|
397
|
+
use_flash_attention=use_flash_attention,
|
376
398
|
)
|
377
399
|
self.quant_conv_mu = Convolution(
|
378
400
|
spatial_dims=spatial_dims,
|