mcp-use 1.3.6__py3-none-any.whl → 1.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-use might be problematic. Click here for more details.

@@ -0,0 +1,239 @@
1
+ """
2
+ Remote agent implementation for executing agents via API.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import Any, TypeVar
8
+
9
+ import httpx
10
+ from langchain.schema import BaseMessage
11
+ from pydantic import BaseModel
12
+
13
+ from ..logging import logger
14
+
15
+ T = TypeVar("T", bound=BaseModel)
16
+
17
+
18
+ class RemoteAgent:
19
+ """Agent that executes remotely via API."""
20
+
21
+ def __init__(self, agent_id: str, api_key: str | None = None, base_url: str = "https://cloud.mcp-use.com"):
22
+ """Initialize remote agent.
23
+
24
+ Args:
25
+ agent_id: The ID of the remote agent to execute
26
+ api_key: API key for authentication. If None, will check MCP_USE_API_KEY env var
27
+ base_url: Base URL for the remote API
28
+ """
29
+ self.agent_id = agent_id
30
+ self.base_url = base_url
31
+
32
+ # Handle API key validation
33
+ if api_key is None:
34
+ api_key = os.getenv("MCP_USE_API_KEY")
35
+ if not api_key:
36
+ raise ValueError(
37
+ "API key is required for remote execution. "
38
+ "Please provide it as a parameter or set the MCP_USE_API_KEY environment variable. "
39
+ "You can get an API key from https://cloud.mcp-use.com"
40
+ )
41
+
42
+ self.api_key = api_key
43
+ # Configure client with reasonable timeouts for agent execution
44
+ self._client = httpx.AsyncClient(
45
+ timeout=httpx.Timeout(
46
+ connect=10.0, # 10 seconds to establish connection
47
+ read=300.0, # 5 minutes to read response (agents can take time)
48
+ write=10.0, # 10 seconds to send request
49
+ pool=10.0, # 10 seconds to get connection from pool
50
+ )
51
+ )
52
+
53
+ def _pydantic_to_json_schema(self, model_class: type[T]) -> dict[str, Any]:
54
+ """Convert a Pydantic model to JSON schema for API transmission.
55
+
56
+ Args:
57
+ model_class: The Pydantic model class to convert
58
+
59
+ Returns:
60
+ JSON schema representation of the model
61
+ """
62
+ return model_class.model_json_schema()
63
+
64
+ def _parse_structured_response(self, response_data: Any, output_schema: type[T]) -> T:
65
+ """Parse the API response into the structured output format.
66
+
67
+ Args:
68
+ response_data: Raw response data from the API
69
+ output_schema: The Pydantic model to parse into
70
+
71
+ Returns:
72
+ Parsed structured output
73
+ """
74
+ # Handle different response formats
75
+ if isinstance(response_data, dict):
76
+ if "result" in response_data:
77
+ outer_result = response_data["result"]
78
+ # Check if this is a nested result structure (agent execution response)
79
+ if isinstance(outer_result, dict) and "result" in outer_result:
80
+ # Extract the actual structured output from the nested result
81
+ result_data = outer_result["result"]
82
+ else:
83
+ # Use the outer result directly
84
+ result_data = outer_result
85
+ else:
86
+ result_data = response_data
87
+ elif isinstance(response_data, str):
88
+ try:
89
+ result_data = json.loads(response_data)
90
+ except json.JSONDecodeError:
91
+ # If it's not valid JSON, try to create the model from the string content
92
+ result_data = {"content": response_data}
93
+ else:
94
+ result_data = response_data
95
+
96
+ # Parse into the Pydantic model
97
+ try:
98
+ return output_schema.model_validate(result_data)
99
+ except Exception as e:
100
+ logger.warning(f"Failed to parse structured output: {e}")
101
+ # Fallback: try to parse it as raw content if the model has a content field
102
+ if hasattr(output_schema, "model_fields") and "content" in output_schema.model_fields:
103
+ return output_schema.model_validate({"content": str(result_data)})
104
+ raise
105
+
106
+ async def run(
107
+ self,
108
+ query: str,
109
+ max_steps: int | None = None,
110
+ manage_connector: bool = True,
111
+ external_history: list[BaseMessage] | None = None,
112
+ output_schema: type[T] | None = None,
113
+ ) -> str | T:
114
+ """Run a query on the remote agent.
115
+
116
+ Args:
117
+ query: The query to execute
118
+ max_steps: Maximum number of steps (default: 10)
119
+ manage_connector: Ignored for remote execution
120
+ external_history: Ignored for remote execution (not supported yet)
121
+ output_schema: Optional Pydantic model for structured output
122
+
123
+ Returns:
124
+ The result from the remote agent execution (string or structured output)
125
+ """
126
+ if external_history is not None:
127
+ logger.warning("External history is not yet supported for remote execution")
128
+
129
+ payload = {"query": query, "max_steps": max_steps or 10}
130
+
131
+ # Add structured output schema if provided
132
+ if output_schema is not None:
133
+ payload["output_schema"] = self._pydantic_to_json_schema(output_schema)
134
+ logger.info(f"🔧 Using structured output with schema: {output_schema.__name__}")
135
+
136
+ headers = {"Content-Type": "application/json", "x-api-key": self.api_key}
137
+
138
+ url = f"{self.base_url}/api/v1/agents/{self.agent_id}/run"
139
+
140
+ try:
141
+ logger.info(f"🌐 Executing query on remote agent {self.agent_id}")
142
+ response = await self._client.post(url, json=payload, headers=headers)
143
+ response.raise_for_status()
144
+
145
+ result = response.json()
146
+ logger.info(f"🔧 Response: {result}")
147
+ logger.info("✅ Remote execution completed successfully")
148
+
149
+ # Check for error responses (even with 200 status)
150
+ if isinstance(result, dict):
151
+ # Check for actual error conditions (not just presence of error field)
152
+ if result.get("status") == "error" or (result.get("error") is not None):
153
+ error_msg = result.get("error", str(result))
154
+ logger.error(f"❌ Remote agent execution failed: {error_msg}")
155
+ raise RuntimeError(f"Remote agent execution failed: {error_msg}")
156
+
157
+ # Check if the response indicates agent initialization failure
158
+ if "failed to initialize" in str(result):
159
+ logger.error(f"❌ Agent initialization failed: {result}")
160
+ raise RuntimeError(
161
+ f"Agent initialization failed on remote server. "
162
+ f"This usually indicates:\n"
163
+ f"• Invalid agent configuration (LLM model, system prompt)\n"
164
+ f"• Missing or invalid MCP server configurations\n"
165
+ f"• Network connectivity issues with MCP servers\n"
166
+ f"• Missing environment variables or credentials\n"
167
+ f"Raw error: {result}"
168
+ )
169
+
170
+ # Handle structured output
171
+ if output_schema is not None:
172
+ return self._parse_structured_response(result, output_schema)
173
+
174
+ # Regular string output
175
+ if isinstance(result, dict) and "result" in result:
176
+ return result["result"]
177
+ elif isinstance(result, str):
178
+ return result
179
+ else:
180
+ return str(result)
181
+
182
+ except httpx.HTTPStatusError as e:
183
+ status_code = e.response.status_code
184
+ response_text = e.response.text
185
+
186
+ # Provide specific error messages based on status code
187
+ if status_code == 401:
188
+ logger.error(f"❌ Authentication failed: {response_text}")
189
+ raise RuntimeError(
190
+ "Authentication failed: Invalid or missing API key. "
191
+ "Please check your API key and ensure the MCP_USE_API_KEY environment variable is set correctly."
192
+ ) from e
193
+ elif status_code == 403:
194
+ logger.error(f"❌ Access forbidden: {response_text}")
195
+ raise RuntimeError(
196
+ f"Access denied: You don't have permission to execute agent '{self.agent_id}'. "
197
+ "Check if the agent exists and you have the necessary permissions."
198
+ ) from e
199
+ elif status_code == 404:
200
+ logger.error(f"❌ Agent not found: {response_text}")
201
+ raise RuntimeError(
202
+ f"Agent not found: Agent '{self.agent_id}' does not exist or you don't have access to it. "
203
+ "Please verify the agent ID and ensure it exists in your account."
204
+ ) from e
205
+ elif status_code == 422:
206
+ logger.error(f"❌ Validation error: {response_text}")
207
+ raise RuntimeError(
208
+ f"Request validation failed: {response_text}. "
209
+ "Please check your query parameters and output schema format."
210
+ ) from e
211
+ elif status_code == 500:
212
+ logger.error(f"❌ Server error: {response_text}")
213
+ raise RuntimeError(
214
+ "Internal server error occurred during agent execution. "
215
+ "Please try again later or contact support if the issue persists."
216
+ ) from e
217
+ else:
218
+ logger.error(f"❌ Remote execution failed with status {status_code}: {response_text}")
219
+ raise RuntimeError(f"Remote agent execution failed: {status_code} - {response_text}") from e
220
+ except httpx.TimeoutException as e:
221
+ logger.error(f"❌ Remote execution timed out: {e}")
222
+ raise RuntimeError(
223
+ "Remote agent execution timed out. The server may be overloaded or the query is taking too long to "
224
+ "process. Try again or use a simpler query."
225
+ ) from e
226
+ except httpx.ConnectError as e:
227
+ logger.error(f"❌ Remote execution connection error: {e}")
228
+ raise RuntimeError(
229
+ f"Remote agent connection failed: Cannot connect to {self.base_url}. "
230
+ f"Check if the server is running and the URL is correct."
231
+ ) from e
232
+ except Exception as e:
233
+ logger.error(f"❌ Remote execution error: {e}")
234
+ raise RuntimeError(f"Remote agent execution failed: {str(e)}") from e
235
+
236
+ async def close(self) -> None:
237
+ """Close the HTTP client."""
238
+ await self._client.aclose()
239
+ logger.info("🔌 Remote agent client closed")
mcp_use/client.py CHANGED
@@ -9,6 +9,8 @@ import json
9
9
  import warnings
10
10
  from typing import Any
11
11
 
12
+ from mcp.client.session import ElicitationFnT, SamplingFnT
13
+
12
14
  from mcp_use.types.sandbox import SandboxOptions
13
15
 
14
16
  from .config import create_connector_from_config, load_config_file
@@ -26,8 +28,11 @@ class MCPClient:
26
28
  def __init__(
27
29
  self,
28
30
  config: str | dict[str, Any] | None = None,
31
+ allowed_servers: list[str] | None = None,
29
32
  sandbox: bool = False,
30
33
  sandbox_options: SandboxOptions | None = None,
34
+ sampling_callback: SamplingFnT | None = None,
35
+ elicitation_callback: ElicitationFnT | None = None,
31
36
  ) -> None:
32
37
  """Initialize a new MCP client.
