dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import math
4
- from typing import Tuple, Type
5
6
 
6
7
  import torch
7
8
  from torch import Tensor, nn
@@ -10,12 +11,10 @@ from ultralytics.nn.modules import MLPBlock
10
11
 
11
12
 
12
13
  class TwoWayTransformer(nn.Module):
13
- """
14
- A Two-Way Transformer module for simultaneous attention to image and query points.
14
+ """A Two-Way Transformer module for simultaneous attention to image and query points.
15
15
 
16
- This class implements a specialized transformer decoder that attends to an input image using queries with
17
- supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
18
- cloud processing.
16
+ This class implements a specialized transformer decoder that attends to an input image using queries with supplied
17
+ positional embeddings. It's useful for tasks like object detection, image segmentation, and point cloud processing.
19
18
 
20
19
  Attributes:
21
20
  depth (int): Number of layers in the transformer.
@@ -27,7 +26,7 @@ class TwoWayTransformer(nn.Module):
27
26
  norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
28
27
 
29
28
  Methods:
30
- forward: Processes image and point embeddings through the transformer.
29
+ forward: Process image and point embeddings through the transformer.
31
30
 
32
31
  Examples:
33
32
  >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
@@ -44,19 +43,18 @@ class TwoWayTransformer(nn.Module):
44
43
  embedding_dim: int,
45
44
  num_heads: int,
46
45
  mlp_dim: int,
47
- activation: Type[nn.Module] = nn.ReLU,
46
+ activation: type[nn.Module] = nn.ReLU,
48
47
  attention_downsample_rate: int = 2,
49
48
  ) -> None:
50
- """
51
- Initialize a Two-Way Transformer for simultaneous attention to image and query points.
49
+ """Initialize a Two-Way Transformer for simultaneous attention to image and query points.
52
50
 
53
51
  Args:
54
52
  depth (int): Number of layers in the transformer.
55
53
  embedding_dim (int): Channel dimension for input embeddings.
56
54
  num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
57
55
  mlp_dim (int): Internal channel dimension for the MLP block.
58
- activation (Type[nn.Module]): Activation function to use in the MLP block.
59
- attention_downsample_rate (int): Downsampling rate for attention mechanism.
56
+ activation (Type[nn.Module], optional): Activation function to use in the MLP block.
57
+ attention_downsample_rate (int, optional): Downsampling rate for attention mechanism.
60
58
  """
61
59
  super().__init__()
62
60
  self.depth = depth
@@ -82,21 +80,20 @@ class TwoWayTransformer(nn.Module):
82
80
 
83
81
  def forward(
84
82
  self,
85
- image_embedding: Tensor,
86
- image_pe: Tensor,
87
- point_embedding: Tensor,
88
- ) -> Tuple[Tensor, Tensor]:
89
- """
90
- Process image and point embeddings through the Two-Way Transformer.
83
+ image_embedding: torch.Tensor,
84
+ image_pe: torch.Tensor,
85
+ point_embedding: torch.Tensor,
86
+ ) -> tuple[torch.Tensor, torch.Tensor]:
87
+ """Process image and point embeddings through the Two-Way Transformer.
91
88
 
92
89
  Args:
93
- image_embedding (Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
94
- image_pe (Tensor): Positional encoding to add to the image, with same shape as image_embedding.
95
- point_embedding (Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
90
+ image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
91
+ image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
92
+ point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
96
93
 
97
94
  Returns:
98
- queries (Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).
99
- keys (Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).
95
+ queries (torch.Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).
96
+ keys (torch.Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).
100
97
  """
101
98
  # BxCxHxW -> BxHWxC == B x N_image_tokens x C
102
99
  image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
@@ -126,12 +123,11 @@ class TwoWayTransformer(nn.Module):
126
123
 
127
124
 
128
125
  class TwoWayAttentionBlock(nn.Module):
129
- """
130
- A two-way attention block for simultaneous attention to image and query points.
126
+ """A two-way attention block for simultaneous attention to image and query points.
131
127
 
132
128
  This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
133
- cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
134
- inputs to sparse inputs.
129
+ cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense inputs to
130
+ sparse inputs.
135
131
 
136
132
  Attributes:
137
133
  self_attn (Attention): Self-attention layer for queries.
@@ -145,7 +141,7 @@ class TwoWayAttentionBlock(nn.Module):
145
141
  skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
