dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,10 @@ from ultralytics.nn.modules import MLPBlock
11
11
 
12
12
 
13
13
  class TwoWayTransformer(nn.Module):
14
- """
15
- A Two-Way Transformer module for simultaneous attention to image and query points.
14
+ """A Two-Way Transformer module for simultaneous attention to image and query points.
16
15
 
17
- This class implements a specialized transformer decoder that attends to an input image using queries with
18
- supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
19
- cloud processing.
16
+ This class implements a specialized transformer decoder that attends to an input image using queries with supplied
17
+ positional embeddings. It's useful for tasks like object detection, image segmentation, and point cloud processing.
20
18
 
21
19
  Attributes:
22
20
  depth (int): Number of layers in the transformer.
@@ -48,8 +46,7 @@ class TwoWayTransformer(nn.Module):
48
46
  activation: type[nn.Module] = nn.ReLU,
49
47
  attention_downsample_rate: int = 2,
50
48
  ) -> None:
51
- """
52
- Initialize a Two-Way Transformer for simultaneous attention to image and query points.
49
+ """Initialize a Two-Way Transformer for simultaneous attention to image and query points.
53
50
 
54
51
  Args:
55
52
  depth (int): Number of layers in the transformer.
@@ -87,8 +84,7 @@ class TwoWayTransformer(nn.Module):
87
84
  image_pe: torch.Tensor,
88
85
  point_embedding: torch.Tensor,
89
86
  ) -> tuple[torch.Tensor, torch.Tensor]:
90
- """
91
- Process image and point embeddings through the Two-Way Transformer.
87
+ """Process image and point embeddings through the Two-Way Transformer.
92
88
 
93
89
  Args:
94
90
  image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
@@ -127,12 +123,11 @@ class TwoWayTransformer(nn.Module):
127
123
 
128
124
 
129
125
  class TwoWayAttentionBlock(nn.Module):
130
- """
131
- A two-way attention block for simultaneous attention to image and query points.
126
+ """A two-way attention block for simultaneous attention to image and query points.
132
127
 
133
128
  This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
134
- cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
135
- inputs to sparse inputs.
129
+ cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense inputs to
130
+ sparse inputs.
136
131
 
137
132
  Attributes:
138
133
  self_attn (Attention): Self-attention layer for queries.
@@ -167,12 +162,11 @@ class TwoWayAttentionBlock(nn.Module):
167
162
  attention_downsample_rate: int = 2,
168
163
  skip_first_layer_pe: bool = False,
169
164
  ) -> None:
170
- """
171
- Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
165
+ """Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
172
166
 
173
167
  This block implements a specialized transformer layer with four main components: self-attention on sparse
174
- inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
175
- of dense inputs to sparse inputs.
168
+ inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of
169
+ dense inputs to sparse inputs.
176
170
 
177
171
  Args:
178
172
  embedding_dim (int): Channel dimension of the embeddings.
@@ -200,8 +194,7 @@ class TwoWayAttentionBlock(nn.Module):
200
194
  def forward(
201
195
  self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
202
196
  ) -> tuple[torch.Tensor, torch.Tensor]:
203
- """
204
- Apply two-way attention to process query and key embeddings in a transformer block.
197
+ """Apply two-way attention to process query and key embeddings in a transformer block.
205
198
 
206
199
  Args:
207
200
  queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
@@ -245,11 +238,10 @@ class TwoWayAttentionBlock(nn.Module):
245
238
 
246
239
 
247
240
  class Attention(nn.Module):
248
- """
249
- An attention layer with downscaling capability for embedding size after projection.
241
+ """An attention layer with downscaling capability for embedding size after projection.
250
242
 
251
- This class implements a multi-head attention mechanism with the option to downsample the internal
252
- dimension of queries, keys, and values.
243
+ This class implements a multi-head attention mechanism with the option to downsample the internal dimension of
244
+ queries, keys, and values.
253
245
 
254
246
  Attributes:
255
247
  embedding_dim (int): Dimensionality of input embeddings.
@@ -282,8 +274,7 @@ class Attention(nn.Module):
282
274
  downsample_rate: int = 1,
283
275
  kv_in_dim: int | None = None,
284
276
  ) -> None:
285
- """
286
- Initialize the Attention module with specified dimensions and settings.
277
+ """Initialize the Attention module with specified dimensions and settings.
287
278
 
288
279
  Args:
289
280
  embedding_dim (int): Dimensionality of input embeddings.
@@ -321,8 +312,7 @@ class Attention(nn.Module):
321
312
  return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
322
313
 
323
314
  def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
324
- """
325
- Apply multi-head attention to query, key, and value tensors with optional downsampling.
315
+ """Apply multi-head attention to query, key, and value tensors with optional downsampling.
326
316
 
327
317
  Args:
328
318
  q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).
@@ -9,8 +9,7 @@ import torch.nn.functional as F
9
9
 
10
10
 
11
11
  def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
12
- """
13
- Select the closest conditioning frames to a given frame index.
12
+ """Select the closest conditioning frames to a given frame index.
14
13
 
15
14
  Args:
16
15
  frame_idx (int): Current frame index.
@@ -62,8 +61,7 @@ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any
62
61
 
63
62
 
64
63
  def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
65
- """
66
- Generate 1D sinusoidal positional embeddings for given positions and dimensions.
64
+ """Generate 1D sinusoidal positional embeddings for given positions and dimensions.
67
65
 
68
66
  Args:
69
67
  pos_inds (torch.Tensor): Position indices for which to generate embeddings.
@@ -89,11 +87,10 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
89
87
 
90
88
 
91
89
  def init_t_xy(end_x: int, end_y: int):
92
- """
93
- Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
90
+ """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
94
91
 
95
- This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
96
- and corresponding x and y coordinate tensors.
92
+ This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
93
+ tensor and corresponding x and y coordinate tensors.
97
94
 
98
95
  Args:
99
96
  end_x (int): Width of the grid (number of columns).
@@ -117,11 +114,10 @@ def init_t_xy(end_x: int, end_y: int):
117
114
 
118
115
 
119
116
  def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
120
- """
121
- Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
117
+ """Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
122
118
 
123
- This function generates complex exponential positional encodings for a 2D grid of spatial positions,
124
- using separate frequency components for the x and y dimensions.
119
+ This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
120
+ frequency components for the x and y dimensions.
125
121
 
126
122
  Args:
127
123
  dim (int): Dimension of the positional encoding.
@@ -150,11 +146,10 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
150
146
 
151
147
 
152
148
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
153
- """
154
- Reshape frequency tensor for broadcasting with input tensor.
149
+ """Reshape frequency tensor for broadcasting with input tensor.
155
150
 
156
- Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
157
- This function is typically used in positional encoding operations.
151
+ Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
152
+ is typically used in positional encoding operations.
158
153
 
159
154
  Args:
160
155
  freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
@@ -179,8 +174,7 @@ def apply_rotary_enc(
179
174
  freqs_cis: torch.Tensor,
180
175
  repeat_freqs_k: bool = False,
181
176
  ):
182
- """
183
- Apply rotary positional encoding to query and key tensors.
177
+ """Apply rotary positional encoding to query and key tensors.
184
178
 
185
179
  This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
186
180
  components. RoPE is a technique that injects relative position information into self-attention mechanisms.
@@ -188,10 +182,10 @@ def apply_rotary_enc(
188
182
  Args:
189
183
  xq (torch.Tensor): Query tensor to encode with positional information.
190
184
  xk (torch.Tensor): Key tensor to encode with positional information.
191
- freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
192
- last two dimensions of xq.
193
- repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
194
- to match key sequence length.
185
+ freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
186
+ two dimensions of xq.
187
+ repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
188
+ key sequence length.
195
189
 
196
190
  Returns:
197
191
  xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
@@ -220,8 +214,7 @@ def apply_rotary_enc(
220
214
 
221
215
 
222
216
  def window_partition(x: torch.Tensor, window_size: int):
223
- """
224
- Partition input tensor into non-overlapping windows with padding if needed.
217
+ """Partition input tensor into non-overlapping windows with padding if needed.
225
218
 
226
219
  Args:
227
220
  x (torch.Tensor): Input tensor with shape (B, H, W, C).
@@ -251,23 +244,22 @@ def window_partition(x: torch.Tensor, window_size: int):
251
244
 
252
245
 
253
246
  def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
254
- """
255
- Unpartition windowed sequences into original sequences and remove padding.
247
+ """Unpartition windowed sequences into original sequences and remove padding.
256
248
 
257
- This function reverses the windowing process, reconstructing the original input from windowed segments
258
- and removing any padding that was added during the windowing process.
249
+ This function reverses the windowing process, reconstructing the original input from windowed segments and removing
250
+ any padding that was added during the windowing process.
259
251
 
260
252
  Args:
261
253
  windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
262
- window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
263
- the size of each window, and C is the number of channels.
254
+ window_size, C), where B is the batch size, num_windows is the number of windows, window_size is the size of
255
+ each window, and C is the number of channels.
264
256
  window_size (int): Size of each window.
265
257
  pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
266
258
  hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
267
259
 
268
260
  Returns:
269
- (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
270
- are the original height and width, and C is the number of channels.
261
+ (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
262
+ original height and width, and C is the number of channels.
271
263
 
272
264
  Examples:
273
265
  >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
@@ -289,18 +281,16 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
289
281
 
290
282
 
291
283
  def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
292
- """
293
- Extract relative positional embeddings based on query and key sizes.
284
+ """Extract relative positional embeddings based on query and key sizes.
294
285
 
295
286
  Args:
296
287
  q_size (int): Size of the query.
297
288
  k_size (int): Size of the key.
298
- rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
299
- distance and C is the embedding dimension.
289
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
290
+ and C is the embedding dimension.
300
291
 
301
292
  Returns:
302
- (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
303
- k_size, C).
293
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
304
294
 
305
295
  Examples:
306
296
  >>> q_size, k_size = 8, 16
@@ -338,8 +328,7 @@ def add_decomposed_rel_pos(
338
328
  q_size: tuple[int, int],
339
329
  k_size: tuple[int, int],
340
330
  ) -> torch.Tensor:
341
- """
342
- Add decomposed Relative Positional Embeddings to the attention map.
331
+ """Add decomposed Relative Positional Embeddings to the attention map.
343
332
 
344
333
  This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
345
334
  paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
@@ -354,8 +343,8 @@ def add_decomposed_rel_pos(
354
343
  k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
355
344
 
356
345
  Returns:
357
- (torch.Tensor): Updated attention map with added relative positional embeddings, shape
358
- (B, q_h * q_w, k_h * k_w).
346
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
347
+ k_w).
359
348
 
360
349
  Examples:
361
350
  >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8