open-edison 0.1.19__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ from contextvars import ContextVar
12
12
  from dataclasses import dataclass, field
13
13
  from datetime import datetime
14
14
  from pathlib import Path
15
- from typing import Any
15
+ from typing import Any, cast
16
16
 
17
17
  import mcp.types as mt
18
18
  from fastmcp.prompts.prompt import FunctionPrompt
@@ -27,8 +27,13 @@ from sqlalchemy import JSON, Column, Integer, String, create_engine, event
27
27
  from sqlalchemy.orm import Session, declarative_base
28
28
  from sqlalchemy.sql import select
29
29
 
30
+ from src import events
30
31
  from src.config import get_config_dir # type: ignore[reportMissingImports]
31
- from src.middleware.data_access_tracker import DataAccessTracker
32
+ from src.middleware.data_access_tracker import (
33
+ DataAccessTracker,
34
+ SecurityError,
35
+ )
36
+ from src.permissions import Permissions
32
37
  from src.telemetry import (
33
38
  record_prompt_used,
34
39
  record_resource_used,
@@ -67,7 +72,10 @@ class MCPSessionModel(Base): # type: ignore
67
72
  data_access_summary = Column(JSON) # type: ignore
68
73
 
69
74
 
70
- current_session_id_ctxvar: ContextVar[str | None] = ContextVar("current_session_id", default=None)
75
+ current_session_id_ctxvar: ContextVar[str | None] = ContextVar(
76
+ "current_session_id",
77
+ default=cast(str | None, None), # noqa: B039
78
+ )
71
79
 
72
80
 
73
81
  def get_current_session_data_tracker() -> DataAccessTracker | None:
@@ -101,7 +109,7 @@ def create_db_session() -> Generator[Session, None, None]:
101
109
  engine = create_engine(f"sqlite:///{db_path}")
102
110
 
103
111
  # Ensure changes are flushed to the main database file (avoid WAL for sql.js compatibility)
104
- @event.listens_for(engine, "connect")
112
+ @event.listens_for(engine, "connect") # noqa
105
113
  def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore[no-untyped-def] # noqa
106
114
  cur = dbapi_connection.cursor() # type: ignore[attr-defined]
107
115
  try:
@@ -171,7 +179,7 @@ def get_session_from_db(session_id: str) -> MCPSession:
171
179
  "has_external_communication", False
172
180
  )
173
181
  # Restore ACL highest level if present
174
- if isinstance(summary_data, dict) and "acl" in summary_data:
182
+ if isinstance(summary_data, dict) and "acl" in summary_data: # type: ignore
175
183
  acl_summary: Any = summary_data.get("acl") # type: ignore
176
184
  if isinstance(acl_summary, dict):
177
185
  highest = acl_summary.get("highest_acl_level") # type: ignore
@@ -214,7 +222,7 @@ class SessionTrackingMiddleware(Middleware):
214
222
  return session, session_id
215
223
 
216
224
  # General hooks for on_request, on_message, etc.
