monai-weekly 1.4.dev2430__py3-none-any.whl → 1.4.dev2434__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/bundle/config_parser.py +2 -2
  7. monai/bundle/reference_resolver.py +18 -1
  8. monai/bundle/scripts.py +45 -22
  9. monai/bundle/utils.py +3 -1
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/losses/__init__.py +1 -0
  13. monai/losses/dice.py +10 -1
  14. monai/losses/nacl_loss.py +139 -0
  15. monai/networks/blocks/crossattention.py +48 -26
  16. monai/networks/blocks/mlp.py +16 -4
  17. monai/networks/blocks/selfattention.py +75 -23
  18. monai/networks/blocks/spatialattention.py +16 -1
  19. monai/networks/blocks/transformerblock.py +17 -2
  20. monai/networks/nets/__init__.py +2 -1
  21. monai/networks/nets/autoencoderkl.py +55 -22
  22. monai/networks/nets/cell_sam_wrapper.py +92 -0
  23. monai/networks/nets/controlnet.py +24 -22
  24. monai/networks/nets/diffusion_model_unet.py +159 -19
  25. monai/networks/nets/segresnet_ds.py +127 -1
  26. monai/networks/nets/spade_autoencoderkl.py +24 -2
  27. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  28. monai/networks/nets/transformer.py +17 -17
  29. monai/networks/nets/vista3d.py +908 -0
  30. monai/networks/utils.py +3 -3
  31. monai/transforms/__init__.py +1 -0
  32. monai/transforms/io/array.py +1 -1
  33. monai/transforms/post/array.py +2 -1
  34. monai/transforms/spatial/functional.py +1 -1
  35. monai/transforms/transform.py +2 -2
  36. monai/transforms/utils.py +183 -0
  37. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  38. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  39. monai/utils/module.py +7 -6
  40. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/METADATA +83 -81
  41. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/RECORD +44 -41
  42. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/WHEEL +1 -1
  43. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/LICENSE +0 -0
  44. {monai_weekly-1.4.dev2430.dist-info → monai_weekly-1.4.dev2434.dist-info}/top_level.txt +0 -0
monai/__init__.py CHANGED
@@ -93,4 +93,4 @@ except BaseException:
93
93
 
94
94
  if MONAIEnvVars.debug():
95
95
  raise
96
- __commit_id__ = "2e53df78e580131046dc8db7f7638063db1f5045"
96
+ __commit_id__ = "a5fbe716378948630783deef8ee435e7e3bdc918"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-28T02:19:22+0000",
11
+ "date": "2024-08-25T02:21:56+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "9dd92b4a07706d4b80edace3d39fe008dc805d5a",
15
- "version": "1.4.dev2430"
14
+ "full-revisionid": "dc611d231ba670004b1da1b011fe140375fb91af",
15
+ "version": "1.4.dev2434"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -13,25 +13,17 @@ from __future__ import annotations
13
13
 
14
14
  import gc
15
15
  import logging
16
- from typing import TYPE_CHECKING, Sequence, cast
16
+ from typing import Sequence
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  import torch.nn.functional as F
21
21
 
22
22
  from monai.networks.blocks import Convolution
23
- from monai.utils import optional_import
23
+ from monai.networks.blocks.spatialattention import SpatialAttentionBlock
24
+ from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
24
25
  from monai.utils.type_conversion import convert_to_tensor
25
26
 
26
- AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
27
- AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
28
- ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
29
-
30
- if TYPE_CHECKING:
31
- from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
32
- else:
33
- AutoencoderKLType = cast(type, AutoencoderKL)
34
-
35
27
  # Set up logging configuration
36
28
  logger = logging.getLogger(__name__)
37
29
 
@@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
518
510
  in_channels: Number of input channels.
519
511
  num_channels: Sequence of block output channels.
520
512
  out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
521
- num_res_blocks: Number of residual blocks (see ResBlock) per level.
513
+ num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
522
514
  norm_num_groups: Number of groups for the group norm layers.
523
515
  norm_eps: Epsilon for the normalization.
524
516
  attention_levels: Indicate which level from num_channels contain an attention block.
525
517
  with_nonlocal_attn: If True, use non-local attention block.
518
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
519
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
526
520
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
527
521
  num_splits: Number of splits for the input tensor.
528
522
  dim_split: Dimension of splitting for the input tensor.
@@ -547,6 +541,8 @@ class MaisiEncoder(nn.Module):
547
541
  print_info: bool = False,
