ultralytics 8.2.61__py3-none-any.whl → 8.2.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

@@ -30,26 +30,18 @@ class Model(nn.Module):
30
30
 
31
31
  This class provides a common interface for various operations related to YOLO models, such as training,
32
32
  validation, prediction, exporting, and benchmarking. It handles different types of models, including those
33
- loaded from local files, Ultralytics HUB, or Triton Server. The class is designed to be flexible and
34
- extendable for different tasks and model configurations.
35
-
36
- Args:
37
- model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file
38
- path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
39
- task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's
40
- application domain, such as object detection, segmentation, etc. Defaults to None.
41
- verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False.
33
+ loaded from local files, Ultralytics HUB, or Triton Server.
42
34
 
43
35
  Attributes:
44
- callbacks (dict): A dictionary of callback functions for various events during model operations.
36
+ callbacks (Dict): A dictionary of callback functions for various events during model operations.
45
37
  predictor (BasePredictor): The predictor object used for making predictions.
46
38
  model (nn.Module): The underlying PyTorch model.
47
39
  trainer (BaseTrainer): The trainer object used for training the model.
48
- ckpt (dict): The checkpoint data if the model is loaded from a *.pt file.
40
+ ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
49
41
  cfg (str): The configuration of the model if loaded from a *.yaml file.
50
42
  ckpt_path (str): The path to the checkpoint file.
51
- overrides (dict): A dictionary of overrides for model configuration.
52
- metrics (dict): The latest training/validation metrics.
43
+ overrides (Dict): A dictionary of overrides for model configuration.
44
+ metrics (Dict): The latest training/validation metrics.
53
45
  session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
54
46
  task (str): The type of task the model is intended for.
55
47
  model_name (str): The name of the model.
@@ -75,19 +67,14 @@ class Model(nn.Module):
75
67
  add_callback: Adds a callback function for an event.
76
68
  clear_callback: Clears all callbacks for an event.
77
69
  reset_callbacks: Resets all callbacks to their default functions.
78
- is_triton_model: Checks if a model is a Triton Server model.
79
- is_hub_model: Checks if a model is an Ultralytics HUB model.
80
- _reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
81
- _smart_load: Loads the appropriate module based on the model task.
82
- task_map: Provides a mapping from model tasks to corresponding classes.
83
-
84
- Raises:
85
- FileNotFoundError: If the specified model file does not exist or is inaccessible.
86
- ValueError: If the model file or configuration is invalid or unsupported.
87
- ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
88
- TypeError: If the model is not a PyTorch model when required.
89
- AttributeError: If required attributes or methods are not implemented or available.
90
- NotImplementedError: If a specific model task or mode is not supported.
70
+
71
+ Examples:
72
+ >>> from ultralytics import YOLO
73
+ >>> model = YOLO('yolov8n.pt')
74
+ >>> results = model.predict('image.jpg')
75
+ >>> model.train(data='coco128.yaml', epochs=3)
76
+ >>> metrics = model.val()
77
+ >>> model.export(format='onnx')
91
78
  """
92
79
 
93
80
  def __init__(
@@ -99,22 +86,27 @@ class Model(nn.Module):
99
86
  """
100
87
  Initializes a new instance of the YOLO model class.
101
88
 
102
- This constructor sets up the model based on the provided model path or name. It handles various types of model
103
- sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
104
- important attributes of the model and prepares it for operations like training, prediction, or export.
89
+ This constructor sets up the model based on the provided model path or name. It handles various types of
90
+ model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
91
+ initializes several important attributes of the model and prepares it for operations like training,
92
+ prediction, or export.
105
93
 
106
94
  Args:
107
- model (Union[str, Path], optional): The path or model file to load or create. This can be a local
108
- file path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
109
- task (Any, optional): The task type associated with the YOLO model, specifying its application domain.
110
- Defaults to None.
111
- verbose (bool, optional): If True, enables verbose output during the model's initialization and subsequent
112
- operations. Defaults to False.
95
+ model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
96
+ model name from Ultralytics HUB, or a Triton Server model.
97
+ task (str | None): The task type associated with the YOLO model, specifying its application domain.
98
+ verbose (bool): If True, enables verbose output during the model's initialization and subsequent
99
+ operations.
113
100
 
114
101
  Raises:
115
102
  FileNotFoundError: If the specified model file does not exist or is inaccessible.
116
103
  ValueError: If the model file or configuration is invalid or unsupported.
117
104
  ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
105
+
106
+ Examples:
107
+ >>> model = Model("yolov8n.pt")
108
+ >>> model = Model("path/to/model.yaml", task="detect")
109
+ >>> model = Model("hub_model", verbose=True)
118
110
  """
119
111
  super().__init__()
120
112
  self.callbacks = callbacks.get_default_callbacks()
@@ -155,27 +147,50 @@ class Model(nn.Module):
155
147
  **kwargs,
156
148
  ) -> list:
157
149
  """
158
- An alias for the predict method, enabling the model instance to be callable.
150
+ Alias for the predict method, enabling the model instance to be callable for predictions.
159
151
 
160
- This method simplifies the process of making predictions by allowing the model instance to be called directly
161
- with the required arguments for prediction.
152
+ This method simplifies the process of making predictions by allowing the model instance to be called
153
+ directly with the required arguments.
162
154
 
163
155
  Args:
164
- source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making
165
- predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays.
166
- Defaults to None.
167
- stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
168
- Defaults to False.
169
- **kwargs (any): Additional keyword arguments for configuring the prediction process.
156
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
157
+ the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
158
+ tensor, or a list/tuple of these.
159
+ stream (bool): If True, treat the input source as a continuous stream for predictions.
160
+ **kwargs (Any): Additional keyword arguments to configure the prediction process.
170
161
 
171
162
  Returns:
172
- (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
163
+ (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
164
+ Results object.
165
+
166
+ Examples:
167
+ >>> model = YOLO('yolov8n.pt')
168
+ >>> results = model('https://ultralytics.com/images/bus.jpg')
169
+ >>> for r in results:
170
+ ... print(f"Detected {len(r)} objects in image")
173
171
  """
174
172
  return self.predict(source, stream, **kwargs)
175
173
 
176
174
  @staticmethod
177
175
  def is_triton_model(model: str) -> bool:
178
- """Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
176
+ """
177
+ Checks if the given model string is a Triton Server URL.
178
+
179
+ This static method determines whether the provided model string represents a valid Triton Server URL by
180
+ parsing its components using urllib.parse.urlsplit().
181
+
182
+ Args:
183
+ model (str): The model string to be checked.
184
+
185
+ Returns:
186
+ (bool): True if the model string is a valid Triton Server URL, False otherwise.
187
+
188
+ Examples:
189
+ >>> Model.is_triton_model('http://localhost:8000/v2/models/yolov8n')
190
+ True
191
+ >>> Model.is_triton_model('yolov8n.pt')
192
+ False
193
+ """
179
194
  from urllib.parse import urlsplit
180
195
 
181
196
  url = urlsplit(model)
@@ -183,7 +198,30 @@ class Model(nn.Module):
183
198
 
184
199
  @staticmethod
185
200
  def is_hub_model(model: str) -> bool:
186
- """Check if the provided model is a HUB model."""
201
+ """
202
+ Check if the provided model is an Ultralytics HUB model.
203
+
204
+ This static method determines whether the given model string represents a valid Ultralytics HUB model
205
+ identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
206
+ or a standalone model ID.
207
+
208
+ Args:
209
+ model (str): The model identifier to check. This can be a URL, an API key and model ID
210
+ combination, or a standalone model ID.
211
+
212
+ Returns:
213
+ (bool): True if the model is a valid Ultralytics HUB model, False otherwise.
214
+
215
+ Examples:
216
+ >>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
217
+ True
218
+ >>> Model.is_hub_model("api_key_example_model_id")
219
+ True
220
+ >>> Model.is_hub_model("example_model_id")
221
+ True
222
+ >>> Model.is_hub_model("not_a_hub_model.pt")
223
+ False
224
+ """
187
225
  return any(
188
226
  (
189
227
  model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
@@ -196,11 +234,24 @@ class Model(nn.Module):
196
234
  """
197
235
  Initializes a new model and infers the task type from the model definitions.
198
236
 
237
+ This method creates a new model instance based on the provided configuration file. It loads the model
238
+ configuration, infers the task type if not specified, and initializes the model using the appropriate
239
+ class from the task map.
240
+
199
241
  Args:
200
- cfg (str): model configuration file
201
- task (str | None): model task
202
- model (BaseModel): Customized model.
203
- verbose (bool): display model info on load
242
+ cfg (str): Path to the model configuration file in YAML format.
243
+ task (str | None): The specific task for the model. If None, it will be inferred from the config.
244
+ model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
245
+ a new one.
246
+ verbose (bool): If True, displays model information during loading.
247
+
248
+ Raises:
249
+ ValueError: If the configuration file is invalid or the task cannot be inferred.
250
+ ImportError: If the required dependencies for the specified task are not installed.
251
+
252
+ Examples:
253
+ >>> model = Model()
254
+ >>> model._new('yolov8n.yaml', task='detect', verbose=True)
204
255
  """
205
256
  cfg_dict = yaml_model_load(cfg)
206
257
  self.cfg = cfg
@@ -216,11 +267,23 @@ class Model(nn.Module):
216
267
 
217
268
  def _load(self, weights: str, task=None) -> None:
218
269
  """
219
- Initializes a new model and infers the task type from the model head.
270
+ Loads a model from a checkpoint file or initializes it from a weights file.
271
+
272
+ This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
273
+ up the model, task, and related attributes based on the loaded weights.
220
274
 
221
275
  Args:
222
- weights (str): model checkpoint to be loaded
223
- task (str | None): model task
276
+ weights (str): Path to the model weights file to be loaded.
277
+ task (str | None): The task associated with the model. If None, it will be inferred from the model.
278
+
279
+ Raises:
280
+ FileNotFoundError: If the specified weights file does not exist or is inaccessible.
281
+ ValueError: If the weights file format is unsupported or invalid.
282
+
283
+ Examples:
284
+ >>> model = Model()
285
+ >>> model._load('yolov8n.pt')
286
+ >>> model._load('path/to/weights.pth', task='detect')
224
287
  """
225
288
  if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
226
289
  weights = checks.check_file(weights) # automatically download and return local filename
@@ -241,7 +304,22 @@ class Model(nn.Module):
241
304
  self.model_name = weights
242
305
 
243
306
  def _check_is_pytorch_model(self) -> None:
244
- """Raises TypeError is model is not a PyTorch model."""
307
+ """
308
+ Checks if the model is a PyTorch model and raises a TypeError if it's not.
309
+
310
+ This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
311
+ certain operations that require a PyTorch model are only performed on compatible model types.
312
+
313
+ Raises:
314
+ TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
315
+ information about supported model formats and operations.
316
+
317
+ Examples:
318
+ >>> model = Model("yolov8n.pt")
319
+ >>> model._check_is_pytorch_model() # No error raised
320
+ >>> model = Model("yolov8n.onnx")
321
+ >>> model._check_is_pytorch_model() # Raises TypeError
322
+ """
245
323
  pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
246
324
  pt_module = isinstance(self.model, nn.Module)
247
325
  if not (pt_module or pt_str):
@@ -255,17 +333,21 @@ class Model(nn.Module):
255
333
 
256
334
  def reset_weights(self) -> "Model":
257
335
  """
258
- Resets the model parameters to randomly initialized values, effectively discarding all training information.
336
+ Resets the model's weights to their initial state.
259
337
 
260
338
  This method iterates through all modules in the model and resets their parameters if they have a
261
- 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
262
- to be updated during training.
339
+ 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
340
+ enabling them to be updated during training.
263
341
 
264
342
  Returns:
265
- self (ultralytics.engine.model.Model): The instance of the class with reset weights.
343
+ (Model): The instance of the class with reset weights.
266
344
 
267
345
  Raises:
268
346
  AssertionError: If the model is not a PyTorch model.
347
+
348
+ Examples:
349
+ >>> model = Model('yolov8n.pt')
350
+ >>> model.reset_weights()
269
351
  """
270
352
  self._check_is_pytorch_model()
271
353
  for m in self.model.modules():
@@ -283,13 +365,18 @@ class Model(nn.Module):
283
365
  name and shape and transfers them to the model.
284
366
 
285
367
  Args:
286
- weights (str | Path): Path to the weights file or a weights object. Defaults to 'yolov8n.pt'.
368
+ weights (Union[str, Path]): Path to the weights file or a weights object.
287
369
 
288
370
  Returns:
289
- self (ultralytics.engine.model.Model): The instance of the class with loaded weights.
371
+ (Model): The instance of the class with loaded weights.
290
372
 
291
373
  Raises:
292
374
  AssertionError: If the model is not a PyTorch model.
375
+
376
+ Examples:
377
+ >>> model = Model()
378
+ >>> model.load('yolov8n.pt')
379
+ >>> model.load(Path('path/to/weights.pt'))
293
380
  """
294
381
  self._check_is_pytorch_model()
295
382
  if isinstance(weights, (str, Path)):
@@ -301,14 +388,19 @@ class Model(nn.Module):
301
388
  """
302
389
  Saves the current model state to a file.
303
390
 
304
- This method exports the model's checkpoint (ckpt) to the specified filename.
391
+ This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
392
+ the date, Ultralytics version, license information, and a link to the documentation.
305
393
 
306
394
  Args:
307
- filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
308
- use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
395
+ filename (Union[str, Path]): The name of the file to save the model to.
396
+ use_dill (bool): Whether to try using dill for serialization if available.
309
397
 
310
398
  Raises:
311
399
  AssertionError: If the model is not a PyTorch model.
400
+
401
+ Examples:
402
+ >>> model = Model('yolov8n.pt')
403
+ >>> model.save('my_model.pt')
312
404
  """
313
405
  self._check_is_pytorch_model()
314
406
  from copy import deepcopy
@@ -329,30 +421,47 @@ class Model(nn.Module):
329
421
  """
330
422
  Logs or returns model information.
331
423
 
332
- This method provides an overview or detailed information about the model, depending on the arguments passed.
333
- It can control the verbosity of the output.
424
+ This method provides an overview or detailed information about the model, depending on the arguments
425
+ passed. It can control the verbosity of the output and return the information as a list.
334
426
 
335
427
  Args:
336
- detailed (bool): If True, shows detailed information about the model. Defaults to False.
337
- verbose (bool): If True, prints the information. If False, returns the information. Defaults to True.
428
+ detailed (bool): If True, shows detailed information about the model layers and parameters.
429
+ verbose (bool): If True, prints the information. If False, returns the information as a list.
338
430
 
339
431
  Returns:
340
- (list): Various types of information about the model, depending on the 'detailed' and 'verbose' parameters.
432
+ (List[str]): A list of strings containing various types of information about the model, including
433
+ model summary, layer details, and parameter counts. Empty if verbose is True.
341
434
 
342
435
  Raises:
343
- AssertionError: If the model is not a PyTorch model.
436
+ TypeError: If the model is not a PyTorch model.
437
+
438
+ Examples:
439
+ >>> model = Model('yolov8n.pt')
440
+ >>> model.info() # Prints model summary
441
+ >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list
344
442
  """
345
443
  self._check_is_pytorch_model()
346
444
  return self.model.info(detailed=detailed, verbose=verbose)
347
445
 
348
446
  def fuse(self):
349
447
  """
350
- Fuses Conv2d and BatchNorm2d layers in the model.
448
+ Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.
449
+
450
+ This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
451
+ into a single layer. This fusion can significantly improve inference speed by reducing the number of
452
+ operations and memory accesses required during forward passes.
351
453
 
352
- This method optimizes the model by fusing Conv2d and BatchNorm2d layers, which can improve inference speed.
454
+ The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
455
+ bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
456
+ performs both convolution and normalization in one step.
353
457
 
354
458
  Raises:
355
- AssertionError: If the model is not a PyTorch model.
459
+ TypeError: If the model is not a PyTorch nn.Module.
460
+
461
+ Examples:
462
+ >>> model = Model("yolov8n.pt")
463
+ >>> model.fuse()
464
+ >>> # Model is now fused and ready for optimized inference
356
465
  """
357
466
  self._check_is_pytorch_model()
358
467
  self.model.fuse()
@@ -366,20 +475,26 @@ class Model(nn.Module):
366
475
  """
367
476
  Generates image embeddings based on the provided source.
368
477
 
369
- This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image source.
370
- It allows customization of the embedding process through various keyword arguments.
478
+ This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
479
+ source. It allows customization of the embedding process through various keyword arguments.
371
480
 
372
481
  Args:
373
- source (str | int | PIL.Image | np.ndarray): The source of the image for generating embeddings.
374
- The source can be a file path, URL, PIL image, numpy array, etc. Defaults to None.
375
- stream (bool): If True, predictions are streamed. Defaults to False.
376
- **kwargs (any): Additional keyword arguments for configuring the embedding process.
482
+ source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
483
+ generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
484
+ stream (bool): If True, predictions are streamed.
485
+ **kwargs (Any): Additional keyword arguments for configuring the embedding process.
377
486
 
378
487
  Returns:
379
488
  (List[torch.Tensor]): A list containing the image embeddings.
380
489
 
381
490
  Raises:
382
491
  AssertionError: If the model is not a PyTorch model.
492
+
493
+ Examples:
494
+ >>> model = YOLO('yolov8n.pt')
495
+ >>> image = 'https://ultralytics.com/images/bus.jpg'
496
+ >>> embeddings = model.embed(image)
497
+ >>> print(embeddings[0].shape)
383
498
  """
384
499
  if not kwargs.get("embed"):
385
500
  kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
@@ -397,28 +512,31 @@ class Model(nn.Module):
397
512
 
398
513
  This method facilitates the prediction process, allowing various configurations through keyword arguments.
399
514
  It supports predictions with custom predictors or the default predictor method. The method handles different
400
- types of image sources and can operate in a streaming mode. It also provides support for SAM-type models
401
- through 'prompts'.
402
-
403
- The method sets up a new predictor if not already present and updates its arguments with each call.
404
- It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
405
- is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
406
- for confidence threshold and saving behavior.
515
+ types of image sources and can operate in a streaming mode.
407
516
 
408
517
  Args:
409
- source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
410
- Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS.
411
- stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False.
412
- predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
413
- If None, the method uses a default predictor. Defaults to None.
414
- **kwargs (any): Additional keyword arguments for configuring the prediction process. These arguments allow
415
- for further customization of the prediction behavior.
518
+ source (str | Path | int | List[str] | List[Path] | List[int] | np.ndarray | torch.Tensor): The source
519
+ of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
520
+ images, numpy arrays, and torch tensors.
521
+ stream (bool): If True, treats the input source as a continuous stream for predictions.
522
+ predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
523
+ If None, the method uses a default predictor.
524
+ **kwargs (Any): Additional keyword arguments for configuring the prediction process.
416
525
 
417
526
  Returns:
418
- (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
419
-
420
- Raises:
421
- AttributeError: If the predictor is not properly set up.
527
+ (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
528
+ Results object.
529
+
530
+ Examples:
531
+ >>> model = YOLO('yolov8n.pt')
532
+ >>> results = model.predict(source='path/to/image.jpg', conf=0.25)
533
+ >>> for r in results:
534
+ ... print(r.boxes.data) # print detection bounding boxes
535
+
536
+ Notes:
537
+ - If 'source' is not provided, it defaults to the ASSETS constant with a warning.
538
+ - The method sets up a new predictor if not already present and updates its arguments with each call.
539
+ - For SAM-type models, 'prompts' can be passed as a keyword argument.
422
540
  """
423
541
  if source is None:
424
542
  source = ASSETS
@@ -453,26 +571,33 @@ class Model(nn.Module):
453
571
  """
454
572
  Conducts object tracking on the specified input source using the registered trackers.
455
573
 
456
- This method performs object tracking using the model's predictors and optionally registered trackers. It is
457
- capable of handling different types of input sources such as file paths or video streams. The method supports
458
- customization of the tracking process through various keyword arguments. It registers trackers if they are not
459
- already present and optionally persists them based on the 'persist' flag.
460
-
461
- The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low
462
- confidence predictions as input. The tracking mode is explicitly set in the keyword arguments.
574
+ This method performs object tracking using the model's predictors and optionally registered trackers. It handles
575
+ various input sources such as file paths or video streams, and supports customization through keyword arguments.
576
+ The method registers trackers if not already present and can persist them between calls.
463
577
 
464
578
  Args:
465
- source (str, optional): The input source for object tracking. It can be a file path, URL, or video stream.
466
- stream (bool, optional): Treats the input source as a continuous video stream. Defaults to False.
467
- persist (bool, optional): Persists the trackers between different calls to this method. Defaults to False.
468
- **kwargs (any): Additional keyword arguments for configuring the tracking process. These arguments allow
469
- for further customization of the tracking behavior.
579
+ source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
580
+ tracking. Can be a file path, URL, or video stream.
581
+ stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
582
+ persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
583
+ **kwargs (Any): Additional keyword arguments for configuring the tracking process.
470
584
 
471
585
  Returns:
472
- (List[ultralytics.engine.results.Results]): A list of tracking results, encapsulated in the Results class.
586
+ (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
473
587
 
474
588
  Raises:
475
589
  AttributeError: If the predictor does not have registered trackers.
590
+
591
+ Examples:
592
+ >>> model = YOLO('yolov8n.pt')
593
+ >>> results = model.track(source='path/to/video.mp4', show=True)
594
+ >>> for r in results:
595
+ ... print(r.boxes.id) # print tracking IDs
596
+
597
+ Notes:
598
+ - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.
599
+ - The tracking mode is explicitly set in the keyword arguments.
600
+ - Batch size is set to 1 for tracking in videos.
476
601
  """
477
602
  if not hasattr(self.predictor, "trackers"):
478
603
  from ultralytics.trackers import register_tracker
@@ -491,26 +616,25 @@ class Model(nn.Module):
491
616
  """
492
617
  Validates the model using a specified dataset and validation configuration.
493
618
 
494
- This method facilitates the model validation process, allowing for a range of customization through various
495
- settings and configurations. It supports validation with a custom validator or the default validation approach.
496
- The method combines default configurations, method-specific defaults, and user-provided arguments to configure
497
- the validation process. After validation, it updates the model's metrics with the results obtained from the
498
- validator.
499
-
500
- The method supports various arguments that allow customization of the validation process. For a comprehensive
501
- list of all configurable options, users should refer to the 'configuration' section in the documentation.
619
+ This method facilitates the model validation process, allowing for customization through various settings. It
620
+ supports validation with a custom validator or the default validation approach. The method combines default
621
+ configurations, method-specific defaults, and user-provided arguments to configure the validation process.
502
622
 
503
623
  Args:
504
- validator (BaseValidator, optional): An instance of a custom validator class for validating the model. If
505
- None, the method uses a default validator. Defaults to None.
506
- **kwargs (any): Arbitrary keyword arguments representing the validation configuration. These arguments are
507
- used to customize various aspects of the validation process.
624
+ validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
625
+ validating the model.
626
+ **kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
508
627
 
509
628
  Returns:
510
629
  (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.
511
630
 
512
631
  Raises:
513
632
  AssertionError: If the model is not a PyTorch model.
633
+
634
+ Examples:
635
+ >>> model = YOLO('yolov8n.pt')
636
+ >>> results = model.val(data='coco128.yaml', imgsz=640)
637
+ >>> print(results.box.map) # Print mAP50-95
514
638
  """
515
639
  custom = {"rect": True} # method defaults
516
640
  args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
@@ -528,23 +652,31 @@ class Model(nn.Module):
528
652
  Benchmarks the model across various export formats to evaluate performance.
529
653
 
530
654
  This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
531
- It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured
532
- using a combination of default configuration values, model-specific arguments, method-specific defaults, and
533
- any additional user-provided keyword arguments.
534
-
535
- The method supports various arguments that allow customization of the benchmarking process, such as dataset
536
- choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all
537
- configurable options, users should refer to the 'configuration' section in the documentation.
655
+ It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
656
+ configured using a combination of default configuration values, model-specific arguments, method-specific
657
+ defaults, and any additional user-provided keyword arguments.
538
658
 
539
659
  Args:
540
- **kwargs (any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
541
- default configurations, model-specific arguments, and method defaults.
660
+ **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
661
+ default configurations, model-specific arguments, and method defaults. Common options include:
662
+ - data (str): Path to the dataset for benchmarking.
663
+ - imgsz (int | List[int]): Image size for benchmarking.
664
+ - half (bool): Whether to use half-precision (FP16) mode.
665
+ - int8 (bool): Whether to use int8 precision mode.
666
+ - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
667
+ - verbose (bool): Whether to print detailed benchmark information.
542
668
 
543
669
  Returns:
544
- (dict): A dictionary containing the results of the benchmarking process.
670
+ (Dict): A dictionary containing the results of the benchmarking process, including metrics for
671
+ different export formats.
545
672
 
546
673
  Raises:
547
674
  AssertionError: If the model is not a PyTorch model.
675
+
676
+ Examples:
677
+ >>> model = YOLO('yolov8n.pt')
678
+ >>> results = model.benchmark(data='coco8.yaml', imgsz=640, half=True)
679
+ >>> print(results)
548
680
  """
549
681
  self._check_is_pytorch_model()
550
682
  from ultralytics.utils.benchmarks import benchmark
@@ -570,20 +702,31 @@ class Model(nn.Module):
570
702
 
571
703
  This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
572
704
  purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
573
- defaults, and any additional arguments provided. The combined arguments are used to configure export settings.
574
-
575
- The method supports a wide range of arguments to customize the export process. For a comprehensive list of all
576
- possible arguments, refer to the 'configuration' section in the documentation.
705
+ defaults, and any additional arguments provided.
577
706
 
578
707
  Args:
579
- **kwargs (any): Arbitrary keyword arguments to customize the export process. These are combined with the
580
- model's overrides and method defaults.
708
+ **kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with
709
+ the model's overrides and method defaults. Common arguments include:
710
+ format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
711
+ half (bool): Export model in half-precision.
712
+ int8 (bool): Export model in int8 precision.
713
+ device (str): Device to run the export on.
714
+ workspace (int): Maximum memory workspace size for TensorRT engines.
715
+ nms (bool): Add Non-Maximum Suppression (NMS) module to model.
716
+ simplify (bool): Simplify ONNX model.
581
717
 
582
718
  Returns:
583
- (str): The exported model filename in the specified format, or an object related to the export process.
719
+ (str): The path to the exported model file.
584
720
 
585
721
  Raises:
586
722
  AssertionError: If the model is not a PyTorch model.
723
+ ValueError: If an unsupported export format is specified.
724
+ RuntimeError: If the export process fails due to errors.
725
+
726
+ Examples:
727
+ >>> model = YOLO('yolov8n.pt')
728
+ >>> model.export(format='onnx', dynamic=True, simplify=True)
729
+ 'path/to/exported/model.onnx'
587
730
  """
588
731
  self._check_is_pytorch_model()
589
732
  from .exporter import Exporter
@@ -606,29 +749,38 @@ class Model(nn.Module):
606
749
  """
607
750
  Trains the model using the specified dataset and training configuration.
608
751
 
609
- This method facilitates model training with a range of customizable settings and configurations. It supports
610
- training with a custom trainer or the default training approach defined in the method. The method handles
611
- different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
612
- updating model and configuration after training.
752
+ This method facilitates model training with a range of customizable settings. It supports training with a
753
+ custom trainer or the default training approach. The method handles scenarios such as resuming training
754
+ from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
613
755
 
614
- When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
615
- arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
616
- configurations, method-specific defaults, and user-provided arguments to configure the training process. After
617
- training, it updates the model and its configurations, and optionally attaches metrics.
756
+ When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
757
+ arguments and warns if local arguments are provided. It checks for pip updates and combines default
758
+ configurations, method-specific defaults, and user-provided arguments to configure the training process.
618
759
 
619
760
  Args:
620
- trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
621
- method uses a default trainer. Defaults to None.
622
- **kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are
623
- used to customize various aspects of the training process.
761
+ trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
762
+ **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
763
+ data (str): Path to dataset configuration file.
764
+ epochs (int): Number of training epochs.
765
+ batch_size (int): Batch size for training.
766
+ imgsz (int): Input image size.
767
+ device (str): Device to run training on (e.g., 'cuda', 'cpu').
768
+ workers (int): Number of worker threads for data loading.
769
+ optimizer (str): Optimizer to use for training.
770
+ lr0 (float): Initial learning rate.
771
+ patience (int): Epochs to wait for no observable improvement for early stopping of training.
624
772
 
625
773
  Returns:
626
- (dict | None): Training metrics if available and training is successful; otherwise, None.
774
+ (Dict | None): Training metrics if available and training is successful; otherwise, None.
627
775
 
628
776
  Raises:
629
777
  AssertionError: If the model is not a PyTorch model.
630
778
  PermissionError: If there is a permission issue with the HUB session.
631
779
  ModuleNotFoundError: If the HUB SDK is not installed.
780
+
781
+ Examples:
782
+ >>> model = YOLO('yolov8n.pt')
783
+ >>> results = model.train(data='coco128.yaml', epochs=3)
632
784
  """
633
785
  self._check_is_pytorch_model()
634
786
  if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
@@ -682,14 +834,19 @@ class Model(nn.Module):
682
834
  Args:
683
835
  use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
684
836
  iterations (int): The number of tuning iterations to perform. Defaults to 10.
685
- *args (list): Variable length argument list for additional arguments.
686
- **kwargs (any): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
837
+ *args (List): Variable length argument list for additional arguments.
838
+ **kwargs (Dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
687
839
 
688
840
  Returns:
689
- (dict): A dictionary containing the results of the hyperparameter search.
841
+ (Dict): A dictionary containing the results of the hyperparameter search.
690
842
 
691
843
  Raises:
692
844
  AssertionError: If the model is not a PyTorch model.
845
+
846
+ Examples:
847
+ >>> model = YOLO('yolov8n.pt')
848
+ >>> results = model.tune(use_ray=True, iterations=20)
849
+ >>> print(results)
693
850
  """
694
851
  self._check_is_pytorch_model()
695
852
  if use_ray:
@@ -704,7 +861,27 @@ class Model(nn.Module):
704
861
  return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
705
862
 
706
863
  def _apply(self, fn) -> "Model":
707
- """Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
864
+ """
865
+ Applies a function to model tensors that are not parameters or registered buffers.
866
+
867
+ This method extends the functionality of the parent class's _apply method by additionally resetting the
868
+ predictor and updating the device in the model's overrides. It's typically used for operations like
869
+ moving the model to a different device or changing its precision.
870
+
871
+ Args:
872
+ fn (Callable): A function to be applied to the model's tensors. This is typically a method like
873
+ to(), cpu(), cuda(), half(), or float().
874
+
875
+ Returns:
876
+ (Model): The model instance with the function applied and updated attributes.
877
+
878
+ Raises:
879
+ AssertionError: If the model is not a PyTorch model.
880
+
881
+ Examples:
882
+ >>> model = Model("yolov8n.pt")
883
+ >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
884
+ """
708
885
  self._check_is_pytorch_model()
709
886
  self = super()._apply(fn) # noqa
710
887
  self.predictor = None # reset predictor as device may have changed
@@ -717,10 +894,19 @@ class Model(nn.Module):
717
894
  Retrieves the class names associated with the loaded model.
718
895
 
719
896
  This property returns the class names if they are defined in the model. It checks the class names for validity
720
- using the 'check_class_names' function from the ultralytics.nn.autobackend module.
897
+ using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
898
+ initialized, it sets it up before retrieving the names.
721
899
 
722
900
  Returns:
723
- (list | None): The class names of the model if available, otherwise None.
901
+ (List[str]): A list of class names associated with the model.
902
+
903
+ Raises:
904
+ AttributeError: If the model or predictor does not have a 'names' attribute.
905
+
906
+ Examples:
907
+ >>> model = YOLO('yolov8n.pt')
908
+ >>> print(model.names)
909
+ ['person', 'bicycle', 'car', ...]
724
910
  """
725
911
  from ultralytics.nn.autobackend import check_class_names
726
912
 
@@ -736,11 +922,22 @@ class Model(nn.Module):
736
922
  """
737
923
  Retrieves the device on which the model's parameters are allocated.
738
924
 
739
- This property is used to determine whether the model's parameters are on CPU or GPU. It only applies to models
740
- that are instances of nn.Module.
925
+ This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
926
+ applicable only to models that are instances of nn.Module.
741
927
 
742
928
  Returns:
743
- (torch.device | None): The device (CPU/GPU) of the model if it is a PyTorch model, otherwise None.
929
+ (torch.device): The device (CPU/GPU) of the model.
930
+
931
+ Raises:
932
+ AttributeError: If the model is not a PyTorch nn.Module instance.
933
+
934
+ Examples:
935
+ >>> model = YOLO("yolov8n.pt")
936
+ >>> print(model.device)
937
+ device(type='cuda', index=0) # if CUDA is available
938
+ >>> model = model.to("cpu")
939
+ >>> print(model.device)
940
+ device(type='cpu')
744
941
  """
745
942
  return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
746
943
 
@@ -749,10 +946,20 @@ class Model(nn.Module):
749
946
  """
750
947
  Retrieves the transformations applied to the input data of the loaded model.
751
948
 
752
- This property returns the transformations if they are defined in the model.
949
+ This property returns the transformations if they are defined in the model. The transforms
950
+ typically include preprocessing steps like resizing, normalization, and data augmentation
951
+ that are applied to input data before it is fed into the model.
753
952
 
754
953
  Returns:
755
954
  (object | None): The transform object of the model if available, otherwise None.
955
+
956
+ Examples:
957
+ >>> model = YOLO('yolov8n.pt')
958
+ >>> transforms = model.transforms
959
+ >>> if transforms:
960
+ ... print(f"Model transforms: {transforms}")
961
+ ... else:
962
+ ... print("No transforms defined for this model.")
756
963
  """
757
964
  return self.model.transforms if hasattr(self.model, "transforms") else None
758
965
 
@@ -760,15 +967,25 @@ class Model(nn.Module):
760
967
  """
761
968
  Adds a callback function for a specified event.
762
969
 
763
- This method allows the user to register a custom callback function that is triggered on a specific event during
764
- model training or inference.
970
+ This method allows registering custom callback functions that are triggered on specific events during
971
+ model operations such as training or inference. Callbacks provide a way to extend and customize the
972
+ behavior of the model at various stages of its lifecycle.
765
973
 
766
974
  Args:
767
- event (str): The name of the event to attach the callback to.
768
- func (callable): The callback function to be registered.
975
+ event (str): The name of the event to attach the callback to. Must be a valid event name recognized
976
+ by the Ultralytics framework.
977
+ func (Callable): The callback function to be registered. This function will be called when the
978
+ specified event occurs.
769
979
 
770
980
  Raises:
771
- ValueError: If the event name is not recognized.
981
+ ValueError: If the event name is not recognized or is invalid.
982
+
983
+ Examples:
984
+ >>> def on_train_start(trainer):
985
+ ... print("Training is starting!")
986
+ >>> model = YOLO('yolov8n.pt')
987
+ >>> model.add_callback("on_train_start", on_train_start)
988
+ >>> model.train(data='coco128.yaml', epochs=1)
772
989
  """
773
990
  self.callbacks[event].append(func)
774
991
 
@@ -777,12 +994,26 @@ class Model(nn.Module):
777
994
  Clears all callback functions registered for a specified event.
778
995
 
779
996
  This method removes all custom and default callback functions associated with the given event.
997
+ It resets the callback list for the specified event to an empty list, effectively removing all
998
+ registered callbacks for that event.
780
999
 
781
1000
  Args:
782
- event (str): The name of the event for which to clear the callbacks.
783
-
784
- Raises:
785
- ValueError: If the event name is not recognized.
1001
+ event (str): The name of the event for which to clear the callbacks. This should be a valid event name
1002
+ recognized by the Ultralytics callback system.
1003
+
1004
+ Examples:
1005
+ >>> model = YOLO('yolov8n.pt')
1006
+ >>> model.add_callback('on_train_start', lambda: print('Training started'))
1007
+ >>> model.clear_callback('on_train_start')
1008
+ >>> # All callbacks for 'on_train_start' are now removed
1009
+
1010
+ Notes:
1011
+ - This method affects both custom callbacks added by the user and default callbacks
1012
+ provided by the Ultralytics framework.
1013
+ - After calling this method, no callbacks will be executed for the specified event
1014
+ until new ones are added.
1015
+ - Use with caution as it removes all callbacks, including essential ones that might
1016
+ be required for proper functioning of certain operations.
786
1017
  """
787
1018
  self.callbacks[event] = []
788
1019
 
@@ -791,14 +1022,45 @@ class Model(nn.Module):
791
1022
  Resets all callbacks to their default functions.
792
1023
 
793
1024
  This method reinstates the default callback functions for all events, removing any custom callbacks that were
794
- added previously.
1025
+ previously added. It iterates through all default callback events and replaces the current callbacks with the
1026
+ default ones.
1027
+
1028
+ The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined
1029
+ functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc.
1030
+
1031
+ This method is useful when you want to revert to the original set of callbacks after making custom
1032
+ modifications, ensuring consistent behavior across different runs or experiments.
1033
+
1034
+ Examples:
1035
+ >>> model = YOLO('yolov8n.pt')
1036
+ >>> model.add_callback('on_train_start', custom_function)
1037
+ >>> model.reset_callbacks()
1038
+ # All callbacks are now reset to their default functions
795
1039
  """
796
1040
  for event in callbacks.default_callbacks.keys():
797
1041
  self.callbacks[event] = [callbacks.default_callbacks[event][0]]
798
1042
 
799
1043
  @staticmethod
800
1044
  def _reset_ckpt_args(args: dict) -> dict:
801
- """Reset arguments when loading a PyTorch model."""
1045
+ """
1046
+ Resets specific arguments when loading a PyTorch model checkpoint.
1047
+
1048
+ This static method filters the input arguments dictionary to retain only a specific set of keys that are
1049
+ considered important for model loading. It's used to ensure that only relevant arguments are preserved
1050
+ when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
1051
+
1052
+ Args:
1053
+ args (dict): A dictionary containing various model arguments and settings.
1054
+
1055
+ Returns:
1056
+ (dict): A new dictionary containing only the specified include keys from the input arguments.
1057
+
1058
+ Examples:
1059
+ >>> original_args = {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect', 'batch': 16, 'epochs': 100}
1060
+ >>> reset_args = Model._reset_ckpt_args(original_args)
1061
+ >>> print(reset_args)
1062
+ {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}
1063
+ """
802
1064
  include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
803
1065
  return {k: v for k, v in args.items() if k in include}
804
1066
 
@@ -808,7 +1070,31 @@ class Model(nn.Module):
808
1070
  # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
809
1071
 
810
1072
  def _smart_load(self, key: str):
811
- """Load model/trainer/validator/predictor."""
1073
+ """
1074
+ Loads the appropriate module based on the model task.
1075
+
1076
+ This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
1077
+ based on the current task of the model and the provided key. It uses the task_map attribute to determine
1078
+ the correct module to load.
1079
+
1080
+ Args:
1081
+ key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
1082
+
1083
+ Returns:
1084
+ (object): The loaded module corresponding to the specified key and current task.
1085
+
1086
+ Raises:
1087
+ NotImplementedError: If the specified key is not supported for the current task.
1088
+
1089
+ Examples:
1090
+ >>> model = Model(task='detect')
1091
+ >>> predictor = model._smart_load('predictor')
1092
+ >>> trainer = model._smart_load('trainer')
1093
+
1094
+ Notes:
1095
+ - This method is typically used internally by other methods of the Model class.
1096
+ - The task_map attribute should be properly initialized with the correct mappings for each task.
1097
+ """
812
1098
  try:
813
1099
  return self.task_map[self.task][key]
814
1100
  except Exception as e:
@@ -821,9 +1107,30 @@ class Model(nn.Module):
821
1107
  @property
822
1108
  def task_map(self) -> dict:
823
1109
  """
824
- Map head to model, trainer, validator, and predictor classes.
1110
+ Provides a mapping from model tasks to corresponding classes for different modes.
1111
+
1112
+ This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
1113
+ to a nested dictionary. The nested dictionary contains mappings for different operational modes
1114
+ (model, trainer, validator, predictor) to their respective class implementations.
1115
+
1116
+ The mapping allows for dynamic loading of appropriate classes based on the model's task and the
1117
+ desired operational mode. This facilitates a flexible and extensible architecture for handling
1118
+ various tasks and modes within the Ultralytics framework.
825
1119
 
826
1120
  Returns:
827
- task_map (dict): The map of model task to mode classes.
1121
+ (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
1122
+ nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
1123
+ 'predictor', mapping to their respective class implementations.
1124
+
1125
+ Examples:
1126
+ >>> model = Model()
1127
+ >>> task_map = model.task_map
1128
+ >>> detect_class_map = task_map['detect']
1129
+ >>> segment_class_map = task_map['segment']
1130
+
1131
+ Note:
1132
+ The actual implementation of this method may vary depending on the specific tasks and
1133
+ classes supported by the Ultralytics framework. The docstring provides a general
1134
+ description of the expected behavior and structure.
828
1135
  """
829
1136
  raise NotImplementedError("Please provide task map for your model!")