monai-weekly 1.4.dev2430__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/bundle/config_parser.py +2 -2
  7. monai/bundle/reference_resolver.py +18 -1
  8. monai/bundle/scripts.py +45 -22
  9. monai/bundle/utils.py +3 -1
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/losses/__init__.py +1 -0
  13. monai/losses/dice.py +10 -1
  14. monai/losses/nacl_loss.py +139 -0
  15. monai/networks/blocks/crossattention.py +48 -26
  16. monai/networks/blocks/mlp.py +16 -4
  17. monai/networks/blocks/selfattention.py +75 -23
  18. monai/networks/blocks/spatialattention.py +16 -1
  19. monai/networks/blocks/transformerblock.py +17 -2
  20. monai/networks/nets/__init__.py +2 -1
  21. monai/networks/nets/autoencoderkl.py +55 -22
  22. monai/networks/nets/cell_sam_wrapper.py +92 -0
  23. monai/networks/nets/controlnet.py +24 -22
  24. monai/networks/nets/diffusion_model_unet.py +159 -19
  25. monai/networks/nets/segresnet_ds.py +127 -1
  26. monai/networks/nets/spade_autoencoderkl.py +24 -2
  27. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  28. monai/networks/nets/transformer.py +17 -17
  29. monai/networks/nets/vista3d.py +908 -0
  30. monai/networks/utils.py +3 -3
  31. monai/transforms/__init__.py +1 -0
  32. monai/transforms/io/array.py +1 -1
  33. monai/transforms/post/array.py +2 -1
  34. monai/transforms/spatial/functional.py +1 -1
  35. monai/transforms/transform.py +2 -2
  36. monai/transforms/utils.py +183 -0
  37. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  38. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  39. monai/utils/module.py +7 -6
  40. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
  41. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +44 -41
  42. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
  43. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
  44. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
@@ -59,7 +59,7 @@ class SPADEResBlock(nn.Module):
59
59
  label_nc=label_nc,
60
60
  norm_nc=in_channels,
61
61
  norm="GROUP",
62
- norm_params={"num_groups": norm_num_groups, "affine": False},
62
+ norm_params={"num_groups": norm_num_groups, "affine": False, "eps": norm_eps},
63
63
  hidden_channels=spade_intermediate_channels,
64
64
  kernel_size=3,
65
65
  spatial_dims=spatial_dims,
@@ -77,7 +77,7 @@ class SPADEResBlock(nn.Module):
77
77
  label_nc=label_nc,
78
78
  norm_nc=out_channels,
79
79
  norm="GROUP",
80
- norm_params={"num_groups": norm_num_groups, "affine": False},
80
+ norm_params={"num_groups": norm_num_groups, "affine": False, "eps": norm_eps},
81
81
  hidden_channels=spade_intermediate_channels,
82
82
  kernel_size=3,
83
83
  spatial_dims=spatial_dims,
@@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module):
137
137
  label_nc: number of semantic channels for SPADE normalisation.
138
138
  with_nonlocal_attn: if True use non-local attention block.
139
139
  spade_intermediate_channels: number of intermediate channels for SPADE block layer.
140
+ include_fc: whether to include the final linear layer. Default to True.
141
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
142
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
143
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
140
144
  """
141
145
 
142
146
  def __init__(
@@ -152,6 +156,9 @@ class SPADEDecoder(nn.Module):
152
156
  label_nc: int,
153
157
  with_nonlocal_attn: bool = True,
154
158
  spade_intermediate_channels: int = 128,
159
+ include_fc: bool = True,
160
+ use_combined_linear: bool = False,
161
+ use_flash_attention: bool = False,
155
162
  ) -> None:
156
163
  super().__init__()
157
164
  self.spatial_dims = spatial_dims
@@ -200,6 +207,9 @@ class SPADEDecoder(nn.Module):
200
207
  num_channels=reversed_block_out_channels[0],
201
208
  norm_num_groups=norm_num_groups,
202
209
  norm_eps=norm_eps,
210
+ include_fc=include_fc,
211
+ use_combined_linear=use_combined_linear,
212
+ use_flash_attention=use_flash_attention,
203
213
  )
204
214
  )
205
215
  blocks.append(
@@ -243,6 +253,9 @@ class SPADEDecoder(nn.Module):
243
253
  num_channels=block_in_ch,
244
254
  norm_num_groups=norm_num_groups,
245
255
  norm_eps=norm_eps,
256
+ include_fc=include_fc,
257
+ use_combined_linear=use_combined_linear,
258
+ use_flash_attention=use_flash_attention,
246
259
  )
247
260
  )
248
261
 
@@ -331,6 +344,9 @@ class SPADEAutoencoderKL(nn.Module):
331
344
  with_encoder_nonlocal_attn: bool = True,
332
345
  with_decoder_nonlocal_attn: bool = True,
333
346
  spade_intermediate_channels: int = 128,
347
+ include_fc: bool = True,
348
+ use_combined_linear: bool = False,
349
+ use_flash_attention: bool = False,
334
350
  ) -> None:
335
351
  super().__init__()
336
352
 
@@ -360,6 +376,9 @@ class SPADEAutoencoderKL(nn.Module):
360
376
  norm_eps=norm_eps,
361
377
  attention_levels=attention_levels,
362
378
  with_nonlocal_attn=with_encoder_nonlocal_attn,
379
+ include_fc=include_fc,
380
+ use_combined_linear=use_combined_linear,
381
+ use_flash_attention=use_flash_attention,
363
382
  )
364
383
  self.decoder = SPADEDecoder(
365
384
  spatial_dims=spatial_dims,
@@ -373,6 +392,9 @@ class SPADEAutoencoderKL(nn.Module):
373
392
  label_nc=label_nc,
374
393
  with_nonlocal_attn=with_decoder_nonlocal_attn,
375
394
  spade_intermediate_channels=spade_intermediate_channels,
395
+ include_fc=include_fc,
396
+ use_combined_linear=use_combined_linear,
397
+ use_flash_attention=use_flash_attention,
376
398
  )
377
399
  self.quant_conv_mu = Convolution(
378
400
  spatial_dims=spatial_dims,
@@ -325,6 +325,10 @@ class SPADEAttnUpBlock(nn.Module):
325
325
  resblock_updown: if True use residual blocks for upsampling.
326
326
  num_head_channels: number of channels in each attention head.
327
327
  spade_intermediate_channels: number of intermediate channels for SPADE block layer
328
+ include_fc: whether to include the final linear layer. Default to True.
329
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
330
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
331
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
328
332
  """
