ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +125 -39
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +34 -33
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +33 -47
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +69 -90
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +31 -38
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +21 -26
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +23 -17
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +29 -24
- ultralytics/models/nas/predict.py +14 -11
- ultralytics/models/nas/val.py +11 -13
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +21 -21
- ultralytics/models/rtdetr/train.py +25 -24
- ultralytics/models/rtdetr/val.py +47 -14
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +30 -12
- ultralytics/models/yolo/classify/train.py +83 -19
- ultralytics/models/yolo/classify/val.py +45 -23
- ultralytics/models/yolo/detect/predict.py +29 -19
- ultralytics/models/yolo/detect/train.py +90 -23
- ultralytics/models/yolo/detect/val.py +150 -29
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +18 -13
- ultralytics/models/yolo/obb/train.py +12 -8
- ultralytics/models/yolo/obb/val.py +35 -22
- ultralytics/models/yolo/pose/predict.py +28 -15
- ultralytics/models/yolo/pose/train.py +21 -8
- ultralytics/models/yolo/pose/val.py +51 -31
- ultralytics/models/yolo/segment/predict.py +27 -16
- ultralytics/models/yolo/segment/train.py +11 -8
- ultralytics/models/yolo/segment/val.py +110 -29
- ultralytics/models/yolo/world/train.py +43 -16
- ultralytics/models/yolo/world/train_world.py +61 -36
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +12 -12
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +226 -79
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +37 -35
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +139 -68
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +37 -56
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +117 -52
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +65 -61
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +72 -59
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +202 -64
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +13 -25
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.88.dist-info/RECORD +0 -250
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py
CHANGED
@@ -21,6 +21,7 @@ class TaskAlignedAssigner(nn.Module):
|
|
21
21
|
Attributes:
|
22
22
|
topk (int): The number of top candidates to consider.
|
23
23
|
num_classes (int): The number of object classes.
|
24
|
+
bg_idx (int): Background class index.
|
24
25
|
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
25
26
|
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
26
27
|
eps (float): A small value to prevent division by zero.
|
@@ -39,23 +40,25 @@ class TaskAlignedAssigner(nn.Module):
|
|
39
40
|
@torch.no_grad()
|
40
41
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
41
42
|
"""
|
42
|
-
Compute the task-aligned assignment.
|
43
|
-
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
43
|
+
Compute the task-aligned assignment.
|
44
44
|
|
45
45
|
Args:
|
46
|
-
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
47
|
-
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
48
|
-
anc_points (Tensor): shape(num_total_anchors, 2)
|
49
|
-
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
50
|
-
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
51
|
-
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
46
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
47
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
48
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
49
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
50
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
51
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
52
52
|
|
53
53
|
Returns:
|
54
|
-
target_labels (Tensor): shape(bs, num_total_anchors)
|
55
|
-
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
56
|
-
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
57
|
-
fg_mask (Tensor): shape(bs, num_total_anchors)
|
58
|
-
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
54
|
+
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
55
|
+
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
56
|
+
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
57
|
+
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
58
|
+
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
59
|
+
|
60
|
+
References:
|
61
|
+
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
59
62
|
"""
|
60
63
|
self.bs = pd_scores.shape[0]
|
61
64
|
self.n_max_boxes = gt_bboxes.shape[1]
|
@@ -81,23 +84,22 @@ class TaskAlignedAssigner(nn.Module):
|
|
81
84
|
|
82
85
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
83
86
|
"""
|
84
|
-
Compute the task-aligned assignment.
|
85
|
-
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
|
87
|
+
Compute the task-aligned assignment.
|
86
88
|
|
87
89
|
Args:
|
88
|
-
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
89
|
-
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
90
|
-
anc_points (Tensor): shape(num_total_anchors, 2)
|
91
|
-
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
92
|
-
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
93
|
-
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
90
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
91
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
92
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
93
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
94
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
95
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
94
96
|
|
95
97
|
Returns:
|
96
|
-
target_labels (Tensor): shape(bs, num_total_anchors)
|
97
|
-
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
98
|
-
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
99
|
-
fg_mask (Tensor): shape(bs, num_total_anchors)
|
100
|
-
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
98
|
+
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
99
|
+
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
100
|
+
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
101
|
+
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
102
|
+
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
101
103
|
"""
|
102
104
|
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
103
105
|
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
@@ -118,7 +120,22 @@ class TaskAlignedAssigner(nn.Module):
|
|
118
120
|
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
119
121
|
|
120
122
|
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
121
|
-
"""
|
123
|
+
"""
|
124
|
+
Get positive mask for each ground truth box.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
128
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
129
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
130
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
131
|
+
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
132
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
136
|
+
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
137
|
+
overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
|
138
|
+
"""
|
122
139
|
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
123
140
|
# Get anchor_align metric, (b, max_num_obj, h*w)
|
124
141
|
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
@@ -130,7 +147,20 @@ class TaskAlignedAssigner(nn.Module):
|
|
130
147
|
return mask_pos, align_metric, overlaps
|
131
148
|
|
132
149
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
133
|
-
"""
|
150
|
+
"""
|
151
|
+
Compute alignment metric given predicted and ground truth bounding boxes.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
155
|
+
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
156
|
+
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
157
|
+
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
158
|
+
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
align_metric (torch.Tensor): Alignment metric combining classification and localization.
|
162
|
+
overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
|
163
|
+
"""
|
134
164
|
na = pd_bboxes.shape[-2]
|
135
165
|
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
136
166
|
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
@@ -151,7 +181,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
151
181
|
return align_metric, overlaps
|
152
182
|
|
153
183
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
154
|
-
"""
|
184
|
+
"""
|
185
|
+
Calculate IoU for horizontal bounding boxes.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
gt_bboxes (torch.Tensor): Ground truth boxes.
|
189
|
+
pd_bboxes (torch.Tensor): Predicted boxes.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
(torch.Tensor): IoU values between each pair of boxes.
|
193
|
+
"""
|
155
194
|
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
156
195
|
|
157
196
|
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
@@ -159,16 +198,16 @@ class TaskAlignedAssigner(nn.Module):
|
|
159
198
|
Select the top-k candidates based on the given metrics.
|
160
199
|
|
161
200
|
Args:
|
162
|
-
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
201
|
+
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
163
202
|
max_num_obj is the maximum number of objects, and h*w represents the
|
164
203
|
total number of anchor points.
|
165
204
|
largest (bool): If True, select the largest values; otherwise, select the smallest values.
|
166
|
-
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
205
|
+
topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
167
206
|
topk is the number of top candidates to consider. If not provided,
|
168
207
|
the top-k values are automatically computed based on the given metrics.
|
169
208
|
|
170
209
|
Returns:
|
171
|
-
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
210
|
+
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
172
211
|
"""
|
173
212
|
# (b, max_num_obj, topk)
|
174
213
|
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
@@ -183,7 +222,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
183
222
|
for k in range(self.topk):
|
184
223
|
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
185
224
|
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
186
|
-
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
187
225
|
# Filter invalid bboxes
|
188
226
|
count_tensor.masked_fill_(count_tensor > 1, 0)
|
189
227
|
|
@@ -194,24 +232,21 @@ class TaskAlignedAssigner(nn.Module):
|
|
194
232
|
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
195
233
|
|
196
234
|
Args:
|
197
|
-
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
235
|
+
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
198
236
|
batch size and max_num_obj is the maximum number of objects.
|
199
|
-
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
200
|
-
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
|
237
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
238
|
+
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
201
239
|
anchor points, with shape (b, h*w), where h*w is the total
|
202
240
|
number of anchor points.
|
203
|
-
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
241
|
+
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
204
242
|
(foreground) anchor points.
|
205
243
|
|
206
244
|
Returns:
|
207
|
-
(
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
|
213
|
-
for positive anchor points, where num_classes is the number
|
214
|
-
of object classes.
|
245
|
+
target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
|
246
|
+
target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
|
247
|
+
anchor points.
|
248
|
+
target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
|
249
|
+
anchor points.
|
215
250
|
"""
|
216
251
|
# Assigned target labels, (b, 1)
|
217
252
|
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
@@ -258,7 +293,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
258
293
|
bs, n_boxes, _ = gt_bboxes.shape
|
259
294
|
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
260
295
|
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
261
|
-
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
262
296
|
return bbox_deltas.amin(3).gt_(eps)
|
263
297
|
|
264
298
|
@staticmethod
|
@@ -275,9 +309,6 @@ class TaskAlignedAssigner(nn.Module):
|
|
275
309
|
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
|
276
310
|
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
|
277
311
|
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
|
278
|
-
|
279
|
-
Note:
|
280
|
-
b: batch size, h: height, w: width.
|
281
312
|
"""
|
282
313
|
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
|
283
314
|
fg_mask = mask_pos.sum(-2)
|
@@ -299,7 +330,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
299
330
|
"""Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
|
300
331
|
|
301
332
|
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
302
|
-
"""IoU
|
333
|
+
"""Calculate IoU for rotated bounding boxes."""
|
303
334
|
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
304
335
|
|
305
336
|
@staticmethod
|
@@ -308,11 +339,11 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
|
308
339
|
Select the positive anchor center in gt for rotated bounding boxes.
|
309
340
|
|
310
341
|
Args:
|
311
|
-
xy_centers (Tensor): shape(h*w, 2)
|
312
|
-
gt_bboxes (Tensor): shape(b, n_boxes, 5)
|
342
|
+
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
343
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
|
313
344
|
|
314
345
|
Returns:
|
315
|
-
(Tensor): shape(b, n_boxes, h*w)
|
346
|
+
(torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
|
316
347
|
"""
|
317
348
|
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
|
318
349
|
corners = xywhr2xyxyxyxy(gt_bboxes)
|
@@ -368,13 +399,13 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
|
368
399
|
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
369
400
|
|
370
401
|
Args:
|
371
|
-
pred_dist (torch.Tensor): Predicted rotated distance
|
372
|
-
pred_angle (torch.Tensor): Predicted angle
|
373
|
-
anchor_points (torch.Tensor): Anchor points
|
402
|
+
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
403
|
+
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
|
404
|
+
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
374
405
|
dim (int, optional): Dimension along which to split. Defaults to -1.
|
375
406
|
|
376
407
|
Returns:
|
377
|
-
(torch.Tensor): Predicted rotated bounding boxes
|
408
|
+
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
|
378
409
|
"""
|
379
410
|
lt, rb = pred_dist.split(2, dim=dim)
|
380
411
|
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|