ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +125 -39
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +34 -33
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +33 -47
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +69 -90
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +31 -38
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +21 -26
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +23 -17
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +29 -24
  53. ultralytics/models/nas/predict.py +14 -11
  54. ultralytics/models/nas/val.py +11 -13
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +21 -21
  57. ultralytics/models/rtdetr/train.py +25 -24
  58. ultralytics/models/rtdetr/val.py +47 -14
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +30 -12
  73. ultralytics/models/yolo/classify/train.py +83 -19
  74. ultralytics/models/yolo/classify/val.py +45 -23
  75. ultralytics/models/yolo/detect/predict.py +29 -19
  76. ultralytics/models/yolo/detect/train.py +90 -23
  77. ultralytics/models/yolo/detect/val.py +150 -29
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +18 -13
  80. ultralytics/models/yolo/obb/train.py +12 -8
  81. ultralytics/models/yolo/obb/val.py +35 -22
  82. ultralytics/models/yolo/pose/predict.py +28 -15
  83. ultralytics/models/yolo/pose/train.py +21 -8
  84. ultralytics/models/yolo/pose/val.py +51 -31
  85. ultralytics/models/yolo/segment/predict.py +27 -16
  86. ultralytics/models/yolo/segment/train.py +11 -8
  87. ultralytics/models/yolo/segment/val.py +110 -29
  88. ultralytics/models/yolo/world/train.py +43 -16
  89. ultralytics/models/yolo/world/train_world.py +61 -36
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +12 -12
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +226 -79
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +37 -35
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +139 -68
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +37 -56
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +117 -52
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +65 -61
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +72 -59
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +202 -64
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +13 -25
  149. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.88.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -119,10 +119,10 @@ class BaseModel(torch.nn.Module):
119
119
 
120
120
  Args:
121
121
  x (torch.Tensor): The input tensor to the model.
122
- profile (bool): Print the computation time of each layer if True, defaults to False.
123
- visualize (bool): Save the feature maps of the model if True, defaults to False.
124
- augment (bool): Augment image during prediction, defaults to False.
125
- embed (list, optional): A list of feature vectors/embeddings to return.
122
+ profile (bool): Print the computation time of each layer if True.
123
+ visualize (bool): Save the feature maps of the model if True.
124
+ augment (bool): Augment image during prediction.
125
+ embed (List, optional): A list of feature vectors/embeddings to return.
126
126
 
127
127
  Returns:
128
128
  (torch.Tensor): The last output of the model.
@@ -137,9 +137,9 @@ class BaseModel(torch.nn.Module):
137
137
 
138
138
  Args:
139
139
  x (torch.Tensor): The input tensor to the model.
140
- profile (bool): Print the computation time of each layer if True, defaults to False.
141
- visualize (bool): Save the feature maps of the model if True, defaults to False.
142
- embed (list, optional): A list of feature vectors/embeddings to return.
140
+ profile (bool): Print the computation time of each layer if True.
141
+ visualize (bool): Save the feature maps of the model if True.
142
+ embed (List, optional): A list of feature vectors/embeddings to return.
143
143
 
144
144
  Returns:
145
145
  (torch.Tensor): The last output of the model.
@@ -170,13 +170,12 @@ class BaseModel(torch.nn.Module):
170
170
 
171
171
  def _profile_one_layer(self, m, x, dt):
172
172
  """
173
- Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to
174
- the provided list.
173
+ Profile the computation time and FLOPs of a single layer of the model on a given input.
175
174
 
176
175
  Args:
177
176
  m (torch.nn.Module): The layer to be profiled.
178
177
  x (torch.Tensor): The input data to the layer.
179
- dt (list): A list to store the computation time of the layer.
178
+ dt (List): A list to store the computation time of the layer.
180
179
  """
181
180
  c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
182
181
  flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
@@ -192,8 +191,8 @@ class BaseModel(torch.nn.Module):
192
191
 
193
192
  def fuse(self, verbose=True):
