dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,10 @@ from .ops import HungarianMatcher
15
15
 
16
16
 
17
17
  class DETRLoss(nn.Module):
18
- """
19
- DETR (DEtection TRansformer) Loss class for calculating various loss components.
18
+ """DETR (DEtection TRansformer) Loss class for calculating various loss components.
20
19
 
21
- This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
22
- DETR object detection model.
20
+ This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the DETR
21
+ object detection model.
23
22
 
24
23
  Attributes:
25
24
  nc (int): Number of classes.
@@ -47,8 +46,7 @@ class DETRLoss(nn.Module):
47
46
  gamma: float = 1.5,
48
47
  alpha: float = 0.25,
49
48
  ):
50
- """
51
- Initialize DETR loss function with customizable components and gains.
49
+ """Initialize DETR loss function with customizable components and gains.
52
50
 
53
51
  Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
54
52
  losses and various loss types.
@@ -82,8 +80,7 @@ class DETRLoss(nn.Module):
82
80
  def _get_loss_class(
83
81
  self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
84
82
  ) -> dict[str, torch.Tensor]:
85
- """
86
- Compute classification loss based on predictions, target values, and ground truth scores.
83
+ """Compute classification loss based on predictions, target values, and ground truth scores.
87
84
 
88
85
  Args:
89
86
  pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
@@ -124,8 +121,7 @@ class DETRLoss(nn.Module):
124
121
  def _get_loss_bbox(
125
122
  self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
126
123
  ) -> dict[str, torch.Tensor]:
127
- """
128
- Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
124
+ """Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
129
125
 
130
126
  Args:
131
127
  pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
@@ -199,8 +195,7 @@ class DETRLoss(nn.Module):
199
195
  masks: torch.Tensor | None = None,
200
196
  gt_mask: torch.Tensor | None = None,
201
197
  ) -> dict[str, torch.Tensor]:
202
- """
203
- Get auxiliary losses for intermediate decoder layers.
198
+ """Get auxiliary losses for intermediate decoder layers.
204
199
 
205
200
  Args:
206
201
  pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
@@ -261,8 +256,7 @@ class DETRLoss(nn.Module):
261
256
 
262
257
  @staticmethod
263
258
  def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
264
- """
265
- Extract batch indices, source indices, and destination indices from match indices.
259
+ """Extract batch indices, source indices, and destination indices from match indices.
266
260
 
267
261
  Args:
268
262
  match_indices (list[tuple]): List of tuples containing matched indices.
@@ -279,8 +273,7 @@ class DETRLoss(nn.Module):
279
273
  def _get_assigned_bboxes(
280
274
  self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
281
275
  ) -> tuple[torch.Tensor, torch.Tensor]:
282
- """
283
- Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
276
+ """Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
284
277
 
285
278
  Args:
286
279
  pred_bboxes (torch.Tensor): Predicted bounding boxes.
@@ -317,8 +310,7 @@ class DETRLoss(nn.Module):
317
310
  postfix: str = "",
318
311
  match_indices: list[tuple] | None = None,
319
312
  ) -> dict[str, torch.Tensor]:
320
- """
321
- Calculate losses for a single prediction layer.
313
+ """Calculate losses for a single prediction layer.
322
314
 
323
315
  Args:
324
316
  pred_bboxes (torch.Tensor): Predicted bounding boxes.
@@ -364,8 +356,7 @@ class DETRLoss(nn.Module):
364
356
  postfix: str = "",
365
357
  **kwargs: Any,
366
358
  ) -> dict[str, torch.Tensor]:
367
- """
368
- Calculate loss for predicted bounding boxes and scores.
359
+ """Calculate loss for predicted bounding boxes and scores.
369
360
 
370
361
  Args:
371
362
  pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
@@ -400,8 +391,7 @@ class DETRLoss(nn.Module):
400
391
 
401
392
 
402
393
  class RTDETRDetectionLoss(DETRLoss):
403
- """
404
- Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
394
+ """Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
405
395
 
406
396
  This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
407
397
  an additional denoising training loss when provided with denoising metadata.
@@ -415,8 +405,7 @@ class RTDETRDetectionLoss(DETRLoss):
415
405
  dn_scores: torch.Tensor | None = None,
416
406
  dn_meta: dict[str, Any] | None = None,
