monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/apps/vista3d/inferer.py +177 -0
  7. monai/apps/vista3d/sampler.py +179 -0
  8. monai/apps/vista3d/transforms.py +224 -0
  9. monai/bundle/scripts.py +29 -17
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/inferers/utils.py +1 -0
  13. monai/losses/__init__.py +1 -0
  14. monai/losses/dice.py +10 -1
  15. monai/losses/nacl_loss.py +139 -0
  16. monai/networks/blocks/crossattention.py +48 -26
  17. monai/networks/blocks/mlp.py +16 -4
  18. monai/networks/blocks/selfattention.py +75 -23
  19. monai/networks/blocks/spatialattention.py +16 -1
  20. monai/networks/blocks/transformerblock.py +17 -2
  21. monai/networks/layers/filtering.py +6 -2
  22. monai/networks/nets/__init__.py +2 -1
  23. monai/networks/nets/autoencoderkl.py +55 -22
  24. monai/networks/nets/cell_sam_wrapper.py +92 -0
  25. monai/networks/nets/controlnet.py +24 -22
  26. monai/networks/nets/diffusion_model_unet.py +159 -19
  27. monai/networks/nets/segresnet_ds.py +127 -1
  28. monai/networks/nets/spade_autoencoderkl.py +22 -0
  29. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  30. monai/networks/nets/transformer.py +17 -17
  31. monai/networks/nets/vista3d.py +946 -0
  32. monai/networks/utils.py +4 -4
  33. monai/transforms/__init__.py +13 -2
  34. monai/transforms/io/array.py +59 -3
  35. monai/transforms/io/dictionary.py +29 -2
  36. monai/transforms/spatial/functional.py +1 -1
  37. monai/transforms/transform.py +2 -2
  38. monai/transforms/utility/dictionary.py +4 -0
  39. monai/transforms/utils.py +230 -1
  40. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  41. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  42. monai/utils/enums.py +1 -0
  43. monai/utils/module.py +7 -6
  44. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
  45. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
  46. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
  47. /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
  48. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
  49. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
monai/__init__.py CHANGED
@@ -93,4 +93,4 @@ except BaseException:
93
93
 
94
94
  if MONAIEnvVars.debug():
95
95
  raise
96
- __commit_id__ = "56ee32e36c5c0c7a5cb10afa4ec5589c81171e6b"
96
+ __commit_id__ = "fa1ef8be157d5eb96de17aa78642384f68d99397"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-08-04T02:19:41+0000",
11
+ "date": "2024-09-01T02:28:54+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "951a77d7a7737a3108afa94623a50b87d21eb4a7",
15
- "version": "1.4.dev2431"
14
+ "full-revisionid": "d311b1d7b12a95dd7de995b507ffbb5ed413bab6",
15
+ "version": "1.4.dev2435"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -13,25 +13,17 @@ from __future__ import annotations
13
13
 
14
14
  import gc
15
15
  import logging
16
- from typing import TYPE_CHECKING, Sequence, cast
16
+ from typing import Sequence
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  import torch.nn.functional as F
21
21
 
22
22
  from monai.networks.blocks import Convolution
23
- from monai.utils import optional_import
23
+ from monai.networks.blocks.spatialattention import SpatialAttentionBlock
24
+ from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
24
25
  from monai.utils.type_conversion import convert_to_tensor
25
26
 
26
- AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
27
- AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
28
- ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
29
-
30
- if TYPE_CHECKING:
31
- from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
32
- else:
33
- AutoencoderKLType = cast(type, AutoencoderKL)
34
-
35
27
  # Set up logging configuration
36
28
  logger = logging.getLogger(__name__)
37
29
 
@@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
518
510
  in_channels: Number of input channels.
519
511
  num_channels: Sequence of block output channels.
520
512
  out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
521
- num_res_blocks: Number of residual blocks (see ResBlock) per level.
513
+ num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
522
514
  norm_num_groups: Number of groups for the group norm layers.
523
515
  norm_eps: Epsilon for the normalization.
524
516
  attention_levels: Indicate which level from num_channels contain an attention block.
525
517
  with_nonlocal_attn: If True, use non-local attention block.
518
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
519
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
526
520
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
527
521
  num_splits: Number of splits for the input tensor.
528
522
  dim_split: Dimension of splitting for the input tensor.