146
142
 
147
143
  Methods:
148
- forward: Applies self-attention and cross-attention to queries and keys.
144
+ forward: Apply self-attention and cross-attention to queries and keys.
149
145
 
150
146
  Examples:
151
147
  >>> embedding_dim, num_heads = 256, 8
@@ -162,24 +158,23 @@ class TwoWayAttentionBlock(nn.Module):
162
158
  embedding_dim: int,
163
159
  num_heads: int,
164
160
  mlp_dim: int = 2048,
165
- activation: Type[nn.Module] = nn.ReLU,
161
+ activation: type[nn.Module] = nn.ReLU,
166
162
  attention_downsample_rate: int = 2,
167
163
  skip_first_layer_pe: bool = False,
168
164
  ) -> None:
169
- """
170
- Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
165
+ """Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
171
166
 
172
167
  This block implements a specialized transformer layer with four main components: self-attention on sparse
173
- inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
174
- of dense inputs to sparse inputs.
168
+ inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of
169
+ dense inputs to sparse inputs.
175
170
 
176
171
  Args:
177
172
  embedding_dim (int): Channel dimension of the embeddings.
178
173
  num_heads (int): Number of attention heads in the attention layers.
179
- mlp_dim (int): Hidden dimension of the MLP block.
180
- activation (Type[nn.Module]): Activation function for the MLP block.
181
- attention_downsample_rate (int): Downsampling rate for the attention mechanism.
182
- skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
174
+ mlp_dim (int, optional): Hidden dimension of the MLP block.
175
+ activation (Type[nn.Module], optional): Activation function for the MLP block.
176
+ attention_downsample_rate (int, optional): Downsampling rate for the attention mechanism.
177
+ skip_first_layer_pe (bool, optional): Whether to skip positional encoding in the first layer.
183
178
  """
184
179
  super().__init__()
185
180
  self.self_attn = Attention(embedding_dim, num_heads)
@@ -196,19 +191,20 @@ class TwoWayAttentionBlock(nn.Module):
196
191
 
197
192
  self.skip_first_layer_pe = skip_first_layer_pe
198
193
 
199
- def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
200
- """
201
- Apply two-way attention to process query and key embeddings in a transformer block.
194
+ def forward(
195
+ self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
196
+ ) -> tuple[torch.Tensor, torch.Tensor]:
197
+ """Apply two-way attention to process query and key embeddings in a transformer block.
202
198
 
203
199
  Args:
204
- queries (Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
205
- keys (Tensor): Key embeddings with shape (B, N_keys, embedding_dim).
206
- query_pe (Tensor): Positional encodings for queries with same shape as queries.
207
- key_pe (Tensor): Positional encodings for keys with same shape as keys.
200
+ queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
201
+ keys (torch.Tensor): Key embeddings with shape (B, N_keys, embedding_dim).
202
+ query_pe (torch.Tensor): Positional encodings for queries with same shape as queries.
203
+ key_pe (torch.Tensor): Positional encodings for keys with same shape as keys.
208
204
 
209
205
  Returns:
210
- queries (Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).
211
- keys (Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).
206
+ queries (torch.Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).
207
+ keys (torch.Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).
212
208
  """
213
209
  # Self attention block
214
210
  if self.skip_first_layer_pe:
@@ -242,11 +238,10 @@ class TwoWayAttentionBlock(nn.Module):
242
238
 
243
239
 
244
240
  class Attention(nn.Module):
245
- """
246
- An attention layer with downscaling capability for embedding size after projection.
241
+ """An attention layer with downscaling capability for embedding size after projection.
247
242
 
248
- This class implements a multi-head attention mechanism with the option to downsample the internal
249
- dimension of queries, keys, and values.
243
+ This class implements a multi-head attention mechanism with the option to downsample the internal dimension of
244
+ queries, keys, and values.
250
245
 
251
246
  Attributes:
252
247
  embedding_dim (int): Dimensionality of input embeddings.
@@ -259,9 +254,9 @@ class Attention(nn.Module):
259
254
  out_proj (nn.Linear): Linear projection for output.
260
255
 
261
256
  Methods:
262
- _separate_heads: Separates input tensor into attention heads.
263
- _recombine_heads: Recombines separated attention heads.
264
- forward: Computes attention output for given query, key, and value tensors.
257
+ _separate_heads: Separate input tensor into attention heads.
258
+ _recombine_heads: Recombine separated attention heads.
259
+ forward: Compute attention output for given query, key, and value tensors.
265
260
 
266
261
  Examples:
267
262
  >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
@@ -277,16 +272,15 @@ class Attention(nn.Module):
277
272
  embedding_dim: int,
278
273
  num_heads: int,
279
274
  downsample_rate: int = 1,
280
- kv_in_dim: int = None,
275
+ kv_in_dim: int | None = None,
281
276
  ) -> None:
282
- """
283
- Initialize the Attention module with specified dimensions and settings.
277
+ """Initialize the Attention module with specified dimensions and settings.
284
278
 
285
279
  Args:
286
280
  embedding_dim (int): Dimensionality of input embeddings.
287
281
  num_heads (int): Number of attention heads.
288
- downsample_rate (int): Factor by which internal dimensions are downsampled.
289
- kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
282
+ downsample_rate (int, optional): Factor by which internal dimensions are downsampled.
283
+ kv_in_dim (int | None, optional): Dimensionality of key and value inputs. If None, uses embedding_dim.
290
284
 
291
285
  Raises:
292
286
  AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
@@ -304,7 +298,7 @@ class Attention(nn.Module):
304
298
  self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
305
299
 
306
300
  @staticmethod
307
- def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
301
+ def _separate_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:
308
302
  """Separate the input tensor into the specified number of attention heads."""
309
303
  b, n, c = x.shape
310
304
  x = x.reshape(b, n, num_heads, c // num_heads)
@@ -317,17 +311,16 @@ class Attention(nn.Module):
317
311
  x = x.transpose(1, 2)
318
312
  return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
319
313
 
320
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
321
- """
322
- Apply multi-head attention to query, key, and value tensors with optional downsampling.
314
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
315
+ """Apply multi-head attention to query, key, and value tensors with optional downsampling.
323
316
 
324
317
  Args:
325
- q (Tensor): Query tensor with shape (B, N_q, embedding_dim).
326
- k (Tensor): Key tensor with shape (B, N_k, embedding_dim).
327
- v (Tensor): Value tensor with shape (B, N_k, embedding_dim).
318
+ q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).
319
+ k (torch.Tensor): Key tensor with shape (B, N_k, embedding_dim).
320
+ v (torch.Tensor): Value tensor with shape (B, N_k, embedding_dim).
328
321
 
329
322
  Returns:
330
- (Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).
323
+ (torch.Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).
331
324
  """
332
325
  # Input projections
333
326
  q = self.q_proj(q)
@@ -1,23 +1,24 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from typing import Tuple
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
4
6
 
5
7
  import torch
6
8
  import torch.nn.functional as F
7
9
 
8
10
 
9
- def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
10
- """
11
- Select the closest conditioning frames to a given frame index.
11
+ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
12
+ """Select the closest conditioning frames to a given frame index.
12
13
 
13
14
  Args:
14
15
  frame_idx (int): Current frame index.
15
- cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
16
+ cond_frame_outputs (dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
16
17
  max_cond_frame_num (int): Maximum number of conditioning frames to select.
17
18
 
18
19
  Returns:
19
- selected_outputs (Dict[int, Any]): Selected items from cond_frame_outputs.
20
- unselected_outputs (Dict[int, Any]): Items not selected from cond_frame_outputs.
20
+ selected_outputs (dict[int, Any]): Selected items from cond_frame_outputs.
21
+ unselected_outputs (dict[int, Any]): Items not selected from cond_frame_outputs.
21
22
 
22
23
  Examples:
23
24
  >>> frame_idx = 5
@@ -59,14 +60,13 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
59
60
  return selected_outputs, unselected_outputs
60
61
 
61
62
 
62
- def get_1d_sine_pe(pos_inds, dim, temperature=10000):
63
- """
64
- Generate 1D sinusoidal positional embeddings for given positions and dimensions.
63
+ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
64
+ """Generate 1D sinusoidal positional embeddings for given positions and dimensions.
65
65
 
66
66
  Args:
67
67
  pos_inds (torch.Tensor): Position indices for which to generate embeddings.
68
68
  dim (int): Dimension of the positional embeddings. Should be an even number.
69
- temperature (float): Scaling factor for the frequency of the sinusoidal functions.
69
+ temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.
70
70
 
71
71
  Returns:
72
72
  (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).
@@ -78,7 +78,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
78
78
  torch.Size([4, 128])
79
79
  """