33
38
 
@@ -36,13 +41,16 @@ class MCPClient:
36
41
  If None, an empty configuration is used.
37
42
  sandbox: Whether to use sandboxed execution mode for running MCP servers.
38
43
  sandbox_options: Optional sandbox configuration options.
44
+ sampling_callback: Optional sampling callback function.
39
45
  """
40
46
  self.config: dict[str, Any] = {}
47
+ self.allowed_servers: list[str] = allowed_servers
41
48
  self.sandbox = sandbox
42
49
  self.sandbox_options = sandbox_options
43
50
  self.sessions: dict[str, MCPSession] = {}
44
51
  self.active_sessions: list[str] = []
45
-
52
+ self.sampling_callback = sampling_callback
53
+ self.elicitation_callback = elicitation_callback
46
54
  # Load configuration if provided
47
55
  if config is not None:
48
56
  if isinstance(config, str):
@@ -56,6 +64,8 @@ class MCPClient:
56
64
  config: dict[str, Any],
57
65
  sandbox: bool = False,
58
66
  sandbox_options: SandboxOptions | None = None,
67
+ sampling_callback: SamplingFnT | None = None,
68
+ elicitation_callback: ElicitationFnT | None = None,
59
69
  ) -> "MCPClient":
60
70
  """Create a MCPClient from a dictionary.
