ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,974 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """
3
+ Train a model on a dataset.
4
+
5
+ Usage:
6
+ $ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import gc
12
+ import math
13
+ import os
14
+ import subprocess
15
+ import time
16
+ import warnings
17
+ from copy import copy, deepcopy
18
+ from datetime import datetime, timedelta
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import distributed as dist
24
+ from torch import nn, optim
25
+
26
+ from ultralytics import __version__
27
+ from ultralytics.cfg import get_cfg, get_save_dir
28
+ from ultralytics.data.utils import check_cls_dataset, check_det_dataset
29
+ from ultralytics.nn.tasks import load_checkpoint
30
+ from ultralytics.utils import (
31
+ DEFAULT_CFG,
32
+ GIT,
33
+ LOCAL_RANK,
34
+ LOGGER,
35
+ RANK,
36
+ TQDM,
37
+ YAML,
38
+ callbacks,
39
+ clean_url,
40
+ colorstr,
41
+ emojis,
42
+ )
43
+ from ultralytics.utils.autobatch import check_train_batch_size
44
+ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
45
+ from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
46
+ from ultralytics.utils.files import get_latest_run
47
+ from ultralytics.utils.plotting import plot_results
48
+ from ultralytics.utils.torch_utils import (
49
+ TORCH_2_4,
50
+ EarlyStopping,
51
+ ModelEMA,
52
+ attempt_compile,
53
+ autocast,
54
+ convert_optimizer_state_dict_to_fp16,
55
+ init_seeds,
56
+ one_cycle,
57
+ select_device,
58
+ strip_optimizer,
59
+ torch_distributed_zero_first,
60
+ unset_deterministic,
61
+ unwrap_model,
62
+ )
63
+
64
+
65
+ class BaseTrainer:
66
+ """A base class for creating trainers.
67
+
68
+ This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
69
+ and various training utilities. It supports both single-GPU and multi-GPU distributed training.
70
+
71
+ Attributes:
72
+ args (SimpleNamespace): Configuration for the trainer.
73
+ validator (BaseValidator): Validator instance.
74
+ model (nn.Module): Model instance.
75
+ callbacks (defaultdict): Dictionary of callbacks.
76
+ save_dir (Path): Directory to save results.
77
+ wdir (Path): Directory to save weights.
78
+ last (Path): Path to the last checkpoint.
79
+ best (Path): Path to the best checkpoint.
80
+ save_period (int): Save checkpoint every x epochs (disabled if < 1).
81
+ batch_size (int): Batch size for training.
82
+ epochs (int): Number of epochs to train for.
83
+ start_epoch (int): Starting epoch for training.
84
+ device (torch.device): Device to use for training.
85
+ amp (bool): Flag to enable AMP (Automatic Mixed Precision).
86
+ scaler (amp.GradScaler): Gradient scaler for AMP.
87
+ data (str): Path to data.
88
+ ema (nn.Module): EMA (Exponential Moving Average) of the model.
89
+ resume (bool): Resume training from a checkpoint.
90
+ lf (nn.Module): Loss function.
91
+ scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
92
+ best_fitness (float): The best fitness value achieved.
93
+ fitness (float): Current fitness value.
94
+ loss (float): Current loss value.
95
+ tloss (float): Total loss value.
96
+ loss_names (list): List of loss names.
97
+ csv (Path): Path to results CSV file.
98
+ metrics (dict): Dictionary of metrics.
99
+ plots (dict): Dictionary of plots.
100
+
101
+ Methods:
102
+ train: Execute the training process.
103
+ validate: Run validation on the test set.
104
+ save_model: Save model training checkpoints.
105
+ get_dataset: Get train and validation datasets.
106
+ setup_model: Load, create, or download model.
107
+ build_optimizer: Construct an optimizer for the model.
108
+
109
+ Examples:
110
+ Initialize a trainer and start training
111
+ >>> trainer = BaseTrainer(cfg="config.yaml")
112
+ >>> trainer.train()
113
+ """
114
+
115
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
116
+ """Initialize the BaseTrainer class.
117
+
118
+ Args:
119
+ cfg (str, optional): Path to a configuration file.
120
+ overrides (dict, optional): Configuration overrides.
121
+ _callbacks (list, optional): List of callback functions.
122
+ """
123
+ self.hub_session = overrides.pop("session", None) # HUB
124
+ self.args = get_cfg(cfg, overrides)
125
+ self.check_resume(overrides)
126
+ self.device = select_device(self.args.device)
127
+ # Update "-1" devices so post-training val does not repeat search
128
+ self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
129
+ self.validator = None
130
+ self.metrics = None
131
+ self.plots = {}
132
+ init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
133
+
134
+ # Dirs
135
+ self.save_dir = get_save_dir(self.args)
136
+ self.args.name = self.save_dir.name # update name for loggers
137
+ self.wdir = self.save_dir / "weights" # weights dir
138
+ if RANK in {-1, 0}:
139
+ self.wdir.mkdir(parents=True, exist_ok=True) # make dir
140
+ self.args.save_dir = str(self.save_dir)
141
+ # Save run args, serializing augmentations as reprs for resume compatibility
142
+ args_dict = vars(self.args).copy()
143
+ if args_dict.get("augmentations") is not None:
144
+ # Serialize Albumentations transforms as their repr strings for checkpoint compatibility
145
+ args_dict["augmentations"] = [repr(t) for t in args_dict["augmentations"]]
146
+ YAML.save(self.save_dir / "args.yaml", args_dict) # save run args
147
+ self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
148
+ self.save_period = self.args.save_period
149
+
150
+ self.batch_size = self.args.batch
151
+ self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
152
+ self.start_epoch = 0
153
+ if RANK == -1:
154
+ print_args(vars(self.args))
155
+
156
+ # Device
157
+ if self.device.type in {"cpu", "mps"}:
158
+ self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
159
+
160
+ # Model and Dataset
161
+ self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt
162
+ with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
163
+ self.data = self.get_dataset()
164
+
165
+ self.ema = None
166
+
167
+ # Optimization utils init
168
+ self.lf = None
169
+ self.scheduler = None
170
+
171
+ # Epoch level metrics
172
+ self.best_fitness = None
173
+ self.fitness = None
174
+ self.loss = None
175
+ self.tloss = None
176
+ self.loss_names = ["Loss"]
177
+ self.csv = self.save_dir / "results.csv"
178
+ if self.csv.exists() and not self.args.resume:
179
+ self.csv.unlink()
180
+ self.plot_idx = [0, 1, 2]
181
+ self.nan_recovery_attempts = 0
182
+
183
+ # Callbacks
184
+ self.callbacks = _callbacks or callbacks.get_default_callbacks()
185
+
186
+ if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
187
+ world_size = len(self.args.device.split(","))
188
+ elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
189
+ world_size = len(self.args.device)
190
+ elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
191
+ world_size = 0
192
+ elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
193
+ world_size = 1 # default to device 0
194
+ else: # i.e. device=None or device=''
195
+ world_size = 0
196
+
197
+ self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
198
+ self.world_size = world_size
199
+ # Run subprocess if DDP training, else train normally
200
+ if RANK in {-1, 0} and not self.ddp:
201
+ callbacks.add_integration_callbacks(self)
202
+ # Start console logging immediately at trainer initialization
203
+ self.run_callbacks("on_pretrain_routine_start")
204
+
205
+ def add_callback(self, event: str, callback):
206
+ """Append the given callback to the event's callback list."""
207
+ self.callbacks[event].append(callback)
208
+
209
+ def set_callback(self, event: str, callback):
210
+ """Override the existing callbacks with the given callback for the specified event."""
211
+ self.callbacks[event] = [callback]
212
+
213
+ def run_callbacks(self, event: str):
214
+ """Run all existing callbacks associated with a particular event."""
215
+ for callback in self.callbacks.get(event, []):
216
+ callback(self)
217
+
218
+ def train(self):
219
+ """Allow device='', device=None on Multi-GPU systems to default to device=0."""
220
+ # Run subprocess if DDP training, else train normally
221
+ if self.ddp:
222
+ # Argument checks
223
+ if self.args.rect:
224
+ LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
225
+ self.args.rect = False
226
+ if self.args.batch < 1.0:
227
+ raise ValueError(
228
+ "AutoBatch with batch<1 not supported for Multi-GPU training, "
229
+ f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
230
+ )
231
+
232
+ # Command
233
+ cmd, file = generate_ddp_command(self)
234
+ try:
235
+ LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
236
+ subprocess.run(cmd, check=True)
237
+ except Exception as e:
238
+ raise e
239
+ finally:
240
+ ddp_cleanup(self, str(file))
241
+
242
+ else:
243
+ self._do_train()
244
+
245
+ def _setup_scheduler(self):
246
+ """Initialize training learning rate scheduler."""
247
+ if self.args.cos_lr:
248
+ self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
249
+ else:
250
+ self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
251
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
252
+
253
+ def _setup_ddp(self):
254
+ """Initialize and set the DistributedDataParallel parameters for training."""
255
+ torch.cuda.set_device(RANK)
256
+ self.device = torch.device("cuda", RANK)
257
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
258
+ dist.init_process_group(
259
+ backend="nccl" if dist.is_nccl_available() else "gloo",
260
+ timeout=timedelta(seconds=10800), # 3 hours
261
+ rank=RANK,
262
+ world_size=self.world_size,
263
+ )
264
+
265
+ def _setup_train(self):
266
+ """Build dataloaders and optimizer on correct rank process."""
267
+ ckpt = self.setup_model()
268
+ self.model = self.model.to(self.device)
269
+ self.set_model_attributes()
270
+
271
+ # Compile model
272
+ self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
273
+
274
+ # Freeze layers
275
+ freeze_list = (
276
+ self.args.freeze
277
+ if isinstance(self.args.freeze, list)
278
+ else range(self.args.freeze)
279
+ if isinstance(self.args.freeze, int)
280
+ else []
281
+ )
282
+ always_freeze_names = [".dfl"] # always freeze these layers
283
+ freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
284
+ self.freeze_layer_names = freeze_layer_names
285
+ for k, v in self.model.named_parameters():
286
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
287
+ if any(x in k for x in freeze_layer_names):
288
+ LOGGER.info(f"Freezing layer '{k}'")
289
+ v.requires_grad = False
290
+ elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
291
+ LOGGER.warning(
292
+ f"setting 'requires_grad=True' for frozen layer '{k}'. "
293
+ "See ultralytics.engine.trainer for customization of frozen layers."
294
+ )
295
+ v.requires_grad = True
296
+
297
+ # Check AMP
298
+ self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
299
+ if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
300
+ callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
301
+ self.amp = torch.tensor(check_amp(self.model), device=self.device)
302
+ callbacks.default_callbacks = callbacks_backup # restore callbacks
303
+ if RANK > -1 and self.world_size > 1: # DDP
304
+ dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
305
+ self.amp = bool(self.amp) # as boolean
306
+ self.scaler = (
307
+ torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
308
+ )
309
+ if self.world_size > 1:
310
+ self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
311
+
312
+ # Check imgsz
313
+ gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
314
+ self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
315
+ self.stride = gs # for multiscale training
316
+
317
+ # Batch size
318
+ if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
319
+ self.args.batch = self.batch_size = self.auto_batch()
320
+
321
+ # Dataloaders
322
+ batch_size = self.batch_size // max(self.world_size, 1)
323
+ self.train_loader = self.get_dataloader(
324
+ self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
325
+ )
326
+ # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
327
+ self.test_loader = self.get_dataloader(
328
+ self.data.get("val") or self.data.get("test"),
329
+ batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
330
+ rank=LOCAL_RANK,
331
+ mode="val",
332
+ )
333
+ self.validator = self.get_validator()
334
+ self.ema = ModelEMA(self.model)
335
+ if RANK in {-1, 0}:
336
+ metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
337
+ self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
338
+ if self.args.plots:
339
+ self.plot_training_labels()
340
+
341
+ # Optimizer
342
+ self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
343
+ weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
344
+ iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
345
+ self.optimizer = self.build_optimizer(
346
+ model=self.model,
347
+ name=self.args.optimizer,
348
+ lr=self.args.lr0,
349
+ momentum=self.args.momentum,
350
+ decay=weight_decay,
351
+ iterations=iterations,
352
+ )
353
+ # Scheduler
354
+ self._setup_scheduler()
355
+ self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
356
+ self.resume_training(ckpt)
357
+ self.scheduler.last_epoch = self.start_epoch - 1 # do not move
358
+ self.run_callbacks("on_pretrain_routine_end")
359
+
360
+ def _do_train(self):
361
+ """Train the model with the specified world size."""
362
+ if self.world_size > 1:
363
+ self._setup_ddp()
364
+ self._setup_train()
365
+
366
+ nb = len(self.train_loader) # number of batches
367
+ nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
368
+ last_opt_step = -1
369
+ self.epoch_time = None
370
+ self.epoch_time_start = time.time()
371
+ self.train_time_start = time.time()
372
+ self.run_callbacks("on_train_start")
373
+ LOGGER.info(
374
+ f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
375
+ f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
376
+ f"Logging results to {colorstr('bold', self.save_dir)}\n"
377
+ f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
378
+ )
379
+ if self.args.close_mosaic:
380
+ base_idx = (self.epochs - self.args.close_mosaic) * nb
381
+ self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
382
+ epoch = self.start_epoch
383
+ self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
384
+ while True:
385
+ self.epoch = epoch
386
+ self.run_callbacks("on_train_epoch_start")
387
+ with warnings.catch_warnings():
388
+ warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
389
+ self.scheduler.step()
390
+
391
+ self._model_train()
392
+ if RANK != -1:
393
+ self.train_loader.sampler.set_epoch(epoch)
394
+ pbar = enumerate(self.train_loader)
395
+ # Update dataloader attributes (optional)
396
+ if epoch == (self.epochs - self.args.close_mosaic):
397
+ self._close_dataloader_mosaic()
398
+ self.train_loader.reset()
399
+
400
+ if RANK in {-1, 0}:
401
+ LOGGER.info(self.progress_string())
402
+ pbar = TQDM(enumerate(self.train_loader), total=nb)
403
+ self.tloss = None
404
+ for i, batch in pbar:
405
+ self.run_callbacks("on_train_batch_start")
406
+ # Warmup
407
+ ni = i + nb * epoch
408
+ if ni <= nw:
409
+ xi = [0, nw] # x interp
410
+ self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
411
+ for j, x in enumerate(self.optimizer.param_groups):
412
+ # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
413
+ x["lr"] = np.interp(
414
+ ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
415
+ )
416
+ if "momentum" in x:
417
+ x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
418
+
419
+ # Forward
420
+ with autocast(self.amp):
421
+ batch = self.preprocess_batch(batch)
422
+ if self.args.compile:
423
+ # Decouple inference and loss calculations for improved compile performance
424
+ preds = self.model(batch["img"])
425
+ loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
426
+ else:
427
+ loss, self.loss_items = self.model(batch)
428
+ self.loss = loss.sum()
429
+ if RANK != -1:
430
+ self.loss *= self.world_size
431
+ self.tloss = self.loss_items if self.tloss is None else (self.tloss * i + self.loss_items) / (i + 1)
432
+
433
+ # Backward
434
+ self.scaler.scale(self.loss).backward()
435
+ if ni - last_opt_step >= self.accumulate:
436
+ self.optimizer_step()
437
+ last_opt_step = ni
438
+
439
+ # Timed stopping
440
+ if self.args.time:
441
+ self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
442
+ if RANK != -1: # if DDP training
443
+ broadcast_list = [self.stop if RANK == 0 else None]
444
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
445
+ self.stop = broadcast_list[0]
446
+ if self.stop: # training time exceeded
447
+ break
448
+
449
+ # Log
450
+ if RANK in {-1, 0}:
451
+ loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
452
+ pbar.set_description(
453
+ ("%11s" * 2 + "%11.4g" * (2 + loss_length))
454
+ % (
455
+ f"{epoch + 1}/{self.epochs}",
456
+ f"{self._get_memory():.3g}G", # (GB) GPU memory util
457
+ *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
458
+ batch["cls"].shape[0], # batch size, i.e. 8
459
+ batch["img"].shape[-1], # imgsz, i.e 640
460
+ )
461
+ )
462
+ self.run_callbacks("on_batch_end")
463
+ if self.args.plots and ni in self.plot_idx:
464
+ self.plot_training_samples(batch, ni)
465
+
466
+ self.run_callbacks("on_train_batch_end")
467
+
468
+ self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
469
+
470
+ self.run_callbacks("on_train_epoch_end")
471
+ if RANK in {-1, 0}:
472
+ self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
473
+
474
+ # Validation
475
+ final_epoch = epoch + 1 >= self.epochs
476
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
477
+ self._clear_memory(threshold=0.5) # prevent VRAM spike
478
+ self.metrics, self.fitness = self.validate()
479
+
480
+ # NaN recovery
481
+ if self._handle_nan_recovery(epoch):
482
+ continue
483
+
484
+ self.nan_recovery_attempts = 0
485
+ if RANK in {-1, 0}:
486
+ self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
487
+ self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
488
+ if self.args.time:
489
+ self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
490
+
491
+ # Save model
492
+ if self.args.save or final_epoch:
493
+ self.save_model()
494
+ self.run_callbacks("on_model_save")
495
+
496
+ # Scheduler
497
+ t = time.time()
498
+ self.epoch_time = t - self.epoch_time_start
499
+ self.epoch_time_start = t
500
+ if self.args.time:
501
+ mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
502
+ self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
503
+ self._setup_scheduler()
504
+ self.scheduler.last_epoch = self.epoch # do not move
505
+ self.stop |= epoch >= self.epochs # stop if exceeded epochs
506
+ self.run_callbacks("on_fit_epoch_end")
507
+ self._clear_memory(0.5) # clear if memory utilization > 50%
508
+
509
+ # Early Stopping
510
+ if RANK != -1: # if DDP training
511
+ broadcast_list = [self.stop if RANK == 0 else None]
512
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
513
+ self.stop = broadcast_list[0]
514
+ if self.stop:
515
+ break # must break all DDP ranks
516
+ epoch += 1
517
+
518
+ seconds = time.time() - self.train_time_start
519
+ LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
520
+ # Do final val with best.pt
521
+ self.final_eval()
522
+ if RANK in {-1, 0}:
523
+ if self.args.plots:
524
+ self.plot_metrics()
525
+ self.run_callbacks("on_train_end")
526
+ self._clear_memory()
527
+ unset_deterministic()
528
+ self.run_callbacks("teardown")
529
+
530
+ def auto_batch(self, max_num_obj=0):
531
+ """Calculate optimal batch size based on model and device memory constraints."""
532
+ return check_train_batch_size(
533
+ model=self.model,
534
+ imgsz=self.args.imgsz,
535
+ amp=self.amp,
536
+ batch=self.batch_size,
537
+ max_num_obj=max_num_obj,
538
+ ) # returns batch size
539
+
540
+ def _get_memory(self, fraction=False):
541
+ """Get accelerator memory utilization in GB or as a fraction of total memory."""
542
+ memory, total = 0, 0
543
+ if self.device.type == "mps":
544
+ memory = torch.mps.driver_allocated_memory()
545
+ if fraction:
546
+ return __import__("psutil").virtual_memory().percent / 100
547
+ elif self.device.type != "cpu":
548
+ memory = torch.cuda.memory_reserved()
549
+ if fraction:
550
+ total = torch.cuda.get_device_properties(self.device).total_memory
551
+ return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
552
+
553
+ def _clear_memory(self, threshold: float | None = None):
554
+ """Clear accelerator memory by calling garbage collector and emptying cache."""
555
+ if threshold:
556
+ assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
557
+ if self._get_memory(fraction=True) <= threshold:
558
+ return
559
+ gc.collect()
560
+ if self.device.type == "mps":
561
+ torch.mps.empty_cache()
562
+ elif self.device.type == "cpu":
563
+ return
564
+ else:
565
+ torch.cuda.empty_cache()
566
+
567
+ def read_results_csv(self):
568
+ """Read results.csv into a dictionary using polars."""
569
+ import polars as pl # scope for faster 'import ultralytics'
570
+
571
+ try:
572
+ return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
573
+ except Exception:
574
+ return {}
575
+
576
+ def _model_train(self):
577
+ """Set model in training mode."""
578
+ self.model.train()
579
+ # Freeze BN stat
580
+ for n, m in self.model.named_modules():
581
+ if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
582
+ m.eval()
583
+
584
+ def save_model(self):
585
+ """Save model training checkpoints with additional metadata."""
586
+ import io
587
+
588
+ # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
589
+ buffer = io.BytesIO()
590
+ torch.save(
591
+ {
592
+ "epoch": self.epoch,
593
+ "best_fitness": self.best_fitness,
594
+ "model": None, # resume and final checkpoints derive from EMA
595
+ "ema": deepcopy(unwrap_model(self.ema.ema)).half(),
596
+ "updates": self.ema.updates,
597
+ "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
598
+ "scaler": self.scaler.state_dict(),
599
+ "train_args": vars(self.args), # save as dict
600
+ "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
601
+ "train_results": self.read_results_csv(),
602
+ "date": datetime.now().isoformat(),
603
+ "version": __version__,
604
+ "git": {
605
+ "root": str(GIT.root),
606
+ "branch": GIT.branch,
607
+ "commit": GIT.commit,
608
+ "origin": GIT.origin,
609
+ },
610
+ "license": "AGPL-3.0 (https://ultralytics.com/license)",
611
+ "docs": "https://docs.ultralytics.com",
612
+ },
613
+ buffer,
614
+ )
615
+ serialized_ckpt = buffer.getvalue() # get the serialized content to save
616
+
617
+ # Save checkpoints
618
+ self.wdir.mkdir(parents=True, exist_ok=True) # ensure weights directory exists
619
+ self.last.write_bytes(serialized_ckpt) # save last.pt
620
+ if self.best_fitness == self.fitness:
621
+ self.best.write_bytes(serialized_ckpt) # save best.pt
622
+ if (self.save_period > 0) and (self.epoch % self.save_period == 0):
623
+ (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
624
+
625
+ def get_dataset(self):
626
+ """Get train and validation datasets from data dictionary.
627
+
628
+ Returns:
629
+ (dict): A dictionary containing the training/validation/test dataset and category names.
630
+ """
631
+ try:
632
+ if self.args.task == "classify":
633
+ data = check_cls_dataset(self.args.data)
634
+ elif str(self.args.data).rsplit(".", 1)[-1] == "ndjson":
635
+ # Convert NDJSON to YOLO format
636
+ import asyncio
637
+
638
+ from ultralytics.data.converter import convert_ndjson_to_yolo
639
+
640
+ yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
641
+ self.args.data = str(yaml_path)
642
+ data = check_det_dataset(self.args.data)
643
+ elif str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
644
+ "detect",
645
+ "segment",
646
+ "pose",
647
+ "obb",
648
+ }:
649
+ data = check_det_dataset(self.args.data)
650
+ if "yaml_file" in data:
651
+ self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
652
+ except Exception as e:
653
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
654
+ if self.args.single_cls:
655
+ LOGGER.info("Overriding class names with single class.")
656
+ data["names"] = {0: "item"}
657
+ data["nc"] = 1
658
+ return data
659
+
660
+ def setup_model(self):
661
+ """Load, create, or download model for any task.
662
+
663
+ Returns:
664
+ (dict): Optional checkpoint to resume training from.
665
+ """
666
+ if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
667
+ return
668
+
669
+ cfg, weights = self.model, None
670
+ ckpt = None
671
+ if str(self.model).endswith(".pt"):
672
+ weights, ckpt = load_checkpoint(self.model)
673
+ cfg = weights.yaml
674
+ elif isinstance(self.args.pretrained, (str, Path)):
675
+ weights, _ = load_checkpoint(self.args.pretrained)
676
+ self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
677
+ return ckpt
678
+
679
+ def optimizer_step(self):
680
+ """Perform a single step of the training optimizer with gradient clipping and EMA update."""
681
+ self.scaler.unscale_(self.optimizer) # unscale gradients
682
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
683
+ self.scaler.step(self.optimizer)
684
+ self.scaler.update()
685
+ self.optimizer.zero_grad()
686
+ if self.ema:
687
+ self.ema.update(self.model)
688
+
689
+ def preprocess_batch(self, batch):
690
+ """Allow custom preprocessing model inputs and ground truths depending on task type."""
691
+ return batch
692
+
693
+ def validate(self):
694
+ """Run validation on val set using self.validator.
695
+
696
+ Returns:
697
+ metrics (dict): Dictionary of validation metrics.
698
+ fitness (float): Fitness score for the validation.
699
+ """
700
+ if self.ema and self.world_size > 1:
701
+ # Sync EMA buffers from rank 0 to all ranks
702
+ for buffer in self.ema.ema.buffers():
703
+ dist.broadcast(buffer, src=0)
704
+ metrics = self.validator(self)
705
+ if metrics is None:
706
+ return None, None
707
+ fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
708
+ if not self.best_fitness or self.best_fitness < fitness:
709
+ self.best_fitness = fitness
710
+ return metrics, fitness
711
+
712
+ def get_model(self, cfg=None, weights=None, verbose=True):
713
+ """Get model and raise NotImplementedError for loading cfg files."""
714
+ raise NotImplementedError("This task trainer doesn't support loading cfg files")
715
+
716
+ def get_validator(self):
717
+ """Raise NotImplementedError (must be implemented by subclasses)."""
718
+ raise NotImplementedError("get_validator function not implemented in trainer")
719
+
720
+ def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
721
+ """Raise NotImplementedError (must return a `torch.utils.data.DataLoader` in subclasses)."""
722
+ raise NotImplementedError("get_dataloader function not implemented in trainer")
723
+
724
+ def build_dataset(self, img_path, mode="train", batch=None):
725
+ """Build dataset."""
726
+ raise NotImplementedError("build_dataset function not implemented in trainer")
727
+
728
+ def label_loss_items(self, loss_items=None, prefix="train"):
729
+ """Return a loss dict with labeled training loss items tensor.
730
+
731
+ Notes:
732
+ This is not needed for classification but necessary for segmentation & detection
733
+ """
734
+ return {"loss": loss_items} if loss_items is not None else ["loss"]
735
+
736
+ def set_model_attributes(self):
737
+ """Set or update model parameters before training."""
738
+ self.model.names = self.data["names"]
739
+
740
+ def build_targets(self, preds, targets):
741
+ """Build target tensors for training YOLO model."""
742
+ pass
743
+
744
+ def progress_string(self):
745
+ """Return a string describing training progress."""
746
+ return ""
747
+
748
+ # TODO: may need to put these following functions into callback
749
+ def plot_training_samples(self, batch, ni):
750
+ """Plot training samples during YOLO training."""
751
+ pass
752
+
753
+ def plot_training_labels(self):
754
+ """Plot training labels for YOLO model."""
755
+ pass
756
+
757
+ def save_metrics(self, metrics):
758
+ """Save training metrics to a CSV file."""
759
+ keys, vals = list(metrics.keys()), list(metrics.values())
760
+ n = len(metrics) + 2 # number of cols
761
+ t = time.time() - self.train_time_start
762
+ self.csv.parent.mkdir(parents=True, exist_ok=True) # ensure parent directory exists
763
+ s = "" if self.csv.exists() else ("%s," * n % ("epoch", "time", *keys)).rstrip(",") + "\n"
764
+ with open(self.csv, "a", encoding="utf-8") as f:
765
+ f.write(s + ("%.6g," * n % (self.epoch + 1, t, *vals)).rstrip(",") + "\n")
766
+
767
+ def plot_metrics(self):
768
+ """Plot metrics from a CSV file."""
769
+ plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
770
+
771
+ def on_plot(self, name, data=None):
772
+ """Register plots (e.g. to be consumed in callbacks)."""
773
+ path = Path(name)
774
+ self.plots[path] = {"data": data, "timestamp": time.time()}
775
+
776
+ def final_eval(self):
777
+ """Perform final evaluation and validation for object detection YOLO model."""
778
+ model = self.best if self.best.exists() else None
779
+ with torch_distributed_zero_first(LOCAL_RANK): # strip only on GPU 0; other GPUs should wait
780
+ if RANK in {-1, 0}:
781
+ ckpt = strip_optimizer(self.last) if self.last.exists() else {}
782
+ if model:
783
+ # update best.pt train_metrics from last.pt
784
+ strip_optimizer(self.best, updates={"train_results": ckpt.get("train_results")})
785
+ if model:
786
+ LOGGER.info(f"\nValidating {model}...")
787
+ self.validator.args.plots = self.args.plots
788
+ self.validator.args.compile = False # disable final val compile as too slow
789
+ self.metrics = self.validator(model=model)
790
+ self.metrics.pop("fitness", None)
791
+ self.run_callbacks("on_fit_epoch_end")
792
+
793
+ def check_resume(self, overrides):
794
+ """Check if resume checkpoint exists and update arguments accordingly."""
795
+ resume = self.args.resume
796
+ if resume:
797
+ try:
798
+ exists = isinstance(resume, (str, Path)) and Path(resume).exists()
799
+ last = Path(check_file(resume) if exists else get_latest_run())
800
+
801
+ # Check that resume data YAML exists, otherwise strip to force re-download of dataset
802
+ ckpt_args = load_checkpoint(last)[0].args
803
+ if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
804
+ ckpt_args["data"] = self.args.data
805
+
806
+ resume = True
807
+ self.args = get_cfg(ckpt_args)
808
+ self.args.model = self.args.resume = str(last) # reinstate model
809
+ for k in (
810
+ "imgsz",
811
+ "batch",
812
+ "device",
813
+ "close_mosaic",
814
+ "augmentations",
815
+ "save_period",
816
+ "workers",
817
+ "cache",
818
+ "patience",
819
+ "time",
820
+ "freeze",
821
+ "val",
822
+ "plots",
823
+ ): # allow arg updates to reduce memory or update device on resume
824
+ if k in overrides:
825
+ setattr(self.args, k, overrides[k])
826
+
827
+ # Handle augmentations parameter for resume: check if user provided custom augmentations
828
+ if ckpt_args.get("augmentations") is not None:
829
+ # Augmentations were saved in checkpoint as reprs but can't be restored automatically
830
+ LOGGER.warning(
831
+ "Custom Albumentations transforms were used in the original training run but are not "
832
+ "being restored. To preserve custom augmentations when resuming, you need to pass the "
833
+ "'augmentations' parameter again to get expected results. Example: \n"
834
+ f"model.train(resume=True, augmentations={ckpt_args['augmentations']})"
835
+ )
836
+
837
+ except Exception as e:
838
+ raise FileNotFoundError(
839
+ "Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
840
+ "i.e. 'yolo train resume model=path/to/last.pt'"
841
+ ) from e
842
+ self.resume = resume
843
+
844
+ def _load_checkpoint_state(self, ckpt):
845
+ """Load optimizer, scaler, EMA, and best_fitness from checkpoint."""
846
+ if ckpt.get("optimizer") is not None:
847
+ self.optimizer.load_state_dict(ckpt["optimizer"])
848
+ if ckpt.get("scaler") is not None:
849
+ self.scaler.load_state_dict(ckpt["scaler"])
850
+ if self.ema and ckpt.get("ema"):
851
+ self.ema = ModelEMA(self.model) # validation with EMA creates inference tensors that can't be updated
852
+ self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict())
853
+ self.ema.updates = ckpt["updates"]
854
+ self.best_fitness = ckpt.get("best_fitness", 0.0)
855
+
856
+ def _handle_nan_recovery(self, epoch):
857
+ """Detect and recover from NaN/Inf loss and fitness collapse by loading last checkpoint."""
858
+ loss_nan = self.loss is not None and not self.loss.isfinite()
859
+ fitness_nan = self.fitness is not None and not np.isfinite(self.fitness)
860
+ fitness_collapse = self.best_fitness and self.best_fitness > 0 and self.fitness == 0
861
+ corrupted = RANK in {-1, 0} and loss_nan and (fitness_nan or fitness_collapse)
862
+ reason = "Loss NaN/Inf" if loss_nan else "Fitness NaN/Inf" if fitness_nan else "Fitness collapse"
863
+ if RANK != -1: # DDP: broadcast to all ranks
864
+ broadcast_list = [corrupted if RANK == 0 else None]
865
+ dist.broadcast_object_list(broadcast_list, 0)
866
+ corrupted = broadcast_list[0]
867
+ if not corrupted:
868
+ return False
869
+ if epoch == self.start_epoch or not self.last.exists():
870
+ LOGGER.warning(f"{reason} detected but can not recover from last.pt...")
871
+ return False # Cannot recover on first epoch, let training continue
872
+ self.nan_recovery_attempts += 1
873
+ if self.nan_recovery_attempts > 3:
874
+ raise RuntimeError(f"Training failed: NaN persisted for {self.nan_recovery_attempts} epochs")
875
+ LOGGER.warning(f"{reason} detected (attempt {self.nan_recovery_attempts}/3), recovering from last.pt...")
876
+ self._model_train() # set model to train mode before loading checkpoint to avoid inference tensor errors
877
+ _, ckpt = load_checkpoint(self.last)
878
+ ema_state = ckpt["ema"].float().state_dict()
879
+ if not all(torch.isfinite(v).all() for v in ema_state.values() if isinstance(v, torch.Tensor)):
880
+ raise RuntimeError(f"Checkpoint {self.last} is corrupted with NaN/Inf weights")
881
+ unwrap_model(self.model).load_state_dict(ema_state) # Load EMA weights into model
882
+ self._load_checkpoint_state(ckpt) # Load optimizer/scaler/EMA/best_fitness
883
+ del ckpt, ema_state
884
+ self.scheduler.last_epoch = epoch - 1
885
+ return True
886
+
887
+ def resume_training(self, ckpt):
888
+ """Resume YOLO training from given epoch and best fitness."""
889
+ if ckpt is None or not self.resume:
890
+ return
891
+ start_epoch = ckpt.get("epoch", -1) + 1
892
+ assert start_epoch > 0, (
893
+ f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
894
+ f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
895
+ )
896
+ LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
897
+ if self.epochs < start_epoch:
898
+ LOGGER.info(
899
+ f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
900
+ )
901
+ self.epochs += ckpt["epoch"] # finetune additional epochs
902
+ self._load_checkpoint_state(ckpt)
903
+ self.start_epoch = start_epoch
904
+ if start_epoch > (self.epochs - self.args.close_mosaic):
905
+ self._close_dataloader_mosaic()
906
+
907
+ def _close_dataloader_mosaic(self):
908
+ """Update dataloaders to stop using mosaic augmentation."""
909
+ if hasattr(self.train_loader.dataset, "mosaic"):
910
+ self.train_loader.dataset.mosaic = False
911
+ if hasattr(self.train_loader.dataset, "close_mosaic"):
912
+ LOGGER.info("Closing dataloader mosaic")
913
+ self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
914
+
915
+ def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
916
+ """Construct an optimizer for the given model.
917
+
918
+ Args:
919
+ model (torch.nn.Module): The model for which to build an optimizer.
920
+ name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected based on the
921
+ number of iterations.
922
+ lr (float, optional): The learning rate for the optimizer.
923
+ momentum (float, optional): The momentum factor for the optimizer.
924
+ decay (float, optional): The weight decay for the optimizer.
925
+ iterations (float, optional): The number of iterations, which determines the optimizer if name is 'auto'.
926
+
927
+ Returns:
928
+ (torch.optim.Optimizer): The constructed optimizer.
929
+ """
930
+ g = [], [], [] # optimizer parameter groups
931
+ bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
932
+ if name == "auto":
933
+ LOGGER.info(
934
+ f"{colorstr('optimizer:')} 'optimizer=auto' found, "
935
+ f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
936
+ f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
937
+ )
938
+ nc = self.data.get("nc", 10) # number of classes
939
+ lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
940
+ name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
941
+ self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
942
+
943
+ for module_name, module in model.named_modules():
944
+ for param_name, param in module.named_parameters(recurse=False):
945
+ fullname = f"{module_name}.{param_name}" if module_name else param_name
946
+ if "bias" in fullname: # bias (no decay)
947
+ g[2].append(param)
948
+ elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
949
+ # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
950
+ g[1].append(param)
951
+ else: # weight (with decay)
952
+ g[0].append(param)
953
+
954
+ optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
955
+ name = {x.lower(): x for x in optimizers}.get(name.lower())
956
+ if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
957
+ optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
958
+ elif name == "RMSProp":
959
+ optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
960
+ elif name == "SGD":
961
+ optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
962
+ else:
963
+ raise NotImplementedError(
964
+ f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
965
+ "Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
966
+ )
967
+
968
+ optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
969
+ optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
970
+ LOGGER.info(
971
+ f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
972
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
973
+ )
974
+ return optimizer