ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +125 -39
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +34 -33
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +33 -47
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +69 -90
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +31 -38
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +21 -26
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +23 -17
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +29 -24
  53. ultralytics/models/nas/predict.py +14 -11
  54. ultralytics/models/nas/val.py +11 -13
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +21 -21
  57. ultralytics/models/rtdetr/train.py +25 -24
  58. ultralytics/models/rtdetr/val.py +47 -14
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +30 -12
  73. ultralytics/models/yolo/classify/train.py +83 -19
  74. ultralytics/models/yolo/classify/val.py +45 -23
  75. ultralytics/models/yolo/detect/predict.py +29 -19
  76. ultralytics/models/yolo/detect/train.py +90 -23
  77. ultralytics/models/yolo/detect/val.py +150 -29
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +18 -13
  80. ultralytics/models/yolo/obb/train.py +12 -8
  81. ultralytics/models/yolo/obb/val.py +35 -22
  82. ultralytics/models/yolo/pose/predict.py +28 -15
  83. ultralytics/models/yolo/pose/train.py +21 -8
  84. ultralytics/models/yolo/pose/val.py +51 -31
  85. ultralytics/models/yolo/segment/predict.py +27 -16
  86. ultralytics/models/yolo/segment/train.py +11 -8
  87. ultralytics/models/yolo/segment/val.py +110 -29
  88. ultralytics/models/yolo/world/train.py +43 -16
  89. ultralytics/models/yolo/world/train_world.py +61 -36
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +12 -12
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +226 -79
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +37 -35
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +139 -68
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +37 -56
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +117 -52
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +65 -61
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +72 -59
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +202 -64
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +13 -25
  149. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.88.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py CHANGED
@@ -21,6 +21,7 @@ class TaskAlignedAssigner(nn.Module):
21
21
  Attributes:
22
22
  topk (int): The number of top candidates to consider.
23
23
  num_classes (int): The number of object classes.
24
+ bg_idx (int): Background class index.
24
25
  alpha (float): The alpha parameter for the classification component of the task-aligned metric.
25
26
  beta (float): The beta parameter for the localization component of the task-aligned metric.
26
27
  eps (float): A small value to prevent division by zero.
@@ -39,23 +40,25 @@ class TaskAlignedAssigner(nn.Module):
39
40
  @torch.no_grad()
40
41
  def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
41
42
  """
42
- Compute the task-aligned assignment. Reference code is available at
43
- https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
43
+ Compute the task-aligned assignment.
44
44
 
45
45
  Args:
46
- pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
47
- pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
48
- anc_points (Tensor): shape(num_total_anchors, 2)
49
- gt_labels (Tensor): shape(bs, n_max_boxes, 1)
50
- gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
51
- mask_gt (Tensor): shape(bs, n_max_boxes, 1)
46
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
47
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
48
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
49
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
50
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
51
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
52
52
 
53
53
  Returns:
54
- target_labels (Tensor): shape(bs, num_total_anchors)
55
- target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
56
- target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
57
- fg_mask (Tensor): shape(bs, num_total_anchors)
58
- target_gt_idx (Tensor): shape(bs, num_total_anchors)
54
+ target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
55
+ target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
56
+ target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
57
+ fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
58
+ target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
59
+
60
+ References:
61
+ https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
59
62
  """
60
63
  self.bs = pd_scores.shape[0]
61
64
  self.n_max_boxes = gt_bboxes.shape[1]
@@ -81,23 +84,22 @@ class TaskAlignedAssigner(nn.Module):
81
84
 
82
85
  def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
83
86
  """
84
- Compute the task-aligned assignment. Reference code is available at
85
- https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
87
+ Compute the task-aligned assignment.
86
88
 
87
89
  Args:
88
- pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
89
- pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
90
- anc_points (Tensor): shape(num_total_anchors, 2)
91
- gt_labels (Tensor): shape(bs, n_max_boxes, 1)
92
- gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
93
- mask_gt (Tensor): shape(bs, n_max_boxes, 1)
90
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
91
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
92
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
93
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
94
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
95
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
94
96
 
95
97
  Returns:
96
- target_labels (Tensor): shape(bs, num_total_anchors)
97
- target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
98
- target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
99
- fg_mask (Tensor): shape(bs, num_total_anchors)
100
- target_gt_idx (Tensor): shape(bs, num_total_anchors)
98
+ target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
99
+ target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
100
+ target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
101
+ fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
102
+ target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
101
103
  """
