hud-python 0.4.46__py3-none-any.whl → 0.4.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

@@ -0,0 +1,525 @@
1
+ """Tests for hud.cli.eval module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import AsyncMock, MagicMock, Mock, patch
6
+
7
+ import pytest
8
+ from mcp import types
9
+
10
+ from hud.cli.eval import build_agent, eval_command, get_available_models, run_full_dataset, run_single_task
11
+ from hud.types import Task, Trace
12
+
13
+ class TestBuildAgent:
14
+ """Test the build_agent function."""
15
+
16
+ def test_builds_integration_test_agent(self) -> None:
17
+ """
18
+ Test building an integration test agent.
19
+ """
20
+ with patch("hud.agents.misc.integration_test_agent.IntegrationTestRunner") as mock_runner:
21
+ mock_instance = Mock()
22
+ mock_runner.return_value = mock_instance
23
+
24
+ # Test with verbose=False
25
+ result = build_agent("integration_test", verbose=False)
26
+
27
+ mock_runner.assert_called_once_with(verbose=False)
28
+ assert result == mock_instance
29
+
30
+ def test_builds_claude_agent(self) -> None:
31
+ """
32
+ Test building a Claude agent with default model.
33
+ """
34
+ with patch("hud.agents.ClaudeAgent") as mock_runner:
35
+ mock_instance = Mock()
36
+ mock_runner.return_value = mock_instance
37
+
38
+ # Test with verbose=False
39
+ result = build_agent("claude", verbose=False)
40
+
41
+ mock_runner.assert_called_once_with(
42
+ model="claude-sonnet-4-20250514",
43
+ verbose=False
44
+ )
45
+ assert result == mock_instance
46
+
47
+ def test_builds_claude_agent_with_custom_model_and_allowed_tools(self) -> None:
48
+ """
49
+ Test building a Claude agent with custom model name and allowed tools.
50
+ """
51
+ with patch("hud.agents.ClaudeAgent") as mock_runner:
52
+ mock_instance = Mock()
53
+ mock_runner.return_value = mock_instance
54
+
55
+ # Test with verbose=False
56
+ result = build_agent(
57
+ "claude",
58
+ model="claude-sonnet-4-20250514",
59
+ allowed_tools=["act"],
60
+ verbose=True,
61
+ )
62
+
63
+ mock_runner.assert_called_once_with(
64
+ model="claude-sonnet-4-20250514",
65
+ allowed_tools=["act"],
66
+ verbose=True,
67
+ )
68
+ assert result == mock_instance
69
+
70
+
71
+ class TestRunSingleTask:
72
+ """Test the run_single_task function."""
73
+
74
+ @pytest.mark.asyncio
75
+ async def test_applies_agent_config_from_task(self) -> None:
76
+ """Test that task.agent_config is applied during agent initialization."""
77
+ mock_task = Task(
78
+ prompt="Test",
79
+ mcp_config={"local": {"url": "http://localhost:8765/mcp"}},
80
+ agent_config={
81
+ "system_prompt": "Custom instructions",
82
+ "allowed_tools": ["tool1", "tool2"],
83
+ "append_setup_output": False,
84
+ }
85
+ )
86
+ mock_agent = AsyncMock(
87
+ initialize=AsyncMock(),
88
+ run=AsyncMock(return_value=Trace(reward=1.0, done=True))
89
+ )
90
+
91
+ with patch("hud.utils.tasks.load_tasks", return_value=[mock_task]), \
92
+ patch("hud.agents.misc.integration_test_agent.IntegrationTestRunner", return_value=mock_agent), \
93
+ patch("hud.cli.eval.find_environment_dir", return_value=None), \
94
+ patch("hud.cli.eval.hud.trace"):
95
+ await run_single_task("test.json", agent_type="integration_test", max_steps=10)
96
+
97
+ # Verify agent.run was called with the task containing agent_config
98
+ mock_agent.run.assert_called_once()
99
+ called_task = mock_agent.run.call_args[0][0]
100
+ assert called_task.agent_config == mock_task.agent_config
101
+
102
+ @pytest.mark.asyncio
103
+ async def test_runs_with_group_size_greater_than_one(self) -> None:
104
+ """Test that group_size > 1 triggers run_tasks_grouped instead of agent.run."""
105
+ mock_task = Task(prompt="Test", mcp_config={"local": {"url": "http://localhost:8765/mcp"}})
106
+
107
+ with patch("hud.utils.tasks.load_tasks", return_value=[mock_task]), \
108
+ patch("hud.cli.eval.run_tasks_grouped", new_callable=AsyncMock) as mock_grouped, \
109
+ patch("hud.cli.eval.display_group_statistics"), \
110
+ patch("hud.cli.eval.find_environment_dir", return_value=None), \
111
+ patch("hud.cli.eval.hud.trace"):
112
+
113
+ mock_grouped.return_value = [{"task": mock_task, "rewards": [1.0, 0.5]}]
114
+
115
+ await run_single_task("test.json", agent_type="integration_test", group_size=3, max_steps=10)
116
+
117
+ # Verify run_tasks_grouped was called with correct group_size
118
+ mock_grouped.assert_called_once()
119
+ assert mock_grouped.call_args.kwargs["group_size"] == 3
120
+ assert mock_grouped.call_args.kwargs["max_steps"] == 10
121
+
122
+
123
+ class TestToolFiltering:
124
+ """Test wildcard tool filtering via agent_config in tasks."""
125
+
126
+ @pytest.fixture
127
+ def mock_mcp_client(self):
128
+ """Fixture for mock MCP client."""
129
+ client = MagicMock()
130
+ client.initialize = AsyncMock()
131
+ client.mcp_config = {"local": {"url": "http://localhost"}}
132
+ return client
133
+
134
+ @pytest.fixture
135
+ def mock_model_client(self):
136
+ """Fixture for mock Anthropic client."""
137
+ return MagicMock()
138
+
139
+ async def _run_agent_with_tools(
140
+ self,
141
+ mock_mcp_client: MagicMock,
142
+ mock_model_client: MagicMock,
143
+ tools: list[types.Tool],
144
+ agent_config: dict | None = None,
145
+ ) -> list[types.Tool]:
146
+ """Helper to create agent, initialize with tools and config, return filtered tools."""
147
+ from hud.agents import ClaudeAgent
148
+
149
+ mock_mcp_client.list_tools = AsyncMock(return_value=tools)
150
+
151
+ task = Task(
152
+ prompt="Test",
153
+ mcp_config={"local": {"url": "http://localhost"}},
154
+ agent_config=agent_config or {}
155
+ )
156
+
157
+ agent = ClaudeAgent(
158
+ mcp_client=mock_mcp_client,
159
+ model_client=mock_model_client,
160
+ model="test",
161
+ validate_api_key=False
162
+ )
163
+ await agent.initialize(task)
164
+ return agent.get_available_tools()
165
+
166
+ @pytest.mark.asyncio
167
+ async def test_no_filters_returns_all_tools(self, mock_mcp_client, mock_model_client) -> None:
168
+ """Test that no filters in agent_config returns all tools."""
169
+ tools = [
170
+ types.Tool(name="tool1", description="Tool 1", inputSchema={}),
171
+ types.Tool(name="tool2", description="Tool 2", inputSchema={}),
172
+ types.Tool(name="debug_tool", description="Debug", inputSchema={}),
173
+ ]
174
+
175
+ result = await self._run_agent_with_tools(mock_mcp_client, mock_model_client, tools)
176
+
177
+ assert len(result) == 3
178
+
179
+ @pytest.mark.asyncio
180
+ async def test_allowed_tools_filters_correctly(self, mock_mcp_client, mock_model_client) -> None:
181
+ """Test that allowed_tools in agent_config filters to matching patterns."""
182
+ tools = [
183
+ types.Tool(name="screenshot_take", description="Tool 1", inputSchema={}),
184
+ types.Tool(name="screenshot_full", description="Tool 2", inputSchema={}),
185
+ types.Tool(name="click", description="Tool 3", inputSchema={}),
186
+ ]
187
+ agent_config = {"allowed_tools": ["screenshot_*"]}
188
+
189
+ result = await self._run_agent_with_tools(mock_mcp_client, mock_model_client, tools, agent_config)
190
+
191
+ assert len(result) == 2
192
+ assert all("screenshot" in t.name for t in result)
193
+
194
+ @pytest.mark.asyncio
195
+ async def test_disallowed_tools_excludes_correctly(self, mock_mcp_client, mock_model_client) -> None:
196
+ """Test that disallowed_tools in agent_config excludes matching patterns."""
197
+ tools = [
198
+ types.Tool(name="tool1", description="Tool 1", inputSchema={}),
199
+ types.Tool(name="debug_tool", description="Tool 2", inputSchema={}),
200
+ types.Tool(name="internal_secret", description="Tool 3", inputSchema={}),
201
+ ]
202
+ agent_config = {"disallowed_tools": ["debug_*", "internal_*"]}
203
+
204
+ result = await self._run_agent_with_tools(mock_mcp_client, mock_model_client, tools, agent_config)
205
+
206
+ assert len(result) == 1
207
+ assert result[0].name == "tool1"
208
+
209
+ @pytest.mark.asyncio
210
+ async def test_both_filters_applies_allowed_then_disallowed(self, mock_mcp_client, mock_model_client) -> None:
211
+ """Test that both filters in agent_config work together (disallowed takes precedence)."""
212
+ tools = [
213
+ types.Tool(name="browser_click", description="Tool 1", inputSchema={}),
214
+ types.Tool(name="browser_debug", description="Tool 2", inputSchema={}),
215
+ types.Tool(name="system_click", description="Tool 3", inputSchema={}),
216
+ ]
217
+ agent_config = {
218
+ "allowed_tools": ["browser_*"],
219
+ "disallowed_tools": ["*_debug"]
220
+ }
221
+
222
+ result = await self._run_agent_with_tools(mock_mcp_client, mock_model_client, tools, agent_config)
223
+
224
+ assert len(result) == 1
225
+ assert result[0].name == "browser_click"
226
+
227
+
228
+ class TestRunDatasetToolFiltering:
229
+ """Test tool filtering via run_dataset with agent_config in both init and task."""
230
+
231
+ @pytest.fixture
232
+ def all_tools(self):
233
+ """Fixture for a standard set of tools."""
234
+ return [
235
+ types.Tool(name="browser_click", description="Click", inputSchema={}),
236
+ types.Tool(name="browser_type", description="Type", inputSchema={}),
237
+ types.Tool(name="browser_debug", description="Debug", inputSchema={}),
238
+ types.Tool(name="system_screenshot", description="Screenshot", inputSchema={}),
239
+ types.Tool(name="system_execute", description="Execute", inputSchema={}),
240
+ ]
241
+
242
+ @pytest.fixture
243
+ def captured_agent_fixture(self):
244
+ """Fixture that returns a dictionary to capture the agent instance."""
245
+ return {"agent": None}
246
+
247
+ @pytest.fixture
248
+ def mock_run_context(self, captured_agent_fixture):
249
+ """Fixture for mocking _run_context."""
250
+ async def _mock(self, context, max_steps=10):
251
+ captured_agent_fixture["agent"] = self
252
+ return Trace(reward=1.0, done=True, content="Done")
253
+ return _mock
254
+
255
+ @pytest.fixture
256
+ def mock_call_tools(self):
257
+ """Fixture for mocking call_tools."""
258
+ async def _mock(self, tool_call=None):
259
+ return []
260
+ return _mock
261
+
262
+ @pytest.fixture
263
+ def mock_client_instance(self, all_tools):
264
+ """Fixture for mock MCP client instance."""
265
+ mock_client = MagicMock()
266
+ mock_client.initialize = AsyncMock()
267
+ mock_client.list_tools = AsyncMock(return_value=all_tools)
268
+ mock_client.shutdown = AsyncMock()
269
+ mock_client.mcp_config = {"local": {"url": "http://localhost:8765/mcp"}}
270
+ return mock_client
271
+
272
+ @pytest.mark.asyncio
273
+ async def test_agent_config_intersection_union_via_run_dataset(
274
+ self, all_tools, captured_agent_fixture, mock_run_context, mock_call_tools, mock_client_instance
275
+ ) -> None:
276
+ """Test that allowed_tools intersect and disallowed_tools union when set in both __init__ and task.agent_config."""
277
+ from hud.agents import ClaudeAgent
278
+ from hud.datasets.runner import run_dataset
279
+
280
+ # Create a task with its own agent_config
281
+ task_dict = {
282
+ "prompt": "Test task",
283
+ "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}},
284
+ "agent_config": {
285
+ "allowed_tools": ["browser_*", "system_screenshot"], # Task wants browser_* and system_screenshot
286
+ "disallowed_tools": ["*_debug", "*_execute"], # Task disallows *_debug and *_execute
287
+ }
288
+ }
289
+
290
+ # Agent config passed to __init__ via run_dataset
291
+ agent_init_config = {
292
+ "allowed_tools": ["browser_*", "system_*"], # Agent init wants browser_* and system_*
293
+ "disallowed_tools": ["browser_debug"], # Agent init disallows browser_debug
294
+ "validate_api_key": False,
295
+ }
296
+
297
+ with patch("hud.job"), \
298
+ patch("hud.trace"), \
299
+ patch.object(ClaudeAgent, "_run_context", mock_run_context), \
300
+ patch.object(ClaudeAgent, "call_tools", mock_call_tools), \
301
+ patch("hud.clients.MCPClient", return_value=mock_client_instance):
302
+
303
+ # Run the dataset
304
+ await run_dataset(
305
+ name="test_job",
306
+ dataset=[task_dict],
307
+ agent_class=ClaudeAgent,
308
+ agent_config=agent_init_config,
309
+ max_steps=10,
310
+ )
311
+
312
+ # Verify agent was created and ran
313
+ captured_agent = captured_agent_fixture["agent"]
314
+ assert captured_agent is not None
315
+
316
+ # Get the filtered tools
317
+ filtered_tools = captured_agent.get_available_tools()
318
+ filtered_names = {tool.name for tool in filtered_tools}
319
+
320
+ # Expected behavior:
321
+ # 1. allowed_tools intersection: ["browser_*", "system_*"] ∩ ["browser_*", "system_screenshot"]
322
+ # Exact string intersection: only "browser_*" is in both lists
323
+ # So only tools matching browser_* are allowed: browser_click, browser_type, browser_debug
324
+ # 2. disallowed_tools union: ["browser_debug"] ∪ ["*_debug", "*_execute"]
325
+ # Result: ["browser_debug", "*_debug", "*_execute"] (all patterns included)
326
+ # 3. Final: {browser_click, browser_type, browser_debug} - {browser_debug}
327
+ # Result: browser_click, browser_type
328
+
329
+ expected_tools = {"browser_click", "browser_type"}
330
+ assert filtered_names == expected_tools, f"Expected {expected_tools}, got {filtered_names}"
331
+
332
+ @pytest.mark.asyncio
333
+ async def test_no_allowed_tools_keeps_all_tools_except_disallowed(
334
+ self, all_tools, captured_agent_fixture, mock_run_context, mock_call_tools, mock_client_instance
335
+ ) -> None:
336
+ """Test that when allowed_tools is not set, all tools are available except disallowed ones."""
337
+ from hud.agents import ClaudeAgent
338
+ from hud.datasets.runner import run_dataset
339
+
340
+ # Create a task with its own agent_config (no allowed_tools)
341
+ task_dict = {
342
+ "prompt": "Test task",
343
+ "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}},
344
+ "agent_config": {
345
+ # No allowed_tools set - should allow all tools
346
+ "disallowed_tools": ["*_execute"], # Task disallows *_execute
347
+ }
348
+ }
349
+
350
+ # Agent config passed to __init__ via run_dataset (no allowed_tools)
351
+ agent_init_config = {
352
+ # No allowed_tools set - should allow all tools
353
+ "disallowed_tools": ["browser_debug"], # Agent init disallows browser_debug
354
+ "validate_api_key": False,
355
+ }
356
+
357
+ with patch("hud.job"), \
358
+ patch("hud.trace"), \
359
+ patch.object(ClaudeAgent, "_run_context", mock_run_context), \
360
+ patch.object(ClaudeAgent, "call_tools", mock_call_tools), \
361
+ patch("hud.clients.MCPClient", return_value=mock_client_instance):
362
+
363
+ # Run the dataset
364
+ await run_dataset(
365
+ name="test_job",
366
+ dataset=[task_dict],
367
+ agent_class=ClaudeAgent,
368
+ agent_config=agent_init_config,
369
+ max_steps=10,
370
+ )
371
+
372
+ # Verify agent was created and ran
373
+ captured_agent = captured_agent_fixture["agent"]
374
+ assert captured_agent is not None
375
+
376
+ # Get the filtered tools
377
+ filtered_tools = captured_agent.get_available_tools()
378
+ filtered_names = {tool.name for tool in filtered_tools}
379
+
380
+ # Expected behavior:
381
+ # 1. allowed_tools: None (no allowed_tools set in either init or task)
382
+ # Result: All tools are initially allowed
383
+ # 2. disallowed_tools union: ["browser_debug"] ∪ ["*_execute"]
384
+ # Result: ["browser_debug", "*_execute"] (all patterns included)
385
+ # 3. Final: {all tools} - {browser_debug, system_execute}
386
+ # Result: browser_click, browser_type, system_screenshot
387
+
388
+ expected_tools = {"browser_click", "browser_type", "system_screenshot"}
389
+ assert filtered_names == expected_tools, f"Expected {expected_tools}, got {filtered_names}"
390
+
391
+
392
+ class TestSystemPromptHandling:
393
+ """Test system prompt handling through run_dataset flow."""
394
+
395
+ @pytest.fixture
396
+ def mock_mcp_client(self):
397
+ """Fixture for mock MCP client."""
398
+ client = MagicMock()
399
+ client.initialize = AsyncMock()
400
+ client.list_tools = AsyncMock(return_value=[])
401
+ client.shutdown = AsyncMock()
402
+ client.mcp_config = {"local": {"url": "http://localhost:8765/mcp"}}
403
+ return client
404
+
405
+ @pytest.fixture
406
+ def captured_agent_fixture(self):
407
+ """Fixture that returns a dictionary to capture the agent instance."""
408
+ return {"agent": None}
409
+
410
+ @pytest.fixture
411
+ def mock_run_context(self, captured_agent_fixture):
412
+ """Fixture for mocking _run_context to capture agent."""
413
+ async def _mock(self, context, max_steps=10):
414
+ captured_agent_fixture["agent"] = self
415
+ return Trace(reward=1.0, done=True, content="Done")
416
+ return _mock
417
+
418
+ @pytest.fixture
419
+ def mock_call_tools(self):
420
+ """Fixture for mocking call_tools."""
421
+ async def _mock(self, tool_call=None):
422
+ return []
423
+ return _mock
424
+
425
+ @pytest.mark.asyncio
426
+ async def test_task_system_prompt_only(
427
+ self, captured_agent_fixture, mock_run_context, mock_call_tools, mock_mcp_client
428
+ ) -> None:
429
+ """Test that task system_prompt is appended when agent has default system prompt."""
430
+ from hud.agents import ClaudeAgent
431
+ from hud.agents.base import GLOBAL_SYSTEM_PROMPT
432
+ from hud.datasets.runner import run_dataset
433
+
434
+ task_system_prompt = "Task prompt"
435
+
436
+ # Create a task with its own system_prompt in agent_config
437
+ task_dict = {
438
+ "prompt": "Test task",
439
+ "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}},
440
+ "agent_config": {
441
+ "system_prompt": task_system_prompt,
442
+ }
443
+ }
444
+
445
+ # Agent config with no custom system_prompt (will use default)
446
+ agent_init_config = {
447
+ "validate_api_key": False,
448
+ }
449
+
450
+ with patch("hud.job"), \
451
+ patch("hud.trace"), \
452
+ patch.object(ClaudeAgent, "_run_context", mock_run_context), \
453
+ patch.object(ClaudeAgent, "call_tools", mock_call_tools), \
454
+ patch("hud.clients.MCPClient", return_value=mock_mcp_client):
455
+
456
+ # Run the dataset
457
+ await run_dataset(
458
+ name="test_job",
459
+ dataset=[task_dict],
460
+ agent_class=ClaudeAgent,
461
+ agent_config=agent_init_config,
462
+ max_steps=10,
463
+ )
464
+
465
+ # Verify agent was created and ran
466
+ captured_agent = captured_agent_fixture["agent"]
467
+ assert captured_agent is not None
468
+
469
+ # Verify the task system prompt was appended
470
+ assert captured_agent.system_prompt.endswith(f"\n\n{task_system_prompt}")
471
+ # Verify it starts with the base global system prompt
472
+ assert captured_agent.system_prompt.startswith(GLOBAL_SYSTEM_PROMPT)
473
+
474
+ @pytest.mark.asyncio
475
+ async def test_both_agent_and_task_system_prompts(
476
+ self, captured_agent_fixture, mock_run_context, mock_call_tools, mock_mcp_client
477
+ ) -> None:
478
+ """Test that both agent init and task system prompts are present when both are set."""
479
+ from hud.agents import ClaudeAgent
480
+ from hud.datasets.runner import run_dataset
481
+
482
+ agent_custom_prompt = "Agent init prompt"
483
+ task_system_prompt = "Task prompt"
484
+
485
+ # Create a task with its own system_prompt in agent_config
486
+ task_dict = {
487
+ "prompt": "Test task",
488
+ "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}},
489
+ "agent_config": {
490
+ "system_prompt": task_system_prompt,
491
+ }
492
+ }
493
+
494
+ # Agent config WITH custom system_prompt
495
+ agent_init_config = {
496
+ "system_prompt": agent_custom_prompt,
497
+ "validate_api_key": False,
498
+ }
499
+
500
+ with patch("hud.job"), \
501
+ patch("hud.trace"), \
502
+ patch.object(ClaudeAgent, "_run_context", mock_run_context), \
503
+ patch.object(ClaudeAgent, "call_tools", mock_call_tools), \
504
+ patch("hud.clients.MCPClient", return_value=mock_mcp_client):
505
+
506
+ # Run the dataset
507
+ await run_dataset(
508
+ name="test_job",
509
+ dataset=[task_dict],
510
+ agent_class=ClaudeAgent,
511
+ agent_config=agent_init_config,
512
+ max_steps=10,
513
+ )
514
+
515
+ # Verify agent was created and ran
516
+ captured_agent = captured_agent_fixture["agent"]
517
+ assert captured_agent is not None
518
+
519
+ # Verify the task system prompt was appended at the end
520
+ assert captured_agent.system_prompt.endswith(f"\n\n{task_system_prompt}")
521
+ # Verify it starts with the agent custom prompt
522
+ assert captured_agent.system_prompt.startswith(agent_custom_prompt)
523
+ # Verify both prompts are present
524
+ assert agent_custom_prompt in captured_agent.system_prompt
525
+ assert task_system_prompt in captured_agent.system_prompt
@@ -22,7 +22,7 @@ class TestColors:
22
22
  assert Colors.YELLOW == "\033[93m"
23
23
  assert Colors.GOLD == "\033[33m"
24
24
  assert Colors.RED == "\033[91m"
25
- assert Colors.GRAY == "\033[90m"
25
+ assert Colors.GRAY == "\033[37m"
26
26
  assert Colors.ENDC == "\033[0m"
27
27
  assert Colors.BOLD == "\033[1m"
28
28
 
hud/datasets/parallel.py CHANGED
@@ -261,7 +261,6 @@ async def run_dataset_parallel_manual(
261
261
  max_steps: int = 10,
262
262
  split: str = "train",
263
263
  auto_respond: bool = False,
264
- custom_system_prompt: str | None = None,
265
264
  ) -> list[Any]:
266
265
  """
