ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,29 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import inspect
4
- import sys
5
4
  from pathlib import Path
6
- from typing import Union
5
+ from typing import Any, Dict, List, Union
7
6
 
8
7
  import numpy as np
9
8
  import torch
9
+ from PIL import Image
10
10
 
11
11
  from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
12
- from ultralytics.hub.utils import HUB_WEB_ROOT
12
+ from ultralytics.engine.results import Results
13
+ from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
13
14
  from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
14
- from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
15
+ from ultralytics.utils import (
16
+ ARGV,
17
+ ASSETS,
18
+ DEFAULT_CFG_DICT,
19
+ LOGGER,
20
+ RANK,
21
+ SETTINGS,
22
+ callbacks,
23
+ checks,
24
+ emojis,
25
+ yaml_load,
26
+ )
15
27
 
16
28
 
17
29
  class Model(nn.Module):
@@ -20,26 +32,18 @@ class Model(nn.Module):
20
32
 
21
33
  This class provides a common interface for various operations related to YOLO models, such as training,
22
34
  validation, prediction, exporting, and benchmarking. It handles different types of models, including those
23
- loaded from local files, Ultralytics HUB, or Triton Server. The class is designed to be flexible and
24
- extendable for different tasks and model configurations.
25
-
26
- Args:
27
- model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file
28
- path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
29
- task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's
30
- application domain, such as object detection, segmentation, etc. Defaults to None.
31
- verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False.
35
+ loaded from local files, Ultralytics HUB, or Triton Server.
32
36
 
33
37
  Attributes:
34
- callbacks (dict): A dictionary of callback functions for various events during model operations.
38
+ callbacks (Dict): A dictionary of callback functions for various events during model operations.
35
39
  predictor (BasePredictor): The predictor object used for making predictions.
36
40
  model (nn.Module): The underlying PyTorch model.
37
41
  trainer (BaseTrainer): The trainer object used for training the model.
38
- ckpt (dict): The checkpoint data if the model is loaded from a *.pt file.
42
+ ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
39
43
  cfg (str): The configuration of the model if loaded from a *.yaml file.
40
44
  ckpt_path (str): The path to the checkpoint file.
41
- overrides (dict): A dictionary of overrides for model configuration.
42
- metrics (dict): The latest training/validation metrics.
45
+ overrides (Dict): A dictionary of overrides for model configuration.
46
+ metrics (Dict): The latest training/validation metrics.
43
47
  session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
44
48
  task (str): The type of task the model is intended for.
45
49
  model_name (str): The name of the model.
@@ -65,120 +69,136 @@ class Model(nn.Module):
65
69
  add_callback: Adds a callback function for an event.
66
70
  clear_callback: Clears all callbacks for an event.
67
71
  reset_callbacks: Resets all callbacks to their default functions.
68
- _get_hub_session: Retrieves or creates an Ultralytics HUB session.
69
- is_triton_model: Checks if a model is a Triton Server model.
70
- is_hub_model: Checks if a model is an Ultralytics HUB model.
71
- _reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
72
- _smart_load: Loads the appropriate module based on the model task.
73
- task_map: Provides a mapping from model tasks to corresponding classes.
74
-
75
- Raises:
76
- FileNotFoundError: If the specified model file does not exist or is inaccessible.
77
- ValueError: If the model file or configuration is invalid or unsupported.
78
- ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
79
- TypeError: If the model is not a PyTorch model when required.
80
- AttributeError: If required attributes or methods are not implemented or available.
81
- NotImplementedError: If a specific model task or mode is not supported.
72
+
73
+ Examples:
74
+ >>> from ultralytics import YOLO
75
+ >>> model = YOLO("yolo11n.pt")
76
+ >>> results = model.predict("image.jpg")
77
+ >>> model.train(data="coco8.yaml", epochs=3)
78
+ >>> metrics = model.val()
79
+ >>> model.export(format="onnx")
82
80
  """
83
81
 
84
82
  def __init__(
85
83
  self,
86
- model: Union[str, Path] = "yolov8n.pt",
84
+ model: Union[str, Path] = "yolo11n.pt",
87
85
  task: str = None,
88
86
  verbose: bool = False,
89
87
  ) -> None:
90
88
  """
91
89
  Initializes a new instance of the YOLO model class.
92
90
 
93
- This constructor sets up the model based on the provided model path or name. It handles various types of model
94
- sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
95
- important attributes of the model and prepares it for operations like training, prediction, or export.
91
+ This constructor sets up the model based on the provided model path or name. It handles various types of
92
+ model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
93
+ initializes several important attributes of the model and prepares it for operations like training,
94
+ prediction, or export.
96
95
 
97
96
  Args:
98
- model (Union[str, Path], optional): The path or model file to load or create. This can be a local
99
- file path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
100
- task (Any, optional): The task type associated with the YOLO model, specifying its application domain.
101
- Defaults to None.
102
- verbose (bool, optional): If True, enables verbose output during the model's initialization and subsequent
103
- operations. Defaults to False.
97
+ model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
98
+ model name from Ultralytics HUB, or a Triton Server model.
99
+ task (str | None): The task type associated with the YOLO model, specifying its application domain.
100
+ verbose (bool): If True, enables verbose output during the model's initialization and subsequent
101
+ operations.
104
102
 
105
103
  Raises:
106
104
  FileNotFoundError: If the specified model file does not exist or is inaccessible.
107
105
  ValueError: If the model file or configuration is invalid or unsupported.
108
106
  ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
107
+
108
+ Examples:
109
+ >>> model = Model("yolo11n.pt")
110
+ >>> model = Model("path/to/model.yaml", task="detect")
111
+ >>> model = Model("hub_model", verbose=True)
109
112
  """
110
113
  super().__init__()
111
114
  self.callbacks = callbacks.get_default_callbacks()
112
115
  self.predictor = None # reuse predictor
113
116
  self.model = None # model object
114
117
  self.trainer = None # trainer object
115
- self.ckpt = None # if loaded from *.pt
118
+ self.ckpt = {} # if loaded from *.pt
116
119
  self.cfg = None # if loaded from *.yaml
117
120
  self.ckpt_path = None
118
121
  self.overrides = {} # overrides for trainer object
119
122
  self.metrics = None # validation/training metrics
120
123
  self.session = None # HUB session
121
124
  self.task = task # task type
122
- self.model_name = model = str(model).strip() # strip spaces
125
+ model = str(model).strip()
123
126
 
124
127
  # Check if Ultralytics HUB model from https://hub.ultralytics.com
125
128
  if self.is_hub_model(model):
126
129
  # Fetch model from HUB
127
- checks.check_requirements("hub-sdk>0.0.2")
128
- self.session = self._get_hub_session(model)
129
- model = self.session.model_file
130
+ checks.check_requirements("hub-sdk>=0.0.12")
131
+ session = HUBTrainingSession.create_session(model)
132
+ model = session.model_file
133
+ if session.train_args: # training sent from HUB
134
+ self.session = session
130
135
 
131
136
  # Check if Triton Server model
132
137
  elif self.is_triton_model(model):
133
- self.model = model
134
- self.task = task
138
+ self.model_name = self.model = model
139
+ self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set
135
140
  return
136
141
 
137
142
  # Load or create new YOLO model
138
- model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
139
- if Path(model).suffix in (".yaml", ".yml"):
143
+ if Path(model).suffix in {".yaml", ".yml"}:
140
144
  self._new(model, task=task, verbose=verbose)
141
145
  else:
142
146
  self._load(model, task=task)
143
147
 
144
- self.model_name = model
148
+ # Delete super().training for accessing self.model.training
149
+ del self.training
145
150
 
146
151
  def __call__(
147
152
  self,
148
- source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
153
+ source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
149
154
  stream: bool = False,
150
- **kwargs,
155
+ **kwargs: Any,
151
156
  ) -> list:
152
157
  """
153
- An alias for the predict method, enabling the model instance to be callable.
158
+ Alias for the predict method, enabling the model instance to be callable for predictions.
154
159
 
155
- This method simplifies the process of making predictions by allowing the model instance to be called directly
156
- with the required arguments for prediction.
160
+ This method simplifies the process of making predictions by allowing the model instance to be called
161
+ directly with the required arguments.
157
162
 
158
163
  Args:
159
- source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making
160
- predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays.
161
- Defaults to None.
162
- stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
163
- Defaults to False.
164
- **kwargs (any): Additional keyword arguments for configuring the prediction process.
164
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
165
+ the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
166
+ tensor, or a list/tuple of these.
167
+ stream (bool): If True, treat the input source as a continuous stream for predictions.
168
+ **kwargs: Additional keyword arguments to configure the prediction process.
165
169
 
166
170
  Returns:
167
- (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
171
+ (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
172
+ Results object.
173
+
174
+ Examples:
175
+ >>> model = YOLO("yolo11n.pt")
176
+ >>> results = model("https://ultralytics.com/images/bus.jpg")
177
+ >>> for r in results:
178
+ ... print(f"Detected {len(r)} objects in image")
168
179
  """
169
180
  return self.predict(source, stream, **kwargs)
170
181
 
171
182
  @staticmethod
172
- def _get_hub_session(model: str):
173
- """Creates a session for Hub Training."""
174
- from ultralytics.hub.session import HUBTrainingSession
183
+ def is_triton_model(model: str) -> bool:
184
+ """
185
+ Checks if the given model string is a Triton Server URL.
175
186
 
176
- session = HUBTrainingSession(model)
177
- return session if session.client.authenticated else None
187
+ This static method determines whether the provided model string represents a valid Triton Server URL by
188
+ parsing its components using urllib.parse.urlsplit().
178
189
 
179
- @staticmethod
180
- def is_triton_model(model: str) -> bool:
181
- """Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
190
+ Args:
191
+ model (str): The model string to be checked.
192
+
193
+ Returns:
194
+ (bool): True if the model string is a valid Triton Server URL, False otherwise.
195
+
196
+ Examples:
197
+ >>> Model.is_triton_model("http://localhost:8000/v2/models/yolov8n")
198
+ True
199
+ >>> Model.is_triton_model("yolo11n.pt")
200
+ False
201
+ """
182
202
  from urllib.parse import urlsplit
183
203
 
184
204
  url = urlsplit(model)
@@ -186,24 +206,48 @@ class Model(nn.Module):
186
206
 
187
207
  @staticmethod
188
208
  def is_hub_model(model: str) -> bool:
189
- """Check if the provided model is a HUB model."""
190
- return any(
191
- (
192
- model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
193
- [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
194
- len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
195
- )
196
- )
209
+ """
210
+ Check if the provided model is an Ultralytics HUB model.
211
+
212
+ This static method determines whether the given model string represents a valid Ultralytics HUB model
213
+ identifier.
214
+
215
+ Args:
216
+ model (str): The model string to check.
217
+
218
+ Returns:
219
+ (bool): True if the model is a valid Ultralytics HUB model, False otherwise.
220
+
221
+ Examples:
222
+ >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
223
+ True
224
+ >>> Model.is_hub_model("yolo11n.pt")
225
+ False
226
+ """
227
+ return model.startswith(f"{HUB_WEB_ROOT}/models/")
197
228
 
198
229
  def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
199
230
  """
200
231
  Initializes a new model and infers the task type from the model definitions.
201
232
 
233
+ This method creates a new model instance based on the provided configuration file. It loads the model
234
+ configuration, infers the task type if not specified, and initializes the model using the appropriate
235
+ class from the task map.
236
+
202
237
  Args:
203
- cfg (str): model configuration file
204
- task (str | None): model task
205
- model (BaseModel): Customized model.
206
- verbose (bool): display model info on load
238
+ cfg (str): Path to the model configuration file in YAML format.
239
+ task (str | None): The specific task for the model. If None, it will be inferred from the config.
240
+ model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
241
+ a new one.
242
+ verbose (bool): If True, displays model information during loading.
243
+
244
+ Raises:
245
+ ValueError: If the configuration file is invalid or the task cannot be inferred.
246
+ ImportError: If the required dependencies for the specified task are not installed.
247
+
248
+ Examples:
249
+ >>> model = Model()
250
+ >>> model._new("yolov8n.yaml", task="detect", verbose=True)
207
251
  """
208
252
  cfg_dict = yaml_model_load(cfg)
209
253
  self.cfg = cfg
@@ -215,31 +259,63 @@ class Model(nn.Module):
215
259
  # Below added to allow export from YAMLs
216
260
  self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
217
261
  self.model.task = self.task
262
+ self.model_name = cfg
218
263
 
219
264
  def _load(self, weights: str, task=None) -> None:
220
265
  """
221
- Initializes a new model and infers the task type from the model head.
266
+ Loads a model from a checkpoint file or initializes it from a weights file.
267
+
268
+ This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
269
+ up the model, task, and related attributes based on the loaded weights.
222
270
 
223
271
  Args:
224
- weights (str): model checkpoint to be loaded
225
- task (str | None): model task
272
+ weights (str): Path to the model weights file to be loaded.
273
+ task (str | None): The task associated with the model. If None, it will be inferred from the model.
274
+
275
+ Raises:
276
+ FileNotFoundError: If the specified weights file does not exist or is inaccessible.
277
+ ValueError: If the weights file format is unsupported or invalid.
278
+
279
+ Examples:
280
+ >>> model = Model()
281
+ >>> model._load("yolo11n.pt")
282
+ >>> model._load("path/to/weights.pth", task="detect")
226
283
  """
227
- suffix = Path(weights).suffix
228
- if suffix == ".pt":
284
+ if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
285
+ weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
286
+ weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
287
+
288
+ if Path(weights).suffix == ".pt":
229
289
  self.model, self.ckpt = attempt_load_one_weight(weights)
230
290
  self.task = self.model.args["task"]
231
291
  self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
232
292
  self.ckpt_path = self.model.pt_path
233
293
  else:
234
- weights = checks.check_file(weights)
294
+ weights = checks.check_file(weights) # runs in all cases, not redundant with above call
235
295
  self.model, self.ckpt = weights, None
236
296
  self.task = task or guess_model_task(weights)
237
297
  self.ckpt_path = weights
238
298
  self.overrides["model"] = weights
239
299
  self.overrides["task"] = self.task
300
+ self.model_name = weights
240
301
 
241
302
  def _check_is_pytorch_model(self) -> None:
242
- """Raises TypeError is model is not a PyTorch model."""
303
+ """
304
+ Checks if the model is a PyTorch model and raises a TypeError if it's not.
305
+
306
+ This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
307
+ certain operations that require a PyTorch model are only performed on compatible model types.
308
+
309
+ Raises:
310
+ TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
311
+ information about supported model formats and operations.
312
+
313
+ Examples:
314
+ >>> model = Model("yolo11n.pt")
315
+ >>> model._check_is_pytorch_model() # No error raised
316
+ >>> model = Model("yolov8n.onnx")
317
+ >>> model._check_is_pytorch_model() # Raises TypeError
318
+ """
243
319
  pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
244
320
  pt_module = isinstance(self.model, nn.Module)
245
321
  if not (pt_module or pt_str):
@@ -253,17 +329,21 @@ class Model(nn.Module):
253
329
 
254
330
  def reset_weights(self) -> "Model":
255
331
  """
256
- Resets the model parameters to randomly initialized values, effectively discarding all training information.
332
+ Resets the model's weights to their initial state.
257
333
 
258
334
  This method iterates through all modules in the model and resets their parameters if they have a
259
- 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
260
- to be updated during training.
335
+ 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
336
+ enabling them to be updated during training.
261
337
 
262
338
  Returns:
263
- self (ultralytics.engine.model.Model): The instance of the class with reset weights.
339
+ (Model): The instance of the class with reset weights.
264
340
 
265
341
  Raises:
266
342
  AssertionError: If the model is not a PyTorch model.
343
+
344
+ Examples:
345
+ >>> model = Model("yolo11n.pt")
346
+ >>> model.reset_weights()
267
347
  """
268
348
  self._check_is_pytorch_model()
269
349
  for m in self.model.modules():
@@ -273,7 +353,7 @@ class Model(nn.Module):
273
353
  p.requires_grad = True
274
354
  return self
275
355
 
276
- def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model":
356
+ def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model":
277
357
  """
278
358
  Loads parameters from the specified weights file into the model.
279
359
 
@@ -281,73 +361,103 @@ class Model(nn.Module):
281
361
  name and shape and transfers them to the model.
282
362
 
283
363
  Args:
284
- weights (str | Path): Path to the weights file or a weights object. Defaults to 'yolov8n.pt'.
364
+ weights (Union[str, Path]): Path to the weights file or a weights object.
285
365
 
286
366
  Returns:
287
- self (ultralytics.engine.model.Model): The instance of the class with loaded weights.
367
+ (Model): The instance of the class with loaded weights.
288
368
 
289
369
  Raises:
290
370
  AssertionError: If the model is not a PyTorch model.
371
+
372
+ Examples:
373
+ >>> model = Model()
374
+ >>> model.load("yolo11n.pt")
375
+ >>> model.load(Path("path/to/weights.pt"))
291
376
  """
292
377
  self._check_is_pytorch_model()
293
378
  if isinstance(weights, (str, Path)):
379
+ self.overrides["pretrained"] = weights # remember the weights for DDP training
294
380
  weights, self.ckpt = attempt_load_one_weight(weights)
295
381
  self.model.load(weights)
296
382
  return self
297
383
 
298
- def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None:
384
+ def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
299
385
  """
300
386
  Saves the current model state to a file.
301
387
 
302
- This method exports the model's checkpoint (ckpt) to the specified filename.
388
+ This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
389
+ the date, Ultralytics version, license information, and a link to the documentation.
303
390
 
304
391
  Args:
305
- filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
306
- use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
392
+ filename (Union[str, Path]): The name of the file to save the model to.
307
393
 
308
394
  Raises:
309
395
  AssertionError: If the model is not a PyTorch model.
396
+
397
+ Examples:
398
+ >>> model = Model("yolo11n.pt")
399
+ >>> model.save("my_model.pt")
310
400
  """
311
401
  self._check_is_pytorch_model()
312
- from ultralytics import __version__
402
+ from copy import deepcopy
313
403
  from datetime import datetime
314
404
 
405
+ from ultralytics import __version__
406
+
315
407
  updates = {
408
+ "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
316
409
  "date": datetime.now().isoformat(),
317
410
  "version": __version__,
318
411
  "license": "AGPL-3.0 License (https://ultralytics.com/license)",
319
412
  "docs": "https://docs.ultralytics.com",
320
413
  }
321
- torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill)
414
+ torch.save({**self.ckpt, **updates}, filename)
322
415
 
323
416
  def info(self, detailed: bool = False, verbose: bool = True):
324
417
  """
325
418
  Logs or returns model information.
326
419
 
327
- This method provides an overview or detailed information about the model, depending on the arguments passed.
328
- It can control the verbosity of the output.
420
+ This method provides an overview or detailed information about the model, depending on the arguments
421
+ passed. It can control the verbosity of the output and return the information as a list.
329
422
 
330
423
  Args:
331
- detailed (bool): If True, shows detailed information about the model. Defaults to False.
332
- verbose (bool): If True, prints the information. If False, returns the information. Defaults to True.
424
+ detailed (bool): If True, shows detailed information about the model layers and parameters.
425
+ verbose (bool): If True, prints the information. If False, returns the information as a list.
333
426
 
334
427
  Returns:
335
- (list): Various types of information about the model, depending on the 'detailed' and 'verbose' parameters.
428
+ (List[str]): A list of strings containing various types of information about the model, including
429
+ model summary, layer details, and parameter counts. Empty if verbose is True.
336
430
 
337
431
  Raises:
338
- AssertionError: If the model is not a PyTorch model.
432
+ TypeError: If the model is not a PyTorch model.
433
+
434
+ Examples:
435
+ >>> model = Model("yolo11n.pt")
436
+ >>> model.info() # Prints model summary
437
+ >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list
339
438
  """
340
439
  self._check_is_pytorch_model()
341
440
  return self.model.info(detailed=detailed, verbose=verbose)
342
441
 
343
442
  def fuse(self):
344
443
  """
345
- Fuses Conv2d and BatchNorm2d layers in the model.
444
+ Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.
346
445
 
347
- This method optimizes the model by fusing Conv2d and BatchNorm2d layers, which can improve inference speed.
446
+ This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
447
+ into a single layer. This fusion can significantly improve inference speed by reducing the number of
448
+ operations and memory accesses required during forward passes.
449
+
450
+ The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
451
+ bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
452
+ performs both convolution and normalization in one step.
348
453
 
349
454
  Raises:
350
- AssertionError: If the model is not a PyTorch model.
455
+ TypeError: If the model is not a PyTorch nn.Module.
456
+
457
+ Examples:
458
+ >>> model = Model("yolo11n.pt")
459
+ >>> model.fuse()
460
+ >>> # Model is now fused and ready for optimized inference
351
461
  """
352
462
  self._check_is_pytorch_model()
353
463
  self.model.fuse()
@@ -356,25 +466,31 @@ class Model(nn.Module):
356
466
  self,
357
467
  source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
358
468
  stream: bool = False,
359
- **kwargs,
469
+ **kwargs: Any,
360
470
  ) -> list:
361
471
  """
362
472
  Generates image embeddings based on the provided source.
363
473
 
364
- This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image source.
365
- It allows customization of the embedding process through various keyword arguments.
474
+ This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
475
+ source. It allows customization of the embedding process through various keyword arguments.
366
476
 
367
477
  Args:
368
- source (str | int | PIL.Image | np.ndarray): The source of the image for generating embeddings.
369
- The source can be a file path, URL, PIL image, numpy array, etc. Defaults to None.
370
- stream (bool): If True, predictions are streamed. Defaults to False.
371
- **kwargs (any): Additional keyword arguments for configuring the embedding process.
478
+ source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
479
+ generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
480
+ stream (bool): If True, predictions are streamed.
481
+ **kwargs: Additional keyword arguments for configuring the embedding process.
372
482
 
373
483
  Returns:
374
484
  (List[torch.Tensor]): A list containing the image embeddings.
375
485
 
376
486
  Raises:
377
487
  AssertionError: If the model is not a PyTorch model.
488
+
489
+ Examples:
490
+ >>> model = YOLO("yolo11n.pt")
491
+ >>> image = "https://ultralytics.com/images/bus.jpg"
492
+ >>> embeddings = model.embed(image)
493
+ >>> print(embeddings[0].shape)
378
494
  """
379
495
  if not kwargs.get("embed"):
380
496
  kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
@@ -382,45 +498,48 @@ class Model(nn.Module):
382
498
 
383
499
  def predict(
384
500
  self,
385
- source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
501
+ source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
386
502
  stream: bool = False,
387
503
  predictor=None,
388
- **kwargs,
389
- ) -> list:
504
+ **kwargs: Any,
505
+ ) -> List[Results]:
390
506
  """
391
507
  Performs predictions on the given image source using the YOLO model.
392
508
 
393
509
  This method facilitates the prediction process, allowing various configurations through keyword arguments.
394
510
  It supports predictions with custom predictors or the default predictor method. The method handles different
395
- types of image sources and can operate in a streaming mode. It also provides support for SAM-type models
396
- through 'prompts'.
397
-
398
- The method sets up a new predictor if not already present and updates its arguments with each call.
399
- It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
400
- is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
401
- for confidence threshold and saving behavior.
511
+ types of image sources and can operate in a streaming mode.
402
512
 
403
513
  Args:
404
- source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
405
- Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS.
406
- stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False.
407
- predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
408
- If None, the method uses a default predictor. Defaults to None.
409
- **kwargs (any): Additional keyword arguments for configuring the prediction process. These arguments allow
410
- for further customization of the prediction behavior.
514
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
515
+ of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
516
+ images, numpy arrays, and torch tensors.
517
+ stream (bool): If True, treats the input source as a continuous stream for predictions.
518
+ predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
519
+ If None, the method uses a default predictor.
520
+ **kwargs: Additional keyword arguments for configuring the prediction process.
411
521
 
412
522
  Returns:
413
- (List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
414
-
415
- Raises:
416
- AttributeError: If the predictor is not properly set up.
523
+ (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
524
+ Results object.
525
+
526
+ Examples:
527
+ >>> model = YOLO("yolo11n.pt")
528
+ >>> results = model.predict(source="path/to/image.jpg", conf=0.25)
529
+ >>> for r in results:
530
+ ... print(r.boxes.data) # print detection bounding boxes
531
+
532
+ Notes:
533
+ - If 'source' is not provided, it defaults to the ASSETS constant with a warning.
534
+ - The method sets up a new predictor if not already present and updates its arguments with each call.
535
+ - For SAM-type models, 'prompts' can be passed as a keyword argument.
417
536
  """
418
537
  if source is None:
419
538
  source = ASSETS
420
539
  LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
421
540
 
422
- is_cli = (sys.argv[0].endswith("yolo") or sys.argv[0].endswith("ultralytics")) and any(
423
- x in sys.argv for x in ("predict", "track", "mode=predict", "mode=track")
541
+ is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(
542
+ x in ARGV for x in ("predict", "track", "mode=predict", "mode=track")
424
543
  )
425
544
 
426
545
  custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults
@@ -428,7 +547,7 @@ class Model(nn.Module):
428
547
  prompts = args.pop("prompts", None) # for SAM-type models
429
548
 
430
549
  if not self.predictor:
431
- self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
550
+ self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks)
432
551
  self.predictor.setup_model(model=self.model, verbose=is_cli)
433
552
  else: # only update args if predictor is already setup
434
553
  self.predictor.args = get_cfg(self.predictor.args, args)
@@ -443,31 +562,38 @@ class Model(nn.Module):
443
562
  source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
444
563
  stream: bool = False,
445
564
  persist: bool = False,
446
- **kwargs,
447
- ) -> list:
565
+ **kwargs: Any,
566
+ ) -> List[Results]:
448
567
  """
449
568
  Conducts object tracking on the specified input source using the registered trackers.
450
569
 
451
- This method performs object tracking using the model's predictors and optionally registered trackers. It is
452
- capable of handling different types of input sources such as file paths or video streams. The method supports
453
- customization of the tracking process through various keyword arguments. It registers trackers if they are not
454
- already present and optionally persists them based on the 'persist' flag.
455
-
456
- The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low
457
- confidence predictions as input. The tracking mode is explicitly set in the keyword arguments.
570
+ This method performs object tracking using the model's predictors and optionally registered trackers. It handles
571
+ various input sources such as file paths or video streams, and supports customization through keyword arguments.
572
+ The method registers trackers if not already present and can persist them between calls.
458
573
 
459
574
  Args:
460
- source (str, optional): The input source for object tracking. It can be a file path, URL, or video stream.
461
- stream (bool, optional): Treats the input source as a continuous video stream. Defaults to False.
462
- persist (bool, optional): Persists the trackers between different calls to this method. Defaults to False.
463
- **kwargs (any): Additional keyword arguments for configuring the tracking process. These arguments allow
464
- for further customization of the tracking behavior.
575
+ source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
576
+ tracking. Can be a file path, URL, or video stream.
577
+ stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
578
+ persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
579
+ **kwargs: Additional keyword arguments for configuring the tracking process.
465
580
 
466
581
  Returns:
467
- (List[ultralytics.engine.results.Results]): A list of tracking results, encapsulated in the Results class.
582
+ (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
468
583
 
469
584
  Raises:
470
585
  AttributeError: If the predictor does not have registered trackers.
586
+
587
+ Examples:
588
+ >>> model = YOLO("yolo11n.pt")
589
+ >>> results = model.track(source="path/to/video.mp4", show=True)
590
+ >>> for r in results:
591
+ ... print(r.boxes.id) # print tracking IDs
592
+
593
+ Notes:
594
+ - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.
595
+ - The tracking mode is explicitly set in the keyword arguments.
596
+ - Batch size is set to 1 for tracking in videos.
471
597
  """
472
598
  if not hasattr(self.predictor, "trackers"):
473
599
  from ultralytics.trackers import register_tracker
@@ -481,31 +607,30 @@ class Model(nn.Module):
481
607
  def val(
482
608
  self,
483
609
  validator=None,
484
- **kwargs,
610
+ **kwargs: Any,
485
611
  ):
486
612
  """
487
613
  Validates the model using a specified dataset and validation configuration.
488
614
 
489
- This method facilitates the model validation process, allowing for a range of customization through various
490
- settings and configurations. It supports validation with a custom validator or the default validation approach.
491
- The method combines default configurations, method-specific defaults, and user-provided arguments to configure
492
- the validation process. After validation, it updates the model's metrics with the results obtained from the
493
- validator.
494
-
495
- The method supports various arguments that allow customization of the validation process. For a comprehensive
496
- list of all configurable options, users should refer to the 'configuration' section in the documentation.
615
+ This method facilitates the model validation process, allowing for customization through various settings. It
616
+ supports validation with a custom validator or the default validation approach. The method combines default
617
+ configurations, method-specific defaults, and user-provided arguments to configure the validation process.
497
618
 
498
619
  Args:
499
- validator (BaseValidator, optional): An instance of a custom validator class for validating the model. If
500
- None, the method uses a default validator. Defaults to None.
501
- **kwargs (any): Arbitrary keyword arguments representing the validation configuration. These arguments are
502
- used to customize various aspects of the validation process.
620
+ validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
621
+ validating the model.
622
+ **kwargs: Arbitrary keyword arguments for customizing the validation process.
503
623
 
504
624
  Returns:
505
- (dict): Validation metrics obtained from the validation process.
625
+ (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.
506
626
 
507
627
  Raises:
508
628
  AssertionError: If the model is not a PyTorch model.
629
+
630
+ Examples:
631
+ >>> model = YOLO("yolo11n.pt")
632
+ >>> results = model.val(data="coco8.yaml", imgsz=640)
633
+ >>> print(results.box.map) # Print mAP50-95
509
634
  """
510
635
  custom = {"rect": True} # method defaults
511
636
  args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
@@ -517,29 +642,37 @@ class Model(nn.Module):
517
642
 
518
643
  def benchmark(
519
644
  self,
520
- **kwargs,
645
+ **kwargs: Any,
521
646
  ):
522
647
  """
523
648
  Benchmarks the model across various export formats to evaluate performance.
524
649
 
525
650
  This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
526
- It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured
527
- using a combination of default configuration values, model-specific arguments, method-specific defaults, and
528
- any additional user-provided keyword arguments.
529
-
530
- The method supports various arguments that allow customization of the benchmarking process, such as dataset
531
- choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all
532
- configurable options, users should refer to the 'configuration' section in the documentation.
651
+ It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
652
+ configured using a combination of default configuration values, model-specific arguments, method-specific
653
+ defaults, and any additional user-provided keyword arguments.
533
654
 
534
655
  Args:
535
- **kwargs (any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
536
- default configurations, model-specific arguments, and method defaults.
656
+ **kwargs: Arbitrary keyword arguments to customize the benchmarking process. These are combined with
657
+ default configurations, model-specific arguments, and method defaults. Common options include:
658
+ - data (str): Path to the dataset for benchmarking.
659
+ - imgsz (int | List[int]): Image size for benchmarking.
660
+ - half (bool): Whether to use half-precision (FP16) mode.
661
+ - int8 (bool): Whether to use int8 precision mode.
662
+ - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
663
+ - verbose (bool): Whether to print detailed benchmark information.
537
664
 
538
665
  Returns:
539
- (dict): A dictionary containing the results of the benchmarking process.
666
+ (Dict): A dictionary containing the results of the benchmarking process, including metrics for
667
+ different export formats.
540
668
 
541
669
  Raises:
542
670
  AssertionError: If the model is not a PyTorch model.
671
+
672
+ Examples:
673
+ >>> model = YOLO("yolo11n.pt")
674
+ >>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True)
675
+ >>> print(results)
543
676
  """
544
677
  self._check_is_pytorch_model()
545
678
  from ultralytics.utils.benchmarks import benchmark
@@ -558,66 +691,92 @@ class Model(nn.Module):
558
691
 
559
692
  def export(
560
693
  self,
561
- **kwargs,
562
- ):
694
+ **kwargs: Any,
695
+ ) -> str:
563
696
  """
564
697
  Exports the model to a different format suitable for deployment.
565
698
 
566
699
  This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
567
700
  purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
568
- defaults, and any additional arguments provided. The combined arguments are used to configure export settings.
569
-
570
- The method supports a wide range of arguments to customize the export process. For a comprehensive list of all
571
- possible arguments, refer to the 'configuration' section in the documentation.
701
+ defaults, and any additional arguments provided.
572
702
 
573
703
  Args:
574
- **kwargs (any): Arbitrary keyword arguments to customize the export process. These are combined with the
575
- model's overrides and method defaults.
704
+ **kwargs: Arbitrary keyword arguments to customize the export process. These are combined with
705
+ the model's overrides and method defaults. Common arguments include:
706
+ format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
707
+ half (bool): Export model in half-precision.
708
+ int8 (bool): Export model in int8 precision.
709
+ device (str): Device to run the export on.
710
+ workspace (int): Maximum memory workspace size for TensorRT engines.
711
+ nms (bool): Add Non-Maximum Suppression (NMS) module to model.
712
+ simplify (bool): Simplify ONNX model.
576
713
 
577
714
  Returns:
578
- (object): The exported model in the specified format, or an object related to the export process.
715
+ (str): The path to the exported model file.
579
716
 
580
717
  Raises:
581
718
  AssertionError: If the model is not a PyTorch model.
719
+ ValueError: If an unsupported export format is specified.
720
+ RuntimeError: If the export process fails due to errors.
721
+
722
+ Examples:
723
+ >>> model = YOLO("yolo11n.pt")
724
+ >>> model.export(format="onnx", dynamic=True, simplify=True)
725
+ 'path/to/exported/model.onnx'
582
726
  """
583
727
  self._check_is_pytorch_model()
584
728
  from .exporter import Exporter
585
729
 
586
- custom = {"imgsz": self.model.args["imgsz"], "batch": 1, "data": None, "verbose": False} # method defaults
730
+ custom = {
731
+ "imgsz": self.model.args["imgsz"],
732
+ "batch": 1,
733
+ "data": None,
734
+ "device": None, # reset to avoid multi-GPU errors
735
+ "verbose": False,
736
+ } # method defaults
587
737
  args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
588
738
  return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
589
739
 
590
740
  def train(
591
741
  self,
592
742
  trainer=None,
593
- **kwargs,
743
+ **kwargs: Any,
594
744
  ):
595
745
  """
596
746
  Trains the model using the specified dataset and training configuration.
597
747
 
598
- This method facilitates model training with a range of customizable settings and configurations. It supports
599
- training with a custom trainer or the default training approach defined in the method. The method handles
600
- different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
601
- updating model and configuration after training.
748
+ This method facilitates model training with a range of customizable settings. It supports training with a
749
+ custom trainer or the default training approach. The method handles scenarios such as resuming training
750
+ from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
602
751
 
603
- When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
604
- arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
605
- configurations, method-specific defaults, and user-provided arguments to configure the training process. After
606
- training, it updates the model and its configurations, and optionally attaches metrics.
752
+ When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
753
+ arguments and warns if local arguments are provided. It checks for pip updates and combines default
754
+ configurations, method-specific defaults, and user-provided arguments to configure the training process.
607
755
 
608
756
  Args:
609
- trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
610
- method uses a default trainer. Defaults to None.
611
- **kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are
612
- used to customize various aspects of the training process.
757
+ trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
758
+ **kwargs: Arbitrary keyword arguments for training configuration. Common options include:
759
+ data (str): Path to dataset configuration file.
760
+ epochs (int): Number of training epochs.
761
+ batch_size (int): Batch size for training.
762
+ imgsz (int): Input image size.
763
+ device (str): Device to run training on (e.g., 'cuda', 'cpu').
764
+ workers (int): Number of worker threads for data loading.
765
+ optimizer (str): Optimizer to use for training.
766
+ lr0 (float): Initial learning rate.
767
+ patience (int): Epochs to wait for no observable improvement for early stopping of training.
613
768
 
614
769
  Returns:
615
- (dict | None): Training metrics if available and training is successful; otherwise, None.
770
+ (Dict | None): Training metrics if available and training is successful; otherwise, None.
616
771
 
617
772
  Raises:
618
773
  AssertionError: If the model is not a PyTorch model.
619
774
  PermissionError: If there is a permission issue with the HUB session.
620
775
  ModuleNotFoundError: If the HUB SDK is not installed.
776
+
777
+ Examples:
778
+ >>> model = YOLO("yolo11n.pt")
779
+ >>> results = model.train(data="coco8.yaml", epochs=3)
621
780
  """
622
781
  self._check_is_pytorch_model()
623
782
  if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
@@ -628,7 +787,12 @@ class Model(nn.Module):
628
787
  checks.check_pip_update_available()
629
788
 
630
789
  overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
631
- custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
790
+ custom = {
791
+ # NOTE: handle the case when 'cfg' includes 'data'.
792
+ "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task],
793
+ "model": self.overrides["model"],
794
+ "task": self.task,
795
+ } # method defaults
632
796
  args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
633
797
  if args.get("resume"):
634
798
  args["resume"] = self.ckpt_path
@@ -638,25 +802,12 @@ class Model(nn.Module):
638
802
  self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
639
803
  self.model = self.trainer.model
640
804
 
641
- if SETTINGS["hub"] is True and not self.session:
642
- # Create a model in HUB
643
- try:
644
- self.session = self._get_hub_session(self.model_name)
645
- if self.session:
646
- self.session.create_model(args)
647
- # Check model was created
648
- if not getattr(self.session.model, "id", None):
649
- self.session = None
650
- except (PermissionError, ModuleNotFoundError):
651
- # Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
652
- pass
653
-
654
805
  self.trainer.hub_session = self.session # attach optional HUB session
655
806
  self.trainer.train()
656
807
  # Update model and cfg after training
657
- if RANK in (-1, 0):
808
+ if RANK in {-1, 0}:
658
809
  ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
659
- self.model, _ = attempt_load_one_weight(ckpt)
810
+ self.model, self.ckpt = attempt_load_one_weight(ckpt)
660
811
  self.overrides = self.model.args
661
812
  self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
662
813
  return self.metrics
@@ -665,8 +816,8 @@ class Model(nn.Module):
665
816
  self,
666
817
  use_ray=False,
667
818
  iterations=10,
668
- *args,
669
- **kwargs,
819
+ *args: Any,
820
+ **kwargs: Any,
670
821
  ):
671
822
  """
672
823
  Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
@@ -679,14 +830,19 @@ class Model(nn.Module):
679
830
  Args:
680
831
  use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
681
832
  iterations (int): The number of tuning iterations to perform. Defaults to 10.
682
- *args (list): Variable length argument list for additional arguments.
683
- **kwargs (any): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
833
+ *args: Variable length argument list for additional arguments.
834
+ **kwargs: Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
684
835
 
685
836
  Returns:
686
- (dict): A dictionary containing the results of the hyperparameter search.
837
+ (Dict): A dictionary containing the results of the hyperparameter search.
687
838
 
688
839
  Raises:
689
840
  AssertionError: If the model is not a PyTorch model.
841
+
842
+ Examples:
843
+ >>> model = YOLO("yolo11n.pt")
844
+ >>> results = model.tune(use_ray=True, iterations=20)
845
+ >>> print(results)
690
846
  """
