monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/bundle/scripts.py +29 -17
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +22 -0
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +908 -0
- monai/networks/utils.py +3 -3
- monai/transforms/__init__.py +1 -0
- monai/transforms/io/array.py +1 -1
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utils.py +183 -0
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +40 -37
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-08-
|
11
|
+
"date": "2024-08-25T02:21:56+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.4.
|
14
|
+
"full-revisionid": "dc611d231ba670004b1da1b011fe140375fb91af",
|
15
|
+
"version": "1.4.dev2434"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -13,25 +13,17 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
import gc
|
15
15
|
import logging
|
16
|
-
from typing import
|
16
|
+
from typing import Sequence
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
import torch.nn.functional as F
|
21
21
|
|
22
22
|
from monai.networks.blocks import Convolution
|
23
|
-
from monai.
|
23
|
+
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
|
24
|
+
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
|
24
25
|
from monai.utils.type_conversion import convert_to_tensor
|
25
26
|
|
26
|
-
AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
|
27
|
-
AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
|
28
|
-
ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
|
29
|
-
|
30
|
-
if TYPE_CHECKING:
|
31
|
-
from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
|
32
|
-
else:
|
33
|
-
AutoencoderKLType = cast(type, AutoencoderKL)
|
34
|
-
|
35
27
|
# Set up logging configuration
|
36
28
|
logger = logging.getLogger(__name__)
|
37
29
|
|
@@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
|
|
518
510
|
in_channels: Number of input channels.
|
519
511
|
num_channels: Sequence of block output channels.
|
520
512
|
out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
521
|
-
num_res_blocks: Number of residual blocks (see
|
513
|
+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
|
522
514
|
norm_num_groups: Number of groups for the group norm layers.
|
523
515
|
norm_eps: Epsilon for the normalization.
|
524
516
|
attention_levels: Indicate which level from num_channels contain an attention block.
|
525
517
|
with_nonlocal_attn: If True, use non-local attention block.
|
518
|
+
include_fc: whether to include the final linear layer in the attention block. Default to False.
|
519
|
+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
|
526
520
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
527
521
|
num_splits: Number of splits for the input tensor.
|
528
522
|
dim_split: Dimension of splitting for the input tensor.
|
@@ -547,6 +541,8 @@ class MaisiEncoder(nn.Module):
|
|
547
541
|
print_info: bool = False,
|
548
542
|
save_mem: bool = True,
|
549
543
|
with_nonlocal_attn: bool = True,
|
544
|
+
include_fc: bool = False,
|
545
|
+
use_combined_linear: bool = False,
|
550
546
|
use_flash_attention: bool = False,
|
551
547
|
) -> None:
|
552
548
|
super().__init__()
|
@@ -603,11 +599,13 @@ class MaisiEncoder(nn.Module):
|
|
603
599
|
input_channel = output_channel
|
604
600
|
if attention_levels[i]:
|
605
601
|
blocks.append(
|
606
|
-
|
602
|
+
SpatialAttentionBlock(
|
607
603
|
spatial_dims=spatial_dims,
|
608
604
|
num_channels=input_channel,
|
609
605
|
norm_num_groups=norm_num_groups,
|
610
606
|
norm_eps=norm_eps,
|
607
|
+
include_fc=include_fc,
|
608
|
+
use_combined_linear=use_combined_linear,
|
611
609
|
use_flash_attention=use_flash_attention,
|
612
610
|
)
|
613
611
|
)
|
@@ -626,7 +624,7 @@ class MaisiEncoder(nn.Module):
|
|
626
624
|
|
627
625
|
if with_nonlocal_attn:
|
628
626
|
blocks.append(
|
629
|
-
|
627
|
+
AEKLResBlock(
|
630
628
|
spatial_dims=spatial_dims,
|
631
629
|
in_channels=num_channels[-1],
|
632
630
|
norm_num_groups=norm_num_groups,
|
@@ -636,16 +634,18 @@ class MaisiEncoder(nn.Module):
|
|
636
634
|
)
|
637
635
|
|
638
636
|
blocks.append(
|
639
|
-
|
637
|
+
SpatialAttentionBlock(
|
640
638
|
spatial_dims=spatial_dims,
|
641
639
|
num_channels=num_channels[-1],
|
642
640
|
norm_num_groups=norm_num_groups,
|
643
641
|
norm_eps=norm_eps,
|
642
|
+
include_fc=include_fc,
|
643
|
+
use_combined_linear=use_combined_linear,
|
644
644
|
use_flash_attention=use_flash_attention,
|
645
645
|
)
|
646
646
|
)
|
647
647
|
blocks.append(
|
648
|
-
|
648
|
+
AEKLResBlock(
|
649
649
|
spatial_dims=spatial_dims,
|
650
650
|
in_channels=num_channels[-1],
|
651
651
|
norm_num_groups=norm_num_groups,
|
@@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
|
|
699
699
|
num_channels: Sequence of block output channels.
|
700
700
|
in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
|
701
701
|
out_channels: Number of output channels.
|
702
|
-
num_res_blocks: Number of residual blocks (see
|
702
|
+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
|
703
703
|
norm_num_groups: Number of groups for the group norm layers.
|
704
704
|
norm_eps: Epsilon for the normalization.
|
705
705
|
attention_levels: Indicate which level from num_channels contain an attention block.
|
706
706
|
with_nonlocal_attn: If True, use non-local attention block.
|
707
|
+
include_fc: whether to include the final linear layer in the attention block. Default to False.
|
708
|
+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
|
707
709
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
708
710
|
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
709
711
|
num_splits: Number of splits for the input tensor.
|
@@ -729,6 +731,8 @@ class MaisiDecoder(nn.Module):
|
|
729
731
|
print_info: bool = False,
|
730
732
|
save_mem: bool = True,
|
731
733
|
with_nonlocal_attn: bool = True,
|
734
|
+
include_fc: bool = False,
|
735
|
+
use_combined_linear: bool = False,
|
732
736
|
use_flash_attention: bool = False,
|
733
737
|
use_convtranspose: bool = False,
|
734
738
|
) -> None:
|
@@ -758,7 +762,7 @@ class MaisiDecoder(nn.Module):
|
|
758
762
|
|
759
763
|
if with_nonlocal_attn:
|
760
764
|
blocks.append(
|
761
|
-
|
765
|
+
AEKLResBlock(
|
762
766
|
spatial_dims=spatial_dims,
|
763
767
|
in_channels=reversed_block_out_channels[0],
|
764
768
|
norm_num_groups=norm_num_groups,
|
@@ -767,16 +771,18 @@ class MaisiDecoder(nn.Module):
|
|
767
771
|
)
|
768
772
|
)
|
769
773
|
blocks.append(
|
770
|
-
|
774
|
+
SpatialAttentionBlock(
|
771
775
|
spatial_dims=spatial_dims,
|
772
776
|
num_channels=reversed_block_out_channels[0],
|
773
777
|
norm_num_groups=norm_num_groups,
|
774
778
|
norm_eps=norm_eps,
|
779
|
+
include_fc=include_fc,
|
780
|
+
use_combined_linear=use_combined_linear,
|
775
781
|
use_flash_attention=use_flash_attention,
|
776
782
|
)
|
777
783
|
)
|
778
784
|
blocks.append(
|
779
|
-
|
785
|
+
AEKLResBlock(
|
780
786
|
spatial_dims=spatial_dims,
|
781
787
|
in_channels=reversed_block_out_channels[0],
|
782
788
|
norm_num_groups=norm_num_groups,
|
@@ -812,11 +818,13 @@ class MaisiDecoder(nn.Module):
|
|
812
818
|
|
813
819
|
if reversed_attention_levels[i]:
|
814
820
|
blocks.append(
|
815
|
-
|
821
|
+
SpatialAttentionBlock(
|
816
822
|
spatial_dims=spatial_dims,
|
817
823
|
num_channels=block_in_ch,
|
818
824
|
norm_num_groups=norm_num_groups,
|
819
825
|
norm_eps=norm_eps,
|
826
|
+
include_fc=include_fc,
|
827
|
+
use_combined_linear=use_combined_linear,
|
820
828
|
use_flash_attention=use_flash_attention,
|
821
829
|
)
|
822
830
|
)
|
@@ -870,7 +878,7 @@ class MaisiDecoder(nn.Module):
|
|
870
878
|
return x
|
871
879
|
|
872
880
|
|
873
|
-
class AutoencoderKlMaisi(
|
881
|
+
class AutoencoderKlMaisi(AutoencoderKL):
|
874
882
|
"""
|
875
883
|
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
|
876
884
|
|
@@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
886
894
|
norm_eps: Epsilon for the normalization.
|
887
895
|
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
|
888
896
|
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
|
897
|
+
include_fc: whether to include the final linear layer. Default to False.
|
898
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
889
899
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
890
900
|
use_checkpointing: If True, use activation checkpointing.
|
891
901
|
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
|
@@ -909,6 +919,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
909
919
|
norm_eps: float = 1e-6,
|
910
920
|
with_encoder_nonlocal_attn: bool = False,
|
911
921
|
with_decoder_nonlocal_attn: bool = False,
|
922
|
+
include_fc: bool = False,
|
923
|
+
use_combined_linear: bool = False,
|
912
924
|
use_flash_attention: bool = False,
|
913
925
|
use_checkpointing: bool = False,
|
914
926
|
use_convtranspose: bool = False,
|
@@ -930,12 +942,14 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
930
942
|
norm_eps,
|
931
943
|
with_encoder_nonlocal_attn,
|
932
944
|
with_decoder_nonlocal_attn,
|
933
|
-
use_flash_attention,
|
934
945
|
use_checkpointing,
|
935
946
|
use_convtranspose,
|
947
|
+
include_fc,
|
948
|
+
use_combined_linear,
|
949
|
+
use_flash_attention,
|
936
950
|
)
|
937
951
|
|
938
|
-
self.encoder = MaisiEncoder(
|
952
|
+
self.encoder: nn.Module = MaisiEncoder(
|
939
953
|
spatial_dims=spatial_dims,
|
940
954
|
in_channels=in_channels,
|
941
955
|
num_channels=num_channels,
|
@@ -945,6 +959,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
945
959
|
norm_eps=norm_eps,
|
946
960
|
attention_levels=attention_levels,
|
947
961
|
with_nonlocal_attn=with_encoder_nonlocal_attn,
|
962
|
+
include_fc=include_fc,
|
963
|
+
use_combined_linear=use_combined_linear,
|
948
964
|
use_flash_attention=use_flash_attention,
|
949
965
|
num_splits=num_splits,
|
950
966
|
dim_split=dim_split,
|
@@ -953,7 +969,7 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
953
969
|
save_mem=save_mem,
|
954
970
|
)
|
955
971
|
|
956
|
-
self.decoder = MaisiDecoder(
|
972
|
+
self.decoder: nn.Module = MaisiDecoder(
|
957
973
|
spatial_dims=spatial_dims,
|
958
974
|
num_channels=num_channels,
|
959
975
|
in_channels=latent_channels,
|
@@ -963,6 +979,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
|
|
963
979
|
norm_eps=norm_eps,
|
964
980
|
attention_levels=attention_levels,
|
965
981
|
with_nonlocal_attn=with_decoder_nonlocal_attn,
|
982
|
+
include_fc=include_fc,
|
983
|
+
use_combined_linear=use_combined_linear,
|
966
984
|
use_flash_attention=use_flash_attention,
|
967
985
|
use_convtranspose=use_convtranspose,
|
968
986
|
num_splits=num_splits,
|
@@ -11,24 +11,15 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
-
from typing import
|
14
|
+
from typing import Sequence
|
15
15
|
|
16
16
|
import torch
|
17
17
|
|
18
|
-
from monai.
|
18
|
+
from monai.networks.nets.controlnet import ControlNet
|
19
|
+
from monai.networks.nets.diffusion_model_unet import get_timestep_embedding
|
19
20
|
|
20
|
-
ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
|
21
|
-
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
22
|
-
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
23
|
-
)
|
24
21
|
|
25
|
-
|
26
|
-
from generative.networks.nets.controlnet import ControlNet as ControlNetType
|
27
|
-
else:
|
28
|
-
ControlNetType = cast(type, ControlNet)
|
29
|
-
|
30
|
-
|
31
|
-
class ControlNetMaisi(ControlNetType):
|
22
|
+
class ControlNetMaisi(ControlNet):
|
32
23
|
"""
|
33
24
|
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
|
34
25
|
Diffusion Models" (https://arxiv.org/abs/2302.05543)
|
@@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType):
|
|
49
40
|
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
50
41
|
classes.
|
51
42
|
upcast_attention: if True, upcast attention operations to full precision.
|
52
|
-
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
53
43
|
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
|
54
44
|
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
|
55
45
|
use_checkpointing: if True, use activation checkpointing to save memory.
|
46
|
+
include_fc: whether to include the final linear layer. Default to False.
|
47
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
48
|
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
56
49
|
"""
|
57
50
|
|
58
51
|
def __init__(
|
@@ -71,10 +64,12 @@ class ControlNetMaisi(ControlNetType):
|
|
71
64
|
cross_attention_dim: int | None = None,
|
72
65
|
num_class_embeds: int | None = None,
|
73
66
|
upcast_attention: bool = False,
|
74
|
-
use_flash_attention: bool = False,
|
75
67
|
conditioning_embedding_in_channels: int = 1,
|
76
|
-
conditioning_embedding_num_channels: Sequence[int]
|
68
|
+
conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
|
77
69
|
use_checkpointing: bool = True,
|
70
|
+
include_fc: bool = False,
|
71
|
+
use_combined_linear: bool = False,
|
72
|
+
use_flash_attention: bool = False,
|
78
73
|
) -> None:
|
79
74
|
super().__init__(
|
80
75
|
spatial_dims,
|
@@ -91,9 +86,11 @@ class ControlNetMaisi(ControlNetType):
|
|
91
86
|
cross_attention_dim,
|
92
87
|
num_class_embeds,
|
93
88
|
upcast_attention,
|
94
|
-
use_flash_attention,
|
95
89
|
conditioning_embedding_in_channels,
|
96
90
|
conditioning_embedding_num_channels,
|
91
|
+
include_fc,
|
92
|
+
use_combined_linear,
|
93
|
+
use_flash_attention,
|
97
94
|
)
|
98
95
|
self.use_checkpointing = use_checkpointing
|
99
96
|
|
@@ -105,7 +102,7 @@ class ControlNetMaisi(ControlNetType):
|
|
105
102
|
conditioning_scale: float = 1.0,
|
106
103
|
context: torch.Tensor | None = None,
|
107
104
|
class_labels: torch.Tensor | None = None,
|
108
|
-
) -> tuple[
|
105
|
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
109
106
|
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
|
110
107
|
h = self._apply_initial_convolution(x)
|
111
108
|
if self.use_checkpointing:
|
@@ -37,21 +37,15 @@ import torch
|
|
37
37
|
from torch import nn
|
38
38
|
|
39
39
|
from monai.networks.blocks import Convolution
|
40
|
-
from monai.
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
get_mid_block, has_get_mid_block = optional_import(
|
47
|
-
"generative.networks.nets.diffusion_model_unet", name="get_mid_block"
|
48
|
-
)
|
49
|
-
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
50
|
-
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
40
|
+
from monai.networks.nets.diffusion_model_unet import (
|
41
|
+
get_down_block,
|
42
|
+
get_mid_block,
|
43
|
+
get_timestep_embedding,
|
44
|
+
get_up_block,
|
45
|
+
zero_module,
|
51
46
|
)
|
52
|
-
|
53
|
-
|
54
|
-
zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
|
47
|
+
from monai.utils import ensure_tuple_rep
|
48
|
+
from monai.utils.type_conversion import convert_to_tensor
|
55
49
|
|
56
50
|
__all__ = ["DiffusionModelUNetMaisi"]
|
57
51
|
|
@@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
78
72
|
cross_attention_dim: Number of context dimensions to use.
|
79
73
|
num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
|
80
74
|
upcast_attention: If True, upcast attention operations to full precision.
|
75
|
+
include_fc: whether to include the final linear layer. Default to False.
|
76
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
81
77
|
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
82
78
|
dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
|
83
79
|
include_top_region_index_input: If True, use top region index input.
|
@@ -102,6 +98,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
102
98
|
cross_attention_dim: int | None = None,
|
103
99
|
num_class_embeds: int | None = None,
|
104
100
|
upcast_attention: bool = False,
|
101
|
+
include_fc: bool = False,
|
102
|
+
use_combined_linear: bool = False,
|
105
103
|
use_flash_attention: bool = False,
|
106
104
|
dropout_cattn: float = 0.0,
|
107
105
|
include_top_region_index_input: bool = False,
|
@@ -152,9 +150,6 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
152
150
|
"`num_channels`."
|
153
151
|
)
|
154
152
|
|
155
|
-
if use_flash_attention and not has_xformers:
|
156
|
-
raise ValueError("use_flash_attention is True but xformers is not installed.")
|
157
|
-
|
158
153
|
if use_flash_attention is True and not torch.cuda.is_available():
|
159
154
|
raise ValueError(
|
160
155
|
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
|
@@ -210,7 +205,6 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
210
205
|
input_channel = output_channel
|
211
206
|
output_channel = num_channels[i]
|
212
207
|
is_final_block = i == len(num_channels) - 1
|
213
|
-
|
214
208
|
down_block = get_down_block(
|
215
209
|
spatial_dims=spatial_dims,
|
216
210
|
in_channels=input_channel,
|
@@ -227,6 +221,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
227
221
|
transformer_num_layers=transformer_num_layers,
|
228
222
|
cross_attention_dim=cross_attention_dim,
|
229
223
|
upcast_attention=upcast_attention,
|
224
|
+
include_fc=include_fc,
|
225
|
+
use_combined_linear=use_combined_linear,
|
230
226
|
use_flash_attention=use_flash_attention,
|
231
227
|
dropout_cattn=dropout_cattn,
|
232
228
|
)
|
@@ -245,6 +241,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
245
241
|
transformer_num_layers=transformer_num_layers,
|
246
242
|
cross_attention_dim=cross_attention_dim,
|
247
243
|
upcast_attention=upcast_attention,
|
244
|
+
include_fc=include_fc,
|
245
|
+
use_combined_linear=use_combined_linear,
|
248
246
|
use_flash_attention=use_flash_attention,
|
249
247
|
dropout_cattn=dropout_cattn,
|
250
248
|
)
|
@@ -280,6 +278,8 @@ class DiffusionModelUNetMaisi(nn.Module):
|
|
280
278
|
transformer_num_layers=transformer_num_layers,
|
281
279
|
cross_attention_dim=cross_attention_dim,
|
282
280
|
upcast_attention=upcast_attention,
|
281
|
+
include_fc=include_fc,
|
282
|
+
use_combined_linear=use_combined_linear,
|
283
283
|
use_flash_attention=use_flash_attention,
|
284
284
|
dropout_cattn=dropout_cattn,
|
285
285
|
)
|
monai/bundle/scripts.py
CHANGED
@@ -18,6 +18,7 @@ import re
|
|
18
18
|
import warnings
|
19
19
|
import zipfile
|
20
20
|
from collections.abc import Mapping, Sequence
|
21
|
+
from functools import partial
|
21
22
|
from pathlib import Path
|
22
23
|
from pydoc import locate
|
23
24
|
from shutil import copyfile
|
@@ -1254,6 +1255,7 @@ def verify_net_in_out(
|
|
1254
1255
|
|
1255
1256
|
def _export(
|
1256
1257
|
converter: Callable,
|
1258
|
+
saver: Callable,
|
1257
1259
|
parser: ConfigParser,
|
1258
1260
|
net_id: str,
|
1259
1261
|
filepath: str,
|
@@ -1268,6 +1270,8 @@ def _export(
|
|
1268
1270
|
Args:
|
1269
1271
|
converter: a callable object that takes a torch.nn.module and kwargs as input and
|
1270
1272
|
converts the module to another type.
|
1273
|
+
saver: a callable object that accepts the converted model to save, a filepath to save to, meta values
|
1274
|
+
(extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
|
1271
1275
|
parser: a ConfigParser of the bundle to be converted.
|
1272
1276
|
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
|
1273
1277
|
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
|
@@ -1307,14 +1311,9 @@ def _export(
|
|
1307
1311
|
# add .json extension to all extra files which are always encoded as JSON
|
1308
1312
|
extra_files = {k + ".json": v for k, v in extra_files.items()}
|
1309
1313
|
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
include_config_vals=False,
|
1314
|
-
append_timestamp=False,
|
1315
|
-
meta_values=parser.get().pop("_meta_", None),
|
1316
|
-
more_extra_files=extra_files,
|
1317
|
-
)
|
1314
|
+
meta_values = parser.get().pop("_meta_", None)
|
1315
|
+
saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)
|
1316
|
+
|
1318
1317
|
logger.info(f"exported to file: {filepath}.")
|
1319
1318
|
|
1320
1319
|
|
@@ -1413,17 +1412,23 @@ def onnx_export(
|
|
1413
1412
|
input_shape_ = _get_fake_input_shape(parser=parser)
|
1414
1413
|
|
1415
1414
|
inputs_ = [torch.rand(input_shape_)]
|
1416
|
-
net = parser.get_parsed_content(net_id_)
|
1417
|
-
if has_ignite:
|
1418
|
-
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
|
1419
|
-
Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
|
1420
|
-
else:
|
1421
|
-
ckpt = torch.load(ckpt_file_)
|
1422
|
-
copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])
|
1423
1415
|
|
1424
1416
|
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
|
1425
|
-
|
1426
|
-
|
1417
|
+
|
1418
|
+
def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:
|
1419
|
+
onnx.save(onnx_obj, filename_prefix_or_stream)
|
1420
|
+
|
1421
|
+
_export(
|
1422
|
+
convert_to_onnx,
|
1423
|
+
save_onnx,
|
1424
|
+
parser,
|
1425
|
+
net_id=net_id_,
|
1426
|
+
filepath=filepath_,
|
1427
|
+
ckpt_file=ckpt_file_,
|
1428
|
+
config_file=config_file_,
|
1429
|
+
key_in_ckpt=key_in_ckpt_,
|
1430
|
+
**converter_kwargs_,
|
1431
|
+
)
|
1427
1432
|
|
1428
1433
|
|
1429
1434
|
def ckpt_export(
|
@@ -1544,8 +1549,12 @@ def ckpt_export(
|
|
1544
1549
|
|
1545
1550
|
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
|
1546
1551
|
# Use the given converter to convert a model and save with metadata, config content
|
1552
|
+
|
1553
|
+
save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
|
1554
|
+
|
1547
1555
|
_export(
|
1548
1556
|
convert_to_torchscript,
|
1557
|
+
save_ts,
|
1549
1558
|
parser,
|
1550
1559
|
net_id=net_id_,
|
1551
1560
|
filepath=filepath_,
|
@@ -1715,8 +1724,11 @@ def trt_export(
|
|
1715
1724
|
}
|
1716
1725
|
converter_kwargs_.update(trt_api_parameters)
|
1717
1726
|
|
1727
|
+
save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
|
1728
|
+
|
1718
1729
|
_export(
|
1719
1730
|
convert_to_trt,
|
1731
|
+
save_ts,
|
1720
1732
|
parser,
|
1721
1733
|
net_id=net_id_,
|
1722
1734
|
filepath=filepath_,
|
monai/data/utils.py
CHANGED
@@ -927,7 +927,7 @@ def compute_shape_offset(
|
|
927
927
|
corners = in_affine_ @ corners
|
928
928
|
all_dist = corners_out[:-1].copy()
|
929
929
|
corners_out = corners_out[:-1] / corners_out[-1]
|
930
|
-
out_shape = np.round(
|
930
|
+
out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0)
|
931
931
|
offset = None
|
932
932
|
for i in range(corners.shape[1]):
|
933
933
|
min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
|
monai/data/wsi_datasets.py
CHANGED
@@ -23,7 +23,7 @@ from monai.data.meta_tensor import MetaTensor
|
|
23
23
|
from monai.data.utils import iter_patch_position
|
24
24
|
from monai.data.wsi_reader import BaseWSIReader, WSIReader
|
25
25
|
from monai.transforms import ForegroundMask, Randomizable, apply_transform
|
26
|
-
from monai.utils import convert_to_dst_type, ensure_tuple_rep
|
26
|
+
from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep
|
27
27
|
from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys
|
28
28
|
|
29
29
|
__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"]
|
@@ -123,9 +123,9 @@ class PatchWSIDataset(Dataset):
|
|
123
123
|
def _get_location(self, sample: dict):
|
124
124
|
if self.center_location:
|
125
125
|
size = self._get_size(sample)
|
126
|
-
return
|
126
|
+
return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size)))
|
127
127
|
else:
|
128
|
-
return sample[WSIPatchKeys.LOCATION]
|
128
|
+
return ensure_tuple(sample[WSIPatchKeys.LOCATION])
|
129
129
|
|
130
130
|
def _get_level(self, sample: dict):
|
131
131
|
if self.patch_level is None:
|
monai/losses/__init__.py
CHANGED
@@ -37,6 +37,7 @@ from .giou_loss import BoxGIoULoss, giou
|
|
37
37
|
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
|
38
38
|
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
|
39
39
|
from .multi_scale import MultiScaleLoss
|
40
|
+
from .nacl_loss import NACLLoss
|
40
41
|
from .perceptual import PerceptualLoss
|
41
42
|
from .spatial_mask import MaskedLoss
|
42
43
|
from .spectral_loss import JukeboxLoss
|
monai/losses/dice.py
CHANGED
@@ -666,6 +666,7 @@ class DiceCELoss(_Loss):
|
|
666
666
|
weight: torch.Tensor | None = None,
|
667
667
|
lambda_dice: float = 1.0,
|
668
668
|
lambda_ce: float = 1.0,
|
669
|
+
label_smoothing: float = 0.0,
|
669
670
|
) -> None:
|
670
671
|
"""
|
671
672
|
Args:
|
@@ -704,6 +705,9 @@ class DiceCELoss(_Loss):
|
|
704
705
|
Defaults to 1.0.
|
705
706
|
lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
|
706
707
|
Defaults to 1.0.
|
708
|
+
label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
|
709
|
+
by the given factor to reduce overfitting.
|
710
|
+
Defaults to 0.0.
|
707
711
|
|
708
712
|
"""
|
709
713
|
super().__init__()
|
@@ -728,7 +732,12 @@ class DiceCELoss(_Loss):
|
|
728
732
|
batch=batch,
|
729
733
|
weight=dice_weight,
|
730
734
|
)
|
731
|
-
|
735
|
+
if pytorch_after(1, 10):
|
736
|
+
self.cross_entropy = nn.CrossEntropyLoss(
|
737
|
+
weight=weight, reduction=reduction, label_smoothing=label_smoothing
|
738
|
+
)
|
739
|
+
else:
|
740
|
+
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
|
732
741
|
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
|
733
742
|
if lambda_dice < 0.0:
|
734
743
|
raise ValueError("lambda_dice should be no less than 0.0.")
|