hud-python 0.4.28__py3-none-any.whl → 0.4.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (77) hide show
  1. hud/__init__.py +2 -1
  2. hud/agents/base.py +81 -45
  3. hud/agents/claude.py +8 -4
  4. hud/agents/openai_chat_generic.py +66 -40
  5. hud/agents/tests/test_base.py +0 -4
  6. hud/agents/tests/test_openai.py +1 -1
  7. hud/cli/__init__.py +182 -52
  8. hud/cli/dev.py +8 -9
  9. hud/cli/eval.py +317 -119
  10. hud/cli/flows/__init__.py +0 -0
  11. hud/cli/flows/tasks.py +0 -0
  12. hud/cli/get.py +160 -0
  13. hud/cli/rl/__init__.py +567 -71
  14. hud/cli/rl/config.py +94 -0
  15. hud/cli/rl/display.py +133 -0
  16. hud/cli/rl/gpu.py +63 -0
  17. hud/cli/rl/gpu_utils.py +318 -0
  18. hud/cli/rl/presets.py +96 -0
  19. hud/cli/rl/remote_runner.py +347 -0
  20. hud/cli/rl/rl_api.py +150 -0
  21. hud/cli/rl/vllm.py +177 -0
  22. hud/cli/tests/test_analyze_metadata.py +0 -1
  23. hud/cli/utils/tasks.py +26 -0
  24. hud/clients/base.py +21 -23
  25. hud/clients/mcp_use.py +36 -44
  26. hud/clients/tests/test_mcp_use_retry.py +10 -10
  27. hud/datasets/__init__.py +4 -3
  28. hud/datasets/{execution/parallel.py → parallel.py} +1 -1
  29. hud/datasets/{execution/runner.py → runner.py} +1 -1
  30. hud/datasets/utils.py +1 -1
  31. hud/native/comparator.py +6 -6
  32. hud/native/tests/test_comparator.py +8 -8
  33. hud/native/tests/test_native_init.py +13 -11
  34. hud/otel/config.py +1 -1
  35. hud/otel/instrumentation.py +35 -0
  36. hud/rl/README.md +30 -0
  37. hud/rl/__init__.py +1 -0
  38. hud/rl/actor.py +174 -0
  39. hud/rl/buffer.py +371 -0
  40. hud/rl/chat_template.jinja +101 -0
  41. hud/rl/config.py +184 -0
  42. hud/rl/distributed.py +95 -0
  43. hud/rl/learner.py +589 -0
  44. hud/rl/tests/__init__.py +1 -0
  45. hud/rl/tests/test_learner.py +171 -0
  46. hud/rl/train.py +354 -0
  47. hud/rl/types.py +101 -0
  48. hud/rl/utils/start_vllm_server.sh +30 -0
  49. hud/rl/utils.py +524 -0
  50. hud/rl/vllm_adapter.py +125 -0
  51. hud/settings.py +6 -0
  52. hud/telemetry/__init__.py +2 -1
  53. hud/telemetry/job.py +46 -3
  54. hud/telemetry/tests/test_trace.py +3 -3
  55. hud/telemetry/trace.py +85 -13
  56. hud/tools/tests/test_computer.py +3 -3
  57. hud/tools/tests/test_computer_actions.py +1 -1
  58. hud/types.py +123 -2
  59. hud/utils/group_eval.py +223 -0
  60. hud/utils/hud_console.py +113 -13
  61. hud/utils/tasks.py +119 -0
  62. hud/utils/tests/test_version.py +1 -1
  63. hud/version.py +1 -1
  64. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/METADATA +20 -2
  65. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/RECORD +68 -48
  66. hud/cli/hf.py +0 -406
  67. hud/cli/rl/README.md +0 -243
  68. hud/cli/rl/init.py +0 -370
  69. hud/cli/rl/pod.py +0 -501
  70. hud/cli/rl/ssh.py +0 -322
  71. hud/cli/rl/train.py +0 -562
  72. hud/cli/rl/utils.py +0 -165
  73. hud/datasets/execution/__init__.py +0 -13
  74. hud/datasets/task.py +0 -116
  75. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/WHEEL +0 -0
  76. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/entry_points.txt +0 -0
  77. {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/licenses/LICENSE +0 -0
hud/rl/buffer.py ADDED
@@ -0,0 +1,371 @@
1
+ """Replay buffer for storing and sampling episodes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import random
7
+ from collections import deque
8
+ from typing import TYPE_CHECKING, Generic, TypeVar
9
+
10
+ from hud.types import Task, Trace
11
+ from hud.utils.hud_console import HUDConsole
12
+
13
+ logger = logging.getLogger(__name__)
14
+ hud_console = HUDConsole(logger=logger)
15
+
16
+ T = TypeVar("T")
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import Callable
20
+
21
+ from hud.rl.config import Config
22
+
23
+
24
+ class Buffer(Generic[T]):
25
+ """Simple buffer for a list of tasks, traces or episodes."""
26
+
27
+ def __init__(self, max_size: int = 10000) -> None:
28
+ self.max_size = max_size
29
+ self.buffer: deque[T] = deque(maxlen=max_size)
30
+
31
+ def add(self, items: list[T] | T, shuffle: bool = False) -> None:
32
+ """Add items to buffer."""
33
+ if isinstance(items, list):
34
+ for item in items:
35
+ self.buffer.append(item)
36
+ else:
37
+ self.buffer.append(items)
38
+ if shuffle:
39
+ random.shuffle(self.buffer)
40
+
41
+ def add_fill(self, items: list[T] | T, target_size: int, shuffle: bool = False) -> None:
42
+ """Add items to buffer until the buffer is at least the target size."""
43
+ while len(self.buffer) < target_size:
44
+ self.add(items, shuffle)
45
+
46
+ def get(self, n: int = 0) -> list[T]:
47
+ """Get items from the buffer."""
48
+ if n == 0:
49
+ return list(self.buffer)
50
+ if n > len(self.buffer):
51
+ raise ValueError("Not enough items in buffer")
52
+ return list(self.buffer)[-n:]
53
+
54
+ def consume(self, n: int = 0) -> list[T]:
55
+ """Consume items from the buffer."""
56
+ if n == 0:
57
+ return list(self.buffer)
58
+ if n > len(self.buffer):
59
+ raise ValueError("Not enough items in buffer")
60
+
61
+ return [self.buffer.pop() for _ in range(n)]
62
+
63
+ def get_filtered(
64
+ self, n: int = 0, filter_fn: Callable[[T], bool] | None = None, consume: bool = False
65
+ ) -> list[T]:
66
+ """Filter the buffer by a filter function."""
67
+ filtered = (
68
+ [item for item in self.buffer if filter_fn(item)] if filter_fn else list(self.buffer)
69
+ )
70
+ if n == 0:
71
+ return filtered
72
+ return self.consume(n) if consume else self.get(n)
73
+
74
+ def sample(
75
+ self,
76
+ batch_size: int,
77
+ n: int = 0,
78
+ filter_fn: Callable[[T], bool] | None = None,
79
+ consume: bool = False,
80
+ ) -> list[T]:
81
+ """Sample a batch of items with optional filtering."""
82
+ items = self.get_filtered(n, filter_fn, consume)
83
+
84
+ if len(items) < batch_size:
85
+ hud_console.warning(f"Buffer has {len(items)} items, requested {batch_size}")
86
+ return items
87
+
88
+ return random.sample(items, batch_size)
89
+
90
+ def clear(self) -> None:
91
+ """Clear the buffer."""
92
+ self.buffer.clear()
93
+
94
+ def __len__(self) -> int:
95
+ """Use len() directly on Buffer instances."""
96
+ return len(self.buffer)
97
+
98
+
99
+ class DatasetBuffer(Buffer[Task]):
100
+ """
101
+ Buffer for a dataset.
102
+ Loads in individual tasks that will be trained for a specified number of training steps.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ dataset: list[Task] | Task,
108
+ config: Config,
109
+ ) -> None:
110
+ self.config = config
111
+
112
+ self.group_size = config.training.group_size
113
+ self.batch_size = config.training.batch_size
114
+ self.training_steps = config.training.training_steps
115
+
116
+ if self.group_size > self.batch_size:
117
+ raise ValueError(
118
+ f"Group size is greater than batch size, {self.group_size} > {self.batch_size}"
119
+ )
120
+
121
+ if self.batch_size % self.group_size != 0:
122
+ raise ValueError(
123
+ f"A batch cannot have irregular groups, {self.group_size} % {self.batch_size} != 0"
124
+ )
125
+
126
+ if self.group_size % config.training.mini_batch_size != 0:
127
+ raise ValueError(
128
+ f"Group size is not a multiple of mini batch size, {self.group_size} % {config.training.mini_batch_size} != 0" # noqa: E501
129
+ )
130
+
131
+ self.groups_per_batch = self.batch_size // self.group_size
132
+ self.number_of_tasks = self.training_steps * self.groups_per_batch
133
+
134
+ super().__init__(self.number_of_tasks)
135
+
136
+ dataset = dataset if isinstance(dataset, list) else [dataset]
137
+ tasks = self._validate_tasks(dataset)
138
+ if config.training.shuffle_dataset:
139
+ random.shuffle(tasks)
140
+ if len(tasks) > self.number_of_tasks:
141
+ leftovers = len(tasks) - self.number_of_tasks
142
+ hud_console.warning(
143
+ f"Training steps ({self.training_steps}) will lead to {leftovers} tasks not being trained" # noqa: E501
144
+ )
145
+ tasks = tasks[: self.number_of_tasks]
146
+
147
+ # Check if the dataset is imbalanced
148
+ self.dataset_size = len(tasks)
149
+ if self.training_steps % self.dataset_size != 0:
150
+ leftovers = self.number_of_tasks % self.dataset_size
151
+ hud_console.warning(
152
+ f"Dataset imbalanced ({leftovers} tasks will be trained 1 more time)"
153
+ )
154
+ hud_console.warning(
155
+ f"This is because the number of training steps ({self.training_steps}) is not a multiple of the dataset size ({self.dataset_size})" # noqa: E501
156
+ )
157
+
158
+ self.add_fill(tasks, self.number_of_tasks, config.training.shuffle_dataset)
159
+
160
+ def _validate_tasks(self, tasks: list[Task]) -> list[Task]:
161
+ """Validate that all tasks are proper HUD Task objects."""
162
+ if not tasks:
163
+ raise ValueError("No tasks provided to DatasetBuffer")
164
+
165
+ validated_tasks = []
166
+ for i, task in enumerate(tasks):
167
+ if not isinstance(task, Task):
168
+ raise TypeError(f"Task at index {i} is not a HUD Task object, got {type(task)}")
169
+ validated_tasks.append(task)
170
+
171
+ return validated_tasks
172
+
173
+ @property
174
+ def info(self) -> dict[str, int | float | str]:
175
+ """Get the info of the buffer."""
176
+ return {
177
+ "total_items": len(self),
178
+ "total_traces": self.number_of_tasks * self.group_size,
179
+ "total_batches": self.training_steps,
180
+ "task_repeats": self.number_of_tasks // self.dataset_size,
181
+ "dataset_size": self.dataset_size,
182
+ "group_size": self.group_size,
183
+ "batch_size": self.batch_size,
184
+ }
185
+
186
+ def get_tasks(self, consume: bool = True) -> list[Task]:
187
+ """Get tasks for a batch."""
188
+ tasks = self.consume(self.groups_per_batch) if consume else self.get(self.groups_per_batch)
189
+ # Create groups where each group contains group_size copies of the same task
190
+ result = []
191
+ for task in tasks:
192
+ result.extend([task] * self.group_size)
193
+ return result
194
+
195
+
196
+ class ReplayBuffer(Buffer[Trace]):
197
+ """Buffer for traces."""
198
+
199
+ def __init__(self, config: Config) -> None:
200
+ self.config = config
201
+
202
+ self.buffer_steps = config.training.buffer_steps
203
+ self.select_strategy = config.training.select_strategy
204
+ self.group_size = config.training.group_size
205
+ self.batch_size = config.training.batch_size
206
+
207
+ buffer_size = self.buffer_steps * self.batch_size
208
+
209
+ super().__init__(buffer_size)
210
+
211
+ def sample_traces(self) -> list[Trace]:
212
+ """Sample traces for a batch."""
213
+ if self.select_strategy == "recent":
214
+ return self.get(self.batch_size)
215
+ elif self.select_strategy == "random":
216
+ return self.sample(self.batch_size)
217
+ elif self.select_strategy == "variance":
218
+ return self._sample_high_variance_traces()
219
+ else:
220
+ raise ValueError(f"Invalid select strategy: {self.select_strategy}")
221
+
222
+ def _sample_high_variance_traces(self) -> list[Trace]:
223
+ from collections import Counter, defaultdict, deque
224
+
225
+ # Expect recent window to already be grouped by task id
226
+
227
+ # Build recent window and earlier lookup (short form)
228
+ buf_list = list(self.buffer)
229
+ if len(buf_list) < self.batch_size:
230
+ hud_console.warning(
231
+ f"[group-sampler] Buffer has only {len(buf_list)} traces, need {self.batch_size}"
232
+ )
233
+ while len(buf_list) < self.batch_size:
234
+ take = min(len(buf_list) or 1, self.batch_size - len(buf_list))
235
+ buf_list.extend(buf_list[:take])
236
+ recent_traces = buf_list[-self.batch_size :]
237
+ hud_console.info(
238
+ f"[group-sampler] recent-window histogram: {Counter(getattr(t.task, 'id', 'NA') for t in recent_traces)}" # noqa: E501
239
+ )
240
+
241
+ hud_console.info(
242
+ f"[group-sampler] Building earlier traces lookup, buffer size: {len(buf_list)}"
243
+ )
244
+ earlier_traces_by_task: dict[str, deque[Trace]] = defaultdict(deque)
245
+ for tr in buf_list[: -self.batch_size]:
246
+ earlier_traces_by_task[getattr(tr.task, "id", "NA")].append(tr)
247
+
248
+ # Chunk from the most-recent end
249
+ final_traces: list[Trace] = []
250
+ groups_per_batch = self.batch_size // self.group_size
251
+ hud_console.info(f"[group-sampler] Processing {groups_per_batch} groups")
252
+ for g_idx in range(groups_per_batch):
253
+ start = g_idx * self.group_size
254
+ end = start + self.group_size
255
+ group = recent_traces[start:end]
256
+
257
+ # Assert homogeneity: every trace in a group must share the same task id
258
+ cnt = Counter(getattr(t.task, "id", "NA") for t in group)
259
+ if len(cnt) != 1:
260
+ raise RuntimeError(f"Group {g_idx} is not homogeneous: {dict(cnt)}")
261
+ target_tid = next(iter(cnt.keys()))
262
+
263
+ # Build homogeneous group of target_tid, filling from earlier traces to increase spread
264
+ homogeneous: list[Trace] = [
265
+ t for t in group if getattr(t.task, "id", "NA") == target_tid
266
+ ]
267
+ needed = self.group_size - len(homogeneous)
268
+
269
+ # Greedy fill: choose earlier traces (same task-id) farthest from current mean reward
270
+ def current_mean(homogeneous: list[Trace]) -> float:
271
+ if not homogeneous:
272
+ return 0.0
273
+ vals = [float(getattr(t, "reward", 0.0) or 0.0) for t in homogeneous]
274
+ return sum(vals) / len(vals)
275
+
276
+ while needed > 0:
277
+ pool = earlier_traces_by_task.get(target_tid, deque())
278
+ if pool:
279
+ mu = current_mean(homogeneous)
280
+ # pick element farthest from current mean
281
+ best_i = None
282
+ best_dist = -1.0
283
+ for i, tr in enumerate(list(pool)):
284
+ r = float(getattr(tr, "reward", 0.0) or 0.0)
285
+ dist = abs(r - mu)
286
+ if dist > best_dist:
287
+ best_dist = dist
288
+ best_i = i
289
+ # pop selected
290
+ chosen = list(pool)[best_i] # type: ignore[index]
291
+ # remove from deque efficiently by rotating
292
+ left = list(pool)
293
+ if best_i is not None:
294
+ left.pop(best_i) # O(n) but pool is small in practice
295
+ earlier_traces_by_task[target_tid] = deque(left)
296
+ homogeneous.append(chosen)
297
+ else:
298
+ # duplicate extreme within current homogeneous set
299
+ if not homogeneous:
300
+ raise RuntimeError(f"Group {g_idx} has no traces for target {target_tid}")
301
+ mu = current_mean(homogeneous)
302
+ extreme = max(
303
+ homogeneous, key=lambda t: abs(float(getattr(t, "reward", 0.0) or 0.0) - mu)
304
+ )
305
+ homogeneous.append(extreme)
306
+ needed -= 1
307
+
308
+ # Replacement step: swap in earlier traces to increase reward spread
309
+ pool = earlier_traces_by_task.get(target_tid, deque())
310
+ if pool:
311
+ # Log pool stats
312
+ pool_vals = [float(getattr(tr, "reward", 0.0) or 0.0) for tr in list(pool)]
313
+ if pool_vals:
314
+ pool_mean = sum(pool_vals) / len(pool_vals)
315
+ pool_var = sum((v - pool_mean) * (v - pool_mean) for v in pool_vals) / len(
316
+ pool_vals
317
+ )
318
+ hud_console.info(
319
+ f"[group-sampler] Group {g_idx}: earlier-pool size={len(pool_vals)} mean={pool_mean:.4f} std={(pool_var**0.5):.4f}" # noqa: E501
320
+ )
321
+
322
+ # Decide how many to replace (up to 1/4 of group, at least 1)
323
+ replace_k = max(1, self.group_size // 4)
324
+ replace_k = min(replace_k, len(pool), self.group_size)
325
+
326
+ if replace_k > 0:
327
+ mu = current_mean(homogeneous)
328
+ # Select replacement candidates from pool farthest from current mean
329
+ pool_list = list(pool)
330
+ pool_indices = list(range(len(pool_list)))
331
+ pool_indices.sort(
332
+ key=lambda i: abs(
333
+ (float(getattr(pool_list[i], "reward", 0.0) or 0.0)) - mu
334
+ ),
335
+ reverse=True,
336
+ )
337
+ chosen_pool_idx = set(pool_indices[:replace_k])
338
+ replacements = [pool_list[i] for i in pool_indices[:replace_k]]
339
+
340
+ # Remove chosen from pool deque
341
+ remaining = [tr for i, tr in enumerate(pool_list) if i not in chosen_pool_idx]
342
+ earlier_traces_by_task[target_tid] = deque(remaining)
343
+
344
+ # Select current group positions closest to mean to replace
345
+ group_indices = list(range(len(homogeneous)))
346
+ group_indices.sort(
347
+ key=lambda i: abs(
348
+ (float(getattr(homogeneous[i], "reward", 0.0) or 0.0)) - mu
349
+ )
350
+ )
351
+ target_positions = group_indices[:replace_k]
352
+
353
+ for pos, new_tr in zip(target_positions, replacements, strict=False):
354
+ homogeneous[pos] = new_tr
355
+
356
+ # Validate homogeneity
357
+ if any(getattr(t.task, "id", "NA") != target_tid for t in homogeneous):
358
+ raise RuntimeError(f"Group {g_idx} is not homogeneous after sampling")
359
+ final_traces.extend(homogeneous)
360
+
361
+ for i in range(0, len(final_traces), self.group_size):
362
+ block = final_traces[i : i + self.group_size]
363
+ if len({getattr(t.task, "id", "NA") for t in block}) != 1:
364
+ raise RuntimeError(f"Homogeneity validation failed for block starting at index {i}")
365
+
366
+ hud_console.info(
367
+ f"[group-sampler] final histogram: {Counter(getattr(t.task, 'id', 'NA') for t in final_traces)}" # noqa: E501
368
+ )
369
+ return final_traces
370
+
371
+ # --------------------------------------------------------------------
@@ -0,0 +1,101 @@
1
+ {% set image_count = namespace(value=0) %}
2
+ {% set video_count = namespace(value=0) %}
3
+ {{- '<|im_start|>system\n' }}
4
+ {%- if messages[0]['role'] == 'system' -%}
5
+ {%- if messages[0]['content'] is string -%}
6
+ {{ messages[0]['content'] }}
7
+ {%- else -%}
8
+ {%- for content in messages[0]['content'] -%}
9
+ {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
10
+ {%- set image_count.value = image_count.value + 1 -%}
11
+ {%- if add_vision_id -%}
12
+ {{ 'Picture ' ~ image_count.value ~ ': ' }}
13
+ {%- endif -%}
14
+ {{ '<|vision_start|><|image_pad|><|vision_end|>' }}
15
+ {%- elif content['type'] == 'video' or 'video' in content -%}
16
+ {%- set video_count.value = video_count.value + 1 -%}
17
+ {%- if add_vision_id -%}
18
+ {{ 'Video ' ~ video_count.value ~ ': ' }}
19
+ {%- endif -%}
20
+ {{ '<|vision_start|><|video_pad|><|vision_end|>' }}
21
+ {%- elif 'text' in content -%}
22
+ {{ content['text'] }}
23
+ {%- endif -%}
24
+ {%- endfor -%}
25
+ {%- endif -%}
26
+ {%- else -%}
27
+ {{ 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
28
+ {%- endif -%}
29
+ {%- if tools -%}
30
+ {{ '\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n' }}
31
+ {{- tools | map('tojson') | join('\n') -}}
32
+ {{ '\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>' }}
33
+ {%- endif -%}
34
+ {{ '<|im_end|>\n' }}
35
+ {%- for message in messages -%}
36
+ {# Skip the first system message as it was already rendered. #}
37
+ {%- if loop.first and message.role == 'system' %}{% continue %}{% endif -%}
38
+
39
+ {# Render tool messages. The logic is slightly different with other messages. #}
40
+ {%- if message['role'] == 'tool' -%}
41
+ {%- if loop.first or messages[loop.index0 - 1]['role'] != 'tool' -%}
42
+ {{ '<|im_start|>user' }}
43
+ {%- endif -%}
44
+ {{ '\n<tool_response>\n' }}
45
+ {%- else -%}
46
+ {{ '<|im_start|>' ~ message['role'] ~ '\n' }}
47
+ {%- endif -%}
48
+
49
+ {%- if message['content'] is string -%}
50
+ {{ message['content'] }}
51
+ {%- else -%}
52
+ {%- for content in message['content'] -%}
53
+ {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
54
+ {%- set image_count.value = image_count.value + 1 -%}
55
+ {%- if add_vision_id -%}
56
+ {{ 'Picture ' ~ image_count.value ~ ': ' }}
57
+ {%- endif -%}
58
+ {{ '<|vision_start|><|image_pad|><|vision_end|>' }}
59
+ {%- elif content['type'] == 'video' or 'video' in content -%}
60
+ {%- set video_count.value = video_count.value + 1 -%}
61
+ {%- if add_vision_id -%}
62
+ {{ 'Video ' ~ video_count.value ~ ': ' }}
63
+ {%- endif -%}
64
+ {{ '<|vision_start|><|video_pad|><|vision_end|>' }}
65
+ {%- elif 'text' in content and message['role'] == 'assistant' -%}
66
+ {% generation %} {{ content['text'] }} {% endgeneration %}
67
+ {%- elif 'text' in content -%}
68
+ {{ content['text'] }}
69
+ {%- endif -%}
70
+ {%- endfor -%}
71
+ {%- endif -%}
72
+ {# Render tool_calls in AI messages. #}
73
+ {%- if message['role'] == 'assistant' and 'tool_calls' in message -%}
74
+ {# It will be cleaner if I can use some map function and join them with '\n' #}
75
+ {%- for tool_call in message['tool_calls'] -%}
76
+ {%- if tool_call['function'] is defined -%}
77
+ {%- set tool_call = tool_call['function'] -%}
78
+ {%- endif -%}
79
+ {# Handle the case where arguments is already a JSON string (OpenAI format) #}
80
+ {%- if tool_call.arguments is string -%}
81
+ {% generation %} {{ '<tool_call>\n{"name": "' }}{{ tool_call.name }}{{ '", "arguments": ' }}{{ tool_call.arguments }}{{ '}\n</tool_call>' }} {% endgeneration %}
82
+ {%- else -%}
83
+ {% generation %} {{ '<tool_call>\n' }}{{ tool_call | tojson }}{{ '\n</tool_call>' }} {% endgeneration %}
84
+ {%- endif -%}
85
+ {%- if not loop.last -%}
86
+ {% generation %} {{ '\n' }} {% endgeneration %}
87
+ {%- endif -%}
88
+ {%- endfor -%}
89
+ {%- endif -%}
90
+ {%- if message['role'] == 'tool' -%}
91
+ {{ '\n</tool_response>' }}
92
+ {%- if loop.last or messages[loop.index0 + 1]['role'] != 'tool' -%}
93
+ {{ '<|im_end|>\n' }}
94
+ {%- endif -%}
95
+ {%- else -%}
96
+ {{ '<|im_end|>\n' }}
97
+ {%- endif -%}
98
+ {%- endfor -%}
99
+ {%- if add_generation_prompt -%}
100
+ {{ '<|im_start|>assistant\n' }}
101
+ {%- endif -%}
hud/rl/config.py ADDED
@@ -0,0 +1,184 @@
1
+ """Configuration for RL training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Literal
7
+
8
+ # List of supported VL (Vision-Language) models
9
+ SUPPORTED_MODELS = [
10
+ "Qwen/Qwen2.5-VL-3B-Instruct",
11
+ "Qwen/Qwen2.5-VL-7B-Instruct",
12
+ "Qwen/Qwen2.5-VL-14B-Instruct",
13
+ "Qwen/Qwen2.5-VL-32B-Instruct",
14
+ "Qwen/Qwen2.5-VL-72B-Instruct",
15
+ "Qwen/Qwen2.5-7B-Instruct",
16
+ ]
17
+
18
+
19
+ def validate_vl_model(model_name: str) -> None:
20
+ """Validate that the model is a supported VL model.
21
+
22
+ Args:
23
+ model_name: The model name to validate
24
+
25
+ Raises:
26
+ ValueError: If the model is not a supported VL model
27
+ """
28
+ if not any(model_name.startswith(supported) for supported in SUPPORTED_MODELS):
29
+ raise ValueError(
30
+ f"Model '{model_name}' is not a supported VL model. "
31
+ f"Only VL (Vision-Language) models are supported for RL training.\n"
32
+ f"Supported models: {', '.join(SUPPORTED_MODELS)}\n"
33
+ f"Note: '{model_name}' appears to be a text-only model."
34
+ )
35
+
36
+
37
+ @dataclass
38
+ class ModelConfig:
39
+ """Model and LoRA configuration."""
40
+
41
+ base_model: str = "Qwen/Qwen2.5-VL-3B-Instruct"
42
+ lora_r: int = 8
43
+ lora_alpha: int = 16
44
+ lora_dropout: float = 0.05
45
+ target_modules: tuple[str, ...] = (
46
+ "q_proj",
47
+ "k_proj",
48
+ "v_proj",
49
+ "o_proj",
50
+ "gate_proj",
51
+ "up_proj",
52
+ "down_proj",
53
+ )
54
+ min_pixels: int = 256 * 28 * 28
55
+ max_pixels: int = 512 * 28 * 28
56
+ attn_implementation: str = "flash_attention_2"
57
+ use_liger: bool = True
58
+ gradient_checkpointing: bool = True
59
+
60
+
61
+ @dataclass
62
+ class TrainingConfig:
63
+ """Training hyperparameters."""
64
+
65
+ # Training parameters
66
+ training_steps: int = 100
67
+ shuffle_dataset: bool = False
68
+ save_every_batches: int = 1
69
+
70
+ # Batching parameters
71
+ epochs: int = 2
72
+ batch_size: int = 24
73
+ group_size: int = 4
74
+ mini_batch_size: int = 1
75
+ update_after_group: bool = True # Whether to update the policy after each task group
76
+ accumulate_over_minibatches: bool = False # Whether to accumulate over minibatches
77
+
78
+ # Advantage calculation parameters
79
+ batch_level: Literal["group", "batch"] = "group"
80
+ no_std: bool = False
81
+ leave_one_out: bool = True
82
+
83
+ # Replay buffer parameters
84
+ buffer_steps: int = 4
85
+ select_strategy: Literal["recent", "variance", "random"] = "variance"
86
+
87
+ # Aggregation parameters
88
+ ppo_mode: Literal["per_token", "per_trace"] = "per_token"
89
+ token_agg: Literal["mean", "sum"] = "mean" # noqa: S105
90
+
91
+ # Regularization parameters
92
+ kl_beta: float = 0.0
93
+ entropy_beta: float = 0.0
94
+ top_eps: float = 0.2
95
+ bottom_eps: float = 0.1
96
+
97
+ # Training hyperparameters
98
+ lr: float = 3e-5
99
+ grad_clip: float = 1.0
100
+
101
+ # Adam hyperparameters
102
+ use_8bit_optimizer: bool = True
103
+ adam_betas: tuple[float, float] = (0.9, 0.999)
104
+ adam_eps: float = 1e-8
105
+
106
+
107
+ @dataclass
108
+ class ActorConfig:
109
+ """Actor/episode collection configuration."""
110
+
111
+ # Execution parameters
112
+ max_steps_per_episode: int = 5
113
+ max_parallel_episodes: int = 48
114
+ max_new_tokens: int = 1024
115
+ force_tool_choice: bool = True
116
+ allowed_tools: list[str] | None = None
117
+
118
+ # Model parameters
119
+ temperature: float = 0.7
120
+
121
+ # Hud agent parameters
122
+ system_prompt: str = "You are an expert agent. Complete the task efficiently."
123
+ vllm_base_url: str = "http://localhost:8000/v1"
124
+ vllm_api_key: str = "token-abc123"
125
+
126
+ # Episode execution timeout (seconds)
127
+ episode_timeout_sec: int = 600
128
+
129
+
130
+ @dataclass
131
+ class Config:
132
+ """Main configuration combining all sub-configs."""
133
+
134
+ model: ModelConfig = field(default_factory=ModelConfig)
135
+ training: TrainingConfig = field(default_factory=TrainingConfig)
136
+ actor: ActorConfig = field(default_factory=ActorConfig)
137
+
138
+ # Telemetry configuration
139
+ job_name: str = "RL Training"
140
+ job_id: str | None = None # Use existing job ID if provided
141
+ stats_interval: int = 1
142
+ verbose: bool = False
143
+
144
+ # Paths
145
+ out_dir: str = "./checkpoints"
146
+ adapter_prefix: str = "cua-grpo-step"
147
+
148
+ # Misc
149
+ seed: int = 1234
150
+
151
+ @classmethod
152
+ def from_dict(cls, d: dict) -> Config:
153
+ """Create config from dictionary."""
154
+ model = ModelConfig(**d.get("model", {}))
155
+ training = TrainingConfig(**d.get("training", {}))
156
+ actor = ActorConfig(**d.get("actor", {}))
157
+
158
+ return cls(
159
+ model=model,
160
+ training=training,
161
+ actor=actor,
162
+ job_name=d.get("job_name", "RL Training"),
163
+ job_id=d.get("job_id"),
164
+ stats_interval=d.get("stats_interval", 1),
165
+ verbose=d.get("verbose", False),
166
+ out_dir=d.get("out_dir", "./checkpoints"),
167
+ adapter_prefix=d.get("adapter_prefix", "cua-grpo-step"),
168
+ seed=d.get("seed", 1234),
169
+ )
170
+
171
+ def to_dict(self) -> dict:
172
+ """Convert config to dictionary."""
173
+ return {
174
+ "model": self.model.__dict__,
175
+ "training": self.training.__dict__,
176
+ "actor": self.actor.__dict__,
177
+ "job_name": self.job_name,
178
+ "job_id": self.job_id,
179
+ "stats_interval": self.stats_interval,
180
+ "verbose": self.verbose,
181
+ "out_dir": self.out_dir,
182
+ "adapter_prefix": self.adapter_prefix,
183
+ "seed": self.seed,
184
+ }