dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
ultralytics/utils/tal.py
ADDED
@@ -0,0 +1,416 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
|
6
|
+
from . import LOGGER
|
7
|
+
from .checks import check_version
|
8
|
+
from .metrics import bbox_iou, probiou
|
9
|
+
from .ops import xywhr2xyxyxyxy
|
10
|
+
|
11
|
+
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
|
12
|
+
|
13
|
+
|
14
|
+
class TaskAlignedAssigner(nn.Module):
|
15
|
+
"""
|
16
|
+
A task-aligned assigner for object detection.
|
17
|
+
|
18
|
+
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
19
|
+
classification and localization information.
|
20
|
+
|
21
|
+
Attributes:
|
22
|
+
topk (int): The number of top candidates to consider.
|
23
|
+
num_classes (int): The number of object classes.
|
24
|
+
bg_idx (int): Background class index.
|
25
|
+
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
26
|
+
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
27
|
+
eps (float): A small value to prevent division by zero.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
31
|
+
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
|
32
|
+
super().__init__()
|
33
|
+
self.topk = topk
|
34
|
+
self.num_classes = num_classes
|
35
|
+
self.bg_idx = num_classes
|
36
|
+
self.alpha = alpha
|
37
|
+
self.beta = beta
|
38
|
+
self.eps = eps
|
39
|
+
|
40
|
+
@torch.no_grad()
|
41
|
+
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
42
|
+
"""
|
43
|
+
Compute the task-aligned assignment.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
47
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
48
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
49
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
50
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
51
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
55
|
+
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
56
|
+
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
57
|
+
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
58
|
+
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
59
|
+
|
60
|
+
References:
|
61
|
+
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
62
|
+
"""
|
63
|
+
self.bs = pd_scores.shape[0]
|
64
|
+
self.n_max_boxes = gt_bboxes.shape[1]
|
65
|
+
device = gt_bboxes.device
|
66
|
+
|
67
|
+
if self.n_max_boxes == 0:
|
68
|
+
return (
|
69
|
+
torch.full_like(pd_scores[..., 0], self.bg_idx),
|
70
|
+
torch.zeros_like(pd_bboxes),
|
71
|
+
torch.zeros_like(pd_scores),
|
72
|
+
torch.zeros_like(pd_scores[..., 0]),
|
73
|
+
torch.zeros_like(pd_scores[..., 0]),
|
74
|
+
)
|
75
|
+
|
76
|
+
try:
|
77
|
+
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
|
78
|
+
except torch.cuda.OutOfMemoryError:
|
79
|
+
# Move tensors to CPU, compute, then move back to original device
|
80
|
+
LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
|
81
|
+
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
|
82
|
+
result = self._forward(*cpu_tensors)
|
83
|
+
return tuple(t.to(device) for t in result)
|
84
|
+
|
85
|
+
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
86
|
+
"""
|
87
|
+
Compute the task-aligned assignment.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
91
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
92
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
93
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
94
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
95
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
99
|
+
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
100
|
+
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
101
|
+
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
102
|
+
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
103
|
+
"""
|
104
|
+
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
105
|
+
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
106
|
+
)
|
107
|
+
|
108
|
+
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
109
|
+
|
110
|
+
# Assigned target
|
111
|
+
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
112
|
+
|
113
|
+
# Normalize
|
114
|
+
align_metric *= mask_pos
|
115
|
+
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
|
116
|
+
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
|
117
|
+
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
118
|
+
target_scores = target_scores * norm_align_metric
|
119
|
+
|
120
|
+
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
121
|
+
|
122
|
+
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
123
|
+
"""
|
124
|
+
Get positive mask for each ground truth box.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
128
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
129
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
130
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
131
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
132
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
136
|
+
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
137
|
+
overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
|
138
|
+
"""
|
139
|
+
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
140
|
+
# Get anchor_align metric, (b, max_num_obj, h*w)
|
141
|
+
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
142
|
+
# Get topk_metric mask, (b, max_num_obj, h*w)
|
143
|
+
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
144
|
+
# Merge all mask to a final mask, (b, max_num_obj, h*w)
|
145
|
+
mask_pos = mask_topk * mask_in_gts * mask_gt
|
146
|
+
|
147
|
+
return mask_pos, align_metric, overlaps
|
148
|
+
|
149
|
+
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
150
|
+
"""
|
151
|
+
Compute alignment metric given predicted and ground truth bounding boxes.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
155
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
156
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
157
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
158
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
align_metric (torch.Tensor): Alignment metric combining classification and localization.
|
162
|
+
overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
|
163
|
+
"""
|
164
|
+
na = pd_bboxes.shape[-2]
|
165
|
+
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
166
|
+
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
167
|
+
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
168
|
+
|
169
|
+
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
170
|
+
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
|
171
|
+
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
|
172
|
+
# Get the scores of each grid for each gt cls
|
173
|
+
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
|
174
|
+
|
175
|
+
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
176
|
+
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
177
|
+
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
178
|
+
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
|
179
|
+
|
180
|
+
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
181
|
+
return align_metric, overlaps
|
182
|
+
|
183
|
+
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
184
|
+
"""
|
185
|
+
Calculate IoU for horizontal bounding boxes.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
gt_bboxes (torch.Tensor): Ground truth boxes.
|
189
|
+
pd_bboxes (torch.Tensor): Predicted boxes.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
(torch.Tensor): IoU values between each pair of boxes.
|
193
|
+
"""
|
194
|
+
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
195
|
+
|
196
|
+
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
197
|
+
"""
|
198
|
+
Select the top-k candidates based on the given metrics.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
202
|
+
max_num_obj is the maximum number of objects, and h*w represents the
|
203
|
+
total number of anchor points.
|
204
|
+
largest (bool): If True, select the largest values; otherwise, select the smallest values.
|
205
|
+
topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
206
|
+
topk is the number of top candidates to consider. If not provided,
|
207
|
+
the top-k values are automatically computed based on the given metrics.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
211
|
+
"""
|
212
|
+
# (b, max_num_obj, topk)
|
213
|
+
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
214
|
+
if topk_mask is None:
|
215
|
+
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
216
|
+
# (b, max_num_obj, topk)
|
217
|
+
topk_idxs.masked_fill_(~topk_mask, 0)
|
218
|
+
|
219
|
+
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
|
220
|
+
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
221
|
+
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
222
|
+
for k in range(self.topk):
|
223
|
+
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
224
|
+
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
225
|
+
# Filter invalid bboxes
|
226
|
+
count_tensor.masked_fill_(count_tensor > 1, 0)
|
227
|
+
|
228
|
+
return count_tensor.to(metrics.dtype)
|
229
|
+
|
230
|
+
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
231
|
+
"""
|
232
|
+
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
236
|
+
batch size and max_num_obj is the maximum number of objects.
|
237
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
238
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
239
|
+
anchor points, with shape (b, h*w), where h*w is the total
|
240
|
+
number of anchor points.
|
241
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
242
|
+
(foreground) anchor points.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
|
246
|
+
target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
|
247
|
+
anchor points.
|
248
|
+
target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
|
249
|
+
anchor points.
|
250
|
+
"""
|
251
|
+
# Assigned target labels, (b, 1)
|
252
|
+
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
253
|
+
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
|
254
|
+
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
|
255
|
+
|
256
|
+
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
|
257
|
+
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
258
|
+
|
259
|
+
# Assigned target scores
|
260
|
+
target_labels.clamp_(0)
|
261
|
+
|
262
|
+
# 10x faster than F.one_hot()
|
263
|
+
target_scores = torch.zeros(
|
264
|
+
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
265
|
+
dtype=torch.int64,
|
266
|
+
device=target_labels.device,
|
267
|
+
) # (b, h*w, 80)
|
268
|
+
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
269
|
+
|
270
|
+
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
271
|
+
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
272
|
+
|
273
|
+
return target_labels, target_bboxes, target_scores
|
274
|
+
|
275
|
+
@staticmethod
|
276
|
+
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
277
|
+
"""
|
278
|
+
Select positive anchor centers within ground truth bounding boxes.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
282
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
283
|
+
eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
287
|
+
|
288
|
+
Note:
|
289
|
+
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
290
|
+
Bounding box format: [x_min, y_min, x_max, y_max].
|
291
|
+
"""
|
292
|
+
n_anchors = xy_centers.shape[0]
|
293
|
+
bs, n_boxes, _ = gt_bboxes.shape
|
294
|
+
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
295
|
+
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
296
|
+
return bbox_deltas.amin(3).gt_(eps)
|
297
|
+
|
298
|
+
@staticmethod
|
299
|
+
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
300
|
+
"""
|
301
|
+
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
305
|
+
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
|
306
|
+
n_max_boxes (int): Maximum number of ground truth boxes.
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
|
310
|
+
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
|
311
|
+
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
|
312
|
+
"""
|
313
|
+
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
|
314
|
+
fg_mask = mask_pos.sum(-2)
|
315
|
+
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
316
|
+
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
317
|
+
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
318
|
+
|
319
|
+
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
320
|
+
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
321
|
+
|
322
|
+
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
323
|
+
fg_mask = mask_pos.sum(-2)
|
324
|
+
# Find each grid serve which gt(index)
|
325
|
+
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
326
|
+
return target_gt_idx, fg_mask, mask_pos
|
327
|
+
|
328
|
+
|
329
|
+
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
330
|
+
"""Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
|
331
|
+
|
332
|
+
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
333
|
+
"""Calculate IoU for rotated bounding boxes."""
|
334
|
+
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
335
|
+
|
336
|
+
@staticmethod
|
337
|
+
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
338
|
+
"""
|
339
|
+
Select the positive anchor center in gt for rotated bounding boxes.
|
340
|
+
|
341
|
+
Args:
|
342
|
+
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
343
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
|
344
|
+
|
345
|
+
Returns:
|
346
|
+
(torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
|
347
|
+
"""
|
348
|
+
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
|
349
|
+
corners = xywhr2xyxyxyxy(gt_bboxes)
|
350
|
+
# (b, n_boxes, 1, 2)
|
351
|
+
a, b, _, d = corners.split(1, dim=-2)
|
352
|
+
ab = b - a
|
353
|
+
ad = d - a
|
354
|
+
|
355
|
+
# (b, n_boxes, h*w, 2)
|
356
|
+
ap = xy_centers - a
|
357
|
+
norm_ab = (ab * ab).sum(dim=-1)
|
358
|
+
norm_ad = (ad * ad).sum(dim=-1)
|
359
|
+
ap_dot_ab = (ap * ab).sum(dim=-1)
|
360
|
+
ap_dot_ad = (ap * ad).sum(dim=-1)
|
361
|
+
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
|
362
|
+
|
363
|
+
|
364
|
+
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
365
|
+
"""Generate anchors from features."""
|
366
|
+
anchor_points, stride_tensor = [], []
|
367
|
+
assert feats is not None
|
368
|
+
dtype, device = feats[0].dtype, feats[0].device
|
369
|
+
for i, stride in enumerate(strides):
|
370
|
+
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
|
371
|
+
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
372
|
+
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
373
|
+
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
374
|
+
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
375
|
+
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
376
|
+
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
377
|
+
|
378
|
+
|
379
|
+
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
380
|
+
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
381
|
+
lt, rb = distance.chunk(2, dim)
|
382
|
+
x1y1 = anchor_points - lt
|
383
|
+
x2y2 = anchor_points + rb
|
384
|
+
if xywh:
|
385
|
+
c_xy = (x1y1 + x2y2) / 2
|
386
|
+
wh = x2y2 - x1y1
|
387
|
+
return torch.cat((c_xy, wh), dim) # xywh bbox
|
388
|
+
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
389
|
+
|
390
|
+
|
391
|
+
def bbox2dist(anchor_points, bbox, reg_max):
|
392
|
+
"""Transform bbox(xyxy) to dist(ltrb)."""
|
393
|
+
x1y1, x2y2 = bbox.chunk(2, -1)
|
394
|
+
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
395
|
+
|
396
|
+
|
397
|
+
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
398
|
+
"""
|
399
|
+
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
403
|
+
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
|
404
|
+
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
405
|
+
dim (int, optional): Dimension along which to split. Defaults to -1.
|
406
|
+
|
407
|
+
Returns:
|
408
|
+
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
|
409
|
+
"""
|
410
|
+
lt, rb = pred_dist.split(2, dim=dim)
|
411
|
+
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|
412
|
+
# (bs, h*w, 1)
|
413
|
+
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
|
414
|
+
x, y = xf * cos - yf * sin, xf * sin + yf * cos
|
415
|
+
xy = torch.cat([x, y], dim=dim) + anchor_points
|
416
|
+
return torch.cat([xy, lt + rb], dim=dim)
|