mcp-use 1.3.10__py3-none-any.whl → 1.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-use might be problematic. Click here for more details.

@@ -5,11 +5,19 @@ This module provides a connector for communicating with MCP implementations
5
5
  through HTTP APIs with SSE or Streamable HTTP for transport.
6
6
  """
7
7
 
8
+ from typing import Any
9
+
8
10
  import httpx
9
11
  from mcp import ClientSession
10
12
  from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
13
+ from mcp.shared.exceptions import McpError
14
+
15
+ from mcp_use.auth.oauth import OAuthClientProvider
11
16
 
17
+ from ..auth import BearerAuth, OAuth
18
+ from ..exceptions import OAuthAuthenticationError, OAuthDiscoveryError
12
19
  from ..logging import logger
20
+ from ..middleware import CallbackClientSession, Middleware
13
21
  from ..task_managers import SseConnectionManager, StreamableHttpConnectionManager
14
22
  from .base import BaseConnector
15
23
 
@@ -24,23 +32,27 @@ class HttpConnector(BaseConnector):
24
32
  def __init__(
25
33
  self,
26
34
  base_url: str,
27
- auth_token: str | None = None,
28
35
  headers: dict[str, str] | None = None,
29
36
  timeout: float = 5,
30
37
  sse_read_timeout: float = 60 * 5,
38
+ auth: str | dict[str, Any] | httpx.Auth | None = None,
31
39
  sampling_callback: SamplingFnT | None = None,
32
40
  elicitation_callback: ElicitationFnT | None = None,
33
41
  message_handler: MessageHandlerFnT | None = None,
34
42
  logging_callback: LoggingFnT | None = None,
43
+ middleware: list[Middleware] | None = None,
35
44
  ):
36
45
  """Initialize a new HTTP connector.
37
46
 
38
47
  Args:
39
48
  base_url: The base URL of the MCP HTTP API.
40
- auth_token: Optional authentication token.
41
49
  headers: Optional additional headers.
42
50
  timeout: Timeout for HTTP operations in seconds.
43
51
  sse_read_timeout: Timeout for SSE read operations in seconds.
52
+ auth: Authentication method - can be:
53
+ - A string token: Use Bearer token authentication
54
+ - A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
55
+ - An httpx.Auth object: Use custom authentication
44
56
  sampling_callback: Optional sampling callback.
45
57
  elicitation_callback: Optional elicitation callback.
