dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,10 @@ from ultralytics.utils.files import increment_path
20
20
 
21
21
 
22
22
  class Colors:
23
- """
24
- Ultralytics color palette for visualization and plotting.
23
+ """Ultralytics color palette for visualization and plotting.
25
24
 
26
- This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
27
- RGB values and accessing predefined color schemes for object detection and pose estimation.
25
+ This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
26
+ values and accessing predefined color schemes for object detection and pose estimation.
28
27
 
29
28
  Attributes:
30
29
  palette (list[tuple]): List of RGB color tuples for general use.
@@ -146,8 +145,7 @@ class Colors:
146
145
  )
147
146
 
148
147
  def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
149
- """
150
- Convert hex color codes to RGB values.
148
+ """Convert hex color codes to RGB values.
151
149
 
152
150
  Args:
153
151
  i (int | torch.Tensor): Color index.
@@ -169,8 +167,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
169
167
 
170
168
 
171
169
  class Annotator:
172
- """
173
- Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
170
+ """Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
174
171
 
175
172
  Attributes:
176
173
  im (Image.Image | np.ndarray): The image to annotate.
@@ -279,8 +276,7 @@ class Annotator:
279
276
  }
280
277
 
281
278
  def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
282
- """
283
- Assign text color based on background color.
279
+ """Assign text color based on background color.
284
280
 
285
281
  Args:
286
282
  color (tuple, optional): The background color of the rectangle for text (B, G, R).
@@ -303,8 +299,7 @@ class Annotator:
303
299
  return txt_color
304
300
 
305
301
  def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
306
- """
307
- Draw a bounding box on an image with a given label.
302
+ """Draw a bounding box on an image with a given label.
308
303
 
309
304
  Args:
310
305
  box (tuple): The bounding box coordinates (x1, y1, x2, y2).
@@ -365,8 +360,7 @@ class Annotator:
365
360
  )
366
361
 
367
362
  def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
368
- """
369
- Plot masks on image.
363
+ """Plot masks on image.
370
364
 
371
365
  Args:
372
366
  masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
@@ -424,8 +418,7 @@ class Annotator:
424
418
  conf_thres: float = 0.25,
425
419
  kpt_color: tuple | None = None,
426
420
  ):
427
- """
428
- Plot keypoints on the image.
421
+ """Plot keypoints on the image.
429
422
 
430
423
  Args:
431
424
  kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
@@ -435,7 +428,7 @@ class Annotator:
435
428
  conf_thres (float, optional): Confidence threshold.
436
429
  kpt_color (tuple, optional): Keypoint color (B, G, R).
437
430
 
438
- Note:
431
+ Notes:
439
432
  - `kpt_line=True` currently only supports human pose plotting.
440
433
  - Modifies self.im in-place.
441
434
  - If self.pil is True, converts image to numpy array and back to PIL.
@@ -488,8 +481,7 @@ class Annotator:
488
481
  self.draw.rectangle(xy, fill, outline, width)
489
482
 
490
483
  def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
491
- """
492
- Add text to an image using PIL or cv2.
484
+ """Add text to an image using PIL or cv2.
493
485
 
494
486
  Args:
495
487
  xy (list[int]): Top-left coordinates for text placement.
@@ -544,8 +536,7 @@ class Annotator:
544
536
 
545
537
  @staticmethod
546
538
  def get_bbox_dimension(bbox: tuple | None = None):
547
- """
548
- Calculate the dimensions and area of a bounding box.
539
+ """Calculate the dimensions and area of a bounding box.
549
540
 
550
541
  Args:
551
542
  bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
@@ -570,8 +561,7 @@ class Annotator:
570
561
  @TryExcept()
571
562
  @plt_settings()
572
563
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
573
- """
574
- Plot training labels including class histograms and box statistics.
564
+ """Plot training labels including class histograms and box statistics.
575
565
 
576
566
  Args:
577
567
  boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
@@ -641,12 +631,11 @@ def save_one_box(
641
631
  BGR: bool = False,
642
632
  save: bool = True,
643
633
  ):
644
- """
645
- Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
634
+ """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
646
635
 
647
- This function takes a bounding box and an image, and then saves a cropped portion of the image according
648
- to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
649
- adjustments to the bounding box.
636
+ This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
637
+ bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
638
+ bounding box.
650
639
 