80
80
  pe_dim = dim // 2
81
- dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
81
+ dim_t = torch.arange(pe_dim, dtype=pos_inds.dtype, device=pos_inds.device)
82
82
  dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
83
83
 
84
84
  pos_embed = pos_inds.unsqueeze(-1) / dim_t
@@ -87,25 +87,21 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
87
87
 
88
88
 
89
89
  def init_t_xy(end_x: int, end_y: int):
90
- """
91
- Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
90
+ """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
92
91
 
93
- This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
94
- and corresponding x and y coordinate tensors.
92
+ This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
93
+ tensor and corresponding x and y coordinate tensors.
95
94
 
96
95
  Args:
97
96
  end_x (int): Width of the grid (number of columns).
98
97
  end_y (int): Height of the grid (number of rows).
99
98
 
100
99
  Returns:
101
- t (torch.Tensor): Linear indices for each position in the grid, with shape (end_x * end_y).
102
100
  t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
103
101
  t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).
104
102
 
105
103
  Examples:
106
- >>> t, t_x, t_y = init_t_xy(3, 2)
107
- >>> print(t)
108
- tensor([0., 1., 2., 3., 4., 5.])
104
+ >>> t_x, t_y = init_t_xy(3, 2)
109
105
  >>> print(t_x)
110
106
  tensor([0., 1., 2., 0., 1., 2.])
111
107
  >>> print(t_y)
@@ -118,11 +114,10 @@ def init_t_xy(end_x: int, end_y: int):
118
114
 
119
115
 
120
116
  def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
121
- """
122
- Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
117
+ """Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
123
118
 
124
- This function generates complex exponential positional encodings for a 2D grid of spatial positions,
125
- using separate frequency components for the x and y dimensions.
119
+ This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
120
+ frequency components for the x and y dimensions.
126
121
 
127
122
  Args:
128
123
  dim (int): Dimension of the positional encoding.
@@ -131,18 +126,13 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
131
126
  theta (float, optional): Scaling factor for frequency computation.
132
127
 
133
128
  Returns:
