proxilion 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proxilion/__init__.py +136 -0
- proxilion/audit/__init__.py +133 -0
- proxilion/audit/base_exporters.py +527 -0
- proxilion/audit/compliance/__init__.py +130 -0
- proxilion/audit/compliance/base.py +457 -0
- proxilion/audit/compliance/eu_ai_act.py +603 -0
- proxilion/audit/compliance/iso27001.py +544 -0
- proxilion/audit/compliance/soc2.py +491 -0
- proxilion/audit/events.py +493 -0
- proxilion/audit/explainability.py +1173 -0
- proxilion/audit/exporters/__init__.py +58 -0
- proxilion/audit/exporters/aws_s3.py +636 -0
- proxilion/audit/exporters/azure_storage.py +608 -0
- proxilion/audit/exporters/cloud_base.py +468 -0
- proxilion/audit/exporters/gcp_storage.py +570 -0
- proxilion/audit/exporters/multi_exporter.py +498 -0
- proxilion/audit/hash_chain.py +652 -0
- proxilion/audit/logger.py +543 -0
- proxilion/caching/__init__.py +49 -0
- proxilion/caching/tool_cache.py +633 -0
- proxilion/context/__init__.py +73 -0
- proxilion/context/context_window.py +556 -0
- proxilion/context/message_history.py +505 -0
- proxilion/context/session.py +735 -0
- proxilion/contrib/__init__.py +51 -0
- proxilion/contrib/anthropic.py +609 -0
- proxilion/contrib/google.py +1012 -0
- proxilion/contrib/langchain.py +641 -0
- proxilion/contrib/mcp.py +893 -0
- proxilion/contrib/openai.py +646 -0
- proxilion/core.py +3058 -0
- proxilion/decorators.py +966 -0
- proxilion/engines/__init__.py +287 -0
- proxilion/engines/base.py +266 -0
- proxilion/engines/casbin_engine.py +412 -0
- proxilion/engines/opa_engine.py +493 -0
- proxilion/engines/simple.py +437 -0
- proxilion/exceptions.py +887 -0
- proxilion/guards/__init__.py +54 -0
- proxilion/guards/input_guard.py +522 -0
- proxilion/guards/output_guard.py +634 -0
- proxilion/observability/__init__.py +198 -0
- proxilion/observability/cost_tracker.py +866 -0
- proxilion/observability/hooks.py +683 -0
- proxilion/observability/metrics.py +798 -0
- proxilion/observability/session_cost_tracker.py +1063 -0
- proxilion/policies/__init__.py +67 -0
- proxilion/policies/base.py +304 -0
- proxilion/policies/builtin.py +486 -0
- proxilion/policies/registry.py +376 -0
- proxilion/providers/__init__.py +201 -0
- proxilion/providers/adapter.py +468 -0
- proxilion/providers/anthropic_adapter.py +330 -0
- proxilion/providers/gemini_adapter.py +391 -0
- proxilion/providers/openai_adapter.py +294 -0
- proxilion/py.typed +0 -0
- proxilion/resilience/__init__.py +81 -0
- proxilion/resilience/degradation.py +615 -0
- proxilion/resilience/fallback.py +555 -0
- proxilion/resilience/retry.py +554 -0
- proxilion/scheduling/__init__.py +57 -0
- proxilion/scheduling/priority_queue.py +419 -0
- proxilion/scheduling/scheduler.py +459 -0
- proxilion/security/__init__.py +244 -0
- proxilion/security/agent_trust.py +968 -0
- proxilion/security/behavioral_drift.py +794 -0
- proxilion/security/cascade_protection.py +869 -0
- proxilion/security/circuit_breaker.py +428 -0
- proxilion/security/cost_limiter.py +690 -0
- proxilion/security/idor_protection.py +460 -0
- proxilion/security/intent_capsule.py +849 -0
- proxilion/security/intent_validator.py +495 -0
- proxilion/security/memory_integrity.py +767 -0
- proxilion/security/rate_limiter.py +509 -0
- proxilion/security/scope_enforcer.py +680 -0
- proxilion/security/sequence_validator.py +636 -0
- proxilion/security/trust_boundaries.py +784 -0
- proxilion/streaming/__init__.py +70 -0
- proxilion/streaming/detector.py +761 -0
- proxilion/streaming/transformer.py +674 -0
- proxilion/timeouts/__init__.py +55 -0
- proxilion/timeouts/decorators.py +477 -0
- proxilion/timeouts/manager.py +545 -0
- proxilion/tools/__init__.py +69 -0
- proxilion/tools/decorators.py +493 -0
- proxilion/tools/registry.py +732 -0
- proxilion/types.py +339 -0
- proxilion/validation/__init__.py +93 -0
- proxilion/validation/pydantic_schema.py +351 -0
- proxilion/validation/schema.py +651 -0
- proxilion-0.0.1.dist-info/METADATA +872 -0
- proxilion-0.0.1.dist-info/RECORD +94 -0
- proxilion-0.0.1.dist-info/WHEEL +4 -0
- proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
proxilion/decorators.py
ADDED
|
@@ -0,0 +1,966 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Decorators for Proxilion authorization.
|
|
3
|
+
|
|
4
|
+
This module provides standalone decorators that can be used
|
|
5
|
+
independently of the main Proxilion class for more flexibility.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import functools
|
|
12
|
+
import inspect
|
|
13
|
+
import logging
|
|
14
|
+
from abc import ABC, abstractmethod
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from typing import Any, ParamSpec, TypeVar
|
|
17
|
+
|
|
18
|
+
from proxilion.exceptions import AuthorizationError, SequenceViolationError
|
|
19
|
+
from proxilion.types import UserContext
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
P = ParamSpec("P")
|
|
24
|
+
T = TypeVar("T")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ApprovalStrategy(ABC):
|
|
28
|
+
"""
|
|
29
|
+
Abstract base class for approval strategies.
|
|
30
|
+
|
|
31
|
+
Approval strategies determine how high-risk operations
|
|
32
|
+
are approved before execution.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def request_approval(
|
|
37
|
+
self,
|
|
38
|
+
user: UserContext,
|
|
39
|
+
action: str,
|
|
40
|
+
resource: str,
|
|
41
|
+
context: dict[str, Any],
|
|
42
|
+
) -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Request approval for an action.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
user: The user requesting the action.
|
|
48
|
+
action: The action to perform.
|
|
49
|
+
resource: The resource being acted upon.
|
|
50
|
+
context: Additional context.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
True if approved, False otherwise.
|
|
54
|
+
"""
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
async def request_approval_async(
|
|
59
|
+
self,
|
|
60
|
+
user: UserContext,
|
|
61
|
+
action: str,
|
|
62
|
+
resource: str,
|
|
63
|
+
context: dict[str, Any],
|
|
64
|
+
) -> bool:
|
|
65
|
+
"""Async version of request_approval."""
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class AlwaysApproveStrategy(ApprovalStrategy):
|
|
70
|
+
"""
|
|
71
|
+
Strategy that always approves requests.
|
|
72
|
+
|
|
73
|
+
WARNING: Only use for testing or development.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def request_approval(
|
|
77
|
+
self,
|
|
78
|
+
user: UserContext,
|
|
79
|
+
action: str,
|
|
80
|
+
resource: str,
|
|
81
|
+
context: dict[str, Any],
|
|
82
|
+
) -> bool:
|
|
83
|
+
logger.warning(
|
|
84
|
+
f"AlwaysApproveStrategy: Auto-approving {action} on {resource} for {user.user_id}"
|
|
85
|
+
)
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
async def request_approval_async(
|
|
89
|
+
self,
|
|
90
|
+
user: UserContext,
|
|
91
|
+
action: str,
|
|
92
|
+
resource: str,
|
|
93
|
+
context: dict[str, Any],
|
|
94
|
+
) -> bool:
|
|
95
|
+
return self.request_approval(user, action, resource, context)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class AlwaysDenyStrategy(ApprovalStrategy):
|
|
99
|
+
"""Strategy that always denies requests."""
|
|
100
|
+
|
|
101
|
+
def request_approval(
|
|
102
|
+
self,
|
|
103
|
+
user: UserContext,
|
|
104
|
+
action: str,
|
|
105
|
+
resource: str,
|
|
106
|
+
context: dict[str, Any],
|
|
107
|
+
) -> bool:
|
|
108
|
+
logger.info(
|
|
109
|
+
f"AlwaysDenyStrategy: Denying {action} on {resource} for {user.user_id}"
|
|
110
|
+
)
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
async def request_approval_async(
|
|
114
|
+
self,
|
|
115
|
+
user: UserContext,
|
|
116
|
+
action: str,
|
|
117
|
+
resource: str,
|
|
118
|
+
context: dict[str, Any],
|
|
119
|
+
) -> bool:
|
|
120
|
+
return self.request_approval(user, action, resource, context)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class CallbackApprovalStrategy(ApprovalStrategy):
|
|
124
|
+
"""
|
|
125
|
+
Strategy that uses a callback function for approval.
|
|
126
|
+
|
|
127
|
+
Example:
|
|
128
|
+
>>> def my_approval_callback(user, action, resource, context):
|
|
129
|
+
... # Custom approval logic
|
|
130
|
+
... return user.has_role("approver")
|
|
131
|
+
>>>
|
|
132
|
+
>>> strategy = CallbackApprovalStrategy(my_approval_callback)
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
callback: Callable[[UserContext, str, str, dict[str, Any]], bool],
|
|
138
|
+
async_callback: Callable[[UserContext, str, str, dict[str, Any]], Any] | None = None,
|
|
139
|
+
) -> None:
|
|
140
|
+
"""
|
|
141
|
+
Initialize with callback functions.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
callback: Sync callback for approval.
|
|
145
|
+
async_callback: Optional async callback.
|
|
146
|
+
"""
|
|
147
|
+
self._callback = callback
|
|
148
|
+
self._async_callback = async_callback
|
|
149
|
+
|
|
150
|
+
def request_approval(
|
|
151
|
+
self,
|
|
152
|
+
user: UserContext,
|
|
153
|
+
action: str,
|
|
154
|
+
resource: str,
|
|
155
|
+
context: dict[str, Any],
|
|
156
|
+
) -> bool:
|
|
157
|
+
return self._callback(user, action, resource, context)
|
|
158
|
+
|
|
159
|
+
async def request_approval_async(
|
|
160
|
+
self,
|
|
161
|
+
user: UserContext,
|
|
162
|
+
action: str,
|
|
163
|
+
resource: str,
|
|
164
|
+
context: dict[str, Any],
|
|
165
|
+
) -> bool:
|
|
166
|
+
if self._async_callback:
|
|
167
|
+
return await self._async_callback(user, action, resource, context)
|
|
168
|
+
return self._callback(user, action, resource, context)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class QueueApprovalStrategy(ApprovalStrategy):
|
|
172
|
+
"""
|
|
173
|
+
Strategy that queues requests for later approval.
|
|
174
|
+
|
|
175
|
+
Useful for asynchronous approval workflows where
|
|
176
|
+
a human reviews pending requests.
|
|
177
|
+
|
|
178
|
+
Example:
|
|
179
|
+
>>> strategy = QueueApprovalStrategy()
|
|
180
|
+
>>>
|
|
181
|
+
>>> # Request gets queued
|
|
182
|
+
>>> @require_approval(strategy=strategy)
|
|
183
|
+
... async def delete_database(db_name, user=None):
|
|
184
|
+
... pass
|
|
185
|
+
>>>
|
|
186
|
+
>>> # Admin reviews queue
|
|
187
|
+
>>> for request in strategy.pending_requests:
|
|
188
|
+
... if should_approve(request):
|
|
189
|
+
... strategy.approve(request["id"])
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, timeout: float = 300.0) -> None:
|
|
193
|
+
"""
|
|
194
|
+
Initialize the queue strategy.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
timeout: Seconds to wait for approval (default 5 minutes).
|
|
198
|
+
"""
|
|
199
|
+
self.timeout = timeout
|
|
200
|
+
self._pending: dict[str, dict[str, Any]] = {}
|
|
201
|
+
self._approved: set[str] = set()
|
|
202
|
+
self._denied: set[str] = set()
|
|
203
|
+
self._request_counter = 0
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def pending_requests(self) -> list[dict[str, Any]]:
|
|
207
|
+
"""Get list of pending approval requests."""
|
|
208
|
+
return list(self._pending.values())
|
|
209
|
+
|
|
210
|
+
def approve(self, request_id: str) -> None:
|
|
211
|
+
"""Approve a pending request."""
|
|
212
|
+
if request_id in self._pending:
|
|
213
|
+
self._approved.add(request_id)
|
|
214
|
+
del self._pending[request_id]
|
|
215
|
+
|
|
216
|
+
def deny(self, request_id: str) -> None:
|
|
217
|
+
"""Deny a pending request."""
|
|
218
|
+
if request_id in self._pending:
|
|
219
|
+
self._denied.add(request_id)
|
|
220
|
+
del self._pending[request_id]
|
|
221
|
+
|
|
222
|
+
def request_approval(
|
|
223
|
+
self,
|
|
224
|
+
user: UserContext,
|
|
225
|
+
action: str,
|
|
226
|
+
resource: str,
|
|
227
|
+
context: dict[str, Any],
|
|
228
|
+
) -> bool:
|
|
229
|
+
"""Queue request and wait for approval (blocking)."""
|
|
230
|
+
import time
|
|
231
|
+
|
|
232
|
+
self._request_counter += 1
|
|
233
|
+
request_id = f"req_{self._request_counter}"
|
|
234
|
+
|
|
235
|
+
self._pending[request_id] = {
|
|
236
|
+
"id": request_id,
|
|
237
|
+
"user_id": user.user_id,
|
|
238
|
+
"action": action,
|
|
239
|
+
"resource": resource,
|
|
240
|
+
"context": context,
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
logger.info(f"Approval request queued: {request_id}")
|
|
244
|
+
|
|
245
|
+
# Poll for approval
|
|
246
|
+
start_time = time.time()
|
|
247
|
+
while time.time() - start_time < self.timeout:
|
|
248
|
+
if request_id in self._approved:
|
|
249
|
+
self._approved.discard(request_id)
|
|
250
|
+
return True
|
|
251
|
+
if request_id in self._denied:
|
|
252
|
+
self._denied.discard(request_id)
|
|
253
|
+
return False
|
|
254
|
+
time.sleep(0.1)
|
|
255
|
+
|
|
256
|
+
# Timeout - clean up and deny
|
|
257
|
+
self._pending.pop(request_id, None)
|
|
258
|
+
logger.warning(f"Approval request timed out: {request_id}")
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
async def request_approval_async(
|
|
262
|
+
self,
|
|
263
|
+
user: UserContext,
|
|
264
|
+
action: str,
|
|
265
|
+
resource: str,
|
|
266
|
+
context: dict[str, Any],
|
|
267
|
+
) -> bool:
|
|
268
|
+
"""Queue request and wait for approval (async)."""
|
|
269
|
+
self._request_counter += 1
|
|
270
|
+
request_id = f"req_{self._request_counter}"
|
|
271
|
+
|
|
272
|
+
self._pending[request_id] = {
|
|
273
|
+
"id": request_id,
|
|
274
|
+
"user_id": user.user_id,
|
|
275
|
+
"action": action,
|
|
276
|
+
"resource": resource,
|
|
277
|
+
"context": context,
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
logger.info(f"Approval request queued: {request_id}")
|
|
281
|
+
|
|
282
|
+
# Poll for approval
|
|
283
|
+
elapsed = 0.0
|
|
284
|
+
while elapsed < self.timeout:
|
|
285
|
+
if request_id in self._approved:
|
|
286
|
+
self._approved.discard(request_id)
|
|
287
|
+
return True
|
|
288
|
+
if request_id in self._denied:
|
|
289
|
+
self._denied.discard(request_id)
|
|
290
|
+
return False
|
|
291
|
+
await asyncio.sleep(0.1)
|
|
292
|
+
elapsed += 0.1
|
|
293
|
+
|
|
294
|
+
# Timeout
|
|
295
|
+
self._pending.pop(request_id, None)
|
|
296
|
+
logger.warning(f"Approval request timed out: {request_id}")
|
|
297
|
+
return False
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def require_approval(
|
|
301
|
+
strategy: ApprovalStrategy | None = None,
|
|
302
|
+
reason_param: str = "approval_reason",
|
|
303
|
+
user_param: str = "user",
|
|
304
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
305
|
+
"""
|
|
306
|
+
Decorator that requires approval before executing a function.
|
|
307
|
+
|
|
308
|
+
For high-risk operations that need human-in-the-loop approval.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
strategy: The approval strategy to use (default: AlwaysDenyStrategy).
|
|
312
|
+
reason_param: Parameter name to pass approval reason.
|
|
313
|
+
user_param: Parameter name containing UserContext.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
A decorator function.
|
|
317
|
+
|
|
318
|
+
Example:
|
|
319
|
+
>>> @require_approval(strategy=QueueApprovalStrategy())
|
|
320
|
+
... async def delete_all_data(user: UserContext = None):
|
|
321
|
+
... # Only runs if approved
|
|
322
|
+
... await perform_deletion()
|
|
323
|
+
"""
|
|
324
|
+
if strategy is None:
|
|
325
|
+
strategy = AlwaysDenyStrategy()
|
|
326
|
+
|
|
327
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
328
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
329
|
+
resource = func.__name__
|
|
330
|
+
|
|
331
|
+
if is_async:
|
|
332
|
+
@functools.wraps(func)
|
|
333
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
334
|
+
user = kwargs.get(user_param)
|
|
335
|
+
if user is None:
|
|
336
|
+
raise AuthorizationError(
|
|
337
|
+
user="unknown",
|
|
338
|
+
action="execute",
|
|
339
|
+
resource=resource,
|
|
340
|
+
reason="No user context provided for approval",
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
context = dict(kwargs)
|
|
344
|
+
|
|
345
|
+
approved = await strategy.request_approval_async(
|
|
346
|
+
user, "execute", resource, context
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if not approved:
|
|
350
|
+
raise AuthorizationError(
|
|
351
|
+
user=user.user_id,
|
|
352
|
+
action="execute",
|
|
353
|
+
resource=resource,
|
|
354
|
+
reason="Approval denied or timed out",
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
return await func(*args, **kwargs)
|
|
358
|
+
|
|
359
|
+
return async_wrapper # type: ignore
|
|
360
|
+
else:
|
|
361
|
+
@functools.wraps(func)
|
|
362
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
363
|
+
user = kwargs.get(user_param)
|
|
364
|
+
if user is None:
|
|
365
|
+
raise AuthorizationError(
|
|
366
|
+
user="unknown",
|
|
367
|
+
action="execute",
|
|
368
|
+
resource=resource,
|
|
369
|
+
reason="No user context provided for approval",
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
context = dict(kwargs)
|
|
373
|
+
|
|
374
|
+
approved = strategy.request_approval(
|
|
375
|
+
user, "execute", resource, context
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if not approved:
|
|
379
|
+
raise AuthorizationError(
|
|
380
|
+
user=user.user_id,
|
|
381
|
+
action="execute",
|
|
382
|
+
resource=resource,
|
|
383
|
+
reason="Approval denied or timed out",
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
return func(*args, **kwargs)
|
|
387
|
+
|
|
388
|
+
return sync_wrapper # type: ignore
|
|
389
|
+
|
|
390
|
+
return decorator
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def authorize_tool_call(
|
|
394
|
+
proxilion: Any,
|
|
395
|
+
action: str = "execute",
|
|
396
|
+
resource: str | None = None,
|
|
397
|
+
user_param: str = "user",
|
|
398
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
399
|
+
"""
|
|
400
|
+
Standalone decorator for authorizing tool calls.
|
|
401
|
+
|
|
402
|
+
This is an alternative to using the `@auth.authorize()` method
|
|
403
|
+
when you need to decorate functions without having a reference
|
|
404
|
+
to the Proxilion instance at decoration time.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
proxilion: The Proxilion instance to use for authorization.
|
|
408
|
+
action: The action being performed.
|
|
409
|
+
resource: The resource name (defaults to function name).
|
|
410
|
+
user_param: Parameter name containing UserContext.
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
A decorator function.
|
|
414
|
+
|
|
415
|
+
Example:
|
|
416
|
+
>>> from proxilion import Proxilion
|
|
417
|
+
>>> from proxilion.decorators import authorize_tool_call
|
|
418
|
+
>>>
|
|
419
|
+
>>> auth = Proxilion()
|
|
420
|
+
>>>
|
|
421
|
+
>>> @authorize_tool_call(auth, action="execute", resource="search")
|
|
422
|
+
... async def search(query: str, user: UserContext = None):
|
|
423
|
+
... return await perform_search(query)
|
|
424
|
+
"""
|
|
425
|
+
return proxilion.authorize(
|
|
426
|
+
action=action,
|
|
427
|
+
resource=resource,
|
|
428
|
+
user_param=user_param,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def rate_limited(
|
|
433
|
+
capacity: int = 100,
|
|
434
|
+
refill_rate: float = 10.0,
|
|
435
|
+
user_param: str = "user",
|
|
436
|
+
key_func: Callable[[dict[str, Any]], str] | None = None,
|
|
437
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
438
|
+
"""
|
|
439
|
+
Standalone rate limiting decorator.
|
|
440
|
+
|
|
441
|
+
Apply rate limiting to a function without using the full
|
|
442
|
+
Proxilion authorization flow.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
capacity: Maximum tokens in bucket.
|
|
446
|
+
refill_rate: Tokens added per second.
|
|
447
|
+
user_param: Parameter name containing UserContext.
|
|
448
|
+
key_func: Custom function to extract rate limit key.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
A decorator function.
|
|
452
|
+
|
|
453
|
+
Example:
|
|
454
|
+
>>> @rate_limited(capacity=10, refill_rate=1.0)
|
|
455
|
+
... async def expensive_operation(user: UserContext = None):
|
|
456
|
+
... return await perform_operation()
|
|
457
|
+
"""
|
|
458
|
+
from proxilion.exceptions import RateLimitExceeded
|
|
459
|
+
from proxilion.security.rate_limiter import TokenBucketRateLimiter
|
|
460
|
+
|
|
461
|
+
limiter = TokenBucketRateLimiter(capacity=capacity, refill_rate=refill_rate)
|
|
462
|
+
|
|
463
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
464
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
465
|
+
|
|
466
|
+
if is_async:
|
|
467
|
+
@functools.wraps(func)
|
|
468
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
469
|
+
if key_func:
|
|
470
|
+
key = key_func(dict(kwargs))
|
|
471
|
+
else:
|
|
472
|
+
user = kwargs.get(user_param)
|
|
473
|
+
key = user.user_id if user else "anonymous"
|
|
474
|
+
|
|
475
|
+
if not limiter.allow_request(key):
|
|
476
|
+
retry_after = limiter.get_retry_after(key)
|
|
477
|
+
raise RateLimitExceeded(
|
|
478
|
+
limit_type="function",
|
|
479
|
+
limit_key=key,
|
|
480
|
+
limit_value=capacity,
|
|
481
|
+
retry_after=retry_after,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
return await func(*args, **kwargs)
|
|
485
|
+
|
|
486
|
+
return async_wrapper # type: ignore
|
|
487
|
+
else:
|
|
488
|
+
@functools.wraps(func)
|
|
489
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
490
|
+
if key_func:
|
|
491
|
+
key = key_func(dict(kwargs))
|
|
492
|
+
else:
|
|
493
|
+
user = kwargs.get(user_param)
|
|
494
|
+
key = user.user_id if user else "anonymous"
|
|
495
|
+
|
|
496
|
+
if not limiter.allow_request(key):
|
|
497
|
+
retry_after = limiter.get_retry_after(key)
|
|
498
|
+
raise RateLimitExceeded(
|
|
499
|
+
limit_type="function",
|
|
500
|
+
limit_key=key,
|
|
501
|
+
limit_value=capacity,
|
|
502
|
+
retry_after=retry_after,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
return func(*args, **kwargs)
|
|
506
|
+
|
|
507
|
+
return sync_wrapper # type: ignore
|
|
508
|
+
|
|
509
|
+
return decorator
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def circuit_protected(
|
|
513
|
+
failure_threshold: int = 5,
|
|
514
|
+
reset_timeout: float = 30.0,
|
|
515
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
516
|
+
"""
|
|
517
|
+
Standalone circuit breaker decorator.
|
|
518
|
+
|
|
519
|
+
Wrap a function with circuit breaker protection to prevent
|
|
520
|
+
cascading failures.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
failure_threshold: Failures before opening circuit.
|
|
524
|
+
reset_timeout: Seconds before attempting reset.
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
A decorator function.
|
|
528
|
+
|
|
529
|
+
Example:
|
|
530
|
+
>>> @circuit_protected(failure_threshold=3, reset_timeout=60.0)
|
|
531
|
+
... async def external_api_call():
|
|
532
|
+
... return await call_external_api()
|
|
533
|
+
"""
|
|
534
|
+
from proxilion.security.circuit_breaker import CircuitBreaker
|
|
535
|
+
|
|
536
|
+
breaker = CircuitBreaker(
|
|
537
|
+
failure_threshold=failure_threshold,
|
|
538
|
+
reset_timeout=reset_timeout,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
542
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
543
|
+
|
|
544
|
+
if is_async:
|
|
545
|
+
@functools.wraps(func)
|
|
546
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
547
|
+
return await breaker.call_async(func, *args, **kwargs)
|
|
548
|
+
|
|
549
|
+
return async_wrapper # type: ignore
|
|
550
|
+
else:
|
|
551
|
+
@functools.wraps(func)
|
|
552
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
553
|
+
return breaker.call(func, *args, **kwargs)
|
|
554
|
+
|
|
555
|
+
return sync_wrapper # type: ignore
|
|
556
|
+
|
|
557
|
+
return decorator
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def sequence_validated(
|
|
561
|
+
proxilion: Any,
|
|
562
|
+
tool_name: str | None = None,
|
|
563
|
+
user_param: str = "user",
|
|
564
|
+
record_on_success: bool = True,
|
|
565
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
566
|
+
"""
|
|
567
|
+
Decorator that validates tool call sequence before execution.
|
|
568
|
+
|
|
569
|
+
Validates the tool call against sequence rules configured in the
|
|
570
|
+
Proxilion instance. If validation fails, raises SequenceViolationError.
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
proxilion: The Proxilion instance with sequence validator.
|
|
574
|
+
tool_name: Tool name to use (defaults to function name).
|
|
575
|
+
user_param: Parameter name containing UserContext.
|
|
576
|
+
record_on_success: Whether to record the call after successful execution.
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
A decorator function.
|
|
580
|
+
|
|
581
|
+
Example:
|
|
582
|
+
>>> from proxilion import Proxilion
|
|
583
|
+
>>> from proxilion.decorators import sequence_validated
|
|
584
|
+
>>>
|
|
585
|
+
>>> auth = Proxilion()
|
|
586
|
+
>>>
|
|
587
|
+
>>> @sequence_validated(auth, tool_name="delete_file")
|
|
588
|
+
... def delete_file(path: str, user: UserContext = None):
|
|
589
|
+
... os.remove(path)
|
|
590
|
+
...
|
|
591
|
+
>>> # Will fail if confirm_* wasn't called first
|
|
592
|
+
>>> delete_file("/path/to/file", user=user)
|
|
593
|
+
"""
|
|
594
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
595
|
+
name = tool_name or func.__name__
|
|
596
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
597
|
+
|
|
598
|
+
if is_async:
|
|
599
|
+
@functools.wraps(func)
|
|
600
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
601
|
+
user = kwargs.get(user_param)
|
|
602
|
+
if user is None:
|
|
603
|
+
raise AuthorizationError(
|
|
604
|
+
user="unknown",
|
|
605
|
+
action="execute",
|
|
606
|
+
resource=name,
|
|
607
|
+
reason="No user context provided for sequence validation",
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
# Validate sequence
|
|
611
|
+
allowed, violation = proxilion.validate_sequence(name, user)
|
|
612
|
+
if not allowed and violation:
|
|
613
|
+
raise SequenceViolationError(
|
|
614
|
+
rule_name=violation.rule_name,
|
|
615
|
+
tool_name=name,
|
|
616
|
+
required_prior=violation.required_prior,
|
|
617
|
+
forbidden_prior=violation.forbidden_prior,
|
|
618
|
+
violation_type=(
|
|
619
|
+
violation.violation_type.value if violation.violation_type else None
|
|
620
|
+
),
|
|
621
|
+
consecutive_count=(
|
|
622
|
+
violation.consecutive_count if violation.consecutive_count else None
|
|
623
|
+
),
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
# Execute function
|
|
627
|
+
result = await func(*args, **kwargs)
|
|
628
|
+
|
|
629
|
+
# Record successful call
|
|
630
|
+
if record_on_success:
|
|
631
|
+
proxilion.record_tool_call(name, user)
|
|
632
|
+
|
|
633
|
+
return result
|
|
634
|
+
|
|
635
|
+
return async_wrapper # type: ignore
|
|
636
|
+
else:
|
|
637
|
+
@functools.wraps(func)
|
|
638
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
639
|
+
user = kwargs.get(user_param)
|
|
640
|
+
if user is None:
|
|
641
|
+
raise AuthorizationError(
|
|
642
|
+
user="unknown",
|
|
643
|
+
action="execute",
|
|
644
|
+
resource=name,
|
|
645
|
+
reason="No user context provided for sequence validation",
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
# Validate sequence
|
|
649
|
+
allowed, violation = proxilion.validate_sequence(name, user)
|
|
650
|
+
if not allowed and violation:
|
|
651
|
+
raise SequenceViolationError(
|
|
652
|
+
rule_name=violation.rule_name,
|
|
653
|
+
tool_name=name,
|
|
654
|
+
required_prior=violation.required_prior,
|
|
655
|
+
forbidden_prior=violation.forbidden_prior,
|
|
656
|
+
violation_type=(
|
|
657
|
+
violation.violation_type.value if violation.violation_type else None
|
|
658
|
+
),
|
|
659
|
+
consecutive_count=(
|
|
660
|
+
violation.consecutive_count if violation.consecutive_count else None
|
|
661
|
+
),
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Execute function
|
|
665
|
+
result = func(*args, **kwargs)
|
|
666
|
+
|
|
667
|
+
# Record successful call
|
|
668
|
+
if record_on_success:
|
|
669
|
+
proxilion.record_tool_call(name, user)
|
|
670
|
+
|
|
671
|
+
return result
|
|
672
|
+
|
|
673
|
+
return sync_wrapper # type: ignore
|
|
674
|
+
|
|
675
|
+
return decorator
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
def enforce_scope(
|
|
679
|
+
proxilion: Any,
|
|
680
|
+
scope: Any, # ExecutionScope | str
|
|
681
|
+
user_param: str = "user",
|
|
682
|
+
action: str = "execute",
|
|
683
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
684
|
+
"""
|
|
685
|
+
Decorator that enforces execution scope on a function.
|
|
686
|
+
|
|
687
|
+
All tool calls within the decorated function must comply with
|
|
688
|
+
the specified scope's restrictions (read_only, read_write, admin).
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
proxilion: The Proxilion instance with scope enforcer.
|
|
692
|
+
scope: Scope name or ExecutionScope enum.
|
|
693
|
+
user_param: Parameter name containing UserContext.
|
|
694
|
+
action: Default action to validate against.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
A decorator function.
|
|
698
|
+
|
|
699
|
+
Example:
|
|
700
|
+
>>> from proxilion import Proxilion
|
|
701
|
+
>>> from proxilion.decorators import enforce_scope
|
|
702
|
+
>>> from proxilion.security.scope_enforcer import ExecutionScope
|
|
703
|
+
>>>
|
|
704
|
+
>>> auth = Proxilion()
|
|
705
|
+
>>>
|
|
706
|
+
>>> @enforce_scope(auth, "read_only")
|
|
707
|
+
... def handle_user_query(query: str, user: UserContext = None):
|
|
708
|
+
... # Any tool calls here must be read-only
|
|
709
|
+
... return get_user_data(query)
|
|
710
|
+
...
|
|
711
|
+
>>> # If this function calls delete_user, it will raise ScopeViolationError
|
|
712
|
+
"""
|
|
713
|
+
|
|
714
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
715
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
716
|
+
|
|
717
|
+
if is_async:
|
|
718
|
+
@functools.wraps(func)
|
|
719
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
720
|
+
user = kwargs.get(user_param)
|
|
721
|
+
if user is None:
|
|
722
|
+
raise AuthorizationError(
|
|
723
|
+
user="unknown",
|
|
724
|
+
action="execute",
|
|
725
|
+
resource=func.__name__,
|
|
726
|
+
reason="No user context provided for scope enforcement",
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# Enter scope context
|
|
730
|
+
ctx = proxilion.enter_scope(scope, user)
|
|
731
|
+
try:
|
|
732
|
+
# Store scope context in kwargs so nested calls can validate
|
|
733
|
+
kwargs["_scope_context"] = ctx
|
|
734
|
+
result = await func(*args, **kwargs)
|
|
735
|
+
return result
|
|
736
|
+
finally:
|
|
737
|
+
ctx.close()
|
|
738
|
+
|
|
739
|
+
return async_wrapper # type: ignore
|
|
740
|
+
else:
|
|
741
|
+
@functools.wraps(func)
|
|
742
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
743
|
+
user = kwargs.get(user_param)
|
|
744
|
+
if user is None:
|
|
745
|
+
raise AuthorizationError(
|
|
746
|
+
user="unknown",
|
|
747
|
+
action="execute",
|
|
748
|
+
resource=func.__name__,
|
|
749
|
+
reason="No user context provided for scope enforcement",
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
# Enter scope context
|
|
753
|
+
ctx = proxilion.enter_scope(scope, user)
|
|
754
|
+
try:
|
|
755
|
+
# Store scope context in kwargs so nested calls can validate
|
|
756
|
+
kwargs["_scope_context"] = ctx
|
|
757
|
+
result = func(*args, **kwargs)
|
|
758
|
+
return result
|
|
759
|
+
finally:
|
|
760
|
+
ctx.close()
|
|
761
|
+
|
|
762
|
+
return sync_wrapper # type: ignore
|
|
763
|
+
|
|
764
|
+
return decorator
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def scoped_tool(
|
|
768
|
+
proxilion: Any,
|
|
769
|
+
tool_name: str | None = None,
|
|
770
|
+
action: str = "execute",
|
|
771
|
+
user_param: str = "user",
|
|
772
|
+
scope_context_param: str = "_scope_context",
|
|
773
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
774
|
+
"""
|
|
775
|
+
Decorator that validates a tool call against the current scope context.
|
|
776
|
+
|
|
777
|
+
Use this decorator on individual tool functions to validate them
|
|
778
|
+
against the scope established by @enforce_scope.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
proxilion: The Proxilion instance with scope enforcer.
|
|
782
|
+
tool_name: Tool name (defaults to function name).
|
|
783
|
+
action: Action being performed.
|
|
784
|
+
user_param: Parameter name containing UserContext.
|
|
785
|
+
scope_context_param: Parameter name for scope context.
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
A decorator function.
|
|
789
|
+
|
|
790
|
+
Example:
|
|
791
|
+
>>> @scoped_tool(auth, action="delete")
|
|
792
|
+
... def delete_user(user_id: str, user: UserContext = None, _scope_context=None):
|
|
793
|
+
... # Will be validated against current scope
|
|
794
|
+
... ...
|
|
795
|
+
"""
|
|
796
|
+
|
|
797
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
798
|
+
name = tool_name or func.__name__
|
|
799
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
800
|
+
|
|
801
|
+
if is_async:
|
|
802
|
+
@functools.wraps(func)
|
|
803
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
804
|
+
# Get scope context if available
|
|
805
|
+
scope_ctx = kwargs.get(scope_context_param)
|
|
806
|
+
|
|
807
|
+
if scope_ctx is not None:
|
|
808
|
+
# Validate against scope
|
|
809
|
+
scope_ctx.validate_tool(name, action)
|
|
810
|
+
|
|
811
|
+
# Remove internal param before calling function
|
|
812
|
+
clean_kwargs = {k: v for k, v in kwargs.items() if k != scope_context_param}
|
|
813
|
+
return await func(*args, **clean_kwargs)
|
|
814
|
+
|
|
815
|
+
return async_wrapper # type: ignore
|
|
816
|
+
else:
|
|
817
|
+
@functools.wraps(func)
|
|
818
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
819
|
+
# Get scope context if available
|
|
820
|
+
scope_ctx = kwargs.get(scope_context_param)
|
|
821
|
+
|
|
822
|
+
if scope_ctx is not None:
|
|
823
|
+
# Validate against scope
|
|
824
|
+
scope_ctx.validate_tool(name, action)
|
|
825
|
+
|
|
826
|
+
# Remove internal param before calling function
|
|
827
|
+
clean_kwargs = {k: v for k, v in kwargs.items() if k != scope_context_param}
|
|
828
|
+
return func(*args, **clean_kwargs)
|
|
829
|
+
|
|
830
|
+
return sync_wrapper # type: ignore
|
|
831
|
+
|
|
832
|
+
return decorator
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
def cost_limited(
|
|
836
|
+
limiter: Any, # CostLimiter or HybridRateLimiter
|
|
837
|
+
estimate_cost: Callable[..., float] | float = 0.01,
|
|
838
|
+
user_param: str = "user",
|
|
839
|
+
record_actual: bool = True,
|
|
840
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
841
|
+
"""
|
|
842
|
+
Decorator that enforces cost limits on a function.
|
|
843
|
+
|
|
844
|
+
Checks cost limits before execution and optionally records
|
|
845
|
+
actual cost after execution.
|
|
846
|
+
|
|
847
|
+
Args:
|
|
848
|
+
limiter: CostLimiter or HybridRateLimiter instance.
|
|
849
|
+
estimate_cost: Fixed cost estimate or callable to estimate cost from args.
|
|
850
|
+
user_param: Parameter name containing UserContext.
|
|
851
|
+
record_actual: Whether to record actual cost after execution.
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
A decorator function.
|
|
855
|
+
|
|
856
|
+
Example:
|
|
857
|
+
>>> from proxilion.security.cost_limiter import CostLimiter
|
|
858
|
+
>>> from proxilion.decorators import cost_limited
|
|
859
|
+
>>>
|
|
860
|
+
>>> limiter = CostLimiter(limits=[...])
|
|
861
|
+
>>>
|
|
862
|
+
>>> @cost_limited(limiter, estimate_cost=0.05)
|
|
863
|
+
... def call_llm(prompt: str, user: UserContext = None):
|
|
864
|
+
... return client.chat(prompt)
|
|
865
|
+
...
|
|
866
|
+
>>> # Or with dynamic estimation
|
|
867
|
+
>>> @cost_limited(limiter, estimate_cost=lambda model, **kw: MODEL_COSTS[model])
|
|
868
|
+
... def call_model(model: str, prompt: str, user: UserContext = None):
|
|
869
|
+
... return client.chat(model, prompt)
|
|
870
|
+
"""
|
|
871
|
+
from proxilion.exceptions import BudgetExceededError
|
|
872
|
+
|
|
873
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
874
|
+
is_async = inspect.iscoroutinefunction(func)
|
|
875
|
+
|
|
876
|
+
def get_estimated_cost(*args: Any, **kwargs: Any) -> float:
|
|
877
|
+
if callable(estimate_cost):
|
|
878
|
+
return estimate_cost(*args, **kwargs)
|
|
879
|
+
return float(estimate_cost)
|
|
880
|
+
|
|
881
|
+
def get_user_id(kwargs: dict[str, Any]) -> str:
|
|
882
|
+
user = kwargs.get(user_param)
|
|
883
|
+
if user is None:
|
|
884
|
+
return "anonymous"
|
|
885
|
+
if hasattr(user, "user_id"):
|
|
886
|
+
return user.user_id
|
|
887
|
+
return str(user)
|
|
888
|
+
|
|
889
|
+
if is_async:
|
|
890
|
+
@functools.wraps(func)
|
|
891
|
+
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
892
|
+
user_id = get_user_id(kwargs)
|
|
893
|
+
cost_estimate = get_estimated_cost(*args, **kwargs)
|
|
894
|
+
|
|
895
|
+
# Check limit
|
|
896
|
+
if hasattr(limiter, "allow_request"):
|
|
897
|
+
# HybridRateLimiter
|
|
898
|
+
allowed, reason = limiter.allow_request(user_id, cost_estimate)
|
|
899
|
+
else:
|
|
900
|
+
# CostLimiter
|
|
901
|
+
result = limiter.check_limit(user_id, cost_estimate)
|
|
902
|
+
allowed = result.allowed
|
|
903
|
+
# reason available in result.limit_name if not allowed
|
|
904
|
+
|
|
905
|
+
if not allowed:
|
|
906
|
+
raise BudgetExceededError(
|
|
907
|
+
limit_type="cost_limit",
|
|
908
|
+
current_spend=0.0, # Could get from limiter status
|
|
909
|
+
limit=0.0,
|
|
910
|
+
estimated_cost=cost_estimate,
|
|
911
|
+
user_id=user_id,
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
# Execute function
|
|
915
|
+
result = await func(*args, **kwargs)
|
|
916
|
+
|
|
917
|
+
# Record actual cost
|
|
918
|
+
if record_actual:
|
|
919
|
+
if hasattr(limiter, "record_usage"):
|
|
920
|
+
limiter.record_usage(user_id, cost_estimate, func.__name__)
|
|
921
|
+
elif hasattr(limiter, "record_spend"):
|
|
922
|
+
limiter.record_spend(user_id, cost_estimate, func.__name__)
|
|
923
|
+
|
|
924
|
+
return result
|
|
925
|
+
|
|
926
|
+
return async_wrapper # type: ignore
|
|
927
|
+
else:
|
|
928
|
+
@functools.wraps(func)
|
|
929
|
+
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
930
|
+
user_id = get_user_id(kwargs)
|
|
931
|
+
cost_estimate = get_estimated_cost(*args, **kwargs)
|
|
932
|
+
|
|
933
|
+
# Check limit
|
|
934
|
+
if hasattr(limiter, "allow_request"):
|
|
935
|
+
# HybridRateLimiter
|
|
936
|
+
allowed, reason = limiter.allow_request(user_id, cost_estimate)
|
|
937
|
+
else:
|
|
938
|
+
# CostLimiter
|
|
939
|
+
result = limiter.check_limit(user_id, cost_estimate)
|
|
940
|
+
allowed = result.allowed
|
|
941
|
+
# reason available in result.limit_name if not allowed
|
|
942
|
+
|
|
943
|
+
if not allowed:
|
|
944
|
+
raise BudgetExceededError(
|
|
945
|
+
limit_type="cost_limit",
|
|
946
|
+
current_spend=0.0,
|
|
947
|
+
limit=0.0,
|
|
948
|
+
estimated_cost=cost_estimate,
|
|
949
|
+
user_id=user_id,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
# Execute function
|
|
953
|
+
result = func(*args, **kwargs)
|
|
954
|
+
|
|
955
|
+
# Record actual cost
|
|
956
|
+
if record_actual:
|
|
957
|
+
if hasattr(limiter, "record_usage"):
|
|
958
|
+
limiter.record_usage(user_id, cost_estimate, func.__name__)
|
|
959
|
+
elif hasattr(limiter, "record_spend"):
|
|
960
|
+
limiter.record_spend(user_id, cost_estimate, func.__name__)
|
|
961
|
+
|
|
962
|
+
return result
|
|
963
|
+
|
|
964
|
+
return sync_wrapper # type: ignore
|
|
965
|
+
|
|
966
|
+
return decorator
|