217
- async def on_request(
225
+ async def on_request( # noqa
218
226
  self,
219
227
  context: MiddlewareContext[mt.Request[Any, Any]], # type: ignore
220
228
  call_next: CallNext[mt.Request[Any, Any], Any], # type: ignore
@@ -225,17 +233,25 @@ class SessionTrackingMiddleware(Middleware):
225
233
  # Get or create session stats
226
234
  _, _session_id = self._get_or_create_session_stats(context)
227
235
 
228
- return await call_next(context) # type: ignore
236
+ try:
237
+ return await call_next(context) # type: ignore
238
+ except Exception:
239
+ log.exception("MCP request handling failed")
240
+ raise
229
241
 
230
242
  # Hooks for Tools
231
- async def on_list_tools(
243
+ async def on_list_tools( # noqa
232
244
  self,
233
245
  context: MiddlewareContext[Any], # type: ignore
234
246
  call_next: CallNext[Any, Any], # type: ignore
235
247
  ) -> Any:
236
248
  log.debug("🔍 on_list_tools")
237
249
  # Get the original response
238
- response = await call_next(context)
250
+ try:
251
+ response = await call_next(context)
252
+ except Exception:
253
+ log.exception("MCP list_tools failed")
254
+ raise
239
255
  log.trace(f"🔍 on_list_tools response: {response}")
240
256
 
241
257
  session_id = current_session_id_ctxvar.get()
@@ -247,8 +263,11 @@ class SessionTrackingMiddleware(Middleware):
247
263
 
248
264
  # Filter out specific tools or return empty list
249
265
  allowed_tools: list[FunctionTool | ProxyTool | Any] = []
266
+ perms = Permissions()
250
267
  for tool in response:
251
- log.trace(f"🔍 Processing tool listing {tool.name}")
268
+ # Due to proxy & server naming
269
+ tool_name = tool.key
270
+ log.trace(f"🔍 Processing tool listing {tool_name}")
252
271
  if isinstance(tool, FunctionTool):
253
272
  log.trace("🔍 Tool is built-in")
254
273
  log.trace(f"🔍 Tool is a FunctionTool: {tool}")
@@ -260,20 +279,18 @@ class SessionTrackingMiddleware(Middleware):
260
279
  log.trace(f"🔍 Tool is a unknown type: {tool}")
261
280
  continue
262
281
 
263
- log.trace(f"🔍 Getting permissions for tool {tool.name}")
264
- permissions = session.data_access_tracker.get_tool_permissions(tool.name)
265
- log.trace(f"🔍 Tool permissions: {permissions}")
266
- if permissions["enabled"]:
282
+ log.trace(f"🔍 Getting permissions for tool {tool_name}")
283
+ if perms.is_tool_enabled(tool_name):
267
284
  allowed_tools.append(tool)
268
285
  else:
269
286
  log.warning(
270
- f"🔍 Tool {tool.name} is disabled on not configured and will not be allowed"
287
+ f"🔍 Tool {tool_name} is disabled or not configured and will not be allowed"
271
288
  )
272
289
  continue
273
290
 
274
291
  return allowed_tools # type: ignore
275
292
 
276
- async def on_call_tool(
293
+ async def on_call_tool( # noqa
277
294
  self,
278
295
  context: MiddlewareContext[mt.CallToolRequestParams], # type: ignore
279
296
  call_next: CallNext[mt.CallToolRequestParams, ToolResult], # type: ignore
@@ -296,7 +313,28 @@ class SessionTrackingMiddleware(Middleware):
296
313
 
297
314
  assert session.data_access_tracker is not None
298
315
  log.debug(f"🔍 Analyzing tool {context.message.name} for security implications")
299
- session.data_access_tracker.add_tool_call(context.message.name)
316
+ try:
317
+ session.data_access_tracker.add_tool_call(context.message.name)
318
+ except SecurityError as e:
319
+ # Publish pre-block event enriched with session_id then wait up to 30s for approval
320
+ events.fire_and_forget(
321
+ {
322
+ "type": "mcp_pre_block",
323
+ "kind": "tool",
324
+ "name": context.message.name,
325
+ "session_id": session_id,
326
+ "error": str(e),
327
+ }
328
+ )
329
+ approved = await events.wait_for_approval(
330
+ session_id, "tool", context.message.name, timeout_s=30.0
331
+ )
332
+ if not approved:
333
+ raise
334
+ # Approved: apply effects and proceed
335
+ session.data_access_tracker.apply_effects_after_manual_approval(
336
+ "tool", context.message.name
337
+ )
300
338
  # Telemetry: record tool call
301
339
  record_tool_call(context.message.name)
302
340
 
@@ -330,7 +368,7 @@ class SessionTrackingMiddleware(Middleware):
330
368
  return await call_next(context) # type: ignore
331
369
 
332
370
  # Hooks for Resources
333
- async def on_list_resources(
371
+ async def on_list_resources( # noqa
334
372
  self,
335
373
  context: MiddlewareContext[Any], # type: ignore
336
374
  call_next: CallNext[Any, Any], # type: ignore
@@ -338,7 +376,11 @@ class SessionTrackingMiddleware(Middleware):
338
376
  """Process resource access and track security implications."""
339
377
  log.trace("🔍 on_list_resources")
340
378
  # Get the original response
341
- response = await call_next(context)
379
+ try:
380
+ response = await call_next(context)
381
+ except Exception:
382
+ log.exception("MCP list_resources failed")
383
+ raise
342
384
  log.trace(f"🔍 on_list_resources response: {response}")
343
385
 
344
386
  session_id = current_session_id_ctxvar.get()
@@ -350,6 +392,7 @@ class SessionTrackingMiddleware(Middleware):
350
392
 
351
393
  # Filter out specific tools or return empty list
352
394
  allowed_resources: list[FunctionResource | ProxyResource | Any] = []
395
+ perms = Permissions()
353
396
  for resource in response:
354
397
  resource_name = str(resource.uri)
355
398
  log.trace(f"🔍 Processing resource listing {resource_name}")
@@ -365,19 +408,17 @@ class SessionTrackingMiddleware(Middleware):
365
408
  continue
366
409
 
367
410
  log.trace(f"🔍 Getting permissions for resource {resource_name}")
368
- permissions = session.data_access_tracker.get_resource_permissions(resource_name)
369
- log.trace(f"🔍 Resource permissions: {permissions}")
370
- if permissions["enabled"]:
411
+ if perms.is_resource_enabled(resource_name):
371
412
  allowed_resources.append(resource)
372
413
  else:
373
414
  log.warning(
374
- f"🔍 Resource {resource_name} is disabled on not configured and will not be allowed"
415
+ f"🔍 Resource {resource_name} is disabled or not configured and will not be allowed"
375
416
  )
376
417
  continue
377
418
 
378
419
  return allowed_resources # type: ignore
379
420
 
380
- async def on_read_resource(
421
+ async def on_read_resource( # noqa
381
422
  self,
382
423
  context: MiddlewareContext[Any], # type: ignore
383
424
  call_next: CallNext[Any, Any], # type: ignore
@@ -386,7 +427,11 @@ class SessionTrackingMiddleware(Middleware):
386
427
  session_id = current_session_id_ctxvar.get()
387
428
  if session_id is None:
388
429
  log.warning("No session ID found for resource access tracking")
389
- return await call_next(context)
430
+ try:
431
+ return await call_next(context)
432
+ except Exception:
433
+ log.exception("MCP read_resource failed")
434
+ raise
390
435
 
391
436
  session = get_session_from_db(session_id)
392
437
  log.trace(f"Adding resource access to session {session_id}")
@@ -396,7 +441,26 @@ class SessionTrackingMiddleware(Middleware):
396
441
  resource_name = str(context.message.uri)
397
442
 
398
443
  log.debug(f"🔍 Analyzing resource {resource_name} for security implications")
399
- _ = session.data_access_tracker.add_resource_access(resource_name)
444
+ try:
445
+ _ = session.data_access_tracker.add_resource_access(resource_name)
446
+ except SecurityError as e:
447
+ events.fire_and_forget(
448
+ {
449
+ "type": "mcp_pre_block",
450
+ "kind": "resource",
451
+ "name": resource_name,
452
+ "session_id": session_id,
453
+ "error": str(e),
454
+ }
455
+ )
456
+ approved = await events.wait_for_approval(
457
+ session_id, "resource", resource_name, timeout_s=30.0
458
+ )
459
+ if not approved:
460
+ raise
461
+ session.data_access_tracker.apply_effects_after_manual_approval(
462
+ "resource", resource_name
463
+ )
400
464
  record_resource_used(resource_name)
401
465
 
402
466
  # Update database session
@@ -409,10 +473,14 @@ class SessionTrackingMiddleware(Middleware):
409
473
  db_session.commit()
410
474
 
411
475
  log.trace(f"Resource access {resource_name} added to session {session_id}")
412
- return await call_next(context)
476
+ try:
477
+ return await call_next(context)
478
+ except Exception:
479
+ log.exception("MCP read_resource failed")
480
+ raise
413
481
 
414
482
  # Hooks for Prompts
415
- async def on_list_prompts(
483
+ async def on_list_prompts( # noqa
416
484
  self,
417
485
  context: MiddlewareContext[Any], # type: ignore
418
486
  call_next: CallNext[Any, Any], # type: ignore
@@ -420,7 +488,11 @@ class SessionTrackingMiddleware(Middleware):
420
488
  """Process resource access and track security implications."""
421
489
  log.debug("🔍 on_list_prompts")
422
490
  # Get the original response
423
- response = await call_next(context)
491
+ try:
492
+ response = await call_next(context)
493
+ except Exception:
494
+ log.exception("MCP list_prompts failed")
495
+ raise
424
496
  log.debug(f"🔍 on_list_prompts response: {response}")
425
497
 
426
498
  session_id = current_session_id_ctxvar.get()
@@ -432,6 +504,7 @@ class SessionTrackingMiddleware(Middleware):
432
504
 
433
505
  # Filter out specific tools or return empty list
434
506
  allowed_prompts: list[ProxyPrompt | Any] = []
507
+ perms = Permissions()
435
508
  for prompt in response:
436
509
  prompt_name = str(prompt.name)
437
510
  log.trace(f"🔍 Processing prompt listing {prompt_name}")
@@ -447,19 +520,17 @@ class SessionTrackingMiddleware(Middleware):
447
520
  continue
448
521
 
449
522
  log.trace(f"🔍 Getting permissions for prompt {prompt_name}")
450
- permissions = session.data_access_tracker.get_prompt_permissions(prompt_name)
451
- log.trace(f"🔍 Prompt permissions: {permissions}")
452
- if permissions["enabled"]:
523
+ if perms.is_prompt_enabled(prompt_name):
453
524
  allowed_prompts.append(prompt)
454
525
  else:
455
526
  log.warning(
456
- f"🔍 Prompt {prompt_name} is disabled on not configured and will not be allowed"
527
+ f"🔍 Prompt {prompt_name} is disabled or not configured and will not be allowed"
457
528
  )
458
529
  continue
459
530
 
460
531
  return allowed_prompts # type: ignore
461
532
 
462
- async def on_get_prompt(
533
+ async def on_get_prompt( # noqa
463
534
  self,
464
535
  context: MiddlewareContext[Any], # type: ignore
465
536
  call_next: CallNext[Any, Any], # type: ignore
@@ -468,7 +539,11 @@ class SessionTrackingMiddleware(Middleware):
468
539
  session_id = current_session_id_ctxvar.get()
469
540
  if session_id is None:
470
541
  log.warning("No session ID found for prompt access tracking")
471
- return await call_next(context)
542
+ try:
543
+ return await call_next(context)
544
+ except Exception:
545
+ log.exception("MCP get_prompt failed")
546
+ raise
472
547
 
473
548
  session = get_session_from_db(session_id)
474
549
  log.trace(f"Adding prompt access to session {session_id}")
@@ -477,7 +552,24 @@ class SessionTrackingMiddleware(Middleware):
477
552
  prompt_name = context.message.name
478
553
 
479
554
  log.debug(f"🔍 Analyzing prompt {prompt_name} for security implications")
480
- _ = session.data_access_tracker.add_prompt_access(prompt_name)
555
+ try:
556
+ _ = session.data_access_tracker.add_prompt_access(prompt_name)
557
+ except SecurityError as e:
558
+ events.fire_and_forget(
559
+ {
560
+ "type": "mcp_pre_block",
561
+ "kind": "prompt",
562
+ "name": prompt_name,
563
+ "session_id": session_id,
564
+ "error": str(e),
565
+ }
566
+ )
567
+ approved = await events.wait_for_approval(
568
+ session_id, "prompt", prompt_name, timeout_s=30.0
569
+ )
570
+ if not approved:
571
+ raise
572
+ session.data_access_tracker.apply_effects_after_manual_approval("prompt", prompt_name)
481
573
  record_prompt_used(prompt_name)
482
574
 
483
575
  # Update database session
@@ -490,4 +582,8 @@ class SessionTrackingMiddleware(Middleware):
490
582
  db_session.commit()
491
583
 
492
584
  log.trace(f"Prompt access {prompt_name} added to session {session_id}")
493
- return await call_next(context)
585
+ try:
586
+ return await call_next(context)
587
+ except Exception:
588
+ log.exception("MCP get_prompt failed")
589
+ raise
src/oauth_manager.py ADDED
@@ -0,0 +1,281 @@
1
+ """
2
+ OAuth Manager for OpenEdison MCP Gateway
3
+
4
+ Handles OAuth 2.1 authentication for MCP servers using FastMCP's built-in OAuth support.
5
+ Provides detection, token management, and authentication flow coordination.
6
+ """
7
+
8
+ import asyncio
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from pathlib import Path
12
+
13
+ from fastmcp.client.auth.oauth import (
14
+ FileTokenStorage,
15
+ OAuth,
16
+ check_if_auth_required,
17
+ default_cache_dir,
18
+ )
19
+ from loguru import logger as log
20
+
21
+
22
+ class OAuthStatus(Enum):
23
+ """OAuth authentication status for MCP servers."""
24
+
25
+ UNKNOWN = "unknown" # noqa
26
+ NOT_REQUIRED = "not_required"
27
+ NEEDS_AUTH = "needs_auth"
28
+ AUTHENTICATED = "authenticated"
29
+ ERROR = "error"
30
+ EXPIRED = "expired" # noqa
31
+
32
+
33
+ @dataclass
34
+ class OAuthServerInfo:
35
+ """OAuth information for an MCP server."""
36
+
37
+ server_name: str
38
+ mcp_url: str
39
+ status: OAuthStatus
40
+ scopes: list[str] | None = None
41
+ client_name: str = "OpenEdison MCP Gateway"
42
+ error_message: str | None = None
43
+ token_expires_at: str | None = None
44
+ has_refresh_token: bool = False
45
+
46
+
47
+ class OAuthManager:
48
+ """
49
+ Manages OAuth authentication for MCP servers.
50
+
51
+ This class provides a centralized interface for:
52
+ - Detecting which MCP servers require OAuth
53
+ - Managing OAuth tokens and credentials
54
+ - Providing OAuth authentication objects for FastMCP clients
55
+ - Handling token refresh and expiration
56
+ """
57
+
58
+ def __init__(self, cache_dir: Path | None = None):
59
+ """
60
+ Initialize OAuth manager.
61
+
62
+ Args:
63
+ cache_dir: Directory for token cache. Defaults to FastMCP's default.
64
+ """
65
+ self.cache_dir = cache_dir or default_cache_dir()
66
+ self._oauth_info: dict[str, OAuthServerInfo] = {}
67
+
68
+ log.info(f"🔐 OAuth Manager initialized with cache dir: {self.cache_dir}")
69
+
70
+ async def check_oauth_requirement(
71
+ self, server_name: str, mcp_url: str | None, timeout_seconds: float = 10.0
72
+ ) -> OAuthServerInfo:
73
+ """
74
+ Check if an MCP server requires OAuth authentication.
75
+
76
+ Args:
77
+ server_name: Name of the MCP server
78
+ mcp_url: URL of the MCP endpoint (None for local servers)
79
+ timeout_seconds: Timeout for the check request
80
+
81
+ Returns:
82
+ OAuthServerInfo with detection results
83
+ """
84
+ log.info(f"🔍 Checking OAuth requirement for {server_name}")
85
+
86
+ # If no mcp_url provided, this is a local server - no OAuth needed
87
+ if not mcp_url:
88
+ info = OAuthServerInfo(
89
+ server_name=server_name, mcp_url="", status=OAuthStatus.NOT_REQUIRED
90
+ )
91
+ log.info(f"✅ {server_name} is a local server - no OAuth required")
92
+ self._oauth_info[server_name] = info
93
+ return info
94
+
95
+ log.info(f"🔍 Checking OAuth requirement for remote server {server_name} at {mcp_url}")
96
+
97
+ try:
98
+ # Check if auth is required (with timeout)
99
+ auth_required = await asyncio.wait_for(
100
+ check_if_auth_required(mcp_url), timeout=timeout_seconds
101
+ )
102
+
103
+ if not auth_required:
104
+ info = OAuthServerInfo(
105
+ server_name=server_name, mcp_url=mcp_url, status=OAuthStatus.NOT_REQUIRED
106
+ )
107
+ log.info(f"✅ {server_name} does not require OAuth")
108
+ self._oauth_info[server_name] = info
109
+ return info
110
+
111
+ # OAuth is required, proceed with token checking
112
+ log.info(f"🔐 {server_name} requires OAuth authentication")
113
+
114
+ # Check if we have existing valid tokens
115
+ token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
116
+ existing_tokens = await token_storage.get_tokens()
117
+
118
+ status = OAuthStatus.NEEDS_AUTH
119
+ token_expires_at = None
120
+ has_refresh_token = False
121
+
122
+ if existing_tokens:
123
+ # Check if tokens are still valid
124
+ # Note: FastMCP's FileTokenStorage doesn't expose expiration directly,
125
+ # so we'll attempt to use the tokens and see if they work
126
+ has_refresh_token = bool(existing_tokens.refresh_token)
127
+ if existing_tokens.access_token:
128
+ # We have tokens, assume they're valid for now
129
+ # The actual validation will happen when the client tries to use them
130
+ status = OAuthStatus.AUTHENTICATED
131
+ # Try to get expiration time if available
132
+ try:
133
+ expires_at = getattr(existing_tokens, "expires_at", None)
134
+ if expires_at:
135
+ token_expires_at = str(expires_at)
136
+ else:
137
+ expires_in = getattr(existing_tokens, "expires_in", None)
138
+ if expires_in:
139
+ # If expires_in is available, we can calculate expiration
140
+ from datetime import datetime, timedelta
141
+
142
+ expiry = datetime.now() + timedelta(seconds=expires_in)
143
+ token_expires_at = expiry.isoformat()
144
+ except Exception:
145
+ # If we can't get expiration info, that's ok - token_expires_at will be None
146
+ pass
147
+
148
+ info = OAuthServerInfo(
149
+ server_name=server_name,
150
+ mcp_url=mcp_url,
151
+ status=status,
152
+ scopes=None, # We don't have metadata discovery, so no scopes info
153
+ token_expires_at=token_expires_at,
154
+ has_refresh_token=has_refresh_token,
155
+ )
156
+
157
+ log.info(f"🔐 {server_name} OAuth status: {status.value}")
158
+ self._oauth_info[server_name] = info
159
+ return info
160
+
161
+ except TimeoutError:
162
+ info = OAuthServerInfo(
163
+ server_name=server_name,
164
+ mcp_url=mcp_url,
165
+ status=OAuthStatus.ERROR,
166
+ error_message=f"OAuth check timed out after {timeout_seconds}s",
167
+ )
168
+ log.warning(f"⏰ OAuth check for {server_name} timed out")
169
+ self._oauth_info[server_name] = info
170
+ return info
171
+
172
+ except Exception as e:
173
+ info = OAuthServerInfo(
174
+ server_name=server_name,
175
+ mcp_url=mcp_url,
176
+ status=OAuthStatus.ERROR,
177
+ error_message=str(e),
178
+ )
179
+ log.error(f"❌ OAuth check for {server_name} failed: {e}")
180
+ self._oauth_info[server_name] = info
181
+ return info
182
+
183
+ def get_oauth_auth(
184
+ self,
185
+ server_name: str,
186
+ mcp_url: str,
187
+ scopes: list[str] | None = None,
188
+ client_name: str | None = None,
189
+ ) -> OAuth | None:
190
+ """
191
+ Get OAuth authentication object for FastMCP client.
192
+
193
+ Args:
194
+ server_name: Name of the MCP server
195
+ mcp_url: URL of the MCP endpoint
196
+ scopes: OAuth scopes to request
197
+ client_name: Client name for OAuth registration
198
+
199
+ Returns:
200
+ OAuth authentication object, or None if OAuth not required
201
+ """
202
+ info = self._oauth_info.get(server_name)
203
+
204
+ if not info or info.status == OAuthStatus.NOT_REQUIRED:
205
+ return None
206
+
207
+ if info.status == OAuthStatus.ERROR:
208
+ log.warning(f"⚠️ Cannot create OAuth auth for {server_name}: {info.error_message}")
209
+ return None
210
+
211
+ try:
212
+ oauth = OAuth(
213
+ mcp_url=mcp_url,
214
+ scopes=scopes or info.scopes,
215
+ client_name=client_name or info.client_name,
216
+ token_storage_cache_dir=self.cache_dir,
217
+ callback_port=50001,
218
+ )
219
+ log.info(f"🔐 Created OAuth auth for {server_name}")
220
+ return oauth
221
+
222
+ except Exception as e:
223
+ log.error(f"❌ Failed to create OAuth auth for {server_name}: {e}")
224
+ return None
225
+
226
+ def clear_tokens(self, server_name: str, mcp_url: str) -> bool:
227
+ """
228
+ Clear stored OAuth tokens for a server.
229
+
230
+ Args:
231
+ server_name: Name of the MCP server
232
+ mcp_url: URL of the MCP endpoint
233
+
234
+ Returns:
235
+ True if tokens were cleared successfully
236
+ """
237
+ try:
238
+ token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
239
+ token_storage.clear()
240
+
241
+ # Update our cached info
242
+ if server_name in self._oauth_info:
243
+ self._oauth_info[server_name].status = OAuthStatus.NEEDS_AUTH
244
+ self._oauth_info[server_name].token_expires_at = None
245
+ self._oauth_info[server_name].has_refresh_token = False
246
+
247
+ log.info(f"🗑️ Cleared OAuth tokens for {server_name}")
248
+ return True
249
+
250
+ except Exception as e:
251
+ log.error(f"❌ Failed to clear tokens for {server_name}: {e}")
252
+ return False
253
+
254
+ def get_server_info(self, server_name: str) -> OAuthServerInfo | None:
255
+ """Get OAuth info for a server."""
256
+ return self._oauth_info.get(server_name)
257
+
258
+ async def refresh_server_status(self, server_name: str, mcp_url: str) -> OAuthServerInfo:
259
+ """
260
+ Refresh OAuth status for a server.
261
+
262
+ Args:
263
+ server_name: Name of the MCP server
264
+ mcp_url: URL of the MCP endpoint
265
+
266
+ Returns:
267
+ Updated OAuthServerInfo
268
+ """
269
+ return await self.check_oauth_requirement(server_name, mcp_url)
270
+
271
+
272
+ # Global OAuth manager instance
273
+ _oauth_manager: OAuthManager | None = None
274
+
275
+
276
+ def get_oauth_manager() -> OAuthManager:
277
+ """Get the global OAuth manager instance."""
278
+ global _oauth_manager
279
+ if _oauth_manager is None:
280
+ _oauth_manager = OAuthManager()
281
+ return _oauth_manager