hud-python 0.4.45__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (274) hide show
  1. hud/__init__.py +27 -7
  2. hud/agents/__init__.py +11 -5
  3. hud/agents/base.py +220 -500
  4. hud/agents/claude.py +200 -240
  5. hud/agents/gemini.py +275 -0
  6. hud/agents/gemini_cua.py +335 -0
  7. hud/agents/grounded_openai.py +98 -100
  8. hud/agents/misc/integration_test_agent.py +51 -20
  9. hud/agents/misc/response_agent.py +41 -36
  10. hud/agents/openai.py +291 -292
  11. hud/agents/{openai_chat_generic.py → openai_chat.py} +80 -34
  12. hud/agents/operator.py +211 -0
  13. hud/agents/tests/conftest.py +133 -0
  14. hud/agents/tests/test_base.py +300 -622
  15. hud/agents/tests/test_base_runtime.py +233 -0
  16. hud/agents/tests/test_claude.py +379 -210
  17. hud/agents/tests/test_client.py +9 -10
  18. hud/agents/tests/test_gemini.py +369 -0
  19. hud/agents/tests/test_grounded_openai_agent.py +65 -50
  20. hud/agents/tests/test_openai.py +376 -140
  21. hud/agents/tests/test_operator.py +362 -0
  22. hud/agents/tests/test_run_eval.py +179 -0
  23. hud/cli/__init__.py +461 -545
  24. hud/cli/analyze.py +43 -5
  25. hud/cli/build.py +664 -110
  26. hud/cli/debug.py +8 -5
  27. hud/cli/dev.py +882 -734
  28. hud/cli/eval.py +782 -668
  29. hud/cli/flows/dev.py +167 -0
  30. hud/cli/flows/init.py +191 -0
  31. hud/cli/flows/tasks.py +153 -56
  32. hud/cli/flows/templates.py +151 -0
  33. hud/cli/flows/tests/__init__.py +1 -0
  34. hud/cli/flows/tests/test_dev.py +126 -0
  35. hud/cli/init.py +60 -58
  36. hud/cli/push.py +29 -11
  37. hud/cli/rft.py +311 -0
  38. hud/cli/rft_status.py +145 -0
  39. hud/cli/tests/test_analyze.py +5 -5
  40. hud/cli/tests/test_analyze_metadata.py +3 -2
  41. hud/cli/tests/test_analyze_module.py +120 -0
  42. hud/cli/tests/test_build.py +108 -6
  43. hud/cli/tests/test_build_failure.py +41 -0
  44. hud/cli/tests/test_build_module.py +50 -0
  45. hud/cli/tests/test_cli_init.py +6 -1
  46. hud/cli/tests/test_cli_more_wrappers.py +30 -0
  47. hud/cli/tests/test_cli_root.py +140 -0
  48. hud/cli/tests/test_convert.py +361 -0
  49. hud/cli/tests/test_debug.py +12 -10
  50. hud/cli/tests/test_dev.py +197 -0
  51. hud/cli/tests/test_eval.py +251 -0
  52. hud/cli/tests/test_eval_bedrock.py +51 -0
  53. hud/cli/tests/test_init.py +124 -0
  54. hud/cli/tests/test_main_module.py +11 -5
  55. hud/cli/tests/test_mcp_server.py +12 -100
  56. hud/cli/tests/test_push_happy.py +74 -0
  57. hud/cli/tests/test_push_wrapper.py +23 -0
  58. hud/cli/tests/test_registry.py +1 -1
  59. hud/cli/tests/test_utils.py +1 -1
  60. hud/cli/{rl → utils}/celebrate.py +14 -12
  61. hud/cli/utils/config.py +18 -1
  62. hud/cli/utils/docker.py +130 -4
  63. hud/cli/utils/env_check.py +9 -9
  64. hud/cli/utils/git.py +136 -0
  65. hud/cli/utils/interactive.py +39 -5
  66. hud/cli/utils/metadata.py +69 -0
  67. hud/cli/utils/runner.py +1 -1
  68. hud/cli/utils/server.py +2 -2
  69. hud/cli/utils/source_hash.py +3 -3
  70. hud/cli/utils/tasks.py +4 -1
  71. hud/cli/utils/tests/__init__.py +0 -0
  72. hud/cli/utils/tests/test_config.py +58 -0
  73. hud/cli/utils/tests/test_docker.py +93 -0
  74. hud/cli/utils/tests/test_docker_hints.py +71 -0
  75. hud/cli/utils/tests/test_env_check.py +74 -0
  76. hud/cli/utils/tests/test_environment.py +42 -0
  77. hud/cli/utils/tests/test_git.py +142 -0
  78. hud/cli/utils/tests/test_interactive_module.py +60 -0
  79. hud/cli/utils/tests/test_local_runner.py +50 -0
  80. hud/cli/utils/tests/test_logging_utils.py +23 -0
  81. hud/cli/utils/tests/test_metadata.py +49 -0
  82. hud/cli/utils/tests/test_package_runner.py +35 -0
  83. hud/cli/utils/tests/test_registry_utils.py +49 -0
  84. hud/cli/utils/tests/test_remote_runner.py +25 -0
  85. hud/cli/utils/tests/test_runner_modules.py +52 -0
  86. hud/cli/utils/tests/test_source_hash.py +36 -0
  87. hud/cli/utils/tests/test_tasks.py +80 -0
  88. hud/cli/utils/version_check.py +258 -0
  89. hud/cli/{rl → utils}/viewer.py +2 -2
  90. hud/clients/README.md +12 -11
  91. hud/clients/__init__.py +4 -3
  92. hud/clients/base.py +166 -26
  93. hud/clients/environment.py +51 -0
  94. hud/clients/fastmcp.py +13 -6
  95. hud/clients/mcp_use.py +40 -15
  96. hud/clients/tests/test_analyze_scenarios.py +206 -0
  97. hud/clients/tests/test_protocol.py +9 -3
  98. hud/datasets/__init__.py +23 -20
  99. hud/datasets/loader.py +327 -0
  100. hud/datasets/runner.py +192 -105
  101. hud/datasets/tests/__init__.py +0 -0
  102. hud/datasets/tests/test_loader.py +221 -0
  103. hud/datasets/tests/test_utils.py +315 -0
  104. hud/datasets/utils.py +270 -90
  105. hud/environment/__init__.py +50 -0
  106. hud/environment/connection.py +206 -0
  107. hud/environment/connectors/__init__.py +33 -0
  108. hud/environment/connectors/base.py +68 -0
  109. hud/environment/connectors/local.py +177 -0
  110. hud/environment/connectors/mcp_config.py +109 -0
  111. hud/environment/connectors/openai.py +101 -0
  112. hud/environment/connectors/remote.py +172 -0
  113. hud/environment/environment.py +694 -0
  114. hud/environment/integrations/__init__.py +45 -0
  115. hud/environment/integrations/adk.py +67 -0
  116. hud/environment/integrations/anthropic.py +196 -0
  117. hud/environment/integrations/gemini.py +92 -0
  118. hud/environment/integrations/langchain.py +82 -0
  119. hud/environment/integrations/llamaindex.py +68 -0
  120. hud/environment/integrations/openai.py +238 -0
  121. hud/environment/mock.py +306 -0
  122. hud/environment/router.py +112 -0
  123. hud/environment/scenarios.py +493 -0
  124. hud/environment/tests/__init__.py +1 -0
  125. hud/environment/tests/test_connection.py +317 -0
  126. hud/environment/tests/test_connectors.py +218 -0
  127. hud/environment/tests/test_environment.py +161 -0
  128. hud/environment/tests/test_integrations.py +257 -0
  129. hud/environment/tests/test_local_connectors.py +201 -0
  130. hud/environment/tests/test_scenarios.py +280 -0
  131. hud/environment/tests/test_tools.py +208 -0
  132. hud/environment/types.py +23 -0
  133. hud/environment/utils/__init__.py +35 -0
  134. hud/environment/utils/formats.py +215 -0
  135. hud/environment/utils/schema.py +171 -0
  136. hud/environment/utils/tool_wrappers.py +113 -0
  137. hud/eval/__init__.py +67 -0
  138. hud/eval/context.py +674 -0
  139. hud/eval/display.py +299 -0
  140. hud/eval/instrument.py +185 -0
  141. hud/eval/manager.py +466 -0
  142. hud/eval/parallel.py +268 -0
  143. hud/eval/task.py +340 -0
  144. hud/eval/tests/__init__.py +1 -0
  145. hud/eval/tests/test_context.py +178 -0
  146. hud/eval/tests/test_eval.py +210 -0
  147. hud/eval/tests/test_manager.py +152 -0
  148. hud/eval/tests/test_parallel.py +168 -0
  149. hud/eval/tests/test_task.py +145 -0
  150. hud/eval/types.py +63 -0
  151. hud/eval/utils.py +183 -0
  152. hud/patches/__init__.py +19 -0
  153. hud/patches/mcp_patches.py +151 -0
  154. hud/patches/warnings.py +54 -0
  155. hud/samples/browser.py +4 -4
  156. hud/server/__init__.py +2 -1
  157. hud/server/low_level.py +2 -1
  158. hud/server/router.py +164 -0
  159. hud/server/server.py +567 -80
  160. hud/server/tests/test_mcp_server_integration.py +11 -11
  161. hud/server/tests/test_mcp_server_more.py +1 -1
  162. hud/server/tests/test_server_extra.py +2 -0
  163. hud/settings.py +45 -3
  164. hud/shared/exceptions.py +36 -10
  165. hud/shared/hints.py +26 -1
  166. hud/shared/requests.py +15 -3
  167. hud/shared/tests/test_exceptions.py +40 -31
  168. hud/shared/tests/test_hints.py +167 -0
  169. hud/telemetry/__init__.py +20 -19
  170. hud/telemetry/exporter.py +201 -0
  171. hud/telemetry/instrument.py +158 -253
  172. hud/telemetry/tests/test_eval_telemetry.py +356 -0
  173. hud/telemetry/tests/test_exporter.py +258 -0
  174. hud/telemetry/tests/test_instrument.py +401 -0
  175. hud/tools/__init__.py +16 -2
  176. hud/tools/apply_patch.py +639 -0
  177. hud/tools/base.py +54 -4
  178. hud/tools/bash.py +2 -2
  179. hud/tools/computer/__init__.py +4 -0
  180. hud/tools/computer/anthropic.py +2 -2
  181. hud/tools/computer/gemini.py +385 -0
  182. hud/tools/computer/hud.py +23 -6
  183. hud/tools/computer/openai.py +20 -21
  184. hud/tools/computer/qwen.py +434 -0
  185. hud/tools/computer/settings.py +37 -0
  186. hud/tools/edit.py +3 -7
  187. hud/tools/executors/base.py +4 -2
  188. hud/tools/executors/pyautogui.py +1 -1
  189. hud/tools/grounding/grounded_tool.py +13 -18
  190. hud/tools/grounding/grounder.py +10 -31
  191. hud/tools/grounding/tests/test_grounded_tool.py +26 -44
  192. hud/tools/jupyter.py +330 -0
  193. hud/tools/playwright.py +18 -3
  194. hud/tools/shell.py +308 -0
  195. hud/tools/tests/test_apply_patch.py +718 -0
  196. hud/tools/tests/test_computer.py +4 -9
  197. hud/tools/tests/test_computer_actions.py +24 -2
  198. hud/tools/tests/test_jupyter_tool.py +181 -0
  199. hud/tools/tests/test_shell.py +596 -0
  200. hud/tools/tests/test_submit.py +85 -0
  201. hud/tools/tests/test_types.py +193 -0
  202. hud/tools/types.py +21 -1
  203. hud/types.py +167 -57
  204. hud/utils/__init__.py +2 -0
  205. hud/utils/env.py +67 -0
  206. hud/utils/hud_console.py +61 -3
  207. hud/utils/mcp.py +15 -58
  208. hud/utils/strict_schema.py +162 -0
  209. hud/utils/tests/test_init.py +1 -2
  210. hud/utils/tests/test_mcp.py +1 -28
  211. hud/utils/tests/test_pretty_errors.py +186 -0
  212. hud/utils/tests/test_tool_shorthand.py +154 -0
  213. hud/utils/tests/test_version.py +1 -1
  214. hud/utils/types.py +20 -0
  215. hud/version.py +1 -1
  216. hud_python-0.5.1.dist-info/METADATA +264 -0
  217. hud_python-0.5.1.dist-info/RECORD +299 -0
  218. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/WHEEL +1 -1
  219. hud/agents/langchain.py +0 -261
  220. hud/agents/lite_llm.py +0 -72
  221. hud/cli/rl/__init__.py +0 -180
  222. hud/cli/rl/config.py +0 -101
  223. hud/cli/rl/display.py +0 -133
  224. hud/cli/rl/gpu.py +0 -63
  225. hud/cli/rl/gpu_utils.py +0 -321
  226. hud/cli/rl/local_runner.py +0 -595
  227. hud/cli/rl/presets.py +0 -96
  228. hud/cli/rl/remote_runner.py +0 -463
  229. hud/cli/rl/rl_api.py +0 -150
  230. hud/cli/rl/vllm.py +0 -177
  231. hud/cli/rl/wait_utils.py +0 -89
  232. hud/datasets/parallel.py +0 -687
  233. hud/misc/__init__.py +0 -1
  234. hud/misc/claude_plays_pokemon.py +0 -292
  235. hud/otel/__init__.py +0 -35
  236. hud/otel/collector.py +0 -142
  237. hud/otel/config.py +0 -181
  238. hud/otel/context.py +0 -570
  239. hud/otel/exporters.py +0 -369
  240. hud/otel/instrumentation.py +0 -135
  241. hud/otel/processors.py +0 -121
  242. hud/otel/tests/__init__.py +0 -1
  243. hud/otel/tests/test_processors.py +0 -197
  244. hud/rl/README.md +0 -30
  245. hud/rl/__init__.py +0 -1
  246. hud/rl/actor.py +0 -176
  247. hud/rl/buffer.py +0 -405
  248. hud/rl/chat_template.jinja +0 -101
  249. hud/rl/config.py +0 -192
  250. hud/rl/distributed.py +0 -132
  251. hud/rl/learner.py +0 -637
  252. hud/rl/tests/__init__.py +0 -1
  253. hud/rl/tests/test_learner.py +0 -186
  254. hud/rl/train.py +0 -382
  255. hud/rl/types.py +0 -101
  256. hud/rl/utils/start_vllm_server.sh +0 -30
  257. hud/rl/utils.py +0 -524
  258. hud/rl/vllm_adapter.py +0 -143
  259. hud/telemetry/job.py +0 -352
  260. hud/telemetry/replay.py +0 -74
  261. hud/telemetry/tests/test_replay.py +0 -40
  262. hud/telemetry/tests/test_trace.py +0 -63
  263. hud/telemetry/trace.py +0 -158
  264. hud/utils/agent_factories.py +0 -86
  265. hud/utils/async_utils.py +0 -65
  266. hud/utils/group_eval.py +0 -223
  267. hud/utils/progress.py +0 -149
  268. hud/utils/tasks.py +0 -127
  269. hud/utils/tests/test_async_utils.py +0 -173
  270. hud/utils/tests/test_progress.py +0 -261
  271. hud_python-0.4.45.dist-info/METADATA +0 -552
  272. hud_python-0.4.45.dist-info/RECORD +0 -228
  273. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/entry_points.txt +0 -0
  274. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/licenses/LICENSE +0 -0
