hud-python 0.4.45__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (274) hide show
  1. hud/__init__.py +27 -7
  2. hud/agents/__init__.py +11 -5
  3. hud/agents/base.py +220 -500
  4. hud/agents/claude.py +200 -240
  5. hud/agents/gemini.py +275 -0
  6. hud/agents/gemini_cua.py +335 -0
  7. hud/agents/grounded_openai.py +98 -100
  8. hud/agents/misc/integration_test_agent.py +51 -20
  9. hud/agents/misc/response_agent.py +41 -36
  10. hud/agents/openai.py +291 -292
  11. hud/agents/{openai_chat_generic.py → openai_chat.py} +80 -34
  12. hud/agents/operator.py +211 -0
  13. hud/agents/tests/conftest.py +133 -0
  14. hud/agents/tests/test_base.py +300 -622
  15. hud/agents/tests/test_base_runtime.py +233 -0
  16. hud/agents/tests/test_claude.py +379 -210
  17. hud/agents/tests/test_client.py +9 -10
  18. hud/agents/tests/test_gemini.py +369 -0
  19. hud/agents/tests/test_grounded_openai_agent.py +65 -50
  20. hud/agents/tests/test_openai.py +376 -140
  21. hud/agents/tests/test_operator.py +362 -0
  22. hud/agents/tests/test_run_eval.py +179 -0
  23. hud/cli/__init__.py +461 -545
  24. hud/cli/analyze.py +43 -5
  25. hud/cli/build.py +664 -110
  26. hud/cli/debug.py +8 -5
  27. hud/cli/dev.py +882 -734
  28. hud/cli/eval.py +782 -668
  29. hud/cli/flows/dev.py +167 -0
  30. hud/cli/flows/init.py +191 -0
  31. hud/cli/flows/tasks.py +153 -56
  32. hud/cli/flows/templates.py +151 -0
  33. hud/cli/flows/tests/__init__.py +1 -0
  34. hud/cli/flows/tests/test_dev.py +126 -0
  35. hud/cli/init.py +60 -58
  36. hud/cli/push.py +29 -11
  37. hud/cli/rft.py +311 -0
  38. hud/cli/rft_status.py +145 -0
  39. hud/cli/tests/test_analyze.py +5 -5
  40. hud/cli/tests/test_analyze_metadata.py +3 -2
  41. hud/cli/tests/test_analyze_module.py +120 -0
  42. hud/cli/tests/test_build.py +108 -6
  43. hud/cli/tests/test_build_failure.py +41 -0
  44. hud/cli/tests/test_build_module.py +50 -0
  45. hud/cli/tests/test_cli_init.py +6 -1
  46. hud/cli/tests/test_cli_more_wrappers.py +30 -0
  47. hud/cli/tests/test_cli_root.py +140 -0
  48. hud/cli/tests/test_convert.py +361 -0
  49. hud/cli/tests/test_debug.py +12 -10
  50. hud/cli/tests/test_dev.py +197 -0
  51. hud/cli/tests/test_eval.py +251 -0
  52. hud/cli/tests/test_eval_bedrock.py +51 -0
  53. hud/cli/tests/test_init.py +124 -0
  54. hud/cli/tests/test_main_module.py +11 -5
  55. hud/cli/tests/test_mcp_server.py +12 -100
  56. hud/cli/tests/test_push_happy.py +74 -0
  57. hud/cli/tests/test_push_wrapper.py +23 -0
  58. hud/cli/tests/test_registry.py +1 -1
  59. hud/cli/tests/test_utils.py +1 -1
  60. hud/cli/{rl → utils}/celebrate.py +14 -12
  61. hud/cli/utils/config.py +18 -1
  62. hud/cli/utils/docker.py +130 -4
  63. hud/cli/utils/env_check.py +9 -9
  64. hud/cli/utils/git.py +136 -0
  65. hud/cli/utils/interactive.py +39 -5
  66. hud/cli/utils/metadata.py +69 -0
  67. hud/cli/utils/runner.py +1 -1
  68. hud/cli/utils/server.py +2 -2
  69. hud/cli/utils/source_hash.py +3 -3
  70. hud/cli/utils/tasks.py +4 -1
  71. hud/cli/utils/tests/__init__.py +0 -0
  72. hud/cli/utils/tests/test_config.py +58 -0
  73. hud/cli/utils/tests/test_docker.py +93 -0
  74. hud/cli/utils/tests/test_docker_hints.py +71 -0
  75. hud/cli/utils/tests/test_env_check.py +74 -0
  76. hud/cli/utils/tests/test_environment.py +42 -0
  77. hud/cli/utils/tests/test_git.py +142 -0
  78. hud/cli/utils/tests/test_interactive_module.py +60 -0
  79. hud/cli/utils/tests/test_local_runner.py +50 -0
  80. hud/cli/utils/tests/test_logging_utils.py +23 -0
  81. hud/cli/utils/tests/test_metadata.py +49 -0
  82. hud/cli/utils/tests/test_package_runner.py +35 -0
  83. hud/cli/utils/tests/test_registry_utils.py +49 -0
  84. hud/cli/utils/tests/test_remote_runner.py +25 -0
  85. hud/cli/utils/tests/test_runner_modules.py +52 -0
  86. hud/cli/utils/tests/test_source_hash.py +36 -0
  87. hud/cli/utils/tests/test_tasks.py +80 -0
  88. hud/cli/utils/version_check.py +258 -0
  89. hud/cli/{rl → utils}/viewer.py +2 -2
  90. hud/clients/README.md +12 -11
  91. hud/clients/__init__.py +4 -3
  92. hud/clients/base.py +166 -26
  93. hud/clients/environment.py +51 -0
  94. hud/clients/fastmcp.py +13 -6
  95. hud/clients/mcp_use.py +40 -15
  96. hud/clients/tests/test_analyze_scenarios.py +206 -0
  97. hud/clients/tests/test_protocol.py +9 -3
  98. hud/datasets/__init__.py +23 -20
  99. hud/datasets/loader.py +327 -0
  100. hud/datasets/runner.py +192 -105
  101. hud/datasets/tests/__init__.py +0 -0
  102. hud/datasets/tests/test_loader.py +221 -0
  103. hud/datasets/tests/test_utils.py +315 -0
  104. hud/datasets/utils.py +270 -90
  105. hud/environment/__init__.py +50 -0
  106. hud/environment/connection.py +206 -0
  107. hud/environment/connectors/__init__.py +33 -0
  108. hud/environment/connectors/base.py +68 -0
  109. hud/environment/connectors/local.py +177 -0
  110. hud/environment/connectors/mcp_config.py +109 -0
  111. hud/environment/connectors/openai.py +101 -0
  112. hud/environment/connectors/remote.py +172 -0
  113. hud/environment/environment.py +694 -0
  114. hud/environment/integrations/__init__.py +45 -0
  115. hud/environment/integrations/adk.py +67 -0
  116. hud/environment/integrations/anthropic.py +196 -0
  117. hud/environment/integrations/gemini.py +92 -0
  118. hud/environment/integrations/langchain.py +82 -0
  119. hud/environment/integrations/llamaindex.py +68 -0
  120. hud/environment/integrations/openai.py +238 -0
  121. hud/environment/mock.py +306 -0
  122. hud/environment/router.py +112 -0
  123. hud/environment/scenarios.py +493 -0
  124. hud/environment/tests/__init__.py +1 -0
  125. hud/environment/tests/test_connection.py +317 -0
  126. hud/environment/tests/test_connectors.py +218 -0
  127. hud/environment/tests/test_environment.py +161 -0
  128. hud/environment/tests/test_integrations.py +257 -0
  129. hud/environment/tests/test_local_connectors.py +201 -0
  130. hud/environment/tests/test_scenarios.py +280 -0
  131. hud/environment/tests/test_tools.py +208 -0
  132. hud/environment/types.py +23 -0
  133. hud/environment/utils/__init__.py +35 -0
  134. hud/environment/utils/formats.py +215 -0
  135. hud/environment/utils/schema.py +171 -0
  136. hud/environment/utils/tool_wrappers.py +113 -0
  137. hud/eval/__init__.py +67 -0
  138. hud/eval/context.py +674 -0
  139. hud/eval/display.py +299 -0
  140. hud/eval/instrument.py +185 -0
  141. hud/eval/manager.py +466 -0
  142. hud/eval/parallel.py +268 -0
  143. hud/eval/task.py +340 -0
  144. hud/eval/tests/__init__.py +1 -0
  145. hud/eval/tests/test_context.py +178 -0
  146. hud/eval/tests/test_eval.py +210 -0
  147. hud/eval/tests/test_manager.py +152 -0
  148. hud/eval/tests/test_parallel.py +168 -0
  149. hud/eval/tests/test_task.py +145 -0
  150. hud/eval/types.py +63 -0
  151. hud/eval/utils.py +183 -0
  152. hud/patches/__init__.py +19 -0
  153. hud/patches/mcp_patches.py +151 -0
  154. hud/patches/warnings.py +54 -0
  155. hud/samples/browser.py +4 -4
  156. hud/server/__init__.py +2 -1
  157. hud/server/low_level.py +2 -1
  158. hud/server/router.py +164 -0
  159. hud/server/server.py +567 -80
  160. hud/server/tests/test_mcp_server_integration.py +11 -11
  161. hud/server/tests/test_mcp_server_more.py +1 -1
  162. hud/server/tests/test_server_extra.py +2 -0
  163. hud/settings.py +45 -3
  164. hud/shared/exceptions.py +36 -10
  165. hud/shared/hints.py +26 -1
  166. hud/shared/requests.py +15 -3
  167. hud/shared/tests/test_exceptions.py +40 -31
  168. hud/shared/tests/test_hints.py +167 -0
  169. hud/telemetry/__init__.py +20 -19
  170. hud/telemetry/exporter.py +201 -0
  171. hud/telemetry/instrument.py +158 -253
  172. hud/telemetry/tests/test_eval_telemetry.py +356 -0
  173. hud/telemetry/tests/test_exporter.py +258 -0
  174. hud/telemetry/tests/test_instrument.py +401 -0
  175. hud/tools/__init__.py +16 -2
  176. hud/tools/apply_patch.py +639 -0
  177. hud/tools/base.py +54 -4
  178. hud/tools/bash.py +2 -2
  179. hud/tools/computer/__init__.py +4 -0
  180. hud/tools/computer/anthropic.py +2 -2
  181. hud/tools/computer/gemini.py +385 -0
  182. hud/tools/computer/hud.py +23 -6
  183. hud/tools/computer/openai.py +20 -21
  184. hud/tools/computer/qwen.py +434 -0
  185. hud/tools/computer/settings.py +37 -0
  186. hud/tools/edit.py +3 -7
  187. hud/tools/executors/base.py +4 -2
  188. hud/tools/executors/pyautogui.py +1 -1
  189. hud/tools/grounding/grounded_tool.py +13 -18
  190. hud/tools/grounding/grounder.py +10 -31
  191. hud/tools/grounding/tests/test_grounded_tool.py +26 -44
  192. hud/tools/jupyter.py +330 -0
  193. hud/tools/playwright.py +18 -3
  194. hud/tools/shell.py +308 -0
  195. hud/tools/tests/test_apply_patch.py +718 -0
  196. hud/tools/tests/test_computer.py +4 -9
  197. hud/tools/tests/test_computer_actions.py +24 -2
  198. hud/tools/tests/test_jupyter_tool.py +181 -0
  199. hud/tools/tests/test_shell.py +596 -0
  200. hud/tools/tests/test_submit.py +85 -0
  201. hud/tools/tests/test_types.py +193 -0
  202. hud/tools/types.py +21 -1
  203. hud/types.py +167 -57
  204. hud/utils/__init__.py +2 -0
  205. hud/utils/env.py +67 -0
  206. hud/utils/hud_console.py +61 -3
  207. hud/utils/mcp.py +15 -58
  208. hud/utils/strict_schema.py +162 -0
  209. hud/utils/tests/test_init.py +1 -2
  210. hud/utils/tests/test_mcp.py +1 -28
  211. hud/utils/tests/test_pretty_errors.py +186 -0
  212. hud/utils/tests/test_tool_shorthand.py +154 -0
  213. hud/utils/tests/test_version.py +1 -1
  214. hud/utils/types.py +20 -0
  215. hud/version.py +1 -1
  216. hud_python-0.5.1.dist-info/METADATA +264 -0
  217. hud_python-0.5.1.dist-info/RECORD +299 -0
  218. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/WHEEL +1 -1
  219. hud/agents/langchain.py +0 -261
  220. hud/agents/lite_llm.py +0 -72
  221. hud/cli/rl/__init__.py +0 -180
  222. hud/cli/rl/config.py +0 -101
  223. hud/cli/rl/display.py +0 -133
  224. hud/cli/rl/gpu.py +0 -63
  225. hud/cli/rl/gpu_utils.py +0 -321
  226. hud/cli/rl/local_runner.py +0 -595
  227. hud/cli/rl/presets.py +0 -96
  228. hud/cli/rl/remote_runner.py +0 -463
  229. hud/cli/rl/rl_api.py +0 -150
  230. hud/cli/rl/vllm.py +0 -177
  231. hud/cli/rl/wait_utils.py +0 -89
  232. hud/datasets/parallel.py +0 -687
  233. hud/misc/__init__.py +0 -1
  234. hud/misc/claude_plays_pokemon.py +0 -292
  235. hud/otel/__init__.py +0 -35
  236. hud/otel/collector.py +0 -142
  237. hud/otel/config.py +0 -181
  238. hud/otel/context.py +0 -570
  239. hud/otel/exporters.py +0 -369
  240. hud/otel/instrumentation.py +0 -135
  241. hud/otel/processors.py +0 -121
  242. hud/otel/tests/__init__.py +0 -1
  243. hud/otel/tests/test_processors.py +0 -197
  244. hud/rl/README.md +0 -30
  245. hud/rl/__init__.py +0 -1
  246. hud/rl/actor.py +0 -176
  247. hud/rl/buffer.py +0 -405
  248. hud/rl/chat_template.jinja +0 -101
  249. hud/rl/config.py +0 -192
  250. hud/rl/distributed.py +0 -132
  251. hud/rl/learner.py +0 -637
  252. hud/rl/tests/__init__.py +0 -1
  253. hud/rl/tests/test_learner.py +0 -186
  254. hud/rl/train.py +0 -382
  255. hud/rl/types.py +0 -101
  256. hud/rl/utils/start_vllm_server.sh +0 -30
  257. hud/rl/utils.py +0 -524
  258. hud/rl/vllm_adapter.py +0 -143
  259. hud/telemetry/job.py +0 -352
  260. hud/telemetry/replay.py +0 -74
  261. hud/telemetry/tests/test_replay.py +0 -40
  262. hud/telemetry/tests/test_trace.py +0 -63
  263. hud/telemetry/trace.py +0 -158
  264. hud/utils/agent_factories.py +0 -86
  265. hud/utils/async_utils.py +0 -65
  266. hud/utils/group_eval.py +0 -223
  267. hud/utils/progress.py +0 -149
  268. hud/utils/tasks.py +0 -127
  269. hud/utils/tests/test_async_utils.py +0 -173
  270. hud/utils/tests/test_progress.py +0 -261
  271. hud_python-0.4.45.dist-info/METADATA +0 -552
  272. hud_python-0.4.45.dist-info/RECORD +0 -228
  273. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/entry_points.txt +0 -0
  274. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/licenses/LICENSE +0 -0
