dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
4
9
|
|
|
5
10
|
from ultralytics.data.build import load_inference_source
|
|
6
11
|
from ultralytics.engine.model import Model
|
|
@@ -19,19 +24,42 @@ from ultralytics.utils import ROOT, YAML
|
|
|
19
24
|
|
|
20
25
|
|
|
21
26
|
class YOLO(Model):
|
|
22
|
-
"""YOLO (You Only Look Once) object detection model.
|
|
27
|
+
"""YOLO (You Only Look Once) object detection model.
|
|
23
28
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
29
|
+
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
|
30
|
+
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
|
31
|
+
detection, segmentation, classification, pose estimation, and oriented bounding box detection.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
model: The loaded YOLO model instance.
|
|
35
|
+
task: The task type (detect, segment, classify, pose, obb).
|
|
36
|
+
overrides: Configuration overrides for the model.
|
|
37
|
+
|
|
38
|
+
Methods:
|
|
39
|
+
__init__: Initialize a YOLO model with automatic type detection.
|
|
40
|
+
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
|
41
|
+
|
|
42
|
+
Examples:
|
|
43
|
+
Load a pretrained YOLOv11n detection model
|
|
44
|
+
>>> model = YOLO("yolo11n.pt")
|
|
45
|
+
|
|
46
|
+
Load a pretrained YOLO11n segmentation model
|
|
47
|
+
>>> model = YOLO("yolo11n-seg.pt")
|
|
48
|
+
|
|
49
|
+
Initialize from a YAML configuration
|
|
50
|
+
>>> model = YOLO("yolo11n.yaml")
|
|
51
|
+
"""
|
|
27
52
|
|
|
28
|
-
|
|
29
|
-
|
|
53
|
+
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
|
|
54
|
+
"""Initialize a YOLO model.
|
|
55
|
+
|
|
56
|
+
This constructor initializes a YOLO model, automatically switching to specialized model types (YOLOWorld or
|
|
57
|
+
YOLOE) based on the model filename.
|
|
30
58
|
|
|
31
59
|
Args:
|
|
32
60
|
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
|
|
33
|
-
task (str
|
|
34
|
-
|
|
61
|
+
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'. Defaults
|
|
62
|
+
to auto-detection based on model.
|
|
35
63
|
verbose (bool): Display model info on load.
|
|
36
64
|
|
|
37
65
|
Examples:
|
|
@@ -39,7 +67,7 @@ class YOLO(Model):
|
|
|
39
67
|
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
|
|
40
68
|
>>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
|
|
41
69
|
"""
|
|
42
|
-
path = Path(model)
|
|
70
|
+
path = Path(model if isinstance(model, (str, Path)) else "")
|
|
43
71
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
|
44
72
|
new_instance = YOLOWorld(path, verbose=verbose)
|
|
45
73
|
self.__class__ = type(new_instance)
|
|
@@ -51,9 +79,15 @@ class YOLO(Model):
|
|
|
51
79
|
else:
|
|
52
80
|
# Continue with default YOLO initialization
|
|
53
81
|
super().__init__(model=model, task=task, verbose=verbose)
|
|
82
|
+
if hasattr(self.model, "model") and "RTDETR" in self.model.model[-1]._get_name(): # if RTDETR head
|
|
83
|
+
from ultralytics import RTDETR
|
|
84
|
+
|
|
85
|
+
new_instance = RTDETR(self)
|
|
86
|
+
self.__class__ = type(new_instance)
|
|
87
|
+
self.__dict__ = new_instance.__dict__
|
|
54
88
|
|
|
55
89
|
@property
|
|
56
|
-
def task_map(self):
|
|
90
|
+
def task_map(self) -> dict[str, dict[str, Any]]:
|
|
57
91
|
"""Map head to model, trainer, validator, and predictor classes."""
|
|
58
92
|
return {
|
|
59
93
|
"classify": {
|
|
@@ -90,14 +124,35 @@ class YOLO(Model):
|
|
|
90
124
|
|
|
91
125
|
|
|
92
126
|
class YOLOWorld(Model):
|
|
93
|
-
"""YOLO-World object detection model.
|
|
127
|
+
"""YOLO-World object detection model.
|
|
94
128
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
129
|
+
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions without
|
|
130
|
+
requiring training on specific classes. It extends the YOLO architecture to support real-time open-vocabulary
|
|
131
|
+
detection.
|
|
132
|
+
|
|
133
|
+
Attributes:
|
|
134
|
+
model: The loaded YOLO-World model instance.
|
|
135
|
+
task: Always set to 'detect' for object detection.
|
|
136
|
+
overrides: Configuration overrides for the model.
|
|
137
|
+
|
|
138
|
+
Methods:
|
|
139
|
+
__init__: Initialize YOLOv8-World model with a pre-trained model file.
|
|
140
|
+
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
|
141
|
+
set_classes: Set the model's class names for detection.
|
|
142
|
+
|
|
143
|
+
Examples:
|
|
144
|
+
Load a YOLOv8-World model
|
|
145
|
+
>>> model = YOLOWorld("yolov8s-world.pt")
|
|
146
|
+
|
|
147
|
+
Set custom classes for detection
|
|
148
|
+
>>> model.set_classes(["person", "car", "bicycle"])
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
|
|
152
|
+
"""Initialize YOLOv8-World model with a pre-trained model file.
|
|
98
153
|
|
|
99
|
-
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
|
|
100
|
-
|
|
154
|
+
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default COCO
|
|
155
|
+
class names.
|
|
101
156
|
|
|
102
157
|
Args:
|
|
103
158
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -110,7 +165,7 @@ class YOLOWorld(Model):
|
|
|
110
165
|
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
|
111
166
|
|
|
112
167
|
@property
|
|
113
|
-
def task_map(self):
|
|
168
|
+
def task_map(self) -> dict[str, dict[str, Any]]:
|
|
114
169
|
"""Map head to model, validator, and predictor classes."""
|
|
115
170
|
return {
|
|
116
171
|
"detect": {
|
|
@@ -121,9 +176,8 @@ class YOLOWorld(Model):
|
|
|
121
176
|
}
|
|
122
177
|
}
|
|
123
178
|
|
|
124
|
-
def set_classes(self, classes):
|
|
125
|
-
"""
|
|
126
|
-
Set the model's class names for detection.
|
|
179
|
+
def set_classes(self, classes: list[str]) -> None:
|
|
180
|
+
"""Set the model's class names for detection.
|
|
127
181
|
|
|
128
182
|
Args:
|
|
129
183
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
@@ -141,11 +195,41 @@ class YOLOWorld(Model):
|
|
|
141
195
|
|
|
142
196
|
|
|
143
197
|
class YOLOE(Model):
|
|
144
|
-
"""YOLOE object detection and segmentation model.
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
198
|
+
"""YOLOE object detection and segmentation model.
|
|
199
|
+
|
|
200
|
+
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with improved
|
|
201
|
+
performance and additional features like visual and text positional embeddings.
|
|
202
|
+
|
|
203
|
+
Attributes:
|
|
204
|
+
model: The loaded YOLOE model instance.
|
|
205
|
+
task: The task type (detect or segment).
|
|
206
|
+
overrides: Configuration overrides for the model.
|
|
207
|
+
|
|
208
|
+
Methods:
|
|
209
|
+
__init__: Initialize YOLOE model with a pre-trained model file.
|
|
210
|
+
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
|
211
|
+
get_text_pe: Get text positional embeddings for the given texts.
|
|
212
|
+
get_visual_pe: Get visual positional embeddings for the given image and visual features.
|
|
213
|
+
set_vocab: Set vocabulary and class names for the YOLOE model.
|
|
214
|
+
get_vocab: Get vocabulary for the given class names.
|
|
215
|
+
set_classes: Set the model's class names and embeddings for detection.
|
|
216
|
+
val: Validate the model using text or visual prompts.
|
|
217
|
+
predict: Run prediction on images, videos, directories, streams, etc.
|
|
218
|
+
|
|
219
|
+
Examples:
|
|
220
|
+
Load a YOLOE detection model
|
|
221
|
+
>>> model = YOLOE("yoloe-11s-seg.pt")
|
|
222
|
+
|
|
223
|
+
Set vocabulary and class names
|
|
224
|
+
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
|
225
|
+
|
|
226
|
+
Predict with visual prompts
|
|
227
|
+
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
|
228
|
+
>>> results = model.predict("image.jpg", visual_prompts=prompts)
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
|
|
232
|
+
"""Initialize YOLOE model with a pre-trained model file.
|
|
149
233
|
|
|
150
234
|
Args:
|
|
151
235
|
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
|
@@ -154,12 +238,8 @@ class YOLOE(Model):
|
|
|
154
238
|
"""
|
|
155
239
|
super().__init__(model=model, task=task, verbose=verbose)
|
|
156
240
|
|
|
157
|
-
# Assign default COCO class names when there are no custom names
|
|
158
|
-
if not hasattr(self.model, "names"):
|
|
159
|
-
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
|
160
|
-
|
|
161
241
|
@property
|
|
162
|
-
def task_map(self):
|
|
242
|
+
def task_map(self) -> dict[str, dict[str, Any]]:
|
|
163
243
|
"""Map head to model, validator, and predictor classes."""
|
|
164
244
|
return {
|
|
165
245
|
"detect": {
|
|
@@ -182,11 +262,10 @@ class YOLOE(Model):
|
|
|
182
262
|
return self.model.get_text_pe(texts)
|
|
183
263
|
|
|
184
264
|
def get_visual_pe(self, img, visual):
|
|
185
|
-
"""
|
|
186
|
-
Get visual positional embeddings for the given image and visual features.
|
|
265
|
+
"""Get visual positional embeddings for the given image and visual features.
|
|
187
266
|
|
|
188
|
-
This method extracts positional embeddings from visual features based on the input image. It requires
|
|
189
|
-
|
|
267
|
+
This method extracts positional embeddings from visual features based on the input image. It requires that the
|
|
268
|
+
model is an instance of YOLOEModel.
|
|
190
269
|
|
|
191
270
|
Args:
|
|
192
271
|
img (torch.Tensor): Input image tensor.
|
|
@@ -198,22 +277,21 @@ class YOLOE(Model):
|
|
|
198
277
|
Examples:
|
|
199
278
|
>>> model = YOLOE("yoloe-11s-seg.pt")
|
|
200
279
|
>>> img = torch.rand(1, 3, 640, 640)
|
|
201
|
-
>>> visual_features =
|
|
280
|
+
>>> visual_features = torch.rand(1, 1, 80, 80)
|
|
202
281
|
>>> pe = model.get_visual_pe(img, visual_features)
|
|
203
282
|
"""
|
|
204
283
|
assert isinstance(self.model, YOLOEModel)
|
|
205
284
|
return self.model.get_visual_pe(img, visual)
|
|
206
285
|
|
|
207
|
-
def set_vocab(self, vocab, names):
|
|
208
|
-
"""
|
|
209
|
-
Set vocabulary and class names for the YOLOE model.
|
|
286
|
+
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
|
|
287
|
+
"""Set vocabulary and class names for the YOLOE model.
|
|
210
288
|
|
|
211
|
-
This method configures the vocabulary and class names used by the model for text processing and
|
|
212
|
-
|
|
289
|
+
This method configures the vocabulary and class names used by the model for text processing and classification
|
|
290
|
+
tasks. The model must be an instance of YOLOEModel.
|
|
213
291
|
|
|
214
292
|
Args:
|
|
215
|
-
vocab (list): Vocabulary list containing tokens or words used by the model for text processing.
|
|
216
|
-
names (list): List of class names that the model can detect or classify.
|
|
293
|
+
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
|
294
|
+
names (list[str]): List of class names that the model can detect or classify.
|
|
217
295
|
|
|
218
296
|
Raises:
|
|
219
297
|
AssertionError: If the model is not an instance of YOLOEModel.
|
|
@@ -230,15 +308,16 @@ class YOLOE(Model):
|
|
|
230
308
|
assert isinstance(self.model, YOLOEModel)
|
|
231
309
|
return self.model.get_vocab(names)
|
|
232
310
|
|
|
233
|
-
def set_classes(self, classes, embeddings):
|
|
234
|
-
"""
|
|
235
|
-
Set the model's class names and embeddings for detection.
|
|
311
|
+
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
|
|
312
|
+
"""Set the model's class names and embeddings for detection.
|
|
236
313
|
|
|
237
314
|
Args:
|
|
238
315
|
classes (list[str]): A list of categories i.e. ["person"].
|
|
239
316
|
embeddings (torch.Tensor): Embeddings corresponding to the classes.
|
|
240
317
|
"""
|
|
241
318
|
assert isinstance(self.model, YOLOEModel)
|
|
319
|
+
if embeddings is None:
|
|
320
|
+
embeddings = self.get_text_pe(classes) # generate text embeddings if not provided
|
|
242
321
|
self.model.set_classes(classes, embeddings)
|
|
243
322
|
# Verify no background class is present
|
|
244
323
|
assert " " not in classes
|
|
@@ -251,12 +330,11 @@ class YOLOE(Model):
|
|
|
251
330
|
def val(
|
|
252
331
|
self,
|
|
253
332
|
validator=None,
|
|
254
|
-
load_vp=False,
|
|
255
|
-
refer_data=None,
|
|
333
|
+
load_vp: bool = False,
|
|
334
|
+
refer_data: str | None = None,
|
|
256
335
|
**kwargs,
|
|
257
336
|
):
|
|
258
|
-
"""
|
|
259
|
-
Validate the model using text or visual prompts.
|
|
337
|
+
"""Validate the model using text or visual prompts.
|
|
260
338
|
|
|
261
339
|
Args:
|
|
262
340
|
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
|
@@ -279,28 +357,27 @@ class YOLOE(Model):
|
|
|
279
357
|
self,
|
|
280
358
|
source=None,
|
|
281
359
|
stream: bool = False,
|
|
282
|
-
visual_prompts: dict = {},
|
|
360
|
+
visual_prompts: dict[str, list] = {},
|
|
283
361
|
refer_image=None,
|
|
284
|
-
predictor=
|
|
362
|
+
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
|
|
285
363
|
**kwargs,
|
|
286
364
|
):
|
|
287
|
-
"""
|
|
288
|
-
Run prediction on images, videos, directories, streams, etc.
|
|
365
|
+
"""Run prediction on images, videos, directories, streams, etc.
|
|
289
366
|
|
|
290
367
|
Args:
|
|
291
|
-
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
|
292
|
-
|
|
293
|
-
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
|
294
|
-
|
|
295
|
-
visual_prompts (dict): Dictionary containing visual prompts for the model. Must include 'bboxes'
|
|
296
|
-
'cls' keys when non-empty.
|
|
368
|
+
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths, directory
|
|
369
|
+
paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
|
370
|
+
stream (bool): Whether to stream the prediction results. If True, results are yielded as a generator as they
|
|
371
|
+
are computed.
|
|
372
|
+
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include 'bboxes'
|
|
373
|
+
and 'cls' keys when non-empty.
|
|
297
374
|
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
|
298
|
-
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
|
299
|
-
|
|
375
|
+
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically loaded
|
|
376
|
+
based on the task.
|
|
300
377
|
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
|
301
378
|
|
|
302
379
|
Returns:
|
|
303
|
-
(
|
|
380
|
+
(list | generator): List of Results objects or generator of Results objects if stream=True.
|
|
304
381
|
|
|
305
382
|
Examples:
|
|
306
383
|
>>> model = YOLOE("yoloe-11s-seg.pt")
|
|
@@ -317,18 +394,21 @@ class YOLOE(Model):
|
|
|
317
394
|
f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
|
|
318
395
|
f"{len(visual_prompts['cls'])} respectively"
|
|
319
396
|
)
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
397
|
+
if type(self.predictor) is not predictor:
|
|
398
|
+
self.predictor = predictor(
|
|
399
|
+
overrides={
|
|
400
|
+
"task": self.model.task,
|
|
401
|
+
"mode": "predict",
|
|
402
|
+
"save": False,
|
|
403
|
+
"verbose": refer_image is None,
|
|
404
|
+
"batch": 1,
|
|
405
|
+
"device": kwargs.get("device", None),
|
|
406
|
+
"half": kwargs.get("half", False),
|
|
407
|
+
"imgsz": kwargs.get("imgsz", self.overrides["imgsz"]),
|
|
408
|
+
},
|
|
409
|
+
_callbacks=self.callbacks,
|
|
410
|
+
)
|
|
330
411
|
|
|
331
|
-
if len(visual_prompts):
|
|
332
412
|
num_cls = (
|
|
333
413
|
max(len(set(c)) for c in visual_prompts["cls"])
|
|
334
414
|
if isinstance(source, list) and refer_image is None # means multiple images
|
|
@@ -337,18 +417,19 @@ class YOLOE(Model):
|
|
|
337
417
|
self.model.model[-1].nc = num_cls
|
|
338
418
|
self.model.names = [f"object{i}" for i in range(num_cls)]
|
|
339
419
|
self.predictor.set_prompts(visual_prompts.copy())
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
420
|
+
self.predictor.setup_model(model=self.model)
|
|
421
|
+
|
|
422
|
+
if refer_image is None and source is not None:
|
|
423
|
+
dataset = load_inference_source(source)
|
|
424
|
+
if dataset.mode in {"video", "stream"}:
|
|
425
|
+
# NOTE: set the first frame as refer image for videos/streams inference
|
|
426
|
+
refer_image = next(iter(dataset))[1][0]
|
|
427
|
+
if refer_image is not None:
|
|
428
|
+
vpe = self.predictor.get_vpe(refer_image)
|
|
429
|
+
self.model.set_classes(self.model.names, vpe)
|
|
430
|
+
self.task = "segment" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else "detect"
|
|
431
|
+
self.predictor = None # reset predictor
|
|
432
|
+
elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
|
|
433
|
+
self.predictor = None # reset predictor if no visual prompts
|
|
353
434
|
|
|
354
435
|
return super().predict(source, stream, **kwargs)
|
|
@@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class OBBPredictor(DetectionPredictor):
|
|
11
|
-
"""
|
|
12
|
-
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
11
|
+
"""A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
13
12
|
|
|
14
13
|
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
|
15
14
|
bounding boxes.
|
|
@@ -27,10 +26,7 @@ class OBBPredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize OBBPredictor with optional model and data configuration overrides.
|
|
32
|
-
|
|
33
|
-
This constructor sets up an OBBPredictor instance for oriented bounding box detection tasks.
|
|
29
|
+
"""Initialize OBBPredictor with optional model and data configuration overrides.
|
|
34
30
|
|
|
35
31
|
Args:
|
|
36
32
|
cfg (dict, optional): Default configuration for the predictor.
|
|
@@ -47,18 +43,18 @@ class OBBPredictor(DetectionPredictor):
|
|
|
47
43
|
self.args.task = "obb"
|
|
48
44
|
|
|
49
45
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
50
|
-
"""
|
|
51
|
-
Construct the result object from the prediction.
|
|
46
|
+
"""Construct the result object from the prediction.
|
|
52
47
|
|
|
53
48
|
Args:
|
|
54
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N,
|
|
55
|
-
|
|
49
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
|
|
50
|
+
last dimension contains [x, y, w, h, confidence, class_id, angle].
|
|
56
51
|
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
|
57
52
|
orig_img (np.ndarray): The original image before preprocessing.
|
|
58
53
|
img_path (str): The path to the original image.
|
|
59
54
|
|
|
60
55
|
Returns:
|
|
61
|
-
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
|
56
|
+
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
|
57
|
+
boxes.
|
|
62
58
|
"""
|
|
63
59
|
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
|
|
64
60
|
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from copy import copy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
4
8
|
|
|
5
9
|
from ultralytics.models import yolo
|
|
6
10
|
from ultralytics.nn.tasks import OBBModel
|
|
@@ -8,11 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
11
|
-
"""
|
|
12
|
-
|
|
15
|
+
"""A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
16
|
+
|
|
17
|
+
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
|
|
18
|
+
objects at arbitrary angles rather than just axis-aligned rectangles.
|
|
13
19
|
|
|
14
20
|
Attributes:
|
|
15
|
-
loss_names (
|
|
21
|
+
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
|
|
22
|
+
dfl_loss.
|
|
16
23
|
|
|
17
24
|
Methods:
|
|
18
25
|
get_model: Return OBBModel initialized with specified config and weights.
|
|
@@ -25,39 +32,30 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
25
32
|
>>> trainer.train()
|
|
26
33
|
"""
|
|
27
34
|
|
|
28
|
-
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
29
|
-
"""
|
|
30
|
-
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
31
|
-
|
|
32
|
-
This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
|
|
33
|
-
bounding boxes. It automatically sets the task to 'obb' in the configuration.
|
|
35
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
|
36
|
+
"""Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
34
37
|
|
|
35
38
|
Args:
|
|
36
|
-
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
|
37
|
-
|
|
38
|
-
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
|
39
|
-
|
|
40
|
-
_callbacks (list, optional): List of callback functions to be invoked during training.
|
|
41
|
-
|
|
42
|
-
Examples:
|
|
43
|
-
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
|
44
|
-
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
|
45
|
-
>>> trainer = OBBTrainer(overrides=args)
|
|
46
|
-
>>> trainer.train()
|
|
39
|
+
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
|
|
40
|
+
configuration.
|
|
41
|
+
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
|
|
42
|
+
take precedence over those in cfg.
|
|
43
|
+
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
|
47
44
|
"""
|
|
48
45
|
if overrides is None:
|
|
49
46
|
overrides = {}
|
|
50
47
|
overrides["task"] = "obb"
|
|
51
48
|
super().__init__(cfg, overrides, _callbacks)
|
|
52
49
|
|
|
53
|
-
def get_model(
|
|
54
|
-
|
|
55
|
-
|
|
50
|
+
def get_model(
|
|
51
|
+
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
|
52
|
+
) -> OBBModel:
|
|
53
|
+
"""Return OBBModel initialized with specified config and weights.
|
|
56
54
|
|
|
57
55
|
Args:
|
|
58
|
-
cfg (str | dict
|
|
56
|
+
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
|
59
57
|
containing configuration parameters, or None to use default configuration.
|
|
60
|
-
weights (str | Path
|
|
58
|
+
weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
|
|
61
59
|
verbose (bool): Whether to display model information during initialization.
|
|
62
60
|
|
|
63
61
|
Returns:
|