194
193
  """
195
- Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
196
- computation efficiency.
194
+ Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer for improved computation
195
+ efficiency.
197
196
 
198
197
  Returns:
199
198
  (torch.nn.Module): The fused model is returned.
@@ -225,7 +224,7 @@ class BaseModel(torch.nn.Module):
225
224
  Check if the model has less than a certain threshold of BatchNorm layers.
226
225
 
227
226
  Args:
228
- thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
227
+ thresh (int, optional): The threshold number of BatchNorm layers.
229
228
 
230
229
  Returns:
231
230
  (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
@@ -235,21 +234,21 @@ class BaseModel(torch.nn.Module):
235
234
 
236
235
  def info(self, detailed=False, verbose=True, imgsz=640):
237
236
  """
238
- Prints model information.
237
+ Print model information.
239
238
 
240
239
  Args:
241
- detailed (bool): if True, prints out detailed information about the model. Defaults to False
242
- verbose (bool): if True, prints out the model information. Defaults to False
243
- imgsz (int): the size of the image that the model will be trained on. Defaults to 640
240
+ detailed (bool): If True, prints out detailed information about the model.
241
+ verbose (bool): If True, prints out the model information.
242
+ imgsz (int): The size of the image that the model will be trained on.
244
243
  """
245
244
  return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
246
245
 
247
246
  def _apply(self, fn):
248
247
  """
249
- Applies a function to all the tensors in the model that are not parameters or registered buffers.
248
+ Apply a function to all tensors in the model that are not parameters or registered buffers.
250
249
 
251
250
  Args:
252
- fn (function): the function to apply to the model
251
+ fn (function): The function to apply to the model.
253
252
 
254
253
  Returns:
255
254
  (BaseModel): An updated BaseModel object.
@@ -264,11 +263,11 @@ class BaseModel(torch.nn.Module):
264
263
 
265
264
  def load(self, weights, verbose=True):
266
265
  """
267
- Load the weights into the model.
266
+ Load weights into the model.
268
267
 
269
268
  Args:
270
269
  weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
271
- verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
270
+ verbose (bool, optional): Whether to log the transfer progress.
272
271
  """
273
272
  model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
274
273
  csd = model.float().state_dict() # checkpoint state_dict as FP32
@@ -282,8 +281,8 @@ class BaseModel(torch.nn.Module):
282
281
  Compute loss.
283
282
 
284
283
  Args:
285
- batch (dict): Batch to compute loss on
286
- preds (torch.Tensor | List[torch.Tensor]): Predictions.
284
+ batch (dict): Batch to compute loss on.
285
+ preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
287
286
  """
288
287
  if getattr(self, "criterion", None) is None:
289
288
  self.criterion = self.init_criterion()
@@ -300,7 +299,15 @@ class DetectionModel(BaseModel):
300
299
  """YOLO detection model."""
301
300
 
302
301
  def __init__(self, cfg="yolo11n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
303
- """Initialize the YOLO detection model with the given config and parameters."""
302
+ """
303
+ Initialize the YOLO detection model with the given config and parameters.
304
+
305
+ Args:
306
+ cfg (str | dict): Model configuration file path or dictionary.
307
+ ch (int): Number of input channels.
308
+ nc (int, optional): Number of classes.
309
+ verbose (bool): Whether to display model information.
310
+ """
304
311
  super().__init__()
305
312
  self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
306
313
  if self.yaml["backbone"][0][2] == "Silence":
@@ -327,7 +334,7 @@ class DetectionModel(BaseModel):
327
334
  m.inplace = self.inplace
328
335
 
329
336
  def _forward(x):
330
- """Performs a forward pass through the model, handling different Detect subclass types accordingly."""
337
+ """Perform a forward pass through the model, handling different Detect subclass types accordingly."""
331
338
  if self.end2end:
332
339
  return self.forward(x)["one2many"]
333
340
  return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
@@ -345,7 +352,15 @@ class DetectionModel(BaseModel):
345
352
  LOGGER.info("")
346
353
 
347
354
  def _predict_augment(self, x):
348
- """Perform augmentations on input image x and return augmented inference and train outputs."""
355
+ """
356
+ Perform augmentations on input image x and return augmented inference and train outputs.
357
+
358
+ Args:
359
+ x (torch.Tensor): Input image tensor.
360
+
361
+ Returns:
362
+ (torch.Tensor): Augmented inference output.
363
+ """
349
364
  if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
350
365
  LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
351
366
  return self._predict_once(x)
@@ -363,7 +378,19 @@ class DetectionModel(BaseModel):
363
378
 
364
379
  @staticmethod
365
380
  def _descale_pred(p, flips, scale, img_size, dim=1):
366
- """De-scale predictions following augmented inference (inverse operation)."""
381
+ """
382
+ De-scale predictions following augmented inference (inverse operation).
383
+
384
+ Args:
385
+ p (torch.Tensor): Predictions tensor.
386
+ flips (int): Flip type (0=none, 2=ud, 3=lr).
387
+ scale (float): Scale factor.
388
+ img_size (tuple): Original image size (height, width).
389
+ dim (int): Dimension to split at.
390
+
391
+ Returns:
392
+ (torch.Tensor): De-scaled predictions.
393
+ """
367
394
  p[:, :4] /= scale # de-scale
368
395
  x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
369
396
  if flips == 2:
@@ -373,7 +400,15 @@ class DetectionModel(BaseModel):
373
400
  return torch.cat((x, y, wh, cls), dim)
374
401
 
375
402
  def _clip_augmented(self, y):
376
- """Clip YOLO augmented inference tails."""
403
+ """
404
+ Clip YOLO augmented inference tails.
405
+
406
+ Args:
407
+ y (List[torch.Tensor]): List of detection tensors.
408
+
409
+ Returns:
410
+ (List[torch.Tensor]): Clipped detection tensors.
411
+ """
377
412
  nl = self.model[-1].nl # number of detection layers (P3-P5)
378
413
  g = sum(4**x for x in range(nl)) # grid points
379
414
  e = 1 # exclude layer count
@@ -392,7 +427,15 @@ class OBBModel(DetectionModel):
392
427
  """YOLO Oriented Bounding Box (OBB) model."""
393
428
 
394
429
  def __init__(self, cfg="yolo11n-obb.yaml", ch=3, nc=None, verbose=True):
395
- """Initialize YOLO OBB model with given config and parameters."""
430
+ """
431
+ Initialize YOLO OBB model with given config and parameters.
432
+
433
+ Args:
434
+ cfg (str | dict): Model configuration file path or dictionary.
435
+ ch (int): Number of input channels.
436
+ nc (int, optional): Number of classes.
437
+ verbose (bool): Whether to display model information.
438
+ """
396
439
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
397
440
 
398
441
  def init_criterion(self):
@@ -404,7 +447,15 @@ class SegmentationModel(DetectionModel):
404
447
  """YOLO segmentation model."""
405
448
 
406
449
  def __init__(self, cfg="yolo11n-seg.yaml", ch=3, nc=None, verbose=True):
407
- """Initialize YOLOv8 segmentation model with given config and parameters."""
450
+ """
451
+ Initialize YOLOv8 segmentation model with given config and parameters.
452
+
453
+ Args:
454
+ cfg (str | dict): Model configuration file path or dictionary.
455
+ ch (int): Number of input channels.
456
+ nc (int, optional): Number of classes.
457
+ verbose (bool): Whether to display model information.
458
+ """
408
459
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
409
460
 
410
461
  def init_criterion(self):
@@ -416,7 +467,16 @@ class PoseModel(DetectionModel):
416
467
  """YOLO pose model."""
417
468
 
418
469
  def __init__(self, cfg="yolo11n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
419
- """Initialize YOLOv8 Pose model."""
470
+ """
471
+ Initialize YOLOv8 Pose model.
472
+
473
+ Args:
474
+ cfg (str | dict): Model configuration file path or dictionary.
475
+ ch (int): Number of input channels.
476
+ nc (int, optional): Number of classes.
477
+ data_kpt_shape (tuple): Shape of keypoints data.
478
+ verbose (bool): Whether to display model information.
479
+ """
420
480
  if not isinstance(cfg, dict):
421
481
  cfg = yaml_model_load(cfg) # load model YAML
422
482
  if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
@@ -433,12 +493,28 @@ class ClassificationModel(BaseModel):
433
493
  """YOLO classification model."""
434
494
 
435
495
  def __init__(self, cfg="yolo11n-cls.yaml", ch=3, nc=None, verbose=True):
436
- """Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
496
+ """
497
+ Initialize ClassificationModel with YAML, channels, number of classes, verbose flag.
498
+
499
+ Args:
500
+ cfg (str | dict): Model configuration file path or dictionary.
501
+ ch (int): Number of input channels.
502
+ nc (int, optional): Number of classes.
503
+ verbose (bool): Whether to display model information.
504
+ """
437
505
  super().__init__()
