ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +118 -30
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +5 -5
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +13 -19
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +67 -88
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +21 -18
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +12 -13
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +20 -11
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +22 -11
  53. ultralytics/models/nas/predict.py +9 -4
  54. ultralytics/models/nas/val.py +5 -5
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +18 -15
  57. ultralytics/models/rtdetr/train.py +20 -16
  58. ultralytics/models/rtdetr/val.py +42 -6
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +24 -3
  73. ultralytics/models/yolo/classify/train.py +77 -10
  74. ultralytics/models/yolo/classify/val.py +40 -15
  75. ultralytics/models/yolo/detect/predict.py +23 -10
  76. ultralytics/models/yolo/detect/train.py +85 -15
  77. ultralytics/models/yolo/detect/val.py +145 -21
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +12 -4
  80. ultralytics/models/yolo/obb/train.py +7 -0
  81. ultralytics/models/yolo/obb/val.py +25 -7
  82. ultralytics/models/yolo/pose/predict.py +22 -6
  83. ultralytics/models/yolo/pose/train.py +17 -1
  84. ultralytics/models/yolo/pose/val.py +46 -21
  85. ultralytics/models/yolo/segment/predict.py +22 -8
  86. ultralytics/models/yolo/segment/train.py +6 -0
  87. ultralytics/models/yolo/segment/val.py +100 -14
  88. ultralytics/models/yolo/world/train.py +38 -8
  89. ultralytics/models/yolo/world/train_world.py +39 -10
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +3 -0
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +221 -69
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +32 -27
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +116 -35
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +13 -9
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +112 -45
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +61 -53
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +64 -45
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +181 -33
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +8 -16
  149. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.89.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -12,21 +12,22 @@ from .ops import HungarianMatcher
12
12
 
13
13
  class DETRLoss(nn.Module):
14
14
  """
15
- DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
16
- DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
17
- losses.
15
+ DETR (DEtection TRansformer) Loss class for calculating various loss components.
16
+
17
+ This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
18
+ DETR object detection model.
18
19
 
19
20
  Attributes:
20
- nc (int): The number of classes.
21
- loss_gain (dict): Coefficients for different loss components.
21
+ nc (int): Number of classes.
22
+ loss_gain (Dict): Coefficients for different loss components.
22
23
  aux_loss (bool): Whether to compute auxiliary losses.
23
- use_fl (bool): Use FocalLoss or not.
24
- use_vfl (bool): Use VarifocalLoss or not.
25
- use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
26
- uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
24
+ use_fl (bool): Whether to use FocalLoss.
25
+ use_vfl (bool): Whether to use VarifocalLoss.
26
+ use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.
27
+ uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.
27
28
  matcher (HungarianMatcher): Object to compute matching cost and indices.
28
- fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
29
- vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
29
+ fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.
30
+ vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.
30
31
  device (torch.device): Device on which tensors are stored.
31
32
  """
32
33
 
@@ -36,16 +37,16 @@ class DETRLoss(nn.Module):
36
37
  """
37
38
  Initialize DETR loss function with customizable components and gains.
38
39
 
39
- Uses default loss_gain if not provided. Initializes HungarianMatcher with
40
- preset cost gains. Supports auxiliary losses and various loss types.
40
+ Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
41
+ losses and various loss types.
41
42
 
42
43
  Args:
43
44
  nc (int): Number of classes.
44
- loss_gain (dict): Coefficients for different loss components.
45
- aux_loss (bool): Use auxiliary losses from each decoder layer.
46
- use_fl (bool): Use FocalLoss.
47
- use_vfl (bool): Use VarifocalLoss.
48
- use_uni_match (bool): Use fixed layer for auxiliary branch label assignment.
45
+ loss_gain (Dict): Coefficients for different loss components.
46
+ aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
47
+ use_fl (bool): Whether to use FocalLoss.
48
+ use_vfl (bool): Whether to use VarifocalLoss.
49
+ use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
49
50
  uni_match_ind (int): Index of fixed layer for uni_match.
50
51
  """
