ai-lib-python 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_lib_python/__init__.py +43 -0
- ai_lib_python/batch/__init__.py +15 -0
- ai_lib_python/batch/collector.py +244 -0
- ai_lib_python/batch/executor.py +224 -0
- ai_lib_python/cache/__init__.py +26 -0
- ai_lib_python/cache/backends.py +380 -0
- ai_lib_python/cache/key.py +237 -0
- ai_lib_python/cache/manager.py +332 -0
- ai_lib_python/client/__init__.py +37 -0
- ai_lib_python/client/builder.py +528 -0
- ai_lib_python/client/cancel.py +368 -0
- ai_lib_python/client/core.py +433 -0
- ai_lib_python/client/response.py +134 -0
- ai_lib_python/embeddings/__init__.py +36 -0
- ai_lib_python/embeddings/client.py +339 -0
- ai_lib_python/embeddings/types.py +234 -0
- ai_lib_python/embeddings/vectors.py +246 -0
- ai_lib_python/errors/__init__.py +41 -0
- ai_lib_python/errors/base.py +316 -0
- ai_lib_python/errors/classification.py +210 -0
- ai_lib_python/guardrails/__init__.py +35 -0
- ai_lib_python/guardrails/base.py +336 -0
- ai_lib_python/guardrails/filters.py +583 -0
- ai_lib_python/guardrails/validators.py +475 -0
- ai_lib_python/pipeline/__init__.py +55 -0
- ai_lib_python/pipeline/accumulate.py +248 -0
- ai_lib_python/pipeline/base.py +240 -0
- ai_lib_python/pipeline/decode.py +281 -0
- ai_lib_python/pipeline/event_map.py +506 -0
- ai_lib_python/pipeline/fan_out.py +284 -0
- ai_lib_python/pipeline/select.py +297 -0
- ai_lib_python/plugins/__init__.py +32 -0
- ai_lib_python/plugins/base.py +294 -0
- ai_lib_python/plugins/hooks.py +296 -0
- ai_lib_python/plugins/middleware.py +285 -0
- ai_lib_python/plugins/registry.py +294 -0
- ai_lib_python/protocol/__init__.py +71 -0
- ai_lib_python/protocol/loader.py +317 -0
- ai_lib_python/protocol/manifest.py +385 -0
- ai_lib_python/protocol/validator.py +460 -0
- ai_lib_python/py.typed +1 -0
- ai_lib_python/resilience/__init__.py +102 -0
- ai_lib_python/resilience/backpressure.py +225 -0
- ai_lib_python/resilience/circuit_breaker.py +318 -0
- ai_lib_python/resilience/executor.py +343 -0
- ai_lib_python/resilience/fallback.py +341 -0
- ai_lib_python/resilience/preflight.py +413 -0
- ai_lib_python/resilience/rate_limiter.py +291 -0
- ai_lib_python/resilience/retry.py +299 -0
- ai_lib_python/resilience/signals.py +283 -0
- ai_lib_python/routing/__init__.py +118 -0
- ai_lib_python/routing/manager.py +593 -0
- ai_lib_python/routing/strategy.py +345 -0
- ai_lib_python/routing/types.py +397 -0
- ai_lib_python/structured/__init__.py +33 -0
- ai_lib_python/structured/json_mode.py +281 -0
- ai_lib_python/structured/schema.py +316 -0
- ai_lib_python/structured/validator.py +334 -0
- ai_lib_python/telemetry/__init__.py +127 -0
- ai_lib_python/telemetry/exporters/__init__.py +9 -0
- ai_lib_python/telemetry/exporters/prometheus.py +111 -0
- ai_lib_python/telemetry/feedback.py +446 -0
- ai_lib_python/telemetry/health.py +409 -0
- ai_lib_python/telemetry/logger.py +389 -0
- ai_lib_python/telemetry/metrics.py +496 -0
- ai_lib_python/telemetry/tracer.py +473 -0
- ai_lib_python/tokens/__init__.py +25 -0
- ai_lib_python/tokens/counter.py +282 -0
- ai_lib_python/tokens/estimator.py +286 -0
- ai_lib_python/transport/__init__.py +34 -0
- ai_lib_python/transport/auth.py +141 -0
- ai_lib_python/transport/http.py +364 -0
- ai_lib_python/transport/pool.py +425 -0
- ai_lib_python/types/__init__.py +41 -0
- ai_lib_python/types/events.py +343 -0
- ai_lib_python/types/message.py +332 -0
- ai_lib_python/types/tool.py +191 -0
- ai_lib_python/utils/__init__.py +21 -0
- ai_lib_python/utils/tool_call_assembler.py +317 -0
- ai_lib_python-0.5.0.dist-info/METADATA +837 -0
- ai_lib_python-0.5.0.dist-info/RECORD +84 -0
- ai_lib_python-0.5.0.dist-info/WHEEL +4 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-APACHE +201 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-MIT +21 -0
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Preflight checks and unified request gating.
|
|
3
|
+
|
|
4
|
+
Provides unified preflight validation before request execution.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
from ai_lib_python.errors import AiLibError
|
|
14
|
+
from ai_lib_python.resilience.signals import SignalsSnapshot
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ai_lib_python.resilience.backpressure import BackpressureController
|
|
18
|
+
from ai_lib_python.resilience.circuit_breaker import CircuitBreaker
|
|
19
|
+
from ai_lib_python.resilience.rate_limiter import RateLimiter
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PreflightError(AiLibError):
|
|
23
|
+
"""Error raised when preflight check fails."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
message: str,
|
|
28
|
+
component: str,
|
|
29
|
+
retryable: bool = True,
|
|
30
|
+
retry_after_ms: int | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initialize preflight error.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
message: Error message
|
|
36
|
+
component: Component that failed (rate_limiter, circuit_breaker, backpressure)
|
|
37
|
+
retryable: Whether the request can be retried
|
|
38
|
+
retry_after_ms: Suggested retry delay in milliseconds
|
|
39
|
+
"""
|
|
40
|
+
super().__init__(message)
|
|
41
|
+
self.component = component
|
|
42
|
+
self.retryable = retryable
|
|
43
|
+
self.retry_after_ms = retry_after_ms
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class PreflightResult:
|
|
48
|
+
"""Result of preflight checks.
|
|
49
|
+
|
|
50
|
+
Attributes:
|
|
51
|
+
passed: Whether all checks passed
|
|
52
|
+
permit: Backpressure permit (if acquired)
|
|
53
|
+
signals: Current signals snapshot
|
|
54
|
+
errors: List of failed checks
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
passed: bool = True
|
|
58
|
+
permit: Any = None
|
|
59
|
+
signals: SignalsSnapshot | None = None
|
|
60
|
+
errors: list[PreflightError] = field(default_factory=list)
|
|
61
|
+
|
|
62
|
+
def release_permit(self) -> None:
|
|
63
|
+
"""Release the backpressure permit if held."""
|
|
64
|
+
if self.permit is not None:
|
|
65
|
+
# Permit is typically an asyncio.Semaphore release
|
|
66
|
+
try:
|
|
67
|
+
self.permit.release()
|
|
68
|
+
except (ValueError, RuntimeError):
|
|
69
|
+
pass
|
|
70
|
+
self.permit = None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class PreflightConfig:
|
|
75
|
+
"""Configuration for preflight checks.
|
|
76
|
+
|
|
77
|
+
Attributes:
|
|
78
|
+
check_rate_limiter: Whether to check rate limiter
|
|
79
|
+
check_circuit_breaker: Whether to check circuit breaker
|
|
80
|
+
check_backpressure: Whether to check backpressure
|
|
81
|
+
fail_fast: Whether to fail immediately on first failure
|
|
82
|
+
timeout_ms: Timeout for acquiring permits
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
check_rate_limiter: bool = True
|
|
86
|
+
check_circuit_breaker: bool = True
|
|
87
|
+
check_backpressure: bool = True
|
|
88
|
+
fail_fast: bool = True
|
|
89
|
+
timeout_ms: float = 30000.0
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class PreflightChecker:
|
|
93
|
+
"""Unified preflight checker for requests.
|
|
94
|
+
|
|
95
|
+
Performs rate limiter, circuit breaker, and backpressure checks
|
|
96
|
+
before allowing a request to proceed.
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
>>> checker = PreflightChecker(
|
|
100
|
+
... rate_limiter=rate_limiter,
|
|
101
|
+
... circuit_breaker=circuit_breaker,
|
|
102
|
+
... backpressure=backpressure_controller,
|
|
103
|
+
... )
|
|
104
|
+
>>>
|
|
105
|
+
>>> result = await checker.check()
|
|
106
|
+
>>> if result.passed:
|
|
107
|
+
... try:
|
|
108
|
+
... response = await make_request()
|
|
109
|
+
... finally:
|
|
110
|
+
... result.release_permit()
|
|
111
|
+
>>> else:
|
|
112
|
+
... for error in result.errors:
|
|
113
|
+
... print(f"Failed: {error.component}: {error}")
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
rate_limiter: RateLimiter | None = None,
|
|
119
|
+
circuit_breaker: CircuitBreaker | None = None,
|
|
120
|
+
backpressure: BackpressureController | None = None,
|
|
121
|
+
config: PreflightConfig | None = None,
|
|
122
|
+
provider: str | None = None,
|
|
123
|
+
model: str | None = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Initialize preflight checker.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
rate_limiter: Optional rate limiter
|
|
129
|
+
circuit_breaker: Optional circuit breaker
|
|
130
|
+
backpressure: Optional backpressure controller
|
|
131
|
+
config: Preflight configuration
|
|
132
|
+
provider: Provider identifier for signals
|
|
133
|
+
model: Model identifier for signals
|
|
134
|
+
"""
|
|
135
|
+
self._rate_limiter = rate_limiter
|
|
136
|
+
self._circuit_breaker = circuit_breaker
|
|
137
|
+
self._backpressure = backpressure
|
|
138
|
+
self._config = config or PreflightConfig()
|
|
139
|
+
self._provider = provider
|
|
140
|
+
self._model = model
|
|
141
|
+
|
|
142
|
+
async def check(self) -> PreflightResult:
|
|
143
|
+
"""Perform all preflight checks.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
PreflightResult with check status and permit
|
|
147
|
+
"""
|
|
148
|
+
result = PreflightResult()
|
|
149
|
+
errors: list[PreflightError] = []
|
|
150
|
+
|
|
151
|
+
# 1. Check circuit breaker (fast fail)
|
|
152
|
+
if self._config.check_circuit_breaker and self._circuit_breaker:
|
|
153
|
+
try:
|
|
154
|
+
if not self._circuit_breaker.allow():
|
|
155
|
+
cooldown = None
|
|
156
|
+
if self._circuit_breaker._last_failure:
|
|
157
|
+
import time
|
|
158
|
+
|
|
159
|
+
elapsed = time.time() - self._circuit_breaker._last_failure
|
|
160
|
+
remaining = (
|
|
161
|
+
self._circuit_breaker.config.cooldown_seconds - elapsed
|
|
162
|
+
)
|
|
163
|
+
if remaining > 0:
|
|
164
|
+
cooldown = int(remaining * 1000)
|
|
165
|
+
|
|
166
|
+
error = PreflightError(
|
|
167
|
+
"Circuit breaker is open",
|
|
168
|
+
"circuit_breaker",
|
|
169
|
+
retryable=True,
|
|
170
|
+
retry_after_ms=cooldown,
|
|
171
|
+
)
|
|
172
|
+
errors.append(error)
|
|
173
|
+
if self._config.fail_fast:
|
|
174
|
+
result.passed = False
|
|
175
|
+
result.errors = errors
|
|
176
|
+
return result
|
|
177
|
+
except Exception as e:
|
|
178
|
+
errors.append(
|
|
179
|
+
PreflightError(f"Circuit breaker check failed: {e}", "circuit_breaker")
|
|
180
|
+
)
|
|
181
|
+
if self._config.fail_fast:
|
|
182
|
+
result.passed = False
|
|
183
|
+
result.errors = errors
|
|
184
|
+
return result
|
|
185
|
+
|
|
186
|
+
# 2. Check rate limiter
|
|
187
|
+
if self._config.check_rate_limiter and self._rate_limiter:
|
|
188
|
+
try:
|
|
189
|
+
allowed = await self._rate_limiter.acquire()
|
|
190
|
+
if not allowed:
|
|
191
|
+
error = PreflightError(
|
|
192
|
+
"Rate limit exceeded",
|
|
193
|
+
"rate_limiter",
|
|
194
|
+
retryable=True,
|
|
195
|
+
retry_after_ms=1000, # Default 1s retry
|
|
196
|
+
)
|
|
197
|
+
errors.append(error)
|
|
198
|
+
if self._config.fail_fast:
|
|
199
|
+
result.passed = False
|
|
200
|
+
result.errors = errors
|
|
201
|
+
return result
|
|
202
|
+
except Exception as e:
|
|
203
|
+
errors.append(
|
|
204
|
+
PreflightError(f"Rate limiter check failed: {e}", "rate_limiter")
|
|
205
|
+
)
|
|
206
|
+
if self._config.fail_fast:
|
|
207
|
+
result.passed = False
|
|
208
|
+
result.errors = errors
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
# 3. Acquire backpressure permit
|
|
212
|
+
if self._config.check_backpressure and self._backpressure:
|
|
213
|
+
try:
|
|
214
|
+
timeout = self._config.timeout_ms / 1000.0
|
|
215
|
+
permit = await asyncio.wait_for(
|
|
216
|
+
self._backpressure.acquire(),
|
|
217
|
+
timeout=timeout,
|
|
218
|
+
)
|
|
219
|
+
if permit:
|
|
220
|
+
result.permit = permit
|
|
221
|
+
else:
|
|
222
|
+
error = PreflightError(
|
|
223
|
+
"Backpressure limit reached",
|
|
224
|
+
"backpressure",
|
|
225
|
+
retryable=True,
|
|
226
|
+
retry_after_ms=100,
|
|
227
|
+
)
|
|
228
|
+
errors.append(error)
|
|
229
|
+
if self._config.fail_fast:
|
|
230
|
+
result.passed = False
|
|
231
|
+
result.errors = errors
|
|
232
|
+
return result
|
|
233
|
+
except asyncio.TimeoutError:
|
|
234
|
+
error = PreflightError(
|
|
235
|
+
"Backpressure permit timeout",
|
|
236
|
+
"backpressure",
|
|
237
|
+
retryable=True,
|
|
238
|
+
retry_after_ms=100,
|
|
239
|
+
)
|
|
240
|
+
errors.append(error)
|
|
241
|
+
if self._config.fail_fast:
|
|
242
|
+
result.passed = False
|
|
243
|
+
result.errors = errors
|
|
244
|
+
return result
|
|
245
|
+
except Exception as e:
|
|
246
|
+
errors.append(
|
|
247
|
+
PreflightError(f"Backpressure check failed: {e}", "backpressure")
|
|
248
|
+
)
|
|
249
|
+
if self._config.fail_fast:
|
|
250
|
+
result.passed = False
|
|
251
|
+
result.errors = errors
|
|
252
|
+
return result
|
|
253
|
+
|
|
254
|
+
# Generate signals snapshot
|
|
255
|
+
result.signals = self.get_signals()
|
|
256
|
+
result.errors = errors
|
|
257
|
+
result.passed = len(errors) == 0
|
|
258
|
+
|
|
259
|
+
return result
|
|
260
|
+
|
|
261
|
+
def get_signals(self) -> SignalsSnapshot:
|
|
262
|
+
"""Get current signals snapshot.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
SignalsSnapshot with current state
|
|
266
|
+
"""
|
|
267
|
+
inflight = None
|
|
268
|
+
if self._backpressure:
|
|
269
|
+
max_concurrent = self._backpressure.max_concurrent
|
|
270
|
+
in_use = max_concurrent - self._backpressure.available
|
|
271
|
+
inflight = (max_concurrent, in_use)
|
|
272
|
+
|
|
273
|
+
return SignalsSnapshot.from_components(
|
|
274
|
+
inflight=inflight,
|
|
275
|
+
rate_limiter=self._rate_limiter,
|
|
276
|
+
circuit_breaker=self._circuit_breaker,
|
|
277
|
+
provider=self._provider,
|
|
278
|
+
model=self._model,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def on_success(self) -> None:
|
|
282
|
+
"""Report successful request completion."""
|
|
283
|
+
if self._circuit_breaker:
|
|
284
|
+
self._circuit_breaker.on_success()
|
|
285
|
+
|
|
286
|
+
def on_failure(self) -> None:
|
|
287
|
+
"""Report request failure."""
|
|
288
|
+
if self._circuit_breaker:
|
|
289
|
+
self._circuit_breaker.on_failure()
|
|
290
|
+
|
|
291
|
+
async def update_rate_limits(self, headers: dict[str, str]) -> None:
|
|
292
|
+
"""Update rate limiter state from response headers.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
headers: Response headers
|
|
296
|
+
"""
|
|
297
|
+
if not self._rate_limiter:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
# Common header patterns
|
|
301
|
+
remaining_headers = [
|
|
302
|
+
"x-ratelimit-remaining",
|
|
303
|
+
"x-ratelimit-remaining-requests",
|
|
304
|
+
"ratelimit-remaining",
|
|
305
|
+
]
|
|
306
|
+
reset_headers = [
|
|
307
|
+
"x-ratelimit-reset",
|
|
308
|
+
"x-ratelimit-reset-requests",
|
|
309
|
+
"ratelimit-reset",
|
|
310
|
+
"retry-after",
|
|
311
|
+
]
|
|
312
|
+
|
|
313
|
+
# Try to extract remaining count
|
|
314
|
+
remaining = None
|
|
315
|
+
for header in remaining_headers:
|
|
316
|
+
value = headers.get(header) or headers.get(header.title())
|
|
317
|
+
if value:
|
|
318
|
+
try:
|
|
319
|
+
remaining = int(value)
|
|
320
|
+
break
|
|
321
|
+
except ValueError:
|
|
322
|
+
continue
|
|
323
|
+
|
|
324
|
+
# Try to extract reset time
|
|
325
|
+
reset_after = None
|
|
326
|
+
for header in reset_headers:
|
|
327
|
+
value = headers.get(header) or headers.get(header.title())
|
|
328
|
+
if value:
|
|
329
|
+
try:
|
|
330
|
+
val = float(value)
|
|
331
|
+
# Check if it's an epoch timestamp or seconds
|
|
332
|
+
if val > 1_000_000_000:
|
|
333
|
+
import time
|
|
334
|
+
|
|
335
|
+
reset_after = val - time.time()
|
|
336
|
+
else:
|
|
337
|
+
reset_after = val
|
|
338
|
+
break
|
|
339
|
+
except ValueError:
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
# Update rate limiter if we have useful info
|
|
343
|
+
if remaining is not None or reset_after is not None:
|
|
344
|
+
await self._rate_limiter.update_budget(remaining, reset_after)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class PreflightContext:
|
|
348
|
+
"""Context manager for preflight checks.
|
|
349
|
+
|
|
350
|
+
Automatically releases permits on exit.
|
|
351
|
+
|
|
352
|
+
Example:
|
|
353
|
+
>>> async with PreflightContext(checker) as ctx:
|
|
354
|
+
... if ctx.passed:
|
|
355
|
+
... response = await make_request()
|
|
356
|
+
... ctx.on_success()
|
|
357
|
+
... else:
|
|
358
|
+
... print(f"Preflight failed: {ctx.errors}")
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(self, checker: PreflightChecker) -> None:
|
|
362
|
+
"""Initialize context.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
checker: PreflightChecker instance
|
|
366
|
+
"""
|
|
367
|
+
self._checker = checker
|
|
368
|
+
self._result: PreflightResult | None = None
|
|
369
|
+
|
|
370
|
+
async def __aenter__(self) -> PreflightContext:
|
|
371
|
+
"""Enter context and perform checks."""
|
|
372
|
+
self._result = await self._checker.check()
|
|
373
|
+
return self
|
|
374
|
+
|
|
375
|
+
async def __aexit__(
|
|
376
|
+
self,
|
|
377
|
+
exc_type: type[BaseException] | None,
|
|
378
|
+
exc_val: BaseException | None,
|
|
379
|
+
exc_tb: Any,
|
|
380
|
+
) -> None:
|
|
381
|
+
"""Exit context and release permit."""
|
|
382
|
+
if self._result:
|
|
383
|
+
self._result.release_permit()
|
|
384
|
+
|
|
385
|
+
# Report outcome
|
|
386
|
+
if exc_val is not None:
|
|
387
|
+
self._checker.on_failure()
|
|
388
|
+
elif self._result.passed:
|
|
389
|
+
# Success is reported explicitly by caller
|
|
390
|
+
pass
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
def passed(self) -> bool:
|
|
394
|
+
"""Check if preflight passed."""
|
|
395
|
+
return self._result.passed if self._result else False
|
|
396
|
+
|
|
397
|
+
@property
|
|
398
|
+
def signals(self) -> SignalsSnapshot | None:
|
|
399
|
+
"""Get signals snapshot."""
|
|
400
|
+
return self._result.signals if self._result else None
|
|
401
|
+
|
|
402
|
+
@property
|
|
403
|
+
def errors(self) -> list[PreflightError]:
|
|
404
|
+
"""Get list of errors."""
|
|
405
|
+
return self._result.errors if self._result else []
|
|
406
|
+
|
|
407
|
+
def on_success(self) -> None:
|
|
408
|
+
"""Report successful completion."""
|
|
409
|
+
self._checker.on_success()
|
|
410
|
+
|
|
411
|
+
def on_failure(self) -> None:
|
|
412
|
+
"""Report failure."""
|
|
413
|
+
self._checker.on_failure()
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rate limiter using token bucket algorithm.
|
|
3
|
+
|
|
4
|
+
Provides both static and adaptive rate limiting based on provider response headers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import time
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class RateLimiterConfig:
|
|
17
|
+
"""Configuration for rate limiter.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
requests_per_second: Maximum requests per second (0 = unlimited)
|
|
21
|
+
burst_size: Maximum burst size (tokens in bucket)
|
|
22
|
+
initial_tokens: Initial tokens in bucket
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
requests_per_second: float = 0.0
|
|
26
|
+
burst_size: int | None = None
|
|
27
|
+
initial_tokens: int | None = None
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def from_rps(cls, rps: float, burst_multiplier: float = 1.5) -> RateLimiterConfig:
|
|
31
|
+
"""Create config from requests per second.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
rps: Requests per second
|
|
35
|
+
burst_multiplier: Multiplier for burst size
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
RateLimiterConfig instance
|
|
39
|
+
"""
|
|
40
|
+
burst = int(rps * burst_multiplier) if rps > 0 else None
|
|
41
|
+
return cls(
|
|
42
|
+
requests_per_second=rps,
|
|
43
|
+
burst_size=burst,
|
|
44
|
+
initial_tokens=burst,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_rpm(cls, rpm: float, burst_multiplier: float = 1.5) -> RateLimiterConfig:
|
|
49
|
+
"""Create config from requests per minute.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
rpm: Requests per minute
|
|
53
|
+
burst_multiplier: Multiplier for burst size
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
RateLimiterConfig instance
|
|
57
|
+
"""
|
|
58
|
+
return cls.from_rps(rpm / 60.0, burst_multiplier)
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def unlimited(cls) -> RateLimiterConfig:
|
|
62
|
+
"""Create an unlimited rate limiter config."""
|
|
63
|
+
return cls(requests_per_second=0.0)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class RateLimiter:
|
|
67
|
+
"""Token bucket rate limiter.
|
|
68
|
+
|
|
69
|
+
Implements the token bucket algorithm for rate limiting:
|
|
70
|
+
- Tokens are added at a fixed rate
|
|
71
|
+
- Requests consume tokens
|
|
72
|
+
- If no tokens available, requests wait
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
>>> limiter = RateLimiter(RateLimiterConfig.from_rps(10))
|
|
76
|
+
>>> await limiter.acquire() # Wait if needed
|
|
77
|
+
>>> # Make request
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, config: RateLimiterConfig | None = None) -> None:
|
|
81
|
+
"""Initialize rate limiter.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
config: Rate limiter configuration
|
|
85
|
+
"""
|
|
86
|
+
self._config = config or RateLimiterConfig()
|
|
87
|
+
self._lock = asyncio.Lock()
|
|
88
|
+
|
|
89
|
+
# Token bucket state
|
|
90
|
+
self._tokens = float(
|
|
91
|
+
self._config.initial_tokens
|
|
92
|
+
if self._config.initial_tokens is not None
|
|
93
|
+
else (self._config.burst_size or 1)
|
|
94
|
+
)
|
|
95
|
+
self._max_tokens = float(self._config.burst_size or 1)
|
|
96
|
+
self._last_refill = time.monotonic()
|
|
97
|
+
|
|
98
|
+
# Rate (tokens per second)
|
|
99
|
+
self._rate = self._config.requests_per_second
|
|
100
|
+
|
|
101
|
+
def _refill(self) -> None:
|
|
102
|
+
"""Refill tokens based on elapsed time."""
|
|
103
|
+
if self._rate <= 0:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
now = time.monotonic()
|
|
107
|
+
elapsed = now - self._last_refill
|
|
108
|
+
self._last_refill = now
|
|
109
|
+
|
|
110
|
+
# Add tokens based on elapsed time
|
|
111
|
+
new_tokens = elapsed * self._rate
|
|
112
|
+
self._tokens = min(self._tokens + new_tokens, self._max_tokens)
|
|
113
|
+
|
|
114
|
+
async def acquire(self, tokens: int = 1) -> float:
|
|
115
|
+
"""Acquire tokens, waiting if necessary.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
tokens: Number of tokens to acquire
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Wait time in seconds (0 if no wait)
|
|
122
|
+
"""
|
|
123
|
+
if self._rate <= 0:
|
|
124
|
+
return 0.0 # Unlimited
|
|
125
|
+
|
|
126
|
+
async with self._lock:
|
|
127
|
+
self._refill()
|
|
128
|
+
|
|
129
|
+
wait_time = 0.0
|
|
130
|
+
|
|
131
|
+
if self._tokens < tokens:
|
|
132
|
+
# Calculate wait time
|
|
133
|
+
deficit = tokens - self._tokens
|
|
134
|
+
wait_time = deficit / self._rate
|
|
135
|
+
|
|
136
|
+
# Wait for tokens
|
|
137
|
+
await asyncio.sleep(wait_time)
|
|
138
|
+
self._refill()
|
|
139
|
+
|
|
140
|
+
# Consume tokens
|
|
141
|
+
self._tokens -= tokens
|
|
142
|
+
return wait_time
|
|
143
|
+
|
|
144
|
+
async def try_acquire(self, tokens: int = 1) -> bool:
|
|
145
|
+
"""Try to acquire tokens without waiting.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
tokens: Number of tokens to acquire
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
True if acquired, False if would need to wait
|
|
152
|
+
"""
|
|
153
|
+
if self._rate <= 0:
|
|
154
|
+
return True # Unlimited
|
|
155
|
+
|
|
156
|
+
async with self._lock:
|
|
157
|
+
self._refill()
|
|
158
|
+
|
|
159
|
+
if self._tokens >= tokens:
|
|
160
|
+
self._tokens -= tokens
|
|
161
|
+
return True
|
|
162
|
+
return False
|
|
163
|
+
|
|
164
|
+
def get_wait_time(self, tokens: int = 1) -> float:
|
|
165
|
+
"""Get estimated wait time without acquiring.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
tokens: Number of tokens needed
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Estimated wait time in seconds
|
|
172
|
+
"""
|
|
173
|
+
if self._rate <= 0:
|
|
174
|
+
return 0.0
|
|
175
|
+
|
|
176
|
+
self._refill()
|
|
177
|
+
|
|
178
|
+
if self._tokens >= tokens:
|
|
179
|
+
return 0.0
|
|
180
|
+
|
|
181
|
+
deficit = tokens - self._tokens
|
|
182
|
+
return deficit / self._rate
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def available_tokens(self) -> float:
|
|
186
|
+
"""Get current available tokens."""
|
|
187
|
+
self._refill()
|
|
188
|
+
return self._tokens
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def is_limited(self) -> bool:
|
|
192
|
+
"""Check if rate limiting is enabled."""
|
|
193
|
+
return self._rate > 0
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class AdaptiveRateLimiter(RateLimiter):
|
|
197
|
+
"""Adaptive rate limiter that adjusts based on server responses.
|
|
198
|
+
|
|
199
|
+
Monitors rate limit headers from API responses and adjusts
|
|
200
|
+
the rate limit dynamically.
|
|
201
|
+
|
|
202
|
+
Example:
|
|
203
|
+
>>> limiter = AdaptiveRateLimiter()
|
|
204
|
+
>>> await limiter.acquire()
|
|
205
|
+
>>> response = await make_request()
|
|
206
|
+
>>> limiter.update_from_headers(response.headers)
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
config: RateLimiterConfig | None = None,
|
|
212
|
+
header_config: dict[str, str] | None = None,
|
|
213
|
+
) -> None:
|
|
214
|
+
"""Initialize adaptive rate limiter.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
config: Base rate limiter configuration
|
|
218
|
+
header_config: Mapping of header names for rate limit info
|
|
219
|
+
"""
|
|
220
|
+
super().__init__(config)
|
|
221
|
+
self._header_config = header_config or {}
|
|
222
|
+
|
|
223
|
+
# Adaptive state
|
|
224
|
+
self._server_limit: int | None = None
|
|
225
|
+
self._server_remaining: int | None = None
|
|
226
|
+
self._server_reset: float | None = None
|
|
227
|
+
|
|
228
|
+
def update_from_headers(self, headers: dict[str, str]) -> None:
|
|
229
|
+
"""Update rate limit state from response headers.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
headers: Response headers
|
|
233
|
+
"""
|
|
234
|
+
# Extract limit
|
|
235
|
+
limit_header = self._header_config.get(
|
|
236
|
+
"requests_limit", "x-ratelimit-limit-requests"
|
|
237
|
+
)
|
|
238
|
+
if limit_header in headers:
|
|
239
|
+
try:
|
|
240
|
+
self._server_limit = int(headers[limit_header])
|
|
241
|
+
except ValueError:
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
# Extract remaining
|
|
245
|
+
remaining_header = self._header_config.get(
|
|
246
|
+
"requests_remaining", "x-ratelimit-remaining-requests"
|
|
247
|
+
)
|
|
248
|
+
if remaining_header in headers:
|
|
249
|
+
try:
|
|
250
|
+
self._server_remaining = int(headers[remaining_header])
|
|
251
|
+
# Update tokens to match server state
|
|
252
|
+
if self._server_remaining is not None:
|
|
253
|
+
self._tokens = float(self._server_remaining)
|
|
254
|
+
except ValueError:
|
|
255
|
+
pass
|
|
256
|
+
|
|
257
|
+
# Extract reset time
|
|
258
|
+
reset_header = self._header_config.get("requests_reset")
|
|
259
|
+
if reset_header and reset_header in headers:
|
|
260
|
+
try:
|
|
261
|
+
# May be seconds or timestamp
|
|
262
|
+
reset_value = headers[reset_header]
|
|
263
|
+
if "s" in reset_value or "m" in reset_value:
|
|
264
|
+
# Parse duration like "1s" or "1m"
|
|
265
|
+
reset_value = reset_value.rstrip("sm")
|
|
266
|
+
self._server_reset = float(reset_value)
|
|
267
|
+
else:
|
|
268
|
+
self._server_reset = float(reset_value)
|
|
269
|
+
except ValueError:
|
|
270
|
+
pass
|
|
271
|
+
|
|
272
|
+
# Adjust rate based on server limit
|
|
273
|
+
if (
|
|
274
|
+
self._server_limit is not None
|
|
275
|
+
and self._server_reset is not None
|
|
276
|
+
and self._server_reset > 0
|
|
277
|
+
):
|
|
278
|
+
self._rate = self._server_limit / self._server_reset
|
|
279
|
+
self._max_tokens = float(self._server_limit)
|
|
280
|
+
|
|
281
|
+
def get_server_state(self) -> dict[str, Any]:
|
|
282
|
+
"""Get current server-reported rate limit state.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Dict with limit, remaining, and reset values
|
|
286
|
+
"""
|
|
287
|
+
return {
|
|
288
|
+
"limit": self._server_limit,
|
|
289
|
+
"remaining": self._server_remaining,
|
|
290
|
+
"reset": self._server_reset,
|
|
291
|
+
}
|