dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py CHANGED
@@ -3,17 +3,14 @@
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
- from . import LOGGER, TORCH_VERSION
7
- from .checks import check_version
6
+ from . import LOGGER
8
7
  from .metrics import bbox_iou, probiou
9
8
  from .ops import xywhr2xyxyxyxy
10
-
11
- TORCH_1_10 = check_version(TORCH_VERSION, "1.10.0")
9
+ from .torch_utils import TORCH_1_11
12
10
 
13
11
 
14
12
  class TaskAlignedAssigner(nn.Module):
15
- """
16
- A task-aligned assigner for object detection.
13
+ """A task-aligned assigner for object detection.
17
14
 
18
15
  This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
19
16
  classification and localization information.
@@ -27,8 +24,7 @@ class TaskAlignedAssigner(nn.Module):
27
24
  """
28
25
 
29
26
  def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
30
- """
31
- Initialize a TaskAlignedAssigner object with customizable hyperparameters.
27
+ """Initialize a TaskAlignedAssigner object with customizable hyperparameters.
32
28
 
33
29
  Args:
34
30
  topk (int, optional): The number of top candidates to consider.
@@ -46,8 +42,7 @@ class TaskAlignedAssigner(nn.Module):
46
42
 
47
43
  @torch.no_grad()
48
44
  def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
49
- """
50
- Compute the task-aligned assignment.
45
+ """Compute the task-aligned assignment.
51
46
 
52
47
  Args:
53
48
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -90,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
90
85
  return tuple(t.to(device) for t in result)
91
86
 
92
87
  def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
93
- """
94
- Compute the task-aligned assignment.
88
+ """Compute the task-aligned assignment.
95
89
 
96
90
  Args:
97
91
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -127,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
127
121
  return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
128
122
 
129
123
  def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
130
- """
131
- Get positive mask for each ground truth box.
124
+ """Get positive mask for each ground truth box.
132
125
 
133
126
  Args:
134
127
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -141,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
141
134
  Returns:
142
135
  mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
143
136
  align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
144
- overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
137
+ overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
145
138
  """
146
139
  mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
147
140
  # Get anchor_align metric, (b, max_num_obj, h*w)
@@ -154,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
154
147
  return mask_pos, align_metric, overlaps
155
148
 
156
149
  def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
157
- """
158
- Compute alignment metric given predicted and ground truth bounding boxes.
150
+ """Compute alignment metric given predicted and ground truth bounding boxes.
159
151
 
160
152
  Args:
161
153
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -188,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
188
180
  return align_metric, overlaps
189
181
 
190
182
  def iou_calculation(self, gt_bboxes, pd_bboxes):
191
- """
192
- Calculate IoU for horizontal bounding boxes.
183
+ """Calculate IoU for horizontal bounding boxes.
193
184
 
194
185
  Args:
195
186
  gt_bboxes (torch.Tensor): Ground truth boxes.
@@ -201,14 +192,13 @@ class TaskAlignedAssigner(nn.Module):
201
192
  return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
202
193
 
203
194
  def select_topk_candidates(self, metrics, topk_mask=None):
204
- """
205
- Select the top-k candidates based on the given metrics.
195
+ """Select the top-k candidates based on the given metrics.
206
196
 
207
197
  Args:
208
198
  metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
209
199
  the maximum number of objects, and h*w represents the total number of anchor points.
210
- topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
211
- topk is the number of top candidates to consider. If not provided, the top-k values are automatically
200
+ topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
201
+ is the number of top candidates to consider. If not provided, the top-k values are automatically
212
202
  computed based on the given metrics.
213
203
 
214
204
  Returns:
@@ -233,18 +223,16 @@ class TaskAlignedAssigner(nn.Module):
233
223
  return count_tensor.to(metrics.dtype)
234
224
 
235
225
  def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
236
- """
237
- Compute target labels, target bounding boxes, and target scores for the positive anchor points.
226
+ """Compute target labels, target bounding boxes, and target scores for the positive anchor points.
238
227
 
239
228
  Args:
240
- gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
241
- batch size and max_num_obj is the maximum number of objects.
229
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
230
+ max_num_obj is the maximum number of objects.
242
231
  gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
243
- target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
244
- anchor points, with shape (b, h*w), where h*w is the total
245
- number of anchor points.
246
- fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
247
- (foreground) anchor points.
232
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
233
+ shape (b, h*w), where h*w is the total number of anchor points.
234
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
235
+ points.
248
236
 
249
237
  Returns:
250
238
  target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
@@ -277,8 +265,7 @@ class TaskAlignedAssigner(nn.Module):
277
265
 
278
266
  @staticmethod
279
267
  def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
280
- """
281
- Select positive anchor centers within ground truth bounding boxes.
268
+ """Select positive anchor centers within ground truth bounding boxes.
282
269
 
283
270
  Args:
284
271
  xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
@@ -288,9 +275,9 @@ class TaskAlignedAssigner(nn.Module):
288
275
  Returns:
289
276
  (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
290
277
 
291
- Note:
292
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
293
- Bounding box format: [x_min, y_min, x_max, y_max].
278
+ Notes:
279
+ - b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
280
+ - Bounding box format: [x_min, y_min, x_max, y_max].
294
281
  """
295
282
  n_anchors = xy_centers.shape[0]
296
283
  bs, n_boxes, _ = gt_bboxes.shape
@@ -300,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
300
287
 
301
288
  @staticmethod
302
289
  def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
303
- """
304
- Select anchor boxes with highest IoU when assigned to multiple ground truths.
290
+ """Select anchor boxes with highest IoU when assigned to multiple ground truths.
305
291
 
306
292
  Args:
307
293
  mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
@@ -338,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
338
324
 
339
325
  @staticmethod
340
326
  def select_candidates_in_gts(xy_centers, gt_bboxes):
341
- """
342
- Select the positive anchor center in gt for rotated bounding boxes.
327
+ """Select the positive anchor center in gt for rotated bounding boxes.
343
328
 
344
329
  Args:
345
330
  xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
@@ -373,7 +358,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
373
358
  h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
374
359
  sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
375
360
  sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
376
- sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
361
+ sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
377
362
  anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
378
363
  stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
379
364
  return torch.cat(anchor_points), torch.cat(stride_tensor)
@@ -398,8 +383,7 @@ def bbox2dist(anchor_points, bbox, reg_max):
398
383
 
399
384
 
400
385
  def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
401
- """
402
- Decode predicted rotated bounding box coordinates from anchor points and distribution.
386
+ """Decode predicted rotated bounding box coordinates from anchor points and distribution.
403
387
 
404
388
  Args:
405
389
  pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).