ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import math
4
4
  from typing import Tuple, Type
@@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
11
11
 
12
12
  class TwoWayTransformer(nn.Module):
13
13
  """
14
- A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
15
- serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
16
- is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
17
- processing.
14
+ A Two-Way Transformer module for simultaneous attention to image and query points.
15
+
16
+ This class implements a specialized transformer decoder that attends to an input image using queries with
17
+ supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
18
+ cloud processing.
18
19
 
19
20
  Attributes:
20
- depth (int): The number of layers in the transformer.
21
- embedding_dim (int): The channel dimension for the input embeddings.
22
- num_heads (int): The number of heads for multihead attention.
23
- mlp_dim (int): The internal channel dimension for the MLP block.
24
- layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
25
- final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
26
- norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
21
+ depth (int): Number of layers in the transformer.
22
+ embedding_dim (int): Channel dimension for input embeddings.
23
+ num_heads (int): Number of heads for multihead attention.
24
+ mlp_dim (int): Internal channel dimension for the MLP block.
25
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
26
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
27
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
28
+
29
+ Methods:
30
+ forward: Processes image and point embeddings through the transformer.
31
+
32
+ Examples:
33
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
34
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
35
+ >>> image_pe = torch.randn(1, 256, 32, 32)
36
+ >>> point_embedding = torch.randn(1, 100, 256)
37
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
38
+ >>> print(output_queries.shape, output_image.shape)
27
39
  """
28
40
 
29
41
  def __init__(
@@ -36,15 +48,32 @@ class TwoWayTransformer(nn.Module):
36
48
  attention_downsample_rate: int = 2,
37
49
  ) -> None:
38
50
  """
39
- A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
51
+ Initialize a Two-Way Transformer for simultaneous attention to image and query points.
40
52
 
41
53
  Args:
42
- depth (int): number of layers in the transformer
43
- embedding_dim (int): the channel dimension for the input embeddings
44
- num_heads (int): the number of heads for multihead attention. Must
45
- divide embedding_dim
46
- mlp_dim (int): the channel dimension internal to the MLP block
47
- activation (nn.Module): the activation to use in the MLP block
54
+ depth (int): Number of layers in the transformer.
55
+ embedding_dim (int): Channel dimension for input embeddings.
56
+ num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
57
+ mlp_dim (int): Internal channel dimension for the MLP block.
58
+ activation (Type[nn.Module]): Activation function to use in the MLP block.
59
+ attention_downsample_rate (int): Downsampling rate for attention mechanism.
60
+
61
+ Attributes:
62
+ depth (int): Number of layers in the transformer.
63
+ embedding_dim (int): Channel dimension for input embeddings.
64
+ num_heads (int): Number of heads for multihead attention.
65
+ mlp_dim (int): Internal channel dimension for the MLP block.
66
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
67
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
68
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
69
+
70
+ Examples:
71
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
72
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
73
+ >>> image_pe = torch.randn(1, 256, 32, 32)
74
+ >>> point_embedding = torch.randn(1, 100, 256)
75
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
76
+ >>> print(output_queries.shape, output_image.shape)
48
77
  """
49
78
  super().__init__()
50
79
  self.depth = depth
@@ -75,18 +104,25 @@ class TwoWayTransformer(nn.Module):
75
104
  point_embedding: Tensor,
76
105
  ) -> Tuple[Tensor, Tensor]:
77
106
  """
107
+ Processes image and point embeddings through the Two-Way Transformer.
108
+
78
109
  Args:
79
- image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
80
- image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
81
- point_embedding (torch.Tensor): the embedding to add to the query points.
82
- Must have shape B x N_points x embedding_dim for any N_points.
110
+ image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
111
+ image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
112
+ point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
83
113
 
84
114
  Returns:
85
- (torch.Tensor): the processed point_embedding
86
- (torch.Tensor): the processed image_embedding
115
+ (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
116
+
117
+ Examples:
118
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
119
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
120
+ >>> image_pe = torch.randn(1, 256, 32, 32)
121
+ >>> point_embedding = torch.randn(1, 100, 256)
122
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
123
+ >>> print(output_queries.shape, output_image.shape)
87
124
  """
88
125
  # BxCxHxW -> BxHWxC == B x N_image_tokens x C
