ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -93,7 +93,7 @@ class YOLOWorld(Model):
93
93
 
94
94
  def set_classes(self, classes):
95
95
  """
96
- Set classes.
96
+ Set the model's class names for detection.
97
97
 
98
98
  Args:
99
99
  classes (List(str)): A list of categories i.e. ["person"].
@@ -106,6 +106,5 @@ class YOLOWorld(Model):
106
106
  self.model.names = classes
107
107
 
108
108
  # Reset method class names
109
- # self.predictor = None # reset predictor otherwise old names remain
110
109
  if self.predictor:
111
110
  self.predictor.model.names = classes
@@ -11,6 +11,13 @@ class OBBPredictor(DetectionPredictor):
11
11
  """
12
12
  A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
13
13
 
14
+ This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
15
+ bounding boxes.
16
+
17
+ Attributes:
18
+ args (namespace): Configuration arguments for the predictor.
19
+ model (torch.nn.Module): The loaded YOLO OBB model.
20
+
14
21
  Examples:
15
22
  >>> from ultralytics.utils import ASSETS
16
23
  >>> from ultralytics.models.yolo.obb import OBBPredictor
@@ -20,17 +27,18 @@ class OBBPredictor(DetectionPredictor):
20
27
  """
21
28
 
22
29
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
23
- """Initializes OBBPredictor with optional model and data configuration overrides."""
30
+ """Initialize OBBPredictor with optional model and data configuration overrides."""
24
31
  super().__init__(cfg, overrides, _callbacks)
25
32
  self.args.task = "obb"
26
33
 
27
34
  def construct_result(self, pred, img, orig_img, img_path):
28
35
  """
29
- Constructs the result object from the prediction.
36
+ Construct the result object from the prediction.
30
37
 
31
38
  Args:
32
- pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles.
33
- img (torch.Tensor): The image after preprocessing.
39
+ pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 6) where
40
+ the last dimension contains [x, y, w, h, confidence, class_id, angle].
41
+ img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
34
42
  orig_img (np.ndarray): The original image before preprocessing.
35
43
  img_path (str): The path to the original image.
36
44
 
@@ -11,6 +11,13 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
11
11
  """
12
12
  A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
13
13
 
14
+ Attributes:
15
+ loss_names (Tuple[str]): Names of the loss components used during training.
16
+
17
+ Methods:
18
+ get_model: Return OBBModel initialized with specified config and weights.
19
+ get_validator: Return an instance of OBBValidator for validation of YOLO model.
20
+
14
21
  Examples:
15
22
  >>> from ultralytics.models.yolo.obb import OBBTrainer
16
23
  >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
@@ -14,6 +14,24 @@ class OBBValidator(DetectionValidator):
14
14
  """
15
15
  A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
16
16
 
17
+ This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
18
+ satellite imagery where objects can appear at various orientations.
19
+
20
+ Attributes:
21
+ args (Dict): Configuration arguments for the validator.
22
+ metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
23
+ is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
24
+
25
+ Methods:
26
+ init_metrics: Initialize evaluation metrics for YOLO.
27
+ _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
28
+ _prepare_batch: Prepare batch data for OBB validation.
29
+ _prepare_pred: Prepare predictions with scaled and padded bounding boxes.
30
+ plot_predictions: Plot predicted bounding boxes on input images.
31
+ pred_to_json: Serialize YOLO predictions to COCO json format.
32
+ save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
33
+ eval_json: Evaluate YOLO output in JSON format and return performance statistics.
34
+
17
35
  Examples:
18
36
  >>> from ultralytics.models.yolo.obb import OBBValidator
19
37
  >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
@@ -31,7 +49,7 @@ class OBBValidator(DetectionValidator):
31
49
  """Initialize evaluation metrics for YOLO."""
32
50
  super().init_metrics(model)
33
51
  val = self.data.get(self.args.split, "") # validation path
34
- self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
52
+ self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
35
53
 
36
54
  def _process_batch(self, detections, gt_bboxes, gt_cls):
