dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,984 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import gc
7
+ import math
8
+ import os
9
+ import random
10
+ import time
11
+ from contextlib import contextmanager
12
+ from copy import deepcopy
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.distributed as dist
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from ultralytics import __version__
24
+ from ultralytics.utils import (
25
+ DEFAULT_CFG_DICT,
26
+ DEFAULT_CFG_KEYS,
27
+ LOGGER,
28
+ NUM_THREADS,
29
+ PYTHON_VERSION,
30
+ TORCH_VERSION,
31
+ TORCHVISION_VERSION,
32
+ WINDOWS,
33
+ colorstr,
34
+ )
35
+ from ultralytics.utils.checks import check_version
36
+ from ultralytics.utils.cpu import CPUInfo
37
+ from ultralytics.utils.patches import torch_load
38
+
39
+ # Version checks (all default to version>=min_version)
40
+ TORCH_1_9 = check_version(TORCH_VERSION, "1.9.0")
41
+ TORCH_1_10 = check_version(TORCH_VERSION, "1.10.0")
42
+ TORCH_1_11 = check_version(TORCH_VERSION, "1.11.0")
43
+ TORCH_1_13 = check_version(TORCH_VERSION, "1.13.0")
44
+ TORCH_2_0 = check_version(TORCH_VERSION, "2.0.0")
45
+ TORCH_2_1 = check_version(TORCH_VERSION, "2.1.0")
46
+ TORCH_2_4 = check_version(TORCH_VERSION, "2.4.0")
47
+ TORCH_2_8 = check_version(TORCH_VERSION, "2.8.0")
48
+ TORCH_2_9 = check_version(TORCH_VERSION, "2.9.0")
49
+ TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
50
+ TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
51
+ TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
52
+ TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
53
+ if WINDOWS and check_version(TORCH_VERSION, "==2.4.0"): # reject version 2.4.0 on Windows
54
+ LOGGER.warning(
55
+ "Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
56
+ "https://github.com/ultralytics/ultralytics/issues/15049"
57
+ )
58
+
59
+
60
+ @contextmanager
61
+ def torch_distributed_zero_first(local_rank: int):
62
+ """Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first."""
63
+ initialized = dist.is_available() and dist.is_initialized()
64
+ use_ids = initialized and dist.get_backend() == "nccl"
65
+
66
+ if initialized and local_rank not in {-1, 0}:
67
+ dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
68
+ yield
69
+ if initialized and local_rank == 0:
70
+ dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
71
+
72
+
73
+ def smart_inference_mode():
74
+ """Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
75
+
76
+ def decorate(fn):
77
+ """Apply appropriate torch decorator for inference mode based on torch version."""
78
+ if TORCH_1_9 and torch.is_inference_mode_enabled():
79
+ return fn # already in inference_mode, act as a pass-through
80
+ else:
81
+ return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
82
+
83
+ return decorate
84
+
85
+
86
+ def autocast(enabled: bool, device: str = "cuda"):
87
+ """Get the appropriate autocast context manager based on PyTorch version and AMP setting.
88
+
89
+ This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
90
+ older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
91
+
92
+ Args:
93
+ enabled (bool): Whether to enable automatic mixed precision.
94
+ device (str, optional): The device to use for autocast.
95
+
96
+ Returns:
97
+ (torch.amp.autocast): The appropriate autocast context manager.
98
+
99
+ Examples:
100
+ >>> with autocast(enabled=True):
101
+ ... # Your mixed precision operations here
102
+ ... pass
103
+
104
+ Notes:
105
+ - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
106
+ - For older versions, it uses `torch.cuda.autocast`.
107
+ """
108
+ if TORCH_1_13:
109
+ return torch.amp.autocast(device, enabled=enabled)
110
+ else:
111
+ return torch.cuda.amp.autocast(enabled)
112
+
113
+
114
+ @functools.lru_cache
115
+ def get_cpu_info():
116
+ """Return a string with system CPU information, i.e. 'Apple M2'."""
117
+ from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
118
+
119
+ if "cpu_info" not in PERSISTENT_CACHE:
120
+ try:
121
+ PERSISTENT_CACHE["cpu_info"] = CPUInfo.name()
122
+ except Exception:
123
+ pass
124
+ return PERSISTENT_CACHE.get("cpu_info", "unknown")
125
+
126
+
127
+ @functools.lru_cache
128
+ def get_gpu_info(index):
129
+ """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
130
+ properties = torch.cuda.get_device_properties(index)
131
+ return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
132
+
133
+
134
+ def select_device(device="", newline=False, verbose=True):
135
+ """Select the appropriate PyTorch device based on the provided arguments.
136
+
137
+ The function takes a string specifying the device or a torch.device object and returns a torch.device object
138
+ representing the selected device. The function also validates the number of available devices and raises an
139
+ exception if the requested device(s) are not available.
140
+
141
+ Args:
142
+ device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or
143
+ 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available.
144
+ newline (bool, optional): If True, adds a newline at the end of the log string.
145
+ verbose (bool, optional): If True, logs the device information.
146
+
147
+ Returns:
148
+ (torch.device): Selected device.
149
+
150
+ Examples:
151
+ >>> select_device("cuda:0")
152
+ device(type='cuda', index=0)
153
+
154
+ >>> select_device("cpu")
155
+ device(type='cpu')
156
+
157
+ Notes:
158
+ Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
159
+ """
160
+ if isinstance(device, torch.device) or str(device).startswith(("tpu", "intel", "vulkan")):
161
+ return device
162
+
163
+ s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{TORCH_VERSION} "
164
+ device = str(device).lower()
165
+ for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
166
+ device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
167
+
168
+ # Auto-select GPUs
169
+ if "-1" in device:
170
+ from ultralytics.utils.autodevice import GPUInfo
171
+
172
+ # Replace each -1 with a selected GPU or remove it
173
+ parts = device.split(",")
174
+ selected = GPUInfo().select_idle_gpu(count=parts.count("-1"), min_memory_fraction=0.2)
175
+ for i in range(len(parts)):
176
+ if parts[i] == "-1":
177
+ parts[i] = str(selected.pop(0)) if selected else ""
178
+ device = ",".join(p for p in parts if p)
179
+
180
+ cpu = device == "cpu"
181
+ mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
182
+ if cpu or mps:
183
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # force torch.cuda.is_available() = False
184
+ elif device: # non-cpu device requested
185
+ if device == "cuda":
186
+ device = "0"
187
+ if "," in device:
188
+ device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
189
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
190
+ os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
191
+ if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
192
+ LOGGER.info(s)
193
+ install = (
194
+ "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
195
+ "CUDA devices are seen by torch.\n"
196
+ if torch.cuda.device_count() == 0
197
+ else ""
198
+ )
199
+ raise ValueError(
200
+ f"Invalid CUDA 'device={device}' requested."
201
+ f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
202
+ f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
203
+ f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
204
+ f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
205
+ f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
206
+ f"{install}"
207
+ )
208
+
209
+ if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
210
+ devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
211
+ space = " " * len(s)
212
+ for i, d in enumerate(devices):
213
+ s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
214
+ arg = "cuda:0"
215
+ elif mps and TORCH_2_0 and torch.backends.mps.is_available():
216
+ # Prefer MPS if available
217
+ s += f"MPS ({get_cpu_info()})\n"
218
+ arg = "mps"
219
+ else: # revert to CPU
220
+ s += f"CPU ({get_cpu_info()})\n"
221
+ arg = "cpu"
222
+
223
+ if arg in {"cpu", "mps"}:
224
+ torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
225
+ if verbose:
226
+ LOGGER.info(s if newline else s.rstrip())
227
+ return torch.device(arg)
228
+
229
+
230
+ def time_sync():
231
+ """Return PyTorch-accurate time."""
232
+ if torch.cuda.is_available():
233
+ torch.cuda.synchronize()
234
+ return time.time()
235
+
236
+
237
+ def fuse_conv_and_bn(conv, bn):
238
+ """Fuse Conv2d and BatchNorm2d layers for inference optimization.
239
+
240
+ Args:
241
+ conv (nn.Conv2d): Convolutional layer to fuse.
242
+ bn (nn.BatchNorm2d): Batch normalization layer to fuse.
243
+
244
+ Returns:
245
+ (nn.Conv2d): The fused convolutional layer with gradients disabled.
246
+
247
+ Examples:
248
+ >>> conv = nn.Conv2d(3, 16, 3)
249
+ >>> bn = nn.BatchNorm2d(16)
250
+ >>> fused_conv = fuse_conv_and_bn(conv, bn)
251
+ """
252
+ # Compute fused weights
253
+ w_conv = conv.weight.view(conv.out_channels, -1)
254
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
255
+ conv.weight.data = torch.mm(w_bn, w_conv).view(conv.weight.shape)
256
+
257
+ # Compute fused bias
258
+ b_conv = torch.zeros(conv.out_channels, device=conv.weight.device) if conv.bias is None else conv.bias
259
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
260
+ fused_bias = torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn
261
+
262
+ if conv.bias is None:
263
+ conv.register_parameter("bias", nn.Parameter(fused_bias))
264
+ else:
265
+ conv.bias.data = fused_bias
266
+
267
+ return conv.requires_grad_(False)
268
+
269
+
270
+ def fuse_deconv_and_bn(deconv, bn):
271
+ """Fuse ConvTranspose2d and BatchNorm2d layers for inference optimization.
272
+
273
+ Args:
274
+ deconv (nn.ConvTranspose2d): Transposed convolutional layer to fuse.
275
+ bn (nn.BatchNorm2d): Batch normalization layer to fuse.
276
+
277
+ Returns:
278
+ (nn.ConvTranspose2d): The fused transposed convolutional layer with gradients disabled.
279
+
280
+ Examples:
281
+ >>> deconv = nn.ConvTranspose2d(16, 3, 3)
282
+ >>> bn = nn.BatchNorm2d(3)
283
+ >>> fused_deconv = fuse_deconv_and_bn(deconv, bn)
284
+ """
285
+ # Compute fused weights
286
+ w_deconv = deconv.weight.view(deconv.out_channels, -1)
287
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
288
+ deconv.weight.data = torch.mm(w_bn, w_deconv).view(deconv.weight.shape)
289
+
290
+ # Compute fused bias
291
+ b_conv = torch.zeros(deconv.out_channels, device=deconv.weight.device) if deconv.bias is None else deconv.bias
292
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
293
+ fused_bias = torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn
294
+
295
+ if deconv.bias is None:
296
+ deconv.register_parameter("bias", nn.Parameter(fused_bias))
297
+ else:
298
+ deconv.bias.data = fused_bias
299
+
300
+ return deconv.requires_grad_(False)
301
+
302
+
303
+ def model_info(model, detailed=False, verbose=True, imgsz=640):
304
+ """Print and return detailed model information layer by layer.
305
+
306
+ Args:
307
+ model (nn.Module): Model to analyze.
308
+ detailed (bool, optional): Whether to print detailed layer information.
309
+ verbose (bool, optional): Whether to print model information.
310
+ imgsz (int | list, optional): Input image size.
311
+
312
+ Returns:
313
+ n_l (int): Number of layers.
314
+ n_p (int): Number of parameters.
315
+ n_g (int): Number of gradients.
316
+ flops (float): GFLOPs.
317
+ """
318
+ if not verbose:
319
+ return
320
+ n_p = get_num_params(model) # number of parameters
321
+ n_g = get_num_gradients(model) # number of gradients
322
+ layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)
323
+ n_l = len(layers) # number of layers
324
+ if detailed:
325
+ h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}"
326
+ LOGGER.info(h)
327
+ for i, (mn, m) in enumerate(layers.items()):
328
+ mn = mn.replace("module_list.", "")
329
+ mt = m.__class__.__name__
330
+ if len(m._parameters):
331
+ for pn, p in m.named_parameters():
332
+ LOGGER.info(
333
+ f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{list(p.shape)!s:>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
334
+ )
335
+ else: # layers with no learnable params
336
+ LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{[]!s:>20}{'-':>10}{'-':>10}{'-':>15}")
337
+
338
+ flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
339
+ fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
340
+ fs = f", {flops:.1f} GFLOPs" if flops else ""
341
+ yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
342
+ model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
343
+ LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
344
+ return n_l, n_p, n_g, flops
345
+
346
+
347
+ def get_num_params(model):
348
+ """Return the total number of parameters in a YOLO model."""
349
+ return sum(x.numel() for x in model.parameters())
350
+
351
+
352
+ def get_num_gradients(model):
353
+ """Return the total number of parameters with gradients in a YOLO model."""
354
+ return sum(x.numel() for x in model.parameters() if x.requires_grad)
355
+
356
+
357
+ def model_info_for_loggers(trainer):
358
+ """Return model info dict with useful model information.
359
+
360
+ Args:
361
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
362
+
363
+ Returns:
364
+ (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.
365
+
366
+ Examples:
367
+ YOLOv8n info for loggers
368
+ >>> results = {
369
+ ... "model/parameters": 3151904,
370
+ ... "model/GFLOPs": 8.746,
371
+ ... "model/speed_ONNX(ms)": 41.244,
372
+ ... "model/speed_TensorRT(ms)": 3.211,
373
+ ... "model/speed_PyTorch(ms)": 18.755,
374
+ ...}
375
+ """
376
+ if trainer.args.profile: # profile ONNX and TensorRT times
377
+ from ultralytics.utils.benchmarks import ProfileModels
378
+
379
+ results = ProfileModels([trainer.last], device=trainer.device).run()[0]
380
+ results.pop("model/name")
381
+ else: # only return PyTorch times from most recent validation
382
+ results = {
383
+ "model/parameters": get_num_params(trainer.model),
384
+ "model/GFLOPs": round(get_flops(trainer.model), 3),
385
+ }
386
+ results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
387
+ return results
388
+
389
+
390
+ def get_flops(model, imgsz=640):
391
+ """Calculate FLOPs (floating point operations) for a model in billions.
392
+
393
+ Attempts two calculation methods: first with a stride-based tensor for efficiency, then falls back to full image
394
+ size if needed (e.g., for RTDETR models). Returns 0.0 if thop library is unavailable or calculation fails.
395
+
396
+ Args:
397
+ model (nn.Module): The model to calculate FLOPs for.
398
+ imgsz (int | list, optional): Input image size.
399
+
400
+ Returns:
401
+ (float): The model FLOPs in billions.
402
+ """
403
+ try:
404
+ import thop
405
+ except ImportError:
406
+ thop = None # conda support without 'ultralytics-thop' installed
407
+
408
+ if not thop:
409
+ return 0.0 # if not installed return 0.0 GFLOPs
410
+
411
+ try:
412
+ model = unwrap_model(model)
413
+ p = next(model.parameters())
414
+ if not isinstance(imgsz, list):
415
+ imgsz = [imgsz, imgsz] # expand if int/float
416
+ try:
417
+ # Method 1: Use stride-based input tensor
418
+ stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
419
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
420
+ flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
421
+ return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
422
+ except Exception:
423
+ # Method 2: Use actual image size (required for RTDETR models)
424
+ im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
425
+ return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
426
+ except Exception:
427
+ return 0.0
428
+
429
+
430
+ def get_flops_with_torch_profiler(model, imgsz=640):
431
+ """Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
432
+
433
+ Args:
434
+ model (nn.Module): The model to calculate FLOPs for.
435
+ imgsz (int | list, optional): Input image size.
436
+
437
+ Returns:
438
+ (float): The model's FLOPs in billions.
439
+ """
440
+ if not TORCH_2_0: # torch profiler implemented in torch>=2.0
441
+ return 0.0
442
+ model = unwrap_model(model)
443
+ p = next(model.parameters())
444
+ if not isinstance(imgsz, list):
445
+ imgsz = [imgsz, imgsz] # expand if int/float
446
+ try:
447
+ # Use stride size for input tensor
448
+ stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
449
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
450
+ with torch.profiler.profile(with_flops=True) as prof:
451
+ model(im)
452
+ flops = sum(x.flops for x in prof.key_averages()) / 1e9
453
+ flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
454
+ except Exception:
455
+ # Use actual image size for input tensor (i.e. required for RTDETR models)
456
+ im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
457
+ with torch.profiler.profile(with_flops=True) as prof:
458
+ model(im)
459
+ flops = sum(x.flops for x in prof.key_averages()) / 1e9
460
+ return flops
461
+
462
+
463
+ def initialize_weights(model):
464
+ """Initialize model weights to random values."""
465
+ for m in model.modules():
466
+ t = type(m)
467
+ if t is nn.Conv2d:
468
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
469
+ elif t is nn.BatchNorm2d:
470
+ m.eps = 1e-3
471
+ m.momentum = 0.03
472
+ elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
473
+ m.inplace = True
474
+
475
+
476
+ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
477
+ """Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
478
+
479
+ Args:
480
+ img (torch.Tensor): Input image tensor.
481
+ ratio (float, optional): Scaling ratio.
482
+ same_shape (bool, optional): Whether to maintain the same shape.
483
+ gs (int, optional): Grid size for padding.
484
+
485
+ Returns:
486
+ (torch.Tensor): Scaled and padded image tensor.
487
+ """
488
+ if ratio == 1.0:
489
+ return img
490
+ h, w = img.shape[2:]
491
+ s = (int(h * ratio), int(w * ratio)) # new size
492
+ img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
493
+ if not same_shape: # pad/crop img
494
+ h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
495
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
496
+
497
+
498
+ def copy_attr(a, b, include=(), exclude=()):
499
+ """Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
500
+
501
+ Args:
502
+ a (Any): Destination object to copy attributes to.
503
+ b (Any): Source object to copy attributes from.
504
+ include (tuple, optional): Attributes to include. If empty, all attributes are included.
505
+ exclude (tuple, optional): Attributes to exclude.
506
+ """
507
+ for k, v in b.__dict__.items():
508
+ if (len(include) and k not in include) or k.startswith("_") or k in exclude:
509
+ continue
510
+ else:
511
+ setattr(a, k, v)
512
+
513
+
514
+ def intersect_dicts(da, db, exclude=()):
515
+ """Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
516
+
517
+ Args:
518
+ da (dict): First dictionary.
519
+ db (dict): Second dictionary.
520
+ exclude (tuple, optional): Keys to exclude.
521
+
522
+ Returns:
523
+ (dict): Dictionary of intersecting keys with matching shapes.
524
+ """
525
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
526
+
527
+
528
+ def is_parallel(model):
529
+ """Return True if model is of type DP or DDP.
530
+
531
+ Args:
532
+ model (nn.Module): Model to check.
533
+
534
+ Returns:
535
+ (bool): True if model is DataParallel or DistributedDataParallel.
536
+ """
537
+ return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
538
+
539
+
540
+ def unwrap_model(m: nn.Module) -> nn.Module:
541
+ """Unwrap compiled and parallel models to get the base model.
542
+
543
+ Args:
544
+ m (nn.Module): A model that may be wrapped by torch.compile (._orig_mod) or parallel wrappers such as
545
+ DataParallel/DistributedDataParallel (.module).
546
+
547
+ Returns:
548
+ m (nn.Module): The unwrapped base model without compile or parallel wrappers.
549
+ """
550
+ while True:
551
+ if hasattr(m, "_orig_mod") and isinstance(m._orig_mod, nn.Module):
552
+ m = m._orig_mod
553
+ elif hasattr(m, "module") and isinstance(m.module, nn.Module):
554
+ m = m.module
555
+ else:
556
+ return m
557
+
558
+
559
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
560
+ """Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
561
+
562
+ Args:
563
+ y1 (float, optional): Initial value.
564
+ y2 (float, optional): Final value.
565
+ steps (int, optional): Number of steps.
566
+
567
+ Returns:
568
+ (function): Lambda function for computing the sinusoidal ramp.
569
+ """
570
+ return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
571
+
572
+
573
+ def init_seeds(seed=0, deterministic=False):
574
+ """Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
575
+
576
+ Args:
577
+ seed (int, optional): Random seed.
578
+ deterministic (bool, optional): Whether to set deterministic algorithms.
579
+ """
580
+ random.seed(seed)
581
+ np.random.seed(seed)
582
+ torch.manual_seed(seed)
583
+ torch.cuda.manual_seed(seed)
584
+ torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
585
+ # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
586
+ if deterministic:
587
+ if TORCH_2_0:
588
+ torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
589
+ torch.backends.cudnn.deterministic = True
590
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
591
+ os.environ["PYTHONHASHSEED"] = str(seed)
592
+ else:
593
+ LOGGER.warning("Upgrade to torch>=2.0.0 for deterministic training.")
594
+ else:
595
+ unset_deterministic()
596
+
597
+
598
+ def unset_deterministic():
599
+ """Unset all the configurations applied for deterministic training."""
600
+ torch.use_deterministic_algorithms(False)
601
+ torch.backends.cudnn.deterministic = False
602
+ os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
603
+ os.environ.pop("PYTHONHASHSEED", None)
604
+
605
+
606
+ class ModelEMA:
607
+ """Updated Exponential Moving Average (EMA) implementation.
608
+
609
+ Keeps a moving average of everything in the model state_dict (parameters and buffers). For EMA details see
610
+ References.
611
+
612
+ To disable EMA set the `enabled` attribute to `False`.
613
+
614
+ Attributes:
615
+ ema (nn.Module): Copy of the model in evaluation mode.
616
+ updates (int): Number of EMA updates.
617
+ decay (function): Decay function that determines the EMA weight.
618
+ enabled (bool): Whether EMA is enabled.
619
+
620
+ References:
621
+ - https://github.com/rwightman/pytorch-image-models
622
+ - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
623
+ """
624
+
625
+ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
626
+ """Initialize EMA for 'model' with given arguments.
627
+
628
+ Args:
629
+ model (nn.Module): Model to create EMA for.
630
+ decay (float, optional): Maximum EMA decay rate.
631
+ tau (int, optional): EMA decay time constant.
632
+ updates (int, optional): Initial number of updates.
633
+ """
634
+ self.ema = deepcopy(unwrap_model(model)).eval() # FP32 EMA
635
+ self.updates = updates # number of EMA updates
636
+ self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
637
+ for p in self.ema.parameters():
638
+ p.requires_grad_(False)
639
+ self.enabled = True
640
+
641
+ def update(self, model):
642
+ """Update EMA parameters.
643
+
644
+ Args:
645
+ model (nn.Module): Model to update EMA from.
646
+ """
647
+ if self.enabled:
648
+ self.updates += 1
649
+ d = self.decay(self.updates)
650
+
651
+ msd = unwrap_model(model).state_dict() # model state_dict
652
+ for k, v in self.ema.state_dict().items():
653
+ if v.dtype.is_floating_point: # true for FP16 and FP32
654
+ v *= d
655
+ v += (1 - d) * msd[k].detach()
656
+ # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
657
+
658
+ def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
659
+ """Update attributes and save stripped model with optimizer removed.
660
+
661
+ Args:
662
+ model (nn.Module): Model to update attributes from.
663
+ include (tuple, optional): Attributes to include.
664
+ exclude (tuple, optional): Attributes to exclude.
665
+ """
666
+ if self.enabled:
667
+ copy_attr(self.ema, model, include, exclude)
668
+
669
+
670
+ def strip_optimizer(f: str | Path = "best.pt", s: str = "", updates: dict[str, Any] | None = None) -> dict[str, Any]:
671
+ """Strip optimizer from 'f' to finalize training, optionally save as 's'.
672
+
673
+ Args:
674
+ f (str | Path): File path to model to strip the optimizer from.
675
+ s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be
676
+ overwritten.
677
+ updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
678
+
679
+ Returns:
680
+ (dict): The combined checkpoint dictionary.
681
+
682
+ Examples:
683
+ >>> from pathlib import Path
684
+ >>> from ultralytics.utils.torch_utils import strip_optimizer
685
+ >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
686
+ >>> strip_optimizer(f)
687
+ """
688
+ try:
689
+ x = torch_load(f, map_location=torch.device("cpu"))
690
+ assert isinstance(x, dict), "checkpoint is not a Python dictionary"
691
+ assert "model" in x, "'model' missing from checkpoint"
692
+ except Exception as e:
693
+ LOGGER.warning(f"Skipping {f}, not a valid Ultralytics model: {e}")
694
+ return {}
695
+
696
+ metadata = {
697
+ "date": datetime.now().isoformat(),
698
+ "version": __version__,
699
+ "license": "AGPL-3.0 License (https://ultralytics.com/license)",
700
+ "docs": "https://docs.ultralytics.com",
701
+ }
702
+
703
+ # Update model
704
+ if x.get("ema"):
705
+ x["model"] = x["ema"] # replace model with EMA
706
+ if hasattr(x["model"], "args"):
707
+ x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
708
+ if hasattr(x["model"], "criterion"):
709
+ x["model"].criterion = None # strip loss criterion
710
+ x["model"].half() # to FP16
711
+ for p in x["model"].parameters():
712
+ p.requires_grad = False
713
+
714
+ # Update other keys
715
+ args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
716
+ for k in "optimizer", "best_fitness", "ema", "updates", "scaler": # keys
717
+ x[k] = None
718
+ x["epoch"] = -1
719
+ x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
720
+ # x['model'].args = x['train_args']
721
+
722
+ # Save
723
+ combined = {**metadata, **x, **(updates or {})}
724
+ torch.save(combined, s or f) # combine dicts (prefer to the right)
725
+ mb = os.path.getsize(s or f) / 1e6 # file size
726
+ LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
727
+ return combined
728
+
729
+
730
+ def convert_optimizer_state_dict_to_fp16(state_dict):
731
+ """Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
732
+
733
+ Args:
734
+ state_dict (dict): Optimizer state dictionary.
735
+
736
+ Returns:
737
+ (dict): Converted optimizer state dictionary with FP16 tensors.
738
+ """
739
+ for state in state_dict["state"].values():
740
+ for k, v in state.items():
741
+ if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
742
+ state[k] = v.half()
743
+
744
+ return state_dict
745
+
746
+
747
+ @contextmanager
748
+ def cuda_memory_usage(device=None):
749
+ """Monitor and manage CUDA memory usage.
750
+
751
+ This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory. It then
752
+ yields a dictionary containing memory usage information, which can be updated by the caller. Finally, it updates the
753
+ dictionary with the amount of memory reserved by CUDA on the specified device.
754
+
755
+ Args:
756
+ device (torch.device, optional): The CUDA device to query memory usage for.
757
+
758
+ Yields:
759
+ (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
760
+ """
761
+ cuda_info = dict(memory=0)
762
+ if torch.cuda.is_available():
763
+ torch.cuda.empty_cache()
764
+ try:
765
+ yield cuda_info
766
+ finally:
767
+ cuda_info["memory"] = torch.cuda.memory_reserved(device)
768
+ else:
769
+ yield cuda_info
770
+
771
+
772
+ def profile_ops(input, ops, n=10, device=None, max_num_obj=0):
773
+ """Ultralytics speed, memory and FLOPs profiler.
774
+
775
+ Args:
776
+ input (torch.Tensor | list): Input tensor(s) to profile.
777
+ ops (nn.Module | list): Model or list of operations to profile.
778
+ n (int, optional): Number of iterations to average.
779
+ device (str | torch.device, optional): Device to profile on.
780
+ max_num_obj (int, optional): Maximum number of objects for simulation.
781
+
782
+ Returns:
783
+ (list): Profile results for each operation.
784
+
785
+ Examples:
786
+ >>> from ultralytics.utils.torch_utils import profile_ops
787
+ >>> input = torch.randn(16, 3, 640, 640)
788
+ >>> m1 = lambda x: x * torch.sigmoid(x)
789
+ >>> m2 = nn.SiLU()
790
+ >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations
791
+ """
792
+ try:
793
+ import thop
794
+ except ImportError:
795
+ thop = None # conda support without 'ultralytics-thop' installed
796
+
797
+ results = []
798
+ if not isinstance(device, torch.device):
799
+ device = select_device(device)
800
+ LOGGER.info(
801
+ f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
802
+ f"{'input':>24s}{'output':>24s}"
803
+ )
804
+ gc.collect() # attempt to free unused memory
805
+ torch.cuda.empty_cache()
806
+ for x in input if isinstance(input, list) else [input]:
807
+ x = x.to(device)
808
+ x.requires_grad = True
809
+ for m in ops if isinstance(ops, list) else [ops]:
810
+ m = m.to(device) if hasattr(m, "to") else m # device
811
+ m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
812
+ tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
813
+ try:
814
+ flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
815
+ except Exception:
816
+ flops = 0
817
+
818
+ try:
819
+ mem = 0
820
+ for _ in range(n):
821
+ with cuda_memory_usage(device) as cuda_info:
822
+ t[0] = time_sync()
823
+ y = m(x)
824
+ t[1] = time_sync()
825
+ try:
826
+ (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
827
+ t[2] = time_sync()
828
+ except Exception: # no backward method
829
+ # print(e) # for debug
830
+ t[2] = float("nan")
831
+ mem += cuda_info["memory"] / 1e9 # (GB)
832
+ tf += (t[1] - t[0]) * 1000 / n # ms per op forward
833
+ tb += (t[2] - t[1]) * 1000 / n # ms per op backward
834
+ if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
835
+ with cuda_memory_usage(device) as cuda_info:
836
+ torch.randn(
837
+ x.shape[0],
838
+ max_num_obj,
839
+ int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
840
+ device=device,
841
+ dtype=torch.float32,
842
+ )
843
+ mem += cuda_info["memory"] / 1e9 # (GB)
844
+ s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
845
+ p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
846
+ LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{s_in!s:>24s}{s_out!s:>24s}")
847
+ results.append([p, flops, mem, tf, tb, s_in, s_out])
848
+ except Exception as e:
849
+ LOGGER.info(e)
850
+ results.append(None)
851
+ finally:
852
+ gc.collect() # attempt to free unused memory
853
+ torch.cuda.empty_cache()
854
+ return results
855
+
856
+
857
+ class EarlyStopping:
858
+ """Early stopping class that stops training when a specified number of epochs have passed without improvement.
859
+
860
+ Attributes:
861
+ best_fitness (float): Best fitness value observed.
862
+ best_epoch (int): Epoch where best fitness was observed.
863
+ patience (int): Number of epochs to wait after fitness stops improving before stopping.
864
+ possible_stop (bool): Flag indicating if stopping may occur next epoch.
865
+ """
866
+
867
+ def __init__(self, patience=50):
868
+ """Initialize early stopping object.
869
+
870
+ Args:
871
+ patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
872
+ """
873
+ self.best_fitness = 0.0 # i.e. mAP
874
+ self.best_epoch = 0
875
+ self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
876
+ self.possible_stop = False # possible stop may occur next epoch
877
+
878
+ def __call__(self, epoch, fitness):
879
+ """Check whether to stop training.
880
+
881
+ Args:
882
+ epoch (int): Current epoch of training
883
+ fitness (float): Fitness value of current epoch
884
+
885
+ Returns:
886
+ (bool): True if training should stop, False otherwise
887
+ """
888
+ if fitness is None: # check if fitness=None (happens when val=False)
889
+ return False
890
+
891
+ if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training
892
+ self.best_epoch = epoch
893
+ self.best_fitness = fitness
894
+ delta = epoch - self.best_epoch # epochs without improvement
895
+ self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
896
+ stop = delta >= self.patience # stop training if patience exceeded
897
+ if stop:
898
+ prefix = colorstr("EarlyStopping: ")
899
+ LOGGER.info(
900
+ f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
901
+ f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
902
+ f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
903
+ f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
904
+ )
905
+ return stop
906
+
907
+
908
+ def attempt_compile(
909
+ model: torch.nn.Module,
910
+ device: torch.device,
911
+ imgsz: int = 640,
912
+ use_autocast: bool = False,
913
+ warmup: bool = False,
914
+ mode: bool | str = "default",
915
+ ) -> torch.nn.Module:
916
+ """Compile a model with torch.compile and optionally warm up the graph to reduce first-iteration latency.
917
+
918
+ This utility attempts to compile the provided model using the inductor backend with dynamic shapes enabled and an
919
+ autotuning mode. If compilation is unavailable or fails, the original model is returned unchanged. An optional
920
+ warmup performs a single forward pass on a dummy input to prime the compiled graph and measure compile/warmup time.
921
+
922
+ Args:
923
+ model (torch.nn.Module): Model to compile.
924
+ device (torch.device): Inference device used for warmup and autocast decisions.
925
+ imgsz (int, optional): Square input size to create a dummy tensor with shape (1, 3, imgsz, imgsz) for warmup.
926
+ use_autocast (bool, optional): Whether to run warmup under autocast on CUDA or MPS devices.
927
+ warmup (bool, optional): Whether to execute a single dummy forward pass to warm up the compiled model.
928
+ mode (bool | str, optional): torch.compile mode. True → "default", False → no compile, or a string like
929
+ "default", "reduce-overhead", "max-autotune-no-cudagraphs".
930
+
931
+ Returns:
932
+ model (torch.nn.Module): Compiled model if compilation succeeds, otherwise the original unmodified model.
933
+
934
+ Examples:
935
+ >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
936
+ >>> # Try to compile and warm up a model with a 640x640 input
937
+ >>> model = attempt_compile(model, device=device, imgsz=640, use_autocast=True, warmup=True)
938
+
939
+ Notes:
940
+ - If the current PyTorch build does not provide torch.compile, the function returns the input model immediately.
941
+ - Warmup runs under torch.inference_mode and may use torch.autocast for CUDA/MPS to align compute precision.
942
+ - CUDA devices are synchronized after warmup to account for asynchronous kernel execution.
943
+ """
944
+ if not hasattr(torch, "compile") or not mode:
945
+ return model
946
+
947
+ if mode is True:
948
+ mode = "default"
949
+ prefix = colorstr("compile:")
950
+ LOGGER.info(f"{prefix} starting torch.compile with '{mode}' mode...")
951
+ if mode == "max-autotune":
952
+ LOGGER.warning(f"{prefix} mode='{mode}' not recommended, using mode='max-autotune-no-cudagraphs' instead")
953
+ mode = "max-autotune-no-cudagraphs"
954
+ t0 = time.perf_counter()
955
+ try:
956
+ model = torch.compile(model, mode=mode, backend="inductor")
957
+ except Exception as e:
958
+ LOGGER.warning(f"{prefix} torch.compile failed, continuing uncompiled: {e}")
959
+ return model
960
+ t_compile = time.perf_counter() - t0
961
+
962
+ t_warm = 0.0
963
+ if warmup:
964
+ # Use a single dummy tensor to build the graph shape state and reduce first-iteration latency
965
+ dummy = torch.zeros(1, 3, imgsz, imgsz, device=device)
966
+ if use_autocast and device.type == "cuda":
967
+ dummy = dummy.half()
968
+ t1 = time.perf_counter()
969
+ with torch.inference_mode():
970
+ if use_autocast and device.type in {"cuda", "mps"}:
971
+ with torch.autocast(device.type):
972
+ _ = model(dummy)
973
+ else:
974
+ _ = model(dummy)
975
+ if device.type == "cuda":
976
+ torch.cuda.synchronize(device)
977
+ t_warm = time.perf_counter() - t1
978
+
979
+ total = t_compile + t_warm
980
+ if warmup:
981
+ LOGGER.info(f"{prefix} complete in {total:.1f}s (compile {t_compile:.1f}s + warmup {t_warm:.1f}s)")
982
+ else:
983
+ LOGGER.info(f"{prefix} compile complete in {t_compile:.1f}s (no warmup)")
984
+ return model