ultralytics-opencv-headless 8.3.246__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1578 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +313 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +1006 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +501 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1563 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.246.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.246.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.246.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.246.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.246.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.246.dist-info/top_level.txt +1 -0
@@ -0,0 +1,118 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import os
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
13
+ from ultralytics.utils.torch_utils import autocast, profile_ops
14
+
15
+
16
+ def check_train_batch_size(
17
+ model: torch.nn.Module,
18
+ imgsz: int = 640,
19
+ amp: bool = True,
20
+ batch: int | float = -1,
21
+ max_num_obj: int = 1,
22
+ ) -> int:
23
+ """Compute optimal YOLO training batch size using the autobatch() function.
24
+
25
+ Args:
26
+ model (torch.nn.Module): YOLO model to check batch size for.
27
+ imgsz (int, optional): Image size used for training.
28
+ amp (bool, optional): Use automatic mixed precision if True.
29
+ batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.
30
+ max_num_obj (int, optional): The maximum number of objects from dataset.
31
+
32
+ Returns:
33
+ (int): Optimal batch size computed using the autobatch() function.
34
+
35
+ Notes:
36
+ If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
37
+ Otherwise, a default fraction of 0.6 is used.
38
+ """
39
+ with autocast(enabled=amp):
40
+ return autobatch(
41
+ deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj
42
+ )
43
+
44
+
45
+ def autobatch(
46
+ model: torch.nn.Module,
47
+ imgsz: int = 640,
48
+ fraction: float = 0.60,
49
+ batch_size: int = DEFAULT_CFG.batch,
50
+ max_num_obj: int = 1,
51
+ ) -> int:
52
+ """Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
53
+
54
+ Args:
55
+ model (torch.nn.Module): YOLO model to compute batch size for.
56
+ imgsz (int, optional): The image size used as input for the YOLO model.
57
+ fraction (float, optional): The fraction of available CUDA memory to use.
58
+ batch_size (int, optional): The default batch size to use if an error is detected.
59
+ max_num_obj (int, optional): The maximum number of objects from dataset.
60
+
61
+ Returns:
62
+ (int): The optimal batch size.
63
+ """
64
+ # Check device
65
+ prefix = colorstr("AutoBatch: ")
66
+ LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
67
+ device = next(model.parameters()).device # get model device
68
+ if device.type in {"cpu", "mps"}:
69
+ LOGGER.warning(f"{prefix}intended for CUDA devices, using default batch-size {batch_size}")
70
+ return batch_size
71
+ if torch.backends.cudnn.benchmark:
72
+ LOGGER.warning(f"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
73
+ return batch_size
74
+
75
+ # Inspect CUDA memory
76
+ gb = 1 << 30 # bytes to GiB (1024 ** 3)
77
+ d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0'
78
+ properties = torch.cuda.get_device_properties(device) # device properties
79
+ t = properties.total_memory / gb # GiB total
80
+ r = torch.cuda.memory_reserved(device) / gb # GiB reserved
81
+ a = torch.cuda.memory_allocated(device) / gb # GiB allocated
82
+ f = t - (r + a) # GiB free
83
+ LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
84
+
85
+ # Profile batch sizes
86
+ batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
87
+ try:
88
+ img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
89
+ results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)
90
+
91
+ # Fit a solution
92
+ xy = [
93
+ [x, y[2]]
94
+ for i, (x, y) in enumerate(zip(batch_sizes, results))
95
+ if y # valid result
96
+ and isinstance(y[2], (int, float)) # is numeric
97
+ and 0 < y[2] < t # between 0 and GPU limit
98
+ and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory
99
+ ]
100
+ fit_x, fit_y = zip(*xy) if xy else ([], [])
101
+ p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space
102
+ b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size)
103
+ if None in results: # some sizes failed
104
+ i = results.index(None) # first fail index
105
+ if b >= batch_sizes[i]: # y intercept above failure point
106
+ b = batch_sizes[max(i - 1, 0)] # select prior safe point
107
+ if b < 1 or b > 1024: # b outside of safe range
108
+ LOGGER.warning(f"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.")
109
+ b = batch_size
110
+
111
+ fraction = (np.polyval(p, b) + r + a) / t # predicted fraction
112
+ LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
113
+ return b
114
+ except Exception as e:
115
+ LOGGER.warning(f"{prefix}error detected: {e}, using default batch-size {batch_size}.")
116
+ return batch_size
117
+ finally:
118
+ torch.cuda.empty_cache()
@@ -0,0 +1,205 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from ultralytics.utils import LOGGER
8
+ from ultralytics.utils.checks import check_requirements
9
+
10
+
11
+ class GPUInfo:
12
+ """Manages NVIDIA GPU information via pynvml with robust error handling.
13
+
14
+ Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle GPUs
15
+ based on configurable criteria. It safely handles the absence or initialization failure of the pynvml library by
16
+ logging warnings and disabling related features, preventing application crashes.
17
+
18
+ Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU
19
+ selection. Manages NVML initialization and shutdown internally.
20
+
21
+ Attributes:
22
+ pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.
23
+ nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded, False
24
+ otherwise.
25
+ gpu_stats (list[dict[str, Any]]): A list of dictionaries, each holding stats for one GPU, populated on
26
+ initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%), 'memory_used' (MiB),
27
+ 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W), 'power_limit' (W or 'N/A').
28
+ Empty if NVML is unavailable or queries fail.
29
+
30
+ Methods:
31
+ refresh_stats: Refresh the internal gpu_stats list by querying NVML.
32
+ print_status: Print GPU status in a compact table format using current stats.
33
+ select_idle_gpu: Select the most idle GPUs based on utilization and free memory.
34
+ shutdown: Shut down NVML if it was initialized.
35
+
36
+ Examples:
37
+ Initialize GPUInfo and print status
38
+ >>> gpu_info = GPUInfo()
39
+ >>> gpu_info.print_status()
40
+
41
+ Select idle GPUs with minimum memory requirements
42
+ >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)
43
+ >>> print(f"Selected GPU indices: {selected}")
44
+ """
45
+
46
+ def __init__(self):
47
+ """Initialize GPUInfo, attempting to import and initialize pynvml."""
48
+ self.pynvml: Any | None = None
49
+ self.nvml_available: bool = False
50
+ self.gpu_stats: list[dict[str, Any]] = []
51
+
52
+ try:
53
+ check_requirements("nvidia-ml-py>=12.0.0")
54
+ self.pynvml = __import__("pynvml")
55
+ self.pynvml.nvmlInit()
56
+ self.nvml_available = True
57
+ self.refresh_stats()
58
+ except Exception as e:
59
+ LOGGER.warning(f"Failed to initialize pynvml, GPU stats disabled: {e}")
60
+
61
+ def __del__(self):
62
+ """Ensure NVML is shut down when the object is garbage collected."""
63
+ self.shutdown()
64
+
65
+ def shutdown(self):
66
+ """Shut down NVML if it was initialized."""
67
+ if self.nvml_available and self.pynvml:
68
+ try:
69
+ self.pynvml.nvmlShutdown()
70
+ except Exception:
71
+ pass
72
+ self.nvml_available = False
73
+
74
+ def refresh_stats(self):
75
+ """Refresh the internal gpu_stats list by querying NVML."""
76
+ self.gpu_stats = []
77
+ if not self.nvml_available or not self.pynvml:
78
+ return
79
+
80
+ try:
81
+ device_count = self.pynvml.nvmlDeviceGetCount()
82
+ self.gpu_stats.extend(self._get_device_stats(i) for i in range(device_count))
83
+ except Exception as e:
84
+ LOGGER.warning(f"Error during device query: {e}")
85
+ self.gpu_stats = []
86
+
87
+ def _get_device_stats(self, index: int) -> dict[str, Any]:
88
+ """Get stats for a single GPU device."""
89
+ handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)
90
+ memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
91
+ util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
92
+
93
+ def safe_get(func, *args, default=-1, divisor=1):
94
+ try:
95
+ val = func(*args)
96
+ return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val
97
+ except Exception:
98
+ return default
99
+
100
+ temp_type = getattr(self.pynvml, "NVML_TEMPERATURE_GPU", -1)
101
+
102
+ return {
103
+ "index": index,
104
+ "name": self.pynvml.nvmlDeviceGetName(handle),
105
+ "utilization": util.gpu if util else -1,
106
+ "memory_used": memory.used >> 20 if memory else -1, # Convert bytes to MiB
107
+ "memory_total": memory.total >> 20 if memory else -1,
108
+ "memory_free": memory.free >> 20 if memory else -1,
109
+ "temperature": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),
110
+ "power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W
111
+ "power_limit": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),
112
+ }
113
+
114
+ def print_status(self):
115
+ """Print GPU status in a compact table format using current stats."""
116
+ self.refresh_stats()
117
+ if not self.gpu_stats:
118
+ LOGGER.warning("No GPU stats available.")
119
+ return
120
+
121
+ stats = self.gpu_stats
122
+ name_len = max(len(gpu.get("name", "N/A")) for gpu in stats)
123
+ hdr = f"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}"
124
+ LOGGER.info(f"\n--- GPU Status ---\n{hdr}\n{'-' * len(hdr)}")
125
+
126
+ for gpu in stats:
127
+ u = f"{gpu['utilization']:>5}%" if gpu["utilization"] >= 0 else " N/A "
128
+ m = f"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}" if gpu["memory_used"] >= 0 else " N/A / N/A "
129
+ t = f"{gpu['temperature']}C" if gpu["temperature"] >= 0 else " N/A "
130
+ p = f"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}" if gpu["power_draw"] >= 0 else " N/A "
131
+
132
+ LOGGER.info(f"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}")
133
+
134
+ LOGGER.info(f"{'-' * len(hdr)}\n")
135
+
136
+ def select_idle_gpu(
137
+ self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0
138
+ ) -> list[int]:
139
+ """Select the most idle GPUs based on utilization and free memory.
140
+
141
+ Args:
142
+ count (int): The number of idle GPUs to select.
143
+ min_memory_fraction (float): Minimum free memory required as a fraction of total memory.
144
+ min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0.
145
+
146
+ Returns:
147
+ (list[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first).
148
+
149
+ Notes:
150
+ Returns fewer than 'count' if not enough qualify or exist.
151
+ Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.
152
+ """
153
+ assert min_memory_fraction <= 1.0, f"min_memory_fraction must be <= 1.0, got {min_memory_fraction}"
154
+ assert min_util_fraction <= 1.0, f"min_util_fraction must be <= 1.0, got {min_util_fraction}"
155
+ criteria = (
156
+ f"free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%"
157
+ )
158
+ LOGGER.info(f"Searching for {count} idle GPUs with {criteria}...")
159
+
160
+ if count <= 0:
161
+ return []
162
+
163
+ self.refresh_stats()
164
+ if not self.gpu_stats:
165
+ LOGGER.warning("NVML stats unavailable.")
166
+ return []
167
+
168
+ # Filter and sort eligible GPUs
169
+ eligible_gpus = [
170
+ gpu
171
+ for gpu in self.gpu_stats
172
+ if gpu.get("memory_free", 0) / gpu.get("memory_total", 1) >= min_memory_fraction
173
+ and (100 - gpu.get("utilization", 100)) >= min_util_fraction * 100
174
+ ]
175
+ eligible_gpus.sort(key=lambda x: (x.get("utilization", 101), -x.get("memory_free", 0)))
176
+
177
+ # Select top 'count' indices
178
+ selected = [gpu["index"] for gpu in eligible_gpus[:count]]
179
+
180
+ if selected:
181
+ if len(selected) < count:
182
+ LOGGER.warning(f"Requested {count} GPUs but only {len(selected)} met the idle criteria.")
183
+ LOGGER.info(f"Selected idle CUDA devices {selected}")
184
+ else:
185
+ LOGGER.warning(f"No GPUs met criteria ({criteria}).")
186
+
187
+ return selected
188
+
189
+
190
+ if __name__ == "__main__":
191
+ required_free_mem_fraction = 0.2 # Require 20% free VRAM
192
+ required_free_util_fraction = 0.2 # Require 20% free utilization
193
+ num_gpus_to_select = 1
194
+
195
+ gpu_info = GPUInfo()
196
+ gpu_info.print_status()
197
+
198
+ if selected := gpu_info.select_idle_gpu(
199
+ count=num_gpus_to_select,
200
+ min_memory_fraction=required_free_mem_fraction,
201
+ min_util_fraction=required_free_util_fraction,
202
+ ):
203
+ print(f"\n==> Using selected GPU indices: {selected}")
204
+ devices = [f"cuda:{idx}" for idx in selected]
205
+ print(f" Target devices: {devices}")