monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -65,6 +65,7 @@ class CumulativeAverage:
|
|
65
65
|
if self.val is None:
|
66
66
|
return 0
|
67
67
|
|
68
|
+
val: NdarrayOrTensor
|
68
69
|
val = self.val.clone()
|
69
70
|
val[~torch.isfinite(val)] = 0
|
70
71
|
|
@@ -96,6 +97,7 @@ class CumulativeAverage:
|
|
96
97
|
dist.all_reduce(sum)
|
97
98
|
dist.all_reduce(count)
|
98
99
|
|
100
|
+
val: NdarrayOrTensor
|
99
101
|
val = torch.where(count > 0, sum / count, sum)
|
100
102
|
|
101
103
|
if to_numpy:
|
@@ -274,7 +274,7 @@ def _get_paired_iou(
|
|
274
274
|
|
275
275
|
return paired_iou, paired_true, paired_pred
|
276
276
|
|
277
|
-
pairwise_iou = pairwise_iou.cpu().numpy()
|
277
|
+
pairwise_iou = pairwise_iou.cpu().numpy() # type: ignore[assignment]
|
278
278
|
paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
|
279
279
|
paired_iou = pairwise_iou[paired_true, paired_pred]
|
280
280
|
paired_true = torch.as_tensor(list(paired_true[paired_iou > match_iou_threshold] + 1), device=device)
|
monai/metrics/rocauc.py
CHANGED
@@ -88,8 +88,8 @@ def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
|
|
88
88
|
|
89
89
|
n = len(y)
|
90
90
|
indices = y_pred.argsort()
|
91
|
-
y = y[indices].cpu().numpy()
|
92
|
-
y_pred = y_pred[indices].cpu().numpy()
|
91
|
+
y = y[indices].cpu().numpy() # type: ignore[assignment]
|
92
|
+
y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
|
93
93
|
nneg = auc = tmp_pos = tmp_neg = 0.0
|
94
94
|
|
95
95
|
for i in range(n):
|
@@ -17,6 +17,7 @@ from .aspp import SimpleASPP
|
|
17
17
|
from .backbone_fpn_utils import BackboneWithFPN
|
18
18
|
from .convolutions import Convolution, ResidualUnit
|
19
19
|
from .crf import CRF
|
20
|
+
from .crossattention import CrossAttentionBlock
|
20
21
|
from .denseblock import ConvDenseBlock, DenseBlock
|
21
22
|
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
|
22
23
|
from .downsample import MaxAvgPool
|
@@ -30,6 +31,8 @@ from .patchembedding import PatchEmbed, PatchEmbeddingBlock
|
|
30
31
|
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
|
31
32
|
from .segresnet_block import ResBlock
|
32
33
|
from .selfattention import SABlock
|
34
|
+
from .spade_norm import SPADE
|
35
|
+
from .spatialattention import SpatialAttentionBlock
|
33
36
|
from .squeeze_and_excitation import (
|
34
37
|
ChannelSELayer,
|
35
38
|
ResidualSELayer,
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# you may not use this file except in compliance with the License.
|
2
|
+
# You may obtain a copy of the License at
|
3
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
4
|
+
# Unless required by applicable law or agreed to in writing, software
|
5
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
6
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
7
|
+
# See the License for the specific language governing permissions and
|
8
|
+
# limitations under the License.
|
9
|
+
|
10
|
+
from __future__ import annotations
|
11
|
+
|
12
|
+
from typing import Tuple
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import torch.nn.functional as F
|
16
|
+
from torch import nn
|
17
|
+
|
18
|
+
|
19
|
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
20
|
+
"""
|
21
|
+
Get relative positional embeddings according to the relative positions of
|
22
|
+
query and key sizes.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
q_size (int): size of query q.
|
26
|
+
k_size (int): size of key k.
|
27
|
+
rel_pos (Tensor): relative position embeddings (L, C).
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
Extracted positional embeddings according to relative positions.
|
31
|
+
"""
|
32
|
+
rel_pos_resized: torch.Tensor = torch.Tensor()
|
33
|
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
34
|
+
# Interpolate rel pos if needed.
|
35
|
+
if rel_pos.shape[0] != max_rel_dist:
|
36
|
+
# Interpolate rel pos.
|
37
|
+
rel_pos_resized = F.interpolate(
|
38
|
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
|
39
|
+
)
|
40
|
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
41
|
+
else:
|
42
|
+
rel_pos_resized = rel_pos
|
43
|
+
|
44
|
+
# Scale the coords with short length if shapes for q and k are different.
|
45
|
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
46
|
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
47
|
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
48
|
+
|
49
|
+
return rel_pos_resized[relative_coords.long()]
|
50
|
+
|
51
|
+
|
52
|
+
def add_decomposed_rel_pos(
|
53
|
+
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
|
54
|
+
) -> torch.Tensor:
|
55
|
+
r"""
|
56
|
+
Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
|
57
|
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
58
|
+
|
59
|
+
Only 2D and 3D are supported.
|
60
|
+
|
61
|
+
Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
|
62
|
+
`d` apart will have the same embedding value (unlike absolute positional embedding).
|
63
|
+
|
64
|
+
.. math::
|
65
|
+
Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
|
66
|
+
|
67
|
+
where
|
68
|
+
|
69
|
+
.. math::
|
70
|
+
E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
|
71
|
+
|
72
|
+
with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
|
73
|
+
respectively spatial positions of element :math:`i` and :math:`j`
|
74
|
+
|
75
|
+
When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
|
76
|
+
|
77
|
+
.. math::
|
78
|
+
R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
|
79
|
+
|
80
|
+
with :math:`n = 1...dim`
|
81
|
+
|
82
|
+
Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
|
83
|
+
:math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
attn (Tensor): attention map.
|
87
|
+
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
|
88
|
+
rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
|
89
|
+
q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
|
90
|
+
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
attn (Tensor): attention logits with added relative positional embeddings.
|
94
|
+
"""
|
95
|
+
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
|
96
|
+
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])
|
97
|
+
|
98
|
+
batch, _, dim = q.shape
|
99
|
+
|
100
|
+
if len(rel_pos_lst) == 2:
|
101
|
+
q_h, q_w = q_size[:2]
|
102
|
+
k_h, k_w = k_size[:2]
|
103
|
+
r_q = q.reshape(batch, q_h, q_w, dim)
|
104
|
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
|
105
|
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)
|
106
|
+
|
107
|
+
attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
108
|
+
batch, q_h * q_w, k_h * k_w
|
109
|
+
)
|
110
|
+
elif len(rel_pos_lst) == 3:
|
111
|
+
q_h, q_w, q_d = q_size[:3]
|
112
|
+
k_h, k_w, k_d = k_size[:3]
|
113
|
+
|
114
|
+
rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])
|
115
|
+
|
116
|
+
r_q = q.reshape(batch, q_h, q_w, q_d, dim)
|
117
|
+
rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh)
|
118
|
+
rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw)
|
119
|
+
rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd)
|
120
|
+
|
121
|
+
attn = (
|
122
|
+
attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)
|
123
|
+
+ rel_h[:, :, :, :, None, None]
|
124
|
+
+ rel_w[:, :, :, None, :, None]
|
125
|
+
+ rel_d[:, :, :, None, None, :]
|
126
|
+
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)
|
127
|
+
|
128
|
+
return attn
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import Optional, Tuple
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn as nn
|
18
|
+
|
19
|
+
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
20
|
+
from monai.utils import optional_import
|
21
|
+
|
22
|
+
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
23
|
+
|
24
|
+
|
25
|
+
class CrossAttentionBlock(nn.Module):
|
26
|
+
"""
|
27
|
+
A cross-attention block, based on: "Dosovitskiy et al.,
|
28
|
+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
29
|
+
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
hidden_size: int,
|
35
|
+
num_heads: int,
|
36
|
+
dropout_rate: float = 0.0,
|
37
|
+
hidden_input_size: int | None = None,
|
38
|
+
context_input_size: int | None = None,
|
39
|
+
dim_head: int | None = None,
|
40
|
+
qkv_bias: bool = False,
|
41
|
+
save_attn: bool = False,
|
42
|
+
causal: bool = False,
|
43
|
+
sequence_length: int | None = None,
|
44
|
+
rel_pos_embedding: Optional[str] = None,
|
45
|
+
input_size: Optional[Tuple] = None,
|
46
|
+
attention_dtype: Optional[torch.dtype] = None,
|
47
|
+
) -> None:
|
48
|
+
"""
|
49
|
+
Args:
|
50
|
+
hidden_size (int): dimension of hidden layer.
|
51
|
+
num_heads (int): number of attention heads.
|
52
|
+
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
|
53
|
+
hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
|
54
|
+
context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
|
55
|
+
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
|
56
|
+
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
|
57
|
+
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
|
58
|
+
causal: whether to use causal attention.
|
59
|
+
sequence_length: if causal is True, it is necessary to specify the sequence length.
|
60
|
+
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
|
61
|
+
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
|
62
|
+
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
|
63
|
+
positional parameter size.
|
64
|
+
attention_dtype: cast attention operations to this dtype.
|
65
|
+
"""
|
66
|
+
|
67
|
+
super().__init__()
|
68
|
+
|
69
|
+
if not (0 <= dropout_rate <= 1):
|
70
|
+
raise ValueError("dropout_rate should be between 0 and 1.")
|
71
|
+
|
72
|
+
if dim_head:
|
73
|
+
inner_size = num_heads * dim_head
|
74
|
+
self.head_dim = dim_head
|
75
|
+
else:
|
76
|
+
if hidden_size % num_heads != 0:
|
77
|
+
raise ValueError("hidden size should be divisible by num_heads.")
|
78
|
+
inner_size = hidden_size
|
79
|
+
self.head_dim = hidden_size // num_heads
|
80
|
+
|
81
|
+
if causal and sequence_length is None:
|
82
|
+
raise ValueError("sequence_length is necessary for causal attention.")
|
83
|
+
|
84
|
+
self.num_heads = num_heads
|
85
|
+
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
|
86
|
+
self.context_input_size = context_input_size if context_input_size else hidden_size
|
87
|
+
self.out_proj = nn.Linear(inner_size, self.hidden_input_size)
|
88
|
+
# key, query, value projections
|
89
|
+
self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias)
|
90
|
+
self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
|
91
|
+
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
|
92
|
+
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
|
93
|
+
|
94
|
+
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
|
95
|
+
self.drop_output = nn.Dropout(dropout_rate)
|
96
|
+
self.drop_weights = nn.Dropout(dropout_rate)
|
97
|
+
|
98
|
+
self.scale = self.head_dim**-0.5
|
99
|
+
self.save_attn = save_attn
|
100
|
+
self.attention_dtype = attention_dtype
|
101
|
+
|
102
|
+
self.causal = causal
|
103
|
+
self.sequence_length = sequence_length
|
104
|
+
|
105
|
+
if causal and sequence_length is not None:
|
106
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
107
|
+
self.register_buffer(
|
108
|
+
"causal_mask",
|
109
|
+
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
|
110
|
+
)
|
111
|
+
self.causal_mask: torch.Tensor
|
112
|
+
else:
|
113
|
+
self.causal_mask = torch.Tensor()
|
114
|
+
|
115
|
+
self.att_mat = torch.Tensor()
|
116
|
+
self.rel_positional_embedding = (
|
117
|
+
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
|
118
|
+
if rel_pos_embedding is not None
|
119
|
+
else None
|
120
|
+
)
|
121
|
+
self.input_size = input_size
|
122
|
+
|
123
|
+
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
|
124
|
+
"""
|
125
|
+
Args:
|
126
|
+
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
|
127
|
+
context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C
|
128
|
+
|
129
|
+
Return:
|
130
|
+
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
|
131
|
+
"""
|
132
|
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
133
|
+
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
|
134
|
+
|
135
|
+
q = self.to_q(x)
|
136
|
+
kv = context if context is not None else x
|
137
|
+
_, kv_t, _ = kv.size()
|
138
|
+
k = self.to_k(kv)
|
139
|
+
v = self.to_v(kv)
|
140
|
+
|
141
|
+
if self.attention_dtype is not None:
|
142
|
+
q = q.to(self.attention_dtype)
|
143
|
+
k = k.to(self.attention_dtype)
|
144
|
+
|
145
|
+
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
|
146
|
+
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
|
147
|
+
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
|
148
|
+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
|
149
|
+
|
150
|
+
# apply relative positional embedding if defined
|
151
|
+
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
|
152
|
+
|
153
|
+
if self.causal:
|
154
|
+
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
|
155
|
+
|
156
|
+
att_mat = att_mat.softmax(dim=-1)
|
157
|
+
|
158
|
+
if self.save_attn:
|
159
|
+
# no gradients and new tensor;
|
160
|
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
|
161
|
+
self.att_mat = att_mat.detach()
|
162
|
+
|
163
|
+
att_mat = self.drop_weights(att_mat)
|
164
|
+
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
|
165
|
+
x = self.out_rearrange(x)
|
166
|
+
x = self.out_proj(x)
|
167
|
+
x = self.drop_output(x)
|
168
|
+
return x
|
@@ -0,0 +1,56 @@
|
|
1
|
+
# you may not use this file except in compliance with the License.
|
2
|
+
# You may obtain a copy of the License at
|
3
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
4
|
+
# Unless required by applicable law or agreed to in writing, software
|
5
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
6
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
7
|
+
# See the License for the specific language governing permissions and
|
8
|
+
# limitations under the License.
|
9
|
+
|
10
|
+
from __future__ import annotations
|
11
|
+
|
12
|
+
from typing import Iterable, Tuple
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from torch import nn
|
16
|
+
|
17
|
+
from monai.networks.blocks.attention_utils import add_decomposed_rel_pos
|
18
|
+
from monai.utils.misc import ensure_tuple_size
|
19
|
+
|
20
|
+
|
21
|
+
class DecomposedRelativePosEmbedding(nn.Module):
|
22
|
+
def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None:
|
23
|
+
"""
|
24
|
+
Args:
|
25
|
+
s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D)
|
26
|
+
c_dim (int): channel dimension
|
27
|
+
num_heads(int): number of attention heads
|
28
|
+
"""
|
29
|
+
super().__init__()
|
30
|
+
|
31
|
+
# validate inputs
|
32
|
+
if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]:
|
33
|
+
raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)")
|
34
|
+
|
35
|
+
self.s_input_dims = s_input_dims
|
36
|
+
self.c_dim = c_dim
|
37
|
+
self.num_heads = num_heads
|
38
|
+
self.rel_pos_arr = nn.ParameterList(
|
39
|
+
[nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims]
|
40
|
+
)
|
41
|
+
|
42
|
+
def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
|
43
|
+
""""""
|
44
|
+
batch = x.shape[0]
|
45
|
+
h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1)
|
46
|
+
|
47
|
+
att_mat = add_decomposed_rel_pos(
|
48
|
+
att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d),
|
49
|
+
q.contiguous().view(batch * self.num_heads, h * w * d, -1),
|
50
|
+
self.rel_pos_arr,
|
51
|
+
(h, w) if d == 1 else (h, w, d),
|
52
|
+
(h, w) if d == 1 else (h, w, d),
|
53
|
+
)
|
54
|
+
|
55
|
+
att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)
|
56
|
+
return att_mat
|
@@ -11,9 +11,12 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
+
from typing import Optional, Tuple
|
15
|
+
|
14
16
|
import torch
|
15
17
|
import torch.nn as nn
|
16
18
|
|
19
|
+
from monai.networks.layers.utils import get_rel_pos_embedding_layer
|
17
20
|
from monai.utils import optional_import
|
18
21
|
|
19
22
|
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
@@ -33,6 +36,12 @@ class SABlock(nn.Module):
|
|
33
36
|
qkv_bias: bool = False,
|
34
37
|
save_attn: bool = False,
|
35
38
|
dim_head: int | None = None,
|
39
|
+
hidden_input_size: int | None = None,
|
40
|
+
causal: bool = False,
|
41
|
+
sequence_length: int | None = None,
|
42
|
+
rel_pos_embedding: Optional[str] = None,
|
43
|
+
input_size: Optional[Tuple] = None,
|
44
|
+
attention_dtype: Optional[torch.dtype] = None,
|
36
45
|
) -> None:
|
37
46
|
"""
|
38
47
|
Args:
|
@@ -42,6 +51,14 @@ class SABlock(nn.Module):
|
|
42
51
|
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
|
43
52
|
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
|
44
53
|
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
|
54
|
+
hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
|
55
|
+
causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762).
|
56
|
+
sequence_length: if causal is True, it is necessary to specify the sequence length.
|
57
|
+
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
|
58
|
+
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
|
59
|
+
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
|
60
|
+
positional parameter size.
|
61
|
+
attention_dtype: cast attention operations to this dtype.
|
45
62
|
|
46
63
|
"""
|
47
64
|
|
@@ -53,12 +70,23 @@ class SABlock(nn.Module):
|
|
53
70
|
if hidden_size % num_heads != 0:
|
54
71
|
raise ValueError("hidden size should be divisible by num_heads.")
|
55
72
|
|
73
|
+
if dim_head:
|
74
|
+
self.inner_dim = num_heads * dim_head
|
75
|
+
self.dim_head = dim_head
|
76
|
+
else:
|
77
|
+
if hidden_size % num_heads != 0:
|
78
|
+
raise ValueError("hidden size should be divisible by num_heads.")
|
79
|
+
self.inner_dim = hidden_size
|
80
|
+
self.dim_head = hidden_size // num_heads
|
81
|
+
|
82
|
+
if causal and sequence_length is None:
|
83
|
+
raise ValueError("sequence_length is necessary for causal attention.")
|
84
|
+
|
56
85
|
self.num_heads = num_heads
|
57
|
-
self.
|
58
|
-
self.
|
86
|
+
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
|
87
|
+
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
|
59
88
|
|
60
|
-
self.
|
61
|
-
self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias)
|
89
|
+
self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
|
62
90
|
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
|
63
91
|
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
|
64
92
|
self.drop_output = nn.Dropout(dropout_rate)
|
@@ -66,11 +94,52 @@ class SABlock(nn.Module):
|
|
66
94
|
self.scale = self.dim_head**-0.5
|
67
95
|
self.save_attn = save_attn
|
68
96
|
self.att_mat = torch.Tensor()
|
97
|
+
self.attention_dtype = attention_dtype
|
98
|
+
self.causal = causal
|
99
|
+
self.sequence_length = sequence_length
|
100
|
+
|
101
|
+
if causal and sequence_length is not None:
|
102
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
103
|
+
self.register_buffer(
|
104
|
+
"causal_mask",
|
105
|
+
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
|
106
|
+
)
|
107
|
+
self.causal_mask: torch.Tensor
|
108
|
+
else:
|
109
|
+
self.causal_mask = torch.Tensor()
|
110
|
+
|
111
|
+
self.rel_positional_embedding = (
|
112
|
+
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads)
|
113
|
+
if rel_pos_embedding is not None
|
114
|
+
else None
|
115
|
+
)
|
116
|
+
self.input_size = input_size
|
69
117
|
|
70
118
|
def forward(self, x):
|
119
|
+
"""
|
120
|
+
Args:
|
121
|
+
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
|
122
|
+
|
123
|
+
Return:
|
124
|
+
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
|
125
|
+
"""
|
71
126
|
output = self.input_rearrange(self.qkv(x))
|
72
127
|
q, k, v = output[0], output[1], output[2]
|
73
|
-
|
128
|
+
|
129
|
+
if self.attention_dtype is not None:
|
130
|
+
q = q.to(self.attention_dtype)
|
131
|
+
k = k.to(self.attention_dtype)
|
132
|
+
|
133
|
+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
|
134
|
+
|
135
|
+
# apply relative positional embedding if defined
|
136
|
+
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
|
137
|
+
|
138
|
+
if self.causal:
|
139
|
+
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
|
140
|
+
|
141
|
+
att_mat = att_mat.softmax(dim=-1)
|
142
|
+
|
74
143
|
if self.save_attn:
|
75
144
|
# no gradients and new tensor;
|
76
145
|
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import torch.nn as nn
|
16
|
+
import torch.nn.functional as F
|
17
|
+
|
18
|
+
from monai.networks.blocks import Convolution
|
19
|
+
from monai.networks.layers.utils import get_norm_layer
|
20
|
+
|
21
|
+
|
22
|
+
class SPADE(nn.Module):
|
23
|
+
"""
|
24
|
+
Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a
|
25
|
+
semantic map. This block is used in SPADE-based image-to-image translation models, as described in
|
26
|
+
Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291).
|
27
|
+
|
28
|
+
Args:
|
29
|
+
label_nc: number of semantic labels
|
30
|
+
norm_nc: number of output channels
|
31
|
+
kernel_size: kernel size
|
32
|
+
spatial_dims: number of spatial dimensions
|
33
|
+
hidden_channels: number of channels in the intermediate gamma and beta layers
|
34
|
+
norm: type of base normalisation used before applying the SPADE normalisation
|
35
|
+
norm_params: parameters for the base normalisation
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
label_nc: int,
|
41
|
+
norm_nc: int,
|
42
|
+
kernel_size: int = 3,
|
43
|
+
spatial_dims: int = 2,
|
44
|
+
hidden_channels: int = 64,
|
45
|
+
norm: str | tuple = "INSTANCE",
|
46
|
+
norm_params: dict | None = None,
|
47
|
+
) -> None:
|
48
|
+
super().__init__()
|
49
|
+
|
50
|
+
if norm_params is None:
|
51
|
+
norm_params = {}
|
52
|
+
if len(norm_params) != 0:
|
53
|
+
norm = (norm, norm_params)
|
54
|
+
self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc)
|
55
|
+
self.mlp_shared = Convolution(
|
56
|
+
spatial_dims=spatial_dims,
|
57
|
+
in_channels=label_nc,
|
58
|
+
out_channels=hidden_channels,
|
59
|
+
kernel_size=kernel_size,
|
60
|
+
norm=None,
|
61
|
+
act="LEAKYRELU",
|
62
|
+
)
|
63
|
+
self.mlp_gamma = Convolution(
|
64
|
+
spatial_dims=spatial_dims,
|
65
|
+
in_channels=hidden_channels,
|
66
|
+
out_channels=norm_nc,
|
67
|
+
kernel_size=kernel_size,
|
68
|
+
act=None,
|
69
|
+
)
|
70
|
+
self.mlp_beta = Convolution(
|
71
|
+
spatial_dims=spatial_dims,
|
72
|
+
in_channels=hidden_channels,
|
73
|
+
out_channels=norm_nc,
|
74
|
+
kernel_size=kernel_size,
|
75
|
+
act=None,
|
76
|
+
)
|
77
|
+
|
78
|
+
def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor:
|
79
|
+
"""
|
80
|
+
Args:
|
81
|
+
x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels.
|
82
|
+
segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels.
|
83
|
+
The map will be interpolated to the dimension of x internally.
|
84
|
+
"""
|
85
|
+
|
86
|
+
# Part 1. generate parameter-free normalized activations
|
87
|
+
normalized = self.param_free_norm(x.contiguous())
|
88
|
+
|
89
|
+
# Part 2. produce scaling and bias conditioned on semantic map
|
90
|
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
|
91
|
+
actv = self.mlp_shared(segmap)
|
92
|
+
gamma = self.mlp_gamma(actv)
|
93
|
+
beta = self.mlp_beta(actv)
|
94
|
+
out: torch.Tensor = normalized * (1 + gamma) + beta
|
95
|
+
return out
|