438
506
  self._from_yaml(cfg, ch, nc, verbose)
439
507
 
440
508
  def _from_yaml(self, cfg, ch, nc, verbose):
441
- """Set YOLOv8 model configurations and define the model architecture."""
509
+ """
510
+ Set YOLOv8 model configurations and define the model architecture.
511
+
512
+ Args:
513
+ cfg (str | dict): Model configuration file path or dictionary.
514
+ ch (int): Number of input channels.
515
+ nc (int, optional): Number of classes.
516
+ verbose (bool): Whether to display model information.
517
+ """
442
518
  self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
443
519
 
444
520
  # Define model
@@ -455,7 +531,13 @@ class ClassificationModel(BaseModel):
455
531
 
456
532
  @staticmethod
457
533
  def reshape_outputs(model, nc):
458
- """Update a TorchVision classification model to class count 'n' if required."""
534
+ """
535
+ Update a TorchVision classification model to class count 'n' if required.
536
+
537
+ Args:
538
+ model (torch.nn.Module): Model to update.
539
+ nc (int): New number of classes.
540
+ """
459
541
  name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
460
542
  if isinstance(m, Classify): # YOLO Classify() head
461
543
  if m.linear.out_features != nc:
@@ -500,10 +582,10 @@ class RTDETRDetectionModel(DetectionModel):
500
582
  Initialize the RTDETRDetectionModel.
501
583
 
502
584
  Args:
503
- cfg (str): Configuration file name or path.
585
+ cfg (str | dict): Configuration file name or path.
504
586
  ch (int): Number of input channels.
505
- nc (int, optional): Number of classes. Defaults to None.
506
- verbose (bool, optional): Print additional information during initialization. Defaults to True.
587
+ nc (int, optional): Number of classes.
588
+ verbose (bool): Print additional information during initialization.
507
589
  """
508
590
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
509
591
 
@@ -519,7 +601,7 @@ class RTDETRDetectionModel(DetectionModel):
519
601
 
520
602
  Args:
521
603
  batch (dict): Dictionary containing image and label data.
522
- preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None.
604
+ preds (torch.Tensor, optional): Precomputed model predictions.
523
605
 
524
606
  Returns:
525
607
  (tuple): A tuple containing the total loss and main three losses in a tensor.
@@ -564,11 +646,11 @@ class RTDETRDetectionModel(DetectionModel):
564
646
 
565
647
  Args:
566
648
  x (torch.Tensor): The input tensor.
567
- profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
568
- visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
569
- batch (dict, optional): Ground truth data for evaluation. Defaults to None.
570
- augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
571
- embed (list, optional): A list of feature vectors/embeddings to return.
649
+ profile (bool): If True, profile the computation time for each layer.
650
+ visualize (bool): If True, save feature maps for visualization.
651
+ batch (dict, optional): Ground truth data for evaluation.
652
+ augment (bool): If True, perform data augmentation during inference.
653
+ embed (List, optional): A list of feature vectors/embeddings to return.
572
654
 
573
655
  Returns:
574
656
  (torch.Tensor): Model's output tensor.
@@ -596,13 +678,28 @@ class WorldModel(DetectionModel):
596
678
  """YOLOv8 World Model."""
597
679
 
598
680
  def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
599
- """Initialize YOLOv8 world model with given config and parameters."""
681
+ """
682
+ Initialize YOLOv8 world model with given config and parameters.
683
+
684
+ Args:
685
+ cfg (str | dict): Model configuration file path or dictionary.
686
+ ch (int): Number of input channels.
687
+ nc (int, optional): Number of classes.
688
+ verbose (bool): Whether to display model information.
689
+ """
600
690
  self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
601
691
  self.clip_model = None # CLIP model placeholder
602
692
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
603
693
 
604
694
  def set_classes(self, text, batch=80, cache_clip_model=True):
605
- """Set classes in advance so that model could do offline-inference without clip model."""
695
+ """
696
+ Set classes in advance so that model could do offline-inference without clip model.
697
+
698
+ Args:
699
+ text (List[str]): List of class names.
700
+ batch (int): Batch size for processing text tokens.
701
+ cache_clip_model (bool): Whether to cache the CLIP model.
702
+ """
606
703
  try:
607
704
  import clip
608
705
  except ImportError:
@@ -628,11 +725,11 @@ class WorldModel(DetectionModel):
628
725
 
629
726
  Args:
630
727
  x (torch.Tensor): The input tensor.
631
- profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
632
- visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
633
- txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
634
- augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
635
- embed (list, optional): A list of feature vectors/embeddings to return.
728
+ profile (bool): If True, profile the computation time for each layer.
729
+ visualize (bool): If True, save feature maps for visualization.
730
+ txt_feats (torch.Tensor, optional): The text features, use it if it's given.
731
+ augment (bool): If True, perform data augmentation during inference.
732
+ embed (List, optional): A list of feature vectors/embeddings to return.
636
733
 
637
734
  Returns:
638
735
  (torch.Tensor): Model's output tensor.
@@ -671,7 +768,7 @@ class WorldModel(DetectionModel):
671
768
 
672
769
  Args:
673
770
  batch (dict): Batch to compute loss on.
674
- preds (torch.Tensor | List[torch.Tensor]): Predictions.
771
+ preds (torch.Tensor | List[torch.Tensor], optional): Predictions.
675
772
  """