51
52
  super().__init__()
@@ -64,7 +65,7 @@ class DETRLoss(nn.Module):
64
65
  self.device = None
65
66
 
66
67
  def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
67
- """Computes the classification loss based on predictions, target values, and ground truth scores."""
68
+ """Compute classification loss based on predictions, target values, and ground truth scores."""
68
69
  # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
69
70
  name_class = f"loss_class{postfix}"
70
71
  bs, nq = pred_scores.shape[:2]
@@ -86,7 +87,7 @@ class DETRLoss(nn.Module):
86
87
  return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
87
88
 
88
89
  def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
89
- """Computes bounding box and GIoU losses for predicted and ground truth bounding boxes."""
90
+ """Compute bounding box and GIoU losses for predicted and ground truth bounding boxes."""
90
91
  # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
91
92
  name_bbox = f"loss_bbox{postfix}"
92
93
  name_giou = f"loss_giou{postfix}"
@@ -146,7 +147,23 @@ class DETRLoss(nn.Module):
146
147
  masks=None,
147
148
  gt_mask=None,
148
149
  ):
149
- """Get auxiliary losses."""
150
+ """
151
+ Get auxiliary losses for intermediate decoder layers.
152
+
153
+ Args:
154
+ pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
155
+ pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
156
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes.
157
+ gt_cls (torch.Tensor): Ground truth classes.
158
+ gt_groups (List[int]): Number of ground truths per image.
159
+ match_indices (List[tuple], optional): Pre-computed matching indices.
160
+ postfix (str): String to append to loss names.
161
+ masks (torch.Tensor, optional): Predicted masks if using segmentation.
162
+ gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
163
+
164
+ Returns:
165
+ (Dict): Dictionary of auxiliary losses.
166
+ """
150
167
  # NOTE: loss class, bbox, giou, mask, dice
151
168
  loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
152
169
  if match_indices is None and self.use_uni_match:
@@ -192,14 +209,32 @@ class DETRLoss(nn.Module):
192
209
 
193
210
  @staticmethod
194
211
  def _get_index(match_indices):
195
- """Returns batch indices, source indices, and destination indices from provided match indices."""
212
+ """
213
+ Extract batch indices, source indices, and destination indices from match indices.
214
+
215
+ Args:
216
+ match_indices (List[tuple]): List of tuples containing matched indices.
217
+
218
+ Returns:
219
+ (tuple): Tuple containing (batch_idx, src_idx) and dst_idx.
220
+ """
196
221
  batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
197
222
  src_idx = torch.cat([src for (src, _) in match_indices])
198
223
  dst_idx = torch.cat([dst for (_, dst) in match_indices])
199
224
  return (batch_idx, src_idx), dst_idx
200
225
 
201
226
  def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
