hud-python 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +7 -4
- hud/adapters/common/adapter.py +14 -3
- hud/adapters/common/tests/test_adapter.py +16 -4
- hud/datasets.py +188 -0
- hud/env/docker_client.py +14 -2
- hud/env/local_docker_client.py +28 -6
- hud/gym.py +0 -9
- hud/{mcp_agent → mcp}/__init__.py +2 -0
- hud/mcp/base.py +631 -0
- hud/{mcp_agent → mcp}/claude.py +52 -47
- hud/mcp/client.py +312 -0
- hud/{mcp_agent → mcp}/langchain.py +52 -33
- hud/{mcp_agent → mcp}/openai.py +56 -40
- hud/{mcp_agent → mcp}/tests/test_base.py +129 -54
- hud/mcp/tests/test_claude.py +294 -0
- hud/mcp/tests/test_client.py +324 -0
- hud/mcp/tests/test_openai.py +238 -0
- hud/settings.py +6 -0
- hud/task.py +1 -88
- hud/taskset.py +2 -23
- hud/telemetry/__init__.py +5 -0
- hud/telemetry/_trace.py +180 -17
- hud/telemetry/context.py +79 -0
- hud/telemetry/exporter.py +165 -6
- hud/telemetry/job.py +141 -0
- hud/telemetry/tests/test_trace.py +36 -25
- hud/tools/__init__.py +14 -1
- hud/tools/executors/__init__.py +19 -2
- hud/tools/executors/pyautogui.py +84 -50
- hud/tools/executors/tests/test_pyautogui_executor.py +4 -1
- hud/tools/playwright_tool.py +73 -67
- hud/tools/tests/test_edit.py +8 -1
- hud/tools/tests/test_tools.py +3 -0
- hud/trajectory.py +5 -1
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/METADATA +20 -14
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/RECORD +41 -46
- hud/evaluators/__init__.py +0 -9
- hud/evaluators/base.py +0 -32
- hud/evaluators/inspect.py +0 -24
- hud/evaluators/judge.py +0 -189
- hud/evaluators/match.py +0 -156
- hud/evaluators/remote.py +0 -65
- hud/evaluators/tests/__init__.py +0 -0
- hud/evaluators/tests/test_inspect.py +0 -12
- hud/evaluators/tests/test_judge.py +0 -231
- hud/evaluators/tests/test_match.py +0 -115
- hud/evaluators/tests/test_remote.py +0 -98
- hud/mcp_agent/base.py +0 -723
- /hud/{mcp_agent → mcp}/tests/__init__.py +0 -0
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/WHEEL +0 -0
- {hud_python-0.3.0.dist-info → hud_python-0.3.1.dist-info}/licenses/LICENSE +0 -0
hud/{mcp_agent → mcp}/claude.py
RENAMED
|
@@ -17,9 +17,15 @@ if TYPE_CHECKING:
|
|
|
17
17
|
BetaToolResultBlockParam,
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
+
from hud.datasets import TaskConfig
|
|
21
|
+
|
|
22
|
+
import mcp.types as types
|
|
23
|
+
from mcp.types import CallToolRequestParams as MCPToolCall
|
|
24
|
+
from mcp.types import CallToolResult as MCPToolResult
|
|
25
|
+
|
|
20
26
|
from hud.settings import settings
|
|
21
27
|
|
|
22
|
-
from .base import BaseMCPAgent
|
|
28
|
+
from .base import BaseMCPAgent, ModelResponse
|
|
23
29
|
|
|
24
30
|
logger = logging.getLogger(__name__)
|
|
25
31
|
|
|
@@ -66,13 +72,13 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
66
72
|
Initialize Claude MCP agent.
|
|
67
73
|
|
|
68
74
|
Args:
|
|
69
|
-
|
|
75
|
+
model_client: AsyncAnthropic client (created if not provided)
|
|
70
76
|
model: Claude model to use
|
|
71
77
|
max_tokens: Maximum tokens for response
|
|
72
78
|
display_width_px: Display width for computer use tools
|
|
73
79
|
display_height_px: Display height for computer use tools
|
|
74
80
|
use_computer_beta: Whether to use computer-use beta features
|
|
75
|
-
**kwargs: Additional arguments passed to BaseMCPAgent
|
|
81
|
+
**kwargs: Additional arguments passed to BaseMCPAgent (including mcp_client)
|
|
76
82
|
"""
|
|
77
83
|
super().__init__(**kwargs)
|
|
78
84
|
|
|
@@ -90,17 +96,19 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
90
96
|
self.display_height_px = display_height_px
|
|
91
97
|
self.use_computer_beta = use_computer_beta
|
|
92
98
|
|
|
99
|
+
self.model_name = self.model
|
|
100
|
+
|
|
93
101
|
# Track mapping from Claude tool names to MCP tool names
|
|
94
102
|
self._claude_to_mcp_tool_map: dict[str, str] = {}
|
|
95
103
|
|
|
96
|
-
async def initialize(self) -> None:
|
|
104
|
+
async def initialize(self, task: str | TaskConfig | None = None) -> None:
|
|
97
105
|
"""Initialize the agent and build tool mappings."""
|
|
98
|
-
await super().initialize()
|
|
106
|
+
await super().initialize(task)
|
|
99
107
|
# Build tool mappings after tools are discovered
|
|
100
108
|
self._convert_tools_for_claude()
|
|
101
109
|
|
|
102
110
|
async def create_initial_messages(
|
|
103
|
-
self, prompt: str, screenshot: str | None
|
|
111
|
+
self, prompt: str, screenshot: str | None = None
|
|
104
112
|
) -> list[BetaMessageParam]:
|
|
105
113
|
"""Create initial messages for Claude."""
|
|
106
114
|
user_content: list[BetaImageBlockParam | BetaTextBlockParam] = []
|
|
@@ -123,9 +131,7 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
123
131
|
)
|
|
124
132
|
]
|
|
125
133
|
|
|
126
|
-
async def get_model_response(
|
|
127
|
-
self, messages: list[BetaMessageParam], step: int
|
|
128
|
-
) -> dict[str, Any]:
|
|
134
|
+
async def get_model_response(self, messages: list[BetaMessageParam]) -> ModelResponse:
|
|
129
135
|
"""Get response from Claude including any tool calls."""
|
|
130
136
|
# Get Claude tools
|
|
131
137
|
claude_tools = self._convert_tools_for_claude()
|
|
@@ -166,7 +172,6 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
166
172
|
else:
|
|
167
173
|
raise
|
|
168
174
|
|
|
169
|
-
# Add assistant response to messages (for next step)
|
|
170
175
|
messages.append(
|
|
171
176
|
cast(
|
|
172
177
|
"BetaMessageParam",
|
|
@@ -178,12 +183,7 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
178
183
|
)
|
|
179
184
|
|
|
180
185
|
# Process response
|
|
181
|
-
result =
|
|
182
|
-
"content": "",
|
|
183
|
-
"tool_calls": [],
|
|
184
|
-
"done": True,
|
|
185
|
-
"raw_response": response.model_dump(), # For debugging
|
|
186
|
-
}
|
|
186
|
+
result = ModelResponse(content="", tool_calls=[], done=True)
|
|
187
187
|
|
|
188
188
|
# Extract text content and reasoning
|
|
189
189
|
text_content = ""
|
|
@@ -194,16 +194,16 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
194
194
|
# Map Claude tool name back to MCP tool name
|
|
195
195
|
mcp_tool_name = self._claude_to_mcp_tool_map.get(block.name, block.name)
|
|
196
196
|
|
|
197
|
-
#
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
}
|
|
197
|
+
# Create MCPToolCall object with Claude metadata as extra fields
|
|
198
|
+
# Pyright will complain but the tool class accepts extra fields
|
|
199
|
+
tool_call = MCPToolCall(
|
|
200
|
+
name=mcp_tool_name,
|
|
201
|
+
arguments=block.input,
|
|
202
|
+
tool_use_id=block.id, # type: ignore
|
|
203
|
+
claude_name=block.name, # type: ignore
|
|
205
204
|
)
|
|
206
|
-
result
|
|
205
|
+
result.tool_calls.append(tool_call)
|
|
206
|
+
result.done = False
|
|
207
207
|
elif block.type == "text":
|
|
208
208
|
text_content += block.text
|
|
209
209
|
elif hasattr(block, "type") and block.type == "thinking":
|
|
@@ -211,41 +211,44 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
211
211
|
|
|
212
212
|
# Combine text and thinking for final content
|
|
213
213
|
if thinking_content:
|
|
214
|
-
result
|
|
214
|
+
result.content = thinking_content + text_content
|
|
215
215
|
else:
|
|
216
|
-
result
|
|
216
|
+
result.content = text_content
|
|
217
217
|
|
|
218
218
|
return result
|
|
219
219
|
|
|
220
220
|
async def format_tool_results(
|
|
221
|
-
self,
|
|
221
|
+
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
222
222
|
) -> list[BetaMessageParam]:
|
|
223
223
|
"""Format tool results into Claude messages."""
|
|
224
|
-
# Build a mapping of tool_name to tool_use_id from the original calls
|
|
225
|
-
tool_id_map = {}
|
|
226
|
-
for tool_call in tool_calls:
|
|
227
|
-
if "tool_use_id" in tool_call:
|
|
228
|
-
tool_id_map[tool_call["name"]] = tool_call["tool_use_id"]
|
|
229
|
-
|
|
230
224
|
# Process each tool result
|
|
231
225
|
user_content = []
|
|
232
226
|
|
|
233
|
-
for
|
|
234
|
-
#
|
|
235
|
-
tool_use_id =
|
|
227
|
+
for tool_call, result in zip(tool_calls, tool_results, strict=True):
|
|
228
|
+
# Extract Claude-specific metadata from extra fields
|
|
229
|
+
tool_use_id = getattr(tool_call, "tool_use_id", None)
|
|
236
230
|
if not tool_use_id:
|
|
237
|
-
logger.warning("No tool_use_id found for %s",
|
|
231
|
+
logger.warning("No tool_use_id found for %s", tool_call.name)
|
|
238
232
|
continue
|
|
239
233
|
|
|
240
|
-
# Convert
|
|
234
|
+
# Convert MCP tool results to Claude format
|
|
241
235
|
claude_blocks = []
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
236
|
+
|
|
237
|
+
if result.isError:
|
|
238
|
+
# Extract error message from content
|
|
239
|
+
error_msg = "Tool execution failed"
|
|
240
|
+
for content in result.content:
|
|
241
|
+
if isinstance(content, types.TextContent):
|
|
242
|
+
error_msg = content.text
|
|
243
|
+
break
|
|
244
|
+
claude_blocks.append(text_to_content_block(f"Error: {error_msg}"))
|
|
245
|
+
else:
|
|
246
|
+
# Process success content
|
|
247
|
+
for content in result.content:
|
|
248
|
+
if isinstance(content, types.TextContent):
|
|
249
|
+
claude_blocks.append(text_to_content_block(content.text))
|
|
250
|
+
elif isinstance(content, types.ImageContent):
|
|
251
|
+
claude_blocks.append(base64_to_content_block(content.data))
|
|
249
252
|
|
|
250
253
|
# Add tool result
|
|
251
254
|
user_content.append(tool_use_content_block(tool_use_id, claude_blocks))
|
|
@@ -282,7 +285,7 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
282
285
|
}
|
|
283
286
|
# Map Claude's "computer" back to the actual MCP tool name
|
|
284
287
|
self._claude_to_mcp_tool_map["computer"] = tool.name
|
|
285
|
-
|
|
288
|
+
elif tool.name not in self.lifecycle_tools:
|
|
286
289
|
# Convert regular tools
|
|
287
290
|
claude_tool = {
|
|
288
291
|
"name": tool.name,
|
|
@@ -295,6 +298,8 @@ class ClaudeMCPAgent(BaseMCPAgent):
|
|
|
295
298
|
}
|
|
296
299
|
# Direct mapping for non-computer tools
|
|
297
300
|
self._claude_to_mcp_tool_map[tool.name] = tool.name
|
|
301
|
+
else:
|
|
302
|
+
continue
|
|
298
303
|
|
|
299
304
|
claude_tools.append(claude_tool)
|
|
300
305
|
|
hud/mcp/client.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""MCP Client wrapper with automatic initialization and debugging capabilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
from mcp_use.client import MCPClient as MCPUseClient
|
|
10
|
+
from pydantic import AnyUrl
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from typing import Self
|
|
14
|
+
|
|
15
|
+
from mcp import types
|
|
16
|
+
from mcp_use.session import MCPSession as MCPUseSession
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MCPClient:
|
|
22
|
+
"""
|
|
23
|
+
High-level MCP client wrapper that handles initialization, tool discovery,
|
|
24
|
+
and provides debugging capabilities.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
mcp_config: dict[str, dict[str, Any]],
|
|
30
|
+
verbose: bool = False,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Initialize the MCP client.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
mcp_config: MCP server configuration dict (required)
|
|
37
|
+
verbose: Enable verbose logging of server communications
|
|
38
|
+
auto_initialize: Whether to automatically initialize on construction
|
|
39
|
+
"""
|
|
40
|
+
self.verbose = verbose
|
|
41
|
+
|
|
42
|
+
# Initialize mcp_use client with proper config
|
|
43
|
+
# Use from_dict to properly initialize with config
|
|
44
|
+
config = {"mcpServers": mcp_config}
|
|
45
|
+
self._mcp_client = MCPUseClient.from_dict(config)
|
|
46
|
+
|
|
47
|
+
self._sessions: dict[str, MCPUseSession] = {}
|
|
48
|
+
self._available_tools: list[types.Tool] = []
|
|
49
|
+
self._tool_map: dict[str, tuple[str, types.Tool]] = {}
|
|
50
|
+
self._telemetry_data: dict[str, Any] = {}
|
|
51
|
+
|
|
52
|
+
# Set up verbose logging if requested
|
|
53
|
+
if self.verbose:
|
|
54
|
+
self._setup_verbose_logging()
|
|
55
|
+
|
|
56
|
+
def _setup_verbose_logging(self) -> None:
|
|
57
|
+
"""Configure verbose logging for debugging."""
|
|
58
|
+
# Set MCP-related loggers to DEBUG
|
|
59
|
+
logging.getLogger("mcp").setLevel(logging.DEBUG)
|
|
60
|
+
logging.getLogger("mcp_use").setLevel(logging.DEBUG)
|
|
61
|
+
logging.getLogger("mcp.client.stdio").setLevel(logging.DEBUG)
|
|
62
|
+
|
|
63
|
+
# Add handler for server communications
|
|
64
|
+
if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
|
|
65
|
+
handler = logging.StreamHandler()
|
|
66
|
+
handler.setFormatter(
|
|
67
|
+
logging.Formatter("[%(levelname)s] %(asctime)s - %(name)s - %(message)s")
|
|
68
|
+
)
|
|
69
|
+
logger.addHandler(handler)
|
|
70
|
+
logger.setLevel(logging.DEBUG)
|
|
71
|
+
|
|
72
|
+
async def initialize(self) -> None:
|
|
73
|
+
"""Perform async initialization tasks."""
|
|
74
|
+
await self.create_sessions()
|
|
75
|
+
await self.discover_tools()
|
|
76
|
+
await self.fetch_telemetry()
|
|
77
|
+
|
|
78
|
+
async def create_sessions(self) -> dict[str, MCPUseSession]:
|
|
79
|
+
# Create all sessions at once
|
|
80
|
+
try:
|
|
81
|
+
self._sessions = await self._mcp_client.create_all_sessions()
|
|
82
|
+
except Exception as e:
|
|
83
|
+
# If session creation fails, try to get Docker logs
|
|
84
|
+
logger.error("Failed to create sessions: %s", e)
|
|
85
|
+
if self.verbose:
|
|
86
|
+
logger.info("Attempting to check Docker container status...")
|
|
87
|
+
# await self._check_docker_containers()
|
|
88
|
+
raise
|
|
89
|
+
|
|
90
|
+
# Log session details in verbose mode
|
|
91
|
+
if self.verbose and self._sessions:
|
|
92
|
+
for name, session in self._sessions.items():
|
|
93
|
+
logger.debug(" - %s: %s", name, type(session).__name__)
|
|
94
|
+
|
|
95
|
+
return self._sessions
|
|
96
|
+
|
|
97
|
+
async def discover_tools(self) -> list[types.Tool]:
|
|
98
|
+
"""Discover all available tools from connected servers."""
|
|
99
|
+
logger.info("Discovering available tools...")
|
|
100
|
+
|
|
101
|
+
self._available_tools = []
|
|
102
|
+
self._tool_map = {}
|
|
103
|
+
|
|
104
|
+
for server_name, session in self._sessions.items():
|
|
105
|
+
try:
|
|
106
|
+
# Ensure session is initialized
|
|
107
|
+
if not hasattr(session, "connector") or not hasattr(
|
|
108
|
+
session.connector, "client_session"
|
|
109
|
+
):
|
|
110
|
+
await session.initialize()
|
|
111
|
+
|
|
112
|
+
if session.connector.client_session is None:
|
|
113
|
+
logger.warning("Client session not initialized for %s", server_name)
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
# List tools
|
|
117
|
+
tools_result = await session.connector.client_session.list_tools()
|
|
118
|
+
|
|
119
|
+
logger.info(
|
|
120
|
+
"Discovered %d tools from '%s': %s",
|
|
121
|
+
len(tools_result.tools),
|
|
122
|
+
server_name,
|
|
123
|
+
[tool.name for tool in tools_result.tools],
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Add to collections
|
|
127
|
+
for tool in tools_result.tools:
|
|
128
|
+
self._available_tools.append(tool)
|
|
129
|
+
self._tool_map[tool.name] = (server_name, tool)
|
|
130
|
+
|
|
131
|
+
# Log detailed tool info in verbose mode
|
|
132
|
+
if self.verbose:
|
|
133
|
+
for tool in tools_result.tools:
|
|
134
|
+
description = tool.description or ""
|
|
135
|
+
logger.debug(
|
|
136
|
+
" Tool '%s': %s",
|
|
137
|
+
tool.name,
|
|
138
|
+
description[:100] + "..." if len(description) > 100 else description,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error("Error discovering tools from '%s': %s", server_name, e)
|
|
143
|
+
if self.verbose:
|
|
144
|
+
logger.exception("Full error details:")
|
|
145
|
+
|
|
146
|
+
logger.info("Total tools discovered: %d", len(self._available_tools))
|
|
147
|
+
return self._available_tools
|
|
148
|
+
|
|
149
|
+
async def fetch_telemetry(self) -> dict[str, Any]:
|
|
150
|
+
"""Fetch telemetry resource from all servers that provide it."""
|
|
151
|
+
logger.info("Fetching telemetry resources...")
|
|
152
|
+
|
|
153
|
+
for server_name, session in self._sessions.items():
|
|
154
|
+
try:
|
|
155
|
+
if not hasattr(session, "connector") or not hasattr(
|
|
156
|
+
session.connector, "client_session"
|
|
157
|
+
):
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
if session.connector.client_session is None:
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
# Try to read telemetry resource
|
|
164
|
+
try:
|
|
165
|
+
result = await session.connector.client_session.read_resource(
|
|
166
|
+
AnyUrl("telemetry://live")
|
|
167
|
+
)
|
|
168
|
+
if result and result.contents and len(result.contents) > 0:
|
|
169
|
+
telemetry_data = json.loads(result.contents[0].text) # type: ignore
|
|
170
|
+
self._telemetry_data[server_name] = telemetry_data
|
|
171
|
+
|
|
172
|
+
logger.info("📡 Telemetry data from server '%s':", server_name)
|
|
173
|
+
if "live_url" in telemetry_data:
|
|
174
|
+
logger.info(" 🖥️ Live URL: %s", telemetry_data["live_url"])
|
|
175
|
+
if "status" in telemetry_data:
|
|
176
|
+
logger.info(" 📊 Status: %s", telemetry_data["status"])
|
|
177
|
+
if "services" in telemetry_data:
|
|
178
|
+
logger.info(" 📋 Services:")
|
|
179
|
+
for service, status in telemetry_data["services"].items():
|
|
180
|
+
status_icon = "✅" if status == "running" else "❌"
|
|
181
|
+
logger.info(" %s %s: %s", status_icon, service, status)
|
|
182
|
+
|
|
183
|
+
if self.verbose:
|
|
184
|
+
logger.debug(
|
|
185
|
+
"Full telemetry data:\n%s", json.dumps(telemetry_data, indent=2)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
# Resource might not exist, which is fine
|
|
190
|
+
if self.verbose:
|
|
191
|
+
logger.debug("No telemetry resource from '%s': %s", server_name, e)
|
|
192
|
+
|
|
193
|
+
except Exception as e:
|
|
194
|
+
logger.error("Error fetching telemetry from '%s': %s", server_name, e)
|
|
195
|
+
|
|
196
|
+
return self._telemetry_data
|
|
197
|
+
|
|
198
|
+
async def call_tool(
|
|
199
|
+
self, tool_name: str, arguments: dict[str, Any] | None = None
|
|
200
|
+
) -> types.CallToolResult:
|
|
201
|
+
"""
|
|
202
|
+
Call a tool by name with the given arguments.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
tool_name: Name of the tool to call
|
|
206
|
+
arguments: Tool arguments
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Tool execution result
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
ValueError: If tool not found
|
|
213
|
+
"""
|
|
214
|
+
if tool_name not in self._tool_map:
|
|
215
|
+
raise ValueError(f"Tool '{tool_name}' not found")
|
|
216
|
+
|
|
217
|
+
server_name, tool = self._tool_map[tool_name]
|
|
218
|
+
session = self._sessions[server_name]
|
|
219
|
+
|
|
220
|
+
if self.verbose:
|
|
221
|
+
logger.debug(
|
|
222
|
+
"Calling tool '%s' on server '%s' with arguments: %s",
|
|
223
|
+
tool_name,
|
|
224
|
+
server_name,
|
|
225
|
+
json.dumps(arguments, indent=2) if arguments else "None",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if session.connector.client_session is None:
|
|
229
|
+
raise ValueError(f"Client session not initialized for {server_name}")
|
|
230
|
+
|
|
231
|
+
result = await session.connector.client_session.call_tool(
|
|
232
|
+
name=tool_name, arguments=arguments or {}
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if self.verbose:
|
|
236
|
+
logger.debug("Tool '%s' result: %s", tool_name, result)
|
|
237
|
+
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult | None:
|
|
241
|
+
"""
|
|
242
|
+
Read a resource by URI from any server that provides it.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
uri: Resource URI (e.g., "telemetry://live")
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Resource contents or None if not found
|
|
249
|
+
"""
|
|
250
|
+
for server_name, session in self._sessions.items():
|
|
251
|
+
try:
|
|
252
|
+
if not hasattr(session, "connector") or not hasattr(
|
|
253
|
+
session.connector, "client_session"
|
|
254
|
+
):
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
if session.connector.client_session is None:
|
|
258
|
+
continue
|
|
259
|
+
|
|
260
|
+
result = await session.connector.client_session.read_resource(uri)
|
|
261
|
+
|
|
262
|
+
if self.verbose:
|
|
263
|
+
logger.debug(
|
|
264
|
+
"Successfully read resource '%s' from server '%s'", uri, server_name
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
return result
|
|
268
|
+
|
|
269
|
+
except Exception as e:
|
|
270
|
+
if self.verbose:
|
|
271
|
+
logger.debug(
|
|
272
|
+
"Could not read resource '%s' from server '%s': %s", uri, server_name, e
|
|
273
|
+
)
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
return None
|
|
277
|
+
|
|
278
|
+
def get_available_tools(self) -> list[types.Tool]:
|
|
279
|
+
"""Get list of all available tools."""
|
|
280
|
+
return self._available_tools
|
|
281
|
+
|
|
282
|
+
def get_tool_map(self) -> dict[str, tuple[str, types.Tool]]:
|
|
283
|
+
"""Get mapping of tool names to (server_name, tool) tuples."""
|
|
284
|
+
return self._tool_map
|
|
285
|
+
|
|
286
|
+
def get_sessions(self) -> dict[str, MCPUseSession]:
|
|
287
|
+
"""Get active MCP sessions."""
|
|
288
|
+
return self._sessions
|
|
289
|
+
|
|
290
|
+
def get_telemetry_data(self) -> dict[str, Any]:
|
|
291
|
+
"""Get collected telemetry data from all servers."""
|
|
292
|
+
return self._telemetry_data
|
|
293
|
+
|
|
294
|
+
def get_all_active_sessions(self) -> dict[str, MCPUseSession]:
|
|
295
|
+
"""Get all active sessions (compatibility method)."""
|
|
296
|
+
return self._sessions
|
|
297
|
+
|
|
298
|
+
async def close(self) -> None:
|
|
299
|
+
"""Close all active sessions."""
|
|
300
|
+
await self._mcp_client.close_all_sessions()
|
|
301
|
+
|
|
302
|
+
self._sessions = {}
|
|
303
|
+
self._available_tools = []
|
|
304
|
+
self._tool_map = {}
|
|
305
|
+
|
|
306
|
+
async def __aenter__(self) -> Self:
|
|
307
|
+
"""Async context manager entry."""
|
|
308
|
+
return self
|
|
309
|
+
|
|
310
|
+
async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
|
|
311
|
+
"""Async context manager exit."""
|
|
312
|
+
await self.close()
|
|
@@ -5,15 +5,18 @@ from __future__ import annotations
|
|
|
5
5
|
import logging
|
|
6
6
|
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
|
+
import mcp.types as types
|
|
8
9
|
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
|
9
10
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
10
11
|
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
12
|
+
from mcp.types import CallToolRequestParams as MCPToolCall
|
|
13
|
+
from mcp.types import CallToolResult as MCPToolResult
|
|
11
14
|
from mcp_use.adapters.langchain_adapter import LangChainAdapter
|
|
12
15
|
|
|
13
16
|
if TYPE_CHECKING:
|
|
14
17
|
from langchain.schema.language_model import BaseLanguageModel
|
|
15
18
|
from langchain_core.tools import BaseTool
|
|
16
|
-
from .base import BaseMCPAgent
|
|
19
|
+
from .base import BaseMCPAgent, ModelResponse
|
|
17
20
|
|
|
18
21
|
logger = logging.getLogger(__name__)
|
|
19
22
|
|
|
@@ -44,6 +47,12 @@ class LangChainMCPAgent(BaseMCPAgent):
|
|
|
44
47
|
self.adapter = LangChainAdapter(disallowed_tools=self.disallowed_tools)
|
|
45
48
|
self._langchain_tools: list[BaseTool] | None = None
|
|
46
49
|
|
|
50
|
+
self.model_name = (
|
|
51
|
+
"langchain-" + self.llm.model_name # type: ignore
|
|
52
|
+
if hasattr(self.llm, "model_name")
|
|
53
|
+
else "unknown"
|
|
54
|
+
)
|
|
55
|
+
|
|
47
56
|
def _get_langchain_tools(self) -> list[BaseTool]:
|
|
48
57
|
"""Get or create LangChain tools from MCP tools."""
|
|
49
58
|
if self._langchain_tools is not None:
|
|
@@ -86,7 +95,7 @@ class LangChainMCPAgent(BaseMCPAgent):
|
|
|
86
95
|
|
|
87
96
|
return messages
|
|
88
97
|
|
|
89
|
-
async def get_model_response(self, messages: list[BaseMessage]
|
|
98
|
+
async def get_model_response(self, messages: list[BaseMessage]) -> ModelResponse:
|
|
90
99
|
"""Get response from LangChain model including any tool calls."""
|
|
91
100
|
# Get LangChain tools (created lazily)
|
|
92
101
|
langchain_tools = self._get_langchain_tools()
|
|
@@ -133,11 +142,7 @@ class LangChainMCPAgent(BaseMCPAgent):
|
|
|
133
142
|
break
|
|
134
143
|
|
|
135
144
|
if not last_user_msg:
|
|
136
|
-
return
|
|
137
|
-
"content": "No user message found",
|
|
138
|
-
"tool_calls": [],
|
|
139
|
-
"done": True,
|
|
140
|
-
}
|
|
145
|
+
return ModelResponse(content="No user message found", tool_calls=[], done=True)
|
|
141
146
|
|
|
142
147
|
# Extract text from message content
|
|
143
148
|
input_text = ""
|
|
@@ -175,54 +180,68 @@ class LangChainMCPAgent(BaseMCPAgent):
|
|
|
175
180
|
for action, _ in result["intermediate_steps"]:
|
|
176
181
|
if hasattr(action, "tool") and hasattr(action, "tool_input"):
|
|
177
182
|
tool_calls.append(
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
183
|
+
MCPToolCall(
|
|
184
|
+
name=action.tool,
|
|
185
|
+
arguments=action.tool_input,
|
|
186
|
+
)
|
|
182
187
|
)
|
|
183
188
|
|
|
184
|
-
return
|
|
185
|
-
"content": output,
|
|
186
|
-
"tool_calls": tool_calls,
|
|
187
|
-
"done": False, # Continue if tools were called
|
|
188
|
-
}
|
|
189
|
+
return ModelResponse(content=output, tool_calls=tool_calls, done=False)
|
|
189
190
|
else:
|
|
190
191
|
# No tools called, just text response
|
|
191
|
-
return
|
|
192
|
-
"content": output,
|
|
193
|
-
"tool_calls": [],
|
|
194
|
-
"done": True,
|
|
195
|
-
}
|
|
192
|
+
return ModelResponse(content=output, tool_calls=[], done=True)
|
|
196
193
|
|
|
197
194
|
except Exception as e:
|
|
198
195
|
logger.error("Agent execution failed: %s", e)
|
|
199
|
-
return {
|
|
200
|
-
"content": f"Error: {e!s}",
|
|
201
|
-
"tool_calls": [],
|
|
202
|
-
"done": True,
|
|
203
|
-
}
|
|
196
|
+
return ModelResponse(content=f"Error: {e!s}", tool_calls=[], done=True)
|
|
204
197
|
|
|
205
198
|
async def format_tool_results(
|
|
206
|
-
self,
|
|
199
|
+
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
207
200
|
) -> list[BaseMessage]:
|
|
208
201
|
"""Format tool results into LangChain messages."""
|
|
209
202
|
# Create an AI message with the tool calls and results
|
|
210
203
|
messages = []
|
|
211
204
|
|
|
212
205
|
# First add an AI message indicating tools were called
|
|
213
|
-
tool_names = [tc
|
|
206
|
+
tool_names = [tc.name for tc in tool_calls]
|
|
214
207
|
ai_content = f"I'll use the following tools: {', '.join(tool_names)}"
|
|
215
208
|
messages.append(AIMessage(content=ai_content))
|
|
216
209
|
|
|
217
|
-
#
|
|
218
|
-
|
|
219
|
-
|
|
210
|
+
# Build result text from tool results
|
|
211
|
+
text_parts = []
|
|
212
|
+
latest_screenshot = None
|
|
213
|
+
|
|
214
|
+
for tool_call, result in zip(tool_calls, tool_results, strict=False):
|
|
215
|
+
if result.isError:
|
|
216
|
+
error_text = "Tool execution failed"
|
|
217
|
+
for content in result.content:
|
|
218
|
+
if isinstance(content, types.TextContent):
|
|
219
|
+
error_text = content.text
|
|
220
|
+
break
|
|
221
|
+
text_parts.append(f"Error - {tool_call.name}: {error_text}")
|
|
222
|
+
else:
|
|
223
|
+
# Process success content
|
|
224
|
+
tool_output = []
|
|
225
|
+
for content in result.content:
|
|
226
|
+
if isinstance(content, types.TextContent):
|
|
227
|
+
tool_output.append(content.text)
|
|
228
|
+
elif isinstance(content, types.ImageContent):
|
|
229
|
+
latest_screenshot = content.data
|
|
220
230
|
|
|
221
|
-
|
|
231
|
+
if tool_output:
|
|
232
|
+
text_parts.append(f"{tool_call.name}: " + " ".join(tool_output))
|
|
233
|
+
|
|
234
|
+
result_text = "\n".join(text_parts) if text_parts else "No output from tools"
|
|
235
|
+
|
|
236
|
+
# Then add a human message with the tool results
|
|
237
|
+
if latest_screenshot:
|
|
222
238
|
# Include screenshot in multimodal format
|
|
223
239
|
content = [
|
|
224
240
|
{"type": "text", "text": f"Tool results:\n{result_text}"},
|
|
225
|
-
{
|
|
241
|
+
{
|
|
242
|
+
"type": "image_url",
|
|
243
|
+
"image_url": {"url": f"data:image/png;base64,{latest_screenshot}"},
|
|
244
|
+
},
|
|
226
245
|
]
|
|
227
246
|
messages.append(HumanMessage(content=content))
|
|
228
247
|
else:
|