37
55
  """
@@ -61,7 +79,7 @@ class OBBValidator(DetectionValidator):
61
79
  return self.match_predictions(detections[:, 5], gt_cls, iou)
62
80
 
63
81
  def _prepare_batch(self, si, batch):
64
- """Prepares and returns a batch for OBB validation."""
82
+ """Prepare batch data for OBB validation with proper scaling and formatting."""
65
83
  idx = batch["batch_idx"] == si
66
84
  cls = batch["cls"][idx].squeeze(-1)
67
85
  bbox = batch["bboxes"][idx]
@@ -74,7 +92,7 @@ class OBBValidator(DetectionValidator):
74
92
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
75
93
 
76
94
  def _prepare_pred(self, pred, pbatch):
77
- """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
95
+ """Prepare predictions by scaling bounding boxes to original image dimensions."""
78
96
  predn = pred.clone()
79
97
  ops.scale_boxes(
80
98
  pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
@@ -82,7 +100,7 @@ class OBBValidator(DetectionValidator):
82
100
  return predn
83
101
 
84
102
  def plot_predictions(self, batch, preds, ni):
85
- """Plots predicted bounding boxes on input images and saves the result."""
103
+ """Plot predicted bounding boxes on input images and save the result."""
86
104
  plot_images(
87
105
  batch["img"],
88
106
  *output_to_rotated_target(preds, max_det=self.args.max_det),
@@ -93,7 +111,7 @@ class OBBValidator(DetectionValidator):
93
111
  ) # pred
94
112
 
95
113
  def pred_to_json(self, predn, filename):
96
- """Serialize YOLO predictions to COCO json format."""
114
+ """Convert YOLO predictions to COCO JSON format with rotated bounding box information."""
97
115
  stem = Path(filename).stem
98
116
  image_id = int(stem) if stem.isnumeric() else stem
99
117
  rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
@@ -110,7 +128,7 @@ class OBBValidator(DetectionValidator):
110
128
  )
111
129
 
112
130
  def save_one_txt(self, predn, save_conf, shape, file):
113
- """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
131
+ """Save YOLO detections to a txt file in normalized coordinates using the Results class."""
114
132
  import numpy as np
115
133
 
116
134
  from ultralytics.engine.results import Results
@@ -126,7 +144,7 @@ class OBBValidator(DetectionValidator):
126
144
  ).save_txt(file, save_conf=save_conf)
127
145
 
128
146
  def eval_json(self, stats):
129
- """Evaluates YOLO output in JSON format and returns performance statistics."""
147
+ """Evaluate YOLO output in JSON format and save predictions in DOTA format."""
130
148
  if self.args.save_json and self.is_dota and len(self.jdict):
131
149
  import json
132
150
  import re
@@ -8,6 +8,16 @@ class PosePredictor(DetectionPredictor):
8
8
  """
9
9
  A class extending the DetectionPredictor class for prediction based on a pose model.
10
10
 
11
+ This class specializes in pose estimation, handling keypoints detection alongside standard object detection
12
+ capabilities inherited from DetectionPredictor.
13
+
14
+ Attributes:
15
+ args (namespace): Configuration arguments for the predictor.
16
+ model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
17
+
18
+ Methods:
19
+ construct_result: Constructs the result object from the prediction, including keypoints.
20
+
11
21
  Examples:
12
22
  >>> from ultralytics.utils import ASSETS
13
23
  >>> from ultralytics.models.yolo.pose import PosePredictor
@@ -17,7 +27,7 @@ class PosePredictor(DetectionPredictor):
17
27
  """
18
28
 
19
29
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
20
- """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
30
+ """Initialize PosePredictor, set task to 'pose' and log a warning for using 'mps' as device."""
21
31
  super().__init__(cfg, overrides, _callbacks)
22
32
  self.args.task = "pose"
23
33
  if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
@@ -28,19 +38,25 @@ class PosePredictor(DetectionPredictor):
28
38
 
29
39
  def construct_result(self, pred, img, orig_img, img_path):
30
40
  """
31
- Constructs the result object from the prediction.
41
+ Construct the result object from the prediction, including keypoints.
42
+
43
+ This method extends the parent class implementation by extracting keypoint data from predictions
44
+ and adding them to the result object.
32
45
 
33
46
  Args:
34
- pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints.
35
- img (torch.Tensor): The image after preprocessing.
36
- orig_img (np.ndarray): The original image before preprocessing.
37
- img_path (str): The path to the original image.
47
+ pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
48
+ the number of detections, K is the number of keypoints, and D is the keypoint dimension.
49
+ img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
50
+ orig_img (np.ndarray): The original unprocessed image as a numpy array.
51
+ img_path (str): The path to the original image file.
38
52
 
39
53
  Returns:
40
54
  (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
41
55
  """
