monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/scripts.py +29 -17
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/inferers/utils.py +1 -0
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +22 -0
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +946 -0
- monai/networks/utils.py +4 -4
- monai/transforms/__init__.py +13 -2
- monai/transforms/io/array.py +59 -3
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utility/dictionary.py +4 -0
- monai/transforms/utils.py +230 -1
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/enums.py +1 -0
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-
|
11
|
+
"date": "2024-09-01T02:28:54+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.4.
|
14
|
+
"full-revisionid": "d311b1d7b12a95dd7de995b507ffbb5ed413bab6",
|
15
|
+
"version": "1.4.dev2435"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -13,25 +13,17 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
import gc
|
15
15
|
import logging
|
16
|
-
from typing import
|
16
|
+
from typing import Sequence
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
import torch.nn.functional as F
|
21
21
|
|
22
22
|
from monai.networks.blocks import Convolution
|
23
|
-
from monai.
|
23
|
+
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
|
24
|
+
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
|
24
25
|
from monai.utils.type_conversion import convert_to_tensor
|
25
26
|
|
26
|
-
AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
|
27
|
-
AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
|
28
|
-
ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
|
29
|
-
|
30
|
-
if TYPE_CHECKING:
|
31
|
-
from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
|
32
|
-
else:
|
33
|
-
AutoencoderKLType = cast(type, AutoencoderKL)
|
34
|
-
|
35
27
|
# Set up logging configuration
|
36
28
|
logger = logging.getLogger(__name__)
|
37
29
|
|
@@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
|
|
518
510
|
in_channels: Number of input channels.
|
519
511
|
num_channels: Sequence of block output channels.
|
520
512
|
out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
521
|
-
num_res_blocks: Number of residual blocks (see
|
513
|
+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
|
522
514
|
norm_num_groups: Number of groups for the group norm layers.
|
523
515
|
norm_eps: Epsilon for the normalization.
|
524
516
|
attention_levels: Indicate which level from num_channels contain an attention block.
|
525
517
|
with_nonlocal_attn: If True, use non-local attention block.
|
518
|
+
include_fc: whether to include the final linear layer in the attention block. Default to False.
|
519
|
+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
|
526
520
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
527
521
|
num_splits: Number of splits for the input tensor.
|
528
522
|
dim_split: Dimension of splitting for the input tensor.
|
@@ -547,6 +541,8 @@ class MaisiEncoder(nn.Module):
|
|
547
541
|
print_info: bool = False,
|
548
542
|
save_mem: bool = True,
|
549
543
|
with_nonlocal_attn: bool = True,
|
544
|
+
include_fc: bool = False,
|
545
|
+
use_combined_linear: bool = False,
|
550
546
|
use_flash_attention: bool = False,
|
551
547
|
) -> None:
|
552
548
|
super().__init__()
|
@@ -603,11 +599,13 @@ class MaisiEncoder(nn.Module):
|
|
603
599
|
input_channel = output_channel
|
604
600
|
if attention_levels[i]:
|
605
601
|
blocks.append(
|
606
|
-
|
602
|
+
SpatialAttentionBlock(
|
607
603
|
spatial_dims=spatial_dims,
|
608
604
|
num_channels=input_channel,
|
609
605
|
norm_num_groups=norm_num_groups,
|
610
606
|
norm_eps=norm_eps,
|
607
|
+
include_fc=include_fc,
|
608
|
+
use_combined_linear=use_combined_linear,
|
611
609
|
use_flash_attention=use_flash_attention,
|
612
610
|
)
|
613
611
|
)
|
@@ -626,7 +624,7 @@ class MaisiEncoder(nn.Module):
|
|
626
624
|
|
627
625
|
if with_nonlocal_attn:
|
628
626
|
blocks.append(
|
629
|
-
|
627
|
+
AEKLResBlock(
|
630
628
|
spatial_dims=spatial_dims,
|
631
629
|
in_channels=num_channels[-1],
|
632
630
|
norm_num_groups=norm_num_groups,
|
@@ -636,16 +634,18 @@ class MaisiEncoder(nn.Module):
|
|
636
634
|
)
|
637
635
|
|
638
636
|
blocks.append(
|
639
|
-
|
637
|
+
SpatialAttentionBlock(
|
640
638
|
spatial_dims=spatial_dims,
|
641
639
|
num_channels=num_channels[-1],
|
642
640
|
norm_num_groups=norm_num_groups,
|
643
641
|
norm_eps=norm_eps,
|
642
|
+
include_fc=include_fc,
|
643
|
+
use_combined_linear=use_combined_linear,
|
644
644
|
use_flash_attention=use_flash_attention,
|
645
645
|
)
|
646
646
|
)
|
647
647
|
blocks.append(
|
648
|
-
|
648
|
+
AEKLResBlock(
|
649
649
|
spatial_dims=spatial_dims,
|
650
650
|
in_channels=num_channels[-1],
|
651
651
|
norm_num_groups=norm_num_groups,
|
@@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
|
|
699
699
|
num_channels: Sequence of block output channels.
|
700
700
|
in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
701
701
|
out_channels: Number of output channels.
|
702
|
-
num_res_blocks: Number of residual blocks (see
|
702
|
+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
|
703
703
|
norm_num_groups: Number of groups for the group norm layers.
|
704
704
|
norm_eps: Epsilon for the normalization.
|
705
705
|
attention_levels: Indicate which level from num_channels contain an attention block.
|
706
706
|
with_nonlocal_attn: If True, use non-local attention block.
|
707
|
+
include_fc: whether to include the final linear layer in the attention block. Default to False.
|
708
|
+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
|
707
709
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
708
710
|
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
709
711
|
num_splits: Number of splits for the input tensor.
|
@@ -729,6 +731,8 @@ class MaisiDecoder(nn.Module):
|
|
729
731
|
print_info: bool = False,
|
730
732
|
save_mem: bool = True,
|
731
733
|
with_nonlocal_attn: bool = True,
|
734
|
+
include_fc: bool = False,
|
735
|
+
use_combined_linear: bool = False,
|
732
736
|
use_flash_attention: bool = False,
|
733
737
|
use_convtranspose: bool = False,
|
734
738
|
) -> None:
|
@@ -758,7 +762,7 @@ class MaisiDecoder(nn.Module):
|
|
758
762
|
|
759
763
|
if with_nonlocal_attn:
|
760
764
|
blocks.append(
|
761
|
-
|
765
|
+
AEKLResBlock(
|
762
766
|
spatial_dims=spatial_dims,
|
763
767
|
in_channels=reversed_block_out_channels[0],
|
764
768
|
norm_num_groups=norm_num_groups,
|
@@ -767,16 +771,18 @@ class MaisiDecoder(nn.Module):
|
|
767
771
|
)
|
768
772
|
)
|
769
773
|
blocks.append(
|
770
|
-
|
774
|
+
SpatialAttentionBlock(
|
771
775
|
spatial_dims=spatial_dims,
|
772
776
|
num_channels=reversed_block_out_channels[0],
|
773
777
|
norm_num_groups=norm_num_groups,
|
774
778
|
norm_eps=norm_eps,
|
779
|
+
include_fc=include_fc,
|
780
|
+
use_combined_linear=use_combined_linear,
|
775
781
|
use_flash_attention=use_flash_attention,
|
776
782
|
)
|
777
783
|
)
|
778
784
|
blocks.append(
|
779
|
-
|
785
|
+
AEKLResBlock(
|
780
786
|
spatial_dims=spatial_dims,
|
781
787
|
in_channels=reversed_block_out_channels[0],
|
782
788
|
norm_num_groups=norm_num_groups,
|
@@ -812,11 +818,13 @@ class MaisiDecoder(nn.Module):
|
|
812
818
|
|
813
819
|
if reversed_attention_levels[i]:
|
814
820
|
blocks.append(
|
815
|
-
|
821
|
+
SpatialAttentionBlock(
|
816
822
|
spatial_dims=spatial_dims,
|
817
823
|
num_channels=block_in_ch,
|
818
824
|
norm_num_groups=norm_num_groups,
|
819
825
|
norm_eps=norm_eps,
|
826
|
+
include_fc=include_fc,
|
827
|
+
use_combined_linear=use_combined_linear,
|
820
828
|
use_flash_attention=use_flash_attention,
|
821
829
|
)
|
822
830
|
)
|
@@ -870,7 +878,7 @@ class MaisiDecoder(nn.Module):
|
|
870
878
|
return x
|
871
879
|
|
872
880
|
|
873
|
-
class AutoencoderKlMaisi(
|
881
|
+
class AutoencoderKlMaisi(AutoencoderKL):
|
874
882
|
"""
|
875
883
|
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
|
876
884
|
|
@@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
886
894
|
norm_eps: Epsilon for the normalization.
|
887
895
|
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
|
888
896
|
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
|
897
|
+
include_fc: whether to include the final linear layer. Default to False.
|
898
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
889
899
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
890
900
|
use_checkpointing: If True, use activation checkpointing.
|
891
901
|
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
@@ -909,6 +919,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
909
919
|
norm_eps: float = 1e-6,
|
910
920
|
with_encoder_nonlocal_attn: bool = False,
|
911
921
|
with_decoder_nonlocal_attn: bool = False,
|
922
|
+
include_fc: bool = False,
|
923
|
+
use_combined_linear: bool = False,
|
912
924
|
use_flash_attention: bool = False,
|
913
925
|
use_checkpointing: bool = False,
|
914
926
|
use_convtranspose: bool = False,
|
@@ -930,12 +942,14 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
930
942
|
norm_eps,
|
931
943
|
with_encoder_nonlocal_attn,
|
932
944
|
with_decoder_nonlocal_attn,
|
933
|
-
use_flash_attention,
|
934
945
|
use_checkpointing,
|
935
946
|
use_convtranspose,
|
947
|
+
include_fc,
|
948
|
+
use_combined_linear,
|
949
|
+
use_flash_attention,
|
936
950
|
)
|
937
951
|
|
938
|
-
self.encoder = MaisiEncoder(
|
952
|
+
self.encoder: nn.Module = MaisiEncoder(
|
939
953
|
spatial_dims=spatial_dims,
|
940
954
|
in_channels=in_channels,
|
941
955
|
num_channels=num_channels,
|
@@ -945,6 +959,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
945
959
|
norm_eps=norm_eps,
|
946
960
|
attention_levels=attention_levels,
|
947
961
|
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
962
|
+
include_fc=include_fc,
|
963
|
+
use_combined_linear=use_combined_linear,
|
948
964
|
use_flash_attention=use_flash_attention,
|
949
965
|
num_splits=num_splits,
|
950
966
|
dim_split=dim_split,
|
@@ -953,7 +969,7 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
953
969
|
save_mem=save_mem,
|
954
970
|
)
|
955
971
|
|
956
|
-
self.decoder = MaisiDecoder(
|
972
|
+
self.decoder: nn.Module = MaisiDecoder(
|
957
973
|
spatial_dims=spatial_dims,
|
958
974
|
num_channels=num_channels,
|
959
975
|
in_channels=latent_channels,
|
@@ -963,6 +979,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
963
979
|
norm_eps=norm_eps,
|
964
980
|
attention_levels=attention_levels,
|
965
981
|
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
982
|
+
include_fc=include_fc,
|
983
|
+
use_combined_linear=use_combined_linear,
|
966
984
|
use_flash_attention=use_flash_attention,
|
967
985
|
use_convtranspose=use_convtranspose,
|
968
986
|
num_splits=num_splits,
|
@@ -11,24 +11,15 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
-
from typing import
|
14
|
+
from typing import Sequence
|
15
15
|
|
16
16
|
import torch
|
17
17
|
|
18
|
-
from monai.
|
18
|
+
from monai.networks.nets.controlnet import ControlNet
|
19
|
+
from monai.networks.nets.diffusion_model_unet import get_timestep_embedding
|
19
20
|
|
20
|
-
ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
|
21
|
-
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
22
|
-
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
23
|
-
)
|
24
21
|
|
25
|
-
|
26
|
-
from generative.networks.nets.controlnet import ControlNet as ControlNetType
|
27
|
-
else:
|
28
|
-
ControlNetType = cast(type, ControlNet)
|
29
|
-
|
30
|
-
|
31
|
-
class ControlNetMaisi(ControlNetType):
|
22
|
+
class ControlNetMaisi(ControlNet):
|
32
23
|
"""
|
33
24
|
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
|
34
25
|
Diffusion Models" (https://arxiv.org/abs/2302.05543)
|
@@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType):
|
|
49
40
|
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
50
41
|
classes.
|
51
42
|
upcast_attention: if True, upcast attention operations to full precision.
|
52
|
-
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
53
43
|
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
|
54
44
|
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
|
55
45
|
use_checkpointing: if True, use activation checkpointing to save memory.
|
46
|
+
include_fc: whether to include the final linear layer. Default to False.
|
47
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
48
|
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
56
49
|
"""
|
57
50
|
|
58
51
|
def __init__(
|
@@ -71,10 +64,12 @@ class ControlNetMaisi(ControlNetType):
|
|
71
64
|
cross_attention_dim: int | None = None,
|
72
65
|
num_class_embeds: int | None = None,
|
73
66
|
upcast_attention: bool = False,
|
74
|
-
use_flash_attention: bool = False,
|
75
67
|
conditioning_embedding_in_channels: int = 1,
|
76
|
-
conditioning_embedding_num_channels: Sequence[int]
|
68
|
+
conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
|
77
69
|
use_checkpointing: bool = True,
|
70
|
+
include_fc: bool = False,
|
71
|
+
use_combined_linear: bool = False,
|
72
|
+
use_flash_attention: bool = False,
|
78
73
|
) -> None:
|
79
74
|
super().__init__(
|
80
75
|
spatial_dims,
|
@@ -91,9 +86,11 @@ class ControlNetMaisi(ControlNetType):
|
|
91
86
|
cross_attention_dim,
|
92
87
|
num_class_embeds,
|
93
88
|
upcast_attention,
|
94
|
-
use_flash_attention,
|
95
89
|
conditioning_embedding_in_channels,
|
96
90
|
conditioning_embedding_num_channels,
|
91
|
+
include_fc,
|
92
|
+
use_combined_linear,
|
93
|
+
use_flash_attention,
|
97
94
|
)
|
98
95
|
self.use_checkpointing = use_checkpointing
|
99
96
|
|
@@ -105,7 +102,7 @@ class ControlNetMaisi(ControlNetType):
|
|
105
102
|
conditioning_scale: float = 1.0,
|
106
103
|
context: torch.Tensor | None = None,
|
107
104
|
class_labels: torch.Tensor | None = None,
|
108
|
-
) -> tuple[
|
105
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
109
106
|
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
|
110
107
|
h = self._apply_initial_convolution(x)
|
111
108
|
if self.use_checkpointing:
|
@@ -37,21 +37,15 @@ import torch
|
|
37
37
|
from torch import nn
|
38
38
|
|
39
39
|
from monai.networks.blocks import Convolution
|
40
|
-
from monai.
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
get_mid_block, has_get_mid_block = optional_import(
|
47
|
-
"generative.networks.nets.diffusion_model_unet", name="get_mid_block"
|
48
|
-
)
|
49
|
-
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
50
|
-
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
40
|
+
from monai.networks.nets.diffusion_model_unet import (
|
41
|
+
get_down_block,
|
42
|
+
get_mid_block,
|
43
|
+
get_timestep_embedding,
|
44
|
+
get_up_block,
|
45
|
+
zero_module,
|
51
46
|
)
|
52
|
-
|
53
|
-
|
54
|
-
zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
|
47
|
+
from monai.utils import ensure_tuple_rep
|
48
|
+
from monai.utils.type_conversion import convert_to_tensor
|
55
49
|
|
56
50
|
__all__ = ["DiffusionModelUNetMaisi"]
|
57
51
|
|
@@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
78
72
|
cross_attention_dim: Number of context dimensions to use.
|
79
73
|
num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
|
80
74
|
upcast_attention: If True, upcast attention operations to full precision.
|
75
|
+
include_fc: whether to include the final linear layer. Default to False.
|
76
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
81
77
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
82
78
|
dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
|
83
79
|
include_top_region_index_input: If True, use top region index input.
|
@@ -102,6 +98,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
102
98
|
cross_attention_dim: int | None = None,
|
103
99
|
num_class_embeds: int | None = None,
|
104
100
|
upcast_attention: bool = False,
|
101
|
+
include_fc: bool = False,
|
102
|
+
use_combined_linear: bool = False,
|
105
103
|
use_flash_attention: bool = False,
|
106
104
|
dropout_cattn: float = 0.0,
|
107
105
|
include_top_region_index_input: bool = False,
|
@@ -152,9 +150,6 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
152
150
|
"`num_channels`."
|
153
151
|
)
|
154
152
|
|
155
|
-
if use_flash_attention and not has_xformers:
|
156
|
-
raise ValueError("use_flash_attention is True but xformers is not installed.")
|
157
|
-
|
158
153
|
if use_flash_attention is True and not torch.cuda.is_available():
|
159
154
|
raise ValueError(
|
160
155
|
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
|
@@ -210,7 +205,6 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
210
205
|
input_channel = output_channel
|
211
206
|
output_channel = num_channels[i]
|
212
207
|
is_final_block = i == len(num_channels) - 1
|
213
|
-
|
214
208
|
down_block = get_down_block(
|
215
209
|
spatial_dims=spatial_dims,
|
216
210
|
in_channels=input_channel,
|
@@ -227,6 +221,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
227
221
|
transformer_num_layers=transformer_num_layers,
|
228
222
|
cross_attention_dim=cross_attention_dim,
|
229
223
|
upcast_attention=upcast_attention,
|
224
|
+
include_fc=include_fc,
|
225
|
+
use_combined_linear=use_combined_linear,
|
230
226
|
use_flash_attention=use_flash_attention,
|
231
227
|
dropout_cattn=dropout_cattn,
|
232
228
|
)
|
@@ -245,6 +241,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
245
241
|
transformer_num_layers=transformer_num_layers,
|
246
242
|
cross_attention_dim=cross_attention_dim,
|
247
243
|
upcast_attention=upcast_attention,
|
244
|
+
include_fc=include_fc,
|
245
|
+
use_combined_linear=use_combined_linear,
|
248
246
|
use_flash_attention=use_flash_attention,
|
249
247
|
dropout_cattn=dropout_cattn,
|
250
248
|
)
|
@@ -280,6 +278,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
280
278
|
transformer_num_layers=transformer_num_layers,
|
281
279
|
cross_attention_dim=cross_attention_dim,
|
282
280
|
upcast_attention=upcast_attention,
|
281
|
+
include_fc=include_fc,
|
282
|
+
use_combined_linear=use_combined_linear,
|
283
283
|
use_flash_attention=use_flash_attention,
|
284
284
|
dropout_cattn=dropout_cattn,
|
285
285
|
)
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import copy
|
15
|
+
from collections.abc import Sequence
|
16
|
+
from typing import Any
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from monai.data.meta_tensor import MetaTensor
|
21
|
+
from monai.utils import optional_import
|
22
|
+
|
23
|
+
tqdm, _ = optional_import("tqdm", name="tqdm")
|
24
|
+
|
25
|
+
__all__ = ["point_based_window_inferer"]
|
26
|
+
|
27
|
+
|
28
|
+
def point_based_window_inferer(
|
29
|
+
inputs: torch.Tensor | MetaTensor,
|
30
|
+
roi_size: Sequence[int],
|
31
|
+
predictor: torch.nn.Module,
|
32
|
+
point_coords: torch.Tensor,
|
33
|
+
point_labels: torch.Tensor,
|
34
|
+
class_vector: torch.Tensor | None = None,
|
35
|
+
prompt_class: torch.Tensor | None = None,
|
36
|
+
prev_mask: torch.Tensor | MetaTensor | None = None,
|
37
|
+
point_start: int = 0,
|
38
|
+
center_only: bool = True,
|
39
|
+
margin: int = 5,
|
40
|
+
**kwargs: Any,
|
41
|
+
) -> torch.Tensor:
|
42
|
+
"""
|
43
|
+
Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
|
44
|
+
The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
|
45
|
+
patch inference and average output stitching, and finally returns the segmented mask.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
inputs: [1CHWD], input image to be processed.
|
49
|
+
roi_size: the spatial window size for inferences.
|
50
|
+
When its components have None or non-positives, the corresponding inputs dimension will be used.
|
51
|
+
if the components of the `roi_size` are non-positive values, the transform will use the
|
52
|
+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
|
53
|
+
to `(32, 64)` if the second spatial dimension size of img is `64`.
|
54
|
+
sw_batch_size: the batch size to run window slices.
|
55
|
+
predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
|
56
|
+
Add transpose=True in kwargs for vista3d.
|
57
|
+
point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
|
58
|
+
point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
|
59
|
+
2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
|
60
|
+
class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
|
61
|
+
prompt_class: [B]. The same as class_vector representing the point class and inform point head about
|
62
|
+
supported class or zeroshot, not used for automatic segmentation. If None, point head is default
|
63
|
+
to supported class segmentation.
|
64
|
+
prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
|
65
|
+
point_start: only use points starting from this number. All points before this number is used to generate
|
66
|
+
prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
|
67
|
+
center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
|
68
|
+
margin: if center_only is false, this value is the distance between point to the patch boundary.
|
69
|
+
Returns:
|
70
|
+
stitched_output: [1, B, H, W, D]. The value is before sigmoid.
|
71
|
+
Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
|
72
|
+
"""
|
73
|
+
if not point_coords.shape[0] == 1:
|
74
|
+
raise ValueError("Only supports single object point click.")
|
75
|
+
if not len(inputs.shape) == 5:
|
76
|
+
raise ValueError("Input image should be 5D.")
|
77
|
+
image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
|
78
|
+
point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
|
79
|
+
prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
|
80
|
+
stitched_output = None
|
81
|
+
for p in point_coords[0][point_start:]:
|
82
|
+
lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)
|
83
|
+
ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)
|
84
|
+
lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)
|
85
|
+
for i in range(len(lx_)):
|
86
|
+
for j in range(len(ly_)):
|
87
|
+
for k in range(len(lz_)):
|
88
|
+
lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
|
89
|
+
unravel_slice = [
|
90
|
+
slice(None),
|
91
|
+
slice(None),
|
92
|
+
slice(int(lx), int(rx)),
|
93
|
+
slice(int(ly), int(ry)),
|
94
|
+
slice(int(lz), int(rz)),
|
95
|
+
]
|
96
|
+
batch_image = image[unravel_slice]
|
97
|
+
output = predictor(
|
98
|
+
batch_image,
|
99
|
+
point_coords=point_coords,
|
100
|
+
point_labels=point_labels,
|
101
|
+
class_vector=class_vector,
|
102
|
+
prompt_class=prompt_class,
|
103
|
+
patch_coords=unravel_slice,
|
104
|
+
prev_mask=prev_mask,
|
105
|
+
**kwargs,
|
106
|
+
)
|
107
|
+
if stitched_output is None:
|
108
|
+
stitched_output = torch.zeros(
|
109
|
+
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
|
110
|
+
)
|
111
|
+
stitched_mask = torch.zeros(
|
112
|
+
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
|
113
|
+
)
|
114
|
+
stitched_output[unravel_slice] += output.to("cpu")
|
115
|
+
stitched_mask[unravel_slice] = 1
|
116
|
+
# if stitched_mask is 0, then NaN value
|
117
|
+
stitched_output = stitched_output / stitched_mask
|
118
|
+
# revert padding
|
119
|
+
stitched_output = stitched_output[
|
120
|
+
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
|
121
|
+
]
|
122
|
+
stitched_mask = stitched_mask[
|
123
|
+
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
|
124
|
+
]
|
125
|
+
if prev_mask is not None:
|
126
|
+
prev_mask = prev_mask[
|
127
|
+
:,
|
128
|
+
:,
|
129
|
+
pad[4] : image.shape[-3] - pad[5],
|
130
|
+
pad[2] : image.shape[-2] - pad[3],
|
131
|
+
pad[0] : image.shape[-1] - pad[1],
|
132
|
+
]
|
133
|
+
prev_mask = prev_mask.to("cpu") # type: ignore
|
134
|
+
# for un-calculated place, use previous mask
|
135
|
+
stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
|
136
|
+
if isinstance(inputs, torch.Tensor):
|
137
|
+
inputs = MetaTensor(inputs)
|
138
|
+
if not hasattr(stitched_output, "meta"):
|
139
|
+
stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
|
140
|
+
return stitched_output
|
141
|
+
|
142
|
+
|
143
|
+
def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
|
144
|
+
"""Helper function to get the window index."""
|
145
|
+
if p - roi // 2 < 0:
|
146
|
+
left, right = 0, roi
|
147
|
+
elif p + roi // 2 > s:
|
148
|
+
left, right = s - roi, s
|
149
|
+
else:
|
150
|
+
left, right = int(p) - roi // 2, int(p) + roi // 2
|
151
|
+
return left, right
|
152
|
+
|
153
|
+
|
154
|
+
def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
|
155
|
+
"""Get the window index."""
|
156
|
+
left, right = _get_window_idx_c(p, roi, s)
|
157
|
+
if center_only:
|
158
|
+
return [left], [right]
|
159
|
+
left_most = max(0, p - roi + margin)
|
160
|
+
right_most = min(s, p + roi - margin)
|
161
|
+
left_list = [left_most, right_most - roi, left]
|
162
|
+
right_list = [left_most + roi, right_most, right]
|
163
|
+
return left_list, right_list
|
164
|
+
|
165
|
+
|
166
|
+
def _pad_previous_mask(
|
167
|
+
inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
|
168
|
+
) -> tuple[torch.Tensor | MetaTensor, list[int]]:
|
169
|
+
"""Helper function to pad inputs."""
|
170
|
+
pad_size = []
|
171
|
+
for k in range(len(inputs.shape) - 1, 1, -1):
|
172
|
+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
|
173
|
+
half = diff // 2
|
174
|
+
pad_size.extend([half, diff - half])
|
175
|
+
if any(pad_size):
|
176
|
+
inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
|
177
|
+
return inputs, pad_size
|