ultralytics 8.3.101__py3-none-any.whl → 8.3.103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/test_exports.py +14 -5
  2. tests/test_solutions.py +140 -76
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/cfg/__init__.py +1 -1
  5. ultralytics/engine/exporter.py +23 -8
  6. ultralytics/engine/tuner.py +8 -2
  7. ultralytics/hub/__init__.py +29 -2
  8. ultralytics/hub/google/__init__.py +18 -1
  9. ultralytics/models/fastsam/predict.py +12 -1
  10. ultralytics/models/nas/predict.py +21 -3
  11. ultralytics/models/rtdetr/val.py +26 -2
  12. ultralytics/models/sam/amg.py +22 -1
  13. ultralytics/models/sam/modules/encoders.py +85 -4
  14. ultralytics/models/sam/modules/memory_attention.py +61 -3
  15. ultralytics/models/sam/modules/utils.py +108 -5
  16. ultralytics/models/utils/loss.py +38 -2
  17. ultralytics/models/utils/ops.py +15 -1
  18. ultralytics/models/yolo/classify/predict.py +11 -1
  19. ultralytics/models/yolo/classify/train.py +17 -1
  20. ultralytics/models/yolo/classify/val.py +82 -6
  21. ultralytics/models/yolo/detect/predict.py +20 -1
  22. ultralytics/models/yolo/model.py +55 -4
  23. ultralytics/models/yolo/obb/predict.py +16 -1
  24. ultralytics/models/yolo/obb/train.py +35 -2
  25. ultralytics/models/yolo/obb/val.py +87 -6
  26. ultralytics/models/yolo/pose/predict.py +18 -1
  27. ultralytics/models/yolo/pose/train.py +48 -3
  28. ultralytics/models/yolo/pose/val.py +113 -8
  29. ultralytics/models/yolo/segment/predict.py +27 -2
  30. ultralytics/models/yolo/segment/train.py +61 -3
  31. ultralytics/models/yolo/segment/val.py +10 -1
  32. ultralytics/models/yolo/world/train_world.py +29 -1
  33. ultralytics/models/yolo/yoloe/train.py +47 -3
  34. ultralytics/nn/autobackend.py +9 -8
  35. ultralytics/nn/modules/activation.py +26 -3
  36. ultralytics/nn/modules/block.py +89 -0
  37. ultralytics/nn/modules/head.py +3 -92
  38. ultralytics/nn/modules/utils.py +70 -4
  39. ultralytics/nn/tasks.py +3 -0
  40. ultralytics/nn/text_model.py +93 -17
  41. ultralytics/solutions/instance_segmentation.py +15 -7
  42. ultralytics/solutions/solutions.py +2 -47
  43. ultralytics/utils/benchmarks.py +1 -1
  44. ultralytics/utils/callbacks/base.py +22 -5
  45. ultralytics/utils/callbacks/comet.py +93 -5
  46. ultralytics/utils/callbacks/dvc.py +64 -5
  47. ultralytics/utils/callbacks/neptune.py +25 -2
  48. ultralytics/utils/callbacks/tensorboard.py +30 -2
  49. ultralytics/utils/callbacks/wb.py +16 -1
  50. ultralytics/utils/dist.py +35 -2
  51. ultralytics/utils/errors.py +27 -6
  52. ultralytics/utils/metrics.py +1 -1
  53. ultralytics/utils/patches.py +33 -5
  54. ultralytics/utils/torch_utils.py +14 -6
  55. ultralytics/utils/triton.py +16 -3
  56. ultralytics/utils/tuner.py +17 -9
  57. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/METADATA +3 -4
  58. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/RECORD +62 -62
  59. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/WHEEL +0 -0
  60. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/entry_points.txt +0 -0
  61. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/licenses/LICENSE +0 -0
  62. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,24 @@ class YOLO(Model):
22
22
  """YOLO (You Only Look Once) object detection model."""
23
23
 
24
24
  def __init__(self, model="yolo11n.pt", task=None, verbose=False):