691
847
  self._check_is_pytorch_model()
692
848
  if use_ray:
@@ -701,7 +857,27 @@ class Model(nn.Module):
701
857
  return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
702
858
 
703
859
  def _apply(self, fn) -> "Model":
704
- """Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
860
+ """
861
+ Applies a function to model tensors that are not parameters or registered buffers.
862
+
863
+ This method extends the functionality of the parent class's _apply method by additionally resetting the
864
+ predictor and updating the device in the model's overrides. It's typically used for operations like
865
+ moving the model to a different device or changing its precision.
866
+
867
+ Args:
868
+ fn (Callable): A function to be applied to the model's tensors. This is typically a method like
869
+ to(), cpu(), cuda(), half(), or float().
870
+
871
+ Returns:
872
+ (Model): The model instance with the function applied and updated attributes.
873
+
874
+ Raises:
875
+ AssertionError: If the model is not a PyTorch model.
876
+
877
+ Examples:
878
+ >>> model = Model("yolo11n.pt")
879
+ >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
880
+ """
705
881
  self._check_is_pytorch_model()
706
882
  self = super()._apply(fn) # noqa
707
883
  self.predictor = None # reset predictor as device may have changed
@@ -709,30 +885,55 @@ class Model(nn.Module):
709
885
  return self
710
886
 
711
887
  @property
712
- def names(self) -> list:
888
+ def names(self) -> Dict[int, str]:
713
889
  """
714
890
  Retrieves the class names associated with the loaded model.
715
891
 
716
892
  This property returns the class names if they are defined in the model. It checks the class names for validity
717
- using the 'check_class_names' function from the ultralytics.nn.autobackend module.
893
+ using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
894
+ initialized, it sets it up before retrieving the names.
718
895
 
719
896
  Returns:
720
- (list | None): The class names of the model if available, otherwise None.
897
+ (Dict[int, str]): A dict of class names associated with the model.
898
+
899
+ Raises:
900
+ AttributeError: If the model or predictor does not have a 'names' attribute.
901
+
902
+ Examples:
903
+ >>> model = YOLO("yolo11n.pt")
904
+ >>> print(model.names)
905
+ {0: 'person', 1: 'bicycle', 2: 'car', ...}
721
906
  """
722
907
  from ultralytics.nn.autobackend import check_class_names
723
908
 
724
- return check_class_names(self.model.names) if hasattr(self.model, "names") else None
909
+ if hasattr(self.model, "names"):
910
+ return check_class_names(self.model.names)
911
+ if not self.predictor: # export formats will not have predictor defined until predict() is called
912
+ self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
913
+ self.predictor.setup_model(model=self.model, verbose=False)
914
+ return self.predictor.model.names
725
915
 
726
916
  @property
727
917
  def device(self) -> torch.device:
728
918
  """
729
919
  Retrieves the device on which the model's parameters are allocated.
730
920
 
731
- This property is used to determine whether the model's parameters are on CPU or GPU. It only applies to models
732
- that are instances of nn.Module.
921
+ This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
922
+ applicable only to models that are instances of nn.Module.
733
923
 
734
924
  Returns:
735
- (torch.device | None): The device (CPU/GPU) of the model if it is a PyTorch model, otherwise None.
925
+ (torch.device): The device (CPU/GPU) of the model.
926
+
927
+ Raises:
928
+ AttributeError: If the model is not a PyTorch nn.Module instance.
929
+
930
+ Examples:
931
+ >>> model = YOLO("yolo11n.pt")
932
+ >>> print(model.device)
933
+ device(type='cuda', index=0) # if CUDA is available
934
+ >>> model = model.to("cpu")
935
+ >>> print(model.device)
936
+ device(type='cpu')
736
937
  """
737
938
  return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
738
939
 
@@ -741,10 +942,20 @@ class Model(nn.Module):
741
942
  """
742
943
  Retrieves the transformations applied to the input data of the loaded model.
743
944
 
744
- This property returns the transformations if they are defined in the model.
945
+ This property returns the transformations if they are defined in the model. The transforms
946
+ typically include preprocessing steps like resizing, normalization, and data augmentation
947
+ that are applied to input data before it is fed into the model.
745
948
 
746
949
  Returns:
747
950
  (object | None): The transform object of the model if available, otherwise None.
951
+
952
+ Examples:
953
+ >>> model = YOLO("yolo11n.pt")
954
+ >>> transforms = model.transforms
955
+ >>> if transforms:
956
+ ... print(f"Model transforms: {transforms}")
957
+ ... else:
958
+ ... print("No transforms defined for this model.")
748
959
  """
749
960
  return self.model.transforms if hasattr(self.model, "transforms") else None
750
961
 
@@ -752,15 +963,25 @@ class Model(nn.Module):
752
963
  """
753
964
  Adds a callback function for a specified event.
754
965
 
755
- This method allows the user to register a custom callback function that is triggered on a specific event during
756
- model training or inference.
966
+ This method allows registering custom callback functions that are triggered on specific events during
967
+ model operations such as training or inference. Callbacks provide a way to extend and customize the
968
+ behavior of the model at various stages of its lifecycle.
757
969
 
758
970
  Args:
759
- event (str): The name of the event to attach the callback to.
760
- func (callable): The callback function to be registered.
971
+ event (str): The name of the event to attach the callback to. Must be a valid event name recognized
972
+ by the Ultralytics framework.
973
+ func (Callable): The callback function to be registered. This function will be called when the
974
+ specified event occurs.
761
975
 
762
976
  Raises:
763
- ValueError: If the event name is not recognized.
977
+ ValueError: If the event name is not recognized or is invalid.
978
+
979
+ Examples:
980
+ >>> def on_train_start(trainer):
981
+ ... print("Training is starting!")
982
+ >>> model = YOLO("yolo11n.pt")
983
+ >>> model.add_callback("on_train_start", on_train_start)
984
+ >>> model.train(data="coco8.yaml", epochs=1)
764
985
  """
765
986
  self.callbacks[event].append(func)
766
987
 
@@ -769,12 +990,26 @@ class Model(nn.Module):
769
990
  Clears all callback functions registered for a specified event.
770
991
 
771
992
  This method removes all custom and default callback functions associated with the given event.
993
+ It resets the callback list for the specified event to an empty list, effectively removing all
994
+ registered callbacks for that event.
772
995
 
773
996
  Args:
774
- event (str): The name of the event for which to clear the callbacks.
775
-
776
- Raises:
777
- ValueError: If the event name is not recognized.
997
+ event (str): The name of the event for which to clear the callbacks. This should be a valid event name
998
+ recognized by the Ultralytics callback system.
999
+
1000
+ Examples:
1001
+ >>> model = YOLO("yolo11n.pt")
1002
+ >>> model.add_callback("on_train_start", lambda: print("Training started"))
1003
+ >>> model.clear_callback("on_train_start")
1004
+ >>> # All callbacks for 'on_train_start' are now removed
1005
+
1006
+ Notes:
1007
+ - This method affects both custom callbacks added by the user and default callbacks
1008
+ provided by the Ultralytics framework.
1009
+ - After calling this method, no callbacks will be executed for the specified event
1010
+ until new ones are added.
1011
+ - Use with caution as it removes all callbacks, including essential ones that might
1012
+ be required for proper functioning of certain operations.
778
1013
  """
