mcp-use 1.3.10__py3-none-any.whl → 1.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-use might be problematic. Click here for more details.
- mcp_use/adapters/langchain_adapter.py +7 -5
- mcp_use/agents/mcpagent.py +16 -4
- mcp_use/agents/prompts/templates.py +1 -10
- mcp_use/auth/__init__.py +6 -0
- mcp_use/auth/bearer.py +17 -0
- mcp_use/auth/oauth.py +625 -0
- mcp_use/auth/oauth_callback.py +214 -0
- mcp_use/client.py +1 -1
- mcp_use/config.py +2 -2
- mcp_use/connectors/base.py +17 -12
- mcp_use/connectors/http.py +117 -21
- mcp_use/connectors/websocket.py +14 -5
- mcp_use/exceptions.py +31 -0
- mcp_use/task_managers/base.py +13 -23
- mcp_use/task_managers/sse.py +5 -0
- mcp_use/task_managers/streamable_http.py +5 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.11.dist-info}/METADATA +19 -24
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.11.dist-info}/RECORD +21 -16
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.11.dist-info}/WHEEL +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.11.dist-info}/entry_points.txt +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.11.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""OAuth callback server implementation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import anyio
|
|
7
|
+
import uvicorn
|
|
8
|
+
from starlette.applications import Starlette
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
from starlette.responses import HTMLResponse
|
|
11
|
+
from starlette.routing import Route
|
|
12
|
+
|
|
13
|
+
from ..logging import logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class CallbackResponse:
|
|
18
|
+
"""Response data from OAuth callback."""
|
|
19
|
+
|
|
20
|
+
code: str | None = None # Authorization code (success)
|
|
21
|
+
state: str | None = None # CSRF protection token
|
|
22
|
+
error: str | None = None # Errors code (if failed)
|
|
23
|
+
error_description: str | None = None
|
|
24
|
+
error_uri: str | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OAuthCallbackServer:
|
|
28
|
+
"""Local server to handle OAuth callback."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, port: int):
|
|
31
|
+
"""Initialize the callback server.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
port: Port to listen on.
|
|
35
|
+
"""
|
|
36
|
+
self.port = port
|
|
37
|
+
self.redirect_uri: str | None = None
|
|
38
|
+
# Thread safe way to pass callback data to the main OAuth flow
|
|
39
|
+
self.response_queue: asyncio.Queue[CallbackResponse] = asyncio.Queue(maxsize=1)
|
|
40
|
+
self.server: uvicorn.Server | None = None
|
|
41
|
+
self._shutdown_event = anyio.Event()
|
|
42
|
+
|
|
43
|
+
async def start(self) -> str:
|
|
44
|
+
"""Start the callback server and return the redirect URI."""
|
|
45
|
+
app = self._create_app()
|
|
46
|
+
|
|
47
|
+
# Create the server
|
|
48
|
+
config = uvicorn.Config(
|
|
49
|
+
app,
|
|
50
|
+
host="127.0.0.1",
|
|
51
|
+
port=self.port,
|
|
52
|
+
log_level="error", # Suppress uvicorn logs
|
|
53
|
+
)
|
|
54
|
+
self.server = uvicorn.Server(config)
|
|
55
|
+
|
|
56
|
+
# Start server in background
|
|
57
|
+
self._server_task = asyncio.create_task(self.server.serve())
|
|
58
|
+
|
|
59
|
+
# Wait a moment for server to start
|
|
60
|
+
await asyncio.sleep(0.1)
|
|
61
|
+
|
|
62
|
+
self.redirect_uri = f"http://localhost:{self.port}/callback"
|
|
63
|
+
return self.redirect_uri
|
|
64
|
+
|
|
65
|
+
async def wait_for_code(self, timeout: float = 300) -> CallbackResponse:
|
|
66
|
+
"""Wait for the OAuth callback with a timeout (default 5 minutes)."""
|
|
67
|
+
try:
|
|
68
|
+
response = await asyncio.wait_for(self.response_queue.get(), timeout=timeout)
|
|
69
|
+
return response
|
|
70
|
+
except TimeoutError:
|
|
71
|
+
raise TimeoutError(f"OAuth callback not received within {timeout} seconds") from None
|
|
72
|
+
finally:
|
|
73
|
+
await self.shutdown()
|
|
74
|
+
|
|
75
|
+
async def shutdown(self):
|
|
76
|
+
"""Shutdown the callback server."""
|
|
77
|
+
self._shutdown_event.set()
|
|
78
|
+
if self.server:
|
|
79
|
+
self.server.should_exit = True
|
|
80
|
+
if hasattr(self, "_server_task"):
|
|
81
|
+
try:
|
|
82
|
+
await asyncio.wait_for(self._server_task, timeout=5.0)
|
|
83
|
+
except TimeoutError:
|
|
84
|
+
self._server_task.cancel()
|
|
85
|
+
|
|
86
|
+
def _create_app(self) -> Starlette:
|
|
87
|
+
"""Create the Starlette application."""
|
|
88
|
+
|
|
89
|
+
async def callback(request: Request) -> HTMLResponse:
|
|
90
|
+
"""Handle the OAuth callback."""
|
|
91
|
+
params = request.query_params
|
|
92
|
+
|
|
93
|
+
# Extract OAuth parameters
|
|
94
|
+
response = CallbackResponse(
|
|
95
|
+
code=params.get("code"),
|
|
96
|
+
state=params.get("state"),
|
|
97
|
+
error=params.get("error"),
|
|
98
|
+
error_description=params.get("error_description"),
|
|
99
|
+
error_uri=params.get("error_uri"),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Log the callback response
|
|
103
|
+
logger.debug(
|
|
104
|
+
f"OAuth callback received: error={response.error}, error_description={response.error_description}"
|
|
105
|
+
)
|
|
106
|
+
if response.code:
|
|
107
|
+
logger.debug("OAuth callback received authorization code")
|
|
108
|
+
else:
|
|
109
|
+
logger.error(f"OAuth callback error: {response.error} - {response.error_description}")
|
|
110
|
+
|
|
111
|
+
# Put response in queue
|
|
112
|
+
try:
|
|
113
|
+
self.response_queue.put_nowait(response)
|
|
114
|
+
except asyncio.QueueFull:
|
|
115
|
+
pass # Ignore if queue is already full
|
|
116
|
+
|
|
117
|
+
# Return success page
|
|
118
|
+
if response.code:
|
|
119
|
+
html = self._success_html()
|
|
120
|
+
else:
|
|
121
|
+
html = self._error_html(response.error, response.error_description)
|
|
122
|
+
|
|
123
|
+
return HTMLResponse(content=html)
|
|
124
|
+
|
|
125
|
+
routes = [Route("/callback", callback)]
|
|
126
|
+
return Starlette(routes=routes)
|
|
127
|
+
|
|
128
|
+
def _success_html(self) -> str:
|
|
129
|
+
"""HTML response for successful authorization."""
|
|
130
|
+
return """
|
|
131
|
+
<!DOCTYPE html>
|
|
132
|
+
<html>
|
|
133
|
+
<head>
|
|
134
|
+
<title>Authorization Successful</title>
|
|
135
|
+
<style>
|
|
136
|
+
body {
|
|
137
|
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
138
|
+
display: flex;
|
|
139
|
+
justify-content: center;
|
|
140
|
+
align-items: center;
|
|
141
|
+
height: 100vh;
|
|
142
|
+
margin: 0;
|
|
143
|
+
background-color: #f5f5f5;
|
|
144
|
+
}
|
|
145
|
+
.container {
|
|
146
|
+
text-align: center;
|
|
147
|
+
padding: 2rem;
|
|
148
|
+
background: white;
|
|
149
|
+
border-radius: 8px;
|
|
150
|
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
|
151
|
+
}
|
|
152
|
+
h1 { color: #22c55e; margin-bottom: 0.5rem; }
|
|
153
|
+
p { color: #666; margin-top: 0.5rem; }
|
|
154
|
+
.icon { font-size: 48px; margin-bottom: 1rem; }
|
|
155
|
+
</style>
|
|
156
|
+
</head>
|
|
157
|
+
<body>
|
|
158
|
+
<div class="container">
|
|
159
|
+
<div class="icon">✅</div>
|
|
160
|
+
<h1>Authorization Successful!</h1>
|
|
161
|
+
<p>You can now close this window and return to your application.</p>
|
|
162
|
+
</div>
|
|
163
|
+
<script>
|
|
164
|
+
// Auto-close after 3 seconds
|
|
165
|
+
setTimeout(() => window.close(), 3000);
|
|
166
|
+
</script>
|
|
167
|
+
</body>
|
|
168
|
+
</html>
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def _error_html(self, error: str | None, description: str | None) -> str:
|
|
172
|
+
"""HTML response for authorization error."""
|
|
173
|
+
error_msg = error or "Unknown error"
|
|
174
|
+
desc_msg = description or "Authorization was not completed successfully."
|
|
175
|
+
|
|
176
|
+
return f"""
|
|
177
|
+
<!DOCTYPE html>
|
|
178
|
+
<html>
|
|
179
|
+
<head>
|
|
180
|
+
<title>Authorization Error</title>
|
|
181
|
+
<style>
|
|
182
|
+
body {{
|
|
183
|
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
184
|
+
display: flex;
|
|
185
|
+
justify-content: center;
|
|
186
|
+
align-items: center;
|
|
187
|
+
height: 100vh;
|
|
188
|
+
margin: 0;
|
|
189
|
+
background-color: #f5f5f5;
|
|
190
|
+
}}
|
|
191
|
+
.container {{
|
|
192
|
+
text-align: center;
|
|
193
|
+
padding: 2rem;
|
|
194
|
+
background: white;
|
|
195
|
+
border-radius: 8px;
|
|
196
|
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
|
197
|
+
max-width: 500px;
|
|
198
|
+
}}
|
|
199
|
+
h1 {{ color: #ef4444; margin-bottom: 0.5rem; }}
|
|
200
|
+
.error {{ color: #dc2626; font-weight: 600; margin: 1rem 0; }}
|
|
201
|
+
.description {{ color: #666; margin-top: 0.5rem; }}
|
|
202
|
+
.icon {{ font-size: 48px; margin-bottom: 1rem; }}
|
|
203
|
+
</style>
|
|
204
|
+
</head>
|
|
205
|
+
<body>
|
|
206
|
+
<div class="container">
|
|
207
|
+
<div class="icon">❌</div>
|
|
208
|
+
<h1>Authorization Error</h1>
|
|
209
|
+
<p class="error">{error_msg}</p>
|
|
210
|
+
<p class="description">{desc_msg}</p>
|
|
211
|
+
</div>
|
|
212
|
+
</body>
|
|
213
|
+
</html>
|
|
214
|
+
"""
|
mcp_use/client.py
CHANGED
mcp_use/config.py
CHANGED
|
@@ -79,7 +79,7 @@ def create_connector_from_config(
|
|
|
79
79
|
return HttpConnector(
|
|
80
80
|
base_url=server_config["url"],
|
|
81
81
|
headers=server_config.get("headers", None),
|
|
82
|
-
|
|
82
|
+
auth=server_config.get("auth", {}),
|
|
83
83
|
timeout=server_config.get("timeout", 5),
|
|
84
84
|
sse_read_timeout=server_config.get("sse_read_timeout", 60 * 5),
|
|
85
85
|
sampling_callback=sampling_callback,
|
|
@@ -93,7 +93,7 @@ def create_connector_from_config(
|
|
|
93
93
|
return WebSocketConnector(
|
|
94
94
|
url=server_config["ws_url"],
|
|
95
95
|
headers=server_config.get("headers", None),
|
|
96
|
-
|
|
96
|
+
auth=server_config.get("auth", {}),
|
|
97
97
|
)
|
|
98
98
|
|
|
99
99
|
raise ValueError("Cannot determine connector type from config")
|
mcp_use/connectors/base.py
CHANGED
|
@@ -111,29 +111,34 @@ class BaseConnector(ABC):
|
|
|
111
111
|
"""Clean up all resources associated with this connector."""
|
|
112
112
|
errors = []
|
|
113
113
|
|
|
114
|
-
# First
|
|
115
|
-
|
|
114
|
+
# First stop the connection manager, this closes the ClientSession inside
|
|
115
|
+
# the same task where it was opened, avoiding cancel-scope mismatches.
|
|
116
|
+
if self._connection_manager:
|
|
116
117
|
try:
|
|
117
|
-
logger.debug("
|
|
118
|
-
await self.
|
|
118
|
+
logger.debug("Stopping connection manager")
|
|
119
|
+
await self._connection_manager.stop()
|
|
119
120
|
except Exception as e:
|
|
120
|
-
error_msg = f"Error
|
|
121
|
+
error_msg = f"Error stopping connection manager: {e}"
|
|
121
122
|
logger.warning(error_msg)
|
|
122
123
|
errors.append(error_msg)
|
|
123
124
|
finally:
|
|
124
|
-
self.
|
|
125
|
+
self._connection_manager = None
|
|
125
126
|
|
|
126
|
-
#
|
|
127
|
-
|
|
127
|
+
# Ensure the client_session reference is cleared (it should already be
|
|
128
|
+
# closed by the connection manager). Only attempt a direct __aexit__ if
|
|
129
|
+
# the connection manager did *not* exist, this covers edge-cases like
|
|
130
|
+
# failed connections where no manager was started.
|
|
131
|
+
if self.client_session:
|
|
128
132
|
try:
|
|
129
|
-
|
|
130
|
-
|
|
133
|
+
if not self._connection_manager:
|
|
134
|
+
logger.debug("Closing client session (no connection manager)")
|
|
135
|
+
await self.client_session.__aexit__(None, None, None)
|
|
131
136
|
except Exception as e:
|
|
132
|
-
error_msg = f"Error
|
|
137
|
+
error_msg = f"Error closing client session: {e}"
|
|
133
138
|
logger.warning(error_msg)
|
|
134
139
|
errors.append(error_msg)
|
|
135
140
|
finally:
|
|
136
|
-
self.
|
|
141
|
+
self.client_session = None
|
|
137
142
|
|
|
138
143
|
# Reset tools
|
|
139
144
|
self._tools = None
|
mcp_use/connectors/http.py
CHANGED
|
@@ -5,10 +5,17 @@ This module provides a connector for communicating with MCP implementations
|
|
|
5
5
|
through HTTP APIs with SSE or Streamable HTTP for transport.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
8
10
|
import httpx
|
|
9
11
|
from mcp import ClientSession
|
|
10
12
|
from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
|
13
|
+
from mcp.shared.exceptions import McpError
|
|
14
|
+
|
|
15
|
+
from mcp_use.auth.oauth import OAuthClientProvider
|
|
11
16
|
|
|
17
|
+
from ..auth import BearerAuth, OAuth
|
|
18
|
+
from ..exceptions import OAuthAuthenticationError, OAuthDiscoveryError
|
|
12
19
|
from ..logging import logger
|
|
13
20
|
from ..task_managers import SseConnectionManager, StreamableHttpConnectionManager
|
|
14
21
|
from .base import BaseConnector
|
|
@@ -24,10 +31,10 @@ class HttpConnector(BaseConnector):
|
|
|
24
31
|
def __init__(
|
|
25
32
|
self,
|
|
26
33
|
base_url: str,
|
|
27
|
-
auth_token: str | None = None,
|
|
28
34
|
headers: dict[str, str] | None = None,
|
|
29
35
|
timeout: float = 5,
|
|
30
36
|
sse_read_timeout: float = 60 * 5,
|
|
37
|
+
auth: str | dict[str, Any] | httpx.Auth | None = None,
|
|
31
38
|
sampling_callback: SamplingFnT | None = None,
|
|
32
39
|
elicitation_callback: ElicitationFnT | None = None,
|
|
33
40
|
message_handler: MessageHandlerFnT | None = None,
|
|
@@ -37,10 +44,13 @@ class HttpConnector(BaseConnector):
|
|
|
37
44
|
|
|
38
45
|
Args:
|
|
39
46
|
base_url: The base URL of the MCP HTTP API.
|
|
40
|
-
auth_token: Optional authentication token.
|
|
41
47
|
headers: Optional additional headers.
|
|
42
48
|
timeout: Timeout for HTTP operations in seconds.
|
|
43
49
|
sse_read_timeout: Timeout for SSE read operations in seconds.
|
|
50
|
+
auth: Authentication method - can be:
|
|
51
|
+
- A string token: Use Bearer token authentication
|
|
52
|
+
- A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
|
|
53
|
+
- An httpx.Auth object: Use custom authentication
|
|
44
54
|
sampling_callback: Optional sampling callback.
|
|
45
55
|
elicitation_callback: Optional elicitation callback.
|
|
46
56
|
"""
|
|
@@ -51,12 +61,57 @@ class HttpConnector(BaseConnector):
|
|
|
51
61
|
logging_callback=logging_callback,
|
|
52
62
|
)
|
|
53
63
|
self.base_url = base_url.rstrip("/")
|
|
54
|
-
self.auth_token = auth_token
|
|
55
64
|
self.headers = headers or {}
|
|
56
|
-
if auth_token:
|
|
57
|
-
self.headers["Authorization"] = f"Bearer {auth_token}"
|
|
58
65
|
self.timeout = timeout
|
|
59
66
|
self.sse_read_timeout = sse_read_timeout
|
|
67
|
+
self._auth: httpx.Auth | None = None
|
|
68
|
+
self._oauth: OAuth | None = None
|
|
69
|
+
|
|
70
|
+
# Handle authentication
|
|
71
|
+
if auth is not None:
|
|
72
|
+
self._set_auth(auth)
|
|
73
|
+
|
|
74
|
+
def _set_auth(self, auth: str | dict[str, Any] | httpx.Auth) -> None:
|
|
75
|
+
"""Set authentication method.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
auth: Authentication method - can be:
|
|
79
|
+
- A string token: Use Bearer token authentication
|
|
80
|
+
- A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
|
|
81
|
+
- An httpx.Auth object: Use custom authentication
|
|
82
|
+
"""
|
|
83
|
+
if isinstance(auth, str):
|
|
84
|
+
# Treat as bearer token
|
|
85
|
+
self._auth = BearerAuth(token=auth)
|
|
86
|
+
self.headers["Authorization"] = f"Bearer {auth}"
|
|
87
|
+
elif isinstance(auth, dict):
|
|
88
|
+
# Check if this is an OAuth provider configuration
|
|
89
|
+
if "oauth_provider" in auth:
|
|
90
|
+
oauth_provider = auth["oauth_provider"]
|
|
91
|
+
if isinstance(oauth_provider, dict):
|
|
92
|
+
oauth_provider = OAuthClientProvider(**oauth_provider)
|
|
93
|
+
self._oauth = OAuth(
|
|
94
|
+
self.base_url,
|
|
95
|
+
scope=auth.get("scope"),
|
|
96
|
+
client_id=auth.get("client_id"),
|
|
97
|
+
client_secret=auth.get("client_secret"),
|
|
98
|
+
callback_port=auth.get("callback_port"),
|
|
99
|
+
oauth_provider=oauth_provider,
|
|
100
|
+
)
|
|
101
|
+
self._oauth_config = auth
|
|
102
|
+
else:
|
|
103
|
+
self._oauth = OAuth(
|
|
104
|
+
self.base_url,
|
|
105
|
+
scope=auth.get("scope"),
|
|
106
|
+
client_id=auth.get("client_id"),
|
|
107
|
+
client_secret=auth.get("client_secret"),
|
|
108
|
+
callback_port=auth.get("callback_port"),
|
|
109
|
+
)
|
|
110
|
+
self._oauth_config = auth
|
|
111
|
+
elif isinstance(auth, httpx.Auth):
|
|
112
|
+
self._auth = auth
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Invalid auth type: {type(auth)}")
|
|
60
115
|
|
|
61
116
|
async def connect(self) -> None:
|
|
62
117
|
"""Establish a connection to the MCP implementation."""
|
|
@@ -64,6 +119,29 @@ class HttpConnector(BaseConnector):
|
|
|
64
119
|
logger.debug("Already connected to MCP implementation")
|
|
65
120
|
return
|
|
66
121
|
|
|
122
|
+
# Handle OAuth if needed
|
|
123
|
+
if self._oauth:
|
|
124
|
+
try:
|
|
125
|
+
# Create a temporary client for OAuth metadata discovery
|
|
126
|
+
async with httpx.AsyncClient() as client:
|
|
127
|
+
bearer_auth = await self._oauth.initialize(client)
|
|
128
|
+
if not bearer_auth:
|
|
129
|
+
# Need to perform OAuth flow
|
|
130
|
+
logger.info("OAuth authentication required")
|
|
131
|
+
bearer_auth = await self._oauth.authenticate()
|
|
132
|
+
|
|
133
|
+
# Update auth and headers
|
|
134
|
+
self._auth = bearer_auth
|
|
135
|
+
self.headers["Authorization"] = f"Bearer {bearer_auth.token.get_secret_value()}"
|
|
136
|
+
except OAuthDiscoveryError:
|
|
137
|
+
# OAuth discovery failed - it means server doesn't support OAuth default urls
|
|
138
|
+
logger.debug("OAuth discovery failed, continuing without initialization.")
|
|
139
|
+
self._oauth = None
|
|
140
|
+
self._auth = None
|
|
141
|
+
except OAuthAuthenticationError as e:
|
|
142
|
+
logger.error(f"OAuth initialization failed: {e}")
|
|
143
|
+
raise
|
|
144
|
+
|
|
67
145
|
# Try streamable HTTP first (new transport), fall back to SSE (old transport)
|
|
68
146
|
# This implements backwards compatibility per MCP specification
|
|
69
147
|
self.transport_type = None
|
|
@@ -73,7 +151,7 @@ class HttpConnector(BaseConnector):
|
|
|
73
151
|
# First, try the new streamable HTTP transport
|
|
74
152
|
logger.debug(f"Attempting streamable HTTP connection to: {self.base_url}")
|
|
75
153
|
connection_manager = StreamableHttpConnectionManager(
|
|
76
|
-
self.base_url, self.headers, self.timeout, self.sse_read_timeout
|
|
154
|
+
self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
|
|
77
155
|
)
|
|
78
156
|
|
|
79
157
|
# Test if this is a streamable HTTP server by attempting initialization
|
|
@@ -94,9 +172,9 @@ class HttpConnector(BaseConnector):
|
|
|
94
172
|
try:
|
|
95
173
|
# Try to initialize - this is where streamable HTTP vs SSE difference should show up
|
|
96
174
|
result = await test_client.initialize()
|
|
175
|
+
logger.debug(f"Streamable HTTP initialization result: {result}")
|
|
97
176
|
|
|
98
177
|
# If we get here, streamable HTTP works
|
|
99
|
-
|
|
100
178
|
self.client_session = test_client
|
|
101
179
|
self.transport_type = "streamable HTTP"
|
|
102
180
|
self._initialized = True # Mark as initialized since we just called initialize()
|
|
@@ -125,14 +203,28 @@ class HttpConnector(BaseConnector):
|
|
|
125
203
|
else:
|
|
126
204
|
self._prompts = []
|
|
127
205
|
|
|
128
|
-
|
|
206
|
+
# Only McpError is raised from client's initialization because
|
|
207
|
+
# exceptions are handled internally.
|
|
208
|
+
except McpError as mcp_error:
|
|
209
|
+
logger.error("MCP protocol error during initialization: %s", mcp_error.error)
|
|
129
210
|
# Clean up the test client
|
|
211
|
+
try:
|
|
212
|
+
await test_client.__aexit__(None, None, None)
|
|
213
|
+
except Exception:
|
|
214
|
+
pass
|
|
215
|
+
raise mcp_error
|
|
216
|
+
|
|
217
|
+
except Exception as init_error:
|
|
218
|
+
# This catches non-McpError exceptions, like a direct httpx timeout
|
|
219
|
+
# but in the most cases this won't happen. It's for safety.
|
|
130
220
|
try:
|
|
131
221
|
await test_client.__aexit__(None, None, None)
|
|
132
222
|
except Exception:
|
|
133
223
|
pass
|
|
134
224
|
raise init_error
|
|
135
225
|
|
|
226
|
+
# Exception from the inner try is propagated here and in
|
|
227
|
+
# the most cases is an McpError, so checking instances is useless
|
|
136
228
|
except Exception as streamable_error:
|
|
137
229
|
logger.debug(f"Streamable HTTP failed: {streamable_error}")
|
|
138
230
|
|
|
@@ -143,24 +235,17 @@ class HttpConnector(BaseConnector):
|
|
|
143
235
|
except Exception:
|
|
144
236
|
pass
|
|
145
237
|
|
|
146
|
-
#
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
should_fallback = True
|
|
151
|
-
elif "405 Method Not Allowed" in str(streamable_error) or "404 Not Found" in str(streamable_error):
|
|
152
|
-
should_fallback = True
|
|
153
|
-
else:
|
|
154
|
-
# For other errors, still try fallback but they might indicate
|
|
155
|
-
# real connectivity issues
|
|
156
|
-
should_fallback = True
|
|
238
|
+
# It doesn't make sense to check error types. Because client
|
|
239
|
+
# always return a McpError, if he can't reach the server
|
|
240
|
+
# because it's offline, or if it has an auth problem.
|
|
241
|
+
should_fallback = True
|
|
157
242
|
|
|
158
243
|
if should_fallback:
|
|
159
244
|
try:
|
|
160
245
|
# Fall back to the old SSE transport
|
|
161
246
|
logger.debug(f"Attempting SSE fallback connection to: {self.base_url}")
|
|
162
247
|
connection_manager = SseConnectionManager(
|
|
163
|
-
self.base_url, self.headers, self.timeout, self.sse_read_timeout
|
|
248
|
+
self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
|
|
164
249
|
)
|
|
165
250
|
|
|
166
251
|
read_stream, write_stream = await connection_manager.start()
|
|
@@ -178,7 +263,18 @@ class HttpConnector(BaseConnector):
|
|
|
178
263
|
await self.client_session.__aenter__()
|
|
179
264
|
self.transport_type = "SSE"
|
|
180
265
|
|
|
181
|
-
except Exception as sse_error:
|
|
266
|
+
except* Exception as sse_error:
|
|
267
|
+
# Get the exception from the ExceptionGroup, and here we will get the correct type.
|
|
268
|
+
sse_error = sse_error.exceptions[0]
|
|
269
|
+
if isinstance(sse_error, httpx.HTTPStatusError) and sse_error.response.status_code in [
|
|
270
|
+
401,
|
|
271
|
+
403,
|
|
272
|
+
407,
|
|
273
|
+
]:
|
|
274
|
+
raise OAuthAuthenticationError(
|
|
275
|
+
f"Server requires authentication (HTTP {sse_error.response.status_code}) "
|
|
276
|
+
"but auth failed. Please provide auth configuration manually."
|
|
277
|
+
) from sse_error
|
|
182
278
|
logger.error(
|
|
183
279
|
f"Both transport methods failed. Streamable HTTP: {streamable_error}, SSE: {sse_error}"
|
|
184
280
|
)
|
mcp_use/connectors/websocket.py
CHANGED
|
@@ -10,6 +10,7 @@ import json
|
|
|
10
10
|
import uuid
|
|
11
11
|
from typing import Any
|
|
12
12
|
|
|
13
|
+
import httpx
|
|
13
14
|
from mcp.types import Tool
|
|
14
15
|
from websockets import ClientConnection
|
|
15
16
|
|
|
@@ -28,21 +29,29 @@ class WebSocketConnector(BaseConnector):
|
|
|
28
29
|
def __init__(
|
|
29
30
|
self,
|
|
30
31
|
url: str,
|
|
31
|
-
auth_token: str | None = None,
|
|
32
32
|
headers: dict[str, str] | None = None,
|
|
33
|
+
auth: str | dict[str, Any] | httpx.Auth | None = None,
|
|
33
34
|
):
|
|
34
35
|
"""Initialize a new WebSocket connector.
|
|
35
36
|
|
|
36
37
|
Args:
|
|
37
38
|
url: The WebSocket URL to connect to.
|
|
38
|
-
auth_token: Optional authentication token.
|
|
39
39
|
headers: Optional additional headers.
|
|
40
|
+
auth: Authentication method - can be:
|
|
41
|
+
- A string token: Use Bearer token authentication
|
|
42
|
+
- A dict: Not supported for WebSocket (will log warning)
|
|
43
|
+
- An httpx.Auth object: Not supported for WebSocket (will log warning)
|
|
40
44
|
"""
|
|
41
45
|
self.url = url
|
|
42
|
-
self.auth_token = auth_token
|
|
43
46
|
self.headers = headers or {}
|
|
44
|
-
|
|
45
|
-
|
|
47
|
+
|
|
48
|
+
# Handle authentication - WebSocket only supports bearer tokens
|
|
49
|
+
# An auth field it's not needed
|
|
50
|
+
if auth is not None:
|
|
51
|
+
if isinstance(auth, str):
|
|
52
|
+
self.headers["Authorization"] = f"Bearer {auth}"
|
|
53
|
+
else:
|
|
54
|
+
logger.warning("WebSocket connector only supports bearer token authentication")
|
|
46
55
|
|
|
47
56
|
self.ws: ClientConnection | None = None
|
|
48
57
|
self._connection_manager: ConnectionManager | None = None
|
mcp_use/exceptions.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""MCP-use exceptions."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MCPError(Exception):
|
|
5
|
+
"""Base exception for MCP-use."""
|
|
6
|
+
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OAuthDiscoveryError(MCPError):
|
|
11
|
+
"""OAuth discovery auth metadata error"""
|
|
12
|
+
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OAuthAuthenticationError(MCPError):
|
|
17
|
+
"""OAuth authentication-related errors"""
|
|
18
|
+
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConnectionError(MCPError):
|
|
23
|
+
"""Connection-related errors."""
|
|
24
|
+
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConfigurationError(MCPError):
|
|
29
|
+
"""Configuration-related errors."""
|
|
30
|
+
|
|
31
|
+
pass
|
mcp_use/task_managers/base.py
CHANGED
|
@@ -22,13 +22,14 @@ class ConnectionManager(Generic[T], ABC):
|
|
|
22
22
|
used with MCP connectors.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
def __init__(self):
|
|
25
|
+
def __init__(self) -> None:
|
|
26
26
|
"""Initialize a new connection manager."""
|
|
27
27
|
self._ready_event = asyncio.Event()
|
|
28
28
|
self._done_event = asyncio.Event()
|
|
29
|
+
self._stop_event = asyncio.Event()
|
|
29
30
|
self._exception: Exception | None = None
|
|
30
31
|
self._connection: T | None = None
|
|
31
|
-
self._task: asyncio.Task | None = None
|
|
32
|
+
self._task: asyncio.Task[None] | None = None
|
|
32
33
|
|
|
33
34
|
@abstractmethod
|
|
34
35
|
async def _establish_connection(self) -> T:
|
|
@@ -86,20 +87,15 @@ class ConnectionManager(Generic[T], ABC):
|
|
|
86
87
|
|
|
87
88
|
async def stop(self) -> None:
|
|
88
89
|
"""Stop the connection manager and close the connection."""
|
|
90
|
+
# Signal stop to the connection task instead of cancelling it, avoids
|
|
91
|
+
# propagating CancelledError to unrelated tasks.
|
|
89
92
|
if self._task and not self._task.done():
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
await self._task
|
|
97
|
-
except asyncio.CancelledError:
|
|
98
|
-
logger.debug(f"{self.__class__.__name__} task cancelled successfully")
|
|
99
|
-
except Exception as e:
|
|
100
|
-
logger.warning(f"Error stopping {self.__class__.__name__} task: {e}")
|
|
101
|
-
|
|
102
|
-
# Wait for the connection to be done
|
|
93
|
+
logger.debug(f"Signaling stop to {self.__class__.__name__} task")
|
|
94
|
+
self._stop_event.set()
|
|
95
|
+
# Wait for it to finish gracefully
|
|
96
|
+
await self._task
|
|
97
|
+
|
|
98
|
+
# Ensure cleanup completed
|
|
103
99
|
await self._done_event.wait()
|
|
104
100
|
logger.debug(f"{self.__class__.__name__} task completed")
|
|
105
101
|
|
|
@@ -125,14 +121,8 @@ class ConnectionManager(Generic[T], ABC):
|
|
|
125
121
|
# Signal that the connection is ready
|
|
126
122
|
self._ready_event.set()
|
|
127
123
|
|
|
128
|
-
# Wait
|
|
129
|
-
|
|
130
|
-
# This keeps the connection open until cancelled
|
|
131
|
-
await asyncio.Event().wait()
|
|
132
|
-
except asyncio.CancelledError:
|
|
133
|
-
# Expected when stopping
|
|
134
|
-
logger.debug(f"{self.__class__.__name__} task received cancellation")
|
|
135
|
-
pass
|
|
124
|
+
# Wait until stop is requested
|
|
125
|
+
await self._stop_event.wait()
|
|
136
126
|
|
|
137
127
|
except Exception as e:
|
|
138
128
|
# Store the exception
|