202
- """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
227
+ """
228
+ Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
229
+
230
+ Args:
231
+ pred_bboxes (torch.Tensor): Predicted bounding boxes.
232
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes.
233
+ match_indices (List[tuple]): List of tuples containing matched indices.
234
+
235
+ Returns:
236
+ (tuple): Tuple containing assigned predictions and ground truths.
237
+ """
203
238
  pred_assigned = torch.cat(
204
239
  [
205
240
  t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
@@ -226,7 +261,23 @@ class DETRLoss(nn.Module):
226
261
  postfix="",
227
262
  match_indices=None,
228
263
  ):
229
- """Get losses."""
264
+ """
265
+ Calculate losses for a single prediction layer.
266
+
267
+ Args:
268
+ pred_bboxes (torch.Tensor): Predicted bounding boxes.
269
+ pred_scores (torch.Tensor): Predicted class scores.
270
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes.
271
+ gt_cls (torch.Tensor): Ground truth classes.
272
+ gt_groups (List[int]): Number of ground truths per image.
273
+ masks (torch.Tensor, optional): Predicted masks if using segmentation.
274
+ gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
275
+ postfix (str): String to append to loss names.
276
+ match_indices (List[tuple], optional): Pre-computed matching indices.
277
+
278
+ Returns:
279
+ (Dict): Dictionary of losses.
280
+ """
230
281
  if match_indices is None:
231
282
  match_indices = self.matcher(
232
283
  pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
@@ -256,7 +307,7 @@ class DETRLoss(nn.Module):
256
307
  Args:
257
308
  pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4].
258
309
  pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes].
259
- batch (dict): Batch information containing:
310
+ batch (Dict): Batch information containing:
260
311
  cls (torch.Tensor): Ground truth classes, shape [num_gts].
261
312
  bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4].
262
313
  gt_groups (List[int]): Number of ground truths for each image in the batch.
@@ -264,9 +315,9 @@ class DETRLoss(nn.Module):
264
315
  **kwargs (Any): Additional arguments, may include 'match_indices'.
265
316
 
266
317
  Returns:
267
- (dict): Computed losses, including main and auxiliary (if enabled).
318
+ (Dict): Computed losses, including main and auxiliary (if enabled).
268
319
 
269
- Note:
320
+ Notes:
270
321
  Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
271
322
  self.aux_loss is True.
272
323
  """
@@ -298,17 +349,17 @@ class RTDETRDetectionLoss(DETRLoss):
298
349
 
299
350
  def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
300
351
  """
301
- Forward pass to compute the detection loss.
352
+ Forward pass to compute detection loss with optional denoising loss.
302
353
 
303
354
  Args:
304
- preds (tuple): Predicted bounding boxes and scores.
305
- batch (dict): Batch data containing ground truth information.
306
- dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
307
- dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
308
- dn_meta (dict, optional): Metadata for denoising. Default is None.
355
+ preds (tuple): Tuple containing predicted bounding boxes and scores.
356
+ batch (Dict): Batch data containing ground truth information.
357
+ dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
358
+ dn_scores (torch.Tensor, optional): Denoising scores.
359
+ dn_meta (Dict, optional): Metadata for denoising.
309
360
 
310
361
  Returns:
311
- (dict): Dictionary containing the total loss and, if applicable, the denoising loss.
362
+ (Dict): Dictionary containing total loss and denoising loss if applicable.
312
363
  """
313
364
  pred_bboxes, pred_scores = preds
314
365
  total_loss = super().forward(pred_bboxes, pred_scores, batch)
@@ -333,12 +384,12 @@ class RTDETRDetectionLoss(DETRLoss):
333
384
  @staticmethod
334
385
  def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
335
386
  """
336
- Get the match indices for denoising.
387
+ Get match indices for denoising.
337
388
 
338
389
  Args:
339
390
  dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
340
391
  dn_num_group (int): Number of denoising groups.
341
- gt_groups (List[int]): List of integers representing the number of ground truths for each image.
392
+ gt_groups (List[int]): List of integers representing number of ground truths per image.
342
393
 
343
394
  Returns:
344
395
  (List[tuple]): List of tuples containing matched indices for denoising.
@@ -18,7 +18,7 @@ class HungarianMatcher(nn.Module):
18
18
  function that considers classification scores, bounding box coordinates, and optionally, mask predictions.
19
19
 
20
20
  Attributes:
21
- cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
21
+ cost_gain (Dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
22
22
  use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
23
23
  with_mask (bool): Indicates whether the model makes mask predictions.
24
24
  num_sample_points (int): The number of sample points used in mask cost calculation.
@@ -26,13 +26,12 @@ class HungarianMatcher(nn.Module):
26
26
  gamma (float): The gamma factor in Focal Loss calculation.
27
27
 
28
28
  Methods:
29
- forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the
30
- assignment between predictions and ground truths for a batch.
31
- _cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
29
+ forward: Computes the assignment between predictions and ground truths for a batch.
30
+ _cost_mask: Computes the mask cost and dice cost if masks are predicted.
32
31
  """
33
32
 
34
33
  def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
