monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -65,6 +65,7 @@ class CumulativeAverage:
65
65
  if self.val is None:
66
66
  return 0
67
67
 
68
+ val: NdarrayOrTensor
68
69
  val = self.val.clone()
69
70
  val[~torch.isfinite(val)] = 0
70
71
 
@@ -96,6 +97,7 @@ class CumulativeAverage:
96
97
  dist.all_reduce(sum)
97
98
  dist.all_reduce(count)
98
99
 
100
+ val: NdarrayOrTensor
99
101
  val = torch.where(count > 0, sum / count, sum)
100
102
 
101
103
  if to_numpy:
@@ -274,7 +274,7 @@ def _get_paired_iou(
274
274
 
275
275
  return paired_iou, paired_true, paired_pred
276
276
 
277
- pairwise_iou = pairwise_iou.cpu().numpy()
277
+ pairwise_iou = pairwise_iou.cpu().numpy() # type: ignore[assignment]
278
278
  paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
279
279
  paired_iou = pairwise_iou[paired_true, paired_pred]
280
280
  paired_true = torch.as_tensor(list(paired_true[paired_iou > match_iou_threshold] + 1), device=device)
monai/metrics/rocauc.py CHANGED
@@ -88,8 +88,8 @@ def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
88
88
 
89
89
  n = len(y)
90
90
  indices = y_pred.argsort()
91
- y = y[indices].cpu().numpy()
92
- y_pred = y_pred[indices].cpu().numpy()
91
+ y = y[indices].cpu().numpy() # type: ignore[assignment]
92
+ y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
93
93
  nneg = auc = tmp_pos = tmp_neg = 0.0
94
94
 
95
95
  for i in range(n):
@@ -17,6 +17,7 @@ from .aspp import SimpleASPP
17
17
  from .backbone_fpn_utils import BackboneWithFPN
18
18
  from .convolutions import Convolution, ResidualUnit
19
19
  from .crf import CRF
20
+ from .crossattention import CrossAttentionBlock
20
21
  from .denseblock import ConvDenseBlock, DenseBlock
21
22
  from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
22
23
  from .downsample import MaxAvgPool
@@ -30,6 +31,8 @@ from .patchembedding import PatchEmbed, PatchEmbeddingBlock
30
31
  from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
31
32
  from .segresnet_block import ResBlock
32
33
  from .selfattention import SABlock
34
+ from .spade_norm import SPADE
35
+ from .spatialattention import SpatialAttentionBlock
33
36
  from .squeeze_and_excitation import (
34
37
  ChannelSELayer,
35
38
  ResidualSELayer,
@@ -0,0 +1,128 @@
1
+ # you may not use this file except in compliance with the License.
2
+ # You may obtain a copy of the License at
3
+ # http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software
5
+ # distributed under the License is distributed on an "AS IS" BASIS,
6
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
+ # See the License for the specific language governing permissions and
8
+ # limitations under the License.
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Tuple
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+
18
+
19
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Get relative positional embeddings according to the relative positions of
22
+ query and key sizes.
23
+
24
+ Args:
25
+ q_size (int): size of query q.
26
+ k_size (int): size of key k.
27
+ rel_pos (Tensor): relative position embeddings (L, C).
28
+
29
+ Returns:
30
+ Extracted positional embeddings according to relative positions.
31
+ """
32
+ rel_pos_resized: torch.Tensor = torch.Tensor()
33
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
34
+ # Interpolate rel pos if needed.
35
+ if rel_pos.shape[0] != max_rel_dist:
36
+ # Interpolate rel pos.
37
+ rel_pos_resized = F.interpolate(
38
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
39
+ )
40
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
41
+ else:
42
+ rel_pos_resized = rel_pos
43
+
44
+ # Scale the coords with short length if shapes for q and k are different.
45
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
46
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
47
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
48
+
49
+ return rel_pos_resized[relative_coords.long()]
50
+
51
+
52
+ def add_decomposed_rel_pos(
53
+ attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
54
+ ) -> torch.Tensor:
55
+ r"""
56
+ Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
57
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
58
+
59
+ Only 2D and 3D are supported.
60
+
61
+ Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
62
+ `d` apart will have the same embedding value (unlike absolute positional embedding).
63
+
64
+ .. math::
65
+ Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
66
+
67
+ where
68
+
69
+ .. math::
70
+ E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
71
+
72
+ with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
73
+ respectively spatial positions of element :math:`i` and :math:`j`
74
+
75
+ When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
76
+
77
+ .. math::
78
+ R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
79
+
80
+ with :math:`n = 1...dim`
81
+
82
+ Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
83
+ :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
84
+
85
+ Args:
86
+ attn (Tensor): attention map.
87
+ q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
88
+ rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
89
+ q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
90
+ k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
91
+
92
+ Returns:
93
+ attn (Tensor): attention logits with added relative positional embeddings.
94
+ """
95
+ rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
96
+ rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])
97
+
98
+ batch, _, dim = q.shape
99
+
100
+ if len(rel_pos_lst) == 2:
101
+ q_h, q_w = q_size[:2]
102
+ k_h, k_w = k_size[:2]
103
+ r_q = q.reshape(batch, q_h, q_w, dim)
104
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
105
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)
106
+
107
+ attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
108
+ batch, q_h * q_w, k_h * k_w
109
+ )
110
+ elif len(rel_pos_lst) == 3:
111
+ q_h, q_w, q_d = q_size[:3]
112
+ k_h, k_w, k_d = k_size[:3]
113
+
114
+ rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])
115
+
116
+ r_q = q.reshape(batch, q_h, q_w, q_d, dim)
117
+ rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh)
118
+ rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw)
119
+ rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd)
120
+
121
+ attn = (
122
+ attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)
123
+ + rel_h[:, :, :, :, None, None]
124
+ + rel_w[:, :, :, None, :, None]
125
+ + rel_d[:, :, :, None, None, :]
126
+ ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)
127
+
128
+ return attn
@@ -0,0 +1,168 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from monai.networks.layers.utils import get_rel_pos_embedding_layer
20
+ from monai.utils import optional_import
21
+
22
+ Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
23
+
24
+
25
+ class CrossAttentionBlock(nn.Module):
26
+ """
27
+ A cross-attention block, based on: "Dosovitskiy et al.,
28
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
29
+ One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int,
35
+ num_heads: int,
36
+ dropout_rate: float = 0.0,
37
+ hidden_input_size: int | None = None,
38
+ context_input_size: int | None = None,
39
+ dim_head: int | None = None,
40
+ qkv_bias: bool = False,
41
+ save_attn: bool = False,
42
+ causal: bool = False,
43
+ sequence_length: int | None = None,
44
+ rel_pos_embedding: Optional[str] = None,
45
+ input_size: Optional[Tuple] = None,
46
+ attention_dtype: Optional[torch.dtype] = None,
47
+ ) -> None:
48
+ """
49
+ Args:
50
+ hidden_size (int): dimension of hidden layer.
51
+ num_heads (int): number of attention heads.
52
+ dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
53
+ hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
54
+ context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
55
+ dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
56
+ qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
57
+ save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
58
+ causal: whether to use causal attention.
59
+ sequence_length: if causal is True, it is necessary to specify the sequence length.
60
+ rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
61
+ For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
62
+ input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
63
+ positional parameter size.
64
+ attention_dtype: cast attention operations to this dtype.
65
+ """
66
+
67
+ super().__init__()
68
+
69
+ if not (0 <= dropout_rate <= 1):
70
+ raise ValueError("dropout_rate should be between 0 and 1.")
71
+
72
+ if dim_head:
73
+ inner_size = num_heads * dim_head
74
+ self.head_dim = dim_head
75
+ else:
76
+ if hidden_size % num_heads != 0:
77
+ raise ValueError("hidden size should be divisible by num_heads.")
78
+ inner_size = hidden_size
79
+ self.head_dim = hidden_size // num_heads
80
+
81
+ if causal and sequence_length is None:
82
+ raise ValueError("sequence_length is necessary for causal attention.")
83
+
84
+ self.num_heads = num_heads
85
+ self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
86
+ self.context_input_size = context_input_size if context_input_size else hidden_size
87
+ self.out_proj = nn.Linear(inner_size, self.hidden_input_size)
88
+ # key, query, value projections
89
+ self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias)
90
+ self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
91
+ self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
92
+ self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
93
+
94
+ self.out_rearrange = Rearrange("b h l d -> b l (h d)")
95
+ self.drop_output = nn.Dropout(dropout_rate)
96
+ self.drop_weights = nn.Dropout(dropout_rate)
97
+
98
+ self.scale = self.head_dim**-0.5
99
+ self.save_attn = save_attn
100
+ self.attention_dtype = attention_dtype
101
+
102
+ self.causal = causal
103
+ self.sequence_length = sequence_length
104
+
105
+ if causal and sequence_length is not None:
106
+ # causal mask to ensure that attention is only applied to the left in the input sequence
107
+ self.register_buffer(
108
+ "causal_mask",
109
+ torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
110
+ )
111
+ self.causal_mask: torch.Tensor
112
+ else:
113
+ self.causal_mask = torch.Tensor()
114
+
115
+ self.att_mat = torch.Tensor()
116
+ self.rel_positional_embedding = (
117
+ get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
118
+ if rel_pos_embedding is not None
119
+ else None
120
+ )
121
+ self.input_size = input_size
122
+
123
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
124
+ """
125
+ Args:
126
+ x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
127
+ context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C
128
+
129
+ Return:
130
+ torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
131
+ """
132
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
133
+ b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
134
+
135
+ q = self.to_q(x)
136
+ kv = context if context is not None else x
137
+ _, kv_t, _ = kv.size()
138
+ k = self.to_k(kv)
139
+ v = self.to_v(kv)
140
+
141
+ if self.attention_dtype is not None:
142
+ q = q.to(self.attention_dtype)
143
+ k = k.to(self.attention_dtype)
144
+
145
+ q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
146
+ k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
147
+ v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
148
+ att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
149
+
150
+ # apply relative positional embedding if defined
151
+ att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
152
+
153
+ if self.causal:
154
+ att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
155
+
156
+ att_mat = att_mat.softmax(dim=-1)
157
+
158
+ if self.save_attn:
159
+ # no gradients and new tensor;
160
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
161
+ self.att_mat = att_mat.detach()
162
+
163
+ att_mat = self.drop_weights(att_mat)
164
+ x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
165
+ x = self.out_rearrange(x)
166
+ x = self.out_proj(x)
167
+ x = self.drop_output(x)
168
+ return x
@@ -0,0 +1,56 @@
1
+ # you may not use this file except in compliance with the License.
2
+ # You may obtain a copy of the License at
3
+ # http://www.apache.org/licenses/LICENSE-2.0
4
+ # Unless required by applicable law or agreed to in writing, software
5
+ # distributed under the License is distributed on an "AS IS" BASIS,
6
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
+ # See the License for the specific language governing permissions and
8
+ # limitations under the License.
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Iterable, Tuple
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from monai.networks.blocks.attention_utils import add_decomposed_rel_pos
18
+ from monai.utils.misc import ensure_tuple_size
19
+
20
+
21
+ class DecomposedRelativePosEmbedding(nn.Module):
22
+ def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None:
23
+ """
24
+ Args:
25
+ s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D)
26
+ c_dim (int): channel dimension
27
+ num_heads(int): number of attention heads
28
+ """
29
+ super().__init__()
30
+
31
+ # validate inputs
32
+ if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]:
33
+ raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)")
34
+
35
+ self.s_input_dims = s_input_dims
36
+ self.c_dim = c_dim
37
+ self.num_heads = num_heads
38
+ self.rel_pos_arr = nn.ParameterList(
39
+ [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims]
40
+ )
41
+
42
+ def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
43
+ """"""
44
+ batch = x.shape[0]
45
+ h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1)
46
+
47
+ att_mat = add_decomposed_rel_pos(
48
+ att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d),
49
+ q.contiguous().view(batch * self.num_heads, h * w * d, -1),
50
+ self.rel_pos_arr,
51
+ (h, w) if d == 1 else (h, w, d),
52
+ (h, w) if d == 1 else (h, w, d),
53
+ )
54
+
55
+ att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)
56
+ return att_mat
@@ -11,9 +11,12 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from typing import Optional, Tuple
15
+
14
16
  import torch
