llm-cost-guard 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_cost_guard/__init__.py +39 -0
- llm_cost_guard/backends/__init__.py +52 -0
- llm_cost_guard/backends/base.py +121 -0
- llm_cost_guard/backends/memory.py +265 -0
- llm_cost_guard/backends/sqlite.py +425 -0
- llm_cost_guard/budget.py +306 -0
- llm_cost_guard/cli.py +464 -0
- llm_cost_guard/clients/__init__.py +11 -0
- llm_cost_guard/clients/anthropic.py +231 -0
- llm_cost_guard/clients/openai.py +262 -0
- llm_cost_guard/exceptions.py +71 -0
- llm_cost_guard/integrations/__init__.py +12 -0
- llm_cost_guard/integrations/cache.py +189 -0
- llm_cost_guard/integrations/langchain.py +257 -0
- llm_cost_guard/models.py +123 -0
- llm_cost_guard/pricing/__init__.py +7 -0
- llm_cost_guard/pricing/anthropic.yaml +88 -0
- llm_cost_guard/pricing/bedrock.yaml +215 -0
- llm_cost_guard/pricing/loader.py +221 -0
- llm_cost_guard/pricing/openai.yaml +148 -0
- llm_cost_guard/pricing/vertex.yaml +133 -0
- llm_cost_guard/providers/__init__.py +69 -0
- llm_cost_guard/providers/anthropic.py +115 -0
- llm_cost_guard/providers/base.py +72 -0
- llm_cost_guard/providers/bedrock.py +135 -0
- llm_cost_guard/providers/openai.py +110 -0
- llm_cost_guard/rate_limit.py +233 -0
- llm_cost_guard/span.py +143 -0
- llm_cost_guard/tokenizers/__init__.py +7 -0
- llm_cost_guard/tokenizers/base.py +207 -0
- llm_cost_guard/tracker.py +718 -0
- llm_cost_guard-0.1.0.dist-info/METADATA +357 -0
- llm_cost_guard-0.1.0.dist-info/RECORD +36 -0
- llm_cost_guard-0.1.0.dist-info/WHEEL +4 -0
- llm_cost_guard-0.1.0.dist-info/entry_points.txt +2 -0
- llm_cost_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SQLite storage backend for LLM Cost Guard.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import sqlite3
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
import threading
|
|
11
|
+
|
|
12
|
+
from llm_cost_guard.backends.base import Backend
|
|
13
|
+
from llm_cost_guard.models import CostRecord, CostReport, ModelType
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SQLiteBackend(Backend):
|
|
17
|
+
"""SQLite storage backend with thread-safe connection management."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, db_url: str, **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
Initialize the SQLite backend.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
db_url: SQLite database URL (e.g., "sqlite:///costs.db" or "sqlite:///:memory:")
|
|
25
|
+
"""
|
|
26
|
+
# Parse database path from URL
|
|
27
|
+
if db_url.startswith("sqlite:///"):
|
|
28
|
+
self._db_path = db_url[10:]
|
|
29
|
+
elif db_url.startswith("sqlite://"):
|
|
30
|
+
self._db_path = db_url[9:]
|
|
31
|
+
else:
|
|
32
|
+
self._db_path = db_url
|
|
33
|
+
|
|
34
|
+
# Handle in-memory database
|
|
35
|
+
if self._db_path == ":memory:" or self._db_path == "":
|
|
36
|
+
self._db_path = ":memory:"
|
|
37
|
+
|
|
38
|
+
self._local = threading.local()
|
|
39
|
+
self._init_database()
|
|
40
|
+
|
|
41
|
+
def _get_connection(self) -> sqlite3.Connection:
|
|
42
|
+
"""Get a thread-local database connection."""
|
|
43
|
+
if not hasattr(self._local, "conn") or self._local.conn is None:
|
|
44
|
+
self._local.conn = sqlite3.connect(
|
|
45
|
+
self._db_path, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES
|
|
46
|
+
)
|
|
47
|
+
self._local.conn.row_factory = sqlite3.Row
|
|
48
|
+
return self._local.conn
|
|
49
|
+
|
|
50
|
+
def _init_database(self) -> None:
|
|
51
|
+
"""Initialize the database schema."""
|
|
52
|
+
conn = self._get_connection()
|
|
53
|
+
cursor = conn.cursor()
|
|
54
|
+
|
|
55
|
+
cursor.execute(
|
|
56
|
+
"""
|
|
57
|
+
CREATE TABLE IF NOT EXISTS cost_records (
|
|
58
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
59
|
+
timestamp TIMESTAMP NOT NULL,
|
|
60
|
+
provider TEXT NOT NULL,
|
|
61
|
+
model TEXT NOT NULL,
|
|
62
|
+
model_type TEXT NOT NULL DEFAULT 'chat',
|
|
63
|
+
input_tokens INTEGER NOT NULL DEFAULT 0,
|
|
64
|
+
output_tokens INTEGER NOT NULL DEFAULT 0,
|
|
65
|
+
input_cost REAL NOT NULL DEFAULT 0.0,
|
|
66
|
+
output_cost REAL NOT NULL DEFAULT 0.0,
|
|
67
|
+
total_cost REAL NOT NULL DEFAULT 0.0,
|
|
68
|
+
latency_ms INTEGER NOT NULL DEFAULT 0,
|
|
69
|
+
tags TEXT NOT NULL DEFAULT '{}',
|
|
70
|
+
metadata TEXT NOT NULL DEFAULT '{}',
|
|
71
|
+
success INTEGER NOT NULL DEFAULT 1,
|
|
72
|
+
error_type TEXT,
|
|
73
|
+
cached INTEGER NOT NULL DEFAULT 0,
|
|
74
|
+
cache_savings REAL NOT NULL DEFAULT 0.0,
|
|
75
|
+
span_id TEXT
|
|
76
|
+
)
|
|
77
|
+
"""
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Create indexes for common queries
|
|
81
|
+
cursor.execute(
|
|
82
|
+
"CREATE INDEX IF NOT EXISTS idx_timestamp ON cost_records(timestamp)"
|
|
83
|
+
)
|
|
84
|
+
cursor.execute(
|
|
85
|
+
"CREATE INDEX IF NOT EXISTS idx_provider ON cost_records(provider)"
|
|
86
|
+
)
|
|
87
|
+
cursor.execute("CREATE INDEX IF NOT EXISTS idx_model ON cost_records(model)")
|
|
88
|
+
cursor.execute(
|
|
89
|
+
"CREATE INDEX IF NOT EXISTS idx_span_id ON cost_records(span_id)"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
conn.commit()
|
|
93
|
+
|
|
94
|
+
def _record_to_row(self, record: CostRecord) -> tuple:
|
|
95
|
+
"""Convert a CostRecord to a database row tuple."""
|
|
96
|
+
return (
|
|
97
|
+
record.timestamp,
|
|
98
|
+
record.provider,
|
|
99
|
+
record.model,
|
|
100
|
+
record.model_type.value,
|
|
101
|
+
record.input_tokens,
|
|
102
|
+
record.output_tokens,
|
|
103
|
+
record.input_cost,
|
|
104
|
+
record.output_cost,
|
|
105
|
+
record.total_cost,
|
|
106
|
+
record.latency_ms,
|
|
107
|
+
json.dumps(record.tags),
|
|
108
|
+
json.dumps(record.metadata),
|
|
109
|
+
1 if record.success else 0,
|
|
110
|
+
record.error_type,
|
|
111
|
+
1 if record.cached else 0,
|
|
112
|
+
record.cache_savings,
|
|
113
|
+
record.span_id,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _row_to_record(self, row: sqlite3.Row) -> CostRecord:
|
|
117
|
+
"""Convert a database row to a CostRecord."""
|
|
118
|
+
return CostRecord(
|
|
119
|
+
timestamp=row["timestamp"],
|
|
120
|
+
provider=row["provider"],
|
|
121
|
+
model=row["model"],
|
|
122
|
+
model_type=ModelType(row["model_type"]),
|
|
123
|
+
input_tokens=row["input_tokens"],
|
|
124
|
+
output_tokens=row["output_tokens"],
|
|
125
|
+
input_cost=row["input_cost"],
|
|
126
|
+
output_cost=row["output_cost"],
|
|
127
|
+
total_cost=row["total_cost"],
|
|
128
|
+
latency_ms=row["latency_ms"],
|
|
129
|
+
tags=json.loads(row["tags"]),
|
|
130
|
+
metadata=json.loads(row["metadata"]),
|
|
131
|
+
success=bool(row["success"]),
|
|
132
|
+
error_type=row["error_type"],
|
|
133
|
+
cached=bool(row["cached"]),
|
|
134
|
+
cache_savings=row["cache_savings"],
|
|
135
|
+
span_id=row["span_id"],
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def save_record(self, record: CostRecord) -> None:
|
|
139
|
+
"""Save a cost record."""
|
|
140
|
+
conn = self._get_connection()
|
|
141
|
+
cursor = conn.cursor()
|
|
142
|
+
|
|
143
|
+
cursor.execute(
|
|
144
|
+
"""
|
|
145
|
+
INSERT INTO cost_records (
|
|
146
|
+
timestamp, provider, model, model_type,
|
|
147
|
+
input_tokens, output_tokens, input_cost, output_cost, total_cost,
|
|
148
|
+
latency_ms, tags, metadata, success, error_type,
|
|
149
|
+
cached, cache_savings, span_id
|
|
150
|
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
151
|
+
""",
|
|
152
|
+
self._record_to_row(record),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
conn.commit()
|
|
156
|
+
|
|
157
|
+
def save_records(self, records: List[CostRecord]) -> None:
|
|
158
|
+
"""Save multiple cost records."""
|
|
159
|
+
if not records:
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
conn = self._get_connection()
|
|
163
|
+
cursor = conn.cursor()
|
|
164
|
+
|
|
165
|
+
cursor.executemany(
|
|
166
|
+
"""
|
|
167
|
+
INSERT INTO cost_records (
|
|
168
|
+
timestamp, provider, model, model_type,
|
|
169
|
+
input_tokens, output_tokens, input_cost, output_cost, total_cost,
|
|
170
|
+
latency_ms, tags, metadata, success, error_type,
|
|
171
|
+
cached, cache_savings, span_id
|
|
172
|
+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
173
|
+
""",
|
|
174
|
+
[self._record_to_row(r) for r in records],
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
conn.commit()
|
|
178
|
+
|
|
179
|
+
def _build_where_clause(
|
|
180
|
+
self,
|
|
181
|
+
start_date: Optional[datetime] = None,
|
|
182
|
+
end_date: Optional[datetime] = None,
|
|
183
|
+
tags: Optional[Dict[str, str]] = None,
|
|
184
|
+
) -> tuple[str, list]:
|
|
185
|
+
"""Build WHERE clause and parameters."""
|
|
186
|
+
conditions = []
|
|
187
|
+
params = []
|
|
188
|
+
|
|
189
|
+
if start_date:
|
|
190
|
+
conditions.append("timestamp >= ?")
|
|
191
|
+
params.append(start_date)
|
|
192
|
+
|
|
193
|
+
if end_date:
|
|
194
|
+
conditions.append("timestamp <= ?")
|
|
195
|
+
params.append(end_date)
|
|
196
|
+
|
|
197
|
+
if tags:
|
|
198
|
+
for key, value in tags.items():
|
|
199
|
+
# Use JSON extraction for tag filtering
|
|
200
|
+
conditions.append(f"json_extract(tags, '$.{key}') = ?")
|
|
201
|
+
params.append(value)
|
|
202
|
+
|
|
203
|
+
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
|
204
|
+
return where_clause, params
|
|
205
|
+
|
|
206
|
+
def get_records(
|
|
207
|
+
self,
|
|
208
|
+
start_date: Optional[datetime] = None,
|
|
209
|
+
end_date: Optional[datetime] = None,
|
|
210
|
+
tags: Optional[Dict[str, str]] = None,
|
|
211
|
+
limit: Optional[int] = None,
|
|
212
|
+
offset: int = 0,
|
|
213
|
+
) -> List[CostRecord]:
|
|
214
|
+
"""Retrieve cost records with optional filters."""
|
|
215
|
+
conn = self._get_connection()
|
|
216
|
+
cursor = conn.cursor()
|
|
217
|
+
|
|
218
|
+
where_clause, params = self._build_where_clause(start_date, end_date, tags)
|
|
219
|
+
|
|
220
|
+
query = f"""
|
|
221
|
+
SELECT * FROM cost_records
|
|
222
|
+
WHERE {where_clause}
|
|
223
|
+
ORDER BY timestamp DESC
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
if limit:
|
|
227
|
+
query += f" LIMIT {limit}"
|
|
228
|
+
if offset:
|
|
229
|
+
query += f" OFFSET {offset}"
|
|
230
|
+
|
|
231
|
+
cursor.execute(query, params)
|
|
232
|
+
rows = cursor.fetchall()
|
|
233
|
+
|
|
234
|
+
return [self._row_to_record(row) for row in rows]
|
|
235
|
+
|
|
236
|
+
def get_total_cost(
|
|
237
|
+
self,
|
|
238
|
+
start_date: Optional[datetime] = None,
|
|
239
|
+
end_date: Optional[datetime] = None,
|
|
240
|
+
tags: Optional[Dict[str, str]] = None,
|
|
241
|
+
) -> float:
|
|
242
|
+
"""Get total cost for the given filters."""
|
|
243
|
+
conn = self._get_connection()
|
|
244
|
+
cursor = conn.cursor()
|
|
245
|
+
|
|
246
|
+
where_clause, params = self._build_where_clause(start_date, end_date, tags)
|
|
247
|
+
|
|
248
|
+
cursor.execute(
|
|
249
|
+
f"SELECT COALESCE(SUM(total_cost), 0) FROM cost_records WHERE {where_clause}",
|
|
250
|
+
params,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
result = cursor.fetchone()
|
|
254
|
+
return float(result[0]) if result else 0.0
|
|
255
|
+
|
|
256
|
+
def get_aggregated_costs(
|
|
257
|
+
self,
|
|
258
|
+
start_date: Optional[datetime] = None,
|
|
259
|
+
end_date: Optional[datetime] = None,
|
|
260
|
+
tags: Optional[Dict[str, str]] = None,
|
|
261
|
+
group_by: Optional[List[str]] = None,
|
|
262
|
+
) -> Dict[str, Any]:
|
|
263
|
+
"""Get aggregated costs grouped by specified fields."""
|
|
264
|
+
conn = self._get_connection()
|
|
265
|
+
cursor = conn.cursor()
|
|
266
|
+
|
|
267
|
+
where_clause, params = self._build_where_clause(start_date, end_date, tags)
|
|
268
|
+
|
|
269
|
+
if not group_by:
|
|
270
|
+
cursor.execute(
|
|
271
|
+
f"""
|
|
272
|
+
SELECT
|
|
273
|
+
COALESCE(SUM(total_cost), 0) as total_cost,
|
|
274
|
+
COUNT(*) as total_calls,
|
|
275
|
+
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
|
276
|
+
COALESCE(SUM(output_tokens), 0) as total_output_tokens
|
|
277
|
+
FROM cost_records
|
|
278
|
+
WHERE {where_clause}
|
|
279
|
+
""",
|
|
280
|
+
params,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
row = cursor.fetchone()
|
|
284
|
+
return {
|
|
285
|
+
"total_cost": float(row[0]),
|
|
286
|
+
"total_calls": int(row[1]),
|
|
287
|
+
"total_input_tokens": int(row[2]),
|
|
288
|
+
"total_output_tokens": int(row[3]),
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
# Build GROUP BY clause
|
|
292
|
+
group_columns = []
|
|
293
|
+
for field in group_by:
|
|
294
|
+
if field in ("provider", "model"):
|
|
295
|
+
group_columns.append(field)
|
|
296
|
+
else:
|
|
297
|
+
# Assume it's a tag
|
|
298
|
+
group_columns.append(f"json_extract(tags, '$.{field}') as {field}")
|
|
299
|
+
|
|
300
|
+
select_cols = ", ".join(
|
|
301
|
+
[f if " as " not in f else f.split(" as ")[1] for f in group_columns]
|
|
302
|
+
)
|
|
303
|
+
group_cols = ", ".join(
|
|
304
|
+
[f.split(" as ")[0] if " as " in f else f for f in group_columns]
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
cursor.execute(
|
|
308
|
+
f"""
|
|
309
|
+
SELECT
|
|
310
|
+
{', '.join(group_columns)},
|
|
311
|
+
SUM(total_cost) as cost,
|
|
312
|
+
COUNT(*) as calls,
|
|
313
|
+
SUM(input_tokens) as input_tokens,
|
|
314
|
+
SUM(output_tokens) as output_tokens
|
|
315
|
+
FROM cost_records
|
|
316
|
+
WHERE {where_clause}
|
|
317
|
+
GROUP BY {group_cols}
|
|
318
|
+
ORDER BY cost DESC
|
|
319
|
+
""",
|
|
320
|
+
params,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
rows = cursor.fetchall()
|
|
324
|
+
groups = []
|
|
325
|
+
for row in rows:
|
|
326
|
+
group_data = {}
|
|
327
|
+
for i, field in enumerate(group_by):
|
|
328
|
+
group_data[field] = row[i]
|
|
329
|
+
group_data["cost"] = float(row[len(group_by)])
|
|
330
|
+
group_data["calls"] = int(row[len(group_by) + 1])
|
|
331
|
+
group_data["input_tokens"] = int(row[len(group_by) + 2])
|
|
332
|
+
group_data["output_tokens"] = int(row[len(group_by) + 3])
|
|
333
|
+
groups.append(group_data)
|
|
334
|
+
|
|
335
|
+
return {"groups": groups, "group_by": group_by}
|
|
336
|
+
|
|
337
|
+
def get_report(
|
|
338
|
+
self,
|
|
339
|
+
start_date: Optional[datetime] = None,
|
|
340
|
+
end_date: Optional[datetime] = None,
|
|
341
|
+
tags: Optional[Dict[str, str]] = None,
|
|
342
|
+
group_by: Optional[List[str]] = None,
|
|
343
|
+
) -> CostReport:
|
|
344
|
+
"""Generate a cost report."""
|
|
345
|
+
conn = self._get_connection()
|
|
346
|
+
cursor = conn.cursor()
|
|
347
|
+
|
|
348
|
+
where_clause, params = self._build_where_clause(start_date, end_date, tags)
|
|
349
|
+
|
|
350
|
+
cursor.execute(
|
|
351
|
+
f"""
|
|
352
|
+
SELECT
|
|
353
|
+
COALESCE(SUM(total_cost), 0) as total_cost,
|
|
354
|
+
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
|
355
|
+
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
|
356
|
+
COUNT(*) as total_calls,
|
|
357
|
+
SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successful_calls,
|
|
358
|
+
SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) as failed_calls,
|
|
359
|
+
SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END) as cache_hits,
|
|
360
|
+
COALESCE(SUM(cache_savings), 0) as cache_savings
|
|
361
|
+
FROM cost_records
|
|
362
|
+
WHERE {where_clause}
|
|
363
|
+
""",
|
|
364
|
+
params,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
row = cursor.fetchone()
|
|
368
|
+
|
|
369
|
+
grouped_data = {}
|
|
370
|
+
if group_by:
|
|
371
|
+
agg = self.get_aggregated_costs(start_date, end_date, tags, group_by)
|
|
372
|
+
grouped_data = {"groups": agg.get("groups", [])}
|
|
373
|
+
|
|
374
|
+
total_cost = float(row[0])
|
|
375
|
+
cache_savings = float(row[7])
|
|
376
|
+
|
|
377
|
+
return CostReport(
|
|
378
|
+
start_date=start_date,
|
|
379
|
+
end_date=end_date,
|
|
380
|
+
total_cost=total_cost,
|
|
381
|
+
total_input_tokens=int(row[1]),
|
|
382
|
+
total_output_tokens=int(row[2]),
|
|
383
|
+
total_calls=int(row[3]),
|
|
384
|
+
successful_calls=int(row[4]),
|
|
385
|
+
failed_calls=int(row[5]),
|
|
386
|
+
cache_hits=int(row[6]),
|
|
387
|
+
cache_savings=cache_savings,
|
|
388
|
+
effective_cost=total_cost - cache_savings,
|
|
389
|
+
records=[], # Don't include all records in report for performance
|
|
390
|
+
grouped_data=grouped_data,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
def delete_records(
|
|
394
|
+
self,
|
|
395
|
+
start_date: Optional[datetime] = None,
|
|
396
|
+
end_date: Optional[datetime] = None,
|
|
397
|
+
tags: Optional[Dict[str, str]] = None,
|
|
398
|
+
) -> int:
|
|
399
|
+
"""Delete records matching the filters."""
|
|
400
|
+
conn = self._get_connection()
|
|
401
|
+
cursor = conn.cursor()
|
|
402
|
+
|
|
403
|
+
where_clause, params = self._build_where_clause(start_date, end_date, tags)
|
|
404
|
+
|
|
405
|
+
cursor.execute(f"DELETE FROM cost_records WHERE {where_clause}", params)
|
|
406
|
+
deleted = cursor.rowcount
|
|
407
|
+
|
|
408
|
+
conn.commit()
|
|
409
|
+
return deleted
|
|
410
|
+
|
|
411
|
+
def health_check(self) -> bool:
|
|
412
|
+
"""Check if the backend is healthy."""
|
|
413
|
+
try:
|
|
414
|
+
conn = self._get_connection()
|
|
415
|
+
cursor = conn.cursor()
|
|
416
|
+
cursor.execute("SELECT 1")
|
|
417
|
+
return True
|
|
418
|
+
except Exception:
|
|
419
|
+
return False
|
|
420
|
+
|
|
421
|
+
def close(self) -> None:
|
|
422
|
+
"""Close the database connection."""
|
|
423
|
+
if hasattr(self._local, "conn") and self._local.conn:
|
|
424
|
+
self._local.conn.close()
|
|
425
|
+
self._local.conn = None
|