46
58
  """
@@ -49,14 +61,60 @@ class HttpConnector(BaseConnector):
49
61
  elicitation_callback=elicitation_callback,
50
62
  message_handler=message_handler,
51
63
  logging_callback=logging_callback,
64
+ middleware=middleware,
52
65
  )
53
66
  self.base_url = base_url.rstrip("/")
54
- self.auth_token = auth_token
55
67
  self.headers = headers or {}
56
- if auth_token:
57
- self.headers["Authorization"] = f"Bearer {auth_token}"
58
68
  self.timeout = timeout
59
69
  self.sse_read_timeout = sse_read_timeout
70
+ self._auth: httpx.Auth | None = None
71
+ self._oauth: OAuth | None = None
72
+
73
+ # Handle authentication
74
+ if auth is not None:
75
+ self._set_auth(auth)
76
+
77
+ def _set_auth(self, auth: str | dict[str, Any] | httpx.Auth) -> None:
78
+ """Set authentication method.
79
+
80
+ Args:
81
+ auth: Authentication method - can be:
82
+ - A string token: Use Bearer token authentication
83
+ - A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
84
+ - An httpx.Auth object: Use custom authentication
85
+ """
86
+ if isinstance(auth, str):
87
+ # Treat as bearer token
88
+ self._auth = BearerAuth(token=auth)
89
+ self.headers["Authorization"] = f"Bearer {auth}"
90
+ elif isinstance(auth, dict):
91
+ # Check if this is an OAuth provider configuration
92
+ if "oauth_provider" in auth:
93
+ oauth_provider = auth["oauth_provider"]
94
+ if isinstance(oauth_provider, dict):
95
+ oauth_provider = OAuthClientProvider(**oauth_provider)
96
+ self._oauth = OAuth(
97
+ self.base_url,
98
+ scope=auth.get("scope"),
99
+ client_id=auth.get("client_id"),
100
+ client_secret=auth.get("client_secret"),
101
+ callback_port=auth.get("callback_port"),
102
+ oauth_provider=oauth_provider,
103
+ )
104
+ self._oauth_config = auth
105
+ else:
106
+ self._oauth = OAuth(
107
+ self.base_url,
108
+ scope=auth.get("scope"),
109
+ client_id=auth.get("client_id"),
110
+ client_secret=auth.get("client_secret"),
111
+ callback_port=auth.get("callback_port"),
112
+ )
113
+ self._oauth_config = auth
114
+ elif isinstance(auth, httpx.Auth):
115
+ self._auth = auth
116
+ else:
117
+ raise ValueError(f"Invalid auth type: {type(auth)}")
60
118
 
61
119
  async def connect(self) -> None:
62
120
  """Establish a connection to the MCP implementation."""
@@ -64,6 +122,29 @@ class HttpConnector(BaseConnector):
64
122
  logger.debug("Already connected to MCP implementation")
65
123
  return
66
124
 
125
+ # Handle OAuth if needed
126
+ if self._oauth:
127
+ try:
128
+ # Create a temporary client for OAuth metadata discovery
129
+ async with httpx.AsyncClient() as client:
130
+ bearer_auth = await self._oauth.initialize(client)
131
+ if not bearer_auth:
132
+ # Need to perform OAuth flow
133
+ logger.info("OAuth authentication required")
134
+ bearer_auth = await self._oauth.authenticate()
135
+
136
+ # Update auth and headers
137
+ self._auth = bearer_auth
138
+ self.headers["Authorization"] = f"Bearer {bearer_auth.token.get_secret_value()}"
139
+ except OAuthDiscoveryError:
140
+ # OAuth discovery failed - it means server doesn't support OAuth default urls
141
+ logger.debug("OAuth discovery failed, continuing without initialization.")
142
+ self._oauth = None
143
+ self._auth = None
144
+ except OAuthAuthenticationError as e:
145
+ logger.error(f"OAuth initialization failed: {e}")
146
+ raise
147
+
67
148
  # Try streamable HTTP first (new transport), fall back to SSE (old transport)
68
149
  # This implements backwards compatibility per MCP specification
69
150
  self.transport_type = None
@@ -73,14 +154,14 @@ class HttpConnector(BaseConnector):
73
154
  # First, try the new streamable HTTP transport
74
155
  logger.debug(f"Attempting streamable HTTP connection to: {self.base_url}")
75
156
  connection_manager = StreamableHttpConnectionManager(
76
- self.base_url, self.headers, self.timeout, self.sse_read_timeout
157
+ self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
77
158
  )
78
159
 
79
160
  # Test if this is a streamable HTTP server by attempting initialization
80
161
  read_stream, write_stream = await connection_manager.start()
81
162
 
82
163
  # Test if this actually works by trying to create a client session and initialize it
83
- test_client = ClientSession(
164
+ raw_test_client = ClientSession(
84
165
  read_stream,
85
166
  write_stream,
86
167
  sampling_callback=self.sampling_callback,
@@ -89,14 +170,17 @@ class HttpConnector(BaseConnector):
89
170
  logging_callback=self.logging_callback,
90
171
  client_info=self.client_info,
91
172
  )
92
- await test_client.__aenter__()
173
+ await raw_test_client.__aenter__()
174
+
175
+ # Wrap test client with middleware temporarily for testing
176
+ test_client = CallbackClientSession(raw_test_client, self.public_identifier, self.middleware_manager)
93
177
 
94
178
  try:
95
179
  # Try to initialize - this is where streamable HTTP vs SSE difference should show up
96
180
  result = await test_client.initialize()
181
+ logger.debug(f"Streamable HTTP initialization result: {result}")
97
182
 
98
183
  # If we get here, streamable HTTP works
99
-
100
184
  self.client_session = test_client
101
185
  self.transport_type = "streamable HTTP"
102
186
  self._initialized = True # Mark as initialized since we just called initialize()
@@ -125,14 +209,28 @@ class HttpConnector(BaseConnector):
125
209
  else:
126
210
  self._prompts = []
127
211
 
128
- except Exception as init_error:
212
+ # Only McpError is raised from client's initialization because
213
+ # exceptions are handled internally.
214
+ except McpError as mcp_error:
215
+ logger.error("MCP protocol error during initialization: %s", mcp_error.error)
129
216
  # Clean up the test client
130
217
  try:
131
- await test_client.__aexit__(None, None, None)
218
+ await raw_test_client.__aexit__(None, None, None)
219
+ except Exception:
220
+ pass
221
+ raise mcp_error
222
+
223
+ except Exception as init_error:
224
+ # This catches non-McpError exceptions, like a direct httpx timeout
225
+ # but in the most cases this won't happen. It's for safety.
226
+ try:
227
+ await raw_test_client.__aexit__(None, None, None)
132
228
  except Exception:
133
229
  pass
134
230
  raise init_error
135
231
 
232
+ # Exception from the inner try is propagated here and in
233
+ # the most cases is an McpError, so checking instances is useless
136
234
  except Exception as streamable_error:
137
235
  logger.debug(f"Streamable HTTP failed: {streamable_error}")
138
236
 
@@ -143,30 +241,23 @@ class HttpConnector(BaseConnector):
143
241
  except Exception:
144
242
  pass
145
243
 
146
- # Check if this is a 4xx error that indicates we should try SSE fallback
147
- should_fallback = False
148
- if isinstance(streamable_error, httpx.HTTPStatusError):
149
- if streamable_error.response.status_code in [404, 405]:
150
- should_fallback = True
151
- elif "405 Method Not Allowed" in str(streamable_error) or "404 Not Found" in str(streamable_error):
152
- should_fallback = True
153
- else:
154
- # For other errors, still try fallback but they might indicate
155
- # real connectivity issues
156
- should_fallback = True
244
+ # It doesn't make sense to check error types. Because client
245
+ # always return a McpError, if he can't reach the server
246
+ # because it's offline, or if it has an auth problem.
247
+ should_fallback = True
157
248
 
158
249
  if should_fallback:
159
250
  try:
160
251
  # Fall back to the old SSE transport
161
252
  logger.debug(f"Attempting SSE fallback connection to: {self.base_url}")
162
253
  connection_manager = SseConnectionManager(
163
- self.base_url, self.headers, self.timeout, self.sse_read_timeout
254
+ self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
164
255
  )
165
256
 
166
257
  read_stream, write_stream = await connection_manager.start()
167
258
 
168
259
  # Create the client session for SSE
169
- self.client_session = ClientSession(
260
+ raw_client_session = ClientSession(
170
261
  read_stream,
171
262
  write_stream,
172
263
  sampling_callback=self.sampling_callback,
@@ -175,10 +266,26 @@ class HttpConnector(BaseConnector):
175
266
  logging_callback=self.logging_callback,
176
267
  client_info=self.client_info,
177
268
  )
178
- await self.client_session.__aenter__()
269
+ await raw_client_session.__aenter__()
270
+
271
+ # Wrap with middleware
272
+ self.client_session = CallbackClientSession(
273
+ raw_client_session, self.public_identifier, self.middleware_manager
274
+ )
179
275
  self.transport_type = "SSE"
180
276
 
181
- except Exception as sse_error:
277
+ except* Exception as sse_error:
278
+ # Get the exception from the ExceptionGroup, and here we will get the correct type.
279
+ sse_error = sse_error.exceptions[0]
280
+ if isinstance(sse_error, httpx.HTTPStatusError) and sse_error.response.status_code in [
281
+ 401,
282
+ 403,
283
+ 407,
284
+ ]:
285
+ raise OAuthAuthenticationError(
286
+ f"Server requires authentication (HTTP {sse_error.response.status_code}) "
287
+ "but auth failed. Please provide auth configuration manually."
288
+ ) from sse_error
182
289
  logger.error(
183
290
  f"Both transport methods failed. Streamable HTTP: {streamable_error}, SSE: {sse_error}"
184
291
  )
@@ -194,4 +301,5 @@ class HttpConnector(BaseConnector):
194
301
  @property
195
302
  def public_identifier(self) -> str:
196
303
  """Get the identifier for the connector."""
197
- return {"type": self.transport_type, "base_url": self.base_url}
304
+ transport_type = getattr(self, "transport_type", "http")
305
+ return f"{transport_type}:{self.base_url}"
@@ -15,6 +15,7 @@ from mcp import ClientSession
15
15
  from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
16
16
 
17
17
  from ..logging import logger
18
+ from ..middleware import CallbackClientSession, Middleware
18
19
  from ..task_managers import SseConnectionManager
19
20
 
20
21
  # Import E2B SDK components (optional dependency)
@@ -52,6 +53,7 @@ class SandboxConnector(BaseConnector):
52
53
  elicitation_callback: ElicitationFnT | None = None,
53
54
  message_handler: MessageHandlerFnT | None = None,
54
55
  logging_callback: LoggingFnT | None = None,
56
+ middleware: list[Middleware] | None = None,
55
57
  ):
56
58
  """Initialize a new sandbox connector.
