open-edison 0.1.19__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {open_edison-0.1.19.dist-info → open_edison-0.1.26.dist-info}/METADATA +60 -41
- open_edison-0.1.26.dist-info/RECORD +17 -0
- src/cli.py +2 -1
- src/config.py +63 -51
- src/events.py +153 -0
- src/middleware/data_access_tracker.py +164 -434
- src/middleware/session_tracking.py +93 -29
- src/oauth_manager.py +281 -0
- src/permissions.py +292 -0
- src/server.py +484 -132
- src/single_user_mcp.py +221 -159
- src/telemetry.py +4 -40
- open_edison-0.1.19.dist-info/RECORD +0 -14
- {open_edison-0.1.19.dist-info → open_edison-0.1.26.dist-info}/WHEEL +0 -0
- {open_edison-0.1.19.dist-info → open_edison-0.1.26.dist-info}/entry_points.txt +0 -0
- {open_edison-0.1.19.dist-info → open_edison-0.1.26.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@ from contextvars import ContextVar
|
|
12
12
|
from dataclasses import dataclass, field
|
13
13
|
from datetime import datetime
|
14
14
|
from pathlib import Path
|
15
|
-
from typing import Any
|
15
|
+
from typing import Any, cast
|
16
16
|
|
17
17
|
import mcp.types as mt
|
18
18
|
from fastmcp.prompts.prompt import FunctionPrompt
|
@@ -27,8 +27,13 @@ from sqlalchemy import JSON, Column, Integer, String, create_engine, event
|
|
27
27
|
from sqlalchemy.orm import Session, declarative_base
|
28
28
|
from sqlalchemy.sql import select
|
29
29
|
|
30
|
+
from src import events
|
30
31
|
from src.config import get_config_dir # type: ignore[reportMissingImports]
|
31
|
-
from src.middleware.data_access_tracker import
|
32
|
+
from src.middleware.data_access_tracker import (
|
33
|
+
DataAccessTracker,
|
34
|
+
SecurityError,
|
35
|
+
)
|
36
|
+
from src.permissions import Permissions
|
32
37
|
from src.telemetry import (
|
33
38
|
record_prompt_used,
|
34
39
|
record_resource_used,
|
@@ -67,7 +72,10 @@ class MCPSessionModel(Base): # type: ignore
|
|
67
72
|
data_access_summary = Column(JSON) # type: ignore
|
68
73
|
|
69
74
|
|
70
|
-
current_session_id_ctxvar: ContextVar[str | None] = ContextVar(
|
75
|
+
current_session_id_ctxvar: ContextVar[str | None] = ContextVar(
|
76
|
+
"current_session_id",
|
77
|
+
default=cast(str | None, None), # noqa: B039
|
78
|
+
)
|
71
79
|
|
72
80
|
|
73
81
|
def get_current_session_data_tracker() -> DataAccessTracker | None:
|
@@ -101,7 +109,7 @@ def create_db_session() -> Generator[Session, None, None]:
|
|
101
109
|
engine = create_engine(f"sqlite:///{db_path}")
|
102
110
|
|
103
111
|
# Ensure changes are flushed to the main database file (avoid WAL for sql.js compatibility)
|
104
|
-
@event.listens_for(engine, "connect")
|
112
|
+
@event.listens_for(engine, "connect") # noqa
|
105
113
|
def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore[no-untyped-def] # noqa
|
106
114
|
cur = dbapi_connection.cursor() # type: ignore[attr-defined]
|
107
115
|
try:
|
@@ -171,7 +179,7 @@ def get_session_from_db(session_id: str) -> MCPSession:
|
|
171
179
|
"has_external_communication", False
|
172
180
|
)
|
173
181
|
# Restore ACL highest level if present
|
174
|
-
if isinstance(summary_data, dict) and "acl" in summary_data:
|
182
|
+
if isinstance(summary_data, dict) and "acl" in summary_data: # type: ignore
|
175
183
|
acl_summary: Any = summary_data.get("acl") # type: ignore
|
176
184
|
if isinstance(acl_summary, dict):
|
177
185
|
highest = acl_summary.get("highest_acl_level") # type: ignore
|
@@ -214,7 +222,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
214
222
|
return session, session_id
|
215
223
|
|
216
224
|
# General hooks for on_request, on_message, etc.
|
217
|
-
async def on_request(
|
225
|
+
async def on_request( # noqa
|
218
226
|
self,
|
219
227
|
context: MiddlewareContext[mt.Request[Any, Any]], # type: ignore
|
220
228
|
call_next: CallNext[mt.Request[Any, Any], Any], # type: ignore
|
@@ -228,7 +236,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
228
236
|
return await call_next(context) # type: ignore
|
229
237
|
|
230
238
|
# Hooks for Tools
|
231
|
-
async def on_list_tools(
|
239
|
+
async def on_list_tools( # noqa
|
232
240
|
self,
|
233
241
|
context: MiddlewareContext[Any], # type: ignore
|
234
242
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -247,8 +255,11 @@ class SessionTrackingMiddleware(Middleware):
|
|
247
255
|
|
248
256
|
# Filter out specific tools or return empty list
|
249
257
|
allowed_tools: list[FunctionTool | ProxyTool | Any] = []
|
258
|
+
perms = Permissions()
|
250
259
|
for tool in response:
|
251
|
-
|
260
|
+
# Due to proxy & server naming
|
261
|
+
tool_name = tool.key
|
262
|
+
log.trace(f"🔍 Processing tool listing {tool_name}")
|
252
263
|
if isinstance(tool, FunctionTool):
|
253
264
|
log.trace("🔍 Tool is built-in")
|
254
265
|
log.trace(f"🔍 Tool is a FunctionTool: {tool}")
|
@@ -260,20 +271,18 @@ class SessionTrackingMiddleware(Middleware):
|
|
260
271
|
log.trace(f"🔍 Tool is a unknown type: {tool}")
|
261
272
|
continue
|
262
273
|
|
263
|
-
log.trace(f"🔍 Getting permissions for tool {
|
264
|
-
|
265
|
-
log.trace(f"🔍 Tool permissions: {permissions}")
|
266
|
-
if permissions["enabled"]:
|
274
|
+
log.trace(f"🔍 Getting permissions for tool {tool_name}")
|
275
|
+
if perms.is_tool_enabled(tool_name):
|
267
276
|
allowed_tools.append(tool)
|
268
277
|
else:
|
269
278
|
log.warning(
|
270
|
-
f"🔍 Tool {
|
279
|
+
f"🔍 Tool {tool_name} is disabled or not configured and will not be allowed"
|
271
280
|
)
|
272
281
|
continue
|
273
282
|
|
274
283
|
return allowed_tools # type: ignore
|
275
284
|
|
276
|
-
async def on_call_tool(
|
285
|
+
async def on_call_tool( # noqa
|
277
286
|
self,
|
278
287
|
context: MiddlewareContext[mt.CallToolRequestParams], # type: ignore
|
279
288
|
call_next: CallNext[mt.CallToolRequestParams, ToolResult], # type: ignore
|
@@ -296,7 +305,28 @@ class SessionTrackingMiddleware(Middleware):
|
|
296
305
|
|
297
306
|
assert session.data_access_tracker is not None
|
298
307
|
log.debug(f"🔍 Analyzing tool {context.message.name} for security implications")
|
299
|
-
|
308
|
+
try:
|
309
|
+
session.data_access_tracker.add_tool_call(context.message.name)
|
310
|
+
except SecurityError as e:
|
311
|
+
# Publish pre-block event enriched with session_id then wait up to 30s for approval
|
312
|
+
events.fire_and_forget(
|
313
|
+
{
|
314
|
+
"type": "mcp_pre_block",
|
315
|
+
"kind": "tool",
|
316
|
+
"name": context.message.name,
|
317
|
+
"session_id": session_id,
|
318
|
+
"error": str(e),
|
319
|
+
}
|
320
|
+
)
|
321
|
+
approved = await events.wait_for_approval(
|
322
|
+
session_id, "tool", context.message.name, timeout_s=30.0
|
323
|
+
)
|
324
|
+
if not approved:
|
325
|
+
raise
|
326
|
+
# Approved: apply effects and proceed
|
327
|
+
session.data_access_tracker.apply_effects_after_manual_approval(
|
328
|
+
"tool", context.message.name
|
329
|
+
)
|
300
330
|
# Telemetry: record tool call
|
301
331
|
record_tool_call(context.message.name)
|
302
332
|
|
@@ -330,7 +360,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
330
360
|
return await call_next(context) # type: ignore
|
331
361
|
|
332
362
|
# Hooks for Resources
|
333
|
-
async def on_list_resources(
|
363
|
+
async def on_list_resources( # noqa
|
334
364
|
self,
|
335
365
|
context: MiddlewareContext[Any], # type: ignore
|
336
366
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -350,6 +380,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
350
380
|
|
351
381
|
# Filter out specific tools or return empty list
|
352
382
|
allowed_resources: list[FunctionResource | ProxyResource | Any] = []
|
383
|
+
perms = Permissions()
|
353
384
|
for resource in response:
|
354
385
|
resource_name = str(resource.uri)
|
355
386
|
log.trace(f"🔍 Processing resource listing {resource_name}")
|
@@ -365,19 +396,17 @@ class SessionTrackingMiddleware(Middleware):
|
|
365
396
|
continue
|
366
397
|
|
367
398
|
log.trace(f"🔍 Getting permissions for resource {resource_name}")
|
368
|
-
|
369
|
-
log.trace(f"🔍 Resource permissions: {permissions}")
|
370
|
-
if permissions["enabled"]:
|
399
|
+
if perms.is_resource_enabled(resource_name):
|
371
400
|
allowed_resources.append(resource)
|
372
401
|
else:
|
373
402
|
log.warning(
|
374
|
-
f"🔍 Resource {resource_name} is disabled
|
403
|
+
f"🔍 Resource {resource_name} is disabled or not configured and will not be allowed"
|
375
404
|
)
|
376
405
|
continue
|
377
406
|
|
378
407
|
return allowed_resources # type: ignore
|
379
408
|
|
380
|
-
async def on_read_resource(
|
409
|
+
async def on_read_resource( # noqa
|
381
410
|
self,
|
382
411
|
context: MiddlewareContext[Any], # type: ignore
|
383
412
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -396,7 +425,26 @@ class SessionTrackingMiddleware(Middleware):
|
|
396
425
|
resource_name = str(context.message.uri)
|
397
426
|
|
398
427
|
log.debug(f"🔍 Analyzing resource {resource_name} for security implications")
|
399
|
-
|
428
|
+
try:
|
429
|
+
_ = session.data_access_tracker.add_resource_access(resource_name)
|
430
|
+
except SecurityError as e:
|
431
|
+
events.fire_and_forget(
|
432
|
+
{
|
433
|
+
"type": "mcp_pre_block",
|
434
|
+
"kind": "resource",
|
435
|
+
"name": resource_name,
|
436
|
+
"session_id": session_id,
|
437
|
+
"error": str(e),
|
438
|
+
}
|
439
|
+
)
|
440
|
+
approved = await events.wait_for_approval(
|
441
|
+
session_id, "resource", resource_name, timeout_s=30.0
|
442
|
+
)
|
443
|
+
if not approved:
|
444
|
+
raise
|
445
|
+
session.data_access_tracker.apply_effects_after_manual_approval(
|
446
|
+
"resource", resource_name
|
447
|
+
)
|
400
448
|
record_resource_used(resource_name)
|
401
449
|
|
402
450
|
# Update database session
|
@@ -412,7 +460,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
412
460
|
return await call_next(context)
|
413
461
|
|
414
462
|
# Hooks for Prompts
|
415
|
-
async def on_list_prompts(
|
463
|
+
async def on_list_prompts( # noqa
|
416
464
|
self,
|
417
465
|
context: MiddlewareContext[Any], # type: ignore
|
418
466
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -432,6 +480,7 @@ class SessionTrackingMiddleware(Middleware):
|
|
432
480
|
|
433
481
|
# Filter out specific tools or return empty list
|
434
482
|
allowed_prompts: list[ProxyPrompt | Any] = []
|
483
|
+
perms = Permissions()
|
435
484
|
for prompt in response:
|
436
485
|
prompt_name = str(prompt.name)
|
437
486
|
log.trace(f"🔍 Processing prompt listing {prompt_name}")
|
@@ -447,19 +496,17 @@ class SessionTrackingMiddleware(Middleware):
|
|
447
496
|
continue
|
448
497
|
|
449
498
|
log.trace(f"🔍 Getting permissions for prompt {prompt_name}")
|
450
|
-
|
451
|
-
log.trace(f"🔍 Prompt permissions: {permissions}")
|
452
|
-
if permissions["enabled"]:
|
499
|
+
if perms.is_prompt_enabled(prompt_name):
|
453
500
|
allowed_prompts.append(prompt)
|
454
501
|
else:
|
455
502
|
log.warning(
|
456
|
-
f"🔍 Prompt {prompt_name} is disabled
|
503
|
+
f"🔍 Prompt {prompt_name} is disabled or not configured and will not be allowed"
|
457
504
|
)
|
458
505
|
continue
|
459
506
|
|
460
507
|
return allowed_prompts # type: ignore
|
461
508
|
|
462
|
-
async def on_get_prompt(
|
509
|
+
async def on_get_prompt( # noqa
|
463
510
|
self,
|
464
511
|
context: MiddlewareContext[Any], # type: ignore
|
465
512
|
call_next: CallNext[Any, Any], # type: ignore
|
@@ -477,7 +524,24 @@ class SessionTrackingMiddleware(Middleware):
|
|
477
524
|
prompt_name = context.message.name
|
478
525
|
|
479
526
|
log.debug(f"🔍 Analyzing prompt {prompt_name} for security implications")
|
480
|
-
|
527
|
+
try:
|
528
|
+
_ = session.data_access_tracker.add_prompt_access(prompt_name)
|
529
|
+
except SecurityError as e:
|
530
|
+
events.fire_and_forget(
|
531
|
+
{
|
532
|
+
"type": "mcp_pre_block",
|
533
|
+
"kind": "prompt",
|
534
|
+
"name": prompt_name,
|
535
|
+
"session_id": session_id,
|
536
|
+
"error": str(e),
|
537
|
+
}
|
538
|
+
)
|
539
|
+
approved = await events.wait_for_approval(
|
540
|
+
session_id, "prompt", prompt_name, timeout_s=30.0
|
541
|
+
)
|
542
|
+
if not approved:
|
543
|
+
raise
|
544
|
+
session.data_access_tracker.apply_effects_after_manual_approval("prompt", prompt_name)
|
481
545
|
record_prompt_used(prompt_name)
|
482
546
|
|
483
547
|
# Update database session
|
src/oauth_manager.py
ADDED
@@ -0,0 +1,281 @@
|
|
1
|
+
"""
|
2
|
+
OAuth Manager for OpenEdison MCP Gateway
|
3
|
+
|
4
|
+
Handles OAuth 2.1 authentication for MCP servers using FastMCP's built-in OAuth support.
|
5
|
+
Provides detection, token management, and authentication flow coordination.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import asyncio
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from enum import Enum
|
11
|
+
from pathlib import Path
|
12
|
+
|
13
|
+
from fastmcp.client.auth.oauth import (
|
14
|
+
FileTokenStorage,
|
15
|
+
OAuth,
|
16
|
+
check_if_auth_required,
|
17
|
+
default_cache_dir,
|
18
|
+
)
|
19
|
+
from loguru import logger as log
|
20
|
+
|
21
|
+
|
22
|
+
class OAuthStatus(Enum):
|
23
|
+
"""OAuth authentication status for MCP servers."""
|
24
|
+
|
25
|
+
UNKNOWN = "unknown" # noqa
|
26
|
+
NOT_REQUIRED = "not_required"
|
27
|
+
NEEDS_AUTH = "needs_auth"
|
28
|
+
AUTHENTICATED = "authenticated"
|
29
|
+
ERROR = "error"
|
30
|
+
EXPIRED = "expired" # noqa
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class OAuthServerInfo:
|
35
|
+
"""OAuth information for an MCP server."""
|
36
|
+
|
37
|
+
server_name: str
|
38
|
+
mcp_url: str
|
39
|
+
status: OAuthStatus
|
40
|
+
scopes: list[str] | None = None
|
41
|
+
client_name: str = "OpenEdison MCP Gateway"
|
42
|
+
error_message: str | None = None
|
43
|
+
token_expires_at: str | None = None
|
44
|
+
has_refresh_token: bool = False
|
45
|
+
|
46
|
+
|
47
|
+
class OAuthManager:
|
48
|
+
"""
|
49
|
+
Manages OAuth authentication for MCP servers.
|
50
|
+
|
51
|
+
This class provides a centralized interface for:
|
52
|
+
- Detecting which MCP servers require OAuth
|
53
|
+
- Managing OAuth tokens and credentials
|
54
|
+
- Providing OAuth authentication objects for FastMCP clients
|
55
|
+
- Handling token refresh and expiration
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self, cache_dir: Path | None = None):
|
59
|
+
"""
|
60
|
+
Initialize OAuth manager.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
cache_dir: Directory for token cache. Defaults to FastMCP's default.
|
64
|
+
"""
|
65
|
+
self.cache_dir = cache_dir or default_cache_dir()
|
66
|
+
self._oauth_info: dict[str, OAuthServerInfo] = {}
|
67
|
+
|
68
|
+
log.info(f"🔐 OAuth Manager initialized with cache dir: {self.cache_dir}")
|
69
|
+
|
70
|
+
async def check_oauth_requirement(
|
71
|
+
self, server_name: str, mcp_url: str | None, timeout_seconds: float = 10.0
|
72
|
+
) -> OAuthServerInfo:
|
73
|
+
"""
|
74
|
+
Check if an MCP server requires OAuth authentication.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
server_name: Name of the MCP server
|
78
|
+
mcp_url: URL of the MCP endpoint (None for local servers)
|
79
|
+
timeout_seconds: Timeout for the check request
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
OAuthServerInfo with detection results
|
83
|
+
"""
|
84
|
+
log.info(f"🔍 Checking OAuth requirement for {server_name}")
|
85
|
+
|
86
|
+
# If no mcp_url provided, this is a local server - no OAuth needed
|
87
|
+
if not mcp_url:
|
88
|
+
info = OAuthServerInfo(
|
89
|
+
server_name=server_name, mcp_url="", status=OAuthStatus.NOT_REQUIRED
|
90
|
+
)
|
91
|
+
log.info(f"✅ {server_name} is a local server - no OAuth required")
|
92
|
+
self._oauth_info[server_name] = info
|
93
|
+
return info
|
94
|
+
|
95
|
+
log.info(f"🔍 Checking OAuth requirement for remote server {server_name} at {mcp_url}")
|
96
|
+
|
97
|
+
try:
|
98
|
+
# Check if auth is required (with timeout)
|
99
|
+
auth_required = await asyncio.wait_for(
|
100
|
+
check_if_auth_required(mcp_url), timeout=timeout_seconds
|
101
|
+
)
|
102
|
+
|
103
|
+
if not auth_required:
|
104
|
+
info = OAuthServerInfo(
|
105
|
+
server_name=server_name, mcp_url=mcp_url, status=OAuthStatus.NOT_REQUIRED
|
106
|
+
)
|
107
|
+
log.info(f"✅ {server_name} does not require OAuth")
|
108
|
+
self._oauth_info[server_name] = info
|
109
|
+
return info
|
110
|
+
|
111
|
+
# OAuth is required, proceed with token checking
|
112
|
+
log.info(f"🔐 {server_name} requires OAuth authentication")
|
113
|
+
|
114
|
+
# Check if we have existing valid tokens
|
115
|
+
token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
|
116
|
+
existing_tokens = await token_storage.get_tokens()
|
117
|
+
|
118
|
+
status = OAuthStatus.NEEDS_AUTH
|
119
|
+
token_expires_at = None
|
120
|
+
has_refresh_token = False
|
121
|
+
|
122
|
+
if existing_tokens:
|
123
|
+
# Check if tokens are still valid
|
124
|
+
# Note: FastMCP's FileTokenStorage doesn't expose expiration directly,
|
125
|
+
# so we'll attempt to use the tokens and see if they work
|
126
|
+
has_refresh_token = bool(existing_tokens.refresh_token)
|
127
|
+
if existing_tokens.access_token:
|
128
|
+
# We have tokens, assume they're valid for now
|
129
|
+
# The actual validation will happen when the client tries to use them
|
130
|
+
status = OAuthStatus.AUTHENTICATED
|
131
|
+
# Try to get expiration time if available
|
132
|
+
try:
|
133
|
+
expires_at = getattr(existing_tokens, "expires_at", None)
|
134
|
+
if expires_at:
|
135
|
+
token_expires_at = str(expires_at)
|
136
|
+
else:
|
137
|
+
expires_in = getattr(existing_tokens, "expires_in", None)
|
138
|
+
if expires_in:
|
139
|
+
# If expires_in is available, we can calculate expiration
|
140
|
+
from datetime import datetime, timedelta
|
141
|
+
|
142
|
+
expiry = datetime.now() + timedelta(seconds=expires_in)
|
143
|
+
token_expires_at = expiry.isoformat()
|
144
|
+
except Exception:
|
145
|
+
# If we can't get expiration info, that's ok - token_expires_at will be None
|
146
|
+
pass
|
147
|
+
|
148
|
+
info = OAuthServerInfo(
|
149
|
+
server_name=server_name,
|
150
|
+
mcp_url=mcp_url,
|
151
|
+
status=status,
|
152
|
+
scopes=None, # We don't have metadata discovery, so no scopes info
|
153
|
+
token_expires_at=token_expires_at,
|
154
|
+
has_refresh_token=has_refresh_token,
|
155
|
+
)
|
156
|
+
|
157
|
+
log.info(f"🔐 {server_name} OAuth status: {status.value}")
|
158
|
+
self._oauth_info[server_name] = info
|
159
|
+
return info
|
160
|
+
|
161
|
+
except TimeoutError:
|
162
|
+
info = OAuthServerInfo(
|
163
|
+
server_name=server_name,
|
164
|
+
mcp_url=mcp_url,
|
165
|
+
status=OAuthStatus.ERROR,
|
166
|
+
error_message=f"OAuth check timed out after {timeout_seconds}s",
|
167
|
+
)
|
168
|
+
log.warning(f"⏰ OAuth check for {server_name} timed out")
|
169
|
+
self._oauth_info[server_name] = info
|
170
|
+
return info
|
171
|
+
|
172
|
+
except Exception as e:
|
173
|
+
info = OAuthServerInfo(
|
174
|
+
server_name=server_name,
|
175
|
+
mcp_url=mcp_url,
|
176
|
+
status=OAuthStatus.ERROR,
|
177
|
+
error_message=str(e),
|
178
|
+
)
|
179
|
+
log.error(f"❌ OAuth check for {server_name} failed: {e}")
|
180
|
+
self._oauth_info[server_name] = info
|
181
|
+
return info
|
182
|
+
|
183
|
+
def get_oauth_auth(
|
184
|
+
self,
|
185
|
+
server_name: str,
|
186
|
+
mcp_url: str,
|
187
|
+
scopes: list[str] | None = None,
|
188
|
+
client_name: str | None = None,
|
189
|
+
) -> OAuth | None:
|
190
|
+
"""
|
191
|
+
Get OAuth authentication object for FastMCP client.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
server_name: Name of the MCP server
|
195
|
+
mcp_url: URL of the MCP endpoint
|
196
|
+
scopes: OAuth scopes to request
|
197
|
+
client_name: Client name for OAuth registration
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
OAuth authentication object, or None if OAuth not required
|
201
|
+
"""
|
202
|
+
info = self._oauth_info.get(server_name)
|
203
|
+
|
204
|
+
if not info or info.status == OAuthStatus.NOT_REQUIRED:
|
205
|
+
return None
|
206
|
+
|
207
|
+
if info.status == OAuthStatus.ERROR:
|
208
|
+
log.warning(f"⚠️ Cannot create OAuth auth for {server_name}: {info.error_message}")
|
209
|
+
return None
|
210
|
+
|
211
|
+
try:
|
212
|
+
oauth = OAuth(
|
213
|
+
mcp_url=mcp_url,
|
214
|
+
scopes=scopes or info.scopes,
|
215
|
+
client_name=client_name or info.client_name,
|
216
|
+
token_storage_cache_dir=self.cache_dir,
|
217
|
+
callback_port=50001,
|
218
|
+
)
|
219
|
+
log.info(f"🔐 Created OAuth auth for {server_name}")
|
220
|
+
return oauth
|
221
|
+
|
222
|
+
except Exception as e:
|
223
|
+
log.error(f"❌ Failed to create OAuth auth for {server_name}: {e}")
|
224
|
+
return None
|
225
|
+
|
226
|
+
def clear_tokens(self, server_name: str, mcp_url: str) -> bool:
|
227
|
+
"""
|
228
|
+
Clear stored OAuth tokens for a server.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
server_name: Name of the MCP server
|
232
|
+
mcp_url: URL of the MCP endpoint
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
True if tokens were cleared successfully
|
236
|
+
"""
|
237
|
+
try:
|
238
|
+
token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
|
239
|
+
token_storage.clear()
|
240
|
+
|
241
|
+
# Update our cached info
|
242
|
+
if server_name in self._oauth_info:
|
243
|
+
self._oauth_info[server_name].status = OAuthStatus.NEEDS_AUTH
|
244
|
+
self._oauth_info[server_name].token_expires_at = None
|
245
|
+
self._oauth_info[server_name].has_refresh_token = False
|
246
|
+
|
247
|
+
log.info(f"🗑️ Cleared OAuth tokens for {server_name}")
|
248
|
+
return True
|
249
|
+
|
250
|
+
except Exception as e:
|
251
|
+
log.error(f"❌ Failed to clear tokens for {server_name}: {e}")
|
252
|
+
return False
|
253
|
+
|
254
|
+
def get_server_info(self, server_name: str) -> OAuthServerInfo | None:
|
255
|
+
"""Get OAuth info for a server."""
|
256
|
+
return self._oauth_info.get(server_name)
|
257
|
+
|
258
|
+
async def refresh_server_status(self, server_name: str, mcp_url: str) -> OAuthServerInfo:
|
259
|
+
"""
|
260
|
+
Refresh OAuth status for a server.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
server_name: Name of the MCP server
|
264
|
+
mcp_url: URL of the MCP endpoint
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
Updated OAuthServerInfo
|
268
|
+
"""
|
269
|
+
return await self.check_oauth_requirement(server_name, mcp_url)
|
270
|
+
|
271
|
+
|
272
|
+
# Global OAuth manager instance
|
273
|
+
_oauth_manager: OAuthManager | None = None
|
274
|
+
|
275
|
+
|
276
|
+
def get_oauth_manager() -> OAuthManager:
|
277
|
+
"""Get the global OAuth manager instance."""
|
278
|
+
global _oauth_manager
|
279
|
+
if _oauth_manager is None:
|
280
|
+
_oauth_manager = OAuthManager()
|
281
|
+
return _oauth_manager
|