dgenerate-ultralytics-headless 8.3.191__py3-none-any.whl → 8.3.193__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/RECORD +34 -34
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +7 -5
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/xView.yaml +1 -1
- ultralytics/data/utils.py +1 -1
- ultralytics/engine/exporter.py +5 -4
- ultralytics/engine/model.py +4 -4
- ultralytics/engine/predictor.py +7 -3
- ultralytics/engine/trainer.py +5 -5
- ultralytics/engine/tuner.py +227 -40
- ultralytics/models/yolo/classify/train.py +2 -2
- ultralytics/models/yolo/classify/val.py +1 -1
- ultralytics/models/yolo/detect/val.py +1 -1
- ultralytics/models/yolo/pose/val.py +1 -1
- ultralytics/models/yolo/segment/val.py +14 -14
- ultralytics/models/yolo/world/train.py +1 -1
- ultralytics/models/yolo/yoloe/train.py +3 -4
- ultralytics/models/yolo/yoloe/val.py +3 -3
- ultralytics/nn/__init__.py +2 -4
- ultralytics/nn/autobackend.py +2 -2
- ultralytics/nn/tasks.py +2 -51
- ultralytics/utils/__init__.py +5 -1
- ultralytics/utils/checks.py +2 -1
- ultralytics/utils/plotting.py +2 -2
- ultralytics/utils/tal.py +2 -2
- ultralytics/utils/torch_utils.py +7 -6
- ultralytics/utils/tqdm.py +50 -74
- ultralytics/utils/tuner.py +1 -1
- {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/top_level.txt +0 -0
ultralytics/engine/tuner.py
CHANGED
@@ -20,11 +20,13 @@ import random
|
|
20
20
|
import shutil
|
21
21
|
import subprocess
|
22
22
|
import time
|
23
|
+
from datetime import datetime
|
23
24
|
|
24
25
|
import numpy as np
|
25
26
|
|
26
27
|
from ultralytics.cfg import get_cfg, get_save_dir
|
27
28
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
|
29
|
+
from ultralytics.utils.checks import check_requirements
|
28
30
|
from ultralytics.utils.patches import torch_load
|
29
31
|
from ultralytics.utils.plotting import plot_tune_results
|
30
32
|
|
@@ -34,15 +36,18 @@ class Tuner:
|
|
34
36
|
A class for hyperparameter tuning of YOLO models.
|
35
37
|
|
36
38
|
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
37
|
-
search space and retraining the model to evaluate their performance.
|
39
|
+
search space and retraining the model to evaluate their performance. Supports both local CSV storage and
|
40
|
+
distributed MongoDB Atlas coordination for multi-machine hyperparameter optimization.
|
38
41
|
|
39
42
|
Attributes:
|
40
|
-
space (
|
43
|
+
space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
|
41
44
|
tune_dir (Path): Directory where evolution logs and results will be saved.
|
42
45
|
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
43
46
|
args (dict): Configuration arguments for the tuning process.
|
44
47
|
callbacks (list): Callback functions to be executed during tuning.
|
45
48
|
prefix (str): Prefix string for logging messages.
|
49
|
+
mongodb (MongoClient): Optional MongoDB client for distributed tuning.
|
50
|
+
collection (Collection): MongoDB collection for storing tuning results.
|
46
51
|
|
47
52
|
Methods:
|
48
53
|
_mutate: Mutate hyperparameters based on bounds and scaling factors.
|
@@ -53,11 +58,26 @@ class Tuner:
|
|
53
58
|
>>> from ultralytics import YOLO
|
54
59
|
>>> model = YOLO("yolo11n.pt")
|
55
60
|
>>> model.tune(
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
>>>
|
61
|
+
>>> data="coco8.yaml",
|
62
|
+
>>> epochs=10,
|
63
|
+
>>> iterations=300,
|
64
|
+
>>> plots=False,
|
65
|
+
>>> save=False,
|
66
|
+
>>> val=False
|
67
|
+
>>> )
|
68
|
+
|
69
|
+
Tune with distributed MongoDB Atlas coordination across multiple machines:
|
70
|
+
>>> model.tune(
|
71
|
+
>>> data="coco8.yaml",
|
72
|
+
>>> epochs=10,
|
73
|
+
>>> iterations=300,
|
74
|
+
>>> mongodb_uri="mongodb+srv://user:pass@cluster.mongodb.net/",
|
75
|
+
>>> mongodb_db="ultralytics",
|
76
|
+
>>> mongodb_collection="tune_results"
|
77
|
+
>>> )
|
78
|
+
|
79
|
+
Tune with custom search space:
|
80
|
+
>>> model.tune(space={"lr0": (1e-5, 1e-1), "momentum": (0.6, 0.98)})
|
61
81
|
"""
|
62
82
|
|
63
83
|
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
|
@@ -66,7 +86,7 @@ class Tuner:
|
|
66
86
|
|
67
87
|
Args:
|
68
88
|
args (dict): Configuration for hyperparameter evolution.
|
69
|
-
_callbacks (
|
89
|
+
_callbacks (list | None, optional): Callback functions to be executed during tuning.
|
70
90
|
"""
|
71
91
|
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
72
92
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
@@ -95,6 +115,10 @@ class Tuner:
|
|
95
115
|
"cutmix": (0.0, 1.0), # image cutmix (probability)
|
96
116
|
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
97
117
|
}
|
118
|
+
mongodb_uri = args.pop("mongodb_uri", None)
|
119
|
+
mongodb_db = args.pop("mongodb_db", "ultralytics")
|
120
|
+
mongodb_collection = args.pop("mongodb_collection", "tuner_results")
|
121
|
+
|
98
122
|
self.args = get_cfg(overrides=args)
|
99
123
|
self.args.exist_ok = self.args.resume # resume w/ same tune_dir
|
100
124
|
self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
|
@@ -103,13 +127,151 @@ class Tuner:
|
|
103
127
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
104
128
|
self.prefix = colorstr("Tuner: ")
|
105
129
|
callbacks.add_integration_callbacks(self)
|
130
|
+
|
131
|
+
# MongoDB Atlas support (optional)
|
132
|
+
self.mongodb = None
|
133
|
+
if mongodb_uri:
|
134
|
+
self._init_mongodb(mongodb_uri, mongodb_db, mongodb_collection)
|
135
|
+
|
106
136
|
LOGGER.info(
|
107
137
|
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
108
138
|
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
109
139
|
)
|
110
140
|
|
141
|
+
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
|
142
|
+
"""
|
143
|
+
Create MongoDB client with exponential backoff retry on connection failures.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
uri (str): MongoDB connection string with credentials and cluster information.
|
147
|
+
max_retries (int): Maximum number of connection attempts before giving up.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
(MongoClient): Connected MongoDB client instance.
|
151
|
+
"""
|
152
|
+
check_requirements("pymongo")
|
153
|
+
|
154
|
+
from pymongo import MongoClient
|
155
|
+
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
156
|
+
|
157
|
+
for attempt in range(max_retries):
|
158
|
+
try:
|
159
|
+
client = MongoClient(
|
160
|
+
uri,
|
161
|
+
serverSelectionTimeoutMS=30000,
|
162
|
+
connectTimeoutMS=20000,
|
163
|
+
socketTimeoutMS=40000,
|
164
|
+
retryWrites=True,
|
165
|
+
retryReads=True,
|
166
|
+
maxPoolSize=30,
|
167
|
+
minPoolSize=3,
|
168
|
+
maxIdleTimeMS=60000,
|
169
|
+
)
|
170
|
+
client.admin.command("ping") # Test connection
|
171
|
+
LOGGER.info(f"{self.prefix}Connected to MongoDB Atlas (attempt {attempt + 1})")
|
172
|
+
return client
|
173
|
+
except (ConnectionFailure, ServerSelectionTimeoutError):
|
174
|
+
if attempt == max_retries - 1:
|
175
|
+
raise
|
176
|
+
wait_time = 2**attempt
|
177
|
+
LOGGER.warning(
|
178
|
+
f"{self.prefix}MongoDB connection failed (attempt {attempt + 1}), retrying in {wait_time}s..."
|
179
|
+
)
|
180
|
+
time.sleep(wait_time)
|
181
|
+
|
182
|
+
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
|
183
|
+
"""
|
184
|
+
Initialize MongoDB connection for distributed tuning.
|
185
|
+
|
186
|
+
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines.
|
187
|
+
Each worker saves results to a shared collection and reads the latest best hyperparameters
|
188
|
+
from all workers for evolution.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
|
192
|
+
mongodb_db (str, optional): Database name.
|
193
|
+
mongodb_collection (str, optional): Collection name.
|
194
|
+
|
195
|
+
Notes:
|
196
|
+
- Creates a fitness index for fast queries of top results
|
197
|
+
- Falls back to CSV-only mode if connection fails
|
198
|
+
- Uses connection pooling and retry logic for production reliability
|
199
|
+
"""
|
200
|
+
self.mongodb = self._connect(mongodb_uri)
|
201
|
+
self.collection = self.mongodb[mongodb_db][mongodb_collection]
|
202
|
+
self.collection.create_index([("fitness", -1)], background=True)
|
203
|
+
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
|
204
|
+
|
205
|
+
def _get_mongodb_results(self, n: int = 5) -> list:
|
206
|
+
"""
|
207
|
+
Get top N results from MongoDB sorted by fitness.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
n (int): Number of top results to retrieve.
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
(list[dict]): List of result documents with fitness scores and hyperparameters.
|
214
|
+
"""
|
215
|
+
try:
|
216
|
+
return list(self.collection.find().sort("fitness", -1).limit(n))
|
217
|
+
except Exception:
|
218
|
+
return []
|
219
|
+
|
220
|
+
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
|
221
|
+
"""
|
222
|
+
Save results to MongoDB with proper type conversion.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
fitness (float): Fitness score achieved with these hyperparameters.
|
226
|
+
hyperparameters (dict[str, float]): Dictionary of hyperparameter values.
|
227
|
+
metrics (dict): Complete training metrics dictionary (mAP, precision, recall, losses, etc.).
|
228
|
+
iteration (int): Current iteration number.
|
229
|
+
"""
|
230
|
+
try:
|
231
|
+
self.collection.insert_one(
|
232
|
+
{
|
233
|
+
"fitness": float(fitness),
|
234
|
+
"hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
|
235
|
+
"metrics": metrics,
|
236
|
+
"timestamp": datetime.now(),
|
237
|
+
"iteration": iteration,
|
238
|
+
}
|
239
|
+
)
|
240
|
+
except Exception as e:
|
241
|
+
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
|
242
|
+
|
243
|
+
def _sync_mongodb_to_csv(self):
|
244
|
+
"""
|
245
|
+
Sync MongoDB results to CSV for plotting compatibility.
|
246
|
+
|
247
|
+
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
|
248
|
+
the existing plotting functions to work seamlessly with distributed MongoDB data.
|
249
|
+
"""
|
250
|
+
try:
|
251
|
+
# Get all results from MongoDB
|
252
|
+
all_results = list(self.collection.find().sort("iteration", 1))
|
253
|
+
if not all_results:
|
254
|
+
return
|
255
|
+
|
256
|
+
# Write to CSV
|
257
|
+
headers = ",".join(["fitness"] + list(self.space.keys())) + "\n"
|
258
|
+
with open(self.tune_csv, "w", encoding="utf-8") as f:
|
259
|
+
f.write(headers)
|
260
|
+
for result in all_results:
|
261
|
+
fitness = result["fitness"]
|
262
|
+
hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
|
263
|
+
log_row = [round(fitness, 5)] + hyp_values
|
264
|
+
f.write(",".join(map(str, log_row)) + "\n")
|
265
|
+
|
266
|
+
except Exception as e:
|
267
|
+
LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
|
268
|
+
|
111
269
|
def _mutate(
|
112
|
-
self,
|
270
|
+
self,
|
271
|
+
parent: str = "single",
|
272
|
+
n: int = 5,
|
273
|
+
mutation: float = 0.8,
|
274
|
+
sigma: float = 0.2,
|
113
275
|
) -> dict[str, float]:
|
114
276
|
"""
|
115
277
|
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
@@ -121,23 +283,36 @@ class Tuner:
|
|
121
283
|
sigma (float): Standard deviation for Gaussian random number generator.
|
122
284
|
|
123
285
|
Returns:
|
124
|
-
(
|
286
|
+
(dict[str, float]): A dictionary containing mutated hyperparameters.
|
125
287
|
"""
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
288
|
+
x = None
|
289
|
+
|
290
|
+
# Try MongoDB first if available
|
291
|
+
if self.mongodb:
|
292
|
+
results = self._get_mongodb_results(n)
|
293
|
+
if results:
|
294
|
+
# MongoDB already sorted by fitness DESC, so results[0] is best
|
295
|
+
x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
|
296
|
+
n = min(n, len(x))
|
297
|
+
|
298
|
+
# Fall back to CSV if MongoDB unavailable or empty
|
299
|
+
if x is None and self.tune_csv.exists():
|
300
|
+
csv_data = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
301
|
+
if len(csv_data) > 0:
|
302
|
+
fitness = csv_data[:, 0] # first column
|
303
|
+
n = min(n, len(csv_data))
|
304
|
+
x = csv_data[np.argsort(-fitness)][:n] # top n sorted by fitness DESC
|
305
|
+
|
306
|
+
# Mutate if we have data, otherwise use defaults
|
307
|
+
if x is not None:
|
132
308
|
w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
|
133
|
-
if parent == "single" or len(x)
|
134
|
-
# x = x[random.randint(0, n - 1)] # random selection
|
309
|
+
if parent == "single" or len(x) <= 1:
|
135
310
|
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
|
136
311
|
elif parent == "weighted":
|
137
312
|
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
138
313
|
|
139
314
|
# Mutate
|
140
|
-
r = np.random
|
315
|
+
r = np.random
|
141
316
|
r.seed(int(time.time()))
|
142
317
|
g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
|
143
318
|
ng = len(self.space)
|
@@ -149,9 +324,9 @@ class Tuner:
|
|
149
324
|
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
|
150
325
|
|
151
326
|
# Constrain to limits
|
152
|
-
for k,
|
153
|
-
hyp[k] = max(hyp[k],
|
154
|
-
hyp[k] = min(hyp[k],
|
327
|
+
for k, bounds in self.space.items():
|
328
|
+
hyp[k] = max(hyp[k], bounds[0]) # lower limit
|
329
|
+
hyp[k] = min(hyp[k], bounds[1]) # upper limit
|
155
330
|
hyp[k] = round(hyp[k], 5) # significant digits
|
156
331
|
|
157
332
|
return hyp
|
@@ -160,25 +335,26 @@ class Tuner:
|
|
160
335
|
"""
|
161
336
|
Execute the hyperparameter evolution process when the Tuner instance is called.
|
162
337
|
|
163
|
-
This method iterates through the number of iterations, performing the following steps
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
338
|
+
This method iterates through the specified number of iterations, performing the following steps:
|
339
|
+
1. Sync MongoDB results to CSV (if using distributed mode)
|
340
|
+
2. Mutate hyperparameters using the best previous results or defaults
|
341
|
+
3. Train a YOLO model with the mutated hyperparameters
|
342
|
+
4. Log fitness scores and hyperparameters to MongoDB and/or CSV
|
343
|
+
5. Track the best performing configuration across all iterations
|
169
344
|
|
170
345
|
Args:
|
171
|
-
model (Model): A pre-initialized YOLO model to be used for training.
|
346
|
+
model (Model | None, optional): A pre-initialized YOLO model to be used for training.
|
172
347
|
iterations (int): The number of generations to run the evolution for.
|
173
|
-
cleanup (bool): Whether to delete iteration weights to reduce storage space
|
174
|
-
|
175
|
-
Note:
|
176
|
-
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
177
|
-
Ensure this path is set correctly in the Tuner instance.
|
348
|
+
cleanup (bool): Whether to delete iteration weights to reduce storage space during tuning.
|
178
349
|
"""
|
179
350
|
t0 = time.time()
|
180
351
|
best_save_dir, best_metrics = None, None
|
181
352
|
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
353
|
+
|
354
|
+
# Sync MongoDB to CSV at startup for proper resume logic
|
355
|
+
if self.mongodb:
|
356
|
+
self._sync_mongodb_to_csv()
|
357
|
+
|
182
358
|
start = 0
|
183
359
|
if self.tune_csv.exists():
|
184
360
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
@@ -205,12 +381,23 @@ class Tuner:
|
|
205
381
|
except Exception as e:
|
206
382
|
LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
207
383
|
|
208
|
-
# Save results
|
384
|
+
# Save results - MongoDB takes precedence
|
209
385
|
fitness = metrics.get("fitness", 0.0)
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
386
|
+
if self.mongodb:
|
387
|
+
self._save_to_mongodb(fitness, mutated_hyp, metrics, i + 1)
|
388
|
+
self._sync_mongodb_to_csv()
|
389
|
+
total_mongo_iterations = self.collection.count_documents({})
|
390
|
+
if total_mongo_iterations >= iterations:
|
391
|
+
LOGGER.info(
|
392
|
+
f"{self.prefix}Target iterations ({iterations}) reached in MongoDB ({total_mongo_iterations}). Stopping."
|
393
|
+
)
|
394
|
+
break
|
395
|
+
else:
|
396
|
+
# Save to CSV only if no MongoDB
|
397
|
+
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
398
|
+
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
|
399
|
+
with open(self.tune_csv, "a", encoding="utf-8") as f:
|
400
|
+
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
214
401
|
|
215
402
|
# Get best results
|
216
403
|
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
@@ -226,7 +413,7 @@ class Tuner:
|
|
226
413
|
shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space
|
227
414
|
|
228
415
|
# Plot tune results
|
229
|
-
plot_tune_results(self.tune_csv)
|
416
|
+
plot_tune_results(str(self.tune_csv))
|
230
417
|
|
231
418
|
# Save and print tune results
|
232
419
|
header = (
|
@@ -166,8 +166,8 @@ class ClassificationTrainer(BaseTrainer):
|
|
166
166
|
|
167
167
|
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
168
168
|
"""Preprocess a batch of images and classes."""
|
169
|
-
batch["img"] = batch["img"].to(self.device)
|
170
|
-
batch["cls"] = batch["cls"].to(self.device)
|
169
|
+
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
170
|
+
batch["cls"] = batch["cls"].to(self.device, non_blocking=True)
|
171
171
|
return batch
|
172
172
|
|
173
173
|
def progress_string(self) -> str:
|
@@ -91,7 +91,7 @@ class ClassificationValidator(BaseValidator):
|
|
91
91
|
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
92
92
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
93
93
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
94
|
-
batch["cls"] = batch["cls"].to(self.device)
|
94
|
+
batch["cls"] = batch["cls"].to(self.device, non_blocking=True)
|
95
95
|
return batch
|
96
96
|
|
97
97
|
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
@@ -74,7 +74,7 @@ class DetectionValidator(BaseValidator):
|
|
74
74
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
75
75
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
76
76
|
for k in {"batch_idx", "cls", "bboxes"}:
|
77
|
-
batch[k] = batch[k].to(self.device)
|
77
|
+
batch[k] = batch[k].to(self.device, non_blocking=True)
|
78
78
|
|
79
79
|
return batch
|
80
80
|
|
@@ -86,7 +86,7 @@ class PoseValidator(DetectionValidator):
|
|
86
86
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
87
87
|
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
88
88
|
batch = super().preprocess(batch)
|
89
|
-
batch["keypoints"] = batch["keypoints"].to(self.device).float()
|
89
|
+
batch["keypoints"] = batch["keypoints"].to(self.device, non_blocking=True).float()
|
90
90
|
return batch
|
91
91
|
|
92
92
|
def get_desc(self) -> str:
|
@@ -63,7 +63,7 @@ class SegmentationValidator(DetectionValidator):
|
|
63
63
|
(Dict[str, Any]): Preprocessed batch.
|
64
64
|
"""
|
65
65
|
batch = super().preprocess(batch)
|
66
|
-
batch["masks"] = batch["masks"].to(self.device).float()
|
66
|
+
batch["masks"] = batch["masks"].to(self.device, non_blocking=True).float()
|
67
67
|
return batch
|
68
68
|
|
69
69
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
@@ -133,8 +133,17 @@ class SegmentationValidator(DetectionValidator):
|
|
133
133
|
(Dict[str, Any]): Prepared batch with processed annotations.
|
134
134
|
"""
|
135
135
|
prepared_batch = super()._prepare_batch(si, batch)
|
136
|
-
|
137
|
-
|
136
|
+
nl = len(prepared_batch["cls"])
|
137
|
+
if self.args.overlap_mask:
|
138
|
+
masks = batch["masks"][si]
|
139
|
+
index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
|
140
|
+
masks = (masks == index).float()
|
141
|
+
else:
|
142
|
+
masks = batch["masks"][batch["batch_idx"] == si]
|
143
|
+
if nl and self.process is ops.process_mask_native:
|
144
|
+
masks = F.interpolate(masks[None], prepared_batch["imgsz"], mode="bilinear", align_corners=False)[0]
|
145
|
+
masks = masks.gt_(0.5)
|
146
|
+
prepared_batch["masks"] = masks
|
138
147
|
return prepared_batch
|
139
148
|
|
140
149
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
@@ -158,20 +167,11 @@ class SegmentationValidator(DetectionValidator):
|
|
158
167
|
>>> correct_preds = validator._process_batch(preds, batch)
|
159
168
|
"""
|
160
169
|
tp = super()._process_batch(preds, batch)
|
161
|
-
gt_cls
|
170
|
+
gt_cls = batch["cls"]
|
162
171
|
if len(gt_cls) == 0 or len(preds["cls"]) == 0:
|
163
172
|
tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
|
164
173
|
else:
|
165
|
-
|
166
|
-
if self.args.overlap_mask:
|
167
|
-
nl = len(gt_cls)
|
168
|
-
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
169
|
-
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
170
|
-
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
171
|
-
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
172
|
-
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
173
|
-
gt_masks = gt_masks.gt_(0.5)
|
174
|
-
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
174
|
+
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
|
175
175
|
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
176
176
|
tp.update({"tp_m": tp_m}) # update tp with mask IoU
|
177
177
|
return tp
|
@@ -171,7 +171,7 @@ class WorldTrainer(DetectionTrainer):
|
|
171
171
|
|
172
172
|
# Add text features
|
173
173
|
texts = list(itertools.chain(*batch["texts"]))
|
174
|
-
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
|
174
|
+
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device, non_blocking=True)
|
175
175
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
176
176
|
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
177
177
|
return batch
|
@@ -197,7 +197,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
197
197
|
batch = DetectionTrainer.preprocess_batch(self, batch)
|
198
198
|
|
199
199
|
texts = list(itertools.chain(*batch["texts"]))
|
200
|
-
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
|
200
|
+
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device, non_blocking=True)
|
201
201
|
txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
202
202
|
batch["txt_feats"] = txt_feats
|
203
203
|
return batch
|
@@ -251,8 +251,7 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|
251
251
|
|
252
252
|
def preprocess_batch(self, batch):
|
253
253
|
"""Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
|
254
|
-
|
255
|
-
return batch
|
254
|
+
return DetectionTrainer.preprocess_batch(self, batch)
|
256
255
|
|
257
256
|
def set_text_embeddings(self, datasets, batch: int):
|
258
257
|
"""
|
@@ -318,5 +317,5 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
|
318
317
|
def preprocess_batch(self, batch):
|
319
318
|
"""Preprocess a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
|
320
319
|
batch = super().preprocess_batch(batch)
|
321
|
-
batch["visuals"] = batch["visuals"].to(self.device)
|
320
|
+
batch["visuals"] = batch["visuals"].to(self.device, non_blocking=True)
|
322
321
|
return batch
|
@@ -102,7 +102,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
102
102
|
"""Preprocess batch data, ensuring visuals are on the same device as images."""
|
103
103
|
batch = super().preprocess(batch)
|
104
104
|
if "visuals" in batch:
|
105
|
-
batch["visuals"] = batch["visuals"].to(batch["img"].device)
|
105
|
+
batch["visuals"] = batch["visuals"].to(batch["img"].device, non_blocking=True)
|
106
106
|
return batch
|
107
107
|
|
108
108
|
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
|
@@ -186,9 +186,9 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
186
186
|
self.device = select_device(self.args.device, verbose=False)
|
187
187
|
|
188
188
|
if isinstance(model, (str, Path)):
|
189
|
-
from ultralytics.nn.tasks import
|
189
|
+
from ultralytics.nn.tasks import load_checkpoint
|
190
190
|
|
191
|
-
model =
|
191
|
+
model, _ = load_checkpoint(model, device=self.device) # model, ckpt
|
192
192
|
model.eval().to(self.device)
|
193
193
|
data = check_det_dataset(refer_data or self.args.data)
|
194
194
|
names = [name.split("/", 1)[0] for name in list(data["names"].values())]
|
ultralytics/nn/__init__.py
CHANGED
@@ -5,18 +5,16 @@ from .tasks import (
|
|
5
5
|
ClassificationModel,
|
6
6
|
DetectionModel,
|
7
7
|
SegmentationModel,
|
8
|
-
attempt_load_one_weight,
|
9
|
-
attempt_load_weights,
|
10
8
|
guess_model_scale,
|
11
9
|
guess_model_task,
|
10
|
+
load_checkpoint,
|
12
11
|
parse_model,
|
13
12
|
torch_safe_load,
|
14
13
|
yaml_model_load,
|
15
14
|
)
|
16
15
|
|
17
16
|
__all__ = (
|
18
|
-
"
|
19
|
-
"attempt_load_weights",
|
17
|
+
"load_checkpoint",
|
20
18
|
"parse_model",
|
21
19
|
"yaml_model_load",
|
22
20
|
"guess_model_task",
|
ultralytics/nn/autobackend.py
CHANGED
@@ -203,9 +203,9 @@ class AutoBackend(nn.Module):
|
|
203
203
|
model = model.fuse(verbose=verbose)
|
204
204
|
model = model.to(device)
|
205
205
|
else: # pt file
|
206
|
-
from ultralytics.nn.tasks import
|
206
|
+
from ultralytics.nn.tasks import load_checkpoint
|
207
207
|
|
208
|
-
model, _ =
|
208
|
+
model, _ = load_checkpoint(model, device=device, fuse=fuse) # load model, ckpt
|
209
209
|
|
210
210
|
# Common PyTorch model processing
|
211
211
|
if hasattr(model, "kpt_shape"):
|
ultralytics/nn/tasks.py
CHANGED
@@ -1483,61 +1483,12 @@ def torch_safe_load(weight, safe_only=False):
|
|
1483
1483
|
return ckpt, file
|
1484
1484
|
|
1485
1485
|
|
1486
|
-
def
|
1487
|
-
"""
|
1488
|
-
Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
|
1489
|
-
|
1490
|
-
Args:
|
1491
|
-
weights (str | List[str]): Model weights path(s).
|
1492
|
-
device (torch.device, optional): Device to load model to.
|
1493
|
-
inplace (bool): Whether to do inplace operations.
|
1494
|
-
fuse (bool): Whether to fuse model.
|
1495
|
-
|
1496
|
-
Returns:
|
1497
|
-
(torch.nn.Module): Loaded model.
|
1498
|
-
"""
|
1499
|
-
ensemble = Ensemble()
|
1500
|
-
for w in weights if isinstance(weights, list) else [weights]:
|
1501
|
-
ckpt, w = torch_safe_load(w) # load ckpt
|
1502
|
-
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
|
1503
|
-
model = (ckpt.get("ema") or ckpt["model"]).float() # FP32 model
|
1504
|
-
|
1505
|
-
# Model compatibility updates
|
1506
|
-
model.args = args # attach args to model
|
1507
|
-
model.pt_path = w # attach *.pt file path to model
|
1508
|
-
model.task = getattr(model, "task", guess_model_task(model))
|
1509
|
-
if not hasattr(model, "stride"):
|
1510
|
-
model.stride = torch.tensor([32.0])
|
1511
|
-
|
1512
|
-
# Append
|
1513
|
-
ensemble.append((model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()).to(device))
|
1514
|
-
|
1515
|
-
# Module updates
|
1516
|
-
for m in ensemble.modules():
|
1517
|
-
if hasattr(m, "inplace"):
|
1518
|
-
m.inplace = inplace
|
1519
|
-
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
1520
|
-
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
1521
|
-
|
1522
|
-
# Return model
|
1523
|
-
if len(ensemble) == 1:
|
1524
|
-
return ensemble[-1]
|
1525
|
-
|
1526
|
-
# Return ensemble
|
1527
|
-
LOGGER.info(f"Ensemble created with {weights}\n")
|
1528
|
-
for k in "names", "nc", "yaml":
|
1529
|
-
setattr(ensemble, k, getattr(ensemble[0], k))
|
1530
|
-
ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
|
1531
|
-
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
|
1532
|
-
return ensemble
|
1533
|
-
|
1534
|
-
|
1535
|
-
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
1486
|
+
def load_checkpoint(weight, device=None, inplace=True, fuse=False):
|
1536
1487
|
"""
|
1537
1488
|
Load a single model weights.
|
1538
1489
|
|
1539
1490
|
Args:
|
1540
|
-
weight (str): Model weight path.
|
1491
|
+
weight (str | Path): Model weight path.
|
1541
1492
|
device (torch.device, optional): Device to load model to.
|
1542
1493
|
inplace (bool): Whether to do inplace operations.
|
1543
1494
|
fuse (bool): Whether to fuse model.
|
ultralytics/utils/__init__.py
CHANGED
@@ -49,7 +49,7 @@ MACOS_VERSION = platform.mac_ver()[0] if MACOS else None
|
|
49
49
|
NOT_MACOS14 = not (MACOS and MACOS_VERSION.startswith("14."))
|
50
50
|
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
51
51
|
PYTHON_VERSION = platform.python_version()
|
52
|
-
TORCH_VERSION = torch.__version__
|
52
|
+
TORCH_VERSION = str(torch.__version__) # Normalize torch.__version__ (PyTorch>1.9 returns TorchVersion objects)
|
53
53
|
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
54
54
|
IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode"
|
55
55
|
RKNN_CHIPS = frozenset(
|
@@ -132,6 +132,10 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warning
|
|
132
132
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
|
133
133
|
os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
|
134
134
|
|
135
|
+
# Precompiled type tuples for faster isinstance() checks
|
136
|
+
FLOAT_OR_INT = (float, int)
|
137
|
+
STR_OR_PATH = (str, Path)
|
138
|
+
|
135
139
|
|
136
140
|
class DataExportMixin:
|
137
141
|
"""
|