ultralytics 8.2.72__py3-none-any.whl → 8.2.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (34) hide show
  1. ultralytics/__init__.py +2 -3
  2. ultralytics/cfg/trackers/botsort.yaml +1 -1
  3. ultralytics/cfg/trackers/bytetrack.yaml +1 -1
  4. ultralytics/models/__init__.py +1 -2
  5. ultralytics/models/sam/__init__.py +2 -2
  6. ultralytics/models/sam/amg.py +27 -21
  7. ultralytics/models/sam/build.py +200 -9
  8. ultralytics/models/sam/model.py +86 -34
  9. ultralytics/models/sam/modules/blocks.py +1131 -0
  10. ultralytics/models/sam/modules/decoders.py +390 -23
  11. ultralytics/models/sam/modules/encoders.py +508 -323
  12. ultralytics/models/{sam2 → sam}/modules/memory_attention.py +73 -6
  13. ultralytics/models/sam/modules/sam.py +887 -16
  14. ultralytics/models/sam/modules/tiny_encoder.py +376 -126
  15. ultralytics/models/sam/modules/transformer.py +155 -54
  16. ultralytics/models/{sam2 → sam}/modules/utils.py +105 -3
  17. ultralytics/models/sam/predict.py +382 -92
  18. ultralytics/trackers/bot_sort.py +2 -3
  19. ultralytics/trackers/byte_tracker.py +2 -3
  20. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/METADATA +44 -44
  21. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/RECORD +25 -33
  22. ultralytics/models/sam2/__init__.py +0 -6
  23. ultralytics/models/sam2/build.py +0 -156
  24. ultralytics/models/sam2/model.py +0 -97
  25. ultralytics/models/sam2/modules/__init__.py +0 -1
  26. ultralytics/models/sam2/modules/decoders.py +0 -305
  27. ultralytics/models/sam2/modules/encoders.py +0 -332
  28. ultralytics/models/sam2/modules/sam2.py +0 -804
  29. ultralytics/models/sam2/modules/sam2_blocks.py +0 -715
  30. ultralytics/models/sam2/predict.py +0 -177
  31. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/LICENSE +0 -0
  32. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/WHEEL +0 -0
  33. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/entry_points.txt +0 -0
  34. {ultralytics-8.2.72.dist-info → ultralytics-8.2.74.dist-info}/top_level.txt +0 -0
@@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
11
11
 
12
12
  class TwoWayTransformer(nn.Module):
13
13
  """
14
- A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
15
- serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
16
- is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
17
- processing.
14
+ A Two-Way Transformer module for simultaneous attention to image and query points.
15
+
16
+ This class implements a specialized transformer decoder that attends to an input image using queries with
17
+ supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
18
+ cloud processing.
18
19
 
19
20
  Attributes:
20
- depth (int): The number of layers in the transformer.
21
- embedding_dim (int): The channel dimension for the input embeddings.
22
- num_heads (int): The number of heads for multihead attention.
23
- mlp_dim (int): The internal channel dimension for the MLP block.
24
- layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
25
- final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
26
- norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
21
+ depth (int): Number of layers in the transformer.
22
+ embedding_dim (int): Channel dimension for input embeddings.
23
+ num_heads (int): Number of heads for multihead attention.
24
+ mlp_dim (int): Internal channel dimension for the MLP block.
25
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
26
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
27
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
28
+
29
+ Methods:
30
+ forward: Processes image and point embeddings through the transformer.
31
+
32
+ Examples:
33
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
34
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
35
+ >>> image_pe = torch.randn(1, 256, 32, 32)
36
+ >>> point_embedding = torch.randn(1, 100, 256)
37
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
38
+ >>> print(output_queries.shape, output_image.shape)
27
39
  """
28
40
 
29
41
  def __init__(
@@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module):
36
48
  attention_downsample_rate: int = 2,
37
49
  ) -> None:
38
50
  """
39
- A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
51
+ Initialize a Two-Way Transformer for simultaneous attention to image and query points.
40
52
 
41
53
  Args:
42
- depth (int): number of layers in the transformer
43
- embedding_dim (int): the channel dimension for the input embeddings
44
- num_heads (int): the number of heads for multihead attention. Must
45
- divide embedding_dim
46
- mlp_dim (int): the channel dimension internal to the MLP block
47
- activation (nn.Module): the activation to use in the MLP block
54
+ depth (int): Number of layers in the transformer.
55
+ embedding_dim (int): Channel dimension for input embeddings.
56
+ num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
57
+ mlp_dim (int): Internal channel dimension for the MLP block.
58
+ activation (Type[nn.Module]): Activation function to use in the MLP block.
59
+ attention_downsample_rate (int): Downsampling rate for attention mechanism.
60
+
61
+ Attributes:
62
+ depth (int): Number of layers in the transformer.
63
+ embedding_dim (int): Channel dimension for input embeddings.
64
+ embedding_dim (int): Channel dimension for input embeddings.
65
+ num_heads (int): Number of heads for multihead attention.
66
+ mlp_dim (int): Internal channel dimension for the MLP block.
67
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
68
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
69
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
70
+
71
+ Examples:
72
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
73
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
74
+ >>> image_pe = torch.randn(1, 256, 32, 32)
75
+ >>> point_embedding = torch.randn(1, 100, 256)
76
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
77
+ >>> print(output_queries.shape, output_image.shape)
48
78
  """
49
79
  super().__init__()
50
80
  self.depth = depth
@@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module):
75
105
  point_embedding: Tensor,
76
106
  ) -> Tuple[Tensor, Tensor]:
77
107
  """
108
+ Processes image and point embeddings through the Two-Way Transformer.
109
+
78
110
  Args:
79
- image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
80
- image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
81
- point_embedding (torch.Tensor): the embedding to add to the query points.
82
- Must have shape B x N_points x embedding_dim for any N_points.
111
+ image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
112
+ image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
113
+ point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
83
114
 
84
115
  Returns:
85
- (torch.Tensor): the processed point_embedding
86
- (torch.Tensor): the processed image_embedding
116
+ (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
117
+
118
+ Examples:
119
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
120
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
121
+ >>> image_pe = torch.randn(1, 256, 32, 32)
122
+ >>> point_embedding = torch.randn(1, 100, 256)
123
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
124
+ >>> print(output_queries.shape, output_image.shape)
87
125
  """
88
126
  # BxCxHxW -> BxHWxC == B x N_image_tokens x C
89
127
  image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
@@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module):
114
152
 
115
153
  class TwoWayAttentionBlock(nn.Module):
116
154
  """
117
- An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
118
- keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
119
- of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
120
- sparse inputs.
155
+ A two-way attention block for simultaneous attention to image and query points.
156
+
157
+ This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
158
+ cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
159
+ inputs to sparse inputs.
121
160
 
122
161
  Attributes:
123
- self_attn (Attention): The self-attention layer for the queries.
124
- norm1 (nn.LayerNorm): Layer normalization following the first attention block.
162
+ self_attn (Attention): Self-attention layer for queries.
163
+ norm1 (nn.LayerNorm): Layer normalization after self-attention.
125
164
  cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
126
- norm2 (nn.LayerNorm): Layer normalization following the second attention block.
127
- mlp (MLPBlock): MLP block that transforms the query embeddings.
128
- norm3 (nn.LayerNorm): Layer normalization following the MLP block.
129
- norm4 (nn.LayerNorm): Layer normalization following the third attention block.
165
+ norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
166
+ mlp (MLPBlock): MLP block for transforming query embeddings.
167
+ norm3 (nn.LayerNorm): Layer normalization after MLP block.
168
+ norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
130
169
  cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
131
- skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
170
+ skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
171
+
172
+ Methods:
173
+ forward: Applies self-attention and cross-attention to queries and keys.
174
+
175
+ Examples:
176
+ >>> embedding_dim, num_heads = 256, 8
177
+ >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
178
+ >>> queries = torch.randn(1, 100, embedding_dim)
179
+ >>> keys = torch.randn(1, 1000, embedding_dim)
180
+ >>> query_pe = torch.randn(1, 100, embedding_dim)
181
+ >>> key_pe = torch.randn(1, 1000, embedding_dim)
182
+ >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
132
183
  """
133
184
 
134
185
  def __init__(
@@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module):
141
192
  skip_first_layer_pe: bool = False,
142
193
  ) -> None:
143
194
  """
144
- A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
145
- inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
146
- inputs.
195
+ Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
196
+
197
+ This block implements a specialized transformer layer with four main components: self-attention on sparse
198
+ inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
199
+ of dense inputs to sparse inputs.
147
200
 
148
201
  Args:
149
- embedding_dim (int): the channel dimension of the embeddings
150
- num_heads (int): the number of heads in the attention layers
151
- mlp_dim (int): the hidden dimension of the mlp block
152
- activation (nn.Module): the activation of the mlp block
153
- skip_first_layer_pe (bool): skip the PE on the first layer
202
+ embedding_dim (int): Channel dimension of the embeddings.
203
+ num_heads (int): Number of attention heads in the attention layers.
204
+ mlp_dim (int): Hidden dimension of the MLP block.
205
+ activation (Type[nn.Module]): Activation function for the MLP block.
206
+ attention_downsample_rate (int): Downsampling rate for the attention mechanism.
207
+ skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
208
+
209
+ Examples:
210
+ >>> embedding_dim, num_heads = 256, 8
211
+ >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
212
+ >>> queries = torch.randn(1, 100, embedding_dim)
213
+ >>> keys = torch.randn(1, 1000, embedding_dim)
214
+ >>> query_pe = torch.randn(1, 100, embedding_dim)
215
+ >>> key_pe = torch.randn(1, 1000, embedding_dim)
216
+ >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
154
217
  """
155
218
  super().__init__()
156
219
  self.self_attn = Attention(embedding_dim, num_heads)
@@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module):
168
231
  self.skip_first_layer_pe = skip_first_layer_pe
169
232
 
170
233
  def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
171
- """Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
234
+ """Applies two-way attention to process query and key embeddings in a transformer block."""
172
235
 
173
236
  # Self attention block
174
237
  if self.skip_first_layer_pe:
@@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module):
202
265
 
203
266
 
204
267
  class Attention(nn.Module):
205
- """An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
206
- values.
268
+ """
269
+ An attention layer with downscaling capability for embedding size after projection.
270
+
271
+ This class implements a multi-head attention mechanism with the option to downsample the internal
272
+ dimension of queries, keys, and values.
273
+
274
+ Attributes:
275
+ embedding_dim (int): Dimensionality of input embeddings.
276
+ kv_in_dim (int): Dimensionality of key and value inputs.
277
+ internal_dim (int): Internal dimension after downsampling.
278
+ num_heads (int): Number of attention heads.
279
+ q_proj (nn.Linear): Linear projection for queries.
280
+ k_proj (nn.Linear): Linear projection for keys.
281
+ v_proj (nn.Linear): Linear projection for values.
282
+ out_proj (nn.Linear): Linear projection for output.
283
+
284
+ Methods:
285
+ _separate_heads: Separates input tensor into attention heads.
286
+ _recombine_heads: Recombines separated attention heads.
287
+ forward: Computes attention output for given query, key, and value tensors.
288
+
289
+ Examples:
290
+ >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
291
+ >>> q = torch.randn(1, 100, 256)
292
+ >>> k = v = torch.randn(1, 50, 256)
293
+ >>> output = attn(q, k, v)
294
+ >>> print(output.shape)
295
+ torch.Size([1, 100, 256])
207
296
  """
208
297
 
209
298
  def __init__(
@@ -214,15 +303,27 @@ class Attention(nn.Module):
214
303
  kv_in_dim: int = None,
215
304
  ) -> None:
216
305
  """
217
- Initializes the Attention model with the given dimensions and settings.
306
+ Initializes the Attention module with specified dimensions and settings.
307
+
308
+ This class implements a multi-head attention mechanism with optional downsampling of the internal
309
+ dimension for queries, keys, and values.
218
310
 
219
311
  Args:
220
- embedding_dim (int): The dimensionality of the input embeddings.
221
- num_heads (int): The number of attention heads.
222
- downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
312
+ embedding_dim (int): Dimensionality of input embeddings.
313
+ num_heads (int): Number of attention heads.
314
+ downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
315
+ kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
223
316
 
224
317
  Raises:
225
- AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate).
318
+ AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
319
+
320
+ Examples:
321
+ >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
322
+ >>> q = torch.randn(1, 100, 256)
323
+ >>> k = v = torch.randn(1, 50, 256)
324
+ >>> output = attn(q, k, v)
325
+ >>> print(output.shape)
326
+ torch.Size([1, 100, 256])
226
327
  """
227
328
  super().__init__()
228
329
  self.embedding_dim = embedding_dim
@@ -238,20 +339,20 @@ class Attention(nn.Module):
238
339
 
239
340
  @staticmethod
240
341
  def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
241
- """Separate the input tensor into the specified number of attention heads."""
342
+ """Separates the input tensor into the specified number of attention heads."""
242
343
  b, n, c = x.shape
