ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -13,18 +13,43 @@ class ClassificationValidator(BaseValidator):
13
13
  """
14
14
  A class extending the BaseValidator class for validation based on a classification model.
15
15
 
16
- Notes:
17
- - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
16
+ This validator handles the validation process for classification models, including metrics calculation,
17
+ confusion matrix generation, and visualization of results.
18
+
19
+ Attributes:
20
+ targets (List[torch.Tensor]): Ground truth class labels.
21
+ pred (List[torch.Tensor]): Model predictions.
22
+ metrics (ClassifyMetrics): Object to calculate and store classification metrics.
23
+ names (Dict): Mapping of class indices to class names.
24
+ nc (int): Number of classes.
25
+ confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
26
+
27
+ Methods:
28
+ get_desc: Return a formatted string summarizing classification metrics.
29
+ init_metrics: Initialize confusion matrix, class names, and tracking containers.
30
+ preprocess: Preprocess input batch by moving data to device.
31
+ update_metrics: Update running metrics with model predictions and batch targets.
32
+ finalize_metrics: Finalize metrics including confusion matrix and processing speed.
33
+ postprocess: Extract the primary prediction from model output.
34
+ get_stats: Calculate and return a dictionary of metrics.
35
+ build_dataset: Create a ClassificationDataset instance for validation.
36
+ get_dataloader: Build and return a data loader for classification validation.
37
+ print_results: Print evaluation metrics for the classification model.
38
+ plot_val_samples: Plot validation image samples with their ground truth labels.
39
+ plot_predictions: Plot images with their predicted class labels.
18
40
 
19
41
  Examples:
20
42
  >>> from ultralytics.models.yolo.classify import ClassificationValidator
21
43
  >>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
22
44
  >>> validator = ClassificationValidator(args=args)
23
45
  >>> validator()
46
+
47
+ Notes:
48
+ Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
24
49
  """
25
50
 
26
51
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
27
- """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
52
+ """Initialize ClassificationValidator with dataloader, save directory, and other parameters."""
28
53
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
29
54
  self.targets = None
30
55
  self.pred = None
@@ -32,11 +57,11 @@ class ClassificationValidator(BaseValidator):
32
57
  self.metrics = ClassifyMetrics()
33
58
 
34
59
  def get_desc(self):
35
- """Returns a formatted string summarizing classification metrics."""
60
+ """Return a formatted string summarizing classification metrics."""
36
61
  return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
37
62
 
38
63
  def init_metrics(self, model):
39
- """Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
64
+ """Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
40
65
  self.names = model.names
41
66
  self.nc = len(model.names)
42
67
  self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
@@ -44,20 +69,20 @@ class ClassificationValidator(BaseValidator):
44
69
  self.targets = []
45
70
 
46
71
  def preprocess(self, batch):
47
- """Preprocesses input batch and returns it."""
72
+ """Preprocess input batch by moving data to device and converting to appropriate dtype."""
48
73
  batch["img"] = batch["img"].to(self.device, non_blocking=True)
49
74
  batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
50
75
  batch["cls"] = batch["cls"].to(self.device)
51
76
  return batch
52
77
 
53
78
  def update_metrics(self, preds, batch):
54
- """Updates running metrics with model predictions and batch targets."""
79
+ """Update running metrics with model predictions and batch targets."""
55
80
  n5 = min(len(self.names), 5)
56
81
  self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
57
82
  self.targets.append(batch["cls"].type(torch.int32).cpu())
58
83
 
59
84
  def finalize_metrics(self, *args, **kwargs):
60
- """Finalizes metrics of the model such as confusion_matrix and speed."""
85
+ """Finalize metrics including confusion matrix and processing speed."""
61
86
  self.confusion_matrix.process_cls_preds(self.pred, self.targets)
62
87
  if self.args.plots:
63
88
  for normalize in True, False:
@@ -69,30 +94,30 @@ class ClassificationValidator(BaseValidator):
69
94
  self.metrics.save_dir = self.save_dir
70
95
 
71
96
  def postprocess(self, preds):
72
- """Preprocesses the classification predictions."""
97
+ """Extract the primary prediction from model output if it's in a list or tuple format."""
73
98
  return preds[0] if isinstance(preds, (list, tuple)) else preds