25
- """Initialize YOLO model, switching to YOLOWorld/YOLOE if model filename contains '-world'/'yoloe'."""
25
+ """
26
+ Initialize a YOLO model.
27
+
28
+ This constructor initializes a YOLO model, automatically switching to specialized model types
29
+ (YOLOWorld or YOLOE) based on the model filename.
30
+
31
+ Args:
32
+ model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolov8n.yaml'.
33
+ task (str | None): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
34
+ Defaults to auto-detection based on model.
35
+ verbose (bool): Display model info on load.
36
+
37
+ Examples:
38
+ >>> from ultralytics import YOLO
39
+ >>> model = YOLO("yolov8n.pt") # load a pretrained YOLOv8n detection model
40
+ >>> model = YOLO("yolov8n-seg.pt") # load a pretrained YOLOv8n segmentation model
41
+ >>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
42
+ """
26
43
  path = Path(model)
27
44
  if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
28
45
  new_instance = YOLOWorld(path, verbose=verbose)
@@ -166,12 +183,46 @@ class YOLOE(Model):
166
183
  return self.model.get_text_pe(texts)
167
184
 
168
185
  def get_visual_pe(self, img, visual):
169
- """Get visual positional embeddings for the given image and visual features."""
186
+ """
187
+ Get visual positional embeddings for the given image and visual features.
188
+
189
+ This method extracts positional embeddings from visual features based on the input image. It requires
190
+ that the model is an instance of YOLOEModel.
191
+
192
+ Args:
193
+ img (torch.Tensor): Input image tensor.
194
+ visual (torch.Tensor): Visual features extracted from the image.
195
+
196
+ Returns:
197
+ (torch.Tensor): Visual positional embeddings.
198
+
199
+ Examples:
200
+ >>> model = YOLOE("yoloe-v8s.pt")
201
+ >>> img = torch.rand(1, 3, 640, 640)
202
+ >>> visual_features = model.model.backbone(img)
203
+ >>> pe = model.get_visual_pe(img, visual_features)
204
+ """
170
205
  assert isinstance(self.model, YOLOEModel)
171
206
  return self.model.get_visual_pe(img, visual)
172
207
 
173
208
  def set_vocab(self, vocab, names):
174
- """Set vocabulary and class names for the model."""
209
+ """
210
+ Set vocabulary and class names for the YOLOE model.
211
+
212
+ This method configures the vocabulary and class names used by the model for text processing and
213
+ classification tasks. The model must be an instance of YOLOEModel.
214
+
215
+ Args:
216
+ vocab (list): Vocabulary list containing tokens or words used by the model for text processing.
217
+ names (list): List of class names that the model can detect or classify.
218
+
219
+ Raises:
220
+ AssertionError: If the model is not an instance of YOLOEModel.
221
+
222
+ Examples:
223
+ >>> model = YOLOE("yoloe-v8s.pt")
224
+ >>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
225
+ """
175
226
  assert isinstance(self.model, YOLOEModel)
176
227
  self.model.set_vocab(vocab, names=names)
177
228
 
@@ -290,7 +341,7 @@ class YOLOE(Model):
290
341
 
291
342
  self.predictor.setup_model(model=self.model)
292
343
 
293
- if refer_image is None:
344
+ if refer_image is None and source:
294
345
  dataset = load_inference_source(source)
295
346
  if dataset.mode in {"video", "stream"}:
296
347
  # NOTE: set the first frame as refer image for videos/streams inference
@@ -27,7 +27,22 @@ class OBBPredictor(DetectionPredictor):
27
27
  """
28
28
 
29
29
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
30
- """Initialize OBBPredictor with optional model and data configuration overrides."""
30
+ """
31
+ Initialize OBBPredictor with optional model and data configuration overrides.
32
+
33
+ This constructor sets up an OBBPredictor instance for oriented bounding box detection tasks.
34
+
35
+ Args:
36
+ cfg (dict, optional): Default configuration for the predictor.
37
+ overrides (dict, optional): Configuration overrides that take precedence over the default config.
38
+ _callbacks (list, optional): List of callback functions to be invoked during prediction.
39
+
40
+ Examples:
41
+ >>> from ultralytics.utils import ASSETS
42
+ >>> from ultralytics.models.yolo.obb import OBBPredictor
43
+ >>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
44
+ >>> predictor = OBBPredictor(overrides=args)
45
+ """
31
46
  super().__init__(cfg, overrides, _callbacks)
32
47
  self.args.task = "obb"
33
48
 
@@ -26,14 +26,47 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
26
26
  """
