ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +125 -39
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +34 -33
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +33 -47
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +69 -90
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +31 -38
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +21 -26
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +23 -17
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +29 -24
  53. ultralytics/models/nas/predict.py +14 -11
  54. ultralytics/models/nas/val.py +11 -13
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +21 -21
  57. ultralytics/models/rtdetr/train.py +25 -24
  58. ultralytics/models/rtdetr/val.py +47 -14
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +30 -12
  73. ultralytics/models/yolo/classify/train.py +83 -19
  74. ultralytics/models/yolo/classify/val.py +45 -23
  75. ultralytics/models/yolo/detect/predict.py +29 -19
  76. ultralytics/models/yolo/detect/train.py +90 -23
  77. ultralytics/models/yolo/detect/val.py +150 -29
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +18 -13
  80. ultralytics/models/yolo/obb/train.py +12 -8
  81. ultralytics/models/yolo/obb/val.py +35 -22
  82. ultralytics/models/yolo/pose/predict.py +28 -15
  83. ultralytics/models/yolo/pose/train.py +21 -8
  84. ultralytics/models/yolo/pose/val.py +51 -31
  85. ultralytics/models/yolo/segment/predict.py +27 -16
  86. ultralytics/models/yolo/segment/train.py +11 -8
  87. ultralytics/models/yolo/segment/val.py +110 -29
  88. ultralytics/models/yolo/world/train.py +43 -16
  89. ultralytics/models/yolo/world/train_world.py +61 -36
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +12 -12
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +226 -79
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +37 -35
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +139 -68
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +37 -56
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +117 -52
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +65 -61
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +72 -59
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +202 -64
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +13 -25
  149. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.88.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -109,7 +109,7 @@ from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_d
109
109
 
110
110
 
111
111
  def export_formats():
112
- """Ultralytics YOLO export formats."""
112
+ """Return a dictionary of Ultralytics YOLO export formats."""
113
113
  x = [
114
114
  ["PyTorch", "-", ".pt", True, True, []],
115
115
  ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]],
@@ -133,17 +133,16 @@ def export_formats():
133
133
 
134
134
  def validate_args(format, passed_args, valid_args):
135
135
  """
136
- Validates arguments based on format.
136
+ Validate arguments based on the export format.
137
137
 
138
138
  Args:
139
139
  format (str): The export format.
140
140
  passed_args (Namespace): The arguments used during export.
141
- valid_args (dict): List of valid arguments for the format.
141
+ valid_args (List): List of valid arguments for the format.
142
142
 
143
143
  Raises:
144
- AssertionError: If an argument that's not supported by the export format is used, or if format doesn't have the supported arguments listed.
144
+ AssertionError: If an unsupported argument is used, or if the format lacks supported argument listings.
145
145
  """
146
- # Only check valid usage of these args
147
146
  export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"]
148
147
 
149
148
  assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed."
@@ -156,7 +155,7 @@ def validate_args(format, passed_args, valid_args):
156
155
 
157
156
 
158
157
  def gd_outputs(gd):
159
- """TensorFlow GraphDef model output node names."""
158
+ """Return TensorFlow GraphDef model output node names."""
160
159
  name_list, input_list = [], []
161
160
  for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
162
161
  name_list.append(node.name)
@@ -195,6 +194,7 @@ def arange_patch(args):
195
194
  func = torch.arange
196
195
 
197
196
  def arange(*args, dtype=None, **kwargs):
197
+ """Return a 1-D tensor of size with values from the interval and common difference."""
198
198
  return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
199
199
 
200
200
  torch.arange = arange # patch
@@ -210,17 +210,17 @@ class Exporter:
210
210
 
211
211
  Attributes:
212
212
  args (SimpleNamespace): Configuration for the exporter.
213
- callbacks (list, optional): List of callback functions. Defaults to None.
213
+ callbacks (List, optional): List of callback functions.
214
214
  """
215
215
 
216
216
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
217
217
  """
218
- Initializes the Exporter class.
218
+ Initialize the Exporter class.
219
219
 
220
220
  Args:
221
- cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
222
- overrides (dict, optional): Configuration overrides. Defaults to None.
223
- _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
221
+ cfg (str, optional): Path to a configuration file.
222
+ overrides (Dict, optional): Configuration overrides.
223
+ _callbacks (Dict, optional): Dictionary of callback functions.
224
224
  """
225
225
  self.args = get_cfg(cfg, overrides)
226
226
  if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
@@ -230,7 +230,7 @@ class Exporter:
230
230
  callbacks.add_integration_callbacks(self)
231
231
 
232
232
  def __call__(self, model=None) -> str:
233
- """Returns list of exported files/dirs after running callbacks."""
233
+ """Return list of exported files/dirs after running callbacks."""
234
234
  self.run_callbacks("on_export_start")
235
235
  t = time.time()
236
236
  fmt = self.args.format.lower() # to lowercase
@@ -293,7 +293,8 @@ class Exporter:
293
293
  if rknn:
294
294
  if not self.args.name:
295
295
  LOGGER.warning(
296
- "WARNING ⚠️ Rockchip RKNN export requires a missing 'name' arg for processor type. Using default name='rk3588'."
296
+ "WARNING ⚠️ Rockchip RKNN export requires a missing 'name' arg for processor type. "
297
+ "Using default name='rk3588'."
297
298
  )
298
299
  self.args.name = "rk3588"
299
300
  self.args.name = self.args.name.lower()
@@ -481,7 +482,7 @@ class Exporter:
481
482
  return f # return list of exported files/dirs
482
483
 
483
484
  def get_int8_calibration_dataloader(self, prefix=""):
484
- """Build and return a dataloader suitable for calibration of INT8 models."""
485
+ """Build and return a dataloader for calibration of INT8 models."""
485
486
  LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
486
487
  data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
487
488
  # TensorRT INT8 calibration should use 2x batch size
@@ -497,7 +498,8 @@ class Exporter:
497
498
  n = len(dataset)
498
499
  if n < self.args.batch:
499
500
  raise ValueError(
500
- f"The calibration dataset ({n} images) must have at least as many images as the batch size ('batch={self.args.batch}')."
501
+ f"The calibration dataset ({n} images) must have at least as many images as the batch size "
502
+ f"('batch={self.args.batch}')."
501
503
  )
502
504
  elif n < 300:
503
505
  LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
@@ -602,7 +604,7 @@ class Exporter:
602
604
  )
603
605
 
604
606
  def serialize(ov_model, file):
605
- """Set RT info, serialize and save metadata YAML."""
607
+ """Set RT info, serialize, and save metadata YAML."""
606
608
  ov_model.set_rt_info("YOLO", ["model_info", "model_type"])
607
609
  ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"])
608
610
  ov_model.set_rt_info(114, ["model_info", "pad_value"])
@@ -1572,7 +1574,7 @@ class NMSModel(torch.nn.Module):
1572
1574
  x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
1573
1575
 
1574
1576
  Returns:
1575
- out (torch.Tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
1577
+ (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number of detections after NMS.
1576
1578
  """
1577
1579
  from functools import partial
1578
1580
 
@@ -86,7 +86,7 @@ class Model(torch.nn.Module):
86
86
  verbose: bool = False,
87
87
  ) -> None:
88
88
  """
89
- Initializes a new instance of the YOLO model class.
89
+ Initialize a new instance of the YOLO model class.
90
90
 
91
91
  This constructor sets up the model based on the provided model path or name. It handles various types of
92
92
  model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
@@ -94,7 +94,7 @@ class Model(torch.nn.Module):
94
94
  prediction, or export.
95
95
 
96
96
  Args:
97
- model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
97
+ model (str | Path): Path or name of the model to load or create. Can be a local file path, a
98
98
  model name from Ultralytics HUB, or a Triton Server model.
99
99
  task (str | None): The task type associated with the YOLO model, specifying its application domain.
100
100
  verbose (bool): If True, enables verbose output during the model's initialization and subsequent
@@ -167,7 +167,7 @@ class Model(torch.nn.Module):
167
167
  the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
168
168
  tensor, or a list/tuple of these.
169
169
  stream (bool): If True, treat the input source as a continuous stream for predictions.
170
- **kwargs: Additional keyword arguments to configure the prediction process.
170
+ **kwargs (Any): Additional keyword arguments to configure the prediction process.
171
171
 
172
172
  Returns:
173
173
  (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
@@ -184,7 +184,7 @@ class Model(torch.nn.Module):
184
184
  @staticmethod
185
185
  def is_triton_model(model: str) -> bool:
186
186
  """
