hud-python 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +22 -22
- hud/agents/__init__.py +13 -15
- hud/agents/base.py +599 -599
- hud/agents/claude.py +373 -373
- hud/agents/langchain.py +261 -250
- hud/agents/misc/__init__.py +7 -7
- hud/agents/misc/response_agent.py +82 -80
- hud/agents/openai.py +352 -352
- hud/agents/openai_chat_generic.py +154 -154
- hud/agents/tests/__init__.py +1 -1
- hud/agents/tests/test_base.py +742 -742
- hud/agents/tests/test_claude.py +324 -324
- hud/agents/tests/test_client.py +363 -363
- hud/agents/tests/test_openai.py +237 -237
- hud/cli/__init__.py +617 -617
- hud/cli/__main__.py +8 -8
- hud/cli/analyze.py +371 -371
- hud/cli/analyze_metadata.py +230 -230
- hud/cli/build.py +498 -427
- hud/cli/clone.py +185 -185
- hud/cli/cursor.py +92 -92
- hud/cli/debug.py +392 -392
- hud/cli/docker_utils.py +83 -83
- hud/cli/init.py +280 -281
- hud/cli/interactive.py +353 -353
- hud/cli/mcp_server.py +764 -756
- hud/cli/pull.py +330 -336
- hud/cli/push.py +404 -370
- hud/cli/remote_runner.py +311 -311
- hud/cli/runner.py +160 -160
- hud/cli/tests/__init__.py +3 -3
- hud/cli/tests/test_analyze.py +284 -284
- hud/cli/tests/test_cli_init.py +265 -265
- hud/cli/tests/test_cli_main.py +27 -27
- hud/cli/tests/test_clone.py +142 -142
- hud/cli/tests/test_cursor.py +253 -253
- hud/cli/tests/test_debug.py +453 -453
- hud/cli/tests/test_mcp_server.py +139 -139
- hud/cli/tests/test_utils.py +388 -388
- hud/cli/utils.py +263 -263
- hud/clients/README.md +143 -143
- hud/clients/__init__.py +16 -16
- hud/clients/base.py +378 -379
- hud/clients/fastmcp.py +222 -222
- hud/clients/mcp_use.py +298 -278
- hud/clients/tests/__init__.py +1 -1
- hud/clients/tests/test_client_integration.py +111 -111
- hud/clients/tests/test_fastmcp.py +342 -342
- hud/clients/tests/test_protocol.py +188 -188
- hud/clients/utils/__init__.py +1 -1
- hud/clients/utils/retry_transport.py +160 -160
- hud/datasets.py +327 -322
- hud/misc/__init__.py +1 -1
- hud/misc/claude_plays_pokemon.py +292 -292
- hud/otel/__init__.py +35 -35
- hud/otel/collector.py +142 -142
- hud/otel/config.py +164 -164
- hud/otel/context.py +536 -536
- hud/otel/exporters.py +366 -366
- hud/otel/instrumentation.py +97 -97
- hud/otel/processors.py +118 -118
- hud/otel/tests/__init__.py +1 -1
- hud/otel/tests/test_processors.py +197 -197
- hud/server/__init__.py +5 -5
- hud/server/context.py +114 -114
- hud/server/helper/__init__.py +5 -5
- hud/server/low_level.py +132 -132
- hud/server/server.py +170 -166
- hud/server/tests/__init__.py +3 -3
- hud/settings.py +73 -73
- hud/shared/__init__.py +5 -5
- hud/shared/exceptions.py +180 -180
- hud/shared/requests.py +264 -264
- hud/shared/tests/test_exceptions.py +157 -157
- hud/shared/tests/test_requests.py +275 -275
- hud/telemetry/__init__.py +25 -25
- hud/telemetry/instrument.py +379 -379
- hud/telemetry/job.py +309 -309
- hud/telemetry/replay.py +74 -74
- hud/telemetry/trace.py +83 -83
- hud/tools/__init__.py +33 -33
- hud/tools/base.py +365 -365
- hud/tools/bash.py +161 -161
- hud/tools/computer/__init__.py +15 -15
- hud/tools/computer/anthropic.py +437 -437
- hud/tools/computer/hud.py +376 -376
- hud/tools/computer/openai.py +295 -295
- hud/tools/computer/settings.py +82 -82
- hud/tools/edit.py +314 -314
- hud/tools/executors/__init__.py +30 -30
- hud/tools/executors/base.py +539 -539
- hud/tools/executors/pyautogui.py +621 -621
- hud/tools/executors/tests/__init__.py +1 -1
- hud/tools/executors/tests/test_base_executor.py +338 -338
- hud/tools/executors/tests/test_pyautogui_executor.py +165 -165
- hud/tools/executors/xdo.py +511 -511
- hud/tools/playwright.py +412 -412
- hud/tools/tests/__init__.py +3 -3
- hud/tools/tests/test_base.py +282 -282
- hud/tools/tests/test_bash.py +158 -158
- hud/tools/tests/test_bash_extended.py +197 -197
- hud/tools/tests/test_computer.py +425 -425
- hud/tools/tests/test_computer_actions.py +34 -34
- hud/tools/tests/test_edit.py +259 -259
- hud/tools/tests/test_init.py +27 -27
- hud/tools/tests/test_playwright_tool.py +183 -183
- hud/tools/tests/test_tools.py +145 -145
- hud/tools/tests/test_utils.py +156 -156
- hud/tools/types.py +72 -72
- hud/tools/utils.py +50 -50
- hud/types.py +136 -136
- hud/utils/__init__.py +10 -10
- hud/utils/async_utils.py +65 -65
- hud/utils/design.py +236 -168
- hud/utils/mcp.py +55 -55
- hud/utils/progress.py +149 -149
- hud/utils/telemetry.py +66 -66
- hud/utils/tests/test_async_utils.py +173 -173
- hud/utils/tests/test_init.py +17 -17
- hud/utils/tests/test_progress.py +261 -261
- hud/utils/tests/test_telemetry.py +82 -82
- hud/utils/tests/test_version.py +8 -8
- hud/version.py +7 -7
- {hud_python-0.4.1.dist-info → hud_python-0.4.3.dist-info}/METADATA +10 -8
- hud_python-0.4.3.dist-info/RECORD +131 -0
- {hud_python-0.4.1.dist-info → hud_python-0.4.3.dist-info}/licenses/LICENSE +21 -21
- hud/agents/art.py +0 -101
- hud_python-0.4.1.dist-info/RECORD +0 -132
- {hud_python-0.4.1.dist-info → hud_python-0.4.3.dist-info}/WHEEL +0 -0
- {hud_python-0.4.1.dist-info → hud_python-0.4.3.dist-info}/entry_points.txt +0 -0
hud/agents/base.py
CHANGED
|
@@ -1,599 +1,599 @@
|
|
|
1
|
-
"""Base MCP Agent implementation."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import asyncio
|
|
6
|
-
import json
|
|
7
|
-
import logging
|
|
8
|
-
from abc import ABC, abstractmethod
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Literal
|
|
10
|
-
|
|
11
|
-
import mcp.types as types
|
|
12
|
-
|
|
13
|
-
from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace
|
|
14
|
-
from hud.utils.mcp import MCPConfigPatch, patch_mcp_config, setup_hud_telemetry
|
|
15
|
-
|
|
16
|
-
if TYPE_CHECKING:
|
|
17
|
-
from hud.clients.base import AgentMCPClient
|
|
18
|
-
from hud.datasets import Task
|
|
19
|
-
|
|
20
|
-
from .misc import ResponseAgent
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
logger = logging.getLogger(__name__)
|
|
24
|
-
|
|
25
|
-
GLOBAL_SYSTEM_PROMPT = "You are an assistant that can use tools to help the user. You will be given a task and you will need to use the tools to complete the task." # noqa: E501
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class MCPAgent(ABC):
|
|
29
|
-
"""
|
|
30
|
-
Base class for MCP-enabled agents.
|
|
31
|
-
|
|
32
|
-
This class provides the foundation for agents that interact with MCP servers,
|
|
33
|
-
handling tool discovery and filtering while leaving provider-specific
|
|
34
|
-
implementation details to subclasses.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
metadata: dict[str, Any]
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
mcp_client: AgentMCPClient | None = None,
|
|
42
|
-
# Filtering
|
|
43
|
-
allowed_tools: list[str] | None = None,
|
|
44
|
-
disallowed_tools: list[str] | None = None,
|
|
45
|
-
lifecycle_tools: list[str] | None = None,
|
|
46
|
-
# Messages
|
|
47
|
-
system_prompt: str = GLOBAL_SYSTEM_PROMPT,
|
|
48
|
-
append_setup_output: bool = True,
|
|
49
|
-
initial_screenshot: bool = True,
|
|
50
|
-
# Misc
|
|
51
|
-
model_name: str = "mcp-agent",
|
|
52
|
-
response_agent: ResponseAgent | None = None,
|
|
53
|
-
auto_trace: bool = True,
|
|
54
|
-
) -> None:
|
|
55
|
-
"""
|
|
56
|
-
Initialize the base MCP agent.
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
mcp_client: AgentMCPClient instance for server connections
|
|
60
|
-
allowed_tools: List of tool names to allow (None = all tools)
|
|
61
|
-
disallowed_tools: List of tool names to disallow
|
|
62
|
-
lifecycle_tools: List of tool names to use for lifecycle tools
|
|
63
|
-
initial_screenshot: Whether to capture screenshot before first prompt
|
|
64
|
-
system_prompt: System prompt to use
|
|
65
|
-
append_setup_output: Whether to append setup tool output to initial messages
|
|
66
|
-
"""
|
|
67
|
-
|
|
68
|
-
self.mcp_client = mcp_client
|
|
69
|
-
self._auto_created_client = False # Track if we created the client
|
|
70
|
-
|
|
71
|
-
self.model_name = model_name
|
|
72
|
-
|
|
73
|
-
# Filtering
|
|
74
|
-
self.allowed_tools = allowed_tools
|
|
75
|
-
self.disallowed_tools = disallowed_tools or []
|
|
76
|
-
self.lifecycle_tools = lifecycle_tools or []
|
|
77
|
-
|
|
78
|
-
# Messages
|
|
79
|
-
self.system_prompt = system_prompt
|
|
80
|
-
self.append_setup_output = append_setup_output
|
|
81
|
-
self.initial_screenshot = initial_screenshot
|
|
82
|
-
|
|
83
|
-
# Initialize these here so methods can be called before initialize()
|
|
84
|
-
self._available_tools: list[types.Tool] = []
|
|
85
|
-
self._tool_map: dict[str, types.Tool] = {} # Simplified: just name to tool
|
|
86
|
-
self.screenshot_history: list[str] = []
|
|
87
|
-
self._auto_trace = auto_trace
|
|
88
|
-
self.initialization_complete = False
|
|
89
|
-
|
|
90
|
-
# Response agent to automatically interact with the model
|
|
91
|
-
self.response_agent = response_agent
|
|
92
|
-
|
|
93
|
-
async def initialize(self, task: str | Task | None = None) -> None:
|
|
94
|
-
"""Initialize the agent with task-specific configuration."""
|
|
95
|
-
from hud.datasets import Task
|
|
96
|
-
|
|
97
|
-
# Create client if needed
|
|
98
|
-
if self.mcp_client is None and isinstance(task, Task) and task.mcp_config:
|
|
99
|
-
from hud.clients import MCPClient
|
|
100
|
-
|
|
101
|
-
self.mcp_client = MCPClient(mcp_config=task.mcp_config)
|
|
102
|
-
self._auto_created_client = True
|
|
103
|
-
logger.info("Auto-created MCPClient from task.mcp_config")
|
|
104
|
-
|
|
105
|
-
# Ensure we have a client
|
|
106
|
-
if self.mcp_client is None:
|
|
107
|
-
raise ValueError(
|
|
108
|
-
"No MCPClient. Please provide one when initializing the agent or pass a Task with mcp_config." # noqa: E501
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
await self._setup_config(self.mcp_client.mcp_config)
|
|
112
|
-
|
|
113
|
-
# Initialize client if needed
|
|
114
|
-
await self.mcp_client.initialize()
|
|
115
|
-
|
|
116
|
-
# If task is provided, add lifecycle tools
|
|
117
|
-
if isinstance(task, Task):
|
|
118
|
-
if task.setup_tool:
|
|
119
|
-
if isinstance(task.setup_tool, list):
|
|
120
|
-
for tool in task.setup_tool:
|
|
121
|
-
self.lifecycle_tools.append(tool.name)
|
|
122
|
-
else:
|
|
123
|
-
self.lifecycle_tools.append(task.setup_tool.name)
|
|
124
|
-
if task.evaluate_tool:
|
|
125
|
-
if isinstance(task.evaluate_tool, list):
|
|
126
|
-
for tool in task.evaluate_tool:
|
|
127
|
-
self.lifecycle_tools.append(tool.name)
|
|
128
|
-
else:
|
|
129
|
-
self.lifecycle_tools.append(task.evaluate_tool.name)
|
|
130
|
-
if task.system_prompt:
|
|
131
|
-
self.system_prompt += "\n\n" + task.system_prompt
|
|
132
|
-
|
|
133
|
-
# Re-apply filtering with updated lifecycle tools
|
|
134
|
-
await self._filter_tools()
|
|
135
|
-
|
|
136
|
-
logger.info(
|
|
137
|
-
"Agent initialized with %d available tools (after filtering)",
|
|
138
|
-
len(self._available_tools),
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace:
|
|
142
|
-
"""
|
|
143
|
-
Run the agent with the given prompt or task.
|
|
144
|
-
|
|
145
|
-
Args:
|
|
146
|
-
prompt_or_task: Either a string prompt for simple execution or a Task object
|
|
147
|
-
max_steps: Maximum number of steps (-1 for infinite)
|
|
148
|
-
|
|
149
|
-
Returns:
|
|
150
|
-
Trace with reward, done, content, isError fields and trace steps
|
|
151
|
-
"""
|
|
152
|
-
# Import here to avoid circular imports
|
|
153
|
-
from hud.datasets import Task
|
|
154
|
-
|
|
155
|
-
if isinstance(prompt_or_task, dict):
|
|
156
|
-
prompt_or_task = Task(**prompt_or_task)
|
|
157
|
-
|
|
158
|
-
try:
|
|
159
|
-
# Establish the connection with the MCP server/Environment
|
|
160
|
-
if not self.initialization_complete:
|
|
161
|
-
await self.initialize(prompt_or_task)
|
|
162
|
-
self.initialization_complete = True
|
|
163
|
-
|
|
164
|
-
# Handle Task objects with full lifecycle
|
|
165
|
-
if isinstance(prompt_or_task, Task):
|
|
166
|
-
return await self.run_task(prompt_or_task, max_steps)
|
|
167
|
-
|
|
168
|
-
# Handle simple string prompts
|
|
169
|
-
elif isinstance(prompt_or_task, str):
|
|
170
|
-
context = text_to_blocks(prompt_or_task)
|
|
171
|
-
return await self._run_context(context, max_steps=max_steps)
|
|
172
|
-
|
|
173
|
-
else:
|
|
174
|
-
raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
|
|
175
|
-
finally:
|
|
176
|
-
# Cleanup auto-created resources
|
|
177
|
-
await self._cleanup()
|
|
178
|
-
|
|
179
|
-
async def run_task(self, task: Task, max_steps: int = 10) -> Trace:
|
|
180
|
-
"""
|
|
181
|
-
Execute a task with setup and evaluate phases.
|
|
182
|
-
|
|
183
|
-
Args:
|
|
184
|
-
task: Task object with prompt, setup, and evaluate configs
|
|
185
|
-
max_steps: Maximum steps for task execution (-1 for infinite)
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
Trace with reward from evaluation
|
|
189
|
-
"""
|
|
190
|
-
prompt_result = None
|
|
191
|
-
|
|
192
|
-
try:
|
|
193
|
-
# Setup phase
|
|
194
|
-
start_context: list[types.ContentBlock] = []
|
|
195
|
-
|
|
196
|
-
# Extract the initial task information
|
|
197
|
-
if task.prompt:
|
|
198
|
-
start_context.extend(text_to_blocks(task.prompt))
|
|
199
|
-
|
|
200
|
-
# Execute the setup tool and append the initial observation to the context
|
|
201
|
-
if task.setup_tool is not None:
|
|
202
|
-
logger.info("Setting up tool phase: %s", task.setup_tool)
|
|
203
|
-
results = await self.call_tools(task.setup_tool)
|
|
204
|
-
if any(result.isError for result in results):
|
|
205
|
-
raise RuntimeError(f"{results}")
|
|
206
|
-
|
|
207
|
-
if self.append_setup_output and isinstance(results[0].content, list):
|
|
208
|
-
start_context.extend(results[0].content)
|
|
209
|
-
if not self.initial_screenshot:
|
|
210
|
-
start_context = await self._filter_messages(start_context, include_types=["text"])
|
|
211
|
-
|
|
212
|
-
# Execute the task (agent loop) - this returns a empty trace object with the final response # noqa: E501
|
|
213
|
-
prompt_result = await self._run_context(start_context, max_steps=max_steps)
|
|
214
|
-
|
|
215
|
-
except Exception as e:
|
|
216
|
-
logger.error("Task execution failed: %s", e)
|
|
217
|
-
# Create an error result but don't return yet - we still want to evaluate
|
|
218
|
-
prompt_result = Trace(reward=0.0, done=True, content=str(e), isError=True)
|
|
219
|
-
prompt_result.populate_from_context()
|
|
220
|
-
|
|
221
|
-
# Always evaluate if we have a prompt result and evaluate tool
|
|
222
|
-
if prompt_result is not None and task.evaluate_tool is not None:
|
|
223
|
-
try:
|
|
224
|
-
logger.info("Evaluating tool phase: %s", task.evaluate_tool)
|
|
225
|
-
results = await self.call_tools(task.evaluate_tool)
|
|
226
|
-
|
|
227
|
-
if any(result.isError for result in results):
|
|
228
|
-
raise RuntimeError(f"{results}")
|
|
229
|
-
|
|
230
|
-
# Extract reward and content from evaluation
|
|
231
|
-
if results:
|
|
232
|
-
reward = find_reward(results[0])
|
|
233
|
-
eval_content = find_content(results[0])
|
|
234
|
-
|
|
235
|
-
# Update the prompt result with evaluation reward
|
|
236
|
-
prompt_result.reward = reward
|
|
237
|
-
|
|
238
|
-
# Update the prompt result with evaluation content (if available)
|
|
239
|
-
if eval_content:
|
|
240
|
-
# Prompt result may already have final response content, so we append to it
|
|
241
|
-
if prompt_result.content:
|
|
242
|
-
prompt_result.content += "\n\n" + eval_content
|
|
243
|
-
else:
|
|
244
|
-
prompt_result.content = eval_content
|
|
245
|
-
|
|
246
|
-
except Exception as e:
|
|
247
|
-
logger.error("Evaluation phase failed: %s", e)
|
|
248
|
-
# Continue with the prompt result even if evaluation failed
|
|
249
|
-
|
|
250
|
-
return (
|
|
251
|
-
prompt_result
|
|
252
|
-
if prompt_result
|
|
253
|
-
else Trace(reward=0.0, done=True, content="No result available", isError=True)
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
async def _run_context(
|
|
257
|
-
self, context: list[types.ContentBlock], *, max_steps: int = 10
|
|
258
|
-
) -> Trace:
|
|
259
|
-
"""
|
|
260
|
-
Run the agent with the given context messages. This is the core agent loop.
|
|
261
|
-
|
|
262
|
-
Args:
|
|
263
|
-
context: The context to complete
|
|
264
|
-
max_steps: Maximum number of steps (-1 for infinite)
|
|
265
|
-
|
|
266
|
-
Returns:
|
|
267
|
-
Trace with reward, done, content fields and trace steps
|
|
268
|
-
"""
|
|
269
|
-
final_response = None
|
|
270
|
-
error = None
|
|
271
|
-
|
|
272
|
-
try:
|
|
273
|
-
# Start with system messages
|
|
274
|
-
messages = await self.get_system_messages()
|
|
275
|
-
|
|
276
|
-
# Add initial context
|
|
277
|
-
messages.extend(await self.format_message(context))
|
|
278
|
-
logger.debug("Messages: %s", messages)
|
|
279
|
-
|
|
280
|
-
step_count = 0
|
|
281
|
-
while max_steps == -1 or step_count < max_steps:
|
|
282
|
-
step_count += 1
|
|
283
|
-
if max_steps == -1:
|
|
284
|
-
logger.info("Step %s (unlimited)", step_count)
|
|
285
|
-
else:
|
|
286
|
-
logger.info("Step %s/%s", step_count, max_steps)
|
|
287
|
-
|
|
288
|
-
try:
|
|
289
|
-
# 1. Get model response
|
|
290
|
-
response = await self.get_response(messages)
|
|
291
|
-
|
|
292
|
-
logger.info("Agent:\n%s", response)
|
|
293
|
-
|
|
294
|
-
# Check if we should stop
|
|
295
|
-
if response.done or not response.tool_calls:
|
|
296
|
-
# Optional external ResponseAgent to decide whether to stop
|
|
297
|
-
decision = "STOP"
|
|
298
|
-
if self.response_agent is not None and response.content:
|
|
299
|
-
try:
|
|
300
|
-
decision = await self.response_agent.determine_response(
|
|
301
|
-
response.content
|
|
302
|
-
)
|
|
303
|
-
except Exception as e:
|
|
304
|
-
logger.warning("ResponseAgent failed: %s", e)
|
|
305
|
-
if decision == "STOP":
|
|
306
|
-
logger.info("Stopping execution")
|
|
307
|
-
final_response = response
|
|
308
|
-
break
|
|
309
|
-
else:
|
|
310
|
-
logger.info("Continuing execution")
|
|
311
|
-
messages.extend(await self.format_message(decision))
|
|
312
|
-
continue
|
|
313
|
-
|
|
314
|
-
# 2. Execute tools
|
|
315
|
-
tool_calls = response.tool_calls
|
|
316
|
-
tool_results = await self.call_tools(tool_calls)
|
|
317
|
-
|
|
318
|
-
# 3. Format tool results and add to messages
|
|
319
|
-
tool_messages = await self.format_tool_results(tool_calls, tool_results)
|
|
320
|
-
messages.extend(tool_messages)
|
|
321
|
-
|
|
322
|
-
except Exception as e:
|
|
323
|
-
logger.error("Step failed: %s", e)
|
|
324
|
-
error = str(e)
|
|
325
|
-
break
|
|
326
|
-
|
|
327
|
-
except KeyboardInterrupt:
|
|
328
|
-
logger.info("Agent execution interrupted by user")
|
|
329
|
-
error = "Interrupted by user"
|
|
330
|
-
except asyncio.CancelledError:
|
|
331
|
-
logger.info("Agent execution cancelled")
|
|
332
|
-
error = "Cancelled"
|
|
333
|
-
except Exception as e:
|
|
334
|
-
logger.error("Unexpected error: %s", e)
|
|
335
|
-
error = str(e)
|
|
336
|
-
|
|
337
|
-
# Build result
|
|
338
|
-
trace_result = Trace(
|
|
339
|
-
reward=0.0, # Default - will be set by task evaluation if applicable
|
|
340
|
-
done=True,
|
|
341
|
-
content=final_response.content if final_response else None,
|
|
342
|
-
isError=error is not None,
|
|
343
|
-
info={"error": error} if error else {},
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
# Populate trace steps from current context
|
|
347
|
-
trace_result.populate_from_context()
|
|
348
|
-
|
|
349
|
-
return trace_result
|
|
350
|
-
|
|
351
|
-
async def call_tools(
|
|
352
|
-
self, tool_call: MCPToolCall | list[MCPToolCall] | None = None
|
|
353
|
-
) -> list[MCPToolResult]:
|
|
354
|
-
"""
|
|
355
|
-
Call a tool through the MCP client.
|
|
356
|
-
|
|
357
|
-
Args:
|
|
358
|
-
tool_call: MCPToolCall or list of MCPToolCall
|
|
359
|
-
|
|
360
|
-
Returns:
|
|
361
|
-
List of MCPToolResult
|
|
362
|
-
"""
|
|
363
|
-
if tool_call is None:
|
|
364
|
-
return []
|
|
365
|
-
|
|
366
|
-
if isinstance(tool_call, MCPToolCall):
|
|
367
|
-
tool_call = [tool_call]
|
|
368
|
-
|
|
369
|
-
if self.mcp_client is None:
|
|
370
|
-
raise ValueError("Client is not initialized")
|
|
371
|
-
|
|
372
|
-
results: list[MCPToolResult] = []
|
|
373
|
-
for tc in tool_call:
|
|
374
|
-
try:
|
|
375
|
-
logger.info("Calling tool: %s", tc)
|
|
376
|
-
results.append(await self.mcp_client.call_tool(tc))
|
|
377
|
-
except TimeoutError as e:
|
|
378
|
-
logger.error("Tool execution timed out: %s", e)
|
|
379
|
-
try:
|
|
380
|
-
await self.mcp_client.shutdown()
|
|
381
|
-
except Exception as close_err:
|
|
382
|
-
logger.debug("Failed to close MCP client cleanly: %s", close_err)
|
|
383
|
-
raise
|
|
384
|
-
except Exception as e:
|
|
385
|
-
logger.error("Tool execution failed: %s", e)
|
|
386
|
-
results.append(_format_error_result(str(e)))
|
|
387
|
-
return results
|
|
388
|
-
|
|
389
|
-
@abstractmethod
|
|
390
|
-
async def get_system_messages(self) -> list[Any]:
|
|
391
|
-
"""
|
|
392
|
-
Get the system prompt.
|
|
393
|
-
"""
|
|
394
|
-
raise NotImplementedError
|
|
395
|
-
|
|
396
|
-
@abstractmethod
|
|
397
|
-
async def get_response(
|
|
398
|
-
self, messages: list[Any]
|
|
399
|
-
) -> AgentResponse: # maybe type messages as list[types.ContentBlock]
|
|
400
|
-
"""
|
|
401
|
-
Get response from the model including any tool calls.
|
|
402
|
-
|
|
403
|
-
NOTE: Subclasses should decorate this method with:
|
|
404
|
-
@hud.instrument(span_type="agent", record_args=False, record_result=True)
|
|
405
|
-
|
|
406
|
-
Args:
|
|
407
|
-
messages: Current conversation messages
|
|
408
|
-
|
|
409
|
-
Returns:
|
|
410
|
-
AgentResponse with content, tool_calls, and done fields
|
|
411
|
-
"""
|
|
412
|
-
raise NotImplementedError
|
|
413
|
-
|
|
414
|
-
@abstractmethod
|
|
415
|
-
async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]:
|
|
416
|
-
"""
|
|
417
|
-
Format a list of content blocks into a list of messages.
|
|
418
|
-
"""
|
|
419
|
-
raise NotImplementedError
|
|
420
|
-
|
|
421
|
-
@abstractmethod
|
|
422
|
-
async def format_tool_results(
|
|
423
|
-
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
424
|
-
) -> list[Any]:
|
|
425
|
-
"""
|
|
426
|
-
Format tool results into messages for the model.
|
|
427
|
-
|
|
428
|
-
Args:
|
|
429
|
-
tool_calls: List of MCPToolCall objects that were executed
|
|
430
|
-
tool_results: List of MCPToolResult objects from tool execution
|
|
431
|
-
|
|
432
|
-
Returns:
|
|
433
|
-
List of formatted messages to append to conversation
|
|
434
|
-
"""
|
|
435
|
-
raise NotImplementedError
|
|
436
|
-
|
|
437
|
-
async def format_message(
|
|
438
|
-
self,
|
|
439
|
-
message: str
|
|
440
|
-
| list[str]
|
|
441
|
-
| types.ContentBlock
|
|
442
|
-
| list[types.ContentBlock]
|
|
443
|
-
| list[str | types.ContentBlock],
|
|
444
|
-
) -> list[Any]: # maybe type messages as list[types.ContentBlock]
|
|
445
|
-
"""
|
|
446
|
-
Convencience function.
|
|
447
|
-
|
|
448
|
-
Format a single content message into a list of messages for the model.
|
|
449
|
-
"""
|
|
450
|
-
blocks: list[types.ContentBlock] = []
|
|
451
|
-
if not isinstance(message, list):
|
|
452
|
-
message = [message]
|
|
453
|
-
|
|
454
|
-
for m in message:
|
|
455
|
-
if isinstance(m, str):
|
|
456
|
-
blocks.append(types.TextContent(text=m, type="text"))
|
|
457
|
-
elif isinstance(m, types.ContentBlock):
|
|
458
|
-
blocks.append(m)
|
|
459
|
-
else:
|
|
460
|
-
raise ValueError(f"Invalid message type: {type(m)}")
|
|
461
|
-
|
|
462
|
-
return await self.format_blocks(blocks)
|
|
463
|
-
|
|
464
|
-
async def _filter_tools(self) -> None:
|
|
465
|
-
"""Apply tool filtering based on allowed/disallowed lists."""
|
|
466
|
-
# Get all tools from client
|
|
467
|
-
if self.mcp_client is None:
|
|
468
|
-
raise ValueError("MCP client is not initialized")
|
|
469
|
-
|
|
470
|
-
all_tools = await self.mcp_client.list_tools()
|
|
471
|
-
|
|
472
|
-
# Filter tools
|
|
473
|
-
self._available_tools = []
|
|
474
|
-
self._tool_map = {}
|
|
475
|
-
|
|
476
|
-
for tool in all_tools:
|
|
477
|
-
# Check if tool should be included
|
|
478
|
-
if self.allowed_tools and tool.name not in self.allowed_tools:
|
|
479
|
-
continue
|
|
480
|
-
if tool.name in self.disallowed_tools:
|
|
481
|
-
continue
|
|
482
|
-
|
|
483
|
-
self._available_tools.append(tool)
|
|
484
|
-
# Simplified mapping - just tool name to tool
|
|
485
|
-
self._tool_map[tool.name] = tool
|
|
486
|
-
|
|
487
|
-
async def _setup_config(self, mcp_config: dict[str, dict[str, Any]]) -> None:
|
|
488
|
-
"""Inject metadata into the metadata of the initialize request."""
|
|
489
|
-
if self.metadata:
|
|
490
|
-
patch_mcp_config(
|
|
491
|
-
mcp_config,
|
|
492
|
-
MCPConfigPatch(meta=self.metadata),
|
|
493
|
-
)
|
|
494
|
-
setup_hud_telemetry(mcp_config, auto_trace=self._auto_trace)
|
|
495
|
-
|
|
496
|
-
def get_available_tools(self) -> list[types.Tool]:
|
|
497
|
-
"""Get list of available MCP tools for LLM use (excludes lifecycle tools)."""
|
|
498
|
-
lifecycle_tool_names = self.lifecycle_tools
|
|
499
|
-
return [tool for tool in self._available_tools if tool.name not in lifecycle_tool_names]
|
|
500
|
-
|
|
501
|
-
def get_tool_schemas(self) -> list[dict]:
|
|
502
|
-
"""Get tool schemas in a format suitable for the model."""
|
|
503
|
-
schemas = []
|
|
504
|
-
for tool in self.get_available_tools():
|
|
505
|
-
schema = {
|
|
506
|
-
"name": tool.name,
|
|
507
|
-
"description": tool.description,
|
|
508
|
-
}
|
|
509
|
-
if tool.inputSchema:
|
|
510
|
-
schema["parameters"] = tool.inputSchema
|
|
511
|
-
schemas.append(schema)
|
|
512
|
-
return schemas
|
|
513
|
-
|
|
514
|
-
async def _filter_messages(
|
|
515
|
-
self,
|
|
516
|
-
message_list: list[types.ContentBlock],
|
|
517
|
-
include_types: list[
|
|
518
|
-
Literal["text", "image", "audio", "resource_link", "embedded_resource"]
|
|
519
|
-
],
|
|
520
|
-
) -> list[types.ContentBlock]:
|
|
521
|
-
"""
|
|
522
|
-
Filter a list of messages and return only the messages of the given types.
|
|
523
|
-
|
|
524
|
-
Args:
|
|
525
|
-
message_list: The list of messages to filter
|
|
526
|
-
include_types: List of types to include (None = all types)
|
|
527
|
-
|
|
528
|
-
Returns:
|
|
529
|
-
List of messages in provider-specific format
|
|
530
|
-
"""
|
|
531
|
-
return [message for message in message_list if message.type in include_types]
|
|
532
|
-
|
|
533
|
-
async def _cleanup(self) -> None:
|
|
534
|
-
"""Cleanup resources."""
|
|
535
|
-
if self._auto_created_client and self.mcp_client:
|
|
536
|
-
try:
|
|
537
|
-
await self.mcp_client.shutdown()
|
|
538
|
-
logger.info("Closed auto-created MCPClient")
|
|
539
|
-
except Exception as e:
|
|
540
|
-
logger.warning("Failed to close auto-created client: %s", e)
|
|
541
|
-
finally:
|
|
542
|
-
self.mcp_client = None
|
|
543
|
-
self._auto_created_client = False
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
def _format_error_result(error_message: str) -> MCPToolResult:
|
|
547
|
-
return MCPToolResult(content=text_to_blocks(error_message), isError=True)
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
def text_to_blocks(text: str) -> list[types.ContentBlock]:
|
|
551
|
-
return [types.TextContent(text=text, type="text")]
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
def find_reward(result: MCPToolResult) -> float:
|
|
555
|
-
"""Find the reward in the result.
|
|
556
|
-
|
|
557
|
-
Agent accepts "reward", "grade", "score"
|
|
558
|
-
|
|
559
|
-
If not found, return 0.0
|
|
560
|
-
"""
|
|
561
|
-
accept_keys = ["reward", "grade", "score"]
|
|
562
|
-
for key in accept_keys:
|
|
563
|
-
if isinstance(result.structuredContent, dict) and key in result.structuredContent:
|
|
564
|
-
return result.structuredContent[key]
|
|
565
|
-
if isinstance(result.content, list):
|
|
566
|
-
for content in result.content:
|
|
567
|
-
if isinstance(content, types.TextContent):
|
|
568
|
-
try:
|
|
569
|
-
json_content = json.loads(content.text)
|
|
570
|
-
for key, value in json_content.items():
|
|
571
|
-
if key in accept_keys:
|
|
572
|
-
return value
|
|
573
|
-
except json.JSONDecodeError:
|
|
574
|
-
pass
|
|
575
|
-
return 0.0
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
def find_content(result: MCPToolResult) -> str | None:
|
|
579
|
-
"""Find the content in the result.
|
|
580
|
-
|
|
581
|
-
Agent accepts "content", "text", "message", or "logs"
|
|
582
|
-
|
|
583
|
-
If not found, return 0.0
|
|
584
|
-
"""
|
|
585
|
-
accept_keys = ["content", "text", "message", "logs"]
|
|
586
|
-
for key in accept_keys:
|
|
587
|
-
if isinstance(result.structuredContent, dict) and key in result.structuredContent:
|
|
588
|
-
return result.structuredContent[key]
|
|
589
|
-
if isinstance(result.content, list):
|
|
590
|
-
for content in result.content:
|
|
591
|
-
if isinstance(content, types.TextContent):
|
|
592
|
-
try:
|
|
593
|
-
json_content = json.loads(content.text)
|
|
594
|
-
for key, value in json_content.items():
|
|
595
|
-
if key in accept_keys:
|
|
596
|
-
return value
|
|
597
|
-
except json.JSONDecodeError:
|
|
598
|
-
pass
|
|
599
|
-
return ""
|
|
1
|
+
"""Base MCP Agent implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
10
|
+
|
|
11
|
+
import mcp.types as types
|
|
12
|
+
|
|
13
|
+
from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace
|
|
14
|
+
from hud.utils.mcp import MCPConfigPatch, patch_mcp_config, setup_hud_telemetry
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from hud.clients.base import AgentMCPClient
|
|
18
|
+
from hud.datasets import Task
|
|
19
|
+
|
|
20
|
+
from .misc import ResponseAgent
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
GLOBAL_SYSTEM_PROMPT = "You are an assistant that can use tools to help the user. You will be given a task and you will need to use the tools to complete the task." # noqa: E501
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MCPAgent(ABC):
|
|
29
|
+
"""
|
|
30
|
+
Base class for MCP-enabled agents.
|
|
31
|
+
|
|
32
|
+
This class provides the foundation for agents that interact with MCP servers,
|
|
33
|
+
handling tool discovery and filtering while leaving provider-specific
|
|
34
|
+
implementation details to subclasses.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
metadata: dict[str, Any]
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
mcp_client: AgentMCPClient | None = None,
|
|
42
|
+
# Filtering
|
|
43
|
+
allowed_tools: list[str] | None = None,
|
|
44
|
+
disallowed_tools: list[str] | None = None,
|
|
45
|
+
lifecycle_tools: list[str] | None = None,
|
|
46
|
+
# Messages
|
|
47
|
+
system_prompt: str = GLOBAL_SYSTEM_PROMPT,
|
|
48
|
+
append_setup_output: bool = True,
|
|
49
|
+
initial_screenshot: bool = True,
|
|
50
|
+
# Misc
|
|
51
|
+
model_name: str = "mcp-agent",
|
|
52
|
+
response_agent: ResponseAgent | None = None,
|
|
53
|
+
auto_trace: bool = True,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Initialize the base MCP agent.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
mcp_client: AgentMCPClient instance for server connections
|
|
60
|
+
allowed_tools: List of tool names to allow (None = all tools)
|
|
61
|
+
disallowed_tools: List of tool names to disallow
|
|
62
|
+
lifecycle_tools: List of tool names to use for lifecycle tools
|
|
63
|
+
initial_screenshot: Whether to capture screenshot before first prompt
|
|
64
|
+
system_prompt: System prompt to use
|
|
65
|
+
append_setup_output: Whether to append setup tool output to initial messages
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
self.mcp_client = mcp_client
|
|
69
|
+
self._auto_created_client = False # Track if we created the client
|
|
70
|
+
|
|
71
|
+
self.model_name = model_name
|
|
72
|
+
|
|
73
|
+
# Filtering
|
|
74
|
+
self.allowed_tools = allowed_tools
|
|
75
|
+
self.disallowed_tools = disallowed_tools or []
|
|
76
|
+
self.lifecycle_tools = lifecycle_tools or []
|
|
77
|
+
|
|
78
|
+
# Messages
|
|
79
|
+
self.system_prompt = system_prompt
|
|
80
|
+
self.append_setup_output = append_setup_output
|
|
81
|
+
self.initial_screenshot = initial_screenshot
|
|
82
|
+
|
|
83
|
+
# Initialize these here so methods can be called before initialize()
|
|
84
|
+
self._available_tools: list[types.Tool] = []
|
|
85
|
+
self._tool_map: dict[str, types.Tool] = {} # Simplified: just name to tool
|
|
86
|
+
self.screenshot_history: list[str] = []
|
|
87
|
+
self._auto_trace = auto_trace
|
|
88
|
+
self.initialization_complete = False
|
|
89
|
+
|
|
90
|
+
# Response agent to automatically interact with the model
|
|
91
|
+
self.response_agent = response_agent
|
|
92
|
+
|
|
93
|
+
async def initialize(self, task: str | Task | None = None) -> None:
|
|
94
|
+
"""Initialize the agent with task-specific configuration."""
|
|
95
|
+
from hud.datasets import Task
|
|
96
|
+
|
|
97
|
+
# Create client if needed
|
|
98
|
+
if self.mcp_client is None and isinstance(task, Task) and task.mcp_config:
|
|
99
|
+
from hud.clients import MCPClient
|
|
100
|
+
|
|
101
|
+
self.mcp_client = MCPClient(mcp_config=task.mcp_config)
|
|
102
|
+
self._auto_created_client = True
|
|
103
|
+
logger.info("Auto-created MCPClient from task.mcp_config")
|
|
104
|
+
|
|
105
|
+
# Ensure we have a client
|
|
106
|
+
if self.mcp_client is None:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"No MCPClient. Please provide one when initializing the agent or pass a Task with mcp_config." # noqa: E501
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
await self._setup_config(self.mcp_client.mcp_config)
|
|
112
|
+
|
|
113
|
+
# Initialize client if needed
|
|
114
|
+
await self.mcp_client.initialize()
|
|
115
|
+
|
|
116
|
+
# If task is provided, add lifecycle tools
|
|
117
|
+
if isinstance(task, Task):
|
|
118
|
+
if task.setup_tool:
|
|
119
|
+
if isinstance(task.setup_tool, list):
|
|
120
|
+
for tool in task.setup_tool:
|
|
121
|
+
self.lifecycle_tools.append(tool.name)
|
|
122
|
+
else:
|
|
123
|
+
self.lifecycle_tools.append(task.setup_tool.name)
|
|
124
|
+
if task.evaluate_tool:
|
|
125
|
+
if isinstance(task.evaluate_tool, list):
|
|
126
|
+
for tool in task.evaluate_tool:
|
|
127
|
+
self.lifecycle_tools.append(tool.name)
|
|
128
|
+
else:
|
|
129
|
+
self.lifecycle_tools.append(task.evaluate_tool.name)
|
|
130
|
+
if task.system_prompt:
|
|
131
|
+
self.system_prompt += "\n\n" + task.system_prompt
|
|
132
|
+
|
|
133
|
+
# Re-apply filtering with updated lifecycle tools
|
|
134
|
+
await self._filter_tools()
|
|
135
|
+
|
|
136
|
+
logger.info(
|
|
137
|
+
"Agent initialized with %d available tools (after filtering)",
|
|
138
|
+
len(self._available_tools),
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace:
|
|
142
|
+
"""
|
|
143
|
+
Run the agent with the given prompt or task.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
prompt_or_task: Either a string prompt for simple execution or a Task object
|
|
147
|
+
max_steps: Maximum number of steps (-1 for infinite)
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Trace with reward, done, content, isError fields and trace steps
|
|
151
|
+
"""
|
|
152
|
+
# Import here to avoid circular imports
|
|
153
|
+
from hud.datasets import Task
|
|
154
|
+
|
|
155
|
+
if isinstance(prompt_or_task, dict):
|
|
156
|
+
prompt_or_task = Task(**prompt_or_task)
|
|
157
|
+
|
|
158
|
+
try:
|
|
159
|
+
# Establish the connection with the MCP server/Environment
|
|
160
|
+
if not self.initialization_complete:
|
|
161
|
+
await self.initialize(prompt_or_task)
|
|
162
|
+
self.initialization_complete = True
|
|
163
|
+
|
|
164
|
+
# Handle Task objects with full lifecycle
|
|
165
|
+
if isinstance(prompt_or_task, Task):
|
|
166
|
+
return await self.run_task(prompt_or_task, max_steps)
|
|
167
|
+
|
|
168
|
+
# Handle simple string prompts
|
|
169
|
+
elif isinstance(prompt_or_task, str):
|
|
170
|
+
context = text_to_blocks(prompt_or_task)
|
|
171
|
+
return await self._run_context(context, max_steps=max_steps)
|
|
172
|
+
|
|
173
|
+
else:
|
|
174
|
+
raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
|
|
175
|
+
finally:
|
|
176
|
+
# Cleanup auto-created resources
|
|
177
|
+
await self._cleanup()
|
|
178
|
+
|
|
179
|
+
async def run_task(self, task: Task, max_steps: int = 10) -> Trace:
|
|
180
|
+
"""
|
|
181
|
+
Execute a task with setup and evaluate phases.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
task: Task object with prompt, setup, and evaluate configs
|
|
185
|
+
max_steps: Maximum steps for task execution (-1 for infinite)
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Trace with reward from evaluation
|
|
189
|
+
"""
|
|
190
|
+
prompt_result = None
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
# Setup phase
|
|
194
|
+
start_context: list[types.ContentBlock] = []
|
|
195
|
+
|
|
196
|
+
# Extract the initial task information
|
|
197
|
+
if task.prompt:
|
|
198
|
+
start_context.extend(text_to_blocks(task.prompt))
|
|
199
|
+
|
|
200
|
+
# Execute the setup tool and append the initial observation to the context
|
|
201
|
+
if task.setup_tool is not None:
|
|
202
|
+
logger.info("Setting up tool phase: %s", task.setup_tool)
|
|
203
|
+
results = await self.call_tools(task.setup_tool)
|
|
204
|
+
if any(result.isError for result in results):
|
|
205
|
+
raise RuntimeError(f"{results}")
|
|
206
|
+
|
|
207
|
+
if self.append_setup_output and isinstance(results[0].content, list):
|
|
208
|
+
start_context.extend(results[0].content)
|
|
209
|
+
if not self.initial_screenshot:
|
|
210
|
+
start_context = await self._filter_messages(start_context, include_types=["text"])
|
|
211
|
+
|
|
212
|
+
# Execute the task (agent loop) - this returns a empty trace object with the final response # noqa: E501
|
|
213
|
+
prompt_result = await self._run_context(start_context, max_steps=max_steps)
|
|
214
|
+
|
|
215
|
+
except Exception as e:
|
|
216
|
+
logger.error("Task execution failed: %s", e)
|
|
217
|
+
# Create an error result but don't return yet - we still want to evaluate
|
|
218
|
+
prompt_result = Trace(reward=0.0, done=True, content=str(e), isError=True)
|
|
219
|
+
prompt_result.populate_from_context()
|
|
220
|
+
|
|
221
|
+
# Always evaluate if we have a prompt result and evaluate tool
|
|
222
|
+
if prompt_result is not None and task.evaluate_tool is not None:
|
|
223
|
+
try:
|
|
224
|
+
logger.info("Evaluating tool phase: %s", task.evaluate_tool)
|
|
225
|
+
results = await self.call_tools(task.evaluate_tool)
|
|
226
|
+
|
|
227
|
+
if any(result.isError for result in results):
|
|
228
|
+
raise RuntimeError(f"{results}")
|
|
229
|
+
|
|
230
|
+
# Extract reward and content from evaluation
|
|
231
|
+
if results:
|
|
232
|
+
reward = find_reward(results[0])
|
|
233
|
+
eval_content = find_content(results[0])
|
|
234
|
+
|
|
235
|
+
# Update the prompt result with evaluation reward
|
|
236
|
+
prompt_result.reward = reward
|
|
237
|
+
|
|
238
|
+
# Update the prompt result with evaluation content (if available)
|
|
239
|
+
if eval_content:
|
|
240
|
+
# Prompt result may already have final response content, so we append to it
|
|
241
|
+
if prompt_result.content:
|
|
242
|
+
prompt_result.content += "\n\n" + eval_content
|
|
243
|
+
else:
|
|
244
|
+
prompt_result.content = eval_content
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
logger.error("Evaluation phase failed: %s", e)
|
|
248
|
+
# Continue with the prompt result even if evaluation failed
|
|
249
|
+
|
|
250
|
+
return (
|
|
251
|
+
prompt_result
|
|
252
|
+
if prompt_result
|
|
253
|
+
else Trace(reward=0.0, done=True, content="No result available", isError=True)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
async def _run_context(
|
|
257
|
+
self, context: list[types.ContentBlock], *, max_steps: int = 10
|
|
258
|
+
) -> Trace:
|
|
259
|
+
"""
|
|
260
|
+
Run the agent with the given context messages. This is the core agent loop.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
context: The context to complete
|
|
264
|
+
max_steps: Maximum number of steps (-1 for infinite)
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Trace with reward, done, content fields and trace steps
|
|
268
|
+
"""
|
|
269
|
+
final_response = None
|
|
270
|
+
error = None
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
# Start with system messages
|
|
274
|
+
messages = await self.get_system_messages()
|
|
275
|
+
|
|
276
|
+
# Add initial context
|
|
277
|
+
messages.extend(await self.format_message(context))
|
|
278
|
+
logger.debug("Messages: %s", messages)
|
|
279
|
+
|
|
280
|
+
step_count = 0
|
|
281
|
+
while max_steps == -1 or step_count < max_steps:
|
|
282
|
+
step_count += 1
|
|
283
|
+
if max_steps == -1:
|
|
284
|
+
logger.info("Step %s (unlimited)", step_count)
|
|
285
|
+
else:
|
|
286
|
+
logger.info("Step %s/%s", step_count, max_steps)
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
# 1. Get model response
|
|
290
|
+
response = await self.get_response(messages)
|
|
291
|
+
|
|
292
|
+
logger.info("Agent:\n%s", response)
|
|
293
|
+
|
|
294
|
+
# Check if we should stop
|
|
295
|
+
if response.done or not response.tool_calls:
|
|
296
|
+
# Optional external ResponseAgent to decide whether to stop
|
|
297
|
+
decision = "STOP"
|
|
298
|
+
if self.response_agent is not None and response.content:
|
|
299
|
+
try:
|
|
300
|
+
decision = await self.response_agent.determine_response(
|
|
301
|
+
response.content
|
|
302
|
+
)
|
|
303
|
+
except Exception as e:
|
|
304
|
+
logger.warning("ResponseAgent failed: %s", e)
|
|
305
|
+
if decision == "STOP":
|
|
306
|
+
logger.info("Stopping execution")
|
|
307
|
+
final_response = response
|
|
308
|
+
break
|
|
309
|
+
else:
|
|
310
|
+
logger.info("Continuing execution")
|
|
311
|
+
messages.extend(await self.format_message(decision))
|
|
312
|
+
continue
|
|
313
|
+
|
|
314
|
+
# 2. Execute tools
|
|
315
|
+
tool_calls = response.tool_calls
|
|
316
|
+
tool_results = await self.call_tools(tool_calls)
|
|
317
|
+
|
|
318
|
+
# 3. Format tool results and add to messages
|
|
319
|
+
tool_messages = await self.format_tool_results(tool_calls, tool_results)
|
|
320
|
+
messages.extend(tool_messages)
|
|
321
|
+
|
|
322
|
+
except Exception as e:
|
|
323
|
+
logger.error("Step failed: %s", e)
|
|
324
|
+
error = str(e)
|
|
325
|
+
break
|
|
326
|
+
|
|
327
|
+
except KeyboardInterrupt:
|
|
328
|
+
logger.info("Agent execution interrupted by user")
|
|
329
|
+
error = "Interrupted by user"
|
|
330
|
+
except asyncio.CancelledError:
|
|
331
|
+
logger.info("Agent execution cancelled")
|
|
332
|
+
error = "Cancelled"
|
|
333
|
+
except Exception as e:
|
|
334
|
+
logger.error("Unexpected error: %s", e)
|
|
335
|
+
error = str(e)
|
|
336
|
+
|
|
337
|
+
# Build result
|
|
338
|
+
trace_result = Trace(
|
|
339
|
+
reward=0.0, # Default - will be set by task evaluation if applicable
|
|
340
|
+
done=True,
|
|
341
|
+
content=final_response.content if final_response else None,
|
|
342
|
+
isError=error is not None,
|
|
343
|
+
info={"error": error} if error else {},
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Populate trace steps from current context
|
|
347
|
+
trace_result.populate_from_context()
|
|
348
|
+
|
|
349
|
+
return trace_result
|
|
350
|
+
|
|
351
|
+
async def call_tools(
|
|
352
|
+
self, tool_call: MCPToolCall | list[MCPToolCall] | None = None
|
|
353
|
+
) -> list[MCPToolResult]:
|
|
354
|
+
"""
|
|
355
|
+
Call a tool through the MCP client.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
tool_call: MCPToolCall or list of MCPToolCall
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
List of MCPToolResult
|
|
362
|
+
"""
|
|
363
|
+
if tool_call is None:
|
|
364
|
+
return []
|
|
365
|
+
|
|
366
|
+
if isinstance(tool_call, MCPToolCall):
|
|
367
|
+
tool_call = [tool_call]
|
|
368
|
+
|
|
369
|
+
if self.mcp_client is None:
|
|
370
|
+
raise ValueError("Client is not initialized")
|
|
371
|
+
|
|
372
|
+
results: list[MCPToolResult] = []
|
|
373
|
+
for tc in tool_call:
|
|
374
|
+
try:
|
|
375
|
+
logger.info("Calling tool: %s", tc)
|
|
376
|
+
results.append(await self.mcp_client.call_tool(tc))
|
|
377
|
+
except TimeoutError as e:
|
|
378
|
+
logger.error("Tool execution timed out: %s", e)
|
|
379
|
+
try:
|
|
380
|
+
await self.mcp_client.shutdown()
|
|
381
|
+
except Exception as close_err:
|
|
382
|
+
logger.debug("Failed to close MCP client cleanly: %s", close_err)
|
|
383
|
+
raise
|
|
384
|
+
except Exception as e:
|
|
385
|
+
logger.error("Tool execution failed: %s", e)
|
|
386
|
+
results.append(_format_error_result(str(e)))
|
|
387
|
+
return results
|
|
388
|
+
|
|
389
|
+
@abstractmethod
|
|
390
|
+
async def get_system_messages(self) -> list[Any]:
|
|
391
|
+
"""
|
|
392
|
+
Get the system prompt.
|
|
393
|
+
"""
|
|
394
|
+
raise NotImplementedError
|
|
395
|
+
|
|
396
|
+
@abstractmethod
|
|
397
|
+
async def get_response(
|
|
398
|
+
self, messages: list[Any]
|
|
399
|
+
) -> AgentResponse: # maybe type messages as list[types.ContentBlock]
|
|
400
|
+
"""
|
|
401
|
+
Get response from the model including any tool calls.
|
|
402
|
+
|
|
403
|
+
NOTE: Subclasses should decorate this method with:
|
|
404
|
+
@hud.instrument(span_type="agent", record_args=False, record_result=True)
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
messages: Current conversation messages
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
AgentResponse with content, tool_calls, and done fields
|
|
411
|
+
"""
|
|
412
|
+
raise NotImplementedError
|
|
413
|
+
|
|
414
|
+
@abstractmethod
|
|
415
|
+
async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]:
|
|
416
|
+
"""
|
|
417
|
+
Format a list of content blocks into a list of messages.
|
|
418
|
+
"""
|
|
419
|
+
raise NotImplementedError
|
|
420
|
+
|
|
421
|
+
@abstractmethod
|
|
422
|
+
async def format_tool_results(
|
|
423
|
+
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
424
|
+
) -> list[Any]:
|
|
425
|
+
"""
|
|
426
|
+
Format tool results into messages for the model.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
tool_calls: List of MCPToolCall objects that were executed
|
|
430
|
+
tool_results: List of MCPToolResult objects from tool execution
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
List of formatted messages to append to conversation
|
|
434
|
+
"""
|
|
435
|
+
raise NotImplementedError
|
|
436
|
+
|
|
437
|
+
async def format_message(
|
|
438
|
+
self,
|
|
439
|
+
message: str
|
|
440
|
+
| list[str]
|
|
441
|
+
| types.ContentBlock
|
|
442
|
+
| list[types.ContentBlock]
|
|
443
|
+
| list[str | types.ContentBlock],
|
|
444
|
+
) -> list[Any]: # maybe type messages as list[types.ContentBlock]
|
|
445
|
+
"""
|
|
446
|
+
Convencience function.
|
|
447
|
+
|
|
448
|
+
Format a single content message into a list of messages for the model.
|
|
449
|
+
"""
|
|
450
|
+
blocks: list[types.ContentBlock] = []
|
|
451
|
+
if not isinstance(message, list):
|
|
452
|
+
message = [message]
|
|
453
|
+
|
|
454
|
+
for m in message:
|
|
455
|
+
if isinstance(m, str):
|
|
456
|
+
blocks.append(types.TextContent(text=m, type="text"))
|
|
457
|
+
elif isinstance(m, types.ContentBlock):
|
|
458
|
+
blocks.append(m)
|
|
459
|
+
else:
|
|
460
|
+
raise ValueError(f"Invalid message type: {type(m)}")
|
|
461
|
+
|
|
462
|
+
return await self.format_blocks(blocks)
|
|
463
|
+
|
|
464
|
+
async def _filter_tools(self) -> None:
|
|
465
|
+
"""Apply tool filtering based on allowed/disallowed lists."""
|
|
466
|
+
# Get all tools from client
|
|
467
|
+
if self.mcp_client is None:
|
|
468
|
+
raise ValueError("MCP client is not initialized")
|
|
469
|
+
|
|
470
|
+
all_tools = await self.mcp_client.list_tools()
|
|
471
|
+
|
|
472
|
+
# Filter tools
|
|
473
|
+
self._available_tools = []
|
|
474
|
+
self._tool_map = {}
|
|
475
|
+
|
|
476
|
+
for tool in all_tools:
|
|
477
|
+
# Check if tool should be included
|
|
478
|
+
if self.allowed_tools and tool.name not in self.allowed_tools:
|
|
479
|
+
continue
|
|
480
|
+
if tool.name in self.disallowed_tools:
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
self._available_tools.append(tool)
|
|
484
|
+
# Simplified mapping - just tool name to tool
|
|
485
|
+
self._tool_map[tool.name] = tool
|
|
486
|
+
|
|
487
|
+
async def _setup_config(self, mcp_config: dict[str, dict[str, Any]]) -> None:
|
|
488
|
+
"""Inject metadata into the metadata of the initialize request."""
|
|
489
|
+
if self.metadata:
|
|
490
|
+
patch_mcp_config(
|
|
491
|
+
mcp_config,
|
|
492
|
+
MCPConfigPatch(meta=self.metadata),
|
|
493
|
+
)
|
|
494
|
+
setup_hud_telemetry(mcp_config, auto_trace=self._auto_trace)
|
|
495
|
+
|
|
496
|
+
def get_available_tools(self) -> list[types.Tool]:
|
|
497
|
+
"""Get list of available MCP tools for LLM use (excludes lifecycle tools)."""
|
|
498
|
+
lifecycle_tool_names = self.lifecycle_tools
|
|
499
|
+
return [tool for tool in self._available_tools if tool.name not in lifecycle_tool_names]
|
|
500
|
+
|
|
501
|
+
def get_tool_schemas(self) -> list[dict]:
|
|
502
|
+
"""Get tool schemas in a format suitable for the model."""
|
|
503
|
+
schemas = []
|
|
504
|
+
for tool in self.get_available_tools():
|
|
505
|
+
schema = {
|
|
506
|
+
"name": tool.name,
|
|
507
|
+
"description": tool.description,
|
|
508
|
+
}
|
|
509
|
+
if tool.inputSchema:
|
|
510
|
+
schema["parameters"] = tool.inputSchema
|
|
511
|
+
schemas.append(schema)
|
|
512
|
+
return schemas
|
|
513
|
+
|
|
514
|
+
async def _filter_messages(
|
|
515
|
+
self,
|
|
516
|
+
message_list: list[types.ContentBlock],
|
|
517
|
+
include_types: list[
|
|
518
|
+
Literal["text", "image", "audio", "resource_link", "embedded_resource"]
|
|
519
|
+
],
|
|
520
|
+
) -> list[types.ContentBlock]:
|
|
521
|
+
"""
|
|
522
|
+
Filter a list of messages and return only the messages of the given types.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
message_list: The list of messages to filter
|
|
526
|
+
include_types: List of types to include (None = all types)
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
List of messages in provider-specific format
|
|
530
|
+
"""
|
|
531
|
+
return [message for message in message_list if message.type in include_types]
|
|
532
|
+
|
|
533
|
+
async def _cleanup(self) -> None:
|
|
534
|
+
"""Cleanup resources."""
|
|
535
|
+
if self._auto_created_client and self.mcp_client:
|
|
536
|
+
try:
|
|
537
|
+
await self.mcp_client.shutdown()
|
|
538
|
+
logger.info("Closed auto-created MCPClient")
|
|
539
|
+
except Exception as e:
|
|
540
|
+
logger.warning("Failed to close auto-created client: %s", e)
|
|
541
|
+
finally:
|
|
542
|
+
self.mcp_client = None
|
|
543
|
+
self._auto_created_client = False
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def _format_error_result(error_message: str) -> MCPToolResult:
|
|
547
|
+
return MCPToolResult(content=text_to_blocks(error_message), isError=True)
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def text_to_blocks(text: str) -> list[types.ContentBlock]:
|
|
551
|
+
return [types.TextContent(text=text, type="text")]
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def find_reward(result: MCPToolResult) -> float:
|
|
555
|
+
"""Find the reward in the result.
|
|
556
|
+
|
|
557
|
+
Agent accepts "reward", "grade", "score"
|
|
558
|
+
|
|
559
|
+
If not found, return 0.0
|
|
560
|
+
"""
|
|
561
|
+
accept_keys = ["reward", "grade", "score"]
|
|
562
|
+
for key in accept_keys:
|
|
563
|
+
if isinstance(result.structuredContent, dict) and key in result.structuredContent:
|
|
564
|
+
return result.structuredContent[key]
|
|
565
|
+
if isinstance(result.content, list):
|
|
566
|
+
for content in result.content:
|
|
567
|
+
if isinstance(content, types.TextContent):
|
|
568
|
+
try:
|
|
569
|
+
json_content = json.loads(content.text)
|
|
570
|
+
for key, value in json_content.items():
|
|
571
|
+
if key in accept_keys:
|
|
572
|
+
return value
|
|
573
|
+
except json.JSONDecodeError:
|
|
574
|
+
pass
|
|
575
|
+
return 0.0
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def find_content(result: MCPToolResult) -> str | None:
|
|
579
|
+
"""Find the content in the result.
|
|
580
|
+
|
|
581
|
+
Agent accepts "content", "text", "message", or "logs"
|
|
582
|
+
|
|
583
|
+
If not found, return 0.0
|
|
584
|
+
"""
|
|
585
|
+
accept_keys = ["content", "text", "message", "logs"]
|
|
586
|
+
for key in accept_keys:
|
|
587
|
+
if isinstance(result.structuredContent, dict) and key in result.structuredContent:
|
|
588
|
+
return result.structuredContent[key]
|
|
589
|
+
if isinstance(result.content, list):
|
|
590
|
+
for content in result.content:
|
|
591
|
+
if isinstance(content, types.TextContent):
|
|
592
|
+
try:
|
|
593
|
+
json_content = json.loads(content.text)
|
|
594
|
+
for key, value in json_content.items():
|
|
595
|
+
if key in accept_keys:
|
|
596
|
+
return value
|
|
597
|
+
except json.JSONDecodeError:
|
|
598
|
+
pass
|
|
599
|
+
return ""
|