mcal-ai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,372 @@
1
+ """
2
+ Rate limiting utilities for LLM API calls (Issue #39).
3
+
4
+ Provides token bucket rate limiter to prevent excessive API costs
5
+ and ensure compliance with provider rate limits.
6
+ """
7
+
8
+ import asyncio
9
+ import logging
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime, timezone
13
+ from typing import Optional, Dict, Any
14
+ from collections import deque
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _utc_now() -> datetime:
20
+ """Get current UTC time (timezone-aware)."""
21
+ return datetime.now(timezone.utc)
22
+
23
+
24
+ # =============================================================================
25
+ # Cost Estimation
26
+ # =============================================================================
27
+
28
+ @dataclass
29
+ class ModelPricing:
30
+ """Pricing information for a model (per 1M tokens)."""
31
+ input_cost_per_1m: float # Cost per 1M input tokens
32
+ output_cost_per_1m: float # Cost per 1M output tokens
33
+
34
+ def estimate_cost(self, input_tokens: int, output_tokens: int) -> float:
35
+ """Estimate cost for given token counts."""
36
+ input_cost = (input_tokens / 1_000_000) * self.input_cost_per_1m
37
+ output_cost = (output_tokens / 1_000_000) * self.output_cost_per_1m
38
+ return input_cost + output_cost
39
+
40
+
41
+ # Bedrock pricing (as of 2024 - update as needed)
42
+ # https://aws.amazon.com/bedrock/pricing/
43
+ BEDROCK_PRICING: Dict[str, ModelPricing] = {
44
+ # Llama 3.1 8B - Fast tier
45
+ "llama-3.1-8b": ModelPricing(input_cost_per_1m=0.22, output_cost_per_1m=0.22),
46
+ "llama-3.2-3b": ModelPricing(input_cost_per_1m=0.15, output_cost_per_1m=0.15),
47
+ # Llama 3.3 70B - Smart tier
48
+ "llama-3.3-70b": ModelPricing(input_cost_per_1m=0.72, output_cost_per_1m=0.72),
49
+ "llama-3.1-70b": ModelPricing(input_cost_per_1m=0.72, output_cost_per_1m=0.72),
50
+ "llama-4-maverick": ModelPricing(input_cost_per_1m=0.50, output_cost_per_1m=1.50),
51
+ }
52
+
53
+ # OpenAI pricing
54
+ OPENAI_PRICING: Dict[str, ModelPricing] = {
55
+ "gpt-4o": ModelPricing(input_cost_per_1m=2.50, output_cost_per_1m=10.00),
56
+ "gpt-4o-mini": ModelPricing(input_cost_per_1m=0.15, output_cost_per_1m=0.60),
57
+ "gpt-4-turbo": ModelPricing(input_cost_per_1m=10.00, output_cost_per_1m=30.00),
58
+ }
59
+
60
+ # Anthropic pricing
61
+ ANTHROPIC_PRICING: Dict[str, ModelPricing] = {
62
+ "claude-sonnet-4-20250514": ModelPricing(input_cost_per_1m=3.00, output_cost_per_1m=15.00),
63
+ "claude-3-5-sonnet-20241022": ModelPricing(input_cost_per_1m=3.00, output_cost_per_1m=15.00),
64
+ "claude-3-haiku-20240307": ModelPricing(input_cost_per_1m=0.25, output_cost_per_1m=1.25),
65
+ }
66
+
67
+
68
+ def get_pricing(model: str, provider: str = "bedrock") -> Optional[ModelPricing]:
69
+ """Get pricing for a model."""
70
+ pricing_maps = {
71
+ "bedrock": BEDROCK_PRICING,
72
+ "openai": OPENAI_PRICING,
73
+ "anthropic": ANTHROPIC_PRICING,
74
+ }
75
+ return pricing_maps.get(provider, {}).get(model)
76
+
77
+
78
+ # =============================================================================
79
+ # Token Bucket Rate Limiter
80
+ # =============================================================================
81
+
82
+ @dataclass
83
+ class RateLimitConfig:
84
+ """Configuration for rate limiting."""
85
+ requests_per_minute: int = 60 # RPM limit
86
+ tokens_per_minute: int = 100_000 # TPM limit
87
+ max_cost_per_hour: Optional[float] = None # Cost limit (USD)
88
+ enable_cost_logging: bool = True # Log cost estimates
89
+ warn_at_cost_percent: float = 0.8 # Warn when this % of hourly limit used
90
+
91
+
92
+ @dataclass
93
+ class RateLimitStats:
94
+ """Statistics from rate limiter."""
95
+ total_requests: int = 0
96
+ total_input_tokens: int = 0
97
+ total_output_tokens: int = 0
98
+ total_estimated_cost: float = 0.0
99
+ requests_delayed: int = 0
100
+ total_delay_seconds: float = 0.0
101
+ start_time: datetime = field(default_factory=_utc_now)
102
+
103
+ def to_dict(self) -> Dict[str, Any]:
104
+ """Convert to dictionary."""
105
+ runtime = (_utc_now() - self.start_time).total_seconds()
106
+ return {
107
+ "total_requests": self.total_requests,
108
+ "total_input_tokens": self.total_input_tokens,
109
+ "total_output_tokens": self.total_output_tokens,
110
+ "total_tokens": self.total_input_tokens + self.total_output_tokens,
111
+ "total_estimated_cost_usd": round(self.total_estimated_cost, 4),
112
+ "requests_delayed": self.requests_delayed,
113
+ "total_delay_seconds": round(self.total_delay_seconds, 2),
114
+ "runtime_seconds": round(runtime, 2),
115
+ "avg_requests_per_minute": round(self.total_requests / (runtime / 60), 2) if runtime > 0 else 0,
116
+ }
117
+
118
+
119
+ class TokenBucketRateLimiter:
120
+ """
121
+ Token bucket rate limiter for LLM API calls.
122
+
123
+ Implements both RPM (requests per minute) and TPM (tokens per minute)
124
+ limits using the token bucket algorithm with smooth refilling.
125
+
126
+ Example:
127
+ limiter = TokenBucketRateLimiter(
128
+ config=RateLimitConfig(
129
+ requests_per_minute=60,
130
+ tokens_per_minute=100_000,
131
+ max_cost_per_hour=10.0
132
+ ),
133
+ model="llama-3.3-70b",
134
+ provider="bedrock"
135
+ )
136
+
137
+ # Before each API call
138
+ await limiter.acquire(estimated_tokens=1000)
139
+
140
+ # After each API call
141
+ limiter.record_usage(input_tokens=500, output_tokens=200)
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ config: Optional[RateLimitConfig] = None,
147
+ model: str = "llama-3.3-70b",
148
+ provider: str = "bedrock",
149
+ ):
150
+ self.config = config or RateLimitConfig()
151
+ self.model = model
152
+ self.provider = provider
153
+ self.pricing = get_pricing(model, provider)
154
+
155
+ # Token buckets
156
+ self._request_bucket = float(self.config.requests_per_minute)
157
+ self._token_bucket = float(self.config.tokens_per_minute)
158
+ self._last_refill = time.monotonic()
159
+
160
+ # Cost tracking (sliding window for hourly limit)
161
+ self._hourly_costs: deque = deque() # (timestamp, cost) pairs
162
+
163
+ # Statistics
164
+ self.stats = RateLimitStats()
165
+
166
+ # Lock for thread safety
167
+ self._lock = asyncio.Lock()
168
+
169
+ def _refill_buckets(self) -> None:
170
+ """Refill token buckets based on elapsed time."""
171
+ now = time.monotonic()
172
+ elapsed = now - self._last_refill
173
+ self._last_refill = now
174
+
175
+ # Refill at rate of limit per minute
176
+ request_refill = (elapsed / 60.0) * self.config.requests_per_minute
177
+ token_refill = (elapsed / 60.0) * self.config.tokens_per_minute
178
+
179
+ self._request_bucket = min(
180
+ self._request_bucket + request_refill,
181
+ float(self.config.requests_per_minute)
182
+ )
183
+ self._token_bucket = min(
184
+ self._token_bucket + token_refill,
185
+ float(self.config.tokens_per_minute)
186
+ )
187
+
188
+ def _get_hourly_cost(self) -> float:
189
+ """Get total cost in the last hour."""
190
+ now = time.time()
191
+ hour_ago = now - 3600
192
+
193
+ # Remove old entries
194
+ while self._hourly_costs and self._hourly_costs[0][0] < hour_ago:
195
+ self._hourly_costs.popleft()
196
+
197
+ return sum(cost for _, cost in self._hourly_costs)
198
+
199
+ def _check_cost_limit(self, estimated_cost: float) -> bool:
200
+ """Check if we're within the hourly cost limit."""
201
+ if self.config.max_cost_per_hour is None:
202
+ return True
203
+
204
+ current_hourly = self._get_hourly_cost()
205
+ projected = current_hourly + estimated_cost
206
+
207
+ # Warn if approaching limit
208
+ if self.config.enable_cost_logging:
209
+ usage_percent = current_hourly / self.config.max_cost_per_hour
210
+ if usage_percent >= self.config.warn_at_cost_percent:
211
+ logger.warning(
212
+ f"Rate limiter: {usage_percent:.1%} of hourly cost limit used "
213
+ f"(${current_hourly:.4f} / ${self.config.max_cost_per_hour:.2f})"
214
+ )
215
+
216
+ return projected <= self.config.max_cost_per_hour
217
+
218
+ async def acquire(self, estimated_tokens: int = 1000) -> float:
219
+ """
220
+ Acquire permission to make an API call.
221
+
222
+ Blocks until rate limits allow the request, or raises if cost limit exceeded.
223
+
224
+ Args:
225
+ estimated_tokens: Estimated total tokens for the request
226
+
227
+ Returns:
228
+ Delay in seconds that was waited (0 if no delay)
229
+
230
+ Raises:
231
+ RuntimeError: If hourly cost limit would be exceeded
232
+ """
233
+ async with self._lock:
234
+ total_delay = 0.0
235
+
236
+ while True:
237
+ self._refill_buckets()
238
+
239
+ # Check if we have capacity
240
+ if self._request_bucket >= 1.0 and self._token_bucket >= estimated_tokens:
241
+ # Check cost limit
242
+ if self.pricing:
243
+ # Estimate cost (assume 50/50 input/output split for estimation)
244
+ estimated_cost = self.pricing.estimate_cost(
245
+ estimated_tokens // 2,
246
+ estimated_tokens // 2
247
+ )
248
+ if not self._check_cost_limit(estimated_cost):
249
+ raise RuntimeError(
250
+ f"Hourly cost limit (${self.config.max_cost_per_hour:.2f}) "
251
+ f"would be exceeded. Current: ${self._get_hourly_cost():.4f}"
252
+ )
253
+
254
+ # Consume from buckets
255
+ self._request_bucket -= 1.0
256
+ self._token_bucket -= estimated_tokens
257
+
258
+ if total_delay > 0:
259
+ self.stats.requests_delayed += 1
260
+ self.stats.total_delay_seconds += total_delay
261
+ logger.debug(f"Rate limiter: delayed {total_delay:.2f}s")
262
+
263
+ return total_delay
264
+
265
+ # Calculate wait time
266
+ wait_for_request = 0.0
267
+ wait_for_tokens = 0.0
268
+
269
+ if self._request_bucket < 1.0:
270
+ # Time to refill 1 request
271
+ wait_for_request = (1.0 - self._request_bucket) * 60.0 / self.config.requests_per_minute
272
+
273
+ if self._token_bucket < estimated_tokens:
274
+ # Time to refill needed tokens
275
+ needed = estimated_tokens - self._token_bucket
276
+ wait_for_tokens = needed * 60.0 / self.config.tokens_per_minute
277
+
278
+ wait_time = max(wait_for_request, wait_for_tokens, 0.1) # Min 100ms
279
+ wait_time = min(wait_time, 60.0) # Max 60s
280
+
281
+ logger.debug(
282
+ f"Rate limiter: waiting {wait_time:.2f}s "
283
+ f"(requests: {self._request_bucket:.1f}, tokens: {self._token_bucket:.0f})"
284
+ )
285
+
286
+ # Release lock while waiting
287
+ self._lock.release()
288
+ try:
289
+ await asyncio.sleep(wait_time)
290
+ total_delay += wait_time
291
+ finally:
292
+ await self._lock.acquire()
293
+
294
+ def record_usage(self, input_tokens: int, output_tokens: int) -> None:
295
+ """
296
+ Record actual token usage after an API call.
297
+
298
+ Args:
299
+ input_tokens: Actual input tokens used
300
+ output_tokens: Actual output tokens used
301
+ """
302
+ self.stats.total_requests += 1
303
+ self.stats.total_input_tokens += input_tokens
304
+ self.stats.total_output_tokens += output_tokens
305
+
306
+ # Calculate and record cost
307
+ if self.pricing:
308
+ cost = self.pricing.estimate_cost(input_tokens, output_tokens)
309
+ self.stats.total_estimated_cost += cost
310
+ self._hourly_costs.append((time.time(), cost))
311
+
312
+ if self.config.enable_cost_logging:
313
+ logger.info(
314
+ f"LLM call: {input_tokens} in + {output_tokens} out tokens, "
315
+ f"est. cost: ${cost:.6f}, total: ${self.stats.total_estimated_cost:.4f}"
316
+ )
317
+
318
+ def get_stats(self) -> Dict[str, Any]:
319
+ """Get rate limiter statistics."""
320
+ stats = self.stats.to_dict()
321
+ stats["model"] = self.model
322
+ stats["provider"] = self.provider
323
+ stats["config"] = {
324
+ "requests_per_minute": self.config.requests_per_minute,
325
+ "tokens_per_minute": self.config.tokens_per_minute,
326
+ "max_cost_per_hour": self.config.max_cost_per_hour,
327
+ }
328
+ stats["current_buckets"] = {
329
+ "requests": round(self._request_bucket, 2),
330
+ "tokens": round(self._token_bucket, 0),
331
+ }
332
+ return stats
333
+
334
+ def reset_stats(self) -> None:
335
+ """Reset statistics."""
336
+ self.stats = RateLimitStats()
337
+ self._hourly_costs.clear()
338
+
339
+
340
+ # =============================================================================
341
+ # Convenience Functions
342
+ # =============================================================================
343
+
344
+ def create_rate_limiter(
345
+ model: str = "llama-3.3-70b",
346
+ provider: str = "bedrock",
347
+ rpm: int = 60,
348
+ tpm: int = 100_000,
349
+ max_cost_per_hour: Optional[float] = None,
350
+ enable_cost_logging: bool = True,
351
+ ) -> TokenBucketRateLimiter:
352
+ """
353
+ Create a rate limiter with the specified configuration.
354
+
355
+ Args:
356
+ model: Model name for cost estimation
357
+ provider: Provider name ("bedrock", "openai", "anthropic")
358
+ rpm: Requests per minute limit
359
+ tpm: Tokens per minute limit
360
+ max_cost_per_hour: Maximum cost per hour in USD (None = unlimited)
361
+ enable_cost_logging: Log cost estimates for each call
362
+
363
+ Returns:
364
+ Configured TokenBucketRateLimiter
365
+ """
366
+ config = RateLimitConfig(
367
+ requests_per_minute=rpm,
368
+ tokens_per_minute=tpm,
369
+ max_cost_per_hour=max_cost_per_hour,
370
+ enable_cost_logging=enable_cost_logging,
371
+ )
372
+ return TokenBucketRateLimiter(config=config, model=model, provider=provider)