417
407
  ) -> dict[str, torch.Tensor]:
418
- """
419
- Forward pass to compute detection loss with optional denoising loss.
408
+ """Forward pass to compute detection loss with optional denoising loss.
420
409
 
421
410
  Args:
422
411
  preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
@@ -452,8 +441,7 @@ class RTDETRDetectionLoss(DETRLoss):
452
441
  def get_dn_match_indices(
453
442
  dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
454
443
  ) -> list[tuple[torch.Tensor, torch.Tensor]]:
455
- """
456
- Get match indices for denoising.
444
+ """Get match indices for denoising.
457
445
 
458
446
  Args:
459
447
  dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
@@ -14,8 +14,7 @@ from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
14
14
 
15
15
 
16
16
  class HungarianMatcher(nn.Module):
17
- """
18
- A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
17
+ """A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
19
18
 
20
19
  HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
21
20
  function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
@@ -56,8 +55,7 @@ class HungarianMatcher(nn.Module):
56
55
  alpha: float = 0.25,
57
56
  gamma: float = 2.0,
58
57
  ):
59
- """
60
- Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
58
+ """Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
61
59
 
62
60
  Args:
63
61
  cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
@@ -88,8 +86,7 @@ class HungarianMatcher(nn.Module):
88
86
  masks: torch.Tensor | None = None,
89
87
  gt_mask: list[torch.Tensor] | None = None,
90
88
  ) -> list[tuple[torch.Tensor, torch.Tensor]]:
91
- """
92
- Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
89
+ """Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
93
90
 
94
91
  This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
95
92
  mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
@@ -105,10 +102,10 @@ class HungarianMatcher(nn.Module):
105
102
  gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
106
103
 
107
104
  Returns:
108
- (list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
109
- (index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)
110
- and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
111
- For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
105
+ (list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple (index_i,
106
+ index_j), where index_i is the tensor of indices of the selected predictions (in order) and index_j is
107
+ the tensor of indices of the corresponding selected ground truth targets (in order).
108
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
112
109
  """
113
110
  bs, nq, nc = pred_scores.shape
114
111
 
@@ -198,16 +195,15 @@ def get_cdn_group(
198
195
  box_noise_scale: float = 1.0,
199
196
  training: bool = False,
200
197
  ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
201
- """
202
- Generate contrastive denoising training group with positive and negative samples from ground truths.
198
+ """Generate contrastive denoising training group with positive and negative samples from ground truths.
203
199
 
204
- This function creates denoising queries for contrastive denoising training by adding noise to ground truth
205
- bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.
200
+ This function creates denoising queries for contrastive denoising training by adding noise to ground truth bounding
201
+ boxes and class labels. It generates both positive and negative samples to improve model robustness.
206
202
 
207
203
  Args:
208
- batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
209
- 'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of
210
- ground truths per image.
204
+ batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)), 'gt_bboxes'
205
+ (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of ground truths
206
+ per image.
211
207
  num_classes (int): Total number of object classes.
212
208
  num_queries (int): Number of object queries.
213
209
  class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
@@ -11,11 +11,10 @@ from ultralytics.utils import DEFAULT_CFG, ops
11
11
 
12
12
 
13
13
  class ClassificationPredictor(BasePredictor):
14
- """
15
- A class extending the BasePredictor class for prediction based on a classification model.
14
+ """A class extending the BasePredictor class for prediction based on a classification model.
16
15
 
17
- This predictor handles the specific requirements of classification models, including preprocessing images
18
- and postprocessing predictions to generate classification results.
16
+ This predictor handles the specific requirements of classification models, including preprocessing images and
17
+ postprocessing predictions to generate classification results.
19
18
 
20
19
  Attributes:
21
20
  args (dict): Configuration arguments for the predictor.
@@ -24,20 +23,19 @@ class ClassificationPredictor(BasePredictor):
24
23
  preprocess: Convert input images to model-compatible format.
25
24
  postprocess: Process model predictions into Results objects.
26
25
 
27
- Notes:
28
- - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
29
-
30
26
  Examples:
31
27
  >>> from ultralytics.utils import ASSETS
32
28
  >>> from ultralytics.models.yolo.classify import ClassificationPredictor
33
29
  >>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
34
30
  >>> predictor = ClassificationPredictor(overrides=args)
35
31
  >>> predictor.predict_cli()
32
+
33
+ Notes:
34
+ - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
36
35
  """
37
36
 
38
37
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
- """
40
- Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
38
+ """Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
41
39
 
42
40
  This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
43
41
  tasks. It ensures the task is set to 'classify' regardless of input configuration.
@@ -72,8 +70,7 @@ class ClassificationPredictor(BasePredictor):
72
70
  return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
73
71
 
74
72
  def postprocess(self, preds, img, orig_imgs):
75
- """
76
- Process predictions to return Results objects with classification probabilities.
73
+ """Process predictions to return Results objects with classification probabilities.
77
74
 
78
75
  Args:
79
76
  preds (torch.Tensor): Raw predictions from the model.
@@ -17,8 +17,7 @@ from ultralytics.utils.torch_utils import is_parallel, torch_distributed_zero_fi
17
17
 
18
18
 
19
19
  class ClassificationTrainer(BaseTrainer):
20
- """
21
- A trainer class extending BaseTrainer for training image classification models.
20
+ """A trainer class extending BaseTrainer for training image classification models.
22
21
 
23
22
  This trainer handles the training process for image classification tasks, supporting both YOLO classification models
24
23
  and torchvision models with comprehensive dataset handling and validation.
@@ -51,8 +50,7 @@ class ClassificationTrainer(BaseTrainer):
51
50
  """
52
51
 
53
52
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
54
- """
55
- Initialize a ClassificationTrainer object.
53
+ """Initialize a ClassificationTrainer object.
56
54
 
57
55
  Args:
58
56
  cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
@@ -71,8 +69,7 @@ class ClassificationTrainer(BaseTrainer):
71
69
  self.model.names = self.data["names"]
72
70
 
73
71
  def get_model(self, cfg=None, weights=None, verbose: bool = True):
74
- """
75
- Return a modified PyTorch model configured for training YOLO classification.
72
+ """Return a modified PyTorch model configured for training YOLO classification.
76
73
 
77
74
  Args:
78
75
  cfg (Any, optional): Model configuration.
@@ -96,8 +93,7 @@ class ClassificationTrainer(BaseTrainer):
96
93
  return model
97
94
 
98
95
  def setup_model(self):
99
- """
100
- Load, create or download model for classification tasks.
96
+ """Load, create or download model for classification tasks.
101
97
 
102
98
  Returns:
103
99
  (Any): Model checkpoint if applicable, otherwise None.
@@ -115,8 +111,7 @@ class ClassificationTrainer(BaseTrainer):
115
111
  return ckpt
116
112
 
117
113
  def build_dataset(self, img_path: str, mode: str = "train", batch=None):
118
- """
119
- Create a ClassificationDataset instance given an image path and mode.
114
+ """Create a ClassificationDataset instance given an image path and mode.
120
115
 
121
116
  Args:
122
117
  img_path (str): Path to the dataset images.
@@ -129,8 +124,7 @@ class ClassificationTrainer(BaseTrainer):
129
124
  return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
130
125
 
131
126
  def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
132
- """
133
- Return PyTorch DataLoader with transforms to preprocess images.
127
+ """Return PyTorch DataLoader with transforms to preprocess images.
134
128
 
135
129
  Args:
136
130
  dataset_path (str): Path to the dataset.
@@ -177,8 +171,7 @@ class ClassificationTrainer(BaseTrainer):
177
171
  )
178
172
 
179
173
  def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
180
- """
181
- Return a loss dict with labeled training loss items tensor.
174
+ """Return a loss dict with labeled training loss items tensor.
182
175
 
183
176
  Args:
184
177
  loss_items (torch.Tensor, optional): Loss tensor items.
@@ -195,8 +188,7 @@ class ClassificationTrainer(BaseTrainer):
195
188
  return dict(zip(keys, loss_items))
196
189
 
197
190
  def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
198
- """
199
- Plot training samples with their annotations.
191
+ """Plot training samples with their annotations.
200
192
 
201
193
  Args:
202
194
  batch (dict[str, torch.Tensor]): Batch containing images and class labels.
@@ -16,11 +16,10 @@ from ultralytics.utils.plotting import plot_images
16
16
 
17
17
 
18
18
  class ClassificationValidator(BaseValidator):
19
- """
20
- A class extending the BaseValidator class for validation based on a classification model.
19
+ """A class extending the BaseValidator class for validation based on a classification model.
21
20
 
22
- This validator handles the validation process for classification models, including metrics calculation,
23
- confusion matrix generation, and visualization of results.
21
+ This validator handles the validation process for classification models, including metrics calculation, confusion
22
+ matrix generation, and visualization of results.
24
23
 
25
24
  Attributes:
26
25
  targets (list[torch.Tensor]): Ground truth class labels.
@@ -55,8 +54,7 @@ class ClassificationValidator(BaseValidator):
55
54
  """
56
55
 
57
56
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
58
- """
59
- Initialize ClassificationValidator with dataloader, save directory, and other parameters.
57
+ """Initialize ClassificationValidator with dataloader, save directory, and other parameters.
60
58
 
61
59
  Args:
62
60
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
@@ -96,8 +94,7 @@ class ClassificationValidator(BaseValidator):
96
94
  return batch
97
95
 
98
96
  def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
99
- """
100
- Update running metrics with model predictions and batch targets.
97
+ """Update running metrics with model predictions and batch targets.
101
98
 
102
99
  Args:
103
100
  preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
@@ -112,12 +109,7 @@ class ClassificationValidator(BaseValidator):
112
109
  self.targets.append(batch["cls"].type(torch.int32).cpu())
113
110
 
114
111
  def finalize_metrics(self) -> None:
115
- """
116
- Finalize metrics including confusion matrix and processing speed.
117
-
118
- Notes:
119
- This method processes the accumulated predictions and targets to generate the confusion matrix,
120
- optionally plots it, and updates the metrics object with speed information.
112
+ """Finalize metrics including confusion matrix and processing speed.
121
113
 
122
114
  Examples:
123
115
  >>> validator = ClassificationValidator()
@@ -125,6 +117,10 @@ class ClassificationValidator(BaseValidator):
125
117
  >>> validator.targets = [torch.tensor([0])] # Ground truth class
126
118
  >>> validator.finalize_metrics()
127
119
  >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
120
+
121
+ Notes:
122
+ This method processes the accumulated predictions and targets to generate the confusion matrix,
123
+ optionally plots it, and updates the metrics object with speed information.
128
124
  """
129
125
  self.confusion_matrix.process_cls_preds(self.pred, self.targets)
130
126
  if self.args.plots:
@@ -161,8 +157,7 @@ class ClassificationValidator(BaseValidator):
161
157
  return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
162
158
 
163
159
  def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
164
- """
165
- Build and return a data loader for classification validation.
160
+ """Build and return a data loader for classification validation.
166
161
 
167
162
  Args:
168
163
  dataset_path (str | Path): Path to the dataset directory.
@@ -180,8 +175,7 @@ class ClassificationValidator(BaseValidator):
180
175
  LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
181
176
 
182
177
  def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
183
- """
184
- Plot validation image samples with their ground truth labels.
178
+ """Plot validation image samples with their ground truth labels.
185
179
 
186
180
  Args:
187
181
  batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
@@ -201,8 +195,7 @@ class ClassificationValidator(BaseValidator):
201
195
  )
202
196
 
203
197
  def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
204
- """
205
- Plot images with their predicted class labels and save the visualization.
198
+ """Plot images with their predicted class labels and save the visualization.
206
199
 
207
200
  Args:
208
201
  batch (dict[str, Any]): Batch data containing images and other information.
@@ -6,8 +6,7 @@ from ultralytics.utils import nms, ops
6
6
 
7
7
 
8
8
  class DetectionPredictor(BasePredictor):
9
- """
10
- A class extending the BasePredictor class for prediction based on a detection model.
9
+ """A class extending the BasePredictor class for prediction based on a detection model.
11
10
 
12
11
  This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
13
12
  with bounding boxes and class predictions.
@@ -32,8 +31,7 @@ class DetectionPredictor(BasePredictor):
32
31
  """
33
32
 
34
33
  def postprocess(self, preds, img, orig_imgs, **kwargs):
35
- """
36
- Post-process predictions and return a list of Results objects.
34
+ """Post-process predictions and return a list of Results objects.
37
35
 
38
36
  This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
39
37
  further analysis.
@@ -92,8 +90,7 @@ class DetectionPredictor(BasePredictor):
92
90
  return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
93
91
 
94
92
  def construct_results(self, preds, img, orig_imgs):
95
- """
96
- Construct a list of Results objects from model predictions.
93
+ """Construct a list of Results objects from model predictions.
97
94
 
98
95
  Args:
99
96
  preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
@@ -109,8 +106,7 @@ class DetectionPredictor(BasePredictor):
109
106
  ]
110
107
 
111
108
  def construct_result(self, pred, img, orig_img, img_path):
112
- """
113
- Construct a single Results object from one image prediction.
109
+ """Construct a single Results object from one image prediction.
114
110
 
115
111
  Args:
116
112
  pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
@@ -22,11 +22,10 @@ from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_m
22
22
 
23
23
 
24
24
  class DetectionTrainer(BaseTrainer):
25
- """
26
- A class extending the BaseTrainer class for training based on a detection model.
25
+ """A class extending the BaseTrainer class for training based on a detection model.
27
26
 
28
- This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
29
- for object detection including dataset building, data loading, preprocessing, and model configuration.
27
+ This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models for
28
+ object detection including dataset building, data loading, preprocessing, and model configuration.
30
29
 
31
30
  Attributes:
32
31
  model (DetectionModel): The YOLO detection model being trained.
@@ -54,8 +53,7 @@ class DetectionTrainer(BaseTrainer):
54
53
  """
55
54
 
56
55
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
57
- """
58
- Initialize a DetectionTrainer object for training YOLO object detection model training.
56
+ """Initialize a DetectionTrainer object for training YOLO object detection model training.
59
57
 
60
58
  Args:
61
59
  cfg (dict, optional): Default configuration dictionary containing training parameters.
@@ -65,8 +63,7 @@ class DetectionTrainer(BaseTrainer):
65
63
  super().__init__(cfg, overrides, _callbacks)
66
64
 
67
65
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
68
- """
69
- Build YOLO Dataset for training or validation.
66
+ """Build YOLO Dataset for training or validation.
70
67
 
71
68
  Args:
72
69
  img_path (str): Path to the folder containing images.
@@ -80,8 +77,7 @@ class DetectionTrainer(BaseTrainer):
80
77
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
81
78
 
82
79
  def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
83
- """
84
- Construct and return dataloader for the specified mode.
80
+ """Construct and return dataloader for the specified mode.
85
81
 
86
82
  Args:
87
83
  dataset_path (str): Path to the dataset.
@@ -109,8 +105,7 @@ class DetectionTrainer(BaseTrainer):
109
105
  )
110
106
 
111
107
  def preprocess_batch(self, batch: dict) -> dict:
112
- """
113
- Preprocess a batch of images by scaling and converting to float.
108
+ """Preprocess a batch of images by scaling and converting to float.
114
109
 
115
110
  Args:
116
111
  batch (dict): Dictionary containing batch data with 'img' tensor.
@@ -150,8 +145,7 @@ class DetectionTrainer(BaseTrainer):
150
145
  # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
151
146
 
152
147
  def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
153
- """
154
- Return a YOLO detection model.
148
+ """Return a YOLO detection model.
155
149
 
156
150
  Args:
157
151
  cfg (str, optional): Path to model configuration file.
@@ -174,8 +168,7 @@ class DetectionTrainer(BaseTrainer):
174
168
  )
175
169
 
176
170
  def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
177
- """
178
- Return a loss dict with labeled training loss items tensor.
171
+ """Return a loss dict with labeled training loss items tensor.
179
172
 
180
173
  Args:
181
174
  loss_items (list[float], optional): List of loss values.
@@ -202,8 +195,7 @@ class DetectionTrainer(BaseTrainer):
202
195
  )
203
196
 
204
197
  def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
205
- """
206
- Plot training samples with their annotations.
198
+ """Plot training samples with their annotations.
207
199
 
208
200
  Args:
209
201
  batch (dict[str, Any]): Dictionary containing batch data.
@@ -223,8 +215,7 @@ class DetectionTrainer(BaseTrainer):
223
215
  plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
224
216
 
225
217
  def auto_batch(self):
226
- """
227
- Get optimal batch size by calculating memory occupation of model.
218
+ """Get optimal batch size by calculating memory occupation of model.
228
219
 
229
220
  Returns:
230
221
  (int): Optimal batch size.