ultralytics 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +52 -51
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +191 -161
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +4 -6
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +2 -2
- ultralytics/solutions/instance_segmentation.py +7 -4
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -11
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +189 -79
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +45 -29
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/METADATA +2 -2
- ultralytics-8.3.145.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/top_level.txt +0 -0
ultralytics/models/utils/loss.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
4
|
+
|
3
5
|
import torch
|
4
6
|
import torch.nn as nn
|
5
7
|
import torch.nn.functional as F
|
@@ -19,7 +21,7 @@ class DETRLoss(nn.Module):
|
|
19
21
|
|
20
22
|
Attributes:
|
21
23
|
nc (int): Number of classes.
|
22
|
-
loss_gain (
|
24
|
+
loss_gain (Dict[str, float]): Coefficients for different loss components.
|
23
25
|
aux_loss (bool): Whether to compute auxiliary losses.
|
24
26
|
use_fl (bool): Whether to use FocalLoss.
|
25
27
|
use_vfl (bool): Whether to use VarifocalLoss.
|
@@ -33,15 +35,15 @@ class DETRLoss(nn.Module):
|
|
33
35
|
|
34
36
|
def __init__(
|
35
37
|
self,
|
36
|
-
nc=80,
|
37
|
-
loss_gain=None,
|
38
|
-
aux_loss=True,
|
39
|
-
use_fl=True,
|
40
|
-
use_vfl=False,
|
41
|
-
use_uni_match=False,
|
42
|
-
uni_match_ind=0,
|
43
|
-
gamma=1.5,
|
44
|
-
alpha=0.25,
|
38
|
+
nc: int = 80,
|
39
|
+
loss_gain: Optional[Dict[str, float]] = None,
|
40
|
+
aux_loss: bool = True,
|
41
|
+
use_fl: bool = True,
|
42
|
+
use_vfl: bool = False,
|
43
|
+
use_uni_match: bool = False,
|
44
|
+
uni_match_ind: int = 0,
|
45
|
+
gamma: float = 1.5,
|
46
|
+
alpha: float = 0.25,
|
45
47
|
):
|
46
48
|
"""
|
47
49
|
Initialize DETR loss function with customizable components and gains.
|
@@ -51,14 +53,14 @@ class DETRLoss(nn.Module):
|
|
51
53
|
|
52
54
|
Args:
|
53
55
|
nc (int): Number of classes.
|
54
|
-
loss_gain (
|
56
|
+
loss_gain (Dict[str, float], optional): Coefficients for different loss components.
|
55
57
|
aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
|
56
58
|
use_fl (bool): Whether to use FocalLoss.
|
57
59
|
use_vfl (bool): Whether to use VarifocalLoss.
|
58
60
|
use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
|
59
61
|
uni_match_ind (int): Index of fixed layer for uni_match.
|
60
62
|
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
61
|
-
alpha (float
|
63
|
+
alpha (float): The balancing factor used to address class imbalance.
|
62
64
|
"""
|
63
65
|
super().__init__()
|
64
66
|
|
@@ -75,19 +77,21 @@ class DETRLoss(nn.Module):
|
|
75
77
|
self.uni_match_ind = uni_match_ind
|
76
78
|
self.device = None
|
77
79
|
|
78
|
-
def _get_loss_class(
|
80
|
+
def _get_loss_class(
|
81
|
+
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
|
82
|
+
) -> Dict[str, torch.Tensor]:
|
79
83
|
"""
|
80
84
|
Compute classification loss based on predictions, target values, and ground truth scores.
|
81
85
|
|
82
86
|
Args:
|
83
|
-
pred_scores (torch.Tensor): Predicted class scores with shape (
|
84
|
-
targets (torch.Tensor): Target class indices with shape (
|
85
|
-
gt_scores (torch.Tensor): Ground truth confidence scores with shape (
|
87
|
+
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
|
88
|
+
targets (torch.Tensor): Target class indices with shape (B, N).
|
89
|
+
gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
|
86
90
|
num_gts (int): Number of ground truth objects.
|
87
91
|
postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
|
88
92
|
|
89
93
|
Returns:
|
90
|
-
|
94
|
+
(Dict[str, torch.Tensor]): Dictionary containing classification loss value.
|
91
95
|
|
92
96
|
Notes:
|
93
97
|
The function supports different classification loss types:
|
@@ -115,22 +119,21 @@ class DETRLoss(nn.Module):
|
|
115
119
|
|
116
120
|
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
117
121
|
|
118
|
-
def _get_loss_bbox(
|
122
|
+
def _get_loss_bbox(
|
123
|
+
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
|
124
|
+
) -> Dict[str, torch.Tensor]:
|
119
125
|
"""
|
120
126
|
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
121
127
|
|
122
128
|
Args:
|
123
|
-
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (
|
124
|
-
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4)
|
125
|
-
|
126
|
-
postfix (str): String to append to the loss names for identification in multi-loss scenarios.
|
129
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
|
130
|
+
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
|
131
|
+
postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
|
127
132
|
|
128
133
|
Returns:
|
129
|
-
|
130
|
-
- loss_bbox{postfix}
|
131
|
-
|
132
|
-
- loss_giou{postfix} (torch.Tensor): GIoU loss between predicted and ground truth boxes,
|
133
|
-
scaled by the giou loss gain.
|
134
|
+
(Dict[str, torch.Tensor]): Dictionary containing:
|
135
|
+
- loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
|
136
|
+
- loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
|
134
137
|
|
135
138
|
Notes:
|
136
139
|
If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
|
@@ -184,16 +187,16 @@ class DETRLoss(nn.Module):
|
|
184
187
|
|
185
188
|
def _get_loss_aux(
|
186
189
|
self,
|
187
|
-
pred_bboxes,
|
188
|
-
pred_scores,
|
189
|
-
gt_bboxes,
|
190
|
-
gt_cls,
|
191
|
-
gt_groups,
|
192
|
-
match_indices=None,
|
193
|
-
postfix="",
|
194
|
-
masks=None,
|
195
|
-
gt_mask=None,
|
196
|
-
):
|
190
|
+
pred_bboxes: torch.Tensor,
|
191
|
+
pred_scores: torch.Tensor,
|
192
|
+
gt_bboxes: torch.Tensor,
|
193
|
+
gt_cls: torch.Tensor,
|
194
|
+
gt_groups: List[int],
|
195
|
+
match_indices: Optional[List[Tuple]] = None,
|
196
|
+
postfix: str = "",
|
197
|
+
masks: Optional[torch.Tensor] = None,
|
198
|
+
gt_mask: Optional[torch.Tensor] = None,
|
199
|
+
) -> Dict[str, torch.Tensor]:
|
197
200
|
"""
|
198
201
|
Get auxiliary losses for intermediate decoder layers.
|
199
202
|
|
@@ -203,13 +206,13 @@ class DETRLoss(nn.Module):
|
|
203
206
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
204
207
|
gt_cls (torch.Tensor): Ground truth classes.
|
205
208
|
gt_groups (List[int]): Number of ground truths per image.
|
206
|
-
match_indices (List[
|
207
|
-
postfix (str): String to append to loss names.
|
209
|
+
match_indices (List[Tuple], optional): Pre-computed matching indices.
|
210
|
+
postfix (str, optional): String to append to loss names.
|
208
211
|
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
209
212
|
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
210
213
|
|
211
214
|
Returns:
|
212
|
-
(
|
215
|
+
(Dict[str, torch.Tensor]): Dictionary of auxiliary losses.
|
213
216
|
"""
|
214
217
|
# NOTE: loss class, bbox, giou, mask, dice
|
215
218
|
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
@@ -255,32 +258,36 @@ class DETRLoss(nn.Module):
|
|
255
258
|
return loss
|
256
259
|
|
257
260
|
@staticmethod
|
258
|
-
def _get_index(match_indices):
|
261
|
+
def _get_index(match_indices: List[Tuple]) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
259
262
|
"""
|
260
263
|
Extract batch indices, source indices, and destination indices from match indices.
|
261
264
|
|
262
265
|
Args:
|
263
|
-
match_indices (List[
|
266
|
+
match_indices (List[Tuple]): List of tuples containing matched indices.
|
264
267
|
|
265
268
|
Returns:
|
266
|
-
(
|
269
|
+
batch_idx (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
|
270
|
+
dst_idx (torch.Tensor): Destination indices.
|
267
271
|
"""
|
268
272
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
269
273
|
src_idx = torch.cat([src for (src, _) in match_indices])
|
270
274
|
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
271
275
|
return (batch_idx, src_idx), dst_idx
|
272
276
|
|
273
|
-
def _get_assigned_bboxes(
|
277
|
+
def _get_assigned_bboxes(
|
278
|
+
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: List[Tuple]
|
279
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
274
280
|
"""
|
275
281
|
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
276
282
|
|
277
283
|
Args:
|
278
284
|
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
279
285
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
280
|
-
match_indices (List[
|
286
|
+
match_indices (List[Tuple]): List of tuples containing matched indices.
|
281
287
|
|
282
288
|
Returns:
|
283
|
-
(
|
289
|
+
pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
|
290
|
+
gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
|
284
291
|
"""
|
285
292
|
pred_assigned = torch.cat(
|
286
293
|
[
|
@@ -298,16 +305,16 @@ class DETRLoss(nn.Module):
|
|
298
305
|
|
299
306
|
def _get_loss(
|
300
307
|
self,
|
301
|
-
pred_bboxes,
|
302
|
-
pred_scores,
|
303
|
-
gt_bboxes,
|
304
|
-
gt_cls,
|
305
|
-
gt_groups,
|
306
|
-
masks=None,
|
307
|
-
gt_mask=None,
|
308
|
-
postfix="",
|
309
|
-
match_indices=None,
|
310
|
-
):
|
308
|
+
pred_bboxes: torch.Tensor,
|
309
|
+
pred_scores: torch.Tensor,
|
310
|
+
gt_bboxes: torch.Tensor,
|
311
|
+
gt_cls: torch.Tensor,
|
312
|
+
gt_groups: List[int],
|
313
|
+
masks: Optional[torch.Tensor] = None,
|
314
|
+
gt_mask: Optional[torch.Tensor] = None,
|
315
|
+
postfix: str = "",
|
316
|
+
match_indices: Optional[List[Tuple]] = None,
|
317
|
+
) -> Dict[str, torch.Tensor]:
|
311
318
|
"""
|
312
319
|
Calculate losses for a single prediction layer.
|
313
320
|
|
@@ -319,11 +326,11 @@ class DETRLoss(nn.Module):
|
|
319
326
|
gt_groups (List[int]): Number of ground truths per image.
|
320
327
|
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
321
328
|
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
322
|
-
postfix (str): String to append to loss names.
|
323
|
-
match_indices (List[
|
329
|
+
postfix (str, optional): String to append to loss names.
|
330
|
+
match_indices (List[Tuple], optional): Pre-computed matching indices.
|
324
331
|
|
325
332
|
Returns:
|
326
|
-
(
|
333
|
+
(Dict[str, torch.Tensor]): Dictionary of losses.
|
327
334
|
"""
|
328
335
|
if match_indices is None:
|
329
336
|
match_indices = self.matcher(
|
@@ -347,22 +354,26 @@ class DETRLoss(nn.Module):
|
|
347
354
|
# **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
|
348
355
|
}
|
349
356
|
|
350
|
-
def forward(
|
357
|
+
def forward(
|
358
|
+
self,
|
359
|
+
pred_bboxes: torch.Tensor,
|
360
|
+
pred_scores: torch.Tensor,
|
361
|
+
batch: Dict[str, Any],
|
362
|
+
postfix: str = "",
|
363
|
+
**kwargs: Any,
|
364
|
+
) -> Dict[str, torch.Tensor]:
|
351
365
|
"""
|
352
366
|
Calculate loss for predicted bounding boxes and scores.
|
353
367
|
|
354
368
|
Args:
|
355
|
-
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape
|
356
|
-
pred_scores (torch.Tensor): Predicted class scores, shape
|
357
|
-
batch (
|
358
|
-
|
359
|
-
bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
|
360
|
-
gt_groups (List[int]): Number of ground truths for each image in the batch.
|
361
|
-
postfix (str): Postfix for loss names.
|
369
|
+
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
|
370
|
+
pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
|
371
|
+
batch (Dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
|
372
|
+
postfix (str, optional): Postfix for loss names.
|
362
373
|
**kwargs (Any): Additional arguments, may include 'match_indices'.
|
363
374
|
|
364
375
|
Returns:
|
365
|
-
(
|
376
|
+
(Dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
|
366
377
|
|
367
378
|
Notes:
|
368
379
|
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
|
@@ -394,19 +405,26 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
394
405
|
an additional denoising training loss when provided with denoising metadata.
|
395
406
|
"""
|
396
407
|
|
397
|
-
def forward(
|
408
|
+
def forward(
|
409
|
+
self,
|
410
|
+
preds: Tuple[torch.Tensor, torch.Tensor],
|
411
|
+
batch: Dict[str, Any],
|
412
|
+
dn_bboxes: Optional[torch.Tensor] = None,
|
413
|
+
dn_scores: Optional[torch.Tensor] = None,
|
414
|
+
dn_meta: Optional[Dict[str, Any]] = None,
|
415
|
+
) -> Dict[str, torch.Tensor]:
|
398
416
|
"""
|
399
417
|
Forward pass to compute detection loss with optional denoising loss.
|
400
418
|
|
401
419
|
Args:
|
402
|
-
preds (
|
403
|
-
batch (
|
420
|
+
preds (Tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
|
421
|
+
batch (Dict[str, Any]): Batch data containing ground truth information.
|
404
422
|
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
|
405
423
|
dn_scores (torch.Tensor, optional): Denoising scores.
|
406
|
-
dn_meta (
|
424
|
+
dn_meta (Dict[str, Any], optional): Metadata for denoising.
|
407
425
|
|
408
426
|
Returns:
|
409
|
-
(
|
427
|
+
(Dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
|
410
428
|
"""
|
411
429
|
pred_bboxes, pred_scores = preds
|
412
430
|
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
@@ -429,7 +447,9 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
429
447
|
return total_loss
|
430
448
|
|
431
449
|
@staticmethod
|
432
|
-
def get_dn_match_indices(
|
450
|
+
def get_dn_match_indices(
|
451
|
+
dn_pos_idx: List[torch.Tensor], dn_num_group: int, gt_groups: List[int]
|
452
|
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
433
453
|
"""
|
434
454
|
Get match indices for denoising.
|
435
455
|
|
@@ -439,7 +459,7 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
439
459
|
gt_groups (List[int]): List of integers representing number of ground truths per image.
|
440
460
|
|
441
461
|
Returns:
|
442
|
-
(List[
|
462
|
+
(List[Tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
|
443
463
|
"""
|
444
464
|
dn_match_indices = []
|
445
465
|
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
@@ -447,8 +467,9 @@ class RTDETRDetectionLoss(DETRLoss):
|
|
447
467
|
if num_gt > 0:
|
448
468
|
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
449
469
|
gt_idx = gt_idx.repeat(dn_num_group)
|
450
|
-
assert len(dn_pos_idx[i]) == len(gt_idx),
|
451
|
-
|
470
|
+
assert len(dn_pos_idx[i]) == len(gt_idx), (
|
471
|
+
f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
|
472
|
+
)
|
452
473
|
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
453
474
|
else:
|
454
475
|
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
ultralytics/models/utils/ops.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
4
|
+
|
3
5
|
import torch
|
4
6
|
import torch.nn as nn
|
5
7
|
import torch.nn.functional as F
|
@@ -11,40 +13,58 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
|
11
13
|
|
12
14
|
class HungarianMatcher(nn.Module):
|
13
15
|
"""
|
14
|
-
A module implementing the HungarianMatcher
|
15
|
-
end-to-end fashion.
|
16
|
+
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
16
17
|
|
17
|
-
HungarianMatcher performs optimal assignment over
|
18
|
-
function that considers classification scores, bounding box coordinates, and optionally
|
18
|
+
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
19
|
+
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
20
|
+
used in end-to-end object detection models like DETR.
|
19
21
|
|
20
22
|
Attributes:
|
21
|
-
cost_gain (
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
23
|
+
cost_gain (Dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
|
24
|
+
components.
|
25
|
+
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
26
|
+
with_mask (bool): Whether the model makes mask predictions.
|
27
|
+
num_sample_points (int): Number of sample points used in mask cost calculation.
|
28
|
+
alpha (float): Alpha factor in Focal Loss calculation.
|
29
|
+
gamma (float): Gamma factor in Focal Loss calculation.
|
27
30
|
|
28
31
|
Methods:
|
29
|
-
forward:
|
30
|
-
_cost_mask:
|
32
|
+
forward: Compute optimal assignment between predictions and ground truths for a batch.
|
33
|
+
_cost_mask: Compute mask cost and dice cost if masks are predicted.
|
34
|
+
|
35
|
+
Examples:
|
36
|
+
Initialize a HungarianMatcher with custom cost gains
|
37
|
+
>>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
38
|
+
|
39
|
+
Perform matching between predictions and ground truth
|
40
|
+
>>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
|
41
|
+
>>> pred_scores = torch.rand(2, 100, 80) # 80 classes
|
42
|
+
>>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
|
43
|
+
>>> gt_classes = torch.randint(0, 80, (10,))
|
44
|
+
>>> gt_groups = [5, 5] # 5 GT boxes per image
|
45
|
+
>>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
|
31
46
|
"""
|
32
47
|
|
33
|
-
def __init__(
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
cost_gain: Optional[Dict[str, float]] = None,
|
51
|
+
use_fl: bool = True,
|
52
|
+
with_mask: bool = False,
|
53
|
+
num_sample_points: int = 12544,
|
54
|
+
alpha: float = 0.25,
|
55
|
+
gamma: float = 2.0,
|
56
|
+
):
|
34
57
|
"""
|
35
|
-
Initialize
|
36
|
-
|
37
|
-
The HungarianMatcher uses a cost function that considers classification scores, bounding box coordinates,
|
38
|
-
and optionally mask predictions to perform optimal bipartite matching between predictions and ground truths.
|
58
|
+
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
39
59
|
|
40
60
|
Args:
|
41
|
-
cost_gain (
|
42
|
-
Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
43
|
-
use_fl (bool
|
44
|
-
with_mask (bool
|
45
|
-
num_sample_points (int
|
46
|
-
alpha (float
|
47
|
-
gamma (float
|
61
|
+
cost_gain (Dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
62
|
+
components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
63
|
+
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
64
|
+
with_mask (bool): Whether the model makes mask predictions.
|
65
|
+
num_sample_points (int): Number of sample points used in mask cost calculation.
|
66
|
+
alpha (float): Alpha factor in Focal Loss calculation.
|
67
|
+
gamma (float): Gamma factor in Focal Loss calculation.
|
48
68
|
"""
|
49
69
|
super().__init__()
|
50
70
|
if cost_gain is None:
|
@@ -56,41 +76,49 @@ class HungarianMatcher(nn.Module):
|
|
56
76
|
self.alpha = alpha
|
57
77
|
self.gamma = gamma
|
58
78
|
|
59
|
-
def forward(
|
79
|
+
def forward(
|
80
|
+
self,
|
81
|
+
pred_bboxes: torch.Tensor,
|
82
|
+
pred_scores: torch.Tensor,
|
83
|
+
gt_bboxes: torch.Tensor,
|
84
|
+
gt_cls: torch.Tensor,
|
85
|
+
gt_groups: List[int],
|
86
|
+
masks: Optional[torch.Tensor] = None,
|
87
|
+
gt_mask: Optional[List[torch.Tensor]] = None,
|
88
|
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
60
89
|
"""
|
61
|
-
|
62
|
-
|
90
|
+
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
91
|
+
|
92
|
+
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
93
|
+
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
63
94
|
|
64
95
|
Args:
|
65
96
|
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
66
|
-
pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries,
|
67
|
-
|
97
|
+
pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
|
98
|
+
num_classes).
|
68
99
|
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
|
69
|
-
|
70
|
-
|
100
|
+
gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
|
101
|
+
gt_groups (List[int]): Number of ground truth boxes for each image in the batch.
|
71
102
|
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
|
72
|
-
gt_mask (List[torch.Tensor], optional):
|
103
|
+
gt_mask (List[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
73
104
|
|
74
105
|
Returns:
|
75
|
-
(List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
|
76
|
-
|
77
|
-
|
78
|
-
For each batch element, it holds:
|
79
|
-
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
106
|
+
(List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
|
107
|
+
(index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)
|
108
|
+
and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
|
109
|
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
80
110
|
"""
|
81
111
|
bs, nq, nc = pred_scores.shape
|
82
112
|
|
83
113
|
if sum(gt_groups) == 0:
|
84
114
|
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
85
115
|
|
86
|
-
#
|
87
|
-
# (batch_size * num_queries, num_classes)
|
116
|
+
# Flatten to compute cost matrices in batch format
|
88
117
|
pred_scores = pred_scores.detach().view(-1, nc)
|
89
118
|
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
90
|
-
# (batch_size * num_queries, 4)
|
91
119
|
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
92
120
|
|
93
|
-
# Compute
|
121
|
+
# Compute classification cost
|
94
122
|
pred_scores = pred_scores[:, gt_cls]
|
95
123
|
if self.use_fl:
|
96
124
|
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
@@ -99,23 +127,24 @@ class HungarianMatcher(nn.Module):
|
|
99
127
|
else:
|
100
128
|
cost_class = -pred_scores
|
101
129
|
|
102
|
-
# Compute
|
130
|
+
# Compute L1 cost between boxes
|
103
131
|
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
104
132
|
|
105
|
-
# Compute
|
133
|
+
# Compute GIoU cost between boxes, (bs*num_queries, num_gt)
|
106
134
|
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
107
135
|
|
108
|
-
#
|
136
|
+
# Combine costs into final cost matrix
|
109
137
|
C = (
|
110
138
|
self.cost_gain["class"] * cost_class
|
111
139
|
+ self.cost_gain["bbox"] * cost_bbox
|
112
140
|
+ self.cost_gain["giou"] * cost_giou
|
113
141
|
)
|
114
|
-
|
142
|
+
|
143
|
+
# Add mask costs if available
|
115
144
|
if self.with_mask:
|
116
145
|
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
117
146
|
|
118
|
-
# Set invalid values (NaNs and infinities) to 0
|
147
|
+
# Set invalid values (NaNs and infinities) to 0
|
119
148
|
C[C.isnan() | C.isinf()] = 0.0
|
120
149
|
|
121
150
|
C = C.view(bs, nq, -1).cpu()
|
@@ -158,28 +187,49 @@ class HungarianMatcher(nn.Module):
|
|
158
187
|
|
159
188
|
|
160
189
|
def get_cdn_group(
|
161
|
-
batch
|
162
|
-
|
190
|
+
batch: Dict[str, Any],
|
191
|
+
num_classes: int,
|
192
|
+
num_queries: int,
|
193
|
+
class_embed: torch.Tensor,
|
194
|
+
num_dn: int = 100,
|
195
|
+
cls_noise_ratio: float = 0.5,
|
196
|
+
box_noise_scale: float = 1.0,
|
197
|
+
training: bool = False,
|
198
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[Dict[str, Any]]]:
|
163
199
|
"""
|
164
|
-
|
200
|
+
Generate contrastive denoising training group with positive and negative samples from ground truths.
|
201
|
+
|
202
|
+
This function creates denoising queries for contrastive denoising training by adding noise to ground truth
|
203
|
+
bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
165
204
|
|
166
205
|
Args:
|
167
|
-
batch (
|
168
|
-
(torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int])
|
169
|
-
|
170
|
-
num_classes (int):
|
171
|
-
num_queries (int): Number of queries.
|
172
|
-
class_embed (torch.Tensor):
|
173
|
-
num_dn (int
|
174
|
-
cls_noise_ratio (float
|
175
|
-
box_noise_scale (float
|
176
|
-
training (bool
|
206
|
+
batch (Dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
|
207
|
+
'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (List[int]) indicating number of
|
208
|
+
ground truths per image.
|
209
|
+
num_classes (int): Total number of object classes.
|
210
|
+
num_queries (int): Number of object queries.
|
211
|
+
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
212
|
+
num_dn (int): Number of denoising queries to generate.
|
213
|
+
cls_noise_ratio (float): Noise ratio for class labels.
|
214
|
+
box_noise_scale (float): Noise scale for bounding box coordinates.
|
215
|
+
training (bool): Whether model is in training mode.
|
177
216
|
|
178
217
|
Returns:
|
179
|
-
padding_cls (
|
180
|
-
padding_bbox (
|
181
|
-
attn_mask (
|
182
|
-
dn_meta (
|
218
|
+
padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
|
219
|
+
padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
|
220
|
+
attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
|
221
|
+
dn_meta (Dict[str, Any] | None): Meta information dictionary containing denoising parameters.
|
222
|
+
|
223
|
+
Examples:
|
224
|
+
Generate denoising group for training
|
225
|
+
>>> batch = {
|
226
|
+
... "cls": torch.tensor([0, 1, 2]),
|
227
|
+
... "bboxes": torch.rand(3, 4),
|
228
|
+
... "batch_idx": torch.tensor([0, 0, 1]),
|
229
|
+
... "gt_groups": [2, 1],
|
230
|
+
... }
|
231
|
+
>>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
|
232
|
+
>>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
|
183
233
|
"""
|
184
234
|
if (not training) or num_dn <= 0 or batch is None:
|
185
235
|
return None, None, None, None
|
@@ -197,7 +247,7 @@ def get_cdn_group(
|
|
197
247
|
gt_bbox = batch["bboxes"] # bs*num, 4
|
198
248
|
b_idx = batch["batch_idx"]
|
199
249
|
|
200
|
-
# Each group has positive and negative queries
|
250
|
+
# Each group has positive and negative queries
|
201
251
|
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
202
252
|
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
203
253
|
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
@@ -207,10 +257,10 @@ def get_cdn_group(
|
|
207
257
|
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
208
258
|
|
209
259
|
if cls_noise_ratio > 0:
|
210
|
-
#
|
260
|
+
# Apply class label noise to half of the samples
|
211
261
|
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
212
262
|
idx = torch.nonzero(mask).squeeze(-1)
|
213
|
-
# Randomly
|
263
|
+
# Randomly assign new class labels
|
214
264
|
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
215
265
|
dn_cls[idx] = new_label
|
216
266
|
|
@@ -229,7 +279,6 @@ def get_cdn_group(
|
|
229
279
|
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
230
280
|
|
231
281
|
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
232
|
-
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
233
282
|
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
234
283
|
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
235
284
|
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|