monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/apps/vista3d/inferer.py +177 -0
  7. monai/apps/vista3d/sampler.py +179 -0
  8. monai/apps/vista3d/transforms.py +224 -0
  9. monai/bundle/scripts.py +29 -17
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/inferers/utils.py +1 -0
  13. monai/losses/__init__.py +1 -0
  14. monai/losses/dice.py +10 -1
  15. monai/losses/nacl_loss.py +139 -0
  16. monai/networks/blocks/crossattention.py +48 -26
  17. monai/networks/blocks/mlp.py +16 -4
  18. monai/networks/blocks/selfattention.py +75 -23
  19. monai/networks/blocks/spatialattention.py +16 -1
  20. monai/networks/blocks/transformerblock.py +17 -2
  21. monai/networks/layers/filtering.py +6 -2
  22. monai/networks/nets/__init__.py +2 -1
  23. monai/networks/nets/autoencoderkl.py +55 -22
  24. monai/networks/nets/cell_sam_wrapper.py +92 -0
  25. monai/networks/nets/controlnet.py +24 -22
  26. monai/networks/nets/diffusion_model_unet.py +159 -19
  27. monai/networks/nets/segresnet_ds.py +127 -1
  28. monai/networks/nets/spade_autoencoderkl.py +22 -0
  29. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  30. monai/networks/nets/transformer.py +17 -17
  31. monai/networks/nets/vista3d.py +946 -0
  32. monai/networks/utils.py +4 -4
  33. monai/transforms/__init__.py +13 -2
  34. monai/transforms/io/array.py +59 -3
  35. monai/transforms/io/dictionary.py +29 -2
  36. monai/transforms/spatial/functional.py +1 -1
  37. monai/transforms/transform.py +2 -2
  38. monai/transforms/utility/dictionary.py +4 -0
  39. monai/transforms/utils.py +230 -1
  40. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  41. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  42. monai/utils/enums.py +1 -0
  43. monai/utils/module.py +7 -6
  44. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
  45. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
  46. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
  47. /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
  48. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
  49. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Any
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn.modules.loss import _Loss
20
+
21
+ from monai.networks.layers import GaussianFilter, MeanFilter
22
+
23
+
24
+ class NACLLoss(_Loss):
25
+ """
26
+ Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
27
+ NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions
28
+ to match a soft class proportion of surrounding pixel.
29
+
30
+ Murugesan, Balamurali, et al.
31
+ "Trust your neighbours: Penalty-based constraints for model calibration."
32
+ International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.
33
+ https://arxiv.org/abs/2303.06268
34
+
35
+ Murugesan, Balamurali, et al.
36
+ "Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints."
37
+ https://arxiv.org/abs/2401.14487
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ classes: int,
43
+ dim: int,
44
+ kernel_size: int = 3,
45
+ kernel_ops: str = "mean",
46
+ distance_type: str = "l1",
47
+ alpha: float = 0.1,
48
+ sigma: float = 1.0,
49
+ ) -> None:
50
+ """
51
+ Args:
52
+ classes: number of classes
53
+ dim: dimension of data (supports 2d and 3d)
54
+ kernel_size: size of the spatial kernel
55
+ distance_type: l1/l2 distance between spatial kernel and predicted logits
56
+ alpha: weightage between cross entropy and logit constraint
57
+ sigma: sigma of gaussian
58
+ """
59
+
60
+ super().__init__()
61
+
62
+ if kernel_ops not in ["mean", "gaussian"]:
63
+ raise ValueError("Kernel ops must be either mean or gaussian")
64
+
65
+ if dim not in [2, 3]:
66
+ raise ValueError(f"Support 2d and 3d, got dim={dim}.")
67
+
68
+ if distance_type not in ["l1", "l2"]:
69
+ raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}")
70
+
71
+ self.nc = classes
72
+ self.dim = dim
73
+ self.cross_entropy = nn.CrossEntropyLoss()
74
+ self.distance_type = distance_type
75
+ self.alpha = alpha
76
+ self.ks = kernel_size
77
+ self.svls_layer: Any
78
+
79
+ if kernel_ops == "mean":
80
+ self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size)
81
+ self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim)
82
+ if kernel_ops == "gaussian":
83
+ self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma)
84
+
85
+ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
86
+ """
87
+ Converts the mask to one hot represenation and is smoothened with the selected spatial filter.
88
+
89
+ Args:
90
+ mask: the shape should be BH[WD].
91
+
92
+ Returns:
93
+ torch.Tensor: the shape would be BNH[WD], N being number of classes.
94
+ """
95
+ rmask: torch.Tensor
96
+
97
+ if self.dim == 2:
98
+ oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float()
99
+ rmask = self.svls_layer(oh_labels)
100
+
101
+ if self.dim == 3:
102
+ oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float()
103
+ rmask = self.svls_layer(oh_labels)
104
+
105
+ return rmask
106
+
107
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
108
+ """
109
+ Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.
110
+
111
+ Args:
112
+ inputs: the shape should be BNH[WD], where N is the number of classes.
113
+ targets: the shape should be BH[WD].
114
+
115
+ Returns:
116
+ torch.Tensor: value of the loss.
117
+
118
+ Example:
119
+ >>> import torch
120
+ >>> from monai.losses import NACLLoss
121
+ >>> B, N, H, W = 8, 3, 64, 64
122
+ >>> input = torch.rand(B, N, H, W)
123
+ >>> target = torch.randint(0, N, (B, H, W))
124
+ >>> criterion = NACLLoss(classes = N, dim = 2)
125
+ >>> loss = criterion(input, target)
126
+ """
127
+
128
+ loss_ce = self.cross_entropy(inputs, targets)
129
+
130
+ utargets = self.get_constr_target(targets)
131
+
132
+ if self.distance_type == "l1":
133
+ loss_conf = utargets.sub(inputs).abs_().mean()
134
+ elif self.distance_type == "l2":
135
+ loss_conf = utargets.sub(inputs).pow_(2).abs_().mean()
136
+
137
+ loss: torch.Tensor = loss_ce + self.alpha * loss_conf
138
+
139
+ return loss
@@ -17,7 +17,7 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from monai.networks.layers.utils import get_rel_pos_embedding_layer
20
- from monai.utils import optional_import
20
+ from monai.utils import optional_import, pytorch_after
21
21
 
22
22
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
23
23
 
@@ -44,6 +44,7 @@ class CrossAttentionBlock(nn.Module):
44
44
  rel_pos_embedding: Optional[str] = None,
45
45
  input_size: Optional[Tuple] = None,
46
46
  attention_dtype: Optional[torch.dtype] = None,
47
+ use_flash_attention: bool = False,
47
48
  ) -> None:
48
49
  """