89
- bs, c, h, w = image_embedding.shape
90
126
  image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
91
127
  image_pe = image_pe.flatten(2).permute(0, 2, 1)
92
128
 
@@ -115,21 +151,34 @@ class TwoWayTransformer(nn.Module):
115
151
 
116
152
  class TwoWayAttentionBlock(nn.Module):
117
153
  """
118
- An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
119
- keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
120
- of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
121
- sparse inputs.
154
+ A two-way attention block for simultaneous attention to image and query points.
155
+
156
+ This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
157
+ cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
158
+ inputs to sparse inputs.
122
159
 
123
160
  Attributes:
124
- self_attn (Attention): The self-attention layer for the queries.
125
- norm1 (nn.LayerNorm): Layer normalization following the first attention block.
161
+ self_attn (Attention): Self-attention layer for queries.
162
+ norm1 (nn.LayerNorm): Layer normalization after self-attention.
126
163
  cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
127
- norm2 (nn.LayerNorm): Layer normalization following the second attention block.
128
- mlp (MLPBlock): MLP block that transforms the query embeddings.
129
- norm3 (nn.LayerNorm): Layer normalization following the MLP block.
130
- norm4 (nn.LayerNorm): Layer normalization following the third attention block.
164
+ norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
165
+ mlp (MLPBlock): MLP block for transforming query embeddings.
166
+ norm3 (nn.LayerNorm): Layer normalization after MLP block.
167
+ norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
131
168
  cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
132
- skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
169
+ skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
170
+
171
+ Methods:
172
+ forward: Applies self-attention and cross-attention to queries and keys.
173
+
174
+ Examples:
175
+ >>> embedding_dim, num_heads = 256, 8
176
+ >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
177
+ >>> queries = torch.randn(1, 100, embedding_dim)
178
+ >>> keys = torch.randn(1, 1000, embedding_dim)
179
+ >>> query_pe = torch.randn(1, 100, embedding_dim)
180
+ >>> key_pe = torch.randn(1, 1000, embedding_dim)
181
+ >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
133
182
  """
134
183
 
135
184
  def __init__(
@@ -142,16 +191,28 @@ class TwoWayAttentionBlock(nn.Module):
142
191
  skip_first_layer_pe: bool = False,
143
192
  ) -> None:
144
193
  """
145
- A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
146
- inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
147
- inputs.
194
+ Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
195
+
196
+ This block implements a specialized transformer layer with four main components: self-attention on sparse
197
+ inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
198
+ of dense inputs to sparse inputs.
148
199
 
149
200
  Args:
150
- embedding_dim (int): the channel dimension of the embeddings
151
- num_heads (int): the number of heads in the attention layers
152
- mlp_dim (int): the hidden dimension of the mlp block
153
- activation (nn.Module): the activation of the mlp block
154
- skip_first_layer_pe (bool): skip the PE on the first layer
201
+ embedding_dim (int): Channel dimension of the embeddings.
202
+ num_heads (int): Number of attention heads in the attention layers.
203
+ mlp_dim (int): Hidden dimension of the MLP block.
204
+ activation (Type[nn.Module]): Activation function for the MLP block.
205
+ attention_downsample_rate (int): Downsampling rate for the attention mechanism.
206
+ skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
207
+
208
+ Examples:
209
+ >>> embedding_dim, num_heads = 256, 8
210
+ >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
211
+ >>> queries = torch.randn(1, 100, embedding_dim)
212
+ >>> keys = torch.randn(1, 1000, embedding_dim)
213
+ >>> query_pe = torch.randn(1, 100, embedding_dim)
214
+ >>> key_pe = torch.randn(1, 1000, embedding_dim)
215
+ >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
155
216
  """
156
217
  super().__init__()
157
218
  self.self_attn = Attention(embedding_dim, num_heads)
@@ -169,8 +230,7 @@ class TwoWayAttentionBlock(nn.Module):
169
230
  self.skip_first_layer_pe = skip_first_layer_pe
170
231
 
171
232
  def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
