dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py CHANGED
@@ -1,51 +1,65 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import torch
4
6
  import torch.nn as nn
5
7
 
6
8
  from . import LOGGER
7
9
  from .metrics import bbox_iou, probiou
8
- from .ops import xywhr2xyxyxyxy
10
+ from .ops import xywh2xyxy, xywhr2xyxyxyxy, xyxy2xywh
9
11
  from .torch_utils import TORCH_1_11
10
12
 
11
13
 
12
14
  class TaskAlignedAssigner(nn.Module):
13
- """
14
- A task-aligned assigner for object detection.
15
+ """A task-aligned assigner for object detection.
15
16
 
16
17
  This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
17
18
  classification and localization information.
18
19
 
19
20
  Attributes:
20
21
  topk (int): The number of top candidates to consider.
22
+ topk2 (int): Secondary topk value for additional filtering.
21
23
  num_classes (int): The number of object classes.
22
24
  alpha (float): The alpha parameter for the classification component of the task-aligned metric.
23
25
  beta (float): The beta parameter for the localization component of the task-aligned metric.
26
+ stride (list): List of stride values for different feature levels.
24
27
  eps (float): A small value to prevent division by zero.
25
28
  """
26
29
 
27
- def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
28
- """
29
- Initialize a TaskAlignedAssigner object with customizable hyperparameters.
30
+ def __init__(
31
+ self,
32
+ topk: int = 13,
33
+ num_classes: int = 80,
34
+ alpha: float = 1.0,
35
+ beta: float = 6.0,
36
+ stride: list = [8, 16, 32],
37
+ eps: float = 1e-9,
38
+ topk2=None,
39
+ ):
40
+ """Initialize a TaskAlignedAssigner object with customizable hyperparameters.
30
41
 
31
42
  Args:
32
43
  topk (int, optional): The number of top candidates to consider.
33
44
  num_classes (int, optional): The number of object classes.
34
45
  alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
35
46
  beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
47
+ stride (list, optional): List of stride values for different feature levels.
36
48
  eps (float, optional): A small value to prevent division by zero.
49
+ topk2 (int, optional): Secondary topk value for additional filtering.
37
50
  """
38
51
  super().__init__()
39
52
  self.topk = topk
53
+ self.topk2 = topk2 or topk
40
54
  self.num_classes = num_classes
41
55
  self.alpha = alpha
42
56
  self.beta = beta
57
+ self.stride = stride
43
58
  self.eps = eps
44
59
 
45
60
  @torch.no_grad()
46
61
  def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
47
- """
48
- Compute the task-aligned assignment.
62
+ """Compute the task-aligned assignment.
49
63
 
50
64
  Args:
51
65
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -80,16 +94,17 @@ class TaskAlignedAssigner(nn.Module):
80
94
 
81
95
  try:
82
96
  return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
83
- except torch.cuda.OutOfMemoryError:
84
- # Move tensors to CPU, compute, then move back to original device
85
- LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
86
- cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
87
- result = self._forward(*cpu_tensors)
88
- return tuple(t.to(device) for t in result)
97
+ except RuntimeError as e:
98
+ if "out of memory" in str(e).lower():
99
+ # Move tensors to CPU, compute, then move back to original device
100
+ LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
101
+ cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
102
+ result = self._forward(*cpu_tensors)
103
+ return tuple(t.to(device) for t in result)
104
+ raise
89
105
 
90
106
  def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
91
- """
92
- Compute the task-aligned assignment.
107
+ """Compute the task-aligned assignment.
93
108
 
94
109
  Args:
95
110
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -110,7 +125,9 @@ class TaskAlignedAssigner(nn.Module):
110
125
  pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
111
126
  )
112
127
 
113
- target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
128
+ target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(
129
+ mask_pos, overlaps, self.n_max_boxes, align_metric
130
+ )
114
131
 
115
132
  # Assigned target
116
133
  target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
@@ -125,8 +142,7 @@ class TaskAlignedAssigner(nn.Module):
125
142
  return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
126
143
 
127
144
  def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
128
- """
129
- Get positive mask for each ground truth box.
145
+ """Get positive mask for each ground truth box.
130
146
 
131
147
  Args:
132
148
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -139,9 +155,9 @@ class TaskAlignedAssigner(nn.Module):
139
155
  Returns:
140
156
  mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
141
157
  align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