27
27
 
28
28
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
29
- """Initialize a OBBTrainer object with given arguments."""
29
+ """
30
+ Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
31
+
32
+ This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
33
+ bounding boxes. It automatically sets the task to 'obb' in the configuration.
34
+
35
+ Args:
36
+ cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
37
+ model configuration.
38
+ overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
39
+ will take precedence over those in cfg.
40
+ _callbacks (list, optional): List of callback functions to be invoked during training.
41
+
42
+ Examples:
43
+ >>> from ultralytics.models.yolo.obb import OBBTrainer
44
+ >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
45
+ >>> trainer = OBBTrainer(overrides=args)
46
+ >>> trainer.train()
47
+ """
30
48
  if overrides is None:
31
49
  overrides = {}
32
50
  overrides["task"] = "obb"
33
51
  super().__init__(cfg, overrides, _callbacks)
34
52
 
35
53
  def get_model(self, cfg=None, weights=None, verbose=True):
36
- """Return OBBModel initialized with specified config and weights."""
54
+ """
55
+ Return OBBModel initialized with specified config and weights.
56
+
57
+ Args:
58
+ cfg (str | dict | None): Model configuration. Can be a path to a YAML config file, a dictionary
59
+ containing configuration parameters, or None to use default configuration.
60
+ weights (str | Path | None): Path to pretrained weights file. If None, random initialization is used.
61
+ verbose (bool): Whether to display model information during initialization.
62
+
63
+ Returns:
64
+ (OBBModel): Initialized OBBModel with the specified configuration and weights.
65
+
66
+ Examples:
67
+ >>> trainer = OBBTrainer()
68
+ >>> model = trainer.get_model(cfg="yolov8n-obb.yaml", weights="yolov8n-obb.pt")
69
+ """
37
70
  model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
38
71
  if weights:
39
72
  model.load(weights)
@@ -40,7 +40,19 @@ class OBBValidator(DetectionValidator):
40
40
  """
41
41
 
42
42
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
43
- """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
43
+ """
44
+ Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
45
+
46
+ This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
47
+ It extends the DetectionValidator class and configures it specifically for the OBB task.
48
+
49
+ Args:
50
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
51
+ save_dir (str | Path, optional): Directory to save results.
52
+ pbar (bool, optional): Display progress bar during validation.
53
+ args (dict, optional): Arguments containing validation parameters.
54
+ _callbacks (list, optional): List of callback functions to be called during validation.
55
+ """
44
56
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
45
57
  self.args.task = "obb"
46
58
  self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
@@ -79,7 +91,22 @@ class OBBValidator(DetectionValidator):
79
91
  return self.match_predictions(detections[:, 5], gt_cls, iou)
80
92
 
81
93
  def _prepare_batch(self, si, batch):
82
- """Prepare batch data for OBB validation with proper scaling and formatting."""
94
+ """
95
+ Prepare batch data for OBB validation with proper scaling and formatting.
96
+
97
+ Args:
98
+ si (int): Batch index to process.
99
+ batch (dict): Dictionary containing batch data with keys:
100
+ - batch_idx: Tensor of batch indices
101
+ - cls: Tensor of class labels
102
+ - bboxes: Tensor of bounding boxes
103
+ - ori_shape: Original image shapes
104
+ - img: Batch of images
105
+ - ratio_pad: Ratio and padding information
106
+
107
+ This method filters the batch data for a specific batch index, extracts class labels and bounding boxes,
108
+ and scales the bounding boxes to the original image dimensions.
109
+ """
83
110
  idx = batch["batch_idx"] == si
84
111
  cls = batch["cls"][idx].squeeze(-1)
