dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import os
5
7
  from copy import deepcopy
6
8
 
@@ -11,15 +13,20 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
11
13
  from ultralytics.utils.torch_utils import autocast, profile_ops
12
14
 
13
15
 
14
- def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
15
- """
16
- Compute optimal YOLO training batch size using the autobatch() function.
16
+ def check_train_batch_size(
17
+ model: torch.nn.Module,
18
+ imgsz: int = 640,
19
+ amp: bool = True,
20
+ batch: int | float = -1,
21
+ max_num_obj: int = 1,
22
+ ) -> int:
23
+ """Compute optimal YOLO training batch size using the autobatch() function.
17
24
 
18
25
  Args:
19
26
  model (torch.nn.Module): YOLO model to check batch size for.
20
27
  imgsz (int, optional): Image size used for training.
21
28
  amp (bool, optional): Use automatic mixed precision if True.
22
- batch (float, optional): Fraction of GPU memory to use. If -1, use default.
29
+ batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.
23
30
  max_num_obj (int, optional): The maximum number of objects from dataset.
24
31
 
25
32
  Returns:
@@ -35,9 +42,14 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
35
42
  )
36
43
 
37
44
 
38
- def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1):
39
- """
40
- Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
45
+ def autobatch(
46
+ model: torch.nn.Module,
47
+ imgsz: int = 640,
48
+ fraction: float = 0.60,
49
+ batch_size: int = DEFAULT_CFG.batch,
50
+ max_num_obj: int = 1,
51
+ ) -> int:
52
+ """Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
41
53
 
42
54
  Args:
43
55
  model (torch.nn.Module): YOLO model to compute batch size for.
@@ -1,38 +1,56 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
3
7
  from ultralytics.utils import LOGGER
4
8
  from ultralytics.utils.checks import check_requirements
5
9
 
6
10
 
7
11
  class GPUInfo:
8
- """
9
- Manages NVIDIA GPU information via pynvml with robust error handling.
12
+ """Manages NVIDIA GPU information via pynvml with robust error handling.
10
13
 
11
- Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle
12
- GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml
13
- library by logging warnings and disabling related features, preventing application crashes.
14
+ Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle GPUs
15
+ based on configurable criteria. It safely handles the absence or initialization failure of the pynvml library by
16
+ logging warnings and disabling related features, preventing application crashes.
14
17
 
15
18
  Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU
16
19
  selection. Manages NVML initialization and shutdown internally.
17
20
 
18
21
  Attributes:
19
22
  pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.
20
- nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,
21
- False otherwise.
22
- gpu_stats (list[dict]): A list of dictionaries, each holding stats for one GPU. Populated on initialization
23
- and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%), 'memory_used' (MiB),
24
- 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),
25
- 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.
23
+ nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded, False
24
+ otherwise.
25
+ gpu_stats (list[dict[str, Any]]): A list of dictionaries, each holding stats for one GPU, populated on
26
+ initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%), 'memory_used' (MiB),
27
+ 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W), 'power_limit' (W or 'N/A').
28
+ Empty if NVML is unavailable or queries fail.
29
+
30
+ Methods:
31
+ refresh_stats: Refresh the internal gpu_stats list by querying NVML.
32
+ print_status: Print GPU status in a compact table format using current stats.
33
+ select_idle_gpu: Select the most idle GPUs based on utilization and free memory.
34
+ shutdown: Shut down NVML if it was initialized.
35
+
36
+ Examples:
37
+ Initialize GPUInfo and print status
38
+ >>> gpu_info = GPUInfo()
39
+ >>> gpu_info.print_status()
40
+
41
+ Select idle GPUs with minimum memory requirements
42
+ >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)
43
+ >>> print(f"Selected GPU indices: {selected}")
26
44
  """
27
45
 
28
46
  def __init__(self):
29
- """Initializes GPUInfo, attempting to import and initialize pynvml."""
30
- self.pynvml = None
31
- self.nvml_available = False
32
- self.gpu_stats = []
47
+ """Initialize GPUInfo, attempting to import and initialize pynvml."""
48
+ self.pynvml: Any | None = None
49
+ self.nvml_available: bool = False
50
+ self.gpu_stats: list[dict[str, Any]] = []
33
51
 
34
52
  try:
35
- check_requirements("pynvml>=12.0.0")
53
+ check_requirements("nvidia-ml-py>=12.0.0")
36
54
  self.pynvml = __import__("pynvml")
37
55
  self.pynvml.nvmlInit()
38
56
  self.nvml_available = True
@@ -41,11 +59,11 @@ class GPUInfo:
41
59
  LOGGER.warning(f"Failed to initialize pynvml, GPU stats disabled: {e}")
42
60
 
43
61
  def __del__(self):
44
- """Ensures NVML is shut down when the object is garbage collected."""
62
+ """Ensure NVML is shut down when the object is garbage collected."""
45
63
  self.shutdown()
46
64
 
47
65
  def shutdown(self):
48
- """Shuts down NVML if it was initialized."""
66
+ """Shut down NVML if it was initialized."""
49
67
  if self.nvml_available and self.pynvml:
50
68
  try:
51
69
  self.pynvml.nvmlShutdown()
@@ -54,21 +72,20 @@ class GPUInfo:
54
72
  self.nvml_available = False
55
73
 
56
74
  def refresh_stats(self):
57
- """Refreshes the internal gpu_stats list by querying NVML."""
75
+ """Refresh the internal gpu_stats list by querying NVML."""
58
76
  self.gpu_stats = []
59
77
  if not self.nvml_available or not self.pynvml:
60
78
  return
61
79
 
62
80
  try:
63
81
  device_count = self.pynvml.nvmlDeviceGetCount()
64
- for i in range(device_count):
65
- self.gpu_stats.append(self._get_device_stats(i))
82
+ self.gpu_stats.extend(self._get_device_stats(i) for i in range(device_count))
66
83
  except Exception as e:
67
84
  LOGGER.warning(f"Error during device query: {e}")
68
85
  self.gpu_stats = []
69
86
 
70
- def _get_device_stats(self, index):
71
- """Gets stats for a single GPU device."""
87
+ def _get_device_stats(self, index: int) -> dict[str, Any]:
88
+ """Get stats for a single GPU device."""
72
89
  handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)
73
90
  memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
74
91
  util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
@@ -86,16 +103,16 @@ class GPUInfo:
86
103
  "index": index,
87
104
  "name": self.pynvml.nvmlDeviceGetName(handle),
88
105
  "utilization": util.gpu if util else -1,
89
- "memory_used": memory.used >> 20 if memory else -1,
106
+ "memory_used": memory.used >> 20 if memory else -1, # Convert bytes to MiB
90
107
  "memory_total": memory.total >> 20 if memory else -1,
91
108
  "memory_free": memory.free >> 20 if memory else -1,
92
109
  "temperature": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),
93
- "power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000),
110
+ "power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W
94
111
  "power_limit": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),
95
112
  }
96
113
 
97
114
  def print_status(self):
98
- """Prints GPU status in a compact table format using current stats."""
115
+ """Print GPU status in a compact table format using current stats."""
99
116
  self.refresh_stats()
100
117
  if not self.gpu_stats:
101
118
  LOGGER.warning("No GPU stats available.")
@@ -116,22 +133,28 @@ class GPUInfo:
116
133
 
117
134
  LOGGER.info(f"{'-' * len(hdr)}\n")
118
135
 
119
- def select_idle_gpu(self, count=1, min_memory_mb=0):
120
- """
121
- Selects the 'count' most idle GPUs based on utilization and free memory.
136
+ def select_idle_gpu(
137
+ self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0
138
+ ) -> list[int]:
139
+ """Select the most idle GPUs based on utilization and free memory.
122
140
 
123
141
  Args:
124
- count (int): The number of idle GPUs to select. Defaults to 1.
125
- min_memory_mb (int): Minimum free memory required (MiB). Defaults to 0.
142
+ count (int): The number of idle GPUs to select.
143
+ min_memory_fraction (float): Minimum free memory required as a fraction of total memory.
144
+ min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0.
126
145
 
127
146
  Returns:
128
- (list[int]): Indices of the selected GPUs, sorted by idleness.
147
+ (list[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first).
129
148
 
130
149
  Notes:
131
150
  Returns fewer than 'count' if not enough qualify or exist.
132
151
  Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.
133
152
  """