142
- overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
158
+ overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
143
159
  """
144
- mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
160
+ mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes, mask_gt)
145
161
  # Get anchor_align metric, (b, max_num_obj, h*w)
146
162
  align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
147
163
  # Get topk_metric mask, (b, max_num_obj, h*w)
@@ -152,8 +168,7 @@ class TaskAlignedAssigner(nn.Module):
152
168
  return mask_pos, align_metric, overlaps
153
169
 
154
170
  def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
155
- """
156
- Compute alignment metric given predicted and ground truth bounding boxes.
171
+ """Compute alignment metric given predicted and ground truth bounding boxes.
157
172
 
158
173
  Args:
159
174
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -186,8 +201,7 @@ class TaskAlignedAssigner(nn.Module):
186
201
  return align_metric, overlaps
187
202
 
188
203
  def iou_calculation(self, gt_bboxes, pd_bboxes):
189
- """
190
- Calculate IoU for horizontal bounding boxes.
204
+ """Calculate IoU for horizontal bounding boxes.
191
205
 
192
206
  Args:
193
207
  gt_bboxes (torch.Tensor): Ground truth boxes.
@@ -199,14 +213,13 @@ class TaskAlignedAssigner(nn.Module):
199
213
  return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
200
214
 
201
215
  def select_topk_candidates(self, metrics, topk_mask=None):
202
- """
203
- Select the top-k candidates based on the given metrics.
216
+ """Select the top-k candidates based on the given metrics.
204
217
 
205
218
  Args:
206
219
  metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
207
220
  the maximum number of objects, and h*w represents the total number of anchor points.
208
- topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
209
- topk is the number of top candidates to consider. If not provided, the top-k values are automatically
221
+ topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
222
+ is the number of top candidates to consider. If not provided, the top-k values are automatically
210
223
  computed based on the given metrics.
211
224
 
212
225
  Returns:
@@ -231,18 +244,16 @@ class TaskAlignedAssigner(nn.Module):
231
244
  return count_tensor.to(metrics.dtype)
232
245
 
233
246
  def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
234
- """
235
- Compute target labels, target bounding boxes, and target scores for the positive anchor points.
247
+ """Compute target labels, target bounding boxes, and target scores for the positive anchor points.
236
248
 
237
249
  Args:
238
- gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
239
- batch size and max_num_obj is the maximum number of objects.
250
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
251
+ max_num_obj is the maximum number of objects.
240
252
  gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
241
- target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
242
- anchor points, with shape (b, h*w), where h*w is the total
243
- number of anchor points.
244
- fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
245
- (foreground) anchor points.
253
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
254
+ shape (b, h*w), where h*w is the total number of anchor points.
255
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
256
+ points.
246
257
 
247
258
  Returns:
248
259
  target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
@@ -273,38 +284,42 @@ class TaskAlignedAssigner(nn.Module):
273
284
 
274
285
  return target_labels, target_bboxes, target_scores
275
286
 
276
- @staticmethod
277
- def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
278
- """
279
- Select positive anchor centers within ground truth bounding boxes.
287
+ def select_candidates_in_gts(self, xy_centers, gt_bboxes, mask_gt, eps=1e-9):
288
+ """Select positive anchor centers within ground truth bounding boxes.
280
289
 
281
290
  Args:
282
291
  xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
283
292
  gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
293
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes, shape (b, n_boxes, 1).
284
294
  eps (float, optional): Small value for numerical stability.
285
295
 
286
296
  Returns:
287
297
  (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
288
298
 
289
- Note:
290
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
291
- Bounding box format: [x_min, y_min, x_max, y_max].
299
+ Notes:
300
+ - b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
301
+ - Bounding box format: [x_min, y_min, x_max, y_max].
292
302
  """
303
+ gt_bboxes_xywh = xyxy2xywh(gt_bboxes)
304
+ wh_mask = gt_bboxes_xywh[..., 2:] < self.stride[0] # the smallest stride
305
+ stride_val = torch.tensor(self.stride[1], dtype=gt_bboxes_xywh.dtype, device=gt_bboxes_xywh.device)
306
+ gt_bboxes_xywh[..., 2:] = torch.where((wh_mask * mask_gt).bool(), stride_val, gt_bboxes_xywh[..., 2:])
307
+ gt_bboxes = xywh2xyxy(gt_bboxes_xywh)
308
+
293
309
  n_anchors = xy_centers.shape[0]
294
310
  bs, n_boxes, _ = gt_bboxes.shape
295
311
  lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
296
312
  bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
297
313
  return bbox_deltas.amin(3).gt_(eps)
298
314
 
299
- @staticmethod
300
- def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
301
- """
302
- Select anchor boxes with highest IoU when assigned to multiple ground truths.
315
+ def select_highest_overlaps(self, mask_pos, overlaps, n_max_boxes, align_metric):
316
+ """Select anchor boxes with highest IoU when assigned to multiple ground truths.
303
317
 
304
318
  Args:
305
319
  mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
306
320
  overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
307
321
  n_max_boxes (int): Maximum number of ground truth boxes.
322
+ align_metric (torch.Tensor): Alignment metric for selecting best matches.
308
323
 
309
324
  Returns:
310
325
  target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
@@ -315,12 +330,20 @@ class TaskAlignedAssigner(nn.Module):
315
330
  fg_mask = mask_pos.sum(-2)
316
331
  if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
317
332
  mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
318
- max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
319
333
 
