dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py
CHANGED
|
@@ -1,51 +1,65 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
import torch.nn as nn
|
|
5
7
|
|
|
6
8
|
from . import LOGGER
|
|
7
9
|
from .metrics import bbox_iou, probiou
|
|
8
|
-
from .ops import xywhr2xyxyxyxy
|
|
10
|
+
from .ops import xywh2xyxy, xywhr2xyxyxyxy, xyxy2xywh
|
|
9
11
|
from .torch_utils import TORCH_1_11
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class TaskAlignedAssigner(nn.Module):
|
|
13
|
-
"""
|
|
14
|
-
A task-aligned assigner for object detection.
|
|
15
|
+
"""A task-aligned assigner for object detection.
|
|
15
16
|
|
|
16
17
|
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
|
17
18
|
classification and localization information.
|
|
18
19
|
|
|
19
20
|
Attributes:
|
|
20
21
|
topk (int): The number of top candidates to consider.
|
|
22
|
+
topk2 (int): Secondary topk value for additional filtering.
|
|
21
23
|
num_classes (int): The number of object classes.
|
|
22
24
|
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
|
23
25
|
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
|
26
|
+
stride (list): List of stride values for different feature levels.
|
|
24
27
|
eps (float): A small value to prevent division by zero.
|
|
25
28
|
"""
|
|
26
29
|
|
|
27
|
-
def __init__(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
topk: int = 13,
|
|
33
|
+
num_classes: int = 80,
|
|
34
|
+
alpha: float = 1.0,
|
|
35
|
+
beta: float = 6.0,
|
|
36
|
+
stride: list = [8, 16, 32],
|
|
37
|
+
eps: float = 1e-9,
|
|
38
|
+
topk2=None,
|
|
39
|
+
):
|
|
40
|
+
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
30
41
|
|
|
31
42
|
Args:
|
|
32
43
|
topk (int, optional): The number of top candidates to consider.
|
|
33
44
|
num_classes (int, optional): The number of object classes.
|
|
34
45
|
alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
|
|
35
46
|
beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
|
|
47
|
+
stride (list, optional): List of stride values for different feature levels.
|
|
36
48
|
eps (float, optional): A small value to prevent division by zero.
|
|
49
|
+
topk2 (int, optional): Secondary topk value for additional filtering.
|
|
37
50
|
"""
|
|
38
51
|
super().__init__()
|
|
39
52
|
self.topk = topk
|
|
53
|
+
self.topk2 = topk2 or topk
|
|
40
54
|
self.num_classes = num_classes
|
|
41
55
|
self.alpha = alpha
|
|
42
56
|
self.beta = beta
|
|
57
|
+
self.stride = stride
|
|
43
58
|
self.eps = eps
|
|
44
59
|
|
|
45
60
|
@torch.no_grad()
|
|
46
61
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
47
|
-
"""
|
|
48
|
-
Compute the task-aligned assignment.
|
|
62
|
+
"""Compute the task-aligned assignment.
|
|
49
63
|
|
|
50
64
|
Args:
|
|
51
65
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -80,16 +94,17 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
80
94
|
|
|
81
95
|
try:
|
|
82
96
|
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
|
|
83
|
-
except
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
97
|
+
except RuntimeError as e:
|
|
98
|
+
if "out of memory" in str(e).lower():
|
|
99
|
+
# Move tensors to CPU, compute, then move back to original device
|
|
100
|
+
LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
|
|
101
|
+
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
|
|
102
|
+
result = self._forward(*cpu_tensors)
|
|
103
|
+
return tuple(t.to(device) for t in result)
|
|
104
|
+
raise
|
|
89
105
|
|
|
90
106
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
91
|
-
"""
|
|
92
|
-
Compute the task-aligned assignment.
|
|
107
|
+
"""Compute the task-aligned assignment.
|
|
93
108
|
|
|
94
109
|
Args:
|
|
95
110
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -110,7 +125,9 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
110
125
|
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
|
111
126
|
)
|
|
112
127
|
|
|
113
|
-
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(
|
|
128
|
+
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(
|
|
129
|
+
mask_pos, overlaps, self.n_max_boxes, align_metric
|
|
130
|
+
)
|
|
114
131
|
|
|
115
132
|
# Assigned target
|
|
116
133
|
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
|
@@ -125,8 +142,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
125
142
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
|
126
143
|
|
|
127
144
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
|
128
|
-
"""
|
|
129
|
-
Get positive mask for each ground truth box.
|
|
145
|
+
"""Get positive mask for each ground truth box.
|
|
130
146
|
|
|
131
147
|
Args:
|
|
132
148
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -139,9 +155,9 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
139
155
|
Returns:
|
|
140
156
|
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
|
141
157
|
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
|
142
|
-
overlaps (torch.Tensor): Overlaps between predicted
|
|
158
|
+
overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
|
|
143
159
|
"""
|
|
144
|
-
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
|
160
|
+
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes, mask_gt)
|
|
145
161
|
# Get anchor_align metric, (b, max_num_obj, h*w)
|
|
146
162
|
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
|
147
163
|
# Get topk_metric mask, (b, max_num_obj, h*w)
|
|
@@ -152,8 +168,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
152
168
|
return mask_pos, align_metric, overlaps
|
|
153
169
|
|
|
154
170
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
|
155
|
-
"""
|
|
156
|
-
Compute alignment metric given predicted and ground truth bounding boxes.
|
|
171
|
+
"""Compute alignment metric given predicted and ground truth bounding boxes.
|
|
157
172
|
|
|
158
173
|
Args:
|
|
159
174
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -186,8 +201,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
186
201
|
return align_metric, overlaps
|
|
187
202
|
|
|
188
203
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
|
189
|
-
"""
|
|
190
|
-
Calculate IoU for horizontal bounding boxes.
|
|
204
|
+
"""Calculate IoU for horizontal bounding boxes.
|
|
191
205
|
|
|
192
206
|
Args:
|
|
193
207
|
gt_bboxes (torch.Tensor): Ground truth boxes.
|
|
@@ -199,14 +213,13 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
199
213
|
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
|
200
214
|
|
|
201
215
|
def select_topk_candidates(self, metrics, topk_mask=None):
|
|
202
|
-
"""
|
|
203
|
-
Select the top-k candidates based on the given metrics.
|
|
216
|
+
"""Select the top-k candidates based on the given metrics.
|
|
204
217
|
|
|
205
218
|
Args:
|
|
206
219
|
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
|
207
220
|
the maximum number of objects, and h*w represents the total number of anchor points.
|
|
208
|
-
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
|
209
|
-
|
|
221
|
+
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
|
|
222
|
+
is the number of top candidates to consider. If not provided, the top-k values are automatically
|
|
210
223
|
computed based on the given metrics.
|
|
211
224
|
|
|
212
225
|
Returns:
|
|
@@ -231,18 +244,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
231
244
|
return count_tensor.to(metrics.dtype)
|
|
232
245
|
|
|
233
246
|
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
|
234
|
-
"""
|
|
235
|
-
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
247
|
+
"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
236
248
|
|
|
237
249
|
Args:
|
|
238
|
-
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
|
239
|
-
|
|
250
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
|
|
251
|
+
max_num_obj is the maximum number of objects.
|
|
240
252
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
|
241
|
-
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
(foreground) anchor points.
|
|
253
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
|
|
254
|
+
shape (b, h*w), where h*w is the total number of anchor points.
|
|
255
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
|
|
256
|
+
points.
|
|
246
257
|
|
|
247
258
|
Returns:
|
|
248
259
|
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
|
@@ -273,38 +284,42 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
273
284
|
|
|
274
285
|
return target_labels, target_bboxes, target_scores
|
|
275
286
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
"""
|
|
279
|
-
Select positive anchor centers within ground truth bounding boxes.
|
|
287
|
+
def select_candidates_in_gts(self, xy_centers, gt_bboxes, mask_gt, eps=1e-9):
|
|
288
|
+
"""Select positive anchor centers within ground truth bounding boxes.
|
|
280
289
|
|
|
281
290
|
Args:
|
|
282
291
|
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
|
283
292
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
|
293
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes, shape (b, n_boxes, 1).
|
|
284
294
|
eps (float, optional): Small value for numerical stability.
|
|
285
295
|
|
|
286
296
|
Returns:
|
|
287
297
|
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
|
288
298
|
|
|
289
|
-
|
|
290
|
-
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
291
|
-
Bounding box format: [x_min, y_min, x_max, y_max].
|
|
299
|
+
Notes:
|
|
300
|
+
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
301
|
+
- Bounding box format: [x_min, y_min, x_max, y_max].
|
|
292
302
|
"""
|
|
303
|
+
gt_bboxes_xywh = xyxy2xywh(gt_bboxes)
|
|
304
|
+
wh_mask = gt_bboxes_xywh[..., 2:] < self.stride[0] # the smallest stride
|
|
305
|
+
stride_val = torch.tensor(self.stride[1], dtype=gt_bboxes_xywh.dtype, device=gt_bboxes_xywh.device)
|
|
306
|
+
gt_bboxes_xywh[..., 2:] = torch.where((wh_mask * mask_gt).bool(), stride_val, gt_bboxes_xywh[..., 2:])
|
|
307
|
+
gt_bboxes = xywh2xyxy(gt_bboxes_xywh)
|
|
308
|
+
|
|
293
309
|
n_anchors = xy_centers.shape[0]
|
|
294
310
|
bs, n_boxes, _ = gt_bboxes.shape
|
|
295
311
|
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
|
296
312
|
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
|
297
313
|
return bbox_deltas.amin(3).gt_(eps)
|
|
298
314
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
"""
|
|
302
|
-
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
315
|
+
def select_highest_overlaps(self, mask_pos, overlaps, n_max_boxes, align_metric):
|
|
316
|
+
"""Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
303
317
|
|
|
304
318
|
Args:
|
|
305
319
|
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
|
306
320
|
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
|
|
307
321
|
n_max_boxes (int): Maximum number of ground truth boxes.
|
|
322
|
+
align_metric (torch.Tensor): Alignment metric for selecting best matches.
|
|
308
323
|
|
|
309
324
|
Returns:
|
|
310
325
|
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
|
|
@@ -315,12 +330,20 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
315
330
|
fg_mask = mask_pos.sum(-2)
|
|
316
331
|
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
|
317
332
|
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
|
318
|
-
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
|
319
333
|
|
|
334
|
+
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
|
320
335
|
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
|
321
336
|
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
|
322
|
-
|
|
323
337
|
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
|
338
|
+
|
|
339
|
+
fg_mask = mask_pos.sum(-2)
|
|
340
|
+
|
|
341
|
+
if self.topk2 != self.topk:
|
|
342
|
+
align_metric = align_metric * mask_pos # update overlaps
|
|
343
|
+
max_overlaps_idx = torch.topk(align_metric, self.topk2, dim=-1, largest=True).indices # (b, n_max_boxes)
|
|
344
|
+
topk_idx = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) # update mask_pos
|
|
345
|
+
topk_idx.scatter_(-1, max_overlaps_idx, 1.0)
|
|
346
|
+
mask_pos *= topk_idx
|
|
324
347
|
fg_mask = mask_pos.sum(-2)
|
|
325
348
|
# Find each grid serve which gt(index)
|
|
326
349
|
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
|
@@ -335,13 +358,14 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
|
335
358
|
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
|
336
359
|
|
|
337
360
|
@staticmethod
|
|
338
|
-
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
|
339
|
-
"""
|
|
340
|
-
Select the positive anchor center in gt for rotated bounding boxes.
|
|
361
|
+
def select_candidates_in_gts(xy_centers, gt_bboxes, mask_gt):
|
|
362
|
+
"""Select the positive anchor center in gt for rotated bounding boxes.
|
|
341
363
|
|
|
342
364
|
Args:
|
|
343
365
|
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
|
344
366
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
|
|
367
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (b, n_boxes, 1).
|
|
368
|
+
stride (list[int]): List of stride values for each feature map level.
|
|
345
369
|
|
|
346
370
|
Returns:
|
|
347
371
|
(torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
|
|
@@ -367,7 +391,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
|
|
367
391
|
anchor_points, stride_tensor = [], []
|
|
368
392
|
assert feats is not None
|
|
369
393
|
dtype, device = feats[0].dtype, feats[0].device
|
|
370
|
-
for i
|
|
394
|
+
for i in range(len(feats)): # use len(feats) to avoid TracerWarning from iterating over strides tensor
|
|
395
|
+
stride = strides[i]
|
|
371
396
|
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
|
|
372
397
|
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
|
373
398
|
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
|
@@ -389,15 +414,17 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
|
|
389
414
|
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
|
390
415
|
|
|
391
416
|
|
|
392
|
-
def bbox2dist(anchor_points, bbox, reg_max):
|
|
417
|
+
def bbox2dist(anchor_points: torch.Tensor, bbox: torch.Tensor, reg_max: int | None = None) -> torch.Tensor:
|
|
393
418
|
"""Transform bbox(xyxy) to dist(ltrb)."""
|
|
394
419
|
x1y1, x2y2 = bbox.chunk(2, -1)
|
|
395
|
-
|
|
420
|
+
dist = torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1)
|
|
421
|
+
if reg_max is not None:
|
|
422
|
+
dist = dist.clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
|
423
|
+
return dist
|
|
396
424
|
|
|
397
425
|
|
|
398
426
|
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
399
|
-
"""
|
|
400
|
-
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
427
|
+
"""Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
401
428
|
|
|
402
429
|
Args:
|
|
403
430
|
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
|
@@ -415,3 +442,42 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
|
415
442
|
x, y = xf * cos - yf * sin, xf * sin + yf * cos
|
|
416
443
|
xy = torch.cat([x, y], dim=dim) + anchor_points
|
|
417
444
|
return torch.cat([xy, lt + rb], dim=dim)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def rbox2dist(
|
|
448
|
+
target_bboxes: torch.Tensor,
|
|
449
|
+
anchor_points: torch.Tensor,
|
|
450
|
+
target_angle: torch.Tensor,
|
|
451
|
+
dim: int = -1,
|
|
452
|
+
reg_max: int | None = None,
|
|
453
|
+
):
|
|
454
|
+
"""Decode rotated bounding box (xywh) to distance(ltrb). This is the inverse of dist2rbox.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
target_bboxes (torch.Tensor): Target rotated bounding boxes with shape (bs, h*w, 4), format [x, y, w, h].
|
|
458
|
+
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
|
459
|
+
target_angle (torch.Tensor): Target angle with shape (bs, h*w, 1).
|
|
460
|
+
dim (int, optional): Dimension along which to split.
|
|
461
|
+
reg_max (int, optional): Maximum regression value for clamping.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
(torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4), format [l, t, r, b].
|
|
465
|
+
"""
|
|
466
|
+
xy, wh = target_bboxes.split(2, dim=dim)
|
|
467
|
+
offset = xy - anchor_points # (bs, h*w, 2)
|
|
468
|
+
offset_x, offset_y = offset.split(1, dim=dim)
|
|
469
|
+
cos, sin = torch.cos(target_angle), torch.sin(target_angle)
|
|
470
|
+
xf = offset_x * cos + offset_y * sin
|
|
471
|
+
yf = -offset_x * sin + offset_y * cos
|
|
472
|
+
|
|
473
|
+
w, h = wh.split(1, dim=dim)
|
|
474
|
+
target_l = w / 2 - xf
|
|
475
|
+
target_t = h / 2 - yf
|
|
476
|
+
target_r = w / 2 + xf
|
|
477
|
+
target_b = h / 2 + yf
|
|
478
|
+
|
|
479
|
+
dist = torch.cat([target_l, target_t, target_r, target_b], dim=dim)
|
|
480
|
+
if reg_max is not None:
|
|
481
|
+
dist = dist.clamp_(0, reg_max - 0.01)
|
|
482
|
+
|
|
483
|
+
return dist
|