hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. hud/__init__.py +27 -7
  2. hud/agents/__init__.py +70 -5
  3. hud/agents/base.py +238 -500
  4. hud/agents/claude.py +236 -247
  5. hud/agents/gateway.py +42 -0
  6. hud/agents/gemini.py +264 -0
  7. hud/agents/gemini_cua.py +324 -0
  8. hud/agents/grounded_openai.py +98 -100
  9. hud/agents/misc/integration_test_agent.py +51 -20
  10. hud/agents/misc/response_agent.py +48 -36
  11. hud/agents/openai.py +282 -296
  12. hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
  13. hud/agents/operator.py +199 -0
  14. hud/agents/resolver.py +70 -0
  15. hud/agents/tests/conftest.py +133 -0
  16. hud/agents/tests/test_base.py +300 -622
  17. hud/agents/tests/test_base_runtime.py +233 -0
  18. hud/agents/tests/test_claude.py +381 -214
  19. hud/agents/tests/test_client.py +9 -10
  20. hud/agents/tests/test_gemini.py +369 -0
  21. hud/agents/tests/test_grounded_openai_agent.py +65 -50
  22. hud/agents/tests/test_openai.py +377 -140
  23. hud/agents/tests/test_operator.py +362 -0
  24. hud/agents/tests/test_resolver.py +192 -0
  25. hud/agents/tests/test_run_eval.py +179 -0
  26. hud/agents/types.py +148 -0
  27. hud/cli/__init__.py +493 -546
  28. hud/cli/analyze.py +43 -5
  29. hud/cli/build.py +699 -113
  30. hud/cli/debug.py +8 -5
  31. hud/cli/dev.py +889 -732
  32. hud/cli/eval.py +793 -667
  33. hud/cli/flows/dev.py +167 -0
  34. hud/cli/flows/init.py +191 -0
  35. hud/cli/flows/tasks.py +153 -56
  36. hud/cli/flows/templates.py +151 -0
  37. hud/cli/flows/tests/__init__.py +1 -0
  38. hud/cli/flows/tests/test_dev.py +126 -0
  39. hud/cli/init.py +60 -58
  40. hud/cli/pull.py +1 -1
  41. hud/cli/push.py +38 -13
  42. hud/cli/rft.py +311 -0
  43. hud/cli/rft_status.py +145 -0
  44. hud/cli/tests/test_analyze.py +5 -5
  45. hud/cli/tests/test_analyze_metadata.py +3 -2
  46. hud/cli/tests/test_analyze_module.py +120 -0
  47. hud/cli/tests/test_build.py +110 -8
  48. hud/cli/tests/test_build_failure.py +41 -0
  49. hud/cli/tests/test_build_module.py +50 -0
  50. hud/cli/tests/test_cli_init.py +6 -1
  51. hud/cli/tests/test_cli_more_wrappers.py +30 -0
  52. hud/cli/tests/test_cli_root.py +140 -0
  53. hud/cli/tests/test_convert.py +361 -0
  54. hud/cli/tests/test_debug.py +12 -10
  55. hud/cli/tests/test_dev.py +197 -0
  56. hud/cli/tests/test_eval.py +251 -0
  57. hud/cli/tests/test_eval_bedrock.py +51 -0
  58. hud/cli/tests/test_init.py +124 -0
  59. hud/cli/tests/test_main_module.py +11 -5
  60. hud/cli/tests/test_mcp_server.py +12 -100
  61. hud/cli/tests/test_push.py +1 -1
  62. hud/cli/tests/test_push_happy.py +74 -0
  63. hud/cli/tests/test_push_wrapper.py +23 -0
  64. hud/cli/tests/test_registry.py +1 -1
  65. hud/cli/tests/test_utils.py +1 -1
  66. hud/cli/{rl → utils}/celebrate.py +14 -12
  67. hud/cli/utils/config.py +18 -1
  68. hud/cli/utils/docker.py +130 -4
  69. hud/cli/utils/env_check.py +9 -9
  70. hud/cli/utils/git.py +136 -0
  71. hud/cli/utils/interactive.py +39 -5
  72. hud/cli/utils/metadata.py +70 -1
  73. hud/cli/utils/runner.py +1 -1
  74. hud/cli/utils/server.py +2 -2
  75. hud/cli/utils/source_hash.py +3 -3
  76. hud/cli/utils/tasks.py +4 -1
  77. hud/cli/utils/tests/__init__.py +0 -0
  78. hud/cli/utils/tests/test_config.py +58 -0
  79. hud/cli/utils/tests/test_docker.py +93 -0
  80. hud/cli/utils/tests/test_docker_hints.py +71 -0
  81. hud/cli/utils/tests/test_env_check.py +74 -0
  82. hud/cli/utils/tests/test_environment.py +42 -0
  83. hud/cli/utils/tests/test_git.py +142 -0
  84. hud/cli/utils/tests/test_interactive_module.py +60 -0
  85. hud/cli/utils/tests/test_local_runner.py +50 -0
  86. hud/cli/utils/tests/test_logging_utils.py +23 -0
  87. hud/cli/utils/tests/test_metadata.py +49 -0
  88. hud/cli/utils/tests/test_package_runner.py +35 -0
  89. hud/cli/utils/tests/test_registry_utils.py +49 -0
  90. hud/cli/utils/tests/test_remote_runner.py +25 -0
  91. hud/cli/utils/tests/test_runner_modules.py +52 -0
  92. hud/cli/utils/tests/test_source_hash.py +36 -0
  93. hud/cli/utils/tests/test_tasks.py +80 -0
  94. hud/cli/utils/version_check.py +258 -0
  95. hud/cli/{rl → utils}/viewer.py +2 -2
  96. hud/clients/README.md +12 -11
  97. hud/clients/__init__.py +4 -3
  98. hud/clients/base.py +166 -26
  99. hud/clients/environment.py +51 -0
  100. hud/clients/fastmcp.py +13 -6
  101. hud/clients/mcp_use.py +45 -15
  102. hud/clients/tests/test_analyze_scenarios.py +206 -0
  103. hud/clients/tests/test_protocol.py +9 -3
  104. hud/datasets/__init__.py +23 -20
  105. hud/datasets/loader.py +326 -0
  106. hud/datasets/runner.py +198 -105
  107. hud/datasets/tests/__init__.py +0 -0
  108. hud/datasets/tests/test_loader.py +221 -0
  109. hud/datasets/tests/test_utils.py +315 -0
  110. hud/datasets/utils.py +270 -90
  111. hud/environment/__init__.py +52 -0
  112. hud/environment/connection.py +258 -0
  113. hud/environment/connectors/__init__.py +33 -0
  114. hud/environment/connectors/base.py +68 -0
  115. hud/environment/connectors/local.py +177 -0
  116. hud/environment/connectors/mcp_config.py +137 -0
  117. hud/environment/connectors/openai.py +101 -0
  118. hud/environment/connectors/remote.py +172 -0
  119. hud/environment/environment.py +835 -0
  120. hud/environment/integrations/__init__.py +45 -0
  121. hud/environment/integrations/adk.py +67 -0
  122. hud/environment/integrations/anthropic.py +196 -0
  123. hud/environment/integrations/gemini.py +92 -0
  124. hud/environment/integrations/langchain.py +82 -0
  125. hud/environment/integrations/llamaindex.py +68 -0
  126. hud/environment/integrations/openai.py +238 -0
  127. hud/environment/mock.py +306 -0
  128. hud/environment/router.py +263 -0
  129. hud/environment/scenarios.py +620 -0
  130. hud/environment/tests/__init__.py +1 -0
  131. hud/environment/tests/test_connection.py +317 -0
  132. hud/environment/tests/test_connectors.py +205 -0
  133. hud/environment/tests/test_environment.py +593 -0
  134. hud/environment/tests/test_integrations.py +257 -0
  135. hud/environment/tests/test_local_connectors.py +242 -0
  136. hud/environment/tests/test_scenarios.py +1086 -0
  137. hud/environment/tests/test_tools.py +208 -0
  138. hud/environment/types.py +23 -0
  139. hud/environment/utils/__init__.py +35 -0
  140. hud/environment/utils/formats.py +215 -0
  141. hud/environment/utils/schema.py +171 -0
  142. hud/environment/utils/tool_wrappers.py +113 -0
  143. hud/eval/__init__.py +67 -0
  144. hud/eval/context.py +727 -0
  145. hud/eval/display.py +299 -0
  146. hud/eval/instrument.py +187 -0
  147. hud/eval/manager.py +533 -0
  148. hud/eval/parallel.py +268 -0
  149. hud/eval/task.py +372 -0
  150. hud/eval/tests/__init__.py +1 -0
  151. hud/eval/tests/test_context.py +178 -0
  152. hud/eval/tests/test_eval.py +210 -0
  153. hud/eval/tests/test_manager.py +152 -0
  154. hud/eval/tests/test_parallel.py +168 -0
  155. hud/eval/tests/test_task.py +291 -0
  156. hud/eval/types.py +65 -0
  157. hud/eval/utils.py +194 -0
  158. hud/patches/__init__.py +19 -0
  159. hud/patches/mcp_patches.py +308 -0
  160. hud/patches/warnings.py +54 -0
  161. hud/samples/browser.py +4 -4
  162. hud/server/__init__.py +2 -1
  163. hud/server/low_level.py +2 -1
  164. hud/server/router.py +164 -0
  165. hud/server/server.py +567 -80
  166. hud/server/tests/test_mcp_server_integration.py +11 -11
  167. hud/server/tests/test_mcp_server_more.py +1 -1
  168. hud/server/tests/test_server_extra.py +2 -0
  169. hud/settings.py +45 -3
  170. hud/shared/exceptions.py +36 -10
  171. hud/shared/hints.py +26 -1
  172. hud/shared/requests.py +15 -3
  173. hud/shared/tests/test_exceptions.py +40 -31
  174. hud/shared/tests/test_hints.py +167 -0
  175. hud/telemetry/__init__.py +20 -19
  176. hud/telemetry/exporter.py +201 -0
  177. hud/telemetry/instrument.py +165 -253
  178. hud/telemetry/tests/test_eval_telemetry.py +356 -0
  179. hud/telemetry/tests/test_exporter.py +258 -0
  180. hud/telemetry/tests/test_instrument.py +401 -0
  181. hud/tools/__init__.py +18 -2
  182. hud/tools/agent.py +223 -0
  183. hud/tools/apply_patch.py +639 -0
  184. hud/tools/base.py +54 -4
  185. hud/tools/bash.py +2 -2
  186. hud/tools/computer/__init__.py +36 -3
  187. hud/tools/computer/anthropic.py +2 -2
  188. hud/tools/computer/gemini.py +385 -0
  189. hud/tools/computer/hud.py +23 -6
  190. hud/tools/computer/openai.py +20 -21
  191. hud/tools/computer/qwen.py +434 -0
  192. hud/tools/computer/settings.py +37 -0
  193. hud/tools/edit.py +3 -7
  194. hud/tools/executors/base.py +4 -2
  195. hud/tools/executors/pyautogui.py +1 -1
  196. hud/tools/grounding/grounded_tool.py +13 -18
  197. hud/tools/grounding/grounder.py +10 -31
  198. hud/tools/grounding/tests/test_grounded_tool.py +26 -44
  199. hud/tools/jupyter.py +330 -0
  200. hud/tools/playwright.py +18 -3
  201. hud/tools/shell.py +308 -0
  202. hud/tools/tests/test_agent_tool.py +355 -0
  203. hud/tools/tests/test_apply_patch.py +718 -0
  204. hud/tools/tests/test_computer.py +4 -9
  205. hud/tools/tests/test_computer_actions.py +24 -2
  206. hud/tools/tests/test_jupyter_tool.py +181 -0
  207. hud/tools/tests/test_shell.py +596 -0
  208. hud/tools/tests/test_submit.py +85 -0
  209. hud/tools/tests/test_types.py +193 -0
  210. hud/tools/types.py +21 -1
  211. hud/types.py +194 -56
  212. hud/utils/__init__.py +2 -0
  213. hud/utils/env.py +67 -0
  214. hud/utils/hud_console.py +89 -18
  215. hud/utils/mcp.py +15 -58
  216. hud/utils/strict_schema.py +162 -0
  217. hud/utils/tests/test_init.py +1 -2
  218. hud/utils/tests/test_mcp.py +1 -28
  219. hud/utils/tests/test_pretty_errors.py +186 -0
  220. hud/utils/tests/test_tool_shorthand.py +154 -0
  221. hud/utils/tests/test_version.py +1 -1
  222. hud/utils/types.py +20 -0
  223. hud/version.py +1 -1
  224. hud_python-0.5.13.dist-info/METADATA +264 -0
  225. hud_python-0.5.13.dist-info/RECORD +305 -0
  226. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
  227. hud/agents/langchain.py +0 -261
  228. hud/agents/lite_llm.py +0 -72
  229. hud/cli/rl/__init__.py +0 -180
  230. hud/cli/rl/config.py +0 -101
  231. hud/cli/rl/display.py +0 -133
  232. hud/cli/rl/gpu.py +0 -63
  233. hud/cli/rl/gpu_utils.py +0 -321
  234. hud/cli/rl/local_runner.py +0 -595
  235. hud/cli/rl/presets.py +0 -96
  236. hud/cli/rl/remote_runner.py +0 -463
  237. hud/cli/rl/rl_api.py +0 -150
  238. hud/cli/rl/vllm.py +0 -177
  239. hud/cli/rl/wait_utils.py +0 -89
  240. hud/datasets/parallel.py +0 -687
  241. hud/misc/__init__.py +0 -1
  242. hud/misc/claude_plays_pokemon.py +0 -292
  243. hud/otel/__init__.py +0 -35
  244. hud/otel/collector.py +0 -142
  245. hud/otel/config.py +0 -181
  246. hud/otel/context.py +0 -570
  247. hud/otel/exporters.py +0 -369
  248. hud/otel/instrumentation.py +0 -135
  249. hud/otel/processors.py +0 -121
  250. hud/otel/tests/__init__.py +0 -1
  251. hud/otel/tests/test_processors.py +0 -197
  252. hud/rl/README.md +0 -30
  253. hud/rl/__init__.py +0 -1
  254. hud/rl/actor.py +0 -176
  255. hud/rl/buffer.py +0 -405
  256. hud/rl/chat_template.jinja +0 -101
  257. hud/rl/config.py +0 -192
  258. hud/rl/distributed.py +0 -132
  259. hud/rl/learner.py +0 -637
  260. hud/rl/tests/__init__.py +0 -1
  261. hud/rl/tests/test_learner.py +0 -186
  262. hud/rl/train.py +0 -382
  263. hud/rl/types.py +0 -101
  264. hud/rl/utils/start_vllm_server.sh +0 -30
  265. hud/rl/utils.py +0 -524
  266. hud/rl/vllm_adapter.py +0 -143
  267. hud/telemetry/job.py +0 -352
  268. hud/telemetry/replay.py +0 -74
  269. hud/telemetry/tests/test_replay.py +0 -40
  270. hud/telemetry/tests/test_trace.py +0 -63
  271. hud/telemetry/trace.py +0 -158
  272. hud/utils/agent_factories.py +0 -86
  273. hud/utils/async_utils.py +0 -65
  274. hud/utils/group_eval.py +0 -223
  275. hud/utils/progress.py +0 -149
  276. hud/utils/tasks.py +0 -127
  277. hud/utils/tests/test_async_utils.py +0 -173
  278. hud/utils/tests/test_progress.py +0 -261
  279. hud_python-0.4.45.dist-info/METADATA +0 -552
  280. hud_python-0.4.45.dist-info/RECORD +0 -228
  281. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
  282. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