85
112
  bbox = batch["bboxes"][idx]
@@ -92,7 +119,22 @@ class OBBValidator(DetectionValidator):
92
119
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
93
120
 
94
121
  def _prepare_pred(self, pred, pbatch):
95
- """Prepare predictions by scaling bounding boxes to original image dimensions."""
122
+ """
123
+ Prepare predictions by scaling bounding boxes to original image dimensions.
124
+
125
+ This method takes prediction tensors containing bounding box coordinates and scales them from the model's
126
+ input dimensions to the original image dimensions using the provided batch information.
127
+
128
+ Args:
129
+ pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
130
+ pbatch (dict): Dictionary containing batch information with keys:
131
+ - imgsz (tuple): Model input image size.
132
+ - ori_shape (tuple): Original image shape.
133
+ - ratio_pad (tuple): Ratio and padding information for scaling.
134
+
135
+ Returns:
136
+ (torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
137
+ """
96
138
  predn = pred.clone()
97
139
  ops.scale_boxes(
98
140
  pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
@@ -100,7 +142,20 @@ class OBBValidator(DetectionValidator):
100
142
  return predn
101
143
 
102
144
  def plot_predictions(self, batch, preds, ni):
103
- """Plot predicted bounding boxes on input images and save the result."""
145
+ """
146
+ Plot predicted bounding boxes on input images and save the result.
147
+
148
+ Args:
149
+ batch (dict): Batch data containing images, file paths, and other metadata.
150
+ preds (list): List of prediction tensors for each image in the batch.
151
+ ni (int): Batch index used for naming the output file.
152
+
153
+ Examples:
154
+ >>> validator = OBBValidator()
155
+ >>> batch = {"img": images, "im_file": paths}
156
+ >>> preds = [torch.rand(10, 7)] # Example predictions for one image
157
+ >>> validator.plot_predictions(batch, preds, 0)
158
+ """
104
159
  plot_images(
105
160
  batch["img"],
106
161
  *output_to_rotated_target(preds, max_det=self.args.max_det),
@@ -111,7 +166,19 @@ class OBBValidator(DetectionValidator):
111
166
  ) # pred
112
167
 
113
168
  def pred_to_json(self, predn, filename):
114
- """Convert YOLO predictions to COCO JSON format with rotated bounding box information."""
169
+ """
170
+ Convert YOLO predictions to COCO JSON format with rotated bounding box information.
171
+
172
+ Args:
173
+ predn (torch.Tensor): Prediction tensor containing bounding box coordinates, confidence scores,
174
+ class predictions, and rotation angles with shape (N, 6+) where the last column is the angle.
175
+ filename (str | Path): Path to the image file for which predictions are being processed.
176
+
177
+ Notes:
178
+ This method processes rotated bounding box predictions and converts them to both rbox format
179
+ (x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
180
+ to the JSON dictionary.
181
+ """
115
182
  stem = Path(filename).stem
116
183
  image_id = int(stem) if stem.isnumeric() else stem
117
184
  rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
@@ -128,7 +195,21 @@ class OBBValidator(DetectionValidator):
128
195
  )
129
196
 
130
197
  def save_one_txt(self, predn, save_conf, shape, file):
131
- """Save YOLO detections to a txt file in normalized coordinates using the Results class."""
198
+ """
199
+ Save YOLO OBB (Oriented Bounding Box) detections to a text file in normalized coordinates.
200
+
201
+ Args:
202
+ predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
203
+ class predictions, and angles in format (x, y, w, h, conf, cls, angle).
204
+ save_conf (bool): Whether to save confidence scores in the text file.
205
+ shape (tuple): Original image shape in format (height, width).
206
+ file (Path | str): Output file path to save detections.
207
+
208
+ Examples:
209
+ >>> validator = OBBValidator()
210
+ >>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
211
+ >>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
212
+ """
132
213
  import numpy as np
133
214
 
134
215
  from ultralytics.engine.results import Results
