llm-cost-guard 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. llm_cost_guard/__init__.py +39 -0
  2. llm_cost_guard/backends/__init__.py +52 -0
  3. llm_cost_guard/backends/base.py +121 -0
  4. llm_cost_guard/backends/memory.py +265 -0
  5. llm_cost_guard/backends/sqlite.py +425 -0
  6. llm_cost_guard/budget.py +306 -0
  7. llm_cost_guard/cli.py +464 -0
  8. llm_cost_guard/clients/__init__.py +11 -0
  9. llm_cost_guard/clients/anthropic.py +231 -0
  10. llm_cost_guard/clients/openai.py +262 -0
  11. llm_cost_guard/exceptions.py +71 -0
  12. llm_cost_guard/integrations/__init__.py +12 -0
  13. llm_cost_guard/integrations/cache.py +189 -0
  14. llm_cost_guard/integrations/langchain.py +257 -0
  15. llm_cost_guard/models.py +123 -0
  16. llm_cost_guard/pricing/__init__.py +7 -0
  17. llm_cost_guard/pricing/anthropic.yaml +88 -0
  18. llm_cost_guard/pricing/bedrock.yaml +215 -0
  19. llm_cost_guard/pricing/loader.py +221 -0
  20. llm_cost_guard/pricing/openai.yaml +148 -0
  21. llm_cost_guard/pricing/vertex.yaml +133 -0
  22. llm_cost_guard/providers/__init__.py +69 -0
  23. llm_cost_guard/providers/anthropic.py +115 -0
  24. llm_cost_guard/providers/base.py +72 -0
  25. llm_cost_guard/providers/bedrock.py +135 -0
  26. llm_cost_guard/providers/openai.py +110 -0
  27. llm_cost_guard/rate_limit.py +233 -0
  28. llm_cost_guard/span.py +143 -0
  29. llm_cost_guard/tokenizers/__init__.py +7 -0
  30. llm_cost_guard/tokenizers/base.py +207 -0
  31. llm_cost_guard/tracker.py +718 -0
  32. llm_cost_guard-0.1.0.dist-info/METADATA +357 -0
  33. llm_cost_guard-0.1.0.dist-info/RECORD +36 -0
  34. llm_cost_guard-0.1.0.dist-info/WHEEL +4 -0
  35. llm_cost_guard-0.1.0.dist-info/entry_points.txt +2 -0
  36. llm_cost_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,425 @@
