proxilion 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. proxilion/__init__.py +136 -0
  2. proxilion/audit/__init__.py +133 -0
  3. proxilion/audit/base_exporters.py +527 -0
  4. proxilion/audit/compliance/__init__.py +130 -0
  5. proxilion/audit/compliance/base.py +457 -0
  6. proxilion/audit/compliance/eu_ai_act.py +603 -0
  7. proxilion/audit/compliance/iso27001.py +544 -0
  8. proxilion/audit/compliance/soc2.py +491 -0
  9. proxilion/audit/events.py +493 -0
  10. proxilion/audit/explainability.py +1173 -0
  11. proxilion/audit/exporters/__init__.py +58 -0
  12. proxilion/audit/exporters/aws_s3.py +636 -0
  13. proxilion/audit/exporters/azure_storage.py +608 -0
  14. proxilion/audit/exporters/cloud_base.py +468 -0
  15. proxilion/audit/exporters/gcp_storage.py +570 -0
  16. proxilion/audit/exporters/multi_exporter.py +498 -0
  17. proxilion/audit/hash_chain.py +652 -0
  18. proxilion/audit/logger.py +543 -0
  19. proxilion/caching/__init__.py +49 -0
  20. proxilion/caching/tool_cache.py +633 -0
  21. proxilion/context/__init__.py +73 -0
  22. proxilion/context/context_window.py +556 -0
  23. proxilion/context/message_history.py +505 -0
  24. proxilion/context/session.py +735 -0
  25. proxilion/contrib/__init__.py +51 -0
  26. proxilion/contrib/anthropic.py +609 -0
  27. proxilion/contrib/google.py +1012 -0
  28. proxilion/contrib/langchain.py +641 -0
  29. proxilion/contrib/mcp.py +893 -0
  30. proxilion/contrib/openai.py +646 -0
  31. proxilion/core.py +3058 -0
  32. proxilion/decorators.py +966 -0
  33. proxilion/engines/__init__.py +287 -0
  34. proxilion/engines/base.py +266 -0
  35. proxilion/engines/casbin_engine.py +412 -0
  36. proxilion/engines/opa_engine.py +493 -0
  37. proxilion/engines/simple.py +437 -0
  38. proxilion/exceptions.py +887 -0
  39. proxilion/guards/__init__.py +54 -0
  40. proxilion/guards/input_guard.py +522 -0
  41. proxilion/guards/output_guard.py +634 -0
  42. proxilion/observability/__init__.py +198 -0
  43. proxilion/observability/cost_tracker.py +866 -0
  44. proxilion/observability/hooks.py +683 -0
  45. proxilion/observability/metrics.py +798 -0
  46. proxilion/observability/session_cost_tracker.py +1063 -0
  47. proxilion/policies/__init__.py +67 -0
  48. proxilion/policies/base.py +304 -0
  49. proxilion/policies/builtin.py +486 -0
  50. proxilion/policies/registry.py +376 -0
  51. proxilion/providers/__init__.py +201 -0
  52. proxilion/providers/adapter.py +468 -0
  53. proxilion/providers/anthropic_adapter.py +330 -0
  54. proxilion/providers/gemini_adapter.py +391 -0
  55. proxilion/providers/openai_adapter.py +294 -0
  56. proxilion/py.typed +0 -0
  57. proxilion/resilience/__init__.py +81 -0
  58. proxilion/resilience/degradation.py +615 -0
  59. proxilion/resilience/fallback.py +555 -0
  60. proxilion/resilience/retry.py +554 -0
  61. proxilion/scheduling/__init__.py +57 -0
  62. proxilion/scheduling/priority_queue.py +419 -0
  63. proxilion/scheduling/scheduler.py +459 -0
  64. proxilion/security/__init__.py +244 -0
  65. proxilion/security/agent_trust.py +968 -0
  66. proxilion/security/behavioral_drift.py +794 -0
  67. proxilion/security/cascade_protection.py +869 -0
  68. proxilion/security/circuit_breaker.py +428 -0
  69. proxilion/security/cost_limiter.py +690 -0
  70. proxilion/security/idor_protection.py +460 -0
  71. proxilion/security/intent_capsule.py +849 -0
  72. proxilion/security/intent_validator.py +495 -0
  73. proxilion/security/memory_integrity.py +767 -0
  74. proxilion/security/rate_limiter.py +509 -0
  75. proxilion/security/scope_enforcer.py +680 -0
  76. proxilion/security/sequence_validator.py +636 -0
  77. proxilion/security/trust_boundaries.py +784 -0
  78. proxilion/streaming/__init__.py +70 -0
  79. proxilion/streaming/detector.py +761 -0
  80. proxilion/streaming/transformer.py +674 -0
  81. proxilion/timeouts/__init__.py +55 -0
  82. proxilion/timeouts/decorators.py +477 -0
  83. proxilion/timeouts/manager.py +545 -0
  84. proxilion/tools/__init__.py +69 -0
  85. proxilion/tools/decorators.py +493 -0
  86. proxilion/tools/registry.py +732 -0
  87. proxilion/types.py +339 -0
  88. proxilion/validation/__init__.py +93 -0
  89. proxilion/validation/pydantic_schema.py +351 -0
  90. proxilion/validation/schema.py +651 -0
  91. proxilion-0.0.1.dist-info/METADATA +872 -0
  92. proxilion-0.0.1.dist-info/RECORD +94 -0
  93. proxilion-0.0.1.dist-info/WHEEL +4 -0
  94. proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,893 @@
