llm-cost-guard 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,557 @@
1
+ """
2
+ Redis backend for LLM Cost Guard with distributed budget enforcement.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ from datetime import datetime, timedelta
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ from llm_cost_guard.backends.base import Backend
11
+ from llm_cost_guard.models import CostRecord, CostReport, ModelType
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Lua script for atomic budget check and reservation
16
+ BUDGET_CHECK_SCRIPT = """
17
+ local budget_key = KEYS[1]
18
+ local period_key = KEYS[2]
19
+ local amount = tonumber(ARGV[1])
20
+ local limit = tonumber(ARGV[2])
21
+ local period_seconds = tonumber(ARGV[3])
22
+ local warning_threshold = tonumber(ARGV[4])
23
+
24
+ -- Get current spending
25
+ local current = tonumber(redis.call('GET', budget_key) or '0')
26
+
27
+ -- Check if we'd exceed the limit
28
+ local new_total = current + amount
29
+ if new_total > limit then
30
+ return {-1, current, limit} -- Exceeded
31
+ end
32
+
33
+ -- Check if we're at warning threshold
34
+ local warning = 0
35
+ if new_total >= (limit * warning_threshold) then
36
+ warning = 1
37
+ end
38
+
39
+ -- Atomically increment spending
40
+ redis.call('INCRBYFLOAT', budget_key, amount)
41
+
42
+ -- Set expiry if not set (for period reset)
43
+ local ttl = redis.call('TTL', budget_key)
44
+ if ttl == -1 then
45
+ redis.call('EXPIRE', budget_key, period_seconds)
46
+ end
47
+
48
+ return {new_total, current, warning}
49
+ """
50
+
51
+ # Lua script for atomic budget reservation (pessimistic)
52
+ BUDGET_RESERVE_SCRIPT = """
53
+ local budget_key = KEYS[1]
54
+ local reservation_key = KEYS[2]
55
+ local amount = tonumber(ARGV[1])
56
+ local limit = tonumber(ARGV[2])
57
+ local reservation_id = ARGV[3]
58
+ local period_seconds = tonumber(ARGV[4])
59
+
60
+ -- Get current spending + active reservations
61
+ local current = tonumber(redis.call('GET', budget_key) or '0')
62
+ local reserved = tonumber(redis.call('GET', reservation_key) or '0')
63
+ local effective = current + reserved
64
+
65
+ -- Check if we'd exceed the limit
66
+ if effective + amount > limit then
67
+ return {-1, effective, limit} -- Would exceed
68
+ end
69
+
70
+ -- Add to reservations
71
+ redis.call('INCRBYFLOAT', reservation_key, amount)
72
+ redis.call('EXPIRE', reservation_key, 300) -- 5 minute reservation timeout
73
+
74
+ -- Store individual reservation for cleanup
75
+ redis.call('SETEX', 'reservation:' .. reservation_id, 300, amount)
76
+
77
+ return {effective + amount, effective, 0} -- Success
78
+ """
79
+
80
+ # Lua script for finalizing reservation
81
+ BUDGET_FINALIZE_SCRIPT = """
82
+ local budget_key = KEYS[1]
83
+ local reservation_key = KEYS[2]
84
+ local reserved_amount = tonumber(ARGV[1])
85
+ local actual_amount = tonumber(ARGV[2])
86
+ local reservation_id = ARGV[3]
87
+ local period_seconds = tonumber(ARGV[4])
88
+
89
+ -- Remove from reservations
90
+ redis.call('INCRBYFLOAT', reservation_key, -reserved_amount)
91
+
92
+ -- Add actual amount to spending
93
+ redis.call('INCRBYFLOAT', budget_key, actual_amount)
94
+
95
+ -- Set expiry for period reset
96
+ local ttl = redis.call('TTL', budget_key)
97
+ if ttl == -1 then
98
+ redis.call('EXPIRE', budget_key, period_seconds)
99
+ end
100
+
101
+ -- Clean up reservation record
102
+ redis.call('DEL', 'reservation:' .. reservation_id)
103
+
104
+ return redis.call('GET', budget_key)
105
+ """
106
+
107
+
108
+ class RedisBackend(Backend):
109
+ """
110
+ Redis backend with distributed budget enforcement.
111
+
112
+ Features:
113
+ - Atomic budget checks using Lua scripts
114
+ - Pessimistic reservation for distributed consistency
115
+ - Automatic period reset via TTL
116
+ - Cost record storage with configurable retention
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ url: str = "redis://localhost:6379/0",
122
+ prefix: str = "llm_cost_guard:",
123
+ retention_days: int = 90,
124
+ **kwargs: Any,
125
+ ):
126
+ """
127
+ Initialize Redis backend.
128
+
129
+ Args:
130
+ url: Redis connection URL
131
+ prefix: Key prefix for all Redis keys
132
+ retention_days: How long to retain cost records
133
+ """
134
+ try:
135
+ import redis
136
+ except ImportError:
137
+ raise ImportError(
138
+ "redis package is required for Redis backend. "
139
+ "Install with: pip install llm-cost-guard[redis]"
140
+ )
141
+
142
+ self._prefix = prefix
143
+ self._retention_days = retention_days
144
+
145
+ # Parse URL and connect
146
+ self._client = redis.from_url(url, decode_responses=True, **kwargs)
147
+
148
+ # Register Lua scripts
149
+ self._budget_check_script = self._client.register_script(BUDGET_CHECK_SCRIPT)
150
+ self._budget_reserve_script = self._client.register_script(BUDGET_RESERVE_SCRIPT)
151
+ self._budget_finalize_script = self._client.register_script(BUDGET_FINALIZE_SCRIPT)
152
+
153
+ # Metrics for graceful degradation
154
+ self._metrics = {
155
+ "backend_failures": 0,
156
+ "fallback_activations": 0,
157
+ "records_pending_sync": 0,
158
+ }
159
+
160
+ def _key(self, *parts: str) -> str:
161
+ """Generate a prefixed key."""
162
+ return self._prefix + ":".join(parts)
163
+
164
+ # =========================================================================
165
+ # Distributed Budget Enforcement
166
+ # =========================================================================
167
+
168
+ def check_budget_atomic(
169
+ self,
170
+ budget_name: str,
171
+ amount: float,
172
+ limit: float,
173
+ period_seconds: int,
174
+ warning_threshold: float = 0.8,
175
+ ) -> Tuple[bool, float, bool]:
176
+ """
177
+ Atomically check and record spending against a budget.
178
+
179
+ Args:
180
+ budget_name: Name of the budget
181
+ amount: Amount to add
182
+ limit: Budget limit
183
+ period_seconds: Period duration in seconds
184
+ warning_threshold: Threshold for warning (0-1)
185
+
186
+ Returns:
187
+ Tuple of (allowed, current_spending, is_warning)
188
+ """
189
+ budget_key = self._key("budget", budget_name)
190
+ period_key = self._key("budget_period", budget_name)
191
+
192
+ try:
193
+ result = self._budget_check_script(
194
+ keys=[budget_key, period_key],
195
+ args=[amount, limit, period_seconds, warning_threshold],
196
+ )
197
+
198
+ new_total, current, warning = result
199
+
200
+ if new_total == -1:
201
+ return False, current, False # Exceeded
202
+
203
+ return True, new_total, warning == 1
204
+
205
+ except Exception as e:
206
+ logger.error(f"Redis budget check failed: {e}")
207
+ self._metrics["backend_failures"] += 1
208
+ raise
209
+
210
+ def reserve_budget(
211
+ self,
212
+ budget_name: str,
213
+ estimated_amount: float,
214
+ limit: float,
215
+ reservation_id: str,
216
+ period_seconds: int,
217
+ ) -> Tuple[bool, float]:
218
+ """
219
+ Reserve budget before making an LLM call (pessimistic locking).
220
+
221
+ Args:
222
+ budget_name: Name of the budget
223
+ estimated_amount: Estimated cost to reserve
224
+ limit: Budget limit
225
+ reservation_id: Unique ID for this reservation
226
+ period_seconds: Period duration in seconds
227
+
228
+ Returns:
229
+ Tuple of (allowed, effective_spending)
230
+ """
231
+ budget_key = self._key("budget", budget_name)
232
+ reservation_key = self._key("budget_reserved", budget_name)
233
+
234
+ try:
235
+ result = self._budget_reserve_script(
236
+ keys=[budget_key, reservation_key],
237
+ args=[estimated_amount, limit, reservation_id, period_seconds],
238
+ )
239
+
240
+ new_effective, current, _ = result
241
+
242
+ if new_effective == -1:
243
+ return False, current # Would exceed
244
+
245
+ return True, new_effective
246
+
247
+ except Exception as e:
248
+ logger.error(f"Redis budget reservation failed: {e}")
249
+ self._metrics["backend_failures"] += 1
250
+ raise
251
+
252
+ def finalize_budget(
253
+ self,
254
+ budget_name: str,
255
+ reserved_amount: float,
256
+ actual_amount: float,
257
+ reservation_id: str,
258
+ period_seconds: int,
259
+ ) -> float:
260
+ """
261
+ Finalize a budget reservation with actual cost.
262
+
263
+ Args:
264
+ budget_name: Name of the budget
265
+ reserved_amount: Originally reserved amount
266
+ actual_amount: Actual cost incurred
267
+ reservation_id: Reservation ID
268
+ period_seconds: Period duration in seconds
269
+
270
+ Returns:
271
+ New total spending
272
+ """
273
+ budget_key = self._key("budget", budget_name)
274
+ reservation_key = self._key("budget_reserved", budget_name)
275
+
276
+ try:
277
+ result = self._budget_finalize_script(
278
+ keys=[budget_key, reservation_key],
279
+ args=[reserved_amount, actual_amount, reservation_id, period_seconds],
280
+ )
281
+ return float(result)
282
+
283
+ except Exception as e:
284
+ logger.error(f"Redis budget finalization failed: {e}")
285
+ self._metrics["backend_failures"] += 1
286
+ raise
287
+
288
+ def release_reservation(
289
+ self,
290
+ budget_name: str,
291
+ reserved_amount: float,
292
+ reservation_id: str,
293
+ ) -> None:
294
+ """
295
+ Release a reservation (on failure or cancellation).
296
+
297
+ Args:
298
+ budget_name: Name of the budget
299
+ reserved_amount: Amount that was reserved
300
+ reservation_id: Reservation ID
301
+ """
302
+ reservation_key = self._key("budget_reserved", budget_name)
303
+
304
+ try:
305
+ pipe = self._client.pipeline()
306
+ pipe.incrbyfloat(reservation_key, -reserved_amount)
307
+ pipe.delete(f"reservation:{reservation_id}")
308
+ pipe.execute()
309
+ except Exception as e:
310
+ logger.error(f"Redis reservation release failed: {e}")
311
+ self._metrics["backend_failures"] += 1
312
+
313
+ def get_budget_spending(self, budget_name: str) -> float:
314
+ """Get current spending for a budget."""
315
+ budget_key = self._key("budget", budget_name)
316
+ try:
317
+ value = self._client.get(budget_key)
318
+ return float(value) if value else 0.0
319
+ except Exception as e:
320
+ logger.error(f"Redis get budget spending failed: {e}")
321
+ self._metrics["backend_failures"] += 1
322
+ return 0.0
323
+
324
+ def reset_budget(self, budget_name: str) -> None:
325
+ """Reset a budget (for testing or manual reset)."""
326
+ budget_key = self._key("budget", budget_name)
327
+ reservation_key = self._key("budget_reserved", budget_name)
328
+ try:
329
+ self._client.delete(budget_key, reservation_key)
330
+ except Exception as e:
331
+ logger.error(f"Redis budget reset failed: {e}")
332
+ self._metrics["backend_failures"] += 1
333
+
334
+ # =========================================================================
335
+ # Cost Record Storage
336
+ # =========================================================================
337
+
338
+ def save_record(self, record: CostRecord) -> None:
339
+ """Save a cost record."""
340
+ record_key = self._key("record", record.timestamp.strftime("%Y%m%d%H%M%S%f"))
341
+ record_data = self._serialize_record(record)
342
+
343
+ try:
344
+ pipe = self._client.pipeline()
345
+
346
+ # Save record with TTL
347
+ ttl_seconds = self._retention_days * 24 * 60 * 60
348
+ pipe.setex(record_key, ttl_seconds, json.dumps(record_data))
349
+
350
+ # Add to sorted set for range queries (score = timestamp)
351
+ records_key = self._key("records")
352
+ score = record.timestamp.timestamp()
353
+ pipe.zadd(records_key, {record_key: score})
354
+
355
+ # Update aggregates for quick reporting
356
+ self._update_aggregates(pipe, record)
357
+
358
+ pipe.execute()
359
+
360
+ except Exception as e:
361
+ logger.error(f"Redis save record failed: {e}")
362
+ self._metrics["backend_failures"] += 1
363
+ raise
364
+
365
+ def _update_aggregates(self, pipe: Any, record: CostRecord) -> None:
366
+ """Update aggregate counters for quick reporting."""
367
+ date_str = record.timestamp.strftime("%Y-%m-%d")
368
+ hour_str = record.timestamp.strftime("%Y-%m-%d-%H")
369
+
370
+ # Daily aggregates
371
+ daily_key = self._key("agg", "daily", date_str)
372
+ pipe.hincrbyfloat(daily_key, "total_cost", record.total_cost)
373
+ pipe.hincrby(daily_key, "total_calls", 1)
374
+ pipe.hincrby(daily_key, "input_tokens", record.input_tokens)
375
+ pipe.hincrby(daily_key, "output_tokens", record.output_tokens)
376
+ pipe.expire(daily_key, self._retention_days * 24 * 60 * 60)
377
+
378
+ # Model aggregates
379
+ model_key = self._key("agg", "model", date_str, record.model)
380
+ pipe.hincrbyfloat(model_key, "total_cost", record.total_cost)
381
+ pipe.hincrby(model_key, "total_calls", 1)
382
+ pipe.expire(model_key, self._retention_days * 24 * 60 * 60)
383
+
384
+ # Tag aggregates
385
+ for tag_key, tag_value in record.tags.items():
386
+ tag_agg_key = self._key("agg", "tag", date_str, tag_key, tag_value)
387
+ pipe.hincrbyfloat(tag_agg_key, "total_cost", record.total_cost)
388
+ pipe.hincrby(tag_agg_key, "total_calls", 1)
389
+ pipe.expire(tag_agg_key, self._retention_days * 24 * 60 * 60)
390
+
391
+ def get_records(
392
+ self,
393
+ start_date: Optional[datetime] = None,
394
+ end_date: Optional[datetime] = None,
395
+ tags: Optional[Dict[str, str]] = None,
396
+ limit: int = 1000,
397
+ ) -> List[CostRecord]:
398
+ """Get cost records with optional filtering."""
399
+ records_key = self._key("records")
400
+
401
+ try:
402
+ # Get record keys from sorted set
403
+ min_score = start_date.timestamp() if start_date else "-inf"
404
+ max_score = end_date.timestamp() if end_date else "+inf"
405
+
406
+ record_keys = self._client.zrangebyscore(
407
+ records_key, min_score, max_score, start=0, num=limit
408
+ )
409
+
410
+ if not record_keys:
411
+ return []
412
+
413
+ # Fetch records
414
+ pipe = self._client.pipeline()
415
+ for key in record_keys:
416
+ pipe.get(key)
417
+
418
+ results = pipe.execute()
419
+
420
+ records = []
421
+ for data in results:
422
+ if data:
423
+ record = self._deserialize_record(json.loads(data))
424
+
425
+ # Filter by tags if specified
426
+ if tags:
427
+ if all(record.tags.get(k) == v for k, v in tags.items()):
428
+ records.append(record)
429
+ else:
430
+ records.append(record)
431
+
432
+ return records
433
+
434
+ except Exception as e:
435
+ logger.error(f"Redis get records failed: {e}")
436
+ self._metrics["backend_failures"] += 1
437
+ return []
438
+
439
+ def get_report(
440
+ self,
441
+ start_date: Optional[datetime] = None,
442
+ end_date: Optional[datetime] = None,
443
+ tags: Optional[Dict[str, str]] = None,
444
+ group_by: Optional[List[str]] = None,
445
+ ) -> CostReport:
446
+ """Get aggregated cost report."""
447
+ records = self.get_records(start_date, end_date, tags)
448
+
449
+ total_cost = sum(r.total_cost for r in records)
450
+ total_tokens = sum(r.input_tokens + r.output_tokens for r in records)
451
+
452
+ return CostReport(
453
+ total_cost=total_cost,
454
+ total_tokens=total_tokens,
455
+ total_calls=len(records),
456
+ records=records,
457
+ start_date=start_date,
458
+ end_date=end_date,
459
+ grouped_data={} if not group_by else self._group_records(records, group_by),
460
+ )
461
+
462
+ def _group_records(
463
+ self, records: List[CostRecord], group_by: List[str]
464
+ ) -> Dict[str, Any]:
465
+ """Group records by specified fields."""
466
+ groups: Dict[str, Dict[str, float]] = {}
467
+
468
+ for record in records:
469
+ key_parts = []
470
+ for field in group_by:
471
+ if field == "model":
472
+ key_parts.append(record.model)
473
+ elif field == "provider":
474
+ key_parts.append(record.provider)
475
+ elif field.startswith("tag:"):
476
+ tag_name = field[4:]
477
+ key_parts.append(record.tags.get(tag_name, "unknown"))
478
+ else:
479
+ key_parts.append(record.tags.get(field, "unknown"))
480
+
481
+ key = "|".join(key_parts)
482
+
483
+ if key not in groups:
484
+ groups[key] = {"cost": 0.0, "calls": 0, "tokens": 0}
485
+
486
+ groups[key]["cost"] += record.total_cost
487
+ groups[key]["calls"] += 1
488
+ groups[key]["tokens"] += record.input_tokens + record.output_tokens
489
+
490
+ return groups
491
+
492
+ def _serialize_record(self, record: CostRecord) -> Dict[str, Any]:
493
+ """Serialize a CostRecord to dict."""
494
+ return {
495
+ "timestamp": record.timestamp.isoformat(),
496
+ "provider": record.provider,
497
+ "model": record.model,
498
+ "model_type": record.model_type.value if record.model_type else "chat",
499
+ "input_tokens": record.input_tokens,
500
+ "output_tokens": record.output_tokens,
501
+ "input_cost": record.input_cost,
502
+ "output_cost": record.output_cost,
503
+ "total_cost": record.total_cost,
504
+ "latency_ms": record.latency_ms,
505
+ "tags": record.tags,
506
+ "metadata": record.metadata,
507
+ "success": record.success,
508
+ "error_type": record.error_type,
509
+ "cached": record.cached,
510
+ "cache_savings": record.cache_savings,
511
+ "span_id": record.span_id,
512
+ }
513
+
514
+ def _deserialize_record(self, data: Dict[str, Any]) -> CostRecord:
515
+ """Deserialize a dict to CostRecord."""
516
+ return CostRecord(
517
+ timestamp=datetime.fromisoformat(data["timestamp"]),
518
+ provider=data["provider"],
519
+ model=data["model"],
520
+ model_type=ModelType(data.get("model_type", "chat")),
521
+ input_tokens=data["input_tokens"],
522
+ output_tokens=data["output_tokens"],
523
+ input_cost=data["input_cost"],
524
+ output_cost=data["output_cost"],
525
+ total_cost=data["total_cost"],
526
+ latency_ms=data["latency_ms"],
527
+ tags=data.get("tags", {}),
528
+ metadata=data.get("metadata", {}),
529
+ success=data.get("success", True),
530
+ error_type=data.get("error_type"),
531
+ cached=data.get("cached", False),
532
+ cache_savings=data.get("cache_savings", 0.0),
533
+ span_id=data.get("span_id"),
534
+ )
535
+
536
+ # =========================================================================
537
+ # Health & Metrics
538
+ # =========================================================================
539
+
540
+ def health_check(self) -> bool:
541
+ """Check Redis connection health."""
542
+ try:
543
+ self._client.ping()
544
+ return True
545
+ except Exception:
546
+ return False
547
+
548
+ def get_metrics(self) -> Dict[str, Any]:
549
+ """Get backend metrics for observability."""
550
+ return {
551
+ **self._metrics,
552
+ "connected": self.health_check(),
553
+ }
554
+
555
+ def close(self) -> None:
556
+ """Close the Redis connection."""
557
+ self._client.close()