ultralytics 8.2.61__py3-none-any.whl → 8.2.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +154 -103
- ultralytics/data/annotator.py +16 -12
- ultralytics/data/augment.py +1478 -195
- ultralytics/data/explorer/gui/dash.py +41 -26
- ultralytics/data/loaders.py +1 -1
- ultralytics/engine/model.py +483 -176
- ultralytics/engine/results.py +1035 -256
- ultralytics/models/fastsam/predict.py +1 -3
- ultralytics/models/nas/predict.py +1 -3
- ultralytics/models/rtdetr/predict.py +4 -6
- ultralytics/models/sam/predict.py +1 -3
- ultralytics/solutions/streamlit_inference.py +5 -2
- {ultralytics-8.2.61.dist-info → ultralytics-8.2.62.dist-info}/METADATA +1 -1
- {ultralytics-8.2.61.dist-info → ultralytics-8.2.62.dist-info}/RECORD +19 -19
- {ultralytics-8.2.61.dist-info → ultralytics-8.2.62.dist-info}/WHEEL +1 -1
- {ultralytics-8.2.61.dist-info → ultralytics-8.2.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.61.dist-info → ultralytics-8.2.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.61.dist-info → ultralytics-8.2.62.dist-info}/top_level.txt +0 -0
ultralytics/engine/model.py
CHANGED
|
@@ -30,26 +30,18 @@ class Model(nn.Module):
|
|
|
30
30
|
|
|
31
31
|
This class provides a common interface for various operations related to YOLO models, such as training,
|
|
32
32
|
validation, prediction, exporting, and benchmarking. It handles different types of models, including those
|
|
33
|
-
loaded from local files, Ultralytics HUB, or Triton Server.
|
|
34
|
-
extendable for different tasks and model configurations.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file
|
|
38
|
-
path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
|
|
39
|
-
task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's
|
|
40
|
-
application domain, such as object detection, segmentation, etc. Defaults to None.
|
|
41
|
-
verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False.
|
|
33
|
+
loaded from local files, Ultralytics HUB, or Triton Server.
|
|
42
34
|
|
|
43
35
|
Attributes:
|
|
44
|
-
callbacks (
|
|
36
|
+
callbacks (Dict): A dictionary of callback functions for various events during model operations.
|
|
45
37
|
predictor (BasePredictor): The predictor object used for making predictions.
|
|
46
38
|
model (nn.Module): The underlying PyTorch model.
|
|
47
39
|
trainer (BaseTrainer): The trainer object used for training the model.
|
|
48
|
-
ckpt (
|
|
40
|
+
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
|
|
49
41
|
cfg (str): The configuration of the model if loaded from a *.yaml file.
|
|
50
42
|
ckpt_path (str): The path to the checkpoint file.
|
|
51
|
-
overrides (
|
|
52
|
-
metrics (
|
|
43
|
+
overrides (Dict): A dictionary of overrides for model configuration.
|
|
44
|
+
metrics (Dict): The latest training/validation metrics.
|
|
53
45
|
session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
|
|
54
46
|
task (str): The type of task the model is intended for.
|
|
55
47
|
model_name (str): The name of the model.
|
|
@@ -75,19 +67,14 @@ class Model(nn.Module):
|
|
|
75
67
|
add_callback: Adds a callback function for an event.
|
|
76
68
|
clear_callback: Clears all callbacks for an event.
|
|
77
69
|
reset_callbacks: Resets all callbacks to their default functions.
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
ValueError: If the model file or configuration is invalid or unsupported.
|
|
87
|
-
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
|
88
|
-
TypeError: If the model is not a PyTorch model when required.
|
|
89
|
-
AttributeError: If required attributes or methods are not implemented or available.
|
|
90
|
-
NotImplementedError: If a specific model task or mode is not supported.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
>>> from ultralytics import YOLO
|
|
73
|
+
>>> model = YOLO('yolov8n.pt')
|
|
74
|
+
>>> results = model.predict('image.jpg')
|
|
75
|
+
>>> model.train(data='coco128.yaml', epochs=3)
|
|
76
|
+
>>> metrics = model.val()
|
|
77
|
+
>>> model.export(format='onnx')
|
|
91
78
|
"""
|
|
92
79
|
|
|
93
80
|
def __init__(
|
|
@@ -99,22 +86,27 @@ class Model(nn.Module):
|
|
|
99
86
|
"""
|
|
100
87
|
Initializes a new instance of the YOLO model class.
|
|
101
88
|
|
|
102
|
-
This constructor sets up the model based on the provided model path or name. It handles various types of
|
|
103
|
-
sources, including local files, Ultralytics HUB models, and Triton Server models. The method
|
|
104
|
-
important attributes of the model and prepares it for operations like training,
|
|
89
|
+
This constructor sets up the model based on the provided model path or name. It handles various types of
|
|
90
|
+
model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
|
|
91
|
+
initializes several important attributes of the model and prepares it for operations like training,
|
|
92
|
+
prediction, or export.
|
|
105
93
|
|
|
106
94
|
Args:
|
|
107
|
-
model (Union[str, Path]
|
|
108
|
-
|
|
109
|
-
task (
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
operations. Defaults to False.
|
|
95
|
+
model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
|
|
96
|
+
model name from Ultralytics HUB, or a Triton Server model.
|
|
97
|
+
task (str | None): The task type associated with the YOLO model, specifying its application domain.
|
|
98
|
+
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
|
99
|
+
operations.
|
|
113
100
|
|
|
114
101
|
Raises:
|
|
115
102
|
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
|
116
103
|
ValueError: If the model file or configuration is invalid or unsupported.
|
|
117
104
|
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
|
105
|
+
|
|
106
|
+
Examples:
|
|
107
|
+
>>> model = Model("yolov8n.pt")
|
|
108
|
+
>>> model = Model("path/to/model.yaml", task="detect")
|
|
109
|
+
>>> model = Model("hub_model", verbose=True)
|
|
118
110
|
"""
|
|
119
111
|
super().__init__()
|
|
120
112
|
self.callbacks = callbacks.get_default_callbacks()
|
|
@@ -155,27 +147,50 @@ class Model(nn.Module):
|
|
|
155
147
|
**kwargs,
|
|
156
148
|
) -> list:
|
|
157
149
|
"""
|
|
158
|
-
|
|
150
|
+
Alias for the predict method, enabling the model instance to be callable for predictions.
|
|
159
151
|
|
|
160
|
-
This method simplifies the process of making predictions by allowing the model instance to be called
|
|
161
|
-
with the required arguments
|
|
152
|
+
This method simplifies the process of making predictions by allowing the model instance to be called
|
|
153
|
+
directly with the required arguments.
|
|
162
154
|
|
|
163
155
|
Args:
|
|
164
|
-
source (str | Path | int | PIL.Image | np.ndarray
|
|
165
|
-
predictions.
|
|
166
|
-
|
|
167
|
-
stream (bool
|
|
168
|
-
|
|
169
|
-
**kwargs (any): Additional keyword arguments for configuring the prediction process.
|
|
156
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
|
|
157
|
+
the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
|
|
158
|
+
tensor, or a list/tuple of these.
|
|
159
|
+
stream (bool): If True, treat the input source as a continuous stream for predictions.
|
|
160
|
+
**kwargs (Any): Additional keyword arguments to configure the prediction process.
|
|
170
161
|
|
|
171
162
|
Returns:
|
|
172
|
-
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in
|
|
163
|
+
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
|
164
|
+
Results object.
|
|
165
|
+
|
|
166
|
+
Examples:
|
|
167
|
+
>>> model = YOLO('yolov8n.pt')
|
|
168
|
+
>>> results = model('https://ultralytics.com/images/bus.jpg')
|
|
169
|
+
>>> for r in results:
|
|
170
|
+
... print(f"Detected {len(r)} objects in image")
|
|
173
171
|
"""
|
|
174
172
|
return self.predict(source, stream, **kwargs)
|
|
175
173
|
|
|
176
174
|
@staticmethod
|
|
177
175
|
def is_triton_model(model: str) -> bool:
|
|
178
|
-
"""
|
|
176
|
+
"""
|
|
177
|
+
Checks if the given model string is a Triton Server URL.
|
|
178
|
+
|
|
179
|
+
This static method determines whether the provided model string represents a valid Triton Server URL by
|
|
180
|
+
parsing its components using urllib.parse.urlsplit().
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
model (str): The model string to be checked.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
(bool): True if the model string is a valid Triton Server URL, False otherwise.
|
|
187
|
+
|
|
188
|
+
Examples:
|
|
189
|
+
>>> Model.is_triton_model('http://localhost:8000/v2/models/yolov8n')
|
|
190
|
+
True
|
|
191
|
+
>>> Model.is_triton_model('yolov8n.pt')
|
|
192
|
+
False
|
|
193
|
+
"""
|
|
179
194
|
from urllib.parse import urlsplit
|
|
180
195
|
|
|
181
196
|
url = urlsplit(model)
|
|
@@ -183,7 +198,30 @@ class Model(nn.Module):
|
|
|
183
198
|
|
|
184
199
|
@staticmethod
|
|
185
200
|
def is_hub_model(model: str) -> bool:
|
|
186
|
-
"""
|
|
201
|
+
"""
|
|
202
|
+
Check if the provided model is an Ultralytics HUB model.
|
|
203
|
+
|
|
204
|
+
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
|
205
|
+
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
|
|
206
|
+
or a standalone model ID.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
model (str): The model identifier to check. This can be a URL, an API key and model ID
|
|
210
|
+
combination, or a standalone model ID.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
|
|
214
|
+
|
|
215
|
+
Examples:
|
|
216
|
+
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
|
|
217
|
+
True
|
|
218
|
+
>>> Model.is_hub_model("api_key_example_model_id")
|
|
219
|
+
True
|
|
220
|
+
>>> Model.is_hub_model("example_model_id")
|
|
221
|
+
True
|
|
222
|
+
>>> Model.is_hub_model("not_a_hub_model.pt")
|
|
223
|
+
False
|
|
224
|
+
"""
|
|
187
225
|
return any(
|
|
188
226
|
(
|
|
189
227
|
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
|
@@ -196,11 +234,24 @@ class Model(nn.Module):
|
|
|
196
234
|
"""
|
|
197
235
|
Initializes a new model and infers the task type from the model definitions.
|
|
198
236
|
|
|
237
|
+
This method creates a new model instance based on the provided configuration file. It loads the model
|
|
238
|
+
configuration, infers the task type if not specified, and initializes the model using the appropriate
|
|
239
|
+
class from the task map.
|
|
240
|
+
|
|
199
241
|
Args:
|
|
200
|
-
cfg (str): model configuration file
|
|
201
|
-
task (str | None): model
|
|
202
|
-
model (
|
|
203
|
-
|
|
242
|
+
cfg (str): Path to the model configuration file in YAML format.
|
|
243
|
+
task (str | None): The specific task for the model. If None, it will be inferred from the config.
|
|
244
|
+
model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
|
|
245
|
+
a new one.
|
|
246
|
+
verbose (bool): If True, displays model information during loading.
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
ValueError: If the configuration file is invalid or the task cannot be inferred.
|
|
250
|
+
ImportError: If the required dependencies for the specified task are not installed.
|
|
251
|
+
|
|
252
|
+
Examples:
|
|
253
|
+
>>> model = Model()
|
|
254
|
+
>>> model._new('yolov8n.yaml', task='detect', verbose=True)
|
|
204
255
|
"""
|
|
205
256
|
cfg_dict = yaml_model_load(cfg)
|
|
206
257
|
self.cfg = cfg
|
|
@@ -216,11 +267,23 @@ class Model(nn.Module):
|
|
|
216
267
|
|
|
217
268
|
def _load(self, weights: str, task=None) -> None:
|
|
218
269
|
"""
|
|
219
|
-
|
|
270
|
+
Loads a model from a checkpoint file or initializes it from a weights file.
|
|
271
|
+
|
|
272
|
+
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
|
|
273
|
+
up the model, task, and related attributes based on the loaded weights.
|
|
220
274
|
|
|
221
275
|
Args:
|
|
222
|
-
weights (str): model
|
|
223
|
-
task (str | None): model
|
|
276
|
+
weights (str): Path to the model weights file to be loaded.
|
|
277
|
+
task (str | None): The task associated with the model. If None, it will be inferred from the model.
|
|
278
|
+
|
|
279
|
+
Raises:
|
|
280
|
+
FileNotFoundError: If the specified weights file does not exist or is inaccessible.
|
|
281
|
+
ValueError: If the weights file format is unsupported or invalid.
|
|
282
|
+
|
|
283
|
+
Examples:
|
|
284
|
+
>>> model = Model()
|
|
285
|
+
>>> model._load('yolov8n.pt')
|
|
286
|
+
>>> model._load('path/to/weights.pth', task='detect')
|
|
224
287
|
"""
|
|
225
288
|
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
|
226
289
|
weights = checks.check_file(weights) # automatically download and return local filename
|
|
@@ -241,7 +304,22 @@ class Model(nn.Module):
|
|
|
241
304
|
self.model_name = weights
|
|
242
305
|
|
|
243
306
|
def _check_is_pytorch_model(self) -> None:
|
|
244
|
-
"""
|
|
307
|
+
"""
|
|
308
|
+
Checks if the model is a PyTorch model and raises a TypeError if it's not.
|
|
309
|
+
|
|
310
|
+
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
|
|
311
|
+
certain operations that require a PyTorch model are only performed on compatible model types.
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
|
|
315
|
+
information about supported model formats and operations.
|
|
316
|
+
|
|
317
|
+
Examples:
|
|
318
|
+
>>> model = Model("yolov8n.pt")
|
|
319
|
+
>>> model._check_is_pytorch_model() # No error raised
|
|
320
|
+
>>> model = Model("yolov8n.onnx")
|
|
321
|
+
>>> model._check_is_pytorch_model() # Raises TypeError
|
|
322
|
+
"""
|
|
245
323
|
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
|
246
324
|
pt_module = isinstance(self.model, nn.Module)
|
|
247
325
|
if not (pt_module or pt_str):
|
|
@@ -255,17 +333,21 @@ class Model(nn.Module):
|
|
|
255
333
|
|
|
256
334
|
def reset_weights(self) -> "Model":
|
|
257
335
|
"""
|
|
258
|
-
Resets the model
|
|
336
|
+
Resets the model's weights to their initial state.
|
|
259
337
|
|
|
260
338
|
This method iterates through all modules in the model and resets their parameters if they have a
|
|
261
|
-
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
|
|
262
|
-
to be updated during training.
|
|
339
|
+
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
|
|
340
|
+
enabling them to be updated during training.
|
|
263
341
|
|
|
264
342
|
Returns:
|
|
265
|
-
|
|
343
|
+
(Model): The instance of the class with reset weights.
|
|
266
344
|
|
|
267
345
|
Raises:
|
|
268
346
|
AssertionError: If the model is not a PyTorch model.
|
|
347
|
+
|
|
348
|
+
Examples:
|
|
349
|
+
>>> model = Model('yolov8n.pt')
|
|
350
|
+
>>> model.reset_weights()
|
|
269
351
|
"""
|
|
270
352
|
self._check_is_pytorch_model()
|
|
271
353
|
for m in self.model.modules():
|
|
@@ -283,13 +365,18 @@ class Model(nn.Module):
|
|
|
283
365
|
name and shape and transfers them to the model.
|
|
284
366
|
|
|
285
367
|
Args:
|
|
286
|
-
weights (str
|
|
368
|
+
weights (Union[str, Path]): Path to the weights file or a weights object.
|
|
287
369
|
|
|
288
370
|
Returns:
|
|
289
|
-
|
|
371
|
+
(Model): The instance of the class with loaded weights.
|
|
290
372
|
|
|
291
373
|
Raises:
|
|
292
374
|
AssertionError: If the model is not a PyTorch model.
|
|
375
|
+
|
|
376
|
+
Examples:
|
|
377
|
+
>>> model = Model()
|
|
378
|
+
>>> model.load('yolov8n.pt')
|
|
379
|
+
>>> model.load(Path('path/to/weights.pt'))
|
|
293
380
|
"""
|
|
294
381
|
self._check_is_pytorch_model()
|
|
295
382
|
if isinstance(weights, (str, Path)):
|
|
@@ -301,14 +388,19 @@ class Model(nn.Module):
|
|
|
301
388
|
"""
|
|
302
389
|
Saves the current model state to a file.
|
|
303
390
|
|
|
304
|
-
This method exports the model's checkpoint (ckpt) to the specified filename.
|
|
391
|
+
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
|
|
392
|
+
the date, Ultralytics version, license information, and a link to the documentation.
|
|
305
393
|
|
|
306
394
|
Args:
|
|
307
|
-
filename (str
|
|
308
|
-
use_dill (bool): Whether to try using dill for serialization if available.
|
|
395
|
+
filename (Union[str, Path]): The name of the file to save the model to.
|
|
396
|
+
use_dill (bool): Whether to try using dill for serialization if available.
|
|
309
397
|
|
|
310
398
|
Raises:
|
|
311
399
|
AssertionError: If the model is not a PyTorch model.
|
|
400
|
+
|
|
401
|
+
Examples:
|
|
402
|
+
>>> model = Model('yolov8n.pt')
|
|
403
|
+
>>> model.save('my_model.pt')
|
|
312
404
|
"""
|
|
313
405
|
self._check_is_pytorch_model()
|
|
314
406
|
from copy import deepcopy
|
|
@@ -329,30 +421,47 @@ class Model(nn.Module):
|
|
|
329
421
|
"""
|
|
330
422
|
Logs or returns model information.
|
|
331
423
|
|
|
332
|
-
This method provides an overview or detailed information about the model, depending on the arguments
|
|
333
|
-
It can control the verbosity of the output.
|
|
424
|
+
This method provides an overview or detailed information about the model, depending on the arguments
|
|
425
|
+
passed. It can control the verbosity of the output and return the information as a list.
|
|
334
426
|
|
|
335
427
|
Args:
|
|
336
|
-
detailed (bool): If True, shows detailed information about the model
|
|
337
|
-
verbose (bool): If True, prints the information. If False, returns the information
|
|
428
|
+
detailed (bool): If True, shows detailed information about the model layers and parameters.
|
|
429
|
+
verbose (bool): If True, prints the information. If False, returns the information as a list.
|
|
338
430
|
|
|
339
431
|
Returns:
|
|
340
|
-
(
|
|
432
|
+
(List[str]): A list of strings containing various types of information about the model, including
|
|
433
|
+
model summary, layer details, and parameter counts. Empty if verbose is True.
|
|
341
434
|
|
|
342
435
|
Raises:
|
|
343
|
-
|
|
436
|
+
TypeError: If the model is not a PyTorch model.
|
|
437
|
+
|
|
438
|
+
Examples:
|
|
439
|
+
>>> model = Model('yolov8n.pt')
|
|
440
|
+
>>> model.info() # Prints model summary
|
|
441
|
+
>>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list
|
|
344
442
|
"""
|
|
345
443
|
self._check_is_pytorch_model()
|
|
346
444
|
return self.model.info(detailed=detailed, verbose=verbose)
|
|
347
445
|
|
|
348
446
|
def fuse(self):
|
|
349
447
|
"""
|
|
350
|
-
Fuses Conv2d and BatchNorm2d layers in the model.
|
|
448
|
+
Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.
|
|
449
|
+
|
|
450
|
+
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
|
|
451
|
+
into a single layer. This fusion can significantly improve inference speed by reducing the number of
|
|
452
|
+
operations and memory accesses required during forward passes.
|
|
351
453
|
|
|
352
|
-
|
|
454
|
+
The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
|
|
455
|
+
bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
|
|
456
|
+
performs both convolution and normalization in one step.
|
|
353
457
|
|
|
354
458
|
Raises:
|
|
355
|
-
|
|
459
|
+
TypeError: If the model is not a PyTorch nn.Module.
|
|
460
|
+
|
|
461
|
+
Examples:
|
|
462
|
+
>>> model = Model("yolov8n.pt")
|
|
463
|
+
>>> model.fuse()
|
|
464
|
+
>>> # Model is now fused and ready for optimized inference
|
|
356
465
|
"""
|
|
357
466
|
self._check_is_pytorch_model()
|
|
358
467
|
self.model.fuse()
|
|
@@ -366,20 +475,26 @@ class Model(nn.Module):
|
|
|
366
475
|
"""
|
|
367
476
|
Generates image embeddings based on the provided source.
|
|
368
477
|
|
|
369
|
-
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
|
370
|
-
It allows customization of the embedding process through various keyword arguments.
|
|
478
|
+
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
|
479
|
+
source. It allows customization of the embedding process through various keyword arguments.
|
|
371
480
|
|
|
372
481
|
Args:
|
|
373
|
-
source (str | int |
|
|
374
|
-
|
|
375
|
-
stream (bool): If True, predictions are streamed.
|
|
376
|
-
**kwargs (
|
|
482
|
+
source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
|
|
483
|
+
generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
|
|
484
|
+
stream (bool): If True, predictions are streamed.
|
|
485
|
+
**kwargs (Any): Additional keyword arguments for configuring the embedding process.
|
|
377
486
|
|
|
378
487
|
Returns:
|
|
379
488
|
(List[torch.Tensor]): A list containing the image embeddings.
|
|
380
489
|
|
|
381
490
|
Raises:
|
|
382
491
|
AssertionError: If the model is not a PyTorch model.
|
|
492
|
+
|
|
493
|
+
Examples:
|
|
494
|
+
>>> model = YOLO('yolov8n.pt')
|
|
495
|
+
>>> image = 'https://ultralytics.com/images/bus.jpg'
|
|
496
|
+
>>> embeddings = model.embed(image)
|
|
497
|
+
>>> print(embeddings[0].shape)
|
|
383
498
|
"""
|
|
384
499
|
if not kwargs.get("embed"):
|
|
385
500
|
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
|
@@ -397,28 +512,31 @@ class Model(nn.Module):
|
|
|
397
512
|
|
|
398
513
|
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
|
399
514
|
It supports predictions with custom predictors or the default predictor method. The method handles different
|
|
400
|
-
types of image sources and can operate in a streaming mode.
|
|
401
|
-
through 'prompts'.
|
|
402
|
-
|
|
403
|
-
The method sets up a new predictor if not already present and updates its arguments with each call.
|
|
404
|
-
It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
|
|
405
|
-
is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
|
|
406
|
-
for confidence threshold and saving behavior.
|
|
515
|
+
types of image sources and can operate in a streaming mode.
|
|
407
516
|
|
|
408
517
|
Args:
|
|
409
|
-
source (str | int |
|
|
410
|
-
Accepts various types
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
518
|
+
source (str | Path | int | List[str] | List[Path] | List[int] | np.ndarray | torch.Tensor): The source
|
|
519
|
+
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
|
|
520
|
+
images, numpy arrays, and torch tensors.
|
|
521
|
+
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
|
522
|
+
predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
|
|
523
|
+
If None, the method uses a default predictor.
|
|
524
|
+
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
|
|
416
525
|
|
|
417
526
|
Returns:
|
|
418
|
-
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
527
|
+
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
|
528
|
+
Results object.
|
|
529
|
+
|
|
530
|
+
Examples:
|
|
531
|
+
>>> model = YOLO('yolov8n.pt')
|
|
532
|
+
>>> results = model.predict(source='path/to/image.jpg', conf=0.25)
|
|
533
|
+
>>> for r in results:
|
|
534
|
+
... print(r.boxes.data) # print detection bounding boxes
|
|
535
|
+
|
|
536
|
+
Notes:
|
|
537
|
+
- If 'source' is not provided, it defaults to the ASSETS constant with a warning.
|
|
538
|
+
- The method sets up a new predictor if not already present and updates its arguments with each call.
|
|
539
|
+
- For SAM-type models, 'prompts' can be passed as a keyword argument.
|
|
422
540
|
"""
|
|
423
541
|
if source is None:
|
|
424
542
|
source = ASSETS
|
|
@@ -453,26 +571,33 @@ class Model(nn.Module):
|
|
|
453
571
|
"""
|
|
454
572
|
Conducts object tracking on the specified input source using the registered trackers.
|
|
455
573
|
|
|
456
|
-
This method performs object tracking using the model's predictors and optionally registered trackers. It
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
already present and optionally persists them based on the 'persist' flag.
|
|
460
|
-
|
|
461
|
-
The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low
|
|
462
|
-
confidence predictions as input. The tracking mode is explicitly set in the keyword arguments.
|
|
574
|
+
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
|
575
|
+
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
|
576
|
+
The method registers trackers if not already present and can persist them between calls.
|
|
463
577
|
|
|
464
578
|
Args:
|
|
465
|
-
source (str,
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
579
|
+
source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
|
|
580
|
+
tracking. Can be a file path, URL, or video stream.
|
|
581
|
+
stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
|
|
582
|
+
persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
|
|
583
|
+
**kwargs (Any): Additional keyword arguments for configuring the tracking process.
|
|
470
584
|
|
|
471
585
|
Returns:
|
|
472
|
-
(List[ultralytics.engine.results.Results]): A list of tracking results,
|
|
586
|
+
(List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
|
|
473
587
|
|
|
474
588
|
Raises:
|
|
475
589
|
AttributeError: If the predictor does not have registered trackers.
|
|
590
|
+
|
|
591
|
+
Examples:
|
|
592
|
+
>>> model = YOLO('yolov8n.pt')
|
|
593
|
+
>>> results = model.track(source='path/to/video.mp4', show=True)
|
|
594
|
+
>>> for r in results:
|
|
595
|
+
... print(r.boxes.id) # print tracking IDs
|
|
596
|
+
|
|
597
|
+
Notes:
|
|
598
|
+
- This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.
|
|
599
|
+
- The tracking mode is explicitly set in the keyword arguments.
|
|
600
|
+
- Batch size is set to 1 for tracking in videos.
|
|
476
601
|
"""
|
|
477
602
|
if not hasattr(self.predictor, "trackers"):
|
|
478
603
|
from ultralytics.trackers import register_tracker
|
|
@@ -491,26 +616,25 @@ class Model(nn.Module):
|
|
|
491
616
|
"""
|
|
492
617
|
Validates the model using a specified dataset and validation configuration.
|
|
493
618
|
|
|
494
|
-
This method facilitates the model validation process, allowing for
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
the validation process. After validation, it updates the model's metrics with the results obtained from the
|
|
498
|
-
validator.
|
|
499
|
-
|
|
500
|
-
The method supports various arguments that allow customization of the validation process. For a comprehensive
|
|
501
|
-
list of all configurable options, users should refer to the 'configuration' section in the documentation.
|
|
619
|
+
This method facilitates the model validation process, allowing for customization through various settings. It
|
|
620
|
+
supports validation with a custom validator or the default validation approach. The method combines default
|
|
621
|
+
configurations, method-specific defaults, and user-provided arguments to configure the validation process.
|
|
502
622
|
|
|
503
623
|
Args:
|
|
504
|
-
validator (BaseValidator
|
|
505
|
-
|
|
506
|
-
**kwargs (
|
|
507
|
-
used to customize various aspects of the validation process.
|
|
624
|
+
validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
|
|
625
|
+
validating the model.
|
|
626
|
+
**kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
|
|
508
627
|
|
|
509
628
|
Returns:
|
|
510
629
|
(ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.
|
|
511
630
|
|
|
512
631
|
Raises:
|
|
513
632
|
AssertionError: If the model is not a PyTorch model.
|
|
633
|
+
|
|
634
|
+
Examples:
|
|
635
|
+
>>> model = YOLO('yolov8n.pt')
|
|
636
|
+
>>> results = model.val(data='coco128.yaml', imgsz=640)
|
|
637
|
+
>>> print(results.box.map) # Print mAP50-95
|
|
514
638
|
"""
|
|
515
639
|
custom = {"rect": True} # method defaults
|
|
516
640
|
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
|
@@ -528,23 +652,31 @@ class Model(nn.Module):
|
|
|
528
652
|
Benchmarks the model across various export formats to evaluate performance.
|
|
529
653
|
|
|
530
654
|
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
|
|
531
|
-
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
|
|
532
|
-
using a combination of default configuration values, model-specific arguments, method-specific
|
|
533
|
-
any additional user-provided keyword arguments.
|
|
534
|
-
|
|
535
|
-
The method supports various arguments that allow customization of the benchmarking process, such as dataset
|
|
536
|
-
choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all
|
|
537
|
-
configurable options, users should refer to the 'configuration' section in the documentation.
|
|
655
|
+
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
|
|
656
|
+
configured using a combination of default configuration values, model-specific arguments, method-specific
|
|
657
|
+
defaults, and any additional user-provided keyword arguments.
|
|
538
658
|
|
|
539
659
|
Args:
|
|
540
|
-
**kwargs (
|
|
541
|
-
default configurations, model-specific arguments, and method defaults.
|
|
660
|
+
**kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
|
|
661
|
+
default configurations, model-specific arguments, and method defaults. Common options include:
|
|
662
|
+
- data (str): Path to the dataset for benchmarking.
|
|
663
|
+
- imgsz (int | List[int]): Image size for benchmarking.
|
|
664
|
+
- half (bool): Whether to use half-precision (FP16) mode.
|
|
665
|
+
- int8 (bool): Whether to use int8 precision mode.
|
|
666
|
+
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
|
667
|
+
- verbose (bool): Whether to print detailed benchmark information.
|
|
542
668
|
|
|
543
669
|
Returns:
|
|
544
|
-
(
|
|
670
|
+
(Dict): A dictionary containing the results of the benchmarking process, including metrics for
|
|
671
|
+
different export formats.
|
|
545
672
|
|
|
546
673
|
Raises:
|
|
547
674
|
AssertionError: If the model is not a PyTorch model.
|
|
675
|
+
|
|
676
|
+
Examples:
|
|
677
|
+
>>> model = YOLO('yolov8n.pt')
|
|
678
|
+
>>> results = model.benchmark(data='coco8.yaml', imgsz=640, half=True)
|
|
679
|
+
>>> print(results)
|
|
548
680
|
"""
|
|
549
681
|
self._check_is_pytorch_model()
|
|
550
682
|
from ultralytics.utils.benchmarks import benchmark
|
|
@@ -570,20 +702,31 @@ class Model(nn.Module):
|
|
|
570
702
|
|
|
571
703
|
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
|
|
572
704
|
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
|
|
573
|
-
defaults, and any additional arguments provided.
|
|
574
|
-
|
|
575
|
-
The method supports a wide range of arguments to customize the export process. For a comprehensive list of all
|
|
576
|
-
possible arguments, refer to the 'configuration' section in the documentation.
|
|
705
|
+
defaults, and any additional arguments provided.
|
|
577
706
|
|
|
578
707
|
Args:
|
|
579
|
-
**kwargs (
|
|
580
|
-
model's overrides and method defaults.
|
|
708
|
+
**kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with
|
|
709
|
+
the model's overrides and method defaults. Common arguments include:
|
|
710
|
+
format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
|
|
711
|
+
half (bool): Export model in half-precision.
|
|
712
|
+
int8 (bool): Export model in int8 precision.
|
|
713
|
+
device (str): Device to run the export on.
|
|
714
|
+
workspace (int): Maximum memory workspace size for TensorRT engines.
|
|
715
|
+
nms (bool): Add Non-Maximum Suppression (NMS) module to model.
|
|
716
|
+
simplify (bool): Simplify ONNX model.
|
|
581
717
|
|
|
582
718
|
Returns:
|
|
583
|
-
(str): The
|
|
719
|
+
(str): The path to the exported model file.
|
|
584
720
|
|
|
585
721
|
Raises:
|
|
586
722
|
AssertionError: If the model is not a PyTorch model.
|
|
723
|
+
ValueError: If an unsupported export format is specified.
|
|
724
|
+
RuntimeError: If the export process fails due to errors.
|
|
725
|
+
|
|
726
|
+
Examples:
|
|
727
|
+
>>> model = YOLO('yolov8n.pt')
|
|
728
|
+
>>> model.export(format='onnx', dynamic=True, simplify=True)
|
|
729
|
+
'path/to/exported/model.onnx'
|
|
587
730
|
"""
|
|
588
731
|
self._check_is_pytorch_model()
|
|
589
732
|
from .exporter import Exporter
|
|
@@ -606,29 +749,38 @@ class Model(nn.Module):
|
|
|
606
749
|
"""
|
|
607
750
|
Trains the model using the specified dataset and training configuration.
|
|
608
751
|
|
|
609
|
-
This method facilitates model training with a range of customizable settings
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
updating model and configuration after training.
|
|
752
|
+
This method facilitates model training with a range of customizable settings. It supports training with a
|
|
753
|
+
custom trainer or the default training approach. The method handles scenarios such as resuming training
|
|
754
|
+
from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
|
|
613
755
|
|
|
614
|
-
When using Ultralytics HUB, if the session
|
|
615
|
-
arguments and
|
|
616
|
-
configurations, method-specific defaults, and user-provided arguments to configure the training process.
|
|
617
|
-
training, it updates the model and its configurations, and optionally attaches metrics.
|
|
756
|
+
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
|
|
757
|
+
arguments and warns if local arguments are provided. It checks for pip updates and combines default
|
|
758
|
+
configurations, method-specific defaults, and user-provided arguments to configure the training process.
|
|
618
759
|
|
|
619
760
|
Args:
|
|
620
|
-
trainer (BaseTrainer
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
761
|
+
trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
|
|
762
|
+
**kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
|
|
763
|
+
data (str): Path to dataset configuration file.
|
|
764
|
+
epochs (int): Number of training epochs.
|
|
765
|
+
batch_size (int): Batch size for training.
|
|
766
|
+
imgsz (int): Input image size.
|
|
767
|
+
device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
|
768
|
+
workers (int): Number of worker threads for data loading.
|
|
769
|
+
optimizer (str): Optimizer to use for training.
|
|
770
|
+
lr0 (float): Initial learning rate.
|
|
771
|
+
patience (int): Epochs to wait for no observable improvement for early stopping of training.
|
|
624
772
|
|
|
625
773
|
Returns:
|
|
626
|
-
(
|
|
774
|
+
(Dict | None): Training metrics if available and training is successful; otherwise, None.
|
|
627
775
|
|
|
628
776
|
Raises:
|
|
629
777
|
AssertionError: If the model is not a PyTorch model.
|
|
630
778
|
PermissionError: If there is a permission issue with the HUB session.
|
|
631
779
|
ModuleNotFoundError: If the HUB SDK is not installed.
|
|
780
|
+
|
|
781
|
+
Examples:
|
|
782
|
+
>>> model = YOLO('yolov8n.pt')
|
|
783
|
+
>>> results = model.train(data='coco128.yaml', epochs=3)
|
|
632
784
|
"""
|
|
633
785
|
self._check_is_pytorch_model()
|
|
634
786
|
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
|
@@ -682,14 +834,19 @@ class Model(nn.Module):
|
|
|
682
834
|
Args:
|
|
683
835
|
use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
|
|
684
836
|
iterations (int): The number of tuning iterations to perform. Defaults to 10.
|
|
685
|
-
*args (
|
|
686
|
-
**kwargs (
|
|
837
|
+
*args (List): Variable length argument list for additional arguments.
|
|
838
|
+
**kwargs (Dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
|
|
687
839
|
|
|
688
840
|
Returns:
|
|
689
|
-
(
|
|
841
|
+
(Dict): A dictionary containing the results of the hyperparameter search.
|
|
690
842
|
|
|
691
843
|
Raises:
|
|
692
844
|
AssertionError: If the model is not a PyTorch model.
|
|
845
|
+
|
|
846
|
+
Examples:
|
|
847
|
+
>>> model = YOLO('yolov8n.pt')
|
|
848
|
+
>>> results = model.tune(use_ray=True, iterations=20)
|
|
849
|
+
>>> print(results)
|
|
693
850
|
"""
|
|
694
851
|
self._check_is_pytorch_model()
|
|
695
852
|
if use_ray:
|
|
@@ -704,7 +861,27 @@ class Model(nn.Module):
|
|
|
704
861
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
|
705
862
|
|
|
706
863
|
def _apply(self, fn) -> "Model":
|
|
707
|
-
"""
|
|
864
|
+
"""
|
|
865
|
+
Applies a function to model tensors that are not parameters or registered buffers.
|
|
866
|
+
|
|
867
|
+
This method extends the functionality of the parent class's _apply method by additionally resetting the
|
|
868
|
+
predictor and updating the device in the model's overrides. It's typically used for operations like
|
|
869
|
+
moving the model to a different device or changing its precision.
|
|
870
|
+
|
|
871
|
+
Args:
|
|
872
|
+
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
|
|
873
|
+
to(), cpu(), cuda(), half(), or float().
|
|
874
|
+
|
|
875
|
+
Returns:
|
|
876
|
+
(Model): The model instance with the function applied and updated attributes.
|
|
877
|
+
|
|
878
|
+
Raises:
|
|
879
|
+
AssertionError: If the model is not a PyTorch model.
|
|
880
|
+
|
|
881
|
+
Examples:
|
|
882
|
+
>>> model = Model("yolov8n.pt")
|
|
883
|
+
>>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
|
|
884
|
+
"""
|
|
708
885
|
self._check_is_pytorch_model()
|
|
709
886
|
self = super()._apply(fn) # noqa
|
|
710
887
|
self.predictor = None # reset predictor as device may have changed
|
|
@@ -717,10 +894,19 @@ class Model(nn.Module):
|
|
|
717
894
|
Retrieves the class names associated with the loaded model.
|
|
718
895
|
|
|
719
896
|
This property returns the class names if they are defined in the model. It checks the class names for validity
|
|
720
|
-
using the 'check_class_names' function from the ultralytics.nn.autobackend module.
|
|
897
|
+
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
|
898
|
+
initialized, it sets it up before retrieving the names.
|
|
721
899
|
|
|
722
900
|
Returns:
|
|
723
|
-
(
|
|
901
|
+
(List[str]): A list of class names associated with the model.
|
|
902
|
+
|
|
903
|
+
Raises:
|
|
904
|
+
AttributeError: If the model or predictor does not have a 'names' attribute.
|
|
905
|
+
|
|
906
|
+
Examples:
|
|
907
|
+
>>> model = YOLO('yolov8n.pt')
|
|
908
|
+
>>> print(model.names)
|
|
909
|
+
['person', 'bicycle', 'car', ...]
|
|
724
910
|
"""
|
|
725
911
|
from ultralytics.nn.autobackend import check_class_names
|
|
726
912
|
|
|
@@ -736,11 +922,22 @@ class Model(nn.Module):
|
|
|
736
922
|
"""
|
|
737
923
|
Retrieves the device on which the model's parameters are allocated.
|
|
738
924
|
|
|
739
|
-
This property
|
|
740
|
-
that are instances of nn.Module.
|
|
925
|
+
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
|
926
|
+
applicable only to models that are instances of nn.Module.
|
|
741
927
|
|
|
742
928
|
Returns:
|
|
743
|
-
(torch.device
|
|
929
|
+
(torch.device): The device (CPU/GPU) of the model.
|
|
930
|
+
|
|
931
|
+
Raises:
|
|
932
|
+
AttributeError: If the model is not a PyTorch nn.Module instance.
|
|
933
|
+
|
|
934
|
+
Examples:
|
|
935
|
+
>>> model = YOLO("yolov8n.pt")
|
|
936
|
+
>>> print(model.device)
|
|
937
|
+
device(type='cuda', index=0) # if CUDA is available
|
|
938
|
+
>>> model = model.to("cpu")
|
|
939
|
+
>>> print(model.device)
|
|
940
|
+
device(type='cpu')
|
|
744
941
|
"""
|
|
745
942
|
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
|
746
943
|
|
|
@@ -749,10 +946,20 @@ class Model(nn.Module):
|
|
|
749
946
|
"""
|
|
750
947
|
Retrieves the transformations applied to the input data of the loaded model.
|
|
751
948
|
|
|
752
|
-
This property returns the transformations if they are defined in the model.
|
|
949
|
+
This property returns the transformations if they are defined in the model. The transforms
|
|
950
|
+
typically include preprocessing steps like resizing, normalization, and data augmentation
|
|
951
|
+
that are applied to input data before it is fed into the model.
|
|
753
952
|
|
|
754
953
|
Returns:
|
|
755
954
|
(object | None): The transform object of the model if available, otherwise None.
|
|
955
|
+
|
|
956
|
+
Examples:
|
|
957
|
+
>>> model = YOLO('yolov8n.pt')
|
|
958
|
+
>>> transforms = model.transforms
|
|
959
|
+
>>> if transforms:
|
|
960
|
+
... print(f"Model transforms: {transforms}")
|
|
961
|
+
... else:
|
|
962
|
+
... print("No transforms defined for this model.")
|
|
756
963
|
"""
|
|
757
964
|
return self.model.transforms if hasattr(self.model, "transforms") else None
|
|
758
965
|
|
|
@@ -760,15 +967,25 @@ class Model(nn.Module):
|
|
|
760
967
|
"""
|
|
761
968
|
Adds a callback function for a specified event.
|
|
762
969
|
|
|
763
|
-
This method allows
|
|
764
|
-
model training or inference.
|
|
970
|
+
This method allows registering custom callback functions that are triggered on specific events during
|
|
971
|
+
model operations such as training or inference. Callbacks provide a way to extend and customize the
|
|
972
|
+
behavior of the model at various stages of its lifecycle.
|
|
765
973
|
|
|
766
974
|
Args:
|
|
767
|
-
event (str): The name of the event to attach the callback to.
|
|
768
|
-
|
|
975
|
+
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
|
|
976
|
+
by the Ultralytics framework.
|
|
977
|
+
func (Callable): The callback function to be registered. This function will be called when the
|
|
978
|
+
specified event occurs.
|
|
769
979
|
|
|
770
980
|
Raises:
|
|
771
|
-
ValueError: If the event name is not recognized.
|
|
981
|
+
ValueError: If the event name is not recognized or is invalid.
|
|
982
|
+
|
|
983
|
+
Examples:
|
|
984
|
+
>>> def on_train_start(trainer):
|
|
985
|
+
... print("Training is starting!")
|
|
986
|
+
>>> model = YOLO('yolov8n.pt')
|
|
987
|
+
>>> model.add_callback("on_train_start", on_train_start)
|
|
988
|
+
>>> model.train(data='coco128.yaml', epochs=1)
|
|
772
989
|
"""
|
|
773
990
|
self.callbacks[event].append(func)
|
|
774
991
|
|
|
@@ -777,12 +994,26 @@ class Model(nn.Module):
|
|
|
777
994
|
Clears all callback functions registered for a specified event.
|
|
778
995
|
|
|
779
996
|
This method removes all custom and default callback functions associated with the given event.
|
|
997
|
+
It resets the callback list for the specified event to an empty list, effectively removing all
|
|
998
|
+
registered callbacks for that event.
|
|
780
999
|
|
|
781
1000
|
Args:
|
|
782
|
-
event (str): The name of the event for which to clear the callbacks.
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
1001
|
+
event (str): The name of the event for which to clear the callbacks. This should be a valid event name
|
|
1002
|
+
recognized by the Ultralytics callback system.
|
|
1003
|
+
|
|
1004
|
+
Examples:
|
|
1005
|
+
>>> model = YOLO('yolov8n.pt')
|
|
1006
|
+
>>> model.add_callback('on_train_start', lambda: print('Training started'))
|
|
1007
|
+
>>> model.clear_callback('on_train_start')
|
|
1008
|
+
>>> # All callbacks for 'on_train_start' are now removed
|
|
1009
|
+
|
|
1010
|
+
Notes:
|
|
1011
|
+
- This method affects both custom callbacks added by the user and default callbacks
|
|
1012
|
+
provided by the Ultralytics framework.
|
|
1013
|
+
- After calling this method, no callbacks will be executed for the specified event
|
|
1014
|
+
until new ones are added.
|
|
1015
|
+
- Use with caution as it removes all callbacks, including essential ones that might
|
|
1016
|
+
be required for proper functioning of certain operations.
|
|
786
1017
|
"""
|
|
787
1018
|
self.callbacks[event] = []
|
|
788
1019
|
|
|
@@ -791,14 +1022,45 @@ class Model(nn.Module):
|
|
|
791
1022
|
Resets all callbacks to their default functions.
|
|
792
1023
|
|
|
793
1024
|
This method reinstates the default callback functions for all events, removing any custom callbacks that were
|
|
794
|
-
added
|
|
1025
|
+
previously added. It iterates through all default callback events and replaces the current callbacks with the
|
|
1026
|
+
default ones.
|
|
1027
|
+
|
|
1028
|
+
The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined
|
|
1029
|
+
functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc.
|
|
1030
|
+
|
|
1031
|
+
This method is useful when you want to revert to the original set of callbacks after making custom
|
|
1032
|
+
modifications, ensuring consistent behavior across different runs or experiments.
|
|
1033
|
+
|
|
1034
|
+
Examples:
|
|
1035
|
+
>>> model = YOLO('yolov8n.pt')
|
|
1036
|
+
>>> model.add_callback('on_train_start', custom_function)
|
|
1037
|
+
>>> model.reset_callbacks()
|
|
1038
|
+
# All callbacks are now reset to their default functions
|
|
795
1039
|
"""
|
|
796
1040
|
for event in callbacks.default_callbacks.keys():
|
|
797
1041
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
|
798
1042
|
|
|
799
1043
|
@staticmethod
|
|
800
1044
|
def _reset_ckpt_args(args: dict) -> dict:
|
|
801
|
-
"""
|
|
1045
|
+
"""
|
|
1046
|
+
Resets specific arguments when loading a PyTorch model checkpoint.
|
|
1047
|
+
|
|
1048
|
+
This static method filters the input arguments dictionary to retain only a specific set of keys that are
|
|
1049
|
+
considered important for model loading. It's used to ensure that only relevant arguments are preserved
|
|
1050
|
+
when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
|
|
1051
|
+
|
|
1052
|
+
Args:
|
|
1053
|
+
args (dict): A dictionary containing various model arguments and settings.
|
|
1054
|
+
|
|
1055
|
+
Returns:
|
|
1056
|
+
(dict): A new dictionary containing only the specified include keys from the input arguments.
|
|
1057
|
+
|
|
1058
|
+
Examples:
|
|
1059
|
+
>>> original_args = {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect', 'batch': 16, 'epochs': 100}
|
|
1060
|
+
>>> reset_args = Model._reset_ckpt_args(original_args)
|
|
1061
|
+
>>> print(reset_args)
|
|
1062
|
+
{'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}
|
|
1063
|
+
"""
|
|
802
1064
|
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
|
803
1065
|
return {k: v for k, v in args.items() if k in include}
|
|
804
1066
|
|
|
@@ -808,7 +1070,31 @@ class Model(nn.Module):
|
|
|
808
1070
|
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
809
1071
|
|
|
810
1072
|
def _smart_load(self, key: str):
|
|
811
|
-
"""
|
|
1073
|
+
"""
|
|
1074
|
+
Loads the appropriate module based on the model task.
|
|
1075
|
+
|
|
1076
|
+
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
|
|
1077
|
+
based on the current task of the model and the provided key. It uses the task_map attribute to determine
|
|
1078
|
+
the correct module to load.
|
|
1079
|
+
|
|
1080
|
+
Args:
|
|
1081
|
+
key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
|
|
1082
|
+
|
|
1083
|
+
Returns:
|
|
1084
|
+
(object): The loaded module corresponding to the specified key and current task.
|
|
1085
|
+
|
|
1086
|
+
Raises:
|
|
1087
|
+
NotImplementedError: If the specified key is not supported for the current task.
|
|
1088
|
+
|
|
1089
|
+
Examples:
|
|
1090
|
+
>>> model = Model(task='detect')
|
|
1091
|
+
>>> predictor = model._smart_load('predictor')
|
|
1092
|
+
>>> trainer = model._smart_load('trainer')
|
|
1093
|
+
|
|
1094
|
+
Notes:
|
|
1095
|
+
- This method is typically used internally by other methods of the Model class.
|
|
1096
|
+
- The task_map attribute should be properly initialized with the correct mappings for each task.
|
|
1097
|
+
"""
|
|
812
1098
|
try:
|
|
813
1099
|
return self.task_map[self.task][key]
|
|
814
1100
|
except Exception as e:
|
|
@@ -821,9 +1107,30 @@ class Model(nn.Module):
|
|
|
821
1107
|
@property
|
|
822
1108
|
def task_map(self) -> dict:
|
|
823
1109
|
"""
|
|
824
|
-
|
|
1110
|
+
Provides a mapping from model tasks to corresponding classes for different modes.
|
|
1111
|
+
|
|
1112
|
+
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
|
|
1113
|
+
to a nested dictionary. The nested dictionary contains mappings for different operational modes
|
|
1114
|
+
(model, trainer, validator, predictor) to their respective class implementations.
|
|
1115
|
+
|
|
1116
|
+
The mapping allows for dynamic loading of appropriate classes based on the model's task and the
|
|
1117
|
+
desired operational mode. This facilitates a flexible and extensible architecture for handling
|
|
1118
|
+
various tasks and modes within the Ultralytics framework.
|
|
825
1119
|
|
|
826
1120
|
Returns:
|
|
827
|
-
|
|
1121
|
+
(Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
|
|
1122
|
+
nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
|
|
1123
|
+
'predictor', mapping to their respective class implementations.
|
|
1124
|
+
|
|
1125
|
+
Examples:
|
|
1126
|
+
>>> model = Model()
|
|
1127
|
+
>>> task_map = model.task_map
|
|
1128
|
+
>>> detect_class_map = task_map['detect']
|
|
1129
|
+
>>> segment_class_map = task_map['segment']
|
|
1130
|
+
|
|
1131
|
+
Note:
|
|
1132
|
+
The actual implementation of this method may vary depending on the specific tasks and
|
|
1133
|
+
classes supported by the Ultralytics framework. The docstring provides a general
|
|
1134
|
+
description of the expected behavior and structure.
|
|
828
1135
|
"""
|
|
829
1136
|
raise NotImplementedError("Please provide task map for your model!")
|