monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/scripts.py +29 -17
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/inferers/utils.py +1 -0
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +22 -0
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +946 -0
- monai/networks/utils.py +4 -4
- monai/transforms/__init__.py +13 -2
- monai/transforms/io/array.py +59 -3
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utility/dictionary.py +4 -0
- monai/transforms/utils.py +230 -1
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/enums.py +1 -0
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import Any
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
from torch.nn.modules.loss import _Loss
|
20
|
+
|
21
|
+
from monai.networks.layers import GaussianFilter, MeanFilter
|
22
|
+
|
23
|
+
|
24
|
+
class NACLLoss(_Loss):
|
25
|
+
"""
|
26
|
+
Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
|
27
|
+
NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions
|
28
|
+
to match a soft class proportion of surrounding pixel.
|
29
|
+
|
30
|
+
Murugesan, Balamurali, et al.
|
31
|
+
"Trust your neighbours: Penalty-based constraints for model calibration."
|
32
|
+
International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.
|
33
|
+
https://arxiv.org/abs/2303.06268
|
34
|
+
|
35
|
+
Murugesan, Balamurali, et al.
|
36
|
+
"Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints."
|
37
|
+
https://arxiv.org/abs/2401.14487
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
classes: int,
|
43
|
+
dim: int,
|
44
|
+
kernel_size: int = 3,
|
45
|
+
kernel_ops: str = "mean",
|
46
|
+
distance_type: str = "l1",
|
47
|
+
alpha: float = 0.1,
|
48
|
+
sigma: float = 1.0,
|
49
|
+
) -> None:
|
50
|
+
"""
|
51
|
+
Args:
|
52
|
+
classes: number of classes
|
53
|
+
dim: dimension of data (supports 2d and 3d)
|
54
|
+
kernel_size: size of the spatial kernel
|
55
|
+
distance_type: l1/l2 distance between spatial kernel and predicted logits
|
56
|
+
alpha: weightage between cross entropy and logit constraint
|
57
|
+
sigma: sigma of gaussian
|
58
|
+
"""
|
59
|
+
|
60
|
+
super().__init__()
|
61
|
+
|
62
|
+
if kernel_ops not in ["mean", "gaussian"]:
|
63
|
+
raise ValueError("Kernel ops must be either mean or gaussian")
|
64
|
+
|
65
|
+
if dim not in [2, 3]:
|
66
|
+
raise ValueError(f"Support 2d and 3d, got dim={dim}.")
|
67
|
+
|
68
|
+
if distance_type not in ["l1", "l2"]:
|
69
|
+
raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}")
|
70
|
+
|
71
|
+
self.nc = classes
|
72
|
+
self.dim = dim
|
73
|
+
self.cross_entropy = nn.CrossEntropyLoss()
|
74
|
+
self.distance_type = distance_type
|
75
|
+
self.alpha = alpha
|
76
|
+
self.ks = kernel_size
|
77
|
+
self.svls_layer: Any
|
78
|
+
|
79
|
+
if kernel_ops == "mean":
|
80
|
+
self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size)
|
81
|
+
self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim)
|
82
|
+
if kernel_ops == "gaussian":
|
83
|
+
self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma)
|
84
|
+
|
85
|
+
def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
|
86
|
+
"""
|
87
|
+
Converts the mask to one hot represenation and is smoothened with the selected spatial filter.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
mask: the shape should be BH[WD].
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
torch.Tensor: the shape would be BNH[WD], N being number of classes.
|
94
|
+
"""
|
95
|
+
rmask: torch.Tensor
|
96
|
+
|
97
|
+
if self.dim == 2:
|
98
|
+
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float()
|
99
|
+
rmask = self.svls_layer(oh_labels)
|
100
|
+
|
101
|
+
if self.dim == 3:
|
102
|
+
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float()
|
103
|
+
rmask = self.svls_layer(oh_labels)
|
104
|
+
|
105
|
+
return rmask
|
106
|
+
|
107
|
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
108
|
+
"""
|
109
|
+
Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
inputs: the shape should be BNH[WD], where N is the number of classes.
|
113
|
+
targets: the shape should be BH[WD].
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
torch.Tensor: value of the loss.
|
117
|
+
|
118
|
+
Example:
|
119
|
+
>>> import torch
|
120
|
+
>>> from monai.losses import NACLLoss
|
121
|
+
>>> B, N, H, W = 8, 3, 64, 64
|
122
|
+
>>> input = torch.rand(B, N, H, W)
|
123
|
+
>>> target = torch.randint(0, N, (B, H, W))
|
124
|
+
>>> criterion = NACLLoss(classes = N, dim = 2)
|
125
|
+
>>> loss = criterion(input, target)
|
126
|
+
"""
|
127
|
+
|
128
|
+
loss_ce = self.cross_entropy(inputs, targets)
|
129
|
+
|
130
|
+
utargets = self.get_constr_target(targets)
|
131
|
+
|
132
|
+
if self.distance_type == "l1":
|
133
|
+
loss_conf = utargets.sub(inputs).abs_().mean()
|
134
|
+
elif self.distance_type == "l2":
|
135
|
+
loss_conf = utargets.sub(inputs).pow_(2).abs_().mean()
|
136
|
+
|
137
|
+
loss: torch.Tensor = loss_ce + self.alpha * loss_conf
|
138
|
+
|
139
|
+
return loss
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
20
|
-
from monai.utils import optional_import
|
20
|
+
from monai.utils import optional_import, pytorch_after
|
21
21
|
|
22
22
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
23
23
|
|
@@ -44,6 +44,7 @@ class CrossAttentionBlock(nn.Module):
|
|
44
44
|
rel_pos_embedding: Optional[str] = None,
|
45
45
|
input_size: Optional[Tuple] = None,
|
46
46
|
attention_dtype: Optional[torch.dtype] = None,
|
47
|
+
use_flash_attention: bool = False,
|
47
48
|
) -> None:
|
48
49
|
"""
|
49
50
|
Args:
|
@@ -55,13 +56,15 @@ class CrossAttentionBlock(nn.Module):
|
|
55
56
|
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
|
56
57
|
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
|
57
58
|
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
|
58
|
-
causal: whether to use causal attention.
|
59
|
-
sequence_length: if causal is True, it is necessary to specify the sequence length.
|
60
|
-
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
|
61
|
-
|
62
|
-
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
|
63
|
-
|
59
|
+
causal (bool, optional): whether to use causal attention.
|
60
|
+
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
|
61
|
+
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
|
62
|
+
"decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
|
63
|
+
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
|
64
|
+
parameter size.
|
64
65
|
attention_dtype: cast attention operations to this dtype.
|
66
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
67
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
65
68
|
"""
|
66
69
|
|
67
70
|
super().__init__()
|
@@ -81,6 +84,20 @@ class CrossAttentionBlock(nn.Module):
|
|
81
84
|
if causal and sequence_length is None:
|
82
85
|
raise ValueError("sequence_length is necessary for causal attention.")
|
83
86
|
|
87
|
+
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
|
88
|
+
raise ValueError(
|
89
|
+
"use_flash_attention is only supported for PyTorch versions >= 2.0."
|
90
|
+
"Upgrade your PyTorch or set the flag to False."
|
91
|
+
)
|
92
|
+
if use_flash_attention and save_attn:
|
93
|
+
raise ValueError(
|
94
|
+
"save_attn has been set to True, but use_flash_attention is also set"
|
95
|
+
"to True. save_attn can only be used if use_flash_attention is False"
|
96
|
+
)
|
97
|
+
|
98
|
+
if use_flash_attention and rel_pos_embedding is not None:
|
99
|
+
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
|
100
|
+
|
84
101
|
self.num_heads = num_heads
|
85
102
|
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
|
86
103
|
self.context_input_size = context_input_size if context_input_size else hidden_size
|
@@ -91,9 +108,10 @@ class CrossAttentionBlock(nn.Module):
|
|
91
108
|
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
|
92
109
|
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
|
93
110
|
|
94
|
-
self.out_rearrange = Rearrange("b h
|
111
|
+
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
|
95
112
|
self.drop_output = nn.Dropout(dropout_rate)
|
96
113
|
self.drop_weights = nn.Dropout(dropout_rate)
|
114
|
+
self.dropout_rate = dropout_rate
|
97
115
|
|
98
116
|
self.scale = self.head_dim**-0.5
|
99
117
|
self.save_attn = save_attn
|
@@ -101,6 +119,7 @@ class CrossAttentionBlock(nn.Module):
|
|
101
119
|
|
102
120
|
self.causal = causal
|
103
121
|
self.sequence_length = sequence_length
|
122
|
+
self.use_flash_attention = use_flash_attention
|
104
123
|
|
105
124
|
if causal and sequence_length is not None:
|
106
125
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
@@ -132,36 +151,39 @@ class CrossAttentionBlock(nn.Module):
|
|
132
151
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
133
152
|
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
|
134
153
|
|
135
|
-
q = self.to_q(x)
|
154
|
+
q = self.input_rearrange(self.to_q(x))
|
136
155
|
kv = context if context is not None else x
|
137
156
|
_, kv_t, _ = kv.size()
|
138
|
-
k = self.to_k(kv)
|
139
|
-
v = self.to_v(kv)
|
157
|
+
k = self.input_rearrange(self.to_k(kv))
|
158
|
+
v = self.input_rearrange(self.to_v(kv))
|
140
159
|
|
141
160
|
if self.attention_dtype is not None:
|
142
161
|
q = q.to(self.attention_dtype)
|
143
162
|
k = k.to(self.attention_dtype)
|
144
163
|
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
164
|
+
if self.use_flash_attention:
|
165
|
+
x = torch.nn.functional.scaled_dot_product_attention(
|
166
|
+
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
|
167
|
+
)
|
168
|
+
else:
|
169
|
+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
|
170
|
+
# apply relative positional embedding if defined
|
171
|
+
if self.rel_positional_embedding is not None:
|
172
|
+
att_mat = self.rel_positional_embedding(x, att_mat, q)
|
149
173
|
|
150
|
-
|
151
|
-
|
174
|
+
if self.causal:
|
175
|
+
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
|
152
176
|
|
153
|
-
|
154
|
-
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
|
177
|
+
att_mat = att_mat.softmax(dim=-1)
|
155
178
|
|
156
|
-
|
179
|
+
if self.save_attn:
|
180
|
+
# no gradients and new tensor;
|
181
|
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
|
182
|
+
self.att_mat = att_mat.detach()
|
157
183
|
|
158
|
-
|
159
|
-
|
160
|
-
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
|
161
|
-
self.att_mat = att_mat.detach()
|
184
|
+
att_mat = self.drop_weights(att_mat)
|
185
|
+
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
|
162
186
|
|
163
|
-
att_mat = self.drop_weights(att_mat)
|
164
|
-
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
|
165
187
|
x = self.out_rearrange(x)
|
166
188
|
x = self.out_proj(x)
|
167
189
|
x = self.drop_output(x)
|
monai/networks/blocks/mlp.py
CHANGED
@@ -11,12 +11,15 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
+
from typing import Union
|
15
|
+
|
14
16
|
import torch.nn as nn
|
15
17
|
|
16
18
|
from monai.networks.layers import get_act_layer
|
19
|
+
from monai.networks.layers.factories import split_args
|
17
20
|
from monai.utils import look_up_option
|
18
21
|
|
19
|
-
SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
|
22
|
+
SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}
|
20
23
|
|
21
24
|
|
22
25
|
class MLPBlock(nn.Module):
|
@@ -39,7 +42,7 @@ class MLPBlock(nn.Module):
|
|
39
42
|
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
|
40
43
|
"swin" corresponds to one instance as implemented in
|
41
44
|
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
|
42
|
-
|
45
|
+
"vista3d" mode does not use dropout.
|
43
46
|
|
44
47
|
"""
|
45
48
|
|
@@ -48,15 +51,24 @@ class MLPBlock(nn.Module):
|
|
48
51
|
if not (0 <= dropout_rate <= 1):
|
49
52
|
raise ValueError("dropout_rate should be between 0 and 1.")
|
50
53
|
mlp_dim = mlp_dim or hidden_size
|
51
|
-
|
54
|
+
act_name, _ = split_args(act)
|
55
|
+
self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
|
52
56
|
self.linear2 = nn.Linear(mlp_dim, hidden_size)
|
53
57
|
self.fn = get_act_layer(act)
|
54
|
-
|
58
|
+
# Use Union[nn.Dropout, nn.Identity] for type annotations
|
59
|
+
self.drop1: Union[nn.Dropout, nn.Identity]
|
60
|
+
self.drop2: Union[nn.Dropout, nn.Identity]
|
61
|
+
|
55
62
|
dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
|
56
63
|
if dropout_opt == "vit":
|
64
|
+
self.drop1 = nn.Dropout(dropout_rate)
|
57
65
|
self.drop2 = nn.Dropout(dropout_rate)
|
58
66
|
elif dropout_opt == "swin":
|
67
|
+
self.drop1 = nn.Dropout(dropout_rate)
|
59
68
|
self.drop2 = self.drop1
|
69
|
+
elif dropout_opt == "vista3d":
|
70
|
+
self.drop1 = nn.Identity()
|
71
|
+
self.drop2 = nn.Identity()
|
60
72
|
else:
|
61
73
|
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
|
62
74
|
|
@@ -11,13 +11,14 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
-
from typing import
|
14
|
+
from typing import Tuple, Union
|
15
15
|
|
16
16
|
import torch
|
17
17
|
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
18
19
|
|
19
20
|
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
20
|
-
from monai.utils import optional_import
|
21
|
+
from monai.utils import optional_import, pytorch_after
|
21
22
|
|
22
23
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
23
24
|
|
@@ -39,9 +40,12 @@ class SABlock(nn.Module):
|
|
39
40
|
hidden_input_size: int | None = None,
|
40
41
|
causal: bool = False,
|
41
42
|
sequence_length: int | None = None,
|
42
|
-
rel_pos_embedding:
|
43
|
-
input_size:
|
44
|
-
attention_dtype:
|
43
|
+
rel_pos_embedding: str | None = None,
|
44
|
+
input_size: Tuple | None = None,
|
45
|
+
attention_dtype: torch.dtype | None = None,
|
46
|
+
include_fc: bool = True,
|
47
|
+
use_combined_linear: bool = True,
|
48
|
+
use_flash_attention: bool = False,
|
45
49
|
) -> None:
|
46
50
|
"""
|
47
51
|
Args:
|
@@ -59,6 +63,10 @@ class SABlock(nn.Module):
|
|
59
63
|
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
|
60
64
|
positional parameter size.
|
61
65
|
attention_dtype: cast attention operations to this dtype.
|
66
|
+
include_fc: whether to include the final linear layer. Default to True.
|
67
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
|
68
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
69
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
62
70
|
|
63
71
|
"""
|
64
72
|
|
@@ -82,21 +90,52 @@ class SABlock(nn.Module):
|
|
82
90
|
if causal and sequence_length is None:
|
83
91
|
raise ValueError("sequence_length is necessary for causal attention.")
|
84
92
|
|
93
|
+
if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
|
94
|
+
raise ValueError(
|
95
|
+
"use_flash_attention is only supported for PyTorch versions >= 2.0."
|
96
|
+
"Upgrade your PyTorch or set the flag to False."
|
97
|
+
)
|
98
|
+
if use_flash_attention and save_attn:
|
99
|
+
raise ValueError(
|
100
|
+
"save_attn has been set to True, but use_flash_attention is also set"
|
101
|
+
"to True. save_attn can only be used if use_flash_attention is False."
|
102
|
+
)
|
103
|
+
|
104
|
+
if use_flash_attention and rel_pos_embedding is not None:
|
105
|
+
raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
|
106
|
+
|
85
107
|
self.num_heads = num_heads
|
86
108
|
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
|
87
109
|
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
|
88
110
|
|
89
|
-
self.qkv
|
90
|
-
self.
|
91
|
-
self.
|
111
|
+
self.qkv: Union[nn.Linear, nn.Identity]
|
112
|
+
self.to_q: Union[nn.Linear, nn.Identity]
|
113
|
+
self.to_k: Union[nn.Linear, nn.Identity]
|
114
|
+
self.to_v: Union[nn.Linear, nn.Identity]
|
115
|
+
|
116
|
+
if use_combined_linear:
|
117
|
+
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
|
118
|
+
self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
|
119
|
+
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
|
120
|
+
else:
|
121
|
+
self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
|
122
|
+
self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
|
123
|
+
self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
|
124
|
+
self.qkv = nn.Identity() # add to enable torchscript
|
125
|
+
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
|
126
|
+
self.out_rearrange = Rearrange("b l h d -> b h (l d)")
|
92
127
|
self.drop_output = nn.Dropout(dropout_rate)
|
93
128
|
self.drop_weights = nn.Dropout(dropout_rate)
|
129
|
+
self.dropout_rate = dropout_rate
|
94
130
|
self.scale = self.dim_head**-0.5
|
95
131
|
self.save_attn = save_attn
|
96
132
|
self.att_mat = torch.Tensor()
|
97
133
|
self.attention_dtype = attention_dtype
|
98
134
|
self.causal = causal
|
99
135
|
self.sequence_length = sequence_length
|
136
|
+
self.include_fc = include_fc
|
137
|
+
self.use_combined_linear = use_combined_linear
|
138
|
+
self.use_flash_attention = use_flash_attention
|
100
139
|
|
101
140
|
if causal and sequence_length is not None:
|
102
141
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
@@ -123,31 +162,44 @@ class SABlock(nn.Module):
|
|
123
162
|
Return:
|
124
163
|
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
|
125
164
|
"""
|
126
|
-
|
127
|
-
|
165
|
+
if self.use_combined_linear:
|
166
|
+
output = self.input_rearrange(self.qkv(x))
|
167
|
+
q, k, v = output[0], output[1], output[2]
|
168
|
+
else:
|
169
|
+
q = self.input_rearrange(self.to_q(x))
|
170
|
+
k = self.input_rearrange(self.to_k(x))
|
171
|
+
v = self.input_rearrange(self.to_v(x))
|
128
172
|
|
129
173
|
if self.attention_dtype is not None:
|
130
174
|
q = q.to(self.attention_dtype)
|
131
175
|
k = k.to(self.attention_dtype)
|
132
176
|
|
133
|
-
|
177
|
+
if self.use_flash_attention:
|
178
|
+
x = F.scaled_dot_product_attention(
|
179
|
+
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
|
180
|
+
)
|
181
|
+
else:
|
182
|
+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
|
183
|
+
|
184
|
+
# apply relative positional embedding if defined
|
185
|
+
if self.rel_positional_embedding is not None:
|
186
|
+
att_mat = self.rel_positional_embedding(x, att_mat, q)
|
134
187
|
|
135
|
-
|
136
|
-
|
188
|
+
if self.causal:
|
189
|
+
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
|
137
190
|
|
138
|
-
|
139
|
-
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
|
191
|
+
att_mat = att_mat.softmax(dim=-1)
|
140
192
|
|
141
|
-
|
193
|
+
if self.save_attn:
|
194
|
+
# no gradients and new tensor;
|
195
|
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
|
196
|
+
self.att_mat = att_mat.detach()
|
142
197
|
|
143
|
-
|
144
|
-
|
145
|
-
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
|
146
|
-
self.att_mat = att_mat.detach()
|
198
|
+
att_mat = self.drop_weights(att_mat)
|
199
|
+
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
|
147
200
|
|
148
|
-
att_mat = self.drop_weights(att_mat)
|
149
|
-
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
|
150
201
|
x = self.out_rearrange(x)
|
151
|
-
|
202
|
+
if self.include_fc:
|
203
|
+
x = self.out_proj(x)
|
152
204
|
x = self.drop_output(x)
|
153
205
|
return x
|
@@ -32,7 +32,13 @@ class SpatialAttentionBlock(nn.Module):
|
|
32
32
|
spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
|
33
33
|
num_channels: number of input channels. Must be divisible by num_head_channels.
|
34
34
|
num_head_channels: number of channels per head.
|
35
|
+
norm_num_groups: Number of groups for the group norm layer.
|
36
|
+
norm_eps: Epsilon for the normalization.
|
35
37
|
attention_dtype: cast attention operations to this dtype.
|
38
|
+
include_fc: whether to include the final linear layer. Default to True.
|
39
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
|
40
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
41
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
36
42
|
|
37
43
|
"""
|
38
44
|
|
@@ -44,6 +50,9 @@ class SpatialAttentionBlock(nn.Module):
|
|
44
50
|
norm_num_groups: int = 32,
|
45
51
|
norm_eps: float = 1e-6,
|
46
52
|
attention_dtype: Optional[torch.dtype] = None,
|
53
|
+
include_fc: bool = True,
|
54
|
+
use_combined_linear: bool = False,
|
55
|
+
use_flash_attention: bool = False,
|
47
56
|
) -> None:
|
48
57
|
super().__init__()
|
49
58
|
|
@@ -54,7 +63,13 @@ class SpatialAttentionBlock(nn.Module):
|
|
54
63
|
raise ValueError("num_channels must be divisible by num_head_channels")
|
55
64
|
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
|
56
65
|
self.attn = SABlock(
|
57
|
-
hidden_size=num_channels,
|
66
|
+
hidden_size=num_channels,
|
67
|
+
num_heads=num_heads,
|
68
|
+
qkv_bias=True,
|
69
|
+
attention_dtype=attention_dtype,
|
70
|
+
include_fc=include_fc,
|
71
|
+
use_combined_linear=use_combined_linear,
|
72
|
+
use_flash_attention=use_flash_attention,
|
58
73
|
)
|
59
74
|
|
60
75
|
def forward(self, x: torch.Tensor):
|
@@ -36,6 +36,9 @@ class TransformerBlock(nn.Module):
|
|
36
36
|
causal: bool = False,
|
37
37
|
sequence_length: int | None = None,
|
38
38
|
with_cross_attention: bool = False,
|
39
|
+
use_flash_attention: bool = False,
|
40
|
+
include_fc: bool = True,
|
41
|
+
use_combined_linear: bool = True,
|
39
42
|
) -> None:
|
40
43
|
"""
|
41
44
|
Args:
|
@@ -43,8 +46,12 @@ class TransformerBlock(nn.Module):
|
|
43
46
|
mlp_dim (int): dimension of feedforward layer.
|
44
47
|
num_heads (int): number of attention heads.
|
45
48
|
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
|
46
|
-
qkv_bias
|
49
|
+
qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
|
47
50
|
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
|
51
|
+
use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
|
52
|
+
(see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
53
|
+
include_fc: whether to include the final linear layer. Default to True.
|
54
|
+
use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
|
48
55
|
|
49
56
|
"""
|
50
57
|
|
@@ -66,13 +73,21 @@ class TransformerBlock(nn.Module):
|
|
66
73
|
save_attn=save_attn,
|
67
74
|
causal=causal,
|
68
75
|
sequence_length=sequence_length,
|
76
|
+
include_fc=include_fc,
|
77
|
+
use_combined_linear=use_combined_linear,
|
78
|
+
use_flash_attention=use_flash_attention,
|
69
79
|
)
|
70
80
|
self.norm2 = nn.LayerNorm(hidden_size)
|
71
81
|
self.with_cross_attention = with_cross_attention
|
72
82
|
|
73
83
|
self.norm_cross_attn = nn.LayerNorm(hidden_size)
|
74
84
|
self.cross_attn = CrossAttentionBlock(
|
75
|
-
hidden_size=hidden_size,
|
85
|
+
hidden_size=hidden_size,
|
86
|
+
num_heads=num_heads,
|
87
|
+
dropout_rate=dropout_rate,
|
88
|
+
qkv_bias=qkv_bias,
|
89
|
+
causal=False,
|
90
|
+
use_flash_attention=use_flash_attention,
|
76
91
|
)
|
77
92
|
|
78
93
|
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
|
@@ -51,6 +51,8 @@ class BilateralFilter(torch.autograd.Function):
|
|
51
51
|
ctx.cs = color_sigma
|
52
52
|
ctx.fa = fast_approx
|
53
53
|
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
|
54
|
+
if torch.cuda.is_available():
|
55
|
+
torch.cuda.synchronize()
|
54
56
|
return output_data
|
55
57
|
|
56
58
|
@staticmethod
|
@@ -139,7 +141,8 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):
|
|
139
141
|
do_dsig_y,
|
140
142
|
do_dsig_z,
|
141
143
|
)
|
142
|
-
|
144
|
+
if torch.cuda.is_available():
|
145
|
+
torch.cuda.synchronize()
|
143
146
|
return output_tensor
|
144
147
|
|
145
148
|
@staticmethod
|
@@ -301,7 +304,8 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):
|
|
301
304
|
do_dsig_z,
|
302
305
|
guidance_img,
|
303
306
|
)
|
304
|
-
|
307
|
+
if torch.cuda.is_available():
|
308
|
+
torch.cuda.synchronize()
|
305
309
|
return output_tensor
|
306
310
|
|
307
311
|
@staticmethod
|
monai/networks/nets/__init__.py
CHANGED
@@ -76,7 +76,7 @@ from .resnet import (
|
|
76
76
|
resnet200,
|
77
77
|
)
|
78
78
|
from .segresnet import SegResNet, SegResNetVAE
|
79
|
-
from .segresnet_ds import SegResNetDS
|
79
|
+
from .segresnet_ds import SegResNetDS, SegResNetDS2
|
80
80
|
from .senet import (
|
81
81
|
SENet,
|
82
82
|
SEnet,
|
@@ -118,6 +118,7 @@ from .transformer import DecoderOnlyTransformer
|
|
118
118
|
from .unet import UNet, Unet
|
119
119
|
from .unetr import UNETR
|
120
120
|
from .varautoencoder import VarAutoEncoder
|
121
|
+
from .vista3d import VISTA3D, vista3d132
|
121
122
|
from .vit import ViT
|
122
123
|
from .vitautoenc import ViTAutoEnc
|
123
124
|
from .vnet import VNet
|