dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import math
5
6
  from typing import Any
6
7
 
7
8
  import torch
@@ -9,8 +10,7 @@ import torch.nn.functional as F
9
10
 
10
11
 
11
12
  def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
12
- """
13
- Select the closest conditioning frames to a given frame index.
13
+ """Select the closest conditioning frames to a given frame index.
14
14
 
15
15
  Args:
16
16
  frame_idx (int): Current frame index.
@@ -62,8 +62,7 @@ def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any
62
62
 
63
63
 
64
64
  def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
65
- """
66
- Generate 1D sinusoidal positional embeddings for given positions and dimensions.
65
+ """Generate 1D sinusoidal positional embeddings for given positions and dimensions.
67
66
 
68
67
  Args:
69
68
  pos_inds (torch.Tensor): Position indices for which to generate embeddings.
@@ -88,16 +87,17 @@ def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000)
88
87
  return pos_embed
89
88
 
90
89
 
91
- def init_t_xy(end_x: int, end_y: int):
92
- """
93
- Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
90
+ def init_t_xy(end_x: int, end_y: int, scale: float = 1.0, offset: int = 0):
91
+ """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
94
92
 
95
- This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
96
- and corresponding x and y coordinate tensors.
93
+ This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index
94
+ tensor and corresponding x and y coordinate tensors.
97
95
 
98
96
  Args:
99
97
  end_x (int): Width of the grid (number of columns).
100
98
  end_y (int): Height of the grid (number of rows).
99
+ scale (float): Scaling factor to apply to the coordinates.
100
+ offset (int): Offset to add to the coordinates.
101
101
 
102
102
  Returns:
103
103
  t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
@@ -113,21 +113,21 @@ def init_t_xy(end_x: int, end_y: int):
113
113
  t = torch.arange(end_x * end_y, dtype=torch.float32)
114
114
  t_x = (t % end_x).float()
115
115
  t_y = torch.div(t, end_x, rounding_mode="floor").float()
116
- return t_x, t_y
116
+ return t_x * scale + offset, t_y * scale + offset
117
117
 
118
118
 
119
- def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
120
- """
121
- Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
119
+ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0, scale_pos: float = 1.0):
120
+ """Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
122
121
 
123
- This function generates complex exponential positional encodings for a 2D grid of spatial positions,
124
- using separate frequency components for the x and y dimensions.
122
+ This function generates complex exponential positional encodings for a 2D grid of spatial positions, using separate
123
+ frequency components for the x and y dimensions.
125
124
 
126
125
  Args:
127
126
  dim (int): Dimension of the positional encoding.
128
127
  end_x (int): Width of the 2D grid.
129
128
  end_y (int): Height of the 2D grid.
130
129
  theta (float, optional): Scaling factor for frequency computation.
130
+ scale_pos (float, optional): Scaling factor for position coordinates.
131
131
 
132
132
  Returns:
133
133
  (torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
@@ -141,7 +141,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
141
141
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
142
142
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
143
143
 
144
- t_x, t_y = init_t_xy(end_x, end_y)
144
+ t_x, t_y = init_t_xy(end_x, end_y, scale=scale_pos)
145
145
  freqs_x = torch.outer(t_x, freqs_x)
146
146
  freqs_y = torch.outer(t_y, freqs_y)
147
147
  freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
@@ -150,11 +150,10 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
150
150
 
151
151
 
152
152
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
153
- """
154
- Reshape frequency tensor for broadcasting with input tensor.
153
+ """Reshape frequency tensor for broadcasting with input tensor.
155
154
 
156
- Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
157
- This function is typically used in positional encoding operations.
155
+ Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor. This function
156
+ is typically used in positional encoding operations.
158
157
 
159
158
  Args:
160
159
  freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
@@ -167,7 +166,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
167
166
  AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
168
167
  """
169
168
  ndim = x.ndim
170
- assert 0 <= 1 < ndim
169
+ assert ndim >= 2
171
170
  assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
172
171
  shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
173
172
  return freqs_cis.view(*shape)
@@ -179,8 +178,7 @@ def apply_rotary_enc(
179
178
  freqs_cis: torch.Tensor,
180
179
  repeat_freqs_k: bool = False,
181
180
  ):
182
- """
183
- Apply rotary positional encoding to query and key tensors.
181
+ """Apply rotary positional encoding to query and key tensors.
184
182
 
185
183
  This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
186
184
  components. RoPE is a technique that injects relative position information into self-attention mechanisms.
@@ -188,10 +186,10 @@ def apply_rotary_enc(
188
186
  Args:
189
187
  xq (torch.Tensor): Query tensor to encode with positional information.
190
188
  xk (torch.Tensor): Key tensor to encode with positional information.
191
- freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
192
- last two dimensions of xq.
193
- repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
194
- to match key sequence length.
189
+ freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the last
190
+ two dimensions of xq.
191
+ repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension to match
192
+ key sequence length.
195
193
 
196
194
  Returns:
197
195
  xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
@@ -212,16 +210,20 @@ def apply_rotary_enc(
212
210
  # No keys to rotate, due to dropout
213
211
  return xq_out.type_as(xq).to(xq.device), xk
214
212
  # Repeat freqs along seq_len dim to match k seq_len
215
- if repeat_freqs_k:
216
- r = xk_.shape[-2] // xq_.shape[-2]
217
- freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
213
+ if repeat_freqs_k and (r := xk_.shape[-2] // xq_.shape[-2]) > 1:
214
+ # MPS doesn't support repeat on complex tensors, decompose to real representation
215
+ if freqs_cis.device.type == "mps":
216
+ freqs_cis = torch.view_as_real(freqs_cis)
217
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 3)), r, 1, 1)
218
+ freqs_cis = torch.view_as_complex(freqs_cis.contiguous())
219
+ else:
220
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
218
221
  xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
219
222
  return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
220
223
 
221
224
 
222
225
  def window_partition(x: torch.Tensor, window_size: int):
223
- """
224
- Partition input tensor into non-overlapping windows with padding if needed.
226
+ """Partition input tensor into non-overlapping windows with padding if needed.
225
227
 
226
228
  Args:
227
229
  x (torch.Tensor): Input tensor with shape (B, H, W, C).
@@ -251,23 +253,22 @@ def window_partition(x: torch.Tensor, window_size: int):
251
253
 
252
254
 
253
255
  def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
254
- """
255
- Unpartition windowed sequences into original sequences and remove padding.
256
+ """Unpartition windowed sequences into original sequences and remove padding.
256
257
 
257
- This function reverses the windowing process, reconstructing the original input from windowed segments
258
- and removing any padding that was added during the windowing process.
258
+ This function reverses the windowing process, reconstructing the original input from windowed segments and removing
259
+ any padding that was added during the windowing process.
259
260
 
260
261
  Args:
261
262
  windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
262
- window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
263
- the size of each window, and C is the number of channels.
263
+ window_size, C), where B is the batch size, num_windows is the number of windows, window_size is the size of
264
+ each window, and C is the number of channels.
264
265
  window_size (int): Size of each window.
265
266
  pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
266
267
  hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
267
268
 
268
269
  Returns:
269
- (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
270
- are the original height and width, and C is the number of channels.
270
+ (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W are the
271
+ original height and width, and C is the number of channels.
271
272
 
272
273
  Examples:
273
274
  >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
@@ -289,18 +290,16 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[in
289
290
 
290
291
 
291
292
  def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
292
- """
293
- Extract relative positional embeddings based on query and key sizes.
293
+ """Extract relative positional embeddings based on query and key sizes.
294
294
 
295
295
  Args:
296
296
  q_size (int): Size of the query.
297
297
  k_size (int): Size of the key.
298
- rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
299
- distance and C is the embedding dimension.
298
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative distance
299
+ and C is the embedding dimension.
300
300
 
301
301
  Returns:
302
- (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
303
- k_size, C).
302
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, k_size, C).
304
303
 
305
304
  Examples:
306
305
  >>> q_size, k_size = 8, 16
@@ -338,8 +337,7 @@ def add_decomposed_rel_pos(
338
337
  q_size: tuple[int, int],
339
338
  k_size: tuple[int, int],
340
339
  ) -> torch.Tensor:
341
- """
342
- Add decomposed Relative Positional Embeddings to the attention map.
340
+ """Add decomposed Relative Positional Embeddings to the attention map.
343
341
 
344
342
  This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
345
343
  paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
@@ -354,8 +352,8 @@ def add_decomposed_rel_pos(
354
352
  k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
355
353
 
356
354
  Returns:
357
- (torch.Tensor): Updated attention map with added relative positional embeddings, shape
358
- (B, q_h * q_w, k_h * k_w).
355
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape (B, q_h * q_w, k_h *
356
+ k_w).
359
357
 
360
358
  Examples:
361
359
  >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
@@ -386,3 +384,129 @@ def add_decomposed_rel_pos(
386
384
  )
387
385
 
388
386
  return attn
387
+
388
+
389
+ def get_abs_pos(
390
+ abs_pos: torch.Tensor,
391
+ has_cls_token: bool,
392
+ hw: tuple[int, int],
393
+ retain_cls_token: bool = False,
394
+ tiling: bool = False,
395
+ ) -> torch.Tensor:
396
+ """Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
397
+ original embeddings.
398
+
399
+ Args:
400
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
401
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
402
+ hw (Tuple): size of input image tokens.
403
+ retain_cls_token: whether to retain the cls_token
404
+ tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
405
+
406
+ Returns:
407
+ Absolute positional embeddings after processing with shape (1, H, W, C),: if retain_cls_token is False,
408
+ otherwise (1, 1+H*W, C).
409
+ """
410
+ if retain_cls_token:
411
+ assert has_cls_token
412
+
413
+ h, w = hw
414
+ if has_cls_token:
415
+ cls_pos = abs_pos[:, :1]
416
+ abs_pos = abs_pos[:, 1:]
417
+
418
+ xy_num = abs_pos.shape[1]
419
+ size = int(math.sqrt(xy_num))
420
+ assert size * size == xy_num
421
+
422
+ if size != h or size != w:
423
+ new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
424
+ if tiling:
425
+ new_abs_pos = new_abs_pos.tile([1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])])[
426
+ :, :, :h, :w
427
+ ]
428
+ else:
429
+ new_abs_pos = F.interpolate(
430
+ new_abs_pos,
431
+ size=(h, w),
432
+ mode="bicubic",
433
+ align_corners=False,
434
+ )
435
+
436
+ if not retain_cls_token:
437
+ return new_abs_pos.permute(0, 2, 3, 1)
438
+ else:
439
+ # add cls_token back, flatten spatial dims
440
+ assert has_cls_token
441
+ return torch.cat(
442
+ [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
443
+ dim=1,
444
+ )
445
+
446
+ else:
447
+ if not retain_cls_token:
448
+ return abs_pos.reshape(1, h, w, -1)
449
+ else:
450
+ assert has_cls_token
451
+ return torch.cat([cls_pos, abs_pos], dim=1)
452
+
453
+
454
+ def concat_rel_pos(
455
+ q: torch.Tensor,
456
+ k: torch.Tensor,
457
+ q_hw: tuple[int, int],
458
+ k_hw: tuple[int, int],
459
+ rel_pos_h: torch.Tensor,
460
+ rel_pos_w: torch.Tensor,
461
+ rescale: bool = False,
462
+ relative_coords: torch.Tensor = None,
463
+ ) -> tuple[torch.Tensor, torch.Tensor]:
464
+ """Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now effectively including rel pos biases.
465
+
466
+ Args:
467
+ q (torch.Tensor): q tensor with shape (B, L_q, C).
468
+ k (torch.Tensor): k tensor with shape (B, L_k, C).
469
+ q_hw: These are spatial size of q tensors.
470
+ k_hw: These are spatial size of k tensors.
471
+ rel_pos_h: These are relative pos embeddings/params of height.
472
+ rel_pos_w: These are relative pos embeddings/params of width.
473
+ rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will scale by the wrong factor due to
474
+ the concat.
475
+ relative_coords (torch.Tensor, optional): Precomputed relative coords index tensor.
476
+
477
+ Returns:
478
+ q, k: But, padded so that qk^T accounts for rel pos biases.
479
+ """
480
+ q_h, q_w = q_hw
481
+ k_h, k_w = k_hw
482
+
483
+ assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
484
+
485
+ if relative_coords is not None:
486
+ Rh = rel_pos_h[relative_coords]
487
+ Rw = rel_pos_w[relative_coords]
488
+ else:
489
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
490
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
491
+
492
+ B, _, dim = q.shape
493
+ r_q = q.reshape(B, q_h, q_w, dim)
494
+
495
+ old_scale = dim**0.5
496
+ new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
497
+ # attn will be divided by new_scale, but we want to divide q by old_scale
498
+ scale_ratio = new_scale / old_scale
499
+
500
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
501
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
502
+
503
+ eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
504
+ eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
505
+
506
+ eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
507
+ eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
508
+
509
+ q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
510
+ k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(B, k_h * k_w, -1)
511
+
512
+ return q, k