ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import torch.nn as nn
|
5
5
|
|
6
|
+
from . import LOGGER
|
6
7
|
from .checks import check_version
|
7
8
|
from .metrics import bbox_iou, probiou
|
8
9
|
from .ops import xywhr2xyxyxyxy
|
@@ -58,17 +59,46 @@ class TaskAlignedAssigner(nn.Module):
|
|
58
59
|
"""
|
59
60
|
self.bs = pd_scores.shape[0]
|
60
61
|
self.n_max_boxes = gt_bboxes.shape[1]
|
62
|
+
device = gt_bboxes.device
|
61
63
|
|
62
64
|
if self.n_max_boxes == 0:
|
63
|
-
device = gt_bboxes.device
|
64
65
|
return (
|
65
|
-
torch.full_like(pd_scores[..., 0], self.bg_idx)
|
66
|
-
torch.zeros_like(pd_bboxes)
|
67
|
-
torch.zeros_like(pd_scores)
|
68
|
-
torch.zeros_like(pd_scores[..., 0])
|
69
|
-
torch.zeros_like(pd_scores[..., 0])
|
66
|
+
torch.full_like(pd_scores[..., 0], self.bg_idx),
|
67
|
+
torch.zeros_like(pd_bboxes),
|
68
|
+
torch.zeros_like(pd_scores),
|
69
|
+
torch.zeros_like(pd_scores[..., 0]),
|
70
|
+
torch.zeros_like(pd_scores[..., 0]),
|
70
71
|
)
|
71
72
|
|
73
|
+
try:
|
74
|
+
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
|
75
|
+
except torch.OutOfMemoryError:
|
76
|
+
# Move tensors to CPU, compute, then move back to original device
|
77
|
+
LOGGER.warning("WARNING: CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
|
78
|
+
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
|
79
|
+
result = self._forward(*cpu_tensors)
|
80
|
+
return tuple(t.to(device) for t in result)
|
81
|
+
|
82
|
+
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
83
|
+
"""
|
84
|
+
Compute the task-aligned assignment. Reference code is available at
|
85
|
+
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
89
|
+
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
90
|
+
anc_points (Tensor): shape(num_total_anchors, 2)
|
91
|
+
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
92
|
+
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
93
|
+
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
target_labels (Tensor): shape(bs, num_total_anchors)
|
97
|
+
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
98
|
+
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
99
|
+
fg_mask (Tensor): shape(bs, num_total_anchors)
|
100
|
+
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
101
|
+
"""
|
72
102
|
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
73
103
|
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
74
104
|
)
|
@@ -140,7 +170,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
140
170
|
Returns:
|
141
171
|
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
142
172
|
"""
|
143
|
-
|
144
173
|
# (b, max_num_obj, topk)
|
145
174
|
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
146
175
|
if topk_mask is None:
|
@@ -184,7 +213,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
184
213
|
for positive anchor points, where num_classes is the number
|
185
214
|
of object classes.
|
186
215
|
"""
|
187
|
-
|
188
216
|
# Assigned target labels, (b, 1)
|
189
217
|
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
190
218
|
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
|
@@ -212,14 +240,19 @@ class TaskAlignedAssigner(nn.Module):
|
|
212
240
|
@staticmethod
|
213
241
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
214
242
|
"""
|
215
|
-
Select
|
243
|
+
Select positive anchor centers within ground truth bounding boxes.
|
216
244
|
|
217
245
|
Args:
|
218
|
-
xy_centers (Tensor): shape(h*w, 2)
|
219
|
-
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
246
|
+
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
247
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
248
|
+
eps (float, optional): Small value for numerical stability. Defaults to 1e-9.
|
220
249
|
|
221
250
|
Returns:
|
222
|
-
(Tensor): shape(b, n_boxes, h*w)
|
251
|
+
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
252
|
+
|
253
|
+
Note:
|
254
|
+
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
255
|
+
Bounding box format: [x_min, y_min, x_max, y_max].
|
223
256
|
"""
|
224
257
|
n_anchors = xy_centers.shape[0]
|
225
258
|
bs, n_boxes, _ = gt_bboxes.shape
|
@@ -231,18 +264,22 @@ class TaskAlignedAssigner(nn.Module):
|
|
231
264
|
@staticmethod
|
232
265
|
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
233
266
|
"""
|
234
|
-
|
267
|
+
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
235
268
|
|
236
269
|
Args:
|
237
|
-
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
238
|
-
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
270
|
+
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
271
|
+
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
|
272
|
+
n_max_boxes (int): Maximum number of ground truth boxes.
|
239
273
|
|
240
274
|
Returns:
|
241
|
-
target_gt_idx (Tensor): shape(b, h*w)
|
242
|
-
fg_mask (Tensor): shape(b, h*w)
|
243
|
-
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
275
|
+
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
|
276
|
+
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
|
277
|
+
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
|
278
|
+
|
279
|
+
Note:
|
280
|
+
b: batch size, h: height, w: width.
|
244
281
|
"""
|
245
|
-
# (b, n_max_boxes, h*w) -> (b, h*w)
|
282
|
+
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
|
246
283
|
fg_mask = mask_pos.sum(-2)
|
247
284
|
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
248
285
|
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
@@ -259,6 +296,8 @@ class TaskAlignedAssigner(nn.Module):
|
|
259
296
|
|
260
297
|
|
261
298
|
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
299
|
+
"""Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
|
300
|
+
|
262
301
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
263
302
|
"""IoU calculation for rotated bounding boxes."""
|
264
303
|
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
@@ -297,7 +336,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
|
297
336
|
assert feats is not None
|
298
337
|
dtype, device = feats[0].dtype, feats[0].device
|
299
338
|
for i, stride in enumerate(strides):
|
300
|
-
|
339
|
+
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
|
301
340
|
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
302
341
|
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
303
342
|
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
|
@@ -326,14 +365,16 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
|
326
365
|
|
327
366
|
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
328
367
|
"""
|
329
|
-
Decode predicted
|
368
|
+
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
330
369
|
|
331
370
|
Args:
|
332
|
-
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
333
|
-
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
334
|
-
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
371
|
+
pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
|
372
|
+
pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
|
373
|
+
anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
|
374
|
+
dim (int, optional): Dimension along which to split. Defaults to -1.
|
375
|
+
|
335
376
|
Returns:
|
336
|
-
(torch.Tensor): Predicted rotated bounding boxes, (bs, h*w, 4).
|
377
|
+
(torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
|
337
378
|
"""
|
338
379
|
lt, rb = pred_dist.split(2, dim=dim)
|
339
380
|
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|