hud-python 0.2.10__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +14 -5
- hud/env/docker_client.py +1 -1
- hud/env/environment.py +10 -7
- hud/env/local_docker_client.py +1 -1
- hud/env/remote_client.py +1 -1
- hud/env/remote_docker_client.py +2 -2
- hud/exceptions.py +2 -1
- hud/mcp_agent/__init__.py +15 -0
- hud/mcp_agent/base.py +723 -0
- hud/mcp_agent/claude.py +316 -0
- hud/mcp_agent/langchain.py +231 -0
- hud/mcp_agent/openai.py +318 -0
- hud/mcp_agent/tests/__init__.py +1 -0
- hud/mcp_agent/tests/test_base.py +437 -0
- hud/settings.py +14 -2
- hud/task.py +4 -0
- hud/telemetry/__init__.py +11 -7
- hud/telemetry/_trace.py +82 -71
- hud/telemetry/context.py +9 -27
- hud/telemetry/exporter.py +6 -5
- hud/telemetry/instrumentation/mcp.py +174 -410
- hud/telemetry/mcp_models.py +13 -74
- hud/telemetry/tests/test_context.py +9 -6
- hud/telemetry/tests/test_trace.py +92 -61
- hud/tools/__init__.py +21 -0
- hud/tools/base.py +65 -0
- hud/tools/bash.py +137 -0
- hud/tools/computer/__init__.py +13 -0
- hud/tools/computer/anthropic.py +411 -0
- hud/tools/computer/hud.py +315 -0
- hud/tools/computer/openai.py +283 -0
- hud/tools/edit.py +290 -0
- hud/tools/executors/__init__.py +13 -0
- hud/tools/executors/base.py +331 -0
- hud/tools/executors/pyautogui.py +585 -0
- hud/tools/executors/tests/__init__.py +1 -0
- hud/tools/executors/tests/test_base_executor.py +338 -0
- hud/tools/executors/tests/test_pyautogui_executor.py +162 -0
- hud/tools/executors/xdo.py +503 -0
- hud/tools/helper/README.md +56 -0
- hud/tools/helper/__init__.py +9 -0
- hud/tools/helper/mcp_server.py +78 -0
- hud/tools/helper/server_initialization.py +115 -0
- hud/tools/helper/utils.py +58 -0
- hud/tools/playwright_tool.py +373 -0
- hud/tools/tests/__init__.py +3 -0
- hud/tools/tests/test_bash.py +152 -0
- hud/tools/tests/test_computer.py +52 -0
- hud/tools/tests/test_computer_actions.py +34 -0
- hud/tools/tests/test_edit.py +233 -0
- hud/tools/tests/test_init.py +27 -0
- hud/tools/tests/test_playwright_tool.py +183 -0
- hud/tools/tests/test_tools.py +154 -0
- hud/tools/tests/test_utils.py +156 -0
- hud/tools/utils.py +50 -0
- hud/types.py +10 -1
- hud/utils/tests/test_init.py +21 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.2.10.dist-info → hud_python-0.3.0.dist-info}/METADATA +9 -6
- hud_python-0.3.0.dist-info/RECORD +124 -0
- hud_python-0.2.10.dist-info/RECORD +0 -85
- {hud_python-0.2.10.dist-info → hud_python-0.3.0.dist-info}/WHEEL +0 -0
- {hud_python-0.2.10.dist-info → hud_python-0.3.0.dist-info}/licenses/LICENSE +0 -0
hud/mcp_agent/base.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
"""Base MCP Agent implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import mcp.types as types
|
|
11
|
+
from mcp_use import MCPClient
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from hud.task import Task
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseMCPAgent(ABC):
|
|
20
|
+
"""
|
|
21
|
+
Base class for MCP-enabled agents.
|
|
22
|
+
|
|
23
|
+
This class provides the foundation for agents that interact with MCP servers,
|
|
24
|
+
handling tool discovery and filtering while leaving provider-specific
|
|
25
|
+
implementation details to subclasses.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
client: MCPClient | None = None,
|
|
31
|
+
allowed_tools: list[str] | None = None,
|
|
32
|
+
disallowed_tools: list[str] | None = None,
|
|
33
|
+
initial_screenshot: bool = False,
|
|
34
|
+
max_screenshot_history: int = 3,
|
|
35
|
+
append_tool_system_prompt: bool = True,
|
|
36
|
+
custom_system_prompt: str | None = None,
|
|
37
|
+
lifecycle_tools: dict[str, str] | None = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Initialize the base MCP agent.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
client: MCPClient instance for server connections
|
|
44
|
+
allowed_tools: List of tool names to allow (None = all tools)
|
|
45
|
+
disallowed_tools: List of tool names to disallow
|
|
46
|
+
initial_screenshot: Whether to capture screenshot before first prompt
|
|
47
|
+
max_screenshot_history: Maximum number of screenshots to keep in context
|
|
48
|
+
append_tool_system_prompt: Whether to append available tools to system prompt
|
|
49
|
+
custom_system_prompt: Custom system prompt to use
|
|
50
|
+
lifecycle_tools: Dict mapping lifecycle phases to tool names. Default:
|
|
51
|
+
{
|
|
52
|
+
"setup": "setup", # Setup phase tool
|
|
53
|
+
"evaluate": "evaluate" # Evaluation phase tool
|
|
54
|
+
}
|
|
55
|
+
"""
|
|
56
|
+
self.client = client
|
|
57
|
+
self.allowed_tools = allowed_tools
|
|
58
|
+
self.disallowed_tools = disallowed_tools or []
|
|
59
|
+
self.initial_screenshot = initial_screenshot
|
|
60
|
+
self.max_screenshot_history = max_screenshot_history
|
|
61
|
+
self.append_tool_system_prompt = append_tool_system_prompt
|
|
62
|
+
self.custom_system_prompt = custom_system_prompt
|
|
63
|
+
|
|
64
|
+
# Default lifecycle tool mapping
|
|
65
|
+
default_lifecycle = {"setup": "setup", "evaluate": "evaluate"}
|
|
66
|
+
self.lifecycle_tools = {**default_lifecycle, **(lifecycle_tools or {})}
|
|
67
|
+
|
|
68
|
+
self._available_tools: list[types.Tool] = []
|
|
69
|
+
self._tool_map: dict[str, tuple[str, types.Tool]] = {}
|
|
70
|
+
self._sessions: dict[str, Any] = {}
|
|
71
|
+
|
|
72
|
+
if client is None:
|
|
73
|
+
self.client = MCPClient()
|
|
74
|
+
|
|
75
|
+
async def initialize(self) -> None:
|
|
76
|
+
"""Initialize the agent and discover available tools."""
|
|
77
|
+
# Get existing sessions or create new ones
|
|
78
|
+
if self.client is None:
|
|
79
|
+
raise ValueError("Client is not initialized")
|
|
80
|
+
|
|
81
|
+
sessions = self.client.get_all_active_sessions()
|
|
82
|
+
|
|
83
|
+
if not sessions:
|
|
84
|
+
logger.info("No active sessions found, creating new ones...")
|
|
85
|
+
sessions = await self.client.create_all_sessions()
|
|
86
|
+
|
|
87
|
+
self._sessions = sessions
|
|
88
|
+
|
|
89
|
+
# Discover tools from all servers
|
|
90
|
+
self._available_tools = []
|
|
91
|
+
self._tool_map = {}
|
|
92
|
+
|
|
93
|
+
for server_name, session in sessions.items():
|
|
94
|
+
try:
|
|
95
|
+
# Ensure session is initialized
|
|
96
|
+
if not hasattr(session, "connector") or not hasattr(
|
|
97
|
+
session.connector, "client_session"
|
|
98
|
+
):
|
|
99
|
+
await session.initialize()
|
|
100
|
+
|
|
101
|
+
if session.connector.client_session is None:
|
|
102
|
+
raise ValueError("Client session is not initialized")
|
|
103
|
+
|
|
104
|
+
tools_result = await session.connector.client_session.list_tools()
|
|
105
|
+
|
|
106
|
+
# Log all tools before filtering
|
|
107
|
+
logger.info(
|
|
108
|
+
"Tools from '%s' (pre-filter): %s",
|
|
109
|
+
server_name,
|
|
110
|
+
[tool.name for tool in tools_result.tools],
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
for tool in tools_result.tools:
|
|
114
|
+
# Always include lifecycle tools for framework use
|
|
115
|
+
is_lifecycle_tool = tool.name in self.lifecycle_tools.values()
|
|
116
|
+
|
|
117
|
+
# Apply filtering (but always allow lifecycle tools)
|
|
118
|
+
if not is_lifecycle_tool:
|
|
119
|
+
if self.allowed_tools and tool.name not in self.allowed_tools:
|
|
120
|
+
continue
|
|
121
|
+
if tool.name in self.disallowed_tools:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
self._available_tools.append(tool)
|
|
125
|
+
# Store tool with server reference for execution
|
|
126
|
+
self._tool_map[tool.name] = (server_name, tool)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error("Failed to list tools from server %s: %s", server_name, e)
|
|
130
|
+
|
|
131
|
+
# Separate lifecycle tools from regular tools for clearer logging
|
|
132
|
+
lifecycle_tool_names = list(self.lifecycle_tools.values())
|
|
133
|
+
regular_tools = [
|
|
134
|
+
t.name for t in self._available_tools if t.name not in lifecycle_tool_names
|
|
135
|
+
]
|
|
136
|
+
lifecycle_tools_found = [
|
|
137
|
+
t.name for t in self._available_tools if t.name in lifecycle_tool_names
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
logger.info(
|
|
141
|
+
"Agent initialized with %s tools (%s regular, %s lifecycle)",
|
|
142
|
+
len(self._available_tools),
|
|
143
|
+
len(regular_tools),
|
|
144
|
+
len(lifecycle_tools_found),
|
|
145
|
+
)
|
|
146
|
+
if regular_tools:
|
|
147
|
+
logger.info("Regular tools: %s", regular_tools)
|
|
148
|
+
if lifecycle_tools_found:
|
|
149
|
+
logger.info("Lifecycle tools: %s", lifecycle_tools_found)
|
|
150
|
+
|
|
151
|
+
def get_available_tools(self) -> list[types.Tool]:
|
|
152
|
+
"""Get list of available MCP tools for LLM use (excludes lifecycle tools)."""
|
|
153
|
+
lifecycle_tool_names = list(self.lifecycle_tools.values())
|
|
154
|
+
return [tool for tool in self._available_tools if tool.name not in lifecycle_tool_names]
|
|
155
|
+
|
|
156
|
+
def get_tool_map(self) -> dict[str, tuple[str, types.Tool]]:
|
|
157
|
+
"""Get mapping of tool names to (server_name, tool) tuples."""
|
|
158
|
+
return self._tool_map
|
|
159
|
+
|
|
160
|
+
def get_sessions(self) -> dict[str, Any]:
|
|
161
|
+
"""Get active MCP sessions."""
|
|
162
|
+
return self._sessions
|
|
163
|
+
|
|
164
|
+
def get_tools_by_server(self) -> dict[str, list[types.Tool]]:
|
|
165
|
+
"""Get tools grouped by server name."""
|
|
166
|
+
tools_by_server = {}
|
|
167
|
+
for server_name, tool in self._tool_map.values():
|
|
168
|
+
if server_name not in tools_by_server:
|
|
169
|
+
tools_by_server[server_name] = []
|
|
170
|
+
tools_by_server[server_name].append(tool)
|
|
171
|
+
return tools_by_server
|
|
172
|
+
|
|
173
|
+
def get_tools_by_connector(self) -> dict[Any, list[types.Tool]]:
|
|
174
|
+
"""Get tools grouped by connector instance."""
|
|
175
|
+
tools_by_connector = {}
|
|
176
|
+
for server_name, tool in self._tool_map.values():
|
|
177
|
+
session = self._sessions[server_name]
|
|
178
|
+
connector = session.connector
|
|
179
|
+
|
|
180
|
+
if connector not in tools_by_connector:
|
|
181
|
+
tools_by_connector[connector] = []
|
|
182
|
+
tools_by_connector[connector].append(tool)
|
|
183
|
+
return tools_by_connector
|
|
184
|
+
|
|
185
|
+
def get_system_prompt(self) -> str:
|
|
186
|
+
"""Generate system prompt with optional tool information."""
|
|
187
|
+
base_prompt = self.custom_system_prompt or "You are a helpful assistant."
|
|
188
|
+
|
|
189
|
+
if self.append_tool_system_prompt and self._available_tools:
|
|
190
|
+
tool_descriptions = []
|
|
191
|
+
for tool in self._available_tools:
|
|
192
|
+
desc = f"- {tool.name}: {tool.description}"
|
|
193
|
+
if tool.inputSchema:
|
|
194
|
+
desc += f" (parameters: {tool.inputSchema})"
|
|
195
|
+
tool_descriptions.append(desc)
|
|
196
|
+
|
|
197
|
+
tools_prompt = "\n\nYou have access to the following tools:\n" + "\n".join(
|
|
198
|
+
tool_descriptions
|
|
199
|
+
)
|
|
200
|
+
return base_prompt + tools_prompt
|
|
201
|
+
|
|
202
|
+
return base_prompt
|
|
203
|
+
|
|
204
|
+
async def call_tool(self, tool_call: dict[str, Any]) -> types.CallToolResult:
|
|
205
|
+
"""
|
|
206
|
+
Call a tool through the MCP client.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
tool_call: Dict with 'name' and optional 'arguments' keys
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
The raw MCP CallToolResult
|
|
213
|
+
"""
|
|
214
|
+
tool_name = tool_call.get("name")
|
|
215
|
+
if not tool_name:
|
|
216
|
+
raise ValueError("Tool call must have a 'name' field")
|
|
217
|
+
|
|
218
|
+
tool_args = tool_call.get("arguments", {})
|
|
219
|
+
|
|
220
|
+
if tool_name not in self._tool_map:
|
|
221
|
+
raise ValueError(f"Tool '{tool_name}' not found or not allowed")
|
|
222
|
+
|
|
223
|
+
if self.client is None:
|
|
224
|
+
raise ValueError("Client is not initialized")
|
|
225
|
+
|
|
226
|
+
server_name, tool = self._tool_map[tool_name]
|
|
227
|
+
session = self.client.get_session(server_name)
|
|
228
|
+
|
|
229
|
+
logger.info(
|
|
230
|
+
"Calling tool '%s' on server '%s' with args: %s",
|
|
231
|
+
tool_name,
|
|
232
|
+
server_name,
|
|
233
|
+
tool_args,
|
|
234
|
+
)
|
|
235
|
+
if session.connector.client_session is None:
|
|
236
|
+
raise ValueError("Client session is not initialized")
|
|
237
|
+
|
|
238
|
+
result = await session.connector.client_session.call_tool(tool_name, tool_args)
|
|
239
|
+
|
|
240
|
+
# Log result for debugging
|
|
241
|
+
if result.isError:
|
|
242
|
+
logger.error("Tool '%s' returned error: %s", tool_name, result.content)
|
|
243
|
+
else:
|
|
244
|
+
logger.debug("Tool '%s' completed successfully", tool_name)
|
|
245
|
+
|
|
246
|
+
return result
|
|
247
|
+
|
|
248
|
+
def has_computer_tools(self) -> bool:
|
|
249
|
+
"""Check if any computer control tools are available."""
|
|
250
|
+
computer_tools = {"computer", "computer_anthropic", "computer_openai", "screenshot"}
|
|
251
|
+
return any(tool.name in computer_tools for tool in self._available_tools)
|
|
252
|
+
|
|
253
|
+
def get_tool_schemas(self) -> list[dict]:
|
|
254
|
+
"""Get tool schemas in a format suitable for the model."""
|
|
255
|
+
schemas = []
|
|
256
|
+
for tool in self._available_tools:
|
|
257
|
+
# Filter out lifecycle tools from LLM conversation
|
|
258
|
+
if tool.name in self.lifecycle_tools.values():
|
|
259
|
+
continue
|
|
260
|
+
|
|
261
|
+
schema = {
|
|
262
|
+
"name": tool.name,
|
|
263
|
+
"description": tool.description,
|
|
264
|
+
}
|
|
265
|
+
if tool.inputSchema:
|
|
266
|
+
schema["parameters"] = tool.inputSchema
|
|
267
|
+
schemas.append(schema)
|
|
268
|
+
return schemas
|
|
269
|
+
|
|
270
|
+
async def capture_screenshot(self) -> str | None:
|
|
271
|
+
"""Capture a screenshot using available tools."""
|
|
272
|
+
if not self.has_computer_tools():
|
|
273
|
+
return None
|
|
274
|
+
|
|
275
|
+
# Try different screenshot tools
|
|
276
|
+
for tool_name in [
|
|
277
|
+
"computer",
|
|
278
|
+
"screenshot",
|
|
279
|
+
"computer_anthropic",
|
|
280
|
+
"computer_openai",
|
|
281
|
+
"anthropic_computer",
|
|
282
|
+
"openai_computer",
|
|
283
|
+
]:
|
|
284
|
+
if tool_name in self._tool_map:
|
|
285
|
+
try:
|
|
286
|
+
# Different tools have different APIs
|
|
287
|
+
if tool_name == "computer_openai":
|
|
288
|
+
tool_call = {"name": tool_name, "arguments": {"type": "screenshot"}}
|
|
289
|
+
else:
|
|
290
|
+
tool_call = {"name": tool_name, "arguments": {"action": "screenshot"}}
|
|
291
|
+
|
|
292
|
+
result = await self.call_tool(tool_call)
|
|
293
|
+
|
|
294
|
+
# Extract screenshot from result
|
|
295
|
+
for content in result.content:
|
|
296
|
+
if isinstance(content, types.ImageContent):
|
|
297
|
+
logger.info("Captured screenshot")
|
|
298
|
+
return content.data
|
|
299
|
+
|
|
300
|
+
except Exception as e:
|
|
301
|
+
logger.warning("Failed to capture screenshot with %s: %s", tool_name, e)
|
|
302
|
+
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
def process_tool_results(self, tool_results: list[dict[str, Any]]) -> dict[str, Any]:
|
|
306
|
+
"""
|
|
307
|
+
Process tool results into a standardized format.
|
|
308
|
+
|
|
309
|
+
Returns a dict with:
|
|
310
|
+
- text: Combined text output from all tools
|
|
311
|
+
- screenshot: Latest screenshot if any tool returned one
|
|
312
|
+
- errors: List of any errors encountered
|
|
313
|
+
- results: List of (tool_name, content_blocks) tuples for provider-specific formatting
|
|
314
|
+
"""
|
|
315
|
+
text_parts = []
|
|
316
|
+
latest_screenshot = None
|
|
317
|
+
errors = []
|
|
318
|
+
results = []
|
|
319
|
+
|
|
320
|
+
for tool_result in tool_results:
|
|
321
|
+
tool_name = tool_result["tool_name"]
|
|
322
|
+
content_blocks = []
|
|
323
|
+
|
|
324
|
+
if tool_result.get("error"):
|
|
325
|
+
error_msg = f"{tool_name}: {tool_result.get('error_message', 'Unknown error')}"
|
|
326
|
+
errors.append(error_msg)
|
|
327
|
+
text_parts.append(f"Error - {error_msg}")
|
|
328
|
+
content_blocks.append(
|
|
329
|
+
{
|
|
330
|
+
"type": "error",
|
|
331
|
+
"text": tool_result.get("error_message", "Unknown error"),
|
|
332
|
+
}
|
|
333
|
+
)
|
|
334
|
+
else:
|
|
335
|
+
result = tool_result["result"]
|
|
336
|
+
if result.isError:
|
|
337
|
+
# Extract error from content
|
|
338
|
+
error_text = "Tool execution failed"
|
|
339
|
+
for content in result.content:
|
|
340
|
+
if isinstance(content, types.TextContent):
|
|
341
|
+
error_text = content.text
|
|
342
|
+
break
|
|
343
|
+
error_msg = f"{tool_name}: {error_text}"
|
|
344
|
+
errors.append(error_msg)
|
|
345
|
+
text_parts.append(f"Error - {error_msg}")
|
|
346
|
+
content_blocks.append(
|
|
347
|
+
{
|
|
348
|
+
"type": "error",
|
|
349
|
+
"text": error_text,
|
|
350
|
+
}
|
|
351
|
+
)
|
|
352
|
+
else:
|
|
353
|
+
# Process success content
|
|
354
|
+
tool_output = []
|
|
355
|
+
for content in result.content:
|
|
356
|
+
if isinstance(content, types.TextContent):
|
|
357
|
+
tool_output.append(content.text)
|
|
358
|
+
content_blocks.append(
|
|
359
|
+
{
|
|
360
|
+
"type": "text",
|
|
361
|
+
"text": content.text,
|
|
362
|
+
}
|
|
363
|
+
)
|
|
364
|
+
elif isinstance(content, types.ImageContent):
|
|
365
|
+
# Keep the latest screenshot
|
|
366
|
+
latest_screenshot = content.data
|
|
367
|
+
content_blocks.append(
|
|
368
|
+
{
|
|
369
|
+
"type": "image",
|
|
370
|
+
"data": content.data,
|
|
371
|
+
}
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
if tool_output:
|
|
375
|
+
text_parts.append(f"{tool_name}: " + " ".join(tool_output))
|
|
376
|
+
|
|
377
|
+
results.append((tool_name, content_blocks))
|
|
378
|
+
|
|
379
|
+
return {
|
|
380
|
+
"text": "\n".join(text_parts) if text_parts else "No output from tools",
|
|
381
|
+
"screenshot": latest_screenshot,
|
|
382
|
+
"errors": errors,
|
|
383
|
+
"results": results, # List of (tool_name, content_blocks) for provider-specific use
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
async def run(
|
|
387
|
+
self, prompt_or_task: str | Task, max_steps: int = 10, conversation_mode: bool = False
|
|
388
|
+
) -> dict[str, Any]:
|
|
389
|
+
"""
|
|
390
|
+
Run the agent with the given prompt or task.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
prompt_or_task: Either a string prompt for simple execution or a Task object
|
|
394
|
+
max_steps: Maximum number of steps
|
|
395
|
+
conversation_mode: If True, continue even when model returns text without tool calls
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
For string prompts: The final response string
|
|
399
|
+
For Task objects: Evaluation result dict with 'reward', 'done', 'info' keys
|
|
400
|
+
"""
|
|
401
|
+
# Import here to avoid circular imports
|
|
402
|
+
from hud.task import Task
|
|
403
|
+
|
|
404
|
+
if not self._available_tools:
|
|
405
|
+
await self.initialize()
|
|
406
|
+
|
|
407
|
+
# Handle Task objects with full lifecycle
|
|
408
|
+
if isinstance(prompt_or_task, Task):
|
|
409
|
+
return await self._run_task(prompt_or_task, max_steps)
|
|
410
|
+
|
|
411
|
+
# Handle simple string prompts (existing behavior)
|
|
412
|
+
elif isinstance(prompt_or_task, str):
|
|
413
|
+
return await self._run_prompt(prompt_or_task, max_steps, conversation_mode)
|
|
414
|
+
|
|
415
|
+
else:
|
|
416
|
+
raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
|
|
417
|
+
|
|
418
|
+
async def _run_task(self, task: Task, max_steps: int = 10) -> dict[str, Any]:
|
|
419
|
+
"""
|
|
420
|
+
Execute a task with setup and evaluate phases.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
task: Task object with prompt, setup, and evaluate configs
|
|
424
|
+
max_steps: Maximum steps for task execution
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Evaluation result dict with 'reward', 'done', 'info' keys
|
|
428
|
+
"""
|
|
429
|
+
try:
|
|
430
|
+
# Setup phase
|
|
431
|
+
if task.setup is not None:
|
|
432
|
+
setup_tool = self.lifecycle_tools.get("setup", "setup")
|
|
433
|
+
await self._call_tool_safe(setup_tool, task.setup)
|
|
434
|
+
|
|
435
|
+
# Execute the task prompt
|
|
436
|
+
await self._run_prompt(task.prompt, max_steps, conversation_mode=False)
|
|
437
|
+
|
|
438
|
+
# Evaluate phase
|
|
439
|
+
if task.evaluate is not None:
|
|
440
|
+
evaluate_tool = self.lifecycle_tools.get("evaluate", "evaluate")
|
|
441
|
+
eval_result = await self._call_tool_safe(evaluate_tool, task.evaluate)
|
|
442
|
+
|
|
443
|
+
# Return evaluation result if it's properly formatted
|
|
444
|
+
if (
|
|
445
|
+
isinstance(eval_result, dict)
|
|
446
|
+
and "reward" in eval_result
|
|
447
|
+
and "done" in eval_result
|
|
448
|
+
):
|
|
449
|
+
return eval_result
|
|
450
|
+
elif isinstance(eval_result, dict) and "grade" in eval_result:
|
|
451
|
+
return {
|
|
452
|
+
"reward": eval_result.get("grade", 0.0),
|
|
453
|
+
"done": True,
|
|
454
|
+
"info": {
|
|
455
|
+
"error": eval_result.get("error"),
|
|
456
|
+
"logs": eval_result.get("logs", ""),
|
|
457
|
+
"original_result": eval_result,
|
|
458
|
+
},
|
|
459
|
+
}
|
|
460
|
+
else:
|
|
461
|
+
# Fallback for invalid evaluation format
|
|
462
|
+
return {
|
|
463
|
+
"reward": 0.0,
|
|
464
|
+
"done": True,
|
|
465
|
+
"info": {"error": "Invalid evaluation result", "eval_result": eval_result},
|
|
466
|
+
}
|
|
467
|
+
else:
|
|
468
|
+
# No evaluation - assume success
|
|
469
|
+
return {
|
|
470
|
+
"reward": 0.0,
|
|
471
|
+
"done": True,
|
|
472
|
+
"info": {"message": "Task completed (no evaluation specified)"},
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
except Exception as e:
|
|
476
|
+
return {"reward": 0.0, "done": True, "info": {"error": str(e)}}
|
|
477
|
+
|
|
478
|
+
async def _call_tool_safe(self, tool_name: str, arguments: Any) -> Any:
|
|
479
|
+
"""
|
|
480
|
+
Safely call a tool and return its result.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
tool_name: Name of the tool to call
|
|
484
|
+
arguments: Arguments to pass to the tool (config from task)
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Tool result or None if tool not available/failed
|
|
488
|
+
"""
|
|
489
|
+
try:
|
|
490
|
+
if tool_name in self._tool_map:
|
|
491
|
+
tool_call = {"name": tool_name, "arguments": arguments}
|
|
492
|
+
result = await self.call_tool(tool_call)
|
|
493
|
+
|
|
494
|
+
if result.isError:
|
|
495
|
+
logger.error("Tool %s returned error: %s", tool_name, result.content)
|
|
496
|
+
return {"error": result.content}
|
|
497
|
+
else:
|
|
498
|
+
# Extract content from MCP result
|
|
499
|
+
if hasattr(result, "content") and result.content:
|
|
500
|
+
if len(result.content) == 1:
|
|
501
|
+
content_item = result.content[0]
|
|
502
|
+
# Check if content_item is a text type
|
|
503
|
+
if hasattr(content_item, "text") and hasattr(content_item, "type"):
|
|
504
|
+
if getattr(content_item, "type", None) == "text":
|
|
505
|
+
# Try to parse as JSON if it looks like structured data
|
|
506
|
+
text = content_item.text # type: ignore[reportAttributeAccessIssue]
|
|
507
|
+
if text.strip().startswith("{") and text.strip().endswith("}"):
|
|
508
|
+
try:
|
|
509
|
+
import json
|
|
510
|
+
|
|
511
|
+
return json.loads(text)
|
|
512
|
+
except json.JSONDecodeError:
|
|
513
|
+
return text
|
|
514
|
+
return text
|
|
515
|
+
else:
|
|
516
|
+
return content_item
|
|
517
|
+
else:
|
|
518
|
+
return result.content
|
|
519
|
+
return result
|
|
520
|
+
else:
|
|
521
|
+
logger.warning("Tool %s not available", tool_name)
|
|
522
|
+
return None
|
|
523
|
+
except Exception as e:
|
|
524
|
+
logger.error("Failed to call tool %s: %s", tool_name, e)
|
|
525
|
+
return {"error": str(e)}
|
|
526
|
+
|
|
527
|
+
async def _run_prompt(
|
|
528
|
+
self,
|
|
529
|
+
prompt: str,
|
|
530
|
+
max_steps: int = 10,
|
|
531
|
+
conversation_mode: bool = False,
|
|
532
|
+
) -> dict[str, Any]:
|
|
533
|
+
"""
|
|
534
|
+
Run the agent with the given prompt.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
prompt: The task to complete
|
|
538
|
+
max_steps: Maximum number of steps
|
|
539
|
+
conversation_mode: If True, continue even when model returns text without tool calls
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
The final response or result
|
|
543
|
+
"""
|
|
544
|
+
try:
|
|
545
|
+
latest_screenshot = None
|
|
546
|
+
if self.initial_screenshot:
|
|
547
|
+
latest_screenshot = await self.capture_screenshot()
|
|
548
|
+
|
|
549
|
+
messages = await self.create_initial_messages(prompt, latest_screenshot)
|
|
550
|
+
|
|
551
|
+
step = 0
|
|
552
|
+
while step < max_steps:
|
|
553
|
+
step += 1
|
|
554
|
+
logger.info("step %s/%s", step, max_steps)
|
|
555
|
+
|
|
556
|
+
try:
|
|
557
|
+
response = await self.get_model_response(messages, step)
|
|
558
|
+
|
|
559
|
+
# Log the model's response
|
|
560
|
+
logger.info("Model response - Content: %s", response.get("content", ""))
|
|
561
|
+
logger.info(
|
|
562
|
+
"Model response - Tool calls: %s",
|
|
563
|
+
[tc.get("name") for tc in response.get("tool_calls", [])],
|
|
564
|
+
)
|
|
565
|
+
logger.info("Model response - Done: %s", response.get("done", False))
|
|
566
|
+
|
|
567
|
+
# Check if we should stop
|
|
568
|
+
if response.get("done", False) and not conversation_mode:
|
|
569
|
+
return response.get("content", "Task completed")
|
|
570
|
+
|
|
571
|
+
tool_calls = response.get("tool_calls", [])
|
|
572
|
+
if not tool_calls:
|
|
573
|
+
if conversation_mode:
|
|
574
|
+
# In conversation mode, if model responds without tools,
|
|
575
|
+
# show the response and get user input
|
|
576
|
+
model_response = response.get("content", "")
|
|
577
|
+
if model_response:
|
|
578
|
+
print(f"\n🤖 Agent: {model_response}") # noqa: T201
|
|
579
|
+
user_input = input("\n👤 You: ").strip()
|
|
580
|
+
if user_input.lower() in ["exit", "quit", "bye"]:
|
|
581
|
+
return {
|
|
582
|
+
"done": True,
|
|
583
|
+
"reward": 0.0,
|
|
584
|
+
"info": {"message": "Conversation ended by user."},
|
|
585
|
+
}
|
|
586
|
+
# Add user's response to the conversation
|
|
587
|
+
# This needs to be handled by subclass-specific format
|
|
588
|
+
user_message = await self.create_user_message(user_input)
|
|
589
|
+
messages.append(user_message)
|
|
590
|
+
continue
|
|
591
|
+
else:
|
|
592
|
+
# No content and no tools - something went wrong
|
|
593
|
+
return {
|
|
594
|
+
"done": False,
|
|
595
|
+
"reward": 0.0,
|
|
596
|
+
"info": {"message": "No response generated"},
|
|
597
|
+
}
|
|
598
|
+
else:
|
|
599
|
+
# In task mode, no tool calls means we're done
|
|
600
|
+
logger.info("In task mode with no tool calls - stopping execution")
|
|
601
|
+
logger.info(
|
|
602
|
+
"Final message: %s",
|
|
603
|
+
response.get("content", "No response generated"),
|
|
604
|
+
)
|
|
605
|
+
return {
|
|
606
|
+
"done": True,
|
|
607
|
+
"reward": 0.0,
|
|
608
|
+
"info": {
|
|
609
|
+
"message": response.get("content", "No response generated"),
|
|
610
|
+
},
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
# Execute tool calls
|
|
614
|
+
tool_results = []
|
|
615
|
+
for tool_call in tool_calls:
|
|
616
|
+
if not tool_call.get("name"):
|
|
617
|
+
continue
|
|
618
|
+
try:
|
|
619
|
+
result = await self.call_tool(tool_call)
|
|
620
|
+
tool_results.append(
|
|
621
|
+
{
|
|
622
|
+
"tool_name": tool_call["name"],
|
|
623
|
+
"result": result,
|
|
624
|
+
"error": False,
|
|
625
|
+
}
|
|
626
|
+
)
|
|
627
|
+
except Exception as e:
|
|
628
|
+
logger.error("Tool execution failed: %s", e)
|
|
629
|
+
tool_results.append(
|
|
630
|
+
{
|
|
631
|
+
"tool_name": tool_call["name"],
|
|
632
|
+
"error": True,
|
|
633
|
+
"error_message": str(e),
|
|
634
|
+
}
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# Process results
|
|
638
|
+
processed_results = self.process_tool_results(tool_results)
|
|
639
|
+
|
|
640
|
+
# Update screenshot if we got a new one
|
|
641
|
+
if processed_results["screenshot"]:
|
|
642
|
+
latest_screenshot = processed_results["screenshot"]
|
|
643
|
+
|
|
644
|
+
# Format tool results for the model
|
|
645
|
+
tool_messages = await self.format_tool_results(
|
|
646
|
+
processed_results,
|
|
647
|
+
response.get("tool_calls", []),
|
|
648
|
+
)
|
|
649
|
+
messages.extend(tool_messages)
|
|
650
|
+
|
|
651
|
+
except Exception as e:
|
|
652
|
+
logger.error("Model call failed: %s", e)
|
|
653
|
+
return {"done": False, "reward": 0.0, "info": {"message": f"Error: {e}"}}
|
|
654
|
+
|
|
655
|
+
return {"done": True, "reward": 0.0, "info": {"message": "Task completed"}}
|
|
656
|
+
|
|
657
|
+
except KeyboardInterrupt:
|
|
658
|
+
logger.info("Agent execution interrupted by user")
|
|
659
|
+
return {
|
|
660
|
+
"done": False,
|
|
661
|
+
"reward": 0.0,
|
|
662
|
+
"info": {"message": "Execution interrupted by user (Ctrl+C)"},
|
|
663
|
+
}
|
|
664
|
+
except asyncio.CancelledError:
|
|
665
|
+
logger.info("Agent execution cancelled")
|
|
666
|
+
return {"done": False, "reward": 0.0, "info": {"message": "Execution cancelled"}}
|
|
667
|
+
|
|
668
|
+
@abstractmethod
|
|
669
|
+
async def create_initial_messages(self, prompt: str, screenshot: str | None) -> list[Any]:
|
|
670
|
+
"""
|
|
671
|
+
Create initial messages for the conversation.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
prompt: The user's prompt
|
|
675
|
+
screenshot: Optional initial screenshot
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
List of messages in provider-specific format
|
|
679
|
+
"""
|
|
680
|
+
|
|
681
|
+
@abstractmethod
|
|
682
|
+
async def get_model_response(self, messages: list[Any], step: int) -> dict[str, Any]:
|
|
683
|
+
"""
|
|
684
|
+
Get response from the model including any tool calls.
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
messages: List of messages in provider-specific format
|
|
688
|
+
step: Current step number
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
Dict with 'content', 'tool_calls', and 'done' keys
|
|
692
|
+
"""
|
|
693
|
+
|
|
694
|
+
@abstractmethod
|
|
695
|
+
async def format_tool_results(
|
|
696
|
+
self, processed_results: dict[str, Any], tool_calls: list[dict[str, Any]]
|
|
697
|
+
) -> list[Any]:
|
|
698
|
+
"""
|
|
699
|
+
Format tool results into messages for the model.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
processed_results: Processed tool results from process_tool_results
|
|
703
|
+
tool_calls: Original tool calls from the model
|
|
704
|
+
|
|
705
|
+
Returns:
|
|
706
|
+
List of formatted messages to append to conversation
|
|
707
|
+
"""
|
|
708
|
+
raise NotImplementedError
|
|
709
|
+
|
|
710
|
+
async def create_user_message(self, text: str) -> Any:
|
|
711
|
+
"""
|
|
712
|
+
Create a user message in the format expected by the model.
|
|
713
|
+
|
|
714
|
+
Default implementation for text-only messages.
|
|
715
|
+
Subclasses can override for specific formats.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
text: User's text input
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
Formatted user message
|
|
722
|
+
"""
|
|
723
|
+
return {"role": "user", "content": text}
|