dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/engine/model.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import inspect
|
|
4
6
|
from pathlib import Path
|
|
5
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
6
8
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
import torch
|
|
@@ -10,7 +12,7 @@ from PIL import Image
|
|
|
10
12
|
|
|
11
13
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
|
12
14
|
from ultralytics.engine.results import Results
|
|
13
|
-
from ultralytics.nn.tasks import
|
|
15
|
+
from ultralytics.nn.tasks import guess_model_task, load_checkpoint, yaml_model_load
|
|
14
16
|
from ultralytics.utils import (
|
|
15
17
|
ARGV,
|
|
16
18
|
ASSETS,
|
|
@@ -25,12 +27,11 @@ from ultralytics.utils import (
|
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class Model(torch.nn.Module):
|
|
28
|
-
"""
|
|
29
|
-
A base class for implementing YOLO models, unifying APIs across different model types.
|
|
30
|
+
"""A base class for implementing YOLO models, unifying APIs across different model types.
|
|
30
31
|
|
|
31
|
-
This class provides a common interface for various operations related to YOLO models, such as training,
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
This class provides a common interface for various operations related to YOLO models, such as training, validation,
|
|
33
|
+
prediction, exporting, and benchmarking. It handles different types of models, including those loaded from local
|
|
34
|
+
files, Ultralytics HUB, or Triton Server.
|
|
34
35
|
|
|
35
36
|
Attributes:
|
|
36
37
|
callbacks (dict): A dictionary of callback functions for various events during model operations.
|
|
@@ -48,25 +49,25 @@ class Model(torch.nn.Module):
|
|
|
48
49
|
|
|
49
50
|
Methods:
|
|
50
51
|
__call__: Alias for the predict method, enabling the model instance to be callable.
|
|
51
|
-
_new:
|
|
52
|
-
_load:
|
|
53
|
-
_check_is_pytorch_model:
|
|
54
|
-
reset_weights:
|
|
55
|
-
load:
|
|
56
|
-
save:
|
|
57
|
-
info:
|
|
58
|
-
fuse:
|
|
59
|
-
predict:
|
|
60
|
-
track:
|
|
61
|
-
val:
|
|
62
|
-
benchmark:
|
|
63
|
-
export:
|
|
64
|
-
train:
|
|
65
|
-
tune:
|
|
66
|
-
_apply:
|
|
67
|
-
add_callback:
|
|
68
|
-
clear_callback:
|
|
69
|
-
reset_callbacks:
|
|
52
|
+
_new: Initialize a new model based on a configuration file.
|
|
53
|
+
_load: Load a model from a checkpoint file.
|
|
54
|
+
_check_is_pytorch_model: Ensure that the model is a PyTorch model.
|
|
55
|
+
reset_weights: Reset the model's weights to their initial state.
|
|
56
|
+
load: Load model weights from a specified file.
|
|
57
|
+
save: Save the current state of the model to a file.
|
|
58
|
+
info: Log or return information about the model.
|
|
59
|
+
fuse: Fuse Conv2d and BatchNorm2d layers for optimized inference.
|
|
60
|
+
predict: Perform object detection predictions.
|
|
61
|
+
track: Perform object tracking.
|
|
62
|
+
val: Validate the model on a dataset.
|
|
63
|
+
benchmark: Benchmark the model on various export formats.
|
|
64
|
+
export: Export the model to different formats.
|
|
65
|
+
train: Train the model on a dataset.
|
|
66
|
+
tune: Perform hyperparameter tuning.
|
|
67
|
+
_apply: Apply a function to the model's tensors.
|
|
68
|
+
add_callback: Add a callback function for an event.
|
|
69
|
+
clear_callback: Clear all callbacks for an event.
|
|
70
|
+
reset_callbacks: Reset all callbacks to their default functions.
|
|
70
71
|
|
|
71
72
|
Examples:
|
|
72
73
|
>>> from ultralytics import YOLO
|
|
@@ -79,24 +80,21 @@ class Model(torch.nn.Module):
|
|
|
79
80
|
|
|
80
81
|
def __init__(
|
|
81
82
|
self,
|
|
82
|
-
model:
|
|
83
|
-
task: str = None,
|
|
83
|
+
model: str | Path | Model = "yolo11n.pt",
|
|
84
|
+
task: str | None = None,
|
|
84
85
|
verbose: bool = False,
|
|
85
86
|
) -> None:
|
|
86
|
-
"""
|
|
87
|
-
Initialize a new instance of the YOLO model class.
|
|
87
|
+
"""Initialize a new instance of the YOLO model class.
|
|
88
88
|
|
|
89
|
-
This constructor sets up the model based on the provided model path or name. It handles various types of
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
prediction, or export.
|
|
89
|
+
This constructor sets up the model based on the provided model path or name. It handles various types of model
|
|
90
|
+
sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
|
|
91
|
+
important attributes of the model and prepares it for operations like training, prediction, or export.
|
|
93
92
|
|
|
94
93
|
Args:
|
|
95
|
-
model (str | Path): Path or name of the model to load or create. Can be a local file path, a
|
|
96
|
-
|
|
97
|
-
task (str
|
|
98
|
-
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
|
99
|
-
operations.
|
|
94
|
+
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a model
|
|
95
|
+
name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
|
|
96
|
+
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
|
97
|
+
verbose (bool): If True, enables verbose output during the model's initialization and subsequent operations.
|
|
100
98
|
|
|
101
99
|
Raises:
|
|
102
100
|
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
|
@@ -108,6 +106,9 @@ class Model(torch.nn.Module):
|
|
|
108
106
|
>>> model = Model("path/to/model.yaml", task="detect")
|
|
109
107
|
>>> model = Model("hub_model", verbose=True)
|
|
110
108
|
"""
|
|
109
|
+
if isinstance(model, Model):
|
|
110
|
+
self.__dict__ = model.__dict__ # accepts an already initialized Model
|
|
111
|
+
return
|
|
111
112
|
super().__init__()
|
|
112
113
|
self.callbacks = callbacks.get_default_callbacks()
|
|
113
114
|
self.predictor = None # reuse predictor
|
|
@@ -152,26 +153,25 @@ class Model(torch.nn.Module):
|
|
|
152
153
|
|
|
153
154
|
def __call__(
|
|
154
155
|
self,
|
|
155
|
-
source:
|
|
156
|
+
source: str | Path | int | Image.Image | list | tuple | np.ndarray | torch.Tensor = None,
|
|
156
157
|
stream: bool = False,
|
|
157
158
|
**kwargs: Any,
|
|
158
159
|
) -> list:
|
|
159
|
-
"""
|
|
160
|
-
Alias for the predict method, enabling the model instance to be callable for predictions.
|
|
160
|
+
"""Alias for the predict method, enabling the model instance to be callable for predictions.
|
|
161
161
|
|
|
162
|
-
This method simplifies the process of making predictions by allowing the model instance to be called
|
|
163
|
-
|
|
162
|
+
This method simplifies the process of making predictions by allowing the model instance to be called directly
|
|
163
|
+
with the required arguments.
|
|
164
164
|
|
|
165
165
|
Args:
|
|
166
|
-
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor |
|
|
167
|
-
|
|
168
|
-
|
|
166
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
|
|
167
|
+
to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch tensor, or a list/tuple
|
|
168
|
+
of these.
|
|
169
169
|
stream (bool): If True, treat the input source as a continuous stream for predictions.
|
|
170
170
|
**kwargs (Any): Additional keyword arguments to configure the prediction process.
|
|
171
171
|
|
|
172
172
|
Returns:
|
|
173
|
-
(
|
|
174
|
-
|
|
173
|
+
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
|
|
174
|
+
object.
|
|
175
175
|
|
|
176
176
|
Examples:
|
|
177
177
|
>>> model = YOLO("yolo11n.pt")
|
|
@@ -183,11 +183,10 @@ class Model(torch.nn.Module):
|
|
|
183
183
|
|
|
184
184
|
@staticmethod
|
|
185
185
|
def is_triton_model(model: str) -> bool:
|
|
186
|
-
"""
|
|
187
|
-
Check if the given model string is a Triton Server URL.
|
|
186
|
+
"""Check if the given model string is a Triton Server URL.
|
|
188
187
|
|
|
189
|
-
This static method determines whether the provided model string represents a valid Triton Server URL by
|
|
190
|
-
|
|
188
|
+
This static method determines whether the provided model string represents a valid Triton Server URL by parsing
|
|
189
|
+
its components using urllib.parse.urlsplit().
|
|
191
190
|
|
|
192
191
|
Args:
|
|
193
192
|
model (str): The model string to be checked.
|
|
@@ -208,8 +207,7 @@ class Model(torch.nn.Module):
|
|
|
208
207
|
|
|
209
208
|
@staticmethod
|
|
210
209
|
def is_hub_model(model: str) -> bool:
|
|
211
|
-
"""
|
|
212
|
-
Check if the provided model is an Ultralytics HUB model.
|
|
210
|
+
"""Check if the provided model is an Ultralytics HUB model.
|
|
213
211
|
|
|
214
212
|
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
|
215
213
|
identifier.
|
|
@@ -231,16 +229,15 @@ class Model(torch.nn.Module):
|
|
|
231
229
|
return model.startswith(f"{HUB_WEB_ROOT}/models/")
|
|
232
230
|
|
|
233
231
|
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
|
234
|
-
"""
|
|
235
|
-
Initialize a new model and infer the task type from model definitions.
|
|
232
|
+
"""Initialize a new model and infer the task type from model definitions.
|
|
236
233
|
|
|
237
|
-
Creates a new model instance based on the provided configuration file. Loads the model configuration, infers
|
|
238
|
-
|
|
234
|
+
Creates a new model instance based on the provided configuration file. Loads the model configuration, infers the
|
|
235
|
+
task type if not specified, and initializes the model using the appropriate class from the task map.
|
|
239
236
|
|
|
240
237
|
Args:
|
|
241
238
|
cfg (str): Path to the model configuration file in YAML format.
|
|
242
|
-
task (str
|
|
243
|
-
model (torch.nn.Module
|
|
239
|
+
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
|
240
|
+
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of creating
|
|
244
241
|
a new one.
|
|
245
242
|
verbose (bool): If True, displays model information during loading.
|
|
246
243
|
|
|
@@ -265,15 +262,14 @@ class Model(torch.nn.Module):
|
|
|
265
262
|
self.model_name = cfg
|
|
266
263
|
|
|
267
264
|
def _load(self, weights: str, task=None) -> None:
|
|
268
|
-
"""
|
|
269
|
-
Load a model from a checkpoint file or initialize it from a weights file.
|
|
265
|
+
"""Load a model from a checkpoint file or initialize it from a weights file.
|
|
270
266
|
|
|
271
|
-
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
|
|
272
|
-
|
|
267
|
+
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets up the
|
|
268
|
+
model, task, and related attributes based on the loaded weights.
|
|
273
269
|
|
|
274
270
|
Args:
|
|
275
271
|
weights (str): Path to the model weights file to be loaded.
|
|
276
|
-
task (str
|
|
272
|
+
task (str, optional): The task associated with the model. If None, it will be inferred from the model.
|
|
277
273
|
|
|
278
274
|
Raises:
|
|
279
275
|
FileNotFoundError: If the specified weights file does not exist or is inaccessible.
|
|
@@ -288,9 +284,9 @@ class Model(torch.nn.Module):
|
|
|
288
284
|
weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
|
|
289
285
|
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo11n -> yolo11n.pt
|
|
290
286
|
|
|
291
|
-
if
|
|
292
|
-
self.model, self.ckpt =
|
|
293
|
-
self.task = self.model.
|
|
287
|
+
if str(weights).rpartition(".")[-1] == "pt":
|
|
288
|
+
self.model, self.ckpt = load_checkpoint(weights)
|
|
289
|
+
self.task = self.model.task
|
|
294
290
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
|
295
291
|
self.ckpt_path = self.model.pt_path
|
|
296
292
|
else:
|
|
@@ -303,11 +299,10 @@ class Model(torch.nn.Module):
|
|
|
303
299
|
self.model_name = weights
|
|
304
300
|
|
|
305
301
|
def _check_is_pytorch_model(self) -> None:
|
|
306
|
-
"""
|
|
307
|
-
Check if the model is a PyTorch model and raise TypeError if it's not.
|
|
302
|
+
"""Check if the model is a PyTorch model and raise TypeError if it's not.
|
|
308
303
|
|
|
309
|
-
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
|
|
310
|
-
|
|
304
|
+
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that certain
|
|
305
|
+
operations that require a PyTorch model are only performed on compatible model types.
|
|
311
306
|
|
|
312
307
|
Raises:
|
|
313
308
|
TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
|
|
@@ -319,7 +314,7 @@ class Model(torch.nn.Module):
|
|
|
319
314
|
>>> model = Model("yolo11n.onnx")
|
|
320
315
|
>>> model._check_is_pytorch_model() # Raises TypeError
|
|
321
316
|
"""
|
|
322
|
-
pt_str = isinstance(self.model, (str, Path)) and
|
|
317
|
+
pt_str = isinstance(self.model, (str, Path)) and str(self.model).rpartition(".")[-1] == "pt"
|
|
323
318
|
pt_module = isinstance(self.model, torch.nn.Module)
|
|
324
319
|
if not (pt_module or pt_str):
|
|
325
320
|
raise TypeError(
|
|
@@ -330,13 +325,12 @@ class Model(torch.nn.Module):
|
|
|
330
325
|
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
|
331
326
|
)
|
|
332
327
|
|
|
333
|
-
def reset_weights(self) ->
|
|
334
|
-
"""
|
|
335
|
-
Reset the model's weights to their initial state.
|
|
328
|
+
def reset_weights(self) -> Model:
|
|
329
|
+
"""Reset the model's weights to their initial state.
|
|
336
330
|
|
|
337
331
|
This method iterates through all modules in the model and resets their parameters if they have a
|
|
338
|
-
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
|
|
339
|
-
|
|
332
|
+
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
|
|
333
|
+
to be updated during training.
|
|
340
334
|
|
|
341
335
|
Returns:
|
|
342
336
|
(Model): The instance of the class with reset weights.
|
|
@@ -356,15 +350,14 @@ class Model(torch.nn.Module):
|
|
|
356
350
|
p.requires_grad = True
|
|
357
351
|
return self
|
|
358
352
|
|
|
359
|
-
def load(self, weights:
|
|
360
|
-
"""
|
|
361
|
-
Load parameters from the specified weights file into the model.
|
|
353
|
+
def load(self, weights: str | Path = "yolo11n.pt") -> Model:
|
|
354
|
+
"""Load parameters from the specified weights file into the model.
|
|
362
355
|
|
|
363
356
|
This method supports loading weights from a file or directly from a weights object. It matches parameters by
|
|
364
357
|
name and shape and transfers them to the model.
|
|
365
358
|
|
|
366
359
|
Args:
|
|
367
|
-
weights (
|
|
360
|
+
weights (str | Path): Path to the weights file or a weights object.
|
|
368
361
|
|
|
369
362
|
Returns:
|
|
370
363
|
(Model): The instance of the class with loaded weights.
|
|
@@ -380,16 +373,15 @@ class Model(torch.nn.Module):
|
|
|
380
373
|
self._check_is_pytorch_model()
|
|
381
374
|
if isinstance(weights, (str, Path)):
|
|
382
375
|
self.overrides["pretrained"] = weights # remember the weights for DDP training
|
|
383
|
-
weights, self.ckpt =
|
|
376
|
+
weights, self.ckpt = load_checkpoint(weights)
|
|
384
377
|
self.model.load(weights)
|
|
385
378
|
return self
|
|
386
379
|
|
|
387
|
-
def save(self, filename:
|
|
388
|
-
"""
|
|
389
|
-
Save the current model state to a file.
|
|
380
|
+
def save(self, filename: str | Path = "saved_model.pt") -> None:
|
|
381
|
+
"""Save the current model state to a file.
|
|
390
382
|
|
|
391
|
-
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
|
|
392
|
-
|
|
383
|
+
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as the
|
|
384
|
+
date, Ultralytics version, license information, and a link to the documentation.
|
|
393
385
|
|
|
394
386
|
Args:
|
|
395
387
|
filename (str | Path): The name of the file to save the model to.
|
|
@@ -417,8 +409,7 @@ class Model(torch.nn.Module):
|
|
|
417
409
|
torch.save({**self.ckpt, **updates}, filename)
|
|
418
410
|
|
|
419
411
|
def info(self, detailed: bool = False, verbose: bool = True):
|
|
420
|
-
"""
|
|
421
|
-
Display model information.
|
|
412
|
+
"""Display model information.
|
|
422
413
|
|
|
423
414
|
This method provides an overview or detailed information about the model, depending on the arguments
|
|
424
415
|
passed. It can control the verbosity of the output and return the information as a list.
|
|
@@ -428,8 +419,8 @@ class Model(torch.nn.Module):
|
|
|
428
419
|
verbose (bool): If True, prints the information. If False, returns the information as a list.
|
|
429
420
|
|
|
430
421
|
Returns:
|
|
431
|
-
(
|
|
432
|
-
|
|
422
|
+
(list[str]): A list of strings containing various types of information about the model, including model
|
|
423
|
+
summary, layer details, and parameter counts. Empty if verbose is True.
|
|
433
424
|
|
|
434
425
|
Examples:
|
|
435
426
|
>>> model = Model("yolo11n.pt")
|
|
@@ -440,12 +431,11 @@ class Model(torch.nn.Module):
|
|
|
440
431
|
return self.model.info(detailed=detailed, verbose=verbose)
|
|
441
432
|
|
|
442
433
|
def fuse(self) -> None:
|
|
443
|
-
"""
|
|
444
|
-
Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
|
|
434
|
+
"""Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
|
|
445
435
|
|
|
446
|
-
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
|
|
447
|
-
|
|
448
|
-
|
|
436
|
+
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers into a
|
|
437
|
+
single layer. This fusion can significantly improve inference speed by reducing the number of operations and
|
|
438
|
+
memory accesses required during forward passes.
|
|
449
439
|
|
|
450
440
|
The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
|
|
451
441
|
bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
|
|
@@ -461,24 +451,23 @@ class Model(torch.nn.Module):
|
|
|
461
451
|
|
|
462
452
|
def embed(
|
|
463
453
|
self,
|
|
464
|
-
source:
|
|
454
|
+
source: str | Path | int | list | tuple | np.ndarray | torch.Tensor = None,
|
|
465
455
|
stream: bool = False,
|
|
466
456
|
**kwargs: Any,
|
|
467
457
|
) -> list:
|
|
468
|
-
"""
|
|
469
|
-
Generate image embeddings based on the provided source.
|
|
458
|
+
"""Generate image embeddings based on the provided source.
|
|
470
459
|
|
|
471
460
|
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
|
472
461
|
source. It allows customization of the embedding process through various keyword arguments.
|
|
473
462
|
|
|
474
463
|
Args:
|
|
475
|
-
source (str | Path | int |
|
|
476
|
-
|
|
464
|
+
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for generating
|
|
465
|
+
embeddings. Can be a file path, URL, PIL image, numpy array, etc.
|
|
477
466
|
stream (bool): If True, predictions are streamed.
|
|
478
467
|
**kwargs (Any): Additional keyword arguments for configuring the embedding process.
|
|
479
468
|
|
|
480
469
|
Returns:
|
|
481
|
-
(
|
|
470
|
+
(list[torch.Tensor]): A list containing the image embeddings.
|
|
482
471
|
|
|
483
472
|
Examples:
|
|
484
473
|
>>> model = YOLO("yolo11n.pt")
|
|
@@ -492,30 +481,29 @@ class Model(torch.nn.Module):
|
|
|
492
481
|
|
|
493
482
|
def predict(
|
|
494
483
|
self,
|
|
495
|
-
source:
|
|
484
|
+
source: str | Path | int | Image.Image | list | tuple | np.ndarray | torch.Tensor = None,
|
|
496
485
|
stream: bool = False,
|
|
497
486
|
predictor=None,
|
|
498
487
|
**kwargs: Any,
|
|
499
|
-
) ->
|
|
500
|
-
"""
|
|
501
|
-
Performs predictions on the given image source using the YOLO model.
|
|
488
|
+
) -> list[Results]:
|
|
489
|
+
"""Perform predictions on the given image source using the YOLO model.
|
|
502
490
|
|
|
503
|
-
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
|
504
|
-
|
|
505
|
-
|
|
491
|
+
This method facilitates the prediction process, allowing various configurations through keyword arguments. It
|
|
492
|
+
supports predictions with custom predictors or the default predictor method. The method handles different types
|
|
493
|
+
of image sources and can operate in a streaming mode.
|
|
506
494
|
|
|
507
495
|
Args:
|
|
508
|
-
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor |
|
|
509
|
-
|
|
510
|
-
|
|
496
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
|
|
497
|
+
to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
|
498
|
+
torch tensors.
|
|
511
499
|
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
|
512
|
-
predictor (BasePredictor
|
|
513
|
-
|
|
500
|
+
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions. If
|
|
501
|
+
None, the method uses a default predictor.
|
|
514
502
|
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
|
|
515
503
|
|
|
516
504
|
Returns:
|
|
517
|
-
(
|
|
518
|
-
|
|
505
|
+
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
|
|
506
|
+
object.
|
|
519
507
|
|
|
520
508
|
Examples:
|
|
521
509
|
>>> model = YOLO("yolo11n.pt")
|
|
@@ -553,27 +541,26 @@ class Model(torch.nn.Module):
|
|
|
553
541
|
|
|
554
542
|
def track(
|
|
555
543
|
self,
|
|
556
|
-
source:
|
|
544
|
+
source: str | Path | int | list | tuple | np.ndarray | torch.Tensor = None,
|
|
557
545
|
stream: bool = False,
|
|
558
546
|
persist: bool = False,
|
|
559
547
|
**kwargs: Any,
|
|
560
|
-
) ->
|
|
561
|
-
"""
|
|
562
|
-
Conducts object tracking on the specified input source using the registered trackers.
|
|
548
|
+
) -> list[Results]:
|
|
549
|
+
"""Conduct object tracking on the specified input source using the registered trackers.
|
|
563
550
|
|
|
564
551
|
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
|
565
552
|
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
|
566
553
|
The method registers trackers if not already present and can persist them between calls.
|
|
567
554
|
|
|
568
555
|
Args:
|
|
569
|
-
source (
|
|
556
|
+
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor, optional): Input source for object
|
|
570
557
|
tracking. Can be a file path, URL, or video stream.
|
|
571
558
|
stream (bool): If True, treats the input source as a continuous video stream.
|
|
572
559
|
persist (bool): If True, persists trackers between different calls to this method.
|
|
573
560
|
**kwargs (Any): Additional keyword arguments for configuring the tracking process.
|
|
574
561
|
|
|
575
562
|
Returns:
|
|
576
|
-
(
|
|
563
|
+
(list[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
|
|
577
564
|
|
|
578
565
|
Examples:
|
|
579
566
|
>>> model = YOLO("yolo11n.pt")
|
|
@@ -600,16 +587,15 @@ class Model(torch.nn.Module):
|
|
|
600
587
|
validator=None,
|
|
601
588
|
**kwargs: Any,
|
|
602
589
|
):
|
|
603
|
-
"""
|
|
604
|
-
Validate the model using a specified dataset and validation configuration.
|
|
590
|
+
"""Validate the model using a specified dataset and validation configuration.
|
|
605
591
|
|
|
606
592
|
This method facilitates the model validation process, allowing for customization through various settings. It
|
|
607
593
|
supports validation with a custom validator or the default validation approach. The method combines default
|
|
608
594
|
configurations, method-specific defaults, and user-provided arguments to configure the validation process.
|
|
609
595
|
|
|
610
596
|
Args:
|
|
611
|
-
validator (ultralytics.engine.validator.BaseValidator
|
|
612
|
-
validating the model.
|
|
597
|
+
validator (ultralytics.engine.validator.BaseValidator, optional): An instance of a custom validator class
|
|
598
|
+
for validating the model.
|
|
613
599
|
**kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
|
|
614
600
|
|
|
615
601
|
Returns:
|
|
@@ -631,31 +617,27 @@ class Model(torch.nn.Module):
|
|
|
631
617
|
self.metrics = validator.metrics
|
|
632
618
|
return validator.metrics
|
|
633
619
|
|
|
634
|
-
def benchmark(
|
|
635
|
-
|
|
636
|
-
**kwargs: Any,
|
|
637
|
-
):
|
|
638
|
-
"""
|
|
639
|
-
Benchmark the model across various export formats to evaluate performance.
|
|
620
|
+
def benchmark(self, data=None, format="", verbose=False, **kwargs: Any):
|
|
621
|
+
"""Benchmark the model across various export formats to evaluate performance.
|
|
640
622
|
|
|
641
|
-
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
623
|
+
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. It
|
|
624
|
+
uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured using
|
|
625
|
+
a combination of default configuration values, model-specific arguments, method-specific defaults, and any
|
|
626
|
+
additional user-provided keyword arguments.
|
|
645
627
|
|
|
646
628
|
Args:
|
|
629
|
+
data (str): Path to the dataset for benchmarking.
|
|
630
|
+
verbose (bool): Whether to print detailed benchmark information.
|
|
631
|
+
format (str): Export format name for specific benchmarking.
|
|
647
632
|
**kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. Common options include:
|
|
648
|
-
-
|
|
649
|
-
- imgsz (int | List[int]): Image size for benchmarking.
|
|
633
|
+
- imgsz (int | list[int]): Image size for benchmarking.
|
|
650
634
|
- half (bool): Whether to use half-precision (FP16) mode.
|
|
651
635
|
- int8 (bool): Whether to use int8 precision mode.
|
|
652
636
|
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
|
653
|
-
- verbose (bool): Whether to print detailed benchmark information.
|
|
654
|
-
- format (str): Export format name for specific benchmarking.
|
|
655
637
|
|
|
656
638
|
Returns:
|
|
657
|
-
(dict): A dictionary containing the results of the benchmarking process, including metrics for
|
|
658
|
-
|
|
639
|
+
(dict): A dictionary containing the results of the benchmarking process, including metrics for different
|
|
640
|
+
export formats.
|
|
659
641
|
|
|
660
642
|
Raises:
|
|
661
643
|
AssertionError: If the model is not a PyTorch model.
|
|
@@ -668,40 +650,42 @@ class Model(torch.nn.Module):
|
|
|
668
650
|
self._check_is_pytorch_model()
|
|
669
651
|
from ultralytics.utils.benchmarks import benchmark
|
|
670
652
|
|
|
653
|
+
from .exporter import export_formats
|
|
654
|
+
|
|
671
655
|
custom = {"verbose": False} # method defaults
|
|
672
656
|
args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
|
|
657
|
+
fmts = export_formats()
|
|
658
|
+
export_args = set(dict(zip(fmts["Argument"], fmts["Arguments"])).get(format, [])) - {"batch"}
|
|
659
|
+
export_kwargs = {k: v for k, v in args.items() if k in export_args}
|
|
673
660
|
return benchmark(
|
|
674
661
|
model=self,
|
|
675
|
-
data=
|
|
662
|
+
data=data, # if no 'data' argument passed set data=None for default datasets
|
|
676
663
|
imgsz=args["imgsz"],
|
|
677
|
-
half=args["half"],
|
|
678
|
-
int8=args["int8"],
|
|
679
664
|
device=args["device"],
|
|
680
|
-
verbose=
|
|
681
|
-
format=
|
|
665
|
+
verbose=verbose,
|
|
666
|
+
format=format,
|
|
667
|
+
**export_kwargs,
|
|
682
668
|
)
|
|
683
669
|
|
|
684
670
|
def export(
|
|
685
671
|
self,
|
|
686
672
|
**kwargs: Any,
|
|
687
673
|
) -> str:
|
|
688
|
-
"""
|
|
689
|
-
Export the model to a different format suitable for deployment.
|
|
674
|
+
"""Export the model to a different format suitable for deployment.
|
|
690
675
|
|
|
691
676
|
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
|
|
692
677
|
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
|
|
693
678
|
defaults, and any additional arguments provided.
|
|
694
679
|
|
|
695
680
|
Args:
|
|
696
|
-
**kwargs (Any): Arbitrary keyword arguments
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
simplify (bool): Simplify ONNX model.
|
|
681
|
+
**kwargs (Any): Arbitrary keyword arguments for export configuration. Common options include:
|
|
682
|
+
- format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
|
|
683
|
+
- half (bool): Export model in half-precision.
|
|
684
|
+
- int8 (bool): Export model in int8 precision.
|
|
685
|
+
- device (str): Device to run the export on.
|
|
686
|
+
- workspace (int): Maximum memory workspace size for TensorRT engines.
|
|
687
|
+
- nms (bool): Add Non-Maximum Suppression (NMS) module to model.
|
|
688
|
+
- simplify (bool): Simplify ONNX model.
|
|
705
689
|
|
|
706
690
|
Returns:
|
|
707
691
|
(str): The path to the exported model file.
|
|
@@ -734,32 +718,31 @@ class Model(torch.nn.Module):
|
|
|
734
718
|
trainer=None,
|
|
735
719
|
**kwargs: Any,
|
|
736
720
|
):
|
|
737
|
-
"""
|
|
738
|
-
Trains the model using the specified dataset and training configuration.
|
|
721
|
+
"""Train the model using the specified dataset and training configuration.
|
|
739
722
|
|
|
740
|
-
This method facilitates model training with a range of customizable settings. It supports training with a
|
|
741
|
-
|
|
742
|
-
|
|
723
|
+
This method facilitates model training with a range of customizable settings. It supports training with a custom
|
|
724
|
+
trainer or the default training approach. The method handles scenarios such as resuming training from a
|
|
725
|
+
checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
|
|
743
726
|
|
|
744
|
-
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
|
|
745
|
-
|
|
746
|
-
|
|
727
|
+
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training arguments and
|
|
728
|
+
warns if local arguments are provided. It checks for pip updates and combines default configurations,
|
|
729
|
+
method-specific defaults, and user-provided arguments to configure the training process.
|
|
747
730
|
|
|
748
731
|
Args:
|
|
749
|
-
trainer (BaseTrainer
|
|
732
|
+
trainer (BaseTrainer, optional): Custom trainer instance for model training. If None, uses default.
|
|
750
733
|
**kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
|
|
751
|
-
data (str): Path to dataset configuration file.
|
|
752
|
-
epochs (int): Number of training epochs.
|
|
753
|
-
|
|
754
|
-
imgsz (int): Input image size.
|
|
755
|
-
device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
|
756
|
-
workers (int): Number of worker threads for data loading.
|
|
757
|
-
optimizer (str): Optimizer to use for training.
|
|
758
|
-
lr0 (float): Initial learning rate.
|
|
759
|
-
patience (int): Epochs to wait for no observable improvement for early stopping of training.
|
|
734
|
+
- data (str): Path to dataset configuration file.
|
|
735
|
+
- epochs (int): Number of training epochs.
|
|
736
|
+
- batch (int): Batch size for training.
|
|
737
|
+
- imgsz (int): Input image size.
|
|
738
|
+
- device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
|
739
|
+
- workers (int): Number of worker threads for data loading.
|
|
740
|
+
- optimizer (str): Optimizer to use for training.
|
|
741
|
+
- lr0 (float): Initial learning rate.
|
|
742
|
+
- patience (int): Epochs to wait for no observable improvement for early stopping of training.
|
|
760
743
|
|
|
761
744
|
Returns:
|
|
762
|
-
(
|
|
745
|
+
(dict | None): Training metrics if available and training is successful; otherwise, None.
|
|
763
746
|
|
|
764
747
|
Examples:
|
|
765
748
|
>>> model = YOLO("yolo11n.pt")
|
|
@@ -773,6 +756,8 @@ class Model(torch.nn.Module):
|
|
|
773
756
|
|
|
774
757
|
checks.check_pip_update_available()
|
|
775
758
|
|
|
759
|
+
if isinstance(kwargs.get("pretrained", None), (str, Path)):
|
|
760
|
+
self.load(kwargs["pretrained"]) # load pretrained weights if provided
|
|
776
761
|
overrides = YAML.load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
|
|
777
762
|
custom = {
|
|
778
763
|
# NOTE: handle the case when 'cfg' includes 'data'.
|
|
@@ -780,7 +765,7 @@ class Model(torch.nn.Module):
|
|
|
780
765
|
"model": self.overrides["model"],
|
|
781
766
|
"task": self.task,
|
|
782
767
|
} # method defaults
|
|
783
|
-
args = {**overrides, **custom, **kwargs, "mode": "train"} #
|
|
768
|
+
args = {**overrides, **custom, **kwargs, "mode": "train", "session": self.session} # prioritizes rightmost args
|
|
784
769
|
if args.get("resume"):
|
|
785
770
|
args["resume"] = self.ckpt_path
|
|
786
771
|
|
|
@@ -789,13 +774,12 @@ class Model(torch.nn.Module):
|
|
|
789
774
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
|
790
775
|
self.model = self.trainer.model
|
|
791
776
|
|
|
792
|
-
self.trainer.hub_session = self.session # attach optional HUB session
|
|
793
777
|
self.trainer.train()
|
|
794
778
|
# Update model and cfg after training
|
|
795
779
|
if RANK in {-1, 0}:
|
|
796
780
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
|
797
|
-
self.model, self.ckpt =
|
|
798
|
-
self.overrides = self.model.args
|
|
781
|
+
self.model, self.ckpt = load_checkpoint(ckpt)
|
|
782
|
+
self.overrides = self._reset_ckpt_args(self.model.args)
|
|
799
783
|
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
|
800
784
|
return self.metrics
|
|
801
785
|
|
|
@@ -806,13 +790,12 @@ class Model(torch.nn.Module):
|
|
|
806
790
|
*args: Any,
|
|
807
791
|
**kwargs: Any,
|
|
808
792
|
):
|
|
809
|
-
"""
|
|
810
|
-
Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
|
|
793
|
+
"""Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
|
|
811
794
|
|
|
812
|
-
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
795
|
+
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. When Ray Tune
|
|
796
|
+
is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. Otherwise, it uses
|
|
797
|
+
the internal 'Tuner' class for tuning. The method combines default, overridden, and custom arguments to
|
|
798
|
+
configure the tuning process.
|
|
816
799
|
|
|
817
800
|
Args:
|
|
818
801
|
use_ray (bool): Whether to use Ray Tune for hyperparameter tuning. If False, uses internal tuning method.
|
|
@@ -847,17 +830,16 @@ class Model(torch.nn.Module):
|
|
|
847
830
|
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
|
848
831
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
|
849
832
|
|
|
850
|
-
def _apply(self, fn) ->
|
|
851
|
-
"""
|
|
852
|
-
Apply a function to model tensors that are not parameters or registered buffers.
|
|
833
|
+
def _apply(self, fn) -> Model:
|
|
834
|
+
"""Apply a function to model tensors that are not parameters or registered buffers.
|
|
853
835
|
|
|
854
836
|
This method extends the functionality of the parent class's _apply method by additionally resetting the
|
|
855
|
-
predictor and updating the device in the model's overrides. It's typically used for operations like
|
|
856
|
-
|
|
837
|
+
predictor and updating the device in the model's overrides. It's typically used for operations like moving the
|
|
838
|
+
model to a different device or changing its precision.
|
|
857
839
|
|
|
858
840
|
Args:
|
|
859
|
-
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
|
|
860
|
-
|
|
841
|
+
fn (Callable): A function to be applied to the model's tensors. This is typically a method like to(), cpu(),
|
|
842
|
+
cuda(), half(), or float().
|
|
861
843
|
|
|
862
844
|
Returns:
|
|
863
845
|
(Model): The model instance with the function applied and updated attributes.
|
|
@@ -870,22 +852,21 @@ class Model(torch.nn.Module):
|
|
|
870
852
|
>>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
|
|
871
853
|
"""
|
|
872
854
|
self._check_is_pytorch_model()
|
|
873
|
-
self = super()._apply(fn)
|
|
855
|
+
self = super()._apply(fn)
|
|
874
856
|
self.predictor = None # reset predictor as device may have changed
|
|
875
857
|
self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
|
|
876
858
|
return self
|
|
877
859
|
|
|
878
860
|
@property
|
|
879
|
-
def names(self) ->
|
|
880
|
-
"""
|
|
881
|
-
Retrieves the class names associated with the loaded model.
|
|
861
|
+
def names(self) -> dict[int, str]:
|
|
862
|
+
"""Retrieve the class names associated with the loaded model.
|
|
882
863
|
|
|
883
864
|
This property returns the class names if they are defined in the model. It checks the class names for validity
|
|
884
865
|
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
|
885
866
|
initialized, it sets it up before retrieving the names.
|
|
886
867
|
|
|
887
868
|
Returns:
|
|
888
|
-
(
|
|
869
|
+
(dict[int, str]): A dictionary of class names associated with the model, where keys are class indices and
|
|
889
870
|
values are the corresponding class names.
|
|
890
871
|
|
|
891
872
|
Raises:
|
|
@@ -901,14 +882,14 @@ class Model(torch.nn.Module):
|
|
|
901
882
|
if hasattr(self.model, "names"):
|
|
902
883
|
return check_class_names(self.model.names)
|
|
903
884
|
if not self.predictor: # export formats will not have predictor defined until predict() is called
|
|
904
|
-
|
|
905
|
-
|
|
885
|
+
predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
|
|
886
|
+
predictor.setup_model(model=self.model, verbose=False) # do not mess with self.predictor.model args
|
|
887
|
+
return predictor.model.names
|
|
906
888
|
return self.predictor.model.names
|
|
907
889
|
|
|
908
890
|
@property
|
|
909
891
|
def device(self) -> torch.device:
|
|
910
|
-
"""
|
|
911
|
-
Get the device on which the model's parameters are allocated.
|
|
892
|
+
"""Get the device on which the model's parameters are allocated.
|
|
912
893
|
|
|
913
894
|
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
|
914
895
|
applicable only to models that are instances of torch.nn.Module.
|
|
@@ -931,12 +912,11 @@ class Model(torch.nn.Module):
|
|
|
931
912
|
|
|
932
913
|
@property
|
|
933
914
|
def transforms(self):
|
|
934
|
-
"""
|
|
935
|
-
Retrieves the transformations applied to the input data of the loaded model.
|
|
915
|
+
"""Retrieve the transformations applied to the input data of the loaded model.
|
|
936
916
|
|
|
937
|
-
This property returns the transformations if they are defined in the model. The transforms
|
|
938
|
-
|
|
939
|
-
|
|
917
|
+
This property returns the transformations if they are defined in the model. The transforms typically include
|
|
918
|
+
preprocessing steps like resizing, normalization, and data augmentation that are applied to input data before it
|
|
919
|
+
is fed into the model.
|
|
940
920
|
|
|
941
921
|
Returns:
|
|
942
922
|
(object | None): The transform object of the model if available, otherwise None.
|
|
@@ -952,18 +932,17 @@ class Model(torch.nn.Module):
|
|
|
952
932
|
return self.model.transforms if hasattr(self.model, "transforms") else None
|
|
953
933
|
|
|
954
934
|
def add_callback(self, event: str, func) -> None:
|
|
955
|
-
"""
|
|
956
|
-
Add a callback function for a specified event.
|
|
935
|
+
"""Add a callback function for a specified event.
|
|
957
936
|
|
|
958
|
-
This method allows registering custom callback functions that are triggered on specific events during
|
|
959
|
-
|
|
960
|
-
|
|
937
|
+
This method allows registering custom callback functions that are triggered on specific events during model
|
|
938
|
+
operations such as training or inference. Callbacks provide a way to extend and customize the behavior of the
|
|
939
|
+
model at various stages of its lifecycle.
|
|
961
940
|
|
|
962
941
|
Args:
|
|
963
|
-
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
|
|
964
|
-
|
|
965
|
-
func (Callable): The callback function to be registered. This function will be called when the
|
|
966
|
-
|
|
942
|
+
event (str): The name of the event to attach the callback to. Must be a valid event name recognized by the
|
|
943
|
+
Ultralytics framework.
|
|
944
|
+
func (Callable): The callback function to be registered. This function will be called when the specified
|
|
945
|
+
event occurs.
|
|
967
946
|
|
|
968
947
|
Raises:
|
|
969
948
|
ValueError: If the event name is not recognized or is invalid.
|
|
@@ -978,12 +957,11 @@ class Model(torch.nn.Module):
|
|
|
978
957
|
self.callbacks[event].append(func)
|
|
979
958
|
|
|
980
959
|
def clear_callback(self, event: str) -> None:
|
|
981
|
-
"""
|
|
982
|
-
Clears all callback functions registered for a specified event.
|
|
960
|
+
"""Clear all callback functions registered for a specified event.
|
|
983
961
|
|
|
984
|
-
This method removes all custom and default callback functions associated with the given event.
|
|
985
|
-
|
|
986
|
-
|
|
962
|
+
This method removes all custom and default callback functions associated with the given event. It resets the
|
|
963
|
+
callback list for the specified event to an empty list, effectively removing all registered callbacks for that
|
|
964
|
+
event.
|
|
987
965
|
|
|
988
966
|
Args:
|
|
989
967
|
event (str): The name of the event for which to clear the callbacks. This should be a valid event name
|
|
@@ -1006,8 +984,7 @@ class Model(torch.nn.Module):
|
|
|
1006
984
|
self.callbacks[event] = []
|
|
1007
985
|
|
|
1008
986
|
def reset_callbacks(self) -> None:
|
|
1009
|
-
"""
|
|
1010
|
-
Reset all callbacks to their default functions.
|
|
987
|
+
"""Reset all callbacks to their default functions.
|
|
1011
988
|
|
|
1012
989
|
This method reinstates the default callback functions for all events, removing any custom callbacks that were
|
|
1013
990
|
previously added. It iterates through all default callback events and replaces the current callbacks with the
|
|
@@ -1029,13 +1006,12 @@ class Model(torch.nn.Module):
|
|
|
1029
1006
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
|
1030
1007
|
|
|
1031
1008
|
@staticmethod
|
|
1032
|
-
def _reset_ckpt_args(args: dict) -> dict:
|
|
1033
|
-
"""
|
|
1034
|
-
Reset specific arguments when loading a PyTorch model checkpoint.
|
|
1009
|
+
def _reset_ckpt_args(args: dict[str, Any]) -> dict[str, Any]:
|
|
1010
|
+
"""Reset specific arguments when loading a PyTorch model checkpoint.
|
|
1035
1011
|
|
|
1036
|
-
This method filters the input arguments dictionary to retain only a specific set of keys that are
|
|
1037
|
-
|
|
1038
|
-
|
|
1012
|
+
This method filters the input arguments dictionary to retain only a specific set of keys that are considered
|
|
1013
|
+
important for model loading. It's used to ensure that only relevant arguments are preserved when loading a model
|
|
1014
|
+
from a checkpoint, discarding any unnecessary or potentially conflicting settings.
|
|
1039
1015
|
|
|
1040
1016
|
Args:
|
|
1041
1017
|
args (dict): A dictionary containing various model arguments and settings.
|
|
@@ -1058,12 +1034,11 @@ class Model(torch.nn.Module):
|
|
|
1058
1034
|
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
1059
1035
|
|
|
1060
1036
|
def _smart_load(self, key: str):
|
|
1061
|
-
"""
|
|
1062
|
-
Intelligently loads the appropriate module based on the model task.
|
|
1037
|
+
"""Intelligently load the appropriate module based on the model task.
|
|
1063
1038
|
|
|
1064
|
-
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
|
|
1065
|
-
|
|
1066
|
-
|
|
1039
|
+
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) based
|
|
1040
|
+
on the current task of the model and the provided key. It uses the task_map dictionary to determine the
|
|
1041
|
+
appropriate module to load for the specific task.
|
|
1067
1042
|
|
|
1068
1043
|
Args:
|
|
1069
1044
|
key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
|
|
@@ -1088,21 +1063,20 @@ class Model(torch.nn.Module):
|
|
|
1088
1063
|
|
|
1089
1064
|
@property
|
|
1090
1065
|
def task_map(self) -> dict:
|
|
1091
|
-
"""
|
|
1092
|
-
Provides a mapping from model tasks to corresponding classes for different modes.
|
|
1066
|
+
"""Provide a mapping from model tasks to corresponding classes for different modes.
|
|
1093
1067
|
|
|
1094
|
-
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
|
|
1095
|
-
|
|
1096
|
-
|
|
1068
|
+
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) to a
|
|
1069
|
+
nested dictionary. The nested dictionary contains mappings for different operational modes (model, trainer,
|
|
1070
|
+
validator, predictor) to their respective class implementations.
|
|
1097
1071
|
|
|
1098
|
-
The mapping allows for dynamic loading of appropriate classes based on the model's task and the
|
|
1099
|
-
|
|
1100
|
-
|
|
1072
|
+
The mapping allows for dynamic loading of appropriate classes based on the model's task and the desired
|
|
1073
|
+
operational mode. This facilitates a flexible and extensible architecture for handling various tasks and modes
|
|
1074
|
+
within the Ultralytics framework.
|
|
1101
1075
|
|
|
1102
1076
|
Returns:
|
|
1103
|
-
(
|
|
1104
|
-
|
|
1105
|
-
|
|
1077
|
+
(dict[str, dict[str, Any]]): A dictionary mapping task names to nested dictionaries. Each nested dictionary
|
|
1078
|
+
contains mappings for 'model', 'trainer', 'validator', and 'predictor' keys to their respective class
|
|
1079
|
+
implementations for that task.
|
|
1106
1080
|
|
|
1107
1081
|
Examples:
|
|
1108
1082
|
>>> model = Model("yolo11n.pt")
|
|
@@ -1113,8 +1087,7 @@ class Model(torch.nn.Module):
|
|
|
1113
1087
|
raise NotImplementedError("Please provide task map for your model!")
|
|
1114
1088
|
|
|
1115
1089
|
def eval(self):
|
|
1116
|
-
"""
|
|
1117
|
-
Sets the model to evaluation mode.
|
|
1090
|
+
"""Sets the model to evaluation mode.
|
|
1118
1091
|
|
|
1119
1092
|
This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
|
|
1120
1093
|
that behave differently during training and evaluation. In evaluation mode, these layers use running statistics
|
|
@@ -1132,8 +1105,7 @@ class Model(torch.nn.Module):
|
|
|
1132
1105
|
return self
|
|
1133
1106
|
|
|
1134
1107
|
def __getattr__(self, name):
|
|
1135
|
-
"""
|
|
1136
|
-
Enable accessing model attributes directly through the Model class.
|
|
1108
|
+
"""Enable accessing model attributes directly through the Model class.
|
|
1137
1109
|
|
|
1138
1110
|
This method provides a way to access attributes of the underlying model directly through the Model class
|
|
1139
1111
|
instance. It first checks if the requested attribute is 'model', in which case it returns the model from
|