dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
3
5
  import torch
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
@@ -19,7 +21,7 @@ class DETRLoss(nn.Module):
19
21
 
20
22
  Attributes:
21
23
  nc (int): Number of classes.
22
- loss_gain (dict): Coefficients for different loss components.
24
+ loss_gain (Dict[str, float]): Coefficients for different loss components.
23
25
  aux_loss (bool): Whether to compute auxiliary losses.
24
26
  use_fl (bool): Whether to use FocalLoss.
25
27
  use_vfl (bool): Whether to use VarifocalLoss.
@@ -33,15 +35,15 @@ class DETRLoss(nn.Module):
33
35
 
34
36
  def __init__(
35
37
  self,
36
- nc=80,
37
- loss_gain=None,
38
- aux_loss=True,
39
- use_fl=True,
40
- use_vfl=False,
41
- use_uni_match=False,
42
- uni_match_ind=0,
43
- gamma=1.5,
44
- alpha=0.25,
38
+ nc: int = 80,
39
+ loss_gain: Optional[Dict[str, float]] = None,
40
+ aux_loss: bool = True,
41
+ use_fl: bool = True,
42
+ use_vfl: bool = False,
43
+ use_uni_match: bool = False,
44
+ uni_match_ind: int = 0,
45
+ gamma: float = 1.5,
46
+ alpha: float = 0.25,
45
47
  ):
46
48
  """
47
49
  Initialize DETR loss function with customizable components and gains.
@@ -51,14 +53,14 @@ class DETRLoss(nn.Module):
51
53
 
52
54
  Args:
53
55
  nc (int): Number of classes.
54
- loss_gain (dict): Coefficients for different loss components.
56
+ loss_gain (Dict[str, float], optional): Coefficients for different loss components.
55
57
  aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
56
58
  use_fl (bool): Whether to use FocalLoss.
57
59
  use_vfl (bool): Whether to use VarifocalLoss.
58
60
  use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
59
61
  uni_match_ind (int): Index of fixed layer for uni_match.
60
62
  gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
61
- alpha (float | list): The balancing factor used to address class imbalance.
63
+ alpha (float): The balancing factor used to address class imbalance.
62
64
  """
63
65
  super().__init__()
64
66
 
@@ -75,19 +77,21 @@ class DETRLoss(nn.Module):
75
77
  self.uni_match_ind = uni_match_ind
76
78
  self.device = None
77
79
 
78
- def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
80
+ def _get_loss_class(
81
+ self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
82
+ ) -> Dict[str, torch.Tensor]:
79
83
  """
80
84
  Compute classification loss based on predictions, target values, and ground truth scores.
81
85
 
82
86
  Args:
83
- pred_scores (torch.Tensor): Predicted class scores with shape (batch_size, num_queries, num_classes).
84
- targets (torch.Tensor): Target class indices with shape (batch_size, num_queries).
85
- gt_scores (torch.Tensor): Ground truth confidence scores with shape (batch_size, num_queries).
87
+ pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
88
+ targets (torch.Tensor): Target class indices with shape (B, N).
89
+ gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
86
90
  num_gts (int): Number of ground truth objects.
87
91
  postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
88
92
 
89
93
  Returns:
90
- loss_cls (torch.Tensor): Classification loss value.
94
+ (Dict[str, torch.Tensor]): Dictionary containing classification loss value.
91
95
 
92
96
  Notes:
93
97
  The function supports different classification loss types:
@@ -115,22 +119,21 @@ class DETRLoss(nn.Module):
115
119
 
116
120
  return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
117
121
 
118
- def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
122
+ def _get_loss_bbox(
123
+ self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
124
+ ) -> Dict[str, torch.Tensor]:
119
125
  """
120
126
  Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
121
127
 
122
128
  Args:
123
- pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
124
- gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4), where N is the total
125
- number of ground truth boxes.
126
- postfix (str): String to append to the loss names for identification in multi-loss scenarios.
129
+ pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
130
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
131
+ postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
127
132
 
128
133
  Returns:
129
- loss (dict): Dictionary containing:
130
- - loss_bbox{postfix} (torch.Tensor): L1 loss between predicted and ground truth boxes,
131
- scaled by the bbox loss gain.
132
- - loss_giou{postfix} (torch.Tensor): GIoU loss between predicted and ground truth boxes,
133
- scaled by the giou loss gain.
134
+ (Dict[str, torch.Tensor]): Dictionary containing:
135
+ - loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
136
+ - loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
134
137
 
135
138
  Notes:
136
139
  If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
@@ -184,16 +187,16 @@ class DETRLoss(nn.Module):
184
187
 
185
188
  def _get_loss_aux(
186
189
  self,
187
- pred_bboxes,
188
- pred_scores,
189
- gt_bboxes,
190
- gt_cls,
191
- gt_groups,
192
- match_indices=None,
193
- postfix="",
194
- masks=None,
195
- gt_mask=None,
196
- ):
190
+ pred_bboxes: torch.Tensor,
191
+ pred_scores: torch.Tensor,
192
+ gt_bboxes: torch.Tensor,
193
+ gt_cls: torch.Tensor,
194
+ gt_groups: List[int],
195
+ match_indices: Optional[List[Tuple]] = None,
196
+ postfix: str = "",
197
+ masks: Optional[torch.Tensor] = None,
198
+ gt_mask: Optional[torch.Tensor] = None,
199
+ ) -> Dict[str, torch.Tensor]:
197
200
  """
198
201
  Get auxiliary losses for intermediate decoder layers.
199
202
 
@@ -203,13 +206,13 @@ class DETRLoss(nn.Module):
203
206
  gt_bboxes (torch.Tensor): Ground truth bounding boxes.
204
207
  gt_cls (torch.Tensor): Ground truth classes.
205
208
  gt_groups (List[int]): Number of ground truths per image.
206
- match_indices (List[tuple], optional): Pre-computed matching indices.
207
- postfix (str): String to append to loss names.
209
+ match_indices (List[Tuple], optional): Pre-computed matching indices.
210
+ postfix (str, optional): String to append to loss names.
208
211
  masks (torch.Tensor, optional): Predicted masks if using segmentation.
209
212
  gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
210
213
 
211
214
  Returns:
212
- (dict): Dictionary of auxiliary losses.
215
+ (Dict[str, torch.Tensor]): Dictionary of auxiliary losses.
213
216
  """
214
217
  # NOTE: loss class, bbox, giou, mask, dice
215
218
  loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
@@ -255,32 +258,36 @@ class DETRLoss(nn.Module):
255
258
  return loss
256
259
 
257
260
  @staticmethod
258
- def _get_index(match_indices):
261
+ def _get_index(match_indices: List[Tuple]) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
259
262
  """
260
263
  Extract batch indices, source indices, and destination indices from match indices.
261
264
 
262
265
  Args:
263
- match_indices (List[tuple]): List of tuples containing matched indices.
266
+ match_indices (List[Tuple]): List of tuples containing matched indices.
264
267
 
265
268
  Returns:
266
- (tuple): Tuple containing (batch_idx, src_idx) and dst_idx.
269
+ batch_idx (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
270
+ dst_idx (torch.Tensor): Destination indices.
267
271
  """
268
272
  batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
269
273
  src_idx = torch.cat([src for (src, _) in match_indices])
270
274
  dst_idx = torch.cat([dst for (_, dst) in match_indices])
271
275
  return (batch_idx, src_idx), dst_idx
272
276
 
273
- def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
277
+ def _get_assigned_bboxes(
278
+ self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: List[Tuple]
279
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
274
280
  """
275
281
  Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
276
282
 
277
283
  Args:
278
284
  pred_bboxes (torch.Tensor): Predicted bounding boxes.
279
285
  gt_bboxes (torch.Tensor): Ground truth bounding boxes.
280
- match_indices (List[tuple]): List of tuples containing matched indices.
286
+ match_indices (List[Tuple]): List of tuples containing matched indices.
281
287
 
282
288
  Returns:
283
- (tuple): Tuple containing assigned predictions and ground truths.
289
+ pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
290
+ gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
284
291
  """
285
292
  pred_assigned = torch.cat(
286
293
  [
@@ -298,16 +305,16 @@ class DETRLoss(nn.Module):
298
305
 
299
306
  def _get_loss(
300
307
  self,
301
- pred_bboxes,
302
- pred_scores,
303
- gt_bboxes,
304
- gt_cls,
305
- gt_groups,
306
- masks=None,
307
- gt_mask=None,
308
- postfix="",
309
- match_indices=None,
310
- ):
308
+ pred_bboxes: torch.Tensor,
309
+ pred_scores: torch.Tensor,
310
+ gt_bboxes: torch.Tensor,
311
+ gt_cls: torch.Tensor,
312
+ gt_groups: List[int],
313
+ masks: Optional[torch.Tensor] = None,
314
+ gt_mask: Optional[torch.Tensor] = None,
315
+ postfix: str = "",
316
+ match_indices: Optional[List[Tuple]] = None,
317
+ ) -> Dict[str, torch.Tensor]:
311
318
  """
312
319
  Calculate losses for a single prediction layer.
313
320
 
@@ -319,11 +326,11 @@ class DETRLoss(nn.Module):
319
326
  gt_groups (List[int]): Number of ground truths per image.
320
327
  masks (torch.Tensor, optional): Predicted masks if using segmentation.
321
328
  gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
322
- postfix (str): String to append to loss names.
323
- match_indices (List[tuple], optional): Pre-computed matching indices.
329
+ postfix (str, optional): String to append to loss names.
330
+ match_indices (List[Tuple], optional): Pre-computed matching indices.
324
331
 
325
332
  Returns:
326
- (dict): Dictionary of losses.
333
+ (Dict[str, torch.Tensor]): Dictionary of losses.
327
334
  """
328
335
  if match_indices is None:
329
336
  match_indices = self.matcher(
@@ -347,22 +354,26 @@ class DETRLoss(nn.Module):
347
354
  # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
348
355
  }
349
356
 
350
- def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
357
+ def forward(
358
+ self,
359
+ pred_bboxes: torch.Tensor,
360
+ pred_scores: torch.Tensor,
361
+ batch: Dict[str, Any],
362
+ postfix: str = "",
363
+ **kwargs: Any,
364
+ ) -> Dict[str, torch.Tensor]:
351
365
  """
352
366
  Calculate loss for predicted bounding boxes and scores.
353
367
 
354
368
  Args:
355
- pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
356
- pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
357
- batch (dict): Batch information containing:
358
- cls (torch.Tensor): Ground truth classes, shape [num_gts].
359
- bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
360
- gt_groups (List[int]): Number of ground truths for each image in the batch.
361
- postfix (str): Postfix for loss names.
369
+ pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
370
+ pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
371
+ batch (Dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
372
+ postfix (str, optional): Postfix for loss names.
362
373
  **kwargs (Any): Additional arguments, may include 'match_indices'.
363
374
 
364
375
  Returns:
365
- (dict): Computed losses, including main and auxiliary (if enabled).
376
+ (Dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
366
377
 
367
378
  Notes:
368
379
  Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
@@ -394,19 +405,26 @@ class RTDETRDetectionLoss(DETRLoss):
394
405
  an additional denoising training loss when provided with denoising metadata.
395
406
  """
396
407
 
397
- def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
408
+ def forward(
409
+ self,
410
+ preds: Tuple[torch.Tensor, torch.Tensor],
411
+ batch: Dict[str, Any],
412
+ dn_bboxes: Optional[torch.Tensor] = None,
413
+ dn_scores: Optional[torch.Tensor] = None,
414
+ dn_meta: Optional[Dict[str, Any]] = None,
415
+ ) -> Dict[str, torch.Tensor]:
398
416
  """
399
417
  Forward pass to compute detection loss with optional denoising loss.
400
418
 
401
419
  Args:
402
- preds (tuple): Tuple containing predicted bounding boxes and scores.
403
- batch (dict): Batch data containing ground truth information.
420
+ preds (Tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
421
+ batch (Dict[str, Any]): Batch data containing ground truth information.
404
422
  dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
405
423
  dn_scores (torch.Tensor, optional): Denoising scores.
406
- dn_meta (dict, optional): Metadata for denoising.
424
+ dn_meta (Dict[str, Any], optional): Metadata for denoising.
407
425
 
408
426
  Returns:
409
- (dict): Dictionary containing total loss and denoising loss if applicable.
427
+ (Dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
410
428
  """
411
429
  pred_bboxes, pred_scores = preds
412
430
  total_loss = super().forward(pred_bboxes, pred_scores, batch)
@@ -429,7 +447,9 @@ class RTDETRDetectionLoss(DETRLoss):
429
447
  return total_loss
430
448
 
431
449
  @staticmethod
432
- def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
450
+ def get_dn_match_indices(
451
+ dn_pos_idx: List[torch.Tensor], dn_num_group: int, gt_groups: List[int]
452
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
433
453
  """
434
454
  Get match indices for denoising.
435
455
 
@@ -439,7 +459,7 @@ class RTDETRDetectionLoss(DETRLoss):
439
459
  gt_groups (List[int]): List of integers representing number of ground truths per image.
440
460
 
441
461
  Returns:
442
- (List[tuple]): List of tuples containing matched indices for denoising.
462
+ (List[Tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
443
463
  """
444
464
  dn_match_indices = []
445
465
  idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
@@ -447,8 +467,9 @@ class RTDETRDetectionLoss(DETRLoss):
447
467
  if num_gt > 0:
448
468
  gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
449
469
  gt_idx = gt_idx.repeat(dn_num_group)
450
- assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
451
- f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
470
+ assert len(dn_pos_idx[i]) == len(gt_idx), (
471
+ f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
472
+ )
452
473
  dn_match_indices.append((dn_pos_idx[i], gt_idx))
453
474
  else:
454
475
  dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
3
5
  import torch
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
@@ -11,40 +13,58 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
11
13
 
12
14
  class HungarianMatcher(nn.Module):
13
15
  """
14
- A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
15
- end-to-end fashion.
16
+ A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
16
17
 
17
- HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
18
- function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
18
+ HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
19
+ function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
20
+ used in end-to-end object detection models like DETR.
19
21
 
20
22
  Attributes:
21
- cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
22
- use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
23
- with_mask (bool): Indicates whether the model makes mask predictions.
24
- num_sample_points (int): The number of sample points used in mask cost calculation.
25
- alpha (float): The alpha factor in Focal Loss calculation.
26
- gamma (float): The gamma factor in Focal Loss calculation.
23
+ cost_gain (Dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
24
+ components.
25
+ use_fl (bool): Whether to use Focal Loss for classification cost calculation.
26
+ with_mask (bool): Whether the model makes mask predictions.
27
+ num_sample_points (int): Number of sample points used in mask cost calculation.
28
+ alpha (float): Alpha factor in Focal Loss calculation.
29
+ gamma (float): Gamma factor in Focal Loss calculation.
27
30
 
28
31
  Methods:
29
- forward: Computes the assignment between predictions and ground truths for a batch.
30
- _cost_mask: Computes the mask cost and dice cost if masks are predicted.
32
+ forward: Compute optimal assignment between predictions and ground truths for a batch.
33
+ _cost_mask: Compute mask cost and dice cost if masks are predicted.
34
+
35
+ Examples:
36
+ Initialize a HungarianMatcher with custom cost gains
37
+ >>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
38
+
39
+ Perform matching between predictions and ground truth
40
+ >>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
41
+ >>> pred_scores = torch.rand(2, 100, 80) # 80 classes
42
+ >>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
43
+ >>> gt_classes = torch.randint(0, 80, (10,))
44
+ >>> gt_groups = [5, 5] # 5 GT boxes per image
45
+ >>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
31
46
  """
32
47
 
33
- def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
48
+ def __init__(
49
+ self,
50
+ cost_gain: Optional[Dict[str, float]] = None,
51
+ use_fl: bool = True,
52
+ with_mask: bool = False,
53
+ num_sample_points: int = 12544,
54
+ alpha: float = 0.25,
55
+ gamma: float = 2.0,
56
+ ):
34
57
  """
35
- Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.
36
-
37
- The HungarianMatcher uses a cost function that considers classification scores, bounding box coordinates,
38
- and optionally mask predictions to perform optimal bipartite matching between predictions and ground truths.
58
+ Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
39
59
 
40
60
  Args:
41
- cost_gain (dict, optional): Dictionary of cost coefficients for different components of the matching cost.
42
- Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
43
- use_fl (bool, optional): Whether to use Focal Loss for the classification cost calculation.
44
- with_mask (bool, optional): Whether the model makes mask predictions.
45
- num_sample_points (int, optional): Number of sample points used in mask cost calculation.
46
- alpha (float, optional): Alpha factor in Focal Loss calculation.
47
- gamma (float, optional): Gamma factor in Focal Loss calculation.
61
+ cost_gain (Dict[str, float], optional): Dictionary of cost coefficients for different matching cost
62
+ components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
63
+ use_fl (bool): Whether to use Focal Loss for classification cost calculation.
64
+ with_mask (bool): Whether the model makes mask predictions.
65
+ num_sample_points (int): Number of sample points used in mask cost calculation.
66
+ alpha (float): Alpha factor in Focal Loss calculation.
67
+ gamma (float): Gamma factor in Focal Loss calculation.
48
68
  """
49
69
  super().__init__()
50
70
  if cost_gain is None:
@@ -56,41 +76,49 @@ class HungarianMatcher(nn.Module):
56
76
  self.alpha = alpha
57
77
  self.gamma = gamma
58
78
 
59
- def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
79
+ def forward(
80
+ self,
81
+ pred_bboxes: torch.Tensor,
82
+ pred_scores: torch.Tensor,
83
+ gt_bboxes: torch.Tensor,
84
+ gt_cls: torch.Tensor,
85
+ gt_groups: List[int],
86
+ masks: Optional[torch.Tensor] = None,
87
+ gt_mask: Optional[List[torch.Tensor]] = None,
88
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
60
89
  """
61
- Forward pass for HungarianMatcher. Computes costs based on prediction and ground truth and finds the optimal
62
- matching between predictions and ground truth based on these costs.
90
+ Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
91
+
92
+ This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
93
+ mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
63
94
 
64
95
  Args:
65
96
  pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
66
- pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries, num_classes).
67
- gt_cls (torch.Tensor): Ground truth classes with shape (num_gts, ).
97
+ pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
98
+ num_classes).
68
99
  gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
69
- gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
70
- each image.
100
+ gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
101
+ gt_groups (List[int]): Number of ground truth boxes for each image in the batch.
71
102
  masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
72
- gt_mask (List[torch.Tensor], optional): List of ground truth masks, each with shape (num_masks, Height, Width).
103
+ gt_mask (List[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
73
104
 
74
105
  Returns:
75
- (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
76
- - index_i is the tensor of indices of the selected predictions (in order)
77
- - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
78
- For each batch element, it holds:
79
- len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
106
+ (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
107
+ (index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)
108
+ and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
109
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
80
110
  """
81
111
  bs, nq, nc = pred_scores.shape
82
112
 
83
113
  if sum(gt_groups) == 0:
84
114
  return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
85
115
 
86
- # We flatten to compute the cost matrices in a batch
87
- # (batch_size * num_queries, num_classes)
116
+ # Flatten to compute cost matrices in batch format
88
117
  pred_scores = pred_scores.detach().view(-1, nc)
89
118
  pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
90
- # (batch_size * num_queries, 4)
91
119
  pred_bboxes = pred_bboxes.detach().view(-1, 4)
92
120
 
93
- # Compute the classification cost
121
+ # Compute classification cost
94
122
  pred_scores = pred_scores[:, gt_cls]
95
123
  if self.use_fl:
96
124
  neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
@@ -99,23 +127,24 @@ class HungarianMatcher(nn.Module):
99
127
  else:
100
128
  cost_class = -pred_scores
101
129
 
102
- # Compute the L1 cost between boxes
130
+ # Compute L1 cost between boxes
103
131
  cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
104
132
 
105
- # Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
133
+ # Compute GIoU cost between boxes, (bs*num_queries, num_gt)
106
134
  cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
107
135
 
108
- # Final cost matrix
136
+ # Combine costs into final cost matrix
109
137
  C = (
110
138
  self.cost_gain["class"] * cost_class
111
139
  + self.cost_gain["bbox"] * cost_bbox
112
140
  + self.cost_gain["giou"] * cost_giou
113
141
  )
114
- # Compute the mask cost and dice cost
142
+
143
+ # Add mask costs if available
115
144
  if self.with_mask:
116
145
  C += self._cost_mask(bs, gt_groups, masks, gt_mask)
117
146
 
118
- # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries)
147
+ # Set invalid values (NaNs and infinities) to 0
119
148
  C[C.isnan() | C.isinf()] = 0.0
120
149
 
121
150
  C = C.view(bs, nq, -1).cpu()
@@ -158,28 +187,49 @@ class HungarianMatcher(nn.Module):
158
187
 
159
188
 
160
189
  def get_cdn_group(
161
- batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
162
- ):
190
+ batch: Dict[str, Any],
191
+ num_classes: int,
192
+ num_queries: int,
193
+ class_embed: torch.Tensor,
194
+ num_dn: int = 100,
195
+ cls_noise_ratio: float = 0.5,
196
+ box_noise_scale: float = 1.0,
197
+ training: bool = False,
198
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[Dict[str, Any]]]:
163
199
  """
164
- Get contrastive denoising training group with positive and negative samples from ground truths.
200
+ Generate contrastive denoising training group with positive and negative samples from ground truths.
201
+
202
+ This function creates denoising queries for contrastive denoising training by adding noise to ground truth
203
+ bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.
165
204
 
166
205
  Args:
167
- batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape (num_gts, )), 'gt_bboxes'
168
- (torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int]) which is a list of batch size length
169
- indicating the number of gts of each image.
170
- num_classes (int): Number of classes.
171
- num_queries (int): Number of queries.
172
- class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
173
- num_dn (int, optional): Number of denoising queries.
174
- cls_noise_ratio (float, optional): Noise ratio for class labels.
175
- box_noise_scale (float, optional): Noise scale for bounding box coordinates.
176
- training (bool, optional): If it's in training mode.
206
+ batch (Dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
207
+ 'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (List[int]) indicating number of
208
+ ground truths per image.
209
+ num_classes (int): Total number of object classes.
210
+ num_queries (int): Number of object queries.
211
+ class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
212
+ num_dn (int): Number of denoising queries to generate.
213
+ cls_noise_ratio (float): Noise ratio for class labels.
214
+ box_noise_scale (float): Noise scale for bounding box coordinates.
215
+ training (bool): Whether model is in training mode.
177
216
 
178
217
  Returns:
179
- padding_cls (Optional[torch.Tensor]): The modified class embeddings for denoising.
180
- padding_bbox (Optional[torch.Tensor]): The modified bounding boxes for denoising.
181
- attn_mask (Optional[torch.Tensor]): The attention mask for denoising.
182
- dn_meta (Optional[Dict]): Meta information for denoising.
218
+ padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
219
+ padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
220
+ attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
221
+ dn_meta (Dict[str, Any] | None): Meta information dictionary containing denoising parameters.
222
+
223
+ Examples:
224
+ Generate denoising group for training
225
+ >>> batch = {
226
+ ... "cls": torch.tensor([0, 1, 2]),
227
+ ... "bboxes": torch.rand(3, 4),
228
+ ... "batch_idx": torch.tensor([0, 0, 1]),
229
+ ... "gt_groups": [2, 1],
230
+ ... }
231
+ >>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
232
+ >>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
183
233
  """
184
234
  if (not training) or num_dn <= 0 or batch is None:
185
235
  return None, None, None, None
@@ -197,7 +247,7 @@ def get_cdn_group(
197
247
  gt_bbox = batch["bboxes"] # bs*num, 4
198
248
  b_idx = batch["batch_idx"]
199
249
 
200
- # Each group has positive and negative queries.
250
+ # Each group has positive and negative queries
201
251
  dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
202
252
  dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
203
253
  dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
@@ -207,10 +257,10 @@ def get_cdn_group(
207
257
  neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
208
258
 
209
259
  if cls_noise_ratio > 0:
210
- # Half of bbox prob
260
+ # Apply class label noise to half of the samples
211
261
  mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
212
262
  idx = torch.nonzero(mask).squeeze(-1)
213
- # Randomly put a new one here
263
+ # Randomly assign new class labels
214
264
  new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
215
265
  dn_cls[idx] = new_label
216
266
 
@@ -229,7 +279,6 @@ def get_cdn_group(
229
279
  dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
230
280
 
231
281
  num_dn = int(max_nums * 2 * num_group) # total denoising queries
232
- # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
233
282
  dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
234
283
  padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
235
284
  padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)