651
640
  Args:
652
641
  xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
@@ -699,11 +688,11 @@ def plot_images(
699
688
  save: bool = True,
700
689
  conf_thres: float = 0.25,
701
690
  ) -> np.ndarray | None:
702
- """
703
- Plot image grid with labels, bounding boxes, masks, and keypoints.
691
+ """Plot image grid with labels, bounding boxes, masks, and keypoints.
704
692
 
705
693
  Args:
706
- labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.
694
+ labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
695
+ 'keypoints', 'batch_idx', 'img'.
707
696
  images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
708
697
  paths (Optional[list[str]]): List of file paths for each image in the batch.
709
698
  fname (str): Output filename for the plotted image grid.
@@ -717,7 +706,7 @@ def plot_images(
717
706
  Returns:
718
707
  (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
719
708
 
720
- Note:
709
+ Notes:
721
710
  This function supports both tensor and numpy array inputs. It will automatically
722
711
  convert tensor inputs to numpy arrays for processing.
723
712
 
@@ -868,9 +857,9 @@ def plot_images(
868
857
 
869
858
  @plt_settings()
870
859
  def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
871
- """
872
- Plot training results from a results CSV file. The function supports various types of data including segmentation,
873
- pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
860
+ """Plot training results from a results CSV file. The function supports various types of data including
861
+ segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
862
+ CSV is located.
874
863
 
875
864
  Args:
876
865
  file (str, optional): Path to the CSV file containing the training results.
@@ -922,8 +911,7 @@ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Call
922
911
 
923
912
 
924
913
  def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
925
- """
926
- Plot a scatter plot with points colored based on a 2D histogram.
914
+ """Plot a scatter plot with points colored based on a 2D histogram.
927
915
 
928
916
  Args:
929
917
  v (array-like): Values for the x-axis.
@@ -956,9 +944,9 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
956
944
 
957
945
  @plt_settings()
958
946
  def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
959
- """
960
- Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
961
- in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
947
+ """Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
948
+ key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
949
+ the plots.
962
950
 
963
951
  Args:
964
952
  csv_file (str, optional): Path to the CSV file containing the tuning results.
@@ -1025,8 +1013,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
1025
1013
 
1026
1014
  @plt_settings()
1027
1015
  def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
1028
- """
1029
- Visualize feature maps of a given model module during inference.
1016
+ """Visualize feature maps of a given model module during inference.
1030
1017
 
1031
1018
  Args:
1032
1019
  x (torch.Tensor): Features to be visualized.
ultralytics/utils/tal.py CHANGED
@@ -10,8 +10,7 @@ from .torch_utils import TORCH_1_11
10
10
 
11
11
 
12
12
  class TaskAlignedAssigner(nn.Module):
13
- """
14
- A task-aligned assigner for object detection.
13
+ """A task-aligned assigner for object detection.
15
14
 
16
15
  This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
17
16
  classification and localization information.
@@ -25,8 +24,7 @@ class TaskAlignedAssigner(nn.Module):
25
24
  """
26
25
 
27
26
  def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
28
- """
29
- Initialize a TaskAlignedAssigner object with customizable hyperparameters.
27
+ """Initialize a TaskAlignedAssigner object with customizable hyperparameters.
30
28
 
31
29
  Args:
32
30
  topk (int, optional): The number of top candidates to consider.
@@ -44,8 +42,7 @@ class TaskAlignedAssigner(nn.Module):
44
42
 
45
43
  @torch.no_grad()
46
44
  def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
47
- """
48
- Compute the task-aligned assignment.
45
+ """Compute the task-aligned assignment.
49
46
 
50
47
  Args:
51
48
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -88,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
88
85
  return tuple(t.to(device) for t in result)
89
86
 
90
87
  def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
91
- """
92
- Compute the task-aligned assignment.
88
+ """Compute the task-aligned assignment.
93
89
 
94
90
  Args:
95
91
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -125,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
125
121
  return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
126
122
 
127
123
  def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
128
- """
129
- Get positive mask for each ground truth box.
124
+ """Get positive mask for each ground truth box.
130
125
 
131
126
  Args:
132
127
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -139,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
139
134
  Returns:
140
135
  mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
141
136
  align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
142
- overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
137
+ overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
143
138
  """
144
139
  mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
145
140
  # Get anchor_align metric, (b, max_num_obj, h*w)
@@ -152,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
152
147
  return mask_pos, align_metric, overlaps
153
148
 
154
149
  def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
155
- """
156
- Compute alignment metric given predicted and ground truth bounding boxes.
150
+ """Compute alignment metric given predicted and ground truth bounding boxes.
157
151
 
158
152
  Args:
159
153
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -186,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
186
180
  return align_metric, overlaps
187
181
 
188
182
  def iou_calculation(self, gt_bboxes, pd_bboxes):
189
- """
190
- Calculate IoU for horizontal bounding boxes.
183
+ """Calculate IoU for horizontal bounding boxes.
191
184
 
192
185
  Args:
193
186
  gt_bboxes (torch.Tensor): Ground truth boxes.
@@ -199,14 +192,13 @@ class TaskAlignedAssigner(nn.Module):
199
192
  return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
200
193
 
201
194
  def select_topk_candidates(self, metrics, topk_mask=None):
202
- """
203
- Select the top-k candidates based on the given metrics.
195
+ """Select the top-k candidates based on the given metrics.
204
196
 
205
197
  Args:
206
198
  metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
207
199
  the maximum number of objects, and h*w represents the total number of anchor points.
208
- topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
209
- topk is the number of top candidates to consider. If not provided, the top-k values are automatically
200
+ topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
201
+ is the number of top candidates to consider. If not provided, the top-k values are automatically
210
202
  computed based on the given metrics.
211
203
 
212
204
  Returns:
@@ -231,18 +223,16 @@ class TaskAlignedAssigner(nn.Module):
231
223
  return count_tensor.to(metrics.dtype)
232
224
 
233
225
  def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
234
- """
235
- Compute target labels, target bounding boxes, and target scores for the positive anchor points.
226
+ """Compute target labels, target bounding boxes, and target scores for the positive anchor points.
236
227
 
237
228
  Args:
238
- gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
239
- batch size and max_num_obj is the maximum number of objects.
229
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
230
+ max_num_obj is the maximum number of objects.
240
231
  gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
241
- target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
242
- anchor points, with shape (b, h*w), where h*w is the total
243
- number of anchor points.
244
- fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
245
- (foreground) anchor points.
232
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
233
+ shape (b, h*w), where h*w is the total number of anchor points.
234
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
235
+ points.
246
236
 
247
237
  Returns:
248
238
  target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
@@ -275,8 +265,7 @@ class TaskAlignedAssigner(nn.Module):
275
265
 
276
266
  @staticmethod
277
267
  def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
278
- """
279
- Select positive anchor centers within ground truth bounding boxes.
268
+ """Select positive anchor centers within ground truth bounding boxes.
280
269
 
281
270
  Args:
282
271
  xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
@@ -286,9 +275,9 @@ class TaskAlignedAssigner(nn.Module):
286
275
  Returns:
287
276
  (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
288
277
 
289
- Note:
290
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
291
- Bounding box format: [x_min, y_min, x_max, y_max].
278
+ Notes:
279
+ - b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
280
+ - Bounding box format: [x_min, y_min, x_max, y_max].
292
281
  """
293
282
  n_anchors = xy_centers.shape[0]
294
283
  bs, n_boxes, _ = gt_bboxes.shape
@@ -298,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
298
287
 
299
288
  @staticmethod
300
289
  def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
301
- """
302
- Select anchor boxes with highest IoU when assigned to multiple ground truths.
290
+ """Select anchor boxes with highest IoU when assigned to multiple ground truths.
303
291
 
304
292
  Args:
305
293
  mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
@@ -336,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
336
324
 
337
325
  @staticmethod
338
326
  def select_candidates_in_gts(xy_centers, gt_bboxes):
339
- """
340
- Select the positive anchor center in gt for rotated bounding boxes.
327
+ """Select the positive anchor center in gt for rotated bounding boxes.
341
328
 
342
329
  Args:
343
330
  xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
@@ -396,8 +383,7 @@ def bbox2dist(anchor_points, bbox, reg_max):
396
383
 
397
384
 
398
385
  def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
399
- """
400
- Decode predicted rotated bounding box coordinates from anchor points and distribution.
386
+ """Decode predicted rotated bounding box coordinates from anchor points and distribution.
401
387
 
402
388
  Args:
403
389
  pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).