open-edison 0.1.19__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {open_edison-0.1.19.dist-info → open_edison-0.1.29.dist-info}/METADATA +66 -45
- open_edison-0.1.29.dist-info/RECORD +17 -0
- src/cli.py +2 -1
- src/config.py +71 -71
- src/events.py +153 -0
- src/middleware/data_access_tracker.py +164 -434
- src/middleware/session_tracking.py +133 -37
- src/oauth_manager.py +281 -0
- src/permissions.py +281 -0
- src/server.py +491 -134
- src/single_user_mcp.py +230 -158
- src/telemetry.py +4 -40
- open_edison-0.1.19.dist-info/RECORD +0 -14
- {open_edison-0.1.19.dist-info → open_edison-0.1.29.dist-info}/WHEEL +0 -0
- {open_edison-0.1.19.dist-info → open_edison-0.1.29.dist-info}/entry_points.txt +0 -0
- {open_edison-0.1.19.dist-info → open_edison-0.1.29.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@ from contextvars import ContextVar
|
|
12
12
|
from dataclasses import dataclass, field
|
13
13
|
from datetime import datetime
|
14
14
|
from pathlib import Path
|
15
|
-
from typing import Any
|
15
|
+
from typing import Any, cast
|
16
16
|
|
17
17
|
import mcp.types as mt
|
18
18
|
from fastmcp.prompts.prompt import FunctionPrompt
|
@@ -27,8 +27,13 @@ from sqlalchemy import JSON, Column, Integer, String, create_engine, event
|
|
27
27
|
from sqlalchemy.orm import Session, declarative_base
|
28
28
|
from sqlalchemy.sql import select
|
29
29
|
|
30
|
+
from src import events
|
30
31
|
from src.config import get_config_dir # type: ignore[reportMissingImports]
|
31
|
-
from src.middleware.data_access_tracker import
|
32
|
+
from src.middleware.data_access_tracker import (
|
33
|
+
DataAccessTracker,
|
34
|
+
SecurityError,
|
35
|
+
)
|
36
|
+
from src.permissions import Permissions
|
32
37
|
from src.telemetry import (
|
33
38
|
record_prompt_used,
|
34
39
|
record_resource_used,
|
@@ -67,7 +72,10 @@ class MCPSessionModel(Base): # type: ignore
|
|
67
72
|
data_access_summary = Column(JSON) # type: ignore
|
68
73
|
|
69
74
|
|
70
|
-
current_session_id_ctxvar: ContextVar[str | None] = ContextVar(
|
75
|
+
current_session_id_ctxvar: ContextVar[str | None] = ContextVar(
|
76
|
+
"current_session_id",
|
77
|
+
default=cast(str | None, None), # noqa: B039
|
78
|
+
)
|
71
79
|
|
72
80
|
|
73
81
|
def get_current_session_data_tracker() -> DataAccessTracker | None:
|
@@ -101,7 +109,7 @@ def create_db_session() -> Generator[Session, None, None]:
|
|
101
109
|
engine = create_engine(f"sqlite:///{db_path}")
|
102
110
|
|
103
111
|
# Ensure changes are flushed to the main database file (avoid WAL for sql.js compatibility)
|
104
|
-
@event.listens_for(engine, "connect")
|
112
|
+
@event.listens_for(engine, "connect") # noqa
|
105
113
|
def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore[no-untyped-def] # noqa
|
106
114
|
cur = dbapi_connection.cursor() # type: ignore[attr-defined]
|
107
115
|
try:
|
@@ -171,7 +179,7 @@ def get_session_from_db(session_id: str) -> MCPSession:
|
|
171
179
|
"has_external_communication", False
|
172
180
|
)
|
173
181
|
# Restore ACL highest level if present
|
174
|
-
if isinstance(summary_data, dict) and "acl" in summary_data:
|
182
|
+
if isinstance(summary_data, dict) and "acl" in summary_data: # type: ignore
|
175
183
|
acl_summary: Any = summary_data.get("acl") # type: ignore
|
176
184
|
if isinstance(acl_summary, dict):
|
177
185
|
highest = acl_summary.get("highest_acl_level") # type: ignore
|
@@ -214,7 +222,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
214
222
|
return session, session_id
|
215
223
|
|
216
224
|
# General hooks for on_request, on_message, etc.
|
217
|
-
async def on_request(
|
225
|
+
async def on_request( # noqa
|
218
226
|
self,
|
219
227
|
context: MiddlewareContext[mt.Request[Any, Any]], # type: ignore
|
220
228
|
call_next: CallNext[mt.Request[Any, Any], Any], # type: ignore
|
@@ -225,17 +233,25 @@ class SessionTrackingMiddleware(Middleware):
|
|
225
233
|
# Get or create session stats
|
226
234
|
_, _session_id = self._get_or_create_session_stats(context)
|
227
235
|
|
228
|
-
|
236
|
+
try:
|
237
|
+
return await call_next(context) # type: ignore
|
238
|
+
except Exception:
|
239
|
+
log.exception("MCP request handling failed")
|
240
|
+
raise
|
229
241
|
|
230
242
|
# Hooks for Tools
|
231
|
-
async def on_list_tools(
|
243
|
+
async def on_list_tools( # noqa
|
232
244
|
self,
|
233
245
|
context: MiddlewareContext[Any], # type: ignore
|
234
246
|
call_next: CallNext[Any, Any], # type: ignore
|
235
247
|
) -> Any:
|
236
248
|
log.debug("🔍 on_list_tools")
|
237
249
|
# Get the original response
|
238
|
-
|
250
|
+
try:
|
251
|
+
response = await call_next(context)
|
252
|
+
except Exception:
|
253
|
+
log.exception("MCP list_tools failed")
|
254
|
+
raise
|
239
255
|
log.trace(f"🔍 on_list_tools response: {response}")
|
240
256
|
|
241
257
|
session_id = current_session_id_ctxvar.get()
|
@@ -247,8 +263,11 @@ class SessionTrackingMiddleware(Middleware):
|
|
247
263
|
|
248
264
|
# Filter out specific tools or return empty list
|
249
265
|
allowed_tools: list[FunctionTool | ProxyTool | Any] = []
|
266
|
+
perms = Permissions()
|
250
267
|
for tool in response:
|
251
|
-
|
268
|
+
# Due to proxy & server naming
|
269
|
+
tool_name = tool.key
|
270
|
+
log.trace(f"🔍 Processing tool listing {tool_name}")
|
252
271
|
if isinstance(tool, FunctionTool):
|
253
272
|
log.trace("🔍 Tool is built-in")
|
254
273
|
log.trace(f"🔍 Tool is a FunctionTool: {tool}")
|
@@ -260,20 +279,18 @@ class SessionTrackingMiddleware(Middleware):
|
|
260
279
|
log.trace(f"🔍 Tool is a unknown type: {tool}")
|
261
280
|
continue
|
262
281
|
|
263
|
-
log.trace(f"🔍 Getting permissions for tool {
|
264
|
-
|
265
|
-
log.trace(f"🔍 Tool permissions: {permissions}")
|
266
|
-
if permissions["enabled"]:
|
282
|
+
log.trace(f"🔍 Getting permissions for tool {tool_name}")
|
283
|
+
if perms.is_tool_enabled(tool_name):
|
267
284
|
allowed_tools.append(tool)
|
268
285
|
else:
|
269
286
|
log.warning(
|
270
|
-
f"🔍 Tool {
|
287
|
+
f"🔍 Tool {tool_name} is disabled or not configured and will not be allowed"
|
271
288
|
)
|
272
289
|
continue
|
273
290
|
|
274
291
|
return allowed_tools # type: ignore
|
275
292
|
|
276
|
-
async def on_call_tool(
|
293
|
+
async def on_call_tool( # noqa
|
277
294
|
self,
|
278
295
|
context: MiddlewareContext[mt.CallToolRequestParams], # type: ignore
|
279
296
|
call_next: CallNext[mt.CallToolRequestParams, ToolResult], # type: ignore
|
@@ -296,7 +313,28 @@ class SessionTrackingMiddleware(Middleware):
|
|
296
313
|
|
297
314
|
assert session.data_access_tracker is not None
|
298
315
|
log.debug(f"🔍 Analyzing tool {context.message.name} for security implications")
|
299
|
-
|
316
|
+
try:
|
317
|
+
session.data_access_tracker.add_tool_call(context.message.name)
|
318
|
+
except SecurityError as e:
|
319
|
+
# Publish pre-block event enriched with session_id then wait up to 30s for approval
|
320
|
+
events.fire_and_forget(
|
321
|
+
{
|
322
|
+
"type": "mcp_pre_block",
|
323
|
+
"kind": "tool",
|
324
|
+
"name": context.message.name,
|
325
|
+
"session_id": session_id,
|
326
|
+
"error": str(e),
|
327
|
+
}
|
328
|
+
)
|
329
|
+
approved = await events.wait_for_approval(
|
330
|
+
session_id, "tool", context.message.name, timeout_s=30.0
|
331
|
+
)
|
332
|
+
if not approved:
|
333
|
+
raise
|
334
|
+
# Approved: apply effects and proceed
|
335
|
+
session.data_access_tracker.apply_effects_after_manual_approval(
|
336
|
+
"tool", context.message.name
|
337
|
+
)
|
300
338
|
# Telemetry: record tool call
|
301
339
|
record_tool_call(context.message.name)
|
302
340
|
|
@@ -330,7 +368,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
330
368
|
return await call_next(context) # type: ignore
|
331
369
|
|
332
370
|
# Hooks for Resources
|
333
|
-
async def on_list_resources(
|
371
|
+
async def on_list_resources( # noqa
|
334
372
|
self,
|
335
373
|
context: MiddlewareContext[Any], # type: ignore
|
336
374
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -338,7 +376,11 @@ class SessionTrackingMiddleware(Middleware):
|
|
338
376
|
"""Process resource access and track security implications."""
|
339
377
|
log.trace("🔍 on_list_resources")
|
340
378
|
# Get the original response
|
341
|
-
|
379
|
+
try:
|
380
|
+
response = await call_next(context)
|
381
|
+
except Exception:
|
382
|
+
log.exception("MCP list_resources failed")
|
383
|
+
raise
|
342
384
|
log.trace(f"🔍 on_list_resources response: {response}")
|
343
385
|
|
344
386
|
session_id = current_session_id_ctxvar.get()
|
@@ -350,6 +392,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
350
392
|
|
351
393
|
# Filter out specific tools or return empty list
|
352
394
|
allowed_resources: list[FunctionResource | ProxyResource | Any] = []
|
395
|
+
perms = Permissions()
|
353
396
|
for resource in response:
|
354
397
|
resource_name = str(resource.uri)
|
355
398
|
log.trace(f"🔍 Processing resource listing {resource_name}")
|
@@ -365,19 +408,17 @@ class SessionTrackingMiddleware(Middleware):
|
|
365
408
|
continue
|
366
409
|
|
367
410
|
log.trace(f"🔍 Getting permissions for resource {resource_name}")
|
368
|
-
|
369
|
-
log.trace(f"🔍 Resource permissions: {permissions}")
|
370
|
-
if permissions["enabled"]:
|
411
|
+
if perms.is_resource_enabled(resource_name):
|
371
412
|
allowed_resources.append(resource)
|
372
413
|
else:
|
373
414
|
log.warning(
|
374
|
-
f"🔍 Resource {resource_name} is disabled
|
415
|
+
f"🔍 Resource {resource_name} is disabled or not configured and will not be allowed"
|
375
416
|
)
|
376
417
|
continue
|
377
418
|
|
378
419
|
return allowed_resources # type: ignore
|
379
420
|
|
380
|
-
async def on_read_resource(
|
421
|
+
async def on_read_resource( # noqa
|
381
422
|
self,
|
382
423
|
context: MiddlewareContext[Any], # type: ignore
|
383
424
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -386,7 +427,11 @@ class SessionTrackingMiddleware(Middleware):
|
|
386
427
|
session_id = current_session_id_ctxvar.get()
|
387
428
|
if session_id is None:
|
388
429
|
log.warning("No session ID found for resource access tracking")
|
389
|
-
|
430
|
+
try:
|
431
|
+
return await call_next(context)
|
432
|
+
except Exception:
|
433
|
+
log.exception("MCP read_resource failed")
|
434
|
+
raise
|
390
435
|
|
391
436
|
session = get_session_from_db(session_id)
|
392
437
|
log.trace(f"Adding resource access to session {session_id}")
|
@@ -396,7 +441,26 @@ class SessionTrackingMiddleware(Middleware):
|
|
396
441
|
resource_name = str(context.message.uri)
|
397
442
|
|
398
443
|
log.debug(f"🔍 Analyzing resource {resource_name} for security implications")
|
399
|
-
|
444
|
+
try:
|
445
|
+
_ = session.data_access_tracker.add_resource_access(resource_name)
|
446
|
+
except SecurityError as e:
|
447
|
+
events.fire_and_forget(
|
448
|
+
{
|
449
|
+
"type": "mcp_pre_block",
|
450
|
+
"kind": "resource",
|
451
|
+
"name": resource_name,
|
452
|
+
"session_id": session_id,
|
453
|
+
"error": str(e),
|
454
|
+
}
|
455
|
+
)
|
456
|
+
approved = await events.wait_for_approval(
|
457
|
+
session_id, "resource", resource_name, timeout_s=30.0
|
458
|
+
)
|
459
|
+
if not approved:
|
460
|
+
raise
|
461
|
+
session.data_access_tracker.apply_effects_after_manual_approval(
|
462
|
+
"resource", resource_name
|
463
|
+
)
|
400
464
|
record_resource_used(resource_name)
|
401
465
|
|
402
466
|
# Update database session
|
@@ -409,10 +473,14 @@ class SessionTrackingMiddleware(Middleware):
|
|
409
473
|
db_session.commit()
|
410
474
|
|
411
475
|
log.trace(f"Resource access {resource_name} added to session {session_id}")
|
412
|
-
|
476
|
+
try:
|
477
|
+
return await call_next(context)
|
478
|
+
except Exception:
|
479
|
+
log.exception("MCP read_resource failed")
|
480
|
+
raise
|
413
481
|
|
414
482
|
# Hooks for Prompts
|
415
|
-
async def on_list_prompts(
|
483
|
+
async def on_list_prompts( # noqa
|
416
484
|
self,
|
417
485
|
context: MiddlewareContext[Any], # type: ignore
|
418
486
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -420,7 +488,11 @@ class SessionTrackingMiddleware(Middleware):
|
|
420
488
|
"""Process resource access and track security implications."""
|
421
489
|
log.debug("🔍 on_list_prompts")
|
422
490
|
# Get the original response
|
423
|
-
|
491
|
+
try:
|
492
|
+
response = await call_next(context)
|
493
|
+
except Exception:
|
494
|
+
log.exception("MCP list_prompts failed")
|
495
|
+
raise
|
424
496
|
log.debug(f"🔍 on_list_prompts response: {response}")
|
425
497
|
|
426
498
|
session_id = current_session_id_ctxvar.get()
|
@@ -432,6 +504,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
432
504
|
|
433
505
|
# Filter out specific tools or return empty list
|
434
506
|
allowed_prompts: list[ProxyPrompt | Any] = []
|
507
|
+
perms = Permissions()
|
435
508
|
for prompt in response:
|
436
509
|
prompt_name = str(prompt.name)
|
437
510
|
log.trace(f"🔍 Processing prompt listing {prompt_name}")
|
@@ -447,19 +520,17 @@ class SessionTrackingMiddleware(Middleware):
|
|
447
520
|
continue
|
448
521
|
|
449
522
|
log.trace(f"🔍 Getting permissions for prompt {prompt_name}")
|
450
|
-
|
451
|
-
log.trace(f"🔍 Prompt permissions: {permissions}")
|
452
|
-
if permissions["enabled"]:
|
523
|
+
if perms.is_prompt_enabled(prompt_name):
|
453
524
|
allowed_prompts.append(prompt)
|
454
525
|
else:
|
455
526
|
log.warning(
|
456
|
-
f"🔍 Prompt {prompt_name} is disabled
|
527
|
+
f"🔍 Prompt {prompt_name} is disabled or not configured and will not be allowed"
|
457
528
|
)
|
458
529
|
continue
|
459
530
|
|
460
531
|
return allowed_prompts # type: ignore
|
461
532
|
|
462
|
-
async def on_get_prompt(
|
533
|
+
async def on_get_prompt( # noqa
|
463
534
|
self,
|
464
535
|
context: MiddlewareContext[Any], # type: ignore
|
465
536
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -468,7 +539,11 @@ class SessionTrackingMiddleware(Middleware):
|
|
468
539
|
session_id = current_session_id_ctxvar.get()
|
469
540
|
if session_id is None:
|
470
541
|
log.warning("No session ID found for prompt access tracking")
|
471
|
-
|
542
|
+
try:
|
543
|
+
return await call_next(context)
|
544
|
+
except Exception:
|
545
|
+
log.exception("MCP get_prompt failed")
|
546
|
+
raise
|
472
547
|
|
473
548
|
session = get_session_from_db(session_id)
|
474
549
|
log.trace(f"Adding prompt access to session {session_id}")
|
@@ -477,7 +552,24 @@ class SessionTrackingMiddleware(Middleware):
|
|
477
552
|
prompt_name = context.message.name
|
478
553
|
|
479
554
|
log.debug(f"🔍 Analyzing prompt {prompt_name} for security implications")
|
480
|
-
|
555
|
+
try:
|
556
|
+
_ = session.data_access_tracker.add_prompt_access(prompt_name)
|
557
|
+
except SecurityError as e:
|
558
|
+
events.fire_and_forget(
|
559
|
+
{
|
560
|
+
"type": "mcp_pre_block",
|
561
|
+
"kind": "prompt",
|
562
|
+
"name": prompt_name,
|
563
|
+
"session_id": session_id,
|
564
|
+
"error": str(e),
|
565
|
+
}
|
566
|
+
)
|
567
|
+
approved = await events.wait_for_approval(
|
568
|
+
session_id, "prompt", prompt_name, timeout_s=30.0
|
569
|
+
)
|
570
|
+
if not approved:
|
571
|
+
raise
|
572
|
+
session.data_access_tracker.apply_effects_after_manual_approval("prompt", prompt_name)
|
481
573
|
record_prompt_used(prompt_name)
|
482
574
|
|
483
575
|
# Update database session
|
@@ -490,4 +582,8 @@ class SessionTrackingMiddleware(Middleware):
|
|
490
582
|
db_session.commit()
|
491
583
|
|
492
584
|
log.trace(f"Prompt access {prompt_name} added to session {session_id}")
|
493
|
-
|
585
|
+
try:
|
586
|
+
return await call_next(context)
|
587
|
+
except Exception:
|
588
|
+
log.exception("MCP get_prompt failed")
|
589
|
+
raise
|
src/oauth_manager.py
ADDED
@@ -0,0 +1,281 @@
|
|
1
|
+
"""
|
2
|
+
OAuth Manager for OpenEdison MCP Gateway
|
3
|
+
|
4
|
+
Handles OAuth 2.1 authentication for MCP servers using FastMCP's built-in OAuth support.
|
5
|
+
Provides detection, token management, and authentication flow coordination.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import asyncio
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from enum import Enum
|
11
|
+
from pathlib import Path
|
12
|
+
|
13
|
+
from fastmcp.client.auth.oauth import (
|
14
|
+
FileTokenStorage,
|
15
|
+
OAuth,
|
16
|
+
check_if_auth_required,
|
17
|
+
default_cache_dir,
|
18
|
+
)
|
19
|
+
from loguru import logger as log
|
20
|
+
|
21
|
+
|
22
|
+
class OAuthStatus(Enum):
|
23
|
+
"""OAuth authentication status for MCP servers."""
|
24
|
+
|
25
|
+
UNKNOWN = "unknown" # noqa
|
26
|
+
NOT_REQUIRED = "not_required"
|
27
|
+
NEEDS_AUTH = "needs_auth"
|
28
|
+
AUTHENTICATED = "authenticated"
|
29
|
+
ERROR = "error"
|
30
|
+
EXPIRED = "expired" # noqa
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class OAuthServerInfo:
|
35
|
+
"""OAuth information for an MCP server."""
|
36
|
+
|
37
|
+
server_name: str
|
38
|
+
mcp_url: str
|
39
|
+
status: OAuthStatus
|
40
|
+
scopes: list[str] | None = None
|
41
|
+
client_name: str = "OpenEdison MCP Gateway"
|
42
|
+
error_message: str | None = None
|
43
|
+
token_expires_at: str | None = None
|
44
|
+
has_refresh_token: bool = False
|
45
|
+
|
46
|
+
|
47
|
+
class OAuthManager:
|
48
|
+
"""
|
49
|
+
Manages OAuth authentication for MCP servers.
|
50
|
+
|
51
|
+
This class provides a centralized interface for:
|
52
|
+
- Detecting which MCP servers require OAuth
|
53
|
+
- Managing OAuth tokens and credentials
|
54
|
+
- Providing OAuth authentication objects for FastMCP clients
|
55
|
+
- Handling token refresh and expiration
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self, cache_dir: Path | None = None):
|
59
|
+
"""
|
60
|
+
Initialize OAuth manager.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
cache_dir: Directory for token cache. Defaults to FastMCP's default.
|
64
|
+
"""
|
65
|
+
self.cache_dir = cache_dir or default_cache_dir()
|
66
|
+
self._oauth_info: dict[str, OAuthServerInfo] = {}
|
67
|
+
|
68
|
+
log.info(f"🔐 OAuth Manager initialized with cache dir: {self.cache_dir}")
|
69
|
+
|
70
|
+
async def check_oauth_requirement(
|
71
|
+
self, server_name: str, mcp_url: str | None, timeout_seconds: float = 10.0
|
72
|
+
) -> OAuthServerInfo:
|
73
|
+
"""
|
74
|
+
Check if an MCP server requires OAuth authentication.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
server_name: Name of the MCP server
|
78
|
+
mcp_url: URL of the MCP endpoint (None for local servers)
|
79
|
+
timeout_seconds: Timeout for the check request
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
OAuthServerInfo with detection results
|
83
|
+
"""
|
84
|
+
log.info(f"🔍 Checking OAuth requirement for {server_name}")
|
85
|
+
|
86
|
+
# If no mcp_url provided, this is a local server - no OAuth needed
|
87
|
+
if not mcp_url:
|
88
|
+
info = OAuthServerInfo(
|
89
|
+
server_name=server_name, mcp_url="", status=OAuthStatus.NOT_REQUIRED
|
90
|
+
)
|
91
|
+
log.info(f"✅ {server_name} is a local server - no OAuth required")
|
92
|
+
self._oauth_info[server_name] = info
|
93
|
+
return info
|
94
|
+
|
95
|
+
log.info(f"🔍 Checking OAuth requirement for remote server {server_name} at {mcp_url}")
|
96
|
+
|
97
|
+
try:
|
98
|
+
# Check if auth is required (with timeout)
|
99
|
+
auth_required = await asyncio.wait_for(
|
100
|
+
check_if_auth_required(mcp_url), timeout=timeout_seconds
|
101
|
+
)
|
102
|
+
|
103
|
+
if not auth_required:
|
104
|
+
info = OAuthServerInfo(
|
105
|
+
server_name=server_name, mcp_url=mcp_url, status=OAuthStatus.NOT_REQUIRED
|
106
|
+
)
|
107
|
+
log.info(f"✅ {server_name} does not require OAuth")
|
108
|
+
self._oauth_info[server_name] = info
|
109
|
+
return info
|
110
|
+
|
111
|
+
# OAuth is required, proceed with token checking
|
112
|
+
log.info(f"🔐 {server_name} requires OAuth authentication")
|
113
|
+
|
114
|
+
# Check if we have existing valid tokens
|
115
|
+
token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
|
116
|
+
existing_tokens = await token_storage.get_tokens()
|
117
|
+
|
118
|
+
status = OAuthStatus.NEEDS_AUTH
|
119
|
+
token_expires_at = None
|
120
|
+
has_refresh_token = False
|
121
|
+
|
122
|
+
if existing_tokens:
|
123
|
+
# Check if tokens are still valid
|
124
|
+
# Note: FastMCP's FileTokenStorage doesn't expose expiration directly,
|
125
|
+
# so we'll attempt to use the tokens and see if they work
|
126
|
+
has_refresh_token = bool(existing_tokens.refresh_token)
|
127
|
+
if existing_tokens.access_token:
|
128
|
+
# We have tokens, assume they're valid for now
|
129
|
+
# The actual validation will happen when the client tries to use them
|
130
|
+
status = OAuthStatus.AUTHENTICATED
|
131
|
+
# Try to get expiration time if available
|
132
|
+
try:
|
133
|
+
expires_at = getattr(existing_tokens, "expires_at", None)
|
134
|
+
if expires_at:
|
135
|
+
token_expires_at = str(expires_at)
|
136
|
+
else:
|
137
|
+
expires_in = getattr(existing_tokens, "expires_in", None)
|
138
|
+
if expires_in:
|
139
|
+
# If expires_in is available, we can calculate expiration
|
140
|
+
from datetime import datetime, timedelta
|
141
|
+
|
142
|
+
expiry = datetime.now() + timedelta(seconds=expires_in)
|
143
|
+
token_expires_at = expiry.isoformat()
|
144
|
+
except Exception:
|
145
|
+
# If we can't get expiration info, that's ok - token_expires_at will be None
|
146
|
+
pass
|
147
|
+
|
148
|
+
info = OAuthServerInfo(
|
149
|
+
server_name=server_name,
|
150
|
+
mcp_url=mcp_url,
|
151
|
+
status=status,
|
152
|
+
scopes=None, # We don't have metadata discovery, so no scopes info
|
153
|
+
token_expires_at=token_expires_at,
|
154
|
+
has_refresh_token=has_refresh_token,
|
155
|
+
)
|
156
|
+
|
157
|
+
log.info(f"🔐 {server_name} OAuth status: {status.value}")
|
158
|
+
self._oauth_info[server_name] = info
|
159
|
+
return info
|
160
|
+
|
161
|
+
except TimeoutError:
|
162
|
+
info = OAuthServerInfo(
|
163
|
+
server_name=server_name,
|
164
|
+
mcp_url=mcp_url,
|
165
|
+
status=OAuthStatus.ERROR,
|
166
|
+
error_message=f"OAuth check timed out after {timeout_seconds}s",
|
167
|
+
)
|
168
|
+
log.warning(f"⏰ OAuth check for {server_name} timed out")
|
169
|
+
self._oauth_info[server_name] = info
|
170
|
+
return info
|
171
|
+
|
172
|
+
except Exception as e:
|
173
|
+
info = OAuthServerInfo(
|
174
|
+
server_name=server_name,
|
175
|
+
mcp_url=mcp_url,
|
176
|
+
status=OAuthStatus.ERROR,
|
177
|
+
error_message=str(e),
|
178
|
+
)
|
179
|
+
log.error(f"❌ OAuth check for {server_name} failed: {e}")
|
180
|
+
self._oauth_info[server_name] = info
|
181
|
+
return info
|
182
|
+
|
183
|
+
def get_oauth_auth(
|
184
|
+
self,
|
185
|
+
server_name: str,
|
186
|
+
mcp_url: str,
|
187
|
+
scopes: list[str] | None = None,
|
188
|
+
client_name: str | None = None,
|
189
|
+
) -> OAuth | None:
|
190
|
+
"""
|
191
|
+
Get OAuth authentication object for FastMCP client.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
server_name: Name of the MCP server
|
195
|
+
mcp_url: URL of the MCP endpoint
|
196
|
+
scopes: OAuth scopes to request
|
197
|
+
client_name: Client name for OAuth registration
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
OAuth authentication object, or None if OAuth not required
|
201
|
+
"""
|
202
|
+
info = self._oauth_info.get(server_name)
|
203
|
+
|
204
|
+
if not info or info.status == OAuthStatus.NOT_REQUIRED:
|
205
|
+
return None
|
206
|
+
|
207
|
+
if info.status == OAuthStatus.ERROR:
|
208
|
+
log.warning(f"⚠️ Cannot create OAuth auth for {server_name}: {info.error_message}")
|
209
|
+
return None
|
210
|
+
|
211
|
+
try:
|
212
|
+
oauth = OAuth(
|
213
|
+
mcp_url=mcp_url,
|
214
|
+
scopes=scopes or info.scopes,
|
215
|
+
client_name=client_name or info.client_name,
|
216
|
+
token_storage_cache_dir=self.cache_dir,
|
217
|
+
callback_port=50001,
|
218
|
+
)
|
219
|
+
log.info(f"🔐 Created OAuth auth for {server_name}")
|
220
|
+
return oauth
|
221
|
+
|
222
|
+
except Exception as e:
|
223
|
+
log.error(f"❌ Failed to create OAuth auth for {server_name}: {e}")
|
224
|
+
return None
|
225
|
+
|
226
|
+
def clear_tokens(self, server_name: str, mcp_url: str) -> bool:
|
227
|
+
"""
|
228
|
+
Clear stored OAuth tokens for a server.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
server_name: Name of the MCP server
|
232
|
+
mcp_url: URL of the MCP endpoint
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
True if tokens were cleared successfully
|
236
|
+
"""
|
237
|
+
try:
|
238
|
+
token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
|
239
|
+
token_storage.clear()
|
240
|
+
|
241
|
+
# Update our cached info
|
242
|
+
if server_name in self._oauth_info:
|
243
|
+
self._oauth_info[server_name].status = OAuthStatus.NEEDS_AUTH
|
244
|
+
self._oauth_info[server_name].token_expires_at = None
|
245
|
+
self._oauth_info[server_name].has_refresh_token = False
|
246
|
+
|
247
|
+
log.info(f"🗑️ Cleared OAuth tokens for {server_name}")
|
248
|
+
return True
|
249
|
+
|
250
|
+
except Exception as e:
|
251
|
+
log.error(f"❌ Failed to clear tokens for {server_name}: {e}")
|
252
|
+
return False
|
253
|
+
|
254
|
+
def get_server_info(self, server_name: str) -> OAuthServerInfo | None:
|
255
|
+
"""Get OAuth info for a server."""
|
256
|
+
return self._oauth_info.get(server_name)
|
257
|
+
|
258
|
+
async def refresh_server_status(self, server_name: str, mcp_url: str) -> OAuthServerInfo:
|
259
|
+
"""
|
260
|
+
Refresh OAuth status for a server.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
server_name: Name of the MCP server
|
264
|
+
mcp_url: URL of the MCP endpoint
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
Updated OAuthServerInfo
|
268
|
+
"""
|
269
|
+
return await self.check_oauth_requirement(server_name, mcp_url)
|
270
|
+
|
271
|
+
|
272
|
+
# Global OAuth manager instance
|
273
|
+
_oauth_manager: OAuthManager | None = None
|
274
|
+
|
275
|
+
|
276
|
+
def get_oauth_manager() -> OAuthManager:
|
277
|
+
"""Get the global OAuth manager instance."""
|
278
|
+
global _oauth_manager
|
279
|
+
if _oauth_manager is None:
|
280
|
+
_oauth_manager = OAuthManager()
|
281
|
+
return _oauth_manager
|