dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py CHANGED
@@ -4,16 +4,13 @@ import torch
4
4
  import torch.nn as nn
5
5
 
6
6
  from . import LOGGER
7
- from .checks import check_version
8
7
  from .metrics import bbox_iou, probiou
9
8
  from .ops import xywhr2xyxyxyxy
10
-
11
- TORCH_1_10 = check_version(torch.__version__, "1.10.0")
9
+ from .torch_utils import TORCH_1_11
12
10
 
13
11
 
14
12
  class TaskAlignedAssigner(nn.Module):
15
- """
16
- A task-aligned assigner for object detection.
13
+ """A task-aligned assigner for object detection.
17
14
 
18
15
  This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
19
16
  classification and localization information.
@@ -21,26 +18,31 @@ class TaskAlignedAssigner(nn.Module):
21
18
  Attributes:
22
19
  topk (int): The number of top candidates to consider.
23
20
  num_classes (int): The number of object classes.
24
- bg_idx (int): Background class index.
25
21
  alpha (float): The alpha parameter for the classification component of the task-aligned metric.
26
22
  beta (float): The beta parameter for the localization component of the task-aligned metric.
27
23
  eps (float): A small value to prevent division by zero.
28
24
  """
29
25
 
30
- def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
31
- """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
26
+ def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
27
+ """Initialize a TaskAlignedAssigner object with customizable hyperparameters.
28
+
29
+ Args:
30
+ topk (int, optional): The number of top candidates to consider.
31
+ num_classes (int, optional): The number of object classes.
32
+ alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
33
+ beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
34
+ eps (float, optional): A small value to prevent division by zero.
35
+ """
32
36
  super().__init__()
33
37
  self.topk = topk
34
38
  self.num_classes = num_classes
35
- self.bg_idx = num_classes
36
39
  self.alpha = alpha
37
40
  self.beta = beta
38
41
  self.eps = eps
39
42
 
40
43
  @torch.no_grad()
41
44
  def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
42
- """
43
- Compute the task-aligned assignment.
45
+ """Compute the task-aligned assignment.
44
46
 
45
47
  Args:
46
48
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -66,7 +68,7 @@ class TaskAlignedAssigner(nn.Module):
66
68
 
67
69
  if self.n_max_boxes == 0:
68
70
  return (
69
- torch.full_like(pd_scores[..., 0], self.bg_idx),
71
+ torch.full_like(pd_scores[..., 0], self.num_classes),
70
72
  torch.zeros_like(pd_bboxes),
71
73
  torch.zeros_like(pd_scores),
72
74
  torch.zeros_like(pd_scores[..., 0]),
@@ -83,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
83
85
  return tuple(t.to(device) for t in result)
84
86
 
85
87
  def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
86
- """
87
- Compute the task-aligned assignment.
88
+ """Compute the task-aligned assignment.
88
89
 
89
90
  Args:
90
91
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -120,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
120
121
  return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
121
122
 
122
123
  def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
123
- """
124
- Get positive mask for each ground truth box.
124
+ """Get positive mask for each ground truth box.
125
125
 
126
126
  Args:
127
127
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -134,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
134
134
  Returns:
135
135
  mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
136
136
  align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
137
- overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
137
+ overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
138
138
  """
139
139
  mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
140
140
  # Get anchor_align metric, (b, max_num_obj, h*w)
@@ -147,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
147
147
  return mask_pos, align_metric, overlaps
148
148
 
149
149
  def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
150
- """
151
- Compute alignment metric given predicted and ground truth bounding boxes.
150
+ """Compute alignment metric given predicted and ground truth bounding boxes.
152
151
 
153
152
  Args:
154
153
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -181,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
181
180
  return align_metric, overlaps
182
181
 
183
182
  def iou_calculation(self, gt_bboxes, pd_bboxes):
184
- """
185
- Calculate IoU for horizontal bounding boxes.
183
+ """Calculate IoU for horizontal bounding boxes.
186
184
 
187
185
  Args:
188
186
  gt_bboxes (torch.Tensor): Ground truth boxes.
@@ -193,24 +191,21 @@ class TaskAlignedAssigner(nn.Module):
193
191
  """
194
192
  return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
195
193
 
196
- def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
197
- """
198
- Select the top-k candidates based on the given metrics.
194
+ def select_topk_candidates(self, metrics, topk_mask=None):
195
+ """Select the top-k candidates based on the given metrics.
199
196
 
200
197
  Args:
201
- metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
202
- max_num_obj is the maximum number of objects, and h*w represents the
203
- total number of anchor points.
204
- largest (bool): If True, select the largest values; otherwise, select the smallest values.
205
- topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
206
- topk is the number of top candidates to consider. If not provided,
207
- the top-k values are automatically computed based on the given metrics.
198
+ metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
199
+ the maximum number of objects, and h*w represents the total number of anchor points.
200
+ topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
201
+ is the number of top candidates to consider. If not provided, the top-k values are automatically
202
+ computed based on the given metrics.
208
203
 
209
204
  Returns:
210
205
  (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
211
206
  """
212
207
  # (b, max_num_obj, topk)
213
- topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
208
+ topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
214
209
  if topk_mask is None:
215
210
  topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
216
211
  # (b, max_num_obj, topk)
@@ -228,25 +223,21 @@ class TaskAlignedAssigner(nn.Module):
228
223
  return count_tensor.to(metrics.dtype)
229
224
 
230
225
  def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
231
- """
232
- Compute target labels, target bounding boxes, and target scores for the positive anchor points.
226
+ """Compute target labels, target bounding boxes, and target scores for the positive anchor points.
233
227
 
234
228
  Args:
235
- gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
236
- batch size and max_num_obj is the maximum number of objects.
229
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
230
+ max_num_obj is the maximum number of objects.
237
231
  gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
238
- target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
239
- anchor points, with shape (b, h*w), where h*w is the total
240
- number of anchor points.
241
- fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
242
- (foreground) anchor points.
232
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
233
+ shape (b, h*w), where h*w is the total number of anchor points.
234
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
235
+ points.
243
236
 
244
237
  Returns:
245
- target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
246
- target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
247
- anchor points.
248
- target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
249
- anchor points.
238
+ target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
239
+ target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
240
+ target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
250
241
  """
251
242
  # Assigned target labels, (b, 1)
252
243
  batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
@@ -274,20 +265,19 @@ class TaskAlignedAssigner(nn.Module):
274
265
 
275
266
  @staticmethod
276
267
  def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
277
- """
278
- Select positive anchor centers within ground truth bounding boxes.
268
+ """Select positive anchor centers within ground truth bounding boxes.
279
269
 
280
270
  Args:
281
271
  xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
282
272
  gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
283
- eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
273
+ eps (float, optional): Small value for numerical stability.
284
274
 
285
275
  Returns:
286
276
  (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
287
277
 
288
- Note:
289
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
290
- Bounding box format: [x_min, y_min, x_max, y_max].
278
+ Notes:
279
+ - b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
280
+ - Bounding box format: [x_min, y_min, x_max, y_max].
291
281
  """
292
282
  n_anchors = xy_centers.shape[0]
293
283
  bs, n_boxes, _ = gt_bboxes.shape
@@ -297,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
297
287
 
298
288
  @staticmethod
299
289
  def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
300
- """
301
- Select anchor boxes with highest IoU when assigned to multiple ground truths.
290
+ """Select anchor boxes with highest IoU when assigned to multiple ground truths.
302
291
 
303
292
  Args:
304
293
  mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
@@ -335,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
335
324
 
336
325
  @staticmethod
337
326
  def select_candidates_in_gts(xy_centers, gt_bboxes):
338
- """
339
- Select the positive anchor center in gt for rotated bounding boxes.
327
+ """Select the positive anchor center in gt for rotated bounding boxes.
340
328
 
341
329
  Args:
342
330
  xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
@@ -370,7 +358,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
370
358
  h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
371
359
  sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
372
360
  sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
373
- sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
361
+ sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
374
362
  anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
375
363
  stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
376
364
  return torch.cat(anchor_points), torch.cat(stride_tensor)
@@ -384,7 +372,7 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
384
372
  if xywh:
385
373
  c_xy = (x1y1 + x2y2) / 2
386
374
  wh = x2y2 - x1y1
387
- return torch.cat((c_xy, wh), dim) # xywh bbox
375
+ return torch.cat([c_xy, wh], dim) # xywh bbox
388
376
  return torch.cat((x1y1, x2y2), dim) # xyxy bbox
389
377
 
390
378
 
@@ -395,14 +383,13 @@ def bbox2dist(anchor_points, bbox, reg_max):
395
383
 
396
384
 
397
385
  def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
398
- """
399
- Decode predicted rotated bounding box coordinates from anchor points and distribution.
386
+ """Decode predicted rotated bounding box coordinates from anchor points and distribution.
400
387
 
401
388
  Args:
402
389
  pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
403
390
  pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
404
391
  anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
405
- dim (int, optional): Dimension along which to split. Defaults to -1.
392
+ dim (int, optional): Dimension along which to split.
406
393
 
407
394
  Returns:
408
395
  (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).