hud-python 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +7 -4
- hud/adapters/common/adapter.py +14 -3
- hud/adapters/common/tests/test_adapter.py +16 -4
- hud/datasets.py +188 -0
- hud/env/docker_client.py +14 -2
- hud/env/local_docker_client.py +28 -6
- hud/gym.py +0 -9
- hud/{mcp_agent → mcp}/__init__.py +2 -0
- hud/mcp/base.py +631 -0
- hud/{mcp_agent → mcp}/claude.py +52 -47
- hud/mcp/client.py +312 -0
- hud/{mcp_agent → mcp}/langchain.py +52 -33
- hud/{mcp_agent → mcp}/openai.py +56 -40
- hud/{mcp_agent → mcp}/tests/test_base.py +129 -54
- hud/mcp/tests/test_claude.py +294 -0
- hud/mcp/tests/test_client.py +324 -0
- hud/mcp/tests/test_openai.py +238 -0
- hud/settings.py +6 -0
- hud/task.py +1 -88
- hud/taskset.py +2 -23
- hud/telemetry/__init__.py +5 -0
- hud/telemetry/_trace.py +180 -17
- hud/telemetry/context.py +79 -0
- hud/telemetry/exporter.py +165 -6
- hud/telemetry/job.py +141 -0
- hud/telemetry/tests/test_trace.py +36 -25
- hud/tools/__init__.py +14 -1
- hud/tools/executors/__init__.py +19 -2
- hud/tools/executors/pyautogui.py +84 -50
- hud/tools/executors/tests/test_pyautogui_executor.py +4 -1
- hud/tools/playwright_tool.py +73 -67
- hud/tools/tests/test_edit.py +8 -1
- hud/tools/tests/test_tools.py +3 -0
- hud/trajectory.py +5 -1
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/METADATA +20 -14
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/RECORD +41 -46
- hud/evaluators/__init__.py +0 -9
- hud/evaluators/base.py +0 -32
- hud/evaluators/inspect.py +0 -24
- hud/evaluators/judge.py +0 -189
- hud/evaluators/match.py +0 -156
- hud/evaluators/remote.py +0 -65
- hud/evaluators/tests/__init__.py +0 -0
- hud/evaluators/tests/test_inspect.py +0 -12
- hud/evaluators/tests/test_judge.py +0 -231
- hud/evaluators/tests/test_match.py +0 -115
- hud/evaluators/tests/test_remote.py +0 -98
- hud/mcp_agent/base.py +0 -723
- /hud/{mcp_agent → mcp}/tests/__init__.py +0 -0
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/WHEEL +0 -0
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/licenses/LICENSE +0 -0
hud/mcp/base.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
1
|
+
"""Base MCP Agent implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import mcp.types as types
|
|
11
|
+
from mcp.types import CallToolRequestParams as MCPToolCall
|
|
12
|
+
from mcp.types import CallToolResult as MCPToolResult
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from hud.datasets import TaskConfig
|
|
17
|
+
|
|
18
|
+
from .client import MCPClient
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelResponse(BaseModel):
|
|
22
|
+
"""Response from get_model_response method."""
|
|
23
|
+
|
|
24
|
+
content: str | None = Field(default=None)
|
|
25
|
+
tool_calls: list[MCPToolCall] = Field(default_factory=list)
|
|
26
|
+
done: bool = Field(default=False)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AgentResult(BaseModel):
|
|
30
|
+
"""Unified result from agent execution (task or prompt).
|
|
31
|
+
|
|
32
|
+
Fields:
|
|
33
|
+
- done: Whether execution is complete
|
|
34
|
+
- reward: Numeric reward (mainly for task evaluation)
|
|
35
|
+
- info: Additional metadata dict
|
|
36
|
+
- content: Final text content from the agent
|
|
37
|
+
- error: Error message if execution failed
|
|
38
|
+
- messages: Full conversation history (populated in prompt mode)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
done: bool = Field(default=True)
|
|
42
|
+
reward: float = Field(default=0.0)
|
|
43
|
+
info: dict[str, Any] = Field(default_factory=dict)
|
|
44
|
+
content: str | None = Field(default=None)
|
|
45
|
+
error: str | None = Field(default=None)
|
|
46
|
+
messages: list[Any] = Field(default_factory=list) # Full conversation history
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
logger = logging.getLogger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class BaseMCPAgent(ABC):
|
|
53
|
+
"""
|
|
54
|
+
Base class for MCP-enabled agents.
|
|
55
|
+
|
|
56
|
+
This class provides the foundation for agents that interact with MCP servers,
|
|
57
|
+
handling tool discovery and filtering while leaving provider-specific
|
|
58
|
+
implementation details to subclasses.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
mcp_client: MCPClient | None = None,
|
|
64
|
+
allowed_tools: list[str] | None = None,
|
|
65
|
+
disallowed_tools: list[str] | None = None,
|
|
66
|
+
initial_screenshot: bool = False,
|
|
67
|
+
max_screenshot_history: int = 3,
|
|
68
|
+
append_tool_system_prompt: bool = True,
|
|
69
|
+
custom_system_prompt: str | None = None,
|
|
70
|
+
lifecycle_tools: list[str] | None = None,
|
|
71
|
+
) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Initialize the base MCP agent.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
mcp_client: MCPClient instance for server connections
|
|
77
|
+
allowed_tools: List of tool names to allow (None = all tools)
|
|
78
|
+
disallowed_tools: List of tool names to disallow
|
|
79
|
+
initial_screenshot: Whether to capture screenshot before first prompt
|
|
80
|
+
max_screenshot_history: Maximum number of screenshots to keep in context
|
|
81
|
+
append_tool_system_prompt: Whether to append available tools to system prompt
|
|
82
|
+
custom_system_prompt: Custom system prompt to use
|
|
83
|
+
lifecycle_tools: List of tool names to use for lifecycle tools
|
|
84
|
+
"""
|
|
85
|
+
if not mcp_client:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"MCPClient is required. Please provide a configured MCPClient instance."
|
|
88
|
+
)
|
|
89
|
+
self.mcp_client = mcp_client
|
|
90
|
+
self.allowed_tools = allowed_tools
|
|
91
|
+
self.disallowed_tools = disallowed_tools or []
|
|
92
|
+
self.initial_screenshot = initial_screenshot
|
|
93
|
+
self.max_screenshot_history = max_screenshot_history
|
|
94
|
+
self.append_tool_system_prompt = append_tool_system_prompt
|
|
95
|
+
self.custom_system_prompt = custom_system_prompt
|
|
96
|
+
|
|
97
|
+
self.lifecycle_tools = lifecycle_tools or []
|
|
98
|
+
|
|
99
|
+
self.model_name = "test-agent"
|
|
100
|
+
|
|
101
|
+
# Initialize these here so methods can be called before initialize()
|
|
102
|
+
self._available_tools: list[types.Tool] = []
|
|
103
|
+
self._tool_map: dict[str, tuple[str, types.Tool]] = {}
|
|
104
|
+
self.screenshot_history: list[str] = []
|
|
105
|
+
|
|
106
|
+
def _filter_tools(self) -> None:
|
|
107
|
+
"""Apply tool filtering based on allowed/disallowed lists."""
|
|
108
|
+
# Get all tools from client
|
|
109
|
+
tool_map = self.mcp_client.get_tool_map()
|
|
110
|
+
|
|
111
|
+
# Filter tools
|
|
112
|
+
self._available_tools = []
|
|
113
|
+
self._tool_map = {}
|
|
114
|
+
|
|
115
|
+
for tool_name, (server_name, tool) in tool_map.items():
|
|
116
|
+
# Check if tool should be included
|
|
117
|
+
if self.allowed_tools and tool_name not in self.allowed_tools:
|
|
118
|
+
continue
|
|
119
|
+
if tool_name in self.disallowed_tools:
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
self._available_tools.append(tool)
|
|
123
|
+
self._tool_map[tool_name] = (server_name, tool)
|
|
124
|
+
|
|
125
|
+
async def initialize(self, task: str | TaskConfig | None = None) -> None:
|
|
126
|
+
"""Initialize the agent with task-specific configuration."""
|
|
127
|
+
# If client wasn't initialized on construction, do it now
|
|
128
|
+
if not self.mcp_client.get_sessions():
|
|
129
|
+
await self.mcp_client.initialize()
|
|
130
|
+
|
|
131
|
+
# If task is provided, add lifecycle tools
|
|
132
|
+
from hud.datasets import TaskConfig
|
|
133
|
+
|
|
134
|
+
if isinstance(task, TaskConfig):
|
|
135
|
+
if task.setup_tool:
|
|
136
|
+
self.lifecycle_tools.append(task.setup_tool.name)
|
|
137
|
+
if task.evaluate_tool:
|
|
138
|
+
self.lifecycle_tools.append(task.evaluate_tool.name)
|
|
139
|
+
|
|
140
|
+
# Re-apply filtering with updated lifecycle tools
|
|
141
|
+
self._filter_tools()
|
|
142
|
+
|
|
143
|
+
logger.info(
|
|
144
|
+
"Agent initialized with %d available tools (after filtering)",
|
|
145
|
+
len(self._available_tools),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def get_available_tools(self) -> list[types.Tool]:
|
|
149
|
+
"""Get list of available MCP tools for LLM use (excludes lifecycle tools)."""
|
|
150
|
+
lifecycle_tool_names = self.lifecycle_tools
|
|
151
|
+
return [tool for tool in self._available_tools if tool.name not in lifecycle_tool_names]
|
|
152
|
+
|
|
153
|
+
def get_tool_map(self) -> dict[str, tuple[str, types.Tool]]:
|
|
154
|
+
"""Get mapping of tool names to (server_name, tool) tuples."""
|
|
155
|
+
return self._tool_map
|
|
156
|
+
|
|
157
|
+
def get_sessions(self) -> dict[str, Any]:
|
|
158
|
+
"""Get active MCP sessions."""
|
|
159
|
+
return self.mcp_client.get_sessions()
|
|
160
|
+
|
|
161
|
+
def get_tools_by_server(self) -> dict[str, list[types.Tool]]:
|
|
162
|
+
"""Get tools grouped by server name."""
|
|
163
|
+
tools_by_server = {}
|
|
164
|
+
for server_name, tool in self._tool_map.values():
|
|
165
|
+
if server_name not in tools_by_server:
|
|
166
|
+
tools_by_server[server_name] = []
|
|
167
|
+
tools_by_server[server_name].append(tool)
|
|
168
|
+
return tools_by_server
|
|
169
|
+
|
|
170
|
+
def get_tools_by_connector(self) -> dict[Any, list[types.Tool]]:
|
|
171
|
+
"""Get tools grouped by connector instance."""
|
|
172
|
+
tools_by_connector = {}
|
|
173
|
+
sessions = self.mcp_client.get_sessions()
|
|
174
|
+
for server_name, tool in self._tool_map.values():
|
|
175
|
+
session = sessions[server_name]
|
|
176
|
+
connector = session.connector
|
|
177
|
+
|
|
178
|
+
if connector not in tools_by_connector:
|
|
179
|
+
tools_by_connector[connector] = []
|
|
180
|
+
tools_by_connector[connector].append(tool)
|
|
181
|
+
return tools_by_connector
|
|
182
|
+
|
|
183
|
+
def get_system_prompt(self) -> str:
|
|
184
|
+
"""Generate system prompt with optional tool information."""
|
|
185
|
+
base_prompt = self.custom_system_prompt or "You are a helpful assistant."
|
|
186
|
+
|
|
187
|
+
if self.append_tool_system_prompt and self._available_tools:
|
|
188
|
+
tool_descriptions = []
|
|
189
|
+
for tool in self._available_tools:
|
|
190
|
+
desc = f"- {tool.name}: {tool.description}"
|
|
191
|
+
if tool.inputSchema:
|
|
192
|
+
desc += f" (parameters: {tool.inputSchema})"
|
|
193
|
+
tool_descriptions.append(desc)
|
|
194
|
+
|
|
195
|
+
tools_prompt = "\n\nYou have access to the following tools:\n" + "\n".join(
|
|
196
|
+
tool_descriptions
|
|
197
|
+
)
|
|
198
|
+
return base_prompt + tools_prompt
|
|
199
|
+
|
|
200
|
+
return base_prompt
|
|
201
|
+
|
|
202
|
+
async def call_tool(self, tool_call: MCPToolCall) -> MCPToolResult:
|
|
203
|
+
"""
|
|
204
|
+
Call a tool through the MCP client.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
tool_call: Dict with 'name' and optional 'arguments' keys
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
The raw MCPToolResult
|
|
211
|
+
"""
|
|
212
|
+
tool_name = tool_call.name
|
|
213
|
+
if not tool_name:
|
|
214
|
+
raise ValueError("Tool call must have a 'name' field")
|
|
215
|
+
|
|
216
|
+
tool_args = tool_call.arguments
|
|
217
|
+
|
|
218
|
+
if tool_name not in self._tool_map and tool_name not in self.lifecycle_tools:
|
|
219
|
+
raise ValueError(f"Tool '{tool_name}' not found or not allowed")
|
|
220
|
+
|
|
221
|
+
if self.mcp_client is None:
|
|
222
|
+
raise ValueError("Client is not initialized")
|
|
223
|
+
|
|
224
|
+
# Use client's call_tool method which handles routing
|
|
225
|
+
result = await self.mcp_client.call_tool(tool_name, tool_args)
|
|
226
|
+
|
|
227
|
+
# Log result for debugging
|
|
228
|
+
if result.isError:
|
|
229
|
+
logger.error("Tool '%s' returned error: %s", tool_name, result.content)
|
|
230
|
+
else:
|
|
231
|
+
logger.debug("Tool '%s' completed successfully", tool_name)
|
|
232
|
+
|
|
233
|
+
return result
|
|
234
|
+
|
|
235
|
+
def has_computer_tools(self) -> bool:
|
|
236
|
+
"""Check if any computer control tools are available."""
|
|
237
|
+
computer_tools = {"computer", "computer_anthropic", "computer_openai", "screenshot"}
|
|
238
|
+
return any(tool.name in computer_tools for tool in self._available_tools)
|
|
239
|
+
|
|
240
|
+
def get_tool_schemas(self) -> list[dict]:
|
|
241
|
+
"""Get tool schemas in a format suitable for the model."""
|
|
242
|
+
schemas = []
|
|
243
|
+
for tool in self._available_tools:
|
|
244
|
+
# Filter out lifecycle tools from LLM conversation
|
|
245
|
+
if tool.name in self.lifecycle_tools:
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
schema = {
|
|
249
|
+
"name": tool.name,
|
|
250
|
+
"description": tool.description,
|
|
251
|
+
}
|
|
252
|
+
if tool.inputSchema:
|
|
253
|
+
schema["parameters"] = tool.inputSchema
|
|
254
|
+
schemas.append(schema)
|
|
255
|
+
return schemas
|
|
256
|
+
|
|
257
|
+
async def capture_screenshot(self) -> str | None:
|
|
258
|
+
"""Capture a screenshot using available tools."""
|
|
259
|
+
if not self.has_computer_tools():
|
|
260
|
+
return None
|
|
261
|
+
|
|
262
|
+
# Try different screenshot tools
|
|
263
|
+
for tool_name in [
|
|
264
|
+
"computer",
|
|
265
|
+
"screenshot",
|
|
266
|
+
"computer_anthropic",
|
|
267
|
+
"computer_openai",
|
|
268
|
+
"anthropic_computer",
|
|
269
|
+
"openai_computer",
|
|
270
|
+
]:
|
|
271
|
+
if tool_name in self._tool_map:
|
|
272
|
+
try:
|
|
273
|
+
# Different tools have different APIs
|
|
274
|
+
if tool_name == "computer_openai":
|
|
275
|
+
tool_call = MCPToolCall(name=tool_name, arguments={"type": "screenshot"})
|
|
276
|
+
else:
|
|
277
|
+
tool_call = MCPToolCall(name=tool_name, arguments={"action": "screenshot"})
|
|
278
|
+
|
|
279
|
+
result = await self.call_tool(tool_call)
|
|
280
|
+
|
|
281
|
+
# Extract screenshot from result
|
|
282
|
+
for content in result.content:
|
|
283
|
+
if isinstance(content, types.ImageContent):
|
|
284
|
+
logger.info("Captured screenshot")
|
|
285
|
+
return content.data
|
|
286
|
+
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.warning("Failed to capture screenshot with %s: %s", tool_name, e)
|
|
289
|
+
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
def extract_latest_screenshot(self, tool_results: list[MCPToolResult]) -> str | None:
|
|
293
|
+
"""Extract the latest screenshot from tool results."""
|
|
294
|
+
latest_screenshot = None
|
|
295
|
+
for result in tool_results:
|
|
296
|
+
if not result.isError:
|
|
297
|
+
for content in result.content:
|
|
298
|
+
if isinstance(content, types.ImageContent):
|
|
299
|
+
latest_screenshot = content.data
|
|
300
|
+
return latest_screenshot
|
|
301
|
+
|
|
302
|
+
async def run(self, prompt_or_task: str | TaskConfig, max_steps: int = 10) -> AgentResult:
|
|
303
|
+
"""
|
|
304
|
+
Run the agent with the given prompt or task.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
prompt_or_task: Either a string prompt for simple execution or a Task object
|
|
308
|
+
max_steps: Maximum number of steps
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
AgentResult with appropriate fields populated based on execution type
|
|
312
|
+
"""
|
|
313
|
+
# Import here to avoid circular imports
|
|
314
|
+
from hud.datasets import TaskConfig
|
|
315
|
+
|
|
316
|
+
if not self._available_tools:
|
|
317
|
+
await self.initialize(prompt_or_task)
|
|
318
|
+
|
|
319
|
+
# Handle Task objects with full lifecycle
|
|
320
|
+
if isinstance(prompt_or_task, TaskConfig):
|
|
321
|
+
return await self._run_task(prompt_or_task, max_steps)
|
|
322
|
+
|
|
323
|
+
# Handle simple string prompts
|
|
324
|
+
elif isinstance(prompt_or_task, str):
|
|
325
|
+
return await self._run_prompt(prompt_or_task, max_steps)
|
|
326
|
+
|
|
327
|
+
else:
|
|
328
|
+
raise TypeError(f"prompt_or_task must be str or TaskConfig, got {type(prompt_or_task)}")
|
|
329
|
+
|
|
330
|
+
async def _run_task(self, task: TaskConfig, max_steps: int = 10) -> AgentResult:
|
|
331
|
+
"""
|
|
332
|
+
Execute a task with setup and evaluate phases.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
task: Task object with prompt, setup, and evaluate configs
|
|
336
|
+
max_steps: Maximum steps for task execution
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
AgentResult with reward, done, and info fields
|
|
340
|
+
"""
|
|
341
|
+
try:
|
|
342
|
+
# Setup phase
|
|
343
|
+
if task.setup_tool is not None:
|
|
344
|
+
await self.call_tool(task.setup_tool)
|
|
345
|
+
|
|
346
|
+
# Execute the task prompt
|
|
347
|
+
prompt_result = await self._run_prompt(task.prompt, max_steps)
|
|
348
|
+
|
|
349
|
+
# Evaluate phase
|
|
350
|
+
if task.evaluate_tool is not None:
|
|
351
|
+
eval_result = await self.call_tool(task.evaluate_tool)
|
|
352
|
+
|
|
353
|
+
# Return evaluation result if it's properly formatted
|
|
354
|
+
if (
|
|
355
|
+
isinstance(eval_result, MCPToolResult)
|
|
356
|
+
and eval_result.structuredContent is not None
|
|
357
|
+
):
|
|
358
|
+
return AgentResult(
|
|
359
|
+
reward=self._find_reward(eval_result),
|
|
360
|
+
done=True,
|
|
361
|
+
content=eval_result.structuredContent["content"],
|
|
362
|
+
messages=prompt_result.messages,
|
|
363
|
+
)
|
|
364
|
+
else:
|
|
365
|
+
# Fallback for invalid evaluation format
|
|
366
|
+
return AgentResult(
|
|
367
|
+
reward=0.0,
|
|
368
|
+
done=True,
|
|
369
|
+
error="Invalid evaluation result",
|
|
370
|
+
info={"eval_result": eval_result},
|
|
371
|
+
messages=prompt_result.messages,
|
|
372
|
+
)
|
|
373
|
+
else:
|
|
374
|
+
# No evaluation - assume success
|
|
375
|
+
return AgentResult(
|
|
376
|
+
reward=0.0,
|
|
377
|
+
done=True,
|
|
378
|
+
content=prompt_result.content,
|
|
379
|
+
messages=prompt_result.messages,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
except Exception as e:
|
|
383
|
+
return AgentResult(reward=0.0, done=True, error=str(e))
|
|
384
|
+
|
|
385
|
+
def _find_reward(self, result: MCPToolResult) -> float:
|
|
386
|
+
"""Find the reward in the result.
|
|
387
|
+
|
|
388
|
+
Agent accepts "reward", "grade", "score"
|
|
389
|
+
|
|
390
|
+
If not found, return 0.0
|
|
391
|
+
"""
|
|
392
|
+
accept_keys = ["reward", "grade", "score"]
|
|
393
|
+
for key in accept_keys:
|
|
394
|
+
if isinstance(result.structuredContent, dict) and key in result.structuredContent:
|
|
395
|
+
return result.structuredContent[key]
|
|
396
|
+
return 0.0
|
|
397
|
+
|
|
398
|
+
def _format_error_result(self, error_message: str) -> MCPToolResult:
|
|
399
|
+
return MCPToolResult(
|
|
400
|
+
content=[types.TextContent(text=error_message, type="text")], isError=True
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
async def run_conversation(self, prompt: str, max_steps: int = 10) -> AgentResult:
|
|
404
|
+
"""
|
|
405
|
+
Run the agent in interactive conversation mode.
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
prompt: The initial prompt to start the conversation
|
|
409
|
+
max_steps: Maximum number of steps per turn
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
AgentResult when conversation ends
|
|
413
|
+
"""
|
|
414
|
+
try:
|
|
415
|
+
latest_screenshot = None
|
|
416
|
+
if self.initial_screenshot:
|
|
417
|
+
latest_screenshot = await self.capture_screenshot()
|
|
418
|
+
|
|
419
|
+
messages = await self.create_initial_messages(prompt, latest_screenshot)
|
|
420
|
+
|
|
421
|
+
step = 0
|
|
422
|
+
while step < max_steps:
|
|
423
|
+
step += 1
|
|
424
|
+
logger.info("Conversation step %s/%s", step, max_steps)
|
|
425
|
+
|
|
426
|
+
try:
|
|
427
|
+
response = await self.get_model_response(messages)
|
|
428
|
+
|
|
429
|
+
# Log the model's response
|
|
430
|
+
logger.info("Model response - Content: %s", response.content)
|
|
431
|
+
logger.info(
|
|
432
|
+
"Model response - Tool calls: %s",
|
|
433
|
+
[tc.name for tc in response.tool_calls],
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
tool_calls = response.tool_calls
|
|
437
|
+
if not tool_calls:
|
|
438
|
+
# In conversation mode, if model responds without tools,
|
|
439
|
+
# show the response and get user input
|
|
440
|
+
model_response = response.content
|
|
441
|
+
if model_response:
|
|
442
|
+
print(f"\n🤖 Agent: {model_response}") # noqa: T201
|
|
443
|
+
user_input = input("\n👤 You: ").strip()
|
|
444
|
+
if user_input.lower() in ["exit", "quit", "bye"]:
|
|
445
|
+
return AgentResult(done=True, reward=0.0, messages=messages)
|
|
446
|
+
# Add user's response to the conversation
|
|
447
|
+
user_message = await self.create_user_message(user_input)
|
|
448
|
+
messages.append(user_message)
|
|
449
|
+
continue
|
|
450
|
+
else:
|
|
451
|
+
# No content and no tools - something went wrong
|
|
452
|
+
return AgentResult(
|
|
453
|
+
done=False,
|
|
454
|
+
reward=0.0,
|
|
455
|
+
error="No response generated",
|
|
456
|
+
messages=messages,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Execute tool calls
|
|
460
|
+
tool_results = []
|
|
461
|
+
for tool_call in tool_calls:
|
|
462
|
+
try:
|
|
463
|
+
result = await self.call_tool(tool_call)
|
|
464
|
+
tool_results.append(result)
|
|
465
|
+
except Exception as e:
|
|
466
|
+
logger.error("Tool execution failed: %s", e)
|
|
467
|
+
# Create error MCPToolResult
|
|
468
|
+
error_result = MCPToolResult(
|
|
469
|
+
content=[types.TextContent(text=str(e), type="text")], isError=True
|
|
470
|
+
)
|
|
471
|
+
tool_results.append(error_result)
|
|
472
|
+
|
|
473
|
+
# Format tool results for the model
|
|
474
|
+
tool_messages = await self.format_tool_results(tool_calls, tool_results)
|
|
475
|
+
messages.extend(tool_messages)
|
|
476
|
+
|
|
477
|
+
except Exception as e:
|
|
478
|
+
logger.error("Model call failed: %s", e)
|
|
479
|
+
return AgentResult(done=False, reward=0.0, error=str(e), messages=messages)
|
|
480
|
+
|
|
481
|
+
return AgentResult(done=True, reward=0.0, messages=messages)
|
|
482
|
+
|
|
483
|
+
except KeyboardInterrupt:
|
|
484
|
+
logger.info("Conversation interrupted by user")
|
|
485
|
+
return AgentResult(
|
|
486
|
+
done=False, reward=0.0, messages=messages if "messages" in locals() else []
|
|
487
|
+
)
|
|
488
|
+
except asyncio.CancelledError:
|
|
489
|
+
logger.info("Conversation cancelled")
|
|
490
|
+
return AgentResult(
|
|
491
|
+
done=False, reward=0.0, messages=messages if "messages" in locals() else []
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
async def _run_prompt(self, prompt: str, max_steps: int = 10) -> AgentResult:
|
|
495
|
+
"""
|
|
496
|
+
Run the agent with the given prompt in task mode.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
prompt: The task to complete
|
|
500
|
+
max_steps: Maximum number of steps
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
AgentResult for task completion
|
|
504
|
+
"""
|
|
505
|
+
try:
|
|
506
|
+
latest_screenshot = None
|
|
507
|
+
if self.initial_screenshot:
|
|
508
|
+
latest_screenshot = await self.capture_screenshot()
|
|
509
|
+
|
|
510
|
+
messages = await self.create_initial_messages(prompt, latest_screenshot)
|
|
511
|
+
|
|
512
|
+
step = 0
|
|
513
|
+
while step < max_steps:
|
|
514
|
+
step += 1
|
|
515
|
+
logger.info("step %s/%s", step, max_steps)
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
response = await self.get_model_response(messages)
|
|
519
|
+
|
|
520
|
+
# Log the model's response
|
|
521
|
+
logger.info("Model response - Content: %s", response.content)
|
|
522
|
+
logger.info(
|
|
523
|
+
"Model response - Tool calls: %s",
|
|
524
|
+
[tc.name for tc in response.tool_calls],
|
|
525
|
+
)
|
|
526
|
+
logger.info("Model response - Done: %s", response.done)
|
|
527
|
+
|
|
528
|
+
# Check if we should stop
|
|
529
|
+
if response.done:
|
|
530
|
+
return AgentResult(
|
|
531
|
+
content=response.content, done=response.done, messages=messages
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
tool_calls = response.tool_calls
|
|
535
|
+
if not tool_calls:
|
|
536
|
+
# In task mode, no tool calls means we're done
|
|
537
|
+
logger.info("No tool calls - stopping execution")
|
|
538
|
+
logger.info(
|
|
539
|
+
"Final message: %s",
|
|
540
|
+
response.content,
|
|
541
|
+
)
|
|
542
|
+
return AgentResult(
|
|
543
|
+
done=True, reward=0.0, content=response.content, messages=messages
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Execute tool calls
|
|
547
|
+
tool_results = []
|
|
548
|
+
for tool_call in tool_calls:
|
|
549
|
+
try:
|
|
550
|
+
result = await self.call_tool(tool_call)
|
|
551
|
+
tool_results.append(result)
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.error("Tool execution failed: %s", e)
|
|
554
|
+
# Create error MCPToolResult
|
|
555
|
+
error_result = MCPToolResult(
|
|
556
|
+
content=[types.TextContent(text=str(e), type="text")], isError=True
|
|
557
|
+
)
|
|
558
|
+
tool_results.append(error_result)
|
|
559
|
+
|
|
560
|
+
# Format tool results for the model
|
|
561
|
+
tool_messages = await self.format_tool_results(tool_calls, tool_results)
|
|
562
|
+
messages.extend(tool_messages)
|
|
563
|
+
|
|
564
|
+
except Exception as e:
|
|
565
|
+
logger.error("Model call failed: %s", e)
|
|
566
|
+
return AgentResult(done=False, reward=0.0, error=str(e), messages=messages)
|
|
567
|
+
|
|
568
|
+
return AgentResult(done=True, reward=0.0, messages=messages)
|
|
569
|
+
|
|
570
|
+
except KeyboardInterrupt:
|
|
571
|
+
logger.info("Agent execution interrupted by user")
|
|
572
|
+
return AgentResult(done=False, reward=0.0, messages=messages)
|
|
573
|
+
except asyncio.CancelledError:
|
|
574
|
+
logger.info("Agent execution cancelled")
|
|
575
|
+
return AgentResult(done=False, reward=0.0, messages=messages)
|
|
576
|
+
|
|
577
|
+
@abstractmethod
|
|
578
|
+
async def create_initial_messages(self, prompt: str, screenshot: str | None) -> list[Any]:
|
|
579
|
+
"""
|
|
580
|
+
Create initial messages for the conversation.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
prompt: The user's prompt
|
|
584
|
+
screenshot: Optional initial screenshot
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
List of messages in provider-specific format
|
|
588
|
+
"""
|
|
589
|
+
|
|
590
|
+
@abstractmethod
|
|
591
|
+
async def get_model_response(self, messages: list[Any]) -> ModelResponse:
|
|
592
|
+
"""
|
|
593
|
+
Get response from the model including any tool calls.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
messages: List of messages in provider-specific format
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
ModelResponse with content, tool_calls, and done fields
|
|
600
|
+
"""
|
|
601
|
+
|
|
602
|
+
@abstractmethod
|
|
603
|
+
async def format_tool_results(
|
|
604
|
+
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
605
|
+
) -> list[Any]:
|
|
606
|
+
"""
|
|
607
|
+
Format tool results into messages for the model.
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
tool_calls: List of MCPToolCall objects that were executed
|
|
611
|
+
tool_results: List of MCPToolResult objects from tool execution
|
|
612
|
+
|
|
613
|
+
Returns:
|
|
614
|
+
List of formatted messages to append to conversation
|
|
615
|
+
"""
|
|
616
|
+
raise NotImplementedError
|
|
617
|
+
|
|
618
|
+
async def create_user_message(self, text: str) -> Any:
|
|
619
|
+
"""
|
|
620
|
+
Create a user message in the format expected by the model.
|
|
621
|
+
|
|
622
|
+
Default implementation for text-only messages.
|
|
623
|
+
Subclasses can override for specific formats.
|
|
624
|
+
|
|
625
|
+
Args:
|
|
626
|
+
text: User's text input
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
Formatted user message
|
|
630
|
+
"""
|
|
631
|
+
return {"role": "user", "content": text}
|