hud/rl/utils.py DELETED
@@ -1,524 +0,0 @@
1
- """Utility functions for RL training."""
2
-
3
- from __future__ import annotations
4
-
5
- import base64
6
- import io
7
- import logging
8
- import os
9
- import random
10
- from pathlib import Path
11
- from typing import TYPE_CHECKING, Any
12
-
13
- import numpy as np
14
- import torch
15
- from PIL import Image
16
- from transformers.utils.chat_template_utils import render_jinja_template
17
-
18
- from hud.utils.hud_console import HUDConsole
19
-
20
- from .types import TrainingSample
21
-
22
- if TYPE_CHECKING:
23
- from hud.types import Trace
24
-
25
- from .config import Config
26
-
27
- logger = logging.getLogger(__name__)
28
- hud_console = HUDConsole(logger)
29
-
30
-
31
- def set_seed(seed: int) -> None:
32
- """Set random seeds for reproducibility."""
33
- random.seed(seed)
34
- torch.manual_seed(seed)
35
- if torch.cuda.is_available():
36
- torch.cuda.manual_seed_all(seed)
37
-
38
-
39
- def load_chat_template(path: str) -> str:
40
- """Load chat template from file."""
41
- with open(path) as f:
42
- return f.read()
43
-
44
-
45
- def ensure_dir(path: str) -> None:
46
- """Create directory if it doesn't exist."""
47
- os.makedirs(path, exist_ok=True)
48
-
49
-
50
- def get_memory_usage() -> float:
51
- if torch.cuda.is_available():
52
- torch.cuda.synchronize()
53
- return torch.cuda.memory_allocated() / 1024**3
54
- return 0.0
55
-
56
-
57
- def get_gpu_utilization() -> float:
58
- """Get current GPU utilization percentage (0-100)."""
59
- if not torch.cuda.is_available():
60
- return 0.0
61
-
62
- try:
63
- import nvidia_ml_py as nvml # type: ignore
64
-
65
- nvml.nvmlInit()
66
- device_id = torch.cuda.current_device()
67
- handle = nvml.nvmlDeviceGetHandleByIndex(device_id)
68
- util = nvml.nvmlDeviceGetUtilizationRates(handle)
69
- return float(util.gpu)
70
- except Exception:
71
- # Fallback: estimate based on memory usage
72
- # This is less accurate but works without nvidia-ml-py
73
- return min(100.0, (torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()) * 100)
74
-
75
-
76
- def aggregate_metrics_across_ranks(
77
- metrics: Any, metrics_to_aggregate: list[str] | None = None
78
- ) -> None:
79
- """Aggregate metrics across all ranks for proper distributed statistics.
80
-
81
- Args:
82
- metrics: TrainingMetrics object to update in-place
83
- metrics_to_aggregate: List of metric names to aggregate. If None, aggregates all numeric metrics.
84
-
85
- This function:
86
- 1. Gathers metric values from all ranks
87
- 2. Computes proper mean/std across all GPUs
88
- 3. Updates the metrics object in-place (only on rank 0)
89
- """ # noqa: E501
90
- from hud.rl.distributed import get_local_rank, get_world_size, is_main_process
91
-
92
- if get_world_size() <= 1:
93
- return # Nothing to aggregate in single GPU mode
94
-
95
- # Default metrics that typically vary across GPUs
96
- if metrics_to_aggregate is None:
97
- metrics_to_aggregate = [
98
- "training_time",
99
- "samples_per_second",
100
- "gpu_util",
101
- "gpu_memory",
102
- "grad_norm",
103
- # Include core training scalars
104
- "loss",
105
- "kl",
106
- "entropy",
107
- "tokens",
108
- "policy_ratio",
109
- ]
110
-
111
- # Collect current values from this rank
112
- local_values = {}
113
- for metric_name in metrics_to_aggregate:
114
- if hasattr(metrics, metric_name):
115
- metric_obj = getattr(metrics, metric_name)
116
- # Get the last value if available, otherwise 0
117
- local_values[metric_name] = metric_obj.values[-1] if metric_obj.values else 0.0
118
-
119
- # Convert to tensor for distributed gathering
120
- values_tensor = torch.tensor(
121
- list(local_values.values()), device=f"cuda:{get_local_rank()}", dtype=torch.float32
122
- )
123
-
124
- # Gather from all ranks using NCCL-supported all_gather
125
- world_size = get_world_size()
126
- gather_list = [torch.zeros_like(values_tensor) for _ in range(world_size)]
127
- torch.distributed.all_gather(gather_list, values_tensor)
128
-
129
- # Update metrics on main process only
130
- if is_main_process():
131
- # Reshape: [num_gpus, num_metrics]
132
- all_values = torch.stack(gather_list).cpu().numpy()
133
-
134
- # Update each metric with aggregated values
135
- for i, metric_name in enumerate(local_values.keys()):
136
- metric_obj = getattr(metrics, metric_name)
137
- gpu_values = all_values[:, i].tolist()
138
-
139
- # Replace last value with cross-rank mean for reporting
140
- if len(metric_obj.values) == 0:
141
- metric_obj.values.append(0.0)
142
- metric_obj.values[-1] = float(sum(gpu_values) / len(gpu_values))
143
- # Recompute mean/std across history using updated last value
144
- metric_obj.mean = float(sum(metric_obj.values) / len(metric_obj.values))
145
- variance = sum((x - metric_obj.mean) ** 2 for x in metric_obj.values) / len(
146
- metric_obj.values
147
- )
148
- metric_obj.std = float(variance**0.5)
149
-
150
-
151
- def b64_to_pil(b64_str: str) -> Image.Image:
152
- """Convert base64 string to PIL Image."""
153
- return Image.open(io.BytesIO(base64.b64decode(b64_str))).convert("RGB")
154
-
155
-
156
- def build_assistant_masks(
157
- input_ids: list[list[int]],
158
- tokenizer: Any,
159
- ) -> list[list[int]]:
160
- """
161
- Build assistant masks from token IDs by finding assistant turns.
162
-
163
- Args:
164
- input_ids: List of token sequences
165
- tokenizer: Tokenizer to decode tokens and get special token IDs
166
- verbose: Whether to print verbose information
167
-
168
- Returns:
169
- List of binary masks indicating assistant tokens
170
- """
171
- id_im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
172
- id_im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
173
- id_assistant = tokenizer.convert_tokens_to_ids("assistant")
174
-
175
- assistant_masks: list[list[int]] = []
176
-
177
- for seq in input_ids:
178
- mask = [0] * len(seq)
179
- i_tok = 0
180
- assistant_turn_count = 0
181
-
182
- while i_tok < len(seq):
183
- # Detect start of assistant turn
184
- if (
185
- seq[i_tok] == id_im_start
186
- and i_tok + 1 < len(seq)
187
- and seq[i_tok + 1] == id_assistant
188
- ):
189
- assistant_turn_count += 1
190
-
191
- # Skip '<|im_start|>', 'assistant' and possible newline token
192
- i_tok += 2
193
- # Check for newline after 'assistant'
194
- if i_tok < len(seq) and tokenizer.decode([seq[i_tok]]) == "\n":
195
- i_tok += 1
196
-
197
- # Skip leading spaces after assistant\n
198
- while i_tok < len(seq) and tokenizer.decode([seq[i_tok]]).strip() == "":
199
- i_tok += 1
200
-
201
- assistant_content_start = i_tok
202
-
203
- # Mark tokens until we hit <|im_end|>
204
- content_end = i_tok
205
- while i_tok < len(seq) and seq[i_tok] != id_im_end:
206
- content_end = i_tok + 1 # Track last non-<|im_end|> position
207
- mask[i_tok] = 1
208
- i_tok += 1
209
-
210
- # Remove trailing spaces from the mask
211
- while content_end > assistant_content_start:
212
- if (
213
- mask[content_end - 1] == 1
214
- and tokenizer.decode([seq[content_end - 1]]).strip() == ""
215
- ):
216
- mask[content_end - 1] = 0
217
- content_end -= 1
218
- else:
219
- break
220
-
221
- # Skip the <|im_end|> token
222
- i_tok += 1
223
- else:
224
- i_tok += 1
225
-
226
- assistant_masks.append(mask)
227
-
228
- return assistant_masks
229
-
230
-
231
- def prepare_conversation_history(
232
- conversation_history: list[dict[str, Any]],
233
- ) -> tuple[list[dict[str, Any]], list[Image.Image]]:
234
- """Sanitize conversation history to avoid vLLM errors."""
235
- sanitized_messages = []
236
- images = []
237
- for m in conversation_history:
238
- if "tool_calls" in m:
239
- m = {
240
- "role": m["role"],
241
- "content": m.get("content", ""),
242
- "tool_calls": [
243
- tc.model_dump() if not isinstance(tc, dict) else tc
244
- for tc in m.get("tool_calls", [])
245
- ],
246
- }
247
- elif m.get("role") == "user":
248
- user_content = m.get("content", [])
249
- for c in user_content:
250
- if isinstance(c, dict) and c.get("type") == "image_url":
251
- image_url = c.get("image_url", {})
252
- url = image_url.get("url", "")
253
- if url.startswith("data:image"):
254
- data = url.split(",", 1)[1] if "," in url else url
255
- images.append(b64_to_pil(data))
256
- elif isinstance(data, bytes | bytearray):
257
- images.append(Image.open(io.BytesIO(data)).convert("RGB"))
258
- c = {"type": "image"}
259
- m["content"] = user_content
260
- sanitized_messages.append(m)
261
- return sanitized_messages, images
262
-
263
-
264
- def prepare_inputs(trace: Trace, processor: Any) -> dict[str, torch.Tensor]:
265
- """
266
- Prepare inputs from a trace.
267
-
268
- Args:
269
- trace: Trace to process
270
- processor: Model processor
271
-
272
- Returns:
273
- Inputs for the model
274
- """
275
- if len(trace.messages) == 0:
276
- return {}
277
-
278
- # Get images for current turn
279
- conversation, images = prepare_conversation_history(trace.messages)
280
-
281
- # Get absolute path to chat template
282
- chat_template_path = Path(__file__).parent / "chat_template.jinja"
283
-
284
- # For VL models, processor has a tokenizer attribute; for text models, processor IS tokenizer
285
- tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
286
-
287
- text_list, _ = render_jinja_template(
288
- conversations=[conversation],
289
- chat_template=load_chat_template(str(chat_template_path)),
290
- tools=trace.info["tool_spec"] if trace.info["tool_spec"] else None, # mcp_tools
291
- return_assistant_tokens_mask=True,
292
- **tokenizer.special_tokens_map,
293
- )
294
- # For text models, don't pass images parameter
295
- if hasattr(processor, "tokenizer"):
296
- # VL model - processor accepts images
297
- inputs = processor(
298
- images=images if len(images) > 0 else None,
299
- text=text_list,
300
- return_offsets_mapping=False, # we no longer need char offsets
301
- )
302
- else:
303
- # Text model - processor is tokenizer, doesn't accept images
304
- inputs = processor(
305
- text=text_list,
306
- return_offsets_mapping=False, # we no longer need char offsets
307
- )
308
-
309
- assistant_masks = build_assistant_masks(inputs["input_ids"], tokenizer)
310
- mask_tensor = torch.tensor(assistant_masks, dtype=torch.long)
311
-
312
- # Ensure mask_tensor is 2D before slicing
313
- if mask_tensor.dim() == 1:
314
- mask_tensor = mask_tensor.unsqueeze(0)
315
-
316
- # Slice to align with targets [B, T-1]
317
- inputs["assistant_mask"] = mask_tensor[:, 1:].bool()
318
-
319
- # Log amount of assistant tokens, and the first 10 tokens that are non 0, decoded
320
- # assistant_batches = render_assistant_tokens(mask_tensor, inputs['input_ids'], processor)
321
- inputs.convert_to_tensors(tensor_type="pt")
322
-
323
- return inputs
324
-
325
-
326
- def render_assistant_tokens(
327
- mask_tensor: torch.Tensor, input_ids: torch.Tensor, processor: Any
328
- ) -> list[str]:
329
- """Render assistant tokens as a list of continuous batches."""
330
- # Get the mask as a 1D tensor
331
- mask_1d = mask_tensor[0]
332
-
333
- # Find continuous sequences of non-zero values
334
- batches = []
335
- start_idx = None
336
-
337
- for i in range(len(mask_1d)):
338
- if mask_1d[i] != 0 and start_idx is None:
339
- # Start of a new batch
340
- start_idx = i
341
- elif mask_1d[i] == 0 and start_idx is not None:
342
- # End of current batch
343
- # Extract and decode the tokens in this batch
344
- batch_token_ids = input_ids[0][start_idx:i].tolist()
345
- decoded_batch = processor.decode(batch_token_ids)
346
- batches.append(decoded_batch)
347
- start_idx = None
348
-
349
- # Handle case where the last batch extends to the end
350
- if start_idx is not None:
351
- batch_token_ids = input_ids[0][start_idx:].tolist()
352
- decoded_batch = processor.decode(batch_token_ids)
353
- batches.append(decoded_batch)
354
-
355
- return batches
356
-
357
-
358
- def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
359
- """Calculate entropy from logits in a memory-efficient way."""
360
- log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
361
- entropy = -torch.sum(torch.exp(log_probs) * log_probs, dim=-1)
362
- return entropy
363
-
364
-
365
- def preprocess_advantages(group: list[Trace], config: Config) -> list[TrainingSample]:
366
- """Preprocess a group of traces."""
367
- group_size = config.training.group_size
368
- if config.training.batch_level == "group":
369
- groups = [group[i : i + group_size] for i in range(0, len(group), group_size)]
370
- elif config.training.batch_level == "batch":
371
- groups = [group]
372
- else:
373
- raise ValueError(f"Invalid batch level: {config.training.batch_level}")
374
-
375
- all_samples = []
376
- for i, group in enumerate(groups):
377
- rewards = np.array([trace.reward for trace in group])
378
- mean_reward = np.mean(rewards)
379
- std_reward = np.std(rewards)
380
-
381
- # Calculate advantages
382
- samples = [TrainingSample(**trace.model_dump()) for trace in group]
383
- for sample, reward in zip(samples, rewards, strict=True):
384
- if sample.isError:
385
- sample.advantage = torch.Tensor(np.array([0.0]))
386
- continue
387
- # No std (non-baseline GRPO)
388
- if config.training.no_std:
389
- advantage_value = reward - mean_reward
390
- else:
391
- # Avoid division by zero
392
- if std_reward < 1e-6:
393
- advantage_value = torch.Tensor(np.array([0.0]))
394
- else:
395
- advantage_value = (reward - mean_reward) / std_reward
396
- # Leave one out RLOO/LOOP
397
- if config.training.leave_one_out:
398
- advantage_value = advantage_value * len(group) / (len(group) - 1)
399
- sample.advantage = torch.Tensor(np.array([advantage_value]))
400
- hud_console.info_log(
401
- f"Advantages for group {i} [{mean_reward:.4f} ± {std_reward:.4f}]:"
402
- f"{[round(sample.advantage.item(), 4) for sample in samples if sample.advantage is not None]}" # noqa: E501
403
- )
404
-
405
- all_samples.extend(samples)
406
-
407
- return all_samples
408
-
409
-
410
- def batch_training_samples(samples: list[TrainingSample]) -> list[TrainingSample]:
411
- """Create batched model inputs from a list of TrainingSample.
412
-
413
- Pads token sequences to the maximum length in the list and zero-pads
414
- images to the maximum H/W when present. Returns a dictionary of batched
415
- tensors suitable for a single forward pass. Keeps assistant_masks for
416
- masked scoring.
417
- """
418
- if not samples:
419
- hud_console.warning("No samples to batch.")
420
- return []
421
-
422
- for s in samples:
423
- if (
424
- "assistant_mask" not in s.inputs
425
- or s.inputs["assistant_mask"].sum() == 0
426
- or s.advantage == 0.0
427
- ) and len(samples) > 1:
428
- hud_console.info("Removing sample with zero advantage.")
429
- samples.remove(s)
430
-
431
- if len(samples) == 1:
432
- return samples
433
-
434
- import torch.nn.functional as F
435
-
436
- new_samples = [TrainingSample()]
437
-
438
- input_keys_to_expand = ["input_ids", "attention_mask", "assistant_mask"]
439
- input_keys_to_cat = ["pixel_values", "image_grid_thw"]
440
- updated_inputs: dict[str, list[torch.Tensor]] = {
441
- k: [] for k in input_keys_to_expand + input_keys_to_cat
442
- }
443
-
444
- # Sanity check dimensions
445
- for s in samples:
446
- for k in input_keys_to_expand + input_keys_to_cat:
447
- val = s.inputs.get(k)
448
- if val is not None:
449
- if k in input_keys_to_expand:
450
- if val.dim() == 2 and val.size(0) == 1:
451
- val = val[0]
452
- elif val.dim() != 1:
453
- raise ValueError(f"{k} has unexpected dimensions: {val.shape}")
454
- updated_inputs[k].append(val)
455
-
456
- # Pad 1D sequences to max length
457
- max_len = max(t.size(-1) for t in updated_inputs["input_ids"])
458
-
459
- def pad_1d(x: torch.Tensor, pad_to: int, pad_value: int) -> torch.Tensor:
460
- pad = pad_to - x.size(-1)
461
- return F.pad(x, (0, pad), value=pad_value) if pad > 0 else x
462
-
463
- stacked_inputs: dict[str, torch.Tensor] = {}
464
- # These are 1D sequences that need padding
465
- for k in input_keys_to_expand:
466
- if updated_inputs[k]:
467
- # assistant_mask is T-1, others are T
468
- if k == "assistant_mask":
469
- stacked_inputs[k] = torch.stack(
470
- [pad_1d(x, max_len - 1, 0) for x in updated_inputs[k]], dim=0
471
- )
472
- else:
473
- stacked_inputs[k] = torch.stack(
474
- [pad_1d(x, max_len, 0) for x in updated_inputs[k]], dim=0
475
- )
476
-
477
- for k in input_keys_to_cat:
478
- if updated_inputs[k]:
479
- # pixel_values and image_grid_thw are concatenated across all images from all samples
480
- # Shape of pixel_values: (sum of all patches from all images, feature_dim)
481
- # Shape of image_grid_thw: (sum of all images, 3)
482
- stacked_inputs[k] = torch.cat(updated_inputs[k], dim=0)
483
- else:
484
- stacked_inputs.pop(k)
485
-
486
- new_samples[0].inputs = stacked_inputs
487
-
488
- # Pad logprobs to max length before stacking
489
- # old_logprobs and ref_logprobs have shape [seq_len] or [1, seq_len] after gathering
490
- def pad_logprobs(logprobs: torch.Tensor | None, max_len: int) -> torch.Tensor:
491
- # Always work with 1D tensor, squeeze batch dim if present
492
- if logprobs is None:
493
- return torch.tensor([float("-inf")], dtype=torch.float32)
494
- if logprobs.dim() == 2 and logprobs.size(0) == 1:
495
- logprobs = logprobs.squeeze(0)
496
- elif logprobs.dim() != 1:
497
- raise ValueError(
498
- f"Expected logprobs to have 1 or 2 dimensions, got {logprobs.dim()} with shape {logprobs.shape}" # noqa: E501
499
- )
500
-
501
- # Now logprobs is [seq_len]
502
- seq_len = logprobs.size(0) if logprobs is not None else 0
503
- if seq_len < max_len:
504
- pad_size = max_len - seq_len
505
- # Pad with -inf (log of 0 probability) along sequence dimension
506
- return F.pad(logprobs, (0, pad_size), value=float("-inf"))
507
- return logprobs
508
-
509
- # Stack padded logprobs (these are T-1 length)
510
- old_logprobs_list = [pad_logprobs(s.old_logprobs, max_len - 1) for s in samples]
511
- ref_logprobs_list = [pad_logprobs(s.ref_logprobs, max_len - 1) for s in samples]
512
-
513
- new_samples[0].old_logprobs = torch.stack(old_logprobs_list, dim=0)
514
- new_samples[0].ref_logprobs = torch.stack(ref_logprobs_list, dim=0)
515
-
516
- # Stack advantages, checking for None values
517
- advantages = [s.advantage for s in samples]
518
- if any(adv is None for adv in advantages):
519
- raise ValueError(
520
- "Some samples have None advantages. Make sure advantages are computed before batching."
521
- )
522
- new_samples[0].advantage = torch.stack(advantages, dim=0) # type: ignore
523
-
524
- return new_samples
hud/rl/vllm_adapter.py DELETED
@@ -1,143 +0,0 @@
1
- """vLLM adapter management for LoRA hot-swapping."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- import logging
7
-
8
- import requests
9
-
10
- from hud.utils.hud_console import HUDConsole
11
-
12
- hud_console = HUDConsole(logging.getLogger(__name__))
13
-
14
-
15
- class VLLMAdapter:
16
- """Manages LoRA adapter loading/unloading in vLLM."""
17
-
18
- def __init__(self, base_url: str, api_key: str) -> None:
19
- self.base_url = base_url
20
- self.api_key = api_key
21
- self.current_adapter = None
22
-
23
- def load_adapter(self, adapter_name: str, adapter_path: str, timeout: int = 30) -> bool:
24
- """
25
- Hot-load a LoRA adapter to vLLM.
26
-
27
- Args:
28
- adapter_name: Name to register the adapter as
29
- adapter_path: Path to the adapter checkpoint
30
- timeout: Request timeout in seconds
31
-
32
- Returns:
33
- True if successful, False otherwise
34
- """
35
- url = f"{self.base_url}/load_lora_adapter"
36
- headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
37
- payload = {"lora_name": adapter_name, "lora_path": adapter_path}
38
- # Implement exponential backoff for retrying the adapter load request.
39
- max_retries = 8
40
- backoff_factor = 2
41
- delay = 1 # initial delay in seconds
42
-
43
- for attempt in range(1, max_retries + 1):
44
- try:
45
- response = requests.post(
46
- url, headers=headers, data=json.dumps(payload), timeout=timeout
47
- )
48
- response.raise_for_status()
49
-
50
- self.current_adapter = adapter_name
51
- hud_console.info(f"[VLLMAdapter] Loaded adapter: {adapter_name}")
52
- return True
53
-
54
- except requests.exceptions.RequestException as e:
55
- if attempt == max_retries:
56
- hud_console.error(
57
- f"[VLLMAdapter] Failed to load adapter {adapter_name} after {attempt} attempts: {e}" # noqa: E501
58
- )
59
- return False
60
- else:
61
- hud_console.warning(
62
- f"[VLLMAdapter] Load adapter {adapter_name} failed (attempt {attempt}/{max_retries}): {e}. Retrying in {delay} seconds...", # noqa: E501
63
- )
64
- import time
65
-
66
- time.sleep(delay)
67
- delay *= backoff_factor
68
-
69
- return False
70
-
71
- def unload_adapter(self, adapter_name: str) -> bool:
72
- """
73
- Unload a LoRA adapter from vLLM.
74
-
75
- Args:
76
- adapter_name: Name of the adapter to unload
77
-
78
- Returns:
79
- True if successful, False otherwise
80
- """
81
- url = f"{self.base_url}/unload_lora_adapter"
82
- headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
83
- payload = {"lora_name": adapter_name}
84
-
85
- try:
86
- response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=30)
87
- response.raise_for_status()
88
-
89
- if self.current_adapter == adapter_name:
90
- self.current_adapter = None
91
-
92
- hud_console.info(f"[VLLMAdapter] Unloaded adapter: {adapter_name}")
93
- return True
94
-
95
- except requests.exceptions.RequestException as e:
96
- hud_console.error(f"[VLLMAdapter] Failed to unload adapter {adapter_name}: {e}")
97
- return False
98
-
99
- def list_adapters(self) -> list | None:
100
- """
101
- List all loaded LoRA adapters in vLLM.
102
-
103
- Returns:
104
- List of adapter names, or None if failed
105
- """
106
- url = f"{self.base_url}/list_lora_adapters"
107
- headers = {"Authorization": f"Bearer {self.api_key}"}
108
-
109
- try:
110
- response = requests.get(url, headers=headers, timeout=10)
111
- response.raise_for_status()
112
- return response.json().get("adapters", [])
113
-
114
- except requests.exceptions.RequestException as e:
115
- hud_console.error(f"[VLLMAdapter] Failed to list adapters: {e}")
116
- return None
117
-
118
- def get_current(self) -> str | None:
119
- """Get the name of the currently loaded adapter."""
120
- return self.current_adapter
121
-
122
-
123
- # Convenience function for standalone use
124
- def hotload_lora(
125
- adapter_name: str,
126
- adapter_path: str,
127
- base_url: str = "http://localhost:8000/v1",
128
- api_key: str = "token-abc123",
129
- ) -> bool:
130
- """
131
- Quick function to hot-load a LoRA adapter.
132
-
133
- Args:
134
- adapter_name: Name for the adapter
135
- adapter_path: Path to adapter checkpoint
136
- base_url: vLLM server URL
137
- api_key: API key for vLLM
138
-
139
- Returns:
140
- True if successful
141
- """
142
- adapter = VLLMAdapter(base_url, api_key)
143
- return adapter.load_adapter(adapter_name, adapter_path)