329
333
 
330
334
  def __init__(
@@ -342,6 +346,9 @@ class SPADEAttnUpBlock(nn.Module):
342
346
  resblock_updown: bool = False,
343
347
  num_head_channels: int = 1,
344
348
  spade_intermediate_channels: int = 128,
349
+ include_fc: bool = True,
350
+ use_combined_linear: bool = False,
351
+ use_flash_attention: bool = False,
345
352
  ) -> None:
346
353
  super().__init__()
347
354
  self.resblock_updown = resblock_updown
@@ -371,6 +378,9 @@ class SPADEAttnUpBlock(nn.Module):
371
378
  num_head_channels=num_head_channels,
372
379
  norm_num_groups=norm_num_groups,
373
380
  norm_eps=norm_eps,
381
+ include_fc=include_fc,
382
+ use_combined_linear=use_combined_linear,
383
+ use_flash_attention=use_flash_attention,
374
384
  )
375
385
  )
376
386
 
@@ -457,6 +467,8 @@ class SPADECrossAttnUpBlock(nn.Module):
457
467
  cross_attention_dim: number of context dimensions to use.
458
468
  upcast_attention: if True, upcast attention operations to full precision.
459
469
  spade_intermediate_channels: number of intermediate channels for SPADE block layer.
470
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism.
471
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
460
472
  """
461
473
 
462
474
  def __init__(
@@ -477,6 +489,9 @@ class SPADECrossAttnUpBlock(nn.Module):
477
489
  cross_attention_dim: int | None = None,
478
490
  upcast_attention: bool = False,
479
491
  spade_intermediate_channels: int = 128,
492
+ include_fc: bool = True,
493
+ use_combined_linear: bool = False,
494
+ use_flash_attention: bool = False,
480
495
  ) -> None:
481
496
  super().__init__()
482
497
  self.resblock_updown = resblock_updown
@@ -510,6 +525,9 @@ class SPADECrossAttnUpBlock(nn.Module):
510
525
  num_layers=transformer_num_layers,
511
526
  cross_attention_dim=cross_attention_dim,
512
527
  upcast_attention=upcast_attention,
528
+ include_fc=include_fc,
529
+ use_combined_linear=use_combined_linear,
530
+ use_flash_attention=use_flash_attention,
513
531
  )
514
532
  )
515
533
 
@@ -592,6 +610,9 @@ def get_spade_up_block(
592
610
  cross_attention_dim: int | None,
593
611
  upcast_attention: bool = False,
594
612
  spade_intermediate_channels: int = 128,
613
+ include_fc: bool = True,
614
+ use_combined_linear: bool = False,
615
+ use_flash_attention: bool = False,
595
616
  ) -> nn.Module:
596
617
  if with_attn:
597
618
  return SPADEAttnUpBlock(
@@ -608,6 +629,9 @@ def get_spade_up_block(
608
629
  resblock_updown=resblock_updown,
609
630
  num_head_channels=num_head_channels,
610
631
  spade_intermediate_channels=spade_intermediate_channels,
632
+ include_fc=include_fc,
633
+ use_combined_linear=use_combined_linear,
634
+ use_flash_attention=use_flash_attention,
611
635
  )
612
636
  elif with_cross_attn:
613
637
  return SPADECrossAttnUpBlock(
@@ -627,6 +651,7 @@ def get_spade_up_block(
627
651
  cross_attention_dim=cross_attention_dim,
628
652
  upcast_attention=upcast_attention,
629
653
  spade_intermediate_channels=spade_intermediate_channels,
654
+ use_flash_attention=use_flash_attention,
630
655
  )
631
656
  else:
632
657
  return SPADEUpBlock(
@@ -667,9 +692,11 @@ class SPADEDiffusionModelUNet(nn.Module):
667
692
  transformer_num_layers: number of layers of Transformer blocks to use.
668
693
  cross_attention_dim: number of context dimensions to use.
669
694
  num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
670
- classes.
695
+ classes.
671
696
  upcast_attention: if True, upcast attention operations to full precision.
672
- spade_intermediate_channels: number of intermediate channels for SPADE block layer
697
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
698
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
699
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
673
700
  """