49
50
  Args:
@@ -55,13 +56,15 @@ class CrossAttentionBlock(nn.Module):
55
56
  dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
56
57
  qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
57
58
  save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
58
- causal: whether to use causal attention.
59
- sequence_length: if causal is True, it is necessary to specify the sequence length.
60
- rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
61
- For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
62
- input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
63
- positional parameter size.
59
+ causal (bool, optional): whether to use causal attention.
60
+ sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
61
+ rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
62
+ "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
63
+ input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
64
+ parameter size.
64
65
  attention_dtype: cast attention operations to this dtype.
66
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
67
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
65
68
  """
66
69
 
67
70
  super().__init__()
@@ -81,6 +84,20 @@ class CrossAttentionBlock(nn.Module):
81
84
  if causal and sequence_length is None:
82
85
  raise ValueError("sequence_length is necessary for causal attention.")
83
86
 
87
+ if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
88
+ raise ValueError(
89
+ "use_flash_attention is only supported for PyTorch versions >= 2.0."
90
+ "Upgrade your PyTorch or set the flag to False."
91
+ )
92
+ if use_flash_attention and save_attn:
93
+ raise ValueError(
94
+ "save_attn has been set to True, but use_flash_attention is also set"
95
+ "to True. save_attn can only be used if use_flash_attention is False"
96
+ )
97
+
98
+ if use_flash_attention and rel_pos_embedding is not None:
99
+ raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
100
+
84
101
  self.num_heads = num_heads
85
102
  self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
86
103
  self.context_input_size = context_input_size if context_input_size else hidden_size
@@ -91,9 +108,10 @@ class CrossAttentionBlock(nn.Module):
91
108
  self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
92
109
  self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
93
110
 
94
- self.out_rearrange = Rearrange("b h l d -> b l (h d)")
111
+ self.out_rearrange = Rearrange("b l h d -> b h (l d)")
95
112
  self.drop_output = nn.Dropout(dropout_rate)
96
113
  self.drop_weights = nn.Dropout(dropout_rate)
114
+ self.dropout_rate = dropout_rate
97
115
 
98
116
  self.scale = self.head_dim**-0.5
99
117
  self.save_attn = save_attn
@@ -101,6 +119,7 @@ class CrossAttentionBlock(nn.Module):
101
119
 
102
120
  self.causal = causal
103
121
  self.sequence_length = sequence_length
122
+ self.use_flash_attention = use_flash_attention
104
123
 
105
124
  if causal and sequence_length is not None:
106
125
  # causal mask to ensure that attention is only applied to the left in the input sequence
@@ -132,36 +151,39 @@ class CrossAttentionBlock(nn.Module):
132
151
  # calculate query, key, values for all heads in batch and move head forward to be the batch dim
133
152
  b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
134
153
 
135
- q = self.to_q(x)
154
+ q = self.input_rearrange(self.to_q(x))
136
155
  kv = context if context is not None else x
137
156
  _, kv_t, _ = kv.size()
138
- k = self.to_k(kv)
139
- v = self.to_v(kv)
157
+ k = self.input_rearrange(self.to_k(kv))
158
+ v = self.input_rearrange(self.to_v(kv))
140
159
 
141
160
  if self.attention_dtype is not None:
142
161
  q = q.to(self.attention_dtype)
143
162
  k = k.to(self.attention_dtype)
144
163
 
145
- q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
146
- k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
147
- v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
148
- att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
164
+ if self.use_flash_attention:
165
+ x = torch.nn.functional.scaled_dot_product_attention(
166
+ query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
167
+ )
168
+ else:
169
+ att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
170
+ # apply relative positional embedding if defined
171
+ if self.rel_positional_embedding is not None:
172
+ att_mat = self.rel_positional_embedding(x, att_mat, q)
149
173
 
150
- # apply relative positional embedding if defined
151
- att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
174
+ if self.causal:
175
+ att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
152
176
 
153
- if self.causal:
154
- att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
177
+ att_mat = att_mat.softmax(dim=-1)
155
178
 
156
- att_mat = att_mat.softmax(dim=-1)
179
+ if self.save_attn:
180
+ # no gradients and new tensor;
181
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
182
+ self.att_mat = att_mat.detach()
157
183
 
158
- if self.save_attn:
159
- # no gradients and new tensor;
160
- # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
161
- self.att_mat = att_mat.detach()
184
+ att_mat = self.drop_weights(att_mat)
185
+ x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
162
186
 
163
- att_mat = self.drop_weights(att_mat)
164
- x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
165
187
  x = self.out_rearrange(x)
166
188
  x = self.out_proj(x)
167
189
  x = self.drop_output(x)
@@ -11,12 +11,15 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from typing import Union
15
+
14
16
  import torch.nn as nn
15
17
 
16
18
  from monai.networks.layers import get_act_layer
19
+ from monai.networks.layers.factories import split_args
17
20
  from monai.utils import look_up_option
18
21
 
19
- SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
22
+ SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}
20
23
 
21
24
 
22
25
  class MLPBlock(nn.Module):
@@ -39,7 +42,7 @@ class MLPBlock(nn.Module):
39
42
  https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
40
43
  "swin" corresponds to one instance as implemented in
41
44
  https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
42
-
45
+ "vista3d" mode does not use dropout.
43
46
 
44
47
  """
