monai-weekly 1.4.dev2430__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/bundle/config_parser.py +2 -2
- monai/bundle/reference_resolver.py +18 -1
- monai/bundle/scripts.py +45 -22
- monai/bundle/utils.py +3 -1
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +24 -2
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +908 -0
- monai/networks/utils.py +3 -3
- monai/transforms/__init__.py +1 -0
- monai/transforms/io/array.py +1 -1
- monai/transforms/post/array.py +2 -1
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utils.py +183 -0
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +44 -41
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
@@ -59,7 +59,7 @@ class SPADEResBlock(nn.Module):
|
|
59
59
|
label_nc=label_nc,
|
60
60
|
norm_nc=in_channels,
|
61
61
|
norm="GROUP",
|
62
|
-
norm_params={"num_groups": norm_num_groups, "affine": False},
|
62
|
+
norm_params={"num_groups": norm_num_groups, "affine": False, "eps": norm_eps},
|
63
63
|
hidden_channels=spade_intermediate_channels,
|
64
64
|
kernel_size=3,
|
65
65
|
spatial_dims=spatial_dims,
|
@@ -77,7 +77,7 @@ class SPADEResBlock(nn.Module):
|
|
77
77
|
label_nc=label_nc,
|
78
78
|
norm_nc=out_channels,
|
79
79
|
norm="GROUP",
|
80
|
-
norm_params={"num_groups": norm_num_groups, "affine": False},
|
80
|
+
norm_params={"num_groups": norm_num_groups, "affine": False, "eps": norm_eps},
|
81
81
|
hidden_channels=spade_intermediate_channels,
|
82
82
|
kernel_size=3,
|
83
83
|
spatial_dims=spatial_dims,
|
@@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module):
|
|
137
137
|
label_nc: number of semantic channels for SPADE normalisation.
|
138
138
|
with_nonlocal_attn: if True use non-local attention block.
|
139
139
|
spade_intermediate_channels: number of intermediate channels for SPADE block layer.
|
140
|
+
include_fc: whether to include the final linear layer. Default to True.
|
141
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
142
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
143
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
140
144
|
"""
|
141
145
|
|
142
146
|
def __init__(
|
@@ -152,6 +156,9 @@ class SPADEDecoder(nn.Module):
|
|
152
156
|
label_nc: int,
|
153
157
|
with_nonlocal_attn: bool = True,
|
154
158
|
spade_intermediate_channels: int = 128,
|
159
|
+
include_fc: bool = True,
|
160
|
+
use_combined_linear: bool = False,
|
161
|
+
use_flash_attention: bool = False,
|
155
162
|
) -> None:
|
156
163
|
super().__init__()
|
157
164
|
self.spatial_dims = spatial_dims
|
@@ -200,6 +207,9 @@ class SPADEDecoder(nn.Module):
|
|
200
207
|
num_channels=reversed_block_out_channels[0],
|
201
208
|
norm_num_groups=norm_num_groups,
|
202
209
|
norm_eps=norm_eps,
|
210
|
+
include_fc=include_fc,
|
211
|
+
use_combined_linear=use_combined_linear,
|
212
|
+
use_flash_attention=use_flash_attention,
|
203
213
|
)
|
204
214
|
)
|
205
215
|
blocks.append(
|
@@ -243,6 +253,9 @@ class SPADEDecoder(nn.Module):
|
|
243
253
|
num_channels=block_in_ch,
|
244
254
|
norm_num_groups=norm_num_groups,
|
245
255
|
norm_eps=norm_eps,
|
256
|
+
include_fc=include_fc,
|
257
|
+
use_combined_linear=use_combined_linear,
|
258
|
+
use_flash_attention=use_flash_attention,
|
246
259
|
)
|
247
260
|
)
|
248
261
|
|
@@ -331,6 +344,9 @@ class SPADEAutoencoderKL(nn.Module):
|
|
331
344
|
with_encoder_nonlocal_attn: bool = True,
|
332
345
|
with_decoder_nonlocal_attn: bool = True,
|
333
346
|
spade_intermediate_channels: int = 128,
|
347
|
+
include_fc: bool = True,
|
348
|
+
use_combined_linear: bool = False,
|
349
|
+
use_flash_attention: bool = False,
|
334
350
|
) -> None:
|
335
351
|
super().__init__()
|
336
352
|
|
@@ -360,6 +376,9 @@ class SPADEAutoencoderKL(nn.Module):
|
|
360
376
|
norm_eps=norm_eps,
|
361
377
|
attention_levels=attention_levels,
|
362
378
|
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
379
|
+
include_fc=include_fc,
|
380
|
+
use_combined_linear=use_combined_linear,
|
381
|
+
use_flash_attention=use_flash_attention,
|
363
382
|
)
|
364
383
|
self.decoder = SPADEDecoder(
|
365
384
|
spatial_dims=spatial_dims,
|
@@ -373,6 +392,9 @@ class SPADEAutoencoderKL(nn.Module):
|
|
373
392
|
label_nc=label_nc,
|
374
393
|
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
375
394
|
spade_intermediate_channels=spade_intermediate_channels,
|
395
|
+
include_fc=include_fc,
|
396
|
+
use_combined_linear=use_combined_linear,
|
397
|
+
use_flash_attention=use_flash_attention,
|
376
398
|
)
|
377
399
|
self.quant_conv_mu = Convolution(
|
378
400
|
spatial_dims=spatial_dims,
|
@@ -325,6 +325,10 @@ class SPADEAttnUpBlock(nn.Module):
|
|
325
325
|
resblock_updown: if True use residual blocks for upsampling.
|
326
326
|
num_head_channels: number of channels in each attention head.
|
327
327
|
spade_intermediate_channels: number of intermediate channels for SPADE block layer
|
328
|
+
include_fc: whether to include the final linear layer. Default to True.
|
329
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
330
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
331
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
328
332
|
"""
|
329
333
|
|
330
334
|
def __init__(
|
@@ -342,6 +346,9 @@ class SPADEAttnUpBlock(nn.Module):
|
|
342
346
|
resblock_updown: bool = False,
|
343
347
|
num_head_channels: int = 1,
|
344
348
|
spade_intermediate_channels: int = 128,
|
349
|
+
include_fc: bool = True,
|
350
|
+
use_combined_linear: bool = False,
|
351
|
+
use_flash_attention: bool = False,
|
345
352
|
) -> None:
|
346
353
|
super().__init__()
|
347
354
|
self.resblock_updown = resblock_updown
|
@@ -371,6 +378,9 @@ class SPADEAttnUpBlock(nn.Module):
|
|
371
378
|
num_head_channels=num_head_channels,
|
372
379
|
norm_num_groups=norm_num_groups,
|
373
380
|
norm_eps=norm_eps,
|
381
|
+
include_fc=include_fc,
|
382
|
+
use_combined_linear=use_combined_linear,
|
383
|
+
use_flash_attention=use_flash_attention,
|
374
384
|
)
|
375
385
|
)
|
376
386
|
|
@@ -457,6 +467,8 @@ class SPADECrossAttnUpBlock(nn.Module):
|
|
457
467
|
cross_attention_dim: number of context dimensions to use.
|
458
468
|
upcast_attention: if True, upcast attention operations to full precision.
|
459
469
|
spade_intermediate_channels: number of intermediate channels for SPADE block layer.
|
470
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism.
|
471
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
460
472
|
"""
|
461
473
|
|
462
474
|
def __init__(
|
@@ -477,6 +489,9 @@ class SPADECrossAttnUpBlock(nn.Module):
|
|
477
489
|
cross_attention_dim: int | None = None,
|
478
490
|
upcast_attention: bool = False,
|
479
491
|
spade_intermediate_channels: int = 128,
|
492
|
+
include_fc: bool = True,
|
493
|
+
use_combined_linear: bool = False,
|
494
|
+
use_flash_attention: bool = False,
|
480
495
|
) -> None:
|
481
496
|
super().__init__()
|
482
497
|
self.resblock_updown = resblock_updown
|
@@ -510,6 +525,9 @@ class SPADECrossAttnUpBlock(nn.Module):
|
|
510
525
|
num_layers=transformer_num_layers,
|
511
526
|
cross_attention_dim=cross_attention_dim,
|
512
527
|
upcast_attention=upcast_attention,
|
528
|
+
include_fc=include_fc,
|
529
|
+
use_combined_linear=use_combined_linear,
|
530
|
+
use_flash_attention=use_flash_attention,
|
513
531
|
)
|
514
532
|
)
|
515
533
|
|
@@ -592,6 +610,9 @@ def get_spade_up_block(
|
|
592
610
|
cross_attention_dim: int | None,
|
593
611
|
upcast_attention: bool = False,
|
594
612
|
spade_intermediate_channels: int = 128,
|
613
|
+
include_fc: bool = True,
|
614
|
+
use_combined_linear: bool = False,
|
615
|
+
use_flash_attention: bool = False,
|
595
616
|
) -> nn.Module:
|
596
617
|
if with_attn:
|
597
618
|
return SPADEAttnUpBlock(
|
@@ -608,6 +629,9 @@ def get_spade_up_block(
|
|
608
629
|
resblock_updown=resblock_updown,
|
609
630
|
num_head_channels=num_head_channels,
|
610
631
|
spade_intermediate_channels=spade_intermediate_channels,
|
632
|
+
include_fc=include_fc,
|
633
|
+
use_combined_linear=use_combined_linear,
|
634
|
+
use_flash_attention=use_flash_attention,
|
611
635
|
)
|
612
636
|
elif with_cross_attn:
|
613
637
|
return SPADECrossAttnUpBlock(
|
@@ -627,6 +651,7 @@ def get_spade_up_block(
|
|
627
651
|
cross_attention_dim=cross_attention_dim,
|
628
652
|
upcast_attention=upcast_attention,
|
629
653
|
spade_intermediate_channels=spade_intermediate_channels,
|
654
|
+
use_flash_attention=use_flash_attention,
|
630
655
|
)
|
631
656
|
else:
|
632
657
|
return SPADEUpBlock(
|
@@ -667,9 +692,11 @@ class SPADEDiffusionModelUNet(nn.Module):
|
|
667
692
|
transformer_num_layers: number of layers of Transformer blocks to use.
|
668
693
|
cross_attention_dim: number of context dimensions to use.
|
669
694
|
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
670
|
-
|
695
|
+
classes.
|
671
696
|
upcast_attention: if True, upcast attention operations to full precision.
|
672
|
-
spade_intermediate_channels: number of intermediate channels for SPADE block layer
|
697
|
+
spade_intermediate_channels: number of intermediate channels for SPADE block layer.
|
698
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
699
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
673
700
|
"""
|
674
701
|
|
675
702
|
def __init__(
|
@@ -691,6 +718,9 @@ class SPADEDiffusionModelUNet(nn.Module):
|
|
691
718
|
num_class_embeds: int | None = None,
|
692
719
|
upcast_attention: bool = False,
|
693
720
|
spade_intermediate_channels: int = 128,
|
721
|
+
include_fc: bool = True,
|
722
|
+
use_combined_linear: bool = False,
|
723
|
+
use_flash_attention: bool = False,
|
694
724
|
) -> None:
|
695
725
|
super().__init__()
|
696
726
|
if with_conditioning is True and cross_attention_dim is None:
|
@@ -783,6 +813,9 @@ class SPADEDiffusionModelUNet(nn.Module):
|
|
783
813
|
transformer_num_layers=transformer_num_layers,
|
784
814
|
cross_attention_dim=cross_attention_dim,
|
785
815
|
upcast_attention=upcast_attention,
|
816
|
+
include_fc=include_fc,
|
817
|
+
use_combined_linear=use_combined_linear,
|
818
|
+
use_flash_attention=use_flash_attention,
|
786
819
|
)
|
787
820
|
|
788
821
|
self.down_blocks.append(down_block)
|
@@ -799,6 +832,9 @@ class SPADEDiffusionModelUNet(nn.Module):
|
|
799
832
|
transformer_num_layers=transformer_num_layers,
|
800
833
|
cross_attention_dim=cross_attention_dim,
|
801
834
|
upcast_attention=upcast_attention,
|
835
|
+
include_fc=include_fc,
|
836
|
+
use_combined_linear=use_combined_linear,
|
837
|
+
use_flash_attention=use_flash_attention,
|
802
838
|
)
|
803
839
|
|
804
840
|
# up
|
@@ -834,6 +870,7 @@ class SPADEDiffusionModelUNet(nn.Module):
|
|
834
870
|
upcast_attention=upcast_attention,
|
835
871
|
label_nc=label_nc,
|
836
872
|
spade_intermediate_channels=spade_intermediate_channels,
|
873
|
+
use_flash_attention=use_flash_attention,
|
837
874
|
)
|
838
875
|
|
839
876
|
self.up_blocks.append(up_block)
|
@@ -51,6 +51,10 @@ class DecoderOnlyTransformer(nn.Module):
|
|
51
51
|
attn_layers_heads: Number of attention heads.
|
52
52
|
with_cross_attention: Whether to use cross attention for conditioning.
|
53
53
|
embedding_dropout_rate: Dropout rate for the embedding.
|
54
|
+
include_fc: whether to include the final linear layer. Default to True.
|
55
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
|
56
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
57
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
54
58
|
"""
|
55
59
|
|
56
60
|
def __init__(
|
@@ -62,6 +66,9 @@ class DecoderOnlyTransformer(nn.Module):
|
|
62
66
|
attn_layers_heads: int,
|
63
67
|
with_cross_attention: bool = False,
|
64
68
|
embedding_dropout_rate: float = 0.0,
|
69
|
+
include_fc: bool = True,
|
70
|
+
use_combined_linear: bool = False,
|
71
|
+
use_flash_attention: bool = False,
|
65
72
|
) -> None:
|
66
73
|
super().__init__()
|
67
74
|
self.num_tokens = num_tokens
|
@@ -86,6 +93,9 @@ class DecoderOnlyTransformer(nn.Module):
|
|
86
93
|
causal=True,
|
87
94
|
sequence_length=max_seq_len,
|
88
95
|
with_cross_attention=with_cross_attention,
|
96
|
+
include_fc=include_fc,
|
97
|
+
use_combined_linear=use_combined_linear,
|
98
|
+
use_flash_attention=use_flash_attention,
|
89
99
|
)
|
90
100
|
for _ in range(attn_layers_depth)
|
91
101
|
]
|
@@ -133,25 +143,15 @@ class DecoderOnlyTransformer(nn.Module):
|
|
133
143
|
# copy over all matching keys
|
134
144
|
for k in new_state_dict:
|
135
145
|
if k in old_state_dict:
|
136
|
-
new_state_dict[k] = old_state_dict
|
137
|
-
|
138
|
-
# fix the attention blocks
|
139
|
-
attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k]
|
140
|
-
for block in attention_blocks:
|
141
|
-
new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat(
|
142
|
-
[
|
143
|
-
old_state_dict[f"{block}.attn.to_q.weight"],
|
144
|
-
old_state_dict[f"{block}.attn.to_k.weight"],
|
145
|
-
old_state_dict[f"{block}.attn.to_v.weight"],
|
146
|
-
],
|
147
|
-
dim=0,
|
148
|
-
)
|
146
|
+
new_state_dict[k] = old_state_dict.pop(k)
|
149
147
|
|
150
148
|
# fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2
|
151
|
-
for k in old_state_dict:
|
149
|
+
for k in list(old_state_dict.keys()):
|
152
150
|
if "norm2" in k:
|
153
|
-
new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict
|
151
|
+
new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k)
|
154
152
|
if "norm3" in k:
|
155
|
-
new_state_dict[k.replace("norm3", "norm2")] = old_state_dict
|
156
|
-
|
153
|
+
new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k)
|
154
|
+
if verbose:
|
155
|
+
# print all remaining keys in old_state_dict
|
156
|
+
print("remaining keys in old_state_dict:", old_state_dict.keys())
|
157
157
|
self.load_state_dict(new_state_dict)
|