243
344
  x = x.reshape(b, n, num_heads, c // num_heads)
244
345
  return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
245
346
 
246
347
  @staticmethod
247
348
  def _recombine_heads(x: Tensor) -> Tensor:
248
- """Recombine the separated attention heads into a single tensor."""
349
+ """Recombines separated attention heads into a single tensor."""
249
350
  b, n_heads, n_tokens, c_per_head = x.shape
250
351
  x = x.transpose(1, 2)
251
352
  return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
252
353
 
253
354
  def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
254
- """Compute the attention output given the input query, key, and value tensors."""
355
+ """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
255
356
 
256
357
  # Input projections
257
358
  q = self.q_proj(q)
@@ -1,5 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
+ from typing import Tuple
4
+
3
5
  import torch
4
6
  import torch.nn.functional as F
5
7
 
@@ -70,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
70
72
 
71
73
 
72
74
  def init_t_xy(end_x: int, end_y: int):
73
- """Initializes 1D and 2D coordinate tensors for a grid of size end_x by end_y."""
75
+ """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
74
76
  t = torch.arange(end_x * end_y, dtype=torch.float32)
75
77
  t_x = (t % end_x).float()
76
78
  t_y = torch.div(t, end_x, rounding_mode="floor").float()
@@ -78,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int):
78
80
 
79
81
 
80
82
  def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
81
- """Computes axial complex exponential positional encodings for 2D spatial positions."""
83
+ """Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
82
84
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
83
85
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
84
86
 
@@ -91,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
91
93
 
92
94
 
93
95
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
94
- """Reshapes frequency tensor for broadcasting, ensuring compatibility with input tensor dimensions."""
96
+ """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
95
97
  ndim = x.ndim
96
98
  assert 0 <= 1 < ndim
97
99
  assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
@@ -189,3 +191,103 @@ def window_unpartition(windows, window_size, pad_hw, hw):
189
191
  if Hp > H or Wp > W:
190
192
  x = x[:, :H, :W, :].contiguous()
191
193
  return x
194
+
195
+
196
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
197
+ """
198
+ Extracts relative positional embeddings based on query and key sizes.
199
+
200
+ Args:
201
+ q_size (int): Size of the query.
202
+ k_size (int): Size of the key.
203
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
204
+ distance and C is the embedding dimension.
205
+
206
+ Returns:
207
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
208
+ k_size, C).
209
+
210
+ Examples:
211
+ >>> q_size, k_size = 8, 16
212
+ >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
213
+ >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
214
+ >>> print(extracted_pos.shape)
215
+ torch.Size([8, 16, 64])
216
+ """
217
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
218
+ # Interpolate rel pos if needed.
219
+ if rel_pos.shape[0] != max_rel_dist:
220
+ # Interpolate rel pos.
221
+ rel_pos_resized = F.interpolate(
222
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
223
+ size=max_rel_dist,
224
+ mode="linear",
225
+ )
226
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
227
+ else:
228
+ rel_pos_resized = rel_pos
229
+
230
+ # Scale the coords with short length if shapes for q and k are different.
231
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
232
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
233
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
234
+
235
+ return rel_pos_resized[relative_coords.long()]
236
+
237
+
238
+ def add_decomposed_rel_pos(
239
+ attn: torch.Tensor,
240
+ q: torch.Tensor,
241
+ rel_pos_h: torch.Tensor,
242
+ rel_pos_w: torch.Tensor,
243
+ q_size: Tuple[int, int],
244
+ k_size: Tuple[int, int],
245
+ ) -> torch.Tensor:
246
+ """
247
+ Adds decomposed Relative Positional Embeddings to the attention map.
248
+
249
+ This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
250
+ paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
251
+ positions.
252
+
253
+ Args:
254
+ attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
255
+ q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
256
+ rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
257
+ rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
258
+ q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
259
+ k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
260
+
261
+ Returns:
262
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape
263
+ (B, q_h * q_w, k_h * k_w).
264
+
265
+ Examples:
266
+ >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
267
+ >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
268
+ >>> q = torch.rand(B, q_h * q_w, C)
269
+ >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
270
+ >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
271
+ >>> q_size, k_size = (q_h, q_w), (k_h, k_w)
272
+ >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
273
+ >>> print(updated_attn.shape)
274
+ torch.Size([1, 64, 64])
275
+
276
+ References:
277
+ https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
278
+ """
279
+ q_h, q_w = q_size
280
+ k_h, k_w = k_size
281
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
282
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
283
+
284
+ B, _, dim = q.shape
285
+ r_q = q.reshape(B, q_h, q_w, dim)
286
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
287
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
288
+
289
+ attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
290
+ B, q_h * q_w, k_h * k_w
291
+ )
292
+
293
+ return attn