187
- Checks if the given model string is a Triton Server URL.
187
+ Check if the given model string is a Triton Server URL.
188
188
 
189
189
  This static method determines whether the provided model string represents a valid Triton Server URL by
190
190
  parsing its components using urllib.parse.urlsplit().
@@ -230,11 +230,10 @@ class Model(torch.nn.Module):
230
230
 
231
231
  def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
232
232
  """
233
- Initializes a new model and infers the task type from the model definitions.
233
+ Initialize a new model and infer the task type from model definitions.
234
234
 
235
- This method creates a new model instance based on the provided configuration file. It loads the model
236
- configuration, infers the task type if not specified, and initializes the model using the appropriate
237
- class from the task map.
235
+ Creates a new model instance based on the provided configuration file. Loads the model configuration, infers
236
+ the task type if not specified, and initializes the model using the appropriate class from the task map.
238
237
 
239
238
  Args:
240
239
  cfg (str): Path to the model configuration file in YAML format.
@@ -265,7 +264,7 @@ class Model(torch.nn.Module):
265
264
 
266
265
  def _load(self, weights: str, task=None) -> None:
267
266
  """
268
- Loads a model from a checkpoint file or initializes it from a weights file.
267
+ Load a model from a checkpoint file or initialize it from a weights file.
269
268
 
270
269
  This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
271
270
  up the model, task, and related attributes based on the loaded weights.
@@ -303,7 +302,7 @@ class Model(torch.nn.Module):
303
302
 
304
303
  def _check_is_pytorch_model(self) -> None:
305
304
  """
306
- Checks if the model is a PyTorch model and raises a TypeError if it's not.
305
+ Check if the model is a PyTorch model and raise TypeError if it's not.
307
306
 
308
307
  This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
309
308
  certain operations that require a PyTorch model are only performed on compatible model types.
@@ -331,7 +330,7 @@ class Model(torch.nn.Module):
331
330
 
332
331
  def reset_weights(self) -> "Model":
333
332
  """
334
- Resets the model's weights to their initial state.
333
+ Reset the model's weights to their initial state.
335
334
 
336
335
  This method iterates through all modules in the model and resets their parameters if they have a
337
336
  'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
@@ -357,7 +356,7 @@ class Model(torch.nn.Module):
357
356
 
358
357
  def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model":
359
358
  """
360
- Loads parameters from the specified weights file into the model.
359
+ Load parameters from the specified weights file into the model.
361
360
 
362
361
  This method supports loading weights from a file or directly from a weights object. It matches parameters by
363
362
  name and shape and transfers them to the model.
@@ -385,13 +384,13 @@ class Model(torch.nn.Module):
385
384
 
386
385
  def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
387
386
  """
388
- Saves the current model state to a file.
387
+ Save the current model state to a file.
389
388
 
390
389
  This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
391
390
  the date, Ultralytics version, license information, and a link to the documentation.
392
391
 
393
392
  Args:
394
- filename (Union[str, Path]): The name of the file to save the model to.
393
+ filename (str | Path): The name of the file to save the model to.
395
394
 
396
395
  Raises:
397
396
  AssertionError: If the model is not a PyTorch model.
@@ -417,7 +416,7 @@ class Model(torch.nn.Module):
417
416
 
418
417
  def info(self, detailed: bool = False, verbose: bool = True):
