dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
- tests/conftest.py +5 -8
- tests/test_cli.py +1 -8
- tests/test_python.py +1 -2
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +34 -49
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +244 -323
- ultralytics/data/base.py +12 -22
- ultralytics/data/build.py +47 -40
- ultralytics/data/converter.py +32 -42
- ultralytics/data/dataset.py +43 -71
- ultralytics/data/loaders.py +22 -34
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +27 -36
- ultralytics/engine/exporter.py +49 -116
- ultralytics/engine/model.py +144 -180
- ultralytics/engine/predictor.py +18 -29
- ultralytics/engine/results.py +165 -231
- ultralytics/engine/trainer.py +11 -19
- ultralytics/engine/tuner.py +13 -23
- ultralytics/engine/validator.py +6 -10
- ultralytics/hub/__init__.py +7 -12
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +3 -6
- ultralytics/models/fastsam/model.py +6 -8
- ultralytics/models/fastsam/predict.py +5 -10
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +2 -4
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -18
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +13 -20
- ultralytics/models/sam/amg.py +12 -18
- ultralytics/models/sam/build.py +6 -9
- ultralytics/models/sam/model.py +16 -23
- ultralytics/models/sam/modules/blocks.py +62 -84
- ultralytics/models/sam/modules/decoders.py +17 -24
- ultralytics/models/sam/modules/encoders.py +40 -56
- ultralytics/models/sam/modules/memory_attention.py +10 -16
- ultralytics/models/sam/modules/sam.py +41 -47
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +17 -27
- ultralytics/models/sam/modules/utils.py +31 -42
- ultralytics/models/sam/predict.py +172 -209
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/classify/predict.py +8 -11
- ultralytics/models/yolo/classify/train.py +8 -16
- ultralytics/models/yolo/classify/val.py +13 -20
- ultralytics/models/yolo/detect/predict.py +4 -8
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +38 -48
- ultralytics/models/yolo/model.py +35 -47
- ultralytics/models/yolo/obb/predict.py +5 -8
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +20 -28
- ultralytics/models/yolo/pose/predict.py +5 -8
- ultralytics/models/yolo/pose/train.py +4 -8
- ultralytics/models/yolo/pose/val.py +31 -39
- ultralytics/models/yolo/segment/predict.py +9 -14
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +16 -26
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -16
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/autobackend.py +10 -18
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +99 -185
- ultralytics/nn/modules/conv.py +45 -90
- ultralytics/nn/modules/head.py +44 -98
- ultralytics/nn/modules/transformer.py +44 -76
- ultralytics/nn/modules/utils.py +14 -19
- ultralytics/nn/tasks.py +86 -146
- ultralytics/nn/text_model.py +25 -40
- ultralytics/solutions/ai_gym.py +10 -16
- ultralytics/solutions/analytics.py +7 -10
- ultralytics/solutions/config.py +4 -5
- ultralytics/solutions/distance_calculation.py +9 -12
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +8 -12
- ultralytics/solutions/object_cropper.py +5 -8
- ultralytics/solutions/parking_management.py +12 -14
- ultralytics/solutions/queue_management.py +4 -6
- ultralytics/solutions/region_counter.py +7 -10
- ultralytics/solutions/security_alarm.py +14 -19
- ultralytics/solutions/similarity_search.py +7 -12
- ultralytics/solutions/solutions.py +31 -53
- ultralytics/solutions/speed_estimation.py +6 -9
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/basetrack.py +2 -4
- ultralytics/trackers/bot_sort.py +6 -11
- ultralytics/trackers/byte_tracker.py +10 -15
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +6 -12
- ultralytics/trackers/utils/kalman_filter.py +35 -43
- ultralytics/trackers/utils/matching.py +6 -10
- ultralytics/utils/__init__.py +61 -100
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +11 -13
- ultralytics/utils/benchmarks.py +25 -35
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +2 -4
- ultralytics/utils/callbacks/comet.py +30 -44
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +4 -6
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +4 -6
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +29 -56
- ultralytics/utils/cpu.py +1 -2
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +17 -27
- ultralytics/utils/errors.py +6 -8
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -239
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +11 -17
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +10 -15
- ultralytics/utils/git.py +5 -7
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +11 -15
- ultralytics/utils/loss.py +8 -14
- ultralytics/utils/metrics.py +98 -138
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +47 -74
- ultralytics/utils/patches.py +11 -18
- ultralytics/utils/plotting.py +29 -42
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +45 -73
- ultralytics/utils/tqdm.py +6 -8
- ultralytics/utils/triton.py +9 -12
- ultralytics/utils/tuner.py +1 -2
- dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
ultralytics/utils/plotting.py
CHANGED
|
@@ -20,11 +20,10 @@ from ultralytics.utils.files import increment_path
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class Colors:
|
|
23
|
-
"""
|
|
24
|
-
Ultralytics color palette for visualization and plotting.
|
|
23
|
+
"""Ultralytics color palette for visualization and plotting.
|
|
25
24
|
|
|
26
|
-
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
|
|
27
|
-
|
|
25
|
+
This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
|
|
26
|
+
values and accessing predefined color schemes for object detection and pose estimation.
|
|
28
27
|
|
|
29
28
|
Attributes:
|
|
30
29
|
palette (list[tuple]): List of RGB color tuples for general use.
|
|
@@ -146,8 +145,7 @@ class Colors:
|
|
|
146
145
|
)
|
|
147
146
|
|
|
148
147
|
def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
|
|
149
|
-
"""
|
|
150
|
-
Convert hex color codes to RGB values.
|
|
148
|
+
"""Convert hex color codes to RGB values.
|
|
151
149
|
|
|
152
150
|
Args:
|
|
153
151
|
i (int | torch.Tensor): Color index.
|
|
@@ -169,8 +167,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
|
|
169
167
|
|
|
170
168
|
|
|
171
169
|
class Annotator:
|
|
172
|
-
"""
|
|
173
|
-
Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
170
|
+
"""Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
|
|
174
171
|
|
|
175
172
|
Attributes:
|
|
176
173
|
im (Image.Image | np.ndarray): The image to annotate.
|
|
@@ -279,8 +276,7 @@ class Annotator:
|
|
|
279
276
|
}
|
|
280
277
|
|
|
281
278
|
def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
|
|
282
|
-
"""
|
|
283
|
-
Assign text color based on background color.
|
|
279
|
+
"""Assign text color based on background color.
|
|
284
280
|
|
|
285
281
|
Args:
|
|
286
282
|
color (tuple, optional): The background color of the rectangle for text (B, G, R).
|
|
@@ -303,8 +299,7 @@ class Annotator:
|
|
|
303
299
|
return txt_color
|
|
304
300
|
|
|
305
301
|
def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
|
|
306
|
-
"""
|
|
307
|
-
Draw a bounding box on an image with a given label.
|
|
302
|
+
"""Draw a bounding box on an image with a given label.
|
|
308
303
|
|
|
309
304
|
Args:
|
|
310
305
|
box (tuple): The bounding box coordinates (x1, y1, x2, y2).
|
|
@@ -365,8 +360,7 @@ class Annotator:
|
|
|
365
360
|
)
|
|
366
361
|
|
|
367
362
|
def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
|
|
368
|
-
"""
|
|
369
|
-
Plot masks on image.
|
|
363
|
+
"""Plot masks on image.
|
|
370
364
|
|
|
371
365
|
Args:
|
|
372
366
|
masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
|
|
@@ -424,8 +418,7 @@ class Annotator:
|
|
|
424
418
|
conf_thres: float = 0.25,
|
|
425
419
|
kpt_color: tuple | None = None,
|
|
426
420
|
):
|
|
427
|
-
"""
|
|
428
|
-
Plot keypoints on the image.
|
|
421
|
+
"""Plot keypoints on the image.
|
|
429
422
|
|
|
430
423
|
Args:
|
|
431
424
|
kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
|
|
@@ -435,7 +428,7 @@ class Annotator:
|
|
|
435
428
|
conf_thres (float, optional): Confidence threshold.
|
|
436
429
|
kpt_color (tuple, optional): Keypoint color (B, G, R).
|
|
437
430
|
|
|
438
|
-
|
|
431
|
+
Notes:
|
|
439
432
|
- `kpt_line=True` currently only supports human pose plotting.
|
|
440
433
|
- Modifies self.im in-place.
|
|
441
434
|
- If self.pil is True, converts image to numpy array and back to PIL.
|
|
@@ -488,8 +481,7 @@ class Annotator:
|
|
|
488
481
|
self.draw.rectangle(xy, fill, outline, width)
|
|
489
482
|
|
|
490
483
|
def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
|
|
491
|
-
"""
|
|
492
|
-
Add text to an image using PIL or cv2.
|
|
484
|
+
"""Add text to an image using PIL or cv2.
|
|
493
485
|
|
|
494
486
|
Args:
|
|
495
487
|
xy (list[int]): Top-left coordinates for text placement.
|
|
@@ -544,8 +536,7 @@ class Annotator:
|
|
|
544
536
|
|
|
545
537
|
@staticmethod
|
|
546
538
|
def get_bbox_dimension(bbox: tuple | None = None):
|
|
547
|
-
"""
|
|
548
|
-
Calculate the dimensions and area of a bounding box.
|
|
539
|
+
"""Calculate the dimensions and area of a bounding box.
|
|
549
540
|
|
|
550
541
|
Args:
|
|
551
542
|
bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
|
|
@@ -570,8 +561,7 @@ class Annotator:
|
|
|
570
561
|
@TryExcept()
|
|
571
562
|
@plt_settings()
|
|
572
563
|
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
|
|
573
|
-
"""
|
|
574
|
-
Plot training labels including class histograms and box statistics.
|
|
564
|
+
"""Plot training labels including class histograms and box statistics.
|
|
575
565
|
|
|
576
566
|
Args:
|
|
577
567
|
boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
|
|
@@ -641,12 +631,11 @@ def save_one_box(
|
|
|
641
631
|
BGR: bool = False,
|
|
642
632
|
save: bool = True,
|
|
643
633
|
):
|
|
644
|
-
"""
|
|
645
|
-
Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
634
|
+
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
|
|
646
635
|
|
|
647
|
-
This function takes a bounding box and an image, and then saves a cropped portion of the image according
|
|
648
|
-
|
|
649
|
-
|
|
636
|
+
This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
|
|
637
|
+
bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
|
|
638
|
+
bounding box.
|
|
650
639
|
|
|
651
640
|
Args:
|
|
652
641
|
xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
|
|
@@ -699,11 +688,11 @@ def plot_images(
|
|
|
699
688
|
save: bool = True,
|
|
700
689
|
conf_thres: float = 0.25,
|
|
701
690
|
) -> np.ndarray | None:
|
|
702
|
-
"""
|
|
703
|
-
Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
691
|
+
"""Plot image grid with labels, bounding boxes, masks, and keypoints.
|
|
704
692
|
|
|
705
693
|
Args:
|
|
706
|
-
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
694
|
+
labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
|
|
695
|
+
'keypoints', 'batch_idx', 'img'.
|
|
707
696
|
images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
|
|
708
697
|
paths (Optional[list[str]]): List of file paths for each image in the batch.
|
|
709
698
|
fname (str): Output filename for the plotted image grid.
|
|
@@ -717,7 +706,7 @@ def plot_images(
|
|
|
717
706
|
Returns:
|
|
718
707
|
(np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
|
|
719
708
|
|
|
720
|
-
|
|
709
|
+
Notes:
|
|
721
710
|
This function supports both tensor and numpy array inputs. It will automatically
|
|
722
711
|
convert tensor inputs to numpy arrays for processing.
|
|
723
712
|
|
|
@@ -868,9 +857,9 @@ def plot_images(
|
|
|
868
857
|
|
|
869
858
|
@plt_settings()
|
|
870
859
|
def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
|
|
871
|
-
"""
|
|
872
|
-
|
|
873
|
-
|
|
860
|
+
"""Plot training results from a results CSV file. The function supports various types of data including
|
|
861
|
+
segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
|
|
862
|
+
CSV is located.
|
|
874
863
|
|
|
875
864
|
Args:
|
|
876
865
|
file (str, optional): Path to the CSV file containing the training results.
|
|
@@ -922,8 +911,7 @@ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Call
|
|
|
922
911
|
|
|
923
912
|
|
|
924
913
|
def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
|
|
925
|
-
"""
|
|
926
|
-
Plot a scatter plot with points colored based on a 2D histogram.
|
|
914
|
+
"""Plot a scatter plot with points colored based on a 2D histogram.
|
|
927
915
|
|
|
928
916
|
Args:
|
|
929
917
|
v (array-like): Values for the x-axis.
|
|
@@ -956,9 +944,9 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
|
|
|
956
944
|
|
|
957
945
|
@plt_settings()
|
|
958
946
|
def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
|
|
959
|
-
"""
|
|
960
|
-
|
|
961
|
-
|
|
947
|
+
"""Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
|
|
948
|
+
key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
|
|
949
|
+
the plots.
|
|
962
950
|
|
|
963
951
|
Args:
|
|
964
952
|
csv_file (str, optional): Path to the CSV file containing the tuning results.
|
|
@@ -1025,8 +1013,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
|
|
|
1025
1013
|
|
|
1026
1014
|
@plt_settings()
|
|
1027
1015
|
def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
|
|
1028
|
-
"""
|
|
1029
|
-
Visualize feature maps of a given model module during inference.
|
|
1016
|
+
"""Visualize feature maps of a given model module during inference.
|
|
1030
1017
|
|
|
1031
1018
|
Args:
|
|
1032
1019
|
x (torch.Tensor): Features to be visualized.
|
ultralytics/utils/tal.py
CHANGED
|
@@ -10,8 +10,7 @@ from .torch_utils import TORCH_1_11
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class TaskAlignedAssigner(nn.Module):
|
|
13
|
-
"""
|
|
14
|
-
A task-aligned assigner for object detection.
|
|
13
|
+
"""A task-aligned assigner for object detection.
|
|
15
14
|
|
|
16
15
|
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
|
17
16
|
classification and localization information.
|
|
@@ -25,8 +24,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
25
24
|
"""
|
|
26
25
|
|
|
27
26
|
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
|
|
28
|
-
"""
|
|
29
|
-
Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
27
|
+
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
30
28
|
|
|
31
29
|
Args:
|
|
32
30
|
topk (int, optional): The number of top candidates to consider.
|
|
@@ -44,8 +42,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
44
42
|
|
|
45
43
|
@torch.no_grad()
|
|
46
44
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
47
|
-
"""
|
|
48
|
-
Compute the task-aligned assignment.
|
|
45
|
+
"""Compute the task-aligned assignment.
|
|
49
46
|
|
|
50
47
|
Args:
|
|
51
48
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -88,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
88
85
|
return tuple(t.to(device) for t in result)
|
|
89
86
|
|
|
90
87
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
91
|
-
"""
|
|
92
|
-
Compute the task-aligned assignment.
|
|
88
|
+
"""Compute the task-aligned assignment.
|
|
93
89
|
|
|
94
90
|
Args:
|
|
95
91
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -125,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
125
121
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
|
126
122
|
|
|
127
123
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
|
128
|
-
"""
|
|
129
|
-
Get positive mask for each ground truth box.
|
|
124
|
+
"""Get positive mask for each ground truth box.
|
|
130
125
|
|
|
131
126
|
Args:
|
|
132
127
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -139,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
139
134
|
Returns:
|
|
140
135
|
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
|
141
136
|
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
|
142
|
-
overlaps (torch.Tensor): Overlaps between predicted
|
|
137
|
+
overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
|
|
143
138
|
"""
|
|
144
139
|
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
|
145
140
|
# Get anchor_align metric, (b, max_num_obj, h*w)
|
|
@@ -152,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
152
147
|
return mask_pos, align_metric, overlaps
|
|
153
148
|
|
|
154
149
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
|
155
|
-
"""
|
|
156
|
-
Compute alignment metric given predicted and ground truth bounding boxes.
|
|
150
|
+
"""Compute alignment metric given predicted and ground truth bounding boxes.
|
|
157
151
|
|
|
158
152
|
Args:
|
|
159
153
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -186,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
186
180
|
return align_metric, overlaps
|
|
187
181
|
|
|
188
182
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
|
189
|
-
"""
|
|
190
|
-
Calculate IoU for horizontal bounding boxes.
|
|
183
|
+
"""Calculate IoU for horizontal bounding boxes.
|
|
191
184
|
|
|
192
185
|
Args:
|
|
193
186
|
gt_bboxes (torch.Tensor): Ground truth boxes.
|
|
@@ -199,14 +192,13 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
199
192
|
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
|
200
193
|
|
|
201
194
|
def select_topk_candidates(self, metrics, topk_mask=None):
|
|
202
|
-
"""
|
|
203
|
-
Select the top-k candidates based on the given metrics.
|
|
195
|
+
"""Select the top-k candidates based on the given metrics.
|
|
204
196
|
|
|
205
197
|
Args:
|
|
206
198
|
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
|
207
199
|
the maximum number of objects, and h*w represents the total number of anchor points.
|
|
208
|
-
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
|
209
|
-
|
|
200
|
+
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
|
|
201
|
+
is the number of top candidates to consider. If not provided, the top-k values are automatically
|
|
210
202
|
computed based on the given metrics.
|
|
211
203
|
|
|
212
204
|
Returns:
|
|
@@ -231,18 +223,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
231
223
|
return count_tensor.to(metrics.dtype)
|
|
232
224
|
|
|
233
225
|
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
|
234
|
-
"""
|
|
235
|
-
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
226
|
+
"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
236
227
|
|
|
237
228
|
Args:
|
|
238
|
-
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
|
239
|
-
|
|
229
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
|
|
230
|
+
max_num_obj is the maximum number of objects.
|
|
240
231
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
|
241
|
-
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
(foreground) anchor points.
|
|
232
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
|
|
233
|
+
shape (b, h*w), where h*w is the total number of anchor points.
|
|
234
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
|
|
235
|
+
points.
|
|
246
236
|
|
|
247
237
|
Returns:
|
|
248
238
|
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
|
@@ -275,8 +265,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
275
265
|
|
|
276
266
|
@staticmethod
|
|
277
267
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
|
278
|
-
"""
|
|
279
|
-
Select positive anchor centers within ground truth bounding boxes.
|
|
268
|
+
"""Select positive anchor centers within ground truth bounding boxes.
|
|
280
269
|
|
|
281
270
|
Args:
|
|
282
271
|
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
|
@@ -286,9 +275,9 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
286
275
|
Returns:
|
|
287
276
|
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
|
288
277
|
|
|
289
|
-
|
|
290
|
-
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
291
|
-
Bounding box format: [x_min, y_min, x_max, y_max].
|
|
278
|
+
Notes:
|
|
279
|
+
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
280
|
+
- Bounding box format: [x_min, y_min, x_max, y_max].
|
|
292
281
|
"""
|
|
293
282
|
n_anchors = xy_centers.shape[0]
|
|
294
283
|
bs, n_boxes, _ = gt_bboxes.shape
|
|
@@ -298,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
298
287
|
|
|
299
288
|
@staticmethod
|
|
300
289
|
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
|
301
|
-
"""
|
|
302
|
-
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
290
|
+
"""Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
303
291
|
|
|
304
292
|
Args:
|
|
305
293
|
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
|
@@ -336,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
|
336
324
|
|
|
337
325
|
@staticmethod
|
|
338
326
|
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
|
339
|
-
"""
|
|
340
|
-
Select the positive anchor center in gt for rotated bounding boxes.
|
|
327
|
+
"""Select the positive anchor center in gt for rotated bounding boxes.
|
|
341
328
|
|
|
342
329
|
Args:
|
|
343
330
|
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
|
@@ -396,8 +383,7 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
|
|
396
383
|
|
|
397
384
|
|
|
398
385
|
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
399
|
-
"""
|
|
400
|
-
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
386
|
+
"""Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
401
387
|
|
|
402
388
|
Args:
|
|
403
389
|
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|