676
773
  if not hasattr(self, "criterion"):
677
774
  self.criterion = self.init_criterion()
@@ -689,7 +786,18 @@ class Ensemble(torch.nn.ModuleList):
689
786
  super().__init__()
690
787
 
691
788
  def forward(self, x, augment=False, profile=False, visualize=False):
692
- """Function generates the YOLO network's final layer."""
789
+ """
790
+ Generate the YOLO network's final layer.
791
+
792
+ Args:
793
+ x (torch.Tensor): Input tensor.
794
+ augment (bool): Whether to augment the input.
795
+ profile (bool): Whether to profile the model.
796
+ visualize (bool): Whether to visualize the features.
797
+
798
+ Returns:
799
+ (tuple): Tuple containing the concatenated predictions and None.
800
+ """
693
801
  y = [module(x, augment, profile, visualize)[0] for module in self]
694
802
  # y = torch.stack(y).max(0)[0] # max ensemble
695
803
  # y = torch.stack(y).mean(0) # mean ensemble
@@ -713,12 +821,10 @@ def temporary_modules(modules=None, attributes=None):
713
821
  modules (dict, optional): A dictionary mapping old module paths to new module paths.
714
822
  attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
715
823
 
716
- Example:
717
- ```python
718
- with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
719
- import old.module # this will now import new.module
720
- from old.module import attribute # this will now import new.module.attribute
721
- ```
824
+ Examples:
825
+ >>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
826
+ >>> import old.module # this will now import new.module
827
+ >>> from old.module import attribute # this will now import new.module.attribute
722
828
 
723
829
  Note:
724
830
  The changes are only in effect inside the context manager and are undone once the context manager exits.
@@ -767,7 +873,16 @@ class SafeUnpickler(pickle.Unpickler):
767
873
  """Custom Unpickler that replaces unknown classes with SafeClass."""
768
874
 
769
875
  def find_class(self, module, name):
