ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py CHANGED
@@ -1,8 +1,9 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
+ from . import LOGGER
6
7
  from .checks import check_version
7
8
  from .metrics import bbox_iou, probiou
8
9
  from .ops import xywhr2xyxyxyxy
@@ -58,17 +59,46 @@ class TaskAlignedAssigner(nn.Module):
58
59
  """
59
60
  self.bs = pd_scores.shape[0]
60
61
  self.n_max_boxes = gt_bboxes.shape[1]
62
+ device = gt_bboxes.device
61
63
 
62
64
  if self.n_max_boxes == 0:
63
- device = gt_bboxes.device
64
65
  return (
65
- torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
66
- torch.zeros_like(pd_bboxes).to(device),
67
- torch.zeros_like(pd_scores).to(device),
68
- torch.zeros_like(pd_scores[..., 0]).to(device),
69
- torch.zeros_like(pd_scores[..., 0]).to(device),
66
+ torch.full_like(pd_scores[..., 0], self.bg_idx),
67
+ torch.zeros_like(pd_bboxes),
68
+ torch.zeros_like(pd_scores),
69
+ torch.zeros_like(pd_scores[..., 0]),
70
+ torch.zeros_like(pd_scores[..., 0]),
70
71
  )
71
72
 
73
+ try:
74
+ return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
75
+ except torch.OutOfMemoryError:
76
+ # Move tensors to CPU, compute, then move back to original device
77
+ LOGGER.warning("WARNING: CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
78
+ cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
79
+ result = self._forward(*cpu_tensors)
80
+ return tuple(t.to(device) for t in result)
81
+
82
+ def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
83
+ """
84
+ Compute the task-aligned assignment. Reference code is available at
85
+ https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
86
+
87
+ Args:
88
+ pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
89
+ pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
90
+ anc_points (Tensor): shape(num_total_anchors, 2)
91
+ gt_labels (Tensor): shape(bs, n_max_boxes, 1)
92
+ gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
93
+ mask_gt (Tensor): shape(bs, n_max_boxes, 1)
94
+
95
+ Returns:
96
+ target_labels (Tensor): shape(bs, num_total_anchors)
97
+ target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
98
+ target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
99
+ fg_mask (Tensor): shape(bs, num_total_anchors)
100
+ target_gt_idx (Tensor): shape(bs, num_total_anchors)
101
+ """
72
102
  mask_pos, align_metric, overlaps = self.get_pos_mask(
73
103
  pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
74
104
  )
@@ -140,7 +170,6 @@ class TaskAlignedAssigner(nn.Module):
140
170
  Returns:
141
171
  (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
142
172
  """
143
-
144
173
  # (b, max_num_obj, topk)
145
174
  topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
146
175
  if topk_mask is None:
@@ -184,7 +213,6 @@ class TaskAlignedAssigner(nn.Module):
184
213
  for positive anchor points, where num_classes is the number
185
214
  of object classes.
186
215
  """
187
-
188
216
  # Assigned target labels, (b, 1)
189
217
  batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
190
218
  target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
@@ -212,14 +240,19 @@ class TaskAlignedAssigner(nn.Module):
212
240
  @staticmethod
213
241
  def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
214
242
  """
215
- Select the positive anchor center in gt.
243
+ Select positive anchor centers within ground truth bounding boxes.
216
244
 
217
245
  Args:
218
- xy_centers (Tensor): shape(h*w, 2)
219
- gt_bboxes (Tensor): shape(b, n_boxes, 4)
246
+ xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
247
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
248
+ eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
220
249
 
221
250
  Returns:
222
- (Tensor): shape(b, n_boxes, h*w)
251
+ (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
252
+
253
+ Note:
254
+ b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
255
+ Bounding box format: [x_min, y_min, x_max, y_max].
223
256
  """
224
257
  n_anchors = xy_centers.shape[0]
225
258
  bs, n_boxes, _ = gt_bboxes.shape
@@ -231,18 +264,22 @@ class TaskAlignedAssigner(nn.Module):
231
264
  @staticmethod
232
265
  def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
233
266
  """
234
- If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.
267
+ Select anchor boxes with highest IoU when assigned to multiple ground truths.
235
268
 
236
269
  Args:
237
- mask_pos (Tensor): shape(b, n_max_boxes, h*w)
238
- overlaps (Tensor): shape(b, n_max_boxes, h*w)
270
+ mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
271
+ overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
272
+ n_max_boxes (int): Maximum number of ground truth boxes.
239
273
 
240
274
  Returns:
241
- target_gt_idx (Tensor): shape(b, h*w)
242
- fg_mask (Tensor): shape(b, h*w)
243
- mask_pos (Tensor): shape(b, n_max_boxes, h*w)
275
+ target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
276
+ fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
277
+ mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
278
+
279
+ Note:
280
+ b: batch size, h: height, w: width.
244
281
  """
245
- # (b, n_max_boxes, h*w) -> (b, h*w)
282
+ # Convert (b, n_max_boxes, h*w) -> (b, h*w)
246
283
  fg_mask = mask_pos.sum(-2)
247
284
  if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
248
285
  mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
@@ -259,6 +296,8 @@ class TaskAlignedAssigner(nn.Module):
259
296
 
260
297
 
261
298
  class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
299
+ """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
300
+
262
301
  def iou_calculation(self, gt_bboxes, pd_bboxes):
263
302
  """IoU calculation for rotated bounding boxes."""
264
303
  return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
@@ -297,7 +336,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
297
336
  assert feats is not None
298
337
  dtype, device = feats[0].dtype, feats[0].device
299
338
  for i, stride in enumerate(strides):
300
- _, _, h, w = feats[i].shape
339
+ h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
301
340
  sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
302
341
  sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
303
342
  sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
@@ -326,14 +365,16 @@ def bbox2dist(anchor_points, bbox, reg_max):
326
365
 
327
366
  def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
328
367
  """
329
- Decode predicted object bounding box coordinates from anchor points and distribution.
368
+ Decode predicted rotated bounding box coordinates from anchor points and distribution.
330
369
 
331
370
  Args:
332
- pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
333
- pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
334
- anchor_points (torch.Tensor): Anchor points, (h*w, 2).
371
+ pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
372
+ pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
373
+ anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
374
+ dim (int, optional): Dimension along which to split. Defaults to -1.
375
+
335
376
  Returns:
336
- (torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
377
+ (torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
337
378
  """
338
379
  lt, rb = pred_dist.split(2, dim=dim)
339
380
  cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)