548
542
  save_mem: bool = True,
549
543
  with_nonlocal_attn: bool = True,
544
+ include_fc: bool = False,
545
+ use_combined_linear: bool = False,
550
546
  use_flash_attention: bool = False,
551
547
  ) -> None:
552
548
  super().__init__()
@@ -603,11 +599,13 @@ class MaisiEncoder(nn.Module):
603
599
  input_channel = output_channel
604
600
  if attention_levels[i]:
605
601
  blocks.append(
606
- AttentionBlock(
602
+ SpatialAttentionBlock(
607
603
  spatial_dims=spatial_dims,
608
604
  num_channels=input_channel,
609
605
  norm_num_groups=norm_num_groups,
610
606
  norm_eps=norm_eps,
607
+ include_fc=include_fc,
608
+ use_combined_linear=use_combined_linear,
611
609
  use_flash_attention=use_flash_attention,
612
610
  )
613
611
  )
@@ -626,7 +624,7 @@ class MaisiEncoder(nn.Module):
626
624
 
627
625
  if with_nonlocal_attn:
628
626
  blocks.append(
629
- ResBlock(
627
+ AEKLResBlock(
630
628
  spatial_dims=spatial_dims,
631
629
  in_channels=num_channels[-1],
632
630
  norm_num_groups=norm_num_groups,
@@ -636,16 +634,18 @@ class MaisiEncoder(nn.Module):
636
634
  )
637
635
 
638
636
  blocks.append(
639
- AttentionBlock(
637
+ SpatialAttentionBlock(
640
638
  spatial_dims=spatial_dims,
641
639
  num_channels=num_channels[-1],
642
640
  norm_num_groups=norm_num_groups,
643
641
  norm_eps=norm_eps,
642
+ include_fc=include_fc,
643
+ use_combined_linear=use_combined_linear,
644
644
  use_flash_attention=use_flash_attention,
645
645
  )
646
646
  )
647
647
  blocks.append(
648
- ResBlock(
648
+ AEKLResBlock(
649
649
  spatial_dims=spatial_dims,
650
650
  in_channels=num_channels[-1],
651
651
  norm_num_groups=norm_num_groups,
@@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
699
699
  num_channels: Sequence of block output channels.
700
700
  in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
701
701
  out_channels: Number of output channels.
702
- num_res_blocks: Number of residual blocks (see ResBlock) per level.
702
+ num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
703
703
  norm_num_groups: Number of groups for the group norm layers.
704
704
  norm_eps: Epsilon for the normalization.
705
705
  attention_levels: Indicate which level from num_channels contain an attention block.
706
706
  with_nonlocal_attn: If True, use non-local attention block.
707
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
708
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
707
709
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
708
710
  use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
709
711
  num_splits: Number of splits for the input tensor.
@@ -729,6 +731,8 @@ class MaisiDecoder(nn.Module):
729
731
  print_info: bool = False,
730
732
  save_mem: bool = True,
731
733
  with_nonlocal_attn: bool = True,
734
+ include_fc: bool = False,
735
+ use_combined_linear: bool = False,
732
736
  use_flash_attention: bool = False,
733
737
  use_convtranspose: bool = False,
734
738
  ) -> None:
@@ -758,7 +762,7 @@ class MaisiDecoder(nn.Module):
758
762
 
759
763
  if with_nonlocal_attn:
760
764
  blocks.append(
761
- ResBlock(
765
+ AEKLResBlock(
762
766
  spatial_dims=spatial_dims,
763
767
  in_channels=reversed_block_out_channels[0],
764
768
  norm_num_groups=norm_num_groups,
@@ -767,16 +771,18 @@ class MaisiDecoder(nn.Module):
767
771
  )
768
772
  )
769
773
  blocks.append(
770
- AttentionBlock(
774
+ SpatialAttentionBlock(
771
775
  spatial_dims=spatial_dims,
772
776
  num_channels=reversed_block_out_channels[0],
773
777
  norm_num_groups=norm_num_groups,
774
778
  norm_eps=norm_eps,
779
+ include_fc=include_fc,
780
+ use_combined_linear=use_combined_linear,
775
781
  use_flash_attention=use_flash_attention,
776
782
  )
777
783
  )
778
784
  blocks.append(
779
- ResBlock(
785
+ AEKLResBlock(
780
786
  spatial_dims=spatial_dims,
781
787
  in_channels=reversed_block_out_channels[0],
782
788
  norm_num_groups=norm_num_groups,
@@ -812,11 +818,13 @@ class MaisiDecoder(nn.Module):
812
818
 
813
819
  if reversed_attention_levels[i]:
814
820
  blocks.append(
815
- AttentionBlock(
821
+ SpatialAttentionBlock(
816
822
  spatial_dims=spatial_dims,
817
823
  num_channels=block_in_ch,
818
824
  norm_num_groups=norm_num_groups,
819
825
  norm_eps=norm_eps,
826
+ include_fc=include_fc,
827
+ use_combined_linear=use_combined_linear,
820
828
  use_flash_attention=use_flash_attention,
821
829
  )
822
830
  )
@@ -870,7 +878,7 @@ class MaisiDecoder(nn.Module):
870
878
  return x
871
879
 
872
880
 
873
- class AutoencoderKlMaisi(AutoencoderKLType):
881
+ class AutoencoderKlMaisi(AutoencoderKL):
874
882
  """
875
883
  AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
876
884
 
@@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
886
894
  norm_eps: Epsilon for the normalization.
887
895
  with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
888
896
  with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
897
+ include_fc: whether to include the final linear layer. Default to False.
898
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
889
899
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
890
900
  use_checkpointing: If True, use activation checkpointing.
891
901
  use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
@@ -909,6 +919,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
909
919
  norm_eps: float = 1e-6,
910
920
  with_encoder_nonlocal_attn: bool = False,
911
921
  with_decoder_nonlocal_attn: bool = False,
922
+ include_fc: bool = False,
923
+ use_combined_linear: bool = False,
912
924
  use_flash_attention: bool = False,
913
925
  use_checkpointing: bool = False,
914
926
  use_convtranspose: bool = False,
@@ -930,12 +942,14 @@ class AutoencoderKlMaisi(AutoencoderKLType):
930
942
  norm_eps,
931
943
  with_encoder_nonlocal_attn,
932
944
  with_decoder_nonlocal_attn,
933
- use_flash_attention,
934
945
  use_checkpointing,
935
946
  use_convtranspose,
947
+ include_fc,
948
+ use_combined_linear,
949
+ use_flash_attention,
936
950
  )
937
951
 
938
- self.encoder = MaisiEncoder(
952
+ self.encoder: nn.Module = MaisiEncoder(
939
953
  spatial_dims=spatial_dims,
940
954
  in_channels=in_channels,
941
955
  num_channels=num_channels,
@@ -945,6 +959,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
945
959
  norm_eps=norm_eps,
946
960
  attention_levels=attention_levels,
947
961
  with_nonlocal_attn=with_encoder_nonlocal_attn,
962
+ include_fc=include_fc,
963
+ use_combined_linear=use_combined_linear,
948
964
  use_flash_attention=use_flash_attention,
949
965
  num_splits=num_splits,
950
966
  dim_split=dim_split,
@@ -953,7 +969,7 @@ class AutoencoderKlMaisi(AutoencoderKLType):
953
969
  save_mem=save_mem,
954
970
  )
955
971
 
956
- self.decoder = MaisiDecoder(
972
+ self.decoder: nn.Module = MaisiDecoder(
957
973
  spatial_dims=spatial_dims,
958
974
  num_channels=num_channels,
959
975
  in_channels=latent_channels,
@@ -963,6 +979,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
963
979
  norm_eps=norm_eps,
964
980
  attention_levels=attention_levels,
965
981
  with_nonlocal_attn=with_decoder_nonlocal_attn,
982
+ include_fc=include_fc,
983
+ use_combined_linear=use_combined_linear,
966
984
  use_flash_attention=use_flash_attention,
967
985
  use_convtranspose=use_convtranspose,
968
986
  num_splits=num_splits,
@@ -11,24 +11,15 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from typing import TYPE_CHECKING, Sequence, cast
14
+ from typing import Sequence
15
15
 
16
16
  import torch
17
17
 
18
- from monai.utils import optional_import
18
+ from monai.networks.nets.controlnet import ControlNet
19
+ from monai.networks.nets.diffusion_model_unet import get_timestep_embedding
19
20
 
20
- ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
21
- get_timestep_embedding, has_get_timestep_embedding = optional_import(
22
- "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
23
- )
24
21
 
25
- if TYPE_CHECKING:
26
- from generative.networks.nets.controlnet import ControlNet as ControlNetType
27
- else:
28
- ControlNetType = cast(type, ControlNet)
29
-
30
-
31
- class ControlNetMaisi(ControlNetType):
22
+ class ControlNetMaisi(ControlNet):
32
23
  """
33
24
  Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
34
25
  Diffusion Models" (https://arxiv.org/abs/2302.05543)
@@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType):
49
40
  num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
50
41
  classes.
51
42
  upcast_attention: if True, upcast attention operations to full precision.
52
- use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
53
43
  conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
54
44
  conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
55
45
  use_checkpointing: if True, use activation checkpointing to save memory.
46
+ include_fc: whether to include the final linear layer. Default to False.
47
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
48
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
56
49
  """
57
50
 
58
51
  def __init__(
@@ -71,10 +64,12 @@ class ControlNetMaisi(ControlNetType):
71
64
  cross_attention_dim: int | None = None,
72
65
  num_class_embeds: int | None = None,
73
66
  upcast_attention: bool = False,
74
- use_flash_attention: bool = False,
75
67
  conditioning_embedding_in_channels: int = 1,
76
- conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
68
+ conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
77
69
  use_checkpointing: bool = True,
70
+ include_fc: bool = False,
71
+ use_combined_linear: bool = False,
72
+ use_flash_attention: bool = False,
78
73
  ) -> None:
79
74
  super().__init__(
80
75
  spatial_dims,
@@ -91,9 +86,11 @@ class ControlNetMaisi(ControlNetType):
91
86
  cross_attention_dim,
92
87
  num_class_embeds,
93
88
  upcast_attention,
94
- use_flash_attention,
95
89
  conditioning_embedding_in_channels,
96
90
  conditioning_embedding_num_channels,
91
+ include_fc,
92
+ use_combined_linear,
93
+ use_flash_attention,
97
94
  )
98
95
  self.use_checkpointing = use_checkpointing
99
96
 
@@ -105,7 +102,7 @@ class ControlNetMaisi(ControlNetType):
105
102
  conditioning_scale: float = 1.0,
106
103
  context: torch.Tensor | None = None,
107
104
  class_labels: torch.Tensor | None = None,
108
- ) -> tuple[Sequence[torch.Tensor], torch.Tensor]:
105
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
109
106
  emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
110
107
  h = self._apply_initial_convolution(x)
111
108
  if self.use_checkpointing:
@@ -37,21 +37,15 @@ import torch
37
37
  from torch import nn
38
38
 
39
39
  from monai.networks.blocks import Convolution
40
- from monai.utils import ensure_tuple_rep, optional_import
41
- from monai.utils.type_conversion import convert_to_tensor
42
-
43
- get_down_block, has_get_down_block = optional_import(
44
- "generative.networks.nets.diffusion_model_unet", name="get_down_block"
45
- )
46
- get_mid_block, has_get_mid_block = optional_import(
47
- "generative.networks.nets.diffusion_model_unet", name="get_mid_block"
48
- )
49
- get_timestep_embedding, has_get_timestep_embedding = optional_import(
50
- "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
40
+ from monai.networks.nets.diffusion_model_unet import (
41
+ get_down_block,
42
+ get_mid_block,
43
+ get_timestep_embedding,
44
+ get_up_block,
45
+ zero_module,
51
46
  )
52
- get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block")
53
- xformers, has_xformers = optional_import("xformers")
54
- zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
47
+ from monai.utils import ensure_tuple_rep
48
+ from monai.utils.type_conversion import convert_to_tensor
55
49
 
56
50
  __all__ = ["DiffusionModelUNetMaisi"]
57
51
 
@@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module):
78
72
  cross_attention_dim: Number of context dimensions to use.
79
73
  num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
80
74
  upcast_attention: If True, upcast attention operations to full precision.
75
+ include_fc: whether to include the final linear layer. Default to False.
76
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
81
77
  use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
82
78
  dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
83
79
  include_top_region_index_input: If True, use top region index input.
@@ -102,6 +98,8 @@ class DiffusionModelUNetMaisi(nn.Module):
102
98
  cross_attention_dim: int | None = None,
103
99
  num_class_embeds: int | None = None,
104
100
  upcast_attention: bool = False,
101
+ include_fc: bool = False,
102
+ use_combined_linear: bool = False,
105
103
  use_flash_attention: bool = False,
106
104
  dropout_cattn: float = 0.0,
107
105
  include_top_region_index_input: bool = False,
@@ -152,9 +150,6 @@ class DiffusionModelUNetMaisi(nn.Module):
152
150
  "`num_channels`."
153
151
  )
154
152
 
155
- if use_flash_attention and not has_xformers:
156
- raise ValueError("use_flash_attention is True but xformers is not installed.")
157
-
158
153
  if use_flash_attention is True and not torch.cuda.is_available():
159
154
  raise ValueError(
160
155
  "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
@@ -210,7 +205,6 @@ class DiffusionModelUNetMaisi(nn.Module):
210
205
  input_channel = output_channel
211
206
  output_channel = num_channels[i]
212
207
  is_final_block = i == len(num_channels) - 1
213
-
214
208
  down_block = get_down_block(
215
209
  spatial_dims=spatial_dims,
216
210
  in_channels=input_channel,
@@ -227,6 +221,8 @@ class DiffusionModelUNetMaisi(nn.Module):
227
221
  transformer_num_layers=transformer_num_layers,
228
222
  cross_attention_dim=cross_attention_dim,
229
223
  upcast_attention=upcast_attention,
224
+ include_fc=include_fc,
225
+ use_combined_linear=use_combined_linear,
230
226
  use_flash_attention=use_flash_attention,
231
227
  dropout_cattn=dropout_cattn,
232
228
  )
@@ -245,6 +241,8 @@ class DiffusionModelUNetMaisi(nn.Module):
245
241
  transformer_num_layers=transformer_num_layers,
246
242
  cross_attention_dim=cross_attention_dim,
247
243
  upcast_attention=upcast_attention,
244
+ include_fc=include_fc,
245
+ use_combined_linear=use_combined_linear,
248
246
  use_flash_attention=use_flash_attention,
249
247
  dropout_cattn=dropout_cattn,
250
248
  )
@@ -280,6 +278,8 @@ class DiffusionModelUNetMaisi(nn.Module):
280
278
  transformer_num_layers=transformer_num_layers,
281
279
  cross_attention_dim=cross_attention_dim,
282
280
  upcast_attention=upcast_attention,
281
+ include_fc=include_fc,
282
+ use_combined_linear=use_combined_linear,
283
283
  use_flash_attention=use_flash_attention,
284
284
  dropout_cattn=dropout_cattn,
285
285
  )
@@ -118,7 +118,7 @@ class ConfigParser:
118
118
  self.ref_resolver = ReferenceResolver()
119
119
  if config is None:
120
120
  config = {self.meta_key: {}}
121
- self.set(config=config)
121
+ self.set(config=self.ref_resolver.normalize_meta_id(config))
122
122
 
123
123
  def __repr__(self):
124
124
  return f"{self.config}"
@@ -221,7 +221,7 @@ class ConfigParser:
221
221
  if isinstance(conf_, dict) and k not in conf_:
222
222
  conf_[k] = {}
223
223
  conf_ = conf_[k if isinstance(conf_, dict) else int(k)]
224
- self[ReferenceResolver.normalize_id(id)] = config
224
+ self[ReferenceResolver.normalize_id(id)] = self.ref_resolver.normalize_meta_id(config)
225
225
 
226
226
  def update(self, pairs: dict[str, Any]) -> None:
227
227
  """
@@ -17,7 +17,7 @@ from collections.abc import Sequence
17
17
  from typing import Any, Iterator
18
18
 
19
19
  from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
20
- from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY
20
+ from monai.bundle.utils import DEPRECATED_ID_MAPPING, ID_REF_KEY, ID_SEP_KEY
21
21
  from monai.utils import allow_missing_reference, look_up_option
22
22
 
23
23
  __all__ = ["ReferenceResolver"]
@@ -202,6 +202,23 @@ class ReferenceResolver:
202
202
  """
203
203
  return str(id).replace("#", cls.sep) # backward compatibility `#` is the old separator
204
204
 
205
+ def normalize_meta_id(self, config: Any) -> Any:
206
+ """
207
+ Update deprecated identifiers in `config` using `DEPRECATED_ID_MAPPING`.
208
+ This will replace names that are marked as deprecated with their replacement.
209
+
210
+ Args:
211
+ config: input config to be updated.
212
+ """
213
+ if isinstance(config, dict):
214
+ for _id, _new_id in DEPRECATED_ID_MAPPING.items():
215
+ if _id in config.keys():
216
+ warnings.warn(
217
+ f"Detected deprecated name '{_id}' in configuration file, replacing with '{_new_id}'."
218
+ )
219
+ config[_new_id] = config.pop(_id)
220
+ return config
221
+
205
222
  @classmethod
206
223
  def split_id(cls, id: str | int, last: bool = False) -> list[str]:
207
224
  """
monai/bundle/scripts.py CHANGED
@@ -18,6 +18,7 @@ import re
18
18
  import warnings
19
19
  import zipfile
20
20
  from collections.abc import Mapping, Sequence
21
+ from functools import partial
21
22
  from pathlib import Path
22
23
  from pydoc import locate
23
24
  from shutil import copyfile
@@ -217,10 +218,15 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
217
218
 
218
219
 
219
220
  def _download_from_ngc(
220
- download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool
221
+ download_path: Path,
222
+ filename: str,
223
+ version: str,
224
+ prefix: str = "monai_",
225
+ remove_prefix: str | None = "monai_",
226
+ progress: bool = True,
221
227
  ) -> None:
222
228
  # ensure prefix is contained
223
- filename = _add_ngc_prefix(filename)
229
+ filename = _add_ngc_prefix(filename, prefix=prefix)
224
230
  url = _get_ngc_bundle_url(model_name=filename, version=version)
225
231
  filepath = download_path / f"{filename}_v{version}.zip"
226
232
  if remove_prefix:
@@ -231,10 +237,16 @@ def _download_from_ngc(
231
237
 
232
238
 
233
239
  def _download_from_ngc_private(
234
- download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
240
+ download_path: Path,
241
+ filename: str,
242
+ version: str,
243
+ repo: str,
244
+ prefix: str = "monai_",
245
+ remove_prefix: str | None = "monai_",
246
+ headers: dict | None = None,
235
247
  ) -> None:
236
248
  # ensure prefix is contained
237
- filename = _add_ngc_prefix(filename)
249
+ filename = _add_ngc_prefix(filename, prefix=prefix)
238
250
  request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
239
251
  if has_requests:
240
252
  headers = {} if headers is None else headers
@@ -491,7 +503,7 @@ def download(
491
503
  url: url to download the data. If not `None`, data will be downloaded directly
492
504
  and `source` will not be checked.
493
505
  If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
494
- remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
506
+ remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles
495
507
  have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
496
508
  maintain the consistency between these two sources, remove prefix is necessary.
497
509
  Therefore, if specified, downloaded folder name will remove the prefix.
@@ -1243,6 +1255,7 @@ def verify_net_in_out(
1243
1255
 
1244
1256
  def _export(
1245
1257
  converter: Callable,
1258
+ saver: Callable,
1246
1259
  parser: ConfigParser,
1247
1260
  net_id: str,
1248
1261
  filepath: str,
@@ -1257,6 +1270,8 @@ def _export(
1257
1270
  Args:
1258
1271
  converter: a callable object that takes a torch.nn.module and kwargs as input and
1259
1272
  converts the module to another type.
1273
+ saver: a callable object that accepts the converted model to save, a filepath to save to, meta values
1274
+ (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
1260
1275
  parser: a ConfigParser of the bundle to be converted.
1261
1276
  net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
1262
1277
  filepath: filepath to export, if filename has no extension, it becomes `.ts`.
@@ -1296,14 +1311,9 @@ def _export(
1296
1311
  # add .json extension to all extra files which are always encoded as JSON
1297
1312
  extra_files = {k + ".json": v for k, v in extra_files.items()}
1298
1313
 
1299
- save_net_with_metadata(
1300
- jit_obj=net,
1301
- filename_prefix_or_stream=filepath,
1302
- include_config_vals=False,
1303
- append_timestamp=False,
1304
- meta_values=parser.get().pop("_meta_", None),
1305
- more_extra_files=extra_files,
1306
- )
1314
+ meta_values = parser.get().pop("_meta_", None)
1315
+ saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)
1316
+
1307
1317
  logger.info(f"exported to file: {filepath}.")
1308
1318
 
1309
1319
 
@@ -1402,17 +1412,23 @@ def onnx_export(
1402
1412
  input_shape_ = _get_fake_input_shape(parser=parser)
1403
1413
 
1404
1414
  inputs_ = [torch.rand(input_shape_)]
1405
- net = parser.get_parsed_content(net_id_)
1406
- if has_ignite:
1407
- # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
1408
- Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
1409
- else:
1410
- ckpt = torch.load(ckpt_file_)
1411
- copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])
1412
1415
 
1413
1416
  converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
1414
- onnx_model = convert_to_onnx(model=net, **converter_kwargs_)
1415
- onnx.save(onnx_model, filepath_)
1417
+
1418
+ def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:
1419
+ onnx.save(onnx_obj, filename_prefix_or_stream)
1420
+
1421
+ _export(
1422
+ convert_to_onnx,
1423
+ save_onnx,
1424
+ parser,
1425
+ net_id=net_id_,
1426
+ filepath=filepath_,
1427
+ ckpt_file=ckpt_file_,
1428
+ config_file=config_file_,
1429
+ key_in_ckpt=key_in_ckpt_,
1430
+ **converter_kwargs_,
1431
+ )
1416
1432
 
1417
1433
 
1418
1434
  def ckpt_export(
@@ -1533,8 +1549,12 @@ def ckpt_export(
1533
1549
 
1534
1550
  converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
1535
1551
  # Use the given converter to convert a model and save with metadata, config content
1552
+
1553
+ save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
1554
+
1536
1555
  _export(
1537
1556
  convert_to_torchscript,
1557
+ save_ts,
1538
1558
  parser,
1539
1559
  net_id=net_id_,
1540
1560
  filepath=filepath_,
@@ -1704,8 +1724,11 @@ def trt_export(
1704
1724
  }
1705
1725
  converter_kwargs_.update(trt_api_parameters)
1706
1726
 
1727
+ save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
1728
+
1707
1729
  _export(
1708
1730
  convert_to_trt,
1731
+ save_ts,
1709
1732
  parser,
1710
1733
  net_id=net_id_,
1711
1734
  filepath=filepath_,
monai/bundle/utils.py CHANGED
@@ -36,7 +36,7 @@ DEFAULT_METADATA = {
36
36
  "monai_version": _conf_values["MONAI"],
37
37
  "pytorch_version": str(_conf_values["Pytorch"]).split("+")[0].split("a")[0], # 1.9.0a0+df837d0 or 1.13.0+cu117
38
38
  "numpy_version": _conf_values["Numpy"],
39
- "optional_packages_version": {},
39
+ "required_packages_version": {},
40
40
  "task": "Describe what the network predicts",
41
41
  "description": "A longer description of what the network does, use context, inputs, outputs, etc.",
42
42
  "authors": "Your Name Here",
@@ -157,6 +157,8 @@ DEFAULT_MLFLOW_SETTINGS = {
157
157
 
158
158
  DEFAULT_EXP_MGMT_SETTINGS = {"mlflow": DEFAULT_MLFLOW_SETTINGS} # default experiment management settings
159
159
 
160
+ DEPRECATED_ID_MAPPING = {"optional_packages_version": "required_packages_version"}
161
+
160
162
 
161
163
  def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any) -> Any:
162
164
  """
monai/data/utils.py CHANGED
@@ -927,7 +927,7 @@ def compute_shape_offset(
927
927
  corners = in_affine_ @ corners
928
928
  all_dist = corners_out[:-1].copy()
929
929
  corners_out = corners_out[:-1] / corners_out[-1]
930
- out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0)
930
+ out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0)
931
931
  offset = None
932
932
  for i in range(corners.shape[1]):
933
933
  min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
@@ -23,7 +23,7 @@ from monai.data.meta_tensor import MetaTensor
23
23
  from monai.data.utils import iter_patch_position
24
24
  from monai.data.wsi_reader import BaseWSIReader, WSIReader
25
25
  from monai.transforms import ForegroundMask, Randomizable, apply_transform
26
- from monai.utils import convert_to_dst_type, ensure_tuple_rep
26
+ from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep
27
27
  from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys
28
28
 
29
29
  __all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"]
@@ -123,9 +123,9 @@ class PatchWSIDataset(Dataset):
123
123
  def _get_location(self, sample: dict):
124
124
  if self.center_location:
125
125
  size = self._get_size(sample)
126
- return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))]
126
+ return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size)))
127
127
  else:
128
- return sample[WSIPatchKeys.LOCATION]
128
+ return ensure_tuple(sample[WSIPatchKeys.LOCATION])
129
129
 
130
130
  def _get_level(self, sample: dict):
131
131
  if self.patch_level is None: