dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
3
7
  import torch
4
8
  import torch.nn as nn
5
9
  import torch.nn.functional as F
@@ -11,15 +15,14 @@ from .ops import HungarianMatcher
11
15
 
12
16
 
13
17
  class DETRLoss(nn.Module):
14
- """
15
- DETR (DEtection TRansformer) Loss class for calculating various loss components.
18
+ """DETR (DEtection TRansformer) Loss class for calculating various loss components.
16
19
 
17
- This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
18
- DETR object detection model.
20
+ This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
21
+ object detection model.
19
22
 
20
23
  Attributes:
21
24
  nc (int): Number of classes.
22
- loss_gain (dict): Coefficients for different loss components.
25
+ loss_gain (dict[str, float]): Coefficients for different loss components.
23
26
  aux_loss (bool): Whether to compute auxiliary losses.
24
27
  use_fl (bool): Whether to use FocalLoss.
25
28
  use_vfl (bool): Whether to use VarifocalLoss.
@@ -33,32 +36,31 @@ class DETRLoss(nn.Module):
33
36
 
34
37
  def __init__(
35
38
  self,
36
- nc=80,
37
- loss_gain=None,
38
- aux_loss=True,
39
- use_fl=True,
40
- use_vfl=False,
41
- use_uni_match=False,
42
- uni_match_ind=0,
43
- gamma=1.5,
44
- alpha=0.25,
39
+ nc: int = 80,
40
+ loss_gain: dict[str, float] | None = None,
41
+ aux_loss: bool = True,
42
+ use_fl: bool = True,
43
+ use_vfl: bool = False,
44
+ use_uni_match: bool = False,
45
+ uni_match_ind: int = 0,
46
+ gamma: float = 1.5,
47
+ alpha: float = 0.25,
45
48
  ):
46
- """
47
- Initialize DETR loss function with customizable components and gains.
49
+ """Initialize DETR loss function with customizable components and gains.
48
50
 
49
51
  Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
50
52
  losses and various loss types.
51
53
 
52
54
  Args:
53
55
  nc (int): Number of classes.
54
- loss_gain (dict): Coefficients for different loss components.
56
+ loss_gain (dict[str, float], optional): Coefficients for different loss components.
55
57
  aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
56
58
  use_fl (bool): Whether to use FocalLoss.
57
59
  use_vfl (bool): Whether to use VarifocalLoss.
58
60
  use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
59
61
  uni_match_ind (int): Index of fixed layer for uni_match.
60
62
  gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
61
- alpha (float | list): The balancing factor used to address class imbalance.
63
+ alpha (float): The balancing factor used to address class imbalance.
62
64
  """
63
65
  super().__init__()
64
66
 
@@ -75,19 +77,20 @@ class DETRLoss(nn.Module):
75
77
  self.uni_match_ind = uni_match_ind
76
78
  self.device = None
77
79
 
78
- def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
79
- """
80
- Compute classification loss based on predictions, target values, and ground truth scores.
80
+ def _get_loss_class(
81
+ self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
82
+ ) -> dict[str, torch.Tensor]:
83
+ """Compute classification loss based on predictions, target values, and ground truth scores.
81
84
 
82
85
  Args:
83
- pred_scores (torch.Tensor): Predicted class scores with shape (batch_size, num_queries, num_classes).
84
- targets (torch.Tensor): Target class indices with shape (batch_size, num_queries).
85
- gt_scores (torch.Tensor): Ground truth confidence scores with shape (batch_size, num_queries).
86
+ pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
87
+ targets (torch.Tensor): Target class indices with shape (B, N).
88
+ gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
86
89
  num_gts (int): Number of ground truth objects.
87
90
  postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
88
91
 
89
92
  Returns:
90
- loss_cls (torch.Tensor): Classification loss value.
93
+ (dict[str, torch.Tensor]): Dictionary containing classification loss value.
91
94
 
92
95
  Notes:
93
96
  The function supports different classification loss types:
@@ -115,22 +118,20 @@ class DETRLoss(nn.Module):
115
118
 
116
119
  return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
117
120
 
118
- def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
119
- """
120
- Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
121
+ def _get_loss_bbox(
122
+ self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
123
+ ) -> dict[str, torch.Tensor]:
124
+ """Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
121
125
 
122
126
  Args:
123
- pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
124
- gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4), where N is the total
125
- number of ground truth boxes.
126
- postfix (str): String to append to the loss names for identification in multi-loss scenarios.
127
+ pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
128
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
129
+ postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
127
130
 
128
131
  Returns:
129
- loss (dict): Dictionary containing:
130
- - loss_bbox{postfix} (torch.Tensor): L1 loss between predicted and ground truth boxes,
131
- scaled by the bbox loss gain.
132
- - loss_giou{postfix} (torch.Tensor): GIoU loss between predicted and ground truth boxes,
133
- scaled by the giou loss gain.
132
+ (dict[str, torch.Tensor]): Dictionary containing:
133
+ - loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
134
+ - loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
134
135
 
135
136
  Notes:
136
137
  If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
@@ -184,32 +185,31 @@ class DETRLoss(nn.Module):
184
185
 
185
186
  def _get_loss_aux(
186
187
  self,
187
- pred_bboxes,
188
- pred_scores,
189
- gt_bboxes,
190
- gt_cls,
191
- gt_groups,
192
- match_indices=None,
193
- postfix="",
194
- masks=None,
195
- gt_mask=None,
196
- ):
197
- """
198
- Get auxiliary losses for intermediate decoder layers.
188
+ pred_bboxes: torch.Tensor,
189
+ pred_scores: torch.Tensor,
190
+ gt_bboxes: torch.Tensor,
191
+ gt_cls: torch.Tensor,
192
+ gt_groups: list[int],
193
+ match_indices: list[tuple] | None = None,
194
+ postfix: str = "",
195
+ masks: torch.Tensor | None = None,
196
+ gt_mask: torch.Tensor | None = None,
197
+ ) -> dict[str, torch.Tensor]:
198
+ """Get auxiliary losses for intermediate decoder layers.
199
199
 
200
200
  Args:
201
201
  pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
202
202
  pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
203
203
  gt_bboxes (torch.Tensor): Ground truth bounding boxes.
204
204
  gt_cls (torch.Tensor): Ground truth classes.
205
- gt_groups (List[int]): Number of ground truths per image.
206
- match_indices (List[tuple], optional): Pre-computed matching indices.
207
- postfix (str): String to append to loss names.
205
+ gt_groups (list[int]): Number of ground truths per image.
206
+ match_indices (list[tuple], optional): Pre-computed matching indices.
207
+ postfix (str, optional): String to append to loss names.
208
208
  masks (torch.Tensor, optional): Predicted masks if using segmentation.
209
209
  gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
210
210
 
211
211
  Returns:
212
- (dict): Dictionary of auxiliary losses.
212
+ (dict[str, torch.Tensor]): Dictionary of auxiliary losses.
213
213
  """
214
214
  # NOTE: loss class, bbox, giou, mask, dice
215
215
  loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
@@ -255,32 +255,34 @@ class DETRLoss(nn.Module):
255
255
  return loss
256
256
 
257
257
  @staticmethod
258
- def _get_index(match_indices):
259
- """
260
- Extract batch indices, source indices, and destination indices from match indices.
258
+ def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
259
+ """Extract batch indices, source indices, and destination indices from match indices.
261
260
 
262
261
  Args:
263
- match_indices (List[tuple]): List of tuples containing matched indices.
262
+ match_indices (list[tuple]): List of tuples containing matched indices.
264
263
 
265
264
  Returns:
266
- (tuple): Tuple containing (batch_idx, src_idx) and dst_idx.
265
+ batch_idx (tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
266
+ dst_idx (torch.Tensor): Destination indices.
267
267
  """
268
268
  batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
269
269
  src_idx = torch.cat([src for (src, _) in match_indices])
270
270
  dst_idx = torch.cat([dst for (_, dst) in match_indices])
271
271
  return (batch_idx, src_idx), dst_idx
272
272
 
273
- def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
274
- """
275
- Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
273
+ def _get_assigned_bboxes(
274
+ self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
275
+ ) -> tuple[torch.Tensor, torch.Tensor]:
276
+ """Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
276
277
 
277
278
  Args:
278
279
  pred_bboxes (torch.Tensor): Predicted bounding boxes.
279
280
  gt_bboxes (torch.Tensor): Ground truth bounding boxes.
280
- match_indices (List[tuple]): List of tuples containing matched indices.
281
+ match_indices (list[tuple]): List of tuples containing matched indices.
281
282
 
282
283
  Returns:
283
- (tuple): Tuple containing assigned predictions and ground truths.
284
+ pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
285
+ gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
284
286
  """
285
287
  pred_assigned = torch.cat(
286
288
  [
@@ -298,32 +300,31 @@ class DETRLoss(nn.Module):
298
300
 
299
301
  def _get_loss(
300
302
  self,
301
- pred_bboxes,
302
- pred_scores,
303
- gt_bboxes,
304
- gt_cls,
305
- gt_groups,
306
- masks=None,
307
- gt_mask=None,
308
- postfix="",
309
- match_indices=None,
310
- ):
311
- """
312
- Calculate losses for a single prediction layer.
303
+ pred_bboxes: torch.Tensor,
304
+ pred_scores: torch.Tensor,
305
+ gt_bboxes: torch.Tensor,
306
+ gt_cls: torch.Tensor,
307
+ gt_groups: list[int],
308
+ masks: torch.Tensor | None = None,
309
+ gt_mask: torch.Tensor | None = None,
310
+ postfix: str = "",
311
+ match_indices: list[tuple] | None = None,
312
+ ) -> dict[str, torch.Tensor]:
313
+ """Calculate losses for a single prediction layer.
313
314
 
314
315
  Args:
315
316
  pred_bboxes (torch.Tensor): Predicted bounding boxes.
316
317
  pred_scores (torch.Tensor): Predicted class scores.
317
318
  gt_bboxes (torch.Tensor): Ground truth bounding boxes.
318
319
  gt_cls (torch.Tensor): Ground truth classes.
319
- gt_groups (List[int]): Number of ground truths per image.
320
+ gt_groups (list[int]): Number of ground truths per image.
320
321
  masks (torch.Tensor, optional): Predicted masks if using segmentation.
321
322
  gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
322
- postfix (str): String to append to loss names.
323
- match_indices (List[tuple], optional): Pre-computed matching indices.
323
+ postfix (str, optional): String to append to loss names.
324
+ match_indices (list[tuple], optional): Pre-computed matching indices.
324
325
 
325
326
  Returns:
326
- (dict): Dictionary of losses.
327
+ (dict[str, torch.Tensor]): Dictionary of losses.
327
328
  """
328
329
  if match_indices is None:
329
330
  match_indices = self.matcher(
@@ -347,22 +348,25 @@ class DETRLoss(nn.Module):
347
348
  # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
348
349
  }
349
350
 
350
- def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
351
- """
352
- Calculate loss for predicted bounding boxes and scores.
351
+ def forward(
352
+ self,
353
+ pred_bboxes: torch.Tensor,
354
+ pred_scores: torch.Tensor,
355
+ batch: dict[str, Any],
356
+ postfix: str = "",
357
+ **kwargs: Any,
358
+ ) -> dict[str, torch.Tensor]:
359
+ """Calculate loss for predicted bounding boxes and scores.
353
360
 
354
361
  Args:
355
- pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
356
- pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
357
- batch (dict): Batch information containing:
358
- cls (torch.Tensor): Ground truth classes, shape [num_gts].
359
- bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
360
- gt_groups (List[int]): Number of ground truths for each image in the batch.
361
- postfix (str): Postfix for loss names.
362
+ pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
363
+ pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
364
+ batch (dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
365
+ postfix (str, optional): Postfix for loss names.
362
366
  **kwargs (Any): Additional arguments, may include 'match_indices'.
363
367
 
364
368
  Returns:
365
- (dict): Computed losses, including main and auxiliary (if enabled).
369
+ (dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
366
370
 
367
371
  Notes:
368
372
  Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
@@ -387,26 +391,31 @@ class DETRLoss(nn.Module):
387
391
 
388
392
 
389
393
  class RTDETRDetectionLoss(DETRLoss):
390
- """
391
- Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
394
+ """Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
392
395
 
393
396
  This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
394
397
  an additional denoising training loss when provided with denoising metadata.
395
398
  """
396
399
 
397
- def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
398
- """
399
- Forward pass to compute detection loss with optional denoising loss.
400
+ def forward(
401
+ self,
402
+ preds: tuple[torch.Tensor, torch.Tensor],
403
+ batch: dict[str, Any],
404
+ dn_bboxes: torch.Tensor | None = None,
405
+ dn_scores: torch.Tensor | None = None,
406
+ dn_meta: dict[str, Any] | None = None,
407
+ ) -> dict[str, torch.Tensor]:
408
+ """Forward pass to compute detection loss with optional denoising loss.
400
409
 
401
410
  Args:
402
- preds (tuple): Tuple containing predicted bounding boxes and scores.
403
- batch (dict): Batch data containing ground truth information.
411
+ preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
412
+ batch (dict[str, Any]): Batch data containing ground truth information.
404
413
  dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
405
414
  dn_scores (torch.Tensor, optional): Denoising scores.
406
- dn_meta (dict, optional): Metadata for denoising.
415
+ dn_meta (dict[str, Any], optional): Metadata for denoising.
407
416
 
408
417
  Returns:
409
- (dict): Dictionary containing total loss and denoising loss if applicable.
418
+ (dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
410
419
  """
411
420
  pred_bboxes, pred_scores = preds
412
421
  total_loss = super().forward(pred_bboxes, pred_scores, batch)
@@ -429,17 +438,18 @@ class RTDETRDetectionLoss(DETRLoss):
429
438
  return total_loss
430
439
 
431
440
  @staticmethod
432
- def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
433
- """
434
- Get match indices for denoising.
441
+ def get_dn_match_indices(
442
+ dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
443
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
444
+ """Get match indices for denoising.
435
445
 
436
446
  Args:
437
- dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
447
+ dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
438
448
  dn_num_group (int): Number of denoising groups.
439
- gt_groups (List[int]): List of integers representing number of ground truths per image.
449
+ gt_groups (list[int]): List of integers representing number of ground truths per image.
440
450
 
441
451
  Returns:
442
- (List[tuple]): List of tuples containing matched indices for denoising.
452
+ (list[tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
443
453
  """
444
454
  dn_match_indices = []
445
455
  idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
@@ -447,8 +457,9 @@ class RTDETRDetectionLoss(DETRLoss):
447
457
  if num_gt > 0:
448
458
  gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
449
459
  gt_idx = gt_idx.repeat(dn_num_group)
450
- assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
451
- f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
460
+ assert len(dn_pos_idx[i]) == len(gt_idx), (
461
+ f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
462
+ )
452
463
  dn_match_indices.append((dn_pos_idx[i], gt_idx))
453
464
  else:
454
465
  dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))