102
104
  mask_pos, align_metric, overlaps = self.get_pos_mask(
103
105
  pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
@@ -118,7 +120,22 @@ class TaskAlignedAssigner(nn.Module):
118
120
  return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
119
121
 
120
122
  def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
121
- """Get in_gts mask, (b, max_num_obj, h*w)."""
123
+ """
124
+ Get positive mask for each ground truth box.
125
+
126
+ Args:
127
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
128
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
129
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
130
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
131
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
132
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
133
+
134
+ Returns:
135
+ mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
136
+ align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
137
+ overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
138
+ """
122
139
  mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
123
140
  # Get anchor_align metric, (b, max_num_obj, h*w)
124
141
  align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
@@ -130,7 +147,20 @@ class TaskAlignedAssigner(nn.Module):
130
147
  return mask_pos, align_metric, overlaps
131
148
 
132
149
  def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
133
- """Compute alignment metric given predicted and ground truth bounding boxes."""
150
+ """
151
+ Compute alignment metric given predicted and ground truth bounding boxes.
152
+
153
+ Args:
154
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
155
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
156
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
157
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
158
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
159
+
160
+ Returns:
161
+ align_metric (torch.Tensor): Alignment metric combining classification and localization.
162
+ overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
163
+ """
134
164
  na = pd_bboxes.shape[-2]
135
165
  mask_gt = mask_gt.bool() # b, max_num_obj, h*w
136
166
  overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
@@ -151,7 +181,16 @@ class TaskAlignedAssigner(nn.Module):
151
181
  return align_metric, overlaps
152
182
 
153
183
  def iou_calculation(self, gt_bboxes, pd_bboxes):
154
- """IoU calculation for horizontal bounding boxes."""
184
+ """
185
+ Calculate IoU for horizontal bounding boxes.
186
+
187
+ Args:
188
+ gt_bboxes (torch.Tensor): Ground truth boxes.
189
+ pd_bboxes (torch.Tensor): Predicted boxes.
190
+
191
+ Returns:
192
+ (torch.Tensor): IoU values between each pair of boxes.
193
+ """
155
194
  return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
156
195
 
157
196
  def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
@@ -159,16 +198,16 @@ class TaskAlignedAssigner(nn.Module):
159
198
  Select the top-k candidates based on the given metrics.
160
199
 
161
200
  Args:
162
- metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
201
+ metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
163
202
  max_num_obj is the maximum number of objects, and h*w represents the
164
203
  total number of anchor points.
165
204
  largest (bool): If True, select the largest values; otherwise, select the smallest values.
166
- topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
205
+ topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
167
206
  topk is the number of top candidates to consider. If not provided,
168
207
  the top-k values are automatically computed based on the given metrics.
169
208
 
170
209
  Returns:
171
- (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
210
+ (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
172
211
  """
173
212
  # (b, max_num_obj, topk)
174
213
  topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
@@ -183,7 +222,6 @@ class TaskAlignedAssigner(nn.Module):
183
222
  for k in range(self.topk):
184
223
  # Expand topk_idxs for each value of k and add 1 at the specified positions
185
224
  count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
186
- # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
187
225
  # Filter invalid bboxes
188
226
  count_tensor.masked_fill_(count_tensor > 1, 0)
189
227
 
@@ -194,24 +232,21 @@ class TaskAlignedAssigner(nn.Module):
194
232
  Compute target labels, target bounding boxes, and target scores for the positive anchor points.
195
233
 
196
234
  Args:
197
- gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
235
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
198
236
  batch size and max_num_obj is the maximum number of objects.
199
- gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
200
- target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
237
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
238
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
201
239
  anchor points, with shape (b, h*w), where h*w is the total
202
240
  number of anchor points.
203
- fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
241
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
204
242
  (foreground) anchor points.
205
243
 
206
244
  Returns:
207
- (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
208
- - target_labels (Tensor): Shape (b, h*w), containing the target labels for
209
- positive anchor points.
210
- - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
211
- for positive anchor points.
212
- - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
213
- for positive anchor points, where num_classes is the number
214
- of object classes.
245
+ target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
246
+ target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
247
+ anchor points.
248
+ target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
249
+ anchor points.
215
250
  """
216
251
  # Assigned target labels, (b, 1)
217
252
  batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
@@ -258,7 +293,6 @@ class TaskAlignedAssigner(nn.Module):
258
293
  bs, n_boxes, _ = gt_bboxes.shape
259
294
  lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
260
295
  bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
261
- # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
262
296
  return bbox_deltas.amin(3).gt_(eps)
263
297
 
264
298
  @staticmethod
@@ -275,9 +309,6 @@ class TaskAlignedAssigner(nn.Module):
275
309
  target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
276
310
  fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
277
311
  mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
278
-
279
- Note:
280
- b: batch size, h: height, w: width.
281
312
  """
282
313
  # Convert (b, n_max_boxes, h*w) -> (b, h*w)
283
314
  fg_mask = mask_pos.sum(-2)
@@ -299,7 +330,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
299
330
  """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
300
331
 
301
332
  def iou_calculation(self, gt_bboxes, pd_bboxes):
302
- """IoU calculation for rotated bounding boxes."""
333
+ """Calculate IoU for rotated bounding boxes."""
303
334
  return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
304
335
 
305
336
  @staticmethod
@@ -308,11 +339,11 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
308
339
  Select the positive anchor center in gt for rotated bounding boxes.
309
340
 
310
341
  Args:
311
- xy_centers (Tensor): shape(h*w, 2)
312
- gt_bboxes (Tensor): shape(b, n_boxes, 5)
342
+ xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
343
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
313
344
 
314
345
  Returns:
315
- (Tensor): shape(b, n_boxes, h*w)
346
+ (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
316
347
  """
317
348
  # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
318
349
  corners = xywhr2xyxyxyxy(gt_bboxes)
@@ -368,13 +399,13 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
368
399
  Decode predicted rotated bounding box coordinates from anchor points and distribution.
369
400
 
370
401
  Args:
371
- pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
372
- pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
373
- anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
402
+ pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
403
+ pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
404
+ anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
374
405
  dim (int, optional): Dimension along which to split. Defaults to -1.
375
406
 
376
407
  Returns:
377
- (torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
408
+ (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
378
409
  """
379
410
  lt, rb = pred_dist.split(2, dim=dim)
380
411
  cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)