15
17
  import torch.nn as nn
16
18
 
19
+ from monai.networks.layers.utils import get_rel_pos_embedding_layer
17
20
  from monai.utils import optional_import
18
21
 
19
22
  Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
@@ -33,6 +36,12 @@ class SABlock(nn.Module):
33
36
  qkv_bias: bool = False,
34
37
  save_attn: bool = False,
35
38
  dim_head: int | None = None,
39
+ hidden_input_size: int | None = None,
40
+ causal: bool = False,
41
+ sequence_length: int | None = None,
42
+ rel_pos_embedding: Optional[str] = None,
43
+ input_size: Optional[Tuple] = None,
44
+ attention_dtype: Optional[torch.dtype] = None,
36
45
  ) -> None:
37
46
  """
38
47
  Args:
@@ -42,6 +51,14 @@ class SABlock(nn.Module):
42
51
  qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
43
52
  save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
44
53
  dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
54
+ hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
55
+ causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762).
56
+ sequence_length: if causal is True, it is necessary to specify the sequence length.
57
+ rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
58
+ For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
59
+ input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
60
+ positional parameter size.
61
+ attention_dtype: cast attention operations to this dtype.
45
62
 
46
63
  """
47
64
 
@@ -53,12 +70,23 @@ class SABlock(nn.Module):
53
70
  if hidden_size % num_heads != 0:
54
71
  raise ValueError("hidden size should be divisible by num_heads.")
55
72
 
73
+ if dim_head:
74
+ self.inner_dim = num_heads * dim_head
75
+ self.dim_head = dim_head
76
+ else:
77
+ if hidden_size % num_heads != 0:
78
+ raise ValueError("hidden size should be divisible by num_heads.")
79
+ self.inner_dim = hidden_size
80
+ self.dim_head = hidden_size // num_heads
81
+
82
+ if causal and sequence_length is None:
83
+ raise ValueError("sequence_length is necessary for causal attention.")
84
+
56
85
  self.num_heads = num_heads
57
- self.dim_head = hidden_size // num_heads if dim_head is None else dim_head
58
- self.inner_dim = self.dim_head * num_heads
86
+ self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
87
+ self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
59
88
 
60
- self.out_proj = nn.Linear(self.inner_dim, hidden_size)
61
- self.qkv = nn.Linear(hidden_size, self.inner_dim * 3, bias=qkv_bias)
89
+ self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
62
90
  self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
63
91
  self.out_rearrange = Rearrange("b h l d -> b l (h d)")
64
92
  self.drop_output = nn.Dropout(dropout_rate)
@@ -66,11 +94,52 @@ class SABlock(nn.Module):
66
94
  self.scale = self.dim_head**-0.5
