proxilion 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proxilion/__init__.py +136 -0
- proxilion/audit/__init__.py +133 -0
- proxilion/audit/base_exporters.py +527 -0
- proxilion/audit/compliance/__init__.py +130 -0
- proxilion/audit/compliance/base.py +457 -0
- proxilion/audit/compliance/eu_ai_act.py +603 -0
- proxilion/audit/compliance/iso27001.py +544 -0
- proxilion/audit/compliance/soc2.py +491 -0
- proxilion/audit/events.py +493 -0
- proxilion/audit/explainability.py +1173 -0
- proxilion/audit/exporters/__init__.py +58 -0
- proxilion/audit/exporters/aws_s3.py +636 -0
- proxilion/audit/exporters/azure_storage.py +608 -0
- proxilion/audit/exporters/cloud_base.py +468 -0
- proxilion/audit/exporters/gcp_storage.py +570 -0
- proxilion/audit/exporters/multi_exporter.py +498 -0
- proxilion/audit/hash_chain.py +652 -0
- proxilion/audit/logger.py +543 -0
- proxilion/caching/__init__.py +49 -0
- proxilion/caching/tool_cache.py +633 -0
- proxilion/context/__init__.py +73 -0
- proxilion/context/context_window.py +556 -0
- proxilion/context/message_history.py +505 -0
- proxilion/context/session.py +735 -0
- proxilion/contrib/__init__.py +51 -0
- proxilion/contrib/anthropic.py +609 -0
- proxilion/contrib/google.py +1012 -0
- proxilion/contrib/langchain.py +641 -0
- proxilion/contrib/mcp.py +893 -0
- proxilion/contrib/openai.py +646 -0
- proxilion/core.py +3058 -0
- proxilion/decorators.py +966 -0
- proxilion/engines/__init__.py +287 -0
- proxilion/engines/base.py +266 -0
- proxilion/engines/casbin_engine.py +412 -0
- proxilion/engines/opa_engine.py +493 -0
- proxilion/engines/simple.py +437 -0
- proxilion/exceptions.py +887 -0
- proxilion/guards/__init__.py +54 -0
- proxilion/guards/input_guard.py +522 -0
- proxilion/guards/output_guard.py +634 -0
- proxilion/observability/__init__.py +198 -0
- proxilion/observability/cost_tracker.py +866 -0
- proxilion/observability/hooks.py +683 -0
- proxilion/observability/metrics.py +798 -0
- proxilion/observability/session_cost_tracker.py +1063 -0
- proxilion/policies/__init__.py +67 -0
- proxilion/policies/base.py +304 -0
- proxilion/policies/builtin.py +486 -0
- proxilion/policies/registry.py +376 -0
- proxilion/providers/__init__.py +201 -0
- proxilion/providers/adapter.py +468 -0
- proxilion/providers/anthropic_adapter.py +330 -0
- proxilion/providers/gemini_adapter.py +391 -0
- proxilion/providers/openai_adapter.py +294 -0
- proxilion/py.typed +0 -0
- proxilion/resilience/__init__.py +81 -0
- proxilion/resilience/degradation.py +615 -0
- proxilion/resilience/fallback.py +555 -0
- proxilion/resilience/retry.py +554 -0
- proxilion/scheduling/__init__.py +57 -0
- proxilion/scheduling/priority_queue.py +419 -0
- proxilion/scheduling/scheduler.py +459 -0
- proxilion/security/__init__.py +244 -0
- proxilion/security/agent_trust.py +968 -0
- proxilion/security/behavioral_drift.py +794 -0
- proxilion/security/cascade_protection.py +869 -0
- proxilion/security/circuit_breaker.py +428 -0
- proxilion/security/cost_limiter.py +690 -0
- proxilion/security/idor_protection.py +460 -0
- proxilion/security/intent_capsule.py +849 -0
- proxilion/security/intent_validator.py +495 -0
- proxilion/security/memory_integrity.py +767 -0
- proxilion/security/rate_limiter.py +509 -0
- proxilion/security/scope_enforcer.py +680 -0
- proxilion/security/sequence_validator.py +636 -0
- proxilion/security/trust_boundaries.py +784 -0
- proxilion/streaming/__init__.py +70 -0
- proxilion/streaming/detector.py +761 -0
- proxilion/streaming/transformer.py +674 -0
- proxilion/timeouts/__init__.py +55 -0
- proxilion/timeouts/decorators.py +477 -0
- proxilion/timeouts/manager.py +545 -0
- proxilion/tools/__init__.py +69 -0
- proxilion/tools/decorators.py +493 -0
- proxilion/tools/registry.py +732 -0
- proxilion/types.py +339 -0
- proxilion/validation/__init__.py +93 -0
- proxilion/validation/pydantic_schema.py +351 -0
- proxilion/validation/schema.py +651 -0
- proxilion-0.0.1.dist-info/METADATA +872 -0
- proxilion-0.0.1.dist-info/RECORD +94 -0
- proxilion-0.0.1.dist-info/WHEEL +4 -0
- proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate limiting implementations for Proxilion.
|
|
3
|
+
|
|
4
|
+
This module provides various rate limiting strategies to prevent
|
|
5
|
+
unbounded consumption and protect against denial-of-service attacks.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from proxilion.exceptions import RateLimitExceeded
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class RateLimitState:
|
|
25
|
+
"""State for a rate limit bucket."""
|
|
26
|
+
tokens: float
|
|
27
|
+
last_update: float
|
|
28
|
+
request_count: int = 0
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TokenBucketRateLimiter:
|
|
32
|
+
"""
|
|
33
|
+
Token bucket rate limiter.
|
|
34
|
+
|
|
35
|
+
The token bucket algorithm allows bursts up to the bucket capacity
|
|
36
|
+
while maintaining a long-term average rate. Tokens are added at a
|
|
37
|
+
fixed rate and consumed by requests.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
capacity: Maximum number of tokens in the bucket.
|
|
41
|
+
refill_rate: Tokens added per second.
|
|
42
|
+
|
|
43
|
+
Thread Safety:
|
|
44
|
+
All operations are thread-safe using internal locking.
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> limiter = TokenBucketRateLimiter(capacity=100, refill_rate=10)
|
|
48
|
+
>>> if limiter.allow_request("user_123"):
|
|
49
|
+
... # Process request
|
|
50
|
+
... pass
|
|
51
|
+
>>> else:
|
|
52
|
+
... # Rate limit exceeded
|
|
53
|
+
... raise RateLimitExceeded(...)
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
capacity: int,
|
|
59
|
+
refill_rate: float,
|
|
60
|
+
key_func: Callable[[Any], str] | None = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Initialize the token bucket rate limiter.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
capacity: Maximum tokens in the bucket.
|
|
67
|
+
refill_rate: Tokens added per second.
|
|
68
|
+
key_func: Optional function to extract rate limit key from requests.
|
|
69
|
+
"""
|
|
70
|
+
self.capacity = capacity
|
|
71
|
+
self.refill_rate = refill_rate
|
|
72
|
+
self.key_func = key_func
|
|
73
|
+
|
|
74
|
+
self._buckets: dict[str, RateLimitState] = {}
|
|
75
|
+
self._lock = threading.RLock()
|
|
76
|
+
|
|
77
|
+
def _get_or_create_bucket(self, key: str) -> RateLimitState:
|
|
78
|
+
"""Get or create a bucket for a key."""
|
|
79
|
+
if key not in self._buckets:
|
|
80
|
+
self._buckets[key] = RateLimitState(
|
|
81
|
+
tokens=float(self.capacity),
|
|
82
|
+
last_update=time.monotonic(),
|
|
83
|
+
)
|
|
84
|
+
return self._buckets[key]
|
|
85
|
+
|
|
86
|
+
def _refill_bucket(self, bucket: RateLimitState) -> None:
|
|
87
|
+
"""Refill tokens based on elapsed time."""
|
|
88
|
+
now = time.monotonic()
|
|
89
|
+
elapsed = now - bucket.last_update
|
|
90
|
+
tokens_to_add = elapsed * self.refill_rate
|
|
91
|
+
|
|
92
|
+
bucket.tokens = min(self.capacity, bucket.tokens + tokens_to_add)
|
|
93
|
+
bucket.last_update = now
|
|
94
|
+
|
|
95
|
+
def allow_request(self, key: str, cost: int = 1) -> bool:
|
|
96
|
+
"""
|
|
97
|
+
Check if a request is allowed and consume tokens.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
key: The rate limit key (e.g., user ID, IP address).
|
|
101
|
+
cost: Number of tokens to consume (default 1).
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
True if the request is allowed, False if rate limited.
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
>>> if limiter.allow_request("user_123", cost=5):
|
|
108
|
+
... # Expensive operation
|
|
109
|
+
... pass
|
|
110
|
+
"""
|
|
111
|
+
with self._lock:
|
|
112
|
+
bucket = self._get_or_create_bucket(key)
|
|
113
|
+
self._refill_bucket(bucket)
|
|
114
|
+
|
|
115
|
+
if bucket.tokens >= cost:
|
|
116
|
+
bucket.tokens -= cost
|
|
117
|
+
bucket.request_count += 1
|
|
118
|
+
logger.debug(
|
|
119
|
+
f"Rate limit: key={key}, tokens_remaining={bucket.tokens:.1f}"
|
|
120
|
+
)
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
logger.debug(
|
|
124
|
+
f"Rate limit exceeded: key={key}, "
|
|
125
|
+
f"tokens={bucket.tokens:.1f}, cost={cost}"
|
|
126
|
+
)
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
def get_remaining(self, key: str) -> int:
|
|
130
|
+
"""
|
|
131
|
+
Get remaining tokens for a key.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
key: The rate limit key.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Number of available tokens (floored to int).
|
|
138
|
+
"""
|
|
139
|
+
with self._lock:
|
|
140
|
+
bucket = self._get_or_create_bucket(key)
|
|
141
|
+
self._refill_bucket(bucket)
|
|
142
|
+
return int(bucket.tokens)
|
|
143
|
+
|
|
144
|
+
def get_retry_after(self, key: str, cost: int = 1) -> float:
|
|
145
|
+
"""
|
|
146
|
+
Get seconds until enough tokens are available.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
key: The rate limit key.
|
|
150
|
+
cost: Number of tokens needed.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Seconds to wait, or 0 if tokens are available.
|
|
154
|
+
"""
|
|
155
|
+
with self._lock:
|
|
156
|
+
bucket = self._get_or_create_bucket(key)
|
|
157
|
+
self._refill_bucket(bucket)
|
|
158
|
+
|
|
159
|
+
if bucket.tokens >= cost:
|
|
160
|
+
return 0.0
|
|
161
|
+
|
|
162
|
+
tokens_needed = cost - bucket.tokens
|
|
163
|
+
return tokens_needed / self.refill_rate
|
|
164
|
+
|
|
165
|
+
def reset(self, key: str) -> None:
|
|
166
|
+
"""Reset a bucket to full capacity."""
|
|
167
|
+
with self._lock:
|
|
168
|
+
if key in self._buckets:
|
|
169
|
+
del self._buckets[key]
|
|
170
|
+
|
|
171
|
+
def reset_all(self) -> None:
|
|
172
|
+
"""Reset all buckets."""
|
|
173
|
+
with self._lock:
|
|
174
|
+
self._buckets.clear()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class SlidingWindowRateLimiter:
|
|
178
|
+
"""
|
|
179
|
+
Sliding window rate limiter.
|
|
180
|
+
|
|
181
|
+
Unlike token bucket, sliding window provides more consistent
|
|
182
|
+
rate limiting by tracking requests within a time window.
|
|
183
|
+
This prevents bursts at window boundaries.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
>>> limiter = SlidingWindowRateLimiter(
|
|
187
|
+
... max_requests=100,
|
|
188
|
+
... window_seconds=60
|
|
189
|
+
... )
|
|
190
|
+
>>> if limiter.allow_request("user_123"):
|
|
191
|
+
... # Process request
|
|
192
|
+
... pass
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
max_requests: int,
|
|
198
|
+
window_seconds: float,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Initialize the sliding window rate limiter.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
max_requests: Maximum requests allowed in the window.
|
|
205
|
+
window_seconds: Window size in seconds.
|
|
206
|
+
"""
|
|
207
|
+
self.max_requests = max_requests
|
|
208
|
+
self.window_seconds = window_seconds
|
|
209
|
+
|
|
210
|
+
self._requests: dict[str, list[float]] = defaultdict(list)
|
|
211
|
+
self._lock = threading.RLock()
|
|
212
|
+
|
|
213
|
+
def _cleanup_old_requests(self, key: str) -> None:
|
|
214
|
+
"""Remove requests outside the window."""
|
|
215
|
+
cutoff = time.monotonic() - self.window_seconds
|
|
216
|
+
self._requests[key] = [
|
|
217
|
+
t for t in self._requests[key] if t > cutoff
|
|
218
|
+
]
|
|
219
|
+
|
|
220
|
+
def allow_request(self, key: str, cost: int = 1) -> bool:
|
|
221
|
+
"""
|
|
222
|
+
Check if a request is allowed.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
key: The rate limit key.
|
|
226
|
+
cost: Number of "requests" to count (for weighted limiting).
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
True if allowed, False if rate limited.
|
|
230
|
+
"""
|
|
231
|
+
with self._lock:
|
|
232
|
+
self._cleanup_old_requests(key)
|
|
233
|
+
|
|
234
|
+
current_count = len(self._requests[key])
|
|
235
|
+
if current_count + cost <= self.max_requests:
|
|
236
|
+
now = time.monotonic()
|
|
237
|
+
for _ in range(cost):
|
|
238
|
+
self._requests[key].append(now)
|
|
239
|
+
return True
|
|
240
|
+
|
|
241
|
+
return False
|
|
242
|
+
|
|
243
|
+
def get_remaining(self, key: str) -> int:
|
|
244
|
+
"""Get remaining requests allowed in current window."""
|
|
245
|
+
with self._lock:
|
|
246
|
+
self._cleanup_old_requests(key)
|
|
247
|
+
return max(0, self.max_requests - len(self._requests[key]))
|
|
248
|
+
|
|
249
|
+
def get_retry_after(self, key: str) -> float:
|
|
250
|
+
"""Get seconds until the oldest request expires from window."""
|
|
251
|
+
with self._lock:
|
|
252
|
+
self._cleanup_old_requests(key)
|
|
253
|
+
|
|
254
|
+
if len(self._requests[key]) < self.max_requests:
|
|
255
|
+
return 0.0
|
|
256
|
+
|
|
257
|
+
if not self._requests[key]:
|
|
258
|
+
return 0.0
|
|
259
|
+
|
|
260
|
+
oldest = min(self._requests[key])
|
|
261
|
+
expires_at = oldest + self.window_seconds
|
|
262
|
+
return max(0.0, expires_at - time.monotonic())
|
|
263
|
+
|
|
264
|
+
def reset(self, key: str) -> None:
|
|
265
|
+
"""Reset request history for a key."""
|
|
266
|
+
with self._lock:
|
|
267
|
+
self._requests.pop(key, None)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass
|
|
271
|
+
class RateLimitConfig:
|
|
272
|
+
"""Configuration for a rate limit dimension."""
|
|
273
|
+
capacity: int
|
|
274
|
+
refill_rate: float
|
|
275
|
+
window_seconds: float | None = None # For sliding window
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class MultiDimensionalRateLimiter:
|
|
279
|
+
"""
|
|
280
|
+
Multi-dimensional rate limiter.
|
|
281
|
+
|
|
282
|
+
Applies different rate limits based on multiple dimensions:
|
|
283
|
+
user, tool, action, time of day, etc.
|
|
284
|
+
|
|
285
|
+
Example:
|
|
286
|
+
>>> limiter = MultiDimensionalRateLimiter({
|
|
287
|
+
... "user": RateLimitConfig(capacity=100, refill_rate=10),
|
|
288
|
+
... "tool": RateLimitConfig(capacity=50, refill_rate=5),
|
|
289
|
+
... "global": RateLimitConfig(capacity=1000, refill_rate=100),
|
|
290
|
+
... })
|
|
291
|
+
>>>
|
|
292
|
+
>>> keys = {"user": "user_123", "tool": "database_query"}
|
|
293
|
+
>>> if limiter.allow_request(keys):
|
|
294
|
+
... # All limits passed
|
|
295
|
+
... pass
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
def __init__(
|
|
299
|
+
self,
|
|
300
|
+
limits: dict[str, RateLimitConfig],
|
|
301
|
+
use_sliding_window: bool = False,
|
|
302
|
+
) -> None:
|
|
303
|
+
"""
|
|
304
|
+
Initialize the multi-dimensional rate limiter.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
limits: Dictionary of dimension name to RateLimitConfig.
|
|
308
|
+
use_sliding_window: If True, use sliding window instead of token bucket.
|
|
309
|
+
"""
|
|
310
|
+
self.limits = limits
|
|
311
|
+
self._limiters: dict[str, TokenBucketRateLimiter | SlidingWindowRateLimiter] = {}
|
|
312
|
+
|
|
313
|
+
for dimension, config in limits.items():
|
|
314
|
+
if use_sliding_window and config.window_seconds:
|
|
315
|
+
self._limiters[dimension] = SlidingWindowRateLimiter(
|
|
316
|
+
max_requests=config.capacity,
|
|
317
|
+
window_seconds=config.window_seconds,
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
self._limiters[dimension] = TokenBucketRateLimiter(
|
|
321
|
+
capacity=config.capacity,
|
|
322
|
+
refill_rate=config.refill_rate,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
def allow_request(
|
|
326
|
+
self,
|
|
327
|
+
keys: dict[str, str],
|
|
328
|
+
costs: dict[str, int] | None = None,
|
|
329
|
+
) -> bool:
|
|
330
|
+
"""
|
|
331
|
+
Check if request is allowed across all dimensions.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
keys: Dictionary mapping dimension names to keys.
|
|
335
|
+
costs: Optional per-dimension costs (default 1 for all).
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
True if all dimensions allow the request.
|
|
339
|
+
"""
|
|
340
|
+
costs = costs or {}
|
|
341
|
+
|
|
342
|
+
# Check all dimensions first (don't consume until we know all pass)
|
|
343
|
+
for dimension, key in keys.items():
|
|
344
|
+
if dimension not in self._limiters:
|
|
345
|
+
continue
|
|
346
|
+
|
|
347
|
+
limiter = self._limiters[dimension]
|
|
348
|
+
cost = costs.get(dimension, 1)
|
|
349
|
+
|
|
350
|
+
# For token bucket, we need to check without consuming
|
|
351
|
+
if isinstance(limiter, TokenBucketRateLimiter):
|
|
352
|
+
if limiter.get_remaining(key) < cost:
|
|
353
|
+
logger.debug(
|
|
354
|
+
f"Rate limit failed: dimension={dimension}, key={key}"
|
|
355
|
+
)
|
|
356
|
+
return False
|
|
357
|
+
else:
|
|
358
|
+
if limiter.get_remaining(key) < cost:
|
|
359
|
+
logger.debug(
|
|
360
|
+
f"Rate limit failed: dimension={dimension}, key={key}"
|
|
361
|
+
)
|
|
362
|
+
return False
|
|
363
|
+
|
|
364
|
+
# All checks passed, now consume tokens
|
|
365
|
+
for dimension, key in keys.items():
|
|
366
|
+
if dimension not in self._limiters:
|
|
367
|
+
continue
|
|
368
|
+
|
|
369
|
+
limiter = self._limiters[dimension]
|
|
370
|
+
cost = costs.get(dimension, 1)
|
|
371
|
+
limiter.allow_request(key, cost)
|
|
372
|
+
|
|
373
|
+
return True
|
|
374
|
+
|
|
375
|
+
def get_most_restrictive(
|
|
376
|
+
self,
|
|
377
|
+
keys: dict[str, str],
|
|
378
|
+
) -> tuple[str, int]:
|
|
379
|
+
"""
|
|
380
|
+
Get the most restrictive dimension.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
keys: Dictionary mapping dimension names to keys.
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
Tuple of (dimension_name, remaining_tokens).
|
|
387
|
+
"""
|
|
388
|
+
min_remaining = float("inf")
|
|
389
|
+
min_dimension = ""
|
|
390
|
+
|
|
391
|
+
for dimension, key in keys.items():
|
|
392
|
+
if dimension not in self._limiters:
|
|
393
|
+
continue
|
|
394
|
+
|
|
395
|
+
remaining = self._limiters[dimension].get_remaining(key)
|
|
396
|
+
if remaining < min_remaining:
|
|
397
|
+
min_remaining = remaining
|
|
398
|
+
min_dimension = dimension
|
|
399
|
+
|
|
400
|
+
return min_dimension, int(min_remaining)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class RateLimiterMiddleware:
|
|
404
|
+
"""
|
|
405
|
+
Rate limiter middleware for tool calls.
|
|
406
|
+
|
|
407
|
+
Integrates rate limiting with the authorization flow,
|
|
408
|
+
raising RateLimitExceeded when limits are hit.
|
|
409
|
+
|
|
410
|
+
Example:
|
|
411
|
+
>>> middleware = RateLimiterMiddleware(
|
|
412
|
+
... user_limit=TokenBucketRateLimiter(100, 10),
|
|
413
|
+
... tool_limits={"database_query": TokenBucketRateLimiter(10, 1)},
|
|
414
|
+
... )
|
|
415
|
+
>>>
|
|
416
|
+
>>> middleware.check_rate_limit(user, "database_query")
|
|
417
|
+
"""
|
|
418
|
+
|
|
419
|
+
def __init__(
|
|
420
|
+
self,
|
|
421
|
+
user_limit: TokenBucketRateLimiter | None = None,
|
|
422
|
+
tool_limits: dict[str, TokenBucketRateLimiter] | None = None,
|
|
423
|
+
global_limit: TokenBucketRateLimiter | None = None,
|
|
424
|
+
) -> None:
|
|
425
|
+
"""
|
|
426
|
+
Initialize the middleware.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
user_limit: Per-user rate limiter.
|
|
430
|
+
tool_limits: Per-tool rate limiters.
|
|
431
|
+
global_limit: Global rate limiter.
|
|
432
|
+
"""
|
|
433
|
+
self.user_limit = user_limit
|
|
434
|
+
self.tool_limits = tool_limits or {}
|
|
435
|
+
self.global_limit = global_limit
|
|
436
|
+
|
|
437
|
+
def check_rate_limit(
|
|
438
|
+
self,
|
|
439
|
+
user_id: str,
|
|
440
|
+
tool_name: str,
|
|
441
|
+
cost: int = 1,
|
|
442
|
+
) -> None:
|
|
443
|
+
"""
|
|
444
|
+
Check rate limits and raise if exceeded.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
user_id: The user's ID.
|
|
448
|
+
tool_name: The tool being called.
|
|
449
|
+
cost: Token cost for this request.
|
|
450
|
+
|
|
451
|
+
Raises:
|
|
452
|
+
RateLimitExceeded: If any rate limit is exceeded.
|
|
453
|
+
"""
|
|
454
|
+
# Check global limit
|
|
455
|
+
if self.global_limit and not self.global_limit.allow_request("global", cost):
|
|
456
|
+
retry_after = self.global_limit.get_retry_after("global", cost)
|
|
457
|
+
raise RateLimitExceeded(
|
|
458
|
+
limit_type="global",
|
|
459
|
+
limit_key="global",
|
|
460
|
+
limit_value=self.global_limit.capacity,
|
|
461
|
+
retry_after=retry_after,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Check user limit
|
|
465
|
+
if self.user_limit and not self.user_limit.allow_request(user_id, cost):
|
|
466
|
+
retry_after = self.user_limit.get_retry_after(user_id, cost)
|
|
467
|
+
raise RateLimitExceeded(
|
|
468
|
+
limit_type="user",
|
|
469
|
+
limit_key=user_id,
|
|
470
|
+
limit_value=self.user_limit.capacity,
|
|
471
|
+
retry_after=retry_after,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
# Check tool-specific limit
|
|
475
|
+
if tool_name in self.tool_limits:
|
|
476
|
+
tool_limiter = self.tool_limits[tool_name]
|
|
477
|
+
key = f"{user_id}:{tool_name}"
|
|
478
|
+
if not tool_limiter.allow_request(key, cost):
|
|
479
|
+
retry_after = tool_limiter.get_retry_after(key, cost)
|
|
480
|
+
raise RateLimitExceeded(
|
|
481
|
+
limit_type="tool",
|
|
482
|
+
limit_key=key,
|
|
483
|
+
limit_value=tool_limiter.capacity,
|
|
484
|
+
retry_after=retry_after,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def get_headers(
|
|
488
|
+
self,
|
|
489
|
+
user_id: str,
|
|
490
|
+
tool_name: str,
|
|
491
|
+
) -> dict[str, str]:
|
|
492
|
+
"""
|
|
493
|
+
Get rate limit headers for API responses.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
user_id: The user's ID.
|
|
497
|
+
tool_name: The tool name.
|
|
498
|
+
|
|
499
|
+
Returns:
|
|
500
|
+
Dictionary of rate limit headers.
|
|
501
|
+
"""
|
|
502
|
+
headers: dict[str, str] = {}
|
|
503
|
+
|
|
504
|
+
if self.user_limit:
|
|
505
|
+
remaining = self.user_limit.get_remaining(user_id)
|
|
506
|
+
headers["X-RateLimit-Limit"] = str(self.user_limit.capacity)
|
|
507
|
+
headers["X-RateLimit-Remaining"] = str(remaining)
|
|
508
|
+
|
|
509
|
+
return headers
|