dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/engine/model.py
CHANGED
|
@@ -27,12 +27,11 @@ from ultralytics.utils import (
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class Model(torch.nn.Module):
|
|
30
|
-
"""
|
|
31
|
-
A base class for implementing YOLO models, unifying APIs across different model types.
|
|
30
|
+
"""A base class for implementing YOLO models, unifying APIs across different model types.
|
|
32
31
|
|
|
33
|
-
This class provides a common interface for various operations related to YOLO models, such as training,
|
|
34
|
-
|
|
35
|
-
|
|
32
|
+
This class provides a common interface for various operations related to YOLO models, such as training, validation,
|
|
33
|
+
prediction, exporting, and benchmarking. It handles different types of models, including those loaded from local
|
|
34
|
+
files, Ultralytics HUB, or Triton Server.
|
|
36
35
|
|
|
37
36
|
Attributes:
|
|
38
37
|
callbacks (dict): A dictionary of callback functions for various events during model operations.
|
|
@@ -72,7 +71,7 @@ class Model(torch.nn.Module):
|
|
|
72
71
|
|
|
73
72
|
Examples:
|
|
74
73
|
>>> from ultralytics import YOLO
|
|
75
|
-
>>> model = YOLO("
|
|
74
|
+
>>> model = YOLO("yolo26n.pt")
|
|
76
75
|
>>> results = model.predict("image.jpg")
|
|
77
76
|
>>> model.train(data="coco8.yaml", epochs=3)
|
|
78
77
|
>>> metrics = model.val()
|
|
@@ -81,34 +80,26 @@ class Model(torch.nn.Module):
|
|
|
81
80
|
|
|
82
81
|
def __init__(
|
|
83
82
|
self,
|
|
84
|
-
model: str | Path | Model = "
|
|
85
|
-
task: str = None,
|
|
83
|
+
model: str | Path | Model = "yolo26n.pt",
|
|
84
|
+
task: str | None = None,
|
|
86
85
|
verbose: bool = False,
|
|
87
86
|
) -> None:
|
|
88
|
-
"""
|
|
89
|
-
Initialize a new instance of the YOLO model class.
|
|
87
|
+
"""Initialize a new instance of the YOLO model class.
|
|
90
88
|
|
|
91
|
-
This constructor sets up the model based on the provided model path or name. It handles various types of
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
prediction, or export.
|
|
89
|
+
This constructor sets up the model based on the provided model path or name. It handles various types of model
|
|
90
|
+
sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
|
|
91
|
+
important attributes of the model and prepares it for operations like training, prediction, or export.
|
|
95
92
|
|
|
96
93
|
Args:
|
|
97
|
-
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a
|
|
98
|
-
|
|
94
|
+
model (str | Path | Model): Path or name of the model to load or create. Can be a local file path, a model
|
|
95
|
+
name from Ultralytics HUB, a Triton Server model, or an already initialized Model instance.
|
|
99
96
|
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
|
100
|
-
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
|
101
|
-
operations.
|
|
97
|
+
verbose (bool): If True, enables verbose output during the model's initialization and subsequent operations.
|
|
102
98
|
|
|
103
99
|
Raises:
|
|
104
100
|
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
|
105
101
|
ValueError: If the model file or configuration is invalid or unsupported.
|
|
106
102
|
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
|
107
|
-
|
|
108
|
-
Examples:
|
|
109
|
-
>>> model = Model("yolo11n.pt")
|
|
110
|
-
>>> model = Model("path/to/model.yaml", task="detect")
|
|
111
|
-
>>> model = Model("hub_model", verbose=True)
|
|
112
103
|
"""
|
|
113
104
|
if isinstance(model, Model):
|
|
114
105
|
self.__dict__ = model.__dict__ # accepts an already initialized Model
|
|
@@ -161,25 +152,24 @@ class Model(torch.nn.Module):
|
|
|
161
152
|
stream: bool = False,
|
|
162
153
|
**kwargs: Any,
|
|
163
154
|
) -> list:
|
|
164
|
-
"""
|
|
165
|
-
Alias for the predict method, enabling the model instance to be callable for predictions.
|
|
155
|
+
"""Alias for the predict method, enabling the model instance to be callable for predictions.
|
|
166
156
|
|
|
167
|
-
This method simplifies the process of making predictions by allowing the model instance to be called
|
|
168
|
-
|
|
157
|
+
This method simplifies the process of making predictions by allowing the model instance to be called directly
|
|
158
|
+
with the required arguments.
|
|
169
159
|
|
|
170
160
|
Args:
|
|
171
|
-
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of
|
|
172
|
-
|
|
173
|
-
|
|
161
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
|
|
162
|
+
to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch tensor, or a list/tuple
|
|
163
|
+
of these.
|
|
174
164
|
stream (bool): If True, treat the input source as a continuous stream for predictions.
|
|
175
165
|
**kwargs (Any): Additional keyword arguments to configure the prediction process.
|
|
176
166
|
|
|
177
167
|
Returns:
|
|
178
|
-
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
|
179
|
-
|
|
168
|
+
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
|
|
169
|
+
object.
|
|
180
170
|
|
|
181
171
|
Examples:
|
|
182
|
-
>>> model = YOLO("
|
|
172
|
+
>>> model = YOLO("yolo26n.pt")
|
|
183
173
|
>>> results = model("https://ultralytics.com/images/bus.jpg")
|
|
184
174
|
>>> for r in results:
|
|
185
175
|
... print(f"Detected {len(r)} objects in image")
|
|
@@ -188,11 +178,10 @@ class Model(torch.nn.Module):
|
|
|
188
178
|
|
|
189
179
|
@staticmethod
|
|
190
180
|
def is_triton_model(model: str) -> bool:
|
|
191
|
-
"""
|
|
192
|
-
Check if the given model string is a Triton Server URL.
|
|
181
|
+
"""Check if the given model string is a Triton Server URL.
|
|
193
182
|
|
|
194
|
-
This static method determines whether the provided model string represents a valid Triton Server URL by
|
|
195
|
-
|
|
183
|
+
This static method determines whether the provided model string represents a valid Triton Server URL by parsing
|
|
184
|
+
its components using urllib.parse.urlsplit().
|
|
196
185
|
|
|
197
186
|
Args:
|
|
198
187
|
model (str): The model string to be checked.
|
|
@@ -203,7 +192,7 @@ class Model(torch.nn.Module):
|
|
|
203
192
|
Examples:
|
|
204
193
|
>>> Model.is_triton_model("http://localhost:8000/v2/models/yolo11n")
|
|
205
194
|
True
|
|
206
|
-
>>> Model.is_triton_model("
|
|
195
|
+
>>> Model.is_triton_model("yolo26n.pt")
|
|
207
196
|
False
|
|
208
197
|
"""
|
|
209
198
|
from urllib.parse import urlsplit
|
|
@@ -213,8 +202,7 @@ class Model(torch.nn.Module):
|
|
|
213
202
|
|
|
214
203
|
@staticmethod
|
|
215
204
|
def is_hub_model(model: str) -> bool:
|
|
216
|
-
"""
|
|
217
|
-
Check if the provided model is an Ultralytics HUB model.
|
|
205
|
+
"""Check if the provided model is an Ultralytics HUB model.
|
|
218
206
|
|
|
219
207
|
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
|
220
208
|
identifier.
|
|
@@ -228,7 +216,7 @@ class Model(torch.nn.Module):
|
|
|
228
216
|
Examples:
|
|
229
217
|
>>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
|
|
230
218
|
True
|
|
231
|
-
>>> Model.is_hub_model("
|
|
219
|
+
>>> Model.is_hub_model("yolo26n.pt")
|
|
232
220
|
False
|
|
233
221
|
"""
|
|
234
222
|
from ultralytics.hub import HUB_WEB_ROOT
|
|
@@ -236,17 +224,16 @@ class Model(torch.nn.Module):
|
|
|
236
224
|
return model.startswith(f"{HUB_WEB_ROOT}/models/")
|
|
237
225
|
|
|
238
226
|
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
|
239
|
-
"""
|
|
240
|
-
Initialize a new model and infer the task type from model definitions.
|
|
227
|
+
"""Initialize a new model and infer the task type from model definitions.
|
|
241
228
|
|
|
242
|
-
Creates a new model instance based on the provided configuration file. Loads the model configuration, infers
|
|
243
|
-
|
|
229
|
+
Creates a new model instance based on the provided configuration file. Loads the model configuration, infers the
|
|
230
|
+
task type if not specified, and initializes the model using the appropriate class from the task map.
|
|
244
231
|
|
|
245
232
|
Args:
|
|
246
233
|
cfg (str): Path to the model configuration file in YAML format.
|
|
247
234
|
task (str, optional): The specific task for the model. If None, it will be inferred from the config.
|
|
248
|
-
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of
|
|
249
|
-
|
|
235
|
+
model (torch.nn.Module, optional): A custom model instance. If provided, it will be used instead of creating
|
|
236
|
+
a new one.
|
|
250
237
|
verbose (bool): If True, displays model information during loading.
|
|
251
238
|
|
|
252
239
|
Raises:
|
|
@@ -255,7 +242,7 @@ class Model(torch.nn.Module):
|
|
|
255
242
|
|
|
256
243
|
Examples:
|
|
257
244
|
>>> model = Model()
|
|
258
|
-
>>> model._new("
|
|
245
|
+
>>> model._new("yolo26n.yaml", task="detect", verbose=True)
|
|
259
246
|
"""
|
|
260
247
|
cfg_dict = yaml_model_load(cfg)
|
|
261
248
|
self.cfg = cfg
|
|
@@ -270,11 +257,10 @@ class Model(torch.nn.Module):
|
|
|
270
257
|
self.model_name = cfg
|
|
271
258
|
|
|
272
259
|
def _load(self, weights: str, task=None) -> None:
|
|
273
|
-
"""
|
|
274
|
-
Load a model from a checkpoint file or initialize it from a weights file.
|
|
260
|
+
"""Load a model from a checkpoint file or initialize it from a weights file.
|
|
275
261
|
|
|
276
|
-
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
|
|
277
|
-
|
|
262
|
+
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets up the
|
|
263
|
+
model, task, and related attributes based on the loaded weights.
|
|
278
264
|
|
|
279
265
|
Args:
|
|
280
266
|
weights (str): Path to the model weights file to be loaded.
|
|
@@ -286,12 +272,12 @@ class Model(torch.nn.Module):
|
|
|
286
272
|
|
|
287
273
|
Examples:
|
|
288
274
|
>>> model = Model()
|
|
289
|
-
>>> model._load("
|
|
275
|
+
>>> model._load("yolo26n.pt")
|
|
290
276
|
>>> model._load("path/to/weights.pth", task="detect")
|
|
291
277
|
"""
|
|
292
|
-
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
|
278
|
+
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://", "ul://")):
|
|
293
279
|
weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
|
|
294
|
-
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e.
|
|
280
|
+
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo26 -> yolo26n.pt
|
|
295
281
|
|
|
296
282
|
if str(weights).rpartition(".")[-1] == "pt":
|
|
297
283
|
self.model, self.ckpt = load_checkpoint(weights)
|
|
@@ -308,18 +294,17 @@ class Model(torch.nn.Module):
|
|
|
308
294
|
self.model_name = weights
|
|
309
295
|
|
|
310
296
|
def _check_is_pytorch_model(self) -> None:
|
|
311
|
-
"""
|
|
312
|
-
Check if the model is a PyTorch model and raise TypeError if it's not.
|
|
297
|
+
"""Check if the model is a PyTorch model and raise TypeError if it's not.
|
|
313
298
|
|
|
314
|
-
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
|
|
315
|
-
|
|
299
|
+
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that certain
|
|
300
|
+
operations that require a PyTorch model are only performed on compatible model types.
|
|
316
301
|
|
|
317
302
|
Raises:
|
|
318
303
|
TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
|
|
319
304
|
information about supported model formats and operations.
|
|
320
305
|
|
|
321
306
|
Examples:
|
|
322
|
-
>>> model = Model("
|
|
307
|
+
>>> model = Model("yolo26n.pt")
|
|
323
308
|
>>> model._check_is_pytorch_model() # No error raised
|
|
324
309
|
>>> model = Model("yolo11n.onnx")
|
|
325
310
|
>>> model._check_is_pytorch_model() # Raises TypeError
|
|
@@ -336,12 +321,11 @@ class Model(torch.nn.Module):
|
|
|
336
321
|
)
|
|
337
322
|
|
|
338
323
|
def reset_weights(self) -> Model:
|
|
339
|
-
"""
|
|
340
|
-
Reset the model's weights to their initial state.
|
|
324
|
+
"""Reset the model's weights to their initial state.
|
|
341
325
|
|
|
342
326
|
This method iterates through all modules in the model and resets their parameters if they have a
|
|
343
|
-
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
|
|
344
|
-
|
|
327
|
+
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
|
|
328
|
+
to be updated during training.
|
|
345
329
|
|
|
346
330
|
Returns:
|
|
347
331
|
(Model): The instance of the class with reset weights.
|
|
@@ -350,7 +334,7 @@ class Model(torch.nn.Module):
|
|
|
350
334
|
AssertionError: If the model is not a PyTorch model.
|
|
351
335
|
|
|
352
336
|
Examples:
|
|
353
|
-
>>> model = Model("
|
|
337
|
+
>>> model = Model("yolo26n.pt")
|
|
354
338
|
>>> model.reset_weights()
|
|
355
339
|
"""
|
|
356
340
|
self._check_is_pytorch_model()
|
|
@@ -361,9 +345,8 @@ class Model(torch.nn.Module):
|
|
|
361
345
|
p.requires_grad = True
|
|
362
346
|
return self
|
|
363
347
|
|
|
364
|
-
def load(self, weights: str | Path = "
|
|
365
|
-
"""
|
|
366
|
-
Load parameters from the specified weights file into the model.
|
|
348
|
+
def load(self, weights: str | Path = "yolo26n.pt") -> Model:
|
|
349
|
+
"""Load parameters from the specified weights file into the model.
|
|
367
350
|
|
|
368
351
|
This method supports loading weights from a file or directly from a weights object. It matches parameters by
|
|
369
352
|
name and shape and transfers them to the model.
|
|
@@ -379,7 +362,7 @@ class Model(torch.nn.Module):
|
|
|
379
362
|
|
|
380
363
|
Examples:
|
|
381
364
|
>>> model = Model()
|
|
382
|
-
>>> model.load("
|
|
365
|
+
>>> model.load("yolo26n.pt")
|
|
383
366
|
>>> model.load(Path("path/to/weights.pt"))
|
|
384
367
|
"""
|
|
385
368
|
self._check_is_pytorch_model()
|
|
@@ -390,11 +373,10 @@ class Model(torch.nn.Module):
|
|
|
390
373
|
return self
|
|
391
374
|
|
|
392
375
|
def save(self, filename: str | Path = "saved_model.pt") -> None:
|
|
393
|
-
"""
|
|
394
|
-
Save the current model state to a file.
|
|
376
|
+
"""Save the current model state to a file.
|
|
395
377
|
|
|
396
|
-
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
|
|
397
|
-
|
|
378
|
+
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as the
|
|
379
|
+
date, Ultralytics version, license information, and a link to the documentation.
|
|
398
380
|
|
|
399
381
|
Args:
|
|
400
382
|
filename (str | Path): The name of the file to save the model to.
|
|
@@ -403,7 +385,7 @@ class Model(torch.nn.Module):
|
|
|
403
385
|
AssertionError: If the model is not a PyTorch model.
|
|
404
386
|
|
|
405
387
|
Examples:
|
|
406
|
-
>>> model = Model("
|
|
388
|
+
>>> model = Model("yolo26n.pt")
|
|
407
389
|
>>> model.save("my_model.pt")
|
|
408
390
|
"""
|
|
409
391
|
self._check_is_pytorch_model()
|
|
@@ -421,9 +403,8 @@ class Model(torch.nn.Module):
|
|
|
421
403
|
}
|
|
422
404
|
torch.save({**self.ckpt, **updates}, filename)
|
|
423
405
|
|
|
424
|
-
def info(self, detailed: bool = False, verbose: bool = True):
|
|
425
|
-
"""
|
|
426
|
-
Display model information.
|
|
406
|
+
def info(self, detailed: bool = False, verbose: bool = True, imgsz: int | list[int, int] = 640):
|
|
407
|
+
"""Display model information.
|
|
427
408
|
|
|
428
409
|
This method provides an overview or detailed information about the model, depending on the arguments
|
|
429
410
|
passed. It can control the verbosity of the output and return the information as a list.
|
|
@@ -431,33 +412,33 @@ class Model(torch.nn.Module):
|
|
|
431
412
|
Args:
|
|
432
413
|
detailed (bool): If True, shows detailed information about the model layers and parameters.
|
|
433
414
|
verbose (bool): If True, prints the information. If False, returns the information as a list.
|
|
415
|
+
imgsz (int | list[int, int]): Input image size used for FLOPs calculation.
|
|
434
416
|
|
|
435
417
|
Returns:
|
|
436
|
-
(list[str]): A list of strings containing various types of information about the model, including
|
|
437
|
-
|
|
418
|
+
(list[str]): A list of strings containing various types of information about the model, including model
|
|
419
|
+
summary, layer details, and parameter counts. Empty if verbose is True.
|
|
438
420
|
|
|
439
421
|
Examples:
|
|
440
|
-
>>> model = Model("
|
|
422
|
+
>>> model = Model("yolo26n.pt")
|
|
441
423
|
>>> model.info() # Prints model summary
|
|
442
424
|
>>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list
|
|
443
425
|
"""
|
|
444
426
|
self._check_is_pytorch_model()
|
|
445
|
-
return self.model.info(detailed=detailed, verbose=verbose)
|
|
427
|
+
return self.model.info(detailed=detailed, verbose=verbose, imgsz=imgsz)
|
|
446
428
|
|
|
447
429
|
def fuse(self) -> None:
|
|
448
|
-
"""
|
|
449
|
-
Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
|
|
430
|
+
"""Fuse Conv2d and BatchNorm2d layers in the model for optimized inference.
|
|
450
431
|
|
|
451
|
-
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
|
|
452
|
-
|
|
453
|
-
|
|
432
|
+
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers into a
|
|
433
|
+
single layer. This fusion can significantly improve inference speed by reducing the number of operations and
|
|
434
|
+
memory accesses required during forward passes.
|
|
454
435
|
|
|
455
436
|
The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
|
|
456
437
|
bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
|
|
457
438
|
performs both convolution and normalization in one step.
|
|
458
439
|
|
|
459
440
|
Examples:
|
|
460
|
-
>>> model = Model("
|
|
441
|
+
>>> model = Model("yolo26n.pt")
|
|
461
442
|
>>> model.fuse()
|
|
462
443
|
>>> # Model is now fused and ready for optimized inference
|
|
463
444
|
"""
|
|
@@ -470,15 +451,14 @@ class Model(torch.nn.Module):
|
|
|
470
451
|
stream: bool = False,
|
|
471
452
|
**kwargs: Any,
|
|
472
453
|
) -> list:
|
|
473
|
-
"""
|
|
474
|
-
Generate image embeddings based on the provided source.
|
|
454
|
+
"""Generate image embeddings based on the provided source.
|
|
475
455
|
|
|
476
456
|
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
|
477
457
|
source. It allows customization of the embedding process through various keyword arguments.
|
|
478
458
|
|
|
479
459
|
Args:
|
|
480
|
-
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for
|
|
481
|
-
|
|
460
|
+
source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the image for generating
|
|
461
|
+
embeddings. Can be a file path, URL, PIL image, numpy array, etc.
|
|
482
462
|
stream (bool): If True, predictions are streamed.
|
|
483
463
|
**kwargs (Any): Additional keyword arguments for configuring the embedding process.
|
|
484
464
|
|
|
@@ -486,7 +466,7 @@ class Model(torch.nn.Module):
|
|
|
486
466
|
(list[torch.Tensor]): A list containing the image embeddings.
|
|
487
467
|
|
|
488
468
|
Examples:
|
|
489
|
-
>>> model = YOLO("
|
|
469
|
+
>>> model = YOLO("yolo26n.pt")
|
|
490
470
|
>>> image = "https://ultralytics.com/images/bus.jpg"
|
|
491
471
|
>>> embeddings = model.embed(image)
|
|
492
472
|
>>> print(embeddings[0].shape)
|
|
@@ -502,28 +482,27 @@ class Model(torch.nn.Module):
|
|
|
502
482
|
predictor=None,
|
|
503
483
|
**kwargs: Any,
|
|
504
484
|
) -> list[Results]:
|
|
505
|
-
"""
|
|
506
|
-
Perform predictions on the given image source using the YOLO model.
|
|
485
|
+
"""Perform predictions on the given image source using the YOLO model.
|
|
507
486
|
|
|
508
|
-
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
|
509
|
-
|
|
510
|
-
|
|
487
|
+
This method facilitates the prediction process, allowing various configurations through keyword arguments. It
|
|
488
|
+
supports predictions with custom predictors or the default predictor method. The method handles different types
|
|
489
|
+
of image sources and can operate in a streaming mode.
|
|
511
490
|
|
|
512
491
|
Args:
|
|
513
|
-
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
|
514
|
-
|
|
515
|
-
|
|
492
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image(s)
|
|
493
|
+
to make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
|
494
|
+
torch tensors.
|
|
516
495
|
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
|
517
|
-
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
|
|
518
|
-
|
|
496
|
+
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions. If
|
|
497
|
+
None, the method uses a default predictor.
|
|
519
498
|
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
|
|
520
499
|
|
|
521
500
|
Returns:
|
|
522
|
-
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
|
523
|
-
|
|
501
|
+
(list[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a Results
|
|
502
|
+
object.
|
|
524
503
|
|
|
525
504
|
Examples:
|
|
526
|
-
>>> model = YOLO("
|
|
505
|
+
>>> model = YOLO("yolo26n.pt")
|
|
527
506
|
>>> results = model.predict(source="path/to/image.jpg", conf=0.25)
|
|
528
507
|
>>> for r in results:
|
|
529
508
|
... print(r.boxes.data) # print detection bounding boxes
|
|
@@ -545,7 +524,7 @@ class Model(torch.nn.Module):
|
|
|
545
524
|
args = {**self.overrides, **custom, **kwargs} # highest priority args on the right
|
|
546
525
|
prompts = args.pop("prompts", None) # for SAM-type models
|
|
547
526
|
|
|
548
|
-
if not self.predictor:
|
|
527
|
+
if not self.predictor or self.predictor.args.device != args.get("device", self.predictor.args.device):
|
|
549
528
|
self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks)
|
|
550
529
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
|
551
530
|
else: # only update args if predictor is already setup
|
|
@@ -563,8 +542,7 @@ class Model(torch.nn.Module):
|
|
|
563
542
|
persist: bool = False,
|
|
564
543
|
**kwargs: Any,
|
|
565
544
|
) -> list[Results]:
|
|
566
|
-
"""
|
|
567
|
-
Conduct object tracking on the specified input source using the registered trackers.
|
|
545
|
+
"""Conduct object tracking on the specified input source using the registered trackers.
|
|
568
546
|
|
|
569
547
|
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
|
570
548
|
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
|
@@ -581,7 +559,7 @@ class Model(torch.nn.Module):
|
|
|
581
559
|
(list[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
|
|
582
560
|
|
|
583
561
|
Examples:
|
|
584
|
-
>>> model = YOLO("
|
|
562
|
+
>>> model = YOLO("yolo26n.pt")
|
|
585
563
|
>>> results = model.track(source="path/to/video.mp4", show=True)
|
|
586
564
|
>>> for r in results:
|
|
587
565
|
... print(r.boxes.id) # print tracking IDs
|
|
@@ -605,8 +583,7 @@ class Model(torch.nn.Module):
|
|
|
605
583
|
validator=None,
|
|
606
584
|
**kwargs: Any,
|
|
607
585
|
):
|
|
608
|
-
"""
|
|
609
|
-
Validate the model using a specified dataset and validation configuration.
|
|
586
|
+
"""Validate the model using a specified dataset and validation configuration.
|
|
610
587
|
|
|
611
588
|
This method facilitates the model validation process, allowing for customization through various settings. It
|
|
612
589
|
supports validation with a custom validator or the default validation approach. The method combines default
|
|
@@ -624,7 +601,7 @@ class Model(torch.nn.Module):
|
|
|
624
601
|
AssertionError: If the model is not a PyTorch model.
|
|
625
602
|
|
|
626
603
|
Examples:
|
|
627
|
-
>>> model = YOLO("
|
|
604
|
+
>>> model = YOLO("yolo26n.pt")
|
|
628
605
|
>>> results = model.val(data="coco8.yaml", imgsz=640)
|
|
629
606
|
>>> print(results.box.map) # Print mAP50-95
|
|
630
607
|
"""
|
|
@@ -637,13 +614,12 @@ class Model(torch.nn.Module):
|
|
|
637
614
|
return validator.metrics
|
|
638
615
|
|
|
639
616
|
def benchmark(self, data=None, format="", verbose=False, **kwargs: Any):
|
|
640
|
-
"""
|
|
641
|
-
Benchmark the model across various export formats to evaluate performance.
|
|
617
|
+
"""Benchmark the model across various export formats to evaluate performance.
|
|
642
618
|
|
|
643
|
-
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
619
|
+
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. It
|
|
620
|
+
uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured using
|
|
621
|
+
a combination of default configuration values, model-specific arguments, method-specific defaults, and any
|
|
622
|
+
additional user-provided keyword arguments.
|
|
647
623
|
|
|
648
624
|
Args:
|
|
649
625
|
data (str): Path to the dataset for benchmarking.
|
|
@@ -656,14 +632,14 @@ class Model(torch.nn.Module):
|
|
|
656
632
|
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
|
657
633
|
|
|
658
634
|
Returns:
|
|
659
|
-
(dict): A dictionary containing the results of the benchmarking process, including metrics for
|
|
660
|
-
|
|
635
|
+
(dict): A dictionary containing the results of the benchmarking process, including metrics for different
|
|
636
|
+
export formats.
|
|
661
637
|
|
|
662
638
|
Raises:
|
|
663
639
|
AssertionError: If the model is not a PyTorch model.
|
|
664
640
|
|
|
665
641
|
Examples:
|
|
666
|
-
>>> model = YOLO("
|
|
642
|
+
>>> model = YOLO("yolo26n.pt")
|
|
667
643
|
>>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True)
|
|
668
644
|
>>> print(results)
|
|
669
645
|
"""
|
|
@@ -691,23 +667,21 @@ class Model(torch.nn.Module):
|
|
|
691
667
|
self,
|
|
692
668
|
**kwargs: Any,
|
|
693
669
|
) -> str:
|
|
694
|
-
"""
|
|
695
|
-
Export the model to a different format suitable for deployment.
|
|
670
|
+
"""Export the model to a different format suitable for deployment.
|
|
696
671
|
|
|
697
672
|
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
|
|
698
673
|
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
|
|
699
674
|
defaults, and any additional arguments provided.
|
|
700
675
|
|
|
701
676
|
Args:
|
|
702
|
-
**kwargs (Any): Arbitrary keyword arguments
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
simplify (bool): Simplify ONNX model.
|
|
677
|
+
**kwargs (Any): Arbitrary keyword arguments for export configuration. Common options include:
|
|
678
|
+
- format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
|
|
679
|
+
- half (bool): Export model in half-precision.
|
|
680
|
+
- int8 (bool): Export model in int8 precision.
|
|
681
|
+
- device (str): Device to run the export on.
|
|
682
|
+
- workspace (int): Maximum memory workspace size for TensorRT engines.
|
|
683
|
+
- nms (bool): Add Non-Maximum Suppression (NMS) module to model.
|
|
684
|
+
- simplify (bool): Simplify ONNX model.
|
|
711
685
|
|
|
712
686
|
Returns:
|
|
713
687
|
(str): The path to the exported model file.
|
|
@@ -718,7 +692,7 @@ class Model(torch.nn.Module):
|
|
|
718
692
|
RuntimeError: If the export process fails due to errors.
|
|
719
693
|
|
|
720
694
|
Examples:
|
|
721
|
-
>>> model = YOLO("
|
|
695
|
+
>>> model = YOLO("yolo26n.pt")
|
|
722
696
|
>>> model.export(format="onnx", dynamic=True, simplify=True)
|
|
723
697
|
'path/to/exported/model.onnx'
|
|
724
698
|
"""
|
|
@@ -740,35 +714,35 @@ class Model(torch.nn.Module):
|
|
|
740
714
|
trainer=None,
|
|
741
715
|
**kwargs: Any,
|
|
742
716
|
):
|
|
743
|
-
"""
|
|
744
|
-
Train the model using the specified dataset and training configuration.
|
|
717
|
+
"""Train the model using the specified dataset and training configuration.
|
|
745
718
|
|
|
746
|
-
This method facilitates model training with a range of customizable settings. It supports training with a
|
|
747
|
-
|
|
748
|
-
|
|
719
|
+
This method facilitates model training with a range of customizable settings. It supports training with a custom
|
|
720
|
+
trainer or the default training approach. The method handles scenarios such as resuming training from a
|
|
721
|
+
checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
|
|
749
722
|
|
|
750
|
-
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
|
|
751
|
-
|
|
752
|
-
|
|
723
|
+
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training arguments and
|
|
724
|
+
warns if local arguments are provided. It checks for pip updates and combines default configurations,
|
|
725
|
+
method-specific defaults, and user-provided arguments to configure the training process.
|
|
753
726
|
|
|
754
727
|
Args:
|
|
755
728
|
trainer (BaseTrainer, optional): Custom trainer instance for model training. If None, uses default.
|
|
756
729
|
**kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
|
|
757
|
-
data (str): Path to dataset configuration file.
|
|
758
|
-
epochs (int): Number of training epochs.
|
|
759
|
-
batch (int): Batch size for training.
|
|
760
|
-
imgsz (int): Input image size.
|
|
761
|
-
device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
|
762
|
-
workers (int): Number of worker threads for data loading.
|
|
763
|
-
optimizer (str): Optimizer to use for training.
|
|
764
|
-
lr0 (float): Initial learning rate.
|
|
765
|
-
patience (int): Epochs to wait for no observable improvement for early stopping of training.
|
|
730
|
+
- data (str): Path to dataset configuration file.
|
|
731
|
+
- epochs (int): Number of training epochs.
|
|
732
|
+
- batch (int): Batch size for training.
|
|
733
|
+
- imgsz (int): Input image size.
|
|
734
|
+
- device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
|
735
|
+
- workers (int): Number of worker threads for data loading.
|
|
736
|
+
- optimizer (str): Optimizer to use for training.
|
|
737
|
+
- lr0 (float): Initial learning rate.
|
|
738
|
+
- patience (int): Epochs to wait for no observable improvement for early stopping of training.
|
|
739
|
+
- augmentations (list[Callable]): List of augmentation functions to apply during training.
|
|
766
740
|
|
|
767
741
|
Returns:
|
|
768
742
|
(dict | None): Training metrics if available and training is successful; otherwise, None.
|
|
769
743
|
|
|
770
744
|
Examples:
|
|
771
|
-
>>> model = YOLO("
|
|
745
|
+
>>> model = YOLO("yolo26n.pt")
|
|
772
746
|
>>> results = model.train(data="coco8.yaml", epochs=3)
|
|
773
747
|
"""
|
|
774
748
|
self._check_is_pytorch_model()
|
|
@@ -813,13 +787,12 @@ class Model(torch.nn.Module):
|
|
|
813
787
|
*args: Any,
|
|
814
788
|
**kwargs: Any,
|
|
815
789
|
):
|
|
816
|
-
"""
|
|
817
|
-
Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
|
|
790
|
+
"""Conduct hyperparameter tuning for the model, with an option to use Ray Tune.
|
|
818
791
|
|
|
819
|
-
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method.
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
792
|
+
This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. When Ray Tune
|
|
793
|
+
is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. Otherwise, it uses
|
|
794
|
+
the internal 'Tuner' class for tuning. The method combines default, overridden, and custom arguments to
|
|
795
|
+
configure the tuning process.
|
|
823
796
|
|
|
824
797
|
Args:
|
|
825
798
|
use_ray (bool): Whether to use Ray Tune for hyperparameter tuning. If False, uses internal tuning method.
|
|
@@ -835,7 +808,7 @@ class Model(torch.nn.Module):
|
|
|
835
808
|
TypeError: If the model is not a PyTorch model.
|
|
836
809
|
|
|
837
810
|
Examples:
|
|
838
|
-
>>> model = YOLO("
|
|
811
|
+
>>> model = YOLO("yolo26n.pt")
|
|
839
812
|
>>> results = model.tune(data="coco8.yaml", iterations=5)
|
|
840
813
|
>>> print(results)
|
|
841
814
|
|
|
@@ -852,19 +825,18 @@ class Model(torch.nn.Module):
|
|
|
852
825
|
|
|
853
826
|
custom = {} # method defaults
|
|
854
827
|
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
|
855
|
-
return Tuner(args=args, _callbacks=self.callbacks)(
|
|
828
|
+
return Tuner(args=args, _callbacks=self.callbacks)(iterations=iterations)
|
|
856
829
|
|
|
857
830
|
def _apply(self, fn) -> Model:
|
|
858
|
-
"""
|
|
859
|
-
Apply a function to model tensors that are not parameters or registered buffers.
|
|
831
|
+
"""Apply a function to model tensors that are not parameters or registered buffers.
|
|
860
832
|
|
|
861
833
|
This method extends the functionality of the parent class's _apply method by additionally resetting the
|
|
862
|
-
predictor and updating the device in the model's overrides. It's typically used for operations like
|
|
863
|
-
|
|
834
|
+
predictor and updating the device in the model's overrides. It's typically used for operations like moving the
|
|
835
|
+
model to a different device or changing its precision.
|
|
864
836
|
|
|
865
837
|
Args:
|
|
866
|
-
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
|
|
867
|
-
|
|
838
|
+
fn (Callable): A function to be applied to the model's tensors. This is typically a method like to(), cpu(),
|
|
839
|
+
cuda(), half(), or float().
|
|
868
840
|
|
|
869
841
|
Returns:
|
|
870
842
|
(Model): The model instance with the function applied and updated attributes.
|
|
@@ -873,19 +845,18 @@ class Model(torch.nn.Module):
|
|
|
873
845
|
AssertionError: If the model is not a PyTorch model.
|
|
874
846
|
|
|
875
847
|
Examples:
|
|
876
|
-
>>> model = Model("
|
|
848
|
+
>>> model = Model("yolo26n.pt")
|
|
877
849
|
>>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
|
|
878
850
|
"""
|
|
879
851
|
self._check_is_pytorch_model()
|
|
880
|
-
self = super()._apply(fn)
|
|
852
|
+
self = super()._apply(fn)
|
|
881
853
|
self.predictor = None # reset predictor as device may have changed
|
|
882
854
|
self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
|
|
883
855
|
return self
|
|
884
856
|
|
|
885
857
|
@property
|
|
886
858
|
def names(self) -> dict[int, str]:
|
|
887
|
-
"""
|
|
888
|
-
Retrieve the class names associated with the loaded model.
|
|
859
|
+
"""Retrieve the class names associated with the loaded model.
|
|
889
860
|
|
|
890
861
|
This property returns the class names if they are defined in the model. It checks the class names for validity
|
|
891
862
|
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
|
@@ -899,7 +870,7 @@ class Model(torch.nn.Module):
|
|
|
899
870
|
AttributeError: If the model or predictor does not have a 'names' attribute.
|
|
900
871
|
|
|
901
872
|
Examples:
|
|
902
|
-
>>> model = YOLO("
|
|
873
|
+
>>> model = YOLO("yolo26n.pt")
|
|
903
874
|
>>> print(model.names)
|
|
904
875
|
{0: 'person', 1: 'bicycle', 2: 'car', ...}
|
|
905
876
|
"""
|
|
@@ -915,8 +886,7 @@ class Model(torch.nn.Module):
|
|
|
915
886
|
|
|
916
887
|
@property
|
|
917
888
|
def device(self) -> torch.device:
|
|
918
|
-
"""
|
|
919
|
-
Get the device on which the model's parameters are allocated.
|
|
889
|
+
"""Get the device on which the model's parameters are allocated.
|
|
920
890
|
|
|
921
891
|
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
|
922
892
|
applicable only to models that are instances of torch.nn.Module.
|
|
@@ -928,7 +898,7 @@ class Model(torch.nn.Module):
|
|
|
928
898
|
AttributeError: If the model is not a torch.nn.Module instance.
|
|
929
899
|
|
|
930
900
|
Examples:
|
|
931
|
-
>>> model = YOLO("
|
|
901
|
+
>>> model = YOLO("yolo26n.pt")
|
|
932
902
|
>>> print(model.device)
|
|
933
903
|
device(type='cuda', index=0) # if CUDA is available
|
|
934
904
|
>>> model = model.to("cpu")
|
|
@@ -939,18 +909,17 @@ class Model(torch.nn.Module):
|
|
|
939
909
|
|
|
940
910
|
@property
|
|
941
911
|
def transforms(self):
|
|
942
|
-
"""
|
|
943
|
-
Retrieve the transformations applied to the input data of the loaded model.
|
|
912
|
+
"""Retrieve the transformations applied to the input data of the loaded model.
|
|
944
913
|
|
|
945
|
-
This property returns the transformations if they are defined in the model. The transforms
|
|
946
|
-
|
|
947
|
-
|
|
914
|
+
This property returns the transformations if they are defined in the model. The transforms typically include
|
|
915
|
+
preprocessing steps like resizing, normalization, and data augmentation that are applied to input data before it
|
|
916
|
+
is fed into the model.
|
|
948
917
|
|
|
949
918
|
Returns:
|
|
950
919
|
(object | None): The transform object of the model if available, otherwise None.
|
|
951
920
|
|
|
952
921
|
Examples:
|
|
953
|
-
>>> model = YOLO("
|
|
922
|
+
>>> model = YOLO("yolo26n.pt")
|
|
954
923
|
>>> transforms = model.transforms
|
|
955
924
|
>>> if transforms:
|
|
956
925
|
... print(f"Model transforms: {transforms}")
|
|
@@ -960,18 +929,17 @@ class Model(torch.nn.Module):
|
|
|
960
929
|
return self.model.transforms if hasattr(self.model, "transforms") else None
|
|
961
930
|
|
|
962
931
|
def add_callback(self, event: str, func) -> None:
|
|
963
|
-
"""
|
|
964
|
-
Add a callback function for a specified event.
|
|
932
|
+
"""Add a callback function for a specified event.
|
|
965
933
|
|
|
966
|
-
This method allows registering custom callback functions that are triggered on specific events during
|
|
967
|
-
|
|
968
|
-
|
|
934
|
+
This method allows registering custom callback functions that are triggered on specific events during model
|
|
935
|
+
operations such as training or inference. Callbacks provide a way to extend and customize the behavior of the
|
|
936
|
+
model at various stages of its lifecycle.
|
|
969
937
|
|
|
970
938
|
Args:
|
|
971
|
-
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
|
|
972
|
-
|
|
973
|
-
func (Callable): The callback function to be registered. This function will be called when the
|
|
974
|
-
|
|
939
|
+
event (str): The name of the event to attach the callback to. Must be a valid event name recognized by the
|
|
940
|
+
Ultralytics framework.
|
|
941
|
+
func (Callable): The callback function to be registered. This function will be called when the specified
|
|
942
|
+
event occurs.
|
|
975
943
|
|
|
976
944
|
Raises:
|
|
977
945
|
ValueError: If the event name is not recognized or is invalid.
|
|
@@ -979,26 +947,25 @@ class Model(torch.nn.Module):
|
|
|
979
947
|
Examples:
|
|
980
948
|
>>> def on_train_start(trainer):
|
|
981
949
|
... print("Training is starting!")
|
|
982
|
-
>>> model = YOLO("
|
|
950
|
+
>>> model = YOLO("yolo26n.pt")
|
|
983
951
|
>>> model.add_callback("on_train_start", on_train_start)
|
|
984
952
|
>>> model.train(data="coco8.yaml", epochs=1)
|
|
985
953
|
"""
|
|
986
954
|
self.callbacks[event].append(func)
|
|
987
955
|
|
|
988
956
|
def clear_callback(self, event: str) -> None:
|
|
989
|
-
"""
|
|
990
|
-
Clear all callback functions registered for a specified event.
|
|
957
|
+
"""Clear all callback functions registered for a specified event.
|
|
991
958
|
|
|
992
|
-
This method removes all custom and default callback functions associated with the given event.
|
|
993
|
-
|
|
994
|
-
|
|
959
|
+
This method removes all custom and default callback functions associated with the given event. It resets the
|
|
960
|
+
callback list for the specified event to an empty list, effectively removing all registered callbacks for that
|
|
961
|
+
event.
|
|
995
962
|
|
|
996
963
|
Args:
|
|
997
964
|
event (str): The name of the event for which to clear the callbacks. This should be a valid event name
|
|
998
965
|
recognized by the Ultralytics callback system.
|
|
999
966
|
|
|
1000
967
|
Examples:
|
|
1001
|
-
>>> model = YOLO("
|
|
968
|
+
>>> model = YOLO("yolo26n.pt")
|
|
1002
969
|
>>> model.add_callback("on_train_start", lambda: print("Training started"))
|
|
1003
970
|
>>> model.clear_callback("on_train_start")
|
|
1004
971
|
>>> # All callbacks for 'on_train_start' are now removed
|
|
@@ -1014,8 +981,7 @@ class Model(torch.nn.Module):
|
|
|
1014
981
|
self.callbacks[event] = []
|
|
1015
982
|
|
|
1016
983
|
def reset_callbacks(self) -> None:
|
|
1017
|
-
"""
|
|
1018
|
-
Reset all callbacks to their default functions.
|
|
984
|
+
"""Reset all callbacks to their default functions.
|
|
1019
985
|
|
|
1020
986
|
This method reinstates the default callback functions for all events, removing any custom callbacks that were
|
|
1021
987
|
previously added. It iterates through all default callback events and replaces the current callbacks with the
|
|
@@ -1028,7 +994,7 @@ class Model(torch.nn.Module):
|
|
|
1028
994
|
modifications, ensuring consistent behavior across different runs or experiments.
|
|
1029
995
|
|
|
1030
996
|
Examples:
|
|
1031
|
-
>>> model = YOLO("
|
|
997
|
+
>>> model = YOLO("yolo26n.pt")
|
|
1032
998
|
>>> model.add_callback("on_train_start", custom_function)
|
|
1033
999
|
>>> model.reset_callbacks()
|
|
1034
1000
|
# All callbacks are now reset to their default functions
|
|
@@ -1038,12 +1004,11 @@ class Model(torch.nn.Module):
|
|
|
1038
1004
|
|
|
1039
1005
|
@staticmethod
|
|
1040
1006
|
def _reset_ckpt_args(args: dict[str, Any]) -> dict[str, Any]:
|
|
1041
|
-
"""
|
|
1042
|
-
Reset specific arguments when loading a PyTorch model checkpoint.
|
|
1007
|
+
"""Reset specific arguments when loading a PyTorch model checkpoint.
|
|
1043
1008
|
|
|
1044
|
-
This method filters the input arguments dictionary to retain only a specific set of keys that are
|
|
1045
|
-
|
|
1046
|
-
|
|
1009
|
+
This method filters the input arguments dictionary to retain only a specific set of keys that are considered
|
|
1010
|
+
important for model loading. It's used to ensure that only relevant arguments are preserved when loading a model
|
|
1011
|
+
from a checkpoint, discarding any unnecessary or potentially conflicting settings.
|
|
1047
1012
|
|
|
1048
1013
|
Args:
|
|
1049
1014
|
args (dict): A dictionary containing various model arguments and settings.
|
|
@@ -1066,12 +1031,11 @@ class Model(torch.nn.Module):
|
|
|
1066
1031
|
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
1067
1032
|
|
|
1068
1033
|
def _smart_load(self, key: str):
|
|
1069
|
-
"""
|
|
1070
|
-
Intelligently load the appropriate module based on the model task.
|
|
1034
|
+
"""Intelligently load the appropriate module based on the model task.
|
|
1071
1035
|
|
|
1072
|
-
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
|
|
1073
|
-
|
|
1074
|
-
|
|
1036
|
+
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) based
|
|
1037
|
+
on the current task of the model and the provided key. It uses the task_map dictionary to determine the
|
|
1038
|
+
appropriate module to load for the specific task.
|
|
1075
1039
|
|
|
1076
1040
|
Args:
|
|
1077
1041
|
key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
|
|
@@ -1096,24 +1060,23 @@ class Model(torch.nn.Module):
|
|
|
1096
1060
|
|
|
1097
1061
|
@property
|
|
1098
1062
|
def task_map(self) -> dict:
|
|
1099
|
-
"""
|
|
1100
|
-
Provide a mapping from model tasks to corresponding classes for different modes.
|
|
1063
|
+
"""Provide a mapping from model tasks to corresponding classes for different modes.
|
|
1101
1064
|
|
|
1102
|
-
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
|
|
1103
|
-
|
|
1104
|
-
|
|
1065
|
+
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) to a
|
|
1066
|
+
nested dictionary. The nested dictionary contains mappings for different operational modes (model, trainer,
|
|
1067
|
+
validator, predictor) to their respective class implementations.
|
|
1105
1068
|
|
|
1106
|
-
The mapping allows for dynamic loading of appropriate classes based on the model's task and the
|
|
1107
|
-
|
|
1108
|
-
|
|
1069
|
+
The mapping allows for dynamic loading of appropriate classes based on the model's task and the desired
|
|
1070
|
+
operational mode. This facilitates a flexible and extensible architecture for handling various tasks and modes
|
|
1071
|
+
within the Ultralytics framework.
|
|
1109
1072
|
|
|
1110
1073
|
Returns:
|
|
1111
1074
|
(dict[str, dict[str, Any]]): A dictionary mapping task names to nested dictionaries. Each nested dictionary
|
|
1112
|
-
|
|
1113
|
-
|
|
1075
|
+
contains mappings for 'model', 'trainer', 'validator', and 'predictor' keys to their respective class
|
|
1076
|
+
implementations for that task.
|
|
1114
1077
|
|
|
1115
1078
|
Examples:
|
|
1116
|
-
>>> model = Model("
|
|
1079
|
+
>>> model = Model("yolo26n.pt")
|
|
1117
1080
|
>>> task_map = model.task_map
|
|
1118
1081
|
>>> detect_predictor = task_map["detect"]["predictor"]
|
|
1119
1082
|
>>> segment_trainer = task_map["segment"]["trainer"]
|
|
@@ -1121,8 +1084,7 @@ class Model(torch.nn.Module):
|
|
|
1121
1084
|
raise NotImplementedError("Please provide task map for your model!")
|
|
1122
1085
|
|
|
1123
1086
|
def eval(self):
|
|
1124
|
-
"""
|
|
1125
|
-
Sets the model to evaluation mode.
|
|
1087
|
+
"""Sets the model to evaluation mode.
|
|
1126
1088
|
|
|
1127
1089
|
This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
|
|
1128
1090
|
that behave differently during training and evaluation. In evaluation mode, these layers use running statistics
|
|
@@ -1132,7 +1094,7 @@ class Model(torch.nn.Module):
|
|
|
1132
1094
|
(Model): The model instance with evaluation mode set.
|
|
1133
1095
|
|
|
1134
1096
|
Examples:
|
|
1135
|
-
>>> model = YOLO("
|
|
1097
|
+
>>> model = YOLO("yolo26n.pt")
|
|
1136
1098
|
>>> model.eval()
|
|
1137
1099
|
>>> # Model is now in evaluation mode for inference
|
|
1138
1100
|
"""
|
|
@@ -1140,8 +1102,7 @@ class Model(torch.nn.Module):
|
|
|
1140
1102
|
return self
|
|
1141
1103
|
|
|
1142
1104
|
def __getattr__(self, name):
|
|
1143
|
-
"""
|
|
1144
|
-
Enable accessing model attributes directly through the Model class.
|
|
1105
|
+
"""Enable accessing model attributes directly through the Model class.
|
|
1145
1106
|
|
|
1146
1107
|
This method provides a way to access attributes of the underlying model directly through the Model class
|
|
1147
1108
|
instance. It first checks if the requested attribute is 'model', in which case it returns the model from
|
|
@@ -1157,7 +1118,7 @@ class Model(torch.nn.Module):
|
|
|
1157
1118
|
AttributeError: If the requested attribute does not exist in the model.
|
|
1158
1119
|
|
|
1159
1120
|
Examples:
|
|
1160
|
-
>>> model = YOLO("
|
|
1121
|
+
>>> model = YOLO("yolo26n.pt")
|
|
1161
1122
|
>>> print(model.stride) # Access model.stride attribute
|
|
1162
1123
|
>>> print(model.names) # Access model.names attribute
|
|
1163
1124
|
"""
|