hud-python 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +7 -4
- hud/adapters/common/adapter.py +14 -3
- hud/adapters/common/tests/test_adapter.py +16 -4
- hud/datasets.py +188 -0
- hud/env/docker_client.py +14 -2
- hud/env/local_docker_client.py +28 -6
- hud/gym.py +0 -9
- hud/{mcp_agent → mcp}/__init__.py +2 -0
- hud/mcp/base.py +631 -0
- hud/{mcp_agent → mcp}/claude.py +52 -47
- hud/mcp/client.py +312 -0
- hud/{mcp_agent → mcp}/langchain.py +52 -33
- hud/{mcp_agent → mcp}/openai.py +56 -40
- hud/{mcp_agent → mcp}/tests/test_base.py +129 -54
- hud/mcp/tests/test_claude.py +294 -0
- hud/mcp/tests/test_client.py +324 -0
- hud/mcp/tests/test_openai.py +238 -0
- hud/settings.py +6 -0
- hud/task.py +2 -88
- hud/taskset.py +2 -23
- hud/telemetry/__init__.py +5 -0
- hud/telemetry/_trace.py +180 -17
- hud/telemetry/context.py +79 -0
- hud/telemetry/exporter.py +165 -6
- hud/telemetry/job.py +141 -0
- hud/telemetry/tests/test_trace.py +36 -25
- hud/tools/__init__.py +14 -1
- hud/tools/computer/hud.py +13 -0
- hud/tools/executors/__init__.py +19 -2
- hud/tools/executors/pyautogui.py +84 -50
- hud/tools/executors/tests/test_pyautogui_executor.py +4 -1
- hud/tools/playwright_tool.py +73 -67
- hud/tools/tests/test_edit.py +8 -1
- hud/tools/tests/test_tools.py +3 -0
- hud/trajectory.py +5 -1
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/METADATA +20 -14
- {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/RECORD +42 -47
- hud/evaluators/__init__.py +0 -9
- hud/evaluators/base.py +0 -32
- hud/evaluators/inspect.py +0 -24
- hud/evaluators/judge.py +0 -189
- hud/evaluators/match.py +0 -156
- hud/evaluators/remote.py +0 -65
- hud/evaluators/tests/__init__.py +0 -0
- hud/evaluators/tests/test_inspect.py +0 -12
- hud/evaluators/tests/test_judge.py +0 -231
- hud/evaluators/tests/test_match.py +0 -115
- hud/evaluators/tests/test_remote.py +0 -98
- hud/mcp_agent/base.py +0 -723
- /hud/{mcp_agent → mcp}/tests/__init__.py +0 -0
- {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/WHEEL +0 -0
- {hud_python-0.3.0.dist-info → hud_python-0.3.2.dist-info}/licenses/LICENSE +0 -0
hud/{mcp_agent → mcp}/openai.py
RENAMED
|
@@ -3,8 +3,11 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
from typing import Any, Literal
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
7
7
|
|
|
8
|
+
import mcp.types as types
|
|
9
|
+
from mcp.types import CallToolRequestParams as MCPToolCall
|
|
10
|
+
from mcp.types import CallToolResult as MCPToolResult
|
|
8
11
|
from openai import AsyncOpenAI
|
|
9
12
|
from openai.types.responses import (
|
|
10
13
|
ResponseComputerToolCall,
|
|
@@ -16,7 +19,10 @@ from openai.types.responses import (
|
|
|
16
19
|
|
|
17
20
|
from hud.settings import settings
|
|
18
21
|
|
|
19
|
-
from .base import BaseMCPAgent
|
|
22
|
+
from .base import AgentResult, BaseMCPAgent, ModelResponse
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from hud.datasets import TaskConfig
|
|
20
26
|
|
|
21
27
|
logger = logging.getLogger(__name__)
|
|
22
28
|
|
|
@@ -69,6 +75,8 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
69
75
|
self.pending_call_id: str | None = None
|
|
70
76
|
self.pending_safety_checks: list[Any] = []
|
|
71
77
|
|
|
78
|
+
self.model_name = "openai-" + self.model
|
|
79
|
+
|
|
72
80
|
# Base system prompt for autonomous operation
|
|
73
81
|
self.base_system_prompt = """
|
|
74
82
|
You are an autonomous computer-using agent. Follow these guidelines:
|
|
@@ -84,11 +92,9 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
84
92
|
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
|
|
85
93
|
""" # noqa: E501
|
|
86
94
|
|
|
87
|
-
async def run(
|
|
88
|
-
self, prompt: str, max_steps: int = 10, conversation_mode: bool = False
|
|
89
|
-
) -> dict[str, Any]:
|
|
95
|
+
async def run(self, prompt_or_task: str | TaskConfig, max_steps: int = 10) -> AgentResult:
|
|
90
96
|
"""
|
|
91
|
-
Run the agent with the given prompt.
|
|
97
|
+
Run the agent with the given prompt or task.
|
|
92
98
|
|
|
93
99
|
Override to reset OpenAI-specific state.
|
|
94
100
|
"""
|
|
@@ -98,9 +104,11 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
98
104
|
self.pending_safety_checks = []
|
|
99
105
|
|
|
100
106
|
# Use base implementation
|
|
101
|
-
return await super().run(
|
|
107
|
+
return await super().run(prompt_or_task, max_steps)
|
|
102
108
|
|
|
103
|
-
async def create_initial_messages(
|
|
109
|
+
async def create_initial_messages(
|
|
110
|
+
self, prompt: str, screenshot: str | None = None
|
|
111
|
+
) -> list[Any]:
|
|
104
112
|
"""
|
|
105
113
|
Create initial messages for OpenAI.
|
|
106
114
|
|
|
@@ -111,7 +119,7 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
111
119
|
# Just return a list with the prompt and screenshot
|
|
112
120
|
return [{"prompt": prompt, "screenshot": screenshot}]
|
|
113
121
|
|
|
114
|
-
async def get_model_response(self, messages: list[Any]
|
|
122
|
+
async def get_model_response(self, messages: list[Any]) -> ModelResponse:
|
|
115
123
|
"""Get response from OpenAI including any tool calls."""
|
|
116
124
|
# OpenAI's API is stateful, so we handle messages differently
|
|
117
125
|
|
|
@@ -124,11 +132,11 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
124
132
|
|
|
125
133
|
if not computer_tool_name:
|
|
126
134
|
# No computer tools available, just return a text response
|
|
127
|
-
return
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
135
|
+
return ModelResponse(
|
|
136
|
+
content="No computer use tools available",
|
|
137
|
+
tool_calls=[],
|
|
138
|
+
done=True,
|
|
139
|
+
)
|
|
132
140
|
|
|
133
141
|
# Define the computer use tool
|
|
134
142
|
computer_tool: ToolParam = { # type: ignore[reportAssignmentType]
|
|
@@ -193,11 +201,11 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
193
201
|
|
|
194
202
|
if not latest_screenshot:
|
|
195
203
|
logger.warning("No screenshot provided for response to action")
|
|
196
|
-
return
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
204
|
+
return ModelResponse(
|
|
205
|
+
content="No screenshot available for next action",
|
|
206
|
+
tool_calls=[],
|
|
207
|
+
done=True,
|
|
208
|
+
)
|
|
201
209
|
|
|
202
210
|
# Create response to previous action
|
|
203
211
|
input_param_followup: ResponseInputParam = [ # type: ignore[reportAssignmentType]
|
|
@@ -226,12 +234,11 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
226
234
|
self.last_response_id = response.id
|
|
227
235
|
|
|
228
236
|
# Process response
|
|
229
|
-
result =
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
}
|
|
237
|
+
result = ModelResponse(
|
|
238
|
+
content="",
|
|
239
|
+
tool_calls=[],
|
|
240
|
+
done=False, # Will be set to True only if no tool calls
|
|
241
|
+
)
|
|
235
242
|
|
|
236
243
|
self.pending_call_id = None
|
|
237
244
|
|
|
@@ -244,7 +251,7 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
244
251
|
|
|
245
252
|
if computer_calls:
|
|
246
253
|
# Process computer calls
|
|
247
|
-
result
|
|
254
|
+
result.done = False
|
|
248
255
|
for computer_call in computer_calls:
|
|
249
256
|
self.pending_call_id = computer_call.call_id
|
|
250
257
|
self.pending_safety_checks = computer_call.pending_safety_checks
|
|
@@ -252,13 +259,15 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
252
259
|
# Convert OpenAI action to MCP tool call
|
|
253
260
|
action = computer_call.action.model_dump()
|
|
254
261
|
|
|
255
|
-
#
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
+
# Create MCPToolCall object with OpenAI metadata as extra fields
|
|
263
|
+
# Pyright will complain but the tool class accepts extra fields
|
|
264
|
+
tool_call = MCPToolCall(
|
|
265
|
+
name=computer_tool_name,
|
|
266
|
+
arguments=action,
|
|
267
|
+
call_id=computer_call.call_id, # type: ignore
|
|
268
|
+
pending_safety_checks=computer_call.pending_safety_checks, # type: ignore
|
|
269
|
+
)
|
|
270
|
+
result.tool_calls.append(tool_call)
|
|
262
271
|
else:
|
|
263
272
|
# No computer calls, check for text response
|
|
264
273
|
for item in response.output:
|
|
@@ -270,7 +279,7 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
270
279
|
if isinstance(content, ResponseOutputText)
|
|
271
280
|
]
|
|
272
281
|
if text_parts:
|
|
273
|
-
result
|
|
282
|
+
result.content = "".join(text_parts)
|
|
274
283
|
break
|
|
275
284
|
|
|
276
285
|
# Extract reasoning if present
|
|
@@ -280,16 +289,16 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
280
289
|
reasoning_text += f"Thinking: {item.summary[0].text}\n"
|
|
281
290
|
|
|
282
291
|
if reasoning_text:
|
|
283
|
-
result
|
|
292
|
+
result.content = reasoning_text + result.content if result.content else reasoning_text
|
|
284
293
|
|
|
285
294
|
# Set done=True if no tool calls (task complete or waiting for user)
|
|
286
|
-
if not result
|
|
287
|
-
result
|
|
295
|
+
if not result.tool_calls:
|
|
296
|
+
result.done = True
|
|
288
297
|
|
|
289
298
|
return result
|
|
290
299
|
|
|
291
300
|
async def format_tool_results(
|
|
292
|
-
self,
|
|
301
|
+
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
293
302
|
) -> list[Any]:
|
|
294
303
|
"""
|
|
295
304
|
Format tool results for OpenAI's stateful API.
|
|
@@ -297,12 +306,19 @@ class OpenAIMCPAgent(BaseMCPAgent):
|
|
|
297
306
|
OpenAI doesn't use a traditional message format - we just need to
|
|
298
307
|
preserve the screenshot for the next step.
|
|
299
308
|
"""
|
|
300
|
-
#
|
|
309
|
+
# Extract latest screenshot from results
|
|
310
|
+
latest_screenshot = None
|
|
311
|
+
for result in tool_results:
|
|
312
|
+
if not result.isError:
|
|
313
|
+
for content in result.content:
|
|
314
|
+
if isinstance(content, types.ImageContent):
|
|
315
|
+
latest_screenshot = content.data
|
|
316
|
+
|
|
301
317
|
# Return a simple dict that get_model_response can use
|
|
302
318
|
return [
|
|
303
319
|
{
|
|
304
320
|
"type": "tool_result",
|
|
305
|
-
"screenshot":
|
|
321
|
+
"screenshot": latest_screenshot,
|
|
306
322
|
}
|
|
307
323
|
]
|
|
308
324
|
|
|
@@ -5,10 +5,18 @@ from __future__ import annotations
|
|
|
5
5
|
from typing import TYPE_CHECKING, Any
|
|
6
6
|
from unittest.mock import MagicMock
|
|
7
7
|
|
|
8
|
+
# Import AsyncMock from unittest.mock if available (Python 3.8+)
|
|
9
|
+
try:
|
|
10
|
+
from unittest.mock import AsyncMock
|
|
11
|
+
except ImportError:
|
|
12
|
+
# Fallback for older Python versions
|
|
13
|
+
from unittest.mock import MagicMock as AsyncMock
|
|
14
|
+
|
|
8
15
|
import pytest
|
|
9
16
|
from mcp import types
|
|
17
|
+
from mcp.types import CallToolRequestParams as MCPToolCall
|
|
10
18
|
|
|
11
|
-
from hud.
|
|
19
|
+
from hud.mcp.base import BaseMCPAgent
|
|
12
20
|
from hud.tools.executors.base import BaseExecutor
|
|
13
21
|
|
|
14
22
|
if TYPE_CHECKING:
|
|
@@ -18,8 +26,13 @@ if TYPE_CHECKING:
|
|
|
18
26
|
class MockMCPAgent(BaseMCPAgent):
|
|
19
27
|
"""Concrete implementation of BaseMCPAgent for testing."""
|
|
20
28
|
|
|
21
|
-
def __init__(self, **kwargs: Any) -> None:
|
|
22
|
-
|
|
29
|
+
def __init__(self, mcp_client: Any = None, **kwargs: Any) -> None:
|
|
30
|
+
if mcp_client is None:
|
|
31
|
+
# Create a mock client if none provided
|
|
32
|
+
mcp_client = MagicMock()
|
|
33
|
+
mcp_client.get_all_active_sessions = MagicMock(return_value={})
|
|
34
|
+
mcp_client.get_available_tools = MagicMock(return_value=[])
|
|
35
|
+
super().__init__(mcp_client=mcp_client, **kwargs)
|
|
23
36
|
self.executor = BaseExecutor() # Use simulated executor
|
|
24
37
|
self._messages = []
|
|
25
38
|
|
|
@@ -66,46 +79,58 @@ class TestBaseMCPAgent:
|
|
|
66
79
|
"""Test initialization with default values."""
|
|
67
80
|
agent = MockMCPAgent()
|
|
68
81
|
|
|
69
|
-
assert agent.
|
|
82
|
+
assert agent.mcp_client is not None
|
|
70
83
|
assert agent.allowed_tools is None
|
|
71
84
|
assert agent.disallowed_tools == []
|
|
72
85
|
assert agent.initial_screenshot is False
|
|
73
86
|
assert agent.max_screenshot_history == 3
|
|
74
87
|
assert agent.append_tool_system_prompt is True
|
|
75
88
|
assert agent.custom_system_prompt is None
|
|
76
|
-
assert agent.lifecycle_tools ==
|
|
89
|
+
assert agent.lifecycle_tools == []
|
|
77
90
|
|
|
78
91
|
def test_init_with_params(self):
|
|
79
92
|
"""Test initialization with custom parameters."""
|
|
80
93
|
client = MagicMock()
|
|
81
94
|
agent = MockMCPAgent(
|
|
82
|
-
|
|
95
|
+
mcp_client=client,
|
|
83
96
|
allowed_tools=["tool1", "tool2"],
|
|
84
97
|
disallowed_tools=["bad_tool"],
|
|
85
98
|
initial_screenshot=True,
|
|
86
99
|
max_screenshot_history=5,
|
|
87
100
|
append_tool_system_prompt=False,
|
|
88
101
|
custom_system_prompt="Custom prompt",
|
|
89
|
-
lifecycle_tools=
|
|
102
|
+
lifecycle_tools=["custom_setup", "custom_eval"],
|
|
90
103
|
)
|
|
91
104
|
|
|
92
|
-
assert agent.
|
|
105
|
+
assert agent.mcp_client == client
|
|
93
106
|
assert agent.allowed_tools == ["tool1", "tool2"]
|
|
94
107
|
assert agent.disallowed_tools == ["bad_tool"]
|
|
95
108
|
assert agent.initial_screenshot is True
|
|
96
109
|
assert agent.max_screenshot_history == 5
|
|
97
110
|
assert agent.append_tool_system_prompt is False
|
|
98
111
|
assert agent.custom_system_prompt == "Custom prompt"
|
|
99
|
-
assert agent.lifecycle_tools ==
|
|
112
|
+
assert agent.lifecycle_tools == ["custom_setup", "custom_eval"]
|
|
100
113
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
114
|
+
def test_init_no_client(self):
|
|
115
|
+
"""Test init fails without client."""
|
|
116
|
+
|
|
117
|
+
# Create a minimal concrete implementation to test the ValueError
|
|
118
|
+
class TestAgent(BaseMCPAgent):
|
|
119
|
+
def create_initial_messages(
|
|
120
|
+
self, prompt: str, screenshot: str | None = None
|
|
121
|
+
) -> list[dict[str, Any]]:
|
|
122
|
+
return []
|
|
106
123
|
|
|
107
|
-
|
|
108
|
-
|
|
124
|
+
def format_tool_results(
|
|
125
|
+
self, results: list[tuple[str, Any]], screenshot: str | None = None
|
|
126
|
+
) -> list[dict[str, Any]]:
|
|
127
|
+
return []
|
|
128
|
+
|
|
129
|
+
async def get_model_response(self, messages: list[dict[str, Any]]) -> dict[str, Any]:
|
|
130
|
+
return {"content": "test"}
|
|
131
|
+
|
|
132
|
+
with pytest.raises(ValueError, match="MCPClient is required"):
|
|
133
|
+
TestAgent(mcp_client=None)
|
|
109
134
|
|
|
110
135
|
@pytest.mark.asyncio
|
|
111
136
|
async def test_initialize_with_sessions(self):
|
|
@@ -133,14 +158,31 @@ class TestBaseMCPAgent:
|
|
|
133
158
|
|
|
134
159
|
mock_session.connector.client_session.list_tools = mock_list_tools
|
|
135
160
|
|
|
136
|
-
assert agent.
|
|
137
|
-
agent.
|
|
161
|
+
assert agent.mcp_client is not None
|
|
162
|
+
agent.mcp_client.get_all_active_sessions = MagicMock(return_value={"server1": mock_session})
|
|
163
|
+
|
|
164
|
+
# Mock get_tool_map to return tools discovered from sessions
|
|
165
|
+
tool_map = {
|
|
166
|
+
"tool1": (
|
|
167
|
+
"server1",
|
|
168
|
+
types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}),
|
|
169
|
+
),
|
|
170
|
+
"tool2": (
|
|
171
|
+
"server1",
|
|
172
|
+
types.Tool(name="tool2", description="Tool 2", inputSchema={"type": "object"}),
|
|
173
|
+
),
|
|
174
|
+
"setup": (
|
|
175
|
+
"server1",
|
|
176
|
+
types.Tool(name="setup", description="Setup tool", inputSchema={"type": "object"}),
|
|
177
|
+
),
|
|
178
|
+
}
|
|
179
|
+
agent.mcp_client.get_tool_map = MagicMock(return_value=tool_map)
|
|
138
180
|
|
|
139
181
|
await agent.initialize()
|
|
140
182
|
|
|
141
183
|
# Check available tools were populated (excludes lifecycle tools)
|
|
142
184
|
tools = agent.get_available_tools()
|
|
143
|
-
assert len(tools) ==
|
|
185
|
+
assert len(tools) == 3 # All tools (setup is not in default lifecycle tools)
|
|
144
186
|
|
|
145
187
|
# Check tool map was populated (includes all tools)
|
|
146
188
|
tool_map = agent.get_tool_map()
|
|
@@ -173,15 +215,36 @@ class TestBaseMCPAgent:
|
|
|
173
215
|
|
|
174
216
|
mock_session.connector.client_session.list_tools = mock_list_tools
|
|
175
217
|
|
|
176
|
-
assert agent.
|
|
177
|
-
agent.
|
|
218
|
+
assert agent.mcp_client is not None
|
|
219
|
+
agent.mcp_client.get_all_active_sessions = MagicMock(return_value={"server1": mock_session})
|
|
220
|
+
|
|
221
|
+
# Mock get_tool_map to return tools discovered from sessions
|
|
222
|
+
tool_map = {
|
|
223
|
+
"tool1": (
|
|
224
|
+
"server1",
|
|
225
|
+
types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}),
|
|
226
|
+
),
|
|
227
|
+
"tool2": (
|
|
228
|
+
"server1",
|
|
229
|
+
types.Tool(name="tool2", description="Tool 2", inputSchema={"type": "object"}),
|
|
230
|
+
),
|
|
231
|
+
"tool3": (
|
|
232
|
+
"server1",
|
|
233
|
+
types.Tool(name="tool3", description="Tool 3", inputSchema={"type": "object"}),
|
|
234
|
+
),
|
|
235
|
+
"setup": (
|
|
236
|
+
"server1",
|
|
237
|
+
types.Tool(name="setup", description="Setup", inputSchema={"type": "object"}),
|
|
238
|
+
),
|
|
239
|
+
}
|
|
240
|
+
agent.mcp_client.get_tool_map = MagicMock(return_value=tool_map)
|
|
178
241
|
|
|
179
242
|
await agent.initialize()
|
|
180
243
|
|
|
181
244
|
# Check filtering worked - get_available_tools excludes lifecycle tools
|
|
182
245
|
tools = agent.get_available_tools()
|
|
183
246
|
tool_names = [t.name for t in tools]
|
|
184
|
-
assert len(tools) == 1 # Only tool1 (
|
|
247
|
+
assert len(tools) == 1 # Only tool1 (tool2 and tool3 are filtered out)
|
|
185
248
|
assert "tool1" in tool_names
|
|
186
249
|
assert "setup" not in tool_names # Lifecycle tool excluded from available tools
|
|
187
250
|
assert "tool2" not in tool_names # Not in allowed list
|
|
@@ -216,14 +279,26 @@ class TestBaseMCPAgent:
|
|
|
216
279
|
|
|
217
280
|
mock_session.connector.client_session.call_tool = mock_call_tool
|
|
218
281
|
|
|
219
|
-
assert agent.
|
|
220
|
-
agent.
|
|
221
|
-
|
|
282
|
+
assert agent.mcp_client is not None
|
|
283
|
+
agent.mcp_client.get_all_active_sessions = MagicMock(return_value={"server1": mock_session})
|
|
284
|
+
|
|
285
|
+
# Mock get_tool_map to return tools discovered from sessions
|
|
286
|
+
tool_map = {
|
|
287
|
+
"test_tool": (
|
|
288
|
+
"server1",
|
|
289
|
+
types.Tool(name="test_tool", description="Test", inputSchema={"type": "object"}),
|
|
290
|
+
)
|
|
291
|
+
}
|
|
292
|
+
agent.mcp_client.get_tool_map = MagicMock(return_value=tool_map)
|
|
293
|
+
|
|
294
|
+
# Mock the client's call_tool method directly
|
|
295
|
+
agent.mcp_client.call_tool = AsyncMock(return_value=mock_result)
|
|
222
296
|
|
|
223
297
|
await agent.initialize()
|
|
224
298
|
|
|
225
299
|
# Call the tool
|
|
226
|
-
|
|
300
|
+
tool_call = MCPToolCall(name="test_tool", arguments={"param": "value"})
|
|
301
|
+
result = await agent.call_tool(tool_call)
|
|
227
302
|
|
|
228
303
|
assert result == mock_result
|
|
229
304
|
assert not result.isError
|
|
@@ -240,22 +315,25 @@ class TestBaseMCPAgent:
|
|
|
240
315
|
return types.ListToolsResult(tools=[])
|
|
241
316
|
|
|
242
317
|
mock_session.list_tools = mock_list_tools
|
|
243
|
-
assert agent.
|
|
244
|
-
agent.
|
|
318
|
+
assert agent.mcp_client is not None
|
|
319
|
+
agent.mcp_client.get_all_active_sessions = MagicMock(return_value={"server1": mock_session})
|
|
245
320
|
|
|
246
321
|
await agent.initialize()
|
|
247
322
|
|
|
248
323
|
# Try to call unknown tool
|
|
249
324
|
with pytest.raises(ValueError, match="Tool 'unknown_tool' not found"):
|
|
250
|
-
|
|
325
|
+
tool_call = MCPToolCall(name="unknown_tool", arguments={})
|
|
326
|
+
await agent.call_tool(tool_call)
|
|
251
327
|
|
|
252
328
|
@pytest.mark.asyncio
|
|
253
329
|
async def test_call_tool_no_name(self):
|
|
254
330
|
"""Test calling tool without name."""
|
|
331
|
+
# MCPToolCall accepts empty names, but the agent should validate
|
|
255
332
|
agent = MockMCPAgent()
|
|
333
|
+
tool_call = MCPToolCall(name="", arguments={})
|
|
256
334
|
|
|
257
335
|
with pytest.raises(ValueError, match="Tool call must have a 'name' field"):
|
|
258
|
-
await agent.call_tool(
|
|
336
|
+
await agent.call_tool(tool_call)
|
|
259
337
|
|
|
260
338
|
def test_get_system_prompt_default(self):
|
|
261
339
|
"""Test get_system_prompt with default settings."""
|
|
@@ -307,6 +385,9 @@ class TestBaseMCPAgent:
|
|
|
307
385
|
"""Test getting tool schemas."""
|
|
308
386
|
agent = MockMCPAgent()
|
|
309
387
|
|
|
388
|
+
# Add setup to lifecycle tools to test filtering
|
|
389
|
+
agent.lifecycle_tools = ["setup"]
|
|
390
|
+
|
|
310
391
|
agent._available_tools = [
|
|
311
392
|
types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}),
|
|
312
393
|
types.Tool(name="setup", description="Setup", inputSchema={"type": "object"}),
|
|
@@ -360,36 +441,30 @@ class TestBaseMCPAgent:
|
|
|
360
441
|
|
|
361
442
|
mock_session.connector.client_session.call_tool = mock_call_tool
|
|
362
443
|
|
|
363
|
-
assert agent.
|
|
364
|
-
agent.
|
|
365
|
-
|
|
444
|
+
assert agent.mcp_client is not None
|
|
445
|
+
agent.mcp_client.get_all_active_sessions = MagicMock(return_value={"server1": mock_session})
|
|
446
|
+
|
|
447
|
+
# Mock get_tool_map to return tools discovered from sessions
|
|
448
|
+
tool_map = {
|
|
449
|
+
"screenshot": (
|
|
450
|
+
"server1",
|
|
451
|
+
types.Tool(
|
|
452
|
+
name="screenshot", description="Screenshot", inputSchema={"type": "object"}
|
|
453
|
+
),
|
|
454
|
+
)
|
|
455
|
+
}
|
|
456
|
+
agent.mcp_client.get_tool_map = MagicMock(return_value=tool_map)
|
|
457
|
+
|
|
458
|
+
# Mock the client's call_tool method directly
|
|
459
|
+
agent.mcp_client.call_tool = AsyncMock(return_value=mock_result)
|
|
366
460
|
|
|
367
461
|
await agent.initialize()
|
|
368
462
|
|
|
369
463
|
screenshot = await agent.capture_screenshot()
|
|
370
464
|
assert screenshot == "base64imagedata"
|
|
371
465
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
agent = MockMCPAgent()
|
|
375
|
-
|
|
376
|
-
# Create a proper CallToolResult object
|
|
377
|
-
result = types.CallToolResult(
|
|
378
|
-
content=[
|
|
379
|
-
types.TextContent(type="text", text="Result text"),
|
|
380
|
-
types.ImageContent(type="image", data="imagedata", mimeType="image/png"),
|
|
381
|
-
],
|
|
382
|
-
isError=False,
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
tool_results = [{"tool_name": "test_tool", "result": result}]
|
|
386
|
-
|
|
387
|
-
processed = agent.process_tool_results(tool_results)
|
|
388
|
-
|
|
389
|
-
assert "text" in processed
|
|
390
|
-
assert "Result text" in processed["text"]
|
|
391
|
-
assert "results" in processed
|
|
392
|
-
assert len(processed["results"]) == 1
|
|
466
|
+
# process_tool_results method was removed from base class
|
|
467
|
+
# This functionality is now handled internally
|
|
393
468
|
|
|
394
469
|
def test_get_tools_by_server(self):
|
|
395
470
|
"""Test getting tools grouped by server."""
|