@@ -27,7 +27,24 @@ class PosePredictor(DetectionPredictor):
27
27
  """
28
28
 
29
29
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
30
- """Initialize PosePredictor, set task to 'pose' and log a warning for using 'mps' as device."""
30
+ """
31
+ Initialize PosePredictor, a specialized predictor for pose estimation tasks.
32
+
33
+ This initializer sets up a PosePredictor instance, configuring it for pose detection tasks and handling
34
+ device-specific warnings for Apple MPS.
35
+
36
+ Args:
37
+ cfg (Any): Configuration for the predictor. Default is DEFAULT_CFG.
38
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
39
+ _callbacks (list, optional): List of callback functions to be invoked during prediction.
40
+
41
+ Examples:
42
+ >>> from ultralytics.utils import ASSETS
43
+ >>> from ultralytics.models.yolo.pose import PosePredictor
44
+ >>> args = dict(model="yolov8n-pose.pt", source=ASSETS)
45
+ >>> predictor = PosePredictor(overrides=args)
46
+ >>> predictor.predict_cli()
47
+ """
31
48
  super().__init__(cfg, overrides, _callbacks)
32
49
  self.args.task = "pose"
33
50
  if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
@@ -36,7 +36,27 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
36
36
  """
37
37
 
38
38
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
- """Initialize a PoseTrainer object with specified configurations and overrides."""
39
+ """
40
+ Initialize a PoseTrainer object for training YOLO pose estimation models.
41
+
42
+ This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
43
+ handling specific configurations needed for keypoint detection models.
44
+
45
+ Args:
46
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
47
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
48
+ _callbacks (list, optional): List of callback functions to be executed during training.
49
+
50
+ Notes:
51
+ This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
52
+ A warning is issued when using Apple MPS device due to known bugs with pose models.
53
+
54
+ Examples:
55
+ >>> from ultralytics.models.yolo.pose import PoseTrainer
56
+ >>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
57
+ >>> trainer = PoseTrainer(overrides=args)
58
+ >>> trainer.train()
59
+ """
40
60
  if overrides is None:
41
61
  overrides = {}
42
62
  overrides["task"] = "pose"
@@ -49,7 +69,17 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
49
69
  )
50
70
 
51
71
  def get_model(self, cfg=None, weights=None, verbose=True):
52
- """Get pose estimation model with specified configuration and weights."""
72
+ """
73
+ Get pose estimation model with specified configuration and weights.
74
+
75
+ Args:
76
+ cfg (str | Path | dict | None): Model configuration file path or dictionary.
77
+ weights (str | Path | None): Path to the model weights file.
78
+ verbose (bool): Whether to display model information.
79
+
80
+ Returns:
81
+ (PoseModel): Initialized pose estimation model.
82
+ """
53
83
  model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
54
84
  if weights:
55
85
  model.load(weights)
@@ -69,7 +99,22 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
69
99
  )
70
100
 
71
101
  def plot_training_samples(self, batch, ni):
72
- """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
102
+ """
103
+ Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
104
+
105
+ Args:
106
+ batch (dict): Dictionary containing batch data with the following keys:
107
+ - img (torch.Tensor): Batch of images
108
+ - keypoints (torch.Tensor): Keypoints coordinates for pose estimation
109
+ - cls (torch.Tensor): Class labels
110
+ - bboxes (torch.Tensor): Bounding box coordinates
111
+ - im_file (list): List of image file paths
112
+ - batch_idx (torch.Tensor): Batch indices for each instance
113
+ ni (int): Current training iteration number used for filename
114
+
115
+ The function saves the plotted batch as an image in the trainer's save directory with the filename
116
+ 'train_batch{ni}.jpg', where ni is the iteration number.
117
+ """
73
118
  images = batch["img"]
74
119
  kpts = batch["keypoints"]
75
120
  cls = batch["cls"].squeeze(-1)
@@ -20,7 +20,7 @@ class PoseValidator(DetectionValidator):
20
20
  specialized metrics for pose evaluation.
21
21
 
22
22
  Attributes:
23
- sigma (np.ndarray): Sigma values for OKS calculation, either from OKS_SIGMA or ones divided by number of keypoints.
23
+ sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
24
24
  kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
25
25
  args (dict): Arguments for the validator including task set to "pose".
26
26
  metrics (PoseMetrics): Metrics object for pose evaluation.
@@ -47,7 +47,30 @@ class PoseValidator(DetectionValidator):
47
47
  """
