dgenerate-ultralytics-headless 8.3.191__py3-none-any.whl → 8.3.193__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/RECORD +34 -34
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/cfg/__init__.py +7 -5
  5. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  6. ultralytics/cfg/datasets/xView.yaml +1 -1
  7. ultralytics/data/utils.py +1 -1
  8. ultralytics/engine/exporter.py +5 -4
  9. ultralytics/engine/model.py +4 -4
  10. ultralytics/engine/predictor.py +7 -3
  11. ultralytics/engine/trainer.py +5 -5
  12. ultralytics/engine/tuner.py +227 -40
  13. ultralytics/models/yolo/classify/train.py +2 -2
  14. ultralytics/models/yolo/classify/val.py +1 -1
  15. ultralytics/models/yolo/detect/val.py +1 -1
  16. ultralytics/models/yolo/pose/val.py +1 -1
  17. ultralytics/models/yolo/segment/val.py +14 -14
  18. ultralytics/models/yolo/world/train.py +1 -1
  19. ultralytics/models/yolo/yoloe/train.py +3 -4
  20. ultralytics/models/yolo/yoloe/val.py +3 -3
  21. ultralytics/nn/__init__.py +2 -4
  22. ultralytics/nn/autobackend.py +2 -2
  23. ultralytics/nn/tasks.py +2 -51
  24. ultralytics/utils/__init__.py +5 -1
  25. ultralytics/utils/checks.py +2 -1
  26. ultralytics/utils/plotting.py +2 -2
  27. ultralytics/utils/tal.py +2 -2
  28. ultralytics/utils/torch_utils.py +7 -6
  29. ultralytics/utils/tqdm.py +50 -74
  30. ultralytics/utils/tuner.py +1 -1
  31. {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/WHEEL +0 -0
  32. {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/entry_points.txt +0 -0
  33. {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/licenses/LICENSE +0 -0
  34. {dgenerate_ultralytics_headless-8.3.191.dist-info → dgenerate_ultralytics_headless-8.3.193.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,13 @@ import random
20
20
  import shutil
21
21
  import subprocess
22
22
  import time
23
+ from datetime import datetime
23
24
 
24
25
  import numpy as np
25
26
 
26
27
  from ultralytics.cfg import get_cfg, get_save_dir
27
28
  from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
29
+ from ultralytics.utils.checks import check_requirements
28
30
  from ultralytics.utils.patches import torch_load
29
31
  from ultralytics.utils.plotting import plot_tune_results
30
32
 
@@ -34,15 +36,18 @@ class Tuner:
34
36
  A class for hyperparameter tuning of YOLO models.
35
37
 
36
38
  The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
37
- search space and retraining the model to evaluate their performance.
39
+ search space and retraining the model to evaluate their performance. Supports both local CSV storage and
40
+ distributed MongoDB Atlas coordination for multi-machine hyperparameter optimization.
38
41
 
39
42
  Attributes:
40
- space (Dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
43
+ space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
41
44
  tune_dir (Path): Directory where evolution logs and results will be saved.
42
45
  tune_csv (Path): Path to the CSV file where evolution logs are saved.
43
46
  args (dict): Configuration arguments for the tuning process.
44
47
  callbacks (list): Callback functions to be executed during tuning.
45
48
  prefix (str): Prefix string for logging messages.
49
+ mongodb (MongoClient): Optional MongoDB client for distributed tuning.
50
+ collection (Collection): MongoDB collection for storing tuning results.
46
51
 
47
52
  Methods:
48
53
  _mutate: Mutate hyperparameters based on bounds and scaling factors.
@@ -53,11 +58,26 @@ class Tuner:
53
58
  >>> from ultralytics import YOLO
54
59
  >>> model = YOLO("yolo11n.pt")
55
60
  >>> model.tune(
56
- ... data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False
57
- ... )
58
-
59
- Tune with custom search space.
60
- >>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
61
+ >>> data="coco8.yaml",
62
+ >>> epochs=10,
63
+ >>> iterations=300,
64
+ >>> plots=False,
65
+ >>> save=False,
66
+ >>> val=False
67
+ >>> )
68
+
69
+ Tune with distributed MongoDB Atlas coordination across multiple machines:
70
+ >>> model.tune(
71
+ >>> data="coco8.yaml",
72
+ >>> epochs=10,
73
+ >>> iterations=300,
74
+ >>> mongodb_uri="mongodb+srv://user:pass@cluster.mongodb.net/",
75
+ >>> mongodb_db="ultralytics",
76
+ >>> mongodb_collection="tune_results"
77
+ >>> )
78
+
79
+ Tune with custom search space:
80
+ >>> model.tune(space={"lr0": (1e-5, 1e-1), "momentum": (0.6, 0.98)})
61
81
  """
62
82
 
63
83
  def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
@@ -66,7 +86,7 @@ class Tuner:
66
86
 
67
87
  Args:
68
88
  args (dict): Configuration for hyperparameter evolution.
69
- _callbacks (List, optional): Callback functions to be executed during tuning.
89
+ _callbacks (list | None, optional): Callback functions to be executed during tuning.
70
90
  """
71
91
  self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
72
92
  # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
@@ -95,6 +115,10 @@ class Tuner:
95
115
  "cutmix": (0.0, 1.0), # image cutmix (probability)
96
116
  "copy_paste": (0.0, 1.0), # segment copy-paste (probability)
97
117
  }
118
+ mongodb_uri = args.pop("mongodb_uri", None)
119
+ mongodb_db = args.pop("mongodb_db", "ultralytics")
120
+ mongodb_collection = args.pop("mongodb_collection", "tuner_results")
121
+
98
122
  self.args = get_cfg(overrides=args)
99
123
  self.args.exist_ok = self.args.resume # resume w/ same tune_dir
100
124
  self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
@@ -103,13 +127,151 @@ class Tuner:
103
127
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
104
128
  self.prefix = colorstr("Tuner: ")
105
129
  callbacks.add_integration_callbacks(self)
130
+
131
+ # MongoDB Atlas support (optional)
132
+ self.mongodb = None
133
+ if mongodb_uri:
134
+ self._init_mongodb(mongodb_uri, mongodb_db, mongodb_collection)
135
+
106
136
  LOGGER.info(
107
137
  f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
108
138
  f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
109
139
  )
110
140
 
141
+ def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
142
+ """
143
+ Create MongoDB client with exponential backoff retry on connection failures.
144
+
145
+ Args:
146
+ uri (str): MongoDB connection string with credentials and cluster information.
147
+ max_retries (int): Maximum number of connection attempts before giving up.
148
+
149
+ Returns:
150
+ (MongoClient): Connected MongoDB client instance.
151
+ """
152
+ check_requirements("pymongo")
153
+
154
+ from pymongo import MongoClient
155
+ from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
156
+
157
+ for attempt in range(max_retries):
158
+ try:
159
+ client = MongoClient(
160
+ uri,
161
+ serverSelectionTimeoutMS=30000,
162
+ connectTimeoutMS=20000,
163
+ socketTimeoutMS=40000,
164
+ retryWrites=True,
165
+ retryReads=True,
166
+ maxPoolSize=30,
167
+ minPoolSize=3,
168
+ maxIdleTimeMS=60000,
169
+ )
170
+ client.admin.command("ping") # Test connection
171
+ LOGGER.info(f"{self.prefix}Connected to MongoDB Atlas (attempt {attempt + 1})")
172
+ return client
173
+ except (ConnectionFailure, ServerSelectionTimeoutError):
174
+ if attempt == max_retries - 1:
175
+ raise
176
+ wait_time = 2**attempt
177
+ LOGGER.warning(
178
+ f"{self.prefix}MongoDB connection failed (attempt {attempt + 1}), retrying in {wait_time}s..."
179
+ )
180
+ time.sleep(wait_time)
181
+
182
+ def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
183
+ """
184
+ Initialize MongoDB connection for distributed tuning.
185
+
186
+ Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines.
187
+ Each worker saves results to a shared collection and reads the latest best hyperparameters
188
+ from all workers for evolution.
189
+
190
+ Args:
191
+ mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
192
+ mongodb_db (str, optional): Database name.
193
+ mongodb_collection (str, optional): Collection name.
194
+
195
+ Notes:
196
+ - Creates a fitness index for fast queries of top results
197
+ - Falls back to CSV-only mode if connection fails
198
+ - Uses connection pooling and retry logic for production reliability
199
+ """
200
+ self.mongodb = self._connect(mongodb_uri)
201
+ self.collection = self.mongodb[mongodb_db][mongodb_collection]
202
+ self.collection.create_index([("fitness", -1)], background=True)
203
+ LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
204
+
205
+ def _get_mongodb_results(self, n: int = 5) -> list:
206
+ """
207
+ Get top N results from MongoDB sorted by fitness.
208
+
209
+ Args:
210
+ n (int): Number of top results to retrieve.
211
+
212
+ Returns:
213
+ (list[dict]): List of result documents with fitness scores and hyperparameters.
214
+ """
215
+ try:
216
+ return list(self.collection.find().sort("fitness", -1).limit(n))
217
+ except Exception:
218
+ return []
219
+
220
+ def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
221
+ """
222
+ Save results to MongoDB with proper type conversion.
223
+
224
+ Args:
225
+ fitness (float): Fitness score achieved with these hyperparameters.
226
+ hyperparameters (dict[str, float]): Dictionary of hyperparameter values.
227
+ metrics (dict): Complete training metrics dictionary (mAP, precision, recall, losses, etc.).
228
+ iteration (int): Current iteration number.
229
+ """
230
+ try:
231
+ self.collection.insert_one(
232
+ {
233
+ "fitness": float(fitness),
234
+ "hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
235
+ "metrics": metrics,
236
+ "timestamp": datetime.now(),
237
+ "iteration": iteration,
238
+ }
239
+ )
240
+ except Exception as e:
241
+ LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
242
+
243
+ def _sync_mongodb_to_csv(self):
244
+ """
245
+ Sync MongoDB results to CSV for plotting compatibility.
246
+
247
+ Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
248
+ the existing plotting functions to work seamlessly with distributed MongoDB data.
249
+ """
250
+ try:
251
+ # Get all results from MongoDB
252
+ all_results = list(self.collection.find().sort("iteration", 1))
253
+ if not all_results:
254
+ return
255
+
256
+ # Write to CSV
257
+ headers = ",".join(["fitness"] + list(self.space.keys())) + "\n"
258
+ with open(self.tune_csv, "w", encoding="utf-8") as f:
259
+ f.write(headers)
260
+ for result in all_results:
261
+ fitness = result["fitness"]
262
+ hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
263
+ log_row = [round(fitness, 5)] + hyp_values
264
+ f.write(",".join(map(str, log_row)) + "\n")
265
+
266
+ except Exception as e:
267
+ LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
268
+
111
269
  def _mutate(
112
- self, parent: str = "single", n: int = 5, mutation: float = 0.8, sigma: float = 0.2
270
+ self,
271
+ parent: str = "single",
272
+ n: int = 5,
273
+ mutation: float = 0.8,
274
+ sigma: float = 0.2,
113
275
  ) -> dict[str, float]:
114
276
  """
115
277
  Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
@@ -121,23 +283,36 @@ class Tuner:
121
283
  sigma (float): Standard deviation for Gaussian random number generator.
122
284
 
123
285
  Returns:
124
- (Dict[str, float]): A dictionary containing mutated hyperparameters.
286
+ (dict[str, float]): A dictionary containing mutated hyperparameters.
125
287
  """
126
- if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
127
- # Select parent(s)
128
- x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
129
- fitness = x[:, 0] # first column
130
- n = min(n, len(x)) # number of previous results to consider
131
- x = x[np.argsort(-fitness)][:n] # top n mutations
288
+ x = None
289
+
290
+ # Try MongoDB first if available
291
+ if self.mongodb:
292
+ results = self._get_mongodb_results(n)
293
+ if results:
294
+ # MongoDB already sorted by fitness DESC, so results[0] is best
295
+ x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
296
+ n = min(n, len(x))
297
+
298
+ # Fall back to CSV if MongoDB unavailable or empty
299
+ if x is None and self.tune_csv.exists():
300
+ csv_data = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
301
+ if len(csv_data) > 0:
302
+ fitness = csv_data[:, 0] # first column
303
+ n = min(n, len(csv_data))
304
+ x = csv_data[np.argsort(-fitness)][:n] # top n sorted by fitness DESC
305
+
306
+ # Mutate if we have data, otherwise use defaults
307
+ if x is not None:
132
308
  w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
133
- if parent == "single" or len(x) == 1:
134
- # x = x[random.randint(0, n - 1)] # random selection
309
+ if parent == "single" or len(x) <= 1:
135
310
  x = x[random.choices(range(n), weights=w)[0]] # weighted selection
136
311
  elif parent == "weighted":
137
312
  x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
138
313
 
139
314
  # Mutate
140
- r = np.random # method
315
+ r = np.random
141
316
  r.seed(int(time.time()))
142
317
  g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
143
318
  ng = len(self.space)
@@ -149,9 +324,9 @@ class Tuner:
149
324
  hyp = {k: getattr(self.args, k) for k in self.space.keys()}
150
325
 
151
326
  # Constrain to limits
152
- for k, v in self.space.items():
153
- hyp[k] = max(hyp[k], v[0]) # lower limit
154
- hyp[k] = min(hyp[k], v[1]) # upper limit
327
+ for k, bounds in self.space.items():
328
+ hyp[k] = max(hyp[k], bounds[0]) # lower limit
329
+ hyp[k] = min(hyp[k], bounds[1]) # upper limit
155
330
  hyp[k] = round(hyp[k], 5) # significant digits
156
331
 
157
332
  return hyp
@@ -160,25 +335,26 @@ class Tuner:
160
335
  """
161
336
  Execute the hyperparameter evolution process when the Tuner instance is called.
162
337
 
163
- This method iterates through the number of iterations, performing the following steps in each iteration:
164
-
165
- 1. Load the existing hyperparameters or initialize new ones.
166
- 2. Mutate the hyperparameters using the `_mutate` method.
167
- 3. Train a YOLO model with the mutated hyperparameters.
168
- 4. Log the fitness score and mutated hyperparameters to a CSV file.
338
+ This method iterates through the specified number of iterations, performing the following steps:
339
+ 1. Sync MongoDB results to CSV (if using distributed mode)
340
+ 2. Mutate hyperparameters using the best previous results or defaults
341
+ 3. Train a YOLO model with the mutated hyperparameters
342
+ 4. Log fitness scores and hyperparameters to MongoDB and/or CSV
343
+ 5. Track the best performing configuration across all iterations
169
344
 
170
345
  Args:
171
- model (Model): A pre-initialized YOLO model to be used for training.
346
+ model (Model | None, optional): A pre-initialized YOLO model to be used for training.
172
347
  iterations (int): The number of generations to run the evolution for.
173
- cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
174
-
175
- Note:
176
- The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
177
- Ensure this path is set correctly in the Tuner instance.
348
+ cleanup (bool): Whether to delete iteration weights to reduce storage space during tuning.
178
349
  """
179
350
  t0 = time.time()
180
351
  best_save_dir, best_metrics = None, None
181
352
  (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
353
+
354
+ # Sync MongoDB to CSV at startup for proper resume logic
355
+ if self.mongodb:
356
+ self._sync_mongodb_to_csv()
357
+
182
358
  start = 0
183
359
  if self.tune_csv.exists():
184
360
  x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
@@ -205,12 +381,23 @@ class Tuner:
205
381
  except Exception as e:
206
382
  LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
207
383
 
208
- # Save results and mutated_hyp to CSV
384
+ # Save results - MongoDB takes precedence
209
385
  fitness = metrics.get("fitness", 0.0)
210
- log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
211
- headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
212
- with open(self.tune_csv, "a", encoding="utf-8") as f:
213
- f.write(headers + ",".join(map(str, log_row)) + "\n")
386
+ if self.mongodb:
387
+ self._save_to_mongodb(fitness, mutated_hyp, metrics, i + 1)
388
+ self._sync_mongodb_to_csv()
389
+ total_mongo_iterations = self.collection.count_documents({})
390
+ if total_mongo_iterations >= iterations:
391
+ LOGGER.info(
392
+ f"{self.prefix}Target iterations ({iterations}) reached in MongoDB ({total_mongo_iterations}). Stopping."
393
+ )
394
+ break
395
+ else:
396
+ # Save to CSV only if no MongoDB
397
+ log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
398
+ headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
399
+ with open(self.tune_csv, "a", encoding="utf-8") as f:
400
+ f.write(headers + ",".join(map(str, log_row)) + "\n")
214
401
 
215
402
  # Get best results
216
403
  x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
@@ -226,7 +413,7 @@ class Tuner:
226
413
  shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space
227
414
 
228
415
  # Plot tune results
229
- plot_tune_results(self.tune_csv)
416
+ plot_tune_results(str(self.tune_csv))
230
417
 
231
418
  # Save and print tune results
232
419
  header = (
@@ -166,8 +166,8 @@ class ClassificationTrainer(BaseTrainer):
166
166
 
167
167
  def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
168
168
  """Preprocess a batch of images and classes."""
169
- batch["img"] = batch["img"].to(self.device)
170
- batch["cls"] = batch["cls"].to(self.device)
169
+ batch["img"] = batch["img"].to(self.device, non_blocking=True)
170
+ batch["cls"] = batch["cls"].to(self.device, non_blocking=True)
171
171
  return batch
172
172
 
173
173
  def progress_string(self) -> str:
@@ -91,7 +91,7 @@ class ClassificationValidator(BaseValidator):
91
91
  """Preprocess input batch by moving data to device and converting to appropriate dtype."""
92
92
  batch["img"] = batch["img"].to(self.device, non_blocking=True)
93
93
  batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
94
- batch["cls"] = batch["cls"].to(self.device)
94
+ batch["cls"] = batch["cls"].to(self.device, non_blocking=True)
95
95
  return batch
96
96
 
97
97
  def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
@@ -74,7 +74,7 @@ class DetectionValidator(BaseValidator):
74
74
  batch["img"] = batch["img"].to(self.device, non_blocking=True)
75
75
  batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
76
76
  for k in {"batch_idx", "cls", "bboxes"}:
77
- batch[k] = batch[k].to(self.device)
77
+ batch[k] = batch[k].to(self.device, non_blocking=True)
78
78
 
79
79
  return batch
80
80
 
@@ -86,7 +86,7 @@ class PoseValidator(DetectionValidator):
86
86
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
87
87
  """Preprocess batch by converting keypoints data to float and moving it to the device."""
88
88
  batch = super().preprocess(batch)
89
- batch["keypoints"] = batch["keypoints"].to(self.device).float()
89
+ batch["keypoints"] = batch["keypoints"].to(self.device, non_blocking=True).float()
90
90
  return batch
91
91
 
92
92
  def get_desc(self) -> str:
@@ -63,7 +63,7 @@ class SegmentationValidator(DetectionValidator):
63
63
  (Dict[str, Any]): Preprocessed batch.
64
64
  """
65
65
  batch = super().preprocess(batch)
66
- batch["masks"] = batch["masks"].to(self.device).float()
66
+ batch["masks"] = batch["masks"].to(self.device, non_blocking=True).float()
67
67
  return batch
68
68
 
69
69
  def init_metrics(self, model: torch.nn.Module) -> None:
@@ -133,8 +133,17 @@ class SegmentationValidator(DetectionValidator):
133
133
  (Dict[str, Any]): Prepared batch with processed annotations.
134
134
  """
135
135
  prepared_batch = super()._prepare_batch(si, batch)
136
- midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
137
- prepared_batch["masks"] = batch["masks"][midx]
136
+ nl = len(prepared_batch["cls"])
137
+ if self.args.overlap_mask:
138
+ masks = batch["masks"][si]
139
+ index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
140
+ masks = (masks == index).float()
141
+ else:
142
+ masks = batch["masks"][batch["batch_idx"] == si]
143
+ if nl and self.process is ops.process_mask_native:
144
+ masks = F.interpolate(masks[None], prepared_batch["imgsz"], mode="bilinear", align_corners=False)[0]
145
+ masks = masks.gt_(0.5)
146
+ prepared_batch["masks"] = masks
138
147
  return prepared_batch
139
148
 
140
149
  def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
@@ -158,20 +167,11 @@ class SegmentationValidator(DetectionValidator):
158
167
  >>> correct_preds = validator._process_batch(preds, batch)
159
168
  """
160
169
  tp = super()._process_batch(preds, batch)
161
- gt_cls, gt_masks = batch["cls"], batch["masks"]
170
+ gt_cls = batch["cls"]
162
171
  if len(gt_cls) == 0 or len(preds["cls"]) == 0:
163
172
  tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
164
173
  else:
165
- pred_masks = preds["masks"]
166
- if self.args.overlap_mask:
167
- nl = len(gt_cls)
168
- index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
169
- gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
170
- gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
171
- if gt_masks.shape[1:] != pred_masks.shape[1:]:
172
- gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
173
- gt_masks = gt_masks.gt_(0.5)
174
- iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
174
+ iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
175
175
  tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
176
176
  tp.update({"tp_m": tp_m}) # update tp with mask IoU
177
177
  return tp
@@ -171,7 +171,7 @@ class WorldTrainer(DetectionTrainer):
171
171
 
172
172
  # Add text features
173
173
  texts = list(itertools.chain(*batch["texts"]))
174
- txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
174
+ txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device, non_blocking=True)
175
175
  txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
176
176
  batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
177
177
  return batch
@@ -197,7 +197,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
197
197
  batch = DetectionTrainer.preprocess_batch(self, batch)
198
198
 
199
199
  texts = list(itertools.chain(*batch["texts"]))
200
- txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
200
+ txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device, non_blocking=True)
201
201
  txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
202
202
  batch["txt_feats"] = txt_feats
203
203
  return batch
@@ -251,8 +251,7 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
251
251
 
252
252
  def preprocess_batch(self, batch):
253
253
  """Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
254
- batch = DetectionTrainer.preprocess_batch(self, batch)
255
- return batch
254
+ return DetectionTrainer.preprocess_batch(self, batch)
256
255
 
257
256
  def set_text_embeddings(self, datasets, batch: int):
258
257
  """
@@ -318,5 +317,5 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
318
317
  def preprocess_batch(self, batch):
319
318
  """Preprocess a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
320
319
  batch = super().preprocess_batch(batch)
321
- batch["visuals"] = batch["visuals"].to(self.device)
320
+ batch["visuals"] = batch["visuals"].to(self.device, non_blocking=True)
322
321
  return batch
@@ -102,7 +102,7 @@ class YOLOEDetectValidator(DetectionValidator):
102
102
  """Preprocess batch data, ensuring visuals are on the same device as images."""
103
103
  batch = super().preprocess(batch)
104
104
  if "visuals" in batch:
105
- batch["visuals"] = batch["visuals"].to(batch["img"].device)
105
+ batch["visuals"] = batch["visuals"].to(batch["img"].device, non_blocking=True)
106
106
  return batch
107
107
 
108
108
  def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
@@ -186,9 +186,9 @@ class YOLOEDetectValidator(DetectionValidator):
186
186
  self.device = select_device(self.args.device, verbose=False)
187
187
 
188
188
  if isinstance(model, (str, Path)):
189
- from ultralytics.nn.tasks import attempt_load_weights
189
+ from ultralytics.nn.tasks import load_checkpoint
190
190
 
191
- model = attempt_load_weights(model, device=self.device)
191
+ model, _ = load_checkpoint(model, device=self.device) # model, ckpt
192
192
  model.eval().to(self.device)
193
193
  data = check_det_dataset(refer_data or self.args.data)
194
194
  names = [name.split("/", 1)[0] for name in list(data["names"].values())]
@@ -5,18 +5,16 @@ from .tasks import (
5
5
  ClassificationModel,
6
6
  DetectionModel,
7
7
  SegmentationModel,
8
- attempt_load_one_weight,
9
- attempt_load_weights,
10
8
  guess_model_scale,
11
9
  guess_model_task,
10
+ load_checkpoint,
12
11
  parse_model,
13
12
  torch_safe_load,
14
13
  yaml_model_load,
15
14
  )
16
15
 
17
16
  __all__ = (
18
- "attempt_load_one_weight",
19
- "attempt_load_weights",
17
+ "load_checkpoint",
20
18
  "parse_model",
21
19
  "yaml_model_load",
22
20
  "guess_model_task",
@@ -203,9 +203,9 @@ class AutoBackend(nn.Module):
203
203
  model = model.fuse(verbose=verbose)
204
204
  model = model.to(device)
205
205
  else: # pt file
206
- from ultralytics.nn.tasks import attempt_load_one_weight
206
+ from ultralytics.nn.tasks import load_checkpoint
207
207
 
208
- model, _ = attempt_load_one_weight(model, device=device, fuse=fuse) # load model, ckpt
208
+ model, _ = load_checkpoint(model, device=device, fuse=fuse) # load model, ckpt
209
209
 
210
210
  # Common PyTorch model processing
211
211
  if hasattr(model, "kpt_shape"):
ultralytics/nn/tasks.py CHANGED
@@ -1483,61 +1483,12 @@ def torch_safe_load(weight, safe_only=False):
1483
1483
  return ckpt, file
1484
1484
 
1485
1485
 
1486
- def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
1487
- """
1488
- Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
1489
-
1490
- Args:
1491
- weights (str | List[str]): Model weights path(s).
1492
- device (torch.device, optional): Device to load model to.
1493
- inplace (bool): Whether to do inplace operations.
1494
- fuse (bool): Whether to fuse model.
1495
-
1496
- Returns:
1497
- (torch.nn.Module): Loaded model.
1498
- """
1499
- ensemble = Ensemble()
1500
- for w in weights if isinstance(weights, list) else [weights]:
1501
- ckpt, w = torch_safe_load(w) # load ckpt
1502
- args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
1503
- model = (ckpt.get("ema") or ckpt["model"]).float() # FP32 model
1504
-
1505
- # Model compatibility updates
1506
- model.args = args # attach args to model
1507
- model.pt_path = w # attach *.pt file path to model
1508
- model.task = getattr(model, "task", guess_model_task(model))
1509
- if not hasattr(model, "stride"):
1510
- model.stride = torch.tensor([32.0])
1511
-
1512
- # Append
1513
- ensemble.append((model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()).to(device))
1514
-
1515
- # Module updates
1516
- for m in ensemble.modules():
1517
- if hasattr(m, "inplace"):
1518
- m.inplace = inplace
1519
- elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
1520
- m.recompute_scale_factor = None # torch 1.11.0 compatibility
1521
-
1522
- # Return model
1523
- if len(ensemble) == 1:
1524
- return ensemble[-1]
1525
-
1526
- # Return ensemble
1527
- LOGGER.info(f"Ensemble created with {weights}\n")
1528
- for k in "names", "nc", "yaml":
1529
- setattr(ensemble, k, getattr(ensemble[0], k))
1530
- ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
1531
- assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
1532
- return ensemble
1533
-
1534
-
1535
- def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
1486
+ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1536
1487
  """
1537
1488
  Load a single model weights.
1538
1489
 
1539
1490
  Args:
1540
- weight (str): Model weight path.
1491
+ weight (str | Path): Model weight path.
1541
1492
  device (torch.device, optional): Device to load model to.
1542
1493
  inplace (bool): Whether to do inplace operations.
1543
1494
  fuse (bool): Whether to fuse model.
@@ -49,7 +49,7 @@ MACOS_VERSION = platform.mac_ver()[0] if MACOS else None
49
49
  NOT_MACOS14 = not (MACOS and MACOS_VERSION.startswith("14."))
50
50
  ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
51
51
  PYTHON_VERSION = platform.python_version()
52
- TORCH_VERSION = torch.__version__
52
+ TORCH_VERSION = str(torch.__version__) # Normalize torch.__version__ (PyTorch>1.9 returns TorchVersion objects)
53
53
  TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
54
54
  IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode"
55
55
  RKNN_CHIPS = frozenset(
@@ -132,6 +132,10 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warning
132
132
  os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
133
133
  os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
134
134
 
135
+ # Precompiled type tuples for faster isinstance() checks
136
+ FLOAT_OR_INT = (float, int)
137
+ STR_OR_PATH = (str, Path)
138
+
135
139
 
136
140
  class DataExportMixin:
137
141
  """