45
48
 
@@ -48,15 +51,24 @@ class MLPBlock(nn.Module):
48
51
  if not (0 <= dropout_rate <= 1):
49
52
  raise ValueError("dropout_rate should be between 0 and 1.")
50
53
  mlp_dim = mlp_dim or hidden_size
51
- self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
54
+ act_name, _ = split_args(act)
55
+ self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
52
56
  self.linear2 = nn.Linear(mlp_dim, hidden_size)
53
57
  self.fn = get_act_layer(act)
54
- self.drop1 = nn.Dropout(dropout_rate)
58
+ # Use Union[nn.Dropout, nn.Identity] for type annotations
59
+ self.drop1: Union[nn.Dropout, nn.Identity]
60
+ self.drop2: Union[nn.Dropout, nn.Identity]
61
+
55
62
  dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
56
63
  if dropout_opt == "vit":
64
+ self.drop1 = nn.Dropout(dropout_rate)
57
65
  self.drop2 = nn.Dropout(dropout_rate)
58
66
  elif dropout_opt == "swin":
67
+ self.drop1 = nn.Dropout(dropout_rate)
59
68
  self.drop2 = self.drop1
69
+ elif dropout_opt == "vista3d":
70
+ self.drop1 = nn.Identity()
71
+ self.drop2 = nn.Identity()
60
72
  else:
61
73
  raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
62
74
 
@@ -11,13 +11,14 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from typing import Optional, Tuple
14
+ from typing import Tuple, Union
15
15
 
16
16
  import torch
17
17
  import torch.nn as nn
18
+ import torch.nn.functional as F
18
19
 
19
20
  from monai.networks.layers.utils import get_rel_pos_embedding_layer
20
- from monai.utils import optional_import
21
+ from monai.utils import optional_import, pytorch_after
21
22
 
22
23
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
23
24
 
@@ -39,9 +40,12 @@ class SABlock(nn.Module):
39
40
  hidden_input_size: int | None = None,
40
41
  causal: bool = False,
41
42
  sequence_length: int | None = None,
42
- rel_pos_embedding: Optional[str] = None,
43
- input_size: Optional[Tuple] = None,
44
- attention_dtype: Optional[torch.dtype] = None,
43
+ rel_pos_embedding: str | None = None,
44
+ input_size: Tuple | None = None,
45
+ attention_dtype: torch.dtype | None = None,
46
+ include_fc: bool = True,
47
+ use_combined_linear: bool = True,
48
+ use_flash_attention: bool = False,
45
49
  ) -> None:
46
50
  """
47
51
  Args:
@@ -59,6 +63,10 @@ class SABlock(nn.Module):
59
63
  input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
60
64
  positional parameter size.
61
65
  attention_dtype: cast attention operations to this dtype.
66
+ include_fc: whether to include the final linear layer. Default to True.
67
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
68
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
69
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
62
70
 
63
71
  """
64
72
 
@@ -82,21 +90,52 @@ class SABlock(nn.Module):
82
90
  if causal and sequence_length is None:
83
91
  raise ValueError("sequence_length is necessary for causal attention.")
84
92
 
93
+ if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
94
+ raise ValueError(
95
+ "use_flash_attention is only supported for PyTorch versions >= 2.0."
96
+ "Upgrade your PyTorch or set the flag to False."
97
+ )
98
+ if use_flash_attention and save_attn:
99
+ raise ValueError(
100
+ "save_attn has been set to True, but use_flash_attention is also set"
101
+ "to True. save_attn can only be used if use_flash_attention is False."
102
+ )
103
+
104
+ if use_flash_attention and rel_pos_embedding is not None:
105
+ raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
106
+
85
107
  self.num_heads = num_heads
86
108
  self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
87
109
  self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
88
110
 
89
- self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
90
- self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
91
- self.out_rearrange = Rearrange("b h l d -> b l (h d)")
111
+ self.qkv: Union[nn.Linear, nn.Identity]
112
+ self.to_q: Union[nn.Linear, nn.Identity]
113
+ self.to_k: Union[nn.Linear, nn.Identity]
114
+ self.to_v: Union[nn.Linear, nn.Identity]
115
+
116
+ if use_combined_linear:
117
+ self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
118
+ self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
119
+ self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
120
+ else:
121
+ self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
122
+ self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
123
+ self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
124
+ self.qkv = nn.Identity() # add to enable torchscript
125
+ self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
126
+ self.out_rearrange = Rearrange("b l h d -> b h (l d)")
92
127
  self.drop_output = nn.Dropout(dropout_rate)
93
128
  self.drop_weights = nn.Dropout(dropout_rate)
129
+ self.dropout_rate = dropout_rate
94
130
  self.scale = self.dim_head**-0.5
95
131
  self.save_attn = save_attn
96
132
  self.att_mat = torch.Tensor()
97
133
  self.attention_dtype = attention_dtype
98
134
  self.causal = causal
99
135
  self.sequence_length = sequence_length
136
+ self.include_fc = include_fc
137
+ self.use_combined_linear = use_combined_linear
138
+ self.use_flash_attention = use_flash_attention
100
139
 
101
140
  if causal and sequence_length is not None:
102
141
  # causal mask to ensure that attention is only applied to the left in the input sequence
@@ -123,31 +162,44 @@ class SABlock(nn.Module):
123
162
  Return:
124
163
  torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
125
164
  """
126
- output = self.input_rearrange(self.qkv(x))
127
- q, k, v = output[0], output[1], output[2]
165
+ if self.use_combined_linear:
166
+ output = self.input_rearrange(self.qkv(x))
167
+ q, k, v = output[0], output[1], output[2]
168
+ else:
169
+ q = self.input_rearrange(self.to_q(x))
170
+ k = self.input_rearrange(self.to_k(x))
171
+ v = self.input_rearrange(self.to_v(x))
128
172
 
129
173
  if self.attention_dtype is not None:
130
174
  q = q.to(self.attention_dtype)
131
175
  k = k.to(self.attention_dtype)
132
176
 
133
- att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
177
+ if self.use_flash_attention:
178
+ x = F.scaled_dot_product_attention(
179
+ query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
180
+ )
181
+ else:
182
+ att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
183
+
184
+ # apply relative positional embedding if defined
185
+ if self.rel_positional_embedding is not None:
186
+ att_mat = self.rel_positional_embedding(x, att_mat, q)
134
187
 
135
- # apply relative positional embedding if defined
136
- att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
188
+ if self.causal:
189
+ att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
137
190
 
138
- if self.causal:
139
- att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
191
+ att_mat = att_mat.softmax(dim=-1)
140
192
 
141
- att_mat = att_mat.softmax(dim=-1)
193
+ if self.save_attn:
194
+ # no gradients and new tensor;
195
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
196
+ self.att_mat = att_mat.detach()
142
197
 
143
- if self.save_attn:
144
- # no gradients and new tensor;
145
- # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
146
- self.att_mat = att_mat.detach()
198
+ att_mat = self.drop_weights(att_mat)
199
+ x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
147
200
 
148
- att_mat = self.drop_weights(att_mat)
149
- x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
150
201
  x = self.out_rearrange(x)
151
- x = self.out_proj(x)
202
+ if self.include_fc:
203
+ x = self.out_proj(x)
152
204
  x = self.drop_output(x)
153
205
  return x
@@ -32,7 +32,13 @@ class SpatialAttentionBlock(nn.Module):
32
32
  spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
33
33
  num_channels: number of input channels. Must be divisible by num_head_channels.
34
34
  num_head_channels: number of channels per head.
35
+ norm_num_groups: Number of groups for the group norm layer.
36
+ norm_eps: Epsilon for the normalization.
35
37
  attention_dtype: cast attention operations to this dtype.
38
+ include_fc: whether to include the final linear layer. Default to True.
39
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
40
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
41
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
36
42
 
37
43
  """
38
44
 
@@ -44,6 +50,9 @@ class SpatialAttentionBlock(nn.Module):
44
50
  norm_num_groups: int = 32,
45
51
  norm_eps: float = 1e-6,
46
52
  attention_dtype: Optional[torch.dtype] = None,
53
+ include_fc: bool = True,
54
+ use_combined_linear: bool = False,
55
+ use_flash_attention: bool = False,
47
56
  ) -> None:
48
57
  super().__init__()
49
58
 
@@ -54,7 +63,13 @@ class SpatialAttentionBlock(nn.Module):
54
63
  raise ValueError("num_channels must be divisible by num_head_channels")
55
64
  num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
56
65
  self.attn = SABlock(
57
- hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
66
+ hidden_size=num_channels,
67
+ num_heads=num_heads,
68
+ qkv_bias=True,
69
+ attention_dtype=attention_dtype,
70
+ include_fc=include_fc,
71
+ use_combined_linear=use_combined_linear,
72
+ use_flash_attention=use_flash_attention,
58
73
  )
59
74
 
60
75
  def forward(self, x: torch.Tensor):
@@ -36,6 +36,9 @@ class TransformerBlock(nn.Module):
36
36
  causal: bool = False,
37
37
  sequence_length: int | None = None,
38
38
  with_cross_attention: bool = False,
39
+ use_flash_attention: bool = False,
40
+ include_fc: bool = True,
41
+ use_combined_linear: bool = True,
39
42
  ) -> None:
40
43
  """
41
44
  Args:
@@ -43,8 +46,12 @@ class TransformerBlock(nn.Module):
43
46
  mlp_dim (int): dimension of feedforward layer.
44
47
  num_heads (int): number of attention heads.
45
48
  dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
46
- qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
49
+ qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
47
50
  save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
51
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
52
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
53
+ include_fc: whether to include the final linear layer. Default to True.
54
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
48
55
 
49
56
  """
50
57
 
@@ -66,13 +73,21 @@ class TransformerBlock(nn.Module):
66
73
  save_attn=save_attn,
67
74
  causal=causal,
68
75
  sequence_length=sequence_length,
76
+ include_fc=include_fc,
77
+ use_combined_linear=use_combined_linear,
78
+ use_flash_attention=use_flash_attention,
69
79
  )
70
80
  self.norm2 = nn.LayerNorm(hidden_size)
71
81
  self.with_cross_attention = with_cross_attention
72
82
 
73
83
  self.norm_cross_attn = nn.LayerNorm(hidden_size)
74
84
  self.cross_attn = CrossAttentionBlock(
75
- hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
85
+ hidden_size=hidden_size,
86
+ num_heads=num_heads,
87
+ dropout_rate=dropout_rate,
88
+ qkv_bias=qkv_bias,
89
+ causal=False,
90
+ use_flash_attention=use_flash_attention,
76
91
  )
77
92
 
78
93
  def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -51,6 +51,8 @@ class BilateralFilter(torch.autograd.Function):
51
51
  ctx.cs = color_sigma
52
52
  ctx.fa = fast_approx
53
53
  output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
54
+ if torch.cuda.is_available():
55
+ torch.cuda.synchronize()
54
56
  return output_data
55
57
 
56
58
  @staticmethod
@@ -139,7 +141,8 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):
139
141
  do_dsig_y,
140
142
  do_dsig_z,
141
143
  )
142
-
144
+ if torch.cuda.is_available():
145
+ torch.cuda.synchronize()
143
146
  return output_tensor
144
147
 
145
148
  @staticmethod
@@ -301,7 +304,8 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):
301
304
  do_dsig_z,
302
305
  guidance_img,
303
306
  )
304
-
307
+ if torch.cuda.is_available():
308
+ torch.cuda.synchronize()
305
309
  return output_tensor
306
310
 
307
311
  @staticmethod
@@ -76,7 +76,7 @@ from .resnet import (
76
76
  resnet200,
77
77
  )
78
78
  from .segresnet import SegResNet, SegResNetVAE
79
- from .segresnet_ds import SegResNetDS
79
+ from .segresnet_ds import SegResNetDS, SegResNetDS2
80
80
  from .senet import (
81
81
  SENet,
82
82
  SEnet,
@@ -118,6 +118,7 @@ from .transformer import DecoderOnlyTransformer
118
118
  from .unet import UNet, Unet
119
119
  from .unetr import UNETR
120
120
  from .varautoencoder import VarAutoEncoder
121
+ from .vista3d import VISTA3D, vista3d132
121
122
  from .vit import ViT
122
123
  from .vitautoenc import ViTAutoEnc
123
124
  from .vnet import VNet