dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py
CHANGED
|
@@ -4,16 +4,13 @@ import torch
|
|
|
4
4
|
import torch.nn as nn
|
|
5
5
|
|
|
6
6
|
from . import LOGGER
|
|
7
|
-
from .checks import check_version
|
|
8
7
|
from .metrics import bbox_iou, probiou
|
|
9
8
|
from .ops import xywhr2xyxyxyxy
|
|
10
|
-
|
|
11
|
-
TORCH_1_10 = check_version(torch.__version__, "1.10.0")
|
|
9
|
+
from .torch_utils import TORCH_1_11
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
class TaskAlignedAssigner(nn.Module):
|
|
15
|
-
"""
|
|
16
|
-
A task-aligned assigner for object detection.
|
|
13
|
+
"""A task-aligned assigner for object detection.
|
|
17
14
|
|
|
18
15
|
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
|
19
16
|
classification and localization information.
|
|
@@ -21,26 +18,31 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
21
18
|
Attributes:
|
|
22
19
|
topk (int): The number of top candidates to consider.
|
|
23
20
|
num_classes (int): The number of object classes.
|
|
24
|
-
bg_idx (int): Background class index.
|
|
25
21
|
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
|
26
22
|
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
|
27
23
|
eps (float): A small value to prevent division by zero.
|
|
28
24
|
"""
|
|
29
25
|
|
|
30
|
-
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
|
31
|
-
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
26
|
+
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
|
|
27
|
+
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
topk (int, optional): The number of top candidates to consider.
|
|
31
|
+
num_classes (int, optional): The number of object classes.
|
|
32
|
+
alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
|
|
33
|
+
beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
|
|
34
|
+
eps (float, optional): A small value to prevent division by zero.
|
|
35
|
+
"""
|
|
32
36
|
super().__init__()
|
|
33
37
|
self.topk = topk
|
|
34
38
|
self.num_classes = num_classes
|
|
35
|
-
self.bg_idx = num_classes
|
|
36
39
|
self.alpha = alpha
|
|
37
40
|
self.beta = beta
|
|
38
41
|
self.eps = eps
|
|
39
42
|
|
|
40
43
|
@torch.no_grad()
|
|
41
44
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
42
|
-
"""
|
|
43
|
-
Compute the task-aligned assignment.
|
|
45
|
+
"""Compute the task-aligned assignment.
|
|
44
46
|
|
|
45
47
|
Args:
|
|
46
48
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -66,7 +68,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
66
68
|
|
|
67
69
|
if self.n_max_boxes == 0:
|
|
68
70
|
return (
|
|
69
|
-
torch.full_like(pd_scores[..., 0], self.
|
|
71
|
+
torch.full_like(pd_scores[..., 0], self.num_classes),
|
|
70
72
|
torch.zeros_like(pd_bboxes),
|
|
71
73
|
torch.zeros_like(pd_scores),
|
|
72
74
|
torch.zeros_like(pd_scores[..., 0]),
|
|
@@ -83,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
83
85
|
return tuple(t.to(device) for t in result)
|
|
84
86
|
|
|
85
87
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
|
86
|
-
"""
|
|
87
|
-
Compute the task-aligned assignment.
|
|
88
|
+
"""Compute the task-aligned assignment.
|
|
88
89
|
|
|
89
90
|
Args:
|
|
90
91
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -120,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
120
121
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
|
121
122
|
|
|
122
123
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
|
123
|
-
"""
|
|
124
|
-
Get positive mask for each ground truth box.
|
|
124
|
+
"""Get positive mask for each ground truth box.
|
|
125
125
|
|
|
126
126
|
Args:
|
|
127
127
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -134,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
134
134
|
Returns:
|
|
135
135
|
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
|
136
136
|
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
|
137
|
-
overlaps (torch.Tensor): Overlaps between predicted
|
|
137
|
+
overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
|
|
138
138
|
"""
|
|
139
139
|
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
|
140
140
|
# Get anchor_align metric, (b, max_num_obj, h*w)
|
|
@@ -147,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
147
147
|
return mask_pos, align_metric, overlaps
|
|
148
148
|
|
|
149
149
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
|
150
|
-
"""
|
|
151
|
-
Compute alignment metric given predicted and ground truth bounding boxes.
|
|
150
|
+
"""Compute alignment metric given predicted and ground truth bounding boxes.
|
|
152
151
|
|
|
153
152
|
Args:
|
|
154
153
|
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
|
@@ -181,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
181
180
|
return align_metric, overlaps
|
|
182
181
|
|
|
183
182
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
|
184
|
-
"""
|
|
185
|
-
Calculate IoU for horizontal bounding boxes.
|
|
183
|
+
"""Calculate IoU for horizontal bounding boxes.
|
|
186
184
|
|
|
187
185
|
Args:
|
|
188
186
|
gt_bboxes (torch.Tensor): Ground truth boxes.
|
|
@@ -193,24 +191,21 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
193
191
|
"""
|
|
194
192
|
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
|
195
193
|
|
|
196
|
-
def select_topk_candidates(self, metrics,
|
|
197
|
-
"""
|
|
198
|
-
Select the top-k candidates based on the given metrics.
|
|
194
|
+
def select_topk_candidates(self, metrics, topk_mask=None):
|
|
195
|
+
"""Select the top-k candidates based on the given metrics.
|
|
199
196
|
|
|
200
197
|
Args:
|
|
201
|
-
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
topk is the number of top candidates to consider. If not provided,
|
|
207
|
-
the top-k values are automatically computed based on the given metrics.
|
|
198
|
+
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
|
199
|
+
the maximum number of objects, and h*w represents the total number of anchor points.
|
|
200
|
+
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
|
|
201
|
+
is the number of top candidates to consider. If not provided, the top-k values are automatically
|
|
202
|
+
computed based on the given metrics.
|
|
208
203
|
|
|
209
204
|
Returns:
|
|
210
205
|
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
|
211
206
|
"""
|
|
212
207
|
# (b, max_num_obj, topk)
|
|
213
|
-
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=
|
|
208
|
+
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
|
|
214
209
|
if topk_mask is None:
|
|
215
210
|
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
|
216
211
|
# (b, max_num_obj, topk)
|
|
@@ -228,25 +223,21 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
228
223
|
return count_tensor.to(metrics.dtype)
|
|
229
224
|
|
|
230
225
|
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
|
231
|
-
"""
|
|
232
|
-
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
226
|
+
"""Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
|
233
227
|
|
|
234
228
|
Args:
|
|
235
|
-
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
|
236
|
-
|
|
229
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
|
|
230
|
+
max_num_obj is the maximum number of objects.
|
|
237
231
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
|
238
|
-
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
(foreground) anchor points.
|
|
232
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
|
|
233
|
+
shape (b, h*w), where h*w is the total number of anchor points.
|
|
234
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
|
|
235
|
+
points.
|
|
243
236
|
|
|
244
237
|
Returns:
|
|
245
|
-
target_labels (torch.Tensor):
|
|
246
|
-
target_bboxes (torch.Tensor):
|
|
247
|
-
|
|
248
|
-
target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
|
|
249
|
-
anchor points.
|
|
238
|
+
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
|
239
|
+
target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
|
|
240
|
+
target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
|
|
250
241
|
"""
|
|
251
242
|
# Assigned target labels, (b, 1)
|
|
252
243
|
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
|
@@ -274,20 +265,19 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
274
265
|
|
|
275
266
|
@staticmethod
|
|
276
267
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
|
277
|
-
"""
|
|
278
|
-
Select positive anchor centers within ground truth bounding boxes.
|
|
268
|
+
"""Select positive anchor centers within ground truth bounding boxes.
|
|
279
269
|
|
|
280
270
|
Args:
|
|
281
271
|
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
|
282
272
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
|
283
|
-
eps (float, optional): Small value for numerical stability.
|
|
273
|
+
eps (float, optional): Small value for numerical stability.
|
|
284
274
|
|
|
285
275
|
Returns:
|
|
286
276
|
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
|
287
277
|
|
|
288
|
-
|
|
289
|
-
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
290
|
-
Bounding box format: [x_min, y_min, x_max, y_max].
|
|
278
|
+
Notes:
|
|
279
|
+
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
|
280
|
+
- Bounding box format: [x_min, y_min, x_max, y_max].
|
|
291
281
|
"""
|
|
292
282
|
n_anchors = xy_centers.shape[0]
|
|
293
283
|
bs, n_boxes, _ = gt_bboxes.shape
|
|
@@ -297,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
|
297
287
|
|
|
298
288
|
@staticmethod
|
|
299
289
|
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
|
300
|
-
"""
|
|
301
|
-
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
290
|
+
"""Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
|
302
291
|
|
|
303
292
|
Args:
|
|
304
293
|
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
|
@@ -335,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
|
335
324
|
|
|
336
325
|
@staticmethod
|
|
337
326
|
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
|
338
|
-
"""
|
|
339
|
-
Select the positive anchor center in gt for rotated bounding boxes.
|
|
327
|
+
"""Select the positive anchor center in gt for rotated bounding boxes.
|
|
340
328
|
|
|
341
329
|
Args:
|
|
342
330
|
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
|
@@ -370,7 +358,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
|
|
370
358
|
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
|
|
371
359
|
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
|
372
360
|
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
|
373
|
-
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if
|
|
361
|
+
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
|
|
374
362
|
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
|
375
363
|
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
|
376
364
|
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
|
@@ -384,7 +372,7 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
|
|
384
372
|
if xywh:
|
|
385
373
|
c_xy = (x1y1 + x2y2) / 2
|
|
386
374
|
wh = x2y2 - x1y1
|
|
387
|
-
return torch.cat(
|
|
375
|
+
return torch.cat([c_xy, wh], dim) # xywh bbox
|
|
388
376
|
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
|
389
377
|
|
|
390
378
|
|
|
@@ -395,14 +383,13 @@ def bbox2dist(anchor_points, bbox, reg_max):
|
|
|
395
383
|
|
|
396
384
|
|
|
397
385
|
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
398
|
-
"""
|
|
399
|
-
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
386
|
+
"""Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
|
400
387
|
|
|
401
388
|
Args:
|
|
402
389
|
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
|
403
390
|
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
|
|
404
391
|
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
|
405
|
-
dim (int, optional): Dimension along which to split.
|
|
392
|
+
dim (int, optional): Dimension along which to split.
|
|
406
393
|
|
|
407
394
|
Returns:
|
|
408
395
|
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
|