770
- """Attempt to find a class, returning SafeClass if not among safe modules."""
876
+ """
877
+ Attempt to find a class, returning SafeClass if not among safe modules.
878
+
879
+ Args:
880
+ module (str): Module name.
881
+ name (str): Class name.
882
+
883
+ Returns:
884
+ (type): Found class or SafeClass.
885
+ """
771
886
  safe_modules = (
772
887
  "torch",
773
888
  "collections",
@@ -793,16 +908,13 @@ def torch_safe_load(weight, safe_only=False):
793
908
  weight (str): The file path of the PyTorch model.
794
909
  safe_only (bool): If True, replace unknown classes with SafeClass during loading.
795
910
 
796
- Example:
797
- ```python
798
- from ultralytics.nn.tasks import torch_safe_load
799
-
800
- ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
801
- ```
802
-
803
911
  Returns:
804
912
  ckpt (dict): The loaded model checkpoint.
805
- file (str): The loaded filename
913
+ file (str): The loaded filename.
914
+
915
+ Examples:
916
+ >>> from ultralytics.nn.tasks import torch_safe_load
917
+ >>> ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
806
918
  """
807
919
  from ultralytics.utils.downloads import attempt_download_asset
808
920
 
@@ -863,7 +975,18 @@ def torch_safe_load(weight, safe_only=False):
863
975
 
864
976
 
865
977
  def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
866
- """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
978
+ """
979
+ Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
980
+
981
+ Args:
982
+ weights (str | List[str]): Model weights path(s).
983
+ device (torch.device, optional): Device to load model to.
984
+ inplace (bool): Whether to do inplace operations.
985
+ fuse (bool): Whether to fuse model.
986
+
987
+ Returns:
988
+ (torch.nn.Module): Loaded model.
989
+ """
867
990
  ensemble = Ensemble()
868
991
  for w in weights if isinstance(weights, list) else [weights]:
869
992
  ckpt, w = torch_safe_load(w) # load ckpt
@@ -901,7 +1024,18 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
901
1024
 
902
1025
 
903
1026
  def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
904
- """Loads a single model weights."""
1027
+ """
1028
+ Load a single model weights.
1029
+
1030
+ Args:
1031
+ weight (str): Model weight path.
1032
+ device (torch.device, optional): Device to load model to.
1033
+ inplace (bool): Whether to do inplace operations.
1034
+ fuse (bool): Whether to fuse model.
1035
+
1036
+ Returns:
1037
+ (tuple): Tuple containing the model and checkpoint.
1038
+ """
905
1039
  ckpt, weight = torch_safe_load(weight) # load ckpt
906
1040
  args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
907
1041
  model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
@@ -927,7 +1061,17 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
927
1061
 
928
1062
 
929
1063
  def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
930
- """Parse a YOLO model.yaml dictionary into a PyTorch model."""
1064
+ """
1065
+ Parse a YOLO model.yaml dictionary into a PyTorch model.
1066
+
1067
+ Args:
1068
+ d (dict): Model dictionary.
1069
+ ch (int): Input channels.
1070
+ verbose (bool): Whether to print model details.
1071
+
1072
+ Returns:
1073
+ (tuple): Tuple containing the PyTorch model and sorted list of output layers.
1074
+ """
931
1075
  import ast
932
1076
 
933
1077
  # Args
@@ -1091,7 +1235,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1091
1235
 
1092
1236
 
1093
1237
  def yaml_model_load(path):
1094
- """Load a YOLOv8 model from a YAML file."""
1238
+ """
1239
+ Load a YOLOv8 model from a YAML file.
1240
+
1241
+ Args:
1242
+ path (str | Path): Path to the YAML file.
1243
+
1244
+ Returns:
1245
+ (dict): Model dictionary.
1246
+ """
1095
1247
  path = Path(path)
1096
1248
  if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
1097
1249
  new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
@@ -1108,15 +1260,13 @@ def yaml_model_load(path):
1108
1260
 
1109
1261
  def guess_model_scale(model_path):
1110
1262
  """
1111
- Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The function
1112
- uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by
1113
- n, s, m, l, or x. The function returns the size character of the model scale as a string.
1263
+ Extract the size character n, s, m, l, or x of the model's scale from the model path.
1114
1264
 
1115
1265
  Args:
1116
1266
  model_path (str | Path): The path to the YOLO model's YAML file.
1117
1267
 
1118
1268
  Returns:
1119
- (str): The size character of the model's scale, which can be n, s, m, l, or x.
1269
+ (str): The size character of the model's scale (n, s, m, l, or x).
1120
1270
  """
1121
1271
  try:
1122
1272
  return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # returns n, s, m, l, or x
@@ -1132,10 +1282,7 @@ def guess_model_task(model):
1132
1282
  model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
1133
1283
 
1134
1284
  Returns:
1135
- (str): Task of the model ('detect', 'segment', 'classify', 'pose').
1136
-
1137
- Raises:
1138
- SyntaxError: If the task of the model could not be determined.
1285
+ (str): Task of the model ('detect', 'segment', 'classify', 'pose', 'obb').
1139
1286
  """
1140
1287
 
1141
1288
  def cfg2task(cfg):
@@ -33,7 +33,7 @@ class AIGym(BaseSolution):
33
33
 
34
34
  def __init__(self, **kwargs):
35
35
  """
36
- Initializes AIGym for workout monitoring using pose estimation and predefined angles.
36
+ Initialize AIGym for workout monitoring using pose estimation and predefined angles.
37
37
 
38
38
  Args:
39
39
  **kwargs (Any): Keyword arguments passed to the parent class constructor.
@@ -53,7 +53,7 @@ class AIGym(BaseSolution):
53
53
 
54
54
  def process(self, im0):
55
55
  """
56
- Monitors workouts using Ultralytics YOLO Pose Model.
56
+ Monitor workouts using Ultralytics YOLO Pose Model.
57
57
 
58
58
  This function processes an input image to track and analyze human poses for workout monitoring. It uses
59
59
  the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined
@@ -37,8 +37,8 @@ class Analytics(BaseSolution):
37
37
  color_mapping (Dict[str, str]): Dictionary mapping class labels to colors for consistent visualization.
38
38
 
39
39
  Methods:
40
- process: Processes image data and updates the chart.
41
- update_graph: Updates the chart with new data points.
40
+ process: Process image data and update the chart.
41
+ update_graph: Update the chart with new data points.
42
42
 
43
43
  Examples:
44
44
  >>> analytics = Analytics(analytics_type="line")
@@ -87,7 +87,7 @@ class Analytics(BaseSolution):
87
87
 
88
88
  def process(self, im0, frame_number):
89
89
  """
90
- Processes image data and runs object tracking to update analytics charts.
90
+ Process image data and run object tracking to update analytics charts.
91
91
 
92
92
  Args:
93
93
  im0 (np.ndarray): Input image for processing.
@@ -127,7 +127,7 @@ class Analytics(BaseSolution):
127
127
 
128
128
  def update_graph(self, frame_number, count_dict=None, plot="line"):
129
129
  """
130
- Updates the graph with new data for single or multiple classes.
130
+ Update the graph with new data for single or multiple classes.
131
131
 
132
132
  Args:
133
133
  frame_number (int): The current frame number.
@@ -21,8 +21,8 @@ class Heatmap(ObjectCounter):
21
21
  annotator (SolutionAnnotator): Object for drawing annotations on the image.
22
22
 
23
23
  Methods:
24
- heatmap_effect: Calculates and updates the heatmap effect for a given bounding box.
25
- process: Generates and applies the heatmap effect to each frame.
24
+ heatmap_effect: Calculate and update the heatmap effect for a given bounding box.
25
+ process: Generate and apply the heatmap effect to each frame.
26
26
 
27
27
  Examples:
28
28
  >>> from ultralytics.solutions import Heatmap
@@ -33,7 +33,7 @@ class Heatmap(ObjectCounter):
33
33
 
34
34
  def __init__(self, **kwargs):
35
35
  """
36
- Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks.
36
+ Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.
37
37
 
38
38
  Args:
39
39
  **kwargs (Any): Keyword arguments passed to the parent ObjectCounter class.
@@ -50,7 +50,7 @@ class Heatmap(ObjectCounter):
50
50
 
51
51
  def heatmap_effect(self, box):
52
52
  """
53
- Efficiently calculates heatmap area and effect location for applying colormap.
53
+ Efficiently calculate heatmap area and effect location for applying colormap.
54
54
 
55
55
  Args:
56
56
  box (List[float]): Bounding box coordinates [x0, y0, x1, y1].