419
418
  """
420
- Logs or returns model information.
419
+ Display model information.
421
420
 
422
421
  This method provides an overview or detailed information about the model, depending on the arguments
423
422
  passed. It can control the verbosity of the output and return the information as a list.
@@ -430,9 +429,6 @@ class Model(torch.nn.Module):
430
429
  (List[str]): A list of strings containing various types of information about the model, including
431
430
  model summary, layer details, and parameter counts. Empty if verbose is True.
432
431
 
433
- Raises:
434
- TypeError: If the model is not a PyTorch model.
435
-
436
432
  Examples:
437
433
  >>> model = Model("yolo11n.pt")
438
434
  >>> model.info() # Prints model summary
@@ -441,9 +437,9 @@ class Model(torch.nn.Module):
441
437
  self._check_is_pytorch_model()
442
438
  return self.model.info(detailed=detailed, verbose=verbose)
443
439
 
444
- def fuse(self):
440
+ def fuse(self) -> None:
445
441
  """
446
- Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.
442
+ Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
447
443
 
448
444
  This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
449
445
  into a single layer. This fusion can significantly improve inference speed by reducing the number of
@@ -453,9 +449,6 @@ class Model(torch.nn.Module):
453
449
  bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
454
450
  performs both convolution and normalization in one step.
455
451
 
456
- Raises:
457
- TypeError: If the model is not a PyTorch torch.nn.Module.
458
-
459
452
  Examples:
460
453
  >>> model = Model("yolo11n.pt")
461
454
  >>> model.fuse()
@@ -471,7 +464,7 @@ class Model(torch.nn.Module):
471
464
  **kwargs: Any,
472
465
  ) -> list:
473
466
  """
474
- Generates image embeddings based on the provided source.
467
+ Generate image embeddings based on the provided source.
475
468
 
476
469
  This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
477
470
  source. It allows customization of the embedding process through various keyword arguments.
@@ -480,14 +473,11 @@ class Model(torch.nn.Module):
480
473
  source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
481
474
  generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
482
475
  stream (bool): If True, predictions are streamed.
483
- **kwargs: Additional keyword arguments for configuring the embedding process.
476
+ **kwargs (Any): Additional keyword arguments for configuring the embedding process.
484
477
 
485
478
  Returns:
486
479
  (List[torch.Tensor]): A list containing the image embeddings.
487
480
 
488
- Raises:
489
- AssertionError: If the model is not a PyTorch model.
490
-
491
481
  Examples:
492
482
  >>> model = YOLO("yolo11n.pt")
493
483
  >>> image = "https://ultralytics.com/images/bus.jpg"
@@ -519,7 +509,7 @@ class Model(torch.nn.Module):
519
509
  stream (bool): If True, treats the input source as a continuous stream for predictions.
520
510
  predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
521
511
  If None, the method uses a default predictor.
522
- **kwargs: Additional keyword arguments for configuring the prediction process.
512
+ **kwargs (Any): Additional keyword arguments for configuring the prediction process.
523
513
 
524
514
  Returns:
525
515
  (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
@@ -576,16 +566,13 @@ class Model(torch.nn.Module):
576
566
  Args:
577
567
  source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
578
568
  tracking. Can be a file path, URL, or video stream.
579
- stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
580
- persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
581
- **kwargs: Additional keyword arguments for configuring the tracking process.
569
+ stream (bool): If True, treats the input source as a continuous video stream.
570
+ persist (bool): If True, persists trackers between different calls to this method.
571
+ **kwargs (Any): Additional keyword arguments for configuring the tracking process.
582
572
 
583
573
  Returns:
584
574
  (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
585
575
 
586
- Raises:
587
- AttributeError: If the predictor does not have registered trackers.
588
-
589
576
  Examples:
590
577
  >>> model = YOLO("yolo11n.pt")
591
578
  >>> results = model.track(source="path/to/video.mp4", show=True)
@@ -612,7 +599,7 @@ class Model(torch.nn.Module):
612
599
  **kwargs: Any,
613
600
  ):
614
601
  """
615
- Validates the model using a specified dataset and validation configuration.
602
+ Validate the model using a specified dataset and validation configuration.
616
603
 
617
604
  This method facilitates the model validation process, allowing for customization through various settings. It
618
605
  supports validation with a custom validator or the default validation approach. The method combines default
@@ -621,7 +608,7 @@ class Model(torch.nn.Module):
621
608
  Args:
622
609
  validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
623
610
  validating the model.
624
- **kwargs: Arbitrary keyword arguments for customizing the validation process.
611
+ **kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
625
612
 
