dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,8 @@ Usage:
6
6
  $ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  import gc
10
12
  import math
11
13
  import os
@@ -61,8 +63,7 @@ from ultralytics.utils.torch_utils import (
61
63
 
62
64
 
63
65
  class BaseTrainer:
64
- """
65
- A base class for creating trainers.
66
+ """A base class for creating trainers.
66
67
 
67
68
  This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
68
69
  and various training utilities. It supports both single-GPU and multi-GPU distributed training.
@@ -112,8 +113,7 @@ class BaseTrainer:
112
113
  """
113
114
 
114
115
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
115
- """
116
- Initialize the BaseTrainer class.
116
+ """Initialize the BaseTrainer class.
117
117
 
118
118
  Args:
119
119
  cfg (str, optional): Path to a configuration file.
@@ -138,7 +138,12 @@ class BaseTrainer:
138
138
  if RANK in {-1, 0}:
139
139
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
140
140
  self.args.save_dir = str(self.save_dir)
141
- YAML.save(self.save_dir / "args.yaml", vars(self.args)) # save run args
141
+ # Save run args, serializing augmentations as reprs for resume compatibility
142
+ args_dict = vars(self.args).copy()
143
+ if args_dict.get("augmentations") is not None:
144
+ # Serialize Albumentations transforms as their repr strings for checkpoint compatibility
145
+ args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
146
+ YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
142
147
  self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
143
148
  self.save_period = self.args.save_period
144
149
 
@@ -318,18 +323,18 @@ class BaseTrainer:
318
323
  self.train_loader = self.get_dataloader(
319
324
  self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
320
325
  )
326
+ # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
327
+ self.test_loader = self.get_dataloader(
328
+ self.data.get("val") or self.data.get("test"),
329
+ batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
330
+ rank=LOCAL_RANK,
331
+ mode="val",
332
+ )
333
+ self.validator = self.get_validator()
334
+ self.ema = ModelEMA(self.model)
321
335
  if RANK in {-1, 0}:
322
- # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
323
- self.test_loader = self.get_dataloader(
324
- self.data.get("val") or self.data.get("test"),
325
- batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
326
- rank=-1,
327
- mode="val",
328
- )
329
- self.validator = self.get_validator()
330
336
  metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
331
337
  self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
332
- self.ema = ModelEMA(self.model)
333
338
  if self.args.plots:
334
339
  self.plot_training_labels()
335
340
 
@@ -464,13 +469,13 @@ class BaseTrainer:
464
469
 
465
470
  self.run_callbacks("on_train_epoch_end")
466
471
  if RANK in {-1, 0}:
467
- final_epoch = epoch + 1 >= self.epochs
468
472
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
469
473
 
470
- # Validation
471
- if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
472
- self._clear_memory(threshold=0.5) # prevent VRAM spike
473
- self.metrics, self.fitness = self.validate()
474
+ # Validation
475
+ final_epoch = epoch + 1 >= self.epochs
476
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
477
+ self._clear_memory(threshold=0.5) # prevent VRAM spike
478
+ self.metrics, self.fitness = self.validate()
474
479
 
475
480
  # NaN recovery
476
481
  if self._handle_nan_recovery(epoch):
@@ -510,11 +515,11 @@ class BaseTrainer:
510
515
  break # must break all DDP ranks
511
516
  epoch += 1
512
517
 
518
+ seconds = time.time() - self.train_time_start
519
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
520
+ # Do final val with best.pt
521
+ self.final_eval()
513
522
  if RANK in {-1, 0}:
514
- # Do final val with best.pt
515
- seconds = time.time() - self.train_time_start
516
- LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
517
- self.final_eval()
518
523
  if self.args.plots:
519
524
  self.plot_metrics()
520
525
  self.run_callbacks("on_train_end")
@@ -545,7 +550,7 @@ class BaseTrainer:
545
550
  total = torch.cuda.get_device_properties(self.device).total_memory
546
551
  return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
547
552
 
548
- def _clear_memory(self, threshold: float = None):
553
+ def _clear_memory(self, threshold: float | None = None):
549
554
  """Clear accelerator memory by calling garbage collector and emptying cache."""
550
555
  if threshold:
551
556
  assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
@@ -618,8 +623,7 @@ class BaseTrainer:
618
623
  (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
619
624
 
620
625
  def get_dataset(self):
621
- """
622
- Get train and validation datasets from data dictionary.
626
+ """Get train and validation datasets from data dictionary.
623
627
 
624
628
  Returns:
625
629
  (dict): A dictionary containing the training/validation/test dataset and category names.
@@ -627,7 +631,7 @@ class BaseTrainer:
627
631
  try:
628
632
  if self.args.task == "classify":
629
633
  data = check_cls_dataset(self.args.data)
630
- elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
634
+ elif str(self.args.data).rsplit(".", 1)[-1] == "ndjson":
631
635
  # Convert NDJSON to YOLO format
632
636
  import asyncio
633
637
 
@@ -636,7 +640,7 @@ class BaseTrainer:
636
640
  yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
637
641
  self.args.data = str(yaml_path)
638
642
  data = check_det_dataset(self.args.data)
639
- elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
643
+ elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
640
644
  "detect",
641
645
  "segment",
642
646
  "pose",
@@ -654,8 +658,7 @@ class BaseTrainer:
654
658
  return data
655
659
 
656
660
  def setup_model(self):
657
- """
658
- Load, create, or download model for any task.
661
+ """Load, create, or download model for any task.
659
662
 
660
663
  Returns:
661
664
  (dict): Optional checkpoint to resume training from.
@@ -688,14 +691,19 @@ class BaseTrainer:
688
691
  return batch
689
692
 
690
693
  def validate(self):
691
- """
692
- Run validation on val set using self.validator.
694
+ """Run validation on val set using self.validator.
693
695
 
694
696
  Returns:
695
697
  metrics (dict): Dictionary of validation metrics.
696
698
  fitness (float): Fitness score for the validation.
697
699
  """
700
+ if self.ema and self.world_size > 1:
701
+ # Sync EMA buffers from rank 0 to all ranks
702
+ for buffer in self.ema.ema.buffers():
703
+ dist.broadcast(buffer, src=0)
698
704
  metrics = self.validator(self)
705
+ if metrics is None:
706
+ return None, None
699
707
  fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
700
708
  if not self.best_fitness or self.best_fitness < fitness:
701
709
  self.best_fitness = fitness
@@ -706,11 +714,11 @@ class BaseTrainer:
706
714
  raise NotImplementedError("This task trainer doesn't support loading cfg files")
707
715
 
708
716
  def get_validator(self):
709
- """Return a NotImplementedError when the get_validator function is called."""
717
+ """Raise NotImplementedError (must be implemented by subclasses)."""
710
718
  raise NotImplementedError("get_validator function not implemented in trainer")
711
719
 
712
720
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
713
- """Return dataloader derived from torch.data.Dataloader."""
721
+ """Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
714
722
  raise NotImplementedError("get_dataloader function not implemented in trainer")
715
723
 
716
724
  def build_dataset(self, img_path, mode="train", batch=None):
@@ -718,10 +726,9 @@ class BaseTrainer:
718
726
  raise NotImplementedError("build_dataset function not implemented in trainer")
719
727
 
720
728
  def label_loss_items(self, loss_items=None, prefix="train"):
721
- """
722
- Return a loss dict with labelled training loss items tensor.
729
+ """Return a loss dict with labeled training loss items tensor.
723
730
 
724
- Note:
731
+ Notes:
725
732
  This is not needed for classification but necessary for segmentation & detection
726
733
  """
727
734
  return {"loss": loss_items} if loss_items is not None else ["loss"]
@@ -753,9 +760,9 @@ class BaseTrainer:
753
760
  n = len(metrics) + 2 # number of cols
754
761
  t = time.time() - self.train_time_start
755
762
  self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
756
- s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
763
+ s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
757
764
  with open(self.csv, "a", encoding="utf-8") as f:
758
- f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
765
+ f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
759
766
 
760
767
  def plot_metrics(self):
761
768
  """Plot metrics from a CSV file."""
@@ -768,20 +775,20 @@ class BaseTrainer:
768
775
 
769
776
  def final_eval(self):
770
777
  """Perform final evaluation and validation for object detection YOLO model."""
771
- ckpt = {}
772
- for f in self.last, self.best:
773
- if f.exists():
774
- if f is self.last:
775
- ckpt = strip_optimizer(f)
776
- elif f is self.best:
777
- k = "train_results" # update best.pt train_metrics from last.pt
778
- strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
779
- LOGGER.info(f"\nValidating {f}...")
780
- self.validator.args.plots = self.args.plots
781
- self.validator.args.compile = False # disable final val compile as too slow
782
- self.metrics = self.validator(model=f)
783
- self.metrics.pop("fitness", None)
784
- self.run_callbacks("on_fit_epoch_end")
778
+ model = self.best if self.best.exists() else None
779
+ with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
780
+ if RANK in {-1, 0}:
781
+ ckpt = strip_optimizer(self.last) if self.last.exists() else {}
782
+ if model:
783
+ # update best.pt train_metrics from last.pt
784
+ strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
785
+ if model:
786
+ LOGGER.info(f"\nValidating {model}...")
787
+ self.validator.args.plots = self.args.plots
788
+ self.validator.args.compile = False # disable final val compile as too slow
789
+ self.metrics = self.validator(model=model)
790
+ self.metrics.pop("fitness", None)
791
+ self.run_callbacks("on_fit_epoch_end")
785
792
 
786
793
  def check_resume(self, overrides):
787
794
  """Check if resume checkpoint exists and update arguments accordingly."""
@@ -804,10 +811,29 @@ class BaseTrainer:
804
811
  "batch",
805
812
  "device",
806
813
  "close_mosaic",
814
+ "augmentations",
815
+ "save_period",
816
+ "workers",
817
+ "cache",
818
+ "patience",
819
+ "time",
820
+ "freeze",
821
+ "val",
822
+ "plots",
807
823
  ): # allow arg updates to reduce memory or update device on resume
