open-edison 0.1.17__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,7 +12,7 @@ from contextvars import ContextVar
12
12
  from dataclasses import dataclass, field
13
13
  from datetime import datetime
14
14
  from pathlib import Path
15
- from typing import Any
15
+ from typing import Any, cast
16
16
 
17
17
  import mcp.types as mt
18
18
  from fastmcp.prompts.prompt import FunctionPrompt
@@ -27,8 +27,13 @@ from sqlalchemy import JSON, Column, Integer, String, create_engine, event
27
27
  from sqlalchemy.orm import Session, declarative_base
28
28
  from sqlalchemy.sql import select
29
29
 
30
+ from src import events
30
31
  from src.config import get_config_dir # type: ignore[reportMissingImports]
31
- from src.middleware.data_access_tracker import DataAccessTracker
32
+ from src.middleware.data_access_tracker import (
33
+ DataAccessTracker,
34
+ SecurityError,
35
+ )
36
+ from src.permissions import Permissions
32
37
  from src.telemetry import (
33
38
  record_prompt_used,
34
39
  record_resource_used,
@@ -67,7 +72,10 @@ class MCPSessionModel(Base): # type: ignore
67
72
  data_access_summary = Column(JSON) # type: ignore
68
73
 
69
74
 
70
- current_session_id_ctxvar: ContextVar[str | None] = ContextVar("current_session_id", default=None)
75
+ current_session_id_ctxvar: ContextVar[str | None] = ContextVar(
76
+ "current_session_id",
77
+ default=cast(str | None, None), # noqa: B039
78
+ )
71
79
 
72
80
 
73
81
  def get_current_session_data_tracker() -> DataAccessTracker | None:
@@ -101,7 +109,7 @@ def create_db_session() -> Generator[Session, None, None]:
101
109
  engine = create_engine(f"sqlite:///{db_path}")
102
110
 
103
111
  # Ensure changes are flushed to the main database file (avoid WAL for sql.js compatibility)
104
- @event.listens_for(engine, "connect")
112
+ @event.listens_for(engine, "connect") # noqa
105
113
  def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore[no-untyped-def] # noqa
106
114
  cur = dbapi_connection.cursor() # type: ignore[attr-defined]
107
115
  try:
@@ -171,7 +179,7 @@ def get_session_from_db(session_id: str) -> MCPSession:
171
179
  "has_external_communication", False
172
180
  )
173
181
  # Restore ACL highest level if present
174
- if isinstance(summary_data, dict) and "acl" in summary_data:
182
+ if isinstance(summary_data, dict) and "acl" in summary_data: # type: ignore
175
183
  acl_summary: Any = summary_data.get("acl") # type: ignore
176
184
  if isinstance(acl_summary, dict):
177
185
  highest = acl_summary.get("highest_acl_level") # type: ignore
@@ -214,7 +222,7 @@ class SessionTrackingMiddleware(Middleware):
214
222
  return session, session_id
215
223
 
216
224
  # General hooks for on_request, on_message, etc.
