hud-python 0.4.47__py3-none-any.whl → 0.4.48__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/agents/base.py +49 -142
- hud/agents/claude.py +5 -6
- hud/agents/misc/integration_test_agent.py +2 -0
- hud/agents/tests/test_base.py +2 -5
- hud/cli/__init__.py +2 -2
- hud/cli/eval.py +14 -9
- hud/cli/flows/tasks.py +2 -4
- hud/cli/rl/local_runner.py +25 -13
- hud/cli/rl/vllm.py +2 -0
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_eval.py +525 -0
- hud/cli/tests/test_utils.py +1 -1
- hud/datasets/parallel.py +0 -12
- hud/datasets/runner.py +1 -4
- hud/rl/actor.py +4 -2
- hud/rl/distributed.py +1 -1
- hud/rl/learner.py +2 -1
- hud/rl/train.py +1 -1
- hud/telemetry/trace.py +1 -1
- hud/tools/base.py +11 -9
- hud/tools/computer/__init__.py +2 -0
- hud/tools/computer/qwen.py +431 -0
- hud/tools/computer/settings.py +16 -0
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/playwright.py +1 -1
- hud/types.py +2 -3
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.47.dist-info → hud_python-0.4.48.dist-info}/METADATA +1 -1
- {hud_python-0.4.47.dist-info → hud_python-0.4.48.dist-info}/RECORD +33 -31
- {hud_python-0.4.47.dist-info → hud_python-0.4.48.dist-info}/WHEEL +0 -0
- {hud_python-0.4.47.dist-info → hud_python-0.4.48.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.47.dist-info → hud_python-0.4.48.dist-info}/licenses/LICENSE +0 -0
hud/agents/base.py
CHANGED
|
@@ -3,10 +3,11 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
+
import fnmatch
|
|
6
7
|
import json
|
|
7
8
|
import logging
|
|
8
9
|
from abc import ABC, abstractmethod
|
|
9
|
-
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
|
10
|
+
from typing import TYPE_CHECKING, Any, ClassVar, List, Literal
|
|
10
11
|
|
|
11
12
|
import mcp.types as types
|
|
12
13
|
|
|
@@ -96,12 +97,9 @@ class MCPAgent(ABC):
|
|
|
96
97
|
self.console.set_verbose(True)
|
|
97
98
|
|
|
98
99
|
# User filtering
|
|
99
|
-
self.allowed_tools = allowed_tools
|
|
100
|
-
self.disallowed_tools = disallowed_tools
|
|
101
|
-
|
|
102
|
-
# Task filtering
|
|
103
|
-
self.agent_tools = None
|
|
104
|
-
self.lifecycle_tools = []
|
|
100
|
+
self.allowed_tools: List[str] | None = allowed_tools
|
|
101
|
+
self.disallowed_tools: List[str] | None = disallowed_tools
|
|
102
|
+
self._available_tools: List[types.Tool] | None = None
|
|
105
103
|
|
|
106
104
|
# Messages
|
|
107
105
|
self.system_prompt = system_prompt
|
|
@@ -109,7 +107,6 @@ class MCPAgent(ABC):
|
|
|
109
107
|
self.initial_screenshot = initial_screenshot
|
|
110
108
|
|
|
111
109
|
# Initialize these here so methods can be called before initialize()
|
|
112
|
-
self._available_tools: list[types.Tool] = []
|
|
113
110
|
self._tool_map: dict[str, types.Tool] = {} # Simplified: just name to tool
|
|
114
111
|
self.response_tool_name = None
|
|
115
112
|
|
|
@@ -146,37 +143,48 @@ class MCPAgent(ABC):
|
|
|
146
143
|
except Exception as e:
|
|
147
144
|
self._handle_connection_error(e)
|
|
148
145
|
|
|
149
|
-
# If task is provided, add lifecycle tools
|
|
146
|
+
# If task is provided, apply agent_config and add lifecycle tools
|
|
150
147
|
if isinstance(task, Task):
|
|
151
|
-
if
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
#
|
|
179
|
-
|
|
148
|
+
# Apply agent_config if present
|
|
149
|
+
if task.agent_config:
|
|
150
|
+
if "system_prompt" in task.agent_config and task.agent_config["system_prompt"]:
|
|
151
|
+
self.system_prompt += "\n\n" + task.agent_config["system_prompt"]
|
|
152
|
+
if "append_setup_output" in task.agent_config:
|
|
153
|
+
self.append_setup_output = task.agent_config["append_setup_output"]
|
|
154
|
+
if "initial_screenshot" in task.agent_config:
|
|
155
|
+
self.initial_screenshot = task.agent_config["initial_screenshot"]
|
|
156
|
+
if "allowed_tools" in task.agent_config:
|
|
157
|
+
# If allowed_tools has already been set, we take the intersection of the two
|
|
158
|
+
# If the list had been empty, we were allowing all tools, so we overwrite in this
|
|
159
|
+
if isinstance(self.allowed_tools, list) and len(self.allowed_tools) > 0:
|
|
160
|
+
self.allowed_tools = [tool for tool in self.allowed_tools if tool in task.agent_config["allowed_tools"]]
|
|
161
|
+
else: # If allowed_tools is None, we overwrite it
|
|
162
|
+
self.allowed_tools = task.agent_config["allowed_tools"]
|
|
163
|
+
if "disallowed_tools" in task.agent_config:
|
|
164
|
+
# If disallowed_tools has already been set, we take the union of the two
|
|
165
|
+
if isinstance(self.disallowed_tools, list):
|
|
166
|
+
self.disallowed_tools.extend(task.agent_config["disallowed_tools"])
|
|
167
|
+
else: # If disallowed_tools is None, we overwrite it
|
|
168
|
+
self.disallowed_tools = task.agent_config["disallowed_tools"]
|
|
169
|
+
|
|
170
|
+
all_tools = await self.mcp_client.list_tools()
|
|
171
|
+
self._available_tools = []
|
|
172
|
+
|
|
173
|
+
# Filter tools based on allowed and disallowed patterns
|
|
174
|
+
# No allowed tools and no disallowed tools -> we accept all tools
|
|
175
|
+
# No allowed tools and disallowed tools -> we accept all tools except the disallowed ones
|
|
176
|
+
for tool in all_tools:
|
|
177
|
+
if self.allowed_tools is not None:
|
|
178
|
+
if not any(fnmatch.fnmatch(tool.name, pattern) for pattern in self.allowed_tools):
|
|
179
|
+
continue
|
|
180
|
+
if self.disallowed_tools is not None:
|
|
181
|
+
if any(fnmatch.fnmatch(tool.name, pattern) for pattern in self.disallowed_tools):
|
|
182
|
+
continue
|
|
183
|
+
self._available_tools.append(tool)
|
|
184
|
+
|
|
185
|
+
self.console.info(
|
|
186
|
+
f"Agent initialized with {len(self.get_available_tools())} tools: {', '.join([t.name for t in self.get_available_tools()])}" # noqa: E501
|
|
187
|
+
)
|
|
180
188
|
|
|
181
189
|
async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace:
|
|
182
190
|
"""
|
|
@@ -575,108 +583,6 @@ class MCPAgent(ABC):
|
|
|
575
583
|
|
|
576
584
|
return await self.format_blocks(blocks)
|
|
577
585
|
|
|
578
|
-
async def _filter_tools(self) -> None:
|
|
579
|
-
"""Apply tool filtering based on allowed/disallowed lists."""
|
|
580
|
-
# Get all tools from client
|
|
581
|
-
if self.mcp_client is None:
|
|
582
|
-
raise ValueError("MCP client is not initialized")
|
|
583
|
-
|
|
584
|
-
all_tools = await self.mcp_client.list_tools()
|
|
585
|
-
|
|
586
|
-
response_tools_by_server: dict[str, str] = {} # server_name -> tool_name
|
|
587
|
-
for tool in all_tools:
|
|
588
|
-
if "response" in tool.name or tool.name == "response":
|
|
589
|
-
self.console.debug(f"Found response tool: '{tool.name}'")
|
|
590
|
-
# Extract server name from tool name (e.g., "grader_response" -> "grader")
|
|
591
|
-
if "_" in tool.name:
|
|
592
|
-
server_name = tool.name.split("_", 1)[0]
|
|
593
|
-
response_tools_by_server[server_name] = tool.name
|
|
594
|
-
else:
|
|
595
|
-
response_tools_by_server["_default"] = tool.name
|
|
596
|
-
|
|
597
|
-
# Add response tool to lifecycle tools BEFORE filtering
|
|
598
|
-
if response_tools_by_server and hasattr(self.mcp_client, "mcp_config"):
|
|
599
|
-
# Get server names in order from mcp_config
|
|
600
|
-
server_names = list(self.mcp_client.mcp_config.keys())
|
|
601
|
-
self.console.debug(f"Server names: {server_names}")
|
|
602
|
-
|
|
603
|
-
# Try to find response tool from last server first
|
|
604
|
-
response_tool_name = None
|
|
605
|
-
for server_name in reversed(server_names):
|
|
606
|
-
if server_name in response_tools_by_server:
|
|
607
|
-
response_tool_name = response_tools_by_server[server_name]
|
|
608
|
-
self.console.debug(
|
|
609
|
-
f"Found response tool '{response_tool_name}' from server '{server_name}'"
|
|
610
|
-
)
|
|
611
|
-
break
|
|
612
|
-
|
|
613
|
-
# Fallback to any response tool
|
|
614
|
-
if not response_tool_name and response_tools_by_server:
|
|
615
|
-
response_tool_name = next(iter(response_tools_by_server.values()))
|
|
616
|
-
self.console.debug(f"Using fallback response tool '{response_tool_name}'")
|
|
617
|
-
|
|
618
|
-
# Add to lifecycle tools if found
|
|
619
|
-
if response_tool_name and response_tool_name not in self.lifecycle_tools:
|
|
620
|
-
self.console.debug(f"Auto-detected '{response_tool_name}' tool as a lifecycle tool")
|
|
621
|
-
self.response_tool_name = response_tool_name
|
|
622
|
-
self.lifecycle_tools.append(response_tool_name)
|
|
623
|
-
elif response_tool_name:
|
|
624
|
-
self.console.debug(
|
|
625
|
-
f"Response tool '{response_tool_name}' already in lifecycle_tools"
|
|
626
|
-
)
|
|
627
|
-
self.response_tool_name = response_tool_name
|
|
628
|
-
else:
|
|
629
|
-
self.console.debug("No response tools found or no mcp_config")
|
|
630
|
-
|
|
631
|
-
# Filter tools
|
|
632
|
-
self._available_tools = []
|
|
633
|
-
self._tool_map = {}
|
|
634
|
-
|
|
635
|
-
self.console.debug(f"All tools: {[t.name for t in all_tools]}")
|
|
636
|
-
self.console.debug(f"Allowed tools: {self.allowed_tools}")
|
|
637
|
-
self.console.debug(f"Agent tools: {self.agent_tools}")
|
|
638
|
-
self.console.debug(f"Disallowed tools: {self.disallowed_tools}")
|
|
639
|
-
self.console.debug(f"Lifecycle tools: {self.lifecycle_tools}")
|
|
640
|
-
|
|
641
|
-
for tool in all_tools:
|
|
642
|
-
# Lifecycle tools (setup, evaluate, response) should always be included
|
|
643
|
-
is_lifecycle = tool.name in self.lifecycle_tools
|
|
644
|
-
|
|
645
|
-
# Check if tool should be included
|
|
646
|
-
if not is_lifecycle:
|
|
647
|
-
if self.allowed_tools and tool.name not in self.allowed_tools:
|
|
648
|
-
self.console.debug(f"Skipping tool '{tool.name}' - not in allowed_tools")
|
|
649
|
-
continue
|
|
650
|
-
if self.agent_tools and tool.name not in self.agent_tools:
|
|
651
|
-
self.console.debug(f"Skipping tool '{tool.name}' - not in agent_tools")
|
|
652
|
-
continue
|
|
653
|
-
if tool.name in self.disallowed_tools:
|
|
654
|
-
self.console.debug(f"Skipping tool '{tool.name}' - in disallowed_tools")
|
|
655
|
-
continue
|
|
656
|
-
|
|
657
|
-
self.console.debug(
|
|
658
|
-
f"Adding tool '{tool.name}' to available tools (lifecycle={is_lifecycle})"
|
|
659
|
-
)
|
|
660
|
-
self._available_tools.append(tool)
|
|
661
|
-
self._tool_map[tool.name] = tool
|
|
662
|
-
|
|
663
|
-
# Check if all required tools are available
|
|
664
|
-
if self.required_tools:
|
|
665
|
-
available_tool_names = {tool.name for tool in self._available_tools}
|
|
666
|
-
missing_tools = [
|
|
667
|
-
tool for tool in self.required_tools if tool not in available_tool_names
|
|
668
|
-
]
|
|
669
|
-
if missing_tools:
|
|
670
|
-
raise ValueError(
|
|
671
|
-
f"Required tools not available: {missing_tools}. "
|
|
672
|
-
f"Available tools: {list(available_tool_names)}"
|
|
673
|
-
)
|
|
674
|
-
|
|
675
|
-
available_tools = self.get_available_tools()
|
|
676
|
-
self.console.info(
|
|
677
|
-
f"Agent initialized with {len(available_tools)} tools: {', '.join([t.name for t in available_tools])}" # noqa: E501
|
|
678
|
-
)
|
|
679
|
-
|
|
680
586
|
async def _maybe_submit_response(self, response: AgentResponse, messages: list[Any]) -> None:
|
|
681
587
|
"""Submit response through lifecycle tool if available.
|
|
682
588
|
|
|
@@ -715,8 +621,9 @@ class MCPAgent(ABC):
|
|
|
715
621
|
|
|
716
622
|
def get_available_tools(self) -> list[types.Tool]:
|
|
717
623
|
"""Get list of available MCP tools for LLM use (excludes lifecycle tools)."""
|
|
718
|
-
|
|
719
|
-
|
|
624
|
+
if self._available_tools is None:
|
|
625
|
+
raise RuntimeError("Tools have not been initialized. Call initialize() before accessing available tools.")
|
|
626
|
+
return self._available_tools
|
|
720
627
|
|
|
721
628
|
def get_tool_schemas(self) -> list[dict]:
|
|
722
629
|
"""Get tool schemas in a format suitable for the model."""
|
hud/agents/claude.py
CHANGED
|
@@ -326,7 +326,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
326
326
|
selected_computer_tool = None
|
|
327
327
|
|
|
328
328
|
for priority_name in computer_tool_priority:
|
|
329
|
-
for tool in self.
|
|
329
|
+
for tool in self.get_available_tools():
|
|
330
330
|
# Check both exact match and suffix match (for prefixed tools)
|
|
331
331
|
if tool.name == priority_name or tool.name.endswith(f"_{priority_name}"):
|
|
332
332
|
selected_computer_tool = tool
|
|
@@ -350,13 +350,12 @@ class ClaudeAgent(MCPAgent):
|
|
|
350
350
|
)
|
|
351
351
|
|
|
352
352
|
# Add other non-computer tools
|
|
353
|
-
for tool in self.
|
|
354
|
-
# Skip computer tools (already handled)
|
|
355
|
-
|
|
353
|
+
for tool in self.get_available_tools():
|
|
354
|
+
# Skip computer tools (already handled)
|
|
355
|
+
if any(
|
|
356
356
|
tool.name == priority_name or tool.name.endswith(f"_{priority_name}")
|
|
357
357
|
for priority_name in computer_tool_priority
|
|
358
|
-
)
|
|
359
|
-
if is_computer_tool or tool.name in self.lifecycle_tools:
|
|
358
|
+
):
|
|
360
359
|
continue
|
|
361
360
|
|
|
362
361
|
claude_tool = {
|
|
@@ -17,6 +17,8 @@ class IntegrationTestRunner(MCPAgent):
|
|
|
17
17
|
# Initialize using base to set up client and telemetry correctly
|
|
18
18
|
await self.initialize(task)
|
|
19
19
|
|
|
20
|
+
self.console.info(f"Full system prompt: {self.system_prompt}")
|
|
21
|
+
|
|
20
22
|
# Validate task shape
|
|
21
23
|
if not getattr(task, "integration_test_tool", None):
|
|
22
24
|
raise ValueError(
|
hud/agents/tests/test_base.py
CHANGED
|
@@ -326,9 +326,6 @@ class TestBaseMCPAgent:
|
|
|
326
326
|
"""Test getting tool schemas."""
|
|
327
327
|
agent = MockMCPAgent()
|
|
328
328
|
|
|
329
|
-
# Add setup to lifecycle tools to test filtering
|
|
330
|
-
agent.lifecycle_tools = ["setup"]
|
|
331
|
-
|
|
332
329
|
agent._available_tools = [
|
|
333
330
|
types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}),
|
|
334
331
|
types.Tool(name="setup", description="Setup", inputSchema={"type": "object"}),
|
|
@@ -598,7 +595,7 @@ class TestMCPAgentExtended:
|
|
|
598
595
|
agent = MockAgentExtended(mcp_client=mock_client, allowed_tools=["tool1", "tool3"])
|
|
599
596
|
await agent.initialize("test")
|
|
600
597
|
|
|
601
|
-
available_names = [tool.name for tool in agent.
|
|
598
|
+
available_names = [tool.name for tool in agent.get_available_tools()]
|
|
602
599
|
assert "tool1" in available_names
|
|
603
600
|
assert "tool3" in available_names
|
|
604
601
|
assert "tool2" not in available_names
|
|
@@ -617,7 +614,7 @@ class TestMCPAgentExtended:
|
|
|
617
614
|
agent = MockAgentExtended(mcp_client=mock_client, disallowed_tools=["tool2"])
|
|
618
615
|
await agent.initialize("test")
|
|
619
616
|
|
|
620
|
-
available_names = [tool.name for tool in agent.
|
|
617
|
+
available_names = [tool.name for tool in agent.get_available_tools()]
|
|
621
618
|
assert "tool1" in available_names
|
|
622
619
|
assert "tool3" in available_names
|
|
623
620
|
assert "tool2" not in available_names
|
hud/cli/__init__.py
CHANGED
|
@@ -935,8 +935,8 @@ def eval(
|
|
|
935
935
|
"--max-concurrent",
|
|
936
936
|
help="Max concurrent tasks (prevents rate limits in both asyncio and parallel modes)",
|
|
937
937
|
),
|
|
938
|
-
max_steps: int = typer.Option(
|
|
939
|
-
|
|
938
|
+
max_steps: int | None = typer.Option(
|
|
939
|
+
None,
|
|
940
940
|
"--max-steps",
|
|
941
941
|
help="Maximum steps per task (default: 10 for single, 50 for full)",
|
|
942
942
|
),
|
hud/cli/eval.py
CHANGED
|
@@ -199,6 +199,8 @@ async def run_single_task(
|
|
|
199
199
|
) -> None:
|
|
200
200
|
"""Load one task and execute it, or detect if JSON contains a list and run as dataset."""
|
|
201
201
|
|
|
202
|
+
# Provide early feedback to user
|
|
203
|
+
hud_console.info("🔧 Initializing evaluation...")
|
|
202
204
|
# Import Task and run_dataset lazily
|
|
203
205
|
try:
|
|
204
206
|
from hud.utils.tasks import load_tasks
|
|
@@ -318,7 +320,10 @@ async def run_single_task(
|
|
|
318
320
|
)
|
|
319
321
|
display_group_statistics(stats, show_details=True)
|
|
320
322
|
else:
|
|
321
|
-
#
|
|
323
|
+
# Enable agent step logging for single task mode
|
|
324
|
+
logging.getLogger("hud.agents").setLevel(logging.INFO)
|
|
325
|
+
logging.getLogger("hud.agents.base").setLevel(logging.INFO)
|
|
326
|
+
|
|
322
327
|
with hud.trace(name=task_prompt):
|
|
323
328
|
agent = build_agent(
|
|
324
329
|
agent_type,
|
|
@@ -352,6 +357,9 @@ async def run_full_dataset(
|
|
|
352
357
|
Uses either asyncio-based run_dataset or process-based parallel execution
|
|
353
358
|
depending on the parallel flag."""
|
|
354
359
|
|
|
360
|
+
# Provide early feedback to user
|
|
361
|
+
hud_console.info("🔧 Initializing evaluation...")
|
|
362
|
+
|
|
355
363
|
# Import run_dataset lazily
|
|
356
364
|
try:
|
|
357
365
|
from hud.datasets import run_dataset, run_dataset_parallel, run_dataset_parallel_manual
|
|
@@ -367,7 +375,7 @@ async def run_full_dataset(
|
|
|
367
375
|
hud_console.info(f"📊 Loading tasks from: {source}…")
|
|
368
376
|
tasks: list[Task] = load_tasks(source) # type: ignore[assignment]
|
|
369
377
|
|
|
370
|
-
if
|
|
378
|
+
if len(tasks) == 0:
|
|
371
379
|
hud_console.error(f"No tasks found in: {source}")
|
|
372
380
|
raise typer.Exit(1)
|
|
373
381
|
|
|
@@ -646,10 +654,10 @@ def eval_command(
|
|
|
646
654
|
hud eval hud-evals/SheetBench-50 --full --agent claude
|
|
647
655
|
|
|
648
656
|
# Run large dataset with PARALLEL execution (auto-optimized)
|
|
649
|
-
hud eval hud-evals/OSWorld-Verified-
|
|
657
|
+
hud eval hud-evals/OSWorld-Verified-Gold --full --parallel
|
|
650
658
|
|
|
651
659
|
# Parallel mode with manual configuration (16 workers, 25 tasks each)
|
|
652
|
-
hud eval hud-evals/OSWorld-Verified-
|
|
660
|
+
hud eval hud-evals/OSWorld-Verified-Gold --full --parallel --max-workers 16
|
|
653
661
|
|
|
654
662
|
# Limit total concurrent tasks to prevent rate limits
|
|
655
663
|
hud eval hud-evals/SheetBench-50 --full --parallel --max-concurrent 20
|
|
@@ -674,6 +682,8 @@ def eval_command(
|
|
|
674
682
|
"""
|
|
675
683
|
from hud.settings import settings
|
|
676
684
|
|
|
685
|
+
# Always configure basic logging so agent steps can be logged
|
|
686
|
+
# Set to INFO by default for consistency with run_evaluation.py
|
|
677
687
|
if very_verbose:
|
|
678
688
|
logging.basicConfig(
|
|
679
689
|
level=logging.DEBUG,
|
|
@@ -683,11 +693,6 @@ def eval_command(
|
|
|
683
693
|
logging.getLogger("hud.agents").setLevel(logging.DEBUG)
|
|
684
694
|
logging.getLogger("hud.agents.base").setLevel(logging.DEBUG)
|
|
685
695
|
elif verbose:
|
|
686
|
-
logging.basicConfig(
|
|
687
|
-
level=logging.INFO,
|
|
688
|
-
format="%(asctime)s - %(name)s - %(message)s",
|
|
689
|
-
datefmt="%H:%M:%S",
|
|
690
|
-
)
|
|
691
696
|
logging.getLogger("hud.agents").setLevel(logging.INFO)
|
|
692
697
|
logging.getLogger("hud.agents.base").setLevel(logging.INFO)
|
|
693
698
|
|
hud/cli/flows/tasks.py
CHANGED
|
@@ -364,10 +364,8 @@ def convert_tasks_to_remote(tasks_file: str) -> str:
|
|
|
364
364
|
item["setup_tool"] = _simplify_tool_call(t.setup_tool)
|
|
365
365
|
if t.evaluate_tool is not None:
|
|
366
366
|
item["evaluate_tool"] = _simplify_tool_call(t.evaluate_tool)
|
|
367
|
-
if t.
|
|
368
|
-
item["
|
|
369
|
-
if t.system_prompt is not None:
|
|
370
|
-
item["system_prompt"] = t.system_prompt
|
|
367
|
+
if t.agent_config is not None:
|
|
368
|
+
item["agent_config"] = t.agent_config
|
|
371
369
|
if t.metadata:
|
|
372
370
|
item["metadata"] = t.metadata
|
|
373
371
|
if t.id is not None:
|
hud/cli/rl/local_runner.py
CHANGED
|
@@ -230,19 +230,33 @@ def run_local_training(
|
|
|
230
230
|
console.print("Enter the model name (HuggingFace ID):")
|
|
231
231
|
model = input().strip()
|
|
232
232
|
|
|
233
|
-
#
|
|
234
|
-
if
|
|
233
|
+
# try to get model from config file
|
|
234
|
+
if config_file:
|
|
235
|
+
console.print(f"\n[cyan]Loading configuration from: {config_file}[/cyan]")
|
|
236
|
+
config = load_config(config_file)
|
|
237
|
+
if hasattr(config, "model") and hasattr(config.model, "base_model"):
|
|
238
|
+
if model is None:
|
|
239
|
+
model = config.model.base_model
|
|
240
|
+
else:
|
|
241
|
+
console.print(
|
|
242
|
+
f"[yellow]Model already set to {model}, using that instead "
|
|
243
|
+
f"of {config.model.base_model}[/yellow] (override)"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if model is None:
|
|
247
|
+
console.print("[red]❌ No model specified either through CLI or config file[/red]")
|
|
235
248
|
try:
|
|
236
|
-
|
|
237
|
-
except ValueError as e:
|
|
238
|
-
console.print(f"\n[red]❌ {e}[/red]")
|
|
239
|
-
try:
|
|
240
|
-
import typer
|
|
249
|
+
import typer
|
|
241
250
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
251
|
+
raise typer.Exit(1)
|
|
252
|
+
except Exception:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
# Validate model is a VL model (whether provided via CLI or selected)
|
|
256
|
+
try:
|
|
257
|
+
validate_vl_model(model)
|
|
258
|
+
except ValueError as e:
|
|
259
|
+
console.print(f"\n[red]❌ {e}[/red]")
|
|
246
260
|
try:
|
|
247
261
|
import typer
|
|
248
262
|
|
|
@@ -488,7 +502,6 @@ def run_local_training(
|
|
|
488
502
|
from .vllm import start_vllm_server, wait_for_vllm_server
|
|
489
503
|
|
|
490
504
|
start_vllm_server(config.model.base_model, vllm_gpu_idx, restart=restart)
|
|
491
|
-
|
|
492
505
|
server_ready = asyncio.run(wait_for_vllm_server())
|
|
493
506
|
if not server_ready:
|
|
494
507
|
console.print("[red]❌ Failed to start vLLM server[/red]")
|
|
@@ -507,7 +520,6 @@ def run_local_training(
|
|
|
507
520
|
f"\n[bold green]🎯 Starting DDP training on {len(training_gpus)} GPUs...[/bold green]\n"
|
|
508
521
|
)
|
|
509
522
|
launch_ddp_training(training_gpus, tasks_file, temp_config_path, verbose)
|
|
510
|
-
console.print("\n[green]✅ Training completed successfully![/green]")
|
|
511
523
|
else:
|
|
512
524
|
console.print("\n[bold green]🎯 Starting single-GPU training...[/bold green]\n")
|
|
513
525
|
try:
|
hud/cli/rl/vllm.py
CHANGED
|
@@ -165,6 +165,8 @@ async def wait_for_vllm_server(timeout: int = 360) -> bool: # noqa: ASYNC109
|
|
|
165
165
|
if response.status_code == 200:
|
|
166
166
|
console.print("[green]✅ vLLM server is ready![/green]")
|
|
167
167
|
return True
|
|
168
|
+
except httpx.ConnectError:
|
|
169
|
+
pass
|
|
168
170
|
except Exception as e:
|
|
169
171
|
hud_console.error(f"Failed to connect to vLLM server: {e}")
|
|
170
172
|
|
|
@@ -214,6 +214,7 @@ class TestAnalyzeFromMetadata:
|
|
|
214
214
|
|
|
215
215
|
@mock.patch("hud.cli.utils.metadata.check_local_cache")
|
|
216
216
|
@mock.patch("hud.cli.utils.metadata.fetch_lock_from_registry")
|
|
217
|
+
@mock.patch("hud.cli.utils.metadata.hud_console")
|
|
217
218
|
@mock.patch("hud.cli.utils.metadata.console")
|
|
218
219
|
async def test_analyze_not_found(self, mock_console, mock_hud_console, mock_fetch, mock_check):
|
|
219
220
|
"""Test when environment not found anywhere."""
|
|
@@ -222,9 +223,9 @@ class TestAnalyzeFromMetadata:
|
|
|
222
223
|
|
|
223
224
|
await analyze_from_metadata("test/notfound:latest", "json", verbose=False)
|
|
224
225
|
|
|
225
|
-
# Should show error
|
|
226
|
+
# Should show error via hud_console
|
|
226
227
|
mock_hud_console.error.assert_called_with("Environment metadata not found")
|
|
227
|
-
# Should print suggestions
|
|
228
|
+
# Should print suggestions via console
|
|
228
229
|
mock_console.print.assert_called()
|
|
229
230
|
|
|
230
231
|
@mock.patch("hud.cli.utils.metadata.check_local_cache")
|