dgenerate-ultralytics-headless 8.3.135__py3-none-any.whl → 8.3.138__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.135.dist-info → dgenerate_ultralytics_headless-8.3.138.dist-info}/METADATA +1 -2
- {dgenerate_ultralytics_headless-8.3.135.dist-info → dgenerate_ultralytics_headless-8.3.138.dist-info}/RECORD +40 -40
- tests/test_cuda.py +2 -7
- tests/test_exports.py +1 -6
- tests/test_solutions.py +181 -8
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +4 -4
- ultralytics/data/base.py +1 -1
- ultralytics/data/build.py +4 -3
- ultralytics/data/loaders.py +2 -2
- ultralytics/engine/exporter.py +6 -7
- ultralytics/engine/model.py +2 -2
- ultralytics/engine/predictor.py +3 -10
- ultralytics/engine/trainer.py +1 -1
- ultralytics/engine/validator.py +1 -1
- ultralytics/hub/auth.py +2 -2
- ultralytics/hub/utils.py +8 -3
- ultralytics/models/yolo/classify/predict.py +11 -0
- ultralytics/models/yolo/obb/val.py +1 -1
- ultralytics/models/yolo/world/train.py +66 -20
- ultralytics/models/yolo/world/train_world.py +1 -0
- ultralytics/models/yolo/yoloe/train.py +10 -39
- ultralytics/models/yolo/yoloe/val.py +3 -3
- ultralytics/nn/tasks.py +41 -24
- ultralytics/nn/text_model.py +1 -0
- ultralytics/solutions/similarity_search.py +3 -6
- ultralytics/solutions/streamlit_inference.py +1 -1
- ultralytics/utils/__init__.py +1 -1
- ultralytics/utils/callbacks/hub.py +5 -4
- ultralytics/utils/checks.py +16 -13
- ultralytics/utils/downloads.py +7 -5
- ultralytics/utils/export.py +1 -1
- ultralytics/utils/metrics.py +51 -22
- ultralytics/utils/plotting.py +19 -13
- ultralytics/utils/torch_utils.py +3 -0
- ultralytics/utils/triton.py +1 -1
- {dgenerate_ultralytics_headless-8.3.135.dist-info → dgenerate_ultralytics_headless-8.3.138.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.135.dist-info → dgenerate_ultralytics_headless-8.3.138.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.135.dist-info → dgenerate_ultralytics_headless-8.3.138.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.135.dist-info → dgenerate_ultralytics_headless-8.3.138.dist-info}/top_level.txt +0 -0
ultralytics/engine/model.py
CHANGED
@@ -288,7 +288,7 @@ class Model(torch.nn.Module):
|
|
288
288
|
weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file
|
289
289
|
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolo11n -> yolo11n.pt
|
290
290
|
|
291
|
-
if
|
291
|
+
if str(weights).rpartition(".")[-1] == "pt":
|
292
292
|
self.model, self.ckpt = attempt_load_one_weight(weights)
|
293
293
|
self.task = self.model.args["task"]
|
294
294
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
@@ -319,7 +319,7 @@ class Model(torch.nn.Module):
|
|
319
319
|
>>> model = Model("yolo11n.onnx")
|
320
320
|
>>> model._check_is_pytorch_model() # Raises TypeError
|
321
321
|
"""
|
322
|
-
pt_str = isinstance(self.model, (str, Path)) and
|
322
|
+
pt_str = isinstance(self.model, (str, Path)) and str(self.model).rpartition(".")[-1] == "pt"
|
323
323
|
pt_module = isinstance(self.model, torch.nn.Module)
|
324
324
|
if not (pt_module or pt_str):
|
325
325
|
raise TypeError(
|
ultralytics/engine/predictor.py
CHANGED
@@ -43,7 +43,7 @@ import torch
|
|
43
43
|
|
44
44
|
from ultralytics.cfg import get_cfg, get_save_dir
|
45
45
|
from ultralytics.data import load_inference_source
|
46
|
-
from ultralytics.data.augment import LetterBox
|
46
|
+
from ultralytics.data.augment import LetterBox
|
47
47
|
from ultralytics.nn.autobackend import AutoBackend
|
48
48
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
|
49
49
|
from ultralytics.utils.checks import check_imgsz, check_imshow
|
@@ -247,15 +247,6 @@ class BasePredictor:
|
|
247
247
|
Source for inference.
|
248
248
|
"""
|
249
249
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
250
|
-
self.transforms = (
|
251
|
-
getattr(
|
252
|
-
self.model.model,
|
253
|
-
"transforms",
|
254
|
-
classify_transforms(self.imgsz[0]),
|
255
|
-
)
|
256
|
-
if self.args.task == "classify"
|
257
|
-
else None
|
258
|
-
)
|
259
250
|
self.dataset = load_inference_source(
|
260
251
|
source=source,
|
261
252
|
batch=self.args.batch,
|
@@ -395,6 +386,8 @@ class BasePredictor:
|
|
395
386
|
|
396
387
|
self.device = self.model.device # update device
|
397
388
|
self.args.half = self.model.fp16 # update half
|
389
|
+
if hasattr(self.model, "imgsz"):
|
390
|
+
self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
|
398
391
|
self.model.eval()
|
399
392
|
|
400
393
|
def write_results(self, i, p, im, s):
|
ultralytics/engine/trainer.py
CHANGED
@@ -578,7 +578,7 @@ class BaseTrainer:
|
|
578
578
|
try:
|
579
579
|
if self.args.task == "classify":
|
580
580
|
data = check_cls_dataset(self.args.data)
|
581
|
-
elif self.args.data.
|
581
|
+
elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
582
582
|
"detect",
|
583
583
|
"segment",
|
584
584
|
"pose",
|
ultralytics/engine/validator.py
CHANGED
@@ -175,7 +175,7 @@ class BaseValidator:
|
|
175
175
|
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
|
176
176
|
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
|
177
177
|
|
178
|
-
if str(self.args.data).
|
178
|
+
if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
|
179
179
|
self.data = check_det_dataset(self.args.data)
|
180
180
|
elif self.args.task == "classify":
|
181
181
|
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
ultralytics/hub/auth.py
CHANGED
@@ -37,7 +37,7 @@ class Auth:
|
|
37
37
|
verbose (bool): Enable verbose logging.
|
38
38
|
"""
|
39
39
|
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
40
|
-
api_key = api_key.split("_")[0]
|
40
|
+
api_key = api_key.split("_", 1)[0]
|
41
41
|
|
42
42
|
# Set API key attribute as value passed or SETTINGS API key if none passed
|
43
43
|
self.api_key = api_key or SETTINGS.get("api_key", "")
|
@@ -77,7 +77,7 @@ class Auth:
|
|
77
77
|
for attempts in range(max_attempts):
|
78
78
|
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
|
79
79
|
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
|
80
|
-
self.api_key = input_key.split("_")[0] # remove model id if present
|
80
|
+
self.api_key = input_key.split("_", 1)[0] # remove model id if present
|
81
81
|
if self.authenticate():
|
82
82
|
return True
|
83
83
|
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
ultralytics/hub/utils.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import os
|
4
|
-
import platform
|
5
4
|
import random
|
6
5
|
import threading
|
7
6
|
import time
|
@@ -18,6 +17,7 @@ from ultralytics.utils import (
|
|
18
17
|
IS_PIP_PACKAGE,
|
19
18
|
LOGGER,
|
20
19
|
ONLINE,
|
20
|
+
PYTHON_VERSION,
|
21
21
|
RANK,
|
22
22
|
SETTINGS,
|
23
23
|
TESTS_RUNNING,
|
@@ -27,6 +27,7 @@ from ultralytics.utils import (
|
|
27
27
|
get_git_origin_url,
|
28
28
|
)
|
29
29
|
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
30
|
+
from ultralytics.utils.torch_utils import get_cpu_info
|
30
31
|
|
31
32
|
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
|
32
33
|
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
|
@@ -191,7 +192,9 @@ class Events:
|
|
191
192
|
self.metadata = {
|
192
193
|
"cli": Path(ARGV[0]).name == "yolo",
|
193
194
|
"install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",
|
194
|
-
"python": "."
|
195
|
+
"python": PYTHON_VERSION.rsplit(".", 1)[0], # i.e. 3.13
|
196
|
+
"CPU": get_cpu_info(),
|
197
|
+
# "GPU": get_gpu_info(index=0) if cuda else None,
|
195
198
|
"version": __version__,
|
196
199
|
"env": ENVIRONMENT,
|
197
200
|
"session_id": round(random.random() * 1e15),
|
@@ -205,12 +208,13 @@ class Events:
|
|
205
208
|
and (IS_PIP_PACKAGE or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
206
209
|
)
|
207
210
|
|
208
|
-
def __call__(self, cfg):
|
211
|
+
def __call__(self, cfg, device=None):
|
209
212
|
"""
|
210
213
|
Attempt to add a new event to the events list and send events if the rate limit is reached.
|
211
214
|
|
212
215
|
Args:
|
213
216
|
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
|
217
|
+
device (torch.device | str): The device type (e.g., 'cpu', 'cuda').
|
214
218
|
"""
|
215
219
|
if not self.enabled:
|
216
220
|
# Events disabled, do nothing
|
@@ -222,6 +226,7 @@ class Events:
|
|
222
226
|
**self.metadata,
|
223
227
|
"task": cfg.task,
|
224
228
|
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
|
229
|
+
"device": str(device),
|
225
230
|
}
|
226
231
|
if cfg.mode == "export":
|
227
232
|
params["format"] = cfg.format
|
@@ -4,6 +4,7 @@ import cv2
|
|
4
4
|
import torch
|
5
5
|
from PIL import Image
|
6
6
|
|
7
|
+
from ultralytics.data.augment import classify_transforms
|
7
8
|
from ultralytics.engine.predictor import BasePredictor
|
8
9
|
from ultralytics.engine.results import Results
|
9
10
|
from ultralytics.utils import DEFAULT_CFG, ops
|
@@ -51,6 +52,16 @@ class ClassificationPredictor(BasePredictor):
|
|
51
52
|
self.args.task = "classify"
|
52
53
|
self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
|
53
54
|
|
55
|
+
def setup_source(self, source):
|
56
|
+
"""Sets up source and inference mode and classify transforms."""
|
57
|
+
super().setup_source(source)
|
58
|
+
updated = (
|
59
|
+
self.model.model.transforms.transforms[0].size != max(self.imgsz)
|
60
|
+
if hasattr(self.model.model, "transforms")
|
61
|
+
else True
|
62
|
+
)
|
63
|
+
self.transforms = self.model.model.transforms if not updated else classify_transforms(self.imgsz)
|
64
|
+
|
54
65
|
def preprocess(self, img):
|
55
66
|
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
56
67
|
if not isinstance(img, torch.Tensor):
|
@@ -252,7 +252,7 @@ class OBBValidator(DetectionValidator):
|
|
252
252
|
merged_results = defaultdict(list)
|
253
253
|
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
|
254
254
|
for d in data:
|
255
|
-
image_id = d["image_id"].split("__")[0]
|
255
|
+
image_id = d["image_id"].split("__", 1)[0]
|
256
256
|
pattern = re.compile(r"\d+___\d+")
|
257
257
|
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
258
258
|
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
|
@@ -1,11 +1,14 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import itertools
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import torch
|
4
7
|
|
5
8
|
from ultralytics.data import build_yolo_dataset
|
6
|
-
from ultralytics.models import
|
9
|
+
from ultralytics.models.yolo.detect import DetectionTrainer
|
7
10
|
from ultralytics.nn.tasks import WorldModel
|
8
|
-
from ultralytics.utils import DEFAULT_CFG,
|
11
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
9
12
|
from ultralytics.utils.torch_utils import de_parallel
|
10
13
|
|
11
14
|
|
@@ -13,15 +16,11 @@ def on_pretrain_routine_end(trainer):
|
|
13
16
|
"""Callback to set up model classes and text encoder at the end of the pretrain routine."""
|
14
17
|
if RANK in {-1, 0}:
|
15
18
|
# Set class names for evaluation
|
16
|
-
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
19
|
+
names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
17
20
|
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
18
|
-
device = next(trainer.model.parameters()).device
|
19
|
-
trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
|
20
|
-
for p in trainer.text_model.parameters():
|
21
|
-
p.requires_grad_(False)
|
22
21
|
|
23
22
|
|
24
|
-
class WorldTrainer(
|
23
|
+
class WorldTrainer(DetectionTrainer):
|
25
24
|
"""
|
26
25
|
A class to fine-tune a world model on a close-set dataset.
|
27
26
|
|
@@ -54,14 +53,7 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
54
53
|
if overrides is None:
|
55
54
|
overrides = {}
|
56
55
|
super().__init__(cfg, overrides, _callbacks)
|
57
|
-
|
58
|
-
# Import and assign clip
|
59
|
-
try:
|
60
|
-
import clip
|
61
|
-
except ImportError:
|
62
|
-
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
63
|
-
import clip
|
64
|
-
self.clip = clip
|
56
|
+
self.text_embeddings = None
|
65
57
|
|
66
58
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
67
59
|
"""
|
@@ -102,18 +94,72 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
102
94
|
(Dataset): YOLO dataset configured for training or validation.
|
103
95
|
"""
|
104
96
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
105
|
-
|
97
|
+
dataset = build_yolo_dataset(
|
106
98
|
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
107
99
|
)
|
100
|
+
if mode == "train":
|
101
|
+
self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
|
102
|
+
return dataset
|
103
|
+
|
104
|
+
def set_text_embeddings(self, datasets, batch):
|
105
|
+
"""
|
106
|
+
Set text embeddings for datasets to accelerate training by caching category names.
|
107
|
+
|
108
|
+
This method collects unique category names from all datasets, then generates and caches text embeddings
|
109
|
+
for these categories to improve training efficiency.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
datasets (List[Dataset]): List of datasets from which to extract category names.
|
113
|
+
batch (int | None): Batch size used for processing.
|
114
|
+
|
115
|
+
Notes:
|
116
|
+
This method collects category names from datasets that have the 'category_names' attribute,
|
117
|
+
then uses the first dataset's image path to determine where to cache the generated text embeddings.
|
118
|
+
"""
|
119
|
+
text_embeddings = {}
|
120
|
+
for dataset in datasets:
|
121
|
+
if not hasattr(dataset, "category_names"):
|
122
|
+
continue
|
123
|
+
text_embeddings.update(
|
124
|
+
self.generate_text_embeddings(
|
125
|
+
list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent
|
126
|
+
)
|
127
|
+
)
|
128
|
+
self.text_embeddings = text_embeddings
|
129
|
+
|
130
|
+
def generate_text_embeddings(self, texts, batch, cache_dir):
|
131
|
+
"""
|
132
|
+
Generate text embeddings for a list of text samples.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
texts (List[str]): List of text samples to encode.
|
136
|
+
batch (int): Batch size for processing.
|
137
|
+
cache_dir (Path): Directory to save/load cached embeddings.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
(dict): Dictionary mapping text samples to their embeddings.
|
141
|
+
"""
|
142
|
+
model = "clip:ViT-B/32"
|
143
|
+
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
144
|
+
if cache_path.exists():
|
145
|
+
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
146
|
+
txt_map = torch.load(cache_path)
|
147
|
+
if sorted(txt_map.keys()) == sorted(texts):
|
148
|
+
return txt_map
|
149
|
+
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
150
|
+
assert self.model is not None
|
151
|
+
txt_feats = self.model.get_text_pe(texts, batch, cache_clip_model=False)
|
152
|
+
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
153
|
+
torch.save(txt_map, cache_path)
|
154
|
+
return txt_map
|
108
155
|
|
109
156
|
def preprocess_batch(self, batch):
|
110
157
|
"""Preprocess a batch of images and text for YOLOWorld training."""
|
111
|
-
batch =
|
158
|
+
batch = DetectionTrainer.preprocess_batch(self, batch)
|
112
159
|
|
113
160
|
# Add text features
|
114
161
|
texts = list(itertools.chain(*batch["texts"]))
|
115
|
-
|
116
|
-
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|
162
|
+
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
|
117
163
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
118
164
|
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
119
165
|
return batch
|
@@ -100,6 +100,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
100
100
|
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
|
101
101
|
for im_path in img_path
|
102
102
|
]
|
103
|
+
self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
|
103
104
|
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
104
105
|
|
105
106
|
def get_dataset(self):
|
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
import itertools
|
4
4
|
from copy import copy, deepcopy
|
5
|
-
from pathlib import Path
|
6
5
|
|
7
6
|
import torch
|
8
7
|
|
@@ -157,40 +156,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
157
156
|
Returns:
|
158
157
|
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
|
159
158
|
"""
|
160
|
-
|
161
|
-
if mode == "train":
|
162
|
-
self.set_text_embeddings(
|
163
|
-
datasets.datasets if hasattr(datasets, "datasets") else [datasets], batch
|
164
|
-
) # cache text embeddings to accelerate training
|
165
|
-
return datasets
|
166
|
-
|
167
|
-
def set_text_embeddings(self, datasets, batch):
|
168
|
-
"""
|
169
|
-
Set text embeddings for datasets to accelerate training by caching category names.
|
170
|
-
|
171
|
-
This method collects unique category names from all datasets, then generates and caches text embeddings
|
172
|
-
for these categories to improve training efficiency.
|
173
|
-
|
174
|
-
Args:
|
175
|
-
datasets (List[Dataset]): List of datasets from which to extract category names.
|
176
|
-
batch (int | None): Batch size used for processing.
|
177
|
-
|
178
|
-
Notes:
|
179
|
-
This method collects category names from datasets that have the 'category_names' attribute,
|
180
|
-
then uses the first dataset's image path to determine where to cache the generated text embeddings.
|
181
|
-
"""
|
182
|
-
# TODO: open up an interface to determine whether to do cache
|
183
|
-
category_names = set()
|
184
|
-
for dataset in datasets:
|
185
|
-
if not hasattr(dataset, "category_names"):
|
186
|
-
continue
|
187
|
-
category_names |= dataset.category_names
|
188
|
-
|
189
|
-
# TODO: enable to update the path or use a more general way to get the path
|
190
|
-
img_path = datasets[0].img_path
|
191
|
-
self.text_embeddings = self.generate_text_embeddings(
|
192
|
-
category_names, batch, cache_path=Path(img_path).parent / "text_embeddings.pt"
|
193
|
-
)
|
159
|
+
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
|
194
160
|
|
195
161
|
def preprocess_batch(self, batch):
|
196
162
|
"""Process batch for training, moving text features to the appropriate device."""
|
@@ -202,23 +168,28 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
202
168
|
batch["txt_feats"] = txt_feats
|
203
169
|
return batch
|
204
170
|
|
205
|
-
def generate_text_embeddings(self, texts, batch,
|
171
|
+
def generate_text_embeddings(self, texts, batch, cache_dir):
|
206
172
|
"""
|
207
173
|
Generate text embeddings for a list of text samples.
|
208
174
|
|
209
175
|
Args:
|
210
176
|
texts (List[str]): List of text samples to encode.
|
211
177
|
batch (int): Batch size for processing.
|
212
|
-
|
178
|
+
cache_dir (Path): Directory to save/load cached embeddings.
|
213
179
|
|
214
180
|
Returns:
|
215
181
|
(dict): Dictionary mapping text samples to their embeddings.
|
216
182
|
"""
|
183
|
+
model = "mobileclip:blt"
|
184
|
+
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
217
185
|
if cache_path.exists():
|
218
186
|
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
219
|
-
|
187
|
+
txt_map = torch.load(cache_path)
|
188
|
+
if sorted(txt_map.keys()) == sorted(texts):
|
189
|
+
return txt_map
|
190
|
+
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
220
191
|
assert self.model is not None
|
221
|
-
txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True)
|
192
|
+
txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
|
222
193
|
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
223
194
|
torch.save(txt_map, cache_path)
|
224
195
|
return txt_map
|
@@ -47,7 +47,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
47
47
|
(torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
|
48
48
|
"""
|
49
49
|
assert isinstance(model, YOLOEModel)
|
50
|
-
names = [name.split("/")[0] for name in list(dataloader.dataset.data["names"].values())]
|
50
|
+
names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
|
51
51
|
visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
|
52
52
|
cls_visual_num = torch.zeros(len(names))
|
53
53
|
|
@@ -140,7 +140,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
140
140
|
if trainer is not None:
|
141
141
|
self.device = trainer.device
|
142
142
|
model = trainer.ema.ema
|
143
|
-
names = [name.split("/")[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
143
|
+
names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
144
144
|
|
145
145
|
if load_vp:
|
146
146
|
LOGGER.info("Validate using the visual prompt.")
|
@@ -164,7 +164,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
164
164
|
model = attempt_load_weights(model, device=self.device, inplace=True)
|
165
165
|
model.eval().to(self.device)
|
166
166
|
data = check_det_dataset(refer_data or self.args.data)
|
167
|
-
names = [name.split("/")[0] for name in list(data["names"].values())]
|
167
|
+
names = [name.split("/", 1)[0] for name in list(data["names"].values())]
|
168
168
|
|
169
169
|
if load_vp:
|
170
170
|
LOGGER.info("Validate using the visual prompt.")
|
ultralytics/nn/tasks.py
CHANGED
@@ -146,6 +146,8 @@ class BaseModel(torch.nn.Module):
|
|
146
146
|
(torch.Tensor): The last output of the model.
|
147
147
|
"""
|
148
148
|
y, dt, embeddings = [], [], [] # outputs
|
149
|
+
embed = frozenset(embed) if embed is not None else {-1}
|
150
|
+
max_idx = max(embed)
|
149
151
|
for m in self.model:
|
150
152
|
if m.f != -1: # if not from previous layer
|
151
153
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
@@ -155,9 +157,9 @@ class BaseModel(torch.nn.Module):
|
|
155
157
|
y.append(x if m.i in self.save else None) # save output
|
156
158
|
if visualize:
|
157
159
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
158
|
-
if
|
160
|
+
if m.i in embed:
|
159
161
|
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
160
|
-
if m.i ==
|
162
|
+
if m.i == max_idx:
|
161
163
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
162
164
|
return x
|
163
165
|
|
@@ -677,6 +679,8 @@ class RTDETRDetectionModel(DetectionModel):
|
|
677
679
|
(torch.Tensor): Model's output tensor.
|
678
680
|
"""
|
679
681
|
y, dt, embeddings = [], [], [] # outputs
|
682
|
+
embed = frozenset(embed) if embed is not None else {-1}
|
683
|
+
max_idx = max(embed)
|
680
684
|
for m in self.model[:-1]: # except the head part
|
681
685
|
if m.f != -1: # if not from previous layer
|
682
686
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
@@ -686,9 +690,9 @@ class RTDETRDetectionModel(DetectionModel):
|
|
686
690
|
y.append(x if m.i in self.save else None) # save output
|
687
691
|
if visualize:
|
688
692
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
689
|
-
if
|
693
|
+
if m.i in embed:
|
690
694
|
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
691
|
-
if m.i ==
|
695
|
+
if m.i == max_idx:
|
692
696
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
693
697
|
head = self.model[-1]
|
694
698
|
x = head([y[j] for j in head.f], batch) # head inference
|
@@ -721,24 +725,33 @@ class WorldModel(DetectionModel):
|
|
721
725
|
batch (int): Batch size for processing text tokens.
|
722
726
|
cache_clip_model (bool): Whether to cache the CLIP model.
|
723
727
|
"""
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
728
|
+
self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)
|
729
|
+
self.model[-1].nc = len(text)
|
730
|
+
|
731
|
+
@smart_inference_mode()
|
732
|
+
def get_text_pe(self, text, batch=80, cache_clip_model=True):
|
733
|
+
"""
|
734
|
+
Set classes in advance so that model could do offline-inference without clip model.
|
735
|
+
|
736
|
+
Args:
|
737
|
+
text (List[str]): List of class names.
|
738
|
+
batch (int): Batch size for processing text tokens.
|
739
|
+
cache_clip_model (bool): Whether to cache the CLIP model.
|
740
|
+
|
741
|
+
Returns:
|
742
|
+
(torch.Tensor): Text positional embeddings.
|
743
|
+
"""
|
744
|
+
from ultralytics.nn.text_model import build_text_model
|
745
|
+
|
746
|
+
device = next(self.model.parameters()).device
|
747
|
+
if not getattr(self, "clip_model", None) and cache_clip_model:
|
748
|
+
# For backwards compatibility of models lacking clip_model attribute
|
749
|
+
self.clip_model = build_text_model("clip:ViT-B/32", device=device)
|
750
|
+
model = self.clip_model if cache_clip_model else build_text_model("clip:ViT-B/32", device=device)
|
751
|
+
text_token = model.tokenize(text)
|
737
752
|
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
|
738
753
|
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
|
739
|
-
|
740
|
-
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
741
|
-
self.model[-1].nc = len(text)
|
754
|
+
return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
|
742
755
|
|
743
756
|
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
|
744
757
|
"""
|
@@ -760,6 +773,8 @@ class WorldModel(DetectionModel):
|
|
760
773
|
txt_feats = txt_feats.expand(x.shape[0], -1, -1)
|
761
774
|
ori_txt_feats = txt_feats.clone()
|
762
775
|
y, dt, embeddings = [], [], [] # outputs
|
776
|
+
embed = frozenset(embed) if embed is not None else {-1}
|
777
|
+
max_idx = max(embed)
|
763
778
|
for m in self.model: # except the head part
|
764
779
|
if m.f != -1: # if not from previous layer
|
765
780
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
@@ -777,9 +792,9 @@ class WorldModel(DetectionModel):
|
|
777
792
|
y.append(x if m.i in self.save else None) # save output
|
778
793
|
if visualize:
|
779
794
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
780
|
-
if
|
795
|
+
if m.i in embed:
|
781
796
|
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
782
|
-
if m.i ==
|
797
|
+
if m.i == max_idx:
|
783
798
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
784
799
|
return x
|
785
800
|
|
@@ -976,6 +991,8 @@ class YOLOEModel(DetectionModel):
|
|
976
991
|
"""
|
977
992
|
y, dt, embeddings = [], [], [] # outputs
|
978
993
|
b = x.shape[0]
|
994
|
+
embed = frozenset(embed) if embed is not None else {-1}
|
995
|
+
max_idx = max(embed)
|
979
996
|
for m in self.model: # except the head part
|
980
997
|
if m.f != -1: # if not from previous layer
|
981
998
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
@@ -997,9 +1014,9 @@ class YOLOEModel(DetectionModel):
|
|
997
1014
|
y.append(x if m.i in self.save else None) # save output
|
998
1015
|
if visualize:
|
999
1016
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
1000
|
-
if
|
1017
|
+
if m.i in embed:
|
1001
1018
|
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
1002
|
-
if m.i ==
|
1019
|
+
if m.i == max_idx:
|
1003
1020
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
1004
1021
|
return x
|
1005
1022
|
|
ultralytics/nn/text_model.py
CHANGED
@@ -30,12 +30,9 @@ class VisualAISearch(BaseSolution):
|
|
30
30
|
"""Initializes the VisualAISearch class with the FAISS index file and CLIP model."""
|
31
31
|
super().__init__(**kwargs)
|
32
32
|
check_requirements(["git+https://github.com/ultralytics/CLIP.git", "faiss-cpu"])
|
33
|
-
import clip
|
34
|
-
import faiss
|
35
|
-
|
36
|
-
self.faiss = faiss
|
37
|
-
self.clip = clip
|
38
33
|
|
34
|
+
self.faiss = __import__("faiss")
|
35
|
+
self.clip = __import__("clip")
|
39
36
|
self.faiss_index = "faiss.index"
|
40
37
|
self.data_path_npy = "paths.npy"
|
41
38
|
self.model_name = "ViT-B/32"
|
@@ -51,7 +48,7 @@ class VisualAISearch(BaseSolution):
|
|
51
48
|
safe_download(url=f"{ASSETS_URL}/images.zip", unzip=True, retry=3)
|
52
49
|
self.data_dir = Path("images")
|
53
50
|
|
54
|
-
self.model, self.preprocess = clip.load(self.model_name, device=self.device)
|
51
|
+
self.model, self.preprocess = self.clip.load(self.model_name, device=self.device)
|
55
52
|
|
56
53
|
self.index = None
|
57
54
|
self.image_paths = []
|
@@ -130,7 +130,7 @@ class Inference:
|
|
130
130
|
# Add dropdown menu for model selection
|
131
131
|
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
|
132
132
|
if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
|
133
|
-
available_models.insert(0, self.model_path.split(".pt")[0])
|
133
|
+
available_models.insert(0, self.model_path.split(".pt", 1)[0])
|
134
134
|
selected_model = self.st.sidebar.selectbox("Model", available_models)
|
135
135
|
|
136
136
|
with self.st.spinner("Model is downloading..."):
|
ultralytics/utils/__init__.py
CHANGED
@@ -1387,7 +1387,7 @@ def deprecation_warn(arg, new_arg=None):
|
|
1387
1387
|
def clean_url(url):
|
1388
1388
|
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
|
1389
1389
|
url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows
|
1390
|
-
return unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
1390
|
+
return unquote(url).split("?", 1)[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
1391
1391
|
|
1392
1392
|
|
1393
1393
|
def url2file(url):
|
@@ -73,22 +73,23 @@ def on_train_end(trainer):
|
|
73
73
|
|
74
74
|
def on_train_start(trainer):
|
75
75
|
"""Run events on train start."""
|
76
|
-
events(trainer.args)
|
76
|
+
events(trainer.args, trainer.device)
|
77
77
|
|
78
78
|
|
79
79
|
def on_val_start(validator):
|
80
80
|
"""Run events on validation start."""
|
81
|
-
|
81
|
+
if not validator.training:
|
82
|
+
events(validator.args, validator.device)
|
82
83
|
|
83
84
|
|
84
85
|
def on_predict_start(predictor):
|
85
86
|
"""Run events on predict start."""
|
86
|
-
events(predictor.args)
|
87
|
+
events(predictor.args, predictor.device)
|
87
88
|
|
88
89
|
|
89
90
|
def on_export_start(exporter):
|
90
91
|
"""Run events on export start."""
|
91
|
-
events(exporter.args)
|
92
|
+
events(exporter.args, exporter.device)
|
92
93
|
|
93
94
|
|
94
95
|
callbacks = (
|