dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import inspect
4
6
  from pathlib import Path
5
- from typing import Any, Dict, List, Union
7
+ from typing import Any
6
8
 
7
9
  import numpy as np
8
10
  import torch
@@ -10,7 +12,7 @@ from PIL import Image
10
12
 
11
13
  from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
12
14
  from ultralytics.engine.results import Results
13
- from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, yaml_model_load
15
+ from ultralytics.nn.tasks import guess_model_task, load_checkpoint, yaml_model_load
14
16
  from ultralytics.utils import (
15
17
  ARGV,
16
18
  ASSETS,
@@ -25,12 +27,11 @@ from ultralytics.utils import (
25
27
 
26
28
 
27
29
  class Model(torch.nn.Module):
28
- """
29
- A base class for implementing YOLO models, unifying APIs across different model types.
30
+ """A base class for implementing YOLO models, unifying APIs across different model types.
30
31
 
31
- This class provides a common interface for various operations related to YOLO models, such as training,
32
- validation, prediction, exporting, and benchmarking. It handles different types of models, including those
33
- loaded from local files, Ultralytics HUB, or Triton Server.
32
+ This class provides a common interface for various operations related to YOLO models, such as training, validation,
33
+ prediction, exporting, and benchmarking. It handles different types of models, including those loaded from local
34
+ files, Ultralytics HUB, or Triton Server.
34
35
 
35
36
  Attributes:
36
37
  callbacks (dict): A dictionary of callback functions for various events during model operations.
@@ -48,25 +49,25 @@ class Model(torch.nn.Module):
48
49
 
49
50
  Methods:
50
51
  __call__: Alias for the predict method, enabling the model instance to be callable.
51
- _new: Initializes a new model based on a configuration file.
52
- _load: Loads a model from a checkpoint file.
53
- _check_is_pytorch_model: Ensures that the model is a PyTorch model.
54
- reset_weights: Resets the model's weights to their initial state.
55
- load: Loads model weights from a specified file.
56
- save: Saves the current state of the model to a file.
57
- info: Logs or returns information about the model.
58
- fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference.
59
- predict: Performs object detection predictions.
60
- track: Performs object tracking.
61
- val: Validates the model on a dataset.
62
- benchmark: Benchmarks the model on various export formats.
63
- export: Exports the model to different formats.
64
- train: Trains the model on a dataset.
65
- tune: Performs hyperparameter tuning.
66
- _apply: Applies a function to the model's tensors.
67
- add_callback: Adds a callback function for an event.
68
- clear_callback: Clears all callbacks for an event.
69
- reset_callbacks: Resets all callbacks to their default functions.
52
+ _new: Initialize a new model based on a configuration file.
53
+ _load: Load a model from a checkpoint file.
54
+ _check_is_pytorch_model: Ensure that the model is a PyTorch model.
55
+ reset_weights: Reset the model's weights to their initial state.
56
+ load: Load model weights from a specified file.
57
+ save: Save the current state of the model to a file.
58
+ info: Log or return information about the model.
59
+ fuse: Fuse Conv2d and BatchNorm2d layers for optimized inference.
60
+ predict: Perform object detection predictions.
61
+ track: Perform object tracking.
62
+ val: Validate the model on a dataset.
63
+ benchmark: Benchmark the model on various export formats.
64
+ export: Export the model to different formats.
65
+ train: Train the model on a dataset.
66
+ tune: Perform hyperparameter tuning.
67
+ _apply: Apply a function to the model's tensors.
68
+ add_callback: Add a callback function for an event.
69
+ clear_callback: Clear all callbacks for an event.
70
+ reset_callbacks: Reset all callbacks to their default functions.
70
71
 
71
72
  Examples:
72
73
  >>> from ultralytics import YOLO
@@ -79,24 +80,21 @@ class Model(torch.nn.Module):
79
80
 
80
81
  def __init__(
81
82
  self,
82
- model: Union[str, Path] = "yolo11n.pt",
83
- task: str = None,
83
+ model: str | Path | Model = "yolo11n.pt",
84
+ task: str | None = None,
84
85
  verbose: bool = False,
85
86
  ) -> None:
86
- """
87
- Initialize a new instance of the YOLO model class.
87
+ """Initialize a new instance of the YOLO model class.
88
88
 
89
- This constructor sets up the model based on the provided model path or name. It handles various types of
90
- model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
91
- initializes several important attributes of the model and prepares it for operations like training,
92
- prediction, or export.
89
+ This constructor sets up the model based on the provided model path or name. It handles various types of model
90
+ sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
91
+ important attributes of the model and prepares it for operations like training, prediction, or export.
93
92
 
94
93
  Args:
95
- model (str | Path): Path or name of the model to load or create. Can be a local file path, a
96
- model name from Ultralytics HUB, or a Triton Server model.
97
- task (str | None): The task type associated with the YOLO model, specifying its application domain.
98
- verbose (bool): If True, enables verbose output during the model's initialization and subsequent
99
- operations.
94
+ model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a model
95
+ name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
96
+ task (str, optional): The specific task for the model. If None, it will be inferred from the config.
97
+ verbose (bool): If True, enables verbose output during the model's initialization and subsequent operations.
100
98
 
101
99
  Raises:
102
100
  FileNotFoundError: If the specified model file does not exist or is inaccessible.
@@ -108,6 +106,9 @@ class Model(torch.nn.Module):
108
106
  >>> model = Model("path/to/model.yaml", task="detect")
109
107
  >>> model = Model("hub_model", verbose=True)
110
108
  """
109
+ if isinstance(model, Model):
110
+ self.__dict__ = model.__dict__ # accepts an already initialized Model
111
+ return
111
112
  super().__init__()
112
113
  self.callbacks = callbacks.get_default_callbacks()
113
114
  self.predictor = None # reuse predictor
@@ -152,26 +153,25 @@ class Model(torch.nn.Module):
152
153
 
153
154
  def __call__(
154
155
  self,
155
- source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
156
+ source: str | Path | int | Image.Image | list | tuple | np.ndarray | torch.Tensor = None,
156
157
  stream: bool = False,
157
158
  **kwargs: Any,
158
159
  ) -> list:
159
- """
160
- Alias for the predict method, enabling the model instance to be callable for predictions.
160
+ """Alias for the predict method, enabling the model instance to be callable for predictions.
161
161
 
162
- This method simplifies the process of making predictions by allowing the model instance to be called
163
- directly with the required arguments.
162
+ This method simplifies the process of making predictions by allowing the model instance to be called directly
163
+ with the required arguments.
164
164
 
165
165
  Args:
166
- source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
167
- the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
168
- tensor, or a list/tuple of these.
166
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
167
+ to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch tensor, or a list/tuple
168
+ of these.
169
169
  stream (bool): If True, treat the input source as a continuous stream for predictions.
170
170
  **kwargs (Any): Additional keyword arguments to configure the prediction process.
171
171
 
172
172
  Returns:
173
- (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
174
- Results object.
173
+ (list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
174
+ object.
175
175
 
176
176
  Examples:
177
177
  >>> model = YOLO("yolo11n.pt")
@@ -183,11 +183,10 @@ class Model(torch.nn.Module):
183
183
 
184
184
  @staticmethod
185
185
  def is_triton_model(model: str) -> bool:
186
- """
187
- Check if the given model string is a Triton Server URL.
186
+ """Check if the given model string is a Triton Server URL.
188
187
 
189
- This static method determines whether the provided model string represents a valid Triton Server URL by
190
- parsing its components using urllib.parse.urlsplit().
188
+ This static method determines whether the provided model string represents a valid Triton Server URL by parsing
189
+ its components using urllib.parse.urlsplit().
191
190
 
192
191
  Args:
193
192
  model (str): The model string to be checked.
@@ -208,8 +207,7 @@ class Model(torch.nn.Module):
208
207
 
209
208
  @staticmethod
210
209
  def is_hub_model(model: str) -> bool:
211
- """
212
- Check if the provided model is an Ultralytics HUB model.
210
+ """Check if the provided model is an Ultralytics HUB model.
213
211
 
214
212
  This static method determines whether the given model string represents a valid Ultralytics HUB model
215
213
  identifier.
@@ -231,16 +229,15 @@ class Model(torch.nn.Module):
231
229
  return model.startswith(f"{HUB_WEB_ROOT}/models/")
232
230
 
233
231
  def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
234
- """
235
- Initialize a new model and infer the task type from model definitions.
232
+ """Initialize a new model and infer the task type from model definitions.
236
233
 
237
- Creates a new model instance based on the provided configuration file. Loads the model configuration, infers
238
- the task type if not specified, and initializes the model using the appropriate class from the task map.
234
+ Creates a new model instance based on the provided configuration file. Loads the model configuration, infers the
235
+ task type if not specified, and initializes the model using the appropriate class from the task map.
239
236
 
240
237
  Args:
241
238
  cfg (str): Path to the model configuration file in YAML format.
242
- task (str | None): The specific task for the model. If None, it will be inferred from the config.
243
- model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
239
+ task (str, optional): The specific task for the model. If None, it will be inferred from the config.
240
+ model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of creating
244
241
  a new one.
245
242
  verbose (bool): If True, displays model information during loading.
246
243
 
@@ -265,15 +262,14 @@ class Model(torch.nn.Module):
265
262
  self.model_name = cfg
266
263
 
267
264
  def _load(self, weights: str, task=None) -> None:
268
- """
269
- Load a model from a checkpoint file or initialize it from a weights file.
265
+ """Load a model from a checkpoint file or initialize it from a weights file.
270
266
 
271
- This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
272
- up the model, task, and related attributes based on the loaded weights.
267
+ This method handles loading models from either .pt checkpoint files or other weight file formats. It sets up the
268
+ model, task, and related attributes based on the loaded weights.
273
269
 
274
270
  Args:
275
271
  weights (str): Path to the model weights file to be loaded.
276
- task (str | None): The task associated with the model. If None, it will be inferred from the model.
272
+ task (str, optional): The task associated with the model. If None, it will be inferred from the model.
277
273
 
278
274
  Raises:
279
275
  FileNotFoundError: If the specified weights file does not exist or is inaccessible.
@@ -288,9 +284,9 @@ class Model(torch.nn.Module):
288
284
  weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
289
285
  weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo11n -> yolo11n.pt
290
286
 
291
- if Path(weights).suffix == ".pt":
292
- self.model, self.ckpt = attempt_load_one_weight(weights)
293
- self.task = self.model.args["task"]
287
+ if str(weights).rpartition(".")[-1] == "pt":
288
+ self.model, self.ckpt = load_checkpoint(weights)
289
+ self.task = self.model.task
294
290
  self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
295
291
  self.ckpt_path = self.model.pt_path
296
292
  else:
@@ -303,11 +299,10 @@ class Model(torch.nn.Module):
303
299
  self.model_name = weights
304
300
 
305
301
  def _check_is_pytorch_model(self) -> None:
306
- """
307
- Check if the model is a PyTorch model and raise TypeError if it's not.
302
+ """Check if the model is a PyTorch model and raise TypeError if it's not.
308
303
 
309
- This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
310
- certain operations that require a PyTorch model are only performed on compatible model types.
304
+ This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that certain
305
+ operations that require a PyTorch model are only performed on compatible model types.
311
306
 
312
307
  Raises:
313
308
  TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
@@ -319,7 +314,7 @@ class Model(torch.nn.Module):
319
314
  >>> model = Model("yolo11n.onnx")
320
315
  >>> model._check_is_pytorch_model() # Raises TypeError
321
316
  """
322
- pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
317
+ pt_str = isinstance(self.model, (str, Path)) and str(self.model).rpartition(".")[-1] == "pt"
323
318
  pt_module = isinstance(self.model, torch.nn.Module)
324
319
  if not (pt_module or pt_str):
325
320
  raise TypeError(
@@ -330,13 +325,12 @@ class Model(torch.nn.Module):
330
325
  f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
331
326
  )
332
327
 
333
- def reset_weights(self) -> "Model":
334
- """
335
- Reset the model's weights to their initial state.
328
+ def reset_weights(self) -> Model:
329
+ """Reset the model's weights to their initial state.
336
330
 
337
331
  This method iterates through all modules in the model and resets their parameters if they have a
338
- 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
339
- enabling them to be updated during training.
332
+ 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
333
+ to be updated during training.
340
334
 
341
335
  Returns:
342
336
  (Model): The instance of the class with reset weights.
@@ -356,15 +350,14 @@ class Model(torch.nn.Module):
356
350
  p.requires_grad = True
357
351
  return self
358
352
 
359
- def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model":
360
- """
361
- Load parameters from the specified weights file into the model.
353
+ def load(self, weights: str | Path = "yolo11n.pt") -> Model:
354
+ """Load parameters from the specified weights file into the model.
362
355
 
363
356
  This method supports loading weights from a file or directly from a weights object. It matches parameters by
364
357
  name and shape and transfers them to the model.
365
358
 
366
359
  Args:
367
- weights (Union[str, Path]): Path to the weights file or a weights object.
360
+ weights (str | Path): Path to the weights file or a weights object.
368
361
 
369
362
  Returns:
370
363
  (Model): The instance of the class with loaded weights.
@@ -380,16 +373,15 @@ class Model(torch.nn.Module):
380
373
  self._check_is_pytorch_model()
381
374
  if isinstance(weights, (str, Path)):
382
375
  self.overrides["pretrained"] = weights # remember the weights for DDP training
383
- weights, self.ckpt = attempt_load_one_weight(weights)
376
+ weights, self.ckpt = load_checkpoint(weights)
384
377
  self.model.load(weights)
385
378
  return self
386
379
 
387
- def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
388
- """
389
- Save the current model state to a file.
380
+ def save(self, filename: str | Path = "saved_model.pt") -> None:
381
+ """Save the current model state to a file.
390
382
 
391
- This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
392
- the date, Ultralytics version, license information, and a link to the documentation.
383
+ This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as the
384
+ date, Ultralytics version, license information, and a link to the documentation.
393
385
 
394
386
  Args:
395
387
  filename (str | Path): The name of the file to save the model to.
@@ -417,8 +409,7 @@ class Model(torch.nn.Module):
417
409
  torch.save({**self.ckpt, **updates}, filename)
418
410
 
419
411
  def info(self, detailed: bool = False, verbose: bool = True):
420
- """
421
- Display model information.
412
+ """Display model information.
422
413
 
423
414
  This method provides an overview or detailed information about the model, depending on the arguments
424
415
  passed. It can control the verbosity of the output and return the information as a list.
@@ -428,8 +419,8 @@ class Model(torch.nn.Module):
428
419
  verbose (bool): If True, prints the information. If False, returns the information as a list.
429
420
 
430
421
  Returns:
431
- (List[str]): A list of strings containing various types of information about the model, including
432
- model summary, layer details, and parameter counts. Empty if verbose is True.
422
+ (list[str]): A list of strings containing various types of information about the model, including model
423
+ summary, layer details, and parameter counts. Empty if verbose is True.
433
424
 
434
425
  Examples:
435
426
  >>> model = Model("yolo11n.pt")
@@ -440,12 +431,11 @@ class Model(torch.nn.Module):
440
431
  return self.model.info(detailed=detailed, verbose=verbose)
441
432
 
442
433
  def fuse(self) -> None:
443
- """
444
- Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
434
+ """Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
445
435
 
446
- This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
447
- into a single layer. This fusion can significantly improve inference speed by reducing the number of
448
- operations and memory accesses required during forward passes.
436
+ This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers into a
437
+ single layer. This fusion can significantly improve inference speed by reducing the number of operations and
438
+ memory accesses required during forward passes.
449
439
 
450
440
  The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
451
441
  bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
@@ -461,24 +451,23 @@ class Model(torch.nn.Module):
461
451
 
462
452
  def embed(
463
453
  self,
464
- source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
454
+ source: str | Path | int | list | tuple | np.ndarray | torch.Tensor = None,
465
455
  stream: bool = False,
466
456
  **kwargs: Any,
467
457
  ) -> list:
468
- """
469
- Generate image embeddings based on the provided source.
458
+ """Generate image embeddings based on the provided source.
470
459
 
471
460
  This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
472
461
  source. It allows customization of the embedding process through various keyword arguments.
473
462
 
474
463
  Args:
475
- source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
476
- generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
464
+ source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for generating
465
+ embeddings. Can be a file path, URL, PIL image, numpy array, etc.
477
466
  stream (bool): If True, predictions are streamed.
478
467
  **kwargs (Any): Additional keyword arguments for configuring the embedding process.
479
468
 
480
469
  Returns:
481
- (List[torch.Tensor]): A list containing the image embeddings.
470
+ (list[torch.Tensor]): A list containing the image embeddings.
482
471
 
483
472
  Examples:
484
473
  >>> model = YOLO("yolo11n.pt")
@@ -492,30 +481,29 @@ class Model(torch.nn.Module):
492
481
 
493
482
  def predict(
494
483
  self,
495
- source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
484
+ source: str | Path | int | Image.Image | list | tuple | np.ndarray | torch.Tensor = None,
496
485
  stream: bool = False,
497
486
  predictor=None,
498
487
  **kwargs: Any,
499
- ) -> List[Results]:
500
- """
501
- Performs predictions on the given image source using the YOLO model.
488
+ ) -> list[Results]:
489
+ """Perform predictions on the given image source using the YOLO model.
502
490
 
503
- This method facilitates the prediction process, allowing various configurations through keyword arguments.
504
- It supports predictions with custom predictors or the default predictor method. The method handles different
505
- types of image sources and can operate in a streaming mode.
491
+ This method facilitates the prediction process, allowing various configurations through keyword arguments. It
492
+ supports predictions with custom predictors or the default predictor method. The method handles different types
493
+ of image sources and can operate in a streaming mode.
506
494
 
507
495
  Args:
508
- source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
509
- of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
510
- images, numpy arrays, and torch tensors.
496
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
497
+ to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
498
+ torch tensors.
511
499
  stream (bool): If True, treats the input source as a continuous stream for predictions.
512
- predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
513
- If None, the method uses a default predictor.
500
+ predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions. If
501
+ None, the method uses a default predictor.
514
502
  **kwargs (Any): Additional keyword arguments for configuring the prediction process.
515
503
 
516
504
  Returns:
517
- (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
518
- Results object.
505
+ (list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
506
+ object.
519
507
 
520
508
  Examples:
521
509
  >>> model = YOLO("yolo11n.pt")
@@ -553,27 +541,26 @@ class Model(torch.nn.Module):
553
541
 
554
542
  def track(
555
543
  self,
556
- source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
544
+ source: str | Path | int | list | tuple | np.ndarray | torch.Tensor = None,
557
545
  stream: bool = False,
558
546
  persist: bool = False,
559
547
  **kwargs: Any,
560
- ) -> List[Results]:
561
- """
562
- Conducts object tracking on the specified input source using the registered trackers.
548
+ ) -> list[Results]:
549
+ """Conduct object tracking on the specified input source using the registered trackers.
563
550
 
564
551
  This method performs object tracking using the model's predictors and optionally registered trackers. It handles
565
552
  various input sources such as file paths or video streams, and supports customization through keyword arguments.
566
553
  The method registers trackers if not already present and can persist them between calls.
567
554
 
568
555
  Args:
569
- source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
556
+ source (str | Path | int | list | tuple | np.ndarray | torch.Tensor, optional): Input source for object
570
557
  tracking. Can be a file path, URL, or video stream.
571
558
  stream (bool): If True, treats the input source as a continuous video stream.
572
559
  persist (bool): If True, persists trackers between different calls to this method.
573
560
  **kwargs (Any): Additional keyword arguments for configuring the tracking process.
574
561
 
575
562
  Returns:
576
- (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
563
+ (list[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
577
564
 
578
565
  Examples:
579
566
  >>> model = YOLO("yolo11n.pt")
@@ -600,16 +587,15 @@ class Model(torch.nn.Module):
600
587
  validator=None,
601
588
  **kwargs: Any,
602
589
  ):
603
- """
604
- Validate the model using a specified dataset and validation configuration.
590
+ """Validate the model using a specified dataset and validation configuration.
605
591
 
606
592
  This method facilitates the model validation process, allowing for customization through various settings. It
607
593
  supports validation with a custom validator or the default validation approach. The method combines default
608
594
  configurations, method-specific defaults, and user-provided arguments to configure the validation process.
609
595
 
610
596
  Args:
611
- validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
612
- validating the model.
597
+ validator (ultralytics.engine.validator.BaseValidator, optional): An instance of a custom validator class
598
+ for validating the model.
613
599
  **kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
614
600
 
615
601
  Returns:
@@ -631,31 +617,27 @@ class Model(torch.nn.Module):
631
617
  self.metrics = validator.metrics
632
618
  return validator.metrics
633
619
 
634
- def benchmark(
635
- self,
636
- **kwargs: Any,
637
- ):
638
- """
639
- Benchmark the model across various export formats to evaluate performance.
620
+ def benchmark(self, data=None, format="", verbose=False, **kwargs: Any):
621
+ """Benchmark the model across various export formats to evaluate performance.
640
622
 
641
- This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
642
- It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
643
- configured using a combination of default configuration values, model-specific arguments, method-specific
644
- defaults, and any additional user-provided keyword arguments.
623
+ This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. It
624
+ uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured using
625
+ a combination of default configuration values, model-specific arguments, method-specific defaults, and any
626
+ additional user-provided keyword arguments.
645
627
 
646
628
  Args:
629
+ data (str): Path to the dataset for benchmarking.
630
+ verbose (bool): Whether to print detailed benchmark information.
631
+ format (str): Export format name for specific benchmarking.
647
632
  **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. Common options include:
648
- - data (str): Path to the dataset for benchmarking.
649
- - imgsz (int | List[int]): Image size for benchmarking.
633
+ - imgsz (int | list[int]): Image size for benchmarking.
650
634
  - half (bool): Whether to use half-precision (FP16) mode.
651
635
  - int8 (bool): Whether to use int8 precision mode.
652
636
  - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
653
- - verbose (bool): Whether to print detailed benchmark information.
654
- - format (str): Export format name for specific benchmarking.
655
637
 
656
638
  Returns:
657
- (dict): A dictionary containing the results of the benchmarking process, including metrics for
658
- different export formats.
639
+ (dict): A dictionary containing the results of the benchmarking process, including metrics for different
640
+ export formats.
659
641
 
660
642
  Raises:
661
643
  AssertionError: If the model is not a PyTorch model.
@@ -668,40 +650,42 @@ class Model(torch.nn.Module):
668
650
  self._check_is_pytorch_model()
669
651
  from ultralytics.utils.benchmarks import benchmark
670
652
 
653
+ from .exporter import export_formats
654
+
671
655
  custom = {"verbose": False} # method defaults
672
656
  args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
657
+ fmts = export_formats()
658
+ export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, [])) - {"batch"}
659
+ export_kwargs = {k: v for k, v in args.items() if k in export_args}
673
660
  return benchmark(
674
661
  model=self,
675
- data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets
662
+ data=data, # if no 'data' argument passed set data=None for default datasets
676
663
  imgsz=args["imgsz"],
677
- half=args["half"],
678
- int8=args["int8"],
679
664
  device=args["device"],
680
- verbose=kwargs.get("verbose", False),
681
- format=kwargs.get("format", ""),
665
+ verbose=verbose,
666
+ format=format,
667
+ **export_kwargs,
682
668
  )
683
669
 
684
670
  def export(
685
671
  self,
686
672
  **kwargs: Any,
687
673
  ) -> str:
688
- """
689
- Export the model to a different format suitable for deployment.
674
+ """Export the model to a different format suitable for deployment.
690
675
 
691
676
  This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
692
677
  purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
693
678
  defaults, and any additional arguments provided.
694
679
 
695
680
  Args:
696
- **kwargs (Any): Arbitrary keyword arguments to customize the export process. These are combined with
697
- the model's overrides and method defaults. Common arguments include:
698
- format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
699
- half (bool): Export model in half-precision.
700
- int8 (bool): Export model in int8 precision.
701
- device (str): Device to run the export on.
702
- workspace (int): Maximum memory workspace size for TensorRT engines.
703
- nms (bool): Add Non-Maximum Suppression (NMS) module to model.
704
- simplify (bool): Simplify ONNX model.
681
+ **kwargs (Any): Arbitrary keyword arguments for export configuration. Common options include:
682
+ - format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
683
+ - half (bool): Export model in half-precision.
684
+ - int8 (bool): Export model in int8 precision.
685
+ - device (str): Device to run the export on.
686
+ - workspace (int): Maximum memory workspace size for TensorRT engines.
687
+ - nms (bool): Add Non-Maximum Suppression (NMS) module to model.
688
+ - simplify (bool): Simplify ONNX model.
705
689
 
706
690
  Returns:
707
691
  (str): The path to the exported model file.
@@ -734,32 +718,31 @@ class Model(torch.nn.Module):
734
718
  trainer=None,
735
719
  **kwargs: Any,
736
720
  ):
737
- """
738
- Trains the model using the specified dataset and training configuration.
721
+ """Train the model using the specified dataset and training configuration.
739
722
 
740
- This method facilitates model training with a range of customizable settings. It supports training with a
741
- custom trainer or the default training approach. The method handles scenarios such as resuming training
742
- from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
723
+ This method facilitates model training with a range of customizable settings. It supports training with a custom
724
+ trainer or the default training approach. The method handles scenarios such as resuming training from a
725
+ checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
743
726
 
744
- When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
745
- arguments and warns if local arguments are provided. It checks for pip updates and combines default
746
- configurations, method-specific defaults, and user-provided arguments to configure the training process.
727
+ When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training arguments and
728
+ warns if local arguments are provided. It checks for pip updates and combines default configurations,
729
+ method-specific defaults, and user-provided arguments to configure the training process.
747
730
 
748
731
  Args:
749
- trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
732
+ trainer (BaseTrainer, optional): Custom trainer instance for model training. If None, uses default.
750
733
  **kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
751
- data (str): Path to dataset configuration file.
752
- epochs (int): Number of training epochs.
753
- batch_size (int): Batch size for training.
754
- imgsz (int): Input image size.
755
- device (str): Device to run training on (e.g., 'cuda', 'cpu').
756
- workers (int): Number of worker threads for data loading.
757
- optimizer (str): Optimizer to use for training.
758
- lr0 (float): Initial learning rate.
759
- patience (int): Epochs to wait for no observable improvement for early stopping of training.
734
+ - data (str): Path to dataset configuration file.
735
+ - epochs (int): Number of training epochs.
736
+ - batch (int): Batch size for training.
737
+ - imgsz (int): Input image size.
738
+ - device (str): Device to run training on (e.g., 'cuda', 'cpu').
739
+ - workers (int): Number of worker threads for data loading.
740
+ - optimizer (str): Optimizer to use for training.
741
+ - lr0 (float): Initial learning rate.
742
+ - patience (int): Epochs to wait for no observable improvement for early stopping of training.
760
743
 
761
744
  Returns:
762
- (Dict | None): Training metrics if available and training is successful; otherwise, None.
745
+ (dict | None): Training metrics if available and training is successful; otherwise, None.
763
746
 
764
747
  Examples:
765
748
  >>> model = YOLO("yolo11n.pt")
@@ -773,6 +756,8 @@ class Model(torch.nn.Module):
773
756
 
774
757
  checks.check_pip_update_available()
775
758
 
759
+ if isinstance(kwargs.get("pretrained", None), (str, Path)):
760
+ self.load(kwargs["pretrained"]) # load pretrained weights if provided
776
761
  overrides = YAML.load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
777
762
  custom = {
778
763
  # NOTE: handle the case when 'cfg' includes 'data'.
@@ -780,7 +765,7 @@ class Model(torch.nn.Module):
780
765
  "model": self.overrides["model"],
781
766
  "task": self.task,
782
767
  } # method defaults
783
- args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
768
+ args = {**overrides, **custom, **kwargs, "mode": "train", "session": self.session} # prioritizes rightmost args
784
769
  if args.get("resume"):
785
770
  args["resume"] = self.ckpt_path
786
771
 
@@ -789,13 +774,12 @@ class Model(torch.nn.Module):
789
774
  self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
790
775
  self.model = self.trainer.model
791
776
 
792
- self.trainer.hub_session = self.session # attach optional HUB session
793
777
  self.trainer.train()
794
778
  # Update model and cfg after training
795
779
  if RANK in {-1, 0}:
796
780
  ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
797
- self.model, self.ckpt = attempt_load_one_weight(ckpt)
798
- self.overrides = self.model.args
781
+ self.model, self.ckpt = load_checkpoint(ckpt)
782
+ self.overrides = self._reset_ckpt_args(self.model.args)
799
783
  self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
800
784
  return self.metrics
801
785
 
@@ -806,13 +790,12 @@ class Model(torch.nn.Module):
806
790
  *args: Any,
807
791
  **kwargs: Any,
808
792
  ):
809
- """
810
- Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
793
+ """Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
811
794
 
812
- This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
813
- When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module.
814
- Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and
815
- custom arguments to configure the tuning process.
795
+ This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. When Ray Tune
796
+ is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. Otherwise, it uses
797
+ the internal 'Tuner' class for tuning. The method combines default, overridden, and custom arguments to
798
+ configure the tuning process.
816
799
 
817
800
  Args:
818
801
  use_ray (bool): Whether to use Ray Tune for hyperparameter tuning. If False, uses internal tuning method.
@@ -847,17 +830,16 @@ class Model(torch.nn.Module):
847
830
  args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
848
831
  return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
849
832
 
850
- def _apply(self, fn) -> "Model":
851
- """
852
- Apply a function to model tensors that are not parameters or registered buffers.
833
+ def _apply(self, fn) -> Model:
834
+ """Apply a function to model tensors that are not parameters or registered buffers.
853
835
 
854
836
  This method extends the functionality of the parent class's _apply method by additionally resetting the
855
- predictor and updating the device in the model's overrides. It's typically used for operations like
856
- moving the model to a different device or changing its precision.
837
+ predictor and updating the device in the model's overrides. It's typically used for operations like moving the
838
+ model to a different device or changing its precision.
857
839
 
858
840
  Args:
859
- fn (Callable): A function to be applied to the model's tensors. This is typically a method like
860
- to(), cpu(), cuda(), half(), or float().
841
+ fn (Callable): A function to be applied to the model's tensors. This is typically a method like to(), cpu(),
842
+ cuda(), half(), or float().
861
843
 
862
844
  Returns:
863
845
  (Model): The model instance with the function applied and updated attributes.
@@ -870,22 +852,21 @@ class Model(torch.nn.Module):
870
852
  >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
871
853
  """
872
854
  self._check_is_pytorch_model()
873
- self = super()._apply(fn) # noqa
855
+ self = super()._apply(fn)
874
856
  self.predictor = None # reset predictor as device may have changed
875
857
  self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
876
858
  return self
877
859
 
878
860
  @property
879
- def names(self) -> Dict[int, str]:
880
- """
881
- Retrieves the class names associated with the loaded model.
861
+ def names(self) -> dict[int, str]:
862
+ """Retrieve the class names associated with the loaded model.
882
863
 
883
864
  This property returns the class names if they are defined in the model. It checks the class names for validity
884
865
  using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
885
866
  initialized, it sets it up before retrieving the names.
886
867
 
887
868
  Returns:
888
- (Dict[int, str]): A dictionary of class names associated with the model, where keys are class indices and
869
+ (dict[int, str]): A dictionary of class names associated with the model, where keys are class indices and
889
870
  values are the corresponding class names.
890
871
 
891
872
  Raises:
@@ -901,14 +882,14 @@ class Model(torch.nn.Module):
901
882
  if hasattr(self.model, "names"):
902
883
  return check_class_names(self.model.names)
903
884
  if not self.predictor: # export formats will not have predictor defined until predict() is called
904
- self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
905
- self.predictor.setup_model(model=self.model, verbose=False)
885
+ predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
886
+ predictor.setup_model(model=self.model, verbose=False) # do not mess with self.predictor.model args
887
+ return predictor.model.names
906
888
  return self.predictor.model.names
907
889
 
908
890
  @property
909
891
  def device(self) -> torch.device:
910
- """
911
- Get the device on which the model's parameters are allocated.
892
+ """Get the device on which the model's parameters are allocated.
912
893
 
913
894
  This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
914
895
  applicable only to models that are instances of torch.nn.Module.
@@ -931,12 +912,11 @@ class Model(torch.nn.Module):
931
912
 
932
913
  @property
933
914
  def transforms(self):
934
- """
935
- Retrieves the transformations applied to the input data of the loaded model.
915
+ """Retrieve the transformations applied to the input data of the loaded model.
936
916
 
937
- This property returns the transformations if they are defined in the model. The transforms
938
- typically include preprocessing steps like resizing, normalization, and data augmentation
939
- that are applied to input data before it is fed into the model.
917
+ This property returns the transformations if they are defined in the model. The transforms typically include
918
+ preprocessing steps like resizing, normalization, and data augmentation that are applied to input data before it
919
+ is fed into the model.
940
920
 
941
921
  Returns:
942
922
  (object | None): The transform object of the model if available, otherwise None.
@@ -952,18 +932,17 @@ class Model(torch.nn.Module):
952
932
  return self.model.transforms if hasattr(self.model, "transforms") else None
953
933
 
954
934
  def add_callback(self, event: str, func) -> None:
955
- """
956
- Add a callback function for a specified event.
935
+ """Add a callback function for a specified event.
957
936
 
958
- This method allows registering custom callback functions that are triggered on specific events during
959
- model operations such as training or inference. Callbacks provide a way to extend and customize the
960
- behavior of the model at various stages of its lifecycle.
937
+ This method allows registering custom callback functions that are triggered on specific events during model
938
+ operations such as training or inference. Callbacks provide a way to extend and customize the behavior of the
939
+ model at various stages of its lifecycle.
961
940
 
962
941
  Args:
963
- event (str): The name of the event to attach the callback to. Must be a valid event name recognized
964
- by the Ultralytics framework.
965
- func (Callable): The callback function to be registered. This function will be called when the
966
- specified event occurs.
942
+ event (str): The name of the event to attach the callback to. Must be a valid event name recognized by the
943
+ Ultralytics framework.
944
+ func (Callable): The callback function to be registered. This function will be called when the specified
945
+ event occurs.
967
946
 
968
947
  Raises:
969
948
  ValueError: If the event name is not recognized or is invalid.
@@ -978,12 +957,11 @@ class Model(torch.nn.Module):
978
957
  self.callbacks[event].append(func)
979
958
 
980
959
  def clear_callback(self, event: str) -> None:
981
- """
982
- Clears all callback functions registered for a specified event.
960
+ """Clear all callback functions registered for a specified event.
983
961
 
984
- This method removes all custom and default callback functions associated with the given event.
985
- It resets the callback list for the specified event to an empty list, effectively removing all
986
- registered callbacks for that event.
962
+ This method removes all custom and default callback functions associated with the given event. It resets the
963
+ callback list for the specified event to an empty list, effectively removing all registered callbacks for that
964
+ event.
987
965
 
988
966
  Args:
989
967
  event (str): The name of the event for which to clear the callbacks. This should be a valid event name
@@ -1006,8 +984,7 @@ class Model(torch.nn.Module):
1006
984
  self.callbacks[event] = []
1007
985
 
1008
986
  def reset_callbacks(self) -> None:
1009
- """
1010
- Reset all callbacks to their default functions.
987
+ """Reset all callbacks to their default functions.
1011
988
 
1012
989
  This method reinstates the default callback functions for all events, removing any custom callbacks that were
1013
990
  previously added. It iterates through all default callback events and replaces the current callbacks with the
@@ -1029,13 +1006,12 @@ class Model(torch.nn.Module):
1029
1006
  self.callbacks[event] = [callbacks.default_callbacks[event][0]]
1030
1007
 
1031
1008
  @staticmethod
1032
- def _reset_ckpt_args(args: dict) -> dict:
1033
- """
1034
- Reset specific arguments when loading a PyTorch model checkpoint.
1009
+ def _reset_ckpt_args(args: dict[str, Any]) -> dict[str, Any]:
1010
+ """Reset specific arguments when loading a PyTorch model checkpoint.
1035
1011
 
1036
- This method filters the input arguments dictionary to retain only a specific set of keys that are
1037
- considered important for model loading. It's used to ensure that only relevant arguments are preserved
1038
- when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
1012
+ This method filters the input arguments dictionary to retain only a specific set of keys that are considered
1013
+ important for model loading. It's used to ensure that only relevant arguments are preserved when loading a model
1014
+ from a checkpoint, discarding any unnecessary or potentially conflicting settings.
1039
1015
 
1040
1016
  Args:
1041
1017
  args (dict): A dictionary containing various model arguments and settings.
@@ -1058,12 +1034,11 @@ class Model(torch.nn.Module):
1058
1034
  # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
1059
1035
 
1060
1036
  def _smart_load(self, key: str):
1061
- """
1062
- Intelligently loads the appropriate module based on the model task.
1037
+ """Intelligently load the appropriate module based on the model task.
1063
1038
 
1064
- This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
1065
- based on the current task of the model and the provided key. It uses the task_map dictionary to determine
1066
- the appropriate module to load for the specific task.
1039
+ This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) based
1040
+ on the current task of the model and the provided key. It uses the task_map dictionary to determine the
1041
+ appropriate module to load for the specific task.
1067
1042
 
1068
1043
  Args:
1069
1044
  key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
@@ -1088,21 +1063,20 @@ class Model(torch.nn.Module):
1088
1063
 
1089
1064
  @property
1090
1065
  def task_map(self) -> dict:
1091
- """
1092
- Provides a mapping from model tasks to corresponding classes for different modes.
1066
+ """Provide a mapping from model tasks to corresponding classes for different modes.
1093
1067
 
1094
- This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
1095
- to a nested dictionary. The nested dictionary contains mappings for different operational modes
1096
- (model, trainer, validator, predictor) to their respective class implementations.
1068
+ This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) to a
1069
+ nested dictionary. The nested dictionary contains mappings for different operational modes (model, trainer,
1070
+ validator, predictor) to their respective class implementations.
1097
1071
 
1098
- The mapping allows for dynamic loading of appropriate classes based on the model's task and the
1099
- desired operational mode. This facilitates a flexible and extensible architecture for handling
1100
- various tasks and modes within the Ultralytics framework.
1072
+ The mapping allows for dynamic loading of appropriate classes based on the model's task and the desired
1073
+ operational mode. This facilitates a flexible and extensible architecture for handling various tasks and modes
1074
+ within the Ultralytics framework.
1101
1075
 
1102
1076
  Returns:
1103
- (Dict[str, Dict[str, Any]]): A dictionary mapping task names to nested dictionaries. Each nested dictionary
1104
- contains mappings for 'model', 'trainer', 'validator', and 'predictor' keys to their respective class
1105
- implementations for that task.
1077
+ (dict[str, dict[str, Any]]): A dictionary mapping task names to nested dictionaries. Each nested dictionary
1078
+ contains mappings for 'model', 'trainer', 'validator', and 'predictor' keys to their respective class
1079
+ implementations for that task.
1106
1080
 
1107
1081
  Examples:
1108
1082
  >>> model = Model("yolo11n.pt")
@@ -1113,8 +1087,7 @@ class Model(torch.nn.Module):
1113
1087
  raise NotImplementedError("Please provide task map for your model!")
1114
1088
 
1115
1089
  def eval(self):
1116
- """
1117
- Sets the model to evaluation mode.
1090
+ """Sets the model to evaluation mode.
1118
1091
 
1119
1092
  This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
1120
1093
  that behave differently during training and evaluation. In evaluation mode, these layers use running statistics
@@ -1132,8 +1105,7 @@ class Model(torch.nn.Module):
1132
1105
  return self
1133
1106
 
1134
1107
  def __getattr__(self, name):
1135
- """
1136
- Enable accessing model attributes directly through the Model class.
1108
+ """Enable accessing model attributes directly through the Model class.
1137
1109
 
1138
1110
  This method provides a way to access attributes of the underlying model directly through the Model class
1139
1111
  instance. It first checks if the requested attribute is 'model', in which case it returns the model from