1
+ """
2
+ Model Context Protocol (MCP) integration for Proxilion.
3
+
4
+ This module provides authorization wrappers for MCP tools and servers,
5
+ enabling secure tool execution with user-context authorization.
6
+
7
+ MCP is Anthropic's protocol for connecting AI agents to tools. Proxilion
8
+ intercepts MCP tool calls to add authorization before execution.
9
+
10
+ Key Security Features:
11
+ - Validate client authentication from MCP session
12
+ - Extract user context from MCP metadata
13
+ - Apply authorization before tool execution
14
+ - Audit log all MCP tool invocations
15
+ - Prevent tool shadowing attacks (verify tool definitions)
16
+ - No ambient authority (capabilities must be explicitly passed)
17
+ - Session-bound permissions (expire with session)
18
+ - Cross-server trust isolation
19
+
20
+ Example:
21
+ >>> from proxilion import Proxilion, Policy
22
+ >>> from proxilion.contrib.mcp import MCPToolWrapper, ProxilionMCPServer
23
+ >>>
24
+ >>> auth = Proxilion()
25
+ >>>
26
+ >>> @auth.policy("file_read")
27
+ ... class FileReadPolicy(Policy):
28
+ ... def can_execute(self, context):
29
+ ... path = context.get("path", "")
30
+ ... return not path.startswith("/etc/") and ".." not in path
31
+ >>>
32
+ >>> wrapped_tool = MCPToolWrapper(
33
+ ... original_tool=file_read_tool,
34
+ ... proxilion=auth,
35
+ ... resource="file_read"
36
+ ... )
37
+ """
38
+
39
+ from __future__ import annotations
40
+
41
+ import hashlib
42
+ import json
43
+ import logging
44
+ import threading
45
+ from collections.abc import Callable
46
+ from dataclasses import dataclass, field
47
+ from datetime import datetime, timezone
48
+ from typing import Any, Protocol, TypeVar, runtime_checkable
49
+
50
+ from proxilion.exceptions import (
51
+ AuthorizationError,
52
+ ProxilionError,
53
+ )
54
+ from proxilion.types import AgentContext, UserContext
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+ T = TypeVar("T")
59
+
60
+
61
+ class MCPSecurityError(ProxilionError):
62
+ """Base exception for MCP security errors."""
63
+ pass
64
+
65
+
66
+ class ToolShadowingError(MCPSecurityError):
67
+ """Raised when a tool shadowing attack is detected."""
68
+
69
+ def __init__(self, tool_name: str, expected_hash: str, actual_hash: str) -> None:
70
+ self.tool_name = tool_name
71
+ self.expected_hash = expected_hash
72
+ self.actual_hash = actual_hash
73
+ super().__init__(
74
+ f"Tool shadowing detected for '{tool_name}': "
75
+ f"expected hash {expected_hash[:16]}..., got {actual_hash[:16]}..."
76
+ )
77
+
78
+
79
+ class SessionExpiredError(MCPSecurityError):
80
+ """Raised when an MCP session has expired."""
81
+
82
+ def __init__(self, session_id: str) -> None:
83
+ self.session_id = session_id
84
+ super().__init__(f"MCP session expired: {session_id}")
85
+
86
+
87
+ class InvalidClientError(MCPSecurityError):
88
+ """Raised when MCP client authentication fails."""
89
+
90
+ def __init__(self, reason: str) -> None:
91
+ self.reason = reason
92
+ super().__init__(f"Invalid MCP client: {reason}")
93
+
94
+
95
+ @runtime_checkable
96
+ class MCPTool(Protocol):
97
+ """Protocol for MCP tool definitions."""
98
+
99
+ @property
100
+ def name(self) -> str:
101
+ """Tool name."""
102
+ ...
103
+
104
+ @property
105
+ def description(self) -> str:
106
+ """Tool description."""
107
+ ...
108
+
109
+ @property
110
+ def input_schema(self) -> dict[str, Any]:
111
+ """JSON Schema for tool inputs."""
112
+ ...
113
+
114
+ async def execute(self, arguments: dict[str, Any]) -> Any:
115
+ """Execute the tool with given arguments."""
116
+ ...
117
+
118
+
119
+ @runtime_checkable
120
+ class MCPServer(Protocol):
121
+ """Protocol for MCP server implementations."""
122
+
123
+ @property
124
+ def tools(self) -> list[Any]:
125
+ """List of available tools."""
126
+ ...
127
+
128
+ async def handle_tool_call(
129
+ self,
130
+ tool_name: str,
131
+ arguments: dict[str, Any],
132
+ ) -> Any:
133
+ """Handle a tool call request."""
134
+ ...
135
+
136
+
137
+ @dataclass
138
+ class MCPSession:
139
+ """
140
+ Represents an MCP session with user context and permissions.
141
+
142
+ Attributes:
143
+ session_id: Unique session identifier.
144
+ user_context: The authenticated user for this session.
145
+ agent_context: Optional agent context.
146
+ created_at: Session creation time.
147
+ expires_at: Session expiration time (None for no expiry).
148
+ permissions: Session-specific permissions.
149
+ metadata: Additional session metadata.
150
+ """
151
+ session_id: str
152
+ user_context: UserContext
153
+ agent_context: AgentContext | None = None
154
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
155
+ expires_at: datetime | None = None
156
+ permissions: dict[str, list[str]] = field(default_factory=dict)
157
+ metadata: dict[str, Any] = field(default_factory=dict)
158
+
159
+ def is_expired(self) -> bool:
160
+ """Check if the session has expired."""
161
+ if self.expires_at is None:
162
+ return False
163
+ return datetime.now(timezone.utc) > self.expires_at
164
+
165
+ def has_permission(self, resource: str, action: str) -> bool:
166
+ """Check if session has a specific permission."""
167
+ if resource not in self.permissions:
168
+ return True # No explicit restriction
169
+ return action in self.permissions[resource]
170
+
171
+
172
+ @dataclass
173
+ class ToolDefinitionHash:
174
+ """Hash of a tool definition for shadowing detection."""
175
+ tool_name: str
176
+ definition_hash: str
177
+ registered_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
178
+
179
+
180
+ class MCPSessionManager:
181
+ """
182
+ Manages MCP sessions with user contexts.
183
+
184
+ Provides session creation, validation, and expiration handling
185
+ for MCP connections.
186
+
187
+ Example:
188
+ >>> manager = MCPSessionManager(default_ttl=3600)
189
+ >>> session = manager.create_session(user_context)
190
+ >>> if manager.validate_session(session.session_id):
191
+ ... # Session is valid
192
+ ... pass
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ default_ttl: float | None = 3600.0,
198
+ ) -> None:
199
+ """
200
+ Initialize the session manager.
201
+
202
+ Args:
203
+ default_ttl: Default session TTL in seconds (None for no expiry).
204
+ """
205
+ self.default_ttl = default_ttl
206
+ self._sessions: dict[str, MCPSession] = {}
207
+ self._lock = threading.RLock()
208
+
209
+ def create_session(
210
+ self,
211
+ user_context: UserContext,
212
+ agent_context: AgentContext | None = None,
213
+ ttl: float | None = None,
214
+ permissions: dict[str, list[str]] | None = None,
215
+ metadata: dict[str, Any] | None = None,
216
+ ) -> MCPSession:
217
+ """
218
+ Create a new MCP session.
219
+
220
+ Args:
221
+ user_context: The user for this session.
222
+ agent_context: Optional agent context.
223
+ ttl: Session TTL in seconds (None uses default).
224
+ permissions: Session-specific permissions.
225
+ metadata: Additional metadata.
226
+
227
+ Returns:
228
+ The created MCPSession.
229
+ """
230
+ import uuid
231
+
232
+ session_ttl = ttl if ttl is not None else self.default_ttl
233
+ expires_at = None
234
+ if session_ttl is not None:
235
+ from datetime import timedelta
236
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=session_ttl)
237
+
238
+ session = MCPSession(
239
+ session_id=str(uuid.uuid4()),
240
+ user_context=user_context,
241
+ agent_context=agent_context,
242
+ expires_at=expires_at,
243
+ permissions=permissions or {},
244
+ metadata=metadata or {},
245
+ )
246
+
247
+ with self._lock:
248
+ self._sessions[session.session_id] = session
249
+ self._cleanup_expired()
250
+
251
+ logger.debug(f"Created MCP session: {session.session_id} for user {user_context.user_id}")
252
+ return session
253
+
254
+ def get_session(self, session_id: str) -> MCPSession | None:
255
+ """
256
+ Get a session by ID.
257
+
258
+ Args:
259
+ session_id: The session ID.
260
+
261
+ Returns:
262
+ The session, or None if not found or expired.
263
+ """
264
+ with self._lock:
265
+ session = self._sessions.get(session_id)
266
+ if session is None:
267
+ return None
268
+ if session.is_expired():
269
+ del self._sessions[session_id]
270
+ return None
271
+ return session
272
+
273
+ def validate_session(self, session_id: str) -> bool:
274
+ """
275
+ Validate that a session exists and is not expired.
276
+
277
+ Args:
278
+ session_id: The session ID.
279
+
280
+ Returns:
281
+ True if session is valid.
282
+ """
283
+ return self.get_session(session_id) is not None
284
+
285
+ def invalidate_session(self, session_id: str) -> None:
286
+ """
287
+ Invalidate a session.
288
+
289
+ Args:
290
+ session_id: The session ID to invalidate.
291
+ """
292
+ with self._lock:
293
+ self._sessions.pop(session_id, None)
294
+
295
+ def _cleanup_expired(self) -> None:
296
+ """Remove expired sessions."""
297
+ expired = [
298
+ sid for sid, session in self._sessions.items()
299
+ if session.is_expired()
300
+ ]
301
+ for sid in expired:
302
+ del self._sessions[sid]
303
+
304
+
305
+ class ToolDefinitionRegistry:
306
+ """
307
+ Registry for verifying tool definitions against shadowing attacks.
308
+
309
+ Tool shadowing occurs when a malicious tool is substituted for
310
+ a legitimate one. This registry maintains hashes of registered
311
+ tool definitions to detect such attacks.
312
+ """
313
+
314
+ def __init__(self) -> None:
315
+ self._hashes: dict[str, ToolDefinitionHash] = {}
316
+ self._lock = threading.RLock()
317
+
318
+ def register_tool(
319
+ self,
320
+ tool_name: str,
321
+ definition: dict[str, Any],
322
+ ) -> str:
323
+ """
324
+ Register a tool definition.
325
+
326
+ Args:
327
+ tool_name: The tool name.
328
+ definition: The tool definition dict.
329
+
330
+ Returns:
331
+ The computed hash.
332
+ """
333
+ definition_hash = self._compute_hash(definition)
334
+
335
+ with self._lock:
336
+ self._hashes[tool_name] = ToolDefinitionHash(
337
+ tool_name=tool_name,
338
+ definition_hash=definition_hash,
339
+ )
340
+
341
+ logger.debug(f"Registered tool definition: {tool_name} -> {definition_hash[:16]}...")
342
+ return definition_hash
343
+
344
+ def verify_tool(
345
+ self,
346
+ tool_name: str,
347
+ definition: dict[str, Any],
348
+ ) -> bool:
349
+ """
350
+ Verify a tool definition matches the registered one.
351
+
352
+ Args:
353
+ tool_name: The tool name.
354
+ definition: The tool definition to verify.
355
+
356
+ Returns:
357
+ True if the definition matches.
358
+
359
+ Raises:
360
+ ToolShadowingError: If definitions don't match.
361
+ """
362
+ with self._lock:
363
+ registered = self._hashes.get(tool_name)
364
+
365
+ if registered is None:
366
+ # Tool not registered, allow it
367
+ return True
368
+
369
+ actual_hash = self._compute_hash(definition)
370
+
371
+ if actual_hash != registered.definition_hash:
372
+ raise ToolShadowingError(
373
+ tool_name=tool_name,
374
+ expected_hash=registered.definition_hash,
375
+ actual_hash=actual_hash,
376
+ )
377
+
378
+ return True
379
+
380
+ def _compute_hash(self, definition: dict[str, Any]) -> str:
381
+ """Compute a hash of a tool definition."""
382
+ canonical = json.dumps(definition, sort_keys=True, separators=(",", ":"))
383
+ return hashlib.sha256(canonical.encode()).hexdigest()
384
+
385
+
386
+ class MCPToolWrapper:
387
+ """
388
+ Wraps an MCP tool with Proxilion authorization.
389
+
390
+ Intercepts tool calls before execution to apply authorization
391
+ checks, schema validation, and audit logging.
392
+
393
+ Example:
394
+ >>> from proxilion import Proxilion
395
+ >>> from proxilion.contrib.mcp import MCPToolWrapper
396
+ >>>
397
+ >>> auth = Proxilion()
398
+ >>>
399
+ >>> # Define your MCP tool
400
+ >>> class FileReadTool:
401
+ ... name = "file_read"
402
+ ... description = "Read a file"
403
+ ... input_schema = {"type": "object", "properties": {"path": {"type": "string"}}}
404
+ ...
405
+ ... async def execute(self, arguments):
406
+ ... return open(arguments["path"]).read()
407
+ >>>
408
+ >>> # Wrap with authorization
409
+ >>> wrapped = MCPToolWrapper(
410
+ ... original_tool=FileReadTool(),
411
+ ... proxilion=auth,
412
+ ... resource="file_read",
413
+ ... )
414
+ >>>
415
+ >>> # Execute with user context
416
+ >>> result = await wrapped.execute(
417
+ ... arguments={"path": "/data/file.txt"},
418
+ ... session=mcp_session,
419
+ ... )
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ original_tool: Any,
425
+ proxilion: Any,
426
+ resource: str | None = None,
427
+ action: str = "execute",
428
+ validate_schema: bool = True,
429
+ require_session: bool = True,
430
+ ) -> None:
431
+ """
432
+ Initialize the tool wrapper.
433
+
434
+ Args:
435
+ original_tool: The original MCP tool to wrap.
436
+ proxilion: The Proxilion instance for authorization.
437
+ resource: Resource name for policies (default: tool name).
438
+ action: Action name for authorization.
439
+ validate_schema: Whether to validate against schema.
440
+ require_session: Whether a valid session is required.
441
+ """
442
+ self.original_tool = original_tool
443
+ self.proxilion = proxilion
444
+ self.resource = resource or getattr(original_tool, "name", "unknown")
445
+ self.action = action
446
+ self.validate_schema = validate_schema
447
+ self.require_session = require_session
448
+
449
+ # Extract tool properties
450
+ self._name = getattr(original_tool, "name", "unknown")
451
+ self._description = getattr(original_tool, "description", "")
452
+ self._input_schema = getattr(original_tool, "input_schema", {})
453
+
454
+ @property
455
+ def name(self) -> str:
456
+ """Get the tool name."""
457
+ return self._name
458
+
459
+ @property
460
+ def description(self) -> str:
461
+ """Get the tool description."""
462
+ return self._description
463
+
464
+ @property
465
+ def input_schema(self) -> dict[str, Any]:
466
+ """Get the input schema."""
467
+ return self._input_schema
468
+
469
+ async def execute(
470
+ self,
471
+ arguments: dict[str, Any],
472
+ session: MCPSession | None = None,
473
+ user_context: UserContext | None = None,
474
+ ) -> Any:
475
+ """
476
+ Execute the tool with authorization.
477
+
478
+ Args:
479
+ arguments: Tool arguments.
480
+ session: MCP session (preferred).
481
+ user_context: Direct user context (fallback).
482
+
483
+ Returns:
484
+ The tool execution result.
485
+
486
+ Raises:
487
+ SessionExpiredError: If session is expired.
488
+ AuthorizationError: If authorization fails.
489
+ """
490
+ # Get user context from session or direct parameter
491
+ if session is not None:
492
+ if session.is_expired():
493
+ raise SessionExpiredError(session.session_id)
494
+ user = session.user_context
495
+ # agent_context available via session.agent_context if needed
496
+
497
+ # Check session-specific permissions
498
+ if not session.has_permission(self.resource, self.action):
499
+ raise AuthorizationError(
500
+ user=user.user_id,
501
+ action=self.action,
502
+ resource=self.resource,
503
+ reason="Session lacks required permission",
504
+ )
505
+ elif user_context is not None:
506
+ user = user_context
507
+ elif self.require_session:
508
+ raise AuthorizationError(
509
+ user="unknown",
510
+ action=self.action,
511
+ resource=self.resource,
512
+ reason="No session or user context provided",
513
+ )
514
+ else:
515
+ raise AuthorizationError(
516
+ user="unknown",
517
+ action=self.action,
518
+ resource=self.resource,
519
+ reason="No user context available",
520
+ )
521
+
522
+ # Build context for authorization
523
+ context = {
524
+ "arguments": arguments,
525
+ "tool_name": self.name,
526
+ **arguments, # Flatten arguments for policy access
527
+ }
528
+
529
+ # Check authorization
530
+ result = self.proxilion.check(user, self.action, self.resource, context)
531
+ if not result.allowed:
532
+ raise AuthorizationError(
533
+ user=user.user_id,
534
+ action=self.action,
535
+ resource=self.resource,
536
+ reason=result.reason,
537
+ )
538
+
539
+ # Execute original tool
540
+ execute_method = getattr(self.original_tool, "execute", None)
541
+ if execute_method is None:
542
+ # Try calling the tool directly
543
+ if callable(self.original_tool):
544
+ return await self.original_tool(arguments)
545
+ raise ValueError(f"Tool {self.name} has no execute method")
546
+
547
+ return await execute_method(arguments)
548
+
549
+ def __call__(
550
+ self,
551
+ arguments: dict[str, Any],
552
+ session: MCPSession | None = None,
553
+ user_context: UserContext | None = None,
554
+ ) -> Any:
555
+ """Synchronous wrapper for execute."""
556
+ import asyncio
557
+ return asyncio.run(self.execute(arguments, session, user_context))
558
+
559
+
560
+ class ProxilionMCPServer:
561
+ """
562
+ Wraps an MCP server with Proxilion authorization.
563
+
564
+ Adds an authorization layer to all tool calls on the server,
565
+ with configurable per-tool policies and default behavior.
566
+
567
+ Example:
568
+ >>> from proxilion import Proxilion
569
+ >>> from proxilion.contrib.mcp import ProxilionMCPServer
570
+ >>>
571
+ >>> auth = Proxilion()
572
+ >>> secure_server = ProxilionMCPServer(
573
+ ... original_server=mcp_server,
574
+ ... proxilion=auth,
575
+ ... default_policy="deny",
576
+ ... )
577
+ >>>
578
+ >>> # Handle tool call with authorization
579
+ >>> result = await secure_server.handle_tool_call(
580
+ ... tool_name="file_read",
581
+ ... arguments={"path": "/data/file.txt"},
582
+ ... session=mcp_session,
583
+ ... )
584
+ """
585
+
586
+ def __init__(
587
+ self,
588
+ original_server: Any,
589
+ proxilion: Any,
590
+ default_policy: str = "deny",
591
+ session_manager: MCPSessionManager | None = None,
592
+ tool_registry: ToolDefinitionRegistry | None = None,
593
+ verify_tool_definitions: bool = True,
594
+ ) -> None:
595
+ """
596
+ Initialize the secure MCP server.
597
+
598
+ Args:
599
+ original_server: The original MCP server.
600
+ proxilion: Proxilion instance for authorization.
601
+ default_policy: Default policy ("allow" or "deny").
602
+ session_manager: Session manager (created if None).
603
+ tool_registry: Tool definition registry (created if None).
604
+ verify_tool_definitions: Whether to verify tool definitions.
605
+ """
606
+ self.original_server = original_server
607
+ self.proxilion = proxilion
608
+ self.default_policy = default_policy
609
+ self.session_manager = session_manager or MCPSessionManager()
610
+ self.tool_registry = tool_registry or ToolDefinitionRegistry()
611
+ self.verify_tool_definitions = verify_tool_definitions
612
+
613
+ # Cache wrapped tools
614
+ self._wrapped_tools: dict[str, MCPToolWrapper] = {}
615
+ self._lock = threading.RLock()
616
+
617
+ # Register tool definitions if available
618
+ if hasattr(original_server, "tools") and self.verify_tool_definitions:
619
+ self._register_tools()
620
+
621
+ def _register_tools(self) -> None:
622
+ """Register tool definitions from the original server."""
623
+ tools = getattr(self.original_server, "tools", [])
624
+ for tool in tools:
625
+ name = getattr(tool, "name", None)
626
+ if name:
627
+ definition = {
628
+ "name": name,
629
+ "description": getattr(tool, "description", ""),
630
+ "input_schema": getattr(tool, "input_schema", {}),
631
+ }
632
+ self.tool_registry.register_tool(name, definition)
633
+
634
+ @property
635
+ def tools(self) -> list[Any]:
636
+ """Get the list of available tools."""
637
+ if hasattr(self.original_server, "tools"):
638
+ return self.original_server.tools
639
+ return []
640
+
641
+ def get_wrapped_tool(self, tool_name: str) -> MCPToolWrapper | None:
642
+ """
643
+ Get a wrapped tool by name.
644
+
645
+ Args:
646
+ tool_name: The tool name.
647
+
648
+ Returns:
649
+ The wrapped tool, or None if not found.
650
+ """
651
+ with self._lock:
652
+ if tool_name in self._wrapped_tools:
653
+ return self._wrapped_tools[tool_name]
654
+
655
+ # Find the original tool
656
+ original_tool = self._find_tool(tool_name)
657
+ if original_tool is None:
658
+ return None
659
+
660
+ # Wrap it
661
+ wrapped = MCPToolWrapper(
662
+ original_tool=original_tool,
663
+ proxilion=self.proxilion,
664
+ resource=tool_name,
665
+ )
666
+ self._wrapped_tools[tool_name] = wrapped
667
+ return wrapped
668
+
669
+ def _find_tool(self, tool_name: str) -> Any:
670
+ """Find a tool by name in the original server."""
671
+ tools = getattr(self.original_server, "tools", [])
672
+ for tool in tools:
673
+ if getattr(tool, "name", None) == tool_name:
674
+ return tool
675
+ return None
676
+
677
+ async def handle_tool_call(
678
+ self,
679
+ tool_name: str,
680
+ arguments: dict[str, Any],
681
+ session: MCPSession | None = None,
682
+ session_id: str | None = None,
683
+ user_context: UserContext | None = None,
684
+ tool_definition: dict[str, Any] | None = None,
685
+ ) -> Any:
686
+ """
687
+ Handle a tool call with authorization.
688
+
689
+ Args:
690
+ tool_name: Name of the tool to call.
691
+ arguments: Tool arguments.
692
+ session: MCP session object.
693
+ session_id: Session ID (to look up session).
694
+ user_context: Direct user context (fallback).
695
+ tool_definition: Tool definition for shadowing check.
696
+
697
+ Returns:
698
+ The tool execution result.
699
+
700
+ Raises:
701
+ ToolShadowingError: If tool definition doesn't match.
702
+ SessionExpiredError: If session is expired.
703
+ AuthorizationError: If authorization fails.
704
+ """
705
+ # Verify tool definition if provided
706
+ if tool_definition and self.verify_tool_definitions:
707
+ self.tool_registry.verify_tool(tool_name, tool_definition)
708
+
709
+ # Resolve session
710
+ if session is None and session_id is not None:
711
+ session = self.session_manager.get_session(session_id)
712
+ if session is None:
713
+ raise SessionExpiredError(session_id)
714
+
715
+ # Get wrapped tool
716
+ wrapped_tool = self.get_wrapped_tool(tool_name)
717
+
718
+ if wrapped_tool is None:
719
+ # Tool not found - apply default policy
720
+ if self.default_policy == "deny":
721
+ user_id = "unknown"
722
+ if session:
723
+ user_id = session.user_context.user_id
724
+ elif user_context:
725
+ user_id = user_context.user_id
726
+
727
+ raise AuthorizationError(
728
+ user=user_id,
729
+ action="execute",
730
+ resource=tool_name,
731
+ reason=f"Tool '{tool_name}' not found and default policy is deny",
732
+ )
733
+
734
+ # Default allow - pass through to original server
735
+ if hasattr(self.original_server, "handle_tool_call"):
736
+ return await self.original_server.handle_tool_call(tool_name, arguments)
737
+ raise ValueError(f"Tool '{tool_name}' not found")
738
+
739
+ # Execute with authorization
740
+ return await wrapped_tool.execute(
741
+ arguments=arguments,
742
+ session=session,
743
+ user_context=user_context,
744
+ )
745
+
746
+ def create_session(
747
+ self,
748
+ user_context: UserContext,
749
+ agent_context: AgentContext | None = None,
750
+ **kwargs: Any,
751
+ ) -> MCPSession:
752
+ """
753
+ Create a new session for this server.
754
+
755
+ Args:
756
+ user_context: The user for the session.
757
+ agent_context: Optional agent context.
758
+ **kwargs: Additional session parameters.
759
+
760
+ Returns:
761
+ The created session.
762
+ """
763
+ return self.session_manager.create_session(
764
+ user_context=user_context,
765
+ agent_context=agent_context,
766
+ **kwargs,
767
+ )
768
+
769
+ def validate_client(
770
+ self,
771
+ client_id: str,
772
+ client_secret: str | None = None,
773
+ metadata: dict[str, Any] | None = None,
774
+ ) -> bool:
775
+ """
776
+ Validate an MCP client.
777
+
778
+ Override this method to implement custom client validation.
779
+
780
+ Args:
781
+ client_id: The client identifier.
782
+ client_secret: Optional client secret.
783
+ metadata: Additional client metadata.
784
+
785
+ Returns:
786
+ True if client is valid.
787
+ """
788
+ # Default implementation accepts all clients
789
+ # Override for custom validation
790
+ return True
791
+
792
+
793
+ def extract_user_from_mcp_context(
794
+ mcp_context: dict[str, Any],
795
+ user_id_field: str = "user_id",
796
+ roles_field: str = "roles",
797
+ session_id_field: str = "session_id",
798
+ ) -> UserContext:
799
+ """
800
+ Extract UserContext from MCP context/metadata.
801
+
802
+ Helper function to create a UserContext from MCP's
803
+ context mechanism.
804
+
805
+ Args:
806
+ mcp_context: The MCP context dictionary.
807
+ user_id_field: Field name for user ID.
808
+ roles_field: Field name for roles.
809
+ session_id_field: Field name for session ID.
810
+
811
+ Returns:
812
+ Extracted UserContext.
813
+
814
+ Raises:
815
+ InvalidClientError: If required fields are missing.
816
+ """
817
+ user_id = mcp_context.get(user_id_field)
818
+ if not user_id:
819
+ raise InvalidClientError(f"Missing required field: {user_id_field}")
820
+
821
+ roles = mcp_context.get(roles_field, [])
822
+ if isinstance(roles, str):
823
+ roles = [roles]
824
+
825
+ session_id = mcp_context.get(session_id_field)
826
+
827
+ # Extract remaining fields as attributes
828
+ attributes = {
829
+ k: v for k, v in mcp_context.items()
830
+ if k not in {user_id_field, roles_field, session_id_field}
831
+ }
832
+
833
+ return UserContext(
834
+ user_id=user_id,
835
+ roles=roles,
836
+ session_id=session_id,
837
+ attributes=attributes,
838
+ )
839
+
840
+
841
+ def create_mcp_tool_handler(
842
+ proxilion: Any,
843
+ tools: list[Any],
844
+ session_manager: MCPSessionManager | None = None,
845
+ ) -> Callable[[str, dict[str, Any], MCPSession | None], Any]:
846
+ """
847
+ Create a tool handler function for MCP integration.
848
+
849
+ Returns a function that can be used as the tool call handler
850
+ in an MCP server implementation.
851
+
852
+ Args:
853
+ proxilion: Proxilion instance.
854
+ tools: List of MCP tools.
855
+ session_manager: Optional session manager.
856
+
857
+ Returns:
858
+ An async function that handles tool calls.
859
+
860
+ Example:
861
+ >>> handler = create_mcp_tool_handler(auth, [tool1, tool2])
862
+ >>> result = await handler("tool1", {"arg": "value"}, session)
863
+ """
864
+ wrapped_tools: dict[str, MCPToolWrapper] = {}
865
+
866
+ for tool in tools:
867
+ name = getattr(tool, "name", None)
868
+ if name:
869
+ wrapped_tools[name] = MCPToolWrapper(
870
+ original_tool=tool,
871
+ proxilion=proxilion,
872
+ resource=name,
873
+ )
874
+
875
+ async def handler(
876
+ tool_name: str,
877
+ arguments: dict[str, Any],
878
+ session: MCPSession | None = None,
879
+ ) -> Any:
880
+ if tool_name not in wrapped_tools:
881
+ raise AuthorizationError(
882
+ user="unknown",
883
+ action="execute",
884
+ resource=tool_name,
885
+ reason=f"Unknown tool: {tool_name}",
886
+ )
887
+
888
+ return await wrapped_tools[tool_name].execute(
889
+ arguments=arguments,
890
+ session=session,
891
+ )
892
+
893
+ return handler