hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. hud/__init__.py +27 -7
  2. hud/agents/__init__.py +70 -5
  3. hud/agents/base.py +238 -500
  4. hud/agents/claude.py +236 -247
  5. hud/agents/gateway.py +42 -0
  6. hud/agents/gemini.py +264 -0
  7. hud/agents/gemini_cua.py +324 -0
  8. hud/agents/grounded_openai.py +98 -100
  9. hud/agents/misc/integration_test_agent.py +51 -20
  10. hud/agents/misc/response_agent.py +48 -36
  11. hud/agents/openai.py +282 -296
  12. hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
  13. hud/agents/operator.py +199 -0
  14. hud/agents/resolver.py +70 -0
  15. hud/agents/tests/conftest.py +133 -0
  16. hud/agents/tests/test_base.py +300 -622
  17. hud/agents/tests/test_base_runtime.py +233 -0
  18. hud/agents/tests/test_claude.py +381 -214
  19. hud/agents/tests/test_client.py +9 -10
  20. hud/agents/tests/test_gemini.py +369 -0
  21. hud/agents/tests/test_grounded_openai_agent.py +65 -50
  22. hud/agents/tests/test_openai.py +377 -140
  23. hud/agents/tests/test_operator.py +362 -0
  24. hud/agents/tests/test_resolver.py +192 -0
  25. hud/agents/tests/test_run_eval.py +179 -0
  26. hud/agents/types.py +148 -0
  27. hud/cli/__init__.py +493 -546
  28. hud/cli/analyze.py +43 -5
  29. hud/cli/build.py +699 -113
  30. hud/cli/debug.py +8 -5
  31. hud/cli/dev.py +889 -732
  32. hud/cli/eval.py +793 -667
  33. hud/cli/flows/dev.py +167 -0
  34. hud/cli/flows/init.py +191 -0
  35. hud/cli/flows/tasks.py +153 -56
  36. hud/cli/flows/templates.py +151 -0
  37. hud/cli/flows/tests/__init__.py +1 -0
  38. hud/cli/flows/tests/test_dev.py +126 -0
  39. hud/cli/init.py +60 -58
  40. hud/cli/pull.py +1 -1
  41. hud/cli/push.py +38 -13
  42. hud/cli/rft.py +311 -0
  43. hud/cli/rft_status.py +145 -0
  44. hud/cli/tests/test_analyze.py +5 -5
  45. hud/cli/tests/test_analyze_metadata.py +3 -2
  46. hud/cli/tests/test_analyze_module.py +120 -0
  47. hud/cli/tests/test_build.py +110 -8
  48. hud/cli/tests/test_build_failure.py +41 -0
  49. hud/cli/tests/test_build_module.py +50 -0
  50. hud/cli/tests/test_cli_init.py +6 -1
  51. hud/cli/tests/test_cli_more_wrappers.py +30 -0
  52. hud/cli/tests/test_cli_root.py +140 -0
  53. hud/cli/tests/test_convert.py +361 -0
  54. hud/cli/tests/test_debug.py +12 -10
  55. hud/cli/tests/test_dev.py +197 -0
  56. hud/cli/tests/test_eval.py +251 -0
  57. hud/cli/tests/test_eval_bedrock.py +51 -0
  58. hud/cli/tests/test_init.py +124 -0
  59. hud/cli/tests/test_main_module.py +11 -5
  60. hud/cli/tests/test_mcp_server.py +12 -100
  61. hud/cli/tests/test_push.py +1 -1
  62. hud/cli/tests/test_push_happy.py +74 -0
  63. hud/cli/tests/test_push_wrapper.py +23 -0
  64. hud/cli/tests/test_registry.py +1 -1
  65. hud/cli/tests/test_utils.py +1 -1
  66. hud/cli/{rl → utils}/celebrate.py +14 -12
  67. hud/cli/utils/config.py +18 -1
  68. hud/cli/utils/docker.py +130 -4
  69. hud/cli/utils/env_check.py +9 -9
  70. hud/cli/utils/git.py +136 -0
  71. hud/cli/utils/interactive.py +39 -5
  72. hud/cli/utils/metadata.py +70 -1
  73. hud/cli/utils/runner.py +1 -1
  74. hud/cli/utils/server.py +2 -2
  75. hud/cli/utils/source_hash.py +3 -3
  76. hud/cli/utils/tasks.py +4 -1
  77. hud/cli/utils/tests/__init__.py +0 -0
  78. hud/cli/utils/tests/test_config.py +58 -0
  79. hud/cli/utils/tests/test_docker.py +93 -0
  80. hud/cli/utils/tests/test_docker_hints.py +71 -0
  81. hud/cli/utils/tests/test_env_check.py +74 -0
  82. hud/cli/utils/tests/test_environment.py +42 -0
  83. hud/cli/utils/tests/test_git.py +142 -0
  84. hud/cli/utils/tests/test_interactive_module.py +60 -0
  85. hud/cli/utils/tests/test_local_runner.py +50 -0
  86. hud/cli/utils/tests/test_logging_utils.py +23 -0
  87. hud/cli/utils/tests/test_metadata.py +49 -0
  88. hud/cli/utils/tests/test_package_runner.py +35 -0
  89. hud/cli/utils/tests/test_registry_utils.py +49 -0
  90. hud/cli/utils/tests/test_remote_runner.py +25 -0
  91. hud/cli/utils/tests/test_runner_modules.py +52 -0
  92. hud/cli/utils/tests/test_source_hash.py +36 -0
  93. hud/cli/utils/tests/test_tasks.py +80 -0
  94. hud/cli/utils/version_check.py +258 -0
  95. hud/cli/{rl → utils}/viewer.py +2 -2
  96. hud/clients/README.md +12 -11
  97. hud/clients/__init__.py +4 -3
  98. hud/clients/base.py +166 -26
  99. hud/clients/environment.py +51 -0
  100. hud/clients/fastmcp.py +13 -6
  101. hud/clients/mcp_use.py +45 -15
  102. hud/clients/tests/test_analyze_scenarios.py +206 -0
  103. hud/clients/tests/test_protocol.py +9 -3
  104. hud/datasets/__init__.py +23 -20
  105. hud/datasets/loader.py +326 -0
  106. hud/datasets/runner.py +198 -105
  107. hud/datasets/tests/__init__.py +0 -0
  108. hud/datasets/tests/test_loader.py +221 -0
  109. hud/datasets/tests/test_utils.py +315 -0
  110. hud/datasets/utils.py +270 -90
  111. hud/environment/__init__.py +52 -0
  112. hud/environment/connection.py +258 -0
  113. hud/environment/connectors/__init__.py +33 -0
  114. hud/environment/connectors/base.py +68 -0
  115. hud/environment/connectors/local.py +177 -0
  116. hud/environment/connectors/mcp_config.py +137 -0
  117. hud/environment/connectors/openai.py +101 -0
  118. hud/environment/connectors/remote.py +172 -0
  119. hud/environment/environment.py +835 -0
  120. hud/environment/integrations/__init__.py +45 -0
  121. hud/environment/integrations/adk.py +67 -0
  122. hud/environment/integrations/anthropic.py +196 -0
  123. hud/environment/integrations/gemini.py +92 -0
  124. hud/environment/integrations/langchain.py +82 -0
  125. hud/environment/integrations/llamaindex.py +68 -0
  126. hud/environment/integrations/openai.py +238 -0
  127. hud/environment/mock.py +306 -0
  128. hud/environment/router.py +263 -0
  129. hud/environment/scenarios.py +620 -0
  130. hud/environment/tests/__init__.py +1 -0
  131. hud/environment/tests/test_connection.py +317 -0
  132. hud/environment/tests/test_connectors.py +205 -0
  133. hud/environment/tests/test_environment.py +593 -0
  134. hud/environment/tests/test_integrations.py +257 -0
  135. hud/environment/tests/test_local_connectors.py +242 -0
  136. hud/environment/tests/test_scenarios.py +1086 -0
  137. hud/environment/tests/test_tools.py +208 -0
  138. hud/environment/types.py +23 -0
  139. hud/environment/utils/__init__.py +35 -0
  140. hud/environment/utils/formats.py +215 -0
  141. hud/environment/utils/schema.py +171 -0
  142. hud/environment/utils/tool_wrappers.py +113 -0
  143. hud/eval/__init__.py +67 -0
  144. hud/eval/context.py +727 -0
  145. hud/eval/display.py +299 -0
  146. hud/eval/instrument.py +187 -0
  147. hud/eval/manager.py +533 -0
  148. hud/eval/parallel.py +268 -0
  149. hud/eval/task.py +372 -0
  150. hud/eval/tests/__init__.py +1 -0
  151. hud/eval/tests/test_context.py +178 -0
  152. hud/eval/tests/test_eval.py +210 -0
  153. hud/eval/tests/test_manager.py +152 -0
  154. hud/eval/tests/test_parallel.py +168 -0
  155. hud/eval/tests/test_task.py +291 -0
  156. hud/eval/types.py +65 -0
  157. hud/eval/utils.py +194 -0
  158. hud/patches/__init__.py +19 -0
  159. hud/patches/mcp_patches.py +308 -0
  160. hud/patches/warnings.py +54 -0
  161. hud/samples/browser.py +4 -4
  162. hud/server/__init__.py +2 -1
  163. hud/server/low_level.py +2 -1
  164. hud/server/router.py +164 -0
  165. hud/server/server.py +567 -80
  166. hud/server/tests/test_mcp_server_integration.py +11 -11
  167. hud/server/tests/test_mcp_server_more.py +1 -1
  168. hud/server/tests/test_server_extra.py +2 -0
  169. hud/settings.py +45 -3
  170. hud/shared/exceptions.py +36 -10
  171. hud/shared/hints.py +26 -1
  172. hud/shared/requests.py +15 -3
  173. hud/shared/tests/test_exceptions.py +40 -31
  174. hud/shared/tests/test_hints.py +167 -0
  175. hud/telemetry/__init__.py +20 -19
  176. hud/telemetry/exporter.py +201 -0
  177. hud/telemetry/instrument.py +165 -253
  178. hud/telemetry/tests/test_eval_telemetry.py +356 -0
  179. hud/telemetry/tests/test_exporter.py +258 -0
  180. hud/telemetry/tests/test_instrument.py +401 -0
  181. hud/tools/__init__.py +18 -2
  182. hud/tools/agent.py +223 -0
  183. hud/tools/apply_patch.py +639 -0
  184. hud/tools/base.py +54 -4
  185. hud/tools/bash.py +2 -2
  186. hud/tools/computer/__init__.py +36 -3
  187. hud/tools/computer/anthropic.py +2 -2
  188. hud/tools/computer/gemini.py +385 -0
  189. hud/tools/computer/hud.py +23 -6
  190. hud/tools/computer/openai.py +20 -21
  191. hud/tools/computer/qwen.py +434 -0
  192. hud/tools/computer/settings.py +37 -0
  193. hud/tools/edit.py +3 -7
  194. hud/tools/executors/base.py +4 -2
  195. hud/tools/executors/pyautogui.py +1 -1
  196. hud/tools/grounding/grounded_tool.py +13 -18
  197. hud/tools/grounding/grounder.py +10 -31
  198. hud/tools/grounding/tests/test_grounded_tool.py +26 -44
  199. hud/tools/jupyter.py +330 -0
  200. hud/tools/playwright.py +18 -3
  201. hud/tools/shell.py +308 -0
  202. hud/tools/tests/test_agent_tool.py +355 -0
  203. hud/tools/tests/test_apply_patch.py +718 -0
  204. hud/tools/tests/test_computer.py +4 -9
  205. hud/tools/tests/test_computer_actions.py +24 -2
  206. hud/tools/tests/test_jupyter_tool.py +181 -0
  207. hud/tools/tests/test_shell.py +596 -0
  208. hud/tools/tests/test_submit.py +85 -0
  209. hud/tools/tests/test_types.py +193 -0
  210. hud/tools/types.py +21 -1
  211. hud/types.py +194 -56
  212. hud/utils/__init__.py +2 -0
  213. hud/utils/env.py +67 -0
  214. hud/utils/hud_console.py +89 -18
  215. hud/utils/mcp.py +15 -58
  216. hud/utils/strict_schema.py +162 -0
  217. hud/utils/tests/test_init.py +1 -2
  218. hud/utils/tests/test_mcp.py +1 -28
  219. hud/utils/tests/test_pretty_errors.py +186 -0
  220. hud/utils/tests/test_tool_shorthand.py +154 -0
  221. hud/utils/tests/test_version.py +1 -1
  222. hud/utils/types.py +20 -0
  223. hud/version.py +1 -1
  224. hud_python-0.5.13.dist-info/METADATA +264 -0
  225. hud_python-0.5.13.dist-info/RECORD +305 -0
  226. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
  227. hud/agents/langchain.py +0 -261
  228. hud/agents/lite_llm.py +0 -72
  229. hud/cli/rl/__init__.py +0 -180
  230. hud/cli/rl/config.py +0 -101
  231. hud/cli/rl/display.py +0 -133
  232. hud/cli/rl/gpu.py +0 -63
  233. hud/cli/rl/gpu_utils.py +0 -321
  234. hud/cli/rl/local_runner.py +0 -595
  235. hud/cli/rl/presets.py +0 -96
  236. hud/cli/rl/remote_runner.py +0 -463
  237. hud/cli/rl/rl_api.py +0 -150
  238. hud/cli/rl/vllm.py +0 -177
  239. hud/cli/rl/wait_utils.py +0 -89
  240. hud/datasets/parallel.py +0 -687
  241. hud/misc/__init__.py +0 -1
  242. hud/misc/claude_plays_pokemon.py +0 -292
  243. hud/otel/__init__.py +0 -35
  244. hud/otel/collector.py +0 -142
  245. hud/otel/config.py +0 -181
  246. hud/otel/context.py +0 -570
  247. hud/otel/exporters.py +0 -369
  248. hud/otel/instrumentation.py +0 -135
  249. hud/otel/processors.py +0 -121
  250. hud/otel/tests/__init__.py +0 -1
  251. hud/otel/tests/test_processors.py +0 -197
  252. hud/rl/README.md +0 -30
  253. hud/rl/__init__.py +0 -1
  254. hud/rl/actor.py +0 -176
  255. hud/rl/buffer.py +0 -405
  256. hud/rl/chat_template.jinja +0 -101
  257. hud/rl/config.py +0 -192
  258. hud/rl/distributed.py +0 -132
  259. hud/rl/learner.py +0 -637
  260. hud/rl/tests/__init__.py +0 -1
  261. hud/rl/tests/test_learner.py +0 -186
  262. hud/rl/train.py +0 -382
  263. hud/rl/types.py +0 -101
  264. hud/rl/utils/start_vllm_server.sh +0 -30
  265. hud/rl/utils.py +0 -524
  266. hud/rl/vllm_adapter.py +0 -143
  267. hud/telemetry/job.py +0 -352
  268. hud/telemetry/replay.py +0 -74
  269. hud/telemetry/tests/test_replay.py +0 -40
  270. hud/telemetry/tests/test_trace.py +0 -63
  271. hud/telemetry/trace.py +0 -158
  272. hud/utils/agent_factories.py +0 -86
  273. hud/utils/async_utils.py +0 -65
  274. hud/utils/group_eval.py +0 -223
  275. hud/utils/progress.py +0 -149
  276. hud/utils/tasks.py +0 -127
  277. hud/utils/tests/test_async_utils.py +0 -173
  278. hud/utils/tests/test_progress.py +0 -261
  279. hud_python-0.4.45.dist-info/METADATA +0 -552
  280. hud_python-0.4.45.dist-info/RECORD +0 -228
  281. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
  282. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
