mcp-use 1.3.10__py3-none-any.whl → 1.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-use might be problematic. Click here for more details.
- mcp_use/adapters/langchain_adapter.py +9 -52
- mcp_use/agents/mcpagent.py +88 -37
- mcp_use/agents/prompts/templates.py +1 -10
- mcp_use/agents/remote.py +154 -128
- mcp_use/auth/__init__.py +6 -0
- mcp_use/auth/bearer.py +17 -0
- mcp_use/auth/oauth.py +625 -0
- mcp_use/auth/oauth_callback.py +214 -0
- mcp_use/client.py +25 -1
- mcp_use/config.py +7 -2
- mcp_use/connectors/base.py +25 -12
- mcp_use/connectors/http.py +135 -27
- mcp_use/connectors/sandbox.py +12 -3
- mcp_use/connectors/stdio.py +11 -3
- mcp_use/connectors/websocket.py +15 -6
- mcp_use/exceptions.py +31 -0
- mcp_use/middleware/__init__.py +50 -0
- mcp_use/middleware/logging.py +31 -0
- mcp_use/middleware/metrics.py +314 -0
- mcp_use/middleware/middleware.py +262 -0
- mcp_use/task_managers/base.py +13 -23
- mcp_use/task_managers/sse.py +5 -0
- mcp_use/task_managers/streamable_http.py +5 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/METADATA +21 -25
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/RECORD +28 -19
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/WHEEL +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/entry_points.txt +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/licenses/LICENSE +0 -0
mcp_use/connectors/http.py
CHANGED
|
@@ -5,11 +5,19 @@ This module provides a connector for communicating with MCP implementations
|
|
|
5
5
|
through HTTP APIs with SSE or Streamable HTTP for transport.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
8
10
|
import httpx
|
|
9
11
|
from mcp import ClientSession
|
|
10
12
|
from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
|
13
|
+
from mcp.shared.exceptions import McpError
|
|
14
|
+
|
|
15
|
+
from mcp_use.auth.oauth import OAuthClientProvider
|
|
11
16
|
|
|
17
|
+
from ..auth import BearerAuth, OAuth
|
|
18
|
+
from ..exceptions import OAuthAuthenticationError, OAuthDiscoveryError
|
|
12
19
|
from ..logging import logger
|
|
20
|
+
from ..middleware import CallbackClientSession, Middleware
|
|
13
21
|
from ..task_managers import SseConnectionManager, StreamableHttpConnectionManager
|
|
14
22
|
from .base import BaseConnector
|
|
15
23
|
|
|
@@ -24,23 +32,27 @@ class HttpConnector(BaseConnector):
|
|
|
24
32
|
def __init__(
|
|
25
33
|
self,
|
|
26
34
|
base_url: str,
|
|
27
|
-
auth_token: str | None = None,
|
|
28
35
|
headers: dict[str, str] | None = None,
|
|
29
36
|
timeout: float = 5,
|
|
30
37
|
sse_read_timeout: float = 60 * 5,
|
|
38
|
+
auth: str | dict[str, Any] | httpx.Auth | None = None,
|
|
31
39
|
sampling_callback: SamplingFnT | None = None,
|
|
32
40
|
elicitation_callback: ElicitationFnT | None = None,
|
|
33
41
|
message_handler: MessageHandlerFnT | None = None,
|
|
34
42
|
logging_callback: LoggingFnT | None = None,
|
|
43
|
+
middleware: list[Middleware] | None = None,
|
|
35
44
|
):
|
|
36
45
|
"""Initialize a new HTTP connector.
|
|
37
46
|
|
|
38
47
|
Args:
|
|
39
48
|
base_url: The base URL of the MCP HTTP API.
|
|
40
|
-
auth_token: Optional authentication token.
|
|
41
49
|
headers: Optional additional headers.
|
|
42
50
|
timeout: Timeout for HTTP operations in seconds.
|
|
43
51
|
sse_read_timeout: Timeout for SSE read operations in seconds.
|
|
52
|
+
auth: Authentication method - can be:
|
|
53
|
+
- A string token: Use Bearer token authentication
|
|
54
|
+
- A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
|
|
55
|
+
- An httpx.Auth object: Use custom authentication
|
|
44
56
|
sampling_callback: Optional sampling callback.
|
|
45
57
|
elicitation_callback: Optional elicitation callback.
|
|
46
58
|
"""
|
|
@@ -49,14 +61,60 @@ class HttpConnector(BaseConnector):
|
|
|
49
61
|
elicitation_callback=elicitation_callback,
|
|
50
62
|
message_handler=message_handler,
|
|
51
63
|
logging_callback=logging_callback,
|
|
64
|
+
middleware=middleware,
|
|
52
65
|
)
|
|
53
66
|
self.base_url = base_url.rstrip("/")
|
|
54
|
-
self.auth_token = auth_token
|
|
55
67
|
self.headers = headers or {}
|
|
56
|
-
if auth_token:
|
|
57
|
-
self.headers["Authorization"] = f"Bearer {auth_token}"
|
|
58
68
|
self.timeout = timeout
|
|
59
69
|
self.sse_read_timeout = sse_read_timeout
|
|
70
|
+
self._auth: httpx.Auth | None = None
|
|
71
|
+
self._oauth: OAuth | None = None
|
|
72
|
+
|
|
73
|
+
# Handle authentication
|
|
74
|
+
if auth is not None:
|
|
75
|
+
self._set_auth(auth)
|
|
76
|
+
|
|
77
|
+
def _set_auth(self, auth: str | dict[str, Any] | httpx.Auth) -> None:
|
|
78
|
+
"""Set authentication method.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
auth: Authentication method - can be:
|
|
82
|
+
- A string token: Use Bearer token authentication
|
|
83
|
+
- A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
|
|
84
|
+
- An httpx.Auth object: Use custom authentication
|
|
85
|
+
"""
|
|
86
|
+
if isinstance(auth, str):
|
|
87
|
+
# Treat as bearer token
|
|
88
|
+
self._auth = BearerAuth(token=auth)
|
|
89
|
+
self.headers["Authorization"] = f"Bearer {auth}"
|
|
90
|
+
elif isinstance(auth, dict):
|
|
91
|
+
# Check if this is an OAuth provider configuration
|
|
92
|
+
if "oauth_provider" in auth:
|
|
93
|
+
oauth_provider = auth["oauth_provider"]
|
|
94
|
+
if isinstance(oauth_provider, dict):
|
|
95
|
+
oauth_provider = OAuthClientProvider(**oauth_provider)
|
|
96
|
+
self._oauth = OAuth(
|
|
97
|
+
self.base_url,
|
|
98
|
+
scope=auth.get("scope"),
|
|
99
|
+
client_id=auth.get("client_id"),
|
|
100
|
+
client_secret=auth.get("client_secret"),
|
|
101
|
+
callback_port=auth.get("callback_port"),
|
|
102
|
+
oauth_provider=oauth_provider,
|
|
103
|
+
)
|
|
104
|
+
self._oauth_config = auth
|
|
105
|
+
else:
|
|
106
|
+
self._oauth = OAuth(
|
|
107
|
+
self.base_url,
|
|
108
|
+
scope=auth.get("scope"),
|
|
109
|
+
client_id=auth.get("client_id"),
|
|
110
|
+
client_secret=auth.get("client_secret"),
|
|
111
|
+
callback_port=auth.get("callback_port"),
|
|
112
|
+
)
|
|
113
|
+
self._oauth_config = auth
|
|
114
|
+
elif isinstance(auth, httpx.Auth):
|
|
115
|
+
self._auth = auth
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(f"Invalid auth type: {type(auth)}")
|
|
60
118
|
|
|
61
119
|
async def connect(self) -> None:
|
|
62
120
|
"""Establish a connection to the MCP implementation."""
|
|
@@ -64,6 +122,29 @@ class HttpConnector(BaseConnector):
|
|
|
64
122
|
logger.debug("Already connected to MCP implementation")
|
|
65
123
|
return
|
|
66
124
|
|
|
125
|
+
# Handle OAuth if needed
|
|
126
|
+
if self._oauth:
|
|
127
|
+
try:
|
|
128
|
+
# Create a temporary client for OAuth metadata discovery
|
|
129
|
+
async with httpx.AsyncClient() as client:
|
|
130
|
+
bearer_auth = await self._oauth.initialize(client)
|
|
131
|
+
if not bearer_auth:
|
|
132
|
+
# Need to perform OAuth flow
|
|
133
|
+
logger.info("OAuth authentication required")
|
|
134
|
+
bearer_auth = await self._oauth.authenticate()
|
|
135
|
+
|
|
136
|
+
# Update auth and headers
|
|
137
|
+
self._auth = bearer_auth
|
|
138
|
+
self.headers["Authorization"] = f"Bearer {bearer_auth.token.get_secret_value()}"
|
|
139
|
+
except OAuthDiscoveryError:
|
|
140
|
+
# OAuth discovery failed - it means server doesn't support OAuth default urls
|
|
141
|
+
logger.debug("OAuth discovery failed, continuing without initialization.")
|
|
142
|
+
self._oauth = None
|
|
143
|
+
self._auth = None
|
|
144
|
+
except OAuthAuthenticationError as e:
|
|
145
|
+
logger.error(f"OAuth initialization failed: {e}")
|
|
146
|
+
raise
|
|
147
|
+
|
|
67
148
|
# Try streamable HTTP first (new transport), fall back to SSE (old transport)
|
|
68
149
|
# This implements backwards compatibility per MCP specification
|
|
69
150
|
self.transport_type = None
|
|
@@ -73,14 +154,14 @@ class HttpConnector(BaseConnector):
|
|
|
73
154
|
# First, try the new streamable HTTP transport
|
|
74
155
|
logger.debug(f"Attempting streamable HTTP connection to: {self.base_url}")
|
|
75
156
|
connection_manager = StreamableHttpConnectionManager(
|
|
76
|
-
self.base_url, self.headers, self.timeout, self.sse_read_timeout
|
|
157
|
+
self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
|
|
77
158
|
)
|
|
78
159
|
|
|
79
160
|
# Test if this is a streamable HTTP server by attempting initialization
|
|
80
161
|
read_stream, write_stream = await connection_manager.start()
|
|
81
162
|
|
|
82
163
|
# Test if this actually works by trying to create a client session and initialize it
|
|
83
|
-
|
|
164
|
+
raw_test_client = ClientSession(
|
|
84
165
|
read_stream,
|
|
85
166
|
write_stream,
|
|
86
167
|
sampling_callback=self.sampling_callback,
|
|
@@ -89,14 +170,17 @@ class HttpConnector(BaseConnector):
|
|
|
89
170
|
logging_callback=self.logging_callback,
|
|
90
171
|
client_info=self.client_info,
|
|
91
172
|
)
|
|
92
|
-
await
|
|
173
|
+
await raw_test_client.__aenter__()
|
|
174
|
+
|
|
175
|
+
# Wrap test client with middleware temporarily for testing
|
|
176
|
+
test_client = CallbackClientSession(raw_test_client, self.public_identifier, self.middleware_manager)
|
|
93
177
|
|
|
94
178
|
try:
|
|
95
179
|
# Try to initialize - this is where streamable HTTP vs SSE difference should show up
|
|
96
180
|
result = await test_client.initialize()
|
|
181
|
+
logger.debug(f"Streamable HTTP initialization result: {result}")
|
|
97
182
|
|
|
98
183
|
# If we get here, streamable HTTP works
|
|
99
|
-
|
|
100
184
|
self.client_session = test_client
|
|
101
185
|
self.transport_type = "streamable HTTP"
|
|
102
186
|
self._initialized = True # Mark as initialized since we just called initialize()
|
|
@@ -125,14 +209,28 @@ class HttpConnector(BaseConnector):
|
|
|
125
209
|
else:
|
|
126
210
|
self._prompts = []
|
|
127
211
|
|
|
128
|
-
|
|
212
|
+
# Only McpError is raised from client's initialization because
|
|
213
|
+
# exceptions are handled internally.
|
|
214
|
+
except McpError as mcp_error:
|
|
215
|
+
logger.error("MCP protocol error during initialization: %s", mcp_error.error)
|
|
129
216
|
# Clean up the test client
|
|
130
217
|
try:
|
|
131
|
-
await
|
|
218
|
+
await raw_test_client.__aexit__(None, None, None)
|
|
219
|
+
except Exception:
|
|
220
|
+
pass
|
|
221
|
+
raise mcp_error
|
|
222
|
+
|
|
223
|
+
except Exception as init_error:
|
|
224
|
+
# This catches non-McpError exceptions, like a direct httpx timeout
|
|
225
|
+
# but in the most cases this won't happen. It's for safety.
|
|
226
|
+
try:
|
|
227
|
+
await raw_test_client.__aexit__(None, None, None)
|
|
132
228
|
except Exception:
|
|
133
229
|
pass
|
|
134
230
|
raise init_error
|
|
135
231
|
|
|
232
|
+
# Exception from the inner try is propagated here and in
|
|
233
|
+
# the most cases is an McpError, so checking instances is useless
|
|
136
234
|
except Exception as streamable_error:
|
|
137
235
|
logger.debug(f"Streamable HTTP failed: {streamable_error}")
|
|
138
236
|
|
|
@@ -143,30 +241,23 @@ class HttpConnector(BaseConnector):
|
|
|
143
241
|
except Exception:
|
|
144
242
|
pass
|
|
145
243
|
|
|
146
|
-
#
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
should_fallback = True
|
|
151
|
-
elif "405 Method Not Allowed" in str(streamable_error) or "404 Not Found" in str(streamable_error):
|
|
152
|
-
should_fallback = True
|
|
153
|
-
else:
|
|
154
|
-
# For other errors, still try fallback but they might indicate
|
|
155
|
-
# real connectivity issues
|
|
156
|
-
should_fallback = True
|
|
244
|
+
# It doesn't make sense to check error types. Because client
|
|
245
|
+
# always return a McpError, if he can't reach the server
|
|
246
|
+
# because it's offline, or if it has an auth problem.
|
|
247
|
+
should_fallback = True
|
|
157
248
|
|
|
158
249
|
if should_fallback:
|
|
159
250
|
try:
|
|
160
251
|
# Fall back to the old SSE transport
|
|
161
252
|
logger.debug(f"Attempting SSE fallback connection to: {self.base_url}")
|
|
162
253
|
connection_manager = SseConnectionManager(
|
|
163
|
-
self.base_url, self.headers, self.timeout, self.sse_read_timeout
|
|
254
|
+
self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
|
|
164
255
|
)
|
|
165
256
|
|
|
166
257
|
read_stream, write_stream = await connection_manager.start()
|
|
167
258
|
|
|
168
259
|
# Create the client session for SSE
|
|
169
|
-
|
|
260
|
+
raw_client_session = ClientSession(
|
|
170
261
|
read_stream,
|
|
171
262
|
write_stream,
|
|
172
263
|
sampling_callback=self.sampling_callback,
|
|
@@ -175,10 +266,26 @@ class HttpConnector(BaseConnector):
|
|
|
175
266
|
logging_callback=self.logging_callback,
|
|
176
267
|
client_info=self.client_info,
|
|
177
268
|
)
|
|
178
|
-
await
|
|
269
|
+
await raw_client_session.__aenter__()
|
|
270
|
+
|
|
271
|
+
# Wrap with middleware
|
|
272
|
+
self.client_session = CallbackClientSession(
|
|
273
|
+
raw_client_session, self.public_identifier, self.middleware_manager
|
|
274
|
+
)
|
|
179
275
|
self.transport_type = "SSE"
|
|
180
276
|
|
|
181
|
-
except Exception as sse_error:
|
|
277
|
+
except* Exception as sse_error:
|
|
278
|
+
# Get the exception from the ExceptionGroup, and here we will get the correct type.
|
|
279
|
+
sse_error = sse_error.exceptions[0]
|
|
280
|
+
if isinstance(sse_error, httpx.HTTPStatusError) and sse_error.response.status_code in [
|
|
281
|
+
401,
|
|
282
|
+
403,
|
|
283
|
+
407,
|
|
284
|
+
]:
|
|
285
|
+
raise OAuthAuthenticationError(
|
|
286
|
+
f"Server requires authentication (HTTP {sse_error.response.status_code}) "
|
|
287
|
+
"but auth failed. Please provide auth configuration manually."
|
|
288
|
+
) from sse_error
|
|
182
289
|
logger.error(
|
|
183
290
|
f"Both transport methods failed. Streamable HTTP: {streamable_error}, SSE: {sse_error}"
|
|
184
291
|
)
|
|
@@ -194,4 +301,5 @@ class HttpConnector(BaseConnector):
|
|
|
194
301
|
@property
|
|
195
302
|
def public_identifier(self) -> str:
|
|
196
303
|
"""Get the identifier for the connector."""
|
|
197
|
-
|
|
304
|
+
transport_type = getattr(self, "transport_type", "http")
|
|
305
|
+
return f"{transport_type}:{self.base_url}"
|
mcp_use/connectors/sandbox.py
CHANGED
|
@@ -15,6 +15,7 @@ from mcp import ClientSession
|
|
|
15
15
|
from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
|
16
16
|
|
|
17
17
|
from ..logging import logger
|
|
18
|
+
from ..middleware import CallbackClientSession, Middleware
|
|
18
19
|
from ..task_managers import SseConnectionManager
|
|
19
20
|
|
|
20
21
|
# Import E2B SDK components (optional dependency)
|
|
@@ -52,6 +53,7 @@ class SandboxConnector(BaseConnector):
|
|
|
52
53
|
elicitation_callback: ElicitationFnT | None = None,
|
|
53
54
|
message_handler: MessageHandlerFnT | None = None,
|
|
54
55
|
logging_callback: LoggingFnT | None = None,
|
|
56
|
+
middleware: list[Middleware] | None = None,
|
|
55
57
|
):
|
|
56
58
|
"""Initialize a new sandbox connector.
|
|
57
59
|
|
|
@@ -71,6 +73,7 @@ class SandboxConnector(BaseConnector):
|
|
|
71
73
|
elicitation_callback=elicitation_callback,
|
|
72
74
|
message_handler=message_handler,
|
|
73
75
|
logging_callback=logging_callback,
|
|
76
|
+
middleware=middleware,
|
|
74
77
|
)
|
|
75
78
|
if Sandbox is None:
|
|
76
79
|
raise ImportError(
|
|
@@ -226,7 +229,7 @@ class SandboxConnector(BaseConnector):
|
|
|
226
229
|
read_stream, write_stream = await self._connection_manager.start()
|
|
227
230
|
|
|
228
231
|
# Create the client session
|
|
229
|
-
|
|
232
|
+
raw_client_session = ClientSession(
|
|
230
233
|
read_stream,
|
|
231
234
|
write_stream,
|
|
232
235
|
sampling_callback=self.sampling_callback,
|
|
@@ -235,7 +238,12 @@ class SandboxConnector(BaseConnector):
|
|
|
235
238
|
logging_callback=self.logging_callback,
|
|
236
239
|
client_info=self.client_info,
|
|
237
240
|
)
|
|
238
|
-
await
|
|
241
|
+
await raw_client_session.__aenter__()
|
|
242
|
+
|
|
243
|
+
# Wrap with middleware
|
|
244
|
+
self.client_session = CallbackClientSession(
|
|
245
|
+
raw_client_session, self.public_identifier, self.middleware_manager
|
|
246
|
+
)
|
|
239
247
|
|
|
240
248
|
# Mark as connected
|
|
241
249
|
self._connected = True
|
|
@@ -299,4 +307,5 @@ class SandboxConnector(BaseConnector):
|
|
|
299
307
|
@property
|
|
300
308
|
def public_identifier(self) -> str:
|
|
301
309
|
"""Get the identifier for the connector."""
|
|
302
|
-
|
|
310
|
+
args_str = " ".join(self.user_args) if self.user_args else ""
|
|
311
|
+
return f"sandbox:{self.user_command} {args_str}".strip()
|
mcp_use/connectors/stdio.py
CHANGED
|
@@ -11,6 +11,7 @@ from mcp import ClientSession, StdioServerParameters
|
|
|
11
11
|
from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
|
12
12
|
|
|
13
13
|
from ..logging import logger
|
|
14
|
+
from ..middleware import CallbackClientSession, Middleware
|
|
14
15
|
from ..task_managers import StdioConnectionManager
|
|
15
16
|
from .base import BaseConnector
|
|
16
17
|
|
|
@@ -33,6 +34,7 @@ class StdioConnector(BaseConnector):
|
|
|
33
34
|
elicitation_callback: ElicitationFnT | None = None,
|
|
34
35
|
message_handler: MessageHandlerFnT | None = None,
|
|
35
36
|
logging_callback: LoggingFnT | None = None,
|
|
37
|
+
middleware: list[Middleware] | None = None,
|
|
36
38
|
):
|
|
37
39
|
"""Initialize a new stdio connector.
|
|
38
40
|
|
|
@@ -49,6 +51,7 @@ class StdioConnector(BaseConnector):
|
|
|
49
51
|
elicitation_callback=elicitation_callback,
|
|
50
52
|
message_handler=message_handler,
|
|
51
53
|
logging_callback=logging_callback,
|
|
54
|
+
middleware=middleware,
|
|
52
55
|
)
|
|
53
56
|
self.command = command
|
|
54
57
|
self.args = args or [] # Ensure args is never None
|
|
@@ -71,7 +74,7 @@ class StdioConnector(BaseConnector):
|
|
|
71
74
|
read_stream, write_stream = await self._connection_manager.start()
|
|
72
75
|
|
|
73
76
|
# Create the client session
|
|
74
|
-
|
|
77
|
+
raw_client_session = ClientSession(
|
|
75
78
|
read_stream,
|
|
76
79
|
write_stream,
|
|
77
80
|
sampling_callback=self.sampling_callback,
|
|
@@ -80,7 +83,12 @@ class StdioConnector(BaseConnector):
|
|
|
80
83
|
logging_callback=self.logging_callback,
|
|
81
84
|
client_info=self.client_info,
|
|
82
85
|
)
|
|
83
|
-
await
|
|
86
|
+
await raw_client_session.__aenter__()
|
|
87
|
+
|
|
88
|
+
# Wrap with middleware
|
|
89
|
+
self.client_session = CallbackClientSession(
|
|
90
|
+
raw_client_session, self.public_identifier, self.middleware_manager
|
|
91
|
+
)
|
|
84
92
|
|
|
85
93
|
# Mark as connected
|
|
86
94
|
self._connected = True
|
|
@@ -98,4 +106,4 @@ class StdioConnector(BaseConnector):
|
|
|
98
106
|
@property
|
|
99
107
|
def public_identifier(self) -> str:
|
|
100
108
|
"""Get the identifier for the connector."""
|
|
101
|
-
return
|
|
109
|
+
return f"stdio:{self.command} {' '.join(self.args)}"
|
mcp_use/connectors/websocket.py
CHANGED
|
@@ -10,6 +10,7 @@ import json
|
|
|
10
10
|
import uuid
|
|
11
11
|
from typing import Any
|
|
12
12
|
|
|
13
|
+
import httpx
|
|
13
14
|
from mcp.types import Tool
|
|
14
15
|
from websockets import ClientConnection
|
|
15
16
|
|
|
@@ -28,21 +29,29 @@ class WebSocketConnector(BaseConnector):
|
|
|
28
29
|
def __init__(
|
|
29
30
|
self,
|
|
30
31
|
url: str,
|
|
31
|
-
auth_token: str | None = None,
|
|
32
32
|
headers: dict[str, str] | None = None,
|
|
33
|
+
auth: str | dict[str, Any] | httpx.Auth | None = None,
|
|
33
34
|
):
|
|
34
35
|
"""Initialize a new WebSocket connector.
|
|
35
36
|
|
|
36
37
|
Args:
|
|
37
38
|
url: The WebSocket URL to connect to.
|
|
38
|
-
auth_token: Optional authentication token.
|
|
39
39
|
headers: Optional additional headers.
|
|
40
|
+
auth: Authentication method - can be:
|
|
41
|
+
- A string token: Use Bearer token authentication
|
|
42
|
+
- A dict: Not supported for WebSocket (will log warning)
|
|
43
|
+
- An httpx.Auth object: Not supported for WebSocket (will log warning)
|
|
40
44
|
"""
|
|
41
45
|
self.url = url
|
|
42
|
-
self.auth_token = auth_token
|
|
43
46
|
self.headers = headers or {}
|
|
44
|
-
|
|
45
|
-
|
|
47
|
+
|
|
48
|
+
# Handle authentication - WebSocket only supports bearer tokens
|
|
49
|
+
# An auth field it's not needed
|
|
50
|
+
if auth is not None:
|
|
51
|
+
if isinstance(auth, str):
|
|
52
|
+
self.headers["Authorization"] = f"Bearer {auth}"
|
|
53
|
+
else:
|
|
54
|
+
logger.warning("WebSocket connector only supports bearer token authentication")
|
|
46
55
|
|
|
47
56
|
self.ws: ClientConnection | None = None
|
|
48
57
|
self._connection_manager: ConnectionManager | None = None
|
|
@@ -245,4 +254,4 @@ class WebSocketConnector(BaseConnector):
|
|
|
245
254
|
@property
|
|
246
255
|
def public_identifier(self) -> str:
|
|
247
256
|
"""Get the identifier for the connector."""
|
|
248
|
-
return
|
|
257
|
+
return f"websocket:{self.url}"
|
mcp_use/exceptions.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""MCP-use exceptions."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MCPError(Exception):
|
|
5
|
+
"""Base exception for MCP-use."""
|
|
6
|
+
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OAuthDiscoveryError(MCPError):
|
|
11
|
+
"""OAuth discovery auth metadata error"""
|
|
12
|
+
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OAuthAuthenticationError(MCPError):
|
|
17
|
+
"""OAuth authentication-related errors"""
|
|
18
|
+
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConnectionError(MCPError):
|
|
23
|
+
"""Connection-related errors."""
|
|
24
|
+
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConfigurationError(MCPError):
|
|
29
|
+
"""Configuration-related errors."""
|
|
30
|
+
|
|
31
|
+
pass
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Middleware package for MCP request interception and processing.
|
|
3
|
+
|
|
4
|
+
This package provides a flexible middleware system for intercepting MCP requests
|
|
5
|
+
and responses, enabling logging, metrics, caching, and custom processing.
|
|
6
|
+
|
|
7
|
+
The middleware system follows an Express.js-style pattern where middleware functions
|
|
8
|
+
receive a request context and a call_next function, allowing them to process both
|
|
9
|
+
incoming requests and outgoing responses.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
# Core middleware implementation
|
|
13
|
+
# Default logging middleware
|
|
14
|
+
from .logging import default_logging_middleware
|
|
15
|
+
|
|
16
|
+
# Metrics middleware classes
|
|
17
|
+
from .metrics import (
|
|
18
|
+
CombinedAnalyticsMiddleware,
|
|
19
|
+
ErrorTrackingMiddleware,
|
|
20
|
+
MetricsMiddleware,
|
|
21
|
+
PerformanceMetricsMiddleware,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Protocol types for type-safe middleware
|
|
25
|
+
from .middleware import (
|
|
26
|
+
CallbackClientSession,
|
|
27
|
+
MCPResponseContext,
|
|
28
|
+
Middleware,
|
|
29
|
+
MiddlewareContext,
|
|
30
|
+
MiddlewareManager,
|
|
31
|
+
NextFunctionT,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
# Core types and classes
|
|
36
|
+
"MiddlewareContext",
|
|
37
|
+
"MCPResponseContext",
|
|
38
|
+
"Middleware",
|
|
39
|
+
"MiddlewareManager",
|
|
40
|
+
"CallbackClientSession",
|
|
41
|
+
# Protocol types
|
|
42
|
+
"NextFunctionT",
|
|
43
|
+
# Default logging middleware
|
|
44
|
+
"default_logging_middleware",
|
|
45
|
+
# Metrics middleware
|
|
46
|
+
"MetricsMiddleware",
|
|
47
|
+
"PerformanceMetricsMiddleware",
|
|
48
|
+
"ErrorTrackingMiddleware",
|
|
49
|
+
"CombinedAnalyticsMiddleware",
|
|
50
|
+
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Default logging middleware for MCP requests.
|
|
3
|
+
|
|
4
|
+
Simple debug logging for all MCP requests and responses.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from ..logging import logger
|
|
11
|
+
from .middleware import Middleware, MiddlewareContext, NextFunctionT
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LoggingMiddleware(Middleware):
|
|
15
|
+
"""Default logging middleware that logs all MCP requests and responses with logger.debug."""
|
|
16
|
+
|
|
17
|
+
async def on_request(self, context: MiddlewareContext[Any], call_next: NextFunctionT) -> Any:
|
|
18
|
+
"""Logs all MCP requests and responses with logger.debug."""
|
|
19
|
+
logger.debug(f"[{context.id}] {context.connection_id} -> {context.method}")
|
|
20
|
+
try:
|
|
21
|
+
result = await call_next(context)
|
|
22
|
+
duration = time.time() - context.timestamp
|
|
23
|
+
logger.debug(f"[{context.id}] {context.connection_id} <- {context.method} ({duration:.3f}s)")
|
|
24
|
+
return result
|
|
25
|
+
except Exception as e:
|
|
26
|
+
duration = time.time() - context.timestamp
|
|
27
|
+
logger.debug(f"[{context.id}] {context.connection_id} <- {context.method} FAILED ({duration:.3f}s): {e}")
|
|
28
|
+
raise
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
default_logging_middleware = LoggingMiddleware()
|