ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +125 -39
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +34 -33
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +33 -47
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +69 -90
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +31 -38
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +21 -26
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +23 -17
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +29 -24
  53. ultralytics/models/nas/predict.py +14 -11
  54. ultralytics/models/nas/val.py +11 -13
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +21 -21
  57. ultralytics/models/rtdetr/train.py +25 -24
  58. ultralytics/models/rtdetr/val.py +47 -14
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +30 -12
  73. ultralytics/models/yolo/classify/train.py +83 -19
  74. ultralytics/models/yolo/classify/val.py +45 -23
  75. ultralytics/models/yolo/detect/predict.py +29 -19
  76. ultralytics/models/yolo/detect/train.py +90 -23
  77. ultralytics/models/yolo/detect/val.py +150 -29
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +18 -13
  80. ultralytics/models/yolo/obb/train.py +12 -8
  81. ultralytics/models/yolo/obb/val.py +35 -22
  82. ultralytics/models/yolo/pose/predict.py +28 -15
  83. ultralytics/models/yolo/pose/train.py +21 -8
  84. ultralytics/models/yolo/pose/val.py +51 -31
  85. ultralytics/models/yolo/segment/predict.py +27 -16
  86. ultralytics/models/yolo/segment/train.py +11 -8
  87. ultralytics/models/yolo/segment/val.py +110 -29
  88. ultralytics/models/yolo/world/train.py +43 -16
  89. ultralytics/models/yolo/world/train_world.py +61 -36
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +12 -12
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +226 -79
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +37 -35
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +139 -68
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +37 -56
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +117 -52
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +65 -61
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +72 -59
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +202 -64
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +13 -25
  149. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.88.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -18,18 +18,40 @@ class DetectionValidator(BaseValidator):
18
18
  """
19
19
  A class extending the BaseValidator class for validation based on a detection model.
20
20
 
21
- Example:
22
- ```python
23
- from ultralytics.models.yolo.detect import DetectionValidator
24
-
25
- args = dict(model="yolo11n.pt", data="coco8.yaml")
26
- validator = DetectionValidator(args=args)
27
- validator()
28
- ```
21
+ This class implements validation functionality specific to object detection tasks, including metrics calculation,
22
+ prediction processing, and visualization of results.
23
+
24
+ Attributes:
25
+ nt_per_class (np.ndarray): Number of targets per class.
26
+ nt_per_image (np.ndarray): Number of targets per image.
27
+ is_coco (bool): Whether the dataset is COCO.
28
+ is_lvis (bool): Whether the dataset is LVIS.
29
+ class_map (List): Mapping from model class indices to dataset class indices.
30
+ metrics (DetMetrics): Object detection metrics calculator.
31
+ iouv (torch.Tensor): IoU thresholds for mAP calculation.
32
+ niou (int): Number of IoU thresholds.
33
+ lb (List): List for storing ground truth labels for hybrid saving.
34
+ jdict (List): List for storing JSON detection results.
35
+ stats (Dict): Dictionary for storing statistics during validation.
36
+
37
+ Examples:
38
+ >>> from ultralytics.models.yolo.detect import DetectionValidator
39
+ >>> args = dict(model="yolo11n.pt", data="coco8.yaml")
40
+ >>> validator = DetectionValidator(args=args)
41
+ >>> validator()
29
42
  """
30
43
 
31
44
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
32
- """Initialize detection model with necessary variables and settings."""
45
+ """
46
+ Initialize detection validator with necessary variables and settings.
47
+
48
+ Args:
49
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
50
+ save_dir (Path, optional): Directory to save results.
51
+ pbar (Any, optional): Progress bar for displaying progress.
52
+ args (Dict, optional): Arguments for the validator.
53
+ _callbacks (List, optional): List of callback functions.
54
+ """
33
55
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
34
56
  self.nt_per_class = None
35
57
  self.nt_per_image = None
@@ -48,7 +70,15 @@ class DetectionValidator(BaseValidator):
48
70
  )
49
71
 
50
72
  def preprocess(self, batch):
51
- """Preprocesses batch of images for YOLO training."""
73
+ """
74
+ Preprocess batch of images for YOLO validation.
75
+
76
+ Args:
77
+ batch (Dict): Batch containing images and annotations.
78
+
79
+ Returns:
80
+ (Dict): Preprocessed batch.
81
+ """
52
82
  batch["img"] = batch["img"].to(self.device, non_blocking=True)
