dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py
CHANGED
|
@@ -3,17 +3,14 @@
|
|
|
3
3
|
import torch
|
|
4
4
|
import torch.nn as nn
|
|
5
5
|
|
|
6
|
-
from . import LOGGER
|
|
7
|
-
from .checks import check_version
|
|
6
|
+
from . import LOGGER
|
|
8
7
|
from .metrics import bbox_iou, probiou
|
|
9
8
|
from .ops import xywhr2xyxyxyxy
|
|
10
|
-
|
|
11
|
-
TORCH_1_10 = check_version(TORCH_VERSION, "1.10.0")
|
|
9
|
+
from .torch_utils import TORCH_1_11
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
class TaskAlignedAssigner(nn.Module):
|
|
15
|
-
"""
|
|
16
|
-
A task-aligned assigner for object detection.
|
|
13
|
+
"""A task-aligned assigner for object detection.
|
|
17
14
|
|
|
18
15
|
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
|
19
16
|
classification and localization information.
|
|
@@ -27,8 +24,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
27
24
|
"""
|
|
28
25
|
|
|
29
26
|
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
|
|
30
|
-
"""
|
|
31
|
-
Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
27
|
+
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
32
28
|
|
|
33
29
|
Args:
|
|
34
30
|
topk (int, optional): The number of top candidates to consider.
|
|
@@ -46,8 +42,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
46
42
|
|
|
47
43
|
@torch.no_grad()
|
|
48
44
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
49
|
-
"""
|
|
50
|
-
Compute the task-aligned assignment.
|
|
45
|
+
"""Compute the task-aligned assignment.
|
|
51
46
|
|
|
52
47
|
Args:
|
|
53
48
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -90,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
90
85
|
return tuple(t.to(device) for t in result)
|
|
91
86
|
|
|
92
87
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
93
|
-
"""
|
|
94
|
-
Compute the task-aligned assignment.
|
|
88
|
+
"""Compute the task-aligned assignment.
|
|
95
89
|
|
|
96
90
|
Args:
|
|
97
91
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -127,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
127
121
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
|
128
122
|
|
|
129
123
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
|
130
|
-
"""
|
|
131
|
-
Get positive mask for each ground truth box.
|
|
124
|
+
"""Get positive mask for each ground truth box.
|
|
132
125
|
|
|
133
126
|
Args:
|
|
134
127
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -141,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
141
134
|
Returns:
|
|
142
135
|
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
|
143
136
|
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
|
144
|
-
overlaps (torch.Tensor): Overlaps between predicted
|
|
137
|
+
overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
|
|
145
138
|
"""
|
|
146
139
|
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
|
147
140
|
# Get anchor_align metric, (b, max_num_obj, h*w)
|
|
@@ -154,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
154
147
|
return mask_pos, align_metric, overlaps
|
|
155
148
|
|
|
156
149
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
|
157
|
-
"""
|
|
158
|
-
Compute alignment metric given predicted and ground truth bounding boxes.
|
|
150
|
+
"""Compute alignment metric given predicted and ground truth bounding boxes.
|
|
159
151
|
|
|
160
152
|
Args:
|
|
161
153
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -188,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
188
180
|
return align_metric, overlaps
|
|
189
181
|
|
|
190
182
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
|
191
|
-
"""
|
|
192
|
-
Calculate IoU for horizontal bounding boxes.
|
|
183
|
+
"""Calculate IoU for horizontal bounding boxes.
|
|
193
184
|
|
|
194
185
|
Args:
|
|
195
186
|
gt_bboxes (torch.Tensor): Ground truth boxes.
|
|
@@ -201,14 +192,13 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
201
192
|
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
|
202
193
|
|
|
203
194
|
def select_topk_candidates(self, metrics, topk_mask=None):
|
|
204
|
-
"""
|
|
205
|
-
Select the top-k candidates based on the given metrics.
|
|
195
|
+
"""Select the top-k candidates based on the given metrics.
|
|
206
196
|
|
|
207
197
|
Args:
|
|
208
198
|
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
|
209
199
|
the maximum number of objects, and h*w represents the total number of anchor points.
|
|
210
|
-
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
|
211
|
-
|
|
200
|
+
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
|
|
201
|
+
is the number of top candidates to consider. If not provided, the top-k values are automatically
|
|
212
202
|
computed based on the given metrics.
|
|
213
203
|
|
|
214
204
|
Returns:
|
|
@@ -233,18 +223,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
233
223
|
return count_tensor.to(metrics.dtype)
|
|
234
224
|
|
|
235
225
|
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
|
236
|
-
"""
|
|
237
|
-
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
226
|
+
"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
238
227
|
|
|
239
228
|
Args:
|
|
240
|
-
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
|
241
|
-
|
|
229
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
|
|
230
|
+
max_num_obj is the maximum number of objects.
|
|
242
231
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
|
243
|
-
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
(foreground) anchor points.
|
|
232
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
|
|
233
|
+
shape (b, h*w), where h*w is the total number of anchor points.
|
|
234
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
|
|
235
|
+
points.
|
|
248
236
|
|
|
249
237
|
Returns:
|
|
250
238
|
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
|
@@ -277,8 +265,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
277
265
|
|
|
278
266
|
@staticmethod
|
|
279
267
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
|
280
|
-
"""
|
|
281
|
-
Select positive anchor centers within ground truth bounding boxes.
|
|
268
|
+
"""Select positive anchor centers within ground truth bounding boxes.
|
|
282
269
|
|
|
283
270
|
Args:
|
|
284
271
|
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
|
@@ -288,9 +275,9 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
288
275
|
Returns:
|
|
289
276
|
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
|
290
277
|
|
|
291
|
-
|
|
292
|
-
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
293
|
-
Bounding box format: [x_min, y_min, x_max, y_max].
|
|
278
|
+
Notes:
|
|
279
|
+
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
280
|
+
- Bounding box format: [x_min, y_min, x_max, y_max].
|
|
294
281
|
"""
|
|
295
282
|
n_anchors = xy_centers.shape[0]
|
|
296
283
|
bs, n_boxes, _ = gt_bboxes.shape
|
|
@@ -300,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
300
287
|
|
|
301
288
|
@staticmethod
|
|
302
289
|
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
|
303
|
-
"""
|
|
304
|
-
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
290
|
+
"""Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
305
291
|
|
|
306
292
|
Args:
|
|
307
293
|
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
|
@@ -338,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
|
338
324
|
|
|
339
325
|
@staticmethod
|
|
340
326
|
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
|
341
|
-
"""
|
|
342
|
-
Select the positive anchor center in gt for rotated bounding boxes.
|
|
327
|
+
"""Select the positive anchor center in gt for rotated bounding boxes.
|
|
343
328
|
|
|
344
329
|
Args:
|
|
345
330
|
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
|
@@ -373,7 +358,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
|
|
373
358
|
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
|
|
374
359
|
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
|
375
360
|
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
|
376
|
-
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if
|
|
361
|
+
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
|
|
377
362
|
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
|
378
363
|
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
|
379
364
|
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
|
@@ -398,8 +383,7 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
|
|
398
383
|
|
|
399
384
|
|
|
400
385
|
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
401
|
-
"""
|
|
402
|
-
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
386
|
+
"""Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
403
387
|
|
|
404
388
|
Args:
|
|
405
389
|
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|