57
59
 
@@ -71,6 +73,7 @@ class SandboxConnector(BaseConnector):
71
73
  elicitation_callback=elicitation_callback,
72
74
  message_handler=message_handler,
73
75
  logging_callback=logging_callback,
76
+ middleware=middleware,
74
77
  )
75
78
  if Sandbox is None:
76
79
  raise ImportError(
@@ -226,7 +229,7 @@ class SandboxConnector(BaseConnector):
226
229
  read_stream, write_stream = await self._connection_manager.start()
227
230
 
228
231
  # Create the client session
229
- self.client_session = ClientSession(
232
+ raw_client_session = ClientSession(
230
233
  read_stream,
231
234
  write_stream,
232
235
  sampling_callback=self.sampling_callback,
@@ -235,7 +238,12 @@ class SandboxConnector(BaseConnector):
235
238
  logging_callback=self.logging_callback,
236
239
  client_info=self.client_info,
237
240
  )
238
- await self.client_session.__aenter__()
241
+ await raw_client_session.__aenter__()
242
+
243
+ # Wrap with middleware
244
+ self.client_session = CallbackClientSession(
245
+ raw_client_session, self.public_identifier, self.middleware_manager
246
+ )
239
247
 
240
248
  # Mark as connected
241
249
  self._connected = True
@@ -299,4 +307,5 @@ class SandboxConnector(BaseConnector):
299
307
  @property
300
308
  def public_identifier(self) -> str:
301
309
  """Get the identifier for the connector."""
302
- return {"type": "sandbox", "command": self.user_command, "args": self.user_args}
310
+ args_str = " ".join(self.user_args) if self.user_args else ""
311
+ return f"sandbox:{self.user_command} {args_str}".strip()
@@ -11,6 +11,7 @@ from mcp import ClientSession, StdioServerParameters
11
11
  from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
12
12
 
13
13
  from ..logging import logger
14
+ from ..middleware import CallbackClientSession, Middleware
14
15
  from ..task_managers import StdioConnectionManager
15
16
  from .base import BaseConnector
16
17
 
@@ -33,6 +34,7 @@ class StdioConnector(BaseConnector):
33
34
  elicitation_callback: ElicitationFnT | None = None,
34
35
  message_handler: MessageHandlerFnT | None = None,
35
36
  logging_callback: LoggingFnT | None = None,
37
+ middleware: list[Middleware] | None = None,
36
38
  ):
37
39
  """Initialize a new stdio connector.
38
40
 
@@ -49,6 +51,7 @@ class StdioConnector(BaseConnector):
49
51
  elicitation_callback=elicitation_callback,
50
52
  message_handler=message_handler,
51
53
  logging_callback=logging_callback,
54
+ middleware=middleware,
52
55
  )
53
56
  self.command = command
54
57
  self.args = args or [] # Ensure args is never None
@@ -71,7 +74,7 @@ class StdioConnector(BaseConnector):
71
74
  read_stream, write_stream = await self._connection_manager.start()
72
75
 
73
76
  # Create the client session
74
- self.client_session = ClientSession(
77
+ raw_client_session = ClientSession(
75
78
  read_stream,
76
79
  write_stream,
77
80
  sampling_callback=self.sampling_callback,
@@ -80,7 +83,12 @@ class StdioConnector(BaseConnector):
80
83
  logging_callback=self.logging_callback,
81
84
  client_info=self.client_info,
82
85
  )
83
- await self.client_session.__aenter__()
86
+ await raw_client_session.__aenter__()
87
+
88
+ # Wrap with middleware
89
+ self.client_session = CallbackClientSession(
90
+ raw_client_session, self.public_identifier, self.middleware_manager
91
+ )
84
92
 
85
93
  # Mark as connected
86
94
  self._connected = True
@@ -98,4 +106,4 @@ class StdioConnector(BaseConnector):
98
106
  @property
99
107
  def public_identifier(self) -> str:
100
108
  """Get the identifier for the connector."""
101
- return {"type": "stdio", "command&args": f"{self.command} {' '.join(self.args)}"}
109
+ return f"stdio:{self.command} {' '.join(self.args)}"
@@ -10,6 +10,7 @@ import json
10
10
  import uuid
11
11
  from typing import Any
12
12
 
13
+ import httpx
13
14
  from mcp.types import Tool
14
15
  from websockets import ClientConnection
15
16
 
@@ -28,21 +29,29 @@ class WebSocketConnector(BaseConnector):
28
29
  def __init__(
29
30
  self,
30
31
  url: str,
31
- auth_token: str | None = None,
32
32
  headers: dict[str, str] | None = None,
33
+ auth: str | dict[str, Any] | httpx.Auth | None = None,
33
34
  ):
34
35
  """Initialize a new WebSocket connector.
35
36
 
36
37
  Args:
37
38
  url: The WebSocket URL to connect to.
38
- auth_token: Optional authentication token.
39
39
  headers: Optional additional headers.
40
+ auth: Authentication method - can be:
41
+ - A string token: Use Bearer token authentication
42
+ - A dict: Not supported for WebSocket (will log warning)
43
+ - An httpx.Auth object: Not supported for WebSocket (will log warning)
40
44
  """
41
45
  self.url = url
42
- self.auth_token = auth_token
43
46
  self.headers = headers or {}
44
- if auth_token:
45
- self.headers["Authorization"] = f"Bearer {auth_token}"
47
+
48
+ # Handle authentication - WebSocket only supports bearer tokens
49
+ # An auth field it's not needed
50
+ if auth is not None:
51
+ if isinstance(auth, str):
52
+ self.headers["Authorization"] = f"Bearer {auth}"
53
+ else:
54
+ logger.warning("WebSocket connector only supports bearer token authentication")
46
55
 
47
56
  self.ws: ClientConnection | None = None
48
57
  self._connection_manager: ConnectionManager | None = None
@@ -245,4 +254,4 @@ class WebSocketConnector(BaseConnector):
245
254
  @property
246
255
  def public_identifier(self) -> str:
247
256
  """Get the identifier for the connector."""
248
- return {"type": "websocket", "url": self.url}
257
+ return f"websocket:{self.url}"
mcp_use/exceptions.py ADDED
@@ -0,0 +1,31 @@
1
+ """MCP-use exceptions."""
2
+
3
+
4
+ class MCPError(Exception):
5
+ """Base exception for MCP-use."""
6
+
7
+ pass
8
+
9
+
10
+ class OAuthDiscoveryError(MCPError):
11
+ """OAuth discovery auth metadata error"""
12
+
13
+ pass
14
+
15
+
16
+ class OAuthAuthenticationError(MCPError):
17
+ """OAuth authentication-related errors"""
18
+
19
+ pass
20
+
21
+
22
+ class ConnectionError(MCPError):
23
+ """Connection-related errors."""
24
+
25
+ pass
26
+
27
+
28
+ class ConfigurationError(MCPError):
29
+ """Configuration-related errors."""
30
+
31
+ pass
@@ -0,0 +1,50 @@
1
+ """
2
+ Middleware package for MCP request interception and processing.
3
+
4
+ This package provides a flexible middleware system for intercepting MCP requests
5
+ and responses, enabling logging, metrics, caching, and custom processing.
6
+
7
+ The middleware system follows an Express.js-style pattern where middleware functions
8
+ receive a request context and a call_next function, allowing them to process both
9
+ incoming requests and outgoing responses.
10
+ """
11
+
12
+ # Core middleware implementation
13
+ # Default logging middleware
14
+ from .logging import default_logging_middleware
15
+
16
+ # Metrics middleware classes
17
+ from .metrics import (
18
+ CombinedAnalyticsMiddleware,
19
+ ErrorTrackingMiddleware,
20
+ MetricsMiddleware,
21
+ PerformanceMetricsMiddleware,
22
+ )
23
+
24
+ # Protocol types for type-safe middleware
25
+ from .middleware import (
26
+ CallbackClientSession,
27
+ MCPResponseContext,
28
+ Middleware,
29
+ MiddlewareContext,
30
+ MiddlewareManager,
31
+ NextFunctionT,
32
+ )
33
+
34
+ __all__ = [
35
+ # Core types and classes
36
+ "MiddlewareContext",
37
+ "MCPResponseContext",
38
+ "Middleware",
39
+ "MiddlewareManager",
40
+ "CallbackClientSession",
41
+ # Protocol types
42
+ "NextFunctionT",
43
+ # Default logging middleware
44
+ "default_logging_middleware",
45
+ # Metrics middleware
46
+ "MetricsMiddleware",
47
+ "PerformanceMetricsMiddleware",
48
+ "ErrorTrackingMiddleware",
49
+ "CombinedAnalyticsMiddleware",
50
+ ]
@@ -0,0 +1,31 @@
1
+ """
2
+ Default logging middleware for MCP requests.
3
+
4
+ Simple debug logging for all MCP requests and responses.
5
+ """
6
+
7
+ import time
8
+ from typing import Any
9
+
10
+ from ..logging import logger
11
+ from .middleware import Middleware, MiddlewareContext, NextFunctionT
12
+
13
+
14
+ class LoggingMiddleware(Middleware):
15
+ """Default logging middleware that logs all MCP requests and responses with logger.debug."""
16
+
17
+ async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
18
+ """Logs all MCP requests and responses with logger.debug."""
19
+ logger.debug(f"[{context.id}] {context.connection_id} -> {context.method}")
20
+ try:
21
+ result = await call_next(context)
22
+ duration = time.time() - context.timestamp
23
+ logger.debug(f"[{context.id}] {context.connection_id} <- {context.method} ({duration:.3f}s)")
24
+ return result
25
+ except Exception as e:
26
+ duration = time.time() - context.timestamp
27
+ logger.debug(f"[{context.id}] {context.connection_id} <- {context.method} FAILED ({duration:.3f}s): {e}")
28
+ raise
29
+
30
+
31
+ default_logging_middleware = LoggingMiddleware()