flashlite 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashlite/__init__.py +169 -0
- flashlite/cache/__init__.py +14 -0
- flashlite/cache/base.py +194 -0
- flashlite/cache/disk.py +285 -0
- flashlite/cache/memory.py +157 -0
- flashlite/client.py +671 -0
- flashlite/config.py +154 -0
- flashlite/conversation/__init__.py +30 -0
- flashlite/conversation/context.py +319 -0
- flashlite/conversation/manager.py +385 -0
- flashlite/conversation/multi_agent.py +378 -0
- flashlite/core/__init__.py +13 -0
- flashlite/core/completion.py +145 -0
- flashlite/core/messages.py +130 -0
- flashlite/middleware/__init__.py +18 -0
- flashlite/middleware/base.py +90 -0
- flashlite/middleware/cache.py +121 -0
- flashlite/middleware/logging.py +159 -0
- flashlite/middleware/rate_limit.py +211 -0
- flashlite/middleware/retry.py +149 -0
- flashlite/observability/__init__.py +34 -0
- flashlite/observability/callbacks.py +155 -0
- flashlite/observability/inspect_compat.py +266 -0
- flashlite/observability/logging.py +293 -0
- flashlite/observability/metrics.py +221 -0
- flashlite/py.typed +0 -0
- flashlite/structured/__init__.py +31 -0
- flashlite/structured/outputs.py +189 -0
- flashlite/structured/schema.py +165 -0
- flashlite/templating/__init__.py +11 -0
- flashlite/templating/engine.py +217 -0
- flashlite/templating/filters.py +143 -0
- flashlite/templating/registry.py +165 -0
- flashlite/tools/__init__.py +74 -0
- flashlite/tools/definitions.py +382 -0
- flashlite/tools/execution.py +353 -0
- flashlite/types.py +233 -0
- flashlite-0.1.0.dist-info/METADATA +173 -0
- flashlite-0.1.0.dist-info/RECORD +41 -0
- flashlite-0.1.0.dist-info/WHEEL +4 -0
- flashlite-0.1.0.dist-info/licenses/LICENSE.md +21 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Base middleware protocol and chain implementation."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
|
|
6
|
+
from ..types import CompletionRequest, CompletionResponse
|
|
7
|
+
|
|
8
|
+
# Type alias for the completion handler
|
|
9
|
+
CompletionHandler = Callable[[CompletionRequest], Awaitable[CompletionResponse]]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Middleware(ABC):
|
|
13
|
+
"""
|
|
14
|
+
Abstract base class for middleware.
|
|
15
|
+
|
|
16
|
+
Middleware can intercept requests before they're sent and responses
|
|
17
|
+
after they're received. They form a chain where each middleware
|
|
18
|
+
calls the next one (or short-circuits).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def __call__(
|
|
23
|
+
self,
|
|
24
|
+
request: CompletionRequest,
|
|
25
|
+
next_handler: CompletionHandler,
|
|
26
|
+
) -> CompletionResponse:
|
|
27
|
+
"""
|
|
28
|
+
Process a request.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
request: The completion request
|
|
32
|
+
next_handler: The next middleware or final handler to call
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The completion response
|
|
36
|
+
"""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MiddlewareChain:
|
|
41
|
+
"""
|
|
42
|
+
Chains multiple middleware together.
|
|
43
|
+
|
|
44
|
+
Middleware are executed in order, with each one wrapping the next.
|
|
45
|
+
The final handler (actual API call) is at the end of the chain.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
middleware: list[Middleware],
|
|
51
|
+
final_handler: CompletionHandler,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the middleware chain.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
middleware: List of middleware to apply (in order)
|
|
58
|
+
final_handler: The final handler (actual completion call)
|
|
59
|
+
"""
|
|
60
|
+
self._middleware = middleware
|
|
61
|
+
self._final_handler = final_handler
|
|
62
|
+
|
|
63
|
+
async def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
|
64
|
+
"""Execute the middleware chain."""
|
|
65
|
+
return await self._execute(request, 0)
|
|
66
|
+
|
|
67
|
+
async def _execute(self, request: CompletionRequest, index: int) -> CompletionResponse:
|
|
68
|
+
"""Recursively execute middleware chain."""
|
|
69
|
+
if index >= len(self._middleware):
|
|
70
|
+
# No more middleware, call the final handler
|
|
71
|
+
return await self._final_handler(request)
|
|
72
|
+
|
|
73
|
+
# Get current middleware and create next handler
|
|
74
|
+
current = self._middleware[index]
|
|
75
|
+
|
|
76
|
+
async def next_handler(req: CompletionRequest) -> CompletionResponse:
|
|
77
|
+
return await self._execute(req, index + 1)
|
|
78
|
+
|
|
79
|
+
return await current(request, next_handler)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class PassthroughMiddleware(Middleware):
|
|
83
|
+
"""A middleware that does nothing - useful for testing."""
|
|
84
|
+
|
|
85
|
+
async def __call__(
|
|
86
|
+
self,
|
|
87
|
+
request: CompletionRequest,
|
|
88
|
+
next_handler: CompletionHandler,
|
|
89
|
+
) -> CompletionResponse:
|
|
90
|
+
return await next_handler(request)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Cache middleware for flashlite."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
|
|
6
|
+
from ..cache.base import CacheBackend, generate_cache_key, is_cacheable_request
|
|
7
|
+
from ..types import CompletionRequest, CompletionResponse
|
|
8
|
+
from .base import Middleware
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CacheMiddleware(Middleware):
|
|
14
|
+
"""
|
|
15
|
+
Middleware that caches completion responses.
|
|
16
|
+
|
|
17
|
+
Caches responses based on a hash of the request parameters.
|
|
18
|
+
Emits warnings when caching is used with non-deterministic settings
|
|
19
|
+
(temperature > 0 or reasoning models).
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
cache = MemoryCache(max_size=1000)
|
|
23
|
+
middleware = CacheMiddleware(cache)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
backend: CacheBackend,
|
|
29
|
+
ttl: float | None = None,
|
|
30
|
+
warn_non_deterministic: bool = True,
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the cache middleware.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
backend: The cache backend to use
|
|
37
|
+
ttl: Default TTL for cached entries (seconds)
|
|
38
|
+
warn_non_deterministic: Whether to warn about non-deterministic caching
|
|
39
|
+
"""
|
|
40
|
+
self._backend = backend
|
|
41
|
+
self._ttl = ttl
|
|
42
|
+
self._warn_non_deterministic = warn_non_deterministic
|
|
43
|
+
self._warned_keys: set[str] = set() # Track warnings to avoid spam
|
|
44
|
+
|
|
45
|
+
async def __call__(
|
|
46
|
+
self,
|
|
47
|
+
request: CompletionRequest,
|
|
48
|
+
next_handler: Callable[[CompletionRequest], Awaitable[CompletionResponse]],
|
|
49
|
+
) -> CompletionResponse:
|
|
50
|
+
"""Process request with caching."""
|
|
51
|
+
# Generate cache key
|
|
52
|
+
cache_key = generate_cache_key(request)
|
|
53
|
+
|
|
54
|
+
# Check if request is suitable for caching and emit warnings
|
|
55
|
+
if self._warn_non_deterministic:
|
|
56
|
+
_, warning = is_cacheable_request(request)
|
|
57
|
+
if warning and cache_key not in self._warned_keys:
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Caching enabled but request may be non-deterministic: {warning}. "
|
|
60
|
+
"Consider disabling cache for this request with force_refresh=True, "
|
|
61
|
+
"or set temperature=0 for deterministic outputs."
|
|
62
|
+
)
|
|
63
|
+
self._warned_keys.add(cache_key)
|
|
64
|
+
|
|
65
|
+
# Try to get from cache
|
|
66
|
+
cached_response = await self._backend.get(cache_key)
|
|
67
|
+
if cached_response is not None:
|
|
68
|
+
logger.debug(f"Cache hit for key {cache_key[:16]}...")
|
|
69
|
+
return cached_response
|
|
70
|
+
|
|
71
|
+
# Cache miss - call the next handler
|
|
72
|
+
logger.debug(f"Cache miss for key {cache_key[:16]}...")
|
|
73
|
+
response = await next_handler(request)
|
|
74
|
+
|
|
75
|
+
# Store in cache
|
|
76
|
+
await self._backend.set(cache_key, response, self._ttl)
|
|
77
|
+
|
|
78
|
+
return response
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def backend(self) -> CacheBackend:
|
|
82
|
+
"""Get the cache backend."""
|
|
83
|
+
return self._backend
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class CacheConfig:
|
|
87
|
+
"""
|
|
88
|
+
Configuration for caching behavior.
|
|
89
|
+
|
|
90
|
+
Note: Caching is disabled by default. When enabled, warnings are emitted
|
|
91
|
+
for non-deterministic requests (temperature > 0 or reasoning models).
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
enabled: bool = False,
|
|
97
|
+
backend: CacheBackend | None = None,
|
|
98
|
+
ttl: float | None = None,
|
|
99
|
+
warn_non_deterministic: bool = True,
|
|
100
|
+
):
|
|
101
|
+
"""
|
|
102
|
+
Initialize cache configuration.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
enabled: Whether caching is enabled (default: False)
|
|
106
|
+
backend: The cache backend to use
|
|
107
|
+
ttl: Default TTL for cached entries (seconds)
|
|
108
|
+
warn_non_deterministic: Whether to warn about non-deterministic caching
|
|
109
|
+
"""
|
|
110
|
+
self.enabled = enabled
|
|
111
|
+
self.backend = backend
|
|
112
|
+
self.ttl = ttl
|
|
113
|
+
self.warn_non_deterministic = warn_non_deterministic
|
|
114
|
+
|
|
115
|
+
# Emit info message about caching status
|
|
116
|
+
if not enabled:
|
|
117
|
+
logger.info(
|
|
118
|
+
"Caching is disabled by default. To enable, pass "
|
|
119
|
+
"cache=CacheConfig(enabled=True, backend=...) or "
|
|
120
|
+
"cache=MemoryCache(...) to the Flashlite client."
|
|
121
|
+
)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""Logging middleware for flashlite."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
import uuid
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
|
|
8
|
+
from ..observability.callbacks import CallbackManager
|
|
9
|
+
from ..observability.logging import StructuredLogger
|
|
10
|
+
from ..observability.metrics import CostTracker
|
|
11
|
+
from ..types import CompletionRequest, CompletionResponse
|
|
12
|
+
from .base import Middleware
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LoggingMiddleware(Middleware):
|
|
18
|
+
"""
|
|
19
|
+
Middleware that logs requests and responses.
|
|
20
|
+
|
|
21
|
+
Supports structured logging to files, callback-based logging,
|
|
22
|
+
and cost tracking.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
structured_logger = StructuredLogger(log_file="./logs/completions.jsonl")
|
|
26
|
+
cost_tracker = CostTracker(budget_limit=10.0)
|
|
27
|
+
|
|
28
|
+
middleware = LoggingMiddleware(
|
|
29
|
+
logger=structured_logger,
|
|
30
|
+
cost_tracker=cost_tracker,
|
|
31
|
+
log_level="INFO",
|
|
32
|
+
)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
structured_logger: StructuredLogger | None = None,
|
|
38
|
+
cost_tracker: CostTracker | None = None,
|
|
39
|
+
callbacks: CallbackManager | None = None,
|
|
40
|
+
log_level: str = "INFO",
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the logging middleware.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
structured_logger: Structured logger for file logging
|
|
47
|
+
cost_tracker: Cost tracker for budget monitoring
|
|
48
|
+
callbacks: Callback manager for event hooks
|
|
49
|
+
log_level: Minimum log level for standard logging
|
|
50
|
+
"""
|
|
51
|
+
self._structured_logger = structured_logger
|
|
52
|
+
self._cost_tracker = cost_tracker
|
|
53
|
+
self._callbacks = callbacks
|
|
54
|
+
self._log_level = getattr(logging, log_level.upper())
|
|
55
|
+
|
|
56
|
+
async def __call__(
|
|
57
|
+
self,
|
|
58
|
+
request: CompletionRequest,
|
|
59
|
+
next_handler: Callable[[CompletionRequest], Awaitable[CompletionResponse]],
|
|
60
|
+
) -> CompletionResponse:
|
|
61
|
+
"""Process request with logging."""
|
|
62
|
+
request_id = str(uuid.uuid4())
|
|
63
|
+
start_time = time.perf_counter()
|
|
64
|
+
|
|
65
|
+
# Log request
|
|
66
|
+
if self._structured_logger:
|
|
67
|
+
self._structured_logger.log_request(request, request_id)
|
|
68
|
+
|
|
69
|
+
if self._callbacks:
|
|
70
|
+
await self._callbacks.emit_request(request, request_id)
|
|
71
|
+
|
|
72
|
+
logger.log(
|
|
73
|
+
self._log_level,
|
|
74
|
+
f"[{request_id[:8]}] Starting request: model={request.model}",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
# Call next handler
|
|
79
|
+
response = await next_handler(request)
|
|
80
|
+
|
|
81
|
+
# Calculate latency
|
|
82
|
+
latency_ms = (time.perf_counter() - start_time) * 1000
|
|
83
|
+
|
|
84
|
+
# Log response
|
|
85
|
+
if self._structured_logger:
|
|
86
|
+
self._structured_logger.log_response(
|
|
87
|
+
response, request_id, latency_ms, cached=False
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if self._callbacks:
|
|
91
|
+
await self._callbacks.emit_response(
|
|
92
|
+
response, request_id, latency_ms, cached=False
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Track cost
|
|
96
|
+
if self._cost_tracker:
|
|
97
|
+
cost = self._cost_tracker.track(response)
|
|
98
|
+
logger.debug(
|
|
99
|
+
f"[{request_id[:8]}] Cost: ${cost:.6f}, "
|
|
100
|
+
f"Total: ${self._cost_tracker.total_cost:.4f}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
logger.log(
|
|
104
|
+
self._log_level,
|
|
105
|
+
f"[{request_id[:8]}] Completed: {latency_ms:.1f}ms, "
|
|
106
|
+
f"tokens={response.usage.total_tokens if response.usage else 'N/A'}",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return response
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
latency_ms = (time.perf_counter() - start_time) * 1000
|
|
113
|
+
|
|
114
|
+
if self._structured_logger:
|
|
115
|
+
self._structured_logger.log_error(request_id, e, latency_ms)
|
|
116
|
+
|
|
117
|
+
if self._callbacks:
|
|
118
|
+
await self._callbacks.emit_error(e, request_id, latency_ms)
|
|
119
|
+
|
|
120
|
+
logger.error(
|
|
121
|
+
f"[{request_id[:8]}] Error after {latency_ms:.1f}ms: {e}",
|
|
122
|
+
)
|
|
123
|
+
raise
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class CostTrackingMiddleware(Middleware):
|
|
127
|
+
"""
|
|
128
|
+
Lightweight middleware that only tracks costs.
|
|
129
|
+
|
|
130
|
+
Use this when you want cost tracking without full logging.
|
|
131
|
+
|
|
132
|
+
Example:
|
|
133
|
+
tracker = CostTracker(budget_limit=10.0)
|
|
134
|
+
middleware = CostTrackingMiddleware(tracker)
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(self, cost_tracker: CostTracker):
|
|
138
|
+
"""
|
|
139
|
+
Initialize the cost tracking middleware.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
cost_tracker: The cost tracker to use
|
|
143
|
+
"""
|
|
144
|
+
self._cost_tracker = cost_tracker
|
|
145
|
+
|
|
146
|
+
async def __call__(
|
|
147
|
+
self,
|
|
148
|
+
request: CompletionRequest,
|
|
149
|
+
next_handler: Callable[[CompletionRequest], Awaitable[CompletionResponse]],
|
|
150
|
+
) -> CompletionResponse:
|
|
151
|
+
"""Process request with cost tracking."""
|
|
152
|
+
response = await next_handler(request)
|
|
153
|
+
self._cost_tracker.track(response)
|
|
154
|
+
return response
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def tracker(self) -> CostTracker:
|
|
158
|
+
"""Get the cost tracker."""
|
|
159
|
+
return self._cost_tracker
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""Rate limiting middleware using token bucket algorithm."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
from ..types import CompletionRequest, CompletionResponse, RateLimitConfig, RateLimitError
|
|
9
|
+
from .base import CompletionHandler, Middleware
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class TokenBucket:
|
|
16
|
+
"""
|
|
17
|
+
Token bucket rate limiter.
|
|
18
|
+
|
|
19
|
+
Tokens are added at a constant rate up to a maximum capacity.
|
|
20
|
+
Each request consumes one or more tokens. If not enough tokens
|
|
21
|
+
are available, the request waits.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
rate: float # Tokens added per second
|
|
25
|
+
capacity: float # Maximum tokens in bucket
|
|
26
|
+
tokens: float = field(init=False)
|
|
27
|
+
last_update: float = field(init=False)
|
|
28
|
+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False)
|
|
29
|
+
|
|
30
|
+
def __post_init__(self) -> None:
|
|
31
|
+
self.tokens = self.capacity
|
|
32
|
+
self.last_update = time.monotonic()
|
|
33
|
+
|
|
34
|
+
def _refill(self) -> None:
|
|
35
|
+
"""Refill tokens based on elapsed time."""
|
|
36
|
+
now = time.monotonic()
|
|
37
|
+
elapsed = now - self.last_update
|
|
38
|
+
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
|
39
|
+
self.last_update = now
|
|
40
|
+
|
|
41
|
+
async def acquire(self, tokens: float = 1.0, timeout: float | None = None) -> float:
|
|
42
|
+
"""
|
|
43
|
+
Acquire tokens from the bucket.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
tokens: Number of tokens to acquire
|
|
47
|
+
timeout: Maximum time to wait (None = wait forever)
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Time waited in seconds
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
RateLimitError: If timeout exceeded
|
|
54
|
+
"""
|
|
55
|
+
start_time = time.monotonic()
|
|
56
|
+
deadline = start_time + timeout if timeout else None
|
|
57
|
+
|
|
58
|
+
async with self._lock:
|
|
59
|
+
while True:
|
|
60
|
+
self._refill()
|
|
61
|
+
|
|
62
|
+
if self.tokens >= tokens:
|
|
63
|
+
self.tokens -= tokens
|
|
64
|
+
return time.monotonic() - start_time
|
|
65
|
+
|
|
66
|
+
# Calculate wait time for enough tokens
|
|
67
|
+
tokens_needed = tokens - self.tokens
|
|
68
|
+
wait_time = tokens_needed / self.rate
|
|
69
|
+
|
|
70
|
+
# Check timeout
|
|
71
|
+
if deadline:
|
|
72
|
+
remaining = deadline - time.monotonic()
|
|
73
|
+
if remaining <= 0:
|
|
74
|
+
raise RateLimitError(
|
|
75
|
+
f"Rate limit timeout after {timeout}s",
|
|
76
|
+
retry_after=wait_time,
|
|
77
|
+
)
|
|
78
|
+
wait_time = min(wait_time, remaining)
|
|
79
|
+
|
|
80
|
+
# Release lock while waiting
|
|
81
|
+
self._lock.release()
|
|
82
|
+
try:
|
|
83
|
+
await asyncio.sleep(wait_time)
|
|
84
|
+
finally:
|
|
85
|
+
await self._lock.acquire()
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def available_tokens(self) -> float:
|
|
89
|
+
"""Get current available tokens (without acquiring lock)."""
|
|
90
|
+
elapsed = time.monotonic() - self.last_update
|
|
91
|
+
return min(self.capacity, self.tokens + elapsed * self.rate)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RateLimitMiddleware(Middleware):
|
|
95
|
+
"""
|
|
96
|
+
Middleware that enforces rate limits using token bucket algorithm.
|
|
97
|
+
|
|
98
|
+
Supports both requests-per-minute (RPM) and tokens-per-minute (TPM) limits.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, config: RateLimitConfig | None = None):
|
|
102
|
+
"""
|
|
103
|
+
Initialize rate limit middleware.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
config: Rate limit configuration
|
|
107
|
+
"""
|
|
108
|
+
self.config = config or RateLimitConfig()
|
|
109
|
+
self._rpm_bucket: TokenBucket | None = None
|
|
110
|
+
self._tpm_bucket: TokenBucket | None = None
|
|
111
|
+
self._initialized = False
|
|
112
|
+
|
|
113
|
+
def _ensure_initialized(self) -> None:
|
|
114
|
+
"""Lazily initialize buckets."""
|
|
115
|
+
if self._initialized:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
if self.config.requests_per_minute:
|
|
119
|
+
# Convert RPM to requests per second
|
|
120
|
+
rate = self.config.requests_per_minute / 60.0
|
|
121
|
+
# Capacity allows small bursts (10% of per-minute rate)
|
|
122
|
+
capacity = max(1.0, self.config.requests_per_minute * 0.1)
|
|
123
|
+
self._rpm_bucket = TokenBucket(rate=rate, capacity=capacity)
|
|
124
|
+
logger.debug(
|
|
125
|
+
f"Rate limiter initialized: {self.config.requests_per_minute} RPM "
|
|
126
|
+
f"(rate={rate:.2f}/s, capacity={capacity:.1f})"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if self.config.tokens_per_minute:
|
|
130
|
+
rate = self.config.tokens_per_minute / 60.0
|
|
131
|
+
capacity = max(1000.0, self.config.tokens_per_minute * 0.1)
|
|
132
|
+
self._tpm_bucket = TokenBucket(rate=rate, capacity=capacity)
|
|
133
|
+
logger.debug(
|
|
134
|
+
f"Token rate limiter initialized: {self.config.tokens_per_minute} TPM"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
self._initialized = True
|
|
138
|
+
|
|
139
|
+
async def __call__(
|
|
140
|
+
self,
|
|
141
|
+
request: CompletionRequest,
|
|
142
|
+
next_handler: CompletionHandler,
|
|
143
|
+
) -> CompletionResponse:
|
|
144
|
+
"""Execute with rate limiting."""
|
|
145
|
+
self._ensure_initialized()
|
|
146
|
+
|
|
147
|
+
# Acquire RPM token before making request
|
|
148
|
+
if self._rpm_bucket:
|
|
149
|
+
wait_time = await self._rpm_bucket.acquire()
|
|
150
|
+
if wait_time > 0.1: # Only log significant waits
|
|
151
|
+
logger.debug(f"Rate limit: waited {wait_time:.2f}s for RPM token")
|
|
152
|
+
|
|
153
|
+
# Make the request
|
|
154
|
+
response = await next_handler(request)
|
|
155
|
+
|
|
156
|
+
# For TPM limiting, consume tokens based on actual usage
|
|
157
|
+
# This is post-hoc - we can't know token count before the request
|
|
158
|
+
if self._tpm_bucket and response.usage:
|
|
159
|
+
total_tokens = response.usage.total_tokens
|
|
160
|
+
if total_tokens > 0:
|
|
161
|
+
# Don't block on TPM - just record the usage
|
|
162
|
+
# This creates backpressure for subsequent requests
|
|
163
|
+
await self._tpm_bucket.acquire(tokens=float(total_tokens))
|
|
164
|
+
|
|
165
|
+
return response
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def rpm_available(self) -> float | None:
|
|
169
|
+
"""Get available RPM tokens."""
|
|
170
|
+
if self._rpm_bucket:
|
|
171
|
+
return self._rpm_bucket.available_tokens
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def tpm_available(self) -> float | None:
|
|
176
|
+
"""Get available TPM tokens."""
|
|
177
|
+
if self._tpm_bucket:
|
|
178
|
+
return self._tpm_bucket.available_tokens
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ConcurrencyLimiter:
|
|
183
|
+
"""
|
|
184
|
+
Limits concurrent requests using a semaphore.
|
|
185
|
+
|
|
186
|
+
This is separate from rate limiting - it controls how many
|
|
187
|
+
requests can be in-flight simultaneously.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self, max_concurrency: int):
|
|
191
|
+
"""
|
|
192
|
+
Initialize concurrency limiter.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
max_concurrency: Maximum concurrent requests
|
|
196
|
+
"""
|
|
197
|
+
self.max_concurrency = max_concurrency
|
|
198
|
+
self._semaphore = asyncio.Semaphore(max_concurrency)
|
|
199
|
+
|
|
200
|
+
async def __aenter__(self) -> "ConcurrencyLimiter":
|
|
201
|
+
await self._semaphore.acquire()
|
|
202
|
+
return self
|
|
203
|
+
|
|
204
|
+
async def __aexit__(self, *args: object) -> None:
|
|
205
|
+
self._semaphore.release()
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def available_slots(self) -> int:
|
|
209
|
+
"""Get number of available concurrency slots."""
|
|
210
|
+
# Semaphore._value is the internal counter
|
|
211
|
+
return self._semaphore._value # noqa: SLF001
|