appkit-assistant 0.14.1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,796 @@
1
+ """MCP OAuth Authentication Service.
2
+
3
+ Handles OAuth 2.0 flows for MCP servers including:
4
+ - Discovery of OAuth metadata via RFC 8414
5
+ - Authorization URL construction with PKCE support
6
+ - Token exchange (authorization code for tokens)
7
+ - Token refresh
8
+ - Token storage and retrieval
9
+ """
10
+
11
+ import base64
12
+ import hashlib
13
+ import logging
14
+ import secrets
15
+ from dataclasses import dataclass
16
+ from datetime import UTC, datetime, timedelta
17
+ from http import HTTPStatus
18
+ from urllib.parse import urlencode, urlparse
19
+
20
+ import httpx
21
+ from sqlmodel import Session, select
22
+
23
+ from appkit_assistant.backend.models import (
24
+ AssistantMCPUserToken,
25
+ MCPAuthType,
26
+ MCPServer,
27
+ )
28
+ from appkit_user.authentication.backend.entities import OAuthStateEntity
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Discovery paths per RFC 8414
33
+ WELL_KNOWN_PATHS = [
34
+ "/.well-known/oauth-authorization-server",
35
+ "/.well-known/openid-configuration",
36
+ ]
37
+
38
+
39
+ def _generate_pkce_pair() -> tuple[str, str]:
40
+ """Generate PKCE code verifier and challenge (S256).
41
+
42
+ Returns:
43
+ Tuple of (code_verifier, code_challenge)
44
+ """
45
+ # Generate code verifier (43-128 characters)
46
+ code_verifier = (
47
+ base64.urlsafe_b64encode(secrets.token_bytes(32)).decode("utf-8").rstrip("=")
48
+ )
49
+
50
+ # Generate code challenge (SHA256 hash of verifier)
51
+ code_challenge = (
52
+ base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("utf-8")).digest())
53
+ .decode("utf-8")
54
+ .rstrip("=")
55
+ )
56
+
57
+ return code_verifier, code_challenge
58
+
59
+
60
+ @dataclass
61
+ class OAuthDiscoveryResult:
62
+ """Result of OAuth metadata discovery."""
63
+
64
+ issuer: str | None = None
65
+ authorization_endpoint: str | None = None
66
+ token_endpoint: str | None = None
67
+ registration_endpoint: str | None = None
68
+ scopes_supported: list[str] | None = None
69
+ error: str | None = None
70
+
71
+
72
+ @dataclass
73
+ class ClientRegistrationResult:
74
+ """Result of OAuth Dynamic Client Registration (RFC 7591)."""
75
+
76
+ client_id: str | None = None
77
+ client_secret: str | None = None
78
+ client_id_issued_at: int | None = None
79
+ client_secret_expires_at: int | None = None
80
+ error: str | None = None
81
+ error_description: str | None = None
82
+
83
+
84
+ @dataclass
85
+ class TokenResult:
86
+ """Result of token exchange or refresh."""
87
+
88
+ access_token: str | None = None
89
+ refresh_token: str | None = None
90
+ expires_in: int | None = None
91
+ token_type: str | None = None
92
+ scope: str | None = None
93
+ error: str | None = None
94
+ error_description: str | None = None
95
+
96
+
97
+ class MCPAuthService:
98
+ """Service for handling MCP OAuth authentication flows."""
99
+
100
+ def __init__(self, redirect_uri: str) -> None:
101
+ """Initialize the service.
102
+
103
+ Args:
104
+ redirect_uri: The OAuth redirect URI for the callback endpoint.
105
+ """
106
+ self.redirect_uri = redirect_uri
107
+ self._http_client: httpx.AsyncClient | None = None
108
+
109
+ async def _get_client(self) -> httpx.AsyncClient:
110
+ """Get or create the HTTP client."""
111
+ if self._http_client is None:
112
+ self._http_client = httpx.AsyncClient(timeout=30.0)
113
+ return self._http_client
114
+
115
+ async def close(self) -> None:
116
+ """Close the HTTP client."""
117
+ if self._http_client is not None:
118
+ await self._http_client.aclose()
119
+ self._http_client = None
120
+
121
+ async def discover_oauth_config(self, server_url: str) -> OAuthDiscoveryResult:
122
+ """Discover OAuth metadata from the server.
123
+
124
+ Attempts to fetch OAuth metadata from well-known endpoints per RFC 8414.
125
+
126
+ Args:
127
+ server_url: The MCP server URL to discover OAuth config from.
128
+
129
+ Returns:
130
+ OAuthDiscoveryResult with discovered endpoints or error.
131
+ """
132
+ client = await self._get_client()
133
+ parsed = urlparse(server_url)
134
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
135
+
136
+ for path in WELL_KNOWN_PATHS:
137
+ discovery_url = f"{base_url}{path}"
138
+ logger.debug("Attempting OAuth discovery at: %s", discovery_url)
139
+
140
+ try:
141
+ response = await client.get(discovery_url)
142
+ if response.status_code == HTTPStatus.OK:
143
+ data = response.json()
144
+ return OAuthDiscoveryResult(
145
+ issuer=data.get("issuer"),
146
+ authorization_endpoint=data.get("authorization_endpoint"),
147
+ token_endpoint=data.get("token_endpoint"),
148
+ registration_endpoint=data.get("registration_endpoint"),
149
+ scopes_supported=data.get("scopes_supported"),
150
+ )
151
+ except httpx.RequestError as e:
152
+ logger.debug("Discovery failed at %s: %s", discovery_url, str(e))
153
+ continue
154
+ except Exception as e:
155
+ logger.warning("Unexpected error during discovery: %s", str(e))
156
+ continue
157
+
158
+ return OAuthDiscoveryResult(
159
+ error="OAuth discovery failed: No valid metadata found at well-known paths"
160
+ )
161
+
162
+ async def register_client(
163
+ self,
164
+ registration_endpoint: str,
165
+ client_name: str = "AppKit Assistant",
166
+ additional_redirect_uris: list[str] | None = None,
167
+ ) -> ClientRegistrationResult:
168
+ """Register a new OAuth client via Dynamic Client Registration (RFC 7591).
169
+
170
+ Some OAuth providers like Atlassian MCP require clients to register
171
+ dynamically before they can authenticate.
172
+
173
+ Args:
174
+ registration_endpoint: The OAuth registration endpoint URL.
175
+ client_name: The name to register the client with.
176
+ additional_redirect_uris: Additional redirect URIs to register.
177
+
178
+ Returns:
179
+ ClientRegistrationResult with client_id and optionally client_secret.
180
+ """
181
+ client = await self._get_client()
182
+
183
+ redirect_uris = [self.redirect_uri]
184
+ if additional_redirect_uris:
185
+ redirect_uris.extend(additional_redirect_uris)
186
+
187
+ # Client metadata per RFC 7591
188
+ client_metadata = {
189
+ "client_name": client_name,
190
+ "redirect_uris": redirect_uris,
191
+ "grant_types": ["authorization_code", "refresh_token"],
192
+ "response_types": ["code"],
193
+ "token_endpoint_auth_method": "none", # Public client (no secret)
194
+ }
195
+
196
+ logger.debug(
197
+ "Registering OAuth client at %s with metadata: %s",
198
+ registration_endpoint,
199
+ client_metadata,
200
+ )
201
+
202
+ try:
203
+ response = await client.post(
204
+ registration_endpoint,
205
+ json=client_metadata,
206
+ headers={"Content-Type": "application/json"},
207
+ )
208
+
209
+ if response.status_code in (HTTPStatus.OK, HTTPStatus.CREATED):
210
+ data = response.json()
211
+ logger.debug(
212
+ "Successfully registered OAuth client: %s",
213
+ data.get("client_id"),
214
+ )
215
+ return ClientRegistrationResult(
216
+ client_id=data.get("client_id"),
217
+ client_secret=data.get("client_secret"),
218
+ client_id_issued_at=data.get("client_id_issued_at"),
219
+ client_secret_expires_at=data.get("client_secret_expires_at"),
220
+ )
221
+
222
+ # Handle error response
223
+ try:
224
+ error_data = response.json()
225
+ return ClientRegistrationResult(
226
+ error=error_data.get("error", "registration_failed"),
227
+ error_description=error_data.get(
228
+ "error_description",
229
+ f"HTTP {response.status_code}: {response.text}",
230
+ ),
231
+ )
232
+ except Exception:
233
+ return ClientRegistrationResult(
234
+ error="registration_failed",
235
+ error_description=f"HTTP {response.status_code}: {response.text}",
236
+ )
237
+
238
+ except httpx.RequestError as e:
239
+ logger.error("Client registration request failed: %s", str(e))
240
+ return ClientRegistrationResult(
241
+ error="request_failed",
242
+ error_description=str(e),
243
+ )
244
+
245
+ async def build_authorization_url_with_registration(
246
+ self,
247
+ server: MCPServer,
248
+ state: str | None = None,
249
+ session: Session | None = None,
250
+ user_id: int | None = None,
251
+ ) -> tuple[str, str]:
252
+ """Build the OAuth authorization URL, registering client via DCR if needed.
253
+
254
+ This method will automatically discover OAuth config and perform Dynamic
255
+ Client Registration (RFC 7591) if the server has no client_id configured
256
+ and a registration_endpoint is available.
257
+
258
+ Args:
259
+ server: The MCP server configuration.
260
+ state: Optional state parameter. If not provided, a random one is generated.
261
+ session: DB Session for storing PKCE state and updating server config.
262
+ user_id: User ID for binding state.
263
+
264
+ Returns:
265
+ Tuple of (authorization_url, state) where state should be stored
266
+ for verification.
267
+ """
268
+ # Ensure we work with a session-attached object to avoid StateProxy issues
269
+ # and to allow persisting changes if DCR happens.
270
+ if session and server.id:
271
+ db_server = session.get(MCPServer, server.id)
272
+ if db_server:
273
+ server = db_server
274
+
275
+ # If no client_id, attempt Dynamic Client Registration
276
+ if not server.oauth_client_id:
277
+ logger.debug(
278
+ "No client_id for server %s, attempting DCR",
279
+ server.name,
280
+ )
281
+
282
+ # Discover OAuth config to find registration endpoint
283
+ discovery_result = await self.discover_oauth_config(server.url)
284
+
285
+ if discovery_result.error:
286
+ msg = f"OAuth discovery failed: {discovery_result.error}"
287
+ raise ValueError(msg)
288
+
289
+ if not discovery_result.registration_endpoint:
290
+ msg = "Server has no client ID and no registration endpoint available"
291
+ raise ValueError(msg)
292
+
293
+ # Perform Dynamic Client Registration
294
+ reg_result = await self.register_client(
295
+ registration_endpoint=discovery_result.registration_endpoint,
296
+ client_name="AppKit Assistant",
297
+ )
298
+
299
+ if reg_result.error:
300
+ msg = (
301
+ f"Client registration failed: "
302
+ f"{reg_result.error}: {reg_result.error_description}"
303
+ )
304
+ raise ValueError(msg)
305
+
306
+ if not reg_result.client_id:
307
+ msg = "Client registration succeeded but no client_id returned"
308
+ raise ValueError(msg)
309
+
310
+ # Update server with registered client_id
311
+ server.oauth_client_id = reg_result.client_id
312
+ if reg_result.client_secret:
313
+ server.oauth_client_secret = reg_result.client_secret
314
+
315
+ # Also update discovered endpoints if not already set
316
+ auth_endpoint = discovery_result.authorization_endpoint
317
+ if not server.oauth_authorize_url and auth_endpoint:
318
+ server.oauth_authorize_url = auth_endpoint
319
+ if not server.oauth_token_url and discovery_result.token_endpoint:
320
+ server.oauth_token_url = discovery_result.token_endpoint
321
+
322
+ # Persist updated server configuration
323
+ if session:
324
+ session.add(server)
325
+ try:
326
+ session.commit()
327
+ logger.debug(
328
+ "Persisted DCR client_id %s for server %s",
329
+ reg_result.client_id,
330
+ server.name,
331
+ )
332
+ except Exception as e:
333
+ logger.error("Failed to persist DCR client_id: %s", e)
334
+ session.rollback()
335
+
336
+ # Now delegate to the synchronous URL builder
337
+ return self.build_authorization_url(
338
+ server=server,
339
+ state=state,
340
+ session=session,
341
+ user_id=user_id,
342
+ )
343
+
344
+ def build_authorization_url(
345
+ self,
346
+ server: MCPServer,
347
+ state: str | None = None,
348
+ session: Session | None = None,
349
+ user_id: int | None = None,
350
+ ) -> tuple[str, str]:
351
+ """Build the OAuth authorization URL for user login.
352
+
353
+ Supports PKCE by generating code_verifier and storing it if a session is
354
+ provided.
355
+
356
+ Args:
357
+ server: The MCP server configuration.
358
+ state: Optional state parameter. If not provided, a random one is generated.
359
+ session: DB Session for storing PKCE state.
360
+ user_id: User ID for binding state.
361
+
362
+ Returns:
363
+ Tuple of (authorization_url, state) where state should be stored
364
+ for verification.
365
+ """
366
+ if not server.oauth_authorize_url:
367
+ msg = "Server has no authorization URL configured"
368
+ raise ValueError(msg)
369
+
370
+ if state is None:
371
+ state = secrets.token_urlsafe(32)
372
+ logger.info("Generated new OAuth state: %s", state)
373
+
374
+ if not server.oauth_client_id:
375
+ msg = "Server has no client ID configured"
376
+ raise ValueError(msg)
377
+
378
+ # Generate PKCE parameters (required by OAuth 2.1 / MCP servers)
379
+ code_verifier, code_challenge = _generate_pkce_pair()
380
+
381
+ params = {
382
+ "response_type": "code",
383
+ "redirect_uri": self.redirect_uri,
384
+ "state": state,
385
+ "client_id": server.oauth_client_id,
386
+ "code_challenge": code_challenge,
387
+ "code_challenge_method": "S256",
388
+ }
389
+
390
+ # Store state in DB for CSRF protection, server mapping, and PKCE verifier
391
+ if session:
392
+ provider_key = f"mcp:{server.id}" if server.id else "mcp:unknown"
393
+ logger.info(
394
+ "Saving OAuth state to DB: state=%s, provider=%s, user_id=%s",
395
+ state,
396
+ provider_key,
397
+ user_id,
398
+ )
399
+ oauth_state = OAuthStateEntity(
400
+ session_id="mcp_auth_flow",
401
+ state=state,
402
+ provider=provider_key,
403
+ code_verifier=code_verifier,
404
+ expires_at=datetime.now(UTC) + timedelta(minutes=10),
405
+ user_id=user_id,
406
+ )
407
+ session.add(oauth_state)
408
+ try:
409
+ session.commit()
410
+ logger.info("OAuth state committed successfully: %s", state)
411
+ except Exception as e:
412
+ logger.error("Failed to commit OAuth state: %s", e)
413
+ session.rollback()
414
+ else:
415
+ logger.warning(
416
+ "No DB session provided to build_authorization_url. "
417
+ "PKCE and state check will fail."
418
+ )
419
+
420
+ if server.oauth_scopes:
421
+ params["scope"] = server.oauth_scopes
422
+
423
+ auth_url = f"{server.oauth_authorize_url}?{urlencode(params)}"
424
+ logger.debug(
425
+ "build_authorization_url: base=%s, result=%s",
426
+ server.oauth_authorize_url,
427
+ auth_url,
428
+ )
429
+ return auth_url, state
430
+
431
+ async def exchange_code_for_tokens( # noqa: PLR0911
432
+ self,
433
+ server: MCPServer,
434
+ code: str,
435
+ state: str | None = None,
436
+ session: Session | None = None,
437
+ ) -> TokenResult:
438
+ """Exchange authorization code for access and refresh tokens.
439
+
440
+ Args:
441
+ server: The MCP server configuration.
442
+ code: The authorization code received from the callback.
443
+ state: The state parameter from the callback (required for PKCE).
444
+ session: DB Session for retrieving PKCE verifier.
445
+
446
+ Returns:
447
+ TokenResult with tokens or error information.
448
+ """
449
+ if not server.oauth_token_url:
450
+ return TokenResult(error="no_token_url", error_description="No token URL")
451
+
452
+ client = await self._get_client()
453
+
454
+ data = {
455
+ "grant_type": "authorization_code",
456
+ "code": code,
457
+ "redirect_uri": self.redirect_uri,
458
+ }
459
+
460
+ if not server.oauth_client_id:
461
+ logger.error("Missing client_id for server %s", server.name)
462
+ return TokenResult(
463
+ error="config_missing",
464
+ error_description="Client ID missing in server configuration",
465
+ )
466
+
467
+ data["client_id"] = server.oauth_client_id
468
+
469
+ # Helper to get session from manager if not provided
470
+ # (Though arguments type hint says session is Optional,
471
+ # we expect it for state check)
472
+
473
+ # Check OAuth state (CSRF) and retrieve PKCE code_verifier
474
+ code_verifier: str | None = None
475
+ if state and session:
476
+ provider_key = f"mcp:{server.id}" if server.id else "mcp:unknown"
477
+ oauth_state = session.exec(
478
+ select(OAuthStateEntity).where(
479
+ OAuthStateEntity.state == state,
480
+ OAuthStateEntity.provider == provider_key,
481
+ )
482
+ ).first()
483
+
484
+ if oauth_state:
485
+ if oauth_state.expires_at < datetime.now(UTC):
486
+ logger.warning("OAuth state expired for state %s", state)
487
+ return TokenResult(
488
+ error="invalid_grant",
489
+ error_description="OAuth state expired or invalid",
490
+ )
491
+
492
+ # Retrieve PKCE code_verifier before cleanup
493
+ code_verifier = oauth_state.code_verifier
494
+
495
+ # Clean up used state
496
+ session.delete(oauth_state)
497
+ session.commit()
498
+ else:
499
+ logger.warning("No OAuth state found for state %s.", state)
500
+
501
+ # Add PKCE code_verifier if available (required by providers like Atlassian)
502
+ if code_verifier:
503
+ data["code_verifier"] = code_verifier
504
+
505
+ if server.oauth_client_secret:
506
+ data["client_secret"] = server.oauth_client_secret
507
+
508
+ try:
509
+ response = await client.post(
510
+ server.oauth_token_url,
511
+ data=data,
512
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
513
+ )
514
+
515
+ if response.status_code == HTTPStatus.OK:
516
+ token_data = response.json()
517
+ return TokenResult(
518
+ access_token=token_data.get("access_token"),
519
+ refresh_token=token_data.get("refresh_token"),
520
+ expires_in=token_data.get("expires_in"),
521
+ token_type=token_data.get("token_type", "Bearer"),
522
+ scope=token_data.get("scope"),
523
+ )
524
+
525
+ # Handle error response
526
+ try:
527
+ error_data = response.json()
528
+ return TokenResult(
529
+ error=error_data.get("error", "unknown_error"),
530
+ error_description=error_data.get(
531
+ "error_description",
532
+ f"HTTP {response.status_code}",
533
+ ),
534
+ )
535
+ except Exception:
536
+ return TokenResult(
537
+ error="http_error",
538
+ error_description=f"HTTP {response.status_code}",
539
+ )
540
+
541
+ except httpx.RequestError as e:
542
+ logger.error("Token exchange request failed: %s", str(e))
543
+ return TokenResult(
544
+ error="request_failed",
545
+ error_description=str(e),
546
+ )
547
+
548
+ async def refresh_access_token(
549
+ self,
550
+ server: MCPServer,
551
+ refresh_token: str,
552
+ ) -> TokenResult:
553
+ """Refresh an access token using a refresh token.
554
+
555
+ Args:
556
+ server: The MCP server configuration.
557
+ refresh_token: The refresh token to use.
558
+
559
+ Returns:
560
+ TokenResult with new tokens or error information.
561
+ """
562
+ if not server.oauth_token_url:
563
+ return TokenResult(error="no_token_url", error_description="No token URL")
564
+
565
+ if not server.oauth_client_id:
566
+ return TokenResult(error="no_client_id", error_description="No client ID")
567
+
568
+ client = await self._get_client()
569
+
570
+ data = {
571
+ "grant_type": "refresh_token",
572
+ "refresh_token": refresh_token,
573
+ "client_id": server.oauth_client_id,
574
+ }
575
+
576
+ if server.oauth_client_secret:
577
+ data["client_secret"] = server.oauth_client_secret
578
+
579
+ try:
580
+ response = await client.post(
581
+ server.oauth_token_url,
582
+ data=data,
583
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
584
+ )
585
+
586
+ if response.status_code == HTTPStatus.OK:
587
+ token_data = response.json()
588
+ return TokenResult(
589
+ access_token=token_data.get("access_token"),
590
+ refresh_token=token_data.get("refresh_token", refresh_token),
591
+ expires_in=token_data.get("expires_in"),
592
+ token_type=token_data.get("token_type", "Bearer"),
593
+ scope=token_data.get("scope"),
594
+ )
595
+
596
+ try:
597
+ error_data = response.json()
598
+ return TokenResult(
599
+ error=error_data.get("error", "unknown_error"),
600
+ error_description=error_data.get(
601
+ "error_description",
602
+ f"HTTP {response.status_code}",
603
+ ),
604
+ )
605
+ except Exception:
606
+ return TokenResult(
607
+ error="http_error",
608
+ error_description=f"HTTP {response.status_code}",
609
+ )
610
+
611
+ except httpx.RequestError as e:
612
+ logger.error("Token refresh request failed: %s", str(e))
613
+ return TokenResult(
614
+ error="request_failed",
615
+ error_description=str(e),
616
+ )
617
+
618
+ # Database operations
619
+
620
+ def get_user_token(
621
+ self,
622
+ session: Session,
623
+ user_id: int,
624
+ mcp_server_id: int,
625
+ ) -> AssistantMCPUserToken | None:
626
+ """Get a user's token for an MCP server.
627
+
628
+ Args:
629
+ session: Database session.
630
+ user_id: The user's ID.
631
+ mcp_server_id: The MCP server's ID.
632
+
633
+ Returns:
634
+ The token record or None if not found.
635
+ """
636
+ statement = select(AssistantMCPUserToken).where(
637
+ AssistantMCPUserToken.user_id == user_id,
638
+ AssistantMCPUserToken.mcp_server_id == mcp_server_id,
639
+ )
640
+ return session.exec(statement).first()
641
+
642
+ def save_user_token(
643
+ self,
644
+ session: Session,
645
+ user_id: int,
646
+ mcp_server_id: int,
647
+ token_result: TokenResult,
648
+ ) -> AssistantMCPUserToken:
649
+ """Save or update a user's token for an MCP server.
650
+
651
+ Args:
652
+ session: Database session.
653
+ user_id: The user's ID.
654
+ mcp_server_id: The MCP server's ID.
655
+ token_result: The token data from exchange or refresh.
656
+
657
+ Returns:
658
+ The saved token record.
659
+ """
660
+ # Calculate expiry
661
+ expires_in = token_result.expires_in or 3600 # Default 1 hour
662
+ expires_at = datetime.now(UTC) + timedelta(seconds=expires_in)
663
+
664
+ # Check for existing token
665
+ existing = self.get_user_token(session, user_id, mcp_server_id)
666
+
667
+ if existing:
668
+ existing.access_token = token_result.access_token or ""
669
+ if token_result.refresh_token:
670
+ existing.refresh_token = token_result.refresh_token
671
+ existing.expires_at = expires_at
672
+ existing.updated_at = datetime.now(UTC)
673
+ session.add(existing)
674
+ session.commit()
675
+ session.refresh(existing)
676
+ return existing
677
+
678
+ # Create new token
679
+ new_token = AssistantMCPUserToken(
680
+ user_id=user_id,
681
+ mcp_server_id=mcp_server_id,
682
+ access_token=token_result.access_token or "",
683
+ refresh_token=token_result.refresh_token,
684
+ expires_at=expires_at,
685
+ )
686
+ session.add(new_token)
687
+ session.commit()
688
+ session.refresh(new_token)
689
+ return new_token
690
+
691
+ def delete_user_token(
692
+ self,
693
+ session: Session,
694
+ user_id: int,
695
+ mcp_server_id: int,
696
+ ) -> bool:
697
+ """Delete a user's token for an MCP server.
698
+
699
+ Args:
700
+ session: Database session.
701
+ user_id: The user's ID.
702
+ mcp_server_id: The MCP server's ID.
703
+
704
+ Returns:
705
+ True if a token was deleted, False otherwise.
706
+ """
707
+ token = self.get_user_token(session, user_id, mcp_server_id)
708
+ if token:
709
+ session.delete(token)
710
+ session.commit()
711
+ return True
712
+ return False
713
+
714
+ def is_token_valid(self, token: AssistantMCPUserToken) -> bool:
715
+ """Check if a token is still valid (not expired).
716
+
717
+ Args:
718
+ token: The token to check.
719
+
720
+ Returns:
721
+ True if the token is valid, False if expired.
722
+ """
723
+ # Add 30 second buffer for clock skew
724
+ return token.expires_at > datetime.now(UTC) + timedelta(seconds=30)
725
+
726
+ async def ensure_valid_token(
727
+ self,
728
+ session: Session,
729
+ server: MCPServer,
730
+ token: AssistantMCPUserToken,
731
+ ) -> AssistantMCPUserToken | None:
732
+ """Ensure a token is valid, refreshing if necessary.
733
+
734
+ Args:
735
+ session: Database session.
736
+ server: The MCP server configuration.
737
+ token: The token to validate/refresh.
738
+
739
+ Returns:
740
+ A valid token or None if refresh failed.
741
+ """
742
+ if self.is_token_valid(token):
743
+ return token
744
+
745
+ # Token expired, try to refresh
746
+ if not token.refresh_token:
747
+ logger.warning("Token expired and no refresh token available")
748
+ return None
749
+
750
+ logger.debug("Refreshing expired token for server %s", server.name)
751
+ result = await self.refresh_access_token(server, token.refresh_token)
752
+
753
+ if result.error:
754
+ logger.error(
755
+ "Token refresh failed: %s - %s",
756
+ result.error,
757
+ result.error_description,
758
+ )
759
+ return None
760
+
761
+ # Save the refreshed token
762
+ return self.save_user_token(
763
+ session,
764
+ token.user_id,
765
+ token.mcp_server_id,
766
+ result,
767
+ )
768
+
769
+ def update_server_oauth_config(
770
+ self,
771
+ session: Session,
772
+ server: MCPServer,
773
+ discovery_result: OAuthDiscoveryResult,
774
+ ) -> MCPServer:
775
+ """Update server with discovered OAuth configuration.
776
+
777
+ Args:
778
+ session: Database session.
779
+ server: The MCP server to update.
780
+ discovery_result: The discovery result with OAuth metadata.
781
+
782
+ Returns:
783
+ The updated server.
784
+ """
785
+ server.oauth_issuer = discovery_result.issuer
786
+ server.oauth_authorize_url = discovery_result.authorization_endpoint
787
+ server.oauth_token_url = discovery_result.token_endpoint
788
+ if discovery_result.scopes_supported:
789
+ server.oauth_scopes = " ".join(discovery_result.scopes_supported)
790
+ server.oauth_discovered_at = datetime.now(UTC)
791
+ server.auth_type = MCPAuthType.OAUTH_DISCOVERY
792
+
793
+ session.add(server)
794
+ session.commit()
795
+ session.refresh(server)
796
+ return server