dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,990 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import gc
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from contextlib import contextmanager
9
+ from copy import deepcopy
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.distributed as dist
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from ultralytics import __version__
21
+ from ultralytics.utils import (
22
+ DEFAULT_CFG_DICT,
23
+ DEFAULT_CFG_KEYS,
24
+ LOGGER,
25
+ NUM_THREADS,
26
+ PYTHON_VERSION,
27
+ TORCHVISION_VERSION,
28
+ WINDOWS,
29
+ colorstr,
30
+ )
31
+ from ultralytics.utils.checks import check_version
32
+
33
+ # Version checks (all default to version>=min_version)
34
+ TORCH_1_9 = check_version(torch.__version__, "1.9.0")
35
+ TORCH_1_13 = check_version(torch.__version__, "1.13.0")
36
+ TORCH_2_0 = check_version(torch.__version__, "2.0.0")
37
+ TORCH_2_4 = check_version(torch.__version__, "2.4.0")
38
+ TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
39
+ TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
40
+ TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
41
+ TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
42
+ if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
43
+ LOGGER.warning(
44
+ "Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
45
+ "https://github.com/ultralytics/ultralytics/issues/15049"
46
+ )
47
+
48
+
49
+ @contextmanager
50
+ def torch_distributed_zero_first(local_rank: int):
51
+ """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
52
+ initialized = dist.is_available() and dist.is_initialized()
53
+ use_ids = initialized and dist.get_backend() == "nccl"
54
+
55
+ if initialized and local_rank not in {-1, 0}:
56
+ dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
57
+ yield
58
+ if initialized and local_rank == 0:
59
+ dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
60
+
61
+
62
+ def smart_inference_mode():
63
+ """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
64
+
65
+ def decorate(fn):
66
+ """Applies appropriate torch decorator for inference mode based on torch version."""
67
+ if TORCH_1_9 and torch.is_inference_mode_enabled():
68
+ return fn # already in inference_mode, act as a pass-through
69
+ else:
70
+ return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
71
+
72
+ return decorate
73
+
74
+
75
+ def autocast(enabled: bool, device: str = "cuda"):
76
+ """
77
+ Get the appropriate autocast context manager based on PyTorch version and AMP setting.
78
+
79
+ This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
80
+ older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
81
+
82
+ Args:
83
+ enabled (bool): Whether to enable automatic mixed precision.
84
+ device (str, optional): The device to use for autocast. Defaults to 'cuda'.
85
+
86
+ Returns:
87
+ (torch.amp.autocast): The appropriate autocast context manager.
88
+
89
+ Notes:
90
+ - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
91
+ - For older versions, it uses `torch.cuda.autocast`.
92
+
93
+ Examples:
94
+ >>> with autocast(enabled=True):
95
+ ... # Your mixed precision operations here
96
+ ... pass
97
+ """
98
+ if TORCH_1_13:
99
+ return torch.amp.autocast(device, enabled=enabled)
100
+ else:
101
+ return torch.cuda.amp.autocast(enabled)
102
+
103
+
104
+ def get_cpu_info():
105
+ """Return a string with system CPU information, i.e. 'Apple M2'."""
106
+ from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
107
+
108
+ if "cpu_info" not in PERSISTENT_CACHE:
109
+ try:
110
+ import cpuinfo # pip install py-cpuinfo
111
+
112
+ k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
113
+ info = cpuinfo.get_cpu_info() # info dict
114
+ string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
115
+ PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
116
+ except Exception:
117
+ pass
118
+ return PERSISTENT_CACHE.get("cpu_info", "unknown")
119
+
120
+
121
+ def get_gpu_info(index):
122
+ """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
123
+ properties = torch.cuda.get_device_properties(index)
124
+ return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
125
+
126
+
127
+ def select_device(device="", batch=0, newline=False, verbose=True):
128
+ """
129
+ Select the appropriate PyTorch device based on the provided arguments.
130
+
131
+ The function takes a string specifying the device or a torch.device object and returns a torch.device object
132
+ representing the selected device. The function also validates the number of available devices and raises an
133
+ exception if the requested device(s) are not available.
134
+
135
+ Args:
136
+ device (str | torch.device, optional): Device string or torch.device object.
137
+ Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
138
+ the first available GPU, or CPU if no GPU is available.
139
+ batch (int, optional): Batch size being used in your model.
140
+ newline (bool, optional): If True, adds a newline at the end of the log string.
141
+ verbose (bool, optional): If True, logs the device information.
142
+
143
+ Returns:
144
+ (torch.device): Selected device.
145
+
146
+ Raises:
147
+ ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
148
+ devices when using multiple GPUs.
149
+
150
+ Examples:
151
+ >>> select_device("cuda:0")
152
+ device(type='cuda', index=0)
153
+
154
+ >>> select_device("cpu")
155
+ device(type='cpu')
156
+
157
+ Note:
158
+ Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
159
+ """
160
+ if isinstance(device, torch.device) or str(device).startswith(("tpu", "intel")):
161
+ return device
162
+
163
+ s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
164
+ device = str(device).lower()
165
+ for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
166
+ device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
167
+
168
+ # Auto-select GPUs
169
+ if "-1" in device:
170
+ from ultralytics.utils.autodevice import GPUInfo
171
+
172
+ # Replace each -1 with a selected GPU or remove it
173
+ parts = device.split(",")
174
+ selected = GPUInfo().select_idle_gpu(count=parts.count("-1"), min_memory_mb=2048)
175
+ for i in range(len(parts)):
176
+ if parts[i] == "-1":
177
+ parts[i] = str(selected.pop(0)) if selected else ""
178
+ device = ",".join(p for p in parts if p)
179
+
180
+ cpu = device == "cpu"
181
+ mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
182
+ if cpu or mps:
183
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
184
+ elif device: # non-cpu device requested
185
+ if device == "cuda":
186
+ device = "0"
187
+ if "," in device:
188
+ device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
189
+ visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
190
+ os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
191
+ if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
192
+ LOGGER.info(s)
193
+ install = (
194
+ "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
195
+ "CUDA devices are seen by torch.\n"
196
+ if torch.cuda.device_count() == 0
197
+ else ""
198
+ )
199
+ raise ValueError(
200
+ f"Invalid CUDA 'device={device}' requested."
201
+ f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
202
+ f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
203
+ f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
204
+ f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
205
+ f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
206
+ f"{install}"
207
+ )
208
+
209
+ if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
210
+ devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
211
+ n = len(devices) # device count
212
+ if n > 1: # multi-GPU
213
+ if batch < 1:
214
+ raise ValueError(
215
+ "AutoBatch with batch<1 not supported for Multi-GPU training, "
216
+ f"please specify a valid batch size multiple of GPU count {n}, i.e. batch={n * 8}."
217
+ )
218
+ if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
219
+ raise ValueError(
220
+ f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
221
+ f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
222
+ )
223
+ space = " " * (len(s) + 1)
224
+ for i, d in enumerate(devices):
225
+ s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
226
+ arg = "cuda:0"
227
+ elif mps and TORCH_2_0 and torch.backends.mps.is_available():
228
+ # Prefer MPS if available
229
+ s += f"MPS ({get_cpu_info()})\n"
230
+ arg = "mps"
231
+ else: # revert to CPU
232
+ s += f"CPU ({get_cpu_info()})\n"
233
+ arg = "cpu"
234
+
235
+ if arg in {"cpu", "mps"}:
236
+ torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
237
+ if verbose:
238
+ LOGGER.info(s if newline else s.rstrip())
239
+ return torch.device(arg)
240
+
241
+
242
+ def time_sync():
243
+ """PyTorch-accurate time."""
244
+ if torch.cuda.is_available():
245
+ torch.cuda.synchronize()
246
+ return time.time()
247
+
248
+
249
+ def fuse_conv_and_bn(conv, bn):
250
+ """Fuse Conv2d() and BatchNorm2d() layers."""
251
+ fusedconv = (
252
+ nn.Conv2d(
253
+ conv.in_channels,
254
+ conv.out_channels,
255
+ kernel_size=conv.kernel_size,
256
+ stride=conv.stride,
257
+ padding=conv.padding,
258
+ dilation=conv.dilation,
259
+ groups=conv.groups,
260
+ bias=True,
261
+ )
262
+ .requires_grad_(False)
263
+ .to(conv.weight.device)
264
+ )
265
+
266
+ # Prepare filters
267
+ w_conv = conv.weight.view(conv.out_channels, -1)
268
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
269
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
270
+
271
+ # Prepare spatial bias
272
+ b_conv = (
273
+ torch.zeros(conv.weight.shape[0], dtype=conv.weight.dtype, device=conv.weight.device)
274
+ if conv.bias is None
275
+ else conv.bias
276
+ )
277
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
278
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
279
+
280
+ return fusedconv
281
+
282
+
283
+ def fuse_deconv_and_bn(deconv, bn):
284
+ """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
285
+ fuseddconv = (
286
+ nn.ConvTranspose2d(
287
+ deconv.in_channels,
288
+ deconv.out_channels,
289
+ kernel_size=deconv.kernel_size,
290
+ stride=deconv.stride,
291
+ padding=deconv.padding,
292
+ output_padding=deconv.output_padding,
293
+ dilation=deconv.dilation,
294
+ groups=deconv.groups,
295
+ bias=True,
296
+ )
297
+ .requires_grad_(False)
298
+ .to(deconv.weight.device)
299
+ )
300
+
301
+ # Prepare filters
302
+ w_deconv = deconv.weight.view(deconv.out_channels, -1)
303
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
304
+ fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
305
+
306
+ # Prepare spatial bias
307
+ b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
308
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
309
+ fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
310
+
311
+ return fuseddconv
312
+
313
+
314
+ def model_info(model, detailed=False, verbose=True, imgsz=640):
315
+ """
316
+ Print and return detailed model information layer by layer.
317
+
318
+ Args:
319
+ model (nn.Module): Model to analyze.
320
+ detailed (bool, optional): Whether to print detailed layer information. Defaults to False.
321
+ verbose (bool, optional): Whether to print model information. Defaults to True.
322
+ imgsz (int | List, optional): Input image size. Defaults to 640.
323
+
324
+ Returns:
325
+ (Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs.
326
+ """
327
+ if not verbose:
328
+ return
329
+ n_p = get_num_params(model) # number of parameters
330
+ n_g = get_num_gradients(model) # number of gradients
331
+ layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)
332
+ n_l = len(layers) # number of layers
333
+ if detailed:
334
+ h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}"
335
+ LOGGER.info(h)
336
+ for i, (mn, m) in enumerate(layers.items()):
337
+ mn = mn.replace("module_list.", "")
338
+ mt = m.__class__.__name__
339
+ if len(m._parameters):
340
+ for pn, p in m.named_parameters():
341
+ LOGGER.info(
342
+ f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
343
+ )
344
+ else: # layers with no learnable params
345
+ LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")
346
+
347
+ flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
348
+ fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
349
+ fs = f", {flops:.1f} GFLOPs" if flops else ""
350
+ yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
351
+ model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
352
+ LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
353
+ return n_l, n_p, n_g, flops
354
+
355
+
356
+ def get_num_params(model):
357
+ """Return the total number of parameters in a YOLO model."""
358
+ return sum(x.numel() for x in model.parameters())
359
+
360
+
361
+ def get_num_gradients(model):
362
+ """Return the total number of parameters with gradients in a YOLO model."""
363
+ return sum(x.numel() for x in model.parameters() if x.requires_grad)
364
+
365
+
366
+ def model_info_for_loggers(trainer):
367
+ """
368
+ Return model info dict with useful model information.
369
+
370
+ Args:
371
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
372
+
373
+ Returns:
374
+ (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.
375
+
376
+ Examples:
377
+ YOLOv8n info for loggers
378
+ >>> results = {
379
+ ... "model/parameters": 3151904,
380
+ ... "model/GFLOPs": 8.746,
381
+ ... "model/speed_ONNX(ms)": 41.244,
382
+ ... "model/speed_TensorRT(ms)": 3.211,
383
+ ... "model/speed_PyTorch(ms)": 18.755,
384
+ ...}
385
+ """
386
+ if trainer.args.profile: # profile ONNX and TensorRT times
387
+ from ultralytics.utils.benchmarks import ProfileModels
388
+
389
+ results = ProfileModels([trainer.last], device=trainer.device).run()[0]
390
+ results.pop("model/name")
391
+ else: # only return PyTorch times from most recent validation
392
+ results = {
393
+ "model/parameters": get_num_params(trainer.model),
394
+ "model/GFLOPs": round(get_flops(trainer.model), 3),
395
+ }
396
+ results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
397
+ return results
398
+
399
+
400
+ def get_flops(model, imgsz=640):
401
+ """
402
+ Calculate FLOPs (floating point operations) for a model in billions.
403
+
404
+ Attempts two calculation methods: first with a stride-based tensor for efficiency,
405
+ then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0
406
+ if thop library is unavailable or calculation fails.
407
+
408
+ Args:
409
+ model (nn.Module): The model to calculate FLOPs for.
410
+ imgsz (int | List[int], optional): Input image size. Defaults to 640.
411
+
412
+ Returns:
413
+ (float): The model FLOPs in billions.
414
+ """
415
+ try:
416
+ import thop
417
+ except ImportError:
418
+ thop = None # conda support without 'ultralytics-thop' installed
419
+
420
+ if not thop:
421
+ return 0.0 # if not installed return 0.0 GFLOPs
422
+
423
+ try:
424
+ model = de_parallel(model)
425
+ p = next(model.parameters())
426
+ if not isinstance(imgsz, list):
427
+ imgsz = [imgsz, imgsz] # expand if int/float
428
+ try:
429
+ # Method 1: Use stride-based input tensor
430
+ stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
431
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
432
+ flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
433
+ return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
434
+ except Exception:
435
+ # Method 2: Use actual image size (required for RTDETR models)
436
+ im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
437
+ return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
438
+ except Exception:
439
+ return 0.0
440
+
441
+
442
+ def get_flops_with_torch_profiler(model, imgsz=640):
443
+ """
444
+ Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
445
+
446
+ Args:
447
+ model (nn.Module): The model to calculate FLOPs for.
448
+ imgsz (int | List[int], optional): Input image size. Defaults to 640.
449
+
450
+ Returns:
451
+ (float): The model's FLOPs in billions.
452
+ """
453
+ if not TORCH_2_0: # torch profiler implemented in torch>=2.0
454
+ return 0.0
455
+ model = de_parallel(model)
456
+ p = next(model.parameters())
457
+ if not isinstance(imgsz, list):
458
+ imgsz = [imgsz, imgsz] # expand if int/float
459
+ try:
460
+ # Use stride size for input tensor
461
+ stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
462
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
463
+ with torch.profiler.profile(with_flops=True) as prof:
464
+ model(im)
465
+ flops = sum(x.flops for x in prof.key_averages()) / 1e9
466
+ flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
467
+ except Exception:
468
+ # Use actual image size for input tensor (i.e. required for RTDETR models)
469
+ im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
470
+ with torch.profiler.profile(with_flops=True) as prof:
471
+ model(im)
472
+ flops = sum(x.flops for x in prof.key_averages()) / 1e9
473
+ return flops
474
+
475
+
476
+ def initialize_weights(model):
477
+ """Initialize model weights to random values."""
478
+ for m in model.modules():
479
+ t = type(m)
480
+ if t is nn.Conv2d:
481
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
482
+ elif t is nn.BatchNorm2d:
483
+ m.eps = 1e-3
484
+ m.momentum = 0.03
485
+ elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
486
+ m.inplace = True
487
+
488
+
489
+ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
490
+ """
491
+ Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
492
+
493
+ Args:
494
+ img (torch.Tensor): Input image tensor.
495
+ ratio (float, optional): Scaling ratio. Defaults to 1.0.
496
+ same_shape (bool, optional): Whether to maintain the same shape. Defaults to False.
497
+ gs (int, optional): Grid size for padding. Defaults to 32.
498
+
499
+ Returns:
500
+ (torch.Tensor): Scaled and padded image tensor.
501
+ """
502
+ if ratio == 1.0:
503
+ return img
504
+ h, w = img.shape[2:]
505
+ s = (int(h * ratio), int(w * ratio)) # new size
506
+ img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
507
+ if not same_shape: # pad/crop img
508
+ h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
509
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
510
+
511
+
512
+ def copy_attr(a, b, include=(), exclude=()):
513
+ """
514
+ Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
515
+
516
+ Args:
517
+ a (object): Destination object to copy attributes to.
518
+ b (object): Source object to copy attributes from.
519
+ include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to ().
520
+ exclude (tuple, optional): Attributes to exclude. Defaults to ().
521
+ """
522
+ for k, v in b.__dict__.items():
523
+ if (len(include) and k not in include) or k.startswith("_") or k in exclude:
524
+ continue
525
+ else:
526
+ setattr(a, k, v)
527
+
528
+
529
+ def get_latest_opset():
530
+ """
531
+ Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.
532
+
533
+ Returns:
534
+ (int): The ONNX opset version.
535
+ """
536
+ if TORCH_1_13:
537
+ # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
538
+ return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
539
+ # Otherwise for PyTorch<=1.12 return the corresponding predefined opset
540
+ version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3'
541
+ return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)
542
+
543
+
544
+ def intersect_dicts(da, db, exclude=()):
545
+ """
546
+ Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
547
+
548
+ Args:
549
+ da (dict): First dictionary.
550
+ db (dict): Second dictionary.
551
+ exclude (tuple, optional): Keys to exclude. Defaults to ().
552
+
553
+ Returns:
554
+ (dict): Dictionary of intersecting keys with matching shapes.
555
+ """
556
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
557
+
558
+
559
+ def is_parallel(model):
560
+ """
561
+ Returns True if model is of type DP or DDP.
562
+
563
+ Args:
564
+ model (nn.Module): Model to check.
565
+
566
+ Returns:
567
+ (bool): True if model is DataParallel or DistributedDataParallel.
568
+ """
569
+ return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
570
+
571
+
572
+ def de_parallel(model):
573
+ """
574
+ De-parallelize a model: returns single-GPU model if model is of type DP or DDP.
575
+
576
+ Args:
577
+ model (nn.Module): Model to de-parallelize.
578
+
579
+ Returns:
580
+ (nn.Module): De-parallelized model.
581
+ """
582
+ return model.module if is_parallel(model) else model
583
+
584
+
585
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
586
+ """
587
+ Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
588
+
589
+ Args:
590
+ y1 (float, optional): Initial value. Defaults to 0.0.
591
+ y2 (float, optional): Final value. Defaults to 1.0.
592
+ steps (int, optional): Number of steps. Defaults to 100.
593
+
594
+ Returns:
595
+ (function): Lambda function for computing the sinusoidal ramp.
596
+ """
597
+ return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
598
+
599
+
600
+ def init_seeds(seed=0, deterministic=False):
601
+ """
602
+ Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
603
+
604
+ Args:
605
+ seed (int, optional): Random seed. Defaults to 0.
606
+ deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False.
607
+ """
608
+ random.seed(seed)
609
+ np.random.seed(seed)
610
+ torch.manual_seed(seed)
611
+ torch.cuda.manual_seed(seed)
612
+ torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
613
+ # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
614
+ if deterministic:
615
+ if TORCH_2_0:
616
+ torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
617
+ torch.backends.cudnn.deterministic = True
618
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
619
+ os.environ["PYTHONHASHSEED"] = str(seed)
620
+ else:
621
+ LOGGER.warning("Upgrade to torch>=2.0.0 for deterministic training.")
622
+ else:
623
+ unset_deterministic()
624
+
625
+
626
+ def unset_deterministic():
627
+ """Unsets all the configurations applied for deterministic training."""
628
+ torch.use_deterministic_algorithms(False)
629
+ torch.backends.cudnn.deterministic = False
630
+ os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
631
+ os.environ.pop("PYTHONHASHSEED", None)
632
+
633
+
634
+ class ModelEMA:
635
+ """
636
+ Updated Exponential Moving Average (EMA) implementation.
637
+
638
+ Keeps a moving average of everything in the model state_dict (parameters and buffers).
639
+ For EMA details see References.
640
+
641
+ To disable EMA set the `enabled` attribute to `False`.
642
+
643
+ Attributes:
644
+ ema (nn.Module): Copy of the model in evaluation mode.
645
+ updates (int): Number of EMA updates.
646
+ decay (function): Decay function that determines the EMA weight.
647
+ enabled (bool): Whether EMA is enabled.
648
+
649
+ References:
650
+ - https://github.com/rwightman/pytorch-image-models
651
+ - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
652
+ """
653
+
654
+ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
655
+ """
656
+ Initialize EMA for 'model' with given arguments.
657
+
658
+ Args:
659
+ model (nn.Module): Model to create EMA for.
660
+ decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999.
661
+ tau (int, optional): EMA decay time constant. Defaults to 2000.
662
+ updates (int, optional): Initial number of updates. Defaults to 0.
663
+ """
664
+ self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
665
+ self.updates = updates # number of EMA updates
666
+ self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
667
+ for p in self.ema.parameters():
668
+ p.requires_grad_(False)
669
+ self.enabled = True
670
+
671
+ def update(self, model):
672
+ """
673
+ Update EMA parameters.
674
+
675
+ Args:
676
+ model (nn.Module): Model to update EMA from.
677
+ """
678
+ if self.enabled:
679
+ self.updates += 1
680
+ d = self.decay(self.updates)
681
+
682
+ msd = de_parallel(model).state_dict() # model state_dict
683
+ for k, v in self.ema.state_dict().items():
684
+ if v.dtype.is_floating_point: # true for FP16 and FP32
685
+ v *= d
686
+ v += (1 - d) * msd[k].detach()
687
+ # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
688
+
689
+ def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
690
+ """
691
+ Updates attributes and saves stripped model with optimizer removed.
692
+
693
+ Args:
694
+ model (nn.Module): Model to update attributes from.
695
+ include (tuple, optional): Attributes to include. Defaults to ().
696
+ exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer").
697
+ """
698
+ if self.enabled:
699
+ copy_attr(self.ema, model, include, exclude)
700
+
701
+
702
+ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
703
+ """
704
+ Strip optimizer from 'f' to finalize training, optionally save as 's'.
705
+
706
+ Args:
707
+ f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'.
708
+ s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
709
+ updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
710
+
711
+ Returns:
712
+ (dict): The combined checkpoint dictionary.
713
+
714
+ Examples:
715
+ >>> from pathlib import Path
716
+ >>> from ultralytics.utils.torch_utils import strip_optimizer
717
+ >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
718
+ >>> strip_optimizer(f)
719
+ """
720
+ try:
721
+ x = torch.load(f, map_location=torch.device("cpu"))
722
+ assert isinstance(x, dict), "checkpoint is not a Python dictionary"
723
+ assert "model" in x, "'model' missing from checkpoint"
724
+ except Exception as e:
725
+ LOGGER.warning(f"Skipping {f}, not a valid Ultralytics model: {e}")
726
+ return {}
727
+
728
+ metadata = {
729
+ "date": datetime.now().isoformat(),
730
+ "version": __version__,
731
+ "license": "AGPL-3.0 License (https://ultralytics.com/license)",
732
+ "docs": "https://docs.ultralytics.com",
733
+ }
734
+
735
+ # Update model
736
+ if x.get("ema"):
737
+ x["model"] = x["ema"] # replace model with EMA
738
+ if hasattr(x["model"], "args"):
739
+ x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
740
+ if hasattr(x["model"], "criterion"):
741
+ x["model"].criterion = None # strip loss criterion
742
+ x["model"].half() # to FP16
743
+ for p in x["model"].parameters():
744
+ p.requires_grad = False
745
+
746
+ # Update other keys
747
+ args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
748
+ for k in "optimizer", "best_fitness", "ema", "updates": # keys
749
+ x[k] = None
750
+ x["epoch"] = -1
751
+ x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
752
+ # x['model'].args = x['train_args']
753
+
754
+ # Save
755
+ combined = {**metadata, **x, **(updates or {})}
756
+ torch.save(combined, s or f) # combine dicts (prefer to the right)
757
+ mb = os.path.getsize(s or f) / 1e6 # file size
758
+ LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
759
+ return combined
760
+
761
+
762
+ def convert_optimizer_state_dict_to_fp16(state_dict):
763
+ """
764
+ Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
765
+
766
+ Args:
767
+ state_dict (dict): Optimizer state dictionary.
768
+
769
+ Returns:
770
+ (dict): Converted optimizer state dictionary with FP16 tensors.
771
+ """
772
+ for state in state_dict["state"].values():
773
+ for k, v in state.items():
774
+ if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
775
+ state[k] = v.half()
776
+
777
+ return state_dict
778
+
779
+
780
+ @contextmanager
781
+ def cuda_memory_usage(device=None):
782
+ """
783
+ Monitor and manage CUDA memory usage.
784
+
785
+ This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
786
+ It then yields a dictionary containing memory usage information, which can be updated by the caller.
787
+ Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
788
+
789
+ Args:
790
+ device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
791
+
792
+ Yields:
793
+ (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
794
+ """
795
+ cuda_info = dict(memory=0)
796
+ if torch.cuda.is_available():
797
+ torch.cuda.empty_cache()
798
+ try:
799
+ yield cuda_info
800
+ finally:
801
+ cuda_info["memory"] = torch.cuda.memory_reserved(device)
802
+ else:
803
+ yield cuda_info
804
+
805
+
806
+ def profile_ops(input, ops, n=10, device=None, max_num_obj=0):
807
+ """
808
+ Ultralytics speed, memory and FLOPs profiler.
809
+
810
+ Args:
811
+ input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile.
812
+ ops (nn.Module | List[nn.Module]): Model or list of operations to profile.
813
+ n (int, optional): Number of iterations to average. Defaults to 10.
814
+ device (str | torch.device, optional): Device to profile on. Defaults to None.
815
+ max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0.
816
+
817
+ Returns:
818
+ (list): Profile results for each operation.
819
+
820
+ Examples:
821
+ >>> from ultralytics.utils.torch_utils import profile_ops
822
+ >>> input = torch.randn(16, 3, 640, 640)
823
+ >>> m1 = lambda x: x * torch.sigmoid(x)
824
+ >>> m2 = nn.SiLU()
825
+ >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations
826
+ """
827
+ try:
828
+ import thop
829
+ except ImportError:
830
+ thop = None # conda support without 'ultralytics-thop' installed
831
+
832
+ results = []
833
+ if not isinstance(device, torch.device):
834
+ device = select_device(device)
835
+ LOGGER.info(
836
+ f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
837
+ f"{'input':>24s}{'output':>24s}"
838
+ )
839
+ gc.collect() # attempt to free unused memory
840
+ torch.cuda.empty_cache()
841
+ for x in input if isinstance(input, list) else [input]:
842
+ x = x.to(device)
843
+ x.requires_grad = True
844
+ for m in ops if isinstance(ops, list) else [ops]:
845
+ m = m.to(device) if hasattr(m, "to") else m # device
846
+ m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
847
+ tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
848
+ try:
849
+ flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
850
+ except Exception:
851
+ flops = 0
852
+
853
+ try:
854
+ mem = 0
855
+ for _ in range(n):
856
+ with cuda_memory_usage(device) as cuda_info:
857
+ t[0] = time_sync()
858
+ y = m(x)
859
+ t[1] = time_sync()
860
+ try:
861
+ (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
862
+ t[2] = time_sync()
863
+ except Exception: # no backward method
864
+ # print(e) # for debug
865
+ t[2] = float("nan")
866
+ mem += cuda_info["memory"] / 1e9 # (GB)
867
+ tf += (t[1] - t[0]) * 1000 / n # ms per op forward
868
+ tb += (t[2] - t[1]) * 1000 / n # ms per op backward
869
+ if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
870
+ with cuda_memory_usage(device) as cuda_info:
871
+ torch.randn(
872
+ x.shape[0],
873
+ max_num_obj,
874
+ int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
875
+ device=device,
876
+ dtype=torch.float32,
877
+ )
878
+ mem += cuda_info["memory"] / 1e9 # (GB)
879
+ s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
880
+ p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
881
+ LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
882
+ results.append([p, flops, mem, tf, tb, s_in, s_out])
883
+ except Exception as e:
884
+ LOGGER.info(e)
885
+ results.append(None)
886
+ finally:
887
+ gc.collect() # attempt to free unused memory
888
+ torch.cuda.empty_cache()
889
+ return results
890
+
891
+
892
+ class EarlyStopping:
893
+ """
894
+ Early stopping class that stops training when a specified number of epochs have passed without improvement.
895
+
896
+ Attributes:
897
+ best_fitness (float): Best fitness value observed.
898
+ best_epoch (int): Epoch where best fitness was observed.
899
+ patience (int): Number of epochs to wait after fitness stops improving before stopping.
900
+ possible_stop (bool): Flag indicating if stopping may occur next epoch.
901
+ """
902
+
903
+ def __init__(self, patience=50):
904
+ """
905
+ Initialize early stopping object.
906
+
907
+ Args:
908
+ patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
909
+ """
910
+ self.best_fitness = 0.0 # i.e. mAP
911
+ self.best_epoch = 0
912
+ self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
913
+ self.possible_stop = False # possible stop may occur next epoch
914
+
915
+ def __call__(self, epoch, fitness):
916
+ """
917
+ Check whether to stop training.
918
+
919
+ Args:
920
+ epoch (int): Current epoch of training
921
+ fitness (float): Fitness value of current epoch
922
+
923
+ Returns:
924
+ (bool): True if training should stop, False otherwise
925
+ """
926
+ if fitness is None: # check if fitness=None (happens when val=False)
927
+ return False
928
+
929
+ if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training
930
+ self.best_epoch = epoch
931
+ self.best_fitness = fitness
932
+ delta = epoch - self.best_epoch # epochs without improvement
933
+ self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
934
+ stop = delta >= self.patience # stop training if patience exceeded
935
+ if stop:
936
+ prefix = colorstr("EarlyStopping: ")
937
+ LOGGER.info(
938
+ f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
939
+ f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
940
+ f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
941
+ f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
942
+ )
943
+ return stop
944
+
945
+
946
+ class FXModel(nn.Module):
947
+ """
948
+ A custom model class for torch.fx compatibility.
949
+
950
+ This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
951
+ manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
952
+ copying.
953
+
954
+ Attributes:
955
+ model (nn.Module): The original model's layers.
956
+ """
957
+
958
+ def __init__(self, model):
959
+ """
960
+ Initialize the FXModel.
961
+
962
+ Args:
963
+ model (nn.Module): The original model to wrap for torch.fx compatibility.
964
+ """
965
+ super().__init__()
966
+ copy_attr(self, model)
967
+ # Explicitly set `model` since `copy_attr` somehow does not copy it.
968
+ self.model = model.model
969
+
970
+ def forward(self, x):
971
+ """
972
+ Forward pass through the model.
973
+
974
+ This method performs the forward pass through the model, handling the dependencies between layers and saving
975
+ intermediate outputs.
976
+
977
+ Args:
978
+ x (torch.Tensor): The input tensor to the model.
979
+
980
+ Returns:
981
+ (torch.Tensor): The output tensor from the model.
982
+ """
983
+ y = [] # outputs
984
+ for m in self.model:
985
+ if m.f != -1: # if not from previous layer
986
+ # from earlier layers
987
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
988
+ x = m(x) # run
989
+ y.append(x) # save output
990
+ return x