ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +37 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +111 -41
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +579 -244
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +191 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +226 -82
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +172 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +305 -112
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.63.dist-info/METADATA +370 -0
- ultralytics-8.3.63.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
ultralytics/engine/model.py
CHANGED
@@ -1,17 +1,29 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import inspect
|
4
|
-
import sys
|
5
4
|
from pathlib import Path
|
6
|
-
from typing import Union
|
5
|
+
from typing import Any, Dict, List, Union
|
7
6
|
|
8
7
|
import numpy as np
|
9
8
|
import torch
|
9
|
+
from PIL import Image
|
10
10
|
|
11
11
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
12
|
-
from ultralytics.
|
12
|
+
from ultralytics.engine.results import Results
|
13
|
+
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
|
13
14
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
14
|
-
from ultralytics.utils import
|
15
|
+
from ultralytics.utils import (
|
16
|
+
ARGV,
|
17
|
+
ASSETS,
|
18
|
+
DEFAULT_CFG_DICT,
|
19
|
+
LOGGER,
|
20
|
+
RANK,
|
21
|
+
SETTINGS,
|
22
|
+
callbacks,
|
23
|
+
checks,
|
24
|
+
emojis,
|
25
|
+
yaml_load,
|
26
|
+
)
|
15
27
|
|
16
28
|
|
17
29
|
class Model(nn.Module):
|
@@ -20,26 +32,18 @@ class Model(nn.Module):
|
|
20
32
|
|
21
33
|
This class provides a common interface for various operations related to YOLO models, such as training,
|
22
34
|
validation, prediction, exporting, and benchmarking. It handles different types of models, including those
|
23
|
-
loaded from local files, Ultralytics HUB, or Triton Server.
|
24
|
-
extendable for different tasks and model configurations.
|
25
|
-
|
26
|
-
Args:
|
27
|
-
model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file
|
28
|
-
path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
|
29
|
-
task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's
|
30
|
-
application domain, such as object detection, segmentation, etc. Defaults to None.
|
31
|
-
verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False.
|
35
|
+
loaded from local files, Ultralytics HUB, or Triton Server.
|
32
36
|
|
33
37
|
Attributes:
|
34
|
-
callbacks (
|
38
|
+
callbacks (Dict): A dictionary of callback functions for various events during model operations.
|
35
39
|
predictor (BasePredictor): The predictor object used for making predictions.
|
36
40
|
model (nn.Module): The underlying PyTorch model.
|
37
41
|
trainer (BaseTrainer): The trainer object used for training the model.
|
38
|
-
ckpt (
|
42
|
+
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
|
39
43
|
cfg (str): The configuration of the model if loaded from a *.yaml file.
|
40
44
|
ckpt_path (str): The path to the checkpoint file.
|
41
|
-
overrides (
|
42
|
-
metrics (
|
45
|
+
overrides (Dict): A dictionary of overrides for model configuration.
|
46
|
+
metrics (Dict): The latest training/validation metrics.
|
43
47
|
session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
|
44
48
|
task (str): The type of task the model is intended for.
|
45
49
|
model_name (str): The name of the model.
|
@@ -65,120 +69,136 @@ class Model(nn.Module):
|
|
65
69
|
add_callback: Adds a callback function for an event.
|
66
70
|
clear_callback: Clears all callbacks for an event.
|
67
71
|
reset_callbacks: Resets all callbacks to their default functions.
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
77
|
-
ValueError: If the model file or configuration is invalid or unsupported.
|
78
|
-
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
79
|
-
TypeError: If the model is not a PyTorch model when required.
|
80
|
-
AttributeError: If required attributes or methods are not implemented or available.
|
81
|
-
NotImplementedError: If a specific model task or mode is not supported.
|
72
|
+
|
73
|
+
Examples:
|
74
|
+
>>> from ultralytics import YOLO
|
75
|
+
>>> model = YOLO("yolo11n.pt")
|
76
|
+
>>> results = model.predict("image.jpg")
|
77
|
+
>>> model.train(data="coco8.yaml", epochs=3)
|
78
|
+
>>> metrics = model.val()
|
79
|
+
>>> model.export(format="onnx")
|
82
80
|
"""
|
83
81
|
|
84
82
|
def __init__(
|
85
83
|
self,
|
86
|
-
model: Union[str, Path] = "
|
84
|
+
model: Union[str, Path] = "yolo11n.pt",
|
87
85
|
task: str = None,
|
88
86
|
verbose: bool = False,
|
89
87
|
) -> None:
|
90
88
|
"""
|
91
89
|
Initializes a new instance of the YOLO model class.
|
92
90
|
|
93
|
-
This constructor sets up the model based on the provided model path or name. It handles various types of
|
94
|
-
sources, including local files, Ultralytics HUB models, and Triton Server models. The method
|
95
|
-
important attributes of the model and prepares it for operations like training,
|
91
|
+
This constructor sets up the model based on the provided model path or name. It handles various types of
|
92
|
+
model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
|
93
|
+
initializes several important attributes of the model and prepares it for operations like training,
|
94
|
+
prediction, or export.
|
96
95
|
|
97
96
|
Args:
|
98
|
-
model (Union[str, Path]
|
99
|
-
|
100
|
-
task (
|
101
|
-
|
102
|
-
|
103
|
-
operations. Defaults to False.
|
97
|
+
model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
|
98
|
+
model name from Ultralytics HUB, or a Triton Server model.
|
99
|
+
task (str | None): The task type associated with the YOLO model, specifying its application domain.
|
100
|
+
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
101
|
+
operations.
|
104
102
|
|
105
103
|
Raises:
|
106
104
|
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
107
105
|
ValueError: If the model file or configuration is invalid or unsupported.
|
108
106
|
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
107
|
+
|
108
|
+
Examples:
|
109
|
+
>>> model = Model("yolo11n.pt")
|
110
|
+
>>> model = Model("path/to/model.yaml", task="detect")
|
111
|
+
>>> model = Model("hub_model", verbose=True)
|
109
112
|
"""
|
110
113
|
super().__init__()
|
111
114
|
self.callbacks = callbacks.get_default_callbacks()
|
112
115
|
self.predictor = None # reuse predictor
|
113
116
|
self.model = None # model object
|
114
117
|
self.trainer = None # trainer object
|
115
|
-
self.ckpt =
|
118
|
+
self.ckpt = {} # if loaded from *.pt
|
116
119
|
self.cfg = None # if loaded from *.yaml
|
117
120
|
self.ckpt_path = None
|
118
121
|
self.overrides = {} # overrides for trainer object
|
119
122
|
self.metrics = None # validation/training metrics
|
120
123
|
self.session = None # HUB session
|
121
124
|
self.task = task # task type
|
122
|
-
|
125
|
+
model = str(model).strip()
|
123
126
|
|
124
127
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
125
128
|
if self.is_hub_model(model):
|
126
129
|
# Fetch model from HUB
|
127
|
-
checks.check_requirements("hub-sdk
|
128
|
-
|
129
|
-
model =
|
130
|
+
checks.check_requirements("hub-sdk>=0.0.12")
|
131
|
+
session = HUBTrainingSession.create_session(model)
|
132
|
+
model = session.model_file
|
133
|
+
if session.train_args: # training sent from HUB
|
134
|
+
self.session = session
|
130
135
|
|
131
136
|
# Check if Triton Server model
|
132
137
|
elif self.is_triton_model(model):
|
133
|
-
self.model = model
|
134
|
-
self.task = task
|
138
|
+
self.model_name = self.model = model
|
139
|
+
self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set
|
135
140
|
return
|
136
141
|
|
137
142
|
# Load or create new YOLO model
|
138
|
-
|
139
|
-
if Path(model).suffix in (".yaml", ".yml"):
|
143
|
+
if Path(model).suffix in {".yaml", ".yml"}:
|
140
144
|
self._new(model, task=task, verbose=verbose)
|
141
145
|
else:
|
142
146
|
self._load(model, task=task)
|
143
147
|
|
144
|
-
|
148
|
+
# Delete super().training for accessing self.model.training
|
149
|
+
del self.training
|
145
150
|
|
146
151
|
def __call__(
|
147
152
|
self,
|
148
|
-
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
153
|
+
source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
|
149
154
|
stream: bool = False,
|
150
|
-
**kwargs,
|
155
|
+
**kwargs: Any,
|
151
156
|
) -> list:
|
152
157
|
"""
|
153
|
-
|
158
|
+
Alias for the predict method, enabling the model instance to be callable for predictions.
|
154
159
|
|
155
|
-
This method simplifies the process of making predictions by allowing the model instance to be called
|
156
|
-
with the required arguments
|
160
|
+
This method simplifies the process of making predictions by allowing the model instance to be called
|
161
|
+
directly with the required arguments.
|
157
162
|
|
158
163
|
Args:
|
159
|
-
source (str | Path | int | PIL.Image | np.ndarray
|
160
|
-
predictions.
|
161
|
-
|
162
|
-
stream (bool
|
163
|
-
|
164
|
-
**kwargs (any): Additional keyword arguments for configuring the prediction process.
|
164
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
|
165
|
+
the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
|
166
|
+
tensor, or a list/tuple of these.
|
167
|
+
stream (bool): If True, treat the input source as a continuous stream for predictions.
|
168
|
+
**kwargs: Additional keyword arguments to configure the prediction process.
|
165
169
|
|
166
170
|
Returns:
|
167
|
-
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in
|
171
|
+
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
172
|
+
Results object.
|
173
|
+
|
174
|
+
Examples:
|
175
|
+
>>> model = YOLO("yolo11n.pt")
|
176
|
+
>>> results = model("https://ultralytics.com/images/bus.jpg")
|
177
|
+
>>> for r in results:
|
178
|
+
... print(f"Detected {len(r)} objects in image")
|
168
179
|
"""
|
169
180
|
return self.predict(source, stream, **kwargs)
|
170
181
|
|
171
182
|
@staticmethod
|
172
|
-
def
|
173
|
-
"""
|
174
|
-
|
183
|
+
def is_triton_model(model: str) -> bool:
|
184
|
+
"""
|
185
|
+
Checks if the given model string is a Triton Server URL.
|
175
186
|
|
176
|
-
|
177
|
-
|
187
|
+
This static method determines whether the provided model string represents a valid Triton Server URL by
|
188
|
+
parsing its components using urllib.parse.urlsplit().
|
178
189
|
|
179
|
-
|
180
|
-
|
181
|
-
|
190
|
+
Args:
|
191
|
+
model (str): The model string to be checked.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
(bool): True if the model string is a valid Triton Server URL, False otherwise.
|
195
|
+
|
196
|
+
Examples:
|
197
|
+
>>> Model.is_triton_model("http://localhost:8000/v2/models/yolov8n")
|
198
|
+
True
|
199
|
+
>>> Model.is_triton_model("yolo11n.pt")
|
200
|
+
False
|
201
|
+
"""
|
182
202
|
from urllib.parse import urlsplit
|
183
203
|
|
184
204
|
url = urlsplit(model)
|
@@ -186,24 +206,48 @@ class Model(nn.Module):
|
|
186
206
|
|
187
207
|
@staticmethod
|
188
208
|
def is_hub_model(model: str) -> bool:
|
189
|
-
"""
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
209
|
+
"""
|
210
|
+
Check if the provided model is an Ultralytics HUB model.
|
211
|
+
|
212
|
+
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
213
|
+
identifier.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
model (str): The model string to check.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
|
220
|
+
|
221
|
+
Examples:
|
222
|
+
>>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
|
223
|
+
True
|
224
|
+
>>> Model.is_hub_model("yolo11n.pt")
|
225
|
+
False
|
226
|
+
"""
|
227
|
+
return model.startswith(f"{HUB_WEB_ROOT}/models/")
|
197
228
|
|
198
229
|
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
199
230
|
"""
|
200
231
|
Initializes a new model and infers the task type from the model definitions.
|
201
232
|
|
233
|
+
This method creates a new model instance based on the provided configuration file. It loads the model
|
234
|
+
configuration, infers the task type if not specified, and initializes the model using the appropriate
|
235
|
+
class from the task map.
|
236
|
+
|
202
237
|
Args:
|
203
|
-
cfg (str): model configuration file
|
204
|
-
task (str | None): model
|
205
|
-
model (
|
206
|
-
|
238
|
+
cfg (str): Path to the model configuration file in YAML format.
|
239
|
+
task (str | None): The specific task for the model. If None, it will be inferred from the config.
|
240
|
+
model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
|
241
|
+
a new one.
|
242
|
+
verbose (bool): If True, displays model information during loading.
|
243
|
+
|
244
|
+
Raises:
|
245
|
+
ValueError: If the configuration file is invalid or the task cannot be inferred.
|
246
|
+
ImportError: If the required dependencies for the specified task are not installed.
|
247
|
+
|
248
|
+
Examples:
|
249
|
+
>>> model = Model()
|
250
|
+
>>> model._new("yolov8n.yaml", task="detect", verbose=True)
|
207
251
|
"""
|
208
252
|
cfg_dict = yaml_model_load(cfg)
|
209
253
|
self.cfg = cfg
|
@@ -215,31 +259,63 @@ class Model(nn.Module):
|
|
215
259
|
# Below added to allow export from YAMLs
|
216
260
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
217
261
|
self.model.task = self.task
|
262
|
+
self.model_name = cfg
|
218
263
|
|
219
264
|
def _load(self, weights: str, task=None) -> None:
|
220
265
|
"""
|
221
|
-
|
266
|
+
Loads a model from a checkpoint file or initializes it from a weights file.
|
267
|
+
|
268
|
+
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
|
269
|
+
up the model, task, and related attributes based on the loaded weights.
|
222
270
|
|
223
271
|
Args:
|
224
|
-
weights (str): model
|
225
|
-
task (str | None): model
|
272
|
+
weights (str): Path to the model weights file to be loaded.
|
273
|
+
task (str | None): The task associated with the model. If None, it will be inferred from the model.
|
274
|
+
|
275
|
+
Raises:
|
276
|
+
FileNotFoundError: If the specified weights file does not exist or is inaccessible.
|
277
|
+
ValueError: If the weights file format is unsupported or invalid.
|
278
|
+
|
279
|
+
Examples:
|
280
|
+
>>> model = Model()
|
281
|
+
>>> model._load("yolo11n.pt")
|
282
|
+
>>> model._load("path/to/weights.pth", task="detect")
|
226
283
|
"""
|
227
|
-
|
228
|
-
|
284
|
+
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
285
|
+
weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
|
286
|
+
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
|
287
|
+
|
288
|
+
if Path(weights).suffix == ".pt":
|
229
289
|
self.model, self.ckpt = attempt_load_one_weight(weights)
|
230
290
|
self.task = self.model.args["task"]
|
231
291
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
232
292
|
self.ckpt_path = self.model.pt_path
|
233
293
|
else:
|
234
|
-
weights = checks.check_file(weights)
|
294
|
+
weights = checks.check_file(weights) # runs in all cases, not redundant with above call
|
235
295
|
self.model, self.ckpt = weights, None
|
236
296
|
self.task = task or guess_model_task(weights)
|
237
297
|
self.ckpt_path = weights
|
238
298
|
self.overrides["model"] = weights
|
239
299
|
self.overrides["task"] = self.task
|
300
|
+
self.model_name = weights
|
240
301
|
|
241
302
|
def _check_is_pytorch_model(self) -> None:
|
242
|
-
"""
|
303
|
+
"""
|
304
|
+
Checks if the model is a PyTorch model and raises a TypeError if it's not.
|
305
|
+
|
306
|
+
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
|
307
|
+
certain operations that require a PyTorch model are only performed on compatible model types.
|
308
|
+
|
309
|
+
Raises:
|
310
|
+
TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
|
311
|
+
information about supported model formats and operations.
|
312
|
+
|
313
|
+
Examples:
|
314
|
+
>>> model = Model("yolo11n.pt")
|
315
|
+
>>> model._check_is_pytorch_model() # No error raised
|
316
|
+
>>> model = Model("yolov8n.onnx")
|
317
|
+
>>> model._check_is_pytorch_model() # Raises TypeError
|
318
|
+
"""
|
243
319
|
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
244
320
|
pt_module = isinstance(self.model, nn.Module)
|
245
321
|
if not (pt_module or pt_str):
|
@@ -253,17 +329,21 @@ class Model(nn.Module):
|
|
253
329
|
|
254
330
|
def reset_weights(self) -> "Model":
|
255
331
|
"""
|
256
|
-
Resets the model
|
332
|
+
Resets the model's weights to their initial state.
|
257
333
|
|
258
334
|
This method iterates through all modules in the model and resets their parameters if they have a
|
259
|
-
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
|
260
|
-
to be updated during training.
|
335
|
+
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
|
336
|
+
enabling them to be updated during training.
|
261
337
|
|
262
338
|
Returns:
|
263
|
-
|
339
|
+
(Model): The instance of the class with reset weights.
|
264
340
|
|
265
341
|
Raises:
|
266
342
|
AssertionError: If the model is not a PyTorch model.
|
343
|
+
|
344
|
+
Examples:
|
345
|
+
>>> model = Model("yolo11n.pt")
|
346
|
+
>>> model.reset_weights()
|
267
347
|
"""
|
268
348
|
self._check_is_pytorch_model()
|
269
349
|
for m in self.model.modules():
|
@@ -273,7 +353,7 @@ class Model(nn.Module):
|
|
273
353
|
p.requires_grad = True
|
274
354
|
return self
|
275
355
|
|
276
|
-
def load(self, weights: Union[str, Path] = "
|
356
|
+
def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model":
|
277
357
|
"""
|
278
358
|
Loads parameters from the specified weights file into the model.
|
279
359
|
|
@@ -281,73 +361,103 @@ class Model(nn.Module):
|
|
281
361
|
name and shape and transfers them to the model.
|
282
362
|
|
283
363
|
Args:
|
284
|
-
weights (str
|
364
|
+
weights (Union[str, Path]): Path to the weights file or a weights object.
|
285
365
|
|
286
366
|
Returns:
|
287
|
-
|
367
|
+
(Model): The instance of the class with loaded weights.
|
288
368
|
|
289
369
|
Raises:
|
290
370
|
AssertionError: If the model is not a PyTorch model.
|
371
|
+
|
372
|
+
Examples:
|
373
|
+
>>> model = Model()
|
374
|
+
>>> model.load("yolo11n.pt")
|
375
|
+
>>> model.load(Path("path/to/weights.pt"))
|
291
376
|
"""
|
292
377
|
self._check_is_pytorch_model()
|
293
378
|
if isinstance(weights, (str, Path)):
|
379
|
+
self.overrides["pretrained"] = weights # remember the weights for DDP training
|
294
380
|
weights, self.ckpt = attempt_load_one_weight(weights)
|
295
381
|
self.model.load(weights)
|
296
382
|
return self
|
297
383
|
|
298
|
-
def save(self, filename: Union[str, Path] = "saved_model.pt"
|
384
|
+
def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
|
299
385
|
"""
|
300
386
|
Saves the current model state to a file.
|
301
387
|
|
302
|
-
This method exports the model's checkpoint (ckpt) to the specified filename.
|
388
|
+
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
|
389
|
+
the date, Ultralytics version, license information, and a link to the documentation.
|
303
390
|
|
304
391
|
Args:
|
305
|
-
filename (str
|
306
|
-
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
|
392
|
+
filename (Union[str, Path]): The name of the file to save the model to.
|
307
393
|
|
308
394
|
Raises:
|
309
395
|
AssertionError: If the model is not a PyTorch model.
|
396
|
+
|
397
|
+
Examples:
|
398
|
+
>>> model = Model("yolo11n.pt")
|
399
|
+
>>> model.save("my_model.pt")
|
310
400
|
"""
|
311
401
|
self._check_is_pytorch_model()
|
312
|
-
from
|
402
|
+
from copy import deepcopy
|
313
403
|
from datetime import datetime
|
314
404
|
|
405
|
+
from ultralytics import __version__
|
406
|
+
|
315
407
|
updates = {
|
408
|
+
"model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
|
316
409
|
"date": datetime.now().isoformat(),
|
317
410
|
"version": __version__,
|
318
411
|
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
319
412
|
"docs": "https://docs.ultralytics.com",
|
320
413
|
}
|
321
|
-
torch.save({**self.ckpt, **updates}, filename
|
414
|
+
torch.save({**self.ckpt, **updates}, filename)
|
322
415
|
|
323
416
|
def info(self, detailed: bool = False, verbose: bool = True):
|
324
417
|
"""
|
325
418
|
Logs or returns model information.
|
326
419
|
|
327
|
-
This method provides an overview or detailed information about the model, depending on the arguments
|
328
|
-
It can control the verbosity of the output.
|
420
|
+
This method provides an overview or detailed information about the model, depending on the arguments
|
421
|
+
passed. It can control the verbosity of the output and return the information as a list.
|
329
422
|
|
330
423
|
Args:
|
331
|
-
detailed (bool): If True, shows detailed information about the model
|
332
|
-
verbose (bool): If True, prints the information. If False, returns the information
|
424
|
+
detailed (bool): If True, shows detailed information about the model layers and parameters.
|
425
|
+
verbose (bool): If True, prints the information. If False, returns the information as a list.
|
333
426
|
|
334
427
|
Returns:
|
335
|
-
(
|
428
|
+
(List[str]): A list of strings containing various types of information about the model, including
|
429
|
+
model summary, layer details, and parameter counts. Empty if verbose is True.
|
336
430
|
|
337
431
|
Raises:
|
338
|
-
|
432
|
+
TypeError: If the model is not a PyTorch model.
|
433
|
+
|
434
|
+
Examples:
|
435
|
+
>>> model = Model("yolo11n.pt")
|
436
|
+
>>> model.info() # Prints model summary
|
437
|
+
>>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list
|
339
438
|
"""
|
340
439
|
self._check_is_pytorch_model()
|
341
440
|
return self.model.info(detailed=detailed, verbose=verbose)
|
342
441
|
|
343
442
|
def fuse(self):
|
344
443
|
"""
|
345
|
-
Fuses Conv2d and BatchNorm2d layers in the model.
|
444
|
+
Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.
|
346
445
|
|
347
|
-
This method
|
446
|
+
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
|
447
|
+
into a single layer. This fusion can significantly improve inference speed by reducing the number of
|
448
|
+
operations and memory accesses required during forward passes.
|
449
|
+
|
450
|
+
The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
|
451
|
+
bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
|
452
|
+
performs both convolution and normalization in one step.
|
348
453
|
|
349
454
|
Raises:
|
350
|
-
|
455
|
+
TypeError: If the model is not a PyTorch nn.Module.
|
456
|
+
|
457
|
+
Examples:
|
458
|
+
>>> model = Model("yolo11n.pt")
|
459
|
+
>>> model.fuse()
|
460
|
+
>>> # Model is now fused and ready for optimized inference
|
351
461
|
"""
|
352
462
|
self._check_is_pytorch_model()
|
353
463
|
self.model.fuse()
|
@@ -356,25 +466,31 @@ class Model(nn.Module):
|
|
356
466
|
self,
|
357
467
|
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
358
468
|
stream: bool = False,
|
359
|
-
**kwargs,
|
469
|
+
**kwargs: Any,
|
360
470
|
) -> list:
|
361
471
|
"""
|
362
472
|
Generates image embeddings based on the provided source.
|
363
473
|
|
364
|
-
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
365
|
-
It allows customization of the embedding process through various keyword arguments.
|
474
|
+
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
475
|
+
source. It allows customization of the embedding process through various keyword arguments.
|
366
476
|
|
367
477
|
Args:
|
368
|
-
source (str | int |
|
369
|
-
|
370
|
-
stream (bool): If True, predictions are streamed.
|
371
|
-
**kwargs
|
478
|
+
source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
|
479
|
+
generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
|
480
|
+
stream (bool): If True, predictions are streamed.
|
481
|
+
**kwargs: Additional keyword arguments for configuring the embedding process.
|
372
482
|
|
373
483
|
Returns:
|
374
484
|
(List[torch.Tensor]): A list containing the image embeddings.
|
375
485
|
|
376
486
|
Raises:
|
377
487
|
AssertionError: If the model is not a PyTorch model.
|
488
|
+
|
489
|
+
Examples:
|
490
|
+
>>> model = YOLO("yolo11n.pt")
|
491
|
+
>>> image = "https://ultralytics.com/images/bus.jpg"
|
492
|
+
>>> embeddings = model.embed(image)
|
493
|
+
>>> print(embeddings[0].shape)
|
378
494
|
"""
|
379
495
|
if not kwargs.get("embed"):
|
380
496
|
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
@@ -382,45 +498,48 @@ class Model(nn.Module):
|
|
382
498
|
|
383
499
|
def predict(
|
384
500
|
self,
|
385
|
-
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
501
|
+
source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
|
386
502
|
stream: bool = False,
|
387
503
|
predictor=None,
|
388
|
-
**kwargs,
|
389
|
-
) ->
|
504
|
+
**kwargs: Any,
|
505
|
+
) -> List[Results]:
|
390
506
|
"""
|
391
507
|
Performs predictions on the given image source using the YOLO model.
|
392
508
|
|
393
509
|
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
394
510
|
It supports predictions with custom predictors or the default predictor method. The method handles different
|
395
|
-
types of image sources and can operate in a streaming mode.
|
396
|
-
through 'prompts'.
|
397
|
-
|
398
|
-
The method sets up a new predictor if not already present and updates its arguments with each call.
|
399
|
-
It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
|
400
|
-
is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
|
401
|
-
for confidence threshold and saving behavior.
|
511
|
+
types of image sources and can operate in a streaming mode.
|
402
512
|
|
403
513
|
Args:
|
404
|
-
source (str | int | PIL.Image | np.ndarray
|
405
|
-
Accepts various types
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
514
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
|
515
|
+
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
|
516
|
+
images, numpy arrays, and torch tensors.
|
517
|
+
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
518
|
+
predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
|
519
|
+
If None, the method uses a default predictor.
|
520
|
+
**kwargs: Additional keyword arguments for configuring the prediction process.
|
411
521
|
|
412
522
|
Returns:
|
413
|
-
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in
|
414
|
-
|
415
|
-
|
416
|
-
|
523
|
+
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
524
|
+
Results object.
|
525
|
+
|
526
|
+
Examples:
|
527
|
+
>>> model = YOLO("yolo11n.pt")
|
528
|
+
>>> results = model.predict(source="path/to/image.jpg", conf=0.25)
|
529
|
+
>>> for r in results:
|
530
|
+
... print(r.boxes.data) # print detection bounding boxes
|
531
|
+
|
532
|
+
Notes:
|
533
|
+
- If 'source' is not provided, it defaults to the ASSETS constant with a warning.
|
534
|
+
- The method sets up a new predictor if not already present and updates its arguments with each call.
|
535
|
+
- For SAM-type models, 'prompts' can be passed as a keyword argument.
|
417
536
|
"""
|
418
537
|
if source is None:
|
419
538
|
source = ASSETS
|
420
539
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
421
540
|
|
422
|
-
is_cli = (
|
423
|
-
x in
|
541
|
+
is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any(
|
542
|
+
x in ARGV for x in ("predict", "track", "mode=predict", "mode=track")
|
424
543
|
)
|
425
544
|
|
426
545
|
custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults
|
@@ -428,7 +547,7 @@ class Model(nn.Module):
|
|
428
547
|
prompts = args.pop("prompts", None) # for SAM-type models
|
429
548
|
|
430
549
|
if not self.predictor:
|
431
|
-
self.predictor = predictor or self._smart_load("predictor")(overrides=args, _callbacks=self.callbacks)
|
550
|
+
self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks)
|
432
551
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
433
552
|
else: # only update args if predictor is already setup
|
434
553
|
self.predictor.args = get_cfg(self.predictor.args, args)
|
@@ -443,31 +562,38 @@ class Model(nn.Module):
|
|
443
562
|
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
444
563
|
stream: bool = False,
|
445
564
|
persist: bool = False,
|
446
|
-
**kwargs,
|
447
|
-
) ->
|
565
|
+
**kwargs: Any,
|
566
|
+
) -> List[Results]:
|
448
567
|
"""
|
449
568
|
Conducts object tracking on the specified input source using the registered trackers.
|
450
569
|
|
451
|
-
This method performs object tracking using the model's predictors and optionally registered trackers. It
|
452
|
-
|
453
|
-
|
454
|
-
already present and optionally persists them based on the 'persist' flag.
|
455
|
-
|
456
|
-
The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low
|
457
|
-
confidence predictions as input. The tracking mode is explicitly set in the keyword arguments.
|
570
|
+
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
571
|
+
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
572
|
+
The method registers trackers if not already present and can persist them between calls.
|
458
573
|
|
459
574
|
Args:
|
460
|
-
source (str,
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
575
|
+
source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
|
576
|
+
tracking. Can be a file path, URL, or video stream.
|
577
|
+
stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
|
578
|
+
persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
|
579
|
+
**kwargs: Additional keyword arguments for configuring the tracking process.
|
465
580
|
|
466
581
|
Returns:
|
467
|
-
(List[ultralytics.engine.results.Results]): A list of tracking results,
|
582
|
+
(List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.
|
468
583
|
|
469
584
|
Raises:
|
470
585
|
AttributeError: If the predictor does not have registered trackers.
|
586
|
+
|
587
|
+
Examples:
|
588
|
+
>>> model = YOLO("yolo11n.pt")
|
589
|
+
>>> results = model.track(source="path/to/video.mp4", show=True)
|
590
|
+
>>> for r in results:
|
591
|
+
... print(r.boxes.id) # print tracking IDs
|
592
|
+
|
593
|
+
Notes:
|
594
|
+
- This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.
|
595
|
+
- The tracking mode is explicitly set in the keyword arguments.
|
596
|
+
- Batch size is set to 1 for tracking in videos.
|
471
597
|
"""
|
472
598
|
if not hasattr(self.predictor, "trackers"):
|
473
599
|
from ultralytics.trackers import register_tracker
|
@@ -481,31 +607,30 @@ class Model(nn.Module):
|
|
481
607
|
def val(
|
482
608
|
self,
|
483
609
|
validator=None,
|
484
|
-
**kwargs,
|
610
|
+
**kwargs: Any,
|
485
611
|
):
|
486
612
|
"""
|
487
613
|
Validates the model using a specified dataset and validation configuration.
|
488
614
|
|
489
|
-
This method facilitates the model validation process, allowing for
|
490
|
-
|
491
|
-
|
492
|
-
the validation process. After validation, it updates the model's metrics with the results obtained from the
|
493
|
-
validator.
|
494
|
-
|
495
|
-
The method supports various arguments that allow customization of the validation process. For a comprehensive
|
496
|
-
list of all configurable options, users should refer to the 'configuration' section in the documentation.
|
615
|
+
This method facilitates the model validation process, allowing for customization through various settings. It
|
616
|
+
supports validation with a custom validator or the default validation approach. The method combines default
|
617
|
+
configurations, method-specific defaults, and user-provided arguments to configure the validation process.
|
497
618
|
|
498
619
|
Args:
|
499
|
-
validator (BaseValidator
|
500
|
-
|
501
|
-
**kwargs
|
502
|
-
used to customize various aspects of the validation process.
|
620
|
+
validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
|
621
|
+
validating the model.
|
622
|
+
**kwargs: Arbitrary keyword arguments for customizing the validation process.
|
503
623
|
|
504
624
|
Returns:
|
505
|
-
(
|
625
|
+
(ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.
|
506
626
|
|
507
627
|
Raises:
|
508
628
|
AssertionError: If the model is not a PyTorch model.
|
629
|
+
|
630
|
+
Examples:
|
631
|
+
>>> model = YOLO("yolo11n.pt")
|
632
|
+
>>> results = model.val(data="coco8.yaml", imgsz=640)
|
633
|
+
>>> print(results.box.map) # Print mAP50-95
|
509
634
|
"""
|
510
635
|
custom = {"rect": True} # method defaults
|
511
636
|
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
@@ -517,29 +642,37 @@ class Model(nn.Module):
|
|
517
642
|
|
518
643
|
def benchmark(
|
519
644
|
self,
|
520
|
-
**kwargs,
|
645
|
+
**kwargs: Any,
|
521
646
|
):
|
522
647
|
"""
|
523
648
|
Benchmarks the model across various export formats to evaluate performance.
|
524
649
|
|
525
650
|
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
|
526
|
-
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
|
527
|
-
using a combination of default configuration values, model-specific arguments, method-specific
|
528
|
-
any additional user-provided keyword arguments.
|
529
|
-
|
530
|
-
The method supports various arguments that allow customization of the benchmarking process, such as dataset
|
531
|
-
choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all
|
532
|
-
configurable options, users should refer to the 'configuration' section in the documentation.
|
651
|
+
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
|
652
|
+
configured using a combination of default configuration values, model-specific arguments, method-specific
|
653
|
+
defaults, and any additional user-provided keyword arguments.
|
533
654
|
|
534
655
|
Args:
|
535
|
-
**kwargs
|
536
|
-
default configurations, model-specific arguments, and method defaults.
|
656
|
+
**kwargs: Arbitrary keyword arguments to customize the benchmarking process. These are combined with
|
657
|
+
default configurations, model-specific arguments, and method defaults. Common options include:
|
658
|
+
- data (str): Path to the dataset for benchmarking.
|
659
|
+
- imgsz (int | List[int]): Image size for benchmarking.
|
660
|
+
- half (bool): Whether to use half-precision (FP16) mode.
|
661
|
+
- int8 (bool): Whether to use int8 precision mode.
|
662
|
+
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
663
|
+
- verbose (bool): Whether to print detailed benchmark information.
|
537
664
|
|
538
665
|
Returns:
|
539
|
-
(
|
666
|
+
(Dict): A dictionary containing the results of the benchmarking process, including metrics for
|
667
|
+
different export formats.
|
540
668
|
|
541
669
|
Raises:
|
542
670
|
AssertionError: If the model is not a PyTorch model.
|
671
|
+
|
672
|
+
Examples:
|
673
|
+
>>> model = YOLO("yolo11n.pt")
|
674
|
+
>>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True)
|
675
|
+
>>> print(results)
|
543
676
|
"""
|
544
677
|
self._check_is_pytorch_model()
|
545
678
|
from ultralytics.utils.benchmarks import benchmark
|
@@ -558,66 +691,92 @@ class Model(nn.Module):
|
|
558
691
|
|
559
692
|
def export(
|
560
693
|
self,
|
561
|
-
**kwargs,
|
562
|
-
):
|
694
|
+
**kwargs: Any,
|
695
|
+
) -> str:
|
563
696
|
"""
|
564
697
|
Exports the model to a different format suitable for deployment.
|
565
698
|
|
566
699
|
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
|
567
700
|
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
|
568
|
-
defaults, and any additional arguments provided.
|
569
|
-
|
570
|
-
The method supports a wide range of arguments to customize the export process. For a comprehensive list of all
|
571
|
-
possible arguments, refer to the 'configuration' section in the documentation.
|
701
|
+
defaults, and any additional arguments provided.
|
572
702
|
|
573
703
|
Args:
|
574
|
-
**kwargs
|
575
|
-
model's overrides and method defaults.
|
704
|
+
**kwargs: Arbitrary keyword arguments to customize the export process. These are combined with
|
705
|
+
the model's overrides and method defaults. Common arguments include:
|
706
|
+
format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
|
707
|
+
half (bool): Export model in half-precision.
|
708
|
+
int8 (bool): Export model in int8 precision.
|
709
|
+
device (str): Device to run the export on.
|
710
|
+
workspace (int): Maximum memory workspace size for TensorRT engines.
|
711
|
+
nms (bool): Add Non-Maximum Suppression (NMS) module to model.
|
712
|
+
simplify (bool): Simplify ONNX model.
|
576
713
|
|
577
714
|
Returns:
|
578
|
-
(
|
715
|
+
(str): The path to the exported model file.
|
579
716
|
|
580
717
|
Raises:
|
581
718
|
AssertionError: If the model is not a PyTorch model.
|
719
|
+
ValueError: If an unsupported export format is specified.
|
720
|
+
RuntimeError: If the export process fails due to errors.
|
721
|
+
|
722
|
+
Examples:
|
723
|
+
>>> model = YOLO("yolo11n.pt")
|
724
|
+
>>> model.export(format="onnx", dynamic=True, simplify=True)
|
725
|
+
'path/to/exported/model.onnx'
|
582
726
|
"""
|
583
727
|
self._check_is_pytorch_model()
|
584
728
|
from .exporter import Exporter
|
585
729
|
|
586
|
-
custom = {
|
730
|
+
custom = {
|
731
|
+
"imgsz": self.model.args["imgsz"],
|
732
|
+
"batch": 1,
|
733
|
+
"data": None,
|
734
|
+
"device": None, # reset to avoid multi-GPU errors
|
735
|
+
"verbose": False,
|
736
|
+
} # method defaults
|
587
737
|
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
|
588
738
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
589
739
|
|
590
740
|
def train(
|
591
741
|
self,
|
592
742
|
trainer=None,
|
593
|
-
**kwargs,
|
743
|
+
**kwargs: Any,
|
594
744
|
):
|
595
745
|
"""
|
596
746
|
Trains the model using the specified dataset and training configuration.
|
597
747
|
|
598
|
-
This method facilitates model training with a range of customizable settings
|
599
|
-
|
600
|
-
|
601
|
-
updating model and configuration after training.
|
748
|
+
This method facilitates model training with a range of customizable settings. It supports training with a
|
749
|
+
custom trainer or the default training approach. The method handles scenarios such as resuming training
|
750
|
+
from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
|
602
751
|
|
603
|
-
When using Ultralytics HUB, if the session
|
604
|
-
arguments and
|
605
|
-
configurations, method-specific defaults, and user-provided arguments to configure the training process.
|
606
|
-
training, it updates the model and its configurations, and optionally attaches metrics.
|
752
|
+
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
|
753
|
+
arguments and warns if local arguments are provided. It checks for pip updates and combines default
|
754
|
+
configurations, method-specific defaults, and user-provided arguments to configure the training process.
|
607
755
|
|
608
756
|
Args:
|
609
|
-
trainer (BaseTrainer
|
610
|
-
|
611
|
-
|
612
|
-
|
757
|
+
trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
|
758
|
+
**kwargs: Arbitrary keyword arguments for training configuration. Common options include:
|
759
|
+
data (str): Path to dataset configuration file.
|
760
|
+
epochs (int): Number of training epochs.
|
761
|
+
batch_size (int): Batch size for training.
|
762
|
+
imgsz (int): Input image size.
|
763
|
+
device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
764
|
+
workers (int): Number of worker threads for data loading.
|
765
|
+
optimizer (str): Optimizer to use for training.
|
766
|
+
lr0 (float): Initial learning rate.
|
767
|
+
patience (int): Epochs to wait for no observable improvement for early stopping of training.
|
613
768
|
|
614
769
|
Returns:
|
615
|
-
(
|
770
|
+
(Dict | None): Training metrics if available and training is successful; otherwise, None.
|
616
771
|
|
617
772
|
Raises:
|
618
773
|
AssertionError: If the model is not a PyTorch model.
|
619
774
|
PermissionError: If there is a permission issue with the HUB session.
|
620
775
|
ModuleNotFoundError: If the HUB SDK is not installed.
|
776
|
+
|
777
|
+
Examples:
|
778
|
+
>>> model = YOLO("yolo11n.pt")
|
779
|
+
>>> results = model.train(data="coco8.yaml", epochs=3)
|
621
780
|
"""
|
622
781
|
self._check_is_pytorch_model()
|
623
782
|
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
@@ -628,7 +787,12 @@ class Model(nn.Module):
|
|
628
787
|
checks.check_pip_update_available()
|
629
788
|
|
630
789
|
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
|
631
|
-
custom = {
|
790
|
+
custom = {
|
791
|
+
# NOTE: handle the case when 'cfg' includes 'data'.
|
792
|
+
"data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task],
|
793
|
+
"model": self.overrides["model"],
|
794
|
+
"task": self.task,
|
795
|
+
} # method defaults
|
632
796
|
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
633
797
|
if args.get("resume"):
|
634
798
|
args["resume"] = self.ckpt_path
|
@@ -638,25 +802,12 @@ class Model(nn.Module):
|
|
638
802
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
639
803
|
self.model = self.trainer.model
|
640
804
|
|
641
|
-
if SETTINGS["hub"] is True and not self.session:
|
642
|
-
# Create a model in HUB
|
643
|
-
try:
|
644
|
-
self.session = self._get_hub_session(self.model_name)
|
645
|
-
if self.session:
|
646
|
-
self.session.create_model(args)
|
647
|
-
# Check model was created
|
648
|
-
if not getattr(self.session.model, "id", None):
|
649
|
-
self.session = None
|
650
|
-
except (PermissionError, ModuleNotFoundError):
|
651
|
-
# Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
|
652
|
-
pass
|
653
|
-
|
654
805
|
self.trainer.hub_session = self.session # attach optional HUB session
|
655
806
|
self.trainer.train()
|
656
807
|
# Update model and cfg after training
|
657
|
-
if RANK in
|
808
|
+
if RANK in {-1, 0}:
|
658
809
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
659
|
-
self.model,
|
810
|
+
self.model, self.ckpt = attempt_load_one_weight(ckpt)
|
660
811
|
self.overrides = self.model.args
|
661
812
|
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
662
813
|
return self.metrics
|
@@ -665,8 +816,8 @@ class Model(nn.Module):
|
|
665
816
|
self,
|
666
817
|
use_ray=False,
|
667
818
|
iterations=10,
|
668
|
-
*args,
|
669
|
-
**kwargs,
|
819
|
+
*args: Any,
|
820
|
+
**kwargs: Any,
|
670
821
|
):
|
671
822
|
"""
|
672
823
|
Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
|
@@ -679,14 +830,19 @@ class Model(nn.Module):
|
|
679
830
|
Args:
|
680
831
|
use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
|
681
832
|
iterations (int): The number of tuning iterations to perform. Defaults to 10.
|
682
|
-
*args
|
683
|
-
**kwargs
|
833
|
+
*args: Variable length argument list for additional arguments.
|
834
|
+
**kwargs: Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
|
684
835
|
|
685
836
|
Returns:
|
686
|
-
(
|
837
|
+
(Dict): A dictionary containing the results of the hyperparameter search.
|
687
838
|
|
688
839
|
Raises:
|
689
840
|
AssertionError: If the model is not a PyTorch model.
|
841
|
+
|
842
|
+
Examples:
|
843
|
+
>>> model = YOLO("yolo11n.pt")
|
844
|
+
>>> results = model.tune(use_ray=True, iterations=20)
|
845
|
+
>>> print(results)
|
690
846
|
"""
|
691
847
|
self._check_is_pytorch_model()
|
692
848
|
if use_ray:
|
@@ -701,7 +857,27 @@ class Model(nn.Module):
|
|
701
857
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
702
858
|
|
703
859
|
def _apply(self, fn) -> "Model":
|
704
|
-
"""
|
860
|
+
"""
|
861
|
+
Applies a function to model tensors that are not parameters or registered buffers.
|
862
|
+
|
863
|
+
This method extends the functionality of the parent class's _apply method by additionally resetting the
|
864
|
+
predictor and updating the device in the model's overrides. It's typically used for operations like
|
865
|
+
moving the model to a different device or changing its precision.
|
866
|
+
|
867
|
+
Args:
|
868
|
+
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
|
869
|
+
to(), cpu(), cuda(), half(), or float().
|
870
|
+
|
871
|
+
Returns:
|
872
|
+
(Model): The model instance with the function applied and updated attributes.
|
873
|
+
|
874
|
+
Raises:
|
875
|
+
AssertionError: If the model is not a PyTorch model.
|
876
|
+
|
877
|
+
Examples:
|
878
|
+
>>> model = Model("yolo11n.pt")
|
879
|
+
>>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
|
880
|
+
"""
|
705
881
|
self._check_is_pytorch_model()
|
706
882
|
self = super()._apply(fn) # noqa
|
707
883
|
self.predictor = None # reset predictor as device may have changed
|
@@ -709,30 +885,55 @@ class Model(nn.Module):
|
|
709
885
|
return self
|
710
886
|
|
711
887
|
@property
|
712
|
-
def names(self) ->
|
888
|
+
def names(self) -> Dict[int, str]:
|
713
889
|
"""
|
714
890
|
Retrieves the class names associated with the loaded model.
|
715
891
|
|
716
892
|
This property returns the class names if they are defined in the model. It checks the class names for validity
|
717
|
-
using the 'check_class_names' function from the ultralytics.nn.autobackend module.
|
893
|
+
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
894
|
+
initialized, it sets it up before retrieving the names.
|
718
895
|
|
719
896
|
Returns:
|
720
|
-
(
|
897
|
+
(Dict[int, str]): A dict of class names associated with the model.
|
898
|
+
|
899
|
+
Raises:
|
900
|
+
AttributeError: If the model or predictor does not have a 'names' attribute.
|
901
|
+
|
902
|
+
Examples:
|
903
|
+
>>> model = YOLO("yolo11n.pt")
|
904
|
+
>>> print(model.names)
|
905
|
+
{0: 'person', 1: 'bicycle', 2: 'car', ...}
|
721
906
|
"""
|
722
907
|
from ultralytics.nn.autobackend import check_class_names
|
723
908
|
|
724
|
-
|
909
|
+
if hasattr(self.model, "names"):
|
910
|
+
return check_class_names(self.model.names)
|
911
|
+
if not self.predictor: # export formats will not have predictor defined until predict() is called
|
912
|
+
self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
|
913
|
+
self.predictor.setup_model(model=self.model, verbose=False)
|
914
|
+
return self.predictor.model.names
|
725
915
|
|
726
916
|
@property
|
727
917
|
def device(self) -> torch.device:
|
728
918
|
"""
|
729
919
|
Retrieves the device on which the model's parameters are allocated.
|
730
920
|
|
731
|
-
This property
|
732
|
-
that are instances of nn.Module.
|
921
|
+
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
922
|
+
applicable only to models that are instances of nn.Module.
|
733
923
|
|
734
924
|
Returns:
|
735
|
-
(torch.device
|
925
|
+
(torch.device): The device (CPU/GPU) of the model.
|
926
|
+
|
927
|
+
Raises:
|
928
|
+
AttributeError: If the model is not a PyTorch nn.Module instance.
|
929
|
+
|
930
|
+
Examples:
|
931
|
+
>>> model = YOLO("yolo11n.pt")
|
932
|
+
>>> print(model.device)
|
933
|
+
device(type='cuda', index=0) # if CUDA is available
|
934
|
+
>>> model = model.to("cpu")
|
935
|
+
>>> print(model.device)
|
936
|
+
device(type='cpu')
|
736
937
|
"""
|
737
938
|
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
738
939
|
|
@@ -741,10 +942,20 @@ class Model(nn.Module):
|
|
741
942
|
"""
|
742
943
|
Retrieves the transformations applied to the input data of the loaded model.
|
743
944
|
|
744
|
-
This property returns the transformations if they are defined in the model.
|
945
|
+
This property returns the transformations if they are defined in the model. The transforms
|
946
|
+
typically include preprocessing steps like resizing, normalization, and data augmentation
|
947
|
+
that are applied to input data before it is fed into the model.
|
745
948
|
|
746
949
|
Returns:
|
747
950
|
(object | None): The transform object of the model if available, otherwise None.
|
951
|
+
|
952
|
+
Examples:
|
953
|
+
>>> model = YOLO("yolo11n.pt")
|
954
|
+
>>> transforms = model.transforms
|
955
|
+
>>> if transforms:
|
956
|
+
... print(f"Model transforms: {transforms}")
|
957
|
+
... else:
|
958
|
+
... print("No transforms defined for this model.")
|
748
959
|
"""
|
749
960
|
return self.model.transforms if hasattr(self.model, "transforms") else None
|
750
961
|
|
@@ -752,15 +963,25 @@ class Model(nn.Module):
|
|
752
963
|
"""
|
753
964
|
Adds a callback function for a specified event.
|
754
965
|
|
755
|
-
This method allows
|
756
|
-
model training or inference.
|
966
|
+
This method allows registering custom callback functions that are triggered on specific events during
|
967
|
+
model operations such as training or inference. Callbacks provide a way to extend and customize the
|
968
|
+
behavior of the model at various stages of its lifecycle.
|
757
969
|
|
758
970
|
Args:
|
759
|
-
event (str): The name of the event to attach the callback to.
|
760
|
-
|
971
|
+
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
|
972
|
+
by the Ultralytics framework.
|
973
|
+
func (Callable): The callback function to be registered. This function will be called when the
|
974
|
+
specified event occurs.
|
761
975
|
|
762
976
|
Raises:
|
763
|
-
ValueError: If the event name is not recognized.
|
977
|
+
ValueError: If the event name is not recognized or is invalid.
|
978
|
+
|
979
|
+
Examples:
|
980
|
+
>>> def on_train_start(trainer):
|
981
|
+
... print("Training is starting!")
|
982
|
+
>>> model = YOLO("yolo11n.pt")
|
983
|
+
>>> model.add_callback("on_train_start", on_train_start)
|
984
|
+
>>> model.train(data="coco8.yaml", epochs=1)
|
764
985
|
"""
|
765
986
|
self.callbacks[event].append(func)
|
766
987
|
|
@@ -769,12 +990,26 @@ class Model(nn.Module):
|
|
769
990
|
Clears all callback functions registered for a specified event.
|
770
991
|
|
771
992
|
This method removes all custom and default callback functions associated with the given event.
|
993
|
+
It resets the callback list for the specified event to an empty list, effectively removing all
|
994
|
+
registered callbacks for that event.
|
772
995
|
|
773
996
|
Args:
|
774
|
-
event (str): The name of the event for which to clear the callbacks.
|
775
|
-
|
776
|
-
|
777
|
-
|
997
|
+
event (str): The name of the event for which to clear the callbacks. This should be a valid event name
|
998
|
+
recognized by the Ultralytics callback system.
|
999
|
+
|
1000
|
+
Examples:
|
1001
|
+
>>> model = YOLO("yolo11n.pt")
|
1002
|
+
>>> model.add_callback("on_train_start", lambda: print("Training started"))
|
1003
|
+
>>> model.clear_callback("on_train_start")
|
1004
|
+
>>> # All callbacks for 'on_train_start' are now removed
|
1005
|
+
|
1006
|
+
Notes:
|
1007
|
+
- This method affects both custom callbacks added by the user and default callbacks
|
1008
|
+
provided by the Ultralytics framework.
|
1009
|
+
- After calling this method, no callbacks will be executed for the specified event
|
1010
|
+
until new ones are added.
|
1011
|
+
- Use with caution as it removes all callbacks, including essential ones that might
|
1012
|
+
be required for proper functioning of certain operations.
|
778
1013
|
"""
|
779
1014
|
self.callbacks[event] = []
|
780
1015
|
|
@@ -783,14 +1018,45 @@ class Model(nn.Module):
|
|
783
1018
|
Resets all callbacks to their default functions.
|
784
1019
|
|
785
1020
|
This method reinstates the default callback functions for all events, removing any custom callbacks that were
|
786
|
-
added
|
1021
|
+
previously added. It iterates through all default callback events and replaces the current callbacks with the
|
1022
|
+
default ones.
|
1023
|
+
|
1024
|
+
The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined
|
1025
|
+
functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc.
|
1026
|
+
|
1027
|
+
This method is useful when you want to revert to the original set of callbacks after making custom
|
1028
|
+
modifications, ensuring consistent behavior across different runs or experiments.
|
1029
|
+
|
1030
|
+
Examples:
|
1031
|
+
>>> model = YOLO("yolo11n.pt")
|
1032
|
+
>>> model.add_callback("on_train_start", custom_function)
|
1033
|
+
>>> model.reset_callbacks()
|
1034
|
+
# All callbacks are now reset to their default functions
|
787
1035
|
"""
|
788
1036
|
for event in callbacks.default_callbacks.keys():
|
789
1037
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
790
1038
|
|
791
1039
|
@staticmethod
|
792
1040
|
def _reset_ckpt_args(args: dict) -> dict:
|
793
|
-
"""
|
1041
|
+
"""
|
1042
|
+
Resets specific arguments when loading a PyTorch model checkpoint.
|
1043
|
+
|
1044
|
+
This static method filters the input arguments dictionary to retain only a specific set of keys that are
|
1045
|
+
considered important for model loading. It's used to ensure that only relevant arguments are preserved
|
1046
|
+
when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
|
1047
|
+
|
1048
|
+
Args:
|
1049
|
+
args (dict): A dictionary containing various model arguments and settings.
|
1050
|
+
|
1051
|
+
Returns:
|
1052
|
+
(dict): A new dictionary containing only the specified include keys from the input arguments.
|
1053
|
+
|
1054
|
+
Examples:
|
1055
|
+
>>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100}
|
1056
|
+
>>> reset_args = Model._reset_ckpt_args(original_args)
|
1057
|
+
>>> print(reset_args)
|
1058
|
+
{'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}
|
1059
|
+
"""
|
794
1060
|
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
795
1061
|
return {k: v for k, v in args.items() if k in include}
|
796
1062
|
|
@@ -800,7 +1066,31 @@ class Model(nn.Module):
|
|
800
1066
|
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
801
1067
|
|
802
1068
|
def _smart_load(self, key: str):
|
803
|
-
"""
|
1069
|
+
"""
|
1070
|
+
Loads the appropriate module based on the model task.
|
1071
|
+
|
1072
|
+
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
|
1073
|
+
based on the current task of the model and the provided key. It uses the task_map attribute to determine
|
1074
|
+
the correct module to load.
|
1075
|
+
|
1076
|
+
Args:
|
1077
|
+
key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
|
1078
|
+
|
1079
|
+
Returns:
|
1080
|
+
(object): The loaded module corresponding to the specified key and current task.
|
1081
|
+
|
1082
|
+
Raises:
|
1083
|
+
NotImplementedError: If the specified key is not supported for the current task.
|
1084
|
+
|
1085
|
+
Examples:
|
1086
|
+
>>> model = Model(task="detect")
|
1087
|
+
>>> predictor = model._smart_load("predictor")
|
1088
|
+
>>> trainer = model._smart_load("trainer")
|
1089
|
+
|
1090
|
+
Notes:
|
1091
|
+
- This method is typically used internally by other methods of the Model class.
|
1092
|
+
- The task_map attribute should be properly initialized with the correct mappings for each task.
|
1093
|
+
"""
|
804
1094
|
try:
|
805
1095
|
return self.task_map[self.task][key]
|
806
1096
|
except Exception as e:
|
@@ -813,9 +1103,71 @@ class Model(nn.Module):
|
|
813
1103
|
@property
|
814
1104
|
def task_map(self) -> dict:
|
815
1105
|
"""
|
816
|
-
|
1106
|
+
Provides a mapping from model tasks to corresponding classes for different modes.
|
1107
|
+
|
1108
|
+
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
|
1109
|
+
to a nested dictionary. The nested dictionary contains mappings for different operational modes
|
1110
|
+
(model, trainer, validator, predictor) to their respective class implementations.
|
1111
|
+
|
1112
|
+
The mapping allows for dynamic loading of appropriate classes based on the model's task and the
|
1113
|
+
desired operational mode. This facilitates a flexible and extensible architecture for handling
|
1114
|
+
various tasks and modes within the Ultralytics framework.
|
817
1115
|
|
818
1116
|
Returns:
|
819
|
-
|
1117
|
+
(Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
|
1118
|
+
nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
|
1119
|
+
'predictor', mapping to their respective class implementations.
|
1120
|
+
|
1121
|
+
Examples:
|
1122
|
+
>>> model = Model()
|
1123
|
+
>>> task_map = model.task_map
|
1124
|
+
>>> detect_class_map = task_map["detect"]
|
1125
|
+
>>> segment_class_map = task_map["segment"]
|
1126
|
+
|
1127
|
+
Note:
|
1128
|
+
The actual implementation of this method may vary depending on the specific tasks and
|
1129
|
+
classes supported by the Ultralytics framework. The docstring provides a general
|
1130
|
+
description of the expected behavior and structure.
|
820
1131
|
"""
|
821
1132
|
raise NotImplementedError("Please provide task map for your model!")
|
1133
|
+
|
1134
|
+
def eval(self):
|
1135
|
+
"""
|
1136
|
+
Sets the model to evaluation mode.
|
1137
|
+
|
1138
|
+
This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization
|
1139
|
+
that behave differently during training and evaluation.
|
1140
|
+
|
1141
|
+
Returns:
|
1142
|
+
(Model): The model instance with evaluation mode set.
|
1143
|
+
|
1144
|
+
Examples:
|
1145
|
+
>> model = YOLO("yolo11n.pt")
|
1146
|
+
>> model.eval()
|
1147
|
+
"""
|
1148
|
+
self.model.eval()
|
1149
|
+
return self
|
1150
|
+
|
1151
|
+
def __getattr__(self, name):
|
1152
|
+
"""
|
1153
|
+
Enables accessing model attributes directly through the Model class.
|
1154
|
+
|
1155
|
+
This method provides a way to access attributes of the underlying model directly through the Model class
|
1156
|
+
instance. It first checks if the requested attribute is 'model', in which case it returns the model from
|
1157
|
+
the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model.
|
1158
|
+
|
1159
|
+
Args:
|
1160
|
+
name (str): The name of the attribute to retrieve.
|
1161
|
+
|
1162
|
+
Returns:
|
1163
|
+
(Any): The requested attribute value.
|
1164
|
+
|
1165
|
+
Raises:
|
1166
|
+
AttributeError: If the requested attribute does not exist in the model.
|
1167
|
+
|
1168
|
+
Examples:
|
1169
|
+
>>> model = YOLO("yolo11n.pt")
|
1170
|
+
>>> print(model.stride)
|
1171
|
+
>>> print(model.task)
|
1172
|
+
"""
|
1173
|
+
return self._modules["model"] if name == "model" else getattr(self.model, name)
|