1
+ """
2
+ SQLite storage backend for LLM Cost Guard.
3
+ """
4
+
5
+ import json
6
+ import sqlite3
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+ import threading
11
+
12
+ from llm_cost_guard.backends.base import Backend
13
+ from llm_cost_guard.models import CostRecord, CostReport, ModelType
14
+
15
+
16
+ class SQLiteBackend(Backend):
17
+ """SQLite storage backend with thread-safe connection management."""
18
+
19
+ def __init__(self, db_url: str, **kwargs):
20
+ """
21
+ Initialize the SQLite backend.
22
+
23
+ Args:
24
+ db_url: SQLite database URL (e.g., "sqlite:///costs.db" or "sqlite:///:memory:")
25
+ """
26
+ # Parse database path from URL
27
+ if db_url.startswith("sqlite:///"):
28
+ self._db_path = db_url[10:]
29
+ elif db_url.startswith("sqlite://"):
30
+ self._db_path = db_url[9:]
31
+ else:
32
+ self._db_path = db_url
33
+
34
+ # Handle in-memory database
35
+ if self._db_path == ":memory:" or self._db_path == "":
36
+ self._db_path = ":memory:"
37
+
38
+ self._local = threading.local()
39
+ self._init_database()
40
+
41
+ def _get_connection(self) -> sqlite3.Connection:
42
+ """Get a thread-local database connection."""
43
+ if not hasattr(self._local, "conn") or self._local.conn is None:
44
+ self._local.conn = sqlite3.connect(
45
+ self._db_path, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
46
+ )
47
+ self._local.conn.row_factory = sqlite3.Row
48
+ return self._local.conn
49
+
50
+ def _init_database(self) -> None:
51
+ """Initialize the database schema."""
52
+ conn = self._get_connection()
53
+ cursor = conn.cursor()
54
+
55
+ cursor.execute(
56
+ """
57
+ CREATE TABLE IF NOT EXISTS cost_records (
58
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
59
+ timestamp TIMESTAMP NOT NULL,
60
+ provider TEXT NOT NULL,
61
+ model TEXT NOT NULL,
62
+ model_type TEXT NOT NULL DEFAULT 'chat',
63
+ input_tokens INTEGER NOT NULL DEFAULT 0,
64
+ output_tokens INTEGER NOT NULL DEFAULT 0,
65
+ input_cost REAL NOT NULL DEFAULT 0.0,
66
+ output_cost REAL NOT NULL DEFAULT 0.0,
67
+ total_cost REAL NOT NULL DEFAULT 0.0,
68
+ latency_ms INTEGER NOT NULL DEFAULT 0,
69
+ tags TEXT NOT NULL DEFAULT '{}',
70
+ metadata TEXT NOT NULL DEFAULT '{}',
71
+ success INTEGER NOT NULL DEFAULT 1,
72
+ error_type TEXT,
73
+ cached INTEGER NOT NULL DEFAULT 0,
74
+ cache_savings REAL NOT NULL DEFAULT 0.0,
75
+ span_id TEXT
76
+ )
77
+ """
78
+ )
79
+
80
+ # Create indexes for common queries
81
+ cursor.execute(
82
+ "CREATE INDEX IF NOT EXISTS idx_timestamp ON cost_records(timestamp)"
83
+ )
84
+ cursor.execute(
85
+ "CREATE INDEX IF NOT EXISTS idx_provider ON cost_records(provider)"
86
+ )
87
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_model ON cost_records(model)")
88
+ cursor.execute(
89
+ "CREATE INDEX IF NOT EXISTS idx_span_id ON cost_records(span_id)"
90
+ )
91
+
92
+ conn.commit()
93
+
94
+ def _record_to_row(self, record: CostRecord) -> tuple:
95
+ """Convert a CostRecord to a database row tuple."""
96
+ return (
97
+ record.timestamp,
98
+ record.provider,
99
+ record.model,
100
+ record.model_type.value,
101
+ record.input_tokens,
102
+ record.output_tokens,
103
+ record.input_cost,
104
+ record.output_cost,
105
+ record.total_cost,
106
+ record.latency_ms,
107
+ json.dumps(record.tags),
108
+ json.dumps(record.metadata),
109
+ 1 if record.success else 0,
110
+ record.error_type,
111
+ 1 if record.cached else 0,
112
+ record.cache_savings,
113
+ record.span_id,
114
+ )
115
+
116
+ def _row_to_record(self, row: sqlite3.Row) -> CostRecord:
117
+ """Convert a database row to a CostRecord."""
118
+ return CostRecord(
119
+ timestamp=row["timestamp"],
120
+ provider=row["provider"],
121
+ model=row["model"],
122
+ model_type=ModelType(row["model_type"]),
123
+ input_tokens=row["input_tokens"],
124
+ output_tokens=row["output_tokens"],
125
+ input_cost=row["input_cost"],
126
+ output_cost=row["output_cost"],
127
+ total_cost=row["total_cost"],
128
+ latency_ms=row["latency_ms"],
129
+ tags=json.loads(row["tags"]),
130
+ metadata=json.loads(row["metadata"]),
131
+ success=bool(row["success"]),
132
+ error_type=row["error_type"],
133
+ cached=bool(row["cached"]),
134
+ cache_savings=row["cache_savings"],
135
+ span_id=row["span_id"],
136
+ )
137
+
138
+ def save_record(self, record: CostRecord) -> None:
139
+ """Save a cost record."""
140
+ conn = self._get_connection()
141
+ cursor = conn.cursor()
142
+
143
+ cursor.execute(
144
+ """
145
+ INSERT INTO cost_records (
146
+ timestamp, provider, model, model_type,
147
+ input_tokens, output_tokens, input_cost, output_cost, total_cost,
148
+ latency_ms, tags, metadata, success, error_type,
149
+ cached, cache_savings, span_id
150
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
151
+ """,
152
+ self._record_to_row(record),
153
+ )
154
+
155
+ conn.commit()
156
+
157
+ def save_records(self, records: List[CostRecord]) -> None:
158
+ """Save multiple cost records."""
159
+ if not records:
160
+ return
161
+
162
+ conn = self._get_connection()
163
+ cursor = conn.cursor()
164
+
165
+ cursor.executemany(
166
+ """
167
+ INSERT INTO cost_records (
168
+ timestamp, provider, model, model_type,
169
+ input_tokens, output_tokens, input_cost, output_cost, total_cost,
170
+ latency_ms, tags, metadata, success, error_type,
171
+ cached, cache_savings, span_id
172
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
173
+ """,
174
+ [self._record_to_row(r) for r in records],
175
+ )
176
+
177
+ conn.commit()
178
+
179
+ def _build_where_clause(
180
+ self,
181
+ start_date: Optional[datetime] = None,
182
+ end_date: Optional[datetime] = None,
183
+ tags: Optional[Dict[str, str]] = None,
184
+ ) -> tuple[str, list]:
185
+ """Build WHERE clause and parameters."""
186
+ conditions = []
187
+ params = []
188
+
189
+ if start_date:
190
+ conditions.append("timestamp >= ?")
191
+ params.append(start_date)
192
+
193
+ if end_date:
194
+ conditions.append("timestamp <= ?")
195
+ params.append(end_date)
196
+
197
+ if tags:
198
+ for key, value in tags.items():
199
+ # Use JSON extraction for tag filtering
200
+ conditions.append(f"json_extract(tags, '$.{key}') = ?")
201
+ params.append(value)
202
+
203
+ where_clause = " AND ".join(conditions) if conditions else "1=1"
204
+ return where_clause, params
205
+
206
+ def get_records(
207
+ self,
208
+ start_date: Optional[datetime] = None,
209
+ end_date: Optional[datetime] = None,
210
+ tags: Optional[Dict[str, str]] = None,
211
+ limit: Optional[int] = None,
212
+ offset: int = 0,
213
+ ) -> List[CostRecord]:
214
+ """Retrieve cost records with optional filters."""
215
+ conn = self._get_connection()
216
+ cursor = conn.cursor()
217
+
218
+ where_clause, params = self._build_where_clause(start_date, end_date, tags)
219
+
220
+ query = f"""
221
+ SELECT * FROM cost_records
222
+ WHERE {where_clause}
223
+ ORDER BY timestamp DESC
224
+ """
225
+
226
+ if limit:
227
+ query += f" LIMIT {limit}"
228
+ if offset:
229
+ query += f" OFFSET {offset}"
230
+
231
+ cursor.execute(query, params)
232
+ rows = cursor.fetchall()
233
+
234
+ return [self._row_to_record(row) for row in rows]
235
+
236
+ def get_total_cost(
237
+ self,
238
+ start_date: Optional[datetime] = None,
239
+ end_date: Optional[datetime] = None,
240
+ tags: Optional[Dict[str, str]] = None,
241
+ ) -> float:
242
+ """Get total cost for the given filters."""
243
+ conn = self._get_connection()
244
+ cursor = conn.cursor()
245
+
246
+ where_clause, params = self._build_where_clause(start_date, end_date, tags)
247
+
248
+ cursor.execute(
249
+ f"SELECT COALESCE(SUM(total_cost), 0) FROM cost_records WHERE {where_clause}",
250
+ params,
251
+ )
252
+
253
+ result = cursor.fetchone()
254
+ return float(result[0]) if result else 0.0
255
+
256
+ def get_aggregated_costs(
257
+ self,
258
+ start_date: Optional[datetime] = None,
259
+ end_date: Optional[datetime] = None,
260
+ tags: Optional[Dict[str, str]] = None,
261
+ group_by: Optional[List[str]] = None,
262
+ ) -> Dict[str, Any]:
263
+ """Get aggregated costs grouped by specified fields."""
264
+ conn = self._get_connection()
265
+ cursor = conn.cursor()
266
+
267
+ where_clause, params = self._build_where_clause(start_date, end_date, tags)
268
+
269
+ if not group_by:
270
+ cursor.execute(
271
+ f"""
272
+ SELECT
273
+ COALESCE(SUM(total_cost), 0) as total_cost,
274
+ COUNT(*) as total_calls,
275
+ COALESCE(SUM(input_tokens), 0) as total_input_tokens,
276
+ COALESCE(SUM(output_tokens), 0) as total_output_tokens
277
+ FROM cost_records
278
+ WHERE {where_clause}
279
+ """,
280
+ params,
281
+ )
282
+
283
+ row = cursor.fetchone()
284
+ return {
285
+ "total_cost": float(row[0]),
286
+ "total_calls": int(row[1]),
287
+ "total_input_tokens": int(row[2]),
288
+ "total_output_tokens": int(row[3]),
289
+ }
290
+
291
+ # Build GROUP BY clause
292
+ group_columns = []
293
+ for field in group_by:
294
+ if field in ("provider", "model"):
295
+ group_columns.append(field)
296
+ else:
297
+ # Assume it's a tag
298
+ group_columns.append(f"json_extract(tags, '$.{field}') as {field}")
299
+
300
+ select_cols = ", ".join(
301
+ [f if " as " not in f else f.split(" as ")[1] for f in group_columns]
302
+ )
303
+ group_cols = ", ".join(
304
+ [f.split(" as ")[0] if " as " in f else f for f in group_columns]
305
+ )
306
+
307
+ cursor.execute(
308
+ f"""
309
+ SELECT
310
+ {', '.join(group_columns)},
311
+ SUM(total_cost) as cost,
312
+ COUNT(*) as calls,
313
+ SUM(input_tokens) as input_tokens,
314
+ SUM(output_tokens) as output_tokens
315
+ FROM cost_records
316
+ WHERE {where_clause}
317
+ GROUP BY {group_cols}
318
+ ORDER BY cost DESC
319
+ """,
320
+ params,
321
+ )
322
+
323
+ rows = cursor.fetchall()
324
+ groups = []
325
+ for row in rows:
326
+ group_data = {}
327
+ for i, field in enumerate(group_by):
328
+ group_data[field] = row[i]
329
+ group_data["cost"] = float(row[len(group_by)])
330
+ group_data["calls"] = int(row[len(group_by) + 1])
331
+ group_data["input_tokens"] = int(row[len(group_by) + 2])
332
+ group_data["output_tokens"] = int(row[len(group_by) + 3])
333
+ groups.append(group_data)
334
+
335
+ return {"groups": groups, "group_by": group_by}
336
+
337
+ def get_report(
338
+ self,
339
+ start_date: Optional[datetime] = None,
340
+ end_date: Optional[datetime] = None,
341
+ tags: Optional[Dict[str, str]] = None,
342
+ group_by: Optional[List[str]] = None,
343
+ ) -> CostReport:
344
+ """Generate a cost report."""
345
+ conn = self._get_connection()
346
+ cursor = conn.cursor()
347
+
348
+ where_clause, params = self._build_where_clause(start_date, end_date, tags)
349
+
350
+ cursor.execute(
351
+ f"""
352
+ SELECT
353
+ COALESCE(SUM(total_cost), 0) as total_cost,
354
+ COALESCE(SUM(input_tokens), 0) as total_input_tokens,
355
+ COALESCE(SUM(output_tokens), 0) as total_output_tokens,
356
+ COUNT(*) as total_calls,
357
+ SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successful_calls,
358
+ SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) as failed_calls,
359
+ SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) as cache_hits,
360
+ COALESCE(SUM(cache_savings), 0) as cache_savings
361
+ FROM cost_records
362
+ WHERE {where_clause}
363
+ """,
364
+ params,
365
+ )
366
+
367
+ row = cursor.fetchone()
368
+
369
+ grouped_data = {}
370
+ if group_by:
371
+ agg = self.get_aggregated_costs(start_date, end_date, tags, group_by)
372
+ grouped_data = {"groups": agg.get("groups", [])}
373
+
374
+ total_cost = float(row[0])
375
+ cache_savings = float(row[7])
376
+
377
+ return CostReport(
378
+ start_date=start_date,
379
+ end_date=end_date,
380
+ total_cost=total_cost,
381
+ total_input_tokens=int(row[1]),
382
+ total_output_tokens=int(row[2]),
383
+ total_calls=int(row[3]),
384
+ successful_calls=int(row[4]),
385
+ failed_calls=int(row[5]),
386
+ cache_hits=int(row[6]),
387
+ cache_savings=cache_savings,
388
+ effective_cost=total_cost - cache_savings,
389
+ records=[], # Don't include all records in report for performance
390
+ grouped_data=grouped_data,
391
+ )
392
+
393
+ def delete_records(
394
+ self,
395
+ start_date: Optional[datetime] = None,
396
+ end_date: Optional[datetime] = None,
397
+ tags: Optional[Dict[str, str]] = None,
398
+ ) -> int:
399
+ """Delete records matching the filters."""
400
+ conn = self._get_connection()
401
+ cursor = conn.cursor()
402
+
403
+ where_clause, params = self._build_where_clause(start_date, end_date, tags)
404
+
405
+ cursor.execute(f"DELETE FROM cost_records WHERE {where_clause}", params)
406
+ deleted = cursor.rowcount
407
+
408
+ conn.commit()
409
+ return deleted
410
+
411
+ def health_check(self) -> bool:
412
+ """Check if the backend is healthy."""
413
+ try:
414
+ conn = self._get_connection()
415
+ cursor = conn.cursor()
416
+ cursor.execute("SELECT 1")
417
+ return True
418
+ except Exception:
419
+ return False
420
+
421
+ def close(self) -> None:
422
+ """Close the database connection."""
423
+ if hasattr(self._local, "conn") and self._local.conn:
424
+ self._local.conn.close()
425
+ self._local.conn = None