779
1014
  self.callbacks[event] = []
780
1015
 
@@ -783,14 +1018,45 @@ class Model(nn.Module):
783
1018
  Resets all callbacks to their default functions.
784
1019
 
785
1020
  This method reinstates the default callback functions for all events, removing any custom callbacks that were
786
- added previously.
1021
+ previously added. It iterates through all default callback events and replaces the current callbacks with the
1022
+ default ones.
1023
+
1024
+ The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined
1025
+ functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc.
1026
+
1027
+ This method is useful when you want to revert to the original set of callbacks after making custom
1028
+ modifications, ensuring consistent behavior across different runs or experiments.
1029
+
1030
+ Examples:
1031
+ >>> model = YOLO("yolo11n.pt")
1032
+ >>> model.add_callback("on_train_start", custom_function)
1033
+ >>> model.reset_callbacks()
1034
+ # All callbacks are now reset to their default functions
787
1035
  """
788
1036
  for event in callbacks.default_callbacks.keys():
789
1037
  self.callbacks[event] = [callbacks.default_callbacks[event][0]]
790
1038
 
791
1039
  @staticmethod
792
1040
  def _reset_ckpt_args(args: dict) -> dict:
793
- """Reset arguments when loading a PyTorch model."""
1041
+ """
1042
+ Resets specific arguments when loading a PyTorch model checkpoint.
1043
+
1044
+ This static method filters the input arguments dictionary to retain only a specific set of keys that are
1045
+ considered important for model loading. It's used to ensure that only relevant arguments are preserved
1046
+ when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
1047
+
1048
+ Args:
1049
+ args (dict): A dictionary containing various model arguments and settings.
1050
+
1051
+ Returns:
1052
+ (dict): A new dictionary containing only the specified include keys from the input arguments.
1053
+
1054
+ Examples:
1055
+ >>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100}
1056
+ >>> reset_args = Model._reset_ckpt_args(original_args)
1057
+ >>> print(reset_args)
1058
+ {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}
1059
+ """
794
1060
  include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
795
1061
  return {k: v for k, v in args.items() if k in include}
796
1062
 
@@ -800,7 +1066,31 @@ class Model(nn.Module):
800
1066
  # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
801
1067
 
802
1068
  def _smart_load(self, key: str):
803
- """Load model/trainer/validator/predictor."""
1069
+ """
1070
+ Loads the appropriate module based on the model task.
1071
+
1072
+ This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
1073
+ based on the current task of the model and the provided key. It uses the task_map attribute to determine
1074
+ the correct module to load.
1075
+
1076
+ Args:
1077
+ key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
1078
+
1079
+ Returns:
1080
+ (object): The loaded module corresponding to the specified key and current task.
1081
+
1082
+ Raises:
1083
+ NotImplementedError: If the specified key is not supported for the current task.
1084
+
1085
+ Examples:
1086
+ >>> model = Model(task="detect")
1087
+ >>> predictor = model._smart_load("predictor")
1088
+ >>> trainer = model._smart_load("trainer")
1089
+
1090
+ Notes:
1091
+ - This method is typically used internally by other methods of the Model class.
1092
+ - The task_map attribute should be properly initialized with the correct mappings for each task.
1093
+ """
804
1094
  try:
805
1095
  return self.task_map[self.task][key]
806
1096
  except Exception as e:
@@ -813,9 +1103,71 @@ class Model(nn.Module):
813
1103
  @property
814
1104
  def task_map(self) -> dict:
815
1105
  """
816
- Map head to model, trainer, validator, and predictor classes.
1106
+ Provides a mapping from model tasks to corresponding classes for different modes.
1107
+
1108
+ This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
1109
+ to a nested dictionary. The nested dictionary contains mappings for different operational modes
1110
+ (model, trainer, validator, predictor) to their respective class implementations.
1111
+
1112
+ The mapping allows for dynamic loading of appropriate classes based on the model's task and the
1113
+ desired operational mode. This facilitates a flexible and extensible architecture for handling
1114
+ various tasks and modes within the Ultralytics framework.
817
1115
 
818
1116
  Returns:
819
- task_map (dict): The map of model task to mode classes.
1117
+ (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
1118
+ nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
1119
+ 'predictor', mapping to their respective class implementations.
1120
+
1121
+ Examples:
1122
+ >>> model = Model()
1123
+ >>> task_map = model.task_map
1124
+ >>> detect_class_map = task_map["detect"]
1125
+ >>> segment_class_map = task_map["segment"]
1126
+
1127
+ Note:
1128
+ The actual implementation of this method may vary depending on the specific tasks and
1129
+ classes supported by the Ultralytics framework. The docstring provides a general
1130
+ description of the expected behavior and structure.
820
1131
  """
821
1132
  raise NotImplementedError("Please provide task map for your model!")
1133
+
1134
+ def eval(self):
1135
+ """
1136
+ Sets the model to evaluation mode.
1137
+
1138
+ This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
1139
+ that behave differently during training and evaluation.
1140
+
1141
+ Returns:
1142
+ (Model): The model instance with evaluation mode set.
1143
+
1144
+ Examples:
1145
+ >> model = YOLO("yolo11n.pt")
1146
+ >> model.eval()
1147
+ """
1148
+ self.model.eval()
1149
+ return self
1150
+
1151
+ def __getattr__(self, name):
1152
+ """
1153
+ Enables accessing model attributes directly through the Model class.
1154
+
1155
+ This method provides a way to access attributes of the underlying model directly through the Model class
1156
+ instance. It first checks if the requested attribute is 'model', in which case it returns the model from
1157
+ the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model.
1158
+
1159
+ Args:
1160
+ name (str): The name of the attribute to retrieve.
1161
+
1162
+ Returns:
1163
+ (Any): The requested attribute value.
1164
+
1165
+ Raises:
1166
+ AttributeError: If the requested attribute does not exist in the model.
1167
+
1168
+ Examples:
1169
+ >>> model = YOLO("yolo11n.pt")
1170
+ >>> print(model.stride)
1171
+ >>> print(model.task)
1172
+ """
1173
+ return self._modules["model"] if name == "model" else getattr(self.model, name)