42
56
  result = super().construct_result(pred, img, orig_img, img_path)
57
+ # Extract keypoints from prediction and reshape according to model's keypoint shape
43
58
  pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
59
+ # Scale keypoints coordinates to match the original image dimensions
44
60
  pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
45
61
  result.update(keypoints=pred_kpts)
46
62
  return result
@@ -10,7 +10,23 @@ from ultralytics.utils.plotting import plot_images, plot_results
10
10
 
11
11
  class PoseTrainer(yolo.detect.DetectionTrainer):
12
12
  """
13
- A class extending the DetectionTrainer class for training based on a pose model.
13
+ A class extending the DetectionTrainer class for training YOLO pose estimation models.
14
+
15
+ This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
16
+ of pose keypoints alongside bounding boxes.
17
+
18
+ Attributes:
19
+ args (Dict): Configuration arguments for training.
20
+ model (PoseModel): The pose estimation model being trained.
21
+ data (Dict): Dataset configuration including keypoint shape information.
22
+ loss_names (Tuple[str]): Names of the loss components used in training.
23
+
24
+ Methods:
25
+ get_model: Retrieves a pose estimation model with specified configuration.
26
+ set_model_attributes: Sets keypoints shape attribute on the model.
27
+ get_validator: Creates a validator instance for model evaluation.
28
+ plot_training_samples: Visualizes training samples with keypoints.
29
+ plot_metrics: Generates and saves training/validation metric plots.
14
30
 
15
31
  Examples:
16
32
  >>> from ultralytics.models.yolo.pose import PoseTrainer
@@ -16,6 +16,29 @@ class PoseValidator(DetectionValidator):
16
16
  """
17
17
  A class extending the DetectionValidator class for validation based on a pose model.
18
18
 
19
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
20
+ specialized metrics for pose evaluation.
21
+
22
+ Attributes:
23
+ sigma (np.ndarray): Sigma values for OKS calculation, either from OKS_SIGMA or ones divided by number of keypoints.
24
+ kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
25
+ args (Dict): Arguments for the validator including task set to "pose".
26
+ metrics (PoseMetrics): Metrics object for pose evaluation.
27
+
28
+ Methods:
29
+ preprocess: Preprocesses batch data for pose validation.
30
+ get_desc: Returns description of evaluation metrics.
31
+ init_metrics: Initializes pose metrics for the model.
32
+ _prepare_batch: Prepares a batch for processing.
33
+ _prepare_pred: Prepares and scales predictions for evaluation.
34
+ update_metrics: Updates metrics with new predictions.
35
+ _process_batch: Processes batch to compute IoU between detections and ground truth.
36
+ plot_val_samples: Plots validation samples with ground truth annotations.
37
+ plot_predictions: Plots model predictions.
38
+ save_one_txt: Saves detections to a text file.
39
+ pred_to_json: Converts predictions to COCO JSON format.
40
+ eval_json: Evaluates model using COCO JSON format.
41
+
19
42
  Examples:
20
43
  >>> from ultralytics.models.yolo.pose import PoseValidator
21
44
  >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
@@ -24,7 +47,7 @@ class PoseValidator(DetectionValidator):
24
47
  """
25
48
 
26
49
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
27
- """Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
50
+ """Initialize a PoseValidator object with custom parameters and assigned attributes."""
28
51
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
29
52
  self.sigma = None
30
53
  self.kpt_shape = None
@@ -37,13 +60,13 @@ class PoseValidator(DetectionValidator):
37
60
  )
38
61
 
39
62
  def preprocess(self, batch):