35
- """Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
34
+ """Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
36
35
  super().__init__()
37
36
  if cost_gain is None:
38
37
  cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
@@ -45,24 +44,21 @@ class HungarianMatcher(nn.Module):
45
44
 
46
45
  def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
47
46
  """
48
- Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth
49
- (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between
50
- predictions and ground truth based on these costs.
47
+ Forward pass for HungarianMatcher. Computes costs based on prediction and ground truth and finds the optimal
48
+ matching between predictions and ground truth based on these costs.
51
49
 
52
50
  Args:
53
- pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4].
54
- pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes].
55
- gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ].
56
- gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4].
51
+ pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
52
+ pred_scores (torch.Tensor): Predicted scores with shape (batch_size, num_queries, num_classes).
53
+ gt_cls (torch.Tensor): Ground truth classes with shape (num_gts, ).
54
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
57
55
  gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for
58
56
  each image.
59
- masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width].
60
- Defaults to None.
61
- gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width].
62
- Defaults to None.
57
+ masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
58
+ gt_mask (List[torch.Tensor], optional): List of ground truth masks, each with shape (num_masks, Height, Width).
63
59
 
64
60
  Returns:
65
- (List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
61
+ (List[Tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where:
66
62
  - index_i is the tensor of indices of the selected predictions (in order)
67
63
  - index_j is the tensor of indices of the corresponding selected ground truth targets (in order)
68
64
  For each batch element, it holds:
@@ -74,10 +70,10 @@ class HungarianMatcher(nn.Module):
74
70
  return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
75
71
 
76
72
  # We flatten to compute the cost matrices in a batch
77
- # [batch_size * num_queries, num_classes]
73
+ # (batch_size * num_queries, num_classes)
78
74
  pred_scores = pred_scores.detach().view(-1, nc)
79
75
  pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
80
- # [batch_size * num_queries, 4]
76
+ # (batch_size * num_queries, 4)
81
77
  pred_bboxes = pred_bboxes.detach().view(-1, 4)
82
78
 
83
79
  # Compute the classification cost
@@ -151,26 +147,25 @@ def get_cdn_group(
151
147
  batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
152
148
  ):
153
149
  """
154
- Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
155
- and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
156
- and returns the modified labels, bounding boxes, attention mask and meta information.
150
+ Get contrastive denoising training group with positive and negative samples from ground truths.
157
151
 
158
152
  Args:
159
- batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
160
- (torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
153
+ batch (Dict): A dict that includes 'gt_cls' (torch.Tensor with shape (num_gts, )), 'gt_bboxes'
154
+ (torch.Tensor with shape (num_gts, 4)), 'gt_groups' (List[int]) which is a list of batch size length
161
155
  indicating the number of gts of each image.
162
156
  num_classes (int): Number of classes.
163
157
  num_queries (int): Number of queries.
164
158
  class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
165
- num_dn (int, optional): Number of denoising. Defaults to 100.
166
- cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
167
- box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
168
- training (bool, optional): If it's in training mode. Defaults to False.
159
+ num_dn (int, optional): Number of denoising queries.
160
+ cls_noise_ratio (float, optional): Noise ratio for class labels.
161
+ box_noise_scale (float, optional): Noise scale for bounding box coordinates.
162
+ training (bool, optional): If it's in training mode.
169
163
 
170
164
  Returns:
171
- (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
172
- bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
173
- is less than or equal to 0, the function returns None for all elements in the tuple.
165
+ padding_cls (Optional[torch.Tensor]): The modified class embeddings for denoising.
166
+ padding_bbox (Optional[torch.Tensor]): The modified bounding boxes for denoising.
167
+ attn_mask (Optional[torch.Tensor]): The attention mask for denoising.
168
+ dn_meta (Optional[Dict]): Meta information for denoising.
174
169
  """
175
170
  if (not training) or num_dn <= 0 or batch is None:
176
171
  return None, None, None, None
@@ -13,6 +13,17 @@ class ClassificationPredictor(BasePredictor):
13
13
  """
14
14
  A class extending the BasePredictor class for prediction based on a classification model.
15
15
 
16
+ This predictor handles the specific requirements of classification models, including preprocessing images
17
+ and postprocessing predictions to generate classification results.
18
+
19
+ Attributes:
20
+ args (Dict): Configuration arguments for the predictor.
21
+ _legacy_transform_name (str): Name of the legacy transform class for backward compatibility.
22
+
23
+ Methods:
24
+ preprocess: Convert input images to model-compatible format.
25
+ postprocess: Process model predictions into Results objects.
26
+
16
27
  Notes:
17
28
  - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
18
29
 
@@ -25,13 +36,13 @@ class ClassificationPredictor(BasePredictor):
25
36
  """
26
37
 
27
38
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
28
- """Initializes ClassificationPredictor setting the task to 'classify'."""
39
+ """Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'."""
29
40
  super().__init__(cfg, overrides, _callbacks)
30
41
  self.args.task = "classify"
31
42
  self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
32
43
 
33
44
  def preprocess(self, img):
34
- """Converts input image to model-compatible data type."""
45
+ """Convert input images to model-compatible tensor format with appropriate normalization."""
35
46
  if not isinstance(img, torch.Tensor):
36
47
  is_legacy_transform = any(
37
48
  self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
@@ -46,7 +57,17 @@ class ClassificationPredictor(BasePredictor):
46
57
  return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
47
58
 
48
59
  def postprocess(self, preds, img, orig_imgs):
49
- """Post-processes predictions to return Results objects."""
60
+ """
61
+ Process predictions to return Results objects with classification probabilities.
62
+
63
+ Args:
64
+ preds (torch.Tensor): Raw predictions from the model.
65
+ img (torch.Tensor): Input images after preprocessing.
66
+ orig_imgs (List[np.ndarray] | torch.Tensor): Original images before preprocessing.
67
+
68
+ Returns:
69
+ (List[Results]): List of Results objects containing classification results for each image.
70
+ """
50
71
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
51
72
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
52
73
 
@@ -17,8 +17,28 @@ class ClassificationTrainer(BaseTrainer):
17
17
  """
18
18
  A class extending the BaseTrainer class for training based on a classification model.
19
19
 
20
- Notes:
21
- - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
20
+ This trainer handles the training process for image classification tasks, supporting both YOLO classification models
21
+ and torchvision models.
22
+
23
+ Attributes:
24
+ model (ClassificationModel): The classification model to be trained.
25
+ data (Dict): Dictionary containing dataset information including class names and number of classes.
26
+ loss_names (List[str]): Names of the loss functions used during training.
27
+ validator (ClassificationValidator): Validator instance for model evaluation.
28
+
29
+ Methods:
30
+ set_model_attributes: Set the model's class names from the loaded dataset.
31
+ get_model: Return a modified PyTorch model configured for training.
32
+ setup_model: Load, create or download model for classification.
33
+ build_dataset: Create a ClassificationDataset instance.
34
+ get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
35
+ preprocess_batch: Preprocess a batch of images and classes.
36
+ progress_string: Return a formatted string showing training progress.
37
+ get_validator: Return an instance of ClassificationValidator.
38
+ label_loss_items: Return a loss dict with labelled training loss items.
39
+ plot_metrics: Plot metrics from a CSV file.
40
+ final_eval: Evaluate trained model and save validation results.
41
+ plot_training_samples: Plot training samples with their annotations.
22
42
 
23
43
  Examples:
24
44
  >>> from ultralytics.models.yolo.classify import ClassificationTrainer
@@ -41,7 +61,17 @@ class ClassificationTrainer(BaseTrainer):
41
61
  self.model.names = self.data["names"]
42
62
 
43
63
  def get_model(self, cfg=None, weights=None, verbose=True):
44
- """Returns a modified PyTorch model configured for training YOLO."""
64
+ """
65
+ Return a modified PyTorch model configured for training YOLO.
66
+
67
+ Args:
68
+ cfg (Any): Model configuration.
69
+ weights (Any): Pre-trained model weights.
70
+ verbose (bool): Whether to display model information.
71
+
72
+ Returns:
73
+ (ClassificationModel): Configured PyTorch model for classification.
74
+ """
45
75
  model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
46
76
  if weights:
47
77
  model.load(weights)
@@ -56,7 +86,12 @@ class ClassificationTrainer(BaseTrainer):
56
86
  return model
57
87
 
58
88
  def setup_model(self):
59
- """Load, create or download model for any task."""
89
+ """
90
+ Load, create or download model for classification tasks.
91
+
92
+ Returns:
93
+ (Any): Model checkpoint if applicable, otherwise None.
94
+ """
60
95
  import torchvision # scope for faster 'import ultralytics'
61
96
 
62
97
  if str(self.model) in torchvision.models.__dict__:
@@ -70,11 +105,32 @@ class ClassificationTrainer(BaseTrainer):
70
105
  return ckpt
71
106
 
72
107
  def build_dataset(self, img_path, mode="train", batch=None):
73
- """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
108
+ """
109
+ Create a ClassificationDataset instance given an image path and mode.
110
+
111
+ Args:
112
+ img_path (str): Path to the dataset images.
113
+ mode (str): Dataset mode ('train', 'val', or 'test').
114
+ batch (Any): Batch information (unused in this implementation).
115
+
116
+ Returns:
117
+ (ClassificationDataset): Dataset for the specified mode.
118
+ """
74
119
  return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
75
120
 
76
121
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
77
- """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
122
+ """
123
+ Return PyTorch DataLoader with transforms to preprocess images.
124
+
125
+ Args:
126
+ dataset_path (str): Path to the dataset.
127
+ batch_size (int): Number of images per batch.
128
+ rank (int): Process rank for distributed training.
129
+ mode (str): 'train', 'val', or 'test' mode.
130
+
131
+ Returns:
132
+ (torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
133
+ """
78
134
  with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
79
135
  dataset = self.build_dataset(dataset_path, mode)
80
136
 
@@ -112,9 +168,14 @@ class ClassificationTrainer(BaseTrainer):
112
168
 
113
169
  def label_loss_items(self, loss_items=None, prefix="train"):
114
170
  """
115
- Returns a loss dict with labelled training loss items tensor.
171
+ Return a loss dict with labelled training loss items tensor.
116
172
 
117
- Not needed for classification but necessary for segmentation & detection
173
+ Args:
174
+ loss_items (torch.Tensor, optional): Loss tensor items.
175
+ prefix (str): Prefix to prepend to loss names.
176
+
177
+ Returns:
178
+ (Dict[str, float] | List[str]): Dictionary of loss items or list of loss keys if loss_items is None.
118
179
  """
119
180
  keys = [f"{prefix}/{x}" for x in self.loss_names]
120
181
  if loss_items is None:
@@ -123,7 +184,7 @@ class ClassificationTrainer(BaseTrainer):
123
184
  return dict(zip(keys, loss_items))
124
185
 
125
186
  def plot_metrics(self):
126
- """Plots metrics from a CSV file."""
187
+ """Plot metrics from a CSV file."""
127
188
  plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
128
189
 
129
190
  def final_eval(self):
@@ -140,7 +201,13 @@ class ClassificationTrainer(BaseTrainer):
140
201
  self.run_callbacks("on_fit_epoch_end")
141
202
 
142
203
  def plot_training_samples(self, batch, ni):
143
- """Plots training samples with their annotations."""
204
+ """
205
+ Plot training samples with their annotations.
206
+
207
+ Args:
208
+ batch (Dict[str, torch.Tensor]): Batch containing images and class labels.
209
+ ni (int): Number of iterations.
210
+ """
144
211
  plot_images(
145
212
  images=batch["img"],
146
213
  batch_idx=torch.arange(len(batch["img"])),