hud/rl/buffer.py DELETED
@@ -1,405 +0,0 @@
1
- """Replay buffer for storing and sampling episodes."""
2
-
3
- from __future__ import annotations
4
-
5
- import logging
6
- import random
7
- from collections import deque
8
- from typing import TYPE_CHECKING, Generic, TypeVar
9
-
10
- from hud.types import Task, Trace
11
- from hud.utils.hud_console import HUDConsole
12
-
13
- logger = logging.getLogger(__name__)
14
- hud_console = HUDConsole(logger=logger)
15
-
16
- T = TypeVar("T")
17
-
18
- if TYPE_CHECKING:
19
- from collections.abc import Callable
20
-
21
- from hud.rl.config import Config
22
-
23
-
24
- class Buffer(Generic[T]):
25
- """Simple buffer for a list of tasks, traces or episodes."""
26
-
27
- def __init__(self, max_size: int = 10000) -> None:
28
- self.max_size = max_size
29
- self.buffer: deque[T] = deque(maxlen=max_size)
30
-
31
- def add(self, items: list[T] | T, shuffle: bool = False) -> None:
32
- """Add items to buffer."""
33
- if isinstance(items, list):
34
- for item in items:
35
- self.buffer.append(item)
36
- else:
37
- self.buffer.append(items)
38
- if shuffle:
39
- random.shuffle(self.buffer)
40
-
41
- def add_fill(self, items: list[T] | T, target_size: int, shuffle: bool = False) -> None:
42
- """Add items to buffer until the buffer is at least the target size."""
43
- while len(self.buffer) < target_size:
44
- self.add(items, shuffle)
45
-
46
- def get(self, n: int = 0) -> list[T]:
47
- """Get items from the buffer."""
48
- if n == 0:
49
- return list(self.buffer)
50
- if n > len(self.buffer):
51
- raise ValueError("Not enough items in buffer")
52
- return list(self.buffer)[-n:]
53
-
54
- def consume(self, n: int = 0) -> list[T]:
55
- """Consume items from the buffer."""
56
- if n == 0:
57
- return list(self.buffer)
58
- if n > len(self.buffer):
59
- raise ValueError("Not enough items in buffer")
60
-
61
- return [self.buffer.pop() for _ in range(n)]
62
-
63
- def get_filtered(
64
- self, n: int = 0, filter_fn: Callable[[T], bool] | None = None, consume: bool = False
65
- ) -> list[T]:
66
- """Filter the buffer by a filter function."""
67
- filtered = (
68
- [item for item in self.buffer if filter_fn(item)] if filter_fn else list(self.buffer)
69
- )
70
- if n == 0:
71
- return filtered
72
- return self.consume(n) if consume else self.get(n)
73
-
74
- def sample(
75
- self,
76
- batch_size: int,
77
- n: int = 0,
78
- filter_fn: Callable[[T], bool] | None = None,
79
- consume: bool = False,
80
- ) -> list[T]:
81
- """Sample a batch of items with optional filtering."""
82
- items = self.get_filtered(n, filter_fn, consume)
83
-
84
- if len(items) < batch_size:
85
- hud_console.warning(f"Buffer has {len(items)} items, requested {batch_size}")
86
- return items
87
-
88
- return random.sample(items, batch_size)
89
-
90
- def clear(self) -> None:
91
- """Clear the buffer."""
92
- self.buffer.clear()
93
-
94
- def __len__(self) -> int:
95
- """Use len() directly on Buffer instances."""
96
- return len(self.buffer)
97
-
98
-
99
- class DatasetBuffer(Buffer[Task]):
100
- """
101
- Buffer for a dataset.
102
- Loads in individual tasks that will be trained for a specified number of training steps.
103
- """
104
-
105
- def __init__(
106
- self,
107
- dataset: list[Task] | Task,
108
- config: Config,
109
- ) -> None:
110
- self.config = config
111
-
112
- self.group_size = config.training.group_size
113
- self.batch_size = config.training.batch_size
114
- self.training_steps = config.training.training_steps
115
-
116
- if self.group_size > self.batch_size:
117
- raise ValueError(
118
- f"Group size is greater than batch size, {self.group_size} > {self.batch_size}"
119
- )
120
-
121
- if self.batch_size % self.group_size != 0:
122
- raise ValueError(
123
- f"A batch cannot have irregular groups, {self.group_size} % {self.batch_size} != 0"
124
- )
125
-
126
- if self.group_size % config.training.mini_batch_size != 0:
127
- raise ValueError(
128
- f"Group size is not a multiple of mini batch size, {self.group_size} % {config.training.mini_batch_size} != 0" # noqa: E501
129
- )
130
-
131
- self.groups_per_batch = self.batch_size // self.group_size
132
- self.number_of_tasks = self.training_steps * self.groups_per_batch
133
-
134
- super().__init__(self.number_of_tasks)
135
-
136
- dataset = dataset if isinstance(dataset, list) else [dataset]
137
- tasks = self._validate_tasks(dataset)
138
- if config.training.shuffle_dataset:
139
- random.shuffle(tasks)
140
- if len(tasks) > self.number_of_tasks:
141
- leftovers = len(tasks) - self.number_of_tasks
142
- hud_console.warning(
143
- f"Training steps ({self.training_steps}) will lead to {leftovers} tasks not being trained" # noqa: E501
144
- )
145
- tasks = tasks[: self.number_of_tasks]
146
-
147
- # Check if the dataset is imbalanced
148
- self.dataset_size = len(tasks)
149
- if self.training_steps % self.dataset_size != 0:
150
- leftovers = self.number_of_tasks % self.dataset_size
151
- hud_console.warning(
152
- f"Dataset imbalanced ({leftovers} tasks will be trained 1 more time)"
153
- )
154
- hud_console.warning(
155
- f"This is because the number of training steps ({self.training_steps}) is not a multiple of the dataset size ({self.dataset_size})" # noqa: E501
156
- )
157
-
158
- if config.verbose:
159
- hud_console.info(f"Sample task: {tasks[0]}")
160
-
161
- self.add_fill(tasks, self.number_of_tasks, config.training.shuffle_dataset)
162
-
163
- def _validate_tasks(self, tasks: list[Task]) -> list[Task]:
164
- """Validate that all tasks are proper HUD Task objects."""
165
- if not tasks:
166
- raise ValueError("No tasks provided to DatasetBuffer")
167
-
168
- validated_tasks = []
169
- for i, task in enumerate(tasks):
170
- if not isinstance(task, Task):
171
- raise TypeError(f"Task at index {i} is not a HUD Task object, got {type(task)}")
172
- validated_tasks.append(task)
173
-
174
- return validated_tasks
175
-
176
- @property
177
- def info(self) -> dict[str, int | float | str]:
178
- """Get the info of the buffer."""
179
- return {
180
- "total_items": len(self),
181
- "total_traces": self.number_of_tasks * self.group_size,
182
- "total_batches": self.training_steps,
183
- "task_repeats": self.number_of_tasks // self.dataset_size,
184
- "dataset_size": self.dataset_size,
185
- "group_size": self.group_size,
186
- "batch_size": self.batch_size,
187
- }
188
-
189
- def get_tasks(self, consume: bool = True) -> list[Task]:
190
- """Get tasks for a batch."""
191
- tasks = self.consume(self.groups_per_batch) if consume else self.get(self.groups_per_batch)
192
- # Create groups where each group contains group_size copies of the same task
193
- result = []
194
- for task in tasks:
195
- result.extend([task] * self.group_size)
196
- return result
197
-
198
-
199
- class ReplayBuffer(Buffer[Trace]):
200
- """Buffer for traces."""
201
-
202
- def __init__(self, config: Config) -> None:
203
- self.config = config
204
-
205
- self.buffer_steps = config.training.buffer_steps
206
- self.select_strategy = config.training.select_strategy
207
- self.group_size = config.training.group_size
208
- self.batch_size = config.training.batch_size
209
-
210
- buffer_size = self.buffer_steps * self.batch_size
211
-
212
- super().__init__(buffer_size)
213
-
214
- def sample_traces(self) -> list[Trace]:
215
- """Sample traces for a batch."""
216
- if self.select_strategy == "recent":
217
- return self.get(self.batch_size)
218
- elif self.select_strategy == "random":
219
- return self.sample(self.batch_size)
220
- elif self.select_strategy == "variance":
221
- return self._sample_high_variance_traces()
222
- else:
223
- raise ValueError(f"Invalid select strategy: {self.select_strategy}")
224
-
225
- def _extract_group_key(self, trace: Trace) -> tuple[str, str]:
226
- """Return a stable grouping key for a trace.
227
-
228
- Preference order:
229
- 1) task.id when present (kind='id')
230
- 2) task.prompt exact string (kind='prompt') when id is None
231
- 3) 'NA' for missing/errored entries (kind='NA')
232
- """
233
- if getattr(trace, "isError", False):
234
- return ("NA", "NA")
235
-
236
- task = getattr(trace, "task", None)
237
- if task is None:
238
- return ("NA", "NA")
239
-
240
- tid = getattr(task, "id", None)
241
- if tid is not None:
242
- return ("id", str(tid))
243
-
244
- prompt = getattr(task, "prompt", None)
245
- if prompt:
246
- return ("prompt", str(prompt))
247
-
248
- return ("NA", "NA")
249
-
250
- def _validate_and_split_groups(
251
- self, recent_traces: list[Trace]
252
- ) -> tuple[list[list[Trace]], list[tuple[str, str]]]:
253
- """Validate and split recent traces into homogeneous groups by id or prompt.
254
-
255
- - Uses id when present; otherwise falls back to prompt equality.
256
- - Any NA/error traces are excluded and the group is filled by duplicating
257
- existing valid members in that group.
258
- - Always returns len == groups_per_batch groups of size == group_size.
259
- """
260
- from collections import Counter
261
-
262
- groups_per_batch = self.batch_size // self.group_size
263
-
264
- window_keys = [self._extract_group_key(t) for t in recent_traces]
265
- window_counter = Counter(k for k in window_keys if k[0] != "NA")
266
-
267
- validated_groups: list[list[Trace]] = []
268
- selected_keys: list[tuple[str, str]] = []
269
-
270
- for g_idx in range(groups_per_batch):
271
- start = g_idx * self.group_size
272
- end = start + self.group_size
273
- chunk = recent_traces[start:end]
274
-
275
- key_counts = Counter()
276
- per_item_keys: list[tuple[str, str]] = []
277
- for tr in chunk:
278
- k = self._extract_group_key(tr)
279
- per_item_keys.append(k)
280
- if k[0] != "NA":
281
- key_counts[k] += 1
282
-
283
- if key_counts:
284
- best_key = key_counts.most_common(1)[0][0]
285
- elif window_counter:
286
- best_key = window_counter.most_common(1)[0][0]
287
- else:
288
- best_key = ("NA", "NA")
289
-
290
- homogeneous = [tr for tr, k in zip(chunk, per_item_keys, strict=False) if k == best_key]
291
-
292
- while len(homogeneous) < self.group_size:
293
- if homogeneous:
294
- homogeneous.append(homogeneous[-1])
295
- else:
296
- idx = next((i for i, wk in enumerate(window_keys) if wk[0] != "NA"), None)
297
- if idx is not None:
298
- homogeneous.append(recent_traces[idx])
299
- elif chunk:
300
- homogeneous.append(chunk[0])
301
- else:
302
- homogeneous.append(recent_traces[0])
303
-
304
- validated_groups.append(homogeneous)
305
- selected_keys.append(best_key)
306
-
307
- return validated_groups, selected_keys
308
-
309
- def _sample_high_variance_traces(self) -> list[Trace]:
310
- from collections import Counter, defaultdict, deque
311
-
312
- buf_list = list(self.buffer)
313
- if len(buf_list) < self.batch_size:
314
- hud_console.warning(
315
- f"[group-sampler] Buffer has only {len(buf_list)} traces, need {self.batch_size}"
316
- )
317
- while len(buf_list) < self.batch_size:
318
- take = min(len(buf_list) or 1, self.batch_size - len(buf_list))
319
- buf_list.extend(buf_list[:take])
320
- recent_traces = buf_list[-self.batch_size :]
321
-
322
- recent_keys = [self._extract_group_key(t) for t in recent_traces]
323
- hud_console.info(f"[group-sampler] recent-window histogram: {Counter(recent_keys)}")
324
-
325
- hud_console.info(
326
- f"[group-sampler] Building earlier traces lookup, buffer size: {len(buf_list)}"
327
- )
328
- earlier_traces_by_key: dict[tuple[str, str], deque[Trace]] = defaultdict(deque)
329
- for tr in buf_list[: -self.batch_size]:
330
- k = self._extract_group_key(tr)
331
- if k[0] != "NA":
332
- earlier_traces_by_key[k].append(tr)
333
-
334
- groups, group_keys = self._validate_and_split_groups(recent_traces)
335
-
336
- final_traces: list[Trace] = []
337
- for g_idx, (homogeneous, target_key) in enumerate(zip(groups, group_keys, strict=False)):
338
-
339
- def current_mean(h: list[Trace]) -> float:
340
- if not h:
341
- return 0.0
342
- vals = [float(getattr(t, "reward", 0.0) or 0.0) for t in h]
343
- return sum(vals) / len(vals)
344
-
345
- pool = earlier_traces_by_key.get(target_key, deque())
346
- if pool:
347
- pool_vals = [float(getattr(tr, "reward", 0.0) or 0.0) for tr in list(pool)]
348
- if pool_vals:
349
- pool_mean = sum(pool_vals) / len(pool_vals)
350
- pool_var = sum((v - pool_mean) * (v - pool_mean) for v in pool_vals) / len(
351
- pool_vals
352
- )
353
- hud_console.info(
354
- f"[group-sampler] Group {g_idx}: earlier-pool size={len(pool_vals)} "
355
- f"mean={pool_mean:.4f} std={(pool_var**0.5):.4f}"
356
- )
357
-
358
- replace_k = max(1, self.group_size // 4)
359
- replace_k = min(replace_k, len(pool), self.group_size)
360
-
361
- if replace_k > 0:
362
- mu = current_mean(homogeneous)
363
- pool_list = list(pool)
364
- pool_indices = list(range(len(pool_list)))
365
- pool_indices.sort(
366
- key=lambda i: abs(
367
- (float(getattr(pool_list[i], "reward", 0.0) or 0.0)) - mu
368
- ),
369
- reverse=True,
370
- )
371
- chosen_pool_idx = set(pool_indices[:replace_k])
372
- replacements = [pool_list[i] for i in pool_indices[:replace_k]]
373
-
374
- remaining = [tr for i, tr in enumerate(pool_list) if i not in chosen_pool_idx]
375
- earlier_traces_by_key[target_key] = deque(remaining)
376
-
377
- group_indices = list(range(len(homogeneous)))
378
- mu = current_mean(homogeneous)
379
- group_indices.sort(
380
- key=lambda i: abs(
381
- (float(getattr(homogeneous[i], "reward", 0.0) or 0.0)) - mu
382
- )
383
- )
384
- target_positions = group_indices[:replace_k]
385
-
386
- for pos, new_tr in zip(target_positions, replacements, strict=False):
387
- homogeneous[pos] = new_tr
388
-
389
- if any(self._extract_group_key(t) != target_key for t in homogeneous):
390
- raise RuntimeError(f"Group {g_idx} is not homogeneous after sampling")
391
- final_traces.extend(homogeneous)
392
-
393
- for i in range(0, len(final_traces), self.group_size):
394
- block = final_traces[i : i + self.group_size]
395
- keys = {self._extract_group_key(t) for t in block}
396
- if len(keys) != 1:
397
- raise RuntimeError(f"Homogeneity validation failed for block starting at index {i}")
398
-
399
- hud_console.info(
400
- f"[group-sampler] final histogram: "
401
- f"{Counter(self._extract_group_key(t) for t in final_traces)}"
402
- )
403
- return final_traces
404
-
405
- # --------------------------------------------------------------------
@@ -1,101 +0,0 @@
1
- {% set image_count = namespace(value=0) %}
2
- {% set video_count = namespace(value=0) %}
3
- {{- '<|im_start|>system\n' }}
4
- {%- if messages[0]['role'] == 'system' -%}
5
- {%- if messages[0]['content'] is string -%}
6
- {{ messages[0]['content'] }}
7
- {%- else -%}
8
- {%- for content in messages[0]['content'] -%}
9
- {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
10
- {%- set image_count.value = image_count.value + 1 -%}
11
- {%- if add_vision_id -%}
12
- {{ 'Picture ' ~ image_count.value ~ ': ' }}
13
- {%- endif -%}
14
- {{ '<|vision_start|><|image_pad|><|vision_end|>' }}
15
- {%- elif content['type'] == 'video' or 'video' in content -%}
16
- {%- set video_count.value = video_count.value + 1 -%}
17
- {%- if add_vision_id -%}
18
- {{ 'Video ' ~ video_count.value ~ ': ' }}
19
- {%- endif -%}
20
- {{ '<|vision_start|><|video_pad|><|vision_end|>' }}
21
- {%- elif 'text' in content -%}
22
- {{ content['text'] }}
23
- {%- endif -%}
24
- {%- endfor -%}
25
- {%- endif -%}
26
- {%- else -%}
27
- {{ 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
28
- {%- endif -%}
29
- {%- if tools -%}
30
- {{ '\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n' }}
31
- {{- tools | map('tojson') | join('\n') -}}
32
- {{ '\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>' }}
33
- {%- endif -%}
34
- {{ '<|im_end|>\n' }}
35
- {%- for message in messages -%}
36
- {# Skip the first system message as it was already rendered. #}
37
- {%- if loop.first and message.role == 'system' %}{% continue %}{% endif -%}
38
-
39
- {# Render tool messages. The logic is slightly different with other messages. #}
40
- {%- if message['role'] == 'tool' -%}
41
- {%- if loop.first or messages[loop.index0 - 1]['role'] != 'tool' -%}
42
- {{ '<|im_start|>user' }}
43
- {%- endif -%}
44
- {{ '\n<tool_response>\n' }}
45
- {%- else -%}
46
- {{ '<|im_start|>' ~ message['role'] ~ '\n' }}
47
- {%- endif -%}
48
-
49
- {%- if message['content'] is string -%}
50
- {{ message['content'] }}
51
- {%- else -%}
52
- {%- for content in message['content'] -%}
53
- {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
54
- {%- set image_count.value = image_count.value + 1 -%}
55
- {%- if add_vision_id -%}
56
- {{ 'Picture ' ~ image_count.value ~ ': ' }}
57
- {%- endif -%}
58
- {{ '<|vision_start|><|image_pad|><|vision_end|>' }}
59
- {%- elif content['type'] == 'video' or 'video' in content -%}
60
- {%- set video_count.value = video_count.value + 1 -%}
61
- {%- if add_vision_id -%}
62
- {{ 'Video ' ~ video_count.value ~ ': ' }}
63
- {%- endif -%}
64
- {{ '<|vision_start|><|video_pad|><|vision_end|>' }}
65
- {%- elif 'text' in content and message['role'] == 'assistant' -%}
66
- {% generation %} {{ content['text'] }} {% endgeneration %}
67
- {%- elif 'text' in content -%}
68
- {{ content['text'] }}
69
- {%- endif -%}
70
- {%- endfor -%}
71
- {%- endif -%}
72
- {# Render tool_calls in AI messages. #}
73
- {%- if message['role'] == 'assistant' and 'tool_calls' in message -%}
74
- {# It will be cleaner if I can use some map function and join them with '\n' #}
75
- {%- for tool_call in message['tool_calls'] -%}
76
- {%- if tool_call['function'] is defined -%}
77
- {%- set tool_call = tool_call['function'] -%}
78
- {%- endif -%}
79
- {# Handle the case where arguments is already a JSON string (OpenAI format) #}
80
- {%- if tool_call.arguments is string -%}
81
- {% generation %} {{ '<tool_call>\n{"name": "' }}{{ tool_call.name }}{{ '", "arguments": ' }}{{ tool_call.arguments }}{{ '}\n</tool_call>' }} {% endgeneration %}
82
- {%- else -%}
83
- {% generation %} {{ '<tool_call>\n' }}{{ tool_call | tojson }}{{ '\n</tool_call>' }} {% endgeneration %}
84
- {%- endif -%}
85
- {%- if not loop.last -%}
86
- {% generation %} {{ '\n' }} {% endgeneration %}
87
- {%- endif -%}
88
- {%- endfor -%}
89
- {%- endif -%}
90
- {%- if message['role'] == 'tool' -%}
91
- {{ '\n</tool_response>' }}
92
- {%- if loop.last or messages[loop.index0 + 1]['role'] != 'tool' -%}
93
- {{ '<|im_end|>\n' }}
94
- {%- endif -%}
95
- {%- else -%}
96
- {{ '<|im_end|>\n' }}
97
- {%- endif -%}
98
- {%- endfor -%}
99
- {%- if add_generation_prompt -%}
100
- {{ '<|im_start|>assistant\n' }}
101
- {%- endif -%}
hud/rl/config.py DELETED
@@ -1,192 +0,0 @@
1
- """Configuration for RL training."""
2
-
3
- from __future__ import annotations
4
-
5
- from dataclasses import dataclass, field
6
- from typing import Literal
7
-
8
- # List of supported VL (Vision-Language) models
9
- SUPPORTED_MODELS = [
10
- "Qwen/Qwen2.5-VL-3B-Instruct",
11
- "Qwen/Qwen2.5-VL-7B-Instruct",
12
- "Qwen/Qwen2.5-VL-14B-Instruct",
13
- "Qwen/Qwen2.5-VL-32B-Instruct",
14
- "Qwen/Qwen2.5-VL-72B-Instruct",
15
- "Qwen/Qwen2.5-7B-Instruct",
16
- "Qwen/Qwen2.5-3B-Instruct",
17
- ]
18
-
19
-
20
- def validate_vl_model(model_name: str) -> None:
21
- """Validate that the model is a supported VL model.
22
-
23
- Args:
24
- model_name: The model name to validate
25
-
26
- Raises:
27
- ValueError: If the model is not a supported VL model
28
- """
29
- if not any(model_name.startswith(supported) for supported in SUPPORTED_MODELS):
30
- raise ValueError(
31
- f"Model '{model_name}' is not a supported VL model. "
32
- f"Only VL (Vision-Language) models are supported for RL training.\n"
33
- f"Supported models: {', '.join(SUPPORTED_MODELS)}\n"
34
- f"Note: '{model_name}' appears to be a text-only model."
35
- )
36
-
37
-
38
- @dataclass
39
- class ModelConfig:
40
- """Model and LoRA configuration."""
41
-
42
- base_model: str = "Qwen/Qwen2.5-VL-3B-Instruct"
43
- lora_r: int = 16
44
- lora_alpha: int = 32
45
- lora_dropout: float = 0.1
46
- target_modules: tuple[str, ...] = (
47
- "q_proj",
48
- "k_proj",
49
- "v_proj",
50
- "o_proj",
51
- "gate_proj",
52
- "up_proj",
53
- "down_proj",
54
- )
55
- min_pixels: int = 256 * 28 * 28
56
- max_pixels: int = 512 * 28 * 28
57
- attn_implementation: str = "flash_attention_2"
58
- use_liger: bool = True
59
- gradient_checkpointing: bool = True
60
-
61
-
62
- @dataclass
63
- class TrainingConfig:
64
- """Training hyperparameters."""
65
-
66
- # GPU parameters
67
- gpu_type: str = "A100"
68
- num_gpus: int = 2
69
-
70
- # Training parameters
71
- training_steps: int = 100
72
- shuffle_dataset: bool = False
73
- save_every_batches: int = 1
74
-
75
- # Batching parameters
76
- epochs: int = 1
77
- batch_size: int = 16
78
- group_size: int = 8
79
- mini_batch_size: int = 1
80
- update_after_group: bool = True # Whether to update the policy after each task group
81
- accumulate_over_minibatches: bool = False # Whether to accumulate over minibatches
82
-
83
- # Advantage calculation parameters
84
- batch_level: Literal["group", "batch"] = "group"
85
- no_std: bool = False
86
- leave_one_out: bool = True
87
-
88
- # Replay buffer parameters
89
- buffer_steps: int = 8
90
- select_strategy: Literal["recent", "variance", "random"] = "variance"
91
-
92
- # Aggregation parameters
93
- ppo_mode: Literal["per_token", "per_trace"] = "per_token"
94
- token_agg: Literal["mean", "sum"] = "mean" # noqa: S105
95
-
96
- # Regularization parameters
97
- kl_beta: float = 0.001
98
- entropy_beta: float = 0.001
99
- top_eps: float = 0.2
100
- bottom_eps: float = 0.1
101
-
102
- # Training hyperparameters
103
- lr: float = 3e-5
104
- grad_clip: float = 1.0
105
-
106
- # Adam hyperparameters
107
- use_8bit_optimizer: bool = True
108
- adam_betas: tuple[float, float] = (0.9, 0.999)
109
- adam_eps: float = 1e-8
110
-
111
-
112
- @dataclass
113
- class ActorConfig:
114
- """Actor/episode collection configuration."""
115
-
116
- # Execution parameters
117
- max_steps_per_episode: int = 5
118
- max_parallel_episodes: int = 48
119
- max_new_tokens: int = 1024
120
- force_tool_choice: bool = True
121
- allowed_tools: list[str] | None = None
122
-
123
- # Model parameters
124
- temperature: float = 0.7
125
-
126
- # Hud agent parameters
127
- system_prompt: str = "You are an expert agent. Complete the task efficiently."
128
- vllm_base_url: str = "http://localhost:8000/v1"
129
- vllm_api_key: str = "token-abc123"
130
-
131
- # Episode execution timeout (seconds)
132
- episode_timeout_sec: int = 600
133
-
134
-
135
- @dataclass
136
- class Config:
137
- """Main configuration combining all sub-configs."""
138
-
139
- model: ModelConfig = field(default_factory=ModelConfig)
140
- training: TrainingConfig = field(default_factory=TrainingConfig)
141
- actor: ActorConfig = field(default_factory=ActorConfig)
142
-
143
- # Telemetry configuration
144
- job_name: str = "RL Training"
145
- job_id: str | None = None # Use existing job ID if provided
146
- stats_interval: int = 1
147
- verbose: bool = False
148
- very_verbose: bool = False
149
-
150
- # Paths
151
- out_dir: str = "./checkpoints"
152
- adapter_prefix: str = "cua-grpo-step"
153
-
154
- # Misc
155
- seed: int = 1234
156
-
157
- @classmethod
158
- def from_dict(cls, d: dict) -> Config:
159
- """Create config from dictionary."""
160
- model = ModelConfig(**d.get("model", {}))
161
- training = TrainingConfig(**d.get("training", {}))
162
- actor = ActorConfig(**d.get("actor", {}))
163
-
164
- return cls(
165
- model=model,
166
- training=training,
167
- actor=actor,
168
- job_name=d.get("job_name", "RL Training"),
169
- job_id=d.get("job_id"),
170
- stats_interval=d.get("stats_interval", 1),
171
- verbose=d.get("verbose", False),
172
- very_verbose=d.get("very_verbose", False),
173
- out_dir=d.get("out_dir", "./checkpoints"),
174
- adapter_prefix=d.get("adapter_prefix", "cua-grpo-step"),
175
- seed=d.get("seed", 1234),
176
- )
177
-
178
- def to_dict(self) -> dict:
179
- """Convert config to dictionary."""
180
- return {
181
- "model": self.model.__dict__,
182
- "training": self.training.__dict__,
183
- "actor": self.actor.__dict__,
184
- "job_name": self.job_name,
185
- "job_id": self.job_id,
186
- "stats_interval": self.stats_interval,
187
- "verbose": self.verbose,
188
- "very_verbose": self.very_verbose,
189
- "out_dir": self.out_dir,
190
- "adapter_prefix": self.adapter_prefix,
191
- "seed": self.seed,
192
- }