626
613
  Returns:
627
614
  (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.
@@ -647,7 +634,7 @@ class Model(torch.nn.Module):
647
634
  **kwargs: Any,
648
635
  ):
649
636
  """
650
- Benchmarks the model across various export formats to evaluate performance.
637
+ Benchmark the model across various export formats to evaluate performance.
651
638
 
652
639
  This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
653
640
  It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
@@ -655,15 +642,14 @@ class Model(torch.nn.Module):
655
642
  defaults, and any additional user-provided keyword arguments.
656
643
 
657
644
  Args:
658
- **kwargs: Arbitrary keyword arguments to customize the benchmarking process. These are combined with
659
- default configurations, model-specific arguments, and method defaults. Common options include:
645
+ **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. Common options include:
660
646
  - data (str): Path to the dataset for benchmarking.
661
647
  - imgsz (int | List[int]): Image size for benchmarking.
662
648
  - half (bool): Whether to use half-precision (FP16) mode.
663
649
  - int8 (bool): Whether to use int8 precision mode.
664
650
  - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
665
651
  - verbose (bool): Whether to print detailed benchmark information.
666
- - format (str): Export format name for specific benchmarking
652
+ - format (str): Export format name for specific benchmarking.
667
653
 
668
654
  Returns:
669
655
  (Dict): A dictionary containing the results of the benchmarking process, including metrics for
@@ -698,14 +684,14 @@ class Model(torch.nn.Module):
698
684
  **kwargs: Any,
699
685
  ) -> str:
700
686
  """
701
- Exports the model to a different format suitable for deployment.
687
+ Export the model to a different format suitable for deployment.
702
688
 
703
689
  This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
704
690
  purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
705
691
  defaults, and any additional arguments provided.
706
692
 
707
693
  Args:
708
- **kwargs: Arbitrary keyword arguments to customize the export process. These are combined with
694
+ **kwargs (Any): Arbitrary keyword arguments to customize the export process. These are combined with
709
695
  the model's overrides and method defaults. Common arguments include:
710
696
  format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
711
697
  half (bool): Export model in half-precision.
@@ -759,7 +745,7 @@ class Model(torch.nn.Module):
759
745
 
760
746
  Args:
761
747
  trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
762
- **kwargs: Arbitrary keyword arguments for training configuration. Common options include:
748
+ **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
763
749
  data (str): Path to dataset configuration file.
764
750
  epochs (int): Number of training epochs.
765
751
  batch_size (int): Batch size for training.
@@ -773,11 +759,6 @@ class Model(torch.nn.Module):
773
759
  Returns:
774
760
  (Dict | None): Training metrics if available and training is successful; otherwise, None.
775
761
 
776
- Raises:
777
- AssertionError: If the model is not a PyTorch model.
778
- PermissionError: If there is a permission issue with the HUB session.
779
- ModuleNotFoundError: If the HUB SDK is not installed.
780
-
781
762
  Examples:
782
763
  >>> model = YOLO("yolo11n.pt")
783
764
  >>> results = model.train(data="coco8.yaml", epochs=3)
@@ -832,21 +813,25 @@ class Model(torch.nn.Module):
832
813
  custom arguments to configure the tuning process.
833
814
 
834
815
  Args:
835
- use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
836
- iterations (int): The number of tuning iterations to perform. Defaults to 10.
837
- *args: Variable length argument list for additional arguments.
838
- **kwargs: Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
816
+ use_ray (bool): Whether to use Ray Tune for hyperparameter tuning. If False, uses internal tuning method.
817
+ iterations (int): Number of tuning iterations to perform.
818
+ *args (Any): Additional positional arguments to pass to the tuner.
819
+ **kwargs (Any): Additional keyword arguments for tuning configuration. These are combined with model
820
+ overrides and defaults to configure the tuning process.
839
821
 
840
822
  Returns:
841
- (Dict): A dictionary containing the results of the hyperparameter search.
823
+ (Dict): Results of the hyperparameter search, including best parameters and performance metrics.
842
824
 
843
825
  Raises:
844
- AssertionError: If the model is not a PyTorch model.
826
+ TypeError: If the model is not a PyTorch model.
845
827
 
846
828
  Examples:
847
829
  >>> model = YOLO("yolo11n.pt")
848
- >>> results = model.tune(use_ray=True, iterations=20)
830
+ >>> results = model.tune(data="coco8.yaml", iterations=5)
849
831
  >>> print(results)
832
+
833
+ # Use Ray Tune for more advanced hyperparameter search
834
+ >>> results = model.tune(use_ray=True, iterations=20, data="coco8.yaml")
850
835
  """
851
836
  self._check_is_pytorch_model()
852
837
  if use_ray:
@@ -862,7 +847,7 @@ class Model(torch.nn.Module):
862
847
 
863
848
  def _apply(self, fn) -> "Model":
864
849
  """
865
- Applies a function to model tensors that are not parameters or registered buffers.
850
+ Apply a function to model tensors that are not parameters or registered buffers.
866
851
 
867
852
  This method extends the functionality of the parent class's _apply method by additionally resetting the
868
853
  predictor and updating the device in the model's overrides. It's typically used for operations like
@@ -898,7 +883,8 @@ class Model(torch.nn.Module):
898
883
  initialized, it sets it up before retrieving the names.
899
884
 
900
885
  Returns:
901
- (Dict[int, str]): A dict of class names associated with the model.
886
+ (Dict[int, str]): A dictionary of class names associated with the model, where keys are class indices and
887
+ values are the corresponding class names.
902
888
 
903
889
  Raises:
904
890
  AttributeError: If the model or predictor does not have a 'names' attribute.
@@ -920,7 +906,7 @@ class Model(torch.nn.Module):
920
906
  @property
921
907
  def device(self) -> torch.device:
922
908
  """
923
- Retrieves the device on which the model's parameters are allocated.
909
+ Get the device on which the model's parameters are allocated.
924
910
 
925
911
  This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
926
912
  applicable only to models that are instances of torch.nn.Module.
@@ -965,7 +951,7 @@ class Model(torch.nn.Module):
965
951
 
966
952
  def add_callback(self, event: str, func) -> None:
967
953
  """
968
- Adds a callback function for a specified event.
954
+ Add a callback function for a specified event.
969
955
 
970
956
  This method allows registering custom callback functions that are triggered on specific events during
971
957
  model operations such as training or inference. Callbacks provide a way to extend and customize the
@@ -1019,7 +1005,7 @@ class Model(torch.nn.Module):
1019
1005
 
1020
1006
  def reset_callbacks(self) -> None:
1021
1007
  """
1022
- Resets all callbacks to their default functions.
1008
+ Reset all callbacks to their default functions.
1023
1009
 
1024
1010
  This method reinstates the default callback functions for all events, removing any custom callbacks that were
1025
1011
  previously added. It iterates through all default callback events and replaces the current callbacks with the
@@ -1043,9 +1029,9 @@ class Model(torch.nn.Module):
1043
1029
  @staticmethod
1044
1030
  def _reset_ckpt_args(args: dict) -> dict:
1045
1031
  """
1046
- Resets specific arguments when loading a PyTorch model checkpoint.
1032
+ Reset specific arguments when loading a PyTorch model checkpoint.
1047
1033
 
1048
- This static method filters the input arguments dictionary to retain only a specific set of keys that are
1034
+ This method filters the input arguments dictionary to retain only a specific set of keys that are
1049
1035
  considered important for model loading. It's used to ensure that only relevant arguments are preserved
1050
1036
  when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
1051
1037
 
@@ -1071,29 +1057,25 @@ class Model(torch.nn.Module):
1071
1057
 
1072
1058
  def _smart_load(self, key: str):