172
- """Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
173
-
233
+ """Applies two-way attention to process query and key embeddings in a transformer block."""
174
234
  # Self attention block
175
235
  if self.skip_first_layer_pe:
176
236
  queries = self.self_attn(q=queries, k=queries, v=queries)
@@ -203,8 +263,34 @@ class TwoWayAttentionBlock(nn.Module):
203
263
 
204
264
 
205
265
  class Attention(nn.Module):
206
- """An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
207
- values.
266
+ """
267
+ An attention layer with downscaling capability for embedding size after projection.
268
+
269
+ This class implements a multi-head attention mechanism with the option to downsample the internal
270
+ dimension of queries, keys, and values.
271
+
272
+ Attributes:
273
+ embedding_dim (int): Dimensionality of input embeddings.
274
+ kv_in_dim (int): Dimensionality of key and value inputs.
275
+ internal_dim (int): Internal dimension after downsampling.
276
+ num_heads (int): Number of attention heads.
277
+ q_proj (nn.Linear): Linear projection for queries.
278
+ k_proj (nn.Linear): Linear projection for keys.
279
+ v_proj (nn.Linear): Linear projection for values.
280
+ out_proj (nn.Linear): Linear projection for output.
281
+
282
+ Methods:
283
+ _separate_heads: Separates input tensor into attention heads.
284
+ _recombine_heads: Recombines separated attention heads.
285
+ forward: Computes attention output for given query, key, and value tensors.
286
+
287
+ Examples:
288
+ >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
289
+ >>> q = torch.randn(1, 100, 256)
290
+ >>> k = v = torch.randn(1, 50, 256)
291
+ >>> output = attn(q, k, v)
292
+ >>> print(output.shape)
293
+ torch.Size([1, 100, 256])
208
294
  """
209
295
 
210
296
  def __init__(
@@ -212,46 +298,59 @@ class Attention(nn.Module):
212
298
  embedding_dim: int,
213
299
  num_heads: int,
214
300
  downsample_rate: int = 1,
301
+ kv_in_dim: int = None,
215
302
  ) -> None:
216
303
  """
217
- Initializes the Attention model with the given dimensions and settings.
304
+ Initializes the Attention module with specified dimensions and settings.
305
+
306
+ This class implements a multi-head attention mechanism with optional downsampling of the internal
307
+ dimension for queries, keys, and values.
218
308
 
219
309
  Args:
220
- embedding_dim (int): The dimensionality of the input embeddings.
221
- num_heads (int): The number of attention heads.
222
- downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
310
+ embedding_dim (int): Dimensionality of input embeddings.
311
+ num_heads (int): Number of attention heads.
312
+ downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
313
+ kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
223
314
 
224
315
  Raises:
225
- AssertionError: If 'num_heads' does not evenly divide the internal dimension (embedding_dim / downsample_rate).
316
+ AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
317
+
318
+ Examples:
319
+ >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
320
+ >>> q = torch.randn(1, 100, 256)
321
+ >>> k = v = torch.randn(1, 50, 256)
322
+ >>> output = attn(q, k, v)
323
+ >>> print(output.shape)
324
+ torch.Size([1, 100, 256])
226
325
  """
227
326
  super().__init__()
228
327
  self.embedding_dim = embedding_dim
328
+ self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
229
329
  self.internal_dim = embedding_dim // downsample_rate
230
330
  self.num_heads = num_heads
231
331
  assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
232
332
 
233
333
  self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
234
- self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
235
- self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
334
+ self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
335
+ self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
236
336
  self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
237
337
 
238
338
  @staticmethod
239
339
  def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
240
- """Separate the input tensor into the specified number of attention heads."""
340
+ """Separates the input tensor into the specified number of attention heads."""
241
341
  b, n, c = x.shape
