mcp-use 1.3.10__py3-none-any.whl → 1.3.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-use might be problematic. Click here for more details.
- mcp_use/adapters/langchain_adapter.py +9 -52
- mcp_use/agents/mcpagent.py +88 -37
- mcp_use/agents/prompts/templates.py +1 -10
- mcp_use/agents/remote.py +154 -128
- mcp_use/auth/__init__.py +6 -0
- mcp_use/auth/bearer.py +17 -0
- mcp_use/auth/oauth.py +625 -0
- mcp_use/auth/oauth_callback.py +214 -0
- mcp_use/client.py +25 -1
- mcp_use/config.py +7 -2
- mcp_use/connectors/base.py +25 -12
- mcp_use/connectors/http.py +135 -27
- mcp_use/connectors/sandbox.py +12 -3
- mcp_use/connectors/stdio.py +11 -3
- mcp_use/connectors/websocket.py +15 -6
- mcp_use/exceptions.py +31 -0
- mcp_use/middleware/__init__.py +50 -0
- mcp_use/middleware/logging.py +31 -0
- mcp_use/middleware/metrics.py +314 -0
- mcp_use/middleware/middleware.py +262 -0
- mcp_use/task_managers/base.py +13 -23
- mcp_use/task_managers/sse.py +5 -0
- mcp_use/task_managers/streamable_http.py +5 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/METADATA +21 -25
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/RECORD +28 -19
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/WHEEL +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/entry_points.txt +0 -0
- {mcp_use-1.3.10.dist-info → mcp_use-1.3.12.dist-info}/licenses/LICENSE +0 -0
mcp_use/auth/oauth.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
1
|
+
"""OAuth authentication support for MCP clients."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import secrets
|
|
5
|
+
import webbrowser
|
|
6
|
+
from datetime import UTC, datetime, timedelta
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
13
|
+
from authlib.oauth2 import OAuth2Error
|
|
14
|
+
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
|
15
|
+
|
|
16
|
+
from ..exceptions import OAuthAuthenticationError, OAuthDiscoveryError
|
|
17
|
+
from ..logging import logger
|
|
18
|
+
from .bearer import BearerAuth
|
|
19
|
+
from .oauth_callback import OAuthCallbackServer
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ServerOAuthMetadata(BaseModel):
|
|
23
|
+
"""OAuth metadata from MCP server with flexible field support.
|
|
24
|
+
It is essentially a configuration that tells MCP client:
|
|
25
|
+
|
|
26
|
+
- Where to send users for authorization
|
|
27
|
+
- Where to exchange the codes for tokens
|
|
28
|
+
- Which OAuth features are supported
|
|
29
|
+
- Where to register new users with DCR"""
|
|
30
|
+
|
|
31
|
+
issuer: HttpUrl # The OAuth server's identity
|
|
32
|
+
authorization_endpoint: HttpUrl # URL with endpoint for client auth
|
|
33
|
+
token_endpoint: HttpUrl # URL with endpoint for tokens' exchange
|
|
34
|
+
userinfo_endpoint: HttpUrl | None = None
|
|
35
|
+
revocation_endpoint: HttpUrl | None = None
|
|
36
|
+
introspection_endpoint: HttpUrl | None = None
|
|
37
|
+
registration_endpoint: HttpUrl | None = None # Endpoint for DCR
|
|
38
|
+
jwks_uri: HttpUrl | None = None
|
|
39
|
+
response_types_supported: list[str] = Field(default_factory=lambda: ["code"])
|
|
40
|
+
subject_types_supported: list[str] = Field(default_factory=lambda: ["public"])
|
|
41
|
+
id_token_signing_alg_values_supported: list[str] = Field(default_factory=lambda: ["RS256"])
|
|
42
|
+
scopes_supported: list[str] | None = None # Which permissions are supported
|
|
43
|
+
token_endpoint_auth_methods_supported: list[str] = Field(default_factory=lambda: ["client_secret_basic"])
|
|
44
|
+
claims_supported: list[str] | None = None
|
|
45
|
+
code_challenge_methods_supported: list[str] | None = None
|
|
46
|
+
|
|
47
|
+
class Config:
|
|
48
|
+
extra = "allow" # Allow additional fields
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class OAuthClientProvider(BaseModel):
|
|
52
|
+
"""OAuth client provider configuration for a specific server.
|
|
53
|
+
|
|
54
|
+
This contains all the information needed to authenticate with an OAuth server
|
|
55
|
+
without needing to discover metadata or register clients dynamically."""
|
|
56
|
+
|
|
57
|
+
id: str # Unique identifier
|
|
58
|
+
display_name: str
|
|
59
|
+
metadata: ServerOAuthMetadata | dict[str, Any]
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def oauth_metadata(self) -> ServerOAuthMetadata:
|
|
63
|
+
"""Get OAuth metadata as ServerOAuthMetadata instance."""
|
|
64
|
+
if isinstance(self.metadata, dict):
|
|
65
|
+
return ServerOAuthMetadata(**self.metadata)
|
|
66
|
+
return self.metadata
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TokenData(BaseModel):
|
|
70
|
+
"""OAuth token data.
|
|
71
|
+
|
|
72
|
+
This is the information received after
|
|
73
|
+
successfull authentication"""
|
|
74
|
+
|
|
75
|
+
access_token: str # Actual credential used for requests
|
|
76
|
+
token_type: str = "Bearer"
|
|
77
|
+
expires_at: float | None = None
|
|
78
|
+
refresh_token: str | None = None
|
|
79
|
+
scope: str | None = None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ClientRegistrationResponse(BaseModel):
|
|
83
|
+
"""Dynamic Client Registration response.
|
|
84
|
+
|
|
85
|
+
It represents the response from an OAuth server
|
|
86
|
+
when you dinamically register a new OAuth client."""
|
|
87
|
+
|
|
88
|
+
client_id: str
|
|
89
|
+
client_secret: str | None = None
|
|
90
|
+
client_id_issued_at: int | None = None
|
|
91
|
+
client_secret_expires_at: int | None = None
|
|
92
|
+
redirect_uris: list[str] | None = None # Where auth server should redirect after auth
|
|
93
|
+
grant_types: list[str] | None = None # Which oauth flows it uses
|
|
94
|
+
response_types: list[str] | None = None
|
|
95
|
+
client_name: str | None = None
|
|
96
|
+
token_endpoint_auth_method: str | None = None
|
|
97
|
+
|
|
98
|
+
class Config:
|
|
99
|
+
extra = "allow" # Allow additional fields from server
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class FileTokenStorage:
|
|
103
|
+
"""File-based token storage.
|
|
104
|
+
|
|
105
|
+
It's responsible for:
|
|
106
|
+
|
|
107
|
+
- Saving OAuth tokens to disk after auth
|
|
108
|
+
- Loading saved tokens when the app restarts
|
|
109
|
+
- Deleting tokens when they're revoked
|
|
110
|
+
- Organizing tokens by server URL"""
|
|
111
|
+
|
|
112
|
+
def __init__(self, base_dir: Path | None = None):
|
|
113
|
+
"""Initialize token storage.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
base_dir: Base directory for token storage. Defaults to ~/.mcp_use/tokens
|
|
117
|
+
"""
|
|
118
|
+
self.base_dir = base_dir or Path.home() / ".mcp_use" / "tokens"
|
|
119
|
+
logger.debug(f"FileTokenStorage initialized with base_dir: {self.base_dir}")
|
|
120
|
+
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
121
|
+
|
|
122
|
+
def _get_token_path(self, server_url: str) -> Path:
|
|
123
|
+
"""Get token file path for a server."""
|
|
124
|
+
# Create a safe filename from the URL
|
|
125
|
+
parsed = urlparse(server_url)
|
|
126
|
+
filename = f"{parsed.netloc}_{parsed.path.replace('/', '_')}.json"
|
|
127
|
+
path = self.base_dir / filename
|
|
128
|
+
logger.debug(f"Token path for server '{server_url}' is '{path}'")
|
|
129
|
+
return path
|
|
130
|
+
|
|
131
|
+
async def save_tokens(self, server_url: str, tokens: dict[str, Any]) -> None:
|
|
132
|
+
"""Save tokens to file."""
|
|
133
|
+
token_path = self._get_token_path(server_url)
|
|
134
|
+
logger.debug(f"Saving tokens for '{server_url}' to '{token_path}'")
|
|
135
|
+
token_data = TokenData(**tokens)
|
|
136
|
+
token_path.write_text(token_data.model_dump_json())
|
|
137
|
+
logger.debug(f"Tokens saved successfully for '{server_url}'")
|
|
138
|
+
|
|
139
|
+
async def load_tokens(self, server_url: str) -> TokenData | None:
|
|
140
|
+
"""Load tokens from file."""
|
|
141
|
+
token_path = self._get_token_path(server_url)
|
|
142
|
+
logger.debug(f"Attempting to load tokens for '{server_url}' from '{token_path}'")
|
|
143
|
+
if not token_path.exists():
|
|
144
|
+
logger.debug(f"Token file not found: '{token_path}'")
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
data = json.loads(token_path.read_text())
|
|
149
|
+
token_data = TokenData(**data)
|
|
150
|
+
logger.debug(f"Successfully loaded tokens for '{server_url}'")
|
|
151
|
+
return token_data
|
|
152
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
153
|
+
logger.debug(f"Failed to load or parse token file '{token_path}': {e}")
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
async def delete_tokens(self, server_url: str) -> None:
|
|
157
|
+
"""Delete tokens for a server."""
|
|
158
|
+
token_path = self._get_token_path(server_url)
|
|
159
|
+
logger.debug(f"Deleting tokens for '{server_url}' at '{token_path}'")
|
|
160
|
+
if token_path.exists():
|
|
161
|
+
token_path.unlink()
|
|
162
|
+
logger.debug(f"Token file '{token_path}' deleted.")
|
|
163
|
+
else:
|
|
164
|
+
logger.debug(f"Token file '{token_path}' not found, nothing to delete.")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class OAuth:
|
|
168
|
+
"""OAuth authentication handler for MCP clients.
|
|
169
|
+
|
|
170
|
+
This is the main class that handles all the authentication
|
|
171
|
+
It has several features:
|
|
172
|
+
|
|
173
|
+
- Discovers OAuth server capabilities automatically
|
|
174
|
+
- Registers client dynamically when possible
|
|
175
|
+
- Manages token storage and refresh automaticlly"""
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
server_url: str,
|
|
180
|
+
token_storage: FileTokenStorage | None = None,
|
|
181
|
+
scope: str | None = None,
|
|
182
|
+
client_id: str | None = None,
|
|
183
|
+
client_secret: str | None = None,
|
|
184
|
+
callback_port: int | None = None,
|
|
185
|
+
oauth_provider: OAuthClientProvider | None = None,
|
|
186
|
+
):
|
|
187
|
+
"""Initialize OAuth handler.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
server_url: The MCP server URL
|
|
191
|
+
token_storage: Token storage implementation. Defaults to FileTokenStorage
|
|
192
|
+
scope: OAuth scopes to request
|
|
193
|
+
client_id: OAuth client ID. If not provided, will attempt dynamic registration
|
|
194
|
+
client_secret: OAuth client secret (for confidential clients)
|
|
195
|
+
callback_port: Port for local callback server, if empty, 8080 is used
|
|
196
|
+
oauth_provider: OAuth client provider to prevent metadata discovery
|
|
197
|
+
"""
|
|
198
|
+
logger.debug(f"Initializing OAuth for server: {server_url}")
|
|
199
|
+
self.server_url = server_url
|
|
200
|
+
self.token_storage = token_storage or FileTokenStorage()
|
|
201
|
+
self.scope = scope
|
|
202
|
+
self.client_id = client_id
|
|
203
|
+
self.client_secret = client_secret
|
|
204
|
+
|
|
205
|
+
if callback_port:
|
|
206
|
+
self.callback_port = callback_port
|
|
207
|
+
logger.info(f"Using custom callback port {self.callback_port} provided in config")
|
|
208
|
+
else:
|
|
209
|
+
self.callback_port = 8080
|
|
210
|
+
logger.info(f"Using default callback port {self.callback_port}")
|
|
211
|
+
|
|
212
|
+
# Set the default redirect uri
|
|
213
|
+
self.redirect_uri = f"http://localhost:{self.callback_port}/callback"
|
|
214
|
+
self._oauth_provider = oauth_provider
|
|
215
|
+
self._metadata: ServerOAuthMetadata | None = None
|
|
216
|
+
|
|
217
|
+
if self._oauth_provider:
|
|
218
|
+
self._metadata = self._oauth_provider.oauth_metadata
|
|
219
|
+
logger.debug(f"Using OAuth provider {self._oauth_provider.id} with metadata")
|
|
220
|
+
|
|
221
|
+
self._client: AsyncOAuth2Client | None = None
|
|
222
|
+
self._bearer_auth: BearerAuth | None = None
|
|
223
|
+
logger.debug(f"OAuth initialized with scope='{self.scope}', client_id='{self.client_id}'")
|
|
224
|
+
|
|
225
|
+
async def initialize(self, client: httpx.AsyncClient) -> BearerAuth | None:
|
|
226
|
+
"""Initialize OAuth and return bearer auth if tokens exist."""
|
|
227
|
+
logger.debug(f"OAuth.initialize called for {self.server_url}")
|
|
228
|
+
# Try to load existing tokens
|
|
229
|
+
logger.debug("Attempting to load existing tokens")
|
|
230
|
+
token_data = await self.token_storage.load_tokens(self.server_url)
|
|
231
|
+
if token_data:
|
|
232
|
+
logger.debug("Found existing tokens, checking validity")
|
|
233
|
+
if self._is_token_valid(token_data):
|
|
234
|
+
logger.debug("Existing token is valid, creating BearerAuth")
|
|
235
|
+
self._bearer_auth = BearerAuth(token=SecretStr(token_data.access_token))
|
|
236
|
+
logger.debug("OAuth.initialize returning existing valid BearerAuth")
|
|
237
|
+
return self._bearer_auth
|
|
238
|
+
else:
|
|
239
|
+
logger.debug("Existing token is expired")
|
|
240
|
+
else:
|
|
241
|
+
logger.debug("No existing tokens found")
|
|
242
|
+
|
|
243
|
+
# Discover OAuth metadata
|
|
244
|
+
if not self._metadata:
|
|
245
|
+
logger.debug("No valid token, proceeding to discover OAuth metadata")
|
|
246
|
+
await self._discover_metadata(client)
|
|
247
|
+
else:
|
|
248
|
+
logger.debug("Using provided OAuth metadata, skipping discovery")
|
|
249
|
+
|
|
250
|
+
logger.debug("OAuth.initialize finished, no valid token available yet")
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
async def authenticate(self) -> BearerAuth:
|
|
254
|
+
"""Perform OAuth authentication flow."""
|
|
255
|
+
logger.debug("OAuth.authenticate called")
|
|
256
|
+
if not self._metadata:
|
|
257
|
+
logger.error("OAuth.authenticate called before metadata was discovered.")
|
|
258
|
+
raise OAuthAuthenticationError("OAuth metadata not discovered")
|
|
259
|
+
|
|
260
|
+
# The port check should be done now. OAuth servers
|
|
261
|
+
# register client_id with also redirect_uri, so we
|
|
262
|
+
# have to ensure port is available before DCR
|
|
263
|
+
try:
|
|
264
|
+
import socket
|
|
265
|
+
|
|
266
|
+
sock = socket.socket()
|
|
267
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
268
|
+
sock.bind(("127.0.0.1", self.callback_port))
|
|
269
|
+
sock.close()
|
|
270
|
+
logger.debug(f"Using registered port {self.callback_port} for callback")
|
|
271
|
+
except (ValueError, OSError) as exception:
|
|
272
|
+
logger.error(f"The port {self.callback_port} is not available! Try using a different port!")
|
|
273
|
+
raise exception
|
|
274
|
+
|
|
275
|
+
# Try to get client_id - either from config or dynamic registration
|
|
276
|
+
client_id = self.client_id
|
|
277
|
+
client_secret = self.client_secret
|
|
278
|
+
registration = None # Track if we used DCR
|
|
279
|
+
|
|
280
|
+
if not client_id:
|
|
281
|
+
logger.debug("No client_id provided, attempting dynamic client registration")
|
|
282
|
+
# Try to load previously registered client
|
|
283
|
+
registration = await self._load_client_registration()
|
|
284
|
+
|
|
285
|
+
if registration:
|
|
286
|
+
logger.debug("Using previously registered client")
|
|
287
|
+
client_id = registration.client_id
|
|
288
|
+
client_secret = registration.client_secret
|
|
289
|
+
else:
|
|
290
|
+
# Attempt dynamic registration
|
|
291
|
+
registration = await self._try_dynamic_registration()
|
|
292
|
+
if registration:
|
|
293
|
+
logger.debug("Dynamic registration successful")
|
|
294
|
+
client_id = registration.client_id
|
|
295
|
+
client_secret = registration.client_secret
|
|
296
|
+
# Store for future use
|
|
297
|
+
await self._store_client_registration(registration)
|
|
298
|
+
else:
|
|
299
|
+
logger.error("Dynamic client registration failed or not supported")
|
|
300
|
+
raise OAuthAuthenticationError(
|
|
301
|
+
"OAuth requires a client_id. Server does not support dynamic registration. "
|
|
302
|
+
"Please provide one in the auth configuration. "
|
|
303
|
+
"Example: {'auth': {'client_id': 'your-registered-client-id'}}"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
logger.debug(f"Using client_id: {client_id}")
|
|
307
|
+
|
|
308
|
+
# Create OAuth client
|
|
309
|
+
logger.debug("Creating AsyncOAuth2Client")
|
|
310
|
+
self._client = AsyncOAuth2Client(
|
|
311
|
+
client_id=client_id,
|
|
312
|
+
client_secret=client_secret,
|
|
313
|
+
redirect_uri=self.redirect_uri,
|
|
314
|
+
scope=self.scope,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Start callback server
|
|
318
|
+
logger.debug("Starting OAuth callback server")
|
|
319
|
+
|
|
320
|
+
callback_server = OAuthCallbackServer(port=self.callback_port)
|
|
321
|
+
redirect_uri = await callback_server.start()
|
|
322
|
+
self._client.redirect_uri = redirect_uri
|
|
323
|
+
logger.debug(f"Callback server started, redirect_uri: {redirect_uri}")
|
|
324
|
+
|
|
325
|
+
# Generate state for CSRF protection
|
|
326
|
+
state = secrets.token_urlsafe(32)
|
|
327
|
+
logger.debug(f"Generated state for CSRF protection: {state}")
|
|
328
|
+
|
|
329
|
+
# Build authorization URL
|
|
330
|
+
logger.debug("Creating authorization URL")
|
|
331
|
+
auth_url, _ = self._client.create_authorization_url(
|
|
332
|
+
str(self._metadata.authorization_endpoint),
|
|
333
|
+
state=state,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
logger.debug("OAuth flow started:")
|
|
337
|
+
logger.debug(f" Client ID: {client_id}")
|
|
338
|
+
logger.debug(f" Authorization endpoint: {self._metadata.authorization_endpoint}")
|
|
339
|
+
logger.debug(f" Redirect URI: {redirect_uri}")
|
|
340
|
+
logger.debug(f" Scope: {self.scope}")
|
|
341
|
+
|
|
342
|
+
# Open browser for authorization
|
|
343
|
+
print(f"Opening browser for authorization: {auth_url}")
|
|
344
|
+
webbrowser.open(auth_url)
|
|
345
|
+
|
|
346
|
+
# Wait for callback
|
|
347
|
+
logger.debug("Waiting for authorization code from callback server")
|
|
348
|
+
try:
|
|
349
|
+
response = await callback_server.wait_for_code()
|
|
350
|
+
logger.debug("Received response from callback server")
|
|
351
|
+
except TimeoutError as e:
|
|
352
|
+
logger.error(f"OAuth callback timed out: {e}")
|
|
353
|
+
raise OAuthAuthenticationError(f"OAuth timeout: {e}") from e
|
|
354
|
+
|
|
355
|
+
if response.error:
|
|
356
|
+
logger.error("OAuth authorization failed:")
|
|
357
|
+
logger.error(f" Error: {response.error}")
|
|
358
|
+
logger.error(f" Description: {response.error_description}")
|
|
359
|
+
logger.error(" The OAuth server returned this error, likely because:")
|
|
360
|
+
logger.error(f" 1. The client_id '{client_id}' is not registered with the OAuth server")
|
|
361
|
+
logger.error(" 2. The redirect_uri doesn't match the registered one")
|
|
362
|
+
logger.error(" 3. The requested scopes are invalid")
|
|
363
|
+
raise OAuthAuthenticationError(f"{response.error}: {response.error_description}")
|
|
364
|
+
|
|
365
|
+
if not response.code:
|
|
366
|
+
logger.error("Callback response did not contain an authorization code")
|
|
367
|
+
raise OAuthAuthenticationError("No authorization code received")
|
|
368
|
+
|
|
369
|
+
logger.debug(f"Received authorization code: {response.code[:10]}...")
|
|
370
|
+
|
|
371
|
+
# Verify state
|
|
372
|
+
logger.debug(f"Verifying state. Expected: {state}, Got: {response.state}")
|
|
373
|
+
if response.state != state:
|
|
374
|
+
logger.error("State mismatch in OAuth callback. Possible CSRF attack.")
|
|
375
|
+
raise OAuthAuthenticationError("Invalid state parameter - possible CSRF attack")
|
|
376
|
+
logger.debug("State verified successfully")
|
|
377
|
+
|
|
378
|
+
# Exchange code for tokens
|
|
379
|
+
logger.debug("Exchanging authorization code for tokens")
|
|
380
|
+
try:
|
|
381
|
+
token_response = await self._client.fetch_token(
|
|
382
|
+
str(self._metadata.token_endpoint),
|
|
383
|
+
authorization_response=f"{redirect_uri}?code={response.code}&state={response.state}",
|
|
384
|
+
grant_type="authorization_code",
|
|
385
|
+
)
|
|
386
|
+
logger.debug("Successfully fetched tokens")
|
|
387
|
+
except OAuth2Error as e:
|
|
388
|
+
logger.error(f"Token exchange failed: {e}")
|
|
389
|
+
raise OAuthAuthenticationError(f"Token exchange failed: {e}") from e
|
|
390
|
+
|
|
391
|
+
# Save tokens
|
|
392
|
+
logger.debug("Saving fetched tokens")
|
|
393
|
+
await self.token_storage.save_tokens(self.server_url, token_response)
|
|
394
|
+
|
|
395
|
+
# Create bearer auth
|
|
396
|
+
logger.debug("Creating BearerAuth with new access token")
|
|
397
|
+
self._bearer_auth = BearerAuth(token=SecretStr(token_response["access_token"]))
|
|
398
|
+
return self._bearer_auth
|
|
399
|
+
|
|
400
|
+
async def _discover_metadata(self, client: httpx.AsyncClient) -> None:
|
|
401
|
+
"""Discover OAuth metadata from server."""
|
|
402
|
+
logger.debug(f"Discovering OAuth metadata for {self.server_url}")
|
|
403
|
+
# Try well-known endpoint first
|
|
404
|
+
parsed = urlparse(self.server_url)
|
|
405
|
+
|
|
406
|
+
# Edge case for GH that doesn't have metadata discovery
|
|
407
|
+
if parsed.netloc == "api.githubcopilot.com":
|
|
408
|
+
logger.debug("Detected GitHub MCP server, using its metadata")
|
|
409
|
+
issuer = "https://github.com/login/oauth"
|
|
410
|
+
authorization_endpoint = "https://github.com/login/oauth/authorize"
|
|
411
|
+
token_endpoint = "https://github.com/login/oauth/access_token"
|
|
412
|
+
self._metadata = ServerOAuthMetadata(
|
|
413
|
+
issuer=issuer, authorization_endpoint=authorization_endpoint, token_endpoint=token_endpoint
|
|
414
|
+
)
|
|
415
|
+
return
|
|
416
|
+
|
|
417
|
+
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
418
|
+
well_known_url = f"{base_url}/.well-known/oauth-authorization-server"
|
|
419
|
+
|
|
420
|
+
try:
|
|
421
|
+
logger.debug(f"Trying OAuth metadata discovery at: {well_known_url}")
|
|
422
|
+
response = await client.get(well_known_url)
|
|
423
|
+
response.raise_for_status()
|
|
424
|
+
metadata = response.json()
|
|
425
|
+
self._metadata = ServerOAuthMetadata(**metadata)
|
|
426
|
+
logger.debug("Successfully discovered OAuth metadata")
|
|
427
|
+
logger.debug(f" Authorization endpoint: {self._metadata.authorization_endpoint}")
|
|
428
|
+
logger.debug(f" Token endpoint: {self._metadata.token_endpoint}")
|
|
429
|
+
return
|
|
430
|
+
except (httpx.HTTPError, ValueError) as e:
|
|
431
|
+
logger.debug(f"Failed to discover OAuth metadata at {well_known_url}: {e}")
|
|
432
|
+
pass
|
|
433
|
+
|
|
434
|
+
# Try OpenID Connect discovery
|
|
435
|
+
oidc_url = f"{base_url}/.well-known/openid-configuration"
|
|
436
|
+
logger.debug(f"Trying OpenID Connect discovery at: {oidc_url}")
|
|
437
|
+
try:
|
|
438
|
+
response = await client.get(oidc_url)
|
|
439
|
+
response.raise_for_status()
|
|
440
|
+
metadata = response.json()
|
|
441
|
+
self._metadata = ServerOAuthMetadata(**metadata)
|
|
442
|
+
logger.debug("Successfully discovered OIDC metadata")
|
|
443
|
+
logger.debug(f" Authorization endpoint: {self._metadata.authorization_endpoint}")
|
|
444
|
+
logger.debug(f" Token endpoint: {self._metadata.token_endpoint}")
|
|
445
|
+
return
|
|
446
|
+
except (httpx.HTTPError, ValueError) as e:
|
|
447
|
+
logger.debug(f"Failed to discover OIDC metadata at {oidc_url}: {e}")
|
|
448
|
+
pass
|
|
449
|
+
|
|
450
|
+
# If discovery fails, we'll need the metadata from somewhere else
|
|
451
|
+
logger.error(f"Failed to discover OAuth/OIDC metadata for {self.server_url}")
|
|
452
|
+
raise OAuthDiscoveryError(
|
|
453
|
+
f"Failed to discover OAuth metadata for {self.server_url}. "
|
|
454
|
+
"Server must support OAuth metadata discovery at "
|
|
455
|
+
"/.well-known/oauth-authorization-server or /.well-known/openid-configuration"
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def _is_token_valid(self, token_data: TokenData) -> bool:
|
|
459
|
+
"""Check if token is still valid."""
|
|
460
|
+
logger.debug("Checking token validity")
|
|
461
|
+
if not token_data.expires_at:
|
|
462
|
+
logger.debug("Token has no expiration time, assuming it's valid.")
|
|
463
|
+
return True # No expiration info, assume valid
|
|
464
|
+
|
|
465
|
+
# Check if token expires in more than 60 seconds
|
|
466
|
+
expires_at = datetime.fromtimestamp(token_data.expires_at, tz=UTC)
|
|
467
|
+
now = datetime.now(tz=UTC)
|
|
468
|
+
is_valid = expires_at > now + timedelta(seconds=60)
|
|
469
|
+
logger.debug(f"Token expires at {expires_at}, current time is {now}. Valid: {is_valid}")
|
|
470
|
+
return is_valid
|
|
471
|
+
|
|
472
|
+
async def _try_dynamic_registration(self) -> ClientRegistrationResponse | None:
|
|
473
|
+
"""Try Dynamic Client Registration if supported by the server."""
|
|
474
|
+
if not self._metadata or not self._metadata.registration_endpoint:
|
|
475
|
+
logger.debug("No registration endpoint available, skipping DCR")
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
logger.info("Attempting Dynamic Client Registration")
|
|
479
|
+
logger.debug(f"DCR endpoint: {self._metadata.registration_endpoint}")
|
|
480
|
+
|
|
481
|
+
registration_data = {
|
|
482
|
+
"client_name": "mcp-use",
|
|
483
|
+
"redirect_uris": [self.redirect_uri],
|
|
484
|
+
"grant_types": ["authorization_code"],
|
|
485
|
+
"response_types": ["code"],
|
|
486
|
+
"token_endpoint_auth_method": "none", # Public client
|
|
487
|
+
"application_type": "native",
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
# Add scope if specified
|
|
491
|
+
if self.scope:
|
|
492
|
+
registration_data["scope"] = self.scope
|
|
493
|
+
|
|
494
|
+
logger.debug(f"DCR request payload: {registration_data}")
|
|
495
|
+
try:
|
|
496
|
+
async with httpx.AsyncClient() as client:
|
|
497
|
+
response = await client.post(
|
|
498
|
+
str(self._metadata.registration_endpoint),
|
|
499
|
+
json=registration_data,
|
|
500
|
+
headers={"Content-Type": "application/json"},
|
|
501
|
+
)
|
|
502
|
+
logger.debug(f"DCR response status: {response.status_code}")
|
|
503
|
+
response.raise_for_status()
|
|
504
|
+
|
|
505
|
+
# Parse registration response
|
|
506
|
+
reg_response_data = response.json()
|
|
507
|
+
logger.debug(f"DCR response body: {reg_response_data}")
|
|
508
|
+
reg_response = ClientRegistrationResponse(**reg_response_data)
|
|
509
|
+
|
|
510
|
+
# Update our credentials
|
|
511
|
+
self.client_id = reg_response.client_id
|
|
512
|
+
self.client_secret = reg_response.client_secret
|
|
513
|
+
|
|
514
|
+
logger.info(f"Dynamic Client Registration successful: {self.client_id}")
|
|
515
|
+
|
|
516
|
+
# Store the registered client info for future use
|
|
517
|
+
await self._store_client_registration(reg_response)
|
|
518
|
+
|
|
519
|
+
return reg_response
|
|
520
|
+
|
|
521
|
+
except httpx.HTTPError as e:
|
|
522
|
+
logger.warning(f"Dynamic Client Registration failed: {e}")
|
|
523
|
+
# Log the response if available
|
|
524
|
+
if hasattr(e, "response") and e.response:
|
|
525
|
+
logger.debug(f"DCR response: {e.response.status_code} - {e.response.text}")
|
|
526
|
+
return None
|
|
527
|
+
except Exception as e:
|
|
528
|
+
logger.warning(f"Unexpected error during DCR: {e}")
|
|
529
|
+
return None
|
|
530
|
+
|
|
531
|
+
async def _store_client_registration(self, registration: ClientRegistrationResponse) -> None:
|
|
532
|
+
"""Store client registration data for future use."""
|
|
533
|
+
logger.debug("Storing client registration data")
|
|
534
|
+
# Store alongside tokens in a separate file
|
|
535
|
+
storage_path = self.token_storage.base_dir / "registrations"
|
|
536
|
+
storage_path.mkdir(parents=True, exist_ok=True)
|
|
537
|
+
|
|
538
|
+
# Create a safe filename from the server URL
|
|
539
|
+
parsed = urlparse(self.server_url)
|
|
540
|
+
filename = f"{parsed.netloc}_{parsed.path.replace('/', '_')}_registration.json"
|
|
541
|
+
reg_path = storage_path / filename
|
|
542
|
+
logger.debug(f"Storing client registration to '{reg_path}'")
|
|
543
|
+
|
|
544
|
+
# Store registration data
|
|
545
|
+
reg_path.write_text(registration.model_dump_json())
|
|
546
|
+
logger.debug("Client registration data stored successfully")
|
|
547
|
+
|
|
548
|
+
async def _load_client_registration(self) -> ClientRegistrationResponse | None:
|
|
549
|
+
"""Load previously registered client credentials if available."""
|
|
550
|
+
logger.debug("Attempting to load client registration data")
|
|
551
|
+
storage_path = self.token_storage.base_dir / "registrations"
|
|
552
|
+
|
|
553
|
+
# Create a safe filename from the server URL
|
|
554
|
+
parsed = urlparse(self.server_url)
|
|
555
|
+
filename = f"{parsed.netloc}_{parsed.path.replace('/', '_')}_registration.json"
|
|
556
|
+
reg_path = storage_path / filename
|
|
557
|
+
logger.debug(f"Checking for client registration file at '{reg_path}'")
|
|
558
|
+
|
|
559
|
+
if reg_path.exists():
|
|
560
|
+
logger.debug("Client registration file found")
|
|
561
|
+
try:
|
|
562
|
+
data = json.loads(reg_path.read_text())
|
|
563
|
+
reg_response = ClientRegistrationResponse(**data)
|
|
564
|
+
|
|
565
|
+
# Check if registration is still valid (if expiry info provided)
|
|
566
|
+
if reg_response.client_secret_expires_at:
|
|
567
|
+
expires_at = datetime.fromtimestamp(reg_response.client_secret_expires_at, tz=UTC)
|
|
568
|
+
now = datetime.now(tz=UTC)
|
|
569
|
+
logger.debug(f"Checking client registration expiry. Expires at: {expires_at}, Now: {now}")
|
|
570
|
+
if expires_at <= now:
|
|
571
|
+
logger.debug("Stored client registration has expired")
|
|
572
|
+
return None
|
|
573
|
+
|
|
574
|
+
self.client_id = reg_response.client_id
|
|
575
|
+
self.client_secret = reg_response.client_secret
|
|
576
|
+
logger.debug(f"Loaded stored client registration: {self.client_id}")
|
|
577
|
+
return reg_response
|
|
578
|
+
|
|
579
|
+
except Exception as e:
|
|
580
|
+
logger.debug(f"Failed to load client registration: {e}")
|
|
581
|
+
else:
|
|
582
|
+
logger.debug("Client registration file not found")
|
|
583
|
+
|
|
584
|
+
return None
|
|
585
|
+
|
|
586
|
+
async def refresh_token(self) -> BearerAuth | None:
|
|
587
|
+
"""Refresh the access token if possible."""
|
|
588
|
+
logger.debug("Attempting to refresh token")
|
|
589
|
+
token_data = await self.token_storage.load_tokens(self.server_url)
|
|
590
|
+
if not token_data or not token_data.refresh_token:
|
|
591
|
+
logger.debug("No token data or refresh token found, cannot refresh.")
|
|
592
|
+
return None
|
|
593
|
+
|
|
594
|
+
if not self._metadata:
|
|
595
|
+
logger.debug("No OAuth metadata available, cannot refresh token.")
|
|
596
|
+
return None
|
|
597
|
+
|
|
598
|
+
if not self._client:
|
|
599
|
+
if not self.client_id:
|
|
600
|
+
logger.debug("Cannot refresh token without client_id")
|
|
601
|
+
return None
|
|
602
|
+
logger.debug("Creating temporary AsyncOAuth2Client for token refresh")
|
|
603
|
+
self._client = AsyncOAuth2Client(client_id=self.client_id, client_secret=self.client_secret)
|
|
604
|
+
|
|
605
|
+
logger.debug("Calling client.refresh_token")
|
|
606
|
+
try:
|
|
607
|
+
token_response = await self._client.refresh_token(
|
|
608
|
+
str(self._metadata.token_endpoint),
|
|
609
|
+
refresh_token=token_data.refresh_token,
|
|
610
|
+
)
|
|
611
|
+
logger.debug("Token refresh successful")
|
|
612
|
+
|
|
613
|
+
# Save new tokens
|
|
614
|
+
logger.debug("Saving new tokens after refresh")
|
|
615
|
+
await self.token_storage.save_tokens(self.server_url, token_response)
|
|
616
|
+
|
|
617
|
+
# Update bearer auth
|
|
618
|
+
logger.debug("Updating BearerAuth with new access token")
|
|
619
|
+
self._bearer_auth = BearerAuth(token=SecretStr(token_response["access_token"]))
|
|
620
|
+
return self._bearer_auth
|
|
621
|
+
|
|
622
|
+
except OAuth2Error as e:
|
|
623
|
+
logger.warning(f"Token refresh failed: {e}. Re-authentication is required.")
|
|
624
|
+
# Refresh failed, need to re-authenticate
|
|
625
|
+
return None
|