674
701
 
675
702
  def __init__(
@@ -691,6 +718,9 @@ class SPADEDiffusionModelUNet(nn.Module):
691
718
  num_class_embeds: int | None = None,
692
719
  upcast_attention: bool = False,
693
720
  spade_intermediate_channels: int = 128,
721
+ include_fc: bool = True,
722
+ use_combined_linear: bool = False,
723
+ use_flash_attention: bool = False,
694
724
  ) -> None:
695
725
  super().__init__()
696
726
  if with_conditioning is True and cross_attention_dim is None:
@@ -783,6 +813,9 @@ class SPADEDiffusionModelUNet(nn.Module):
783
813
  transformer_num_layers=transformer_num_layers,
784
814
  cross_attention_dim=cross_attention_dim,
785
815
  upcast_attention=upcast_attention,
816
+ include_fc=include_fc,
817
+ use_combined_linear=use_combined_linear,
818
+ use_flash_attention=use_flash_attention,
786
819
  )
787
820
 
788
821
  self.down_blocks.append(down_block)
@@ -799,6 +832,9 @@ class SPADEDiffusionModelUNet(nn.Module):
799
832
  transformer_num_layers=transformer_num_layers,
800
833
  cross_attention_dim=cross_attention_dim,
801
834
  upcast_attention=upcast_attention,
835
+ include_fc=include_fc,
836
+ use_combined_linear=use_combined_linear,
837
+ use_flash_attention=use_flash_attention,
802
838
  )
803
839
 
804
840
  # up
@@ -834,6 +870,7 @@ class SPADEDiffusionModelUNet(nn.Module):
834
870
  upcast_attention=upcast_attention,
835
871
  label_nc=label_nc,
836
872
  spade_intermediate_channels=spade_intermediate_channels,
873
+ use_flash_attention=use_flash_attention,
837
874
  )
838
875
 
839
876
  self.up_blocks.append(up_block)
@@ -51,6 +51,10 @@ class DecoderOnlyTransformer(nn.Module):
51
51
  attn_layers_heads: Number of attention heads.
52
52
  with_cross_attention: Whether to use cross attention for conditioning.
53
53
  embedding_dropout_rate: Dropout rate for the embedding.
54
+ include_fc: whether to include the final linear layer. Default to True.
55
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
56
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
57
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
54
58
  """
55
59
 
56
60
  def __init__(
@@ -62,6 +66,9 @@ class DecoderOnlyTransformer(nn.Module):
62
66
  attn_layers_heads: int,
63
67
  with_cross_attention: bool = False,
64
68
  embedding_dropout_rate: float = 0.0,
69
+ include_fc: bool = True,
70
+ use_combined_linear: bool = False,
71
+ use_flash_attention: bool = False,
65
72
  ) -> None:
66
73
  super().__init__()
67
74
  self.num_tokens = num_tokens
@@ -86,6 +93,9 @@ class DecoderOnlyTransformer(nn.Module):
86
93
  causal=True,
87
94
  sequence_length=max_seq_len,
88
95
  with_cross_attention=with_cross_attention,
96
+ include_fc=include_fc,
97
+ use_combined_linear=use_combined_linear,
98
+ use_flash_attention=use_flash_attention,
89
99
  )
90
100
  for _ in range(attn_layers_depth)
91
101
  ]
@@ -133,25 +143,15 @@ class DecoderOnlyTransformer(nn.Module):
133
143
  # copy over all matching keys
134
144
  for k in new_state_dict:
135
145
  if k in old_state_dict:
136
- new_state_dict[k] = old_state_dict[k]
137
-
138
- # fix the attention blocks
139
- attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k]
140
- for block in attention_blocks:
141
- new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat(
142
- [
143
- old_state_dict[f"{block}.attn.to_q.weight"],
144
- old_state_dict[f"{block}.attn.to_k.weight"],
145
- old_state_dict[f"{block}.attn.to_v.weight"],
146
- ],
147
- dim=0,
148
- )
146
+ new_state_dict[k] = old_state_dict.pop(k)
149
147
 
150
148
  # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2
151
- for k in old_state_dict:
149
+ for k in list(old_state_dict.keys()):
152
150
  if "norm2" in k:
153
- new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k]
151
+ new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k)
154
152
  if "norm3" in k:
155
- new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k]
156
-
153
+ new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k)
154
+ if verbose:
155
+ # print all remaining keys in old_state_dict
156
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
157
157
  self.load_state_dict(new_state_dict)