monai-weekly 1.4.dev2430__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/bundle/config_parser.py +2 -2
- monai/bundle/reference_resolver.py +18 -1
- monai/bundle/scripts.py +45 -22
- monai/bundle/utils.py +3 -1
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +24 -2
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +908 -0
- monai/networks/utils.py +3 -3
- monai/transforms/__init__.py +1 -0
- monai/transforms/io/array.py +1 -1
- monai/transforms/post/array.py +2 -1
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utils.py +183 -0
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +44 -41
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-
|
11
|
+
"date": "2024-08-25T02:21:56+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.4.
|
14
|
+
"full-revisionid": "dc611d231ba670004b1da1b011fe140375fb91af",
|
15
|
+
"version": "1.4.dev2434"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -13,25 +13,17 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
import gc
|
15
15
|
import logging
|
16
|
-
from typing import
|
16
|
+
from typing import Sequence
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
import torch.nn.functional as F
|
21
21
|
|
22
22
|
from monai.networks.blocks import Convolution
|
23
|
-
from monai.
|
23
|
+
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
|
24
|
+
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
|
24
25
|
from monai.utils.type_conversion import convert_to_tensor
|
25
26
|
|
26
|
-
AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
|
27
|
-
AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
|
28
|
-
ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
|
29
|
-
|
30
|
-
if TYPE_CHECKING:
|
31
|
-
from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
|
32
|
-
else:
|
33
|
-
AutoencoderKLType = cast(type, AutoencoderKL)
|
34
|
-
|
35
27
|
# Set up logging configuration
|
36
28
|
logger = logging.getLogger(__name__)
|
37
29
|
|
@@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
|
|
518
510
|
in_channels: Number of input channels.
|
519
511
|
num_channels: Sequence of block output channels.
|
520
512
|
out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
521
|
-
num_res_blocks: Number of residual blocks (see
|
513
|
+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
|
522
514
|
norm_num_groups: Number of groups for the group norm layers.
|
523
515
|
norm_eps: Epsilon for the normalization.
|
524
516
|
attention_levels: Indicate which level from num_channels contain an attention block.
|
525
517
|
with_nonlocal_attn: If True, use non-local attention block.
|
518
|
+
include_fc: whether to include the final linear layer in the attention block. Default to False.
|
519
|
+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
|
526
520
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
527
521
|
num_splits: Number of splits for the input tensor.
|
528
522
|
dim_split: Dimension of splitting for the input tensor.
|
@@ -547,6 +541,8 @@ class MaisiEncoder(nn.Module):
|
|
547
541
|
print_info: bool = False,
|
548
542
|
save_mem: bool = True,
|
549
543
|
with_nonlocal_attn: bool = True,
|
544
|
+
include_fc: bool = False,
|
545
|
+
use_combined_linear: bool = False,
|
550
546
|
use_flash_attention: bool = False,
|
551
547
|
) -> None:
|
552
548
|
super().__init__()
|
@@ -603,11 +599,13 @@ class MaisiEncoder(nn.Module):
|
|
603
599
|
input_channel = output_channel
|
604
600
|
if attention_levels[i]:
|
605
601
|
blocks.append(
|
606
|
-
|
602
|
+
SpatialAttentionBlock(
|
607
603
|
spatial_dims=spatial_dims,
|
608
604
|
num_channels=input_channel,
|
609
605
|
norm_num_groups=norm_num_groups,
|
610
606
|
norm_eps=norm_eps,
|
607
|
+
include_fc=include_fc,
|
608
|
+
use_combined_linear=use_combined_linear,
|
611
609
|
use_flash_attention=use_flash_attention,
|
612
610
|
)
|
613
611
|
)
|
@@ -626,7 +624,7 @@ class MaisiEncoder(nn.Module):
|
|
626
624
|
|
627
625
|
if with_nonlocal_attn:
|
628
626
|
blocks.append(
|
629
|
-
|
627
|
+
AEKLResBlock(
|
630
628
|
spatial_dims=spatial_dims,
|
631
629
|
in_channels=num_channels[-1],
|
632
630
|
norm_num_groups=norm_num_groups,
|
@@ -636,16 +634,18 @@ class MaisiEncoder(nn.Module):
|
|
636
634
|
)
|
637
635
|
|
638
636
|
blocks.append(
|
639
|
-
|
637
|
+
SpatialAttentionBlock(
|
640
638
|
spatial_dims=spatial_dims,
|
641
639
|
num_channels=num_channels[-1],
|
642
640
|
norm_num_groups=norm_num_groups,
|
643
641
|
norm_eps=norm_eps,
|
642
|
+
include_fc=include_fc,
|
643
|
+
use_combined_linear=use_combined_linear,
|
644
644
|
use_flash_attention=use_flash_attention,
|
645
645
|
)
|
646
646
|
)
|
647
647
|
blocks.append(
|
648
|
-
|
648
|
+
AEKLResBlock(
|
649
649
|
spatial_dims=spatial_dims,
|
650
650
|
in_channels=num_channels[-1],
|
651
651
|
norm_num_groups=norm_num_groups,
|
@@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
|
|
699
699
|
num_channels: Sequence of block output channels.
|
700
700
|
in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
701
701
|
out_channels: Number of output channels.
|
702
|
-
num_res_blocks: Number of residual blocks (see
|
702
|
+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
|
703
703
|
norm_num_groups: Number of groups for the group norm layers.
|
704
704
|
norm_eps: Epsilon for the normalization.
|
705
705
|
attention_levels: Indicate which level from num_channels contain an attention block.
|
706
706
|
with_nonlocal_attn: If True, use non-local attention block.
|
707
|
+
include_fc: whether to include the final linear layer in the attention block. Default to False.
|
708
|
+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
|
707
709
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
708
710
|
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
709
711
|
num_splits: Number of splits for the input tensor.
|
@@ -729,6 +731,8 @@ class MaisiDecoder(nn.Module):
|
|
729
731
|
print_info: bool = False,
|
730
732
|
save_mem: bool = True,
|
731
733
|
with_nonlocal_attn: bool = True,
|
734
|
+
include_fc: bool = False,
|
735
|
+
use_combined_linear: bool = False,
|
732
736
|
use_flash_attention: bool = False,
|
733
737
|
use_convtranspose: bool = False,
|
734
738
|
) -> None:
|
@@ -758,7 +762,7 @@ class MaisiDecoder(nn.Module):
|
|
758
762
|
|
759
763
|
if with_nonlocal_attn:
|
760
764
|
blocks.append(
|
761
|
-
|
765
|
+
AEKLResBlock(
|
762
766
|
spatial_dims=spatial_dims,
|
763
767
|
in_channels=reversed_block_out_channels[0],
|
764
768
|
norm_num_groups=norm_num_groups,
|
@@ -767,16 +771,18 @@ class MaisiDecoder(nn.Module):
|
|
767
771
|
)
|
768
772
|
)
|
769
773
|
blocks.append(
|
770
|
-
|
774
|
+
SpatialAttentionBlock(
|
771
775
|
spatial_dims=spatial_dims,
|
772
776
|
num_channels=reversed_block_out_channels[0],
|
773
777
|
norm_num_groups=norm_num_groups,
|
774
778
|
norm_eps=norm_eps,
|
779
|
+
include_fc=include_fc,
|
780
|
+
use_combined_linear=use_combined_linear,
|
775
781
|
use_flash_attention=use_flash_attention,
|
776
782
|
)
|
777
783
|
)
|
778
784
|
blocks.append(
|
779
|
-
|
785
|
+
AEKLResBlock(
|
780
786
|
spatial_dims=spatial_dims,
|
781
787
|
in_channels=reversed_block_out_channels[0],
|
782
788
|
norm_num_groups=norm_num_groups,
|
@@ -812,11 +818,13 @@ class MaisiDecoder(nn.Module):
|
|
812
818
|
|
813
819
|
if reversed_attention_levels[i]:
|
814
820
|
blocks.append(
|
815
|
-
|
821
|
+
SpatialAttentionBlock(
|
816
822
|
spatial_dims=spatial_dims,
|
817
823
|
num_channels=block_in_ch,
|
818
824
|
norm_num_groups=norm_num_groups,
|
819
825
|
norm_eps=norm_eps,
|
826
|
+
include_fc=include_fc,
|
827
|
+
use_combined_linear=use_combined_linear,
|
820
828
|
use_flash_attention=use_flash_attention,
|
821
829
|
)
|
822
830
|
)
|
@@ -870,7 +878,7 @@ class MaisiDecoder(nn.Module):
|
|
870
878
|
return x
|
871
879
|
|
872
880
|
|
873
|
-
class AutoencoderKlMaisi(
|
881
|
+
class AutoencoderKlMaisi(AutoencoderKL):
|
874
882
|
"""
|
875
883
|
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
|
876
884
|
|
@@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
886
894
|
norm_eps: Epsilon for the normalization.
|
887
895
|
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
|
888
896
|
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
|
897
|
+
include_fc: whether to include the final linear layer. Default to False.
|
898
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
889
899
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
890
900
|
use_checkpointing: If True, use activation checkpointing.
|
891
901
|
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
@@ -909,6 +919,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
909
919
|
norm_eps: float = 1e-6,
|
910
920
|
with_encoder_nonlocal_attn: bool = False,
|
911
921
|
with_decoder_nonlocal_attn: bool = False,
|
922
|
+
include_fc: bool = False,
|
923
|
+
use_combined_linear: bool = False,
|
912
924
|
use_flash_attention: bool = False,
|
913
925
|
use_checkpointing: bool = False,
|
914
926
|
use_convtranspose: bool = False,
|
@@ -930,12 +942,14 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
930
942
|
norm_eps,
|
931
943
|
with_encoder_nonlocal_attn,
|
932
944
|
with_decoder_nonlocal_attn,
|
933
|
-
use_flash_attention,
|
934
945
|
use_checkpointing,
|
935
946
|
use_convtranspose,
|
947
|
+
include_fc,
|
948
|
+
use_combined_linear,
|
949
|
+
use_flash_attention,
|
936
950
|
)
|
937
951
|
|
938
|
-
self.encoder = MaisiEncoder(
|
952
|
+
self.encoder: nn.Module = MaisiEncoder(
|
939
953
|
spatial_dims=spatial_dims,
|
940
954
|
in_channels=in_channels,
|
941
955
|
num_channels=num_channels,
|
@@ -945,6 +959,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
945
959
|
norm_eps=norm_eps,
|
946
960
|
attention_levels=attention_levels,
|
947
961
|
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
962
|
+
include_fc=include_fc,
|
963
|
+
use_combined_linear=use_combined_linear,
|
948
964
|
use_flash_attention=use_flash_attention,
|
949
965
|
num_splits=num_splits,
|
950
966
|
dim_split=dim_split,
|
@@ -953,7 +969,7 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
953
969
|
save_mem=save_mem,
|
954
970
|
)
|
955
971
|
|
956
|
-
self.decoder = MaisiDecoder(
|
972
|
+
self.decoder: nn.Module = MaisiDecoder(
|
957
973
|
spatial_dims=spatial_dims,
|
958
974
|
num_channels=num_channels,
|
959
975
|
in_channels=latent_channels,
|
@@ -963,6 +979,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
963
979
|
norm_eps=norm_eps,
|
964
980
|
attention_levels=attention_levels,
|
965
981
|
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
982
|
+
include_fc=include_fc,
|
983
|
+
use_combined_linear=use_combined_linear,
|
966
984
|
use_flash_attention=use_flash_attention,
|
967
985
|
use_convtranspose=use_convtranspose,
|
968
986
|
num_splits=num_splits,
|
@@ -11,24 +11,15 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
-
from typing import
|
14
|
+
from typing import Sequence
|
15
15
|
|
16
16
|
import torch
|
17
17
|
|
18
|
-
from monai.
|
18
|
+
from monai.networks.nets.controlnet import ControlNet
|
19
|
+
from monai.networks.nets.diffusion_model_unet import get_timestep_embedding
|
19
20
|
|
20
|
-
ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
|
21
|
-
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
22
|
-
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
23
|
-
)
|
24
21
|
|
25
|
-
|
26
|
-
from generative.networks.nets.controlnet import ControlNet as ControlNetType
|
27
|
-
else:
|
28
|
-
ControlNetType = cast(type, ControlNet)
|
29
|
-
|
30
|
-
|
31
|
-
class ControlNetMaisi(ControlNetType):
|
22
|
+
class ControlNetMaisi(ControlNet):
|
32
23
|
"""
|
33
24
|
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
|
34
25
|
Diffusion Models" (https://arxiv.org/abs/2302.05543)
|
@@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType):
|
|
49
40
|
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
50
41
|
classes.
|
51
42
|
upcast_attention: if True, upcast attention operations to full precision.
|
52
|
-
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
53
43
|
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
|
54
44
|
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
|
55
45
|
use_checkpointing: if True, use activation checkpointing to save memory.
|
46
|
+
include_fc: whether to include the final linear layer. Default to False.
|
47
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
48
|
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
56
49
|
"""
|
57
50
|
|
58
51
|
def __init__(
|
@@ -71,10 +64,12 @@ class ControlNetMaisi(ControlNetType):
|
|
71
64
|
cross_attention_dim: int | None = None,
|
72
65
|
num_class_embeds: int | None = None,
|
73
66
|
upcast_attention: bool = False,
|
74
|
-
use_flash_attention: bool = False,
|
75
67
|
conditioning_embedding_in_channels: int = 1,
|
76
|
-
conditioning_embedding_num_channels: Sequence[int]
|
68
|
+
conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
|
77
69
|
use_checkpointing: bool = True,
|
70
|
+
include_fc: bool = False,
|
71
|
+
use_combined_linear: bool = False,
|
72
|
+
use_flash_attention: bool = False,
|
78
73
|
) -> None:
|
79
74
|
super().__init__(
|
80
75
|
spatial_dims,
|
@@ -91,9 +86,11 @@ class ControlNetMaisi(ControlNetType):
|
|
91
86
|
cross_attention_dim,
|
92
87
|
num_class_embeds,
|
93
88
|
upcast_attention,
|
94
|
-
use_flash_attention,
|
95
89
|
conditioning_embedding_in_channels,
|
96
90
|
conditioning_embedding_num_channels,
|
91
|
+
include_fc,
|
92
|
+
use_combined_linear,
|
93
|
+
use_flash_attention,
|
97
94
|
)
|
98
95
|
self.use_checkpointing = use_checkpointing
|
99
96
|
|
@@ -105,7 +102,7 @@ class ControlNetMaisi(ControlNetType):
|
|
105
102
|
conditioning_scale: float = 1.0,
|
106
103
|
context: torch.Tensor | None = None,
|
107
104
|
class_labels: torch.Tensor | None = None,
|
108
|
-
) -> tuple[
|
105
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
109
106
|
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
|
110
107
|
h = self._apply_initial_convolution(x)
|
111
108
|
if self.use_checkpointing:
|
@@ -37,21 +37,15 @@ import torch
|
|
37
37
|
from torch import nn
|
38
38
|
|
39
39
|
from monai.networks.blocks import Convolution
|
40
|
-
from monai.
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
get_mid_block, has_get_mid_block = optional_import(
|
47
|
-
"generative.networks.nets.diffusion_model_unet", name="get_mid_block"
|
48
|
-
)
|
49
|
-
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
50
|
-
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
40
|
+
from monai.networks.nets.diffusion_model_unet import (
|
41
|
+
get_down_block,
|
42
|
+
get_mid_block,
|
43
|
+
get_timestep_embedding,
|
44
|
+
get_up_block,
|
45
|
+
zero_module,
|
51
46
|
)
|
52
|
-
|
53
|
-
|
54
|
-
zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
|
47
|
+
from monai.utils import ensure_tuple_rep
|
48
|
+
from monai.utils.type_conversion import convert_to_tensor
|
55
49
|
|
56
50
|
__all__ = ["DiffusionModelUNetMaisi"]
|
57
51
|
|
@@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
78
72
|
cross_attention_dim: Number of context dimensions to use.
|
79
73
|
num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
|
80
74
|
upcast_attention: If True, upcast attention operations to full precision.
|
75
|
+
include_fc: whether to include the final linear layer. Default to False.
|
76
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
81
77
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
82
78
|
dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
|
83
79
|
include_top_region_index_input: If True, use top region index input.
|
@@ -102,6 +98,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
102
98
|
cross_attention_dim: int | None = None,
|
103
99
|
num_class_embeds: int | None = None,
|
104
100
|
upcast_attention: bool = False,
|
101
|
+
include_fc: bool = False,
|
102
|
+
use_combined_linear: bool = False,
|
105
103
|
use_flash_attention: bool = False,
|
106
104
|
dropout_cattn: float = 0.0,
|
107
105
|
include_top_region_index_input: bool = False,
|
@@ -152,9 +150,6 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
152
150
|
"`num_channels`."
|
153
151
|
)
|
154
152
|
|
155
|
-
if use_flash_attention and not has_xformers:
|
156
|
-
raise ValueError("use_flash_attention is True but xformers is not installed.")
|
157
|
-
|
158
153
|
if use_flash_attention is True and not torch.cuda.is_available():
|
159
154
|
raise ValueError(
|
160
155
|
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
|
@@ -210,7 +205,6 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
210
205
|
input_channel = output_channel
|
211
206
|
output_channel = num_channels[i]
|
212
207
|
is_final_block = i == len(num_channels) - 1
|
213
|
-
|
214
208
|
down_block = get_down_block(
|
215
209
|
spatial_dims=spatial_dims,
|
216
210
|
in_channels=input_channel,
|
@@ -227,6 +221,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
227
221
|
transformer_num_layers=transformer_num_layers,
|
228
222
|
cross_attention_dim=cross_attention_dim,
|
229
223
|
upcast_attention=upcast_attention,
|
224
|
+
include_fc=include_fc,
|
225
|
+
use_combined_linear=use_combined_linear,
|
230
226
|
use_flash_attention=use_flash_attention,
|
231
227
|
dropout_cattn=dropout_cattn,
|
232
228
|
)
|
@@ -245,6 +241,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
245
241
|
transformer_num_layers=transformer_num_layers,
|
246
242
|
cross_attention_dim=cross_attention_dim,
|
247
243
|
upcast_attention=upcast_attention,
|
244
|
+
include_fc=include_fc,
|
245
|
+
use_combined_linear=use_combined_linear,
|
248
246
|
use_flash_attention=use_flash_attention,
|
249
247
|
dropout_cattn=dropout_cattn,
|
250
248
|
)
|
@@ -280,6 +278,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
280
278
|
transformer_num_layers=transformer_num_layers,
|
281
279
|
cross_attention_dim=cross_attention_dim,
|
282
280
|
upcast_attention=upcast_attention,
|
281
|
+
include_fc=include_fc,
|
282
|
+
use_combined_linear=use_combined_linear,
|
283
283
|
use_flash_attention=use_flash_attention,
|
284
284
|
dropout_cattn=dropout_cattn,
|
285
285
|
)
|
monai/bundle/config_parser.py
CHANGED
@@ -118,7 +118,7 @@ class ConfigParser:
|
|
118
118
|
self.ref_resolver = ReferenceResolver()
|
119
119
|
if config is None:
|
120
120
|
config = {self.meta_key: {}}
|
121
|
-
self.set(config=config)
|
121
|
+
self.set(config=self.ref_resolver.normalize_meta_id(config))
|
122
122
|
|
123
123
|
def __repr__(self):
|
124
124
|
return f"{self.config}"
|
@@ -221,7 +221,7 @@ class ConfigParser:
|
|
221
221
|
if isinstance(conf_, dict) and k not in conf_:
|
222
222
|
conf_[k] = {}
|
223
223
|
conf_ = conf_[k if isinstance(conf_, dict) else int(k)]
|
224
|
-
self[ReferenceResolver.normalize_id(id)] = config
|
224
|
+
self[ReferenceResolver.normalize_id(id)] = self.ref_resolver.normalize_meta_id(config)
|
225
225
|
|
226
226
|
def update(self, pairs: dict[str, Any]) -> None:
|
227
227
|
"""
|
@@ -17,7 +17,7 @@ from collections.abc import Sequence
|
|
17
17
|
from typing import Any, Iterator
|
18
18
|
|
19
19
|
from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
|
20
|
-
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY
|
20
|
+
from monai.bundle.utils import DEPRECATED_ID_MAPPING, ID_REF_KEY, ID_SEP_KEY
|
21
21
|
from monai.utils import allow_missing_reference, look_up_option
|
22
22
|
|
23
23
|
__all__ = ["ReferenceResolver"]
|
@@ -202,6 +202,23 @@ class ReferenceResolver:
|
|
202
202
|
"""
|
203
203
|
return str(id).replace("#", cls.sep) # backward compatibility `#` is the old separator
|
204
204
|
|
205
|
+
def normalize_meta_id(self, config: Any) -> Any:
|
206
|
+
"""
|
207
|
+
Update deprecated identifiers in `config` using `DEPRECATED_ID_MAPPING`.
|
208
|
+
This will replace names that are marked as deprecated with their replacement.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
config: input config to be updated.
|
212
|
+
"""
|
213
|
+
if isinstance(config, dict):
|
214
|
+
for _id, _new_id in DEPRECATED_ID_MAPPING.items():
|
215
|
+
if _id in config.keys():
|
216
|
+
warnings.warn(
|
217
|
+
f"Detected deprecated name '{_id}' in configuration file, replacing with '{_new_id}'."
|
218
|
+
)
|
219
|
+
config[_new_id] = config.pop(_id)
|
220
|
+
return config
|
221
|
+
|
205
222
|
@classmethod
|
206
223
|
def split_id(cls, id: str | int, last: bool = False) -> list[str]:
|
207
224
|
"""
|
monai/bundle/scripts.py
CHANGED
@@ -18,6 +18,7 @@ import re
|
|
18
18
|
import warnings
|
19
19
|
import zipfile
|
20
20
|
from collections.abc import Mapping, Sequence
|
21
|
+
from functools import partial
|
21
22
|
from pathlib import Path
|
22
23
|
from pydoc import locate
|
23
24
|
from shutil import copyfile
|
@@ -217,10 +218,15 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
|
|
217
218
|
|
218
219
|
|
219
220
|
def _download_from_ngc(
|
220
|
-
download_path: Path,
|
221
|
+
download_path: Path,
|
222
|
+
filename: str,
|
223
|
+
version: str,
|
224
|
+
prefix: str = "monai_",
|
225
|
+
remove_prefix: str | None = "monai_",
|
226
|
+
progress: bool = True,
|
221
227
|
) -> None:
|
222
228
|
# ensure prefix is contained
|
223
|
-
filename = _add_ngc_prefix(filename)
|
229
|
+
filename = _add_ngc_prefix(filename, prefix=prefix)
|
224
230
|
url = _get_ngc_bundle_url(model_name=filename, version=version)
|
225
231
|
filepath = download_path / f"{filename}_v{version}.zip"
|
226
232
|
if remove_prefix:
|
@@ -231,10 +237,16 @@ def _download_from_ngc(
|
|
231
237
|
|
232
238
|
|
233
239
|
def _download_from_ngc_private(
|
234
|
-
download_path: Path,
|
240
|
+
download_path: Path,
|
241
|
+
filename: str,
|
242
|
+
version: str,
|
243
|
+
repo: str,
|
244
|
+
prefix: str = "monai_",
|
245
|
+
remove_prefix: str | None = "monai_",
|
246
|
+
headers: dict | None = None,
|
235
247
|
) -> None:
|
236
248
|
# ensure prefix is contained
|
237
|
-
filename = _add_ngc_prefix(filename)
|
249
|
+
filename = _add_ngc_prefix(filename, prefix=prefix)
|
238
250
|
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
|
239
251
|
if has_requests:
|
240
252
|
headers = {} if headers is None else headers
|
@@ -491,7 +503,7 @@ def download(
|
|
491
503
|
url: url to download the data. If not `None`, data will be downloaded directly
|
492
504
|
and `source` will not be checked.
|
493
505
|
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
|
494
|
-
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
|
506
|
+
remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles
|
495
507
|
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
|
496
508
|
maintain the consistency between these two sources, remove prefix is necessary.
|
497
509
|
Therefore, if specified, downloaded folder name will remove the prefix.
|
@@ -1243,6 +1255,7 @@ def verify_net_in_out(
|
|
1243
1255
|
|
1244
1256
|
def _export(
|
1245
1257
|
converter: Callable,
|
1258
|
+
saver: Callable,
|
1246
1259
|
parser: ConfigParser,
|
1247
1260
|
net_id: str,
|
1248
1261
|
filepath: str,
|
@@ -1257,6 +1270,8 @@ def _export(
|
|
1257
1270
|
Args:
|
1258
1271
|
converter: a callable object that takes a torch.nn.module and kwargs as input and
|
1259
1272
|
converts the module to another type.
|
1273
|
+
saver: a callable object that accepts the converted model to save, a filepath to save to, meta values
|
1274
|
+
(extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
|
1260
1275
|
parser: a ConfigParser of the bundle to be converted.
|
1261
1276
|
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
|
1262
1277
|
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
|
@@ -1296,14 +1311,9 @@ def _export(
|
|
1296
1311
|
# add .json extension to all extra files which are always encoded as JSON
|
1297
1312
|
extra_files = {k + ".json": v for k, v in extra_files.items()}
|
1298
1313
|
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
include_config_vals=False,
|
1303
|
-
append_timestamp=False,
|
1304
|
-
meta_values=parser.get().pop("_meta_", None),
|
1305
|
-
more_extra_files=extra_files,
|
1306
|
-
)
|
1314
|
+
meta_values = parser.get().pop("_meta_", None)
|
1315
|
+
saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)
|
1316
|
+
|
1307
1317
|
logger.info(f"exported to file: {filepath}.")
|
1308
1318
|
|
1309
1319
|
|
@@ -1402,17 +1412,23 @@ def onnx_export(
|
|
1402
1412
|
input_shape_ = _get_fake_input_shape(parser=parser)
|
1403
1413
|
|
1404
1414
|
inputs_ = [torch.rand(input_shape_)]
|
1405
|
-
net = parser.get_parsed_content(net_id_)
|
1406
|
-
if has_ignite:
|
1407
|
-
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
|
1408
|
-
Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
|
1409
|
-
else:
|
1410
|
-
ckpt = torch.load(ckpt_file_)
|
1411
|
-
copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])
|
1412
1415
|
|
1413
1416
|
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
|
1414
|
-
|
1415
|
-
|
1417
|
+
|
1418
|
+
def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:
|
1419
|
+
onnx.save(onnx_obj, filename_prefix_or_stream)
|
1420
|
+
|
1421
|
+
_export(
|
1422
|
+
convert_to_onnx,
|
1423
|
+
save_onnx,
|
1424
|
+
parser,
|
1425
|
+
net_id=net_id_,
|
1426
|
+
filepath=filepath_,
|
1427
|
+
ckpt_file=ckpt_file_,
|
1428
|
+
config_file=config_file_,
|
1429
|
+
key_in_ckpt=key_in_ckpt_,
|
1430
|
+
**converter_kwargs_,
|
1431
|
+
)
|
1416
1432
|
|
1417
1433
|
|
1418
1434
|
def ckpt_export(
|
@@ -1533,8 +1549,12 @@ def ckpt_export(
|
|
1533
1549
|
|
1534
1550
|
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
|
1535
1551
|
# Use the given converter to convert a model and save with metadata, config content
|
1552
|
+
|
1553
|
+
save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
|
1554
|
+
|
1536
1555
|
_export(
|
1537
1556
|
convert_to_torchscript,
|
1557
|
+
save_ts,
|
1538
1558
|
parser,
|
1539
1559
|
net_id=net_id_,
|
1540
1560
|
filepath=filepath_,
|
@@ -1704,8 +1724,11 @@ def trt_export(
|
|
1704
1724
|
}
|
1705
1725
|
converter_kwargs_.update(trt_api_parameters)
|
1706
1726
|
|
1727
|
+
save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
|
1728
|
+
|
1707
1729
|
_export(
|
1708
1730
|
convert_to_trt,
|
1731
|
+
save_ts,
|
1709
1732
|
parser,
|
1710
1733
|
net_id=net_id_,
|
1711
1734
|
filepath=filepath_,
|
monai/bundle/utils.py
CHANGED
@@ -36,7 +36,7 @@ DEFAULT_METADATA = {
|
|
36
36
|
"monai_version": _conf_values["MONAI"],
|
37
37
|
"pytorch_version": str(_conf_values["Pytorch"]).split("+")[0].split("a")[0], # 1.9.0a0+df837d0 or 1.13.0+cu117
|
38
38
|
"numpy_version": _conf_values["Numpy"],
|
39
|
-
"
|
39
|
+
"required_packages_version": {},
|
40
40
|
"task": "Describe what the network predicts",
|
41
41
|
"description": "A longer description of what the network does, use context, inputs, outputs, etc.",
|
42
42
|
"authors": "Your Name Here",
|
@@ -157,6 +157,8 @@ DEFAULT_MLFLOW_SETTINGS = {
|
|
157
157
|
|
158
158
|
DEFAULT_EXP_MGMT_SETTINGS = {"mlflow": DEFAULT_MLFLOW_SETTINGS} # default experiment management settings
|
159
159
|
|
160
|
+
DEPRECATED_ID_MAPPING = {"optional_packages_version": "required_packages_version"}
|
161
|
+
|
160
162
|
|
161
163
|
def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any) -> Any:
|
162
164
|
"""
|
monai/data/utils.py
CHANGED
@@ -927,7 +927,7 @@ def compute_shape_offset(
|
|
927
927
|
corners = in_affine_ @ corners
|
928
928
|
all_dist = corners_out[:-1].copy()
|
929
929
|
corners_out = corners_out[:-1] / corners_out[-1]
|
930
|
-
out_shape = np.round(
|
930
|
+
out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0)
|
931
931
|
offset = None
|
932
932
|
for i in range(corners.shape[1]):
|
933
933
|
min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
|
monai/data/wsi_datasets.py
CHANGED
@@ -23,7 +23,7 @@ from monai.data.meta_tensor import MetaTensor
|
|
23
23
|
from monai.data.utils import iter_patch_position
|
24
24
|
from monai.data.wsi_reader import BaseWSIReader, WSIReader
|
25
25
|
from monai.transforms import ForegroundMask, Randomizable, apply_transform
|
26
|
-
from monai.utils import convert_to_dst_type, ensure_tuple_rep
|
26
|
+
from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep
|
27
27
|
from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys
|
28
28
|
|
29
29
|
__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"]
|
@@ -123,9 +123,9 @@ class PatchWSIDataset(Dataset):
|
|
123
123
|
def _get_location(self, sample: dict):
|
124
124
|
if self.center_location:
|
125
125
|
size = self._get_size(sample)
|
126
|
-
return
|
126
|
+
return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size)))
|
127
127
|
else:
|
128
|
-
return sample[WSIPatchKeys.LOCATION]
|
128
|
+
return ensure_tuple(sample[WSIPatchKeys.LOCATION])
|
129
129
|
|
130
130
|
def _get_level(self, sample: dict):
|
131
131
|
if self.patch_level is None:
|