334
+ max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
320
335
  is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
321
336
  is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
322
-
323
337
  mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
338
+
339
+ fg_mask = mask_pos.sum(-2)
340
+
341
+ if self.topk2 != self.topk:
342
+ align_metric = align_metric * mask_pos # update overlaps
343
+ max_overlaps_idx = torch.topk(align_metric, self.topk2, dim=-1, largest=True).indices # (b, n_max_boxes)
344
+ topk_idx = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) # update mask_pos
345
+ topk_idx.scatter_(-1, max_overlaps_idx, 1.0)
346
+ mask_pos *= topk_idx
324
347
  fg_mask = mask_pos.sum(-2)
325
348
  # Find each grid serve which gt(index)
326
349
  target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
@@ -335,13 +358,14 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
335
358
  return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
336
359
 
337
360
  @staticmethod
338
- def select_candidates_in_gts(xy_centers, gt_bboxes):
339
- """
340
- Select the positive anchor center in gt for rotated bounding boxes.
361
+ def select_candidates_in_gts(xy_centers, gt_bboxes, mask_gt):
362
+ """Select the positive anchor center in gt for rotated bounding boxes.
341
363
 
342
364
  Args:
343
365
  xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
344
366
  gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
367
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (b, n_boxes, 1).
368
+ stride (list[int]): List of stride values for each feature map level.
345
369
 
346
370
  Returns:
347
371
  (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
@@ -367,7 +391,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
367
391
  anchor_points, stride_tensor = [], []
368
392
  assert feats is not None
369
393
  dtype, device = feats[0].dtype, feats[0].device
370
- for i, stride in enumerate(strides):
394
+ for i in range(len(feats)): # use len(feats) to avoid TracerWarning from iterating over strides tensor
395
+ stride = strides[i]
371
396
  h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
372
397
  sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
373
398
  sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
@@ -389,15 +414,17 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
389
414
  return torch.cat((x1y1, x2y2), dim) # xyxy bbox
390
415
 
391
416
 
392
- def bbox2dist(anchor_points, bbox, reg_max):
417
+ def bbox2dist(anchor_points: torch.Tensor, bbox: torch.Tensor, reg_max: int | None = None) -> torch.Tensor:
393
418
  """Transform bbox(xyxy) to dist(ltrb)."""
394
419
  x1y1, x2y2 = bbox.chunk(2, -1)
395
- return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
420
+ dist = torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1)
421
+ if reg_max is not None:
422
+ dist = dist.clamp_(0, reg_max - 0.01) # dist (lt, rb)
423
+ return dist
396
424
 
397
425
 
398
426
  def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
399
- """
400
- Decode predicted rotated bounding box coordinates from anchor points and distribution.
427
+ """Decode predicted rotated bounding box coordinates from anchor points and distribution.
401
428
 
402
429
  Args:
403
430
  pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
@@ -415,3 +442,42 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
415
442
  x, y = xf * cos - yf * sin, xf * sin + yf * cos
416
443
  xy = torch.cat([x, y], dim=dim) + anchor_points
417
444
  return torch.cat([xy, lt + rb], dim=dim)
445
+
446
+
447
+ def rbox2dist(
448
+ target_bboxes: torch.Tensor,
449
+ anchor_points: torch.Tensor,
450
+ target_angle: torch.Tensor,
451
+ dim: int = -1,
452
+ reg_max: int | None = None,
453
+ ):
454
+ """Decode rotated bounding box (xywh) to distance(ltrb). This is the inverse of dist2rbox.
455
+
456
+ Args:
457
+ target_bboxes (torch.Tensor): Target rotated bounding boxes with shape (bs, h*w, 4), format [x, y, w, h].
458
+ anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
459
+ target_angle (torch.Tensor): Target angle with shape (bs, h*w, 1).
460
+ dim (int, optional): Dimension along which to split.
461
+ reg_max (int, optional): Maximum regression value for clamping.
462
+
463
+ Returns:
464
+ (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4), format [l, t, r, b].
465
+ """
466
+ xy, wh = target_bboxes.split(2, dim=dim)
467
+ offset = xy - anchor_points # (bs, h*w, 2)
468
+ offset_x, offset_y = offset.split(1, dim=dim)
469
+ cos, sin = torch.cos(target_angle), torch.sin(target_angle)
470
+ xf = offset_x * cos + offset_y * sin
471
+ yf = -offset_x * sin + offset_y * cos
472
+
473
+ w, h = wh.split(1, dim=dim)
474
+ target_l = w / 2 - xf
475
+ target_t = h / 2 - yf
476
+ target_r = w / 2 + xf
477
+ target_b = h / 2 + yf
478
+
479
+ dist = torch.cat([target_l, target_t, target_r, target_b], dim=dim)
480
+ if reg_max is not None:
481
+ dist = dist.clamp_(0, reg_max - 0.01)
482
+
483
+ return dist