open-edison 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,477 @@
1
+ """
2
+ Session Tracking Middleware for Open Edison
3
+
4
+ This middleware tracks tool usage patterns across all mounted tool calls,
5
+ providing session-level statistics accessible via contextvar.
6
+ """
7
+
8
+ import uuid
9
+ from collections.abc import Generator
10
+ from contextlib import contextmanager
11
+ from contextvars import ContextVar
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import mcp.types as mt
18
+ from fastmcp.prompts.prompt import FunctionPrompt
19
+ from fastmcp.resources import FunctionResource
20
+ from fastmcp.server.middleware import Middleware
21
+ from fastmcp.server.middleware.middleware import CallNext, MiddlewareContext
22
+ from fastmcp.server.proxy import ProxyPrompt, ProxyResource, ProxyTool
23
+ from fastmcp.tools import FunctionTool
24
+ from fastmcp.tools.tool import ToolResult
25
+ from loguru import logger as log
26
+ from sqlalchemy import JSON, Column, Integer, String, create_engine, event
27
+ from sqlalchemy.orm import Session, declarative_base
28
+ from sqlalchemy.sql import select
29
+
30
+ from src.config import get_config_dir # type: ignore[reportMissingImports]
31
+ from src.middleware.data_access_tracker import DataAccessTracker
32
+
33
+
34
+ @dataclass
35
+ class ToolCall:
36
+ id: str
37
+ tool_name: str
38
+ parameters: dict[str, Any]
39
+ timestamp: datetime
40
+ duration_ms: float | None = None
41
+ status: str = "pending"
42
+ result: Any | None = None
43
+
44
+
45
+ @dataclass
46
+ class MCPSession:
47
+ session_id: str
48
+ correlation_id: str
49
+ tool_calls: list[ToolCall] = field(default_factory=list)
50
+ data_access_tracker: DataAccessTracker | None = None
51
+
52
+
53
+ Base = declarative_base()
54
+
55
+
56
+ class MCPSessionModel(Base): # type: ignore
57
+ __tablename__: str = "mcp_sessions"
58
+ id = Column(Integer, primary_key=True) # type: ignore
59
+ session_id = Column(String, unique=True) # type: ignore
60
+ correlation_id = Column(String) # type: ignore
61
+ tool_calls = Column(JSON) # type: ignore
62
+ data_access_summary = Column(JSON) # type: ignore
63
+
64
+
65
+ current_session_id_ctxvar: ContextVar[str | None] = ContextVar("current_session_id", default=None)
66
+
67
+
68
+ def get_current_session_data_tracker() -> DataAccessTracker | None:
69
+ """
70
+ Get the data access tracker for the current session.
71
+
72
+ Returns:
73
+ DataAccessTracker instance for the current session, or None if no session
74
+ """
75
+ session_id = current_session_id_ctxvar.get()
76
+ if session_id is None:
77
+ return None
78
+
79
+ try:
80
+ session = get_session_from_db(session_id)
81
+ return session.data_access_tracker
82
+ except Exception as e:
83
+ log.error(f"Failed to get current session data tracker: {e}")
84
+ return None
85
+
86
+
87
+ @contextmanager
88
+ def create_db_session() -> Generator[Session, None, None]:
89
+ """Create a db session to our local sqlite db (fixed location under config dir)."""
90
+ try:
91
+ cfg_dir = get_config_dir()
92
+ except Exception:
93
+ cfg_dir = Path.cwd()
94
+ db_path = cfg_dir / "sessions.db"
95
+ db_path.parent.mkdir(parents=True, exist_ok=True)
96
+ engine = create_engine(f"sqlite:///{db_path}")
97
+
98
+ # Ensure changes are flushed to the main database file (avoid WAL for sql.js compatibility)
99
+ @event.listens_for(engine, "connect")
100
+ def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore[no-untyped-def]
101
+ cur = dbapi_connection.cursor() # type: ignore[attr-defined]
102
+ try:
103
+ cur.execute("PRAGMA journal_mode=DELETE") # type: ignore[attr-defined]
104
+ cur.execute("PRAGMA synchronous=FULL") # type: ignore[attr-defined]
105
+ finally:
106
+ cur.close() # type: ignore[attr-defined]
107
+
108
+ # Ensure tables exist
109
+ Base.metadata.create_all(engine) # type: ignore
110
+ session = Session(engine)
111
+ try:
112
+ yield session
113
+ finally:
114
+ session.close()
115
+
116
+
117
+ def get_session_from_db(session_id: str) -> MCPSession:
118
+ """Get session from db"""
119
+ with create_db_session() as db_session:
120
+ session = db_session.execute(
121
+ select(MCPSessionModel).where(MCPSessionModel.session_id == session_id)
122
+ ).scalar_one_or_none()
123
+
124
+ if session is None:
125
+ # Create a new session model for the database
126
+ new_session_model = MCPSessionModel(
127
+ session_id=session_id,
128
+ correlation_id=str(uuid.uuid4()),
129
+ tool_calls=[], # type: ignore
130
+ data_access_summary={}, # type: ignore
131
+ )
132
+ db_session.add(new_session_model)
133
+ db_session.commit()
134
+
135
+ # Return the MCPSession object
136
+ return MCPSession(
137
+ session_id=session_id,
138
+ correlation_id=str(new_session_model.correlation_id),
139
+ tool_calls=[],
140
+ data_access_tracker=DataAccessTracker(),
141
+ )
142
+ # Return existing session
143
+ tool_calls: list[ToolCall] = []
144
+ if session.tool_calls is not None: # type: ignore
145
+ tool_calls_data = session.tool_calls # type: ignore
146
+ for tc_dict in tool_calls_data: # type: ignore
147
+ # Convert timestamp string back to datetime if it exists
148
+ tc_dict_copy = dict(tc_dict) # type: ignore
149
+ if "timestamp" in tc_dict_copy: # type: ignore
150
+ tc_dict_copy["timestamp"] = datetime.fromisoformat(tc_dict_copy["timestamp"]) # type: ignore
151
+ tool_calls.append(ToolCall(**tc_dict_copy)) # type: ignore
152
+
153
+ # Restore data access tracker from database if available
154
+ data_access_tracker = DataAccessTracker()
155
+ if hasattr(session, "data_access_summary") and session.data_access_summary: # type: ignore
156
+ summary_data = session.data_access_summary # type: ignore
157
+ if "lethal_trifecta" in summary_data:
158
+ trifecta = summary_data["lethal_trifecta"]
159
+ data_access_tracker.has_private_data_access = trifecta.get(
160
+ "has_private_data_access", False
161
+ )
162
+ data_access_tracker.has_untrusted_content_exposure = trifecta.get(
163
+ "has_untrusted_content_exposure", False
164
+ )
165
+ data_access_tracker.has_external_communication = trifecta.get(
166
+ "has_external_communication", False
167
+ )
168
+
169
+ return MCPSession(
170
+ session_id=session_id,
171
+ correlation_id=str(session.correlation_id),
172
+ tool_calls=tool_calls,
173
+ data_access_tracker=data_access_tracker,
174
+ )
175
+
176
+
177
+ class SessionTrackingMiddleware(Middleware):
178
+ """
179
+ Middleware that tracks tool call statistics for all mounted tools.
180
+
181
+ This middleware intercepts every tool call and maintains per-session
182
+ statistics accessible via contextvar.
183
+ """
184
+
185
+ def _get_or_create_session_stats(
186
+ self,
187
+ context: MiddlewareContext[mt.Request[Any, Any]], # type: ignore
188
+ ) -> tuple[MCPSession, str]:
189
+ """Get or create session stats for the current connection.
190
+ returns (session, session_id)"""
191
+
192
+ # Get session ID from HTTP headers if available
193
+ assert context.fastmcp_context is not None
194
+ session_id = context.fastmcp_context.session_id
195
+
196
+ # For debugging, let's log what we got
197
+ log.debug(f"FastMCP context session_id: {context.fastmcp_context.session_id}")
198
+
199
+ # Check if we already have a session for this user
200
+ session = get_session_from_db(session_id)
201
+ _ = current_session_id_ctxvar.set(session_id)
202
+ return session, session_id
203
+
204
+ # General hooks for on_request, on_message, etc.
205
+ async def on_request(
206
+ self,
207
+ context: MiddlewareContext[mt.Request[Any, Any]], # type: ignore
208
+ call_next: CallNext[mt.Request[Any, Any], Any], # type: ignore
209
+ ) -> Any:
210
+ """
211
+ Process the request and track tool calls.
212
+ """
213
+ # Get or create session stats
214
+ _, _session_id = self._get_or_create_session_stats(context)
215
+
216
+ return await call_next(context) # type: ignore
217
+
218
+ # Hooks for Tools
219
+ async def on_list_tools(
220
+ self,
221
+ context: MiddlewareContext[Any], # type: ignore
222
+ call_next: CallNext[Any, Any], # type: ignore
223
+ ) -> Any:
224
+ log.debug("🔍 on_list_tools")
225
+ # Get the original response
226
+ response = await call_next(context)
227
+ log.trace(f"🔍 on_list_tools response: {response}")
228
+
229
+ session_id = current_session_id_ctxvar.get()
230
+ if session_id is None:
231
+ raise ValueError("No session ID found in context")
232
+ session = get_session_from_db(session_id)
233
+ log.trace(f"Getting tool permissions for session {session_id}")
234
+ assert session.data_access_tracker is not None
235
+
236
+ # Filter out specific tools or return empty list
237
+ allowed_tools: list[FunctionTool | ProxyTool | Any] = []
238
+ for tool in response:
239
+ log.trace(f"🔍 Processing tool listing {tool.name}")
240
+ if isinstance(tool, FunctionTool):
241
+ log.trace("🔍 Tool is built-in")
242
+ log.trace(f"🔍 Tool is a FunctionTool: {tool}")
243
+ elif isinstance(tool, ProxyTool):
244
+ log.trace("🔍 Tool is a user-mounted tool")
245
+ log.trace(f"🔍 Tool is a ProxyTool: {tool}")
246
+ else:
247
+ log.warning("🔍 Tool is of unknown type and will be disabled")
248
+ log.trace(f"🔍 Tool is a unknown type: {tool}")
249
+ continue
250
+
251
+ log.trace(f"🔍 Getting permissions for tool {tool.name}")
252
+ permissions = session.data_access_tracker.get_tool_permissions(tool.name)
253
+ log.trace(f"🔍 Tool permissions: {permissions}")
254
+ if permissions["enabled"]:
255
+ allowed_tools.append(tool)
256
+ else:
257
+ log.warning(
258
+ f"🔍 Tool {tool.name} is disabled on not configured and will not be allowed"
259
+ )
260
+ continue
261
+
262
+ return allowed_tools # type: ignore
263
+
264
+ async def on_call_tool(
265
+ self,
266
+ context: MiddlewareContext[mt.CallToolRequestParams], # type: ignore
267
+ call_next: CallNext[mt.CallToolRequestParams, ToolResult], # type: ignore
268
+ ) -> ToolResult:
269
+ """Process tool calls and track security implications."""
270
+ session_id = current_session_id_ctxvar.get()
271
+ if session_id is None:
272
+ raise ValueError("No session ID found in context")
273
+ session = get_session_from_db(session_id)
274
+ log.trace(f"Adding tool call to session {session_id}")
275
+
276
+ # Create new tool call
277
+ new_tool_call = ToolCall(
278
+ id=str(uuid.uuid4()),
279
+ tool_name=context.message.name,
280
+ parameters=context.message.arguments or {},
281
+ timestamp=datetime.now(),
282
+ )
283
+ session.tool_calls.append(new_tool_call)
284
+
285
+ assert session.data_access_tracker is not None
286
+ log.debug(f"🔍 Analyzing tool {context.message.name} for security implications")
287
+ _ = session.data_access_tracker.add_tool_call(context.message.name)
288
+
289
+ # Update database session
290
+ with create_db_session() as db_session:
291
+ db_session_model = db_session.execute(
292
+ select(MCPSessionModel).where(MCPSessionModel.session_id == session_id)
293
+ ).scalar_one()
294
+
295
+ # Convert tool calls to dict format for JSON storage
296
+ tool_calls_dict = [
297
+ {
298
+ "id": tc.id,
299
+ "tool_name": tc.tool_name,
300
+ "parameters": tc.parameters,
301
+ "timestamp": tc.timestamp.isoformat(),
302
+ "duration_ms": tc.duration_ms,
303
+ "status": tc.status,
304
+ "result": tc.result,
305
+ }
306
+ for tc in session.tool_calls
307
+ ]
308
+ # Update the tool_calls for this session
309
+ db_session_model.tool_calls = tool_calls_dict # type: ignore
310
+ db_session_model.data_access_summary = session.data_access_tracker.to_dict() # type: ignore
311
+
312
+ db_session.commit()
313
+
314
+ log.trace(f"Tool call {context.message.name} added to session {session_id}")
315
+
316
+ return await call_next(context) # type: ignore
317
+
318
+ # Hooks for Resources
319
+ async def on_list_resources(
320
+ self,
321
+ context: MiddlewareContext[Any], # type: ignore
322
+ call_next: CallNext[Any, Any], # type: ignore
323
+ ) -> Any:
324
+ """Process resource access and track security implications."""
325
+ log.trace("🔍 on_list_resources")
326
+ # Get the original response
327
+ response = await call_next(context)
328
+ log.trace(f"🔍 on_list_resources response: {response}")
329
+
330
+ session_id = current_session_id_ctxvar.get()
331
+ if session_id is None:
332
+ raise ValueError("No session ID found in context")
333
+ session = get_session_from_db(session_id)
334
+ log.trace(f"Getting tool permissions for session {session_id}")
335
+ assert session.data_access_tracker is not None
336
+
337
+ # Filter out specific tools or return empty list
338
+ allowed_resources: list[FunctionResource | ProxyResource | Any] = []
339
+ for resource in response:
340
+ resource_name = str(resource.uri)
341
+ log.trace(f"🔍 Processing resource listing {resource_name}")
342
+ if isinstance(resource, FunctionResource):
343
+ log.trace("🔍 Resource is built-in")
344
+ log.trace(f"🔍 Resource is a FunctionResource: {resource}")
345
+ elif isinstance(resource, ProxyResource):
346
+ log.trace("🔍 Resource is a user-mounted tool")
347
+ log.trace(f"🔍 Resource is a ProxyResource: {resource}")
348
+ else:
349
+ log.warning("🔍 Resource is of unknown type and will be disabled")
350
+ log.trace(f"🔍 Resource is a unknown type: {resource}")
351
+ continue
352
+
353
+ log.trace(f"🔍 Getting permissions for resource {resource_name}")
354
+ permissions = session.data_access_tracker.get_resource_permissions(resource_name)
355
+ log.trace(f"🔍 Resource permissions: {permissions}")
356
+ if permissions["enabled"]:
357
+ allowed_resources.append(resource)
358
+ else:
359
+ log.warning(
360
+ f"🔍 Resource {resource_name} is disabled on not configured and will not be allowed"
361
+ )
362
+ continue
363
+
364
+ return allowed_resources # type: ignore
365
+
366
+ async def on_read_resource(
367
+ self,
368
+ context: MiddlewareContext[Any], # type: ignore
369
+ call_next: CallNext[Any, Any], # type: ignore
370
+ ) -> Any:
371
+ """Process resource access and track security implications."""
372
+ session_id = current_session_id_ctxvar.get()
373
+ if session_id is None:
374
+ log.warning("No session ID found for resource access tracking")
375
+ return await call_next(context)
376
+
377
+ session = get_session_from_db(session_id)
378
+ log.trace(f"Adding resource access to session {session_id}")
379
+ assert session.data_access_tracker is not None
380
+
381
+ # Get the resource name from the context
382
+ resource_name = str(context.message.uri)
383
+
384
+ log.debug(f"🔍 Analyzing resource {resource_name} for security implications")
385
+ _ = session.data_access_tracker.add_resource_access(resource_name)
386
+
387
+ # Update database session
388
+ with create_db_session() as db_session:
389
+ db_session_model = db_session.execute(
390
+ select(MCPSessionModel).where(MCPSessionModel.session_id == session_id)
391
+ ).scalar_one()
392
+
393
+ db_session_model.data_access_summary = session.data_access_tracker.to_dict() # type: ignore
394
+ db_session.commit()
395
+
396
+ log.trace(f"Resource access {resource_name} added to session {session_id}")
397
+ return await call_next(context)
398
+
399
+ # Hooks for Prompts
400
+ async def on_list_prompts(
401
+ self,
402
+ context: MiddlewareContext[Any], # type: ignore
403
+ call_next: CallNext[Any, Any], # type: ignore
404
+ ) -> Any:
405
+ """Process resource access and track security implications."""
406
+ log.debug("🔍 on_list_prompts")
407
+ # Get the original response
408
+ response = await call_next(context)
409
+ log.debug(f"🔍 on_list_prompts response: {response}")
410
+
411
+ session_id = current_session_id_ctxvar.get()
412
+ if session_id is None:
413
+ raise ValueError("No session ID found in context")
414
+ session = get_session_from_db(session_id)
415
+ log.trace(f"Getting prompt permissions for session {session_id}")
416
+ assert session.data_access_tracker is not None
417
+
418
+ # Filter out specific tools or return empty list
419
+ allowed_prompts: list[ProxyPrompt | Any] = []
420
+ for prompt in response:
421
+ prompt_name = str(prompt.name)
422
+ log.trace(f"🔍 Processing prompt listing {prompt_name}")
423
+ if isinstance(prompt, FunctionPrompt):
424
+ log.trace("🔍 Prompt is built-in")
425
+ log.trace(f"🔍 Prompt is a FunctionPrompt: {prompt}")
426
+ elif isinstance(prompt, ProxyPrompt):
427
+ log.trace("🔍 Prompt is a user-mounted tool")
428
+ log.trace(f"🔍 Prompt is a ProxyPrompt: {prompt}")
429
+ else:
430
+ log.warning("🔍 Prompt is of unknown type and will be disabled")
431
+ log.trace(f"🔍 Prompt is a unknown type: {prompt}")
432
+ continue
433
+
434
+ log.trace(f"🔍 Getting permissions for prompt {prompt_name}")
435
+ permissions = session.data_access_tracker.get_prompt_permissions(prompt_name)
436
+ log.trace(f"🔍 Prompt permissions: {permissions}")
437
+ if permissions["enabled"]:
438
+ allowed_prompts.append(prompt)
439
+ else:
440
+ log.warning(
441
+ f"🔍 Prompt {prompt_name} is disabled on not configured and will not be allowed"
442
+ )
443
+ continue
444
+
445
+ return allowed_prompts # type: ignore
446
+
447
+ async def on_get_prompt(
448
+ self,
449
+ context: MiddlewareContext[Any], # type: ignore
450
+ call_next: CallNext[Any, Any], # type: ignore
451
+ ) -> Any:
452
+ """Process prompt access and track security implications."""
453
+ session_id = current_session_id_ctxvar.get()
454
+ if session_id is None:
455
+ log.warning("No session ID found for prompt access tracking")
456
+ return await call_next(context)
457
+
458
+ session = get_session_from_db(session_id)
459
+ log.trace(f"Adding prompt access to session {session_id}")
460
+ assert session.data_access_tracker is not None
461
+
462
+ prompt_name = context.message.name
463
+
464
+ log.debug(f"🔍 Analyzing prompt {prompt_name} for security implications")
465
+ _ = session.data_access_tracker.add_prompt_access(prompt_name)
466
+
467
+ # Update database session
468
+ with create_db_session() as db_session:
469
+ db_session_model = db_session.execute(
470
+ select(MCPSessionModel).where(MCPSessionModel.session_id == session_id)
471
+ ).scalar_one()
472
+
473
+ db_session_model.data_access_summary = session.data_access_tracker.to_dict() # type: ignore
474
+ db_session.commit()
475
+
476
+ log.trace(f"Prompt access {prompt_name} added to session {session_id}")
477
+ return await call_next(context)