74
99
 
75
100
  def get_stats(self):
76
- """Returns a dictionary of metrics obtained by processing targets and predictions."""
101
+ """Calculate and return a dictionary of metrics by processing targets and predictions."""
77
102
  self.metrics.process(self.targets, self.pred)
78
103
  return self.metrics.results_dict
79
104
 
80
105
  def build_dataset(self, img_path):
81
- """Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters."""
106
+ """Create a ClassificationDataset instance for validation."""
82
107
  return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
83
108
 
84
109
  def get_dataloader(self, dataset_path, batch_size):
85
- """Builds and returns a data loader for classification tasks with given parameters."""
110
+ """Build and return a data loader for classification validation."""
86
111
  dataset = self.build_dataset(dataset_path)
87
112
  return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
88
113
 
89
114
  def print_results(self):
90
- """Prints evaluation metrics for YOLO object detection model."""
115
+ """Print evaluation metrics for the classification model."""
91
116
  pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
92
117
  LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
93
118
 
94
119
  def plot_val_samples(self, batch, ni):
95
- """Plot validation image samples."""
120
+ """Plot validation image samples with their ground truth labels."""
96
121
  plot_images(
97
122
  images=batch["img"],
98
123
  batch_idx=torch.arange(len(batch["img"])),
@@ -103,7 +128,7 @@ class ClassificationValidator(BaseValidator):
103
128
  )
104
129
 
105
130
  def plot_predictions(self, batch, preds, ni):
106
- """Plots predicted bounding boxes on input images and saves the result."""
131
+ """Plot images with their predicted class labels and save the visualization."""
107
132
  plot_images(
108
133
  batch["img"],
109
134
  batch_idx=torch.arange(len(batch["img"])),
@@ -9,6 +9,19 @@ class DetectionPredictor(BasePredictor):
9
9
  """
10
10
  A class extending the BasePredictor class for prediction based on a detection model.
11
11
 
12
+ This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
13
+ with bounding boxes and class predictions.
14
+
15
+ Attributes:
16
+ args (namespace): Configuration arguments for the predictor.
17
+ model (nn.Module): The detection model used for inference.
18
+ batch (List): Batch of images and metadata for processing.
19
+
20
+ Methods:
21
+ postprocess: Process raw model predictions into detection results.
22
+ construct_results: Build Results objects from processed predictions.
23
+ construct_result: Create a single Result object from a prediction.
24
+
12
25
  Examples:
13
26
  >>> from ultralytics.utils import ASSETS
14
27
  >>> from ultralytics.models.yolo.detect import DetectionPredictor
@@ -38,15 +51,15 @@ class DetectionPredictor(BasePredictor):
38
51
 
39
52
  def construct_results(self, preds, img, orig_imgs):
40
53
  """
41
- Constructs a list of result objects from the predictions.
54
+ Construct a list of Results objects from model predictions.
42
55
 
43
56
  Args:
44
- preds (List[torch.Tensor]): List of predicted bounding boxes and scores.
45
- img (torch.Tensor): The image after preprocessing.
57
+ preds (List[torch.Tensor]): List of predicted bounding boxes and scores for each image.
58
+ img (torch.Tensor): Batch of preprocessed images used for inference.
46
59
  orig_imgs (List[np.ndarray]): List of original images before preprocessing.
47
60
 
48
61
  Returns:
49
- (list): List of result objects containing the original images, image paths, class names, and bounding boxes.
62
+ (List[Results]): List of Results objects containing detection information for each image.
50
63
  """
51
64
  return [
52
65
  self.construct_result(pred, img, orig_img, img_path)
@@ -55,16 +68,16 @@ class DetectionPredictor(BasePredictor):
55
68
 
56
69
  def construct_result(self, pred, img, orig_img, img_path):
57
70
  """
58
- Constructs the result object from the prediction.
71
+ Construct a single Results object from one image prediction.
59
72
 
60
73
  Args:
61
- pred (torch.Tensor): The predicted bounding boxes and scores.
62
- img (torch.Tensor): The image after preprocessing.
63
- orig_img (np.ndarray): The original image before preprocessing.
64
- img_path (str): The path to the original image.
74
+ pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
75
+ img (torch.Tensor): Preprocessed image tensor used for inference.
76
+ orig_img (np.ndarray): Original image before preprocessing.
77
+ img_path (str): Path to the original image file.
65
78
 
66
79
  Returns:
67
- (Results): The result object containing the original image, image path, class names, and bounding boxes.
80
+ (Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
68
81
  """
69
82
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
70
83
  return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
@@ -20,6 +20,28 @@ class DetectionTrainer(BaseTrainer):
20
20
  """
21
21
  A class extending the BaseTrainer class for training based on a detection model.
22
22
 
23
+ This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
24
+ for object detection.
25
+
26
+ Attributes:
27
+ model (DetectionModel): The YOLO detection model being trained.
28
+ data (Dict): Dictionary containing dataset information including class names and number of classes.
29
+ loss_names (Tuple[str]): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
30
+
31
+ Methods:
32
+ build_dataset: Build YOLO dataset for training or validation.
33
+ get_dataloader: Construct and return dataloader for the specified mode.
34
+ preprocess_batch: Preprocess a batch of images by scaling and converting to float.
35
+ set_model_attributes: Set model attributes based on dataset information.
36
+ get_model: Return a YOLO detection model.
37
+ get_validator: Return a validator for model evaluation.
38
+ label_loss_items: Return a loss dictionary with labeled training loss items.
39
+ progress_string: Return a formatted string of training progress.
40
+ plot_training_samples: Plot training samples with their annotations.
41
+ plot_metrics: Plot metrics from a CSV file.
42
+ plot_training_labels: Create a labeled training plot of the YOLO model.
43
+ auto_batch: Calculate optimal batch size based on model memory requirements.
44
+
23
45
  Examples:
24
46
  >>> from ultralytics.models.yolo.detect import DetectionTrainer
25
47
  >>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
@@ -29,18 +51,32 @@ class DetectionTrainer(BaseTrainer):
29
51
 
30
52
  def build_dataset(self, img_path, mode="train", batch=None):
31
53
  """
32
- Build YOLO Dataset.
54
+ Build YOLO Dataset for training or validation.
33
55
 
34
56
  Args:
35
57
  img_path (str): Path to the folder containing images.
36
58
  mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
37
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
59
+ batch (int, optional): Size of batches, this is for `rect`.
60
+
61
+ Returns:
62
+ (Dataset): YOLO dataset object configured for the specified mode.
38
63
  """
39
64
  gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
40
65
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
41
66
 
42
67
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
43
- """Construct and return dataloader."""
68
+ """
69
+ Construct and return dataloader for the specified mode.
70
+
71
+ Args:
72
+ dataset_path (str): Path to the dataset.
73
+ batch_size (int): Number of images per batch.
74
+ rank (int): Process rank for distributed training.
75
+ mode (str): 'train' for training dataloader, 'val' for validation dataloader.
76
+
77
+ Returns:
78
+ (DataLoader): PyTorch dataloader object.
79
+ """
44
80
  assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
45
81
  with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
46
82
  dataset = self.build_dataset(dataset_path, mode, batch_size)
@@ -52,7 +88,15 @@ class DetectionTrainer(BaseTrainer):
52
88
  return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
53
89
 
54
90
  def preprocess_batch(self, batch):
55
- """Preprocesses a batch of images by scaling and converting to float."""
91
+ """
92
+ Preprocess a batch of images by scaling and converting to float.
93
+
94
+ Args:
95
+ batch (Dict): Dictionary containing batch data with 'img' tensor.
96
+
97
+ Returns:
98
+ (Dict): Preprocessed batch with normalized images.
99
+ """
56
100
  batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
57
101
  if self.args.multi_scale:
58
102
  imgs = batch["img"]
@@ -71,7 +115,8 @@ class DetectionTrainer(BaseTrainer):
71
115
  return batch
72
116
 
73
117
  def set_model_attributes(self):
74
- """Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
118
+ """Set model attributes based on dataset information."""
119
+ # Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
75
120
  # self.args.box *= 3 / nl # scale to layers
76
121
  # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
77
122
  # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
@@ -81,14 +126,24 @@ class DetectionTrainer(BaseTrainer):
81
126
  # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
82
127
 
83
128
  def get_model(self, cfg=None, weights=None, verbose=True):
84
- """Return a YOLO detection model."""
129
+ """
130
+ Return a YOLO detection model.
131
+
132
+ Args:
133
+ cfg (str, optional): Path to model configuration file.
134
+ weights (str, optional): Path to model weights.
135
+ verbose (bool): Whether to display model information.
136
+
137
+ Returns:
138
+ (DetectionModel): YOLO detection model.
139
+ """
85
140
  model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
86
141
  if weights:
87
142
  model.load(weights)
88
143
  return model
89
144
 
90
145
  def get_validator(self):
91
- """Returns a DetectionValidator for YOLO model validation."""
146
+ """Return a DetectionValidator for YOLO model validation."""
92
147
  self.loss_names = "box_loss", "cls_loss", "dfl_loss"
93
148
  return yolo.detect.DetectionValidator(
94
149
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
@@ -96,9 +151,14 @@ class DetectionTrainer(BaseTrainer):
96
151
 
97
152
  def label_loss_items(self, loss_items=None, prefix="train"):
98
153
  """
99
- Returns a loss dict with labelled training loss items tensor.
154
+ Return a loss dict with labeled training loss items tensor.
100
155
 
101
- Not needed for classification but necessary for segmentation & detection
156
+ Args:
157
+ loss_items (List[float], optional): List of loss values.
158
+ prefix (str): Prefix for keys in the returned dictionary.
159
+
160
+ Returns:
161
+ (Dict | List): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
102
162
  """
103
163
  keys = [f"{prefix}/{x}" for x in self.loss_names]
104
164
  if loss_items is not None:
@@ -108,7 +168,7 @@ class DetectionTrainer(BaseTrainer):
108
168
  return keys
109
169
 
110
170
  def progress_string(self):
111
- """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
171
+ """Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
112
172
  return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
113
173
  "Epoch",
114
174
  "GPU_mem",
@@ -118,7 +178,13 @@ class DetectionTrainer(BaseTrainer):
118
178
  )
119
179
 
120
180
  def plot_training_samples(self, batch, ni):
121
- """Plots training samples with their annotations."""
181
+ """
182
+ Plot training samples with their annotations.
183
+
184
+ Args:
185
+ batch (Dict): Dictionary containing batch data.
186
+ ni (int): Number of iterations.
187
+ """
122
188
  plot_images(
123
189
  images=batch["img"],
124
190
  batch_idx=batch["batch_idx"],
@@ -130,7 +196,7 @@ class DetectionTrainer(BaseTrainer):
130
196
  )
131
197
 
132
198
  def plot_metrics(self):
133
- """Plots metrics from a CSV file."""
199
+ """Plot metrics from a CSV file."""
134
200
  plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
135
201
 
136
202
  def plot_training_labels(self):
@@ -140,8 +206,12 @@ class DetectionTrainer(BaseTrainer):
140
206
  plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
141
207
 
142
208
  def auto_batch(self):
143
- """Get batch size by calculating memory occupation of model."""
209
+ """
210
+ Get optimal batch size by calculating memory occupation of model.
211
+
212
+ Returns:
213
+ (int): Optimal batch size.
214
+ """
144
215
  train_dataset = self.build_dataset(self.trainset, mode="train", batch=16)
145
- # 4 for mosaic augmentation
146
- max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4
216
+ max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
147
217
  return super().auto_batch(max_num_obj)
@@ -18,6 +18,22 @@ class DetectionValidator(BaseValidator):
18
18
  """
19
19
  A class extending the BaseValidator class for validation based on a detection model.
20
20
 
21
+ This class implements validation functionality specific to object detection tasks, including metrics calculation,
22
+ prediction processing, and visualization of results.
23
+
24
+ Attributes:
25
+ nt_per_class (np.ndarray): Number of targets per class.
26
+ nt_per_image (np.ndarray): Number of targets per image.
27
+ is_coco (bool): Whether the dataset is COCO.
28
+ is_lvis (bool): Whether the dataset is LVIS.
29
+ class_map (List): Mapping from model class indices to dataset class indices.
30
+ metrics (DetMetrics): Object detection metrics calculator.
31
+ iouv (torch.Tensor): IoU thresholds for mAP calculation.
32
+ niou (int): Number of IoU thresholds.
33
+ lb (List): List for storing ground truth labels for hybrid saving.
34
+ jdict (List): List for storing JSON detection results.
35
+ stats (Dict): Dictionary for storing statistics during validation.
36
+
21
37
  Examples:
22
38
  >>> from ultralytics.models.yolo.detect import DetectionValidator
23
39
  >>> args = dict(model="yolo11n.pt", data="coco8.yaml")
@@ -26,7 +42,16 @@ class DetectionValidator(BaseValidator):
26
42
  """
27
43
 
28
44
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
29
- """Initialize detection model with necessary variables and settings."""
45
+ """
46
+ Initialize detection validator with necessary variables and settings.
47
+
48
+ Args:
49
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
50
+ save_dir (Path, optional): Directory to save results.
51
+ pbar (Any, optional): Progress bar for displaying progress.
52
+ args (Dict, optional): Arguments for the validator.
53
+ _callbacks (List, optional): List of callback functions.
54
+ """
30
55
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
31
56
  self.nt_per_class = None
32
57
  self.nt_per_image = None
@@ -45,7 +70,15 @@ class DetectionValidator(BaseValidator):
45
70
  )
46
71
 
47
72
  def preprocess(self, batch):
48
- """Preprocesses batch of images for YOLO training."""
73
+ """
74
+ Preprocess batch of images for YOLO validation.
75
+
76
+ Args:
77
+ batch (Dict): Batch containing images and annotations.
78
+
79
+ Returns:
80
+ (Dict): Preprocessed batch.
81
+ """
49
82
  batch["img"] = batch["img"].to(self.device, non_blocking=True)
50
83
  batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
51
84
  for k in ["batch_idx", "cls", "bboxes"]:
@@ -63,7 +96,12 @@ class DetectionValidator(BaseValidator):
63
96
  return batch
64
97
 
65
98
  def init_metrics(self, model):
66
- """Initialize evaluation metrics for YOLO."""
99
+ """
100
+ Initialize evaluation metrics for YOLO detection validation.
101
+
102
+ Args:
103
+ model (torch.nn.Module): Model to validate.
104
+ """
67
105
  val = self.data.get(self.args.split, "") # validation path
68
106
  self.is_coco = (
69
107
  isinstance(val, str)
@@ -88,7 +126,15 @@ class DetectionValidator(BaseValidator):
88
126
  return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
89
127
 
90
128
  def postprocess(self, preds):
91
- """Apply Non-maximum suppression to prediction outputs."""
129
+ """
130
+ Apply Non-maximum suppression to prediction outputs.
131
+
132
+ Args:
133
+ preds (torch.Tensor): Raw predictions from the model.
134
+
135
+ Returns:
136
+ (List[torch.Tensor]): Processed predictions after NMS.
137
+ """
92
138
  return ops.non_max_suppression(
93
139
  preds,
94
140
  self.args.conf,
@@ -103,7 +149,16 @@ class DetectionValidator(BaseValidator):
103
149
  )
104
150
 
105
151
  def _prepare_batch(self, si, batch):
106
- """Prepares a batch of images and annotations for validation."""
152
+ """
153
+ Prepare a batch of images and annotations for validation.
154
+
155
+ Args:
156
+ si (int): Batch index.
157
+ batch (Dict): Batch data containing images and annotations.
158
+
159
+ Returns:
160
+ (Dict): Prepared batch with processed annotations.
161
+ """
107
162
  idx = batch["batch_idx"] == si
108
163
  cls = batch["cls"][idx].squeeze(-1)
109
164
  bbox = batch["bboxes"][idx]
@@ -116,7 +171,16 @@ class DetectionValidator(BaseValidator):
116
171
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
117
172
 
118
173
  def _prepare_pred(self, pred, pbatch):
119
- """Prepares a batch of images and annotations for validation."""
174
+ """
175
+ Prepare predictions for evaluation against ground truth.
176
+
177
+ Args:
178
+ pred (torch.Tensor): Model predictions.
179
+ pbatch (Dict): Prepared batch information.
180
+
181
+ Returns:
182
+ (torch.Tensor): Prepared predictions in native space.
183
+ """
120
184
  predn = pred.clone()
121
185
  ops.scale_boxes(
122
186
  pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
@@ -124,7 +188,13 @@ class DetectionValidator(BaseValidator):
124
188
  return predn
125
189
 
126
190
  def update_metrics(self, preds, batch):
127
- """Metrics."""
191
+ """
192
+ Update metrics with new predictions and ground truth.
193
+
194
+ Args:
195
+ preds (List[torch.Tensor]): List of predictions from the model.
196
+ batch (Dict): Batch data containing ground truth.
197
+ """
128
198
  for si, pred in enumerate(preds):
129
199
  self.seen += 1
130
200
  npr = len(pred)
@@ -173,12 +243,23 @@ class DetectionValidator(BaseValidator):
173
243
  )
174
244
 
175
245
  def finalize_metrics(self, *args, **kwargs):
176
- """Set final values for metrics speed and confusion matrix."""
246
+ """
247
+ Set final values for metrics speed and confusion matrix.
248
+
249
+ Args:
250
+ *args (Any): Variable length argument list.
251
+ **kwargs (Any): Arbitrary keyword arguments.
252
+ """
177
253
  self.metrics.speed = self.speed
178
254
  self.metrics.confusion_matrix = self.confusion_matrix
179
255
 
180
256
  def get_stats(self):
181
- """Returns metrics statistics and results dictionary."""
257
+ """
258
+ Calculate and return metrics statistics.
259
+
260
+ Returns:
261
+ (Dict): Dictionary containing metrics results.
262
+ """
182
263
  stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
183
264
  self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
184
265
  self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
@@ -188,7 +269,7 @@ class DetectionValidator(BaseValidator):
188
269
  return self.metrics.results_dict
189
270
 
190
271
  def print_results(self):
191
- """Prints training/validation set metrics per class."""
272
+ """Print training/validation set metrics per class."""
192
273
  pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
193
274
  LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
194
275
  if self.nt_per_class.sum() == 0:
@@ -220,10 +301,6 @@ class DetectionValidator(BaseValidator):
220
301
 
221
302
  Returns:
222
303
  (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
223
-
224
- Note:
225
- The function does not return any value directly usable for metrics calculation. Instead, it provides an
226
- intermediate representation used for evaluating predictions against ground truth.
227
304
  """
228
305
  iou = box_iou(gt_bboxes, detections[:, :4])
229
306
  return self.match_predictions(detections[:, 5], gt_cls, iou)
@@ -235,17 +312,35 @@ class DetectionValidator(BaseValidator):
235
312
  Args:
236
313
  img_path (str): Path to the folder containing images.
237
314
  mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
238
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
315
+ batch (int, optional): Size of batches, this is for `rect`.
316
+
317
+ Returns:
318
+ (Dataset): YOLO dataset.
239
319
  """
240
320
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
241
321
 
242
322
  def get_dataloader(self, dataset_path, batch_size):
243
- """Construct and return dataloader."""
323
+ """
324
+ Construct and return dataloader.
325
+
326
+ Args:
327
+ dataset_path (str): Path to the dataset.
328
+ batch_size (int): Size of each batch.
329
+
330
+ Returns:
331
+ (torch.utils.data.DataLoader): Dataloader for validation.
332
+ """
244
333
  dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
245
334
  return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
246
335
 
247
336
  def plot_val_samples(self, batch, ni):
248
- """Plot validation image samples."""
337
+ """
338
+ Plot validation image samples.
339
+
340
+ Args:
341
+ batch (Dict): Batch containing images and annotations.
342
+ ni (int): Batch index.
343
+ """
249
344
  plot_images(
250
345
  batch["img"],
251
346
  batch["batch_idx"],
@@ -258,7 +353,14 @@ class DetectionValidator(BaseValidator):
258
353
  )
259
354
 
260
355
  def plot_predictions(self, batch, preds, ni):
261
- """Plots predicted bounding boxes on input images and saves the result."""
356
+ """
357
+ Plot predicted bounding boxes on input images and save the result.
358
+
359
+ Args:
360
+ batch (Dict): Batch containing images and annotations.
361
+ preds (List[torch.Tensor]): List of predictions from the model.
362
+ ni (int): Batch index.
363
+ """
262
364
  plot_images(
263
365
  batch["img"],
264
366
  *output_to_target(preds, max_det=self.args.max_det),
@@ -269,7 +371,15 @@ class DetectionValidator(BaseValidator):
269
371
  ) # pred
270
372
 
271
373
  def save_one_txt(self, predn, save_conf, shape, file):
272
- """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
374
+ """
375
+ Save YOLO detections to a txt file in normalized coordinates in a specific format.
376
+
377
+ Args:
378
+ predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
379
+ save_conf (bool): Whether to save confidence scores.
380
+ shape (tuple): Shape of the original image.
381
+ file (Path): File path to save the detections.
382
+ """
273
383
  from ultralytics.engine.results import Results
274
384
 
275
385
  Results(
@@ -280,7 +390,13 @@ class DetectionValidator(BaseValidator):
280
390
  ).save_txt(file, save_conf=save_conf)
281
391
 
282
392
  def pred_to_json(self, predn, filename):
283
- """Serialize YOLO predictions to COCO json format."""
393
+ """
394
+ Serialize YOLO predictions to COCO json format.
395
+
396
+ Args:
397
+ predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
398
+ filename (str): Image filename.
399
+ """
284
400
  stem = Path(filename).stem
285
401
  image_id = int(stem) if stem.isnumeric() else stem
286
402
  box = ops.xyxy2xywh(predn[:, :4]) # xywh
@@ -296,7 +412,15 @@ class DetectionValidator(BaseValidator):
296
412
  )
297
413
 
298
414
  def eval_json(self, stats):
299
- """Evaluates YOLO output in JSON format and returns performance statistics."""
415
+ """
416
+ Evaluate YOLO output in JSON format and return performance statistics.
417
+
418
+ Args:
419
+ stats (Dict): Current statistics dictionary.
420
+
421
+ Returns:
422
+ (Dict): Updated statistics dictionary with COCO/LVIS evaluation results.
423
+ """
300
424
  if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
301
425
  pred_json = self.save_dir / "predictions.json" # predictions
302
426
  anno_json = (