134
- freqs_cis_x (torch.Tensor): Complex exponential positional encodings for x-dimension with shape
135
- (end_x*end_y, dim//4).
136
- freqs_cis_y (torch.Tensor): Complex exponential positional encodings for y-dimension with shape
137
- (end_x*end_y, dim//4).
129
+ (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
138
130
 
139
131
  Examples:
140
132
  >>> dim, end_x, end_y = 128, 8, 8
141
- >>> freqs_cis_x, freqs_cis_y = compute_axial_cis(dim, end_x, end_y)
142
- >>> freqs_cis_x.shape
143
- torch.Size([64, 32])
144
- >>> freqs_cis_y.shape
145
- torch.Size([64, 32])
133
+ >>> freqs_cis = compute_axial_cis(dim, end_x, end_y)
134
+ >>> freqs_cis.shape
135
+ torch.Size([64, 64])
146
136
  """
147
137
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
148
138
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
@@ -156,11 +146,10 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
156
146
 
157
147
 
158
148
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
159
- """
160
- Reshape frequency tensor for broadcasting with input tensor.
149
+ """Reshape frequency tensor for broadcasting with input tensor.
161
150
 
162
- Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
163
- This function is typically used in positional encoding operations.
151
+ Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
152
+ is typically used in positional encoding operations.
164
153
 
165
154
  Args:
166
155
  freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
@@ -185,8 +174,7 @@ def apply_rotary_enc(
185
174
  freqs_cis: torch.Tensor,
186
175
  repeat_freqs_k: bool = False,
187
176
  ):
188
- """
189
- Apply rotary positional encoding to query and key tensors.
177
+ """Apply rotary positional encoding to query and key tensors.
190
178
 
191
179
  This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
192
180
  components. RoPE is a technique that injects relative position information into self-attention mechanisms.
@@ -194,10 +182,10 @@ def apply_rotary_enc(
194
182
  Args:
195
183
  xq (torch.Tensor): Query tensor to encode with positional information.
196
184
  xk (torch.Tensor): Key tensor to encode with positional information.
197
- freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
198
- last two dimensions of xq.
199
- repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
200
- to match key sequence length.
185
+ freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
186
+ two dimensions of xq.
187
+ repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
188
+ key sequence length.
201
189
 
202
190
  Returns:
203
191
  xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
@@ -225,9 +213,8 @@ def apply_rotary_enc(
225
213
  return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
226
214
 
227
215
 
228
- def window_partition(x, window_size):
229
- """
230
- Partition input tensor into non-overlapping windows with padding if needed.
216
+ def window_partition(x: torch.Tensor, window_size: int):
217
+ """Partition input tensor into non-overlapping windows with padding if needed.
231
218
 
232
219
  Args:
233
220
  x (torch.Tensor): Input tensor with shape (B, H, W, C).
@@ -235,7 +222,7 @@ def window_partition(x, window_size):
235
222
 
236
223
  Returns:
237
224
  windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
238
- padded_h_w (Tuple[int, int]): Padded height and width before partition.
225
+ padded_h_w (tuple[int, int]): Padded height and width before partition.
239
226
 
240
227
  Examples:
241
228
  >>> x = torch.randn(1, 16, 16, 3)
@@ -256,24 +243,23 @@ def window_partition(x, window_size):
256
243
  return windows, (Hp, Wp)
257
244
 
258
245
 
259
- def window_unpartition(windows, window_size, pad_hw, hw):
260
- """
261
- Unpartition windowed sequences into original sequences and remove padding.
246
+ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
247
+ """Unpartition windowed sequences into original sequences and remove padding.
262
248
 
263
- This function reverses the windowing process, reconstructing the original input from windowed segments
264
- and removing any padding that was added during the windowing process.
249
+ This function reverses the windowing process, reconstructing the original input from windowed segments and removing
250
+ any padding that was added during the windowing process.
265
251
 
266
252
  Args:
267
253
  windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
268
- window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
269
- the size of each window, and C is the number of channels.
254
+ window_size, C), where B is the batch size, num_windows is the number of windows, window_size is the size of
255
+ each window, and C is the number of channels.
270
256
  window_size (int): Size of each window.
271
- pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
272
- hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
257
+ pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
258
+ hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
273
259
 
274
260
  Returns:
275
- (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
276
- are the original height and width, and C is the number of channels.
261
+ (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
262
+ original height and width, and C is the number of channels.
277
263
 
278
264
  Examples:
279
265
  >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
@@ -295,18 +281,16 @@ def window_unpartition(windows, window_size, pad_hw, hw):
295
281
 
296
282
 
297
283
  def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
298
- """
299
- Extract relative positional embeddings based on query and key sizes.
284
+ """Extract relative positional embeddings based on query and key sizes.
300
285
 
301
286
  Args:
302
287
  q_size (int): Size of the query.
303
288
  k_size (int): Size of the key.
304
- rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
305
- distance and C is the embedding dimension.
289
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
290
+ and C is the embedding dimension.
306
291
 
307
292
  Returns:
308
- (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
309
- k_size, C).
293
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
310
294
 
311
295
  Examples:
312
296
  >>> q_size, k_size = 8, 16
@@ -341,11 +325,10 @@ def add_decomposed_rel_pos(
341
325
  q: torch.Tensor,
342
326
  rel_pos_h: torch.Tensor,
343
327
  rel_pos_w: torch.Tensor,
344
- q_size: Tuple[int, int],
345
- k_size: Tuple[int, int],
328
+ q_size: tuple[int, int],
329
+ k_size: tuple[int, int],
346
330
  ) -> torch.Tensor:
347
- """
348
- Add decomposed Relative Positional Embeddings to the attention map.
331
+ """Add decomposed Relative Positional Embeddings to the attention map.
349
332
 
350
333
  This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
351
334
  paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
@@ -356,12 +339,12 @@ def add_decomposed_rel_pos(
356
339
  q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
357
340
  rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
358
341
  rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
359
- q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
360
- k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
342
+ q_size (tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
343
+ k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
361
344
 
362
345
  Returns:
363
- (torch.Tensor): Updated attention map with added relative positional embeddings, shape
364
- (B, q_h * q_w, k_h * k_w).
346
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
347
+ k_w).
365
348
 
366
349
  Examples:
367
350
  >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8