ultralytics 8.2.72__py3-none-any.whl → 8.2.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -3
- ultralytics/models/__init__.py +1 -2
- ultralytics/models/sam/__init__.py +2 -2
- ultralytics/models/sam/amg.py +27 -21
- ultralytics/models/sam/build.py +200 -9
- ultralytics/models/sam/model.py +86 -34
- ultralytics/models/sam/modules/blocks.py +1131 -0
- ultralytics/models/sam/modules/decoders.py +390 -23
- ultralytics/models/sam/modules/encoders.py +508 -323
- ultralytics/models/{sam2 → sam}/modules/memory_attention.py +73 -6
- ultralytics/models/sam/modules/sam.py +887 -16
- ultralytics/models/sam/modules/tiny_encoder.py +376 -126
- ultralytics/models/sam/modules/transformer.py +155 -54
- ultralytics/models/{sam2 → sam}/modules/utils.py +105 -3
- ultralytics/models/sam/predict.py +382 -92
- {ultralytics-8.2.72.dist-info → ultralytics-8.2.73.dist-info}/METADATA +44 -44
- {ultralytics-8.2.72.dist-info → ultralytics-8.2.73.dist-info}/RECORD +21 -29
- ultralytics/models/sam2/__init__.py +0 -6
- ultralytics/models/sam2/build.py +0 -156
- ultralytics/models/sam2/model.py +0 -97
- ultralytics/models/sam2/modules/__init__.py +0 -1
- ultralytics/models/sam2/modules/decoders.py +0 -305
- ultralytics/models/sam2/modules/encoders.py +0 -332
- ultralytics/models/sam2/modules/sam2.py +0 -804
- ultralytics/models/sam2/modules/sam2_blocks.py +0 -715
- ultralytics/models/sam2/predict.py +0 -177
- {ultralytics-8.2.72.dist-info → ultralytics-8.2.73.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.72.dist-info → ultralytics-8.2.73.dist-info}/WHEEL +0 -0
- {ultralytics-8.2.72.dist-info → ultralytics-8.2.73.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.72.dist-info → ultralytics-8.2.73.dist-info}/top_level.txt +0 -0
|
@@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
|
|
|
11
11
|
|
|
12
12
|
class TwoWayTransformer(nn.Module):
|
|
13
13
|
"""
|
|
14
|
-
A Two-Way Transformer module
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
14
|
+
A Two-Way Transformer module for simultaneous attention to image and query points.
|
|
15
|
+
|
|
16
|
+
This class implements a specialized transformer decoder that attends to an input image using queries with
|
|
17
|
+
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
|
|
18
|
+
cloud processing.
|
|
18
19
|
|
|
19
20
|
Attributes:
|
|
20
|
-
depth (int):
|
|
21
|
-
embedding_dim (int):
|
|
22
|
-
num_heads (int):
|
|
23
|
-
mlp_dim (int):
|
|
24
|
-
layers (nn.ModuleList):
|
|
25
|
-
final_attn_token_to_image (Attention):
|
|
26
|
-
norm_final_attn (nn.LayerNorm):
|
|
21
|
+
depth (int): Number of layers in the transformer.
|
|
22
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
|
23
|
+
num_heads (int): Number of heads for multihead attention.
|
|
24
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
|
25
|
+
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
|
|
26
|
+
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
|
27
|
+
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
|
28
|
+
|
|
29
|
+
Methods:
|
|
30
|
+
forward: Processes image and point embeddings through the transformer.
|
|
31
|
+
|
|
32
|
+
Examples:
|
|
33
|
+
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
34
|
+
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
|
35
|
+
>>> image_pe = torch.randn(1, 256, 32, 32)
|
|
36
|
+
>>> point_embedding = torch.randn(1, 100, 256)
|
|
37
|
+
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
|
38
|
+
>>> print(output_queries.shape, output_image.shape)
|
|
27
39
|
"""
|
|
28
40
|
|
|
29
41
|
def __init__(
|
|
@@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module):
|
|
|
36
48
|
attention_downsample_rate: int = 2,
|
|
37
49
|
) -> None:
|
|
38
50
|
"""
|
|
39
|
-
|
|
51
|
+
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
|
40
52
|
|
|
41
53
|
Args:
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
54
|
+
depth (int): Number of layers in the transformer.
|
|
55
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
|
56
|
+
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
|
|
57
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
|
58
|
+
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
|
59
|
+
attention_downsample_rate (int): Downsampling rate for attention mechanism.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
depth (int): Number of layers in the transformer.
|
|
63
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
|
64
|
+
embedding_dim (int): Channel dimension for input embeddings.
|
|
65
|
+
num_heads (int): Number of heads for multihead attention.
|
|
66
|
+
mlp_dim (int): Internal channel dimension for the MLP block.
|
|
67
|
+
layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
|
|
68
|
+
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
|
69
|
+
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
73
|
+
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
|
74
|
+
>>> image_pe = torch.randn(1, 256, 32, 32)
|
|
75
|
+
>>> point_embedding = torch.randn(1, 100, 256)
|
|
76
|
+
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
|
77
|
+
>>> print(output_queries.shape, output_image.shape)
|
|
48
78
|
"""
|
|
49
79
|
super().__init__()
|
|
50
80
|
self.depth = depth
|
|
@@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module):
|
|
|
75
105
|
point_embedding: Tensor,
|
|
76
106
|
) -> Tuple[Tensor, Tensor]:
|
|
77
107
|
"""
|
|
108
|
+
Processes image and point embeddings through the Two-Way Transformer.
|
|
109
|
+
|
|
78
110
|
Args:
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
Must have shape B x N_points x embedding_dim for any N_points.
|
|
111
|
+
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
|
112
|
+
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
|
|
113
|
+
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
|
|
83
114
|
|
|
84
115
|
Returns:
|
|
85
|
-
|
|
86
|
-
|
|
116
|
+
(Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
|
|
117
|
+
|
|
118
|
+
Examples:
|
|
119
|
+
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
|
120
|
+
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
|
121
|
+
>>> image_pe = torch.randn(1, 256, 32, 32)
|
|
122
|
+
>>> point_embedding = torch.randn(1, 100, 256)
|
|
123
|
+
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
|
124
|
+
>>> print(output_queries.shape, output_image.shape)
|
|
87
125
|
"""
|
|
88
126
|
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
|
89
127
|
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
|
@@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module):
|
|
|
114
152
|
|
|
115
153
|
class TwoWayAttentionBlock(nn.Module):
|
|
116
154
|
"""
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
sparse inputs
|
|
155
|
+
A two-way attention block for simultaneous attention to image and query points.
|
|
156
|
+
|
|
157
|
+
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
|
|
158
|
+
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
|
|
159
|
+
inputs to sparse inputs.
|
|
121
160
|
|
|
122
161
|
Attributes:
|
|
123
|
-
self_attn (Attention):
|
|
124
|
-
norm1 (nn.LayerNorm): Layer normalization
|
|
162
|
+
self_attn (Attention): Self-attention layer for queries.
|
|
163
|
+
norm1 (nn.LayerNorm): Layer normalization after self-attention.
|
|
125
164
|
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
|
126
|
-
norm2 (nn.LayerNorm): Layer normalization
|
|
127
|
-
mlp (MLPBlock): MLP block
|
|
128
|
-
norm3 (nn.LayerNorm): Layer normalization
|
|
129
|
-
norm4 (nn.LayerNorm): Layer normalization
|
|
165
|
+
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
|
|
166
|
+
mlp (MLPBlock): MLP block for transforming query embeddings.
|
|
167
|
+
norm3 (nn.LayerNorm): Layer normalization after MLP block.
|
|
168
|
+
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
|
|
130
169
|
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
|
131
|
-
skip_first_layer_pe (bool): Whether to skip
|
|
170
|
+
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
|
171
|
+
|
|
172
|
+
Methods:
|
|
173
|
+
forward: Applies self-attention and cross-attention to queries and keys.
|
|
174
|
+
|
|
175
|
+
Examples:
|
|
176
|
+
>>> embedding_dim, num_heads = 256, 8
|
|
177
|
+
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
|
178
|
+
>>> queries = torch.randn(1, 100, embedding_dim)
|
|
179
|
+
>>> keys = torch.randn(1, 1000, embedding_dim)
|
|
180
|
+
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
|
181
|
+
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
|
182
|
+
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
|
132
183
|
"""
|
|
133
184
|
|
|
134
185
|
def __init__(
|
|
@@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
141
192
|
skip_first_layer_pe: bool = False,
|
|
142
193
|
) -> None:
|
|
143
194
|
"""
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
195
|
+
Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
|
196
|
+
|
|
197
|
+
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
|
198
|
+
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
|
|
199
|
+
of dense inputs to sparse inputs.
|
|
147
200
|
|
|
148
201
|
Args:
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
202
|
+
embedding_dim (int): Channel dimension of the embeddings.
|
|
203
|
+
num_heads (int): Number of attention heads in the attention layers.
|
|
204
|
+
mlp_dim (int): Hidden dimension of the MLP block.
|
|
205
|
+
activation (Type[nn.Module]): Activation function for the MLP block.
|
|
206
|
+
attention_downsample_rate (int): Downsampling rate for the attention mechanism.
|
|
207
|
+
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
|
208
|
+
|
|
209
|
+
Examples:
|
|
210
|
+
>>> embedding_dim, num_heads = 256, 8
|
|
211
|
+
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
|
212
|
+
>>> queries = torch.randn(1, 100, embedding_dim)
|
|
213
|
+
>>> keys = torch.randn(1, 1000, embedding_dim)
|
|
214
|
+
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
|
215
|
+
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
|
216
|
+
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
|
154
217
|
"""
|
|
155
218
|
super().__init__()
|
|
156
219
|
self.self_attn = Attention(embedding_dim, num_heads)
|
|
@@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
168
231
|
self.skip_first_layer_pe = skip_first_layer_pe
|
|
169
232
|
|
|
170
233
|
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
|
171
|
-
"""
|
|
234
|
+
"""Applies two-way attention to process query and key embeddings in a transformer block."""
|
|
172
235
|
|
|
173
236
|
# Self attention block
|
|
174
237
|
if self.skip_first_layer_pe:
|
|
@@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
202
265
|
|
|
203
266
|
|
|
204
267
|
class Attention(nn.Module):
|
|
205
|
-
"""
|
|
206
|
-
|
|
268
|
+
"""
|
|
269
|
+
An attention layer with downscaling capability for embedding size after projection.
|
|
270
|
+
|
|
271
|
+
This class implements a multi-head attention mechanism with the option to downsample the internal
|
|
272
|
+
dimension of queries, keys, and values.
|
|
273
|
+
|
|
274
|
+
Attributes:
|
|
275
|
+
embedding_dim (int): Dimensionality of input embeddings.
|
|
276
|
+
kv_in_dim (int): Dimensionality of key and value inputs.
|
|
277
|
+
internal_dim (int): Internal dimension after downsampling.
|
|
278
|
+
num_heads (int): Number of attention heads.
|
|
279
|
+
q_proj (nn.Linear): Linear projection for queries.
|
|
280
|
+
k_proj (nn.Linear): Linear projection for keys.
|
|
281
|
+
v_proj (nn.Linear): Linear projection for values.
|
|
282
|
+
out_proj (nn.Linear): Linear projection for output.
|
|
283
|
+
|
|
284
|
+
Methods:
|
|
285
|
+
_separate_heads: Separates input tensor into attention heads.
|
|
286
|
+
_recombine_heads: Recombines separated attention heads.
|
|
287
|
+
forward: Computes attention output for given query, key, and value tensors.
|
|
288
|
+
|
|
289
|
+
Examples:
|
|
290
|
+
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
|
291
|
+
>>> q = torch.randn(1, 100, 256)
|
|
292
|
+
>>> k = v = torch.randn(1, 50, 256)
|
|
293
|
+
>>> output = attn(q, k, v)
|
|
294
|
+
>>> print(output.shape)
|
|
295
|
+
torch.Size([1, 100, 256])
|
|
207
296
|
"""
|
|
208
297
|
|
|
209
298
|
def __init__(
|
|
@@ -214,15 +303,27 @@ class Attention(nn.Module):
|
|
|
214
303
|
kv_in_dim: int = None,
|
|
215
304
|
) -> None:
|
|
216
305
|
"""
|
|
217
|
-
Initializes the Attention
|
|
306
|
+
Initializes the Attention module with specified dimensions and settings.
|
|
307
|
+
|
|
308
|
+
This class implements a multi-head attention mechanism with optional downsampling of the internal
|
|
309
|
+
dimension for queries, keys, and values.
|
|
218
310
|
|
|
219
311
|
Args:
|
|
220
|
-
embedding_dim (int):
|
|
221
|
-
num_heads (int):
|
|
222
|
-
downsample_rate (int
|
|
312
|
+
embedding_dim (int): Dimensionality of input embeddings.
|
|
313
|
+
num_heads (int): Number of attention heads.
|
|
314
|
+
downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
|
|
315
|
+
kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
|
|
223
316
|
|
|
224
317
|
Raises:
|
|
225
|
-
AssertionError: If
|
|
318
|
+
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
|
319
|
+
|
|
320
|
+
Examples:
|
|
321
|
+
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
|
322
|
+
>>> q = torch.randn(1, 100, 256)
|
|
323
|
+
>>> k = v = torch.randn(1, 50, 256)
|
|
324
|
+
>>> output = attn(q, k, v)
|
|
325
|
+
>>> print(output.shape)
|
|
326
|
+
torch.Size([1, 100, 256])
|
|
226
327
|
"""
|
|
227
328
|
super().__init__()
|
|
228
329
|
self.embedding_dim = embedding_dim
|
|
@@ -238,20 +339,20 @@ class Attention(nn.Module):
|
|
|
238
339
|
|
|
239
340
|
@staticmethod
|
|
240
341
|
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
|
241
|
-
"""
|
|
342
|
+
"""Separates the input tensor into the specified number of attention heads."""
|
|
242
343
|
b, n, c = x.shape
|
|
243
344
|
x = x.reshape(b, n, num_heads, c // num_heads)
|
|
244
345
|
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
|
245
346
|
|
|
246
347
|
@staticmethod
|
|
247
348
|
def _recombine_heads(x: Tensor) -> Tensor:
|
|
248
|
-
"""
|
|
349
|
+
"""Recombines separated attention heads into a single tensor."""
|
|
249
350
|
b, n_heads, n_tokens, c_per_head = x.shape
|
|
250
351
|
x = x.transpose(1, 2)
|
|
251
352
|
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
|
252
353
|
|
|
253
354
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
|
254
|
-
"""
|
|
355
|
+
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
|
|
255
356
|
|
|
256
357
|
# Input projections
|
|
257
358
|
q = self.q_proj(q)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
import torch.nn.functional as F
|
|
5
7
|
|
|
@@ -70,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
|
|
70
72
|
|
|
71
73
|
|
|
72
74
|
def init_t_xy(end_x: int, end_y: int):
|
|
73
|
-
"""Initializes 1D and 2D coordinate tensors for a grid of
|
|
75
|
+
"""Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
|
|
74
76
|
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
|
75
77
|
t_x = (t % end_x).float()
|
|
76
78
|
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
|
@@ -78,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int):
|
|
|
78
80
|
|
|
79
81
|
|
|
80
82
|
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
81
|
-
"""Computes axial complex exponential positional encodings for 2D spatial positions."""
|
|
83
|
+
"""Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
|
|
82
84
|
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
83
85
|
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
|
84
86
|
|
|
@@ -91,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|
|
91
93
|
|
|
92
94
|
|
|
93
95
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
94
|
-
"""Reshapes frequency tensor for broadcasting
|
|
96
|
+
"""Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
|
|
95
97
|
ndim = x.ndim
|
|
96
98
|
assert 0 <= 1 < ndim
|
|
97
99
|
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
|
@@ -189,3 +191,103 @@ def window_unpartition(windows, window_size, pad_hw, hw):
|
|
|
189
191
|
if Hp > H or Wp > W:
|
|
190
192
|
x = x[:, :H, :W, :].contiguous()
|
|
191
193
|
return x
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
|
197
|
+
"""
|
|
198
|
+
Extracts relative positional embeddings based on query and key sizes.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
q_size (int): Size of the query.
|
|
202
|
+
k_size (int): Size of the key.
|
|
203
|
+
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
|
|
204
|
+
distance and C is the embedding dimension.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
|
|
208
|
+
k_size, C).
|
|
209
|
+
|
|
210
|
+
Examples:
|
|
211
|
+
>>> q_size, k_size = 8, 16
|
|
212
|
+
>>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
|
|
213
|
+
>>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
|
|
214
|
+
>>> print(extracted_pos.shape)
|
|
215
|
+
torch.Size([8, 16, 64])
|
|
216
|
+
"""
|
|
217
|
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
218
|
+
# Interpolate rel pos if needed.
|
|
219
|
+
if rel_pos.shape[0] != max_rel_dist:
|
|
220
|
+
# Interpolate rel pos.
|
|
221
|
+
rel_pos_resized = F.interpolate(
|
|
222
|
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
|
223
|
+
size=max_rel_dist,
|
|
224
|
+
mode="linear",
|
|
225
|
+
)
|
|
226
|
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
|
227
|
+
else:
|
|
228
|
+
rel_pos_resized = rel_pos
|
|
229
|
+
|
|
230
|
+
# Scale the coords with short length if shapes for q and k are different.
|
|
231
|
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
|
232
|
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
|
233
|
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
234
|
+
|
|
235
|
+
return rel_pos_resized[relative_coords.long()]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def add_decomposed_rel_pos(
|
|
239
|
+
attn: torch.Tensor,
|
|
240
|
+
q: torch.Tensor,
|
|
241
|
+
rel_pos_h: torch.Tensor,
|
|
242
|
+
rel_pos_w: torch.Tensor,
|
|
243
|
+
q_size: Tuple[int, int],
|
|
244
|
+
k_size: Tuple[int, int],
|
|
245
|
+
) -> torch.Tensor:
|
|
246
|
+
"""
|
|
247
|
+
Adds decomposed Relative Positional Embeddings to the attention map.
|
|
248
|
+
|
|
249
|
+
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
|
250
|
+
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
|
251
|
+
positions.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
|
|
255
|
+
q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
|
|
256
|
+
rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
|
|
257
|
+
rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
|
|
258
|
+
q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
|
|
259
|
+
k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
|
|
263
|
+
(B, q_h * q_w, k_h * k_w).
|
|
264
|
+
|
|
265
|
+
Examples:
|
|
266
|
+
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
|
|
267
|
+
>>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
|
|
268
|
+
>>> q = torch.rand(B, q_h * q_w, C)
|
|
269
|
+
>>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
|
|
270
|
+
>>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
|
|
271
|
+
>>> q_size, k_size = (q_h, q_w), (k_h, k_w)
|
|
272
|
+
>>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
|
|
273
|
+
>>> print(updated_attn.shape)
|
|
274
|
+
torch.Size([1, 64, 64])
|
|
275
|
+
|
|
276
|
+
References:
|
|
277
|
+
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
|
|
278
|
+
"""
|
|
279
|
+
q_h, q_w = q_size
|
|
280
|
+
k_h, k_w = k_size
|
|
281
|
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
282
|
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
283
|
+
|
|
284
|
+
B, _, dim = q.shape
|
|
285
|
+
r_q = q.reshape(B, q_h, q_w, dim)
|
|
286
|
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
|
287
|
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
288
|
+
|
|
289
|
+
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
|
290
|
+
B, q_h * q_w, k_h * k_w
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
return attn
|