proxilion 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. proxilion/__init__.py +136 -0
  2. proxilion/audit/__init__.py +133 -0
  3. proxilion/audit/base_exporters.py +527 -0
  4. proxilion/audit/compliance/__init__.py +130 -0
  5. proxilion/audit/compliance/base.py +457 -0
  6. proxilion/audit/compliance/eu_ai_act.py +603 -0
  7. proxilion/audit/compliance/iso27001.py +544 -0
  8. proxilion/audit/compliance/soc2.py +491 -0
  9. proxilion/audit/events.py +493 -0
  10. proxilion/audit/explainability.py +1173 -0
  11. proxilion/audit/exporters/__init__.py +58 -0
  12. proxilion/audit/exporters/aws_s3.py +636 -0
  13. proxilion/audit/exporters/azure_storage.py +608 -0
  14. proxilion/audit/exporters/cloud_base.py +468 -0
  15. proxilion/audit/exporters/gcp_storage.py +570 -0
  16. proxilion/audit/exporters/multi_exporter.py +498 -0
  17. proxilion/audit/hash_chain.py +652 -0
  18. proxilion/audit/logger.py +543 -0
  19. proxilion/caching/__init__.py +49 -0
  20. proxilion/caching/tool_cache.py +633 -0
  21. proxilion/context/__init__.py +73 -0
  22. proxilion/context/context_window.py +556 -0
  23. proxilion/context/message_history.py +505 -0
  24. proxilion/context/session.py +735 -0
  25. proxilion/contrib/__init__.py +51 -0
  26. proxilion/contrib/anthropic.py +609 -0
  27. proxilion/contrib/google.py +1012 -0
  28. proxilion/contrib/langchain.py +641 -0
  29. proxilion/contrib/mcp.py +893 -0
  30. proxilion/contrib/openai.py +646 -0
  31. proxilion/core.py +3058 -0
  32. proxilion/decorators.py +966 -0
  33. proxilion/engines/__init__.py +287 -0
  34. proxilion/engines/base.py +266 -0
  35. proxilion/engines/casbin_engine.py +412 -0
  36. proxilion/engines/opa_engine.py +493 -0
  37. proxilion/engines/simple.py +437 -0
  38. proxilion/exceptions.py +887 -0
  39. proxilion/guards/__init__.py +54 -0
  40. proxilion/guards/input_guard.py +522 -0
  41. proxilion/guards/output_guard.py +634 -0
  42. proxilion/observability/__init__.py +198 -0
  43. proxilion/observability/cost_tracker.py +866 -0
  44. proxilion/observability/hooks.py +683 -0
  45. proxilion/observability/metrics.py +798 -0
  46. proxilion/observability/session_cost_tracker.py +1063 -0
  47. proxilion/policies/__init__.py +67 -0
  48. proxilion/policies/base.py +304 -0
  49. proxilion/policies/builtin.py +486 -0
  50. proxilion/policies/registry.py +376 -0
  51. proxilion/providers/__init__.py +201 -0
  52. proxilion/providers/adapter.py +468 -0
  53. proxilion/providers/anthropic_adapter.py +330 -0
  54. proxilion/providers/gemini_adapter.py +391 -0
  55. proxilion/providers/openai_adapter.py +294 -0
  56. proxilion/py.typed +0 -0
  57. proxilion/resilience/__init__.py +81 -0
  58. proxilion/resilience/degradation.py +615 -0
  59. proxilion/resilience/fallback.py +555 -0
  60. proxilion/resilience/retry.py +554 -0
  61. proxilion/scheduling/__init__.py +57 -0
  62. proxilion/scheduling/priority_queue.py +419 -0
  63. proxilion/scheduling/scheduler.py +459 -0
  64. proxilion/security/__init__.py +244 -0
  65. proxilion/security/agent_trust.py +968 -0
  66. proxilion/security/behavioral_drift.py +794 -0
  67. proxilion/security/cascade_protection.py +869 -0
  68. proxilion/security/circuit_breaker.py +428 -0
  69. proxilion/security/cost_limiter.py +690 -0
  70. proxilion/security/idor_protection.py +460 -0
  71. proxilion/security/intent_capsule.py +849 -0
  72. proxilion/security/intent_validator.py +495 -0
  73. proxilion/security/memory_integrity.py +767 -0
  74. proxilion/security/rate_limiter.py +509 -0
  75. proxilion/security/scope_enforcer.py +680 -0
  76. proxilion/security/sequence_validator.py +636 -0
  77. proxilion/security/trust_boundaries.py +784 -0
  78. proxilion/streaming/__init__.py +70 -0
  79. proxilion/streaming/detector.py +761 -0
  80. proxilion/streaming/transformer.py +674 -0
  81. proxilion/timeouts/__init__.py +55 -0
  82. proxilion/timeouts/decorators.py +477 -0
  83. proxilion/timeouts/manager.py +545 -0
  84. proxilion/tools/__init__.py +69 -0
  85. proxilion/tools/decorators.py +493 -0
  86. proxilion/tools/registry.py +732 -0
  87. proxilion/types.py +339 -0
  88. proxilion/validation/__init__.py +93 -0
  89. proxilion/validation/pydantic_schema.py +351 -0
  90. proxilion/validation/schema.py +651 -0
  91. proxilion-0.0.1.dist-info/METADATA +872 -0
  92. proxilion-0.0.1.dist-info/RECORD +94 -0
  93. proxilion-0.0.1.dist-info/WHEEL +4 -0
  94. proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
