dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,416 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from . import LOGGER
7
+ from .checks import check_version
8
+ from .metrics import bbox_iou, probiou
9
+ from .ops import xywhr2xyxyxyxy
10
+
11
+ TORCH_1_10 = check_version(torch.__version__, "1.10.0")
12
+
13
+
14
+ class TaskAlignedAssigner(nn.Module):
15
+ """
16
+ A task-aligned assigner for object detection.
17
+
18
+ This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
19
+ classification and localization information.
20
+
21
+ Attributes:
22
+ topk (int): The number of top candidates to consider.
23
+ num_classes (int): The number of object classes.
24
+ bg_idx (int): Background class index.
25
+ alpha (float): The alpha parameter for the classification component of the task-aligned metric.
26
+ beta (float): The beta parameter for the localization component of the task-aligned metric.
27
+ eps (float): A small value to prevent division by zero.
28
+ """
29
+
30
+ def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
31
+ """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
32
+ super().__init__()
33
+ self.topk = topk
34
+ self.num_classes = num_classes
35
+ self.bg_idx = num_classes
36
+ self.alpha = alpha
37
+ self.beta = beta
38
+ self.eps = eps
39
+
40
+ @torch.no_grad()
41
+ def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
42
+ """
43
+ Compute the task-aligned assignment.
44
+
45
+ Args:
46
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
47
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
48
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
49
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
50
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
51
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
52
+
53
+ Returns:
54
+ target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
55
+ target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
56
+ target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
57
+ fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
58
+ target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
59
+
60
+ References:
61
+ https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
62
+ """
63
+ self.bs = pd_scores.shape[0]
64
+ self.n_max_boxes = gt_bboxes.shape[1]
65
+ device = gt_bboxes.device
66
+
67
+ if self.n_max_boxes == 0:
68
+ return (
69
+ torch.full_like(pd_scores[..., 0], self.bg_idx),
70
+ torch.zeros_like(pd_bboxes),
71
+ torch.zeros_like(pd_scores),
72
+ torch.zeros_like(pd_scores[..., 0]),
73
+ torch.zeros_like(pd_scores[..., 0]),
74
+ )
75
+
76
+ try:
77
+ return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
78
+ except torch.cuda.OutOfMemoryError:
79
+ # Move tensors to CPU, compute, then move back to original device
80
+ LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
81
+ cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
82
+ result = self._forward(*cpu_tensors)
83
+ return tuple(t.to(device) for t in result)
84
+
85
+ def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
86
+ """
87
+ Compute the task-aligned assignment.
88
+
89
+ Args:
90
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
91
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
92
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
93
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
94
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
95
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
96
+
97
+ Returns:
98
+ target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
99
+ target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
100
+ target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
101
+ fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
102
+ target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
103
+ """
104
+ mask_pos, align_metric, overlaps = self.get_pos_mask(
105
+ pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
106
+ )
107
+
108
+ target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
109
+
110
+ # Assigned target
111
+ target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
112
+
113
+ # Normalize
114
+ align_metric *= mask_pos
115
+ pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
116
+ pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
117
+ norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
118
+ target_scores = target_scores * norm_align_metric
119
+
120
+ return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
121
+
122
+ def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
123
+ """
124
+ Get positive mask for each ground truth box.
125
+
126
+ Args:
127
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
128
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
129
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
130
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
131
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
132
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
133
+
134
+ Returns:
135
+ mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
136
+ align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
137
+ overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
138
+ """
139
+ mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
140
+ # Get anchor_align metric, (b, max_num_obj, h*w)
141
+ align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
142
+ # Get topk_metric mask, (b, max_num_obj, h*w)
143
+ mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
144
+ # Merge all mask to a final mask, (b, max_num_obj, h*w)
145
+ mask_pos = mask_topk * mask_in_gts * mask_gt
146
+
147
+ return mask_pos, align_metric, overlaps
148
+
149
+ def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
150
+ """
151
+ Compute alignment metric given predicted and ground truth bounding boxes.
152
+
153
+ Args:
154
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
155
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
156
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
157
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
158
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
159
+
160
+ Returns:
161
+ align_metric (torch.Tensor): Alignment metric combining classification and localization.
162
+ overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
163
+ """
164
+ na = pd_bboxes.shape[-2]
165
+ mask_gt = mask_gt.bool() # b, max_num_obj, h*w
166
+ overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
167
+ bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
168
+
169
+ ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
170
+ ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
171
+ ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
172
+ # Get the scores of each grid for each gt cls
173
+ bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
174
+
175
+ # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
176
+ pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
177
+ gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
178
+ overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
179
+
180
+ align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
181
+ return align_metric, overlaps
182
+
183
+ def iou_calculation(self, gt_bboxes, pd_bboxes):
184
+ """
185
+ Calculate IoU for horizontal bounding boxes.
186
+
187
+ Args:
188
+ gt_bboxes (torch.Tensor): Ground truth boxes.
189
+ pd_bboxes (torch.Tensor): Predicted boxes.
190
+
191
+ Returns:
192
+ (torch.Tensor): IoU values between each pair of boxes.
193
+ """
194
+ return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
195
+
196
+ def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
197
+ """
198
+ Select the top-k candidates based on the given metrics.
199
+
200
+ Args:
201
+ metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
202
+ max_num_obj is the maximum number of objects, and h*w represents the
203
+ total number of anchor points.
204
+ largest (bool): If True, select the largest values; otherwise, select the smallest values.
205
+ topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
206
+ topk is the number of top candidates to consider. If not provided,
207
+ the top-k values are automatically computed based on the given metrics.
208
+
209
+ Returns:
210
+ (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
211
+ """
212
+ # (b, max_num_obj, topk)
213
+ topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
214
+ if topk_mask is None:
215
+ topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
216
+ # (b, max_num_obj, topk)
217
+ topk_idxs.masked_fill_(~topk_mask, 0)
218
+
219
+ # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
220
+ count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
221
+ ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
222
+ for k in range(self.topk):
223
+ # Expand topk_idxs for each value of k and add 1 at the specified positions
224
+ count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
225
+ # Filter invalid bboxes
226
+ count_tensor.masked_fill_(count_tensor > 1, 0)
227
+
228
+ return count_tensor.to(metrics.dtype)
229
+
230
+ def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
231
+ """
232
+ Compute target labels, target bounding boxes, and target scores for the positive anchor points.
233
+
234
+ Args:
235
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
236
+ batch size and max_num_obj is the maximum number of objects.
237
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
238
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
239
+ anchor points, with shape (b, h*w), where h*w is the total
240
+ number of anchor points.
241
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
242
+ (foreground) anchor points.
243
+
244
+ Returns:
245
+ target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
246
+ target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
247
+ anchor points.
248
+ target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
249
+ anchor points.
250
+ """
251
+ # Assigned target labels, (b, 1)
252
+ batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
253
+ target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
254
+ target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
255
+
256
+ # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
257
+ target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
258
+
259
+ # Assigned target scores
260
+ target_labels.clamp_(0)
261
+
262
+ # 10x faster than F.one_hot()
263
+ target_scores = torch.zeros(
264
+ (target_labels.shape[0], target_labels.shape[1], self.num_classes),
265
+ dtype=torch.int64,
266
+ device=target_labels.device,
267
+ ) # (b, h*w, 80)
268
+ target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
269
+
270
+ fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
271
+ target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
272
+
273
+ return target_labels, target_bboxes, target_scores
274
+
275
+ @staticmethod
276
+ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
277
+ """
278
+ Select positive anchor centers within ground truth bounding boxes.
279
+
280
+ Args:
281
+ xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
282
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
283
+ eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
284
+
285
+ Returns:
286
+ (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
287
+
288
+ Note:
289
+ b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
290
+ Bounding box format: [x_min, y_min, x_max, y_max].
291
+ """
292
+ n_anchors = xy_centers.shape[0]
293
+ bs, n_boxes, _ = gt_bboxes.shape
294
+ lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
295
+ bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
296
+ return bbox_deltas.amin(3).gt_(eps)
297
+
298
+ @staticmethod
299
+ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
300
+ """
301
+ Select anchor boxes with highest IoU when assigned to multiple ground truths.
302
+
303
+ Args:
304
+ mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
305
+ overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
306
+ n_max_boxes (int): Maximum number of ground truth boxes.
307
+
308
+ Returns:
309
+ target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
310
+ fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
311
+ mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
312
+ """
313
+ # Convert (b, n_max_boxes, h*w) -> (b, h*w)
314
+ fg_mask = mask_pos.sum(-2)
315
+ if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
316
+ mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
317
+ max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
318
+
319
+ is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
320
+ is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
321
+
322
+ mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
323
+ fg_mask = mask_pos.sum(-2)
324
+ # Find each grid serve which gt(index)
325
+ target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
326
+ return target_gt_idx, fg_mask, mask_pos
327
+
328
+
329
+ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
330
+ """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
331
+
332
+ def iou_calculation(self, gt_bboxes, pd_bboxes):
333
+ """Calculate IoU for rotated bounding boxes."""
334
+ return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
335
+
336
+ @staticmethod
337
+ def select_candidates_in_gts(xy_centers, gt_bboxes):
338
+ """
339
+ Select the positive anchor center in gt for rotated bounding boxes.
340
+
341
+ Args:
342
+ xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
343
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
344
+
345
+ Returns:
346
+ (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
347
+ """
348
+ # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
349
+ corners = xywhr2xyxyxyxy(gt_bboxes)
350
+ # (b, n_boxes, 1, 2)
351
+ a, b, _, d = corners.split(1, dim=-2)
352
+ ab = b - a
353
+ ad = d - a
354
+
355
+ # (b, n_boxes, h*w, 2)
356
+ ap = xy_centers - a
357
+ norm_ab = (ab * ab).sum(dim=-1)
358
+ norm_ad = (ad * ad).sum(dim=-1)
359
+ ap_dot_ab = (ap * ab).sum(dim=-1)
360
+ ap_dot_ad = (ap * ad).sum(dim=-1)
361
+ return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
362
+
363
+
364
+ def make_anchors(feats, strides, grid_cell_offset=0.5):
365
+ """Generate anchors from features."""
366
+ anchor_points, stride_tensor = [], []
367
+ assert feats is not None
368
+ dtype, device = feats[0].dtype, feats[0].device
369
+ for i, stride in enumerate(strides):
370
+ h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
371
+ sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
372
+ sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
373
+ sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
374
+ anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
375
+ stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
376
+ return torch.cat(anchor_points), torch.cat(stride_tensor)
377
+
378
+
379
+ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
380
+ """Transform distance(ltrb) to box(xywh or xyxy)."""
381
+ lt, rb = distance.chunk(2, dim)
382
+ x1y1 = anchor_points - lt
383
+ x2y2 = anchor_points + rb
384
+ if xywh:
385
+ c_xy = (x1y1 + x2y2) / 2
386
+ wh = x2y2 - x1y1
387
+ return torch.cat((c_xy, wh), dim) # xywh bbox
388
+ return torch.cat((x1y1, x2y2), dim) # xyxy bbox
389
+
390
+
391
+ def bbox2dist(anchor_points, bbox, reg_max):
392
+ """Transform bbox(xyxy) to dist(ltrb)."""
393
+ x1y1, x2y2 = bbox.chunk(2, -1)
394
+ return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
395
+
396
+
397
+ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
398
+ """
399
+ Decode predicted rotated bounding box coordinates from anchor points and distribution.
400
+
401
+ Args:
402
+ pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
403
+ pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
404
+ anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
405
+ dim (int, optional): Dimension along which to split. Defaults to -1.
406
+
407
+ Returns:
408
+ (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
409
+ """
410
+ lt, rb = pred_dist.split(2, dim=dim)
411
+ cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
412
+ # (bs, h*w, 1)
413
+ xf, yf = ((rb - lt) / 2).split(1, dim=dim)
414
+ x, y = xf * cos - yf * sin, xf * sin + yf * cos
415
+ xy = torch.cat([x, y], dim=dim) + anchor_points
416
+ return torch.cat([xy, lt + rb], dim=dim)