appkit-assistant 0.14.1__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- appkit_assistant/backend/mcp_auth_service.py +796 -0
- appkit_assistant/backend/model_manager.py +2 -1
- appkit_assistant/backend/models.py +43 -0
- appkit_assistant/backend/processors/openai_responses_processor.py +265 -36
- appkit_assistant/backend/repositories.py +1 -1
- appkit_assistant/backend/system_prompt_cache.py +5 -5
- appkit_assistant/components/mcp_server_dialogs.py +327 -21
- appkit_assistant/components/message.py +62 -0
- appkit_assistant/components/thread.py +99 -1
- appkit_assistant/state/mcp_server_state.py +42 -1
- appkit_assistant/state/system_prompt_state.py +4 -4
- appkit_assistant/state/thread_list_state.py +5 -5
- appkit_assistant/state/thread_state.py +190 -28
- {appkit_assistant-0.14.1.dist-info → appkit_assistant-0.15.1.dist-info}/METADATA +1 -1
- appkit_assistant-0.15.1.dist-info/RECORD +29 -0
- appkit_assistant-0.14.1.dist-info/RECORD +0 -28
- {appkit_assistant-0.14.1.dist-info → appkit_assistant-0.15.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,796 @@
|
|
|
1
|
+
"""MCP OAuth Authentication Service.
|
|
2
|
+
|
|
3
|
+
Handles OAuth 2.0 flows for MCP servers including:
|
|
4
|
+
- Discovery of OAuth metadata via RFC 8414
|
|
5
|
+
- Authorization URL construction with PKCE support
|
|
6
|
+
- Token exchange (authorization code for tokens)
|
|
7
|
+
- Token refresh
|
|
8
|
+
- Token storage and retrieval
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import base64
|
|
12
|
+
import hashlib
|
|
13
|
+
import logging
|
|
14
|
+
import secrets
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from datetime import UTC, datetime, timedelta
|
|
17
|
+
from http import HTTPStatus
|
|
18
|
+
from urllib.parse import urlencode, urlparse
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
from sqlmodel import Session, select
|
|
22
|
+
|
|
23
|
+
from appkit_assistant.backend.models import (
|
|
24
|
+
AssistantMCPUserToken,
|
|
25
|
+
MCPAuthType,
|
|
26
|
+
MCPServer,
|
|
27
|
+
)
|
|
28
|
+
from appkit_user.authentication.backend.entities import OAuthStateEntity
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
# Discovery paths per RFC 8414
|
|
33
|
+
WELL_KNOWN_PATHS = [
|
|
34
|
+
"/.well-known/oauth-authorization-server",
|
|
35
|
+
"/.well-known/openid-configuration",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _generate_pkce_pair() -> tuple[str, str]:
|
|
40
|
+
"""Generate PKCE code verifier and challenge (S256).
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tuple of (code_verifier, code_challenge)
|
|
44
|
+
"""
|
|
45
|
+
# Generate code verifier (43-128 characters)
|
|
46
|
+
code_verifier = (
|
|
47
|
+
base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Generate code challenge (SHA256 hash of verifier)
|
|
51
|
+
code_challenge = (
|
|
52
|
+
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
|
|
53
|
+
.decode("utf-8")
|
|
54
|
+
.rstrip("=")
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return code_verifier, code_challenge
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class OAuthDiscoveryResult:
|
|
62
|
+
"""Result of OAuth metadata discovery."""
|
|
63
|
+
|
|
64
|
+
issuer: str | None = None
|
|
65
|
+
authorization_endpoint: str | None = None
|
|
66
|
+
token_endpoint: str | None = None
|
|
67
|
+
registration_endpoint: str | None = None
|
|
68
|
+
scopes_supported: list[str] | None = None
|
|
69
|
+
error: str | None = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class ClientRegistrationResult:
|
|
74
|
+
"""Result of OAuth Dynamic Client Registration (RFC 7591)."""
|
|
75
|
+
|
|
76
|
+
client_id: str | None = None
|
|
77
|
+
client_secret: str | None = None
|
|
78
|
+
client_id_issued_at: int | None = None
|
|
79
|
+
client_secret_expires_at: int | None = None
|
|
80
|
+
error: str | None = None
|
|
81
|
+
error_description: str | None = None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class TokenResult:
|
|
86
|
+
"""Result of token exchange or refresh."""
|
|
87
|
+
|
|
88
|
+
access_token: str | None = None
|
|
89
|
+
refresh_token: str | None = None
|
|
90
|
+
expires_in: int | None = None
|
|
91
|
+
token_type: str | None = None
|
|
92
|
+
scope: str | None = None
|
|
93
|
+
error: str | None = None
|
|
94
|
+
error_description: str | None = None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class MCPAuthService:
|
|
98
|
+
"""Service for handling MCP OAuth authentication flows."""
|
|
99
|
+
|
|
100
|
+
def __init__(self, redirect_uri: str) -> None:
|
|
101
|
+
"""Initialize the service.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
redirect_uri: The OAuth redirect URI for the callback endpoint.
|
|
105
|
+
"""
|
|
106
|
+
self.redirect_uri = redirect_uri
|
|
107
|
+
self._http_client: httpx.AsyncClient | None = None
|
|
108
|
+
|
|
109
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
110
|
+
"""Get or create the HTTP client."""
|
|
111
|
+
if self._http_client is None:
|
|
112
|
+
self._http_client = httpx.AsyncClient(timeout=30.0)
|
|
113
|
+
return self._http_client
|
|
114
|
+
|
|
115
|
+
async def close(self) -> None:
|
|
116
|
+
"""Close the HTTP client."""
|
|
117
|
+
if self._http_client is not None:
|
|
118
|
+
await self._http_client.aclose()
|
|
119
|
+
self._http_client = None
|
|
120
|
+
|
|
121
|
+
async def discover_oauth_config(self, server_url: str) -> OAuthDiscoveryResult:
|
|
122
|
+
"""Discover OAuth metadata from the server.
|
|
123
|
+
|
|
124
|
+
Attempts to fetch OAuth metadata from well-known endpoints per RFC 8414.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
server_url: The MCP server URL to discover OAuth config from.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
OAuthDiscoveryResult with discovered endpoints or error.
|
|
131
|
+
"""
|
|
132
|
+
client = await self._get_client()
|
|
133
|
+
parsed = urlparse(server_url)
|
|
134
|
+
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
135
|
+
|
|
136
|
+
for path in WELL_KNOWN_PATHS:
|
|
137
|
+
discovery_url = f"{base_url}{path}"
|
|
138
|
+
logger.debug("Attempting OAuth discovery at: %s", discovery_url)
|
|
139
|
+
|
|
140
|
+
try:
|
|
141
|
+
response = await client.get(discovery_url)
|
|
142
|
+
if response.status_code == HTTPStatus.OK:
|
|
143
|
+
data = response.json()
|
|
144
|
+
return OAuthDiscoveryResult(
|
|
145
|
+
issuer=data.get("issuer"),
|
|
146
|
+
authorization_endpoint=data.get("authorization_endpoint"),
|
|
147
|
+
token_endpoint=data.get("token_endpoint"),
|
|
148
|
+
registration_endpoint=data.get("registration_endpoint"),
|
|
149
|
+
scopes_supported=data.get("scopes_supported"),
|
|
150
|
+
)
|
|
151
|
+
except httpx.RequestError as e:
|
|
152
|
+
logger.debug("Discovery failed at %s: %s", discovery_url, str(e))
|
|
153
|
+
continue
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.warning("Unexpected error during discovery: %s", str(e))
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
return OAuthDiscoveryResult(
|
|
159
|
+
error="OAuth discovery failed: No valid metadata found at well-known paths"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def register_client(
|
|
163
|
+
self,
|
|
164
|
+
registration_endpoint: str,
|
|
165
|
+
client_name: str = "AppKit Assistant",
|
|
166
|
+
additional_redirect_uris: list[str] | None = None,
|
|
167
|
+
) -> ClientRegistrationResult:
|
|
168
|
+
"""Register a new OAuth client via Dynamic Client Registration (RFC 7591).
|
|
169
|
+
|
|
170
|
+
Some OAuth providers like Atlassian MCP require clients to register
|
|
171
|
+
dynamically before they can authenticate.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
registration_endpoint: The OAuth registration endpoint URL.
|
|
175
|
+
client_name: The name to register the client with.
|
|
176
|
+
additional_redirect_uris: Additional redirect URIs to register.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
ClientRegistrationResult with client_id and optionally client_secret.
|
|
180
|
+
"""
|
|
181
|
+
client = await self._get_client()
|
|
182
|
+
|
|
183
|
+
redirect_uris = [self.redirect_uri]
|
|
184
|
+
if additional_redirect_uris:
|
|
185
|
+
redirect_uris.extend(additional_redirect_uris)
|
|
186
|
+
|
|
187
|
+
# Client metadata per RFC 7591
|
|
188
|
+
client_metadata = {
|
|
189
|
+
"client_name": client_name,
|
|
190
|
+
"redirect_uris": redirect_uris,
|
|
191
|
+
"grant_types": ["authorization_code", "refresh_token"],
|
|
192
|
+
"response_types": ["code"],
|
|
193
|
+
"token_endpoint_auth_method": "none", # Public client (no secret)
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
logger.debug(
|
|
197
|
+
"Registering OAuth client at %s with metadata: %s",
|
|
198
|
+
registration_endpoint,
|
|
199
|
+
client_metadata,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
response = await client.post(
|
|
204
|
+
registration_endpoint,
|
|
205
|
+
json=client_metadata,
|
|
206
|
+
headers={"Content-Type": "application/json"},
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if response.status_code in (HTTPStatus.OK, HTTPStatus.CREATED):
|
|
210
|
+
data = response.json()
|
|
211
|
+
logger.debug(
|
|
212
|
+
"Successfully registered OAuth client: %s",
|
|
213
|
+
data.get("client_id"),
|
|
214
|
+
)
|
|
215
|
+
return ClientRegistrationResult(
|
|
216
|
+
client_id=data.get("client_id"),
|
|
217
|
+
client_secret=data.get("client_secret"),
|
|
218
|
+
client_id_issued_at=data.get("client_id_issued_at"),
|
|
219
|
+
client_secret_expires_at=data.get("client_secret_expires_at"),
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Handle error response
|
|
223
|
+
try:
|
|
224
|
+
error_data = response.json()
|
|
225
|
+
return ClientRegistrationResult(
|
|
226
|
+
error=error_data.get("error", "registration_failed"),
|
|
227
|
+
error_description=error_data.get(
|
|
228
|
+
"error_description",
|
|
229
|
+
f"HTTP {response.status_code}: {response.text}",
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
except Exception:
|
|
233
|
+
return ClientRegistrationResult(
|
|
234
|
+
error="registration_failed",
|
|
235
|
+
error_description=f"HTTP {response.status_code}: {response.text}",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
except httpx.RequestError as e:
|
|
239
|
+
logger.error("Client registration request failed: %s", str(e))
|
|
240
|
+
return ClientRegistrationResult(
|
|
241
|
+
error="request_failed",
|
|
242
|
+
error_description=str(e),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
async def build_authorization_url_with_registration(
|
|
246
|
+
self,
|
|
247
|
+
server: MCPServer,
|
|
248
|
+
state: str | None = None,
|
|
249
|
+
session: Session | None = None,
|
|
250
|
+
user_id: int | None = None,
|
|
251
|
+
) -> tuple[str, str]:
|
|
252
|
+
"""Build the OAuth authorization URL, registering client via DCR if needed.
|
|
253
|
+
|
|
254
|
+
This method will automatically discover OAuth config and perform Dynamic
|
|
255
|
+
Client Registration (RFC 7591) if the server has no client_id configured
|
|
256
|
+
and a registration_endpoint is available.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
server: The MCP server configuration.
|
|
260
|
+
state: Optional state parameter. If not provided, a random one is generated.
|
|
261
|
+
session: DB Session for storing PKCE state and updating server config.
|
|
262
|
+
user_id: User ID for binding state.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Tuple of (authorization_url, state) where state should be stored
|
|
266
|
+
for verification.
|
|
267
|
+
"""
|
|
268
|
+
# Ensure we work with a session-attached object to avoid StateProxy issues
|
|
269
|
+
# and to allow persisting changes if DCR happens.
|
|
270
|
+
if session and server.id:
|
|
271
|
+
db_server = session.get(MCPServer, server.id)
|
|
272
|
+
if db_server:
|
|
273
|
+
server = db_server
|
|
274
|
+
|
|
275
|
+
# If no client_id, attempt Dynamic Client Registration
|
|
276
|
+
if not server.oauth_client_id:
|
|
277
|
+
logger.debug(
|
|
278
|
+
"No client_id for server %s, attempting DCR",
|
|
279
|
+
server.name,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Discover OAuth config to find registration endpoint
|
|
283
|
+
discovery_result = await self.discover_oauth_config(server.url)
|
|
284
|
+
|
|
285
|
+
if discovery_result.error:
|
|
286
|
+
msg = f"OAuth discovery failed: {discovery_result.error}"
|
|
287
|
+
raise ValueError(msg)
|
|
288
|
+
|
|
289
|
+
if not discovery_result.registration_endpoint:
|
|
290
|
+
msg = "Server has no client ID and no registration endpoint available"
|
|
291
|
+
raise ValueError(msg)
|
|
292
|
+
|
|
293
|
+
# Perform Dynamic Client Registration
|
|
294
|
+
reg_result = await self.register_client(
|
|
295
|
+
registration_endpoint=discovery_result.registration_endpoint,
|
|
296
|
+
client_name="AppKit Assistant",
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if reg_result.error:
|
|
300
|
+
msg = (
|
|
301
|
+
f"Client registration failed: "
|
|
302
|
+
f"{reg_result.error}: {reg_result.error_description}"
|
|
303
|
+
)
|
|
304
|
+
raise ValueError(msg)
|
|
305
|
+
|
|
306
|
+
if not reg_result.client_id:
|
|
307
|
+
msg = "Client registration succeeded but no client_id returned"
|
|
308
|
+
raise ValueError(msg)
|
|
309
|
+
|
|
310
|
+
# Update server with registered client_id
|
|
311
|
+
server.oauth_client_id = reg_result.client_id
|
|
312
|
+
if reg_result.client_secret:
|
|
313
|
+
server.oauth_client_secret = reg_result.client_secret
|
|
314
|
+
|
|
315
|
+
# Also update discovered endpoints if not already set
|
|
316
|
+
auth_endpoint = discovery_result.authorization_endpoint
|
|
317
|
+
if not server.oauth_authorize_url and auth_endpoint:
|
|
318
|
+
server.oauth_authorize_url = auth_endpoint
|
|
319
|
+
if not server.oauth_token_url and discovery_result.token_endpoint:
|
|
320
|
+
server.oauth_token_url = discovery_result.token_endpoint
|
|
321
|
+
|
|
322
|
+
# Persist updated server configuration
|
|
323
|
+
if session:
|
|
324
|
+
session.add(server)
|
|
325
|
+
try:
|
|
326
|
+
session.commit()
|
|
327
|
+
logger.debug(
|
|
328
|
+
"Persisted DCR client_id %s for server %s",
|
|
329
|
+
reg_result.client_id,
|
|
330
|
+
server.name,
|
|
331
|
+
)
|
|
332
|
+
except Exception as e:
|
|
333
|
+
logger.error("Failed to persist DCR client_id: %s", e)
|
|
334
|
+
session.rollback()
|
|
335
|
+
|
|
336
|
+
# Now delegate to the synchronous URL builder
|
|
337
|
+
return self.build_authorization_url(
|
|
338
|
+
server=server,
|
|
339
|
+
state=state,
|
|
340
|
+
session=session,
|
|
341
|
+
user_id=user_id,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def build_authorization_url(
|
|
345
|
+
self,
|
|
346
|
+
server: MCPServer,
|
|
347
|
+
state: str | None = None,
|
|
348
|
+
session: Session | None = None,
|
|
349
|
+
user_id: int | None = None,
|
|
350
|
+
) -> tuple[str, str]:
|
|
351
|
+
"""Build the OAuth authorization URL for user login.
|
|
352
|
+
|
|
353
|
+
Supports PKCE by generating code_verifier and storing it if a session is
|
|
354
|
+
provided.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
server: The MCP server configuration.
|
|
358
|
+
state: Optional state parameter. If not provided, a random one is generated.
|
|
359
|
+
session: DB Session for storing PKCE state.
|
|
360
|
+
user_id: User ID for binding state.
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Tuple of (authorization_url, state) where state should be stored
|
|
364
|
+
for verification.
|
|
365
|
+
"""
|
|
366
|
+
if not server.oauth_authorize_url:
|
|
367
|
+
msg = "Server has no authorization URL configured"
|
|
368
|
+
raise ValueError(msg)
|
|
369
|
+
|
|
370
|
+
if state is None:
|
|
371
|
+
state = secrets.token_urlsafe(32)
|
|
372
|
+
logger.info("Generated new OAuth state: %s", state)
|
|
373
|
+
|
|
374
|
+
if not server.oauth_client_id:
|
|
375
|
+
msg = "Server has no client ID configured"
|
|
376
|
+
raise ValueError(msg)
|
|
377
|
+
|
|
378
|
+
# Generate PKCE parameters (required by OAuth 2.1 / MCP servers)
|
|
379
|
+
code_verifier, code_challenge = _generate_pkce_pair()
|
|
380
|
+
|
|
381
|
+
params = {
|
|
382
|
+
"response_type": "code",
|
|
383
|
+
"redirect_uri": self.redirect_uri,
|
|
384
|
+
"state": state,
|
|
385
|
+
"client_id": server.oauth_client_id,
|
|
386
|
+
"code_challenge": code_challenge,
|
|
387
|
+
"code_challenge_method": "S256",
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
# Store state in DB for CSRF protection, server mapping, and PKCE verifier
|
|
391
|
+
if session:
|
|
392
|
+
provider_key = f"mcp:{server.id}" if server.id else "mcp:unknown"
|
|
393
|
+
logger.info(
|
|
394
|
+
"Saving OAuth state to DB: state=%s, provider=%s, user_id=%s",
|
|
395
|
+
state,
|
|
396
|
+
provider_key,
|
|
397
|
+
user_id,
|
|
398
|
+
)
|
|
399
|
+
oauth_state = OAuthStateEntity(
|
|
400
|
+
session_id="mcp_auth_flow",
|
|
401
|
+
state=state,
|
|
402
|
+
provider=provider_key,
|
|
403
|
+
code_verifier=code_verifier,
|
|
404
|
+
expires_at=datetime.now(UTC) + timedelta(minutes=10),
|
|
405
|
+
user_id=user_id,
|
|
406
|
+
)
|
|
407
|
+
session.add(oauth_state)
|
|
408
|
+
try:
|
|
409
|
+
session.commit()
|
|
410
|
+
logger.info("OAuth state committed successfully: %s", state)
|
|
411
|
+
except Exception as e:
|
|
412
|
+
logger.error("Failed to commit OAuth state: %s", e)
|
|
413
|
+
session.rollback()
|
|
414
|
+
else:
|
|
415
|
+
logger.warning(
|
|
416
|
+
"No DB session provided to build_authorization_url. "
|
|
417
|
+
"PKCE and state check will fail."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
if server.oauth_scopes:
|
|
421
|
+
params["scope"] = server.oauth_scopes
|
|
422
|
+
|
|
423
|
+
auth_url = f"{server.oauth_authorize_url}?{urlencode(params)}"
|
|
424
|
+
logger.debug(
|
|
425
|
+
"build_authorization_url: base=%s, result=%s",
|
|
426
|
+
server.oauth_authorize_url,
|
|
427
|
+
auth_url,
|
|
428
|
+
)
|
|
429
|
+
return auth_url, state
|
|
430
|
+
|
|
431
|
+
async def exchange_code_for_tokens( # noqa: PLR0911
|
|
432
|
+
self,
|
|
433
|
+
server: MCPServer,
|
|
434
|
+
code: str,
|
|
435
|
+
state: str | None = None,
|
|
436
|
+
session: Session | None = None,
|
|
437
|
+
) -> TokenResult:
|
|
438
|
+
"""Exchange authorization code for access and refresh tokens.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
server: The MCP server configuration.
|
|
442
|
+
code: The authorization code received from the callback.
|
|
443
|
+
state: The state parameter from the callback (required for PKCE).
|
|
444
|
+
session: DB Session for retrieving PKCE verifier.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
TokenResult with tokens or error information.
|
|
448
|
+
"""
|
|
449
|
+
if not server.oauth_token_url:
|
|
450
|
+
return TokenResult(error="no_token_url", error_description="No token URL")
|
|
451
|
+
|
|
452
|
+
client = await self._get_client()
|
|
453
|
+
|
|
454
|
+
data = {
|
|
455
|
+
"grant_type": "authorization_code",
|
|
456
|
+
"code": code,
|
|
457
|
+
"redirect_uri": self.redirect_uri,
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
if not server.oauth_client_id:
|
|
461
|
+
logger.error("Missing client_id for server %s", server.name)
|
|
462
|
+
return TokenResult(
|
|
463
|
+
error="config_missing",
|
|
464
|
+
error_description="Client ID missing in server configuration",
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
data["client_id"] = server.oauth_client_id
|
|
468
|
+
|
|
469
|
+
# Helper to get session from manager if not provided
|
|
470
|
+
# (Though arguments type hint says session is Optional,
|
|
471
|
+
# we expect it for state check)
|
|
472
|
+
|
|
473
|
+
# Check OAuth state (CSRF) and retrieve PKCE code_verifier
|
|
474
|
+
code_verifier: str | None = None
|
|
475
|
+
if state and session:
|
|
476
|
+
provider_key = f"mcp:{server.id}" if server.id else "mcp:unknown"
|
|
477
|
+
oauth_state = session.exec(
|
|
478
|
+
select(OAuthStateEntity).where(
|
|
479
|
+
OAuthStateEntity.state == state,
|
|
480
|
+
OAuthStateEntity.provider == provider_key,
|
|
481
|
+
)
|
|
482
|
+
).first()
|
|
483
|
+
|
|
484
|
+
if oauth_state:
|
|
485
|
+
if oauth_state.expires_at < datetime.now(UTC):
|
|
486
|
+
logger.warning("OAuth state expired for state %s", state)
|
|
487
|
+
return TokenResult(
|
|
488
|
+
error="invalid_grant",
|
|
489
|
+
error_description="OAuth state expired or invalid",
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Retrieve PKCE code_verifier before cleanup
|
|
493
|
+
code_verifier = oauth_state.code_verifier
|
|
494
|
+
|
|
495
|
+
# Clean up used state
|
|
496
|
+
session.delete(oauth_state)
|
|
497
|
+
session.commit()
|
|
498
|
+
else:
|
|
499
|
+
logger.warning("No OAuth state found for state %s.", state)
|
|
500
|
+
|
|
501
|
+
# Add PKCE code_verifier if available (required by providers like Atlassian)
|
|
502
|
+
if code_verifier:
|
|
503
|
+
data["code_verifier"] = code_verifier
|
|
504
|
+
|
|
505
|
+
if server.oauth_client_secret:
|
|
506
|
+
data["client_secret"] = server.oauth_client_secret
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
response = await client.post(
|
|
510
|
+
server.oauth_token_url,
|
|
511
|
+
data=data,
|
|
512
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
if response.status_code == HTTPStatus.OK:
|
|
516
|
+
token_data = response.json()
|
|
517
|
+
return TokenResult(
|
|
518
|
+
access_token=token_data.get("access_token"),
|
|
519
|
+
refresh_token=token_data.get("refresh_token"),
|
|
520
|
+
expires_in=token_data.get("expires_in"),
|
|
521
|
+
token_type=token_data.get("token_type", "Bearer"),
|
|
522
|
+
scope=token_data.get("scope"),
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
# Handle error response
|
|
526
|
+
try:
|
|
527
|
+
error_data = response.json()
|
|
528
|
+
return TokenResult(
|
|
529
|
+
error=error_data.get("error", "unknown_error"),
|
|
530
|
+
error_description=error_data.get(
|
|
531
|
+
"error_description",
|
|
532
|
+
f"HTTP {response.status_code}",
|
|
533
|
+
),
|
|
534
|
+
)
|
|
535
|
+
except Exception:
|
|
536
|
+
return TokenResult(
|
|
537
|
+
error="http_error",
|
|
538
|
+
error_description=f"HTTP {response.status_code}",
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
except httpx.RequestError as e:
|
|
542
|
+
logger.error("Token exchange request failed: %s", str(e))
|
|
543
|
+
return TokenResult(
|
|
544
|
+
error="request_failed",
|
|
545
|
+
error_description=str(e),
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
async def refresh_access_token(
|
|
549
|
+
self,
|
|
550
|
+
server: MCPServer,
|
|
551
|
+
refresh_token: str,
|
|
552
|
+
) -> TokenResult:
|
|
553
|
+
"""Refresh an access token using a refresh token.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
server: The MCP server configuration.
|
|
557
|
+
refresh_token: The refresh token to use.
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
TokenResult with new tokens or error information.
|
|
561
|
+
"""
|
|
562
|
+
if not server.oauth_token_url:
|
|
563
|
+
return TokenResult(error="no_token_url", error_description="No token URL")
|
|
564
|
+
|
|
565
|
+
if not server.oauth_client_id:
|
|
566
|
+
return TokenResult(error="no_client_id", error_description="No client ID")
|
|
567
|
+
|
|
568
|
+
client = await self._get_client()
|
|
569
|
+
|
|
570
|
+
data = {
|
|
571
|
+
"grant_type": "refresh_token",
|
|
572
|
+
"refresh_token": refresh_token,
|
|
573
|
+
"client_id": server.oauth_client_id,
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
if server.oauth_client_secret:
|
|
577
|
+
data["client_secret"] = server.oauth_client_secret
|
|
578
|
+
|
|
579
|
+
try:
|
|
580
|
+
response = await client.post(
|
|
581
|
+
server.oauth_token_url,
|
|
582
|
+
data=data,
|
|
583
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
if response.status_code == HTTPStatus.OK:
|
|
587
|
+
token_data = response.json()
|
|
588
|
+
return TokenResult(
|
|
589
|
+
access_token=token_data.get("access_token"),
|
|
590
|
+
refresh_token=token_data.get("refresh_token", refresh_token),
|
|
591
|
+
expires_in=token_data.get("expires_in"),
|
|
592
|
+
token_type=token_data.get("token_type", "Bearer"),
|
|
593
|
+
scope=token_data.get("scope"),
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
try:
|
|
597
|
+
error_data = response.json()
|
|
598
|
+
return TokenResult(
|
|
599
|
+
error=error_data.get("error", "unknown_error"),
|
|
600
|
+
error_description=error_data.get(
|
|
601
|
+
"error_description",
|
|
602
|
+
f"HTTP {response.status_code}",
|
|
603
|
+
),
|
|
604
|
+
)
|
|
605
|
+
except Exception:
|
|
606
|
+
return TokenResult(
|
|
607
|
+
error="http_error",
|
|
608
|
+
error_description=f"HTTP {response.status_code}",
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
except httpx.RequestError as e:
|
|
612
|
+
logger.error("Token refresh request failed: %s", str(e))
|
|
613
|
+
return TokenResult(
|
|
614
|
+
error="request_failed",
|
|
615
|
+
error_description=str(e),
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Database operations
|
|
619
|
+
|
|
620
|
+
def get_user_token(
|
|
621
|
+
self,
|
|
622
|
+
session: Session,
|
|
623
|
+
user_id: int,
|
|
624
|
+
mcp_server_id: int,
|
|
625
|
+
) -> AssistantMCPUserToken | None:
|
|
626
|
+
"""Get a user's token for an MCP server.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
session: Database session.
|
|
630
|
+
user_id: The user's ID.
|
|
631
|
+
mcp_server_id: The MCP server's ID.
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
The token record or None if not found.
|
|
635
|
+
"""
|
|
636
|
+
statement = select(AssistantMCPUserToken).where(
|
|
637
|
+
AssistantMCPUserToken.user_id == user_id,
|
|
638
|
+
AssistantMCPUserToken.mcp_server_id == mcp_server_id,
|
|
639
|
+
)
|
|
640
|
+
return session.exec(statement).first()
|
|
641
|
+
|
|
642
|
+
def save_user_token(
|
|
643
|
+
self,
|
|
644
|
+
session: Session,
|
|
645
|
+
user_id: int,
|
|
646
|
+
mcp_server_id: int,
|
|
647
|
+
token_result: TokenResult,
|
|
648
|
+
) -> AssistantMCPUserToken:
|
|
649
|
+
"""Save or update a user's token for an MCP server.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
session: Database session.
|
|
653
|
+
user_id: The user's ID.
|
|
654
|
+
mcp_server_id: The MCP server's ID.
|
|
655
|
+
token_result: The token data from exchange or refresh.
|
|
656
|
+
|
|
657
|
+
Returns:
|
|
658
|
+
The saved token record.
|
|
659
|
+
"""
|
|
660
|
+
# Calculate expiry
|
|
661
|
+
expires_in = token_result.expires_in or 3600 # Default 1 hour
|
|
662
|
+
expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
|
|
663
|
+
|
|
664
|
+
# Check for existing token
|
|
665
|
+
existing = self.get_user_token(session, user_id, mcp_server_id)
|
|
666
|
+
|
|
667
|
+
if existing:
|
|
668
|
+
existing.access_token = token_result.access_token or ""
|
|
669
|
+
if token_result.refresh_token:
|
|
670
|
+
existing.refresh_token = token_result.refresh_token
|
|
671
|
+
existing.expires_at = expires_at
|
|
672
|
+
existing.updated_at = datetime.now(UTC)
|
|
673
|
+
session.add(existing)
|
|
674
|
+
session.commit()
|
|
675
|
+
session.refresh(existing)
|
|
676
|
+
return existing
|
|
677
|
+
|
|
678
|
+
# Create new token
|
|
679
|
+
new_token = AssistantMCPUserToken(
|
|
680
|
+
user_id=user_id,
|
|
681
|
+
mcp_server_id=mcp_server_id,
|
|
682
|
+
access_token=token_result.access_token or "",
|
|
683
|
+
refresh_token=token_result.refresh_token,
|
|
684
|
+
expires_at=expires_at,
|
|
685
|
+
)
|
|
686
|
+
session.add(new_token)
|
|
687
|
+
session.commit()
|
|
688
|
+
session.refresh(new_token)
|
|
689
|
+
return new_token
|
|
690
|
+
|
|
691
|
+
def delete_user_token(
|
|
692
|
+
self,
|
|
693
|
+
session: Session,
|
|
694
|
+
user_id: int,
|
|
695
|
+
mcp_server_id: int,
|
|
696
|
+
) -> bool:
|
|
697
|
+
"""Delete a user's token for an MCP server.
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
session: Database session.
|
|
701
|
+
user_id: The user's ID.
|
|
702
|
+
mcp_server_id: The MCP server's ID.
|
|
703
|
+
|
|
704
|
+
Returns:
|
|
705
|
+
True if a token was deleted, False otherwise.
|
|
706
|
+
"""
|
|
707
|
+
token = self.get_user_token(session, user_id, mcp_server_id)
|
|
708
|
+
if token:
|
|
709
|
+
session.delete(token)
|
|
710
|
+
session.commit()
|
|
711
|
+
return True
|
|
712
|
+
return False
|
|
713
|
+
|
|
714
|
+
def is_token_valid(self, token: AssistantMCPUserToken) -> bool:
|
|
715
|
+
"""Check if a token is still valid (not expired).
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
token: The token to check.
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
True if the token is valid, False if expired.
|
|
722
|
+
"""
|
|
723
|
+
# Add 30 second buffer for clock skew
|
|
724
|
+
return token.expires_at > datetime.now(UTC) + timedelta(seconds=30)
|
|
725
|
+
|
|
726
|
+
async def ensure_valid_token(
|
|
727
|
+
self,
|
|
728
|
+
session: Session,
|
|
729
|
+
server: MCPServer,
|
|
730
|
+
token: AssistantMCPUserToken,
|
|
731
|
+
) -> AssistantMCPUserToken | None:
|
|
732
|
+
"""Ensure a token is valid, refreshing if necessary.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
session: Database session.
|
|
736
|
+
server: The MCP server configuration.
|
|
737
|
+
token: The token to validate/refresh.
|
|
738
|
+
|
|
739
|
+
Returns:
|
|
740
|
+
A valid token or None if refresh failed.
|
|
741
|
+
"""
|
|
742
|
+
if self.is_token_valid(token):
|
|
743
|
+
return token
|
|
744
|
+
|
|
745
|
+
# Token expired, try to refresh
|
|
746
|
+
if not token.refresh_token:
|
|
747
|
+
logger.warning("Token expired and no refresh token available")
|
|
748
|
+
return None
|
|
749
|
+
|
|
750
|
+
logger.debug("Refreshing expired token for server %s", server.name)
|
|
751
|
+
result = await self.refresh_access_token(server, token.refresh_token)
|
|
752
|
+
|
|
753
|
+
if result.error:
|
|
754
|
+
logger.error(
|
|
755
|
+
"Token refresh failed: %s - %s",
|
|
756
|
+
result.error,
|
|
757
|
+
result.error_description,
|
|
758
|
+
)
|
|
759
|
+
return None
|
|
760
|
+
|
|
761
|
+
# Save the refreshed token
|
|
762
|
+
return self.save_user_token(
|
|
763
|
+
session,
|
|
764
|
+
token.user_id,
|
|
765
|
+
token.mcp_server_id,
|
|
766
|
+
result,
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
def update_server_oauth_config(
|
|
770
|
+
self,
|
|
771
|
+
session: Session,
|
|
772
|
+
server: MCPServer,
|
|
773
|
+
discovery_result: OAuthDiscoveryResult,
|
|
774
|
+
) -> MCPServer:
|
|
775
|
+
"""Update server with discovered OAuth configuration.
|
|
776
|
+
|
|
777
|
+
Args:
|
|
778
|
+
session: Database session.
|
|
779
|
+
server: The MCP server to update.
|
|
780
|
+
discovery_result: The discovery result with OAuth metadata.
|
|
781
|
+
|
|
782
|
+
Returns:
|
|
783
|
+
The updated server.
|
|
784
|
+
"""
|
|
785
|
+
server.oauth_issuer = discovery_result.issuer
|
|
786
|
+
server.oauth_authorize_url = discovery_result.authorization_endpoint
|
|
787
|
+
server.oauth_token_url = discovery_result.token_endpoint
|
|
788
|
+
if discovery_result.scopes_supported:
|
|
789
|
+
server.oauth_scopes = " ".join(discovery_result.scopes_supported)
|
|
790
|
+
server.oauth_discovered_at = datetime.now(UTC)
|
|
791
|
+
server.auth_type = MCPAuthType.OAUTH_DISCOVERY
|
|
792
|
+
|
|
793
|
+
session.add(server)
|
|
794
|
+
session.commit()
|
|
795
|
+
session.refresh(server)
|
|
796
|
+
return server
|