proxilion/core.py ADDED
@@ -0,0 +1,3058 @@
1
+ """
2
+ Core Proxilion class and authorization logic.
3
+
4
+ This module provides the main entry point for the Proxilion SDK,
5
+ integrating policy evaluation, schema validation, rate limiting,
6
+ circuit breaking, and audit logging into a unified API.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import contextvars
13
+ import functools
14
+ import inspect
15
+ import logging
16
+ import threading
17
+ from collections.abc import Callable
18
+ from contextlib import contextmanager
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+ from typing import Any, ParamSpec, TypeVar
22
+
23
+ from proxilion.audit.events import AuditEventV2
24
+ from proxilion.audit.logger import AuditLogger, InMemoryAuditLogger, LoggerConfig
25
+ from proxilion.context.context_window import (
26
+ ContextStrategy,
27
+ ContextWindow,
28
+ )
29
+ from proxilion.context.session import (
30
+ Session,
31
+ SessionConfig,
32
+ SessionManager,
33
+ )
34
+ from proxilion.engines import EngineFactory
35
+ from proxilion.exceptions import (
36
+ AuthorizationError,
37
+ CircuitOpenError,
38
+ IDORViolationError,
39
+ InputGuardViolation,
40
+ OutputGuardViolation,
41
+ PolicyNotFoundError,
42
+ PolicyViolation,
43
+ RateLimitExceeded,
44
+ SchemaValidationError,
45
+ )
46
+ from proxilion.guards import GuardAction, GuardResult, InputGuard, OutputGuard
47
+ from proxilion.observability.cost_tracker import (
48
+ CostSummary,
49
+ CostTracker,
50
+ UsageRecord,
51
+ )
52
+ from proxilion.policies.base import Policy
53
+ from proxilion.policies.registry import PolicyRegistry
54
+ from proxilion.providers import (
55
+ Provider,
56
+ ProviderAdapter,
57
+ UnifiedResponse,
58
+ UnifiedToolCall,
59
+ get_adapter,
60
+ )
61
+ from proxilion.resilience.degradation import (
62
+ DegradationTier,
63
+ GracefulDegradation,
64
+ )
65
+ from proxilion.resilience.fallback import (
66
+ FallbackChain,
67
+ FallbackOption,
68
+ ModelFallback,
69
+ ToolFallback,
70
+ )
71
+ from proxilion.resilience.retry import (
72
+ DEFAULT_RETRY_POLICY,
73
+ RetryPolicy,
74
+ retry_async,
75
+ retry_with_backoff,
76
+ )
77
+ from proxilion.security.circuit_breaker import (
78
+ CircuitBreakerRegistry,
79
+ )
80
+ from proxilion.security.cost_limiter import (
81
+ CostLimiter,
82
+ CostLimitResult,
83
+ )
84
+ from proxilion.security.idor_protection import IDORProtector
85
+ from proxilion.security.rate_limiter import (
86
+ RateLimiterMiddleware,
87
+ TokenBucketRateLimiter,
88
+ )
89
+ from proxilion.security.scope_enforcer import (
90
+ ExecutionScope,
91
+ ScopeBinding,
92
+ ScopeContext,
93
+ ScopeEnforcer,
94
+ )
95
+ from proxilion.security.sequence_validator import (
96
+ SequenceRule,
97
+ SequenceValidator,
98
+ SequenceViolation,
99
+ )
100
+ from proxilion.streaming.detector import (
101
+ DetectedToolCall,
102
+ StreamEventType,
103
+ StreamingToolCallDetector,
104
+ )
105
+ from proxilion.streaming.transformer import (
106
+ StreamTransformer,
107
+ create_authorization_stream,
108
+ create_guarded_stream,
109
+ )
110
+ from proxilion.timeouts.manager import (
111
+ DeadlineContext,
112
+ TimeoutConfig,
113
+ TimeoutManager,
114
+ )
115
+ from proxilion.timeouts.manager import (
116
+ TimeoutError as ProxilionTimeoutError,
117
+ )
118
+ from proxilion.tools.decorators import (
119
+ tool,
120
+ )
121
+ from proxilion.tools.registry import (
122
+ RiskLevel,
123
+ ToolCategory,
124
+ ToolDefinition,
125
+ ToolExecutionResult,
126
+ ToolRegistry,
127
+ )
128
+ from proxilion.types import (
129
+ AgentContext,
130
+ AuthorizationResult,
131
+ ToolCallRequest,
132
+ UserContext,
133
+ )
134
+ from proxilion.validation.schema import SchemaValidator, ToolSchema
135
+
136
+ logger = logging.getLogger(__name__)
137
+
138
+ P = ParamSpec("P")
139
+ T = TypeVar("T")
140
+
141
+ # Context variable for current user context
142
+ _current_user: contextvars.ContextVar[UserContext | None] = contextvars.ContextVar(
143
+ "proxilion_user", default=None
144
+ )
145
+
146
+ # Context variable for current agent context
147
+ _current_agent: contextvars.ContextVar[AgentContext | None] = contextvars.ContextVar(
148
+ "proxilion_agent", default=None
149
+ )
150
+
151
+
152
+ def get_current_user() -> UserContext | None:
153
+ """Get the current user from context."""
154
+ return _current_user.get()
155
+
156
+
157
+ def get_current_agent() -> AgentContext | None:
158
+ """Get the current agent from context."""
159
+ return _current_agent.get()
160
+
161
+
162
+ class Proxilion:
163
+ """
164
+ Main entry point for the Proxilion authorization SDK.
165
+
166
+ Proxilion provides application-layer security for LLM tool calls,
167
+ combining policy-based authorization, schema validation, rate limiting,
168
+ circuit breaking, and tamper-evident audit logging.
169
+
170
+ Features:
171
+ - Pundit-style policy definitions
172
+ - Multiple policy engine backends (simple, Casbin, OPA)
173
+ - Schema validation with security checks
174
+ - Token bucket rate limiting
175
+ - Circuit breaker pattern
176
+ - Hash-chained audit logs
177
+
178
+ Example:
179
+ >>> from proxilion import Proxilion, Policy, UserContext
180
+ >>>
181
+ >>> auth = Proxilion(
182
+ ... policy_engine="simple",
183
+ ... audit_log_path="./logs/audit.jsonl"
184
+ ... )
185
+ >>>
186
+ >>> @auth.policy("search")
187
+ ... class SearchPolicy(Policy):
188
+ ... def can_execute(self, context):
189
+ ... return True # All authenticated users
190
+ ...
191
+ ... def can_search_private(self, context):
192
+ ... return "admin" in self.user.roles
193
+ >>>
194
+ >>> @auth.authorize("execute", resource="search")
195
+ ... async def search_tool(query: str, user: UserContext = None):
196
+ ... return await perform_search(query)
197
+ >>>
198
+ >>> user = UserContext(user_id="alice", roles=["analyst"])
199
+ >>> result = await search_tool("find documents", user=user)
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ policy_engine: str = "simple",
205
+ engine_config: dict[str, Any] | None = None,
206
+ audit_log_path: Path | str | None = None,
207
+ rate_limit_config: dict[str, Any] | None = None,
208
+ enable_circuit_breaker: bool = True,
209
+ circuit_breaker_config: dict[str, Any] | None = None,
210
+ enable_idor_protection: bool = True,
211
+ default_deny: bool = True,
212
+ input_guard: InputGuard | None = None,
213
+ output_guard: OutputGuard | None = None,
214
+ sequence_validator: SequenceValidator | None = None,
215
+ scope_enforcer: ScopeEnforcer | None = None,
216
+ cost_tracker: CostTracker | None = None,
217
+ cost_limiter: CostLimiter | None = None,
218
+ session_manager: SessionManager | None = None,
219
+ session_config: SessionConfig | None = None,
220
+ timeout_manager: TimeoutManager | None = None,
221
+ timeout_config: TimeoutConfig | None = None,
222
+ retry_policy: RetryPolicy | None = None,
223
+ graceful_degradation: GracefulDegradation | None = None,
224
+ enable_degradation: bool = False,
225
+ tool_registry: ToolRegistry | None = None,
226
+ ) -> None:
227
+ """
228
+ Initialize Proxilion.
229
+
230
+ Args:
231
+ policy_engine: Policy engine type ("simple", "casbin", "opa").
232
+ engine_config: Configuration for the policy engine.
233
+ audit_log_path: Path for audit log file (None for in-memory).
234
+ rate_limit_config: Rate limiting configuration.
235
+ enable_circuit_breaker: Whether to enable circuit breaker.
236
+ circuit_breaker_config: Circuit breaker configuration.
237
+ enable_idor_protection: Whether to enable IDOR protection.
238
+ default_deny: If True, deny requests with no matching policy.
239
+ input_guard: Optional InputGuard for prompt injection detection.
240
+ output_guard: Optional OutputGuard for data leakage detection.
241
+ sequence_validator: Optional SequenceValidator for tool sequence rules.
242
+ scope_enforcer: Optional ScopeEnforcer for semantic scope enforcement.
243
+ cost_tracker: Optional CostTracker for usage and cost tracking.
244
+ cost_limiter: Optional CostLimiter for cost-based rate limiting.
245
+ session_manager: Optional SessionManager for session tracking.
246
+ session_config: Optional SessionConfig for session defaults.
247
+ timeout_manager: Optional TimeoutManager for timeout handling.
248
+ timeout_config: Optional TimeoutConfig for timeout defaults.
249
+ retry_policy: Optional RetryPolicy for retry behavior.
250
+ graceful_degradation: Optional GracefulDegradation for tier management.
251
+ enable_degradation: Whether to enable graceful degradation.
252
+ tool_registry: Optional ToolRegistry for tool management.
253
+ """
254
+ self._lock = threading.RLock()
255
+ self._default_deny = default_deny
256
+
257
+ # Initialize policy registry
258
+ self._registry = PolicyRegistry()
259
+
260
+ # Initialize policy engine
261
+ self._engine = EngineFactory.create(
262
+ engine_type=policy_engine,
263
+ config=engine_config or {},
264
+ )
265
+
266
+ # Initialize schema validator
267
+ self._schema_validator = SchemaValidator()
268
+
269
+ # Initialize rate limiter
270
+ self._rate_limiter: RateLimiterMiddleware | None = None
271
+ if rate_limit_config:
272
+ self._setup_rate_limiter(rate_limit_config)
273
+
274
+ # Initialize circuit breaker registry
275
+ self._circuit_breakers: CircuitBreakerRegistry | None = None
276
+ if enable_circuit_breaker:
277
+ cb_config = circuit_breaker_config or {}
278
+ self._circuit_breakers = CircuitBreakerRegistry(
279
+ default_config={
280
+ "failure_threshold": cb_config.get("failure_threshold", 5),
281
+ "reset_timeout": cb_config.get("reset_timeout", 30.0),
282
+ }
283
+ )
284
+
285
+ # Initialize IDOR protector
286
+ self._idor_protector: IDORProtector | None = None
287
+ if enable_idor_protection:
288
+ self._idor_protector = IDORProtector()
289
+
290
+ # Initialize guards
291
+ self._input_guard = input_guard
292
+ self._output_guard = output_guard
293
+
294
+ # Initialize sequence validator
295
+ self._sequence_validator = sequence_validator
296
+
297
+ # Initialize scope enforcer
298
+ self._scope_enforcer = scope_enforcer
299
+ self._current_scope: ScopeBinding | None = None
300
+
301
+ # Initialize cost tracking
302
+ self._cost_tracker = cost_tracker
303
+ self._cost_limiter = cost_limiter
304
+ if cost_limiter and cost_tracker:
305
+ cost_limiter.set_cost_tracker(cost_tracker)
306
+
307
+ # Initialize session management
308
+ self._session_config = session_config or SessionConfig()
309
+ self._session_manager = session_manager or SessionManager(self._session_config)
310
+
311
+ # Initialize timeout management
312
+ self._timeout_config = timeout_config or TimeoutConfig()
313
+ self._timeout_manager = timeout_manager or TimeoutManager(self._timeout_config)
314
+
315
+ # Initialize resilience components
316
+ self._retry_policy = retry_policy or DEFAULT_RETRY_POLICY
317
+ self._graceful_degradation: GracefulDegradation | None = None
318
+ if enable_degradation:
319
+ self._graceful_degradation = graceful_degradation or GracefulDegradation()
320
+
321
+ # Initialize tool registry
322
+ self._tool_registry = tool_registry or ToolRegistry()
323
+
324
+ # Initialize audit logger
325
+ self._audit_logger: AuditLogger | InMemoryAuditLogger
326
+ if audit_log_path:
327
+ config = LoggerConfig.default(Path(audit_log_path))
328
+ self._audit_logger = AuditLogger(config)
329
+ else:
330
+ self._audit_logger = InMemoryAuditLogger()
331
+
332
+ def _setup_rate_limiter(self, config: dict[str, Any]) -> None:
333
+ """Set up rate limiter from configuration."""
334
+ user_limit = None
335
+ if "user" in config:
336
+ user_cfg = config["user"]
337
+ user_limit = TokenBucketRateLimiter(
338
+ capacity=user_cfg.get("capacity", 100),
339
+ refill_rate=user_cfg.get("refill_rate", 10.0),
340
+ )
341
+
342
+ global_limit = None
343
+ if "global" in config:
344
+ global_cfg = config["global"]
345
+ global_limit = TokenBucketRateLimiter(
346
+ capacity=global_cfg.get("capacity", 1000),
347
+ refill_rate=global_cfg.get("refill_rate", 100.0),
348
+ )
349
+
350
+ tool_limits: dict[str, TokenBucketRateLimiter] = {}
351
+ if "tools" in config:
352
+ for tool_name, tool_cfg in config["tools"].items():
353
+ tool_limits[tool_name] = TokenBucketRateLimiter(
354
+ capacity=tool_cfg.get("capacity", 50),
355
+ refill_rate=tool_cfg.get("refill_rate", 5.0),
356
+ )
357
+
358
+ self._rate_limiter = RateLimiterMiddleware(
359
+ user_limit=user_limit,
360
+ tool_limits=tool_limits,
361
+ global_limit=global_limit,
362
+ )
363
+
364
+ # ==================== Policy Registration ====================
365
+
366
+ def policy(self, resource_name: str) -> Callable[[type[Policy]], type[Policy]]:
367
+ """
368
+ Decorator to register a policy class for a resource.
369
+
370
+ Args:
371
+ resource_name: The resource this policy applies to.
372
+
373
+ Returns:
374
+ A decorator that registers the policy class.
375
+
376
+ Example:
377
+ >>> @auth.policy("database_query")
378
+ ... class DatabaseQueryPolicy(Policy):
379
+ ... def can_execute(self, context):
380
+ ... return "analyst" in self.user.roles
381
+ """
382
+ return self._registry.policy(resource_name)
383
+
384
+ def register_policy(
385
+ self,
386
+ resource_name: str,
387
+ policy_class: type[Policy],
388
+ ) -> None:
389
+ """
390
+ Register a policy class programmatically.
391
+
392
+ Args:
393
+ resource_name: The resource this policy applies to.
394
+ policy_class: The policy class to register.
395
+ """
396
+ self._registry.register(resource_name, policy_class)
397
+
398
+ # ==================== Schema Registration ====================
399
+
400
+ def register_schema(self, tool_name: str, schema: ToolSchema) -> None:
401
+ """
402
+ Register a tool schema for validation.
403
+
404
+ Args:
405
+ tool_name: The tool name.
406
+ schema: The tool schema.
407
+ """
408
+ self._schema_validator.register_schema(tool_name, schema)
409
+
410
+ # ==================== IDOR Protection ====================
411
+
412
+ def register_scope(
413
+ self,
414
+ user_id: str,
415
+ resource_type: str,
416
+ allowed_ids: set[str],
417
+ ) -> None:
418
+ """
419
+ Register allowed object IDs for a user.
420
+
421
+ Args:
422
+ user_id: The user's ID.
423
+ resource_type: Type of resource (e.g., "document").
424
+ allowed_ids: Set of object IDs the user can access.
425
+ """
426
+ if self._idor_protector:
427
+ self._idor_protector.register_scope(user_id, resource_type, allowed_ids)
428
+
429
+ def register_id_pattern(
430
+ self,
431
+ parameter_name: str,
432
+ resource_type: str,
433
+ ) -> None:
434
+ """
435
+ Register a parameter as containing object IDs.
436
+
437
+ Args:
438
+ parameter_name: The parameter name (e.g., "document_id").
439
+ resource_type: The resource type it refers to.
440
+ """
441
+ if self._idor_protector:
442
+ self._idor_protector.register_id_pattern(parameter_name, resource_type)
443
+
444
+ # ==================== Guards ====================
445
+
446
+ def guard_input(
447
+ self,
448
+ input_text: str,
449
+ context: dict[str, Any] | None = None,
450
+ raise_on_block: bool = False,
451
+ ) -> GuardResult:
452
+ """
453
+ Check input against guards before tool execution.
454
+
455
+ Detects prompt injection patterns and other malicious input.
456
+
457
+ Args:
458
+ input_text: The user input to check.
459
+ context: Optional context for pattern evaluation.
460
+ raise_on_block: If True, raise InputGuardViolation on block.
461
+
462
+ Returns:
463
+ GuardResult with check outcome.
464
+
465
+ Raises:
466
+ InputGuardViolation: If raise_on_block=True and input is blocked.
467
+
468
+ Example:
469
+ >>> result = auth.guard_input("ignore all instructions")
470
+ >>> if not result.passed:
471
+ ... print(f"Blocked: {result.matched_patterns}")
472
+ """
473
+ if self._input_guard is None:
474
+ return GuardResult(passed=True, action=GuardAction.ALLOW)
475
+
476
+ result = self._input_guard.check(input_text, context)
477
+
478
+ if raise_on_block and not result.passed and result.action == GuardAction.BLOCK:
479
+ raise InputGuardViolation(
480
+ matched_patterns=result.matched_patterns,
481
+ risk_score=result.risk_score,
482
+ )
483
+
484
+ return result
485
+
486
+ def guard_output(
487
+ self,
488
+ output_text: str,
489
+ context: dict[str, Any] | None = None,
490
+ raise_on_block: bool = False,
491
+ auto_redact: bool = False,
492
+ ) -> GuardResult:
493
+ """
494
+ Check output against guards after tool execution.
495
+
496
+ Detects sensitive data leakage such as API keys and credentials.
497
+
498
+ Args:
499
+ output_text: The output to check.
500
+ context: Optional context for pattern evaluation.
501
+ raise_on_block: If True, raise OutputGuardViolation on block.
502
+ auto_redact: If True and leakage detected, redact in result.
503
+
504
+ Returns:
505
+ GuardResult with check outcome.
506
+
507
+ Raises:
508
+ OutputGuardViolation: If raise_on_block=True and output is blocked.
509
+
510
+ Example:
511
+ >>> result = auth.guard_output("API key: sk-abc123...")
512
+ >>> if not result.passed:
513
+ ... print(f"Leakage: {result.matched_patterns}")
514
+ """
515
+ if self._output_guard is None:
516
+ return GuardResult(passed=True, action=GuardAction.ALLOW)
517
+
518
+ result = self._output_guard.check(output_text, context)
519
+
520
+ if auto_redact and not result.passed:
521
+ result.sanitized_input = self._output_guard.redact(output_text)
522
+
523
+ if raise_on_block and not result.passed and result.action == GuardAction.BLOCK:
524
+ raise OutputGuardViolation(
525
+ matched_patterns=result.matched_patterns,
526
+ risk_score=result.risk_score,
527
+ )
528
+
529
+ return result
530
+
531
+ def redact_output(self, output_text: str) -> str:
532
+ """
533
+ Redact sensitive data from output text.
534
+
535
+ Args:
536
+ output_text: Text to redact.
537
+
538
+ Returns:
539
+ Text with sensitive data redacted.
540
+ """
541
+ if self._output_guard is None:
542
+ return output_text
543
+ return self._output_guard.redact(output_text)
544
+
545
+ def set_input_guard(self, guard: InputGuard | None) -> None:
546
+ """
547
+ Set or replace the input guard.
548
+
549
+ Args:
550
+ guard: The input guard to use, or None to disable.
551
+ """
552
+ self._input_guard = guard
553
+
554
+ def set_output_guard(self, guard: OutputGuard | None) -> None:
555
+ """
556
+ Set or replace the output guard.
557
+
558
+ Args:
559
+ guard: The output guard to use, or None to disable.
560
+ """
561
+ self._output_guard = guard
562
+
563
+ # ==================== Sequence Validation ====================
564
+
565
+ def validate_sequence(
566
+ self,
567
+ tool_name: str,
568
+ user: UserContext,
569
+ ) -> tuple[bool, SequenceViolation | None]:
570
+ """
571
+ Validate a tool call against sequence rules.
572
+
573
+ Checks if the tool call is allowed given the user's recent
574
+ tool call history. Prevents dangerous patterns like:
575
+ - Calling delete without confirm first
576
+ - Download followed by execute
577
+ - Rapid consecutive calls
578
+
579
+ Args:
580
+ tool_name: Name of the tool to validate.
581
+ user: The user context.
582
+
583
+ Returns:
584
+ Tuple of (allowed, violation). If allowed is False,
585
+ violation contains details about what rule was violated.
586
+
587
+ Example:
588
+ >>> allowed, violation = auth.validate_sequence("delete_file", user)
589
+ >>> if not allowed:
590
+ ... print(f"Blocked: {violation.message}")
591
+ """
592
+ if self._sequence_validator is None:
593
+ return True, None
594
+ return self._sequence_validator.validate_call(tool_name, user.user_id)
595
+
596
+ def record_tool_call(
597
+ self,
598
+ tool_name: str,
599
+ user: UserContext,
600
+ ) -> None:
601
+ """
602
+ Record a tool call for sequence tracking.
603
+
604
+ Should be called after a tool call completes successfully
605
+ to maintain accurate history for sequence validation.
606
+
607
+ Args:
608
+ tool_name: Name of the tool called.
609
+ user: The user context.
610
+
611
+ Example:
612
+ >>> auth.record_tool_call("confirm_delete", user)
613
+ >>> auth.record_tool_call("delete_file", user)
614
+ """
615
+ if self._sequence_validator:
616
+ self._sequence_validator.record_call(tool_name, user.user_id)
617
+
618
+ def add_sequence_rule(self, rule: SequenceRule) -> None:
619
+ """
620
+ Add a sequence validation rule.
621
+
622
+ Args:
623
+ rule: The rule to add.
624
+
625
+ Example:
626
+ >>> auth.add_sequence_rule(SequenceRule(
627
+ ... name="require_auth",
628
+ ... action=SequenceAction.REQUIRE_BEFORE,
629
+ ... target_pattern="access_*",
630
+ ... required_pattern="authenticate",
631
+ ... ))
632
+ """
633
+ if self._sequence_validator:
634
+ self._sequence_validator.add_rule(rule)
635
+
636
+ def remove_sequence_rule(self, name: str) -> bool:
637
+ """
638
+ Remove a sequence rule by name.
639
+
640
+ Args:
641
+ name: The rule name to remove.
642
+
643
+ Returns:
644
+ True if rule was removed, False if not found.
645
+ """
646
+ if self._sequence_validator:
647
+ return self._sequence_validator.remove_rule(name)
648
+ return False
649
+
650
+ def get_tool_history(
651
+ self,
652
+ user: UserContext,
653
+ limit: int | None = None,
654
+ ) -> list[tuple[str, Any]]:
655
+ """
656
+ Get tool call history for a user.
657
+
658
+ Args:
659
+ user: The user context.
660
+ limit: Maximum entries to return.
661
+
662
+ Returns:
663
+ List of (tool_name, timestamp) tuples, most recent first.
664
+ """
665
+ if self._sequence_validator:
666
+ return self._sequence_validator.get_history(user.user_id, limit)
667
+ return []
668
+
669
+ def clear_tool_history(self, user: UserContext | None = None) -> None:
670
+ """
671
+ Clear tool call history.
672
+
673
+ Args:
674
+ user: User to clear history for (None to clear all).
675
+ """
676
+ if self._sequence_validator:
677
+ user_id = user.user_id if user else None
678
+ self._sequence_validator.clear_history(user_id)
679
+
680
+ def set_sequence_validator(
681
+ self,
682
+ validator: SequenceValidator | None,
683
+ ) -> None:
684
+ """
685
+ Set or replace the sequence validator.
686
+
687
+ Args:
688
+ validator: The validator to use, or None to disable.
689
+ """
690
+ self._sequence_validator = validator
691
+
692
+ # ==================== Scope Enforcement Methods ====================
693
+
694
+ def enter_scope(
695
+ self,
696
+ scope: ExecutionScope | str,
697
+ user: UserContext,
698
+ ) -> ScopeContext:
699
+ """
700
+ Enter a scoped execution context.
701
+
702
+ Creates a ScopeContext that validates tool calls against the
703
+ specified scope's restrictions.
704
+
705
+ Args:
706
+ scope: Scope name or ExecutionScope enum.
707
+ user: The user context for this execution.
708
+
709
+ Returns:
710
+ ScopeContext for validating tool calls.
711
+
712
+ Raises:
713
+ ProxilionError: If scope enforcer is not configured.
714
+
715
+ Example:
716
+ >>> ctx = auth.enter_scope("read_only", user)
717
+ >>> try:
718
+ ... ctx.validate_tool("get_user") # OK
719
+ ... ctx.validate_tool("delete_user") # Raises ScopeViolationError
720
+ ... finally:
721
+ ... ctx.close()
722
+ """
723
+ if self._scope_enforcer is None:
724
+ from proxilion.exceptions import ProxilionError
725
+ raise ProxilionError("Scope enforcer not configured")
726
+
727
+ if isinstance(scope, str):
728
+ scope_binding = self._scope_enforcer.get_scope(scope)
729
+ else:
730
+ scope_binding = self._scope_enforcer.create_scope_from_enum(scope)
731
+
732
+ return ScopeContext(self._scope_enforcer, scope_binding, user)
733
+
734
+ def validate_scope(
735
+ self,
736
+ tool_name: str,
737
+ action: str = "execute",
738
+ ) -> tuple[bool, str | None]:
739
+ """
740
+ Validate a tool against the current scope.
741
+
742
+ Args:
743
+ tool_name: Name of the tool to validate.
744
+ action: Action being performed.
745
+
746
+ Returns:
747
+ Tuple of (allowed, reason). If allowed is False, reason explains why.
748
+ """
749
+ if self._current_scope is None or self._scope_enforcer is None:
750
+ return True, None # No scope enforcement
751
+
752
+ return self._scope_enforcer.validate_in_scope(
753
+ tool_name, action, self._current_scope
754
+ )
755
+
756
+ def set_scope_enforcer(
757
+ self,
758
+ enforcer: ScopeEnforcer | None,
759
+ ) -> None:
760
+ """
761
+ Set or replace the scope enforcer.
762
+
763
+ Args:
764
+ enforcer: The enforcer to use, or None to disable.
765
+ """
766
+ self._scope_enforcer = enforcer
767
+
768
+ def get_scope_enforcer(self) -> ScopeEnforcer | None:
769
+ """Get the current scope enforcer."""
770
+ return self._scope_enforcer
771
+
772
+ def create_scope(
773
+ self,
774
+ name: str,
775
+ allowed_tools: set[str] | None = None,
776
+ denied_tools: set[str] | None = None,
777
+ allowed_actions: set[str] | None = None,
778
+ denied_actions: set[str] | None = None,
779
+ description: str = "",
780
+ ) -> ScopeBinding:
781
+ """
782
+ Create a custom scope binding.
783
+
784
+ Args:
785
+ name: Unique name for the scope.
786
+ allowed_tools: Set of tool patterns allowed.
787
+ denied_tools: Set of tool patterns denied.
788
+ allowed_actions: Set of actions allowed.
789
+ denied_actions: Set of actions denied.
790
+ description: Human-readable description.
791
+
792
+ Returns:
793
+ The created ScopeBinding.
794
+
795
+ Raises:
796
+ ProxilionError: If scope enforcer is not configured.
797
+
798
+ Example:
799
+ >>> scope = auth.create_scope(
800
+ ... name="user_data",
801
+ ... allowed_tools={"get_user_*", "search_users"},
802
+ ... denied_tools={"delete_*"},
803
+ ... allowed_actions={"read", "list"},
804
+ ... )
805
+ """
806
+ if self._scope_enforcer is None:
807
+ from proxilion.exceptions import ProxilionError
808
+ raise ProxilionError("Scope enforcer not configured")
809
+
810
+ return self._scope_enforcer.create_scope(
811
+ name=name,
812
+ allowed_tools=allowed_tools,
813
+ denied_tools=denied_tools,
814
+ allowed_actions=allowed_actions,
815
+ denied_actions=denied_actions,
816
+ description=description,
817
+ )
818
+
819
+ def classify_tool(
820
+ self,
821
+ tool_name: str,
822
+ scope: ExecutionScope | None = None,
823
+ actions: set[str] | None = None,
824
+ ) -> None:
825
+ """
826
+ Classify a tool with a specific scope.
827
+
828
+ Args:
829
+ tool_name: Name of the tool.
830
+ scope: Scope to assign (or None for pattern-based classification).
831
+ actions: Actions the tool performs.
832
+
833
+ Raises:
834
+ ProxilionError: If scope enforcer is not configured.
835
+ """
836
+ if self._scope_enforcer is None:
837
+ from proxilion.exceptions import ProxilionError
838
+ raise ProxilionError("Scope enforcer not configured")
839
+
840
+ self._scope_enforcer.classify_tool(tool_name, scope, actions)
841
+
842
+ # ==================== Cost Tracking Methods ====================
843
+
844
+ def record_usage(
845
+ self,
846
+ model: str,
847
+ input_tokens: int,
848
+ output_tokens: int,
849
+ cache_read_tokens: int = 0,
850
+ cache_write_tokens: int = 0,
851
+ tool_name: str | None = None,
852
+ user: UserContext | None = None,
853
+ request_id: str | None = None,
854
+ ) -> UsageRecord | None:
855
+ """
856
+ Record token usage and calculate cost.
857
+
858
+ Args:
859
+ model: Model identifier.
860
+ input_tokens: Number of input tokens.
861
+ output_tokens: Number of output tokens.
862
+ cache_read_tokens: Number of cached tokens read.
863
+ cache_write_tokens: Number of tokens written to cache.
864
+ tool_name: Tool that triggered the usage.
865
+ user: User who incurred the usage.
866
+ request_id: Request identifier.
867
+
868
+ Returns:
869
+ UsageRecord if cost tracker is configured, None otherwise.
870
+
871
+ Example:
872
+ >>> record = auth.record_usage(
873
+ ... model="claude-sonnet-4-20250514",
874
+ ... input_tokens=1000,
875
+ ... output_tokens=500,
876
+ ... user=user,
877
+ ... )
878
+ >>> print(f"Cost: ${record.cost_usd:.4f}")
879
+ """
880
+ if self._cost_tracker is None:
881
+ return None
882
+
883
+ return self._cost_tracker.record_usage(
884
+ model=model,
885
+ input_tokens=input_tokens,
886
+ output_tokens=output_tokens,
887
+ cache_read_tokens=cache_read_tokens,
888
+ cache_write_tokens=cache_write_tokens,
889
+ tool_name=tool_name,
890
+ user_id=user.user_id if user else None,
891
+ request_id=request_id,
892
+ )
893
+
894
+ def check_budget(
895
+ self,
896
+ user: UserContext,
897
+ estimated_cost: float = 0.0,
898
+ estimated_tokens: int = 0,
899
+ ) -> tuple[bool, str | None]:
900
+ """
901
+ Check if a request would exceed budget limits.
902
+
903
+ Args:
904
+ user: User making the request.
905
+ estimated_cost: Estimated cost of the request.
906
+ estimated_tokens: Estimated tokens for the request.
907
+
908
+ Returns:
909
+ Tuple of (allowed, reason). If not allowed, reason explains why.
910
+
911
+ Example:
912
+ >>> allowed, reason = auth.check_budget(user, estimated_tokens=10000)
913
+ >>> if not allowed:
914
+ ... print(f"Budget issue: {reason}")
915
+ """
916
+ if self._cost_tracker is None:
917
+ return True, None
918
+
919
+ return self._cost_tracker.check_budget(
920
+ user_id=user.user_id,
921
+ estimated_cost=estimated_cost,
922
+ estimated_tokens=estimated_tokens,
923
+ )
924
+
925
+ def check_cost_limit(
926
+ self,
927
+ user: UserContext,
928
+ estimated_cost: float,
929
+ ) -> CostLimitResult | None:
930
+ """
931
+ Check cost-based rate limits.
932
+
933
+ Args:
934
+ user: User making the request.
935
+ estimated_cost: Estimated cost of the request.
936
+
937
+ Returns:
938
+ CostLimitResult if cost limiter is configured, None otherwise.
939
+ """
940
+ if self._cost_limiter is None:
941
+ return None
942
+
943
+ return self._cost_limiter.check_limit(user.user_id, estimated_cost)
944
+
945
+ def get_cost_summary(
946
+ self,
947
+ user: UserContext | None = None,
948
+ start: datetime | None = None,
949
+ end: datetime | None = None,
950
+ ) -> CostSummary | None:
951
+ """
952
+ Get a cost summary for the specified period.
953
+
954
+ Args:
955
+ user: Filter by user (or None for all users).
956
+ start: Start of period.
957
+ end: End of period.
958
+
959
+ Returns:
960
+ CostSummary if cost tracker is configured, None otherwise.
961
+ """
962
+ if self._cost_tracker is None:
963
+ return None
964
+
965
+ return self._cost_tracker.get_summary(
966
+ start=start,
967
+ end=end,
968
+ user_id=user.user_id if user else None,
969
+ )
970
+
971
+ def get_budget_status(self, user: UserContext) -> dict[str, Any]:
972
+ """
973
+ Get current budget status for a user.
974
+
975
+ Args:
976
+ user: User to check.
977
+
978
+ Returns:
979
+ Dictionary with budget status information.
980
+ """
981
+ if self._cost_tracker is None:
982
+ return {"cost_tracking_enabled": False}
983
+
984
+ return self._cost_tracker.get_budget_status(user.user_id)
985
+
986
+ def set_cost_tracker(self, tracker: CostTracker | None) -> None:
987
+ """
988
+ Set or replace the cost tracker.
989
+
990
+ Args:
991
+ tracker: The tracker to use, or None to disable.
992
+ """
993
+ self._cost_tracker = tracker
994
+ if self._cost_limiter and tracker:
995
+ self._cost_limiter.set_cost_tracker(tracker)
996
+
997
+ def set_cost_limiter(self, limiter: CostLimiter | None) -> None:
998
+ """
999
+ Set or replace the cost limiter.
1000
+
1001
+ Args:
1002
+ limiter: The limiter to use, or None to disable.
1003
+ """
1004
+ self._cost_limiter = limiter
1005
+ if limiter and self._cost_tracker:
1006
+ limiter.set_cost_tracker(self._cost_tracker)
1007
+
1008
+ def get_cost_tracker(self) -> CostTracker | None:
1009
+ """Get the current cost tracker."""
1010
+ return self._cost_tracker
1011
+
1012
+ def get_cost_limiter(self) -> CostLimiter | None:
1013
+ """Get the current cost limiter."""
1014
+ return self._cost_limiter
1015
+
1016
+ # ==================== Session Management Methods ====================
1017
+
1018
+ def create_session(
1019
+ self,
1020
+ user: UserContext,
1021
+ session_id: str | None = None,
1022
+ config: SessionConfig | None = None,
1023
+ metadata: dict[str, Any] | None = None,
1024
+ ) -> Session:
1025
+ """
1026
+ Create a new session for a user.
1027
+
1028
+ Args:
1029
+ user: The user context.
1030
+ session_id: Optional session ID (auto-generated if not provided).
1031
+ config: Optional session-specific configuration.
1032
+ metadata: Optional initial metadata.
1033
+
1034
+ Returns:
1035
+ The created session.
1036
+
1037
+ Example:
1038
+ >>> session = auth.create_session(user)
1039
+ >>> session.add_message(MessageRole.USER, "Hello!")
1040
+ """
1041
+ return self._session_manager.create_session(
1042
+ user=user,
1043
+ session_id=session_id,
1044
+ config=config,
1045
+ metadata=metadata,
1046
+ )
1047
+
1048
+ def get_session(self, session_id: str) -> Session | None:
1049
+ """
1050
+ Get a session by ID.
1051
+
1052
+ Args:
1053
+ session_id: The session ID.
1054
+
1055
+ Returns:
1056
+ The session if found and not expired, None otherwise.
1057
+ """
1058
+ return self._session_manager.get_session(session_id)
1059
+
1060
+ def get_or_create_session(
1061
+ self,
1062
+ user: UserContext,
1063
+ session_id: str | None = None,
1064
+ config: SessionConfig | None = None,
1065
+ ) -> tuple[Session, bool]:
1066
+ """
1067
+ Get an existing session or create a new one.
1068
+
1069
+ Args:
1070
+ user: The user context.
1071
+ session_id: Optional session ID to look up.
1072
+ config: Optional session configuration for creation.
1073
+
1074
+ Returns:
1075
+ Tuple of (session, created) where created is True if new.
1076
+
1077
+ Example:
1078
+ >>> session, is_new = auth.get_or_create_session(user, "sess_123")
1079
+ >>> if is_new:
1080
+ ... print("Created new session")
1081
+ """
1082
+ return self._session_manager.get_or_create_session(
1083
+ user=user,
1084
+ session_id=session_id,
1085
+ config=config,
1086
+ )
1087
+
1088
+ def get_user_sessions(
1089
+ self,
1090
+ user: UserContext,
1091
+ include_expired: bool = False,
1092
+ ) -> list[Session]:
1093
+ """
1094
+ Get all sessions for a user.
1095
+
1096
+ Args:
1097
+ user: The user context.
1098
+ include_expired: Whether to include expired sessions.
1099
+
1100
+ Returns:
1101
+ List of sessions for the user.
1102
+ """
1103
+ return self._session_manager.get_user_sessions(
1104
+ user_id=user.user_id,
1105
+ include_expired=include_expired,
1106
+ )
1107
+
1108
+ def terminate_session(
1109
+ self,
1110
+ session_id: str,
1111
+ reason: str | None = None,
1112
+ ) -> bool:
1113
+ """
1114
+ Terminate a session.
1115
+
1116
+ Args:
1117
+ session_id: The session ID.
1118
+ reason: Optional reason for termination.
1119
+
1120
+ Returns:
1121
+ True if session was found and terminated.
1122
+ """
1123
+ return self._session_manager.terminate_session(session_id, reason)
1124
+
1125
+ def terminate_user_sessions(
1126
+ self,
1127
+ user: UserContext,
1128
+ reason: str | None = None,
1129
+ ) -> int:
1130
+ """
1131
+ Terminate all sessions for a user.
1132
+
1133
+ Args:
1134
+ user: The user context.
1135
+ reason: Optional reason for termination.
1136
+
1137
+ Returns:
1138
+ Number of sessions terminated.
1139
+ """
1140
+ return self._session_manager.terminate_user_sessions(user.user_id, reason)
1141
+
1142
+ def cleanup_expired_sessions(self) -> int:
1143
+ """
1144
+ Remove expired and terminated sessions.
1145
+
1146
+ Returns:
1147
+ Number of sessions removed.
1148
+ """
1149
+ return self._session_manager.cleanup_expired()
1150
+
1151
+ def get_session_stats(self) -> dict[str, Any]:
1152
+ """
1153
+ Get session statistics.
1154
+
1155
+ Returns:
1156
+ Dictionary with session statistics.
1157
+
1158
+ Example:
1159
+ >>> stats = auth.get_session_stats()
1160
+ >>> print(f"Active: {stats['active']}, Expired: {stats['expired']}")
1161
+ """
1162
+ return self._session_manager.get_stats()
1163
+
1164
+ def get_active_session_count(self) -> int:
1165
+ """
1166
+ Get count of active (non-expired) sessions.
1167
+
1168
+ Returns:
1169
+ Number of active sessions.
1170
+ """
1171
+ return self._session_manager.get_active_count()
1172
+
1173
+ def set_session_manager(self, manager: SessionManager | None) -> None:
1174
+ """
1175
+ Set or replace the session manager.
1176
+
1177
+ Args:
1178
+ manager: The manager to use. If None, creates a new default manager.
1179
+ """
1180
+ if manager is None:
1181
+ self._session_manager = SessionManager(self._session_config)
1182
+ else:
1183
+ self._session_manager = manager
1184
+
1185
+ def get_session_manager(self) -> SessionManager:
1186
+ """Get the current session manager."""
1187
+ return self._session_manager
1188
+
1189
+ def create_context_window(
1190
+ self,
1191
+ max_tokens: int,
1192
+ strategy: ContextStrategy = ContextStrategy.KEEP_SYSTEM_RECENT,
1193
+ reserve_output: int = 1000,
1194
+ ) -> ContextWindow:
1195
+ """
1196
+ Create a context window for managing LLM context.
1197
+
1198
+ Args:
1199
+ max_tokens: Maximum tokens for context.
1200
+ strategy: Strategy for fitting messages.
1201
+ reserve_output: Tokens to reserve for output.
1202
+
1203
+ Returns:
1204
+ Configured ContextWindow.
1205
+
1206
+ Example:
1207
+ >>> window = auth.create_context_window(8000)
1208
+ >>> messages = session.history.get_messages()
1209
+ >>> fitted = window.fit_messages(messages)
1210
+ """
1211
+ return ContextWindow(
1212
+ max_tokens=max_tokens,
1213
+ strategy=strategy,
1214
+ reserve_output=reserve_output,
1215
+ )
1216
+
1217
+ # ==================== Timeout Management Methods ====================
1218
+
1219
+ def get_timeout(self, operation: str) -> float:
1220
+ """
1221
+ Get timeout for a specific operation.
1222
+
1223
+ Args:
1224
+ operation: Name of the operation or tool.
1225
+
1226
+ Returns:
1227
+ Timeout value in seconds.
1228
+
1229
+ Example:
1230
+ >>> timeout = auth.get_timeout("web_search")
1231
+ """
1232
+ return self._timeout_manager.get_timeout(operation)
1233
+
1234
+ def get_llm_timeout(self) -> float:
1235
+ """
1236
+ Get timeout for LLM operations.
1237
+
1238
+ Returns:
1239
+ LLM timeout in seconds.
1240
+ """
1241
+ return self._timeout_manager.get_llm_timeout()
1242
+
1243
+ def set_tool_timeout(self, tool_name: str, timeout: float) -> None:
1244
+ """
1245
+ Set timeout for a specific tool.
1246
+
1247
+ Args:
1248
+ tool_name: Name of the tool.
1249
+ timeout: Timeout value in seconds.
1250
+ """
1251
+ self._timeout_manager.set_tool_timeout(tool_name, timeout)
1252
+
1253
+ def create_deadline(
1254
+ self,
1255
+ timeout: float | None = None,
1256
+ operation: str | None = None,
1257
+ ) -> DeadlineContext:
1258
+ """
1259
+ Create a deadline context for tracking time budget.
1260
+
1261
+ Args:
1262
+ timeout: Explicit timeout (uses total_request_timeout if None).
1263
+ operation: Optional operation name.
1264
+
1265
+ Returns:
1266
+ DeadlineContext for tracking the deadline.
1267
+
1268
+ Example:
1269
+ >>> async with auth.create_deadline(30.0) as deadline:
1270
+ ... result1 = await tool1(timeout=deadline.remaining())
1271
+ ... result2 = await tool2(timeout=deadline.remaining())
1272
+ """
1273
+ return self._timeout_manager.create_deadline(timeout, operation)
1274
+
1275
+ def create_tool_deadline(self, tool_name: str) -> DeadlineContext:
1276
+ """
1277
+ Create a deadline context for a specific tool.
1278
+
1279
+ Args:
1280
+ tool_name: Name of the tool.
1281
+
1282
+ Returns:
1283
+ DeadlineContext with tool-specific timeout.
1284
+ """
1285
+ return self._timeout_manager.create_tool_deadline(tool_name)
1286
+
1287
+ def get_effective_timeout(
1288
+ self,
1289
+ operation: str,
1290
+ requested_timeout: float | None = None,
1291
+ ) -> float:
1292
+ """
1293
+ Get effective timeout considering current deadline.
1294
+
1295
+ If there's an active deadline context, returns the minimum
1296
+ of the requested timeout and remaining deadline time.
1297
+
1298
+ Args:
1299
+ operation: Name of the operation.
1300
+ requested_timeout: Requested timeout (uses config if None).
1301
+
1302
+ Returns:
1303
+ Effective timeout in seconds.
1304
+ """
1305
+ return self._timeout_manager.get_effective_timeout(operation, requested_timeout)
1306
+
1307
+ def set_timeout_manager(self, manager: TimeoutManager | None) -> None:
1308
+ """
1309
+ Set or replace the timeout manager.
1310
+
1311
+ Args:
1312
+ manager: The manager to use. If None, creates a new default manager.
1313
+ """
1314
+ if manager is None:
1315
+ self._timeout_manager = TimeoutManager(self._timeout_config)
1316
+ else:
1317
+ self._timeout_manager = manager
1318
+
1319
+ def get_timeout_manager(self) -> TimeoutManager:
1320
+ """Get the current timeout manager."""
1321
+ return self._timeout_manager
1322
+
1323
+ async def authorize_with_timeout(
1324
+ self,
1325
+ user: UserContext,
1326
+ action: str,
1327
+ resource: str,
1328
+ timeout: float | None = None,
1329
+ context: dict[str, Any] | None = None,
1330
+ ) -> AuthorizationResult:
1331
+ """
1332
+ Authorize with a timeout.
1333
+
1334
+ Args:
1335
+ user: The user context.
1336
+ action: The action to perform.
1337
+ resource: The resource name.
1338
+ timeout: Explicit timeout (uses default if None).
1339
+ context: Additional context for the policy.
1340
+
1341
+ Returns:
1342
+ AuthorizationResult if authorized within timeout.
1343
+
1344
+ Raises:
1345
+ ProxilionTimeoutError: If authorization times out.
1346
+
1347
+ Example:
1348
+ >>> result = await auth.authorize_with_timeout(
1349
+ ... user, "execute", "search", timeout=5.0
1350
+ ... )
1351
+ """
1352
+ effective_timeout = timeout or self._timeout_config.default_timeout
1353
+ try:
1354
+ return await asyncio.wait_for(
1355
+ asyncio.to_thread(self.check, user, action, resource, context),
1356
+ timeout=effective_timeout,
1357
+ )
1358
+ except asyncio.TimeoutError as e:
1359
+ raise ProxilionTimeoutError(
1360
+ message="Authorization check timed out",
1361
+ operation=f"authorize:{action}:{resource}",
1362
+ timeout=effective_timeout,
1363
+ ) from e
1364
+
1365
+ # ==================== Authorization Methods ====================
1366
+
1367
+ def can(
1368
+ self,
1369
+ user: UserContext,
1370
+ action: str,
1371
+ resource: str,
1372
+ context: dict[str, Any] | None = None,
1373
+ ) -> bool:
1374
+ """
1375
+ Check if a user can perform an action on a resource.
1376
+
1377
+ This is a simple boolean check that returns True/False
1378
+ without raising exceptions.
1379
+
1380
+ Args:
1381
+ user: The user context.
1382
+ action: The action to perform (e.g., "execute", "read").
1383
+ resource: The resource name.
1384
+ context: Additional context for the policy.
1385
+
1386
+ Returns:
1387
+ True if authorized, False otherwise.
1388
+
1389
+ Example:
1390
+ >>> if auth.can(user, "execute", "database_query"):
1391
+ ... result = await database_query(query)
1392
+ """
1393
+ result = self.check(user, action, resource, context)
1394
+ return result.allowed
1395
+
1396
+ def check(
1397
+ self,
1398
+ user: UserContext,
1399
+ action: str,
1400
+ resource: str,
1401
+ context: dict[str, Any] | None = None,
1402
+ ) -> AuthorizationResult:
1403
+ """
1404
+ Check authorization and return detailed result.
1405
+
1406
+ Unlike `can()`, this method returns an AuthorizationResult
1407
+ with the reason for the decision and policies evaluated.
1408
+
1409
+ Args:
1410
+ user: The user context.
1411
+ action: The action to perform.
1412
+ resource: The resource name.
1413
+ context: Additional context for the policy.
1414
+
1415
+ Returns:
1416
+ AuthorizationResult with allowed status and details.
1417
+
1418
+ Example:
1419
+ >>> result = auth.check(user, "execute", "database_query")
1420
+ >>> if not result.allowed:
1421
+ ... print(f"Denied: {result.reason}")
1422
+ """
1423
+ context = context or {}
1424
+
1425
+ # Try policy registry first (Pundit-style)
1426
+ try:
1427
+ policy_class = self._registry.get_policy(resource)
1428
+ except PolicyNotFoundError:
1429
+ policy_class = None
1430
+
1431
+ if policy_class:
1432
+ policy = policy_class(user, resource)
1433
+ method_name = f"can_{action}"
1434
+
1435
+ if hasattr(policy, method_name):
1436
+ method = getattr(policy, method_name)
1437
+ try:
1438
+ allowed = method(context)
1439
+ return AuthorizationResult(
1440
+ allowed=bool(allowed),
1441
+ reason=f"Policy {policy_class.__name__}.{method_name} returned {allowed}",
1442
+ policies_evaluated=[policy_class.__name__],
1443
+ )
1444
+ except Exception as e:
1445
+ logger.error(f"Policy {policy_class.__name__}.{method_name} raised: {e}")
1446
+ return AuthorizationResult.deny(
1447
+ reason=f"Policy evaluation failed: {e}",
1448
+ policies=[policy_class.__name__],
1449
+ )
1450
+
1451
+ # Fall back to policy engine
1452
+ result = self._engine.evaluate(user, action, resource, context)
1453
+ return result
1454
+
1455
+ def authorize_or_raise(
1456
+ self,
1457
+ user: UserContext,
1458
+ action: str,
1459
+ resource: str,
1460
+ context: dict[str, Any] | None = None,
1461
+ ) -> AuthorizationResult:
1462
+ """
1463
+ Check authorization and raise if denied.
1464
+
1465
+ Args:
1466
+ user: The user context.
1467
+ action: The action to perform.
1468
+ resource: The resource name.
1469
+ context: Additional context.
1470
+
1471
+ Returns:
1472
+ AuthorizationResult if authorized.
1473
+
1474
+ Raises:
1475
+ AuthorizationError: If authorization is denied.
1476
+ """
1477
+ result = self.check(user, action, resource, context)
1478
+
1479
+ if not result.allowed:
1480
+ raise AuthorizationError(
1481
+ user=user.user_id,
1482
+ action=action,
1483
+ resource=resource,
1484
+ reason=result.reason,
1485
+ )
1486
+
1487
+ return result
1488
+
1489
+ # ==================== Decorator ====================
1490
+
1491
+ def authorize(
1492
+ self,
1493
+ action: str,
1494
+ resource: str | None = None,
1495
+ user_param: str = "user",
1496
+ agent_param: str = "agent",
1497
+ validate_schema: bool = True,
1498
+ apply_rate_limit: bool = True,
1499
+ check_circuit_breaker: bool = True,
1500
+ log_audit: bool = True,
1501
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]:
1502
+ """
1503
+ Decorator to protect a function with authorization.
1504
+
1505
+ Wraps the function with:
1506
+ 1. Schema validation (if registered)
1507
+ 2. Authorization check
1508
+ 3. Rate limiting
1509
+ 4. Circuit breaker
1510
+ 5. Audit logging
1511
+
1512
+ Works with both sync and async functions.
1513
+
1514
+ Args:
1515
+ action: The action being performed (e.g., "execute").
1516
+ resource: The resource name (defaults to function name).
1517
+ user_param: Parameter name containing UserContext.
1518
+ agent_param: Parameter name containing AgentContext.
1519
+ validate_schema: Whether to validate against registered schema.
1520
+ apply_rate_limit: Whether to apply rate limiting.
1521
+ check_circuit_breaker: Whether to check circuit breaker.
1522
+ log_audit: Whether to log audit events.
1523
+
1524
+ Returns:
1525
+ A decorator function.
1526
+
1527
+ Example:
1528
+ >>> @auth.authorize("execute", resource="search")
1529
+ ... async def search_tool(query: str, user: UserContext = None):
1530
+ ... return await perform_search(query)
1531
+ """
1532
+ def decorator(func: Callable[P, T]) -> Callable[P, T]:
1533
+ # Determine resource name
1534
+ nonlocal resource
1535
+ if resource is None:
1536
+ resource = func.__name__
1537
+
1538
+ # Check if function is async
1539
+ is_async = inspect.iscoroutinefunction(func)
1540
+
1541
+ if is_async:
1542
+ @functools.wraps(func)
1543
+ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
1544
+ return await self._execute_with_auth(
1545
+ func=func,
1546
+ args=args,
1547
+ kwargs=kwargs,
1548
+ action=action,
1549
+ resource=resource,
1550
+ user_param=user_param,
1551
+ agent_param=agent_param,
1552
+ validate_schema=validate_schema,
1553
+ apply_rate_limit=apply_rate_limit,
1554
+ check_circuit_breaker=check_circuit_breaker,
1555
+ log_audit=log_audit,
1556
+ is_async=True,
1557
+ )
1558
+ return async_wrapper # type: ignore
1559
+ else:
1560
+ @functools.wraps(func)
1561
+ def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
1562
+ # Run in event loop if one exists, otherwise synchronously
1563
+ return self._execute_with_auth_sync(
1564
+ func=func,
1565
+ args=args,
1566
+ kwargs=kwargs,
1567
+ action=action,
1568
+ resource=resource,
1569
+ user_param=user_param,
1570
+ agent_param=agent_param,
1571
+ validate_schema=validate_schema,
1572
+ apply_rate_limit=apply_rate_limit,
1573
+ check_circuit_breaker=check_circuit_breaker,
1574
+ log_audit=log_audit,
1575
+ )
1576
+ return sync_wrapper # type: ignore
1577
+
1578
+ return decorator
1579
+
1580
+ async def _execute_with_auth(
1581
+ self,
1582
+ func: Callable[..., Any],
1583
+ args: tuple[Any, ...],
1584
+ kwargs: dict[str, Any],
1585
+ action: str,
1586
+ resource: str,
1587
+ user_param: str,
1588
+ agent_param: str,
1589
+ validate_schema: bool,
1590
+ apply_rate_limit: bool,
1591
+ check_circuit_breaker: bool,
1592
+ log_audit: bool,
1593
+ is_async: bool,
1594
+ ) -> Any:
1595
+ """Execute function with all authorization checks (async version)."""
1596
+ # Extract user and agent from kwargs or context
1597
+ user = kwargs.get(user_param) or get_current_user()
1598
+ agent = kwargs.get(agent_param) or get_current_agent()
1599
+
1600
+ if user is None:
1601
+ raise AuthorizationError(
1602
+ user="unknown",
1603
+ action=action,
1604
+ resource=resource,
1605
+ reason="No user context provided",
1606
+ )
1607
+
1608
+ # Build context from arguments
1609
+ context = dict(kwargs)
1610
+ context["_args"] = args
1611
+
1612
+ tool_request = ToolCallRequest(
1613
+ tool_name=resource,
1614
+ arguments=context,
1615
+ )
1616
+
1617
+ auth_result: AuthorizationResult | None = None
1618
+ execution_result: dict[str, Any] | None = None
1619
+ error_message: str | None = None
1620
+
1621
+ try:
1622
+ # 1. Schema validation
1623
+ if validate_schema:
1624
+ self._validate_schema(resource, context)
1625
+
1626
+ # 2. IDOR protection
1627
+ if self._idor_protector:
1628
+ violations = self._idor_protector.validate_arguments(user.user_id, context)
1629
+ if violations:
1630
+ param, res_type, obj_id = violations[0]
1631
+ raise IDORViolationError(
1632
+ user_id=user.user_id,
1633
+ resource_type=res_type,
1634
+ object_id=obj_id,
1635
+ )
1636
+
1637
+ # 3. Authorization check
1638
+ auth_result = self.authorize_or_raise(user, action, resource, context)
1639
+
1640
+ # 4. Rate limiting
1641
+ if apply_rate_limit and self._rate_limiter:
1642
+ self._apply_rate_limit(user, resource)
1643
+
1644
+ # 5. Circuit breaker
1645
+ if check_circuit_breaker and self._circuit_breakers:
1646
+ breaker = self._circuit_breakers.get(resource)
1647
+ if not breaker.is_available():
1648
+ stats = breaker.stats
1649
+ raise CircuitOpenError(
1650
+ circuit_name=resource,
1651
+ failure_count=stats.consecutive_failures,
1652
+ reset_timeout=breaker.reset_timeout,
1653
+ last_failure=stats.last_failure_error,
1654
+ )
1655
+
1656
+ # 6. Execute function
1657
+ if is_async:
1658
+ result = await func(*args, **kwargs)
1659
+ else:
1660
+ result = func(*args, **kwargs)
1661
+
1662
+ # Record success for circuit breaker
1663
+ if self._circuit_breakers:
1664
+ breaker = self._circuit_breakers.get(resource)
1665
+ breaker._record_success()
1666
+
1667
+ execution_result = {"success": True, "result_type": type(result).__name__}
1668
+ return result
1669
+
1670
+ except (AuthorizationError, PolicyViolation, SchemaValidationError,
1671
+ RateLimitExceeded, CircuitOpenError, IDORViolationError):
1672
+ # Re-raise Proxilion exceptions
1673
+ raise
1674
+
1675
+ except Exception as e:
1676
+ # Record failure for circuit breaker
1677
+ if self._circuit_breakers:
1678
+ breaker = self._circuit_breakers.get(resource)
1679
+ breaker._record_failure(e)
1680
+
1681
+ error_message = str(e)
1682
+ raise
1683
+
1684
+ finally:
1685
+ # 7. Audit logging
1686
+ if log_audit:
1687
+ self._log_audit_event(
1688
+ user=user,
1689
+ agent=agent,
1690
+ tool_request=tool_request,
1691
+ auth_result=auth_result,
1692
+ execution_result=execution_result,
1693
+ error_message=error_message,
1694
+ )
1695
+
1696
+ def _execute_with_auth_sync(
1697
+ self,
1698
+ func: Callable[..., Any],
1699
+ args: tuple[Any, ...],
1700
+ kwargs: dict[str, Any],
1701
+ action: str,
1702
+ resource: str,
1703
+ user_param: str,
1704
+ agent_param: str,
1705
+ validate_schema: bool,
1706
+ apply_rate_limit: bool,
1707
+ check_circuit_breaker: bool,
1708
+ log_audit: bool,
1709
+ ) -> Any:
1710
+ """Execute function with all authorization checks (sync version)."""
1711
+ # Extract user and agent from kwargs or context
1712
+ user = kwargs.get(user_param) or get_current_user()
1713
+ agent = kwargs.get(agent_param) or get_current_agent()
1714
+
1715
+ if user is None:
1716
+ raise AuthorizationError(
1717
+ user="unknown",
1718
+ action=action,
1719
+ resource=resource,
1720
+ reason="No user context provided",
1721
+ )
1722
+
1723
+ # Build context from arguments
1724
+ context = dict(kwargs)
1725
+ context["_args"] = args
1726
+
1727
+ tool_request = ToolCallRequest(
1728
+ tool_name=resource,
1729
+ arguments=context,
1730
+ )
1731
+
1732
+ auth_result: AuthorizationResult | None = None
1733
+ execution_result: dict[str, Any] | None = None
1734
+ error_message: str | None = None
1735
+
1736
+ try:
1737
+ # 1. Schema validation
1738
+ if validate_schema:
1739
+ self._validate_schema(resource, context)
1740
+
1741
+ # 2. IDOR protection
1742
+ if self._idor_protector:
1743
+ violations = self._idor_protector.validate_arguments(user.user_id, context)
1744
+ if violations:
1745
+ param, res_type, obj_id = violations[0]
1746
+ raise IDORViolationError(
1747
+ user_id=user.user_id,
1748
+ resource_type=res_type,
1749
+ object_id=obj_id,
1750
+ )
1751
+
1752
+ # 3. Authorization check
1753
+ auth_result = self.authorize_or_raise(user, action, resource, context)
1754
+
1755
+ # 4. Rate limiting
1756
+ if apply_rate_limit and self._rate_limiter:
1757
+ self._apply_rate_limit(user, resource)
1758
+
1759
+ # 5. Circuit breaker
1760
+ if check_circuit_breaker and self._circuit_breakers:
1761
+ breaker = self._circuit_breakers.get(resource)
1762
+ if not breaker.is_available():
1763
+ stats = breaker.stats
1764
+ raise CircuitOpenError(
1765
+ circuit_name=resource,
1766
+ failure_count=stats.consecutive_failures,
1767
+ reset_timeout=breaker.reset_timeout,
1768
+ last_failure=stats.last_failure_error,
1769
+ )
1770
+
1771
+ # 6. Execute function
1772
+ result = func(*args, **kwargs)
1773
+
1774
+ # Record success for circuit breaker
1775
+ if self._circuit_breakers:
1776
+ breaker = self._circuit_breakers.get(resource)
1777
+ breaker._record_success()
1778
+
1779
+ execution_result = {"success": True, "result_type": type(result).__name__}
1780
+ return result
1781
+
1782
+ except (AuthorizationError, PolicyViolation, SchemaValidationError,
1783
+ RateLimitExceeded, CircuitOpenError, IDORViolationError):
1784
+ raise
1785
+
1786
+ except Exception as e:
1787
+ if self._circuit_breakers:
1788
+ breaker = self._circuit_breakers.get(resource)
1789
+ breaker._record_failure(e)
1790
+
1791
+ error_message = str(e)
1792
+ raise
1793
+
1794
+ finally:
1795
+ if log_audit:
1796
+ self._log_audit_event(
1797
+ user=user,
1798
+ agent=agent,
1799
+ tool_request=tool_request,
1800
+ auth_result=auth_result,
1801
+ execution_result=execution_result,
1802
+ error_message=error_message,
1803
+ )
1804
+
1805
+ def _validate_schema(self, tool_name: str, arguments: dict[str, Any]) -> None:
1806
+ """Validate arguments against registered schema."""
1807
+ result = self._schema_validator.validate(tool_name, arguments)
1808
+ if not result.valid:
1809
+ raise SchemaValidationError(
1810
+ tool_name=tool_name,
1811
+ errors=result.errors,
1812
+ )
1813
+
1814
+ def _apply_rate_limit(self, user: UserContext, tool_name: str) -> None:
1815
+ """Apply rate limiting for the user and tool."""
1816
+ if self._rate_limiter:
1817
+ self._rate_limiter.check_rate_limit(user.user_id, tool_name)
1818
+
1819
+ def _log_audit_event(
1820
+ self,
1821
+ user: UserContext,
1822
+ agent: AgentContext | None,
1823
+ tool_request: ToolCallRequest,
1824
+ auth_result: AuthorizationResult | None,
1825
+ execution_result: dict[str, Any] | None,
1826
+ error_message: str | None,
1827
+ ) -> None:
1828
+ """Log an audit event."""
1829
+ # Note: event_type is determined internally by log_authorization based on allowed flag
1830
+
1831
+ # Filter out non-serializable arguments
1832
+ filtered_args = {}
1833
+ for key, value in tool_request.arguments.items():
1834
+ if key.startswith("_"):
1835
+ continue
1836
+ if isinstance(value, (str, int, float, bool, list, dict, type(None))):
1837
+ filtered_args[key] = value
1838
+ else:
1839
+ filtered_args[key] = f"<{type(value).__name__}>"
1840
+
1841
+ self._audit_logger.log_authorization(
1842
+ user_id=user.user_id,
1843
+ user_roles=list(user.roles),
1844
+ tool_name=tool_request.tool_name,
1845
+ tool_arguments=filtered_args,
1846
+ allowed=auth_result.allowed if auth_result else False,
1847
+ reason=auth_result.reason if auth_result else None,
1848
+ policies_evaluated=auth_result.policies_evaluated if auth_result else [],
1849
+ session_id=user.session_id,
1850
+ user_attributes=dict(user.attributes),
1851
+ agent_id=agent.agent_id if agent else None,
1852
+ agent_capabilities=list(agent.capabilities) if agent else [],
1853
+ agent_trust_score=agent.trust_score if agent else None,
1854
+ execution_result=execution_result,
1855
+ )
1856
+
1857
+ # ==================== Context Managers ====================
1858
+
1859
+ @contextmanager
1860
+ def user_context(self, user: UserContext):
1861
+ """
1862
+ Context manager to set the current user for authorization.
1863
+
1864
+ All authorization checks within this context will use
1865
+ the provided user unless overridden.
1866
+
1867
+ Args:
1868
+ user: The user context to set.
1869
+
1870
+ Yields:
1871
+ The user context.
1872
+
1873
+ Example:
1874
+ >>> with auth.user_context(user) as ctx:
1875
+ ... result = await tool_function(args)
1876
+ """
1877
+ token = _current_user.set(user)
1878
+ try:
1879
+ yield user
1880
+ finally:
1881
+ _current_user.reset(token)
1882
+
1883
+ @contextmanager
1884
+ def agent_context(self, agent: AgentContext):
1885
+ """
1886
+ Context manager to set the current agent.
1887
+
1888
+ Args:
1889
+ agent: The agent context to set.
1890
+
1891
+ Yields:
1892
+ The agent context.
1893
+ """
1894
+ token = _current_agent.set(agent)
1895
+ try:
1896
+ yield agent
1897
+ finally:
1898
+ _current_agent.reset(token)
1899
+
1900
+ @contextmanager
1901
+ def session(self, user: UserContext, agent: AgentContext | None = None):
1902
+ """
1903
+ Context manager to set both user and agent context.
1904
+
1905
+ Args:
1906
+ user: The user context.
1907
+ agent: Optional agent context.
1908
+
1909
+ Yields:
1910
+ Tuple of (user, agent).
1911
+
1912
+ Example:
1913
+ >>> with auth.session(user, agent) as (u, a):
1914
+ ... result = await tool_function(args)
1915
+ """
1916
+ user_token = _current_user.set(user)
1917
+ agent_token = _current_agent.set(agent) if agent else None
1918
+ try:
1919
+ yield (user, agent)
1920
+ finally:
1921
+ _current_user.reset(user_token)
1922
+ if agent_token:
1923
+ _current_agent.reset(agent_token)
1924
+
1925
+ # ==================== Resilience Methods ====================
1926
+
1927
+ def get_retry_policy(self) -> RetryPolicy:
1928
+ """Get the current retry policy."""
1929
+ return self._retry_policy
1930
+
1931
+ def set_retry_policy(self, policy: RetryPolicy) -> None:
1932
+ """
1933
+ Set the retry policy.
1934
+
1935
+ Args:
1936
+ policy: The retry policy to use.
1937
+ """
1938
+ self._retry_policy = policy
1939
+
1940
+ def get_graceful_degradation(self) -> GracefulDegradation | None:
1941
+ """Get the graceful degradation instance."""
1942
+ return self._graceful_degradation
1943
+
1944
+ def set_graceful_degradation(
1945
+ self, degradation: GracefulDegradation | None
1946
+ ) -> None:
1947
+ """
1948
+ Set the graceful degradation instance.
1949
+
1950
+ Args:
1951
+ degradation: The degradation instance, or None to disable.
1952
+ """
1953
+ self._graceful_degradation = degradation
1954
+
1955
+ def get_current_tier(self) -> DegradationTier:
1956
+ """
1957
+ Get the current degradation tier.
1958
+
1959
+ Returns:
1960
+ Current tier, or FULL if degradation is disabled.
1961
+ """
1962
+ if self._graceful_degradation:
1963
+ return self._graceful_degradation.get_current_tier()
1964
+ return DegradationTier.FULL
1965
+
1966
+ def is_tool_available_at_tier(self, tool_name: str) -> bool:
1967
+ """
1968
+ Check if a tool is available at the current degradation tier.
1969
+
1970
+ Args:
1971
+ tool_name: Name of the tool to check.
1972
+
1973
+ Returns:
1974
+ True if available (or degradation is disabled).
1975
+ """
1976
+ if self._graceful_degradation:
1977
+ return self._graceful_degradation.is_tool_available(tool_name)
1978
+ return True
1979
+
1980
+ def is_model_available_at_tier(self, model_name: str) -> bool:
1981
+ """
1982
+ Check if a model is available at the current degradation tier.
1983
+
1984
+ Args:
1985
+ model_name: Name of the model to check.
1986
+
1987
+ Returns:
1988
+ True if available (or degradation is disabled).
1989
+ """
1990
+ if self._graceful_degradation:
1991
+ return self._graceful_degradation.is_model_available(model_name)
1992
+ return True
1993
+
1994
+ def record_operation_failure(self, component: str) -> None:
1995
+ """
1996
+ Record an operation failure for degradation tracking.
1997
+
1998
+ May trigger automatic tier degradation if threshold is reached.
1999
+
2000
+ Args:
2001
+ component: Name of the component that failed.
2002
+ """
2003
+ if self._graceful_degradation:
2004
+ self._graceful_degradation.record_failure(component)
2005
+
2006
+ def record_operation_success(self, component: str) -> None:
2007
+ """
2008
+ Record an operation success for degradation tracking.
2009
+
2010
+ May trigger automatic tier recovery if threshold is reached.
2011
+
2012
+ Args:
2013
+ component: Name of the component that succeeded.
2014
+ """
2015
+ if self._graceful_degradation:
2016
+ self._graceful_degradation.record_success(component)
2017
+
2018
+ def create_retry_decorator(
2019
+ self,
2020
+ policy: RetryPolicy | None = None,
2021
+ on_retry: Any | None = None,
2022
+ ) -> Any:
2023
+ """
2024
+ Create a retry decorator with the configured or specified policy.
2025
+
2026
+ Args:
2027
+ policy: Override retry policy (uses default if None).
2028
+ on_retry: Optional callback for retry events.
2029
+
2030
+ Returns:
2031
+ Decorator function.
2032
+
2033
+ Example:
2034
+ >>> @auth.create_retry_decorator()
2035
+ ... async def call_llm():
2036
+ ... return await client.chat.completions.create(...)
2037
+ """
2038
+ effective_policy = policy or self._retry_policy
2039
+ return retry_with_backoff(policy=effective_policy, on_retry=on_retry)
2040
+
2041
+ def create_fallback_chain(
2042
+ self,
2043
+ options: list[FallbackOption] | None = None,
2044
+ ) -> FallbackChain[Any]:
2045
+ """
2046
+ Create a fallback chain.
2047
+
2048
+ Args:
2049
+ options: Initial fallback options.
2050
+
2051
+ Returns:
2052
+ FallbackChain instance.
2053
+
2054
+ Example:
2055
+ >>> chain = auth.create_fallback_chain()
2056
+ >>> chain.add_option(FallbackOption("primary", primary_handler))
2057
+ >>> chain.add_option(FallbackOption("backup", backup_handler))
2058
+ >>> result = await chain.execute_async()
2059
+ """
2060
+ return FallbackChain(options=options)
2061
+
2062
+ def create_model_fallback(self) -> ModelFallback:
2063
+ """
2064
+ Create a model fallback chain.
2065
+
2066
+ Returns:
2067
+ ModelFallback instance.
2068
+
2069
+ Example:
2070
+ >>> fallback = auth.create_model_fallback()
2071
+ >>> fallback.add_model("claude-opus", call_claude)
2072
+ >>> fallback.add_model("gpt-4o", call_gpt)
2073
+ >>> result = await fallback.complete(prompt="Hello")
2074
+ """
2075
+ return ModelFallback()
2076
+
2077
+ def create_tool_fallback(self) -> ToolFallback:
2078
+ """
2079
+ Create a tool fallback chain.
2080
+
2081
+ Returns:
2082
+ ToolFallback instance.
2083
+
2084
+ Example:
2085
+ >>> fallback = auth.create_tool_fallback()
2086
+ >>> fallback.add_tool("google_search", google_search)
2087
+ >>> fallback.add_tool("bing_search", bing_search)
2088
+ >>> result = await fallback.invoke(query="test")
2089
+ """
2090
+ return ToolFallback()
2091
+
2092
+ async def execute_with_retry(
2093
+ self,
2094
+ func: Any,
2095
+ *args: Any,
2096
+ policy: RetryPolicy | None = None,
2097
+ **kwargs: Any,
2098
+ ) -> Any:
2099
+ """
2100
+ Execute a function with retry logic.
2101
+
2102
+ Args:
2103
+ func: Async function to execute.
2104
+ *args: Positional arguments for the function.
2105
+ policy: Override retry policy.
2106
+ **kwargs: Keyword arguments for the function.
2107
+
2108
+ Returns:
2109
+ Function result on success.
2110
+
2111
+ Raises:
2112
+ Exception: The last exception if all retries fail.
2113
+
2114
+ Example:
2115
+ >>> result = await auth.execute_with_retry(
2116
+ ... call_llm_api,
2117
+ ... prompt="Hello",
2118
+ ... )
2119
+ """
2120
+ effective_policy = policy or self._retry_policy
2121
+ return await retry_async(func, *args, policy=effective_policy, **kwargs)
2122
+
2123
+ # ==================== Streaming Methods ====================
2124
+
2125
+ def create_tool_call_detector(
2126
+ self,
2127
+ provider: str = "auto",
2128
+ ) -> StreamingToolCallDetector:
2129
+ """
2130
+ Create a streaming tool call detector.
2131
+
2132
+ Args:
2133
+ provider: LLM provider ("openai", "anthropic", "google", or "auto").
2134
+
2135
+ Returns:
2136
+ StreamingToolCallDetector instance.
2137
+
2138
+ Example:
2139
+ >>> detector = auth.create_tool_call_detector()
2140
+ >>> async for chunk in llm_stream:
2141
+ ... events = detector.process_chunk(chunk)
2142
+ ... for event in events:
2143
+ ... if event.type == StreamEventType.TOOL_CALL_END:
2144
+ ... tool_call = event.tool_call
2145
+ ... auth.authorize_tool_call(user, tool_call)
2146
+ """
2147
+ return StreamingToolCallDetector(provider=provider)
2148
+
2149
+ def create_stream_transformer(self) -> StreamTransformer:
2150
+ """
2151
+ Create a stream transformer with optional output guard integration.
2152
+
2153
+ Returns:
2154
+ StreamTransformer instance.
2155
+
2156
+ Example:
2157
+ >>> transformer = auth.create_stream_transformer()
2158
+ >>> transformer.add_filter(my_filter)
2159
+ >>> async for chunk in transformer.transform(stream):
2160
+ ... yield chunk
2161
+ """
2162
+ transformer = StreamTransformer()
2163
+
2164
+ # If output guard is configured, add it as a filter
2165
+ if self._output_guard:
2166
+ def guard_filter(content: str) -> str | None:
2167
+ result = self._output_guard.check(content)
2168
+ if result.action == GuardAction.BLOCK:
2169
+ return None
2170
+ elif result.action == GuardAction.SANITIZE:
2171
+ return self._output_guard.redact(content)
2172
+ return content
2173
+
2174
+ transformer.add_filter(guard_filter)
2175
+
2176
+ return transformer
2177
+
2178
+ def create_guarded_stream(
2179
+ self,
2180
+ stream: Any, # AsyncIterator[str]
2181
+ ) -> Any: # AsyncIterator[str]
2182
+ """
2183
+ Create a stream filtered by the configured output guard.
2184
+
2185
+ Args:
2186
+ stream: Source async iterator of strings.
2187
+
2188
+ Returns:
2189
+ Filtered async iterator.
2190
+
2191
+ Example:
2192
+ >>> async for chunk in auth.create_guarded_stream(llm_stream):
2193
+ ... ws.send(chunk) # Filtered content
2194
+ """
2195
+ if self._output_guard:
2196
+ return create_guarded_stream(stream, self._output_guard)
2197
+ return stream
2198
+
2199
+ def create_authorized_stream(
2200
+ self,
2201
+ stream: Any, # AsyncIterator[Any]
2202
+ user: UserContext,
2203
+ detector: StreamingToolCallDetector | None = None,
2204
+ ) -> Any: # AsyncIterator[StreamEvent]
2205
+ """
2206
+ Create a stream that authorizes tool calls.
2207
+
2208
+ Tool calls detected in the stream are authorized before
2209
+ the TOOL_CALL_END event is emitted.
2210
+
2211
+ Args:
2212
+ stream: Raw LLM streaming chunks.
2213
+ user: User context for authorization.
2214
+ detector: Optional detector instance.
2215
+
2216
+ Returns:
2217
+ Async iterator of StreamEvents with authorized tool calls.
2218
+
2219
+ Example:
2220
+ >>> async for event in auth.create_authorized_stream(llm_stream, user):
2221
+ ... if event.type == StreamEventType.TOOL_CALL_END:
2222
+ ... # Tool call is authorized
2223
+ ... result = execute_tool(event.tool_call)
2224
+ ... elif event.type == StreamEventType.ERROR:
2225
+ ... # Tool call was not authorized
2226
+ ... handle_unauthorized(event.error)
2227
+ """
2228
+
2229
+ def authorizer(tool_call: DetectedToolCall) -> bool:
2230
+ return self.can(user, "execute", tool_call.name)
2231
+
2232
+ return create_authorization_stream(stream, authorizer, detector)
2233
+
2234
+ async def process_stream_with_authorization(
2235
+ self,
2236
+ stream: Any, # AsyncIterator[Any]
2237
+ user: UserContext,
2238
+ provider: str = "auto",
2239
+ on_text: Any | None = None, # Callable[[str], None] | Callable[[str], Awaitable[None]]
2240
+ # Callable[[DetectedToolCall], None] | Callable[[DetectedToolCall], Awaitable[None]]
2241
+ on_tool_call: Any | None = None,
2242
+ # Callable[[DetectedToolCall, str], None] |
2243
+ # Callable[[DetectedToolCall, str], Awaitable[None]]
2244
+ on_unauthorized: Any | None = None,
2245
+ ) -> dict[str, Any]:
2246
+ """
2247
+ Process a stream with authorization and callbacks.
2248
+
2249
+ A high-level method that handles stream processing with
2250
+ callbacks for text, authorized tool calls, and unauthorized attempts.
2251
+
2252
+ Args:
2253
+ stream: Raw LLM streaming chunks.
2254
+ user: User context for authorization.
2255
+ provider: LLM provider for detection.
2256
+ on_text: Callback for text chunks.
2257
+ on_tool_call: Callback for authorized tool calls.
2258
+ on_unauthorized: Callback for unauthorized tool calls.
2259
+
2260
+ Returns:
2261
+ Dictionary with processing results including:
2262
+ - text: Full accumulated text
2263
+ - tool_calls: List of authorized tool calls
2264
+ - unauthorized_calls: List of unauthorized tool calls
2265
+ - stats: Processing statistics
2266
+
2267
+ Example:
2268
+ >>> async def handle_text(text):
2269
+ ... await ws.send(text)
2270
+ >>>
2271
+ >>> async def handle_tool_call(tool_call):
2272
+ ... result = await execute_tool(tool_call)
2273
+ ... await ws.send(f"Tool result: {result}")
2274
+ >>>
2275
+ >>> result = await auth.process_stream_with_authorization(
2276
+ ... stream, user,
2277
+ ... on_text=handle_text,
2278
+ ... on_tool_call=handle_tool_call,
2279
+ ... )
2280
+ """
2281
+ import inspect
2282
+
2283
+ detector = self.create_tool_call_detector(provider)
2284
+ text_buffer = []
2285
+ authorized_calls: list[DetectedToolCall] = []
2286
+ unauthorized_calls: list[tuple[DetectedToolCall, str]] = []
2287
+
2288
+ async def _call_handler(handler: Any, *args: Any) -> None:
2289
+ if handler is None:
2290
+ return
2291
+ result = handler(*args)
2292
+ if inspect.iscoroutine(result):
2293
+ await result
2294
+
2295
+ async for chunk in stream:
2296
+ events = detector.process_chunk(chunk)
2297
+
2298
+ for event in events:
2299
+ if event.type == StreamEventType.TEXT and event.content:
2300
+ # Apply output guard if configured
2301
+ content = event.content
2302
+ if self._output_guard:
2303
+ guard_result = self._output_guard.check(content)
2304
+ if guard_result.action == GuardAction.BLOCK:
2305
+ continue
2306
+ elif guard_result.action == GuardAction.SANITIZE:
2307
+ content = self._output_guard.redact(content)
2308
+
2309
+ text_buffer.append(content)
2310
+ await _call_handler(on_text, content)
2311
+
2312
+ elif event.type == StreamEventType.TOOL_CALL_END and event.tool_call:
2313
+ tool_call = event.tool_call
2314
+ # Check authorization
2315
+ if self.can(user, "execute", tool_call.name):
2316
+ authorized_calls.append(tool_call)
2317
+ await _call_handler(on_tool_call, tool_call)
2318
+ else:
2319
+ reason = f"User {user.user_id} not authorized to execute {tool_call.name}"
2320
+ unauthorized_calls.append((tool_call, reason))
2321
+ await _call_handler(on_unauthorized, tool_call, reason)
2322
+
2323
+ return {
2324
+ "text": "".join(text_buffer),
2325
+ "tool_calls": authorized_calls,
2326
+ "unauthorized_calls": unauthorized_calls,
2327
+ "stats": detector.get_stats(),
2328
+ }
2329
+
2330
+ def authorize_detected_tool_call(
2331
+ self,
2332
+ user: UserContext,
2333
+ tool_call: DetectedToolCall,
2334
+ context: dict[str, Any] | None = None,
2335
+ ) -> AuthorizationResult:
2336
+ """
2337
+ Authorize a detected tool call.
2338
+
2339
+ Args:
2340
+ user: The user context.
2341
+ tool_call: The detected tool call from streaming.
2342
+ context: Additional context for the policy.
2343
+
2344
+ Returns:
2345
+ AuthorizationResult.
2346
+
2347
+ Raises:
2348
+ AuthorizationError: If not authorized.
2349
+
2350
+ Example:
2351
+ >>> for event in events:
2352
+ ... if event.type == StreamEventType.TOOL_CALL_END:
2353
+ ... result = auth.authorize_detected_tool_call(user, event.tool_call)
2354
+ ... if result.allowed:
2355
+ ... execute_tool(event.tool_call)
2356
+ """
2357
+ # Merge tool call arguments into context
2358
+ merged_context = dict(context or {})
2359
+ merged_context["tool_call_id"] = tool_call.id
2360
+ merged_context["tool_call_arguments"] = tool_call.arguments
2361
+
2362
+ return self.check(user, "execute", tool_call.name, merged_context)
2363
+
2364
+ # ==================== Tool Registry Methods ====================
2365
+
2366
+ def get_tool_registry(self) -> ToolRegistry:
2367
+ """
2368
+ Get the tool registry.
2369
+
2370
+ Returns:
2371
+ The ToolRegistry instance.
2372
+ """
2373
+ return self._tool_registry
2374
+
2375
+ def set_tool_registry(self, registry: ToolRegistry | None) -> None:
2376
+ """
2377
+ Set or replace the tool registry.
2378
+
2379
+ Args:
2380
+ registry: The registry to use. If None, creates a new default registry.
2381
+ """
2382
+ if registry is None:
2383
+ self._tool_registry = ToolRegistry()
2384
+ else:
2385
+ self._tool_registry = registry
2386
+
2387
+ def register_tool(
2388
+ self,
2389
+ tool_def: ToolDefinition,
2390
+ ) -> None:
2391
+ """
2392
+ Register a tool with the registry.
2393
+
2394
+ Args:
2395
+ tool_def: The tool definition to register.
2396
+
2397
+ Example:
2398
+ >>> auth.register_tool(ToolDefinition(
2399
+ ... name="search_web",
2400
+ ... description="Search the web",
2401
+ ... parameters={"type": "object", "properties": {...}},
2402
+ ... category=ToolCategory.SEARCH,
2403
+ ... ))
2404
+ """
2405
+ self._tool_registry.register(tool_def)
2406
+
2407
+ def unregister_tool(self, name: str) -> bool:
2408
+ """
2409
+ Unregister a tool by name.
2410
+
2411
+ Args:
2412
+ name: The tool name to unregister.
2413
+
2414
+ Returns:
2415
+ True if the tool was found and unregistered.
2416
+ """
2417
+ return self._tool_registry.unregister(name)
2418
+
2419
+ def get_tool(self, name: str) -> ToolDefinition | None:
2420
+ """
2421
+ Get a tool definition by name.
2422
+
2423
+ Args:
2424
+ name: The tool name.
2425
+
2426
+ Returns:
2427
+ The ToolDefinition or None if not found.
2428
+ """
2429
+ return self._tool_registry.get(name)
2430
+
2431
+ def list_tools(
2432
+ self,
2433
+ category: ToolCategory | None = None,
2434
+ max_risk_level: RiskLevel | None = None,
2435
+ enabled_only: bool = True,
2436
+ ) -> list[ToolDefinition]:
2437
+ """
2438
+ List registered tools with optional filtering.
2439
+
2440
+ Args:
2441
+ category: Filter by category.
2442
+ max_risk_level: Filter by maximum risk level.
2443
+ enabled_only: Only include enabled tools.
2444
+
2445
+ Returns:
2446
+ List of matching tool definitions.
2447
+
2448
+ Example:
2449
+ >>> tools = auth.list_tools(category=ToolCategory.SEARCH)
2450
+ >>> for tool in tools:
2451
+ ... print(f"{tool.name}: {tool.description}")
2452
+ """
2453
+ tools = self._tool_registry.list_all()
2454
+
2455
+ if category is not None:
2456
+ tools = [t for t in tools if t.category == category]
2457
+
2458
+ if max_risk_level is not None:
2459
+ tools = [t for t in tools if t.risk_level.value <= max_risk_level.value]
2460
+
2461
+ if enabled_only:
2462
+ tools = [t for t in tools if t.enabled]
2463
+
2464
+ return tools
2465
+
2466
+ def export_tools(
2467
+ self,
2468
+ format: str = "openai",
2469
+ category: ToolCategory | None = None,
2470
+ max_risk_level: RiskLevel | None = None,
2471
+ ) -> list[dict[str, Any]]:
2472
+ """
2473
+ Export tools to LLM provider format.
2474
+
2475
+ Args:
2476
+ format: Target format ("openai", "anthropic", or "gemini").
2477
+ category: Filter by category.
2478
+ max_risk_level: Filter by maximum risk level.
2479
+
2480
+ Returns:
2481
+ List of tools in the specified format.
2482
+
2483
+ Example:
2484
+ >>> tools = auth.export_tools(format="openai")
2485
+ >>> response = client.chat.completions.create(
2486
+ ... model="gpt-4o",
2487
+ ... messages=[...],
2488
+ ... tools=tools,
2489
+ ... )
2490
+ """
2491
+ # Get filtered tools
2492
+ if category is not None:
2493
+ tools = self._tool_registry.list_by_category(category)
2494
+ elif max_risk_level is not None:
2495
+ tools = self._tool_registry.list_by_risk_level(max_risk_level)
2496
+ else:
2497
+ tools = self._tool_registry.list_enabled()
2498
+
2499
+ # Export each tool manually
2500
+ if format == "openai":
2501
+ return [t.to_openai_format() for t in tools]
2502
+ elif format == "anthropic":
2503
+ return [t.to_anthropic_format() for t in tools]
2504
+ elif format == "gemini":
2505
+ return [t.to_gemini_format() for t in tools]
2506
+ else:
2507
+ return [t.to_dict() for t in tools]
2508
+
2509
+ def execute_tool(
2510
+ self,
2511
+ name: str,
2512
+ user: UserContext,
2513
+ authorize: bool = True,
2514
+ **kwargs: Any,
2515
+ ) -> ToolExecutionResult:
2516
+ """
2517
+ Execute a registered tool synchronously.
2518
+
2519
+ Args:
2520
+ name: The tool name.
2521
+ user: The user context for authorization.
2522
+ authorize: Whether to check authorization.
2523
+ **kwargs: Arguments to pass to the tool.
2524
+
2525
+ Returns:
2526
+ ToolExecutionResult with execution details.
2527
+
2528
+ Raises:
2529
+ AuthorizationError: If not authorized and authorize=True.
2530
+
2531
+ Example:
2532
+ >>> result = auth.execute_tool(
2533
+ ... "search_web",
2534
+ ... user,
2535
+ ... query="python async",
2536
+ ... )
2537
+ >>> if result.success:
2538
+ ... print(result.result)
2539
+ """
2540
+ if authorize:
2541
+ tool_def = self._tool_registry.get(name)
2542
+ if tool_def:
2543
+ # Check authorization
2544
+ auth_result = self.check(user, "execute", name, kwargs)
2545
+ if not auth_result.allowed:
2546
+ raise AuthorizationError(
2547
+ user=user.user_id,
2548
+ action="execute",
2549
+ resource=name,
2550
+ reason=auth_result.reason,
2551
+ )
2552
+
2553
+ # Check risk level requires approval
2554
+ if tool_def.requires_approval:
2555
+ # In a real implementation, this would trigger an approval workflow
2556
+ logger.warning(
2557
+ f"Tool {name} requires approval but automatic approval is not implemented"
2558
+ )
2559
+
2560
+ return self._tool_registry.execute(name, **kwargs)
2561
+
2562
+ async def execute_tool_async(
2563
+ self,
2564
+ name: str,
2565
+ user: UserContext,
2566
+ authorize: bool = True,
2567
+ **kwargs: Any,
2568
+ ) -> ToolExecutionResult:
2569
+ """
2570
+ Execute a registered tool asynchronously.
2571
+
2572
+ Args:
2573
+ name: The tool name.
2574
+ user: The user context for authorization.
2575
+ authorize: Whether to check authorization.
2576
+ **kwargs: Arguments to pass to the tool.
2577
+
2578
+ Returns:
2579
+ ToolExecutionResult with execution details.
2580
+
2581
+ Raises:
2582
+ AuthorizationError: If not authorized and authorize=True.
2583
+
2584
+ Example:
2585
+ >>> result = await auth.execute_tool_async(
2586
+ ... "search_web",
2587
+ ... user,
2588
+ ... query="python async",
2589
+ ... )
2590
+ >>> if result.success:
2591
+ ... print(result.result)
2592
+ """
2593
+ if authorize:
2594
+ tool_def = self._tool_registry.get(name)
2595
+ if tool_def:
2596
+ # Check authorization
2597
+ auth_result = self.check(user, "execute", name, kwargs)
2598
+ if not auth_result.allowed:
2599
+ raise AuthorizationError(
2600
+ user=user.user_id,
2601
+ action="execute",
2602
+ resource=name,
2603
+ reason=auth_result.reason,
2604
+ )
2605
+
2606
+ if tool_def.requires_approval:
2607
+ logger.warning(
2608
+ f"Tool {name} requires approval but automatic approval is not implemented"
2609
+ )
2610
+
2611
+ return await self._tool_registry.execute_async(name, **kwargs)
2612
+
2613
+ def tool(
2614
+ self,
2615
+ name: str | None = None,
2616
+ description: str | None = None,
2617
+ category: ToolCategory = ToolCategory.CUSTOM,
2618
+ risk_level: RiskLevel = RiskLevel.LOW,
2619
+ requires_approval: bool = False,
2620
+ timeout: float | None = None,
2621
+ enabled: bool = True,
2622
+ **metadata: Any,
2623
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
2624
+ """
2625
+ Decorator to register a function as a tool.
2626
+
2627
+ Registers the function with this Proxilion instance's tool registry.
2628
+
2629
+ Args:
2630
+ name: Tool name (defaults to function name).
2631
+ description: Tool description (defaults to docstring).
2632
+ category: Tool category for organization.
2633
+ risk_level: Risk level for authorization decisions.
2634
+ requires_approval: Whether tool requires explicit approval.
2635
+ timeout: Execution timeout in seconds.
2636
+ enabled: Whether tool is enabled by default.
2637
+ **metadata: Additional metadata.
2638
+
2639
+ Returns:
2640
+ Decorator function.
2641
+
2642
+ Example:
2643
+ >>> @auth.tool(
2644
+ ... name="search_web",
2645
+ ... category=ToolCategory.SEARCH,
2646
+ ... )
2647
+ ... def search_web(query: str, max_results: int = 10) -> list[dict]:
2648
+ ... return perform_search(query, max_results)
2649
+ """
2650
+ return tool(
2651
+ name=name,
2652
+ description=description,
2653
+ category=category,
2654
+ risk_level=risk_level,
2655
+ requires_approval=requires_approval,
2656
+ timeout=timeout,
2657
+ registry=self._tool_registry,
2658
+ enabled=enabled,
2659
+ **metadata,
2660
+ )
2661
+
2662
+ def enable_tool(self, name: str) -> bool:
2663
+ """
2664
+ Enable a registered tool.
2665
+
2666
+ Args:
2667
+ name: The tool name.
2668
+
2669
+ Returns:
2670
+ True if tool was found and enabled.
2671
+ """
2672
+ return self._tool_registry.enable(name)
2673
+
2674
+ def disable_tool(self, name: str) -> bool:
2675
+ """
2676
+ Disable a registered tool.
2677
+
2678
+ Args:
2679
+ name: The tool name.
2680
+
2681
+ Returns:
2682
+ True if tool was found and disabled.
2683
+ """
2684
+ return self._tool_registry.disable(name)
2685
+
2686
+ def get_tool_stats(self) -> dict[str, Any]:
2687
+ """
2688
+ Get statistics about registered tools.
2689
+
2690
+ Returns:
2691
+ Dictionary with tool statistics.
2692
+
2693
+ Example:
2694
+ >>> stats = auth.get_tool_stats()
2695
+ >>> print(f"Total tools: {stats['total']}")
2696
+ >>> print(f"Enabled: {stats['enabled']}")
2697
+ """
2698
+ tools = self._tool_registry.list_all()
2699
+
2700
+ # Count by category
2701
+ by_category: dict[str, int] = {}
2702
+ for tool_def in tools:
2703
+ cat_name = tool_def.category.name
2704
+ by_category[cat_name] = by_category.get(cat_name, 0) + 1
2705
+
2706
+ # Count by risk level
2707
+ by_risk: dict[str, int] = {}
2708
+ for tool_def in tools:
2709
+ risk_name = tool_def.risk_level.name
2710
+ by_risk[risk_name] = by_risk.get(risk_name, 0) + 1
2711
+
2712
+ return {
2713
+ "total": len(tools),
2714
+ "enabled": sum(1 for t in tools if t.enabled),
2715
+ "disabled": sum(1 for t in tools if not t.enabled),
2716
+ "requires_approval": sum(1 for t in tools if t.requires_approval),
2717
+ "by_category": by_category,
2718
+ "by_risk_level": by_risk,
2719
+ }
2720
+
2721
+ # ==================== Provider Adapter Methods ====================
2722
+
2723
+ def get_provider_adapter(
2724
+ self,
2725
+ provider: str | Provider | None = None,
2726
+ response: Any = None,
2727
+ ) -> ProviderAdapter:
2728
+ """
2729
+ Get a provider adapter.
2730
+
2731
+ Can auto-detect from response or use explicit provider name.
2732
+
2733
+ Args:
2734
+ provider: Provider name or enum.
2735
+ response: Optional response for auto-detection.
2736
+
2737
+ Returns:
2738
+ Appropriate ProviderAdapter instance.
2739
+
2740
+ Example:
2741
+ >>> adapter = auth.get_provider_adapter("openai")
2742
+ >>> tools = adapter.format_tools(auth.list_tools())
2743
+ """
2744
+ return get_adapter(provider=provider, response=response)
2745
+
2746
+ def extract_tool_calls_from_response(
2747
+ self,
2748
+ response: Any,
2749
+ provider: str | Provider | None = None,
2750
+ ) -> list[UnifiedToolCall]:
2751
+ """
2752
+ Extract tool calls from an LLM response.
2753
+
2754
+ Args:
2755
+ response: LLM response object.
2756
+ provider: Optional provider hint (auto-detected if None).
2757
+
2758
+ Returns:
2759
+ List of unified tool calls.
2760
+
2761
+ Example:
2762
+ >>> tool_calls = auth.extract_tool_calls_from_response(llm_response)
2763
+ >>> for call in tool_calls:
2764
+ ... result = auth.authorize_and_execute(user, call)
2765
+ """
2766
+ adapter = get_adapter(provider=provider, response=response)
2767
+ return adapter.extract_tool_calls(response)
2768
+
2769
+ def extract_unified_response(
2770
+ self,
2771
+ response: Any,
2772
+ provider: str | Provider | None = None,
2773
+ ) -> UnifiedResponse:
2774
+ """
2775
+ Extract a unified response from an LLM response.
2776
+
2777
+ Args:
2778
+ response: LLM response object.
2779
+ provider: Optional provider hint (auto-detected if None).
2780
+
2781
+ Returns:
2782
+ UnifiedResponse instance.
2783
+
2784
+ Example:
2785
+ >>> unified = auth.extract_unified_response(llm_response)
2786
+ >>> print(f"Content: {unified.content}")
2787
+ >>> print(f"Tool calls: {len(unified.tool_calls)}")
2788
+ """
2789
+ adapter = get_adapter(provider=provider, response=response)
2790
+ return adapter.extract_response(response)
2791
+
2792
+ def authorize_tool_calls(
2793
+ self,
2794
+ user: UserContext,
2795
+ tool_calls: list[UnifiedToolCall],
2796
+ context: dict[str, Any] | None = None,
2797
+ ) -> list[tuple[UnifiedToolCall, AuthorizationResult]]:
2798
+ """
2799
+ Authorize a list of tool calls.
2800
+
2801
+ Args:
2802
+ user: The user context.
2803
+ tool_calls: List of tool calls to authorize.
2804
+ context: Additional context for policies.
2805
+
2806
+ Returns:
2807
+ List of (tool_call, authorization_result) tuples.
2808
+
2809
+ Example:
2810
+ >>> tool_calls = auth.extract_tool_calls_from_response(response)
2811
+ >>> results = auth.authorize_tool_calls(user, tool_calls)
2812
+ >>> for call, result in results:
2813
+ ... if result.allowed:
2814
+ ... execute_tool(call)
2815
+ """
2816
+ results = []
2817
+ for call in tool_calls:
2818
+ merged_context = dict(context or {})
2819
+ merged_context["tool_call_id"] = call.id
2820
+ merged_context.update(call.arguments)
2821
+
2822
+ auth_result = self.check(user, "execute", call.name, merged_context)
2823
+ results.append((call, auth_result))
2824
+
2825
+ return results
2826
+
2827
+ def authorize_and_execute_tool_calls(
2828
+ self,
2829
+ user: UserContext,
2830
+ tool_calls: list[UnifiedToolCall],
2831
+ context: dict[str, Any] | None = None,
2832
+ ) -> list[tuple[UnifiedToolCall, ToolExecutionResult | AuthorizationResult]]:
2833
+ """
2834
+ Authorize and execute a list of tool calls.
2835
+
2836
+ For each tool call:
2837
+ 1. Check authorization
2838
+ 2. If authorized and tool is registered, execute it
2839
+ 3. Return results
2840
+
2841
+ Args:
2842
+ user: The user context.
2843
+ tool_calls: List of tool calls to process.
2844
+ context: Additional context for policies.
2845
+
2846
+ Returns:
2847
+ List of (tool_call, result) tuples where result is either
2848
+ ToolExecutionResult (if executed) or AuthorizationResult (if denied).
2849
+
2850
+ Example:
2851
+ >>> tool_calls = auth.extract_tool_calls_from_response(response)
2852
+ >>> results = auth.authorize_and_execute_tool_calls(user, tool_calls)
2853
+ >>> for call, result in results:
2854
+ ... if isinstance(result, ToolExecutionResult):
2855
+ ... print(f"{call.name}: {result.result}")
2856
+ ... else:
2857
+ ... print(f"{call.name}: DENIED - {result.reason}")
2858
+ """
2859
+ results = []
2860
+ for call in tool_calls:
2861
+ merged_context = dict(context or {})
2862
+ merged_context["tool_call_id"] = call.id
2863
+ merged_context.update(call.arguments)
2864
+
2865
+ # Check authorization
2866
+ auth_result = self.check(user, "execute", call.name, merged_context)
2867
+ if not auth_result.allowed:
2868
+ results.append((call, auth_result))
2869
+ continue
2870
+
2871
+ # Try to execute if tool is registered
2872
+ tool_def = self._tool_registry.get(call.name)
2873
+ if tool_def:
2874
+ exec_result = self._tool_registry.execute(call.name, **call.arguments)
2875
+ results.append((call, exec_result))
2876
+ else:
2877
+ # Tool not registered, return auth result indicating it's allowed
2878
+ # but caller must handle execution
2879
+ results.append((call, auth_result))
2880
+
2881
+ return results
2882
+
2883
+ def format_tool_results(
2884
+ self,
2885
+ results: list[tuple[UnifiedToolCall, Any, bool]],
2886
+ provider: str | Provider,
2887
+ ) -> list[Any]:
2888
+ """
2889
+ Format tool results for a specific provider.
2890
+
2891
+ Args:
2892
+ results: List of (tool_call, result, is_error) tuples.
2893
+ provider: Target provider.
2894
+
2895
+ Returns:
2896
+ List of provider-formatted tool result messages.
2897
+
2898
+ Example:
2899
+ >>> results = [
2900
+ ... (call1, {"temp": 72}, False),
2901
+ ... (call2, "Error message", True),
2902
+ ... ]
2903
+ >>> formatted = auth.format_tool_results(results, "openai")
2904
+ >>> messages.extend(formatted)
2905
+ """
2906
+ adapter = get_adapter(provider=provider)
2907
+ return [
2908
+ adapter.format_tool_result(call, result, is_error)
2909
+ for call, result, is_error in results
2910
+ ]
2911
+
2912
+ def export_tools_for_provider(
2913
+ self,
2914
+ provider: str | Provider,
2915
+ category: ToolCategory | None = None,
2916
+ max_risk_level: RiskLevel | None = None,
2917
+ ) -> list[dict[str, Any]]:
2918
+ """
2919
+ Export tools formatted for a specific provider.
2920
+
2921
+ Args:
2922
+ provider: Target provider.
2923
+ category: Optional category filter.
2924
+ max_risk_level: Optional risk level filter.
2925
+
2926
+ Returns:
2927
+ List of tools in provider-specific format.
2928
+
2929
+ Example:
2930
+ >>> openai_tools = auth.export_tools_for_provider("openai")
2931
+ >>> response = client.chat.completions.create(
2932
+ ... model="gpt-4o",
2933
+ ... messages=[...],
2934
+ ... tools=openai_tools,
2935
+ ... )
2936
+ """
2937
+ # Get filtered tools
2938
+ if category is not None:
2939
+ tools = self._tool_registry.list_by_category(category)
2940
+ elif max_risk_level is not None:
2941
+ tools = self._tool_registry.list_by_risk_level(max_risk_level)
2942
+ else:
2943
+ tools = self._tool_registry.list_enabled()
2944
+
2945
+ adapter = get_adapter(provider=provider)
2946
+ return adapter.format_tools(tools)
2947
+
2948
+ async def process_response_with_authorization(
2949
+ self,
2950
+ response: Any,
2951
+ user: UserContext,
2952
+ provider: str | Provider | None = None,
2953
+ execute_tools: bool = True,
2954
+ ) -> dict[str, Any]:
2955
+ """
2956
+ Process an LLM response with authorization and optional execution.
2957
+
2958
+ High-level method that:
2959
+ 1. Extracts tool calls from response
2960
+ 2. Authorizes each tool call
2961
+ 3. Optionally executes authorized tools
2962
+ 4. Returns comprehensive results
2963
+
2964
+ Args:
2965
+ response: LLM response object.
2966
+ user: User context for authorization.
2967
+ provider: Optional provider hint.
2968
+ execute_tools: Whether to execute authorized tools.
2969
+
2970
+ Returns:
2971
+ Dictionary with:
2972
+ - unified_response: The unified response
2973
+ - authorized_calls: List of authorized tool calls
2974
+ - denied_calls: List of (call, reason) for denied calls
2975
+ - execution_results: List of (call, result) if execute_tools=True
2976
+
2977
+ Example:
2978
+ >>> result = await auth.process_response_with_authorization(
2979
+ ... llm_response,
2980
+ ... user,
2981
+ ... )
2982
+ >>> for call in result["authorized_calls"]:
2983
+ ... print(f"Authorized: {call.name}")
2984
+ >>> for call, reason in result["denied_calls"]:
2985
+ ... print(f"Denied: {call.name} - {reason}")
2986
+ """
2987
+ # Extract response
2988
+ adapter = get_adapter(provider=provider, response=response)
2989
+ unified_response = adapter.extract_response(response)
2990
+
2991
+ authorized_calls = []
2992
+ denied_calls = []
2993
+ execution_results = []
2994
+
2995
+ # Process each tool call
2996
+ for call in unified_response.tool_calls:
2997
+ # Check authorization
2998
+ auth_result = self.check(
2999
+ user, "execute", call.name,
3000
+ {"tool_call_id": call.id, **call.arguments}
3001
+ )
3002
+
3003
+ if not auth_result.allowed:
3004
+ denied_calls.append((call, auth_result.reason))
3005
+ continue
3006
+
3007
+ authorized_calls.append(call)
3008
+
3009
+ # Execute if requested and tool is registered
3010
+ if execute_tools:
3011
+ tool_def = self._tool_registry.get(call.name)
3012
+ if tool_def:
3013
+ exec_result = await self._tool_registry.execute_async(
3014
+ call.name, **call.arguments
3015
+ )
3016
+ execution_results.append((call, exec_result))
3017
+
3018
+ return {
3019
+ "unified_response": unified_response,
3020
+ "authorized_calls": authorized_calls,
3021
+ "denied_calls": denied_calls,
3022
+ "execution_results": execution_results,
3023
+ }
3024
+
3025
+ # ==================== Utility Methods ====================
3026
+
3027
+ def get_audit_events(self) -> list[AuditEventV2]:
3028
+ """
3029
+ Get all audit events (for in-memory logger).
3030
+
3031
+ Returns:
3032
+ List of audit events.
3033
+ """
3034
+ if isinstance(self._audit_logger, InMemoryAuditLogger):
3035
+ return self._audit_logger.events
3036
+ return []
3037
+
3038
+ def verify_audit_chain(self) -> Any:
3039
+ """
3040
+ Verify the integrity of the audit log.
3041
+
3042
+ Returns:
3043
+ ChainVerificationResult.
3044
+ """
3045
+ return self._audit_logger.verify()
3046
+
3047
+ def close(self) -> None:
3048
+ """Close the Proxilion instance and flush audit logs."""
3049
+ if hasattr(self._audit_logger, "close"):
3050
+ self._audit_logger.close()
3051
+
3052
+ def __enter__(self) -> Proxilion:
3053
+ """Context manager entry."""
3054
+ return self
3055
+
3056
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
3057
+ """Context manager exit."""
3058
+ self.close()