242
342
  x = x.reshape(b, n, num_heads, c // num_heads)
243
343
  return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
244
344
 
245
345
  @staticmethod
246
346
  def _recombine_heads(x: Tensor) -> Tensor:
247
- """Recombine the separated attention heads into a single tensor."""
347
+ """Recombines separated attention heads into a single tensor."""
248
348
  b, n_heads, n_tokens, c_per_head = x.shape
249
349
  x = x.transpose(1, 2)
250
350
  return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
251
351
 
252
352
  def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
253
- """Compute the attention output given the input query, key, and value tensors."""
254
-
353
+ """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
255
354
  # Input projections
256
355
  q = self.q_proj(q)
257
356
  k = self.k_proj(k)
@@ -0,0 +1,293 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
10
+ """
11
+ Selects the closest conditioning frames to a given frame index.
12
+
13
+ Args:
14
+ frame_idx (int): Current frame index.
15
+ cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
16
+ max_cond_frame_num (int): Maximum number of conditioning frames to select.
17
+
18
+ Returns:
19
+ (Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries:
20
+ - selected_outputs: Selected items from cond_frame_outputs.
21
+ - unselected_outputs: Items not selected from cond_frame_outputs.
22
+
23
+ Examples:
24
+ >>> frame_idx = 5
25
+ >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
26
+ >>> max_cond_frame_num = 2
27
+ >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
28
+ >>> print(selected)
29
+ {3: 'b', 7: 'c'}
30
+ >>> print(unselected)
31
+ {1: 'a', 9: 'd'}
32
+ """
33
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
34
+ selected_outputs = cond_frame_outputs
35
+ unselected_outputs = {}
36
+ else:
37
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
38
+ selected_outputs = {}
39
+
40
+ # the closest conditioning frame before `frame_idx` (if any)
41
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
42
+ if idx_before is not None:
43
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
44
+
45
+ # the closest conditioning frame after `frame_idx` (if any)
46
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
47
+ if idx_after is not None:
48
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
49
+
50
+ # add other temporally closest conditioning frames until reaching a total
51
+ # of `max_cond_frame_num` conditioning frames.
52
+ num_remain = max_cond_frame_num - len(selected_outputs)
53
+ inds_remain = sorted(
54
+ (t for t in cond_frame_outputs if t not in selected_outputs),
55
+ key=lambda x: abs(x - frame_idx),
56
+ )[:num_remain]
57
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
58
+ unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
59
+
60
+ return selected_outputs, unselected_outputs
61
+
62
+
63
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
64
+ """Generates 1D sinusoidal positional embeddings for given positions and dimensions."""
65
+ pe_dim = dim // 2
66
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
67
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
68
+
69
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
70
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
71
+ return pos_embed
72
+
73
+
74
+ def init_t_xy(end_x: int, end_y: int):
75
+ """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
76
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
77
+ t_x = (t % end_x).float()
78
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
79
+ return t_x, t_y
80
+
81
+
82
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
83
+ """Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
84
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
85
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
86
+
87
+ t_x, t_y = init_t_xy(end_x, end_y)
88
+ freqs_x = torch.outer(t_x, freqs_x)
89
+ freqs_y = torch.outer(t_y, freqs_y)
90
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
91
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
92
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
93
+
94
+
95
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
96
+ """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
97
+ ndim = x.ndim
98
+ assert 0 <= 1 < ndim
99
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
100
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
101
+ return freqs_cis.view(*shape)
102
+
103
+
104
+ def apply_rotary_enc(
105
+ xq: torch.Tensor,
106
+ xk: torch.Tensor,
107
+ freqs_cis: torch.Tensor,
108
+ repeat_freqs_k: bool = False,
109
+ ):
110
+ """Applies rotary positional encoding to query and key tensors using complex-valued frequency components."""
111
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
112
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
113
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
114
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
115
+ if xk_ is None:
116
+ # no keys to rotate, due to dropout
117
+ return xq_out.type_as(xq).to(xq.device), xk
118
+ # repeat freqs along seq_len dim to match k seq_len
119
+ if repeat_freqs_k:
120
+ r = xk_.shape[-2] // xq_.shape[-2]
121
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
122
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
123
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
124
+
125
+
126
+ def window_partition(x, window_size):
127
+ """
128
+ Partitions input tensor into non-overlapping windows with padding if needed.
129
+
130
+ Args:
131
+ x (torch.Tensor): Input tensor with shape (B, H, W, C).
132
+ window_size (int): Size of each window.
133
+
134
+ Returns:
135
+ (Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing:
136
+ - windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
137
+ - (Hp, Wp) (Tuple[int, int]): Padded height and width before partition.
138
+
139
+ Examples:
140
+ >>> x = torch.randn(1, 16, 16, 3)
141
+ >>> windows, (Hp, Wp) = window_partition(x, window_size=4)
142
+ >>> print(windows.shape, Hp, Wp)
143
+ torch.Size([16, 4, 4, 3]) 16 16
144
+ """
145
+ B, H, W, C = x.shape
146
+
147
+ pad_h = (window_size - H % window_size) % window_size
148
+ pad_w = (window_size - W % window_size) % window_size
149
+ if pad_h > 0 or pad_w > 0:
150
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
151
+ Hp, Wp = H + pad_h, W + pad_w
152
+
153
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
154
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
155
+ return windows, (Hp, Wp)
156
+
157
+
158
+ def window_unpartition(windows, window_size, pad_hw, hw):
159
+ """
160
+ Unpartitions windowed sequences into original sequences and removes padding.
161
+
162
+ This function reverses the windowing process, reconstructing the original input from windowed segments
163
+ and removing any padding that was added during the windowing process.
164
+
165
+ Args:
166
+ windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
167
+ window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
168
+ the size of each window, and C is the number of channels.
169
+ window_size (int): Size of each window.
170
+ pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
171
+ hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
172
+
173
+ Returns:
174
+ (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
175
+ are the original height and width, and C is the number of channels.
176
+
177
+ Examples:
178
+ >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
179
+ >>> pad_hw = (16, 16) # Padded height and width
180
+ >>> hw = (15, 14) # Original height and width
181
+ >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
182
+ >>> print(x.shape)
183
+ torch.Size([1, 15, 14, 64])
184
+ """
185
+ Hp, Wp = pad_hw
186
+ H, W = hw
187
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
188
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
189
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
190
+
191
+ if Hp > H or Wp > W:
192
+ x = x[:, :H, :W, :].contiguous()
193
+ return x
194
+
195
+
196
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
197
+ """
198
+ Extracts relative positional embeddings based on query and key sizes.
199
+
200
+ Args:
201
+ q_size (int): Size of the query.
202
+ k_size (int): Size of the key.
203
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
204
+ distance and C is the embedding dimension.
205
+
206
+ Returns:
207
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
208
+ k_size, C).
209
+
210
+ Examples:
211
+ >>> q_size, k_size = 8, 16
212
+ >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
213
+ >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
214
+ >>> print(extracted_pos.shape)
215
+ torch.Size([8, 16, 64])
216
+ """
217
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
218
+ # Interpolate rel pos if needed.
219
+ if rel_pos.shape[0] != max_rel_dist:
220
+ # Interpolate rel pos.
221
+ rel_pos_resized = F.interpolate(
222
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
223
+ size=max_rel_dist,
224
+ mode="linear",
225
+ )
226
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
227
+ else:
228
+ rel_pos_resized = rel_pos
229
+
230
+ # Scale the coords with short length if shapes for q and k are different.
231
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
232
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
233
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
234
+
235
+ return rel_pos_resized[relative_coords.long()]
236
+
237
+
238
+ def add_decomposed_rel_pos(
239
+ attn: torch.Tensor,
240
+ q: torch.Tensor,
241
+ rel_pos_h: torch.Tensor,
242
+ rel_pos_w: torch.Tensor,
243
+ q_size: Tuple[int, int],
244
+ k_size: Tuple[int, int],
245
+ ) -> torch.Tensor:
246
+ """
247
+ Adds decomposed Relative Positional Embeddings to the attention map.
248
+
249
+ This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
250
+ paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
251
+ positions.
252
+
253
+ Args:
254
+ attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
255
+ q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
256
+ rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
257
+ rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
258
+ q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
259
+ k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
260
+
261
+ Returns:
262
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape
263
+ (B, q_h * q_w, k_h * k_w).
264
+
265
+ Examples:
266
+ >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
267
+ >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
268
+ >>> q = torch.rand(B, q_h * q_w, C)
269
+ >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
270
+ >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
271
+ >>> q_size, k_size = (q_h, q_w), (k_h, k_w)
272
+ >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
273
+ >>> print(updated_attn.shape)
274
+ torch.Size([1, 64, 64])
275
+
276
+ References:
277
+ https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
278
+ """
279
+ q_h, q_w = q_size
280
+ k_h, k_w = k_size
281
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
282
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
283
+
284
+ B, _, dim = q.shape
285
+ r_q = q.reshape(B, q_h, q_w, dim)
286
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
287
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
288
+
289
+ attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
290
+ B, q_h * q_w, k_h * k_w
291
+ )
292
+
293
+ return attn