48
48
 
49
49
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
50
- """Initialize a PoseValidator object with custom parameters and assigned attributes."""
50
+ """
51
+ Initialize a PoseValidator object for pose estimation validation.
52
+
53
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
54
+ specialized metrics for pose evaluation.
55
+
56
+ Args:
57
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
58
+ save_dir (Path | str, optional): Directory to save results.
59
+ pbar (Any, optional): Progress bar for displaying progress.
60
+ args (dict, optional): Arguments for the validator including task set to "pose".
61
+ _callbacks (list, optional): List of callback functions to be executed during validation.
62
+
63
+ Examples:
64
+ >>> from ultralytics.models.yolo.pose import PoseValidator
65
+ >>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
66
+ >>> validator = PoseValidator(args=args)
67
+ >>> validator()
68
+
69
+ Notes:
70
+ This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
71
+ for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
72
+ due to a known bug with pose models.
73
+ """
51
74
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
52
75
  self.sigma = None
53
76
  self.kpt_shape = None
@@ -91,7 +114,20 @@ class PoseValidator(DetectionValidator):
91
114
  self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
92
115
 
93
116
  def _prepare_batch(self, si, batch):
94
- """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions."""
117
+ """
118
+ Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
119
+
120
+ Args:
121
+ si (int): Batch index.
122
+ batch (dict): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
123
+
124
+ Returns:
125
+ pbatch (dict): Prepared batch with keypoints scaled to original image dimensions.
126
+
127
+ Notes:
128
+ This method extends the parent class's _prepare_batch method by adding keypoint processing.
129
+ Keypoints are scaled from normalized coordinates to original image dimensions.
130
+ """
95
131
  pbatch = super()._prepare_batch(si, batch)
96
132
  kpts = batch["keypoints"][batch["batch_idx"] == si]
97
133
  h, w = pbatch["imgsz"]
@@ -103,7 +139,23 @@ class PoseValidator(DetectionValidator):
103
139
  return pbatch
104
140
 
105
141
  def _prepare_pred(self, pred, pbatch):
106
- """Prepare and scale keypoints in predictions for pose processing."""
142
+ """
143
+ Prepare and scale keypoints in predictions for pose processing.
144
+
145
+ This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
146
+ the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
147
+ to match the original image dimensions.
148
+
149
+ Args:
150
+ pred (torch.Tensor): Raw prediction tensor from the model.
151
+ pbatch (dict): Processed batch dictionary containing image information including:
152
+ - imgsz: Image size used for inference
153
+ - ori_shape: Original image shape
154
+ - ratio_pad: Ratio and padding information for coordinate scaling
155
+
156
+ Returns:
157
+ predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
158
+ """
107
159
  predn = super()._prepare_pred(pred, pbatch)
108
160
  nk = pbatch["kpts"].shape[1]
109
161
  pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
@@ -204,7 +256,19 @@ class PoseValidator(DetectionValidator):
204
256
  return self.match_predictions(detections[:, 5], gt_cls, iou)
205
257
 
206
258
  def plot_val_samples(self, batch, ni):
207
- """Plot and save validation set samples with ground truth bounding boxes and keypoints."""
259
+ """
260
+ Plot and save validation set samples with ground truth bounding boxes and keypoints.
261
+
262
+ Args:
263
+ batch (dict): Dictionary containing batch data with keys:
264
+ - img (torch.Tensor): Batch of images
265
+ - batch_idx (torch.Tensor): Batch indices for each image
266
+ - cls (torch.Tensor): Class labels
267
+ - bboxes (torch.Tensor): Bounding box coordinates
268
+ - keypoints (torch.Tensor): Keypoint coordinates
269
+ - im_file (list): List of image file paths
270
+ ni (int): Batch index used for naming the output file
271
+ """
208
272
  plot_images(
209
273
  batch["img"],
210
274
  batch["batch_idx"],
@@ -218,7 +282,18 @@ class PoseValidator(DetectionValidator):
218
282
  )