hud/cli/rl/gpu_utils.py DELETED
@@ -1,321 +0,0 @@
1
- """GPU utilities for DDP training."""
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
- import subprocess
7
- import time
8
- from typing import TYPE_CHECKING, Any
9
-
10
- from hud.utils.hud_console import HUDConsole
11
-
12
- if TYPE_CHECKING:
13
- from hud.rl.config import Config
14
- hud_console = HUDConsole(logging.getLogger(__name__))
15
-
16
-
17
- def get_gpu_memory_info() -> dict[int, dict[str, Any]]:
18
- """Get memory usage information for all GPUs."""
19
-
20
- gpu_memory = {}
21
- try:
22
- # Get memory info for all GPUs
23
- cmd = [
24
- "nvidia-smi",
25
- "--query-gpu=index,memory.used,memory.total,memory.free",
26
- "--format=csv,noheader,nounits",
27
- ]
28
- result = subprocess.run(cmd, capture_output=True, text=True, check=True) # noqa: S603
29
-
30
- for line in result.stdout.strip().split("\n"):
31
- if not line:
32
- continue
33
- parts = line.split(", ")
34
- if len(parts) >= 4:
35
- gpu_idx = int(parts[0])
36
- memory_used = float(parts[1])
37
- memory_total = float(parts[2])
38
- memory_free = float(parts[3])
39
- gpu_memory[gpu_idx] = {
40
- "used_mb": memory_used,
41
- "total_mb": memory_total,
42
- "free_mb": memory_free,
43
- "used_pct": (memory_used / memory_total) * 100,
44
- }
45
-
46
- # Get process information per GPU
47
- for gpu_idx in gpu_memory: # noqa: PLC0206
48
- cmd = [
49
- "nvidia-smi",
50
- "-i",
51
- str(gpu_idx),
52
- "--query-compute-apps=pid,used_memory",
53
- "--format=csv,noheader,nounits",
54
- ]
55
- try:
56
- result = subprocess.run(cmd, capture_output=True, text=True, check=True) # noqa: S603
57
- processes = []
58
- for line in result.stdout.strip().split("\n"):
59
- if not line:
60
- continue
61
- parts = line.split(", ")
62
- if len(parts) >= 2:
63
- pid = int(parts[0])
64
- memory_mb = float(parts[1])
65
- processes.append({"pid": pid, "memory_mb": memory_mb})
66
- gpu_memory[gpu_idx]["processes"] = processes
67
- except Exception as e:
68
- hud_console.error(f"Failed to get process info for GPU {gpu_idx}: {e}")
69
- gpu_memory[gpu_idx]["processes"] = []
70
-
71
- except Exception as e:
72
- hud_console.error(f"Failed to get GPU memory info {e}")
73
- return {}
74
-
75
- return gpu_memory
76
-
77
-
78
- def health_check_gpus(gpu_indices: list[int]) -> dict[str, Any]:
79
- """Perform health check on specified GPUs including memory status.
80
-
81
- Returns:
82
- Dict with:
83
- - healthy_gpus: List of healthy GPU indices
84
- - unhealthy_gpus: Dict of unhealthy GPU index -> error message
85
- - all_healthy: Boolean indicating if all GPUs are healthy
86
- - memory_issues: Boolean indicating if there are memory issues
87
- """
88
- import torch
89
- from rich.console import Console
90
- from rich.table import Table
91
-
92
- console = Console()
93
-
94
- console.print("\n[bold cyan]🏥 GPU Health Check[/bold cyan]")
95
-
96
- # First get memory info
97
- memory_info = get_gpu_memory_info()
98
-
99
- healthy_gpus = []
100
- unhealthy_gpus = {}
101
- memory_issues = []
102
-
103
- # Create a table for results
104
- table = Table(title="GPU Health Status")
105
- table.add_column("GPU", style="cyan")
106
- table.add_column("Memory Usage", style="yellow")
107
- table.add_column("Status", style="green")
108
- table.add_column("Details", style="yellow")
109
-
110
- for gpu_idx in gpu_indices:
111
- # Memory info
112
- mem_str = "Unknown"
113
- if gpu_idx in memory_info:
114
- mem = memory_info[gpu_idx]
115
- used_gb = mem["used_mb"] / 1024
116
- total_gb = mem["total_mb"] / 1024
117
- mem_str = f"{used_gb:.1f}/{total_gb:.1f} GB ({mem['used_pct']:.0f}%)"
118
-
119
- # Check for high memory usage
120
- if mem["used_pct"] > 70:
121
- memory_issues.append(gpu_idx)
122
- proc_info = f" ({len(mem['processes'])} processes)" if mem["processes"] else ""
123
- unhealthy_gpus[gpu_idx] = f"High memory usage{proc_info}"
124
- table.add_row(
125
- f"GPU {gpu_idx}", mem_str, "❌ Unhealthy", f"High memory usage{proc_info}"
126
- )
127
- continue
128
-
129
- # If no severe memory issue, do accessibility test
130
- try:
131
- # Try to allocate a small tensor on the GPU
132
- torch.cuda.set_device(gpu_idx)
133
- device = torch.device(f"cuda:{gpu_idx}")
134
-
135
- # Test basic allocation
136
- test_tensor = torch.zeros(100, 100, device=device)
137
-
138
- # Test computation
139
- result = torch.matmul(test_tensor, test_tensor)
140
-
141
- # Force synchronization
142
- torch.cuda.synchronize(device)
143
-
144
- # Clean up
145
- del test_tensor, result
146
- torch.cuda.empty_cache()
147
-
148
- healthy_gpus.append(gpu_idx)
149
- table.add_row(f"GPU {gpu_idx}", mem_str, "✅ Healthy", "Passed all tests")
150
-
151
- except Exception as e:
152
- error_msg = str(e)
153
- if "busy or unavailable" in error_msg:
154
- short_msg = "Device busy or unavailable"
155
- elif "out of memory" in error_msg:
156
- short_msg = "Insufficient memory"
157
- else:
158
- short_msg = error_msg[:50] + "..." if len(error_msg) > 50 else error_msg
159
-
160
- unhealthy_gpus[gpu_idx] = short_msg
161
- table.add_row(f"GPU {gpu_idx}", mem_str, "❌ Unhealthy", short_msg)
162
-
163
- # Small delay between GPU checks
164
- time.sleep(0.1)
165
-
166
- console.print(table)
167
-
168
- return {
169
- "healthy_gpus": healthy_gpus,
170
- "unhealthy_gpus": unhealthy_gpus,
171
- "all_healthy": len(unhealthy_gpus) == 0,
172
- "memory_issues": memory_issues,
173
- }
174
-
175
-
176
- def calculate_optimal_gpu_allocation(gpu_info: dict[str, Any], config: Config) -> dict[str, Any]:
177
- """Calculate optimal GPU allocation for DDP GRPO training.
178
-
179
- Key insight: In GRPO, we want to process groups in parallel.
180
- Optimal case: num_gpus = num_groups (each GPU processes 1 group).
181
- """
182
- devices = gpu_info["devices"]
183
- available_gpus = [device["index"] for device in devices]
184
-
185
- # Need at least 2 GPUs (1 for training, 1 for vLLM)
186
- if len(available_gpus) < 2:
187
- return {"use_ddp": False, "reason": "Need at least 2 GPUs"}
188
-
189
- # Reserve last GPU for vLLM
190
- vllm_gpu = available_gpus[-1]
191
- training_gpus = available_gpus[:-1]
192
-
193
- # Calculate number of groups
194
- batch_size = config.training.batch_size
195
- group_size = config.training.group_size
196
- num_groups = batch_size // group_size
197
-
198
- if num_groups == 0:
199
- num_groups = 1
200
-
201
- # Optimal: Use exactly num_groups GPUs (each processes 1 group in parallel)
202
- # But cap at available training GPUs
203
- optimal_gpu_count = min(len(training_gpus), num_groups)
204
-
205
- # Only use DDP if we have more than 1 group and more than 1 GPU
206
- use_ddp = optimal_gpu_count > 1 and num_groups > 1
207
-
208
- if not use_ddp:
209
- # Single GPU training
210
- return {
211
- "use_ddp": False,
212
- "reason": f"Single GPU sufficient for {num_groups} group(s)",
213
- "training_gpus": [training_gpus[0]],
214
- "vllm_gpu": vllm_gpu,
215
- "num_groups": num_groups,
216
- }
217
-
218
- # Use optimal number of GPUs for DDP
219
- training_gpus = training_gpus[:optimal_gpu_count]
220
-
221
- return {
222
- "use_ddp": True,
223
- "training_gpus": training_gpus,
224
- "vllm_gpu": vllm_gpu,
225
- "num_groups": num_groups,
226
- "groups_per_gpu": num_groups / len(training_gpus),
227
- "parallel_efficiency": min(
228
- 1.0, num_groups / len(training_gpus)
229
- ), # 1.0 = perfect load balance
230
- }
231
-
232
-
233
- def adjust_config_for_ddp(config: Config, num_gpus: int) -> Config:
234
- """Adjust configuration for optimal DDP performance.
235
-
236
- Scaling rule:
237
- - For 1 GPU: batch_size = 2 * group_size
238
- - For N GPUs (N > 1): batch_size = N * group_size
239
-
240
- This ensures each GPU processes exactly 1 group in parallel for optimal performance.
241
- """
242
- group_size = config.training.group_size
243
-
244
- # Apply scaling rule
245
- if num_gpus == 1:
246
- # Special case: 2 groups for single GPU
247
- groups_per_gpu = 2
248
- config.training.batch_size = 2 * group_size
249
- else:
250
- groups_per_gpu = config.training.batch_size // group_size
251
- # Multi-GPU: each GPU processes groups_per_gpu groups
252
- config.training.batch_size = num_gpus * group_size * groups_per_gpu
253
-
254
- # Update max_parallel_episodes to match
255
- config.actor.max_parallel_episodes = config.training.batch_size
256
-
257
- config.training.num_gpus = num_gpus
258
-
259
- # Log the adjustment
260
- from rich.console import Console
261
-
262
- console = Console()
263
- console.print(
264
- f"\n[cyan]📊 Adjusted batch_size to {config.training.batch_size} ({config.training.batch_size // group_size} groups)[/cyan]" # noqa: E501
265
- )
266
- console.print(
267
- f"[cyan] Each of the {num_gpus} GPU(s) will process {groups_per_gpu} group(s) in parallel[/cyan]" # noqa: E501
268
- )
269
-
270
- return config
271
-
272
-
273
- def kill_high_memory_processes(memory_threshold: float = 70.0) -> int:
274
- """Kill all GPU processes using more than threshold% memory.
275
-
276
- Returns:
277
- Number of processes killed
278
- """
279
- from rich.console import Console
280
-
281
- console = Console()
282
-
283
- memory_info = get_gpu_memory_info()
284
- killed_count = 0
285
-
286
- for gpu_idx, info in memory_info.items():
287
- if info["used_pct"] > memory_threshold:
288
- for proc in info.get("processes", []):
289
- pid = proc["pid"]
290
- try:
291
- # Try graceful termination first
292
- subprocess.run(["kill", "-TERM", str(pid)], check=False, capture_output=True) # noqa: S603, S607
293
- killed_count += 1
294
- console.print(
295
- f"[yellow]Terminating PID {pid} on GPU {gpu_idx} ({proc['memory_mb'] / 1024:.1f} GB)[/yellow]" # noqa: E501
296
- )
297
- except Exception as e:
298
- console.print(f"[red]Failed to kill PID {pid}: {e}[/red]")
299
-
300
- if killed_count > 0:
301
- console.print(f"\n[yellow]Sent termination signal to {killed_count} processes...[/yellow]")
302
- time.sleep(3)
303
-
304
- # Force kill any remaining
305
- for info in memory_info.values():
306
- for proc in info.get("processes", []):
307
- pid = proc["pid"]
308
- try:
309
- # Check if still running
310
- subprocess.run( # noqa: S603
311
- ["kill", "-0", str(pid)], # noqa: S607
312
- check=True,
313
- capture_output=True,
314
- )
315
- # If no error, process is still running, force kill
316
- subprocess.run(["kill", "-KILL", str(pid)], check=False) # noqa: S603, S607
317
- console.print(f"[red]Force killed PID {pid}[/red]")
318
- except Exception:
319
- hud_console.error(f"Failed to kill PID {pid}")
320
-
321
- return killed_count