217
- async def on_request(
225
+ async def on_request( # noqa
218
226
  self,
219
227
  context: MiddlewareContext[mt.Request[Any, Any]], # type: ignore
220
228
  call_next: CallNext[mt.Request[Any, Any], Any], # type: ignore
@@ -228,7 +236,7 @@ class SessionTrackingMiddleware(Middleware):
228
236
  return await call_next(context) # type: ignore
229
237
 
230
238
  # Hooks for Tools
231
- async def on_list_tools(
239
+ async def on_list_tools( # noqa
232
240
  self,
233
241
  context: MiddlewareContext[Any], # type: ignore
234
242
  call_next: CallNext[Any, Any], # type: ignore
@@ -247,8 +255,11 @@ class SessionTrackingMiddleware(Middleware):
247
255
 
248
256
  # Filter out specific tools or return empty list
249
257
  allowed_tools: list[FunctionTool | ProxyTool | Any] = []
258
+ perms = Permissions()
250
259
  for tool in response:
251
- log.trace(f"🔍 Processing tool listing {tool.name}")
260
+ # Due to proxy & server naming
261
+ tool_name = tool.key
262
+ log.trace(f"🔍 Processing tool listing {tool_name}")
252
263
  if isinstance(tool, FunctionTool):
253
264
  log.trace("🔍 Tool is built-in")
254
265
  log.trace(f"🔍 Tool is a FunctionTool: {tool}")
@@ -260,20 +271,18 @@ class SessionTrackingMiddleware(Middleware):
260
271
  log.trace(f"🔍 Tool is a unknown type: {tool}")
261
272
  continue
262
273
 
263
- log.trace(f"🔍 Getting permissions for tool {tool.name}")
264
- permissions = session.data_access_tracker.get_tool_permissions(tool.name)
265
- log.trace(f"🔍 Tool permissions: {permissions}")
266
- if permissions["enabled"]:
274
+ log.trace(f"🔍 Getting permissions for tool {tool_name}")
275
+ if perms.is_tool_enabled(tool_name):
267
276
  allowed_tools.append(tool)
268
277
  else:
269
278
  log.warning(
270
- f"🔍 Tool {tool.name} is disabled on not configured and will not be allowed"
279
+ f"🔍 Tool {tool_name} is disabled or not configured and will not be allowed"
271
280
  )
272
281
  continue
273
282
 
274
283
  return allowed_tools # type: ignore
275
284
 
276
- async def on_call_tool(
285
+ async def on_call_tool( # noqa
277
286
  self,
278
287
  context: MiddlewareContext[mt.CallToolRequestParams], # type: ignore
279
288
  call_next: CallNext[mt.CallToolRequestParams, ToolResult], # type: ignore
@@ -296,7 +305,28 @@ class SessionTrackingMiddleware(Middleware):
296
305
 
297
306
  assert session.data_access_tracker is not None
298
307
  log.debug(f"🔍 Analyzing tool {context.message.name} for security implications")
299
- session.data_access_tracker.add_tool_call(context.message.name)
308
+ try:
309
+ session.data_access_tracker.add_tool_call(context.message.name)
310
+ except SecurityError as e:
311
+ # Publish pre-block event enriched with session_id then wait up to 30s for approval
312
+ events.fire_and_forget(
313
+ {
314
+ "type": "mcp_pre_block",
315
+ "kind": "tool",
316
+ "name": context.message.name,
317
+ "session_id": session_id,
318
+ "error": str(e),
319
+ }
320
+ )
321
+ approved = await events.wait_for_approval(
322
+ session_id, "tool", context.message.name, timeout_s=30.0
323
+ )
324
+ if not approved:
325
+ raise
326
+ # Approved: apply effects and proceed
327
+ session.data_access_tracker.apply_effects_after_manual_approval(
328
+ "tool", context.message.name
329
+ )
300
330
  # Telemetry: record tool call
301
331
  record_tool_call(context.message.name)
302
332
 
@@ -330,7 +360,7 @@ class SessionTrackingMiddleware(Middleware):
330
360
  return await call_next(context) # type: ignore
331
361
 
332
362
  # Hooks for Resources
333
- async def on_list_resources(
363
+ async def on_list_resources( # noqa
334
364
  self,
335
365
  context: MiddlewareContext[Any], # type: ignore
336
366
  call_next: CallNext[Any, Any], # type: ignore
@@ -350,6 +380,7 @@ class SessionTrackingMiddleware(Middleware):
350
380
 
351
381
  # Filter out specific tools or return empty list
352
382
  allowed_resources: list[FunctionResource | ProxyResource | Any] = []
383
+ perms = Permissions()
353
384
  for resource in response:
354
385
  resource_name = str(resource.uri)
355
386
  log.trace(f"🔍 Processing resource listing {resource_name}")
@@ -365,19 +396,17 @@ class SessionTrackingMiddleware(Middleware):
365
396
  continue
366
397
 
367
398
  log.trace(f"🔍 Getting permissions for resource {resource_name}")
368
- permissions = session.data_access_tracker.get_resource_permissions(resource_name)
369
- log.trace(f"🔍 Resource permissions: {permissions}")
370
- if permissions["enabled"]:
399
+ if perms.is_resource_enabled(resource_name):
371
400
  allowed_resources.append(resource)
372
401
  else:
373
402
  log.warning(
374
- f"🔍 Resource {resource_name} is disabled on not configured and will not be allowed"
403
+ f"🔍 Resource {resource_name} is disabled or not configured and will not be allowed"
375
404
  )
376
405
  continue
377
406
 
378
407
  return allowed_resources # type: ignore
379
408
 
380
- async def on_read_resource(
409
+ async def on_read_resource( # noqa
381
410
  self,
382
411
  context: MiddlewareContext[Any], # type: ignore
383
412
  call_next: CallNext[Any, Any], # type: ignore
@@ -396,7 +425,26 @@ class SessionTrackingMiddleware(Middleware):
396
425
  resource_name = str(context.message.uri)
397
426
 
398
427
  log.debug(f"🔍 Analyzing resource {resource_name} for security implications")
399
- _ = session.data_access_tracker.add_resource_access(resource_name)
428
+ try:
429
+ _ = session.data_access_tracker.add_resource_access(resource_name)
430
+ except SecurityError as e:
431
+ events.fire_and_forget(
432
+ {
433
+ "type": "mcp_pre_block",
434
+ "kind": "resource",
435
+ "name": resource_name,
436
+ "session_id": session_id,
437
+ "error": str(e),
438
+ }
439
+ )
440
+ approved = await events.wait_for_approval(
441
+ session_id, "resource", resource_name, timeout_s=30.0
442
+ )
443
+ if not approved:
444
+ raise
445
+ session.data_access_tracker.apply_effects_after_manual_approval(
446
+ "resource", resource_name
447
+ )
400
448
  record_resource_used(resource_name)
401
449
 
402
450
  # Update database session
@@ -412,7 +460,7 @@ class SessionTrackingMiddleware(Middleware):
412
460
  return await call_next(context)
413
461
 
414
462
  # Hooks for Prompts
415
- async def on_list_prompts(
463
+ async def on_list_prompts( # noqa
416
464
  self,
417
465
  context: MiddlewareContext[Any], # type: ignore
418
466
  call_next: CallNext[Any, Any], # type: ignore
@@ -432,6 +480,7 @@ class SessionTrackingMiddleware(Middleware):
432
480
 
433
481
  # Filter out specific tools or return empty list
434
482
  allowed_prompts: list[ProxyPrompt | Any] = []
483
+ perms = Permissions()
435
484
  for prompt in response:
436
485
  prompt_name = str(prompt.name)
437
486
  log.trace(f"🔍 Processing prompt listing {prompt_name}")
@@ -447,19 +496,17 @@ class SessionTrackingMiddleware(Middleware):
447
496
  continue
448
497
 
449
498
  log.trace(f"🔍 Getting permissions for prompt {prompt_name}")
450
- permissions = session.data_access_tracker.get_prompt_permissions(prompt_name)
451
- log.trace(f"🔍 Prompt permissions: {permissions}")
452
- if permissions["enabled"]:
499
+ if perms.is_prompt_enabled(prompt_name):
453
500
  allowed_prompts.append(prompt)
454
501
  else:
455
502
  log.warning(
456
- f"🔍 Prompt {prompt_name} is disabled on not configured and will not be allowed"
503
+ f"🔍 Prompt {prompt_name} is disabled or not configured and will not be allowed"
457
504
  )
458
505
  continue
459
506
 
460
507
  return allowed_prompts # type: ignore
461
508
 
462
- async def on_get_prompt(
509
+ async def on_get_prompt( # noqa
463
510
  self,
464
511
  context: MiddlewareContext[Any], # type: ignore
465
512
  call_next: CallNext[Any, Any], # type: ignore
@@ -477,7 +524,24 @@ class SessionTrackingMiddleware(Middleware):
477
524
  prompt_name = context.message.name
478
525
 
479
526
  log.debug(f"🔍 Analyzing prompt {prompt_name} for security implications")
480
- _ = session.data_access_tracker.add_prompt_access(prompt_name)
527
+ try:
528
+ _ = session.data_access_tracker.add_prompt_access(prompt_name)
529
+ except SecurityError as e:
530
+ events.fire_and_forget(
531
+ {
532
+ "type": "mcp_pre_block",
533
+ "kind": "prompt",
534
+ "name": prompt_name,
535
+ "session_id": session_id,
536
+ "error": str(e),
537
+ }
538
+ )
539
+ approved = await events.wait_for_approval(
540
+ session_id, "prompt", prompt_name, timeout_s=30.0
541
+ )
542
+ if not approved:
543
+ raise
544
+ session.data_access_tracker.apply_effects_after_manual_approval("prompt", prompt_name)
481
545
  record_prompt_used(prompt_name)
482
546
 
483
547
  # Update database session
src/oauth_manager.py ADDED
@@ -0,0 +1,281 @@
1
+ """
2
+ OAuth Manager for OpenEdison MCP Gateway
3
+
4
+ Handles OAuth 2.1 authentication for MCP servers using FastMCP's built-in OAuth support.
5
+ Provides detection, token management, and authentication flow coordination.
6
+ """
7
+
8
+ import asyncio
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from pathlib import Path
12
+
13
+ from fastmcp.client.auth.oauth import (
14
+ FileTokenStorage,
15
+ OAuth,
16
+ check_if_auth_required,
17
+ default_cache_dir,
18
+ )
19
+ from loguru import logger as log
20
+
21
+
22
+ class OAuthStatus(Enum):
23
+ """OAuth authentication status for MCP servers."""
24
+
25
+ UNKNOWN = "unknown" # noqa
26
+ NOT_REQUIRED = "not_required"
27
+ NEEDS_AUTH = "needs_auth"
28
+ AUTHENTICATED = "authenticated"
29
+ ERROR = "error"
30
+ EXPIRED = "expired" # noqa
31
+
32
+
33
+ @dataclass
34
+ class OAuthServerInfo:
35
+ """OAuth information for an MCP server."""
36
+
37
+ server_name: str
38
+ mcp_url: str
39
+ status: OAuthStatus
40
+ scopes: list[str] | None = None
41
+ client_name: str = "OpenEdison MCP Gateway"
42
+ error_message: str | None = None
43
+ token_expires_at: str | None = None
44
+ has_refresh_token: bool = False
45
+
46
+
47
+ class OAuthManager:
48
+ """
49
+ Manages OAuth authentication for MCP servers.
50
+
51
+ This class provides a centralized interface for:
52
+ - Detecting which MCP servers require OAuth
53
+ - Managing OAuth tokens and credentials
54
+ - Providing OAuth authentication objects for FastMCP clients
55
+ - Handling token refresh and expiration
56
+ """
57
+
58
+ def __init__(self, cache_dir: Path | None = None):
59
+ """
60
+ Initialize OAuth manager.
61
+
62
+ Args:
63
+ cache_dir: Directory for token cache. Defaults to FastMCP's default.
64
+ """
65
+ self.cache_dir = cache_dir or default_cache_dir()
66
+ self._oauth_info: dict[str, OAuthServerInfo] = {}
67
+
68
+ log.info(f"🔐 OAuth Manager initialized with cache dir: {self.cache_dir}")
69
+
70
+ async def check_oauth_requirement(
71
+ self, server_name: str, mcp_url: str | None, timeout_seconds: float = 10.0
72
+ ) -> OAuthServerInfo:
73
+ """
74
+ Check if an MCP server requires OAuth authentication.
75
+
76
+ Args:
77
+ server_name: Name of the MCP server
78
+ mcp_url: URL of the MCP endpoint (None for local servers)
79
+ timeout_seconds: Timeout for the check request
80
+
81
+ Returns:
82
+ OAuthServerInfo with detection results
83
+ """
84
+ log.info(f"🔍 Checking OAuth requirement for {server_name}")
85
+
86
+ # If no mcp_url provided, this is a local server - no OAuth needed
87
+ if not mcp_url:
88
+ info = OAuthServerInfo(
89
+ server_name=server_name, mcp_url="", status=OAuthStatus.NOT_REQUIRED
90
+ )
91
+ log.info(f"✅ {server_name} is a local server - no OAuth required")
92
+ self._oauth_info[server_name] = info
93
+ return info
94
+
95
+ log.info(f"🔍 Checking OAuth requirement for remote server {server_name} at {mcp_url}")
96
+
97
+ try:
98
+ # Check if auth is required (with timeout)
99
+ auth_required = await asyncio.wait_for(
100
+ check_if_auth_required(mcp_url), timeout=timeout_seconds
101
+ )
102
+
103
+ if not auth_required:
104
+ info = OAuthServerInfo(
105
+ server_name=server_name, mcp_url=mcp_url, status=OAuthStatus.NOT_REQUIRED
106
+ )
107
+ log.info(f"✅ {server_name} does not require OAuth")
108
+ self._oauth_info[server_name] = info
109
+ return info
110
+
111
+ # OAuth is required, proceed with token checking
112
+ log.info(f"🔐 {server_name} requires OAuth authentication")
113
+
114
+ # Check if we have existing valid tokens
115
+ token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
116
+ existing_tokens = await token_storage.get_tokens()
117
+
118
+ status = OAuthStatus.NEEDS_AUTH
119
+ token_expires_at = None
120
+ has_refresh_token = False
121
+
122
+ if existing_tokens:
123
+ # Check if tokens are still valid
124
+ # Note: FastMCP's FileTokenStorage doesn't expose expiration directly,
125
+ # so we'll attempt to use the tokens and see if they work
126
+ has_refresh_token = bool(existing_tokens.refresh_token)
127
+ if existing_tokens.access_token:
128
+ # We have tokens, assume they're valid for now
129
+ # The actual validation will happen when the client tries to use them
130
+ status = OAuthStatus.AUTHENTICATED
131
+ # Try to get expiration time if available
132
+ try:
133
+ expires_at = getattr(existing_tokens, "expires_at", None)
134
+ if expires_at:
135
+ token_expires_at = str(expires_at)
136
+ else:
137
+ expires_in = getattr(existing_tokens, "expires_in", None)
138
+ if expires_in:
139
+ # If expires_in is available, we can calculate expiration
140
+ from datetime import datetime, timedelta
141
+
142
+ expiry = datetime.now() + timedelta(seconds=expires_in)
143
+ token_expires_at = expiry.isoformat()
144
+ except Exception:
145
+ # If we can't get expiration info, that's ok - token_expires_at will be None
146
+ pass
147
+
148
+ info = OAuthServerInfo(
149
+ server_name=server_name,
150
+ mcp_url=mcp_url,
151
+ status=status,
152
+ scopes=None, # We don't have metadata discovery, so no scopes info
153
+ token_expires_at=token_expires_at,
154
+ has_refresh_token=has_refresh_token,
155
+ )
156
+
157
+ log.info(f"🔐 {server_name} OAuth status: {status.value}")
158
+ self._oauth_info[server_name] = info
159
+ return info
160
+
161
+ except TimeoutError:
162
+ info = OAuthServerInfo(
163
+ server_name=server_name,
164
+ mcp_url=mcp_url,
165
+ status=OAuthStatus.ERROR,
166
+ error_message=f"OAuth check timed out after {timeout_seconds}s",
167
+ )
168
+ log.warning(f"⏰ OAuth check for {server_name} timed out")
169
+ self._oauth_info[server_name] = info
170
+ return info
171
+
172
+ except Exception as e:
173
+ info = OAuthServerInfo(
174
+ server_name=server_name,
175
+ mcp_url=mcp_url,
176
+ status=OAuthStatus.ERROR,
177
+ error_message=str(e),
178
+ )
179
+ log.error(f"❌ OAuth check for {server_name} failed: {e}")
180
+ self._oauth_info[server_name] = info
181
+ return info
182
+
183
+ def get_oauth_auth(
184
+ self,
185
+ server_name: str,
186
+ mcp_url: str,
187
+ scopes: list[str] | None = None,
188
+ client_name: str | None = None,
189
+ ) -> OAuth | None:
190
+ """
191
+ Get OAuth authentication object for FastMCP client.
192
+
193
+ Args:
194
+ server_name: Name of the MCP server
195
+ mcp_url: URL of the MCP endpoint
196
+ scopes: OAuth scopes to request
197
+ client_name: Client name for OAuth registration
198
+
199
+ Returns:
200
+ OAuth authentication object, or None if OAuth not required
201
+ """
202
+ info = self._oauth_info.get(server_name)
203
+
204
+ if not info or info.status == OAuthStatus.NOT_REQUIRED:
205
+ return None
206
+
207
+ if info.status == OAuthStatus.ERROR:
208
+ log.warning(f"⚠️ Cannot create OAuth auth for {server_name}: {info.error_message}")
209
+ return None
210
+
211
+ try:
212
+ oauth = OAuth(
213
+ mcp_url=mcp_url,
214
+ scopes=scopes or info.scopes,
215
+ client_name=client_name or info.client_name,
216
+ token_storage_cache_dir=self.cache_dir,
217
+ callback_port=50001,
218
+ )
219
+ log.info(f"🔐 Created OAuth auth for {server_name}")
220
+ return oauth
221
+
222
+ except Exception as e:
223
+ log.error(f"❌ Failed to create OAuth auth for {server_name}: {e}")
224
+ return None
225
+
226
+ def clear_tokens(self, server_name: str, mcp_url: str) -> bool:
227
+ """
228
+ Clear stored OAuth tokens for a server.
229
+
230
+ Args:
231
+ server_name: Name of the MCP server
232
+ mcp_url: URL of the MCP endpoint
233
+
234
+ Returns:
235
+ True if tokens were cleared successfully
236
+ """
237
+ try:
238
+ token_storage = FileTokenStorage(server_url=mcp_url, cache_dir=self.cache_dir)
239
+ token_storage.clear()
240
+
241
+ # Update our cached info
242
+ if server_name in self._oauth_info:
243
+ self._oauth_info[server_name].status = OAuthStatus.NEEDS_AUTH
244
+ self._oauth_info[server_name].token_expires_at = None
245
+ self._oauth_info[server_name].has_refresh_token = False
246
+
247
+ log.info(f"🗑️ Cleared OAuth tokens for {server_name}")
248
+ return True
249
+
250
+ except Exception as e:
251
+ log.error(f"❌ Failed to clear tokens for {server_name}: {e}")
252
+ return False
253
+
254
+ def get_server_info(self, server_name: str) -> OAuthServerInfo | None:
255
+ """Get OAuth info for a server."""
256
+ return self._oauth_info.get(server_name)
257
+
258
+ async def refresh_server_status(self, server_name: str, mcp_url: str) -> OAuthServerInfo:
259
+ """
260
+ Refresh OAuth status for a server.
261
+
262
+ Args:
263
+ server_name: Name of the MCP server
264
+ mcp_url: URL of the MCP endpoint
265
+
266
+ Returns:
267
+ Updated OAuthServerInfo
268
+ """
269
+ return await self.check_oauth_requirement(server_name, mcp_url)
270
+
271
+
272
+ # Global OAuth manager instance
273
+ _oauth_manager: OAuthManager | None = None
274
+
275
+
276
+ def get_oauth_manager() -> OAuthManager:
277
+ """Get the global OAuth manager instance."""
278
+ global _oauth_manager
279
+ if _oauth_manager is None:
280
+ _oauth_manager = OAuthManager()
281
+ return _oauth_manager