40
- """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
63
+ """Preprocess batch by converting keypoints data to float and moving it to the device."""
41
64
  batch = super().preprocess(batch)
42
65
  batch["keypoints"] = batch["keypoints"].to(self.device).float()
43
66
  return batch
44
67
 
45
68
  def get_desc(self):
46
- """Returns description of evaluation metrics in string format."""
69
+ """Return description of evaluation metrics in string format."""
47
70
  return ("%22s" + "%11s" * 10) % (
48
71
  "Class",
49
72
  "Images",
@@ -59,7 +82,7 @@ class PoseValidator(DetectionValidator):
59
82
  )
60
83
 
61
84
  def init_metrics(self, model):
62
- """Initiate pose estimation metrics for YOLO model."""
85
+ """Initialize pose estimation metrics for YOLO model."""
63
86
  super().init_metrics(model)
64
87
  self.kpt_shape = self.data["kpt_shape"]
65
88
  is_pose = self.kpt_shape == [17, 3]
@@ -68,7 +91,7 @@ class PoseValidator(DetectionValidator):
68
91
  self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
69
92
 
70
93
  def _prepare_batch(self, si, batch):
71
- """Prepares a batch for processing by converting keypoints to float and moving to device."""
94
+ """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions."""
72
95
  pbatch = super()._prepare_batch(si, batch)
73
96
  kpts = batch["keypoints"][batch["batch_idx"] == si]
74
97
  h, w = pbatch["imgsz"]
@@ -80,7 +103,7 @@ class PoseValidator(DetectionValidator):
80
103
  return pbatch
81
104
 
82
105
  def _prepare_pred(self, pred, pbatch):
83
- """Prepares and scales keypoints in a batch for pose processing."""
106
+ """Prepare and scale keypoints in predictions for pose processing."""
84
107
  predn = super()._prepare_pred(pred, pbatch)
85
108
  nk = pbatch["kpts"].shape[1]
86
109
  pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
@@ -88,7 +111,16 @@ class PoseValidator(DetectionValidator):
88
111
  return predn, pred_kpts
89
112
 
90
113
  def update_metrics(self, preds, batch):
91
- """Metrics."""
114
+ """
115
+ Update metrics with new predictions and ground truth data.
116
+
117
+ This method processes each prediction, compares it with ground truth, and updates various statistics
118
+ for performance evaluation.
119
+
120
+ Args:
121
+ preds (List[torch.Tensor]): List of prediction tensors from the model.
122
+ batch (Dict): Batch data containing images and ground truth annotations.
123
+ """
92
124
  for si, pred in enumerate(preds):
93
125
  self.seen += 1
94
126
  npr = len(pred)
@@ -158,16 +190,9 @@ class PoseValidator(DetectionValidator):
158
190
  (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
159
191
  where N is the number of detections.
160
192
 
161
- Examples:
162
- >>> detections = torch.rand(100, 6) # 100 predictions: (x1, y1, x2, y2, conf, class)
163
- >>> gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2)
164
- >>> gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices
165
- >>> pred_kpts = torch.rand(100, 51) # 100 predicted keypoints
166
- >>> gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints
167
- >>> correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
168
-
169
- Note:
170
- `0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
193
+ Notes:
194
+ `0.53` scale factor used in area computation is referenced from
195
+ https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
171
196
  """
172
197
  if pred_kpts is not None and gt_kpts is not None:
173
198
  # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
@@ -179,7 +204,7 @@ class PoseValidator(DetectionValidator):
179
204
  return self.match_predictions(detections[:, 5], gt_cls, iou)
180
205
 
181
206
  def plot_val_samples(self, batch, ni):
182
- """Plots and saves validation set samples with predicted bounding boxes and keypoints."""
207
+ """Plot and save validation set samples with ground truth bounding boxes and keypoints."""
183
208
  plot_images(
184
209
  batch["img"],
185
210
  batch["batch_idx"],
@@ -193,7 +218,7 @@ class PoseValidator(DetectionValidator):
193
218
  )
194
219
 
195
220
  def plot_predictions(self, batch, preds, ni):
196
- """Plots predictions for YOLO model."""
221
+ """Plot and save model predictions with bounding boxes and keypoints."""
197
222
  pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
198
223
  plot_images(
199
224
  batch["img"],
@@ -218,7 +243,7 @@ class PoseValidator(DetectionValidator):
218
243
  ).save_txt(file, save_conf=save_conf)
219
244
 
220
245
  def pred_to_json(self, predn, filename):
221
- """Converts YOLO predictions to COCO JSON format."""
246
+ """Convert YOLO predictions to COCO JSON format."""
222
247
  stem = Path(filename).stem
223
248
  image_id = int(stem) if stem.isnumeric() else stem
224
249
  box = ops.xyxy2xywh(predn[:, :4]) # xywh
@@ -235,7 +260,7 @@ class PoseValidator(DetectionValidator):
235
260
  )
236
261
 
237
262
  def eval_json(self, stats):