61
71
 
@@ -63,12 +73,25 @@ class MCPClient:
63
73
  config: The configuration dictionary.
64
74
  sandbox: Whether to use sandboxed execution mode for running MCP servers.
65
75
  sandbox_options: Optional sandbox configuration options.
76
+ sampling_callback: Optional sampling callback function.
77
+ elicitation_callback: Optional elicitation callback function.
66
78
  """
67
- return cls(config=config, sandbox=sandbox, sandbox_options=sandbox_options)
79
+ return cls(
80
+ config=config,
81
+ sandbox=sandbox,
82
+ sandbox_options=sandbox_options,
83
+ sampling_callback=sampling_callback,
84
+ elicitation_callback=elicitation_callback,
85
+ )
68
86
 
69
87
  @classmethod
70
88
  def from_config_file(
71
- cls, filepath: str, sandbox: bool = False, sandbox_options: SandboxOptions | None = None
89
+ cls,
90
+ filepath: str,
91
+ sandbox: bool = False,
92
+ sandbox_options: SandboxOptions | None = None,
93
+ sampling_callback: SamplingFnT | None = None,
94
+ elicitation_callback: ElicitationFnT | None = None,
72
95
  ) -> "MCPClient":
73
96
  """Create a MCPClient from a configuration file.
74
97
 