@@ -547,6 +541,8 @@ class MaisiEncoder(nn.Module):
547
541
  print_info: bool = False,
548
542
  save_mem: bool = True,
549
543
  with_nonlocal_attn: bool = True,
544
+ include_fc: bool = False,
545
+ use_combined_linear: bool = False,
550
546
  use_flash_attention: bool = False,
551
547
  ) -> None:
552
548
  super().__init__()
@@ -603,11 +599,13 @@ class MaisiEncoder(nn.Module):
603
599
  input_channel = output_channel
604
600
  if attention_levels[i]:
605
601
  blocks.append(
606
- AttentionBlock(
602
+ SpatialAttentionBlock(
607
603
  spatial_dims=spatial_dims,
608
604
  num_channels=input_channel,
609
605
  norm_num_groups=norm_num_groups,
610
606
  norm_eps=norm_eps,
607
+ include_fc=include_fc,
608
+ use_combined_linear=use_combined_linear,
611
609
  use_flash_attention=use_flash_attention,
612
610
  )
613
611
  )
@@ -626,7 +624,7 @@ class MaisiEncoder(nn.Module):
626
624
 
627
625
  if with_nonlocal_attn:
628
626
  blocks.append(
629
- ResBlock(
627
+ AEKLResBlock(
630
628
  spatial_dims=spatial_dims,
631
629
  in_channels=num_channels[-1],
632
630
  norm_num_groups=norm_num_groups,
@@ -636,16 +634,18 @@ class MaisiEncoder(nn.Module):
636
634
  )
637
635
 
638
636
  blocks.append(
639
- AttentionBlock(
637
+ SpatialAttentionBlock(
640
638
  spatial_dims=spatial_dims,
641
639
  num_channels=num_channels[-1],
642
640
  norm_num_groups=norm_num_groups,
643
641
  norm_eps=norm_eps,
642
+ include_fc=include_fc,
643
+ use_combined_linear=use_combined_linear,
644
644
  use_flash_attention=use_flash_attention,
645
645
  )
646
646
  )
647
647
  blocks.append(
648
- ResBlock(
648
+ AEKLResBlock(
649
649
  spatial_dims=spatial_dims,
650
650
  in_channels=num_channels[-1],
651
651
  norm_num_groups=norm_num_groups,
@@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
699
699
  num_channels: Sequence of block output channels.
700
700
  in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
701
701
  out_channels: Number of output channels.
702
- num_res_blocks: Number of residual blocks (see ResBlock) per level.
702
+ num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
703
703
  norm_num_groups: Number of groups for the group norm layers.
704
704
  norm_eps: Epsilon for the normalization.
705
705
  attention_levels: Indicate which level from num_channels contain an attention block.
706
706
  with_nonlocal_attn: If True, use non-local attention block.
707
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
708
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
707
709
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
708
710
  use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
709
711
  num_splits: Number of splits for the input tensor.
@@ -729,6 +731,8 @@ class MaisiDecoder(nn.Module):
729
731
  print_info: bool = False,
730
732
  save_mem: bool = True,
731
733
  with_nonlocal_attn: bool = True,
734
+ include_fc: bool = False,
735
+ use_combined_linear: bool = False,
732
736
  use_flash_attention: bool = False,
733
737
  use_convtranspose: bool = False,
734
738
  ) -> None:
@@ -758,7 +762,7 @@ class MaisiDecoder(nn.Module):
758
762
 
759
763
  if with_nonlocal_attn:
760
764
  blocks.append(
761
- ResBlock(
765
+ AEKLResBlock(
762
766
  spatial_dims=spatial_dims,
763
767
  in_channels=reversed_block_out_channels[0],
764
768
  norm_num_groups=norm_num_groups,
@@ -767,16 +771,18 @@ class MaisiDecoder(nn.Module):
767
771
  )
768
772
  )
769
773
  blocks.append(
770
- AttentionBlock(
774
+ SpatialAttentionBlock(
771
775
  spatial_dims=spatial_dims,
772
776
  num_channels=reversed_block_out_channels[0],
773
777
  norm_num_groups=norm_num_groups,
774
778
  norm_eps=norm_eps,
779
+ include_fc=include_fc,
780
+ use_combined_linear=use_combined_linear,
775
781
  use_flash_attention=use_flash_attention,
776
782
  )
777
783
  )
778
784
  blocks.append(
779
- ResBlock(
785
+ AEKLResBlock(
780
786
  spatial_dims=spatial_dims,
781
787
  in_channels=reversed_block_out_channels[0],
782
788
  norm_num_groups=norm_num_groups,
@@ -812,11 +818,13 @@ class MaisiDecoder(nn.Module):
812
818
 
813
819
  if reversed_attention_levels[i]:
814
820
  blocks.append(
815
- AttentionBlock(
821
+ SpatialAttentionBlock(
816
822
  spatial_dims=spatial_dims,
817
823
  num_channels=block_in_ch,
818
824
  norm_num_groups=norm_num_groups,
819
825
  norm_eps=norm_eps,
826
+ include_fc=include_fc,
827
+ use_combined_linear=use_combined_linear,
820
828
  use_flash_attention=use_flash_attention,
821
829
  )
822
830
  )
@@ -870,7 +878,7 @@ class MaisiDecoder(nn.Module):
870
878
  return x
871
879
 
872
880
 
873
- class AutoencoderKlMaisi(AutoencoderKLType):
881
+ class AutoencoderKlMaisi(AutoencoderKL):
874
882
  """
875
883
  AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
876
884
 
@@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
886
894
  norm_eps: Epsilon for the normalization.
887
895
  with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
888
896
  with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
897
+ include_fc: whether to include the final linear layer. Default to False.
898
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
889
899
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
890
900
  use_checkpointing: If True, use activation checkpointing.
891
901
  use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
@@ -909,6 +919,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
909
919
  norm_eps: float = 1e-6,
910
920
  with_encoder_nonlocal_attn: bool = False,
911
921
  with_decoder_nonlocal_attn: bool = False,
922
+ include_fc: bool = False,
923
+ use_combined_linear: bool = False,
912
924
  use_flash_attention: bool = False,
913
925
  use_checkpointing: bool = False,
914
926
  use_convtranspose: bool = False,
@@ -930,12 +942,14 @@ class AutoencoderKlMaisi(AutoencoderKLType):
930
942
  norm_eps,
931
943
  with_encoder_nonlocal_attn,
932
944
  with_decoder_nonlocal_attn,
933
- use_flash_attention,
934
945
  use_checkpointing,
935
946
  use_convtranspose,
947
+ include_fc,
948
+ use_combined_linear,
949
+ use_flash_attention,
936
950
  )
937
951
 
938
- self.encoder = MaisiEncoder(
952
+ self.encoder: nn.Module = MaisiEncoder(
939
953
  spatial_dims=spatial_dims,
940
954
  in_channels=in_channels,
941
955
  num_channels=num_channels,
@@ -945,6 +959,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
945
959
  norm_eps=norm_eps,
946
960
  attention_levels=attention_levels,
947
961
  with_nonlocal_attn=with_encoder_nonlocal_attn,
962
+ include_fc=include_fc,
963
+ use_combined_linear=use_combined_linear,
948
964
  use_flash_attention=use_flash_attention,
949
965
  num_splits=num_splits,
950
966
  dim_split=dim_split,
@@ -953,7 +969,7 @@ class AutoencoderKlMaisi(AutoencoderKLType):
953
969
  save_mem=save_mem,
954
970
  )
955
971
 
956
- self.decoder = MaisiDecoder(
972
+ self.decoder: nn.Module = MaisiDecoder(
957
973
  spatial_dims=spatial_dims,
958
974
  num_channels=num_channels,
959
975
  in_channels=latent_channels,
@@ -963,6 +979,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
963
979
  norm_eps=norm_eps,
964
980
  attention_levels=attention_levels,
965
981
  with_nonlocal_attn=with_decoder_nonlocal_attn,
982
+ include_fc=include_fc,
983
+ use_combined_linear=use_combined_linear,
966
984
  use_flash_attention=use_flash_attention,
967
985
  use_convtranspose=use_convtranspose,
968
986
  num_splits=num_splits,
@@ -11,24 +11,15 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from typing import TYPE_CHECKING, Sequence, cast
14
+ from typing import Sequence
15
15
 
16
16
  import torch
17
17
 
18
- from monai.utils import optional_import
18
+ from monai.networks.nets.controlnet import ControlNet
19
+ from monai.networks.nets.diffusion_model_unet import get_timestep_embedding
19
20
 
20
- ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
21
- get_timestep_embedding, has_get_timestep_embedding = optional_import(
22
- "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
23
- )
24
21
 
25
- if TYPE_CHECKING:
26
- from generative.networks.nets.controlnet import ControlNet as ControlNetType
27
- else:
28
- ControlNetType = cast(type, ControlNet)
29
-
30
-
31
- class ControlNetMaisi(ControlNetType):
22
+ class ControlNetMaisi(ControlNet):
32
23
  """
33
24
  Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
34
25
  Diffusion Models" (https://arxiv.org/abs/2302.05543)
@@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType):
49
40
  num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
50
41
  classes.
51
42
  upcast_attention: if True, upcast attention operations to full precision.
52
- use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
53
43
  conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
54
44
  conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
55
45
  use_checkpointing: if True, use activation checkpointing to save memory.
46
+ include_fc: whether to include the final linear layer. Default to False.
47
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
48
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
56
49
  """
57
50
 
58
51
  def __init__(
@@ -71,10 +64,12 @@ class ControlNetMaisi(ControlNetType):
71
64
  cross_attention_dim: int | None = None,
72
65
  num_class_embeds: int | None = None,
73
66
  upcast_attention: bool = False,
74
- use_flash_attention: bool = False,
75
67
  conditioning_embedding_in_channels: int = 1,
76
- conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
68
+ conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
77
69
  use_checkpointing: bool = True,
70
+ include_fc: bool = False,
71
+ use_combined_linear: bool = False,
72
+ use_flash_attention: bool = False,
78
73
  ) -> None:
79
74
  super().__init__(
80
75
  spatial_dims,
@@ -91,9 +86,11 @@ class ControlNetMaisi(ControlNetType):
91
86
  cross_attention_dim,
92
87
  num_class_embeds,
93
88
  upcast_attention,
94
- use_flash_attention,
95
89
  conditioning_embedding_in_channels,
96
90
  conditioning_embedding_num_channels,
91
+ include_fc,
92
+ use_combined_linear,
93
+ use_flash_attention,
97
94
  )
98
95
  self.use_checkpointing = use_checkpointing
99
96
 
@@ -105,7 +102,7 @@ class ControlNetMaisi(ControlNetType):
105
102
  conditioning_scale: float = 1.0,
106
103
  context: torch.Tensor | None = None,
107
104
  class_labels: torch.Tensor | None = None,
108
- ) -> tuple[Sequence[torch.Tensor], torch.Tensor]:
105
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
109
106
  emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
110
107
  h = self._apply_initial_convolution(x)
111
108
  if self.use_checkpointing:
@@ -37,21 +37,15 @@ import torch
37
37
  from torch import nn
38
38
 
39
39
  from monai.networks.blocks import Convolution
40
- from monai.utils import ensure_tuple_rep, optional_import
41
- from monai.utils.type_conversion import convert_to_tensor
42
-
43
- get_down_block, has_get_down_block = optional_import(
44
- "generative.networks.nets.diffusion_model_unet", name="get_down_block"
45
- )
46
- get_mid_block, has_get_mid_block = optional_import(
47
- "generative.networks.nets.diffusion_model_unet", name="get_mid_block"
48
- )
49
- get_timestep_embedding, has_get_timestep_embedding = optional_import(
50
- "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
40
+ from monai.networks.nets.diffusion_model_unet import (
41
+ get_down_block,
42
+ get_mid_block,
43
+ get_timestep_embedding,
44
+ get_up_block,
45
+ zero_module,
51
46
  )
52
- get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block")
53
- xformers, has_xformers = optional_import("xformers")
54
- zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
47
+ from monai.utils import ensure_tuple_rep
48
+ from monai.utils.type_conversion import convert_to_tensor
55
49
 
56
50
  __all__ = ["DiffusionModelUNetMaisi"]
57
51
 
@@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module):
78
72
  cross_attention_dim: Number of context dimensions to use.
79
73
  num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
80
74
  upcast_attention: If True, upcast attention operations to full precision.
75
+ include_fc: whether to include the final linear layer. Default to False.
76
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
81
77
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
82
78
  dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
83
79
  include_top_region_index_input: If True, use top region index input.
@@ -102,6 +98,8 @@ class DiffusionModelUNetMaisi(nn.Module):
102
98
  cross_attention_dim: int | None = None,
103
99
  num_class_embeds: int | None = None,
104
100
  upcast_attention: bool = False,
101
+ include_fc: bool = False,
102
+ use_combined_linear: bool = False,
105
103
  use_flash_attention: bool = False,
106
104
  dropout_cattn: float = 0.0,
107
105
  include_top_region_index_input: bool = False,
@@ -152,9 +150,6 @@ class DiffusionModelUNetMaisi(nn.Module):
152
150
  "`num_channels`."
153
151
  )
154
152
 
155
- if use_flash_attention and not has_xformers:
156
- raise ValueError("use_flash_attention is True but xformers is not installed.")
157
-
158
153
  if use_flash_attention is True and not torch.cuda.is_available():
159
154
  raise ValueError(
160
155
  "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
@@ -210,7 +205,6 @@ class DiffusionModelUNetMaisi(nn.Module):
210
205
  input_channel = output_channel
211
206
  output_channel = num_channels[i]
212
207
  is_final_block = i == len(num_channels) - 1
213
-
214
208
  down_block = get_down_block(
215
209
  spatial_dims=spatial_dims,
216
210
  in_channels=input_channel,
@@ -227,6 +221,8 @@ class DiffusionModelUNetMaisi(nn.Module):
227
221
  transformer_num_layers=transformer_num_layers,
228
222
  cross_attention_dim=cross_attention_dim,
229
223
  upcast_attention=upcast_attention,
224
+ include_fc=include_fc,
225
+ use_combined_linear=use_combined_linear,
230
226
  use_flash_attention=use_flash_attention,
231
227
  dropout_cattn=dropout_cattn,
232
228
  )
@@ -245,6 +241,8 @@ class DiffusionModelUNetMaisi(nn.Module):
245
241
  transformer_num_layers=transformer_num_layers,
246
242
  cross_attention_dim=cross_attention_dim,
247
243
  upcast_attention=upcast_attention,
244
+ include_fc=include_fc,
245
+ use_combined_linear=use_combined_linear,
248
246
  use_flash_attention=use_flash_attention,
249
247
  dropout_cattn=dropout_cattn,
250
248
  )
@@ -280,6 +278,8 @@ class DiffusionModelUNetMaisi(nn.Module):
280
278
  transformer_num_layers=transformer_num_layers,
281
279
  cross_attention_dim=cross_attention_dim,
282
280
  upcast_attention=upcast_attention,
281
+ include_fc=include_fc,
282
+ use_combined_linear=use_combined_linear,
283
283
  use_flash_attention=use_flash_attention,
284
284
  dropout_cattn=dropout_cattn,
285
285
  )
@@ -0,0 +1,177 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ from collections.abc import Sequence
16
+ from typing import Any
17
+
18
+ import torch
19
+
20
+ from monai.data.meta_tensor import MetaTensor
21
+ from monai.utils import optional_import
22
+
23
+ tqdm, _ = optional_import("tqdm", name="tqdm")
24
+
25
+ __all__ = ["point_based_window_inferer"]
26
+
27
+
28
+ def point_based_window_inferer(
29
+ inputs: torch.Tensor | MetaTensor,
30
+ roi_size: Sequence[int],
31
+ predictor: torch.nn.Module,
32
+ point_coords: torch.Tensor,
33
+ point_labels: torch.Tensor,
34
+ class_vector: torch.Tensor | None = None,
35
+ prompt_class: torch.Tensor | None = None,
36
+ prev_mask: torch.Tensor | MetaTensor | None = None,
37
+ point_start: int = 0,
38
+ center_only: bool = True,
39
+ margin: int = 5,
40
+ **kwargs: Any,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
44
+ The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
45
+ patch inference and average output stitching, and finally returns the segmented mask.
46
+
47
+ Args:
48
+ inputs: [1CHWD], input image to be processed.
49
+ roi_size: the spatial window size for inferences.
50
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
51
+ if the components of the `roi_size` are non-positive values, the transform will use the
52
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
53
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
54
+ sw_batch_size: the batch size to run window slices.
55
+ predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
56
+ Add transpose=True in kwargs for vista3d.
57
+ point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
58
+ point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
59
+ 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
60
+ class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
61
+ prompt_class: [B]. The same as class_vector representing the point class and inform point head about
62
+ supported class or zeroshot, not used for automatic segmentation. If None, point head is default
63
+ to supported class segmentation.
64
+ prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
65
+ point_start: only use points starting from this number. All points before this number is used to generate
66
+ prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
67
+ center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
68
+ margin: if center_only is false, this value is the distance between point to the patch boundary.
69
+ Returns:
70
+ stitched_output: [1, B, H, W, D]. The value is before sigmoid.
71
+ Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
72
+ """
73
+ if not point_coords.shape[0] == 1:
74
+ raise ValueError("Only supports single object point click.")
75
+ if not len(inputs.shape) == 5:
76
+ raise ValueError("Input image should be 5D.")
77
+ image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
78
+ point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
79
+ prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
80
+ stitched_output = None
81
+ for p in point_coords[0][point_start:]:
82
+ lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)
83
+ ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)
84
+ lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)
85
+ for i in range(len(lx_)):
86
+ for j in range(len(ly_)):
87
+ for k in range(len(lz_)):
88
+ lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
89
+ unravel_slice = [
90
+ slice(None),
91
+ slice(None),
92
+ slice(int(lx), int(rx)),
93
+ slice(int(ly), int(ry)),
94
+ slice(int(lz), int(rz)),
95
+ ]
96
+ batch_image = image[unravel_slice]
97
+ output = predictor(
98
+ batch_image,
99
+ point_coords=point_coords,
100
+ point_labels=point_labels,
101
+ class_vector=class_vector,
102
+ prompt_class=prompt_class,
103
+ patch_coords=unravel_slice,
104
+ prev_mask=prev_mask,
105
+ **kwargs,
106
+ )
107
+ if stitched_output is None:
108
+ stitched_output = torch.zeros(
109
+ [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
110
+ )
111
+ stitched_mask = torch.zeros(
112
+ [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
113
+ )
114
+ stitched_output[unravel_slice] += output.to("cpu")
115
+ stitched_mask[unravel_slice] = 1
116
+ # if stitched_mask is 0, then NaN value
117
+ stitched_output = stitched_output / stitched_mask
118
+ # revert padding
119
+ stitched_output = stitched_output[
120
+ :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
121
+ ]
122
+ stitched_mask = stitched_mask[
123
+ :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
124
+ ]
125
+ if prev_mask is not None:
126
+ prev_mask = prev_mask[
127
+ :,
128
+ :,
129
+ pad[4] : image.shape[-3] - pad[5],
130
+ pad[2] : image.shape[-2] - pad[3],
131
+ pad[0] : image.shape[-1] - pad[1],
132
+ ]
133
+ prev_mask = prev_mask.to("cpu") # type: ignore
134
+ # for un-calculated place, use previous mask
135
+ stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
136
+ if isinstance(inputs, torch.Tensor):
137
+ inputs = MetaTensor(inputs)
138
+ if not hasattr(stitched_output, "meta"):
139
+ stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
140
+ return stitched_output
141
+
142
+
143
+ def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
144
+ """Helper function to get the window index."""
145
+ if p - roi // 2 < 0:
146
+ left, right = 0, roi
147
+ elif p + roi // 2 > s:
148
+ left, right = s - roi, s
149
+ else:
150
+ left, right = int(p) - roi // 2, int(p) + roi // 2
151
+ return left, right
152
+
153
+
154
+ def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
155
+ """Get the window index."""
156
+ left, right = _get_window_idx_c(p, roi, s)
157
+ if center_only:
158
+ return [left], [right]
159
+ left_most = max(0, p - roi + margin)
160
+ right_most = min(s, p + roi - margin)
161
+ left_list = [left_most, right_most - roi, left]
162
+ right_list = [left_most + roi, right_most, right]
163
+ return left_list, right_list
164
+
165
+
166
+ def _pad_previous_mask(
167
+ inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
168
+ ) -> tuple[torch.Tensor | MetaTensor, list[int]]:
169
+ """Helper function to pad inputs."""
170
+ pad_size = []
171
+ for k in range(len(inputs.shape) - 1, 1, -1):
172
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
173
+ half = diff // 2
174
+ pad_size.extend([half, diff - half])
175
+ if any(pad_size):
176
+ inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
177
+ return inputs, pad_size