1073
1059
  """
1074
- Loads the appropriate module based on the model task.
1060
+ Intelligently loads the appropriate module based on the model task.
1075
1061
 
1076
1062
  This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
1077
- based on the current task of the model and the provided key. It uses the task_map attribute to determine
1078
- the correct module to load.
1063
+ based on the current task of the model and the provided key. It uses the task_map dictionary to determine
1064
+ the appropriate module to load for the specific task.
1079
1065
 
1080
1066
  Args:
1081
1067
  key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
1082
1068
 
1083
1069
  Returns:
1084
- (object): The loaded module corresponding to the specified key and current task.
1070
+ (object): The loaded module class corresponding to the specified key and current task.
1085
1071
 
1086
1072
  Raises:
1087
1073
  NotImplementedError: If the specified key is not supported for the current task.
1088
1074
 
1089
1075
  Examples:
1090
1076
  >>> model = Model(task="detect")
1091
- >>> predictor = model._smart_load("predictor")
1092
- >>> trainer = model._smart_load("trainer")
1093
-
1094
- Notes:
1095
- - This method is typically used internally by other methods of the Model class.
1096
- - The task_map attribute should be properly initialized with the correct mappings for each task.
1077
+ >>> predictor_class = model._smart_load("predictor")
1078
+ >>> trainer_class = model._smart_load("trainer")
1097
1079
  """
1098
1080
  try:
1099
1081
  return self.task_map[self.task][key]
@@ -1118,20 +1100,15 @@ class Model(torch.nn.Module):
1118
1100
  various tasks and modes within the Ultralytics framework.
1119
1101
 
1120
1102
  Returns:
1121
- (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
1122
- nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
1123
- 'predictor', mapping to their respective class implementations.
1103
+ (Dict[str, Dict[str, Any]]): A dictionary mapping task names to nested dictionaries. Each nested dictionary
1104
+ contains mappings for 'model', 'trainer', 'validator', and 'predictor' keys to their respective class
1105
+ implementations for that task.
1124
1106
 
1125
1107
  Examples:
1126
- >>> model = Model()
1108
+ >>> model = Model("yolo11n.pt")
1127
1109
  >>> task_map = model.task_map
1128
- >>> detect_class_map = task_map["detect"]
1129
- >>> segment_class_map = task_map["segment"]
1130
-
1131
- Note:
1132
- The actual implementation of this method may vary depending on the specific tasks and
1133
- classes supported by the Ultralytics framework. The docstring provides a general
1134
- description of the expected behavior and structure.
1110
+ >>> detect_predictor = task_map["detect"]["predictor"]
1111
+ >>> segment_trainer = task_map["segment"]["trainer"]
1135
1112
  """
1136
1113
  raise NotImplementedError("Please provide task map for your model!")
1137
1114
 
@@ -1140,21 +1117,23 @@ class Model(torch.nn.Module):
1140
1117
  Sets the model to evaluation mode.
1141
1118
 
1142
1119
  This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
1143
- that behave differently during training and evaluation.
1120
+ that behave differently during training and evaluation. In evaluation mode, these layers use running statistics
1121
+ rather than computing batch statistics, and dropout layers are disabled.
1144
1122
 
1145
1123
  Returns:
1146
1124
  (Model): The model instance with evaluation mode set.
1147
1125
 
1148
1126
  Examples:
1149
- >> model = YOLO("yolo11n.pt")
1150
- >> model.eval()
1127
+ >>> model = YOLO("yolo11n.pt")
1128
+ >>> model.eval()
1129
+ >>> # Model is now in evaluation mode for inference
1151
1130
  """
1152
1131
  self.model.eval()
1153
1132
  return self
1154
1133
 
1155
1134
  def __getattr__(self, name):
1156
1135
  """
1157
- Enables accessing model attributes directly through the Model class.
1136
+ Enable accessing model attributes directly through the Model class.
1158
1137
 
1159
1138
  This method provides a way to access attributes of the underlying model directly through the Model class
1160
1139
  instance. It first checks if the requested attribute is 'model', in which case it returns the model from
@@ -1171,7 +1150,7 @@ class Model(torch.nn.Module):
1171
1150
 
1172
1151
  Examples:
1173
1152
  >>> model = YOLO("yolo11n.pt")
1174
- >>> print(model.stride)
1175
- >>> print(model.task)
1153
+ >>> print(model.stride) # Access model.stride attribute
1154
+ >>> print(model.names) # Access model.names attribute
1176
1155
  """
1177
1156
  return self._modules["model"] if name == "model" else getattr(self.model, name)