@@ -76,8 +99,16 @@ class MCPClient:
76
99
  filepath: The path to the configuration file.
77
100
  sandbox: Whether to use sandboxed execution mode for running MCP servers.
78
101
  sandbox_options: Optional sandbox configuration options.
102
+ sampling_callback: Optional sampling callback function.
103
+ elicitation_callback: Optional elicitation callback function.
79
104
  """
80
- return cls(config=load_config_file(filepath), sandbox=sandbox, sandbox_options=sandbox_options)
105
+ return cls(
106
+ config=load_config_file(filepath),
107
+ sandbox=sandbox,
108
+ sandbox_options=sandbox_options,
109
+ sampling_callback=sampling_callback,
110
+ elicitation_callback=elicitation_callback,
111
+ )
81
112
 
82
113
  def add_server(
83
114
  self,
@@ -151,7 +182,11 @@ class MCPClient:
151
182
 
152
183
  # Create connector with options
153
184
  connector = create_connector_from_config(
154
- server_config, sandbox=self.sandbox, sandbox_options=self.sandbox_options
185
+ server_config,
186
+ sandbox=self.sandbox,
187
+ sandbox_options=self.sandbox_options,
188
+ sampling_callback=self.sampling_callback,
189
+ elicitation_callback=self.elicitation_callback,
155
190
  )
156
191
 
157
192
  # Create the session
@@ -187,9 +222,10 @@ class MCPClient:
187
222
  warnings.warn("No MCP servers defined in config", UserWarning, stacklevel=2)
188
223
  return {}
189
224
 
190
- # Create sessions for all servers
225
+ # Create sessions only for allowed servers if applicable else create for all servers
191
226
  for name in servers:
192
- await self.create_session(name, auto_initialize)
227
+ if self.allowed_servers is None or name in self.allowed_servers:
228
+ await self.create_session(name, auto_initialize)
193
229
 
194
230
  return self.sessions
195
231
 
mcp_use/config.py CHANGED
@@ -7,6 +7,8 @@ This module provides functionality to load MCP configuration from JSON files.
7
7
  import json
8
8
  from typing import Any
9
9
 
10
+ from mcp.client.session import ElicitationFnT, SamplingFnT
11
+
10
12
  from mcp_use.types.sandbox import SandboxOptions
11
13
 
12
14
  from .connectors import (
@@ -36,6 +38,8 @@ def create_connector_from_config(
36
38
  server_config: dict[str, Any],
37
39
  sandbox: bool = False,
38
40
  sandbox_options: SandboxOptions | None = None,
41
+ sampling_callback: SamplingFnT | None = None,
42
+ elicitation_callback: ElicitationFnT | None = None,
39
43
  ) -> BaseConnector:
40
44
  """Create a connector based on server configuration.
41
45
  This function can be called with just the server_config parameter:
@@ -44,7 +48,7 @@ def create_connector_from_config(
44
48
  server_config: The server configuration section
45
49
  sandbox: Whether to use sandboxed execution mode for running MCP servers.
46
50
  sandbox_options: Optional sandbox configuration options.
47
-
51
+ sampling_callback: Optional sampling callback function.
48
52
  Returns:
49
53
  A configured connector instance
50
54
  """
@@ -55,6 +59,8 @@ def create_connector_from_config(
55
59
  command=server_config["command"],
56
60
  args=server_config["args"],
57
61
  env=server_config.get("env", None),
62
+ sampling_callback=sampling_callback,
63
+ elicitation_callback=elicitation_callback,
58
64
  )
59
65
 
60
66
  # Sandboxed connector
@@ -64,6 +70,8 @@ def create_connector_from_config(
64
70
  args=server_config["args"],
65
71
  env=server_config.get("env", None),
66
72
  e2b_options=sandbox_options,
73
+ sampling_callback=sampling_callback,
74
+ elicitation_callback=elicitation_callback,
67
75
  )
68
76
 
69
77
  # HTTP connector
@@ -72,6 +80,10 @@ def create_connector_from_config(
72
80
  base_url=server_config["url"],
73
81
  headers=server_config.get("headers", None),
74
82
  auth_token=server_config.get("auth_token", None),
83
+ timeout=server_config.get("timeout", 5),
84
+ sse_read_timeout=server_config.get("sse_read_timeout", 60 * 5),
85
+ sampling_callback=sampling_callback,
86
+ elicitation_callback=elicitation_callback,
75
87
  )
76
88
 
77
89
  # WebSocket connector
@@ -6,13 +6,17 @@ must implement.
6
6
  """
7
7
 
8
8
  from abc import ABC, abstractmethod
9
+ from datetime import timedelta
9
10
  from typing import Any
10
11
 
11
- from mcp import ClientSession
12
+ from mcp import ClientSession, Implementation
13
+ from mcp.client.session import ElicitationFnT, SamplingFnT
12
14
  from mcp.shared.exceptions import McpError
13
15
  from mcp.types import CallToolResult, GetPromptResult, Prompt, ReadResourceResult, Resource, Tool
14
16
  from pydantic import AnyUrl
15
17
 
18
+ import mcp_use
19
+
16
20
  from ..logging import logger
17
21
  from ..task_managers import ConnectionManager
18
22
 
@@ -23,7 +27,11 @@ class BaseConnector(ABC):
23
27
  This class defines the interface that all MCP connectors must implement.
24
28
  """
25
29
 
26
- def __init__(self):
30
+ def __init__(
31
+ self,
32
+ sampling_callback: SamplingFnT | None = None,
33
+ elicitation_callback: ElicitationFnT | None = None,
34
+ ):
27
35
  """Initialize base connector with common attributes."""
28
36
  self.client_session: ClientSession | None = None
29
37
  self._connection_manager: ConnectionManager | None = None
@@ -33,6 +41,17 @@ class BaseConnector(ABC):
33
41
  self._connected = False
34
42
  self._initialized = False # Track if client_session.initialize() has been called
35
43
  self.auto_reconnect = True # Whether to automatically reconnect on connection loss (not configurable for now)
44
+ self.sampling_callback = sampling_callback
45
+ self.elicitation_callback = elicitation_callback
46
+
47
+ @property
48
+ def client_info(self) -> Implementation:
49
+ """Get the client info for the connector."""
50
+ return Implementation(
51
+ name="mcp-use",
52
+ version=mcp_use.__version__,
53
+ url="https://github.com/mcp-use/mcp-use",
54
+ )
36
55
 
37
56
  @abstractmethod
38
57
  async def connect(self) -> None:
@@ -110,28 +129,41 @@ class BaseConnector(ABC):
110
129
 
111
130
  if server_capabilities.tools:
112
131
  # Get available tools directly from client session
113
- tools_result = await self.client_session.list_tools()
114
- self._tools = tools_result.tools if tools_result else []
132
+ try:
133
+ tools_result = await self.client_session.list_tools()
134
+ self._tools = tools_result.tools if tools_result else []
135
+ except Exception as e:
136
+ logger.error(f"Error listing tools: {e}")
137
+ self._tools = []
115
138
  else:
116
139
  self._tools = []
117
140
 
118
141
  if server_capabilities.resources:
119
142
  # Get available resources directly from client session
120
- resources_result = await self.client_session.list_resources()
121
- self._resources = resources_result.resources if resources_result else []
143
+ try:
144
+ resources_result = await self.client_session.list_resources()
145
+ self._resources = resources_result.resources if resources_result else []
146
+ except Exception as e:
147
+ logger.error(f"Error listing resources: {e}")
148
+ self._resources = []
122
149
  else:
123
150
  self._resources = []
124
151
 
125
152
  if server_capabilities.prompts:
126
153
  # Get available prompts directly from client session
127
- prompts_result = await self.client_session.list_prompts()
128
- self._prompts = prompts_result.prompts if prompts_result else []
154
+ try:
155
+ prompts_result = await self.client_session.list_prompts()
156
+ self._prompts = prompts_result.prompts if prompts_result else []
157
+ except Exception as e:
158
+ logger.error(f"Error listing prompts: {e}")
159
+ self._prompts = []
129
160
  else:
130
161
  self._prompts = []
131
162
 
132
163
  logger.debug(
133
164
  f"MCP session initialized with {len(self._tools)} tools, "
134
- "{len(self._resources)} resources, and {len(self._prompts)} prompts"
165
+ f"{len(self._resources)} resources, "
166
+ f"and {len(self._prompts)} prompts"
135
167
  )
136
168
 
137
169
  return result
@@ -235,12 +267,15 @@ class BaseConnector(ABC):
235
267
  "Connection to MCP server has been lost. Auto-reconnection is disabled. Please reconnect manually."
236
268
  )
237
269
 
238
- async def call_tool(self, name: str, arguments: dict[str, Any]) -> CallToolResult:
270
+ async def call_tool(
271
+ self, name: str, arguments: dict[str, Any], read_timeout_seconds: timedelta | None = None
272
+ ) -> CallToolResult:
239
273
  """Call an MCP tool with automatic reconnection handling.
240
274
 
241
275
  Args:
242
276
  name: The name of the tool to call.
243
277
  arguments: The arguments to pass to the tool.
278
+ read_timeout_seconds: timeout seconds when calling tool
244
279
 
245
280
  Returns:
246
281
  The result of the tool call.
@@ -254,7 +289,7 @@ class BaseConnector(ABC):
254
289
 
255
290
  logger.debug(f"Calling tool '{name}' with arguments: {arguments}")
256
291
  try:
257
- result = await self.client_session.call_tool(name, arguments)
292
+ result = await self.client_session.call_tool(name, arguments, read_timeout_seconds)
258
293
  logger.debug(f"Tool '{name}' called with result: {result}")
259
294
  return result
260
295
  except Exception as e:
@@ -7,9 +7,10 @@ through HTTP APIs with SSE or Streamable HTTP for transport.
7
7
 
8
8
  import httpx
9
9
  from mcp import ClientSession
10
+ from mcp.client.session import ElicitationFnT, SamplingFnT
10
11
 
11
12
  from ..logging import logger
12
- from ..task_managers import ConnectionManager, SseConnectionManager, StreamableHttpConnectionManager
13
+ from ..task_managers import SseConnectionManager, StreamableHttpConnectionManager
13
14
  from .base import BaseConnector
14
15
 
15
16
 
@@ -27,6 +28,8 @@ class HttpConnector(BaseConnector):
27
28
  headers: dict[str, str] | None = None,
28
29
  timeout: float = 5,
29
30
  sse_read_timeout: float = 60 * 5,
31
+ sampling_callback: SamplingFnT | None = None,
32
+ elicitation_callback: ElicitationFnT | None = None,
30
33
  ):