219
283
 
220
284
  def plot_predictions(self, batch, preds, ni):
221
- """Plot and save model predictions with bounding boxes and keypoints."""
285
+ """
286
+ Plot and save model predictions with bounding boxes and keypoints.
287
+
288
+ Args:
289
+ batch (dict): Dictionary containing batch data including images, file paths, and other metadata.
290
+ preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
291
+ confidence scores, class predictions, and keypoints.
292
+ ni (int): Batch index used for naming the output file.
293
+
294
+ The function extracts keypoints from predictions, converts predictions to target format, and plots them
295
+ on the input images. The resulting visualization is saved to the specified save directory.
296
+ """
222
297
  pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
223
298
  plot_images(
224
299
  batch["img"],
@@ -231,7 +306,21 @@ class PoseValidator(DetectionValidator):
231
306
  ) # pred
232
307
 
233
308
  def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
234
- """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
309
+ """
310
+ Save YOLO pose detections to a text file in normalized coordinates.
311
+
312
+ Args:
313
+ predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
314
+ pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
315
+ and D is the dimension (typically 3 for x, y, visibility).
316
+ save_conf (bool): Whether to save confidence scores.
317
+ shape (tuple): Original image shape (height, width).
318
+ file (Path): Output file path to save detections.
319
+
320
+ Notes:
321
+ The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
322
+ normalized (x, y, visibility) values for each point.
323
+ """
235
324
  from ultralytics.engine.results import Results
236
325
 
237
326
  Results(
@@ -243,7 +332,23 @@ class PoseValidator(DetectionValidator):
243
332
  ).save_txt(file, save_conf=save_conf)
244
333
 
245
334
  def pred_to_json(self, predn, filename):
246
- """Convert YOLO predictions to COCO JSON format."""
335
+ """
336
+ Convert YOLO predictions to COCO JSON format.
337
+
338
+ This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
339
+ to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
340
+
341
+ Args:
342
+ predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
343
+ and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
344
+ keypoints dimension.
345
+ filename (str | Path): Path to the image file for which predictions are being processed.
346
+
347
+ Notes:
348
+ The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
349
+ converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
350
+ before saving to the JSON dictionary.
351
+ """
247
352
  stem = Path(filename).stem
248
353
  image_id = int(stem) if stem.isnumeric() else stem
249
354
  box = ops.xyxy2xywh(predn[:, :4]) # xywh
@@ -31,12 +31,37 @@ class SegmentationPredictor(DetectionPredictor):
31
31
  """
32
32
 
33
33
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
34
- """Initialize the SegmentationPredictor with configuration, overrides, and callbacks."""
34
+ """
35
+ Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
36
+
37
+ This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
38
+ prediction results.
39
+
40
+ Args:
41
+ cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
42
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
43
+ _callbacks (list, optional): List of callback functions to be invoked during prediction.
44
+ """
35
45
  super().__init__(cfg, overrides, _callbacks)
36
46
  self.args.task = "segment"
37
47
 
38
48
  def postprocess(self, preds, img, orig_imgs):
39
- """Apply non-max suppression and process detections for each image in the input batch."""
49
+ """
50
+ Apply non-max suppression and process segmentation detections for each image in the input batch.
51
+
52
+ Args:
53
+ preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
54
+ img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
55
+ orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
56
+
57
+ Returns:
58
+ (list): List of Results objects containing the segmentation predictions for each image in the batch.
59
+ Each Results object includes both bounding boxes and segmentation masks.
60
+
61
+ Examples:
62
+ >>> predictor = SegmentationPredictor(overrides=dict(model="yolov8n-seg.pt"))
63
+ >>> results = predictor.postprocess(preds, img, orig_img)
64
+ """
40
65
  # Extract protos - tuple if PyTorch model or array if exported
41
66
  protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
42
67
  return super().postprocess(preds[0], img, orig_imgs, protos=protos)