67
95
  self.save_attn = save_attn
68
96
  self.att_mat = torch.Tensor()
97
+ self.attention_dtype = attention_dtype
98
+ self.causal = causal
99
+ self.sequence_length = sequence_length
100
+
101
+ if causal and sequence_length is not None:
102
+ # causal mask to ensure that attention is only applied to the left in the input sequence
103
+ self.register_buffer(
104
+ "causal_mask",
105
+ torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
106
+ )
107
+ self.causal_mask: torch.Tensor
108
+ else:
109
+ self.causal_mask = torch.Tensor()
110
+
111
+ self.rel_positional_embedding = (
112
+ get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads)
113
+ if rel_pos_embedding is not None
114
+ else None
115
+ )
116
+ self.input_size = input_size
69
117
 
70
118
  def forward(self, x):
119
+ """
120
+ Args:
121
+ x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
122
+
123
+ Return:
124
+ torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
125
+ """
71
126
  output = self.input_rearrange(self.qkv(x))
72
127
  q, k, v = output[0], output[1], output[2]
73
- att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
128
+
129
+ if self.attention_dtype is not None:
130
+ q = q.to(self.attention_dtype)
131
+ k = k.to(self.attention_dtype)
132
+
133
+ att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
134
+
135
+ # apply relative positional embedding if defined
136
+ att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
137
+
138
+ if self.causal:
139
+ att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
140
+
141
+ att_mat = att_mat.softmax(dim=-1)
142
+
74
143
  if self.save_attn:
75
144
  # no gradients and new tensor;
76
145
  # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
@@ -0,0 +1,95 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from monai.networks.blocks import Convolution
19
+ from monai.networks.layers.utils import get_norm_layer
20
+
21
+
22
+ class SPADE(nn.Module):
23
+ """
24
+ Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a
25
+ semantic map. This block is used in SPADE-based image-to-image translation models, as described in
26
+ Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291).
27
+
28
+ Args:
29
+ label_nc: number of semantic labels
30
+ norm_nc: number of output channels
31
+ kernel_size: kernel size
32
+ spatial_dims: number of spatial dimensions
33
+ hidden_channels: number of channels in the intermediate gamma and beta layers
34
+ norm: type of base normalisation used before applying the SPADE normalisation
35
+ norm_params: parameters for the base normalisation
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ label_nc: int,
41
+ norm_nc: int,
42
+ kernel_size: int = 3,
43
+ spatial_dims: int = 2,
44
+ hidden_channels: int = 64,
45
+ norm: str | tuple = "INSTANCE",
46
+ norm_params: dict | None = None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ if norm_params is None:
51
+ norm_params = {}
52
+ if len(norm_params) != 0:
53
+ norm = (norm, norm_params)
54
+ self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc)
55
+ self.mlp_shared = Convolution(
56
+ spatial_dims=spatial_dims,
57
+ in_channels=label_nc,
58
+ out_channels=hidden_channels,
59
+ kernel_size=kernel_size,
60
+ norm=None,
61
+ act="LEAKYRELU",
62
+ )
63
+ self.mlp_gamma = Convolution(
64
+ spatial_dims=spatial_dims,
65
+ in_channels=hidden_channels,
66
+ out_channels=norm_nc,
67
+ kernel_size=kernel_size,
68
+ act=None,
69
+ )
70
+ self.mlp_beta = Convolution(
71
+ spatial_dims=spatial_dims,
72
+ in_channels=hidden_channels,
73
+ out_channels=norm_nc,
74
+ kernel_size=kernel_size,
75
+ act=None,
76
+ )
77
+
78
+ def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Args:
81
+ x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels.
82
+ segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels.
83
+ The map will be interpolated to the dimension of x internally.
84
+ """
85
+
86
+ # Part 1. generate parameter-free normalized activations
87
+ normalized = self.param_free_norm(x.contiguous())
88
+
89
+ # Part 2. produce scaling and bias conditioned on semantic map
90
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
91
+ actv = self.mlp_shared(segmap)
92
+ gamma = self.mlp_gamma(actv)
93
+ beta = self.mlp_beta(actv)
94
+ out: torch.Tensor = normalized * (1 + gamma) + beta
95
+ return out