267
266
  Run all tasks in a dataset using process-based parallelism with manual configuration.
@@ -282,7 +281,6 @@ async def run_dataset_parallel_manual(
282
281
  max_steps: Maximum steps per task
283
282
  split: Dataset split when loading from string
284
283
  auto_respond: Whether to use ResponseAgent
285
- custom_system_prompt: Override system prompt for all tasks
286
284
 
287
285
  Returns:
288
286
  List of results in the same order as the input dataset
@@ -349,14 +347,6 @@ async def run_dataset_parallel_manual(
349
347
  else:
350
348
  raise ValueError(f"Dataset must be string, Dataset, or list, got {type(dataset)}")
351
349
 
352
- # Apply custom system prompt if provided
353
- if custom_system_prompt:
354
- for task_dict in task_dicts:
355
- if "system_prompt" not in task_dict:
356
- task_dict["system_prompt"] = custom_system_prompt
357
- else:
358
- task_dict["system_prompt"] += "\n" + custom_system_prompt
359
-
360
350
  # Prepare job metadata
361
351
  job_metadata = metadata or {}
362
352
  job_metadata.update(
@@ -380,8 +370,6 @@ async def run_dataset_parallel_manual(
380
370
  except Exception:
381
371
  logger.warning("Failed to extract dataset verification info")
382
372
 
383
- # task_dicts = task_dicts[:10]
384
-
385
373
  # Create job context
386
374
  with hud.job(name, metadata=job_metadata, dataset_link=dataset_link) as job_obj:
387
375
  # Prepare agent class info for pickling
hud/datasets/runner.py CHANGED
@@ -27,7 +27,6 @@ async def run_dataset(
27
27
  max_steps: int = 10,
28
28
  split: str = "train",
29
29
  auto_respond: bool = False,
30
- custom_system_prompt: str | None = None,
31
30
  ) -> list[Any]:
32
31
  """
33
32
  Run all tasks in a dataset with automatic job tracking.
@@ -43,7 +42,6 @@ async def run_dataset(
43
42
  max_steps: Maximum steps per task
44
43
  split: Dataset split to use when loading from string (default: "train")
45
44
  auto_respond: Whether to use auto-response agent
46
- custom_system_prompt: Override system prompt for all tasks
47
45
 
48
46
  Returns:
49
47
  List of results from agent.run() in dataset order
@@ -102,8 +100,7 @@ async def run_dataset(
102
100
  async with sem:
103
101
  # Create trace for this task
104
102
  task_name = task_dict.get("prompt") or f"Task {index}"
105
- if custom_system_prompt and "system_prompt" not in task_dict:
106
- task_dict["system_prompt"] = custom_system_prompt
103
+
107
104
  # Ensure task_id is a string for baggage propagation
108
105
  raw_task_id = task_dict.get("id")
109
106
  safe_task_id = str(raw_task_id) if raw_task_id is not None else None
hud/rl/actor.py CHANGED
@@ -37,7 +37,7 @@ class Actor:
37
37
  # Match connection limits to parallel_episodes to avoid bottlenecks
38
38
  # Use shorter per-request timeout and keep retries modest to avoid long blocking
39
39
  http_client = create_retry_httpx_client(
40
- timeout=httpx.Timeout(30.0),
40
+ timeout=httpx.Timeout(60.0),
41
41
  )
42
42
  return AsyncOpenAI(
43
43
  base_url=base_url,
@@ -151,7 +151,9 @@ if __name__ == "__main__":
151
151
  "name": "evaluate",
152
152
  "arguments": {"name": "game_2048_max_number", "arguments": {"target": 128}},
153
153
  },
154
- "system_prompt": "You are an expert 2048 game player. Use arrow keys to reach the target tile. First take a screenshot, then make strategic moves.", # noqa: E501
154
+ "agent_config": {
155
+ "system_prompt": "You are an expert 2048 game player. Use arrow keys to reach the target tile. First take a screenshot, then make strategic moves.", # noqa: E501
156
+ },
155
157
  }
156
158
 
157
159
  task = Task(**task_data)
hud/rl/distributed.py CHANGED
@@ -81,7 +81,7 @@ def broadcast_object(obj: Any, src: int = 0) -> Any:
81
81
  return obj
82
82
 
83
83
  obj_list = [obj] if dist.get_rank() == src else [None]
84
- dist.broadcast_object_list(obj_list, src=src, device=torch.device("cpu"))
84
+ dist.broadcast_object_list(obj_list, src=src)
85
85
  return obj_list[0]
86
86
 
87
87
 
hud/rl/learner.py CHANGED
@@ -148,11 +148,12 @@ class GRPOLearner:
148
148
 
149
149
  # Add LoRA adapters or load existing adapter
150
150
  policy.config.use_cache = False
151
-
151
+
152
152
  if model_cfg.adapter_path:
153
153
  # Load existing adapter as baseline
154
154
  self.log(f"Loading existing LoRA adapter from: {model_cfg.adapter_path}")
155
155
  from peft import PeftModel
156
+
156
157
  policy = PeftModel.from_pretrained(policy, model_cfg.adapter_path)
157
158
  # Enable adapter training
158
159
  policy.train()
hud/rl/train.py CHANGED
@@ -95,7 +95,7 @@ async def train(config: Config, tasks: list[Task]) -> None:
95
95
  if is_main_process()
96
96
  else None
97
97
  )
98
-
98
+
99
99
  # Load initial adapter if provided
100
100
  if is_main_process() and config.model.adapter_path and vllm:
101
101
  hud_console.info(f"Loading baseline adapter from: {config.model.adapter_path}")
hud/telemetry/trace.py CHANGED
@@ -139,7 +139,7 @@ def trace(
139
139
  else:
140
140
  # Use a placeholder for custom backends
141
141
  logger.warning(
142
- "HUD API key is not set, using a placeholder for the task run ID. If this looks wrong, check your API key." # noqa: E501
142
+ "HUD API key is not set, using a placeholder for the task run ID. If this looks wrong, check your API key." # noqa: E501
143
143
  )
144
144
  task_run_id = str(uuid.uuid4())
145
145