31
34
  """Initialize a new HTTP connector.
32
35
 
@@ -36,8 +39,10 @@ class HttpConnector(BaseConnector):
36
39
  headers: Optional additional headers.
37
40
  timeout: Timeout for HTTP operations in seconds.
38
41
  sse_read_timeout: Timeout for SSE read operations in seconds.
42
+ sampling_callback: Optional sampling callback.
43
+ elicitation_callback: Optional elicitation callback.
39
44
  """
40
- super().__init__()
45
+ super().__init__(sampling_callback=sampling_callback, elicitation_callback=elicitation_callback)
41
46
  self.base_url = base_url.rstrip("/")
42
47
  self.auth_token = auth_token
43
48
  self.headers = headers or {}
@@ -46,14 +51,6 @@ class HttpConnector(BaseConnector):
46
51
  self.timeout = timeout
47
52
  self.sse_read_timeout = sse_read_timeout
48
53
 
49
- async def _setup_client(self, connection_manager: ConnectionManager) -> None:
50
- """Set up the client session with the provided connection manager."""
51
-
52
- self._connection_manager = connection_manager
53
- read_stream, write_stream = await self._connection_manager.start()
54
- self.client_session = ClientSession(read_stream, write_stream, sampling_callback=None)
55
- await self.client_session.__aenter__()
56
-
57
54
  async def connect(self) -> None:
58
55
  """Establish a connection to the MCP implementation."""
59
56
  if self._connected:
@@ -76,7 +73,13 @@ class HttpConnector(BaseConnector):
76
73
  read_stream, write_stream = await connection_manager.start()
77
74
 
78
75
  # Test if this actually works by trying to create a client session and initialize it
79
- test_client = ClientSession(read_stream, write_stream, sampling_callback=None)
76
+ test_client = ClientSession(
77
+ read_stream,
78
+ write_stream,
79
+ sampling_callback=self.sampling_callback,
80
+ elicitation_callback=self.elicitation_callback,
81
+ client_info=self.client_info,
82
+ )
80
83
  await test_client.__aenter__()
81
84
 
82
85
  try:
@@ -154,7 +157,13 @@ class HttpConnector(BaseConnector):
154
157
  read_stream, write_stream = await connection_manager.start()
155
158
 
156
159
  # Create the client session for SSE
157
- self.client_session = ClientSession(read_stream, write_stream, sampling_callback=None)
160
+ self.client_session = ClientSession(
161
+ read_stream,
162
+ write_stream,
163
+ sampling_callback=self.sampling_callback,
164
+ elicitation_callback=self.elicitation_callback,
165
+ client_info=self.client_info,
166
+ )
158
167
  await self.client_session.__aenter__()
159
168
  self.transport_type = "SSE"
160
169
 
@@ -12,6 +12,7 @@ import time
12
12
 
13
13
  import aiohttp
14
14
  from mcp import ClientSession
15
+ from mcp.client.session import ElicitationFnT, SamplingFnT
15
16
 
16
17
  from ..logging import logger
17
18
  from ..task_managers import SseConnectionManager
@@ -50,6 +51,8 @@ class SandboxConnector(BaseConnector):
50
51
  e2b_options: SandboxOptions | None = None,
51
52
  timeout: float = 5,
52
53
  sse_read_timeout: float = 60 * 5,
54
+ sampling_callback: SamplingFnT | None = None,
55
+ elicitation_callback: ElicitationFnT | None = None,
53
56
  ):
54
57
  """Initialize a new sandbox connector.
55
58
 
@@ -61,8 +64,10 @@ class SandboxConnector(BaseConnector):
61
64
  See SandboxOptions for available options and defaults.
62
65
  timeout: Timeout for the sandbox process in seconds.
63
66
  sse_read_timeout: Timeout for the SSE connection in seconds.
67
+ sampling_callback: Optional sampling callback.
68
+ elicitation_callback: Optional elicitation callback.
64
69
  """
65
- super().__init__()
70
+ super().__init__(sampling_callback=sampling_callback, elicitation_callback=elicitation_callback)
66
71
  if Sandbox is None:
67
72
  raise ImportError(
68
73
  "E2B SDK (e2b-code-interpreter) not found. Please install it with "
@@ -217,7 +222,13 @@ class SandboxConnector(BaseConnector):
217
222
  read_stream, write_stream = await self._connection_manager.start()
218
223
 
219
224
  # Create the client session
220
- self.client_session = ClientSession(read_stream, write_stream, sampling_callback=None)
225
+ self.client_session = ClientSession(
226
+ read_stream,
227
+ write_stream,
228
+ sampling_callback=self.sampling_callback,
229
+ elicitation_callback=self.elicitation_callback,
230
+ client_info=self.client_info,
231
+ )
221
232
  await self.client_session.__aenter__()
222
233
 
223
234
  # Mark as connected