808
824
  if k in overrides:
809
825
  setattr(self.args, k, overrides[k])
810
826
 
827
+ # Handle augmentations parameter for resume: check if user provided custom augmentations
828
+ if ckpt_args.get("augmentations") is not None:
829
+ # Augmentations were saved in checkpoint as reprs but can't be restored automatically
830
+ LOGGER.warning(
831
+ "Custom Albumentations transforms were used in the original training run but are not "
832
+ "being restored. To preserve custom augmentations when resuming, you need to pass the "
833
+ "'augmentations' parameter again to get expected results. Example: \n"
834
+ f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
835
+ )
836
+
811
837
  except Exception as e:
812
838
  raise FileNotFoundError(
813
839
  "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
@@ -887,18 +913,16 @@ class BaseTrainer:
887
913
  self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
888
914
 
889
915
  def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
890
- """
891
- Construct an optimizer for the given model.
916
+ """Construct an optimizer for the given model.
892
917
 
893
918
  Args:
894
919
  model (torch.nn.Module): The model for which to build an optimizer.
895
- name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
896
- based on the number of iterations.
920
+ name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
921
+ number of iterations.
897
922
  lr (float, optional): The learning rate for the optimizer.
898
923
  momentum (float, optional): The momentum factor for the optimizer.
899
924
  decay (float, optional): The weight decay for the optimizer.
900
- iterations (float, optional): The number of iterations, which determines the optimizer if
901
- name is 'auto'.
925
+ iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
902
926
 
903
927
  Returns:
904
928
  (torch.optim.Optimizer): The constructed optimizer.
@@ -8,7 +8,7 @@ that yield the best model performance. This is particularly crucial in deep lear
8
8
  where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
9
9
 
10
10
  Examples:
11
- Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
11
+ Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
12
12
  >>> from ultralytics import YOLO
13
13
  >>> model = YOLO("yolo11n.pt")
14
14
  >>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
@@ -34,12 +34,11 @@ from ultralytics.utils.plotting import plot_tune_results
34
34
 
35
35
 
36
36
  class Tuner:
37
- """
38
- A class for hyperparameter tuning of YOLO models.
37
+ """A class for hyperparameter tuning of YOLO models.
39
38
 
40
39
  The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
41
- search space and retraining the model to evaluate their performance. Supports both local CSV storage and
42
- distributed MongoDB Atlas coordination for multi-machine hyperparameter optimization.
40
+ search space and retraining the model to evaluate their performance. Supports both local CSV storage and distributed
41
+ MongoDB Atlas coordination for multi-machine hyperparameter optimization.
43
42
 
44
43
  Attributes:
45
44
  space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
@@ -56,7 +55,7 @@ class Tuner:
56
55
  __call__: Execute the hyperparameter evolution across multiple iterations.
57
56
 
58
57
  Examples:
59
- Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
58
+ Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
60
59
  >>> from ultralytics import YOLO
61
60
  >>> model = YOLO("yolo11n.pt")
62
61
  >>> model.tune(
@@ -83,8 +82,7 @@ class Tuner:
83
82
  """
84
83
 
85
84
  def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
86
- """
87
- Initialize the Tuner with configurations.
85
+ """Initialize the Tuner with configurations.
88
86
 
89
87
  Args:
90
88
  args (dict): Configuration for hyperparameter evolution.
@@ -142,8 +140,7 @@ class Tuner:
142
140
  )
143
141
 
144
142
  def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
145
- """
146
- Create MongoDB client with exponential backoff retry on connection failures.
143
+ """Create MongoDB client with exponential backoff retry on connection failures.
147
144
 
148
145
  Args:
149
146
  uri (str): MongoDB connection string with credentials and cluster information.
@@ -183,12 +180,10 @@ class Tuner:
183
180
  time.sleep(wait_time)
184
181
 
185
182
  def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
186
- """
187
- Initialize MongoDB connection for distributed tuning.
183
+ """Initialize MongoDB connection for distributed tuning.
188
184
 
189
- Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines.
190
- Each worker saves results to a shared collection and reads the latest best hyperparameters
191
- from all workers for evolution.
185
+ Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines. Each worker
186
+ saves results to a shared collection and reads the latest best hyperparameters from all workers for evolution.
192
187
 
193
188
  Args:
194
189
  mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
@@ -206,8 +201,7 @@ class Tuner:
206
201
  LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
207
202
 
208
203
  def _get_mongodb_results(self, n: int = 5) -> list:
209
- """
210
- Get top N results from MongoDB sorted by fitness.
204
+ """Get top N results from MongoDB sorted by fitness.
211
205
 
212
206
  Args:
213
207
  n (int): Number of top results to retrieve.
@@ -221,8 +215,7 @@ class Tuner:
221
215
  return []
222
216
 
223
217
  def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
224
- """
225
- Save results to MongoDB with proper type conversion.
218
+ """Save results to MongoDB with proper type conversion.
226
219
 
227
220
  Args:
228
221
  fitness (float): Fitness score achieved with these hyperparameters.
@@ -233,7 +226,7 @@ class Tuner:
233
226
  try:
234
227
  self.collection.insert_one(
235
228
  {
236
- "fitness": float(fitness),
229
+ "fitness": fitness,
237
230
  "hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
238
231
  "metrics": metrics,
239
232
  "timestamp": datetime.now(),
@@ -244,8 +237,7 @@ class Tuner:
244
237
  LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
245
238
 
246
239
  def _sync_mongodb_to_csv(self):
247
- """
248
- Sync MongoDB results to CSV for plotting compatibility.
240
+ """Sync MongoDB results to CSV for plotting compatibility.
249
241
 
250
242
  Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
251
243
  the existing plotting functions to work seamlessly with distributed MongoDB data.
@@ -257,19 +249,20 @@ class Tuner:
257
249
  return
258
250
 
259
251
  # Write to CSV
260
- headers = ",".join(["fitness"] + list(self.space.keys())) + "\n"
252
+ headers = ",".join(["fitness", *list(self.space.keys())]) + "\n"
261
253
  with open(self.tune_csv, "w", encoding="utf-8") as f:
262
254
  f.write(headers)
263
255
  for result in all_results:
264
256
  fitness = result["fitness"]
265
257
  hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
266
- log_row = [round(fitness, 5)] + hyp_values
258
+ log_row = [round(fitness, 5), *hyp_values]
267
259
  f.write(",".join(map(str, log_row)) + "\n")
268
260
 
269
261
  except Exception as e:
270
262
  LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
271
263
 
272
- def _crossover(self, x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
264
+ @staticmethod
265
+ def _crossover(x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
273
266
  """BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
274
267
  k = min(k, len(x))
275
268
  # fitness weights (shifted to >0); fallback to uniform if degenerate
@@ -288,11 +281,9 @@ class Tuner:
288
281
  mutation: float = 0.5,
289
282
  sigma: float = 0.2,
290
283
  ) -> dict[str, float]:
291
- """
292
- Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
284
+ """Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
293
285
 
294
286
  Args:
295
- parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
296
287
  n (int): Number of top parents to consider.
297
288
  mutation (float): Probability of a parameter mutation in any given iteration.
298
289
  sigma (float): Standard deviation for Gaussian random number generator.
@@ -304,8 +295,7 @@ class Tuner:
304
295
 
305
296
  # Try MongoDB first if available
306
297
  if self.mongodb:
307
- results = self._get_mongodb_results(n)
308
- if results:
298
+ if results := self._get_mongodb_results(n):
309
299
  # MongoDB already sorted by fitness DESC, so results[0] is best
310
300
  x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
311
301
  elif self.collection.name in self.collection.database.list_collection_names(): # Tuner started elsewhere
@@ -344,13 +334,12 @@ class Tuner:
344
334
 
345
335
  # Update types
346
336
  if "close_mosaic" in hyp:
347
- hyp["close_mosaic"] = int(round(hyp["close_mosaic"]))
337
+ hyp["close_mosaic"] = round(hyp["close_mosaic"])
348
338
 
349
339
  return hyp
350
340
 
351
341
  def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
352
- """
353
- Execute the hyperparameter evolution process when the Tuner instance is called.
342
+ """Execute the hyperparameter evolution process when the Tuner instance is called.
354
343
 
355
344
  This method iterates through the specified number of iterations, performing the following steps:
356
345
  1. Sync MongoDB results to CSV (if using distributed mode)
@@ -421,7 +410,7 @@ class Tuner:
421
410
  else:
422
411
  # Save to CSV only if no MongoDB
423
412
  log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
424
- headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
413
+ headers = "" if self.tune_csv.exists() else (",".join(["fitness", *list(self.space.keys())]) + "\n")
425
414
  with open(self.tune_csv, "a", encoding="utf-8") as f:
426
415
  f.write(headers + ",".join(map(str, log_row)) + "\n")
427
416
 
@@ -29,26 +29,26 @@ from pathlib import Path
29
29
 
30
30
  import numpy as np
31
31
  import torch
32
+ import torch.distributed as dist
32
33
 
33
34
  from ultralytics.cfg import get_cfg, get_save_dir
34
35
  from ultralytics.data.utils import check_cls_dataset, check_det_dataset
35
36
  from ultralytics.nn.autobackend import AutoBackend
36
- from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
37
+ from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
37
38
  from ultralytics.utils.checks import check_imgsz
38
39
  from ultralytics.utils.ops import Profile
39
40
  from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
40
41
 
41
42
 
42
43
  class BaseValidator:
43
- """
44
- A base class for creating validators.
44
+ """A base class for creating validators.
45
45
 
46
46
  This class provides the foundation for validation processes, including model evaluation, metric computation, and
47
47
  result visualization.
48
48
 
49
49
  Attributes:
50
50
  args (SimpleNamespace): Configuration for the validator.
51
- dataloader (DataLoader): Dataloader to use for validation.
51
+ dataloader (DataLoader): DataLoader to use for validation.
52
52
  model (nn.Module): Model to validate.
53
53
  data (dict): Data dictionary containing dataset information.
54
54
  device (torch.device): Device to use for validation.
@@ -61,8 +61,8 @@ class BaseValidator:
61
61
  nc (int): Number of classes.
62
62
  iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
63
63
  jdict (list): List to store JSON validation results.
64
- speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
65
- batch processing times in milliseconds.
64
+ speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
65
+ processing times in milliseconds.
66
66
  save_dir (Path): Directory to save results.
67
67
  plots (dict): Dictionary to store plots for visualization.
68
68
  callbacks (dict): Dictionary to store various callback functions.
@@ -92,11 +92,10 @@ class BaseValidator:
92
92
  """
93
93
 
94
94
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
95
- """
96
- Initialize a BaseValidator instance.
95
+ """Initialize a BaseValidator instance.
97
96
 
98
97
  Args:
99
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
98
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
100
99
  save_dir (Path, optional): Directory to save results.
101
100
  args (SimpleNamespace, optional): Configuration for the validator.
102
101
  _callbacks (dict, optional): Dictionary to store various callback functions.
@@ -130,8 +129,7 @@ class BaseValidator:
130
129
 
131
130
  @smart_inference_mode()
132
131
  def __call__(self, trainer=None, model=None):
133
- """
134
- Execute validation process, running inference on dataloader and computing performance metrics.
132
+ """Execute validation process, running inference on dataloader and computing performance metrics.
135
133
 
136
134
  Args:
137
135
  trainer (object, optional): Trainer object that contains the model to validate.
@@ -160,7 +158,7 @@ class BaseValidator:
160
158
  callbacks.add_integration_callbacks(self)
161
159
  model = AutoBackend(
162
160
  model=model or self.args.model,
163
- device=select_device(self.args.device),
161
+ device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
164
162
  dnn=self.args.dnn,
165
163
  data=self.args.data,
166
164
  fp16=self.args.half,
@@ -223,21 +221,34 @@ class BaseValidator:
223
221
  preds = self.postprocess(preds)
224
222
 
225
223
  self.update_metrics(preds, batch)
226
- if self.args.plots and batch_i < 3:
224
+ if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
227
225
  self.plot_val_samples(batch, batch_i)
228
226
  self.plot_predictions(batch, preds, batch_i)
229
227
 
230
228
  self.run_callbacks("on_val_batch_end")
231
- stats = self.get_stats()
232
- self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
233
- self.finalize_metrics()
234
- self.print_results()
235
- self.run_callbacks("on_val_end")
229
+
230
+ stats = {}
231
+ self.gather_stats()
232
+ if RANK in {-1, 0}:
233
+ stats = self.get_stats()
234
+ self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
235
+ self.finalize_metrics()
236
+ self.print_results()
237
+ self.run_callbacks("on_val_end")
238
+
236
239
  if self.training:
237
240
  model.float()
238
- results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
241
+ # Reduce loss across all GPUs
242
+ loss = self.loss.clone().detach()
243
+ if trainer.world_size > 1:
244
+ dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
245
+ if RANK > 0:
246
+ return
247
+ results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
239
248
  return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
240
249
  else:
250
+ if RANK > 0:
251
+ return stats
241
252
  LOGGER.info(
242
253
  "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
243
254
  *tuple(self.speed.values())
@@ -255,8 +266,7 @@ class BaseValidator:
255
266
  def match_predictions(
256
267
  self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
257
268
  ) -> torch.Tensor:
258
- """
259
- Match predictions to ground truth objects using IoU.
269
+ """Match predictions to ground truth objects using IoU.
260
270
 
261
271
  Args:
262
272
  pred_classes (torch.Tensor): Predicted class indices of shape (N,).
@@ -336,6 +346,10 @@ class BaseValidator:
336
346
  """Return statistics about the model's performance."""
337
347
  return {}
338
348
 
349
+ def gather_stats(self):
350
+ """Gather statistics from all the GPUs during DDP training to GPU 0."""
351
+ pass
352
+
339
353
  def print_results(self):
340
354
  """Print the results of the model's predictions."""
341
355
  pass
@@ -350,7 +364,10 @@ class BaseValidator:
350
364
  return []
351
365
 
352
366
  def on_plot(self, name, data=None):
353
- """Register plots for visualization."""
367
+ """Register plots for visualization, deduplicating by type."""
368
+ plot_type = data.get("type") if data else None
369
+ if plot_type and any((v.get("data") or {}).get("type") == plot_type for v in self.plots.values()):
370
+ return # Skip duplicate plot types
354
371
  self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
355
372
 
356
373
  def plot_val_samples(self, batch, ni):