hud/rl/learner.py DELETED
@@ -1,637 +0,0 @@
1
- """GRPO learner for vision-language and text models."""
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
- import os
7
- from typing import TYPE_CHECKING, Any
8
-
9
- import torch
10
- from peft import LoraConfig, get_peft_model
11
- from torch.nn.parallel import DistributedDataParallel as DDP
12
- from transformers import (
13
- AutoModelForCausalLM,
14
- AutoProcessor,
15
- AutoTokenizer,
16
- Qwen2_5_VLForConditionalGeneration,
17
- )
18
-
19
- try:
20
- from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl # type: ignore
21
-
22
- LIGER_AVAILABLE = True
23
- except ImportError:
24
- LIGER_AVAILABLE = False
25
-
26
- try:
27
- import bitsandbytes as bnb # type: ignore
28
-
29
- BNB_AVAILABLE = True
30
- except ImportError:
31
- BNB_AVAILABLE = False
32
-
33
- from contextlib import nullcontext
34
-
35
- from hud.rl.distributed import (
36
- get_local_rank,
37
- get_world_size,
38
- is_main_process,
39
- )
40
- from hud.rl.utils import (
41
- batch_training_samples,
42
- entropy_from_logits,
43
- get_gpu_utilization,
44
- get_memory_usage,
45
- prepare_inputs,
46
- )
47
- from hud.utils.hud_console import HUDConsole
48
-
49
- from .types import TrainingMetrics, TrainingSample
50
-
51
- logger = logging.getLogger(__name__)
52
- hud_console = HUDConsole(logger)
53
-
54
- if TYPE_CHECKING:
55
- from .config import Config
56
-
57
-
58
- class GRPOLearner:
59
- """GRPO learning algorithm for Vision-Language Models (VLMs) and Text Models."""
60
-
61
- def __init__(self, config: Config) -> None:
62
- self.config = config
63
- self.local_rank = get_local_rank()
64
- self.world_size = get_world_size()
65
- self.device = torch.device(
66
- f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu"
67
- )
68
-
69
- # Detect model type
70
- self.is_vl_model = "VL" in config.model.base_model
71
-
72
- # Load models and processor
73
- self.processor, self.policy, self.ref, self.optimizer = self._load_models()
74
- self.metrics: list[TrainingMetrics] = []
75
-
76
- def log(self, message: str) -> None:
77
- hud_console.info_log(f"[{self.local_rank}] {message}")
78
-
79
- def _load_models(self) -> tuple[Any, Any, Any, Any]:
80
- """Load policy, reference models and optimizer."""
81
- model_cfg = self.config.model
82
-
83
- # Detect if this is a VL model or standard text model
84
- is_vl_model = "VL" in model_cfg.base_model
85
- model_type = "Vision-Language" if is_vl_model else "Text"
86
- self.log(f"Loading {model_type} model: {model_cfg.base_model}")
87
-
88
- # Apply Liger kernel optimizations if available and enabled
89
- if model_cfg.use_liger and LIGER_AVAILABLE:
90
- if is_vl_model:
91
- self.log("Applying Liger kernel optimizations to Qwen2.5-VL")
92
- apply_liger_kernel_to_qwen2_5_vl(
93
- rope=True, # Optimized RoPE
94
- rms_norm=True, # Optimized RMSNorm
95
- swiglu=True, # Optimized SwiGLU
96
- fused_linear_cross_entropy=True, # Fused Linear+CrossEntropy for memory
97
- )
98
- elif model_cfg.use_liger and not LIGER_AVAILABLE:
99
- self.log(
100
- "Liger kernel requested but not installed. Install with: pip install liger-kernel"
101
- )
102
-
103
- # Load processor/tokenizer based on model type
104
- if is_vl_model:
105
- # Some environments require remote code for Qwen2.5-VL processors
106
- processor = AutoProcessor.from_pretrained(
107
- model_cfg.base_model,
108
- min_pixels=model_cfg.min_pixels,
109
- max_pixels=model_cfg.max_pixels,
110
- trust_remote_code=True,
111
- )
112
- else:
113
- processor = AutoTokenizer.from_pretrained(model_cfg.base_model)
114
-
115
- # Load policy model with LoRA
116
- # Use attention implementation from config
117
- attn_implementation = model_cfg.attn_implementation
118
-
119
- # Choose the appropriate model class
120
- model_class = Qwen2_5_VLForConditionalGeneration if is_vl_model else AutoModelForCausalLM
121
-
122
- try:
123
- policy = model_class.from_pretrained(
124
- model_cfg.base_model,
125
- torch_dtype=torch.bfloat16,
126
- attn_implementation=attn_implementation,
127
- trust_remote_code=True,
128
- )
129
- self.log(f"Using {attn_implementation} for attention")
130
- except (ImportError, ValueError) as e:
131
- # Only fallback if explicitly using flash_attention_2 and it's not available
132
- if attn_implementation == "flash_attention_2":
133
- self.log(f"Flash Attention 2 not available ({e}), using eager attention")
134
- policy = model_class.from_pretrained(
135
- model_cfg.base_model,
136
- torch_dtype=torch.bfloat16,
137
- attn_implementation="eager",
138
- )
139
- else:
140
- raise # Re-raise if it's a different error
141
-
142
- # Move model to device
143
- policy = policy.to(self.device) # type: ignore
144
- # Enable gradient checkpointing for memory efficiency
145
- if model_cfg.gradient_checkpointing:
146
- policy.gradient_checkpointing_enable()
147
- self.log("Gradient checkpointing enabled for memory efficiency")
148
-
149
- # Add LoRA adapters
150
- lora_config = LoraConfig(
151
- r=model_cfg.lora_r,
152
- lora_alpha=model_cfg.lora_alpha,
153
- lora_dropout=model_cfg.lora_dropout,
154
- task_type="CAUSAL_LM",
155
- bias="none",
156
- target_modules=list(model_cfg.target_modules),
157
- )
158
- policy.config.use_cache = False
159
- policy = get_peft_model(policy, lora_config)
160
-
161
- # Wrap with DDP if in distributed mode
162
- if self.world_size > 1:
163
- policy = DDP(
164
- policy,
165
- device_ids=[self.local_rank],
166
- output_device=self.local_rank,
167
- broadcast_buffers=False,
168
- find_unused_parameters=True,
169
- )
170
- self.log("Wrapped model (find_unused_parameters=True)")
171
-
172
- # Create optimizer - need to access underlying model if DDP
173
- base_model = policy.module if hasattr(policy, "module") else policy
174
- trainable_params = [p for _, p in base_model.named_parameters() if p.requires_grad] # type: ignore
175
-
176
- # Use 8-bit optimizer if configured
177
- if self.config.training.use_8bit_optimizer and BNB_AVAILABLE:
178
- hud_console.info("Using 8-bit AdamW optimizer from bitsandbytes")
179
- optimizer = bnb.optim.AdamW8bit(
180
- trainable_params,
181
- lr=self.config.training.lr,
182
- betas=self.config.training.adam_betas,
183
- eps=self.config.training.adam_eps,
184
- )
185
- else:
186
- self.log("Using standard FP32 AdamW optimizer")
187
- optimizer = torch.optim.AdamW(
188
- trainable_params,
189
- lr=self.config.training.lr,
190
- betas=self.config.training.adam_betas,
191
- eps=self.config.training.adam_eps,
192
- )
193
-
194
- # Log optimizer info
195
- self.log(f"Optimizer: {type(optimizer).__name__}")
196
- num_params = sum(p.numel() for p in trainable_params)
197
- self.log(f"Number of trainable parameters: {num_params:,}")
198
-
199
- return processor, policy, None, optimizer
200
-
201
- def prepare_groups(
202
- self,
203
- samples: list[TrainingSample],
204
- ) -> list[list[TrainingSample]]:
205
- """Prepare groups of samples for training."""
206
- # Prepare inputs with messages
207
- batch = []
208
- for sample in samples:
209
- inputs = prepare_inputs(sample, self.processor)
210
- # If inputs are invalid, create dummy inputs to maintain batch size
211
- if (
212
- not inputs
213
- or "input_ids" not in inputs
214
- or inputs.get("input_ids", torch.tensor([])).numel() == 0
215
- ):
216
- hud_console.warning_log("Sample has invalid inputs, using dummy values")
217
- # Create minimal dummy inputs to keep batch size consistent
218
- inputs = {
219
- "input_ids": torch.zeros(1, 2, dtype=torch.long), # Minimal sequence
220
- "attention_mask": torch.ones(1, 2, dtype=torch.long),
221
- "assistant_mask": torch.zeros(1, 1, dtype=torch.bool), # T-1 length
222
- }
223
- elif "assistant_mask" not in inputs:
224
- hud_console.warning_log("Sample missing assistant_mask, creating zero mask")
225
- seq_len = inputs["input_ids"].shape[-1]
226
- inputs["assistant_mask"] = torch.zeros(
227
- inputs["input_ids"].shape[0], seq_len - 1, dtype=torch.bool
228
- )
229
-
230
- new_sample = TrainingSample(**sample.model_dump())
231
- new_sample.inputs = inputs
232
- new_sample.advantage = sample.advantage
233
- batch.append(new_sample)
234
-
235
- with hud_console.progress("Processing batch of traces...") as progress, torch.no_grad():
236
- for i, sample in enumerate(batch):
237
- if is_main_process():
238
- progress.update(f"Processing batch of traces... {i}/{len(batch)}")
239
- if sample.inputs:
240
- sample = sample.to_device(self.device)
241
- sample.old_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
242
- # Free GPU memory for this sample immediately
243
- sample.to_device(torch.device("cpu"))
244
-
245
- policy_module = self.policy.module if hasattr(self.policy, "module") else self.policy
246
- with policy_module.disable_adapter():
247
- for i, sample in enumerate(batch):
248
- if is_main_process():
249
- progress.update(f"Processing batch of traces... {i}/{len(batch)}")
250
- if sample.inputs:
251
- # Move back to GPU for reference computation, then free
252
- sample = sample.to_device(self.device)
253
- sample.ref_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs)
254
- sample.to_device(torch.device("cpu"))
255
-
256
- hud_console.info_log("Creating mini-batches...")
257
- group_size = self.config.training.group_size
258
- processed_batch = []
259
- if not self.config.training.accumulate_over_minibatches:
260
- # Find minibatches and group them via batch_training_samples
261
- # Minibatches control the batch size of the forward pass to the model
262
- mb_size = self.config.training.mini_batch_size
263
- group_size = group_size // mb_size
264
- for i in range(0, len(batch), mb_size):
265
- processed_batch.extend(batch_training_samples(batch[i : i + mb_size]))
266
- else:
267
- processed_batch = batch
268
-
269
- for sample in processed_batch:
270
- sample.to_device(torch.device("cpu"))
271
-
272
- # Convert to grouped batches (if updating the model after each task group)
273
- if self.config.training.update_after_group:
274
- return [
275
- processed_batch[i : i + group_size]
276
- for i in range(0, len(processed_batch), group_size)
277
- ]
278
- else:
279
- return [processed_batch]
280
-
281
- def update(self, samples: list[TrainingSample]) -> TrainingMetrics:
282
- """Perform a gradient update on a batch."""
283
- import time
284
-
285
- training_start_time = time.time()
286
-
287
- # Always create metrics for synchronization
288
- self.metrics.append(TrainingMetrics())
289
- metrics = self.metrics[-1]
290
-
291
- # Prepare groups for GRPO training
292
- groups = self.prepare_groups(samples)
293
- self.log(f"Updating over {len(groups)} groups")
294
-
295
- # Update over mini batch size
296
- with hud_console.progress("Gradient update...") as progress:
297
- for epoch in range(self.config.training.epochs): # Do not accumulate across epochs
298
- progress.update(f"Training epoch {epoch + 1}/{self.config.training.epochs}")
299
- for group_idx, group in enumerate(groups): # Do not accumulate across "groups"
300
- self.optimizer.zero_grad(set_to_none=True)
301
-
302
- debug_per_group = ""
303
- grad_accum_steps = len(group)
304
- # Tensor for distributed sync
305
- global_skip = torch.zeros(1, device=self.device)
306
-
307
- for s_idx, sample_minibatch in enumerate(group):
308
- # self.log(f"{group_idx} {sample_minibatch.inputs['assistant_mask'].sum()}")
309
- # mini_updated = sample_minibatch.inputs["assistant_mask"].sum() > 0
310
-
311
- # Update mini_updated globally
312
- # self.log(f"{group_idx} Mini updated: {mini_updated}")
313
-
314
- # Do not sync until the last minibatch
315
- if s_idx < len(group) - 1 and self.world_size > 1:
316
- ddp_ctx = self.policy.no_sync()
317
- else:
318
- ddp_ctx = nullcontext()
319
-
320
- with ddp_ctx, torch.autocast(device_type="cuda", dtype=torch.bfloat16):
321
- try:
322
- # if mini_updated:
323
- loss = self.compute_loss(sample_minibatch) / grad_accum_steps
324
- debug_per_group += f"l{s_idx}:{round(loss.item(), 3)!s} "
325
- loss.backward()
326
- # else: # Dummy backward that touches all params, produces zero g
327
- # dummy = sum(p.sum() for p in self.policy.parameters()) * 0.0
328
- # debug_per_group += f"d{s_idx}:{str(round(dummy.item(), 3))} "
329
- # dummy.backward()
330
- # self.log(f"{group_idx} GPU Backward: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB") # noqa: E501
331
- except torch.cuda.OutOfMemoryError:
332
- hud_console.warning_log(
333
- f"{group_idx} CUDA OOM for {sample_minibatch.inputs['input_ids'].numel()} tokens; skipping minibatch" # noqa: E501
334
- )
335
- # Dummy backward to keep DDP happy
336
- dummy = torch.sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore
337
- debug_per_group += f"o{s_idx}:{round(dummy.item(), 3)!s} "
338
- dummy.backward()
339
- # mark global skip if OOM
340
- global_skip.fill_(1)
341
- continue
342
-
343
- if torch.cuda.is_available():
344
- torch.cuda.empty_cache()
345
-
346
- # After minibatches loop, sync skip across ranks
347
- if torch.distributed.is_initialized():
348
- torch.distributed.all_reduce(global_skip, op=torch.distributed.ReduceOp.MAX)
349
- skip_any = bool(global_skip.item())
350
-
351
- if skip_any:
352
- self.log(f"G[{group_idx}] {debug_per_group} N/A (skipped)")
353
- continue
354
-
355
- grad_norm = torch.nn.utils.clip_grad_norm_(
356
- self.policy.parameters(),
357
- self.config.training.grad_clip,
358
- error_if_nonfinite=True,
359
- )
360
- self.optimizer.step()
361
-
362
- debug_per_group += f"g:{round(grad_norm.item(), 3)!s}"
363
- self.log(f"G[{group_idx}] {debug_per_group}")
364
-
365
- metrics.update(
366
- {
367
- "grad_norm": grad_norm.item()
368
- if isinstance(grad_norm, torch.Tensor)
369
- else float(grad_norm),
370
- }
371
- )
372
-
373
- # Calculate training time and throughput
374
- training_time = time.time() - training_start_time
375
- total_samples = (
376
- len(groups) * self.config.training.group_size * self.config.training.mini_batch_size
377
- )
378
- samples_per_second = total_samples / training_time if training_time > 0 else 0.0
379
-
380
- metrics.update(
381
- {
382
- "training_time": training_time,
383
- "samples_per_second": samples_per_second,
384
- }
385
- )
386
-
387
- return metrics
388
-
389
- def compute_loss(self, sample: TrainingSample) -> torch.Tensor:
390
- """Compute GRPO loss for a batch of samples."""
391
- training_cfg = self.config.training
392
- metrics = self.metrics[-1] if len(self.metrics) > 0 else TrainingMetrics()
393
-
394
- sample.to_device(self.device)
395
-
396
- pol_logp, pol_entropy = self.compute_logprobs(
397
- self.policy,
398
- sample.inputs,
399
- )
400
-
401
- sanity_check(sample, pol_logp, sample.old_logprobs, sample.ref_logprobs)
402
-
403
- metrics.update(
404
- {
405
- "gpu_util": get_gpu_utilization(), # Track peak utilization
406
- "gpu_memory": get_memory_usage(), # Track memory usage
407
- }
408
- )
409
- self.log(f"GPU Util: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB")
410
-
411
- old_logp = sample.old_logprobs
412
- ref_logp = sample.ref_logprobs
413
-
414
- if old_logp is None or ref_logp is None or sample.advantage is None:
415
- raise ValueError("old_logp, ref_logp, or sample.advantage is None")
416
-
417
- # Use assistant mask to remove non-assistant tokens
418
- m = sample.inputs["assistant_mask"]
419
-
420
- # Aggregate per trace or per token
421
- if training_cfg.ppo_mode == "per_trace":
422
- counts = m.sum(dim=1).clamp_min(1.0)
423
- pol_logp = (pol_logp * m.float()).sum(dim=1) / counts
424
- pol_entropy = (pol_entropy * m.float()).sum(dim=1) / counts
425
- old_logp = (old_logp * m.float()).sum(dim=1) / counts
426
- ref_logp = (ref_logp * m.float()).sum(dim=1) / counts
427
-
428
- # Clip log probability differences
429
- log_ratio = torch.where(m, pol_logp - old_logp, torch.zeros_like(pol_logp))
430
- ratio_tok = torch.exp(log_ratio.clamp(-20.0, 20.0))
431
-
432
- # Ensure advantage shape matches ratio_tok for broadcasting
433
- advantage = (
434
- sample.advantage.view(-1, 1) if ratio_tok.dim() == 2 else sample.advantage.squeeze(-1)
435
- )
436
-
437
- unclipped = ratio_tok * advantage
438
- clipped = (
439
- torch.clamp(ratio_tok, 1 - training_cfg.top_eps, 1 + training_cfg.bottom_eps)
440
- * advantage
441
- )
442
-
443
- policy_term = -torch.minimum(unclipped, clipped)
444
-
445
- # Clip log probability differences in KL
446
- log_rho = torch.where(m, pol_logp - ref_logp, torch.zeros_like(pol_logp))
447
- rho_tok = torch.exp(log_rho.clamp(-20.0, 20.0))
448
- kl_approx = rho_tok - torch.log(rho_tok) - 1
449
-
450
- total_loss = (
451
- policy_term + training_cfg.kl_beta * kl_approx + training_cfg.entropy_beta * pol_entropy
452
- )
453
-
454
- # Aggregate loss
455
- if training_cfg.ppo_mode == "per_trace":
456
- total_loss = total_loss.mean() if training_cfg.token_agg == "mean" else total_loss.sum() # noqa: S105
457
- else:
458
- if training_cfg.token_agg == "mean": # noqa: S105
459
- total_loss = (total_loss * m).sum() / m.sum().clamp_min(1.0)
460
- else:
461
- total_loss = (total_loss * m).sum()
462
-
463
- # Compute metrics only over masked (assistant) tokens
464
- mask_count = m.sum().clamp_min(1.0)
465
- metrics.update(
466
- {
467
- "policy_ratio": (ratio_tok * m).sum().item() / mask_count.item()
468
- if mask_count.item() > 0
469
- else 1.0,
470
- "kl": (kl_approx * m).sum().item() / mask_count.item()
471
- if mask_count.item() > 0
472
- else 0.0,
473
- "entropy": (pol_entropy * m).sum().item() / mask_count.item()
474
- if mask_count.item() > 0
475
- else 0.0,
476
- "tokens": sample.inputs["input_ids"].numel(),
477
- "loss": total_loss.item(),
478
- }
479
- )
480
-
481
- sample.to_device(torch.device("cpu"))
482
-
483
- return total_loss
484
-
485
- def compute_logprobs(self, model: Any, inputs: Any) -> tuple[torch.Tensor, torch.Tensor]:
486
- """Compute masked per-token log probabilities via the model.
487
-
488
- Returns log probabilities for the actual next tokens.
489
- """
490
- try:
491
- model_inputs = {k: v for k, v in inputs.items() if k != "assistant_mask"}
492
- out = model(**model_inputs)
493
-
494
- logits = out.logits / self.config.actor.temperature
495
-
496
- targets = inputs["input_ids"][:, 1:]
497
-
498
- # Align logits to predict next token: use logits[:, :-1, :]
499
- next_logits = logits[:, :-1, :]
500
-
501
- token_log_probs = _selective_log_softmax(next_logits, targets)
502
-
503
- # Compute entropy only for assistant tokens to save memory
504
- assistant_mask = inputs["assistant_mask"]
505
- entropy = torch.zeros_like(token_log_probs)
506
- if assistant_mask.any():
507
- entropy[assistant_mask] = entropy_from_logits(logits[:, :-1][assistant_mask])
508
-
509
- return token_log_probs, entropy
510
- except (IndexError, RuntimeError) as e:
511
- # Handle empty inputs or DDP errors
512
- hud_console.warning_log(f"Error in compute_logprobs: {e}. Returning dummy values.")
513
- # Return dummy values that match expected shapes
514
- seq_len = inputs["input_ids"].shape[1] - 1 if "input_ids" in inputs else 0
515
- batch_size = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
516
- # Create dummy tensors that still participate in autograd so backward doesn't fail
517
- try:
518
- # Touch params to build a graph
519
- param_sum = torch.sum(next(self.policy.parameters()))
520
- base = param_sum * 0.0
521
- except StopIteration:
522
- base = torch.tensor(0.0, device=self.device)
523
- dummy_logprobs = (
524
- base + torch.zeros(batch_size, seq_len, device=self.device)
525
- ).requires_grad_(True)
526
- dummy_entropy = (
527
- base + torch.zeros(batch_size, seq_len, device=self.device)
528
- ).requires_grad_(True)
529
- return dummy_logprobs, dummy_entropy
530
-
531
- def save(self, path: str) -> None:
532
- """Save the current policy checkpoint (only on rank 0)."""
533
- if is_main_process():
534
- os.makedirs(path, exist_ok=True)
535
- # Unwrap DDP model if needed
536
- model_to_save = self.policy.module if hasattr(self.policy, "module") else self.policy
537
- model_to_save.save_pretrained(path)
538
- self.log(f"Saved checkpoint to {path}")
539
-
540
- def load(self, path: str) -> None:
541
- """Load a policy checkpoint."""
542
- # Would need to reload LoRA weights
543
- self.log(f"Loading checkpoint from {path}")
544
- # Implementation depends on PEFT version
545
-
546
-
547
- def sanity_check(
548
- sample: TrainingSample,
549
- pol_logp: torch.Tensor,
550
- old_logp: torch.Tensor | None,
551
- ref_logp: torch.Tensor | None,
552
- ) -> None:
553
- assert "assistant_mask" in sample.inputs # noqa: S101
554
- m = sample.inputs["assistant_mask"]
555
- if old_logp is None or ref_logp is None:
556
- return
557
- with torch.no_grad():
558
- B, K = pol_logp.shape
559
- assert old_logp.shape == (B, K), "old_logp shape mismatch" # noqa: S101
560
- assert ref_logp.shape == (B, K), "ref_logp shape mismatch" # noqa: S101
561
- assert m.shape == (B, K), "assistant_mask shape mismatch" # noqa: S101
562
-
563
- # Check mask is subset of attention_mask[:, 1:]
564
- att = sample.inputs.get("attention_mask", None)
565
- if att is not None and att.dim() == 2:
566
- att_shift = att[:, 1:].bool()
567
- bad = (m & ~att_shift).sum().item()
568
- if bad > 0:
569
- hud_console.warning_log(f"assistant_mask overlaps padding: {bad} tokens")
570
-
571
- # Finiteness on masked entries only
572
- def _stats(name: str, t: torch.Tensor) -> None:
573
- sel = t[m]
574
- if sel.numel() == 0:
575
- hud_console.warning_log(f"{name} empty under mask")
576
- return
577
- finite = torch.isfinite(sel)
578
- if finite.sum() < sel.numel():
579
- hud_console.warning_log(
580
- f"{name} non-finite: {((~finite).sum().item())}/{sel.numel()}"
581
- )
582
- sel = sel[finite].float()
583
-
584
- _stats("pol_logp", pol_logp)
585
- _stats("old_logp", old_logp)
586
- _stats("ref_logp", ref_logp)
587
-
588
- # Log-probabilities should be <= 0 (log-softmax)
589
- if (pol_logp[m] > 1e-6).any():
590
- hud_console.warning_log("pol_logp has positive values under mask")
591
-
592
- # Precompute masked deltas and ratios for diagnostics (before exp)
593
- masked_log_ratio = torch.zeros_like(pol_logp)
594
- masked_log_ratio[m] = (pol_logp - old_logp)[m]
595
- masked_log_rho = torch.zeros_like(pol_logp)
596
- masked_log_rho[m] = (pol_logp - ref_logp)[m]
597
-
598
- _stats("log_ratio(masked)", masked_log_ratio)
599
- _stats("log_rho(masked)", masked_log_rho)
600
-
601
- # Ratios after clamp (diagnostic only)
602
- ratio_diag = torch.zeros_like(pol_logp)
603
- rho_diag = torch.zeros_like(pol_logp)
604
- ratio_diag[m] = torch.exp(masked_log_ratio[m].clamp(-20.0, 20.0))
605
- rho_diag[m] = torch.exp(masked_log_rho[m].clamp(-20.0, 20.0))
606
- _stats("ratio_tok(masked)", ratio_diag)
607
- _stats("rho_tok(masked)", rho_diag)
608
-
609
-
610
- def _selective_log_softmax(
611
- logits_bt_v: torch.Tensor,
612
- index_bt: torch.Tensor,
613
- ) -> torch.Tensor:
614
- """Gather log softmax for selected indices with reduced peak memory.
615
-
616
- Uses logsumexp subtraction for float32/64; falls back to per-row
617
- log_softmax for bf16/fp16.
618
- logits_bt_v: [B, T, V]
619
- index_bt: [B, T]
620
- Returns: [B, T]
621
- """
622
- if logits_bt_v.dtype in (torch.float32, torch.float64):
623
- # Compute logsumexp per [B, T] in a loop over batch to reduce
624
- # peak from B*T*V to T*V
625
- logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits_bt_v])
626
- selected_logits = torch.gather(logits_bt_v, dim=-1, index=index_bt.unsqueeze(-1)).squeeze(
627
- -1
628
- )
629
- return selected_logits - logsumexp_values
630
- # Reduced precision: numerically stable route using per-row log_softmax
631
- token_logprobs_rows: list[torch.Tensor] = []
632
- for logits_row, index_row in zip(logits_bt_v, index_bt, strict=True):
633
- logprobs_row = logits_row.log_softmax(dim=-1)
634
- token_logprobs_rows.append(
635
- torch.gather(logprobs_row, dim=-1, index=index_row.unsqueeze(-1)).squeeze(-1)
636
- )
637
- return torch.stack(token_logprobs_rows)
hud/rl/tests/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Tests for RL module."""