53
83
  batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
54
84
  for k in ["batch_idx", "cls", "bboxes"]:
@@ -66,7 +96,12 @@ class DetectionValidator(BaseValidator):
66
96
  return batch
67
97
 
68
98
  def init_metrics(self, model):
69
- """Initialize evaluation metrics for YOLO."""
99
+ """
100
+ Initialize evaluation metrics for YOLO detection validation.
101
+
102
+ Args:
103
+ model (torch.nn.Module): Model to validate.
104
+ """
70
105
  val = self.data.get(self.args.split, "") # validation path
71
106
  self.is_coco = (
72
107
  isinstance(val, str)
@@ -91,7 +126,15 @@ class DetectionValidator(BaseValidator):
91
126
  return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
92
127
 
93
128
  def postprocess(self, preds):
94
- """Apply Non-maximum suppression to prediction outputs."""
129
+ """
130
+ Apply Non-maximum suppression to prediction outputs.
131
+
132
+ Args:
133
+ preds (torch.Tensor): Raw predictions from the model.
134
+
135
+ Returns:
136
+ (List[torch.Tensor]): Processed predictions after NMS.
137
+ """
95
138
  return ops.non_max_suppression(
96
139
  preds,
97
140
  self.args.conf,
@@ -106,7 +149,16 @@ class DetectionValidator(BaseValidator):
106
149
  )
107
150
 
108
151
  def _prepare_batch(self, si, batch):
109
- """Prepares a batch of images and annotations for validation."""
152
+ """
153
+ Prepare a batch of images and annotations for validation.
154
+
155
+ Args:
156
+ si (int): Batch index.
157
+ batch (Dict): Batch data containing images and annotations.
158
+
159
+ Returns:
160
+ (Dict): Prepared batch with processed annotations.
161
+ """
110
162
  idx = batch["batch_idx"] == si
111
163
  cls = batch["cls"][idx].squeeze(-1)
112
164
  bbox = batch["bboxes"][idx]
@@ -119,7 +171,16 @@ class DetectionValidator(BaseValidator):
119
171
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
120
172
 
121
173
  def _prepare_pred(self, pred, pbatch):
122
- """Prepares a batch of images and annotations for validation."""
174
+ """
175
+ Prepare predictions for evaluation against ground truth.
176
+
177
+ Args:
178
+ pred (torch.Tensor): Model predictions.
179
+ pbatch (Dict): Prepared batch information.
180
+
181
+ Returns:
182
+ (torch.Tensor): Prepared predictions in native space.
183
+ """
123
184
  predn = pred.clone()
124
185
  ops.scale_boxes(
125
186
  pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
@@ -127,7 +188,13 @@ class DetectionValidator(BaseValidator):
127
188
  return predn
128
189
 
129
190
  def update_metrics(self, preds, batch):
130
- """Metrics."""
191
+ """
192
+ Update metrics with new predictions and ground truth.
193
+
194
+ Args:
195
+ preds (List[torch.Tensor]): List of predictions from the model.
196
+ batch (Dict): Batch data containing ground truth.
197
+ """
131
198
  for si, pred in enumerate(preds):
132
199
  self.seen += 1
133
200
  npr = len(pred)
@@ -176,12 +243,23 @@ class DetectionValidator(BaseValidator):
176
243
  )
177
244
 
178
245
  def finalize_metrics(self, *args, **kwargs):
179
- """Set final values for metrics speed and confusion matrix."""
246
+ """
247
+ Set final values for metrics speed and confusion matrix.
248
+
249
+ Args:
250
+ *args (Any): Variable length argument list.
251
+ **kwargs (Any): Arbitrary keyword arguments.
252
+ """
180
253
  self.metrics.speed = self.speed
181
254
  self.metrics.confusion_matrix = self.confusion_matrix
182
255
 
183
256
  def get_stats(self):
184
- """Returns metrics statistics and results dictionary."""
257
+ """
258
+ Calculate and return metrics statistics.
259
+
260
+ Returns:
261
+ (Dict): Dictionary containing metrics results.
262
+ """
185
263
  stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
186
264
  self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
187
265
  self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
@@ -191,7 +269,7 @@ class DetectionValidator(BaseValidator):
191
269
  return self.metrics.results_dict
192
270
 
193
271
  def print_results(self):
194
- """Prints training/validation set metrics per class."""
272
+ """Print training/validation set metrics per class."""
195
273
  pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
196
274
  LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
197
275
  if self.nt_per_class.sum() == 0:
@@ -223,10 +301,6 @@ class DetectionValidator(BaseValidator):
223
301
 
224
302
  Returns:
225
303
  (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.
226
-
227
- Note:
228
- The function does not return any value directly usable for metrics calculation. Instead, it provides an
229
- intermediate representation used for evaluating predictions against ground truth.
230
304
  """
231
305
  iou = box_iou(gt_bboxes, detections[:, :4])
232
306
  return self.match_predictions(detections[:, 5], gt_cls, iou)
@@ -238,17 +312,35 @@ class DetectionValidator(BaseValidator):
238
312
  Args:
239
313
  img_path (str): Path to the folder containing images.
240
314
  mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
241
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
315
+ batch (int, optional): Size of batches, this is for `rect`.
316
+
317
+ Returns:
318
+ (Dataset): YOLO dataset.
242
319
  """
243
320
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
244
321
 
245
322
  def get_dataloader(self, dataset_path, batch_size):
246
- """Construct and return dataloader."""
323
+ """
324
+ Construct and return dataloader.
325
+
326
+ Args:
327
+ dataset_path (str): Path to the dataset.
328
+ batch_size (int): Size of each batch.
329
+
330
+ Returns:
331
+ (torch.utils.data.DataLoader): Dataloader for validation.
332
+ """
247
333
  dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
248
334
  return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
249
335
 
250
336
  def plot_val_samples(self, batch, ni):
251
- """Plot validation image samples."""
337
+ """
338
+ Plot validation image samples.
339
+
340
+ Args:
341
+ batch (Dict): Batch containing images and annotations.
342
+ ni (int): Batch index.
343
+ """
252
344
  plot_images(
253
345
  batch["img"],
254
346
  batch["batch_idx"],
@@ -261,7 +353,14 @@ class DetectionValidator(BaseValidator):
261
353
  )
262
354
 
263
355
  def plot_predictions(self, batch, preds, ni):
264
- """Plots predicted bounding boxes on input images and saves the result."""
356
+ """
357
+ Plot predicted bounding boxes on input images and save the result.
358
+
359
+ Args:
360
+ batch (Dict): Batch containing images and annotations.
361
+ preds (List[torch.Tensor]): List of predictions from the model.
362
+ ni (int): Batch index.
363
+ """
265
364
  plot_images(
266
365
  batch["img"],
267
366
  *output_to_target(preds, max_det=self.args.max_det),
@@ -272,7 +371,15 @@ class DetectionValidator(BaseValidator):
272
371
  ) # pred
273
372
 
274
373
  def save_one_txt(self, predn, save_conf, shape, file):
275
- """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
374
+ """
375
+ Save YOLO detections to a txt file in normalized coordinates in a specific format.
376
+
377
+ Args:
378
+ predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
379
+ save_conf (bool): Whether to save confidence scores.
380
+ shape (tuple): Shape of the original image.
381
+ file (Path): File path to save the detections.
382
+ """
276
383
  from ultralytics.engine.results import Results
277
384
 
278
385
  Results(
@@ -283,7 +390,13 @@ class DetectionValidator(BaseValidator):
283
390
  ).save_txt(file, save_conf=save_conf)
284
391
 
285
392
  def pred_to_json(self, predn, filename):
286
- """Serialize YOLO predictions to COCO json format."""
393
+ """
394
+ Serialize YOLO predictions to COCO json format.
395
+
396
+ Args:
397
+ predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
398
+ filename (str): Image filename.
399
+ """
287
400
  stem = Path(filename).stem
288
401
  image_id = int(stem) if stem.isnumeric() else stem
289
402
  box = ops.xyxy2xywh(predn[:, :4]) # xywh
@@ -299,7 +412,15 @@ class DetectionValidator(BaseValidator):
299
412
  )
300
413
 
301
414
  def eval_json(self, stats):
302
- """Evaluates YOLO output in JSON format and returns performance statistics."""
415
+ """
416
+ Evaluate YOLO output in JSON format and return performance statistics.
417
+
418
+ Args:
419
+ stats (Dict): Current statistics dictionary.
420
+
421
+ Returns:
422
+ (Dict): Updated statistics dictionary with COCO/LVIS evaluation results.
423
+ """
303
424
  if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
304
425
  pred_json = self.save_dir / "predictions.json" # predictions
305
426
  anno_json = (
@@ -93,7 +93,7 @@ class YOLOWorld(Model):
93
93
 
94
94
  def set_classes(self, classes):
95
95
  """
96
- Set classes.
96
+ Set the model's class names for detection.
97
97
 
98
98
  Args:
99
99
  classes (List(str)): A list of categories i.e. ["person"].
@@ -106,6 +106,5 @@ class YOLOWorld(Model):
106
106
  self.model.names = classes
107
107
 
108
108
  # Reset method class names
109
- # self.predictor = None # reset predictor otherwise old names remain
110
109
  if self.predictor:
111
110
  self.predictor.model.names = classes
@@ -11,29 +11,34 @@ class OBBPredictor(DetectionPredictor):
11
11
  """
12
12
  A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
13
13
 
14
- Example:
15
- ```python
16
- from ultralytics.utils import ASSETS
17
- from ultralytics.models.yolo.obb import OBBPredictor
18
-
19
- args = dict(model="yolo11n-obb.pt", source=ASSETS)
20
- predictor = OBBPredictor(overrides=args)
21
- predictor.predict_cli()
22
- ```
14
+ This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
15
+ bounding boxes.
16
+
17
+ Attributes:
18
+ args (namespace): Configuration arguments for the predictor.
19
+ model (torch.nn.Module): The loaded YOLO OBB model.
20
+
21
+ Examples:
22
+ >>> from ultralytics.utils import ASSETS
23
+ >>> from ultralytics.models.yolo.obb import OBBPredictor
24
+ >>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
25
+ >>> predictor = OBBPredictor(overrides=args)
26
+ >>> predictor.predict_cli()
23
27
  """
24
28
 
25
29
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
26
- """Initializes OBBPredictor with optional model and data configuration overrides."""
30
+ """Initialize OBBPredictor with optional model and data configuration overrides."""
27
31
  super().__init__(cfg, overrides, _callbacks)
28
32
  self.args.task = "obb"
29
33
 
30
34
  def construct_result(self, pred, img, orig_img, img_path):
31
35
  """
32
- Constructs the result object from the prediction.
36
+ Construct the result object from the prediction.
33
37
 
34
38
  Args:
35
- pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles.
36
- img (torch.Tensor): The image after preprocessing.
39
+ pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 6) where
40
+ the last dimension contains [x, y, w, h, confidence, class_id, angle].
41
+ img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
37
42
  orig_img (np.ndarray): The original image before preprocessing.
38
43
  img_path (str): The path to the original image.
39
44
 
@@ -11,14 +11,18 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
11
11
  """
12
12
  A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
13
13
 
14
- Example:
15
- ```python
16
- from ultralytics.models.yolo.obb import OBBTrainer
17
-
18
- args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
19
- trainer = OBBTrainer(overrides=args)
20
- trainer.train()
21
- ```
14
+ Attributes:
15
+ loss_names (Tuple[str]): Names of the loss components used during training.
16
+
17
+ Methods:
18
+ get_model: Return OBBModel initialized with specified config and weights.
19
+ get_validator: Return an instance of OBBValidator for validation of YOLO model.
20
+
21
+ Examples:
22
+ >>> from ultralytics.models.yolo.obb import OBBTrainer
23
+ >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
24
+ >>> trainer = OBBTrainer(overrides=args)
25
+ >>> trainer.train()
22
26
  """
23
27
 
24
28
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@@ -14,14 +14,29 @@ class OBBValidator(DetectionValidator):
14
14
  """
15
15
  A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
16
16
 
17
- Example:
18
- ```python
19
- from ultralytics.models.yolo.obb import OBBValidator
20
-
21
- args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
22
- validator = OBBValidator(args=args)
23
- validator(model=args["model"])
24
- ```
17
+ This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
18
+ satellite imagery where objects can appear at various orientations.
19
+
20
+ Attributes:
21
+ args (Dict): Configuration arguments for the validator.
22
+ metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
23
+ is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
24
+
25
+ Methods:
26
+ init_metrics: Initialize evaluation metrics for YOLO.
27
+ _process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
28
+ _prepare_batch: Prepare batch data for OBB validation.
29
+ _prepare_pred: Prepare predictions with scaled and padded bounding boxes.
30
+ plot_predictions: Plot predicted bounding boxes on input images.
31
+ pred_to_json: Serialize YOLO predictions to COCO json format.
32
+ save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
33
+ eval_json: Evaluate YOLO output in JSON format and return performance statistics.
34
+
35
+ Examples:
36
+ >>> from ultralytics.models.yolo.obb import OBBValidator
37
+ >>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
38
+ >>> validator = OBBValidator(args=args)
39
+ >>> validator(model=args["model"])
25
40
  """
26
41
 
27
42
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
@@ -34,7 +49,7 @@ class OBBValidator(DetectionValidator):
34
49
  """Initialize evaluation metrics for YOLO."""
35
50
  super().init_metrics(model)
36
51
  val = self.data.get(self.args.split, "") # validation path
37
- self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
52
+ self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
38
53
 
39
54
  def _process_batch(self, detections, gt_bboxes, gt_cls):
40
55
  """
@@ -51,13 +66,11 @@ class OBBValidator(DetectionValidator):
51
66
  (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
52
67
  Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
53
68
 
54
- Example:
55
- ```python
56
- detections = torch.rand(100, 7) # 100 sample detections
57
- gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
58
- gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
59
- correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
60
- ```
69
+ Examples:
70
+ >>> detections = torch.rand(100, 7) # 100 sample detections
71
+ >>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
72
+ >>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
73
+ >>> correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
61
74
 
62
75
  Note:
63
76
  This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
@@ -66,7 +79,7 @@ class OBBValidator(DetectionValidator):
66
79
  return self.match_predictions(detections[:, 5], gt_cls, iou)
67
80
 
68
81
  def _prepare_batch(self, si, batch):
69
- """Prepares and returns a batch for OBB validation."""
82
+ """Prepare batch data for OBB validation with proper scaling and formatting."""
70
83
  idx = batch["batch_idx"] == si
71
84
  cls = batch["cls"][idx].squeeze(-1)
72
85
  bbox = batch["bboxes"][idx]
@@ -79,7 +92,7 @@ class OBBValidator(DetectionValidator):
79
92
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
80
93
 
81
94
  def _prepare_pred(self, pred, pbatch):
82
- """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
95
+ """Prepare predictions by scaling bounding boxes to original image dimensions."""
83
96
  predn = pred.clone()
84
97
  ops.scale_boxes(
85
98
  pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
@@ -87,7 +100,7 @@ class OBBValidator(DetectionValidator):
87
100
  return predn
88
101
 
89
102
  def plot_predictions(self, batch, preds, ni):
90
- """Plots predicted bounding boxes on input images and saves the result."""
103
+ """Plot predicted bounding boxes on input images and save the result."""
91
104
  plot_images(
92
105
  batch["img"],
93
106
  *output_to_rotated_target(preds, max_det=self.args.max_det),
@@ -98,7 +111,7 @@ class OBBValidator(DetectionValidator):
98
111
  ) # pred
99
112
 
100
113
  def pred_to_json(self, predn, filename):
101
- """Serialize YOLO predictions to COCO json format."""
114
+ """Convert YOLO predictions to COCO JSON format with rotated bounding box information."""
102
115
  stem = Path(filename).stem
103
116
  image_id = int(stem) if stem.isnumeric() else stem
104
117
  rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
@@ -115,7 +128,7 @@ class OBBValidator(DetectionValidator):
115
128
  )
116
129
 
117
130
  def save_one_txt(self, predn, save_conf, shape, file):
118
- """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
131
+ """Save YOLO detections to a txt file in normalized coordinates using the Results class."""
119
132
  import numpy as np
120
133
 
121
134
  from ultralytics.engine.results import Results
@@ -131,7 +144,7 @@ class OBBValidator(DetectionValidator):
131
144
  ).save_txt(file, save_conf=save_conf)
132
145
 
133
146
  def eval_json(self, stats):
134
- """Evaluates YOLO output in JSON format and returns performance statistics."""
147
+ """Evaluate YOLO output in JSON format and save predictions in DOTA format."""
135
148
  if self.args.save_json and self.is_dota and len(self.jdict):
136
149
  import json
137
150
  import re
@@ -8,19 +8,26 @@ class PosePredictor(DetectionPredictor):
8
8
  """
9
9
  A class extending the DetectionPredictor class for prediction based on a pose model.
10
10
 
11
- Example:
12
- ```python
13
- from ultralytics.utils import ASSETS
14
- from ultralytics.models.yolo.pose import PosePredictor
15
-
16
- args = dict(model="yolo11n-pose.pt", source=ASSETS)
17
- predictor = PosePredictor(overrides=args)
18
- predictor.predict_cli()
19
- ```
11
+ This class specializes in pose estimation, handling keypoints detection alongside standard object detection
12
+ capabilities inherited from DetectionPredictor.
13
+
14
+ Attributes:
15
+ args (namespace): Configuration arguments for the predictor.
16
+ model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
17
+
18
+ Methods:
19
+ construct_result: Constructs the result object from the prediction, including keypoints.
20
+
21
+ Examples:
22
+ >>> from ultralytics.utils import ASSETS
23
+ >>> from ultralytics.models.yolo.pose import PosePredictor
24
+ >>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
25
+ >>> predictor = PosePredictor(overrides=args)
26
+ >>> predictor.predict_cli()
20
27
  """
21
28
 
22
29
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
23
- """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
30
+ """Initialize PosePredictor, set task to 'pose' and log a warning for using 'mps' as device."""
24
31
  super().__init__(cfg, overrides, _callbacks)
25
32
  self.args.task = "pose"
26
33
  if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
@@ -31,19 +38,25 @@ class PosePredictor(DetectionPredictor):
31
38
 
32
39
  def construct_result(self, pred, img, orig_img, img_path):
33
40
  """
34
- Constructs the result object from the prediction.
41
+ Construct the result object from the prediction, including keypoints.
42
+
43
+ This method extends the parent class implementation by extracting keypoint data from predictions
44
+ and adding them to the result object.
35
45
 
36
46
  Args:
37
- pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints.
38
- img (torch.Tensor): The image after preprocessing.
39
- orig_img (np.ndarray): The original image before preprocessing.
40
- img_path (str): The path to the original image.
47
+ pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
48
+ the number of detections, K is the number of keypoints, and D is the keypoint dimension.
49
+ img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
50
+ orig_img (np.ndarray): The original unprocessed image as a numpy array.
51
+ img_path (str): The path to the original image file.
41
52
 
42
53
  Returns:
43
54
  (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
44
55
  """
45
56
  result = super().construct_result(pred, img, orig_img, img_path)
57
+ # Extract keypoints from prediction and reshape according to model's keypoint shape
46
58
  pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
59
+ # Scale keypoints coordinates to match the original image dimensions
47
60
  pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
48
61
  result.update(keypoints=pred_kpts)
49
62
  return result
@@ -10,16 +10,29 @@ from ultralytics.utils.plotting import plot_images, plot_results
10
10
 
11
11
  class PoseTrainer(yolo.detect.DetectionTrainer):
12
12
  """
13
- A class extending the DetectionTrainer class for training based on a pose model.
13
+ A class extending the DetectionTrainer class for training YOLO pose estimation models.
14
14
 
15
- Example:
16
- ```python
17
- from ultralytics.models.yolo.pose import PoseTrainer
15
+ This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
16
+ of pose keypoints alongside bounding boxes.
18
17
 
19
- args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
20
- trainer = PoseTrainer(overrides=args)
21
- trainer.train()
22
- ```
18
+ Attributes:
19
+ args (Dict): Configuration arguments for training.
20
+ model (PoseModel): The pose estimation model being trained.
21
+ data (Dict): Dataset configuration including keypoint shape information.
22
+ loss_names (Tuple[str]): Names of the loss components used in training.
23
+
24
+ Methods:
25
+ get_model: Retrieves a pose estimation model with specified configuration.
26
+ set_model_attributes: Sets keypoints shape attribute on the model.
27
+ get_validator: Creates a validator instance for model evaluation.
28
+ plot_training_samples: Visualizes training samples with keypoints.
29
+ plot_metrics: Generates and saves training/validation metric plots.
30
+
31
+ Examples:
32
+ >>> from ultralytics.models.yolo.pose import PoseTrainer
33
+ >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
34
+ >>> trainer = PoseTrainer(overrides=args)
35
+ >>> trainer.train()
23
36
  """
24
37
 
25
38
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):