134
- LOGGER.info(f"Searching for {count} idle GPUs with >= {min_memory_mb} MiB free memory...")
153
+ assert min_memory_fraction <= 1.0, f"min_memory_fraction must be <= 1.0, got {min_memory_fraction}"
154
+ assert min_util_fraction <= 1.0, f"min_util_fraction must be <= 1.0, got {min_util_fraction}"
155
+ LOGGER.info(
156
+ f"Searching for {count} idle GPUs with free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%..."
157
+ )
135
158
 
136
159
  if count <= 0:
137
160
  return []
@@ -145,7 +168,8 @@ class GPUInfo:
145
168
  eligible_gpus = [
146
169
  gpu
147
170
  for gpu in self.gpu_stats
148
- if gpu.get("memory_free", -1) >= min_memory_mb and gpu.get("utilization", -1) != -1
171
+ if gpu.get("memory_free", 0) / gpu.get("memory_total", 1) >= min_memory_fraction
172
+ and (100 - gpu.get("utilization", 100)) >= min_util_fraction * 100
149
173
  ]
150
174
  eligible_gpus.sort(key=lambda x: (x.get("utilization", 101), -x.get("memory_free", 0)))
151
175
 
@@ -155,20 +179,26 @@ class GPUInfo:
155
179
  if selected:
156
180
  LOGGER.info(f"Selected idle CUDA devices {selected}")
157
181
  else:
158
- LOGGER.warning(f"No GPUs met criteria (Util != -1, Free Mem >= {min_memory_mb} MiB).")
182
+ LOGGER.warning(
183
+ f"No GPUs met criteria (Free Mem >= {min_memory_fraction * 100:.1f}% and Free Util >= {min_util_fraction * 100:.1f}%)."
184
+ )
159
185
 
160
186
  return selected
161
187
 
162
188
 
163
189
  if __name__ == "__main__":
164
- required_free_mem = 2048 # Require 2GB free VRAM
190
+ required_free_mem_fraction = 0.2 # Require 20% free VRAM
191
+ required_free_util_fraction = 0.2 # Require 20% free utilization
165
192
  num_gpus_to_select = 1
166
193
 
167
194
  gpu_info = GPUInfo()
168
195
  gpu_info.print_status()
169
196
 
170
- selected = gpu_info.select_idle_gpu(count=num_gpus_to_select, min_memory_mb=required_free_mem)
171
- if selected:
197
+ if selected := gpu_info.select_idle_gpu(
198
+ count=num_gpus_to_select,
199
+ min_memory_fraction=required_free_mem_fraction,
200
+ min_util_fraction=required_free_util_fraction,
201
+ ):
172
202
  print(f"\n==> Using selected GPU indices: {selected}")
173
203
  devices = [f"cuda:{idx}" for idx in selected]
174
204
  print(f" Target devices: {devices}")