dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,106 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
3
+
4
+ import os
5
+ from copy import deepcopy
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
11
+ from ultralytics.utils.torch_utils import autocast, profile_ops
12
+
13
+
14
+ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
15
+ """
16
+ Compute optimal YOLO training batch size using the autobatch() function.
17
+
18
+ Args:
19
+ model (torch.nn.Module): YOLO model to check batch size for.
20
+ imgsz (int, optional): Image size used for training.
21
+ amp (bool, optional): Use automatic mixed precision if True.
22
+ batch (float, optional): Fraction of GPU memory to use. If -1, use default.
23
+ max_num_obj (int, optional): The maximum number of objects from dataset.
24
+
25
+ Returns:
26
+ (int): Optimal batch size computed using the autobatch() function.
27
+
28
+ Notes:
29
+ If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
30
+ Otherwise, a default fraction of 0.6 is used.
31
+ """
32
+ with autocast(enabled=amp):
33
+ return autobatch(
34
+ deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj
35
+ )
36
+
37
+
38
+ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1):
39
+ """
40
+ Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
41
+
42
+ Args:
43
+ model (torch.nn.Module): YOLO model to compute batch size for.
44
+ imgsz (int, optional): The image size used as input for the YOLO model.
45
+ fraction (float, optional): The fraction of available CUDA memory to use.
46
+ batch_size (int, optional): The default batch size to use if an error is detected.
47
+ max_num_obj (int, optional): The maximum number of objects from dataset.
48
+
49
+ Returns:
50
+ (int): The optimal batch size.
51
+ """
52
+ # Check device
53
+ prefix = colorstr("AutoBatch: ")
54
+ LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
55
+ device = next(model.parameters()).device # get model device
56
+ if device.type in {"cpu", "mps"}:
57
+ LOGGER.warning(f"{prefix}intended for CUDA devices, using default batch-size {batch_size}")
58
+ return batch_size
59
+ if torch.backends.cudnn.benchmark:
60
+ LOGGER.warning(f"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
61
+ return batch_size
62
+
63
+ # Inspect CUDA memory
64
+ gb = 1 << 30 # bytes to GiB (1024 ** 3)
65
+ d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0'
66
+ properties = torch.cuda.get_device_properties(device) # device properties
67
+ t = properties.total_memory / gb # GiB total
68
+ r = torch.cuda.memory_reserved(device) / gb # GiB reserved
69
+ a = torch.cuda.memory_allocated(device) / gb # GiB allocated
70
+ f = t - (r + a) # GiB free
71
+ LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
72
+
73
+ # Profile batch sizes
74
+ batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
75
+ try:
76
+ img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
77
+ results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)
78
+
79
+ # Fit a solution
80
+ xy = [
81
+ [x, y[2]]
82
+ for i, (x, y) in enumerate(zip(batch_sizes, results))
83
+ if y # valid result
84
+ and isinstance(y[2], (int, float)) # is numeric
85
+ and 0 < y[2] < t # between 0 and GPU limit
86
+ and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory
87
+ ]
88
+ fit_x, fit_y = zip(*xy) if xy else ([], [])
89
+ p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space
90
+ b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size)
91
+ if None in results: # some sizes failed
92
+ i = results.index(None) # first fail index
93
+ if b >= batch_sizes[i]: # y intercept above failure point
94
+ b = batch_sizes[max(i - 1, 0)] # select prior safe point
95
+ if b < 1 or b > 1024: # b outside of safe range
96
+ LOGGER.warning(f"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.")
97
+ b = batch_size
98
+
99
+ fraction = (np.polyval(p, b) + r + a) / t # predicted fraction
100
+ LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
101
+ return b
102
+ except Exception as e:
103
+ LOGGER.warning(f"{prefix}error detected: {e}, using default batch-size {batch_size}.")
104
+ return batch_size
105
+ finally:
106
+ torch.cuda.empty_cache()
@@ -0,0 +1,174 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.utils import LOGGER
4
+ from ultralytics.utils.checks import check_requirements
5
+
6
+
7
+ class GPUInfo:
8
+ """
9
+ Manages NVIDIA GPU information via pynvml with robust error handling.
10
+
11
+ Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle
12
+ GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml
13
+ library by logging warnings and disabling related features, preventing application crashes.
14
+
15
+ Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU
16
+ selection. Manages NVML initialization and shutdown internally.
17
+
18
+ Attributes:
19
+ pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.
20
+ nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,
21
+ False otherwise.
22
+ gpu_stats (list[dict]): A list of dictionaries, each holding stats for one GPU. Populated on initialization
23
+ and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%), 'memory_used' (MiB),
24
+ 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),
25
+ 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.
26
+ """
27
+
28
+ def __init__(self):
29
+ """Initializes GPUInfo, attempting to import and initialize pynvml."""
30
+ self.pynvml = None
31
+ self.nvml_available = False
32
+ self.gpu_stats = []
33
+
34
+ try:
35
+ check_requirements("pynvml>=12.0.0")
36
+ self.pynvml = __import__("pynvml")
37
+ self.pynvml.nvmlInit()
38
+ self.nvml_available = True
39
+ self.refresh_stats()
40
+ except Exception as e:
41
+ LOGGER.warning(f"Failed to initialize pynvml, GPU stats disabled: {e}")
42
+
43
+ def __del__(self):
44
+ """Ensures NVML is shut down when the object is garbage collected."""
45
+ self.shutdown()
46
+
47
+ def shutdown(self):
48
+ """Shuts down NVML if it was initialized."""
49
+ if self.nvml_available and self.pynvml:
50
+ try:
51
+ self.pynvml.nvmlShutdown()
52
+ except Exception:
53
+ pass
54
+ self.nvml_available = False
55
+
56
+ def refresh_stats(self):
57
+ """Refreshes the internal gpu_stats list by querying NVML."""
58
+ self.gpu_stats = []
59
+ if not self.nvml_available or not self.pynvml:
60
+ return
61
+
62
+ try:
63
+ device_count = self.pynvml.nvmlDeviceGetCount()
64
+ for i in range(device_count):
65
+ self.gpu_stats.append(self._get_device_stats(i))
66
+ except Exception as e:
67
+ LOGGER.warning(f"Error during device query: {e}")
68
+ self.gpu_stats = []
69
+
70
+ def _get_device_stats(self, index):
71
+ """Gets stats for a single GPU device."""
72
+ handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)
73
+ memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
74
+ util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
75
+
76
+ def safe_get(func, *args, default=-1, divisor=1):
77
+ try:
78
+ val = func(*args)
79
+ return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val
80
+ except Exception:
81
+ return default
82
+
83
+ temp_type = getattr(self.pynvml, "NVML_TEMPERATURE_GPU", -1)
84
+
85
+ return {
86
+ "index": index,
87
+ "name": self.pynvml.nvmlDeviceGetName(handle),
88
+ "utilization": util.gpu if util else -1,
89
+ "memory_used": memory.used >> 20 if memory else -1,
90
+ "memory_total": memory.total >> 20 if memory else -1,
91
+ "memory_free": memory.free >> 20 if memory else -1,
92
+ "temperature": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),
93
+ "power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000),
94
+ "power_limit": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),
95
+ }
96
+
97
+ def print_status(self):
98
+ """Prints GPU status in a compact table format using current stats."""
99
+ self.refresh_stats()
100
+ if not self.gpu_stats:
101
+ LOGGER.warning("No GPU stats available.")
102
+ return
103
+
104
+ stats = self.gpu_stats
105
+ name_len = max(len(gpu.get("name", "N/A")) for gpu in stats)
106
+ hdr = f"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}"
107
+ LOGGER.info(f"\n--- GPU Status ---\n{hdr}\n{'-' * len(hdr)}")
108
+
109
+ for gpu in stats:
110
+ u = f"{gpu['utilization']:>5}%" if gpu["utilization"] >= 0 else " N/A "
111
+ m = f"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}" if gpu["memory_used"] >= 0 else " N/A / N/A "
112
+ t = f"{gpu['temperature']}C" if gpu["temperature"] >= 0 else " N/A "
113
+ p = f"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}" if gpu["power_draw"] >= 0 else " N/A "
114
+
115
+ LOGGER.info(f"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}")
116
+
117
+ LOGGER.info(f"{'-' * len(hdr)}\n")
118
+
119
+ def select_idle_gpu(self, count=1, min_memory_mb=0):
120
+ """
121
+ Selects the 'count' most idle GPUs based on utilization and free memory.
122
+
123
+ Args:
124
+ count (int): The number of idle GPUs to select. Defaults to 1.
125
+ min_memory_mb (int): Minimum free memory required (MiB). Defaults to 0.
126
+
127
+ Returns:
128
+ (list[int]): Indices of the selected GPUs, sorted by idleness.
129
+
130
+ Notes:
131
+ Returns fewer than 'count' if not enough qualify or exist.
132
+ Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.
133
+ """
134
+ LOGGER.info(f"Searching for {count} idle GPUs with >= {min_memory_mb} MiB free memory...")
135
+
136
+ if count <= 0:
137
+ return []
138
+
139
+ self.refresh_stats()
140
+ if not self.gpu_stats:
141
+ LOGGER.warning("NVML stats unavailable.")
142
+ return []
143
+
144
+ # Filter and sort eligible GPUs
145
+ eligible_gpus = [
146
+ gpu
147
+ for gpu in self.gpu_stats
148
+ if gpu.get("memory_free", -1) >= min_memory_mb and gpu.get("utilization", -1) != -1
149
+ ]
150
+ eligible_gpus.sort(key=lambda x: (x.get("utilization", 101), -x.get("memory_free", 0)))
151
+
152
+ # Select top 'count' indices
153
+ selected = [gpu["index"] for gpu in eligible_gpus[:count]]
154
+
155
+ if selected:
156
+ LOGGER.info(f"Selected idle CUDA devices {selected}")
157
+ else:
158
+ LOGGER.warning(f"No GPUs met criteria (Util != -1, Free Mem >= {min_memory_mb} MiB).")
159
+
160
+ return selected
161
+
162
+
163
+ if __name__ == "__main__":
164
+ required_free_mem = 2048 # Require 2GB free VRAM
165
+ num_gpus_to_select = 1
166
+
167
+ gpu_info = GPUInfo()
168
+ gpu_info.print_status()
169
+
170
+ selected = gpu_info.select_idle_gpu(count=num_gpus_to_select, min_memory_mb=required_free_mem)
171
+ if selected:
172
+ print(f"\n==> Using selected GPU indices: {selected}")
173
+ devices = [f"cuda:{idx}" for idx in selected]
174
+ print(f" Target devices: {devices}")