dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,512 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
13
+ """Select the closest conditioning frames to a given frame index.
14
+
15
+ Args:
16
+ frame_idx (int): Current frame index.
17
+ cond_frame_outputs (dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
18
+ max_cond_frame_num (int): Maximum number of conditioning frames to select.
19
+
20
+ Returns:
21
+ selected_outputs (dict[int, Any]): Selected items from cond_frame_outputs.
22
+ unselected_outputs (dict[int, Any]): Items not selected from cond_frame_outputs.
23
+
24
+ Examples:
25
+ >>> frame_idx = 5
26
+ >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
27
+ >>> max_cond_frame_num = 2
28
+ >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
29
+ >>> print(selected)
30
+ {3: 'b', 7: 'c'}
31
+ >>> print(unselected)
32
+ {1: 'a', 9: 'd'}
33
+ """
34
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
35
+ selected_outputs = cond_frame_outputs
36
+ unselected_outputs = {}
37
+ else:
38
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
39
+ selected_outputs = {}
40
+
41
+ # The closest conditioning frame before `frame_idx` (if any)
42
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
43
+ if idx_before is not None:
44
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
45
+
46
+ # The closest conditioning frame after `frame_idx` (if any)
47
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
48
+ if idx_after is not None:
49
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
50
+
51
+ # Add other temporally closest conditioning frames until reaching a total
52
+ # of `max_cond_frame_num` conditioning frames.
53
+ num_remain = max_cond_frame_num - len(selected_outputs)
54
+ inds_remain = sorted(
55
+ (t for t in cond_frame_outputs if t not in selected_outputs),
56
+ key=lambda x: abs(x - frame_idx),
57
+ )[:num_remain]
58
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
59
+ unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
60
+
61
+ return selected_outputs, unselected_outputs
62
+
63
+
64
+ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
65
+ """Generate 1D sinusoidal positional embeddings for given positions and dimensions.
66
+
67
+ Args:
68
+ pos_inds (torch.Tensor): Position indices for which to generate embeddings.
69
+ dim (int): Dimension of the positional embeddings. Should be an even number.
70
+ temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.
71
+
72
+ Returns:
73
+ (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).
74
+
75
+ Examples:
76
+ >>> pos = torch.tensor([0, 1, 2, 3])
77
+ >>> embeddings = get_1d_sine_pe(pos, 128)
78
+ >>> embeddings.shape
79
+ torch.Size([4, 128])
80
+ """
81
+ pe_dim = dim // 2
82
+ dim_t = torch.arange(pe_dim, dtype=pos_inds.dtype, device=pos_inds.device)
83
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
84
+
85
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
86
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
87
+ return pos_embed
88
+
89
+
90
+ def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
91
+ """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
92
+
93
+ This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
94
+ tensor and corresponding x and y coordinate tensors.
95
+
96
+ Args:
97
+ end_x (int): Width of the grid (number of columns).
98
+ end_y (int): Height of the grid (number of rows).
99
+ scale (float): Scaling factor to apply to the coordinates.
100
+ offset (int): Offset to add to the coordinates.
101
+
102
+ Returns:
103
+ t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
104
+ t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).
105
+
106
+ Examples:
107
+ >>> t_x, t_y = init_t_xy(3, 2)
108
+ >>> print(t_x)
109
+ tensor([0., 1., 2., 0., 1., 2.])
110
+ >>> print(t_y)
111
+ tensor([0., 0., 0., 1., 1., 1.])
112
+ """
113
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
114
+ t_x = (t % end_x).float()
115
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
116
+ return t_x * scale + offset, t_y * scale + offset
117
+
118
+
119
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
120
+ """Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
121
+
122
+ This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
123
+ frequency components for the x and y dimensions.
124
+
125
+ Args:
126
+ dim (int): Dimension of the positional encoding.
127
+ end_x (int): Width of the 2D grid.
128
+ end_y (int): Height of the 2D grid.
129
+ theta (float, optional): Scaling factor for frequency computation.
130
+ scale_pos (float, optional): Scaling factor for position coordinates.
131
+
132
+ Returns:
133
+ (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
134
+
135
+ Examples:
136
+ >>> dim, end_x, end_y = 128, 8, 8
137
+ >>> freqs_cis = compute_axial_cis(dim, end_x, end_y)
138
+ >>> freqs_cis.shape
139
+ torch.Size([64, 64])
140
+ """
141
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
142
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
143
+
144
+ t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
145
+ freqs_x = torch.outer(t_x, freqs_x)
146
+ freqs_y = torch.outer(t_y, freqs_y)
147
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
148
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
149
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
150
+
151
+
152
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
153
+ """Reshape frequency tensor for broadcasting with input tensor.
154
+
155
+ Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
156
+ is typically used in positional encoding operations.
157
+
158
+ Args:
159
+ freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
160
+ x (torch.Tensor): Input tensor to broadcast with.
161
+
162
+ Returns:
163
+ (torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.
164
+
165
+ Raises:
166
+ AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
167
+ """
168
+ ndim = x.ndim
169
+ assert ndim >= 2
170
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
171
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
172
+ return freqs_cis.view(*shape)
173
+
174
+
175
+ def apply_rotary_enc(
176
+ xq: torch.Tensor,
177
+ xk: torch.Tensor,
178
+ freqs_cis: torch.Tensor,
179
+ repeat_freqs_k: bool = False,
180
+ ):
181
+ """Apply rotary positional encoding to query and key tensors.
182
+
183
+ This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
184
+ components. RoPE is a technique that injects relative position information into self-attention mechanisms.
185
+
186
+ Args:
187
+ xq (torch.Tensor): Query tensor to encode with positional information.
188
+ xk (torch.Tensor): Key tensor to encode with positional information.
189
+ freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
190
+ two dimensions of xq.
191
+ repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
192
+ key sequence length.
193
+
194
+ Returns:
195
+ xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
196
+ xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.
197
+
198
+ Examples:
199
+ >>> import torch
200
+ >>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]
201
+ >>> xk = torch.randn(2, 8, 16, 64)
202
+ >>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64
203
+ >>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)
204
+ """
205
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
206
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
207
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
208
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
209
+ if xk_ is None:
210
+ # No keys to rotate, due to dropout
211
+ return xq_out.type_as(xq).to(xq.device), xk
212
+ # Repeat freqs along seq_len dim to match k seq_len
213
+ if repeat_freqs_k and (r := xk_.shape[-2] // xq_.shape[-2]) > 1:
214
+ # MPS doesn't support repeat on complex tensors, decompose to real representation
215
+ if freqs_cis.device.type == "mps":
216
+ freqs_cis = torch.view_as_real(freqs_cis)
217
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 3)), r, 1, 1)
218
+ freqs_cis = torch.view_as_complex(freqs_cis.contiguous())
219
+ else:
220
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
221
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
222
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
223
+
224
+
225
+ def window_partition(x: torch.Tensor, window_size: int):
226
+ """Partition input tensor into non-overlapping windows with padding if needed.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor with shape (B, H, W, C).
230
+ window_size (int): Size of each window.
231
+
232
+ Returns:
233
+ windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
234
+ padded_h_w (tuple[int, int]): Padded height and width before partition.
235
+
236
+ Examples:
237
+ >>> x = torch.randn(1, 16, 16, 3)
238
+ >>> windows, (Hp, Wp) = window_partition(x, window_size=4)
239
+ >>> print(windows.shape, Hp, Wp)
240
+ torch.Size([16, 4, 4, 3]) 16 16
241
+ """
242
+ B, H, W, C = x.shape
243
+
244
+ pad_h = (window_size - H % window_size) % window_size
245
+ pad_w = (window_size - W % window_size) % window_size
246
+ if pad_h > 0 or pad_w > 0:
247
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
248
+ Hp, Wp = H + pad_h, W + pad_w
249
+
250
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
251
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
252
+ return windows, (Hp, Wp)
253
+
254
+
255
+ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
256
+ """Unpartition windowed sequences into original sequences and remove padding.
257
+
258
+ This function reverses the windowing process, reconstructing the original input from windowed segments and removing
259
+ any padding that was added during the windowing process.
260
+
261
+ Args:
262
+ windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
263
+ window_size, C), where B is the batch size, num_windows is the number of windows, window_size is the size of
264
+ each window, and C is the number of channels.
265
+ window_size (int): Size of each window.
266
+ pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
267
+ hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
268
+
269
+ Returns:
270
+ (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
271
+ original height and width, and C is the number of channels.
272
+
273
+ Examples:
274
+ >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
275
+ >>> pad_hw = (16, 16) # Padded height and width
276
+ >>> hw = (15, 14) # Original height and width
277
+ >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
278
+ >>> print(x.shape)
279
+ torch.Size([1, 15, 14, 64])
280
+ """
281
+ Hp, Wp = pad_hw
282
+ H, W = hw
283
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
284
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
285
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
286
+
287
+ if Hp > H or Wp > W:
288
+ x = x[:, :H, :W, :].contiguous()
289
+ return x
290
+
291
+
292
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
293
+ """Extract relative positional embeddings based on query and key sizes.
294
+
295
+ Args:
296
+ q_size (int): Size of the query.
297
+ k_size (int): Size of the key.
298
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
299
+ and C is the embedding dimension.
300
+
301
+ Returns:
302
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
303
+
304
+ Examples:
305
+ >>> q_size, k_size = 8, 16
306
+ >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
307
+ >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
308
+ >>> print(extracted_pos.shape)
309
+ torch.Size([8, 16, 64])
310
+ """
311
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
312
+ # Interpolate rel pos if needed.
313
+ if rel_pos.shape[0] != max_rel_dist:
314
+ # Interpolate rel pos.
315
+ rel_pos_resized = F.interpolate(
316
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
317
+ size=max_rel_dist,
318
+ mode="linear",
319
+ )
320
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
321
+ else:
322
+ rel_pos_resized = rel_pos
323
+
324
+ # Scale the coords with short length if shapes for q and k are different.
325
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
326
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
327
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
328
+
329
+ return rel_pos_resized[relative_coords.long()]
330
+
331
+
332
+ def add_decomposed_rel_pos(
333
+ attn: torch.Tensor,
334
+ q: torch.Tensor,
335
+ rel_pos_h: torch.Tensor,
336
+ rel_pos_w: torch.Tensor,
337
+ q_size: tuple[int, int],
338
+ k_size: tuple[int, int],
339
+ ) -> torch.Tensor:
340
+ """Add decomposed Relative Positional Embeddings to the attention map.
341
+
342
+ This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
343
+ paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
344
+ positions.
345
+
346
+ Args:
347
+ attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
348
+ q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
349
+ rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
350
+ rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
351
+ q_size (tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
352
+ k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
353
+
354
+ Returns:
355
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
356
+ k_w).
357
+
358
+ Examples:
359
+ >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
360
+ >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
361
+ >>> q = torch.rand(B, q_h * q_w, C)
362
+ >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
363
+ >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
364
+ >>> q_size, k_size = (q_h, q_w), (k_h, k_w)
365
+ >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
366
+ >>> print(updated_attn.shape)
367
+ torch.Size([1, 64, 64])
368
+
369
+ References:
370
+ https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
371
+ """
372
+ q_h, q_w = q_size
373
+ k_h, k_w = k_size
374
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
375
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
376
+
377
+ B, _, dim = q.shape
378
+ r_q = q.reshape(B, q_h, q_w, dim)
379
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
380
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
381
+
382
+ attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
383
+ B, q_h * q_w, k_h * k_w
384
+ )
385
+
386
+ return attn
387
+
388
+
389
+ def get_abs_pos(
390
+ abs_pos: torch.Tensor,
391
+ has_cls_token: bool,
392
+ hw: tuple[int, int],
393
+ retain_cls_token: bool = False,
394
+ tiling: bool = False,
395
+ ) -> torch.Tensor:
396
+ """Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
397
+ original embeddings.
398
+
399
+ Args:
400
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
401
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
402
+ hw (Tuple): size of input image tokens.
403
+ retain_cls_token: whether to retain the cls_token
404
+ tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
405
+
406
+ Returns:
407
+ Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
408
+ otherwise (1, 1+H*W, C).
409
+ """
410
+ if retain_cls_token:
411
+ assert has_cls_token
412
+
413
+ h, w = hw
414
+ if has_cls_token:
415
+ cls_pos = abs_pos[:, :1]
416
+ abs_pos = abs_pos[:, 1:]
417
+
418
+ xy_num = abs_pos.shape[1]
419
+ size = int(math.sqrt(xy_num))
420
+ assert size * size == xy_num
421
+
422
+ if size != h or size != w:
423
+ new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
424
+ if tiling:
425
+ new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
426
+ :, :, :h, :w
427
+ ]
428
+ else:
429
+ new_abs_pos = F.interpolate(
430
+ new_abs_pos,
431
+ size=(h, w),
432
+ mode="bicubic",
433
+ align_corners=False,
434
+ )
435
+
436
+ if not retain_cls_token:
437
+ return new_abs_pos.permute(0, 2, 3, 1)
438
+ else:
439
+ # add cls_token back, flatten spatial dims
440
+ assert has_cls_token
441
+ return torch.cat(
442
+ [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
443
+ dim=1,
444
+ )
445
+
446
+ else:
447
+ if not retain_cls_token:
448
+ return abs_pos.reshape(1, h, w, -1)
449
+ else:
450
+ assert has_cls_token
451
+ return torch.cat([cls_pos, abs_pos], dim=1)
452
+
453
+
454
+ def concat_rel_pos(
455
+ q: torch.Tensor,
456
+ k: torch.Tensor,
457
+ q_hw: tuple[int, int],
458
+ k_hw: tuple[int, int],
459
+ rel_pos_h: torch.Tensor,
460
+ rel_pos_w: torch.Tensor,
461
+ rescale: bool = False,
462
+ relative_coords: torch.Tensor = None,
463
+ ) -> tuple[torch.Tensor, torch.Tensor]:
464
+ """Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
465
+
466
+ Args:
467
+ q (torch.Tensor): q tensor with shape (B, L_q, C).
468
+ k (torch.Tensor): k tensor with shape (B, L_k, C).
469
+ q_hw: These are spatial size of q tensors.
470
+ k_hw: These are spatial size of k tensors.
471
+ rel_pos_h: These are relative pos embeddings/params of height.
472
+ rel_pos_w: These are relative pos embeddings/params of width.
473
+ rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
474
+ the concat.
475
+ relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
476
+
477
+ Returns:
478
+ q, k: But, padded so that qk^T accounts for rel pos biases.
479
+ """
480
+ q_h, q_w = q_hw
481
+ k_h, k_w = k_hw
482
+
483
+ assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
484
+
485
+ if relative_coords is not None:
486
+ Rh = rel_pos_h[relative_coords]
487
+ Rw = rel_pos_w[relative_coords]
488
+ else:
489
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
490
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
491
+
492
+ B, _, dim = q.shape
493
+ r_q = q.reshape(B, q_h, q_w, dim)
494
+
495
+ old_scale = dim**0.5
496
+ new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
497
+ # attn will be divided by new_scale, but we want to divide q by old_scale
498
+ scale_ratio = new_scale / old_scale
499
+
500
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
501
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
502
+
503
+ eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
504
+ eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
505
+
506
+ eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
507
+ eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
508
+
509
+ q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
510
+ k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
511
+
512
+ return q, k