238
- """Evaluates object detection model using COCO JSON format."""
263
+ """Evaluate object detection model using COCO JSON format."""
239
264
  if self.args.save_json and self.is_coco and len(self.jdict):
240
265
  anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
241
266
  pred_json = self.save_dir / "predictions.json" # predictions
@@ -9,6 +9,19 @@ class SegmentationPredictor(DetectionPredictor):
9
9
  """
10
10
  A class extending the DetectionPredictor class for prediction based on a segmentation model.
11
11
 
12
+ This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
13
+ prediction results.
14
+
15
+ Attributes:
16
+ args (Dict): Configuration arguments for the predictor.
17
+ model (torch.nn.Module): The loaded YOLO segmentation model.
18
+ batch (List): Current batch of images being processed.
19
+
20
+ Methods:
21
+ postprocess: Applies non-max suppression and processes detections.
22
+ construct_results: Constructs a list of result objects from predictions.
23
+ construct_result: Constructs a single result object from a prediction.
24
+
12
25
  Examples:
13
26
  >>> from ultralytics.utils import ASSETS
14
27
  >>> from ultralytics.models.yolo.segment import SegmentationPredictor
@@ -18,19 +31,19 @@ class SegmentationPredictor(DetectionPredictor):
18
31
  """
19
32
 
20
33
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
21
- """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
34
+ """Initialize the SegmentationPredictor with configuration, overrides, and callbacks."""
22
35
  super().__init__(cfg, overrides, _callbacks)
23
36
  self.args.task = "segment"
24
37
 
25
38
  def postprocess(self, preds, img, orig_imgs):
26
- """Applies non-max suppression and processes detections for each image in an input batch."""
27
- # tuple if PyTorch model or array if exported
39
+ """Apply non-max suppression and process detections for each image in the input batch."""
40
+ # Extract protos - tuple if PyTorch model or array if exported
28
41
  protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
29
42
  return super().postprocess(preds[0], img, orig_imgs, protos=protos)
30
43
 
31
44
  def construct_results(self, preds, img, orig_imgs, protos):
32
45
  """
33
- Constructs a list of result objects from the predictions.
46
+ Construct a list of result objects from the predictions.
34
47
 
35
48
  Args:
36
49
  preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
@@ -39,7 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
39
52
  protos (List[torch.Tensor]): List of prototype masks.
40
53
 
41
54
  Returns:
42
- (list): List of result objects containing the original images, image paths, class names, bounding boxes, and masks.
55
+ (List[Results]): List of result objects containing the original images, image paths, class names,
56
+ bounding boxes, and masks.
43
57
  """
44
58
  return [
45
59
  self.construct_result(pred, img, orig_img, img_path, proto)
@@ -48,7 +62,7 @@ class SegmentationPredictor(DetectionPredictor):
48
62
 
49
63
  def construct_result(self, pred, img, orig_img, img_path, proto):
50
64
  """
51
- Constructs the result object from the prediction.
65
+ Construct a single result object from the prediction.
52
66
 
53
67
  Args:
54
68
  pred (np.ndarray): The predicted bounding boxes, scores, and masks.
@@ -58,7 +72,7 @@ class SegmentationPredictor(DetectionPredictor):
58
72
  proto (torch.Tensor): The prototype masks.
59
73
 
60
74
  Returns:
61
- (Results): The result object containing the original image, image path, class names, bounding boxes, and masks.
75
+ (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
62
76
  """
63
77
  if not len(pred): # save empty boxes
64
78
  masks = None
@@ -69,6 +83,6 @@ class SegmentationPredictor(DetectionPredictor):
69
83
  masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
70
84
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
71
85
  if masks is not None:
72
- keep = masks.sum((-2, -1)) > 0 # only keep preds with masks
86
+ keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
73
87
  pred, masks = pred[keep], masks[keep]
74
88
  return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
@@ -12,6 +12,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
12
12
  """
13
13
  A class extending the DetectionTrainer class for training based on a segmentation model.
14
14
 
15
+ This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
16
+ functionality including model initialization, validation, and visualization.
17
+
18
+ Attributes:
19
+ loss_names (Tuple[str]): Names of the loss components used during training.
20
+
15
21
  Examples:
16
22
  >>> from ultralytics.models.yolo.segment import SegmentationTrainer
17
23
  >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)