llm-cost-guard 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_cost_guard/__init__.py +39 -0
- llm_cost_guard/backends/__init__.py +52 -0
- llm_cost_guard/backends/base.py +121 -0
- llm_cost_guard/backends/memory.py +265 -0
- llm_cost_guard/backends/sqlite.py +425 -0
- llm_cost_guard/budget.py +306 -0
- llm_cost_guard/cli.py +464 -0
- llm_cost_guard/clients/__init__.py +11 -0
- llm_cost_guard/clients/anthropic.py +231 -0
- llm_cost_guard/clients/openai.py +262 -0
- llm_cost_guard/exceptions.py +71 -0
- llm_cost_guard/integrations/__init__.py +12 -0
- llm_cost_guard/integrations/cache.py +189 -0
- llm_cost_guard/integrations/langchain.py +257 -0
- llm_cost_guard/models.py +123 -0
- llm_cost_guard/pricing/__init__.py +7 -0
- llm_cost_guard/pricing/anthropic.yaml +88 -0
- llm_cost_guard/pricing/bedrock.yaml +215 -0
- llm_cost_guard/pricing/loader.py +221 -0
- llm_cost_guard/pricing/openai.yaml +148 -0
- llm_cost_guard/pricing/vertex.yaml +133 -0
- llm_cost_guard/providers/__init__.py +69 -0
- llm_cost_guard/providers/anthropic.py +115 -0
- llm_cost_guard/providers/base.py +72 -0
- llm_cost_guard/providers/bedrock.py +135 -0
- llm_cost_guard/providers/openai.py +110 -0
- llm_cost_guard/rate_limit.py +233 -0
- llm_cost_guard/span.py +143 -0
- llm_cost_guard/tokenizers/__init__.py +7 -0
- llm_cost_guard/tokenizers/base.py +207 -0
- llm_cost_guard/tracker.py +718 -0
- llm_cost_guard-0.1.0.dist-info/METADATA +357 -0
- llm_cost_guard-0.1.0.dist-info/RECORD +36 -0
- llm_cost_guard-0.1.0.dist-info/WHEEL +4 -0
- llm_cost_guard-0.1.0.dist-info/entry_points.txt +2 -0
- llm_cost_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main CostTracker class for LLM Cost Guard.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import functools
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
import threading
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union
|
|
13
|
+
|
|
14
|
+
from llm_cost_guard.backends import Backend, MemoryBackend, get_backend
|
|
15
|
+
from llm_cost_guard.budget import Budget, BudgetAction, BudgetTracker
|
|
16
|
+
from llm_cost_guard.exceptions import (
|
|
17
|
+
BudgetExceededError,
|
|
18
|
+
RateLimitExceededError,
|
|
19
|
+
TrackingUnavailableError,
|
|
20
|
+
)
|
|
21
|
+
from llm_cost_guard.models import CostRecord, CostReport, HealthStatus, ModelType, UsageData
|
|
22
|
+
from llm_cost_guard.pricing.loader import PricingLoader, get_pricing_loader
|
|
23
|
+
from llm_cost_guard.providers import detect_provider, get_provider
|
|
24
|
+
from llm_cost_guard.rate_limit import RateLimit, RateLimiter
|
|
25
|
+
from llm_cost_guard.span import Span, get_current_span
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CostTracker:
|
|
33
|
+
"""
|
|
34
|
+
Main entry point for cost tracking.
|
|
35
|
+
|
|
36
|
+
Provides decorator-based and context manager tracking for LLM API calls.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
budgets: Optional[List[Budget]] = None,
|
|
42
|
+
rate_limits: Optional[List[RateLimit]] = None,
|
|
43
|
+
backend: str = "memory",
|
|
44
|
+
auto_detect_provider: bool = True,
|
|
45
|
+
pricing_update: bool = True,
|
|
46
|
+
pricing_overrides: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
47
|
+
on_tracking_failure: Literal["block", "allow", "fallback"] = "allow",
|
|
48
|
+
store_prompts: bool = False,
|
|
49
|
+
track_failed_calls: bool = True,
|
|
50
|
+
track_cache_savings: bool = True,
|
|
51
|
+
max_unique_tag_values: int = 1000,
|
|
52
|
+
budget_mode: Literal["local", "distributed"] = "local",
|
|
53
|
+
streaming_budget_mode: Literal["estimate", "actual"] = "actual",
|
|
54
|
+
streaming_max_output_estimate: int = 4096,
|
|
55
|
+
**backend_kwargs: Any,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Initialize the CostTracker.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
budgets: List of budget configurations
|
|
62
|
+
rate_limits: List of rate limit configurations
|
|
63
|
+
backend: Backend URL (memory, sqlite:///, postgresql://, redis://)
|
|
64
|
+
auto_detect_provider: Automatically detect provider from model name
|
|
65
|
+
pricing_update: Check for pricing updates on startup
|
|
66
|
+
pricing_overrides: Custom pricing for models (e.g., negotiated rates)
|
|
67
|
+
on_tracking_failure: Action when tracking fails (block/allow/fallback)
|
|
68
|
+
store_prompts: Store prompts in records (default: False for security)
|
|
69
|
+
track_failed_calls: Track costs for failed API calls
|
|
70
|
+
track_cache_savings: Track cache hit savings
|
|
71
|
+
max_unique_tag_values: Maximum unique values per tag key
|
|
72
|
+
budget_mode: Budget enforcement mode (local or distributed)
|
|
73
|
+
streaming_budget_mode: How to handle streaming budgets
|
|
74
|
+
streaming_max_output_estimate: Max output tokens to estimate for streaming
|
|
75
|
+
"""
|
|
76
|
+
self._auto_detect_provider = auto_detect_provider
|
|
77
|
+
self._on_tracking_failure = on_tracking_failure
|
|
78
|
+
self._store_prompts = store_prompts
|
|
79
|
+
self._track_failed_calls = track_failed_calls
|
|
80
|
+
self._track_cache_savings = track_cache_savings
|
|
81
|
+
self._max_unique_tag_values = max_unique_tag_values
|
|
82
|
+
self._budget_mode = budget_mode
|
|
83
|
+
self._streaming_budget_mode = streaming_budget_mode
|
|
84
|
+
self._streaming_max_output_estimate = streaming_max_output_estimate
|
|
85
|
+
|
|
86
|
+
# Initialize backend
|
|
87
|
+
self._backend_url = backend
|
|
88
|
+
self._fallback_backend: Optional[MemoryBackend] = None
|
|
89
|
+
try:
|
|
90
|
+
self._backend: Backend = get_backend(backend, **backend_kwargs)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
if on_tracking_failure == "block":
|
|
93
|
+
raise TrackingUnavailableError(f"Failed to initialize backend: {e}", backend)
|
|
94
|
+
elif on_tracking_failure == "fallback":
|
|
95
|
+
logger.warning(f"Failed to initialize backend {backend}, using memory fallback: {e}")
|
|
96
|
+
self._backend = MemoryBackend()
|
|
97
|
+
self._fallback_backend = self._backend
|
|
98
|
+
else:
|
|
99
|
+
logger.warning(f"Failed to initialize backend {backend}: {e}")
|
|
100
|
+
self._backend = MemoryBackend()
|
|
101
|
+
|
|
102
|
+
# Initialize pricing
|
|
103
|
+
self._pricing = PricingLoader(pricing_overrides=pricing_overrides)
|
|
104
|
+
|
|
105
|
+
# Initialize budget tracking
|
|
106
|
+
self._budget_tracker = BudgetTracker(budgets)
|
|
107
|
+
|
|
108
|
+
# Initialize rate limiting
|
|
109
|
+
self._rate_limiter = RateLimiter(rate_limits)
|
|
110
|
+
|
|
111
|
+
# Tag cardinality tracking
|
|
112
|
+
self._tag_values: Dict[str, set] = {}
|
|
113
|
+
self._tag_lock = threading.Lock()
|
|
114
|
+
|
|
115
|
+
# Last call tracking
|
|
116
|
+
self._last_record: Optional[CostRecord] = None
|
|
117
|
+
self._lock = threading.Lock()
|
|
118
|
+
|
|
119
|
+
def track(
|
|
120
|
+
self,
|
|
121
|
+
func: Optional[F] = None,
|
|
122
|
+
*,
|
|
123
|
+
tags: Optional[Dict[str, str]] = None,
|
|
124
|
+
streaming: bool = False,
|
|
125
|
+
provider: Optional[str] = None,
|
|
126
|
+
model: Optional[str] = None,
|
|
127
|
+
) -> Union[F, Callable[[F], F]]:
|
|
128
|
+
"""
|
|
129
|
+
Decorator to track LLM call costs.
|
|
130
|
+
|
|
131
|
+
Can be used with or without arguments:
|
|
132
|
+
@tracker.track
|
|
133
|
+
def my_call(): ...
|
|
134
|
+
|
|
135
|
+
@tracker.track(tags={"team": "search"})
|
|
136
|
+
def my_call(): ...
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
func: Function to decorate (when used without arguments)
|
|
140
|
+
tags: Tags for attribution
|
|
141
|
+
streaming: Whether the function returns a streaming response
|
|
142
|
+
provider: Override provider detection
|
|
143
|
+
model: Override model detection
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Decorated function
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def decorator(f: F) -> F:
|
|
150
|
+
if asyncio.iscoroutinefunction(f):
|
|
151
|
+
return self._wrap_async(f, tags, streaming, provider, model) # type: ignore
|
|
152
|
+
else:
|
|
153
|
+
return self._wrap_sync(f, tags, streaming, provider, model) # type: ignore
|
|
154
|
+
|
|
155
|
+
if func is not None:
|
|
156
|
+
return decorator(func)
|
|
157
|
+
return decorator
|
|
158
|
+
|
|
159
|
+
def _wrap_sync(
|
|
160
|
+
self,
|
|
161
|
+
func: F,
|
|
162
|
+
tags: Optional[Dict[str, str]],
|
|
163
|
+
streaming: bool,
|
|
164
|
+
provider_override: Optional[str],
|
|
165
|
+
model_override: Optional[str],
|
|
166
|
+
) -> F:
|
|
167
|
+
"""Wrap a synchronous function for tracking."""
|
|
168
|
+
|
|
169
|
+
@functools.wraps(func)
|
|
170
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
171
|
+
start_time = time.time()
|
|
172
|
+
success = True
|
|
173
|
+
error_type = None
|
|
174
|
+
response = None
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
response = func(*args, **kwargs)
|
|
178
|
+
return response
|
|
179
|
+
except Exception as e:
|
|
180
|
+
success = False
|
|
181
|
+
error_type = type(e).__name__
|
|
182
|
+
raise
|
|
183
|
+
finally:
|
|
184
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
185
|
+
|
|
186
|
+
if response is not None or (not success and self._track_failed_calls):
|
|
187
|
+
try:
|
|
188
|
+
self._record_call(
|
|
189
|
+
response=response,
|
|
190
|
+
tags=tags,
|
|
191
|
+
success=success,
|
|
192
|
+
error_type=error_type,
|
|
193
|
+
latency_ms=latency_ms,
|
|
194
|
+
provider_override=provider_override,
|
|
195
|
+
model_override=model_override,
|
|
196
|
+
)
|
|
197
|
+
except Exception as e:
|
|
198
|
+
self._handle_tracking_error(e)
|
|
199
|
+
|
|
200
|
+
return wrapper # type: ignore
|
|
201
|
+
|
|
202
|
+
def _wrap_async(
|
|
203
|
+
self,
|
|
204
|
+
func: F,
|
|
205
|
+
tags: Optional[Dict[str, str]],
|
|
206
|
+
streaming: bool,
|
|
207
|
+
provider_override: Optional[str],
|
|
208
|
+
model_override: Optional[str],
|
|
209
|
+
) -> F:
|
|
210
|
+
"""Wrap an asynchronous function for tracking."""
|
|
211
|
+
|
|
212
|
+
@functools.wraps(func)
|
|
213
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
214
|
+
start_time = time.time()
|
|
215
|
+
success = True
|
|
216
|
+
error_type = None
|
|
217
|
+
response = None
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
response = await func(*args, **kwargs)
|
|
221
|
+
return response
|
|
222
|
+
except Exception as e:
|
|
223
|
+
success = False
|
|
224
|
+
error_type = type(e).__name__
|
|
225
|
+
raise
|
|
226
|
+
finally:
|
|
227
|
+
latency_ms = int((time.time() - start_time) * 1000)
|
|
228
|
+
|
|
229
|
+
if response is not None or (not success and self._track_failed_calls):
|
|
230
|
+
try:
|
|
231
|
+
self._record_call(
|
|
232
|
+
response=response,
|
|
233
|
+
tags=tags,
|
|
234
|
+
success=success,
|
|
235
|
+
error_type=error_type,
|
|
236
|
+
latency_ms=latency_ms,
|
|
237
|
+
provider_override=provider_override,
|
|
238
|
+
model_override=model_override,
|
|
239
|
+
)
|
|
240
|
+
except Exception as e:
|
|
241
|
+
self._handle_tracking_error(e)
|
|
242
|
+
|
|
243
|
+
return wrapper # type: ignore
|
|
244
|
+
|
|
245
|
+
@contextmanager
|
|
246
|
+
def track_context(
|
|
247
|
+
self,
|
|
248
|
+
tags: Optional[Dict[str, str]] = None,
|
|
249
|
+
provider: Optional[str] = None,
|
|
250
|
+
model: Optional[str] = None,
|
|
251
|
+
):
|
|
252
|
+
"""
|
|
253
|
+
Context manager for tracking LLM calls.
|
|
254
|
+
|
|
255
|
+
Usage:
|
|
256
|
+
with tracker.track_context(tags={"feature": "search"}):
|
|
257
|
+
response = openai.chat.completions.create(...)
|
|
258
|
+
|
|
259
|
+
Note: This context manager doesn't automatically extract usage from
|
|
260
|
+
responses. Use the decorator or manual recording for automatic tracking.
|
|
261
|
+
"""
|
|
262
|
+
start_time = time.time()
|
|
263
|
+
tags = tags or {}
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
yield
|
|
267
|
+
finally:
|
|
268
|
+
pass # Context manager for grouping, actual tracking via decorator or record()
|
|
269
|
+
|
|
270
|
+
def span(
|
|
271
|
+
self,
|
|
272
|
+
name: str,
|
|
273
|
+
tags: Optional[Dict[str, str]] = None,
|
|
274
|
+
) -> Span:
|
|
275
|
+
"""
|
|
276
|
+
Create a tracking span for grouping multiple LLM calls.
|
|
277
|
+
|
|
278
|
+
Usage:
|
|
279
|
+
with tracker.span("rag_pipeline", tags={"user": "123"}) as span:
|
|
280
|
+
# Multiple LLM calls here
|
|
281
|
+
result = agent.run(query)
|
|
282
|
+
print(span.total_cost)
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
name: Name of the span
|
|
286
|
+
tags: Tags for the span
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Span context manager
|
|
290
|
+
"""
|
|
291
|
+
return Span(name=name, tags=tags or {})
|
|
292
|
+
|
|
293
|
+
def record(
|
|
294
|
+
self,
|
|
295
|
+
provider: str,
|
|
296
|
+
model: str,
|
|
297
|
+
input_tokens: int,
|
|
298
|
+
output_tokens: int,
|
|
299
|
+
tags: Optional[Dict[str, str]] = None,
|
|
300
|
+
success: bool = True,
|
|
301
|
+
error_type: Optional[str] = None,
|
|
302
|
+
latency_ms: int = 0,
|
|
303
|
+
cached_tokens: int = 0,
|
|
304
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
305
|
+
) -> CostRecord:
|
|
306
|
+
"""
|
|
307
|
+
Manually record an LLM call.
|
|
308
|
+
|
|
309
|
+
Use this for custom integrations or when automatic tracking isn't available.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
provider: Provider name (openai, anthropic, bedrock)
|
|
313
|
+
model: Model name
|
|
314
|
+
input_tokens: Number of input tokens
|
|
315
|
+
output_tokens: Number of output tokens
|
|
316
|
+
tags: Attribution tags
|
|
317
|
+
success: Whether the call succeeded
|
|
318
|
+
error_type: Error type if call failed
|
|
319
|
+
latency_ms: Call latency in milliseconds
|
|
320
|
+
cached_tokens: Number of cached input tokens
|
|
321
|
+
metadata: Additional metadata (high-cardinality data)
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
The created CostRecord
|
|
325
|
+
"""
|
|
326
|
+
tags = tags or {}
|
|
327
|
+
metadata = metadata or {}
|
|
328
|
+
|
|
329
|
+
# Validate tag cardinality
|
|
330
|
+
self._check_tag_cardinality(tags)
|
|
331
|
+
|
|
332
|
+
# Calculate cost
|
|
333
|
+
input_cost, output_cost, total_cost = self._pricing.calculate_cost(
|
|
334
|
+
provider, model, input_tokens, output_tokens, cached_tokens
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Calculate cache savings
|
|
338
|
+
cache_savings = 0.0
|
|
339
|
+
if cached_tokens > 0 and self._track_cache_savings:
|
|
340
|
+
pricing = self._pricing.get_pricing(provider, model)
|
|
341
|
+
if pricing.cached_input_cost_per_1k is not None:
|
|
342
|
+
cache_savings = (cached_tokens / 1000) * (
|
|
343
|
+
pricing.input_cost_per_1k - pricing.cached_input_cost_per_1k
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Check budgets
|
|
347
|
+
exceeded = self._budget_tracker.check_budget(total_cost, tags)
|
|
348
|
+
for budget, action in exceeded:
|
|
349
|
+
if action == BudgetAction.BLOCK:
|
|
350
|
+
raise BudgetExceededError(
|
|
351
|
+
f"Budget '{budget.name}' would be exceeded",
|
|
352
|
+
budget=budget,
|
|
353
|
+
current=self._budget_tracker.get_spending(budget.name),
|
|
354
|
+
limit=budget.limit,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Check rate limits
|
|
358
|
+
rate_exceeded = self._rate_limiter.check(model=model, provider=provider, tags=tags)
|
|
359
|
+
if rate_exceeded:
|
|
360
|
+
limit, current, retry_after = rate_exceeded[0]
|
|
361
|
+
raise RateLimitExceededError(
|
|
362
|
+
f"Rate limit '{limit.name}' exceeded",
|
|
363
|
+
limit_name=limit.name,
|
|
364
|
+
current=current,
|
|
365
|
+
limit=limit.limit,
|
|
366
|
+
retry_after_seconds=retry_after,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Create record
|
|
370
|
+
record = CostRecord(
|
|
371
|
+
timestamp=datetime.now(),
|
|
372
|
+
provider=provider,
|
|
373
|
+
model=model,
|
|
374
|
+
model_type=ModelType.CHAT, # Default, can be enhanced
|
|
375
|
+
input_tokens=input_tokens,
|
|
376
|
+
output_tokens=output_tokens,
|
|
377
|
+
input_cost=input_cost,
|
|
378
|
+
output_cost=output_cost,
|
|
379
|
+
total_cost=total_cost,
|
|
380
|
+
latency_ms=latency_ms,
|
|
381
|
+
tags=tags,
|
|
382
|
+
metadata=metadata,
|
|
383
|
+
success=success,
|
|
384
|
+
error_type=error_type,
|
|
385
|
+
cached=cached_tokens > 0,
|
|
386
|
+
cache_savings=cache_savings,
|
|
387
|
+
span_id=get_current_span().span_id if get_current_span() else None,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Save to backend
|
|
391
|
+
try:
|
|
392
|
+
self._backend.save_record(record)
|
|
393
|
+
except Exception as e:
|
|
394
|
+
self._handle_tracking_error(e)
|
|
395
|
+
|
|
396
|
+
# Record against budgets
|
|
397
|
+
self._budget_tracker.record_cost(total_cost, tags)
|
|
398
|
+
|
|
399
|
+
# Record rate limit usage
|
|
400
|
+
self._rate_limiter.record(model=model, provider=provider, tags=tags)
|
|
401
|
+
|
|
402
|
+
# Update last record
|
|
403
|
+
with self._lock:
|
|
404
|
+
self._last_record = record
|
|
405
|
+
|
|
406
|
+
# Record in current span if any
|
|
407
|
+
current_span = get_current_span()
|
|
408
|
+
if current_span:
|
|
409
|
+
current_span.record_call(
|
|
410
|
+
cost=total_cost,
|
|
411
|
+
input_tokens=input_tokens,
|
|
412
|
+
output_tokens=output_tokens,
|
|
413
|
+
model=model,
|
|
414
|
+
record=record,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
return record
|
|
418
|
+
|
|
419
|
+
def _record_call(
|
|
420
|
+
self,
|
|
421
|
+
response: Any,
|
|
422
|
+
tags: Optional[Dict[str, str]],
|
|
423
|
+
success: bool,
|
|
424
|
+
error_type: Optional[str],
|
|
425
|
+
latency_ms: int,
|
|
426
|
+
provider_override: Optional[str],
|
|
427
|
+
model_override: Optional[str],
|
|
428
|
+
) -> Optional[CostRecord]:
|
|
429
|
+
"""Record a call from a wrapped function."""
|
|
430
|
+
if response is None:
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
# Detect provider and model
|
|
434
|
+
if model_override:
|
|
435
|
+
model = model_override
|
|
436
|
+
provider = provider_override or (
|
|
437
|
+
detect_provider(model) if self._auto_detect_provider else "unknown"
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
# Try to extract from response
|
|
441
|
+
provider = provider_override or "openai" # Default
|
|
442
|
+
model = "unknown"
|
|
443
|
+
|
|
444
|
+
# Try to detect and extract
|
|
445
|
+
try:
|
|
446
|
+
if self._auto_detect_provider:
|
|
447
|
+
# Try OpenAI-style response
|
|
448
|
+
if hasattr(response, "model"):
|
|
449
|
+
model = response.model
|
|
450
|
+
provider = detect_provider(model)
|
|
451
|
+
elif isinstance(response, dict) and "model" in response:
|
|
452
|
+
model = response["model"]
|
|
453
|
+
provider = detect_provider(model)
|
|
454
|
+
except Exception:
|
|
455
|
+
pass
|
|
456
|
+
|
|
457
|
+
# Get provider handler
|
|
458
|
+
try:
|
|
459
|
+
provider_handler = get_provider(provider)
|
|
460
|
+
except ValueError:
|
|
461
|
+
logger.warning(f"Unknown provider {provider}, skipping cost tracking")
|
|
462
|
+
return None
|
|
463
|
+
|
|
464
|
+
# Extract usage
|
|
465
|
+
usage = provider_handler.extract_usage(response)
|
|
466
|
+
if model == "unknown":
|
|
467
|
+
model = provider_handler.extract_model(response)
|
|
468
|
+
|
|
469
|
+
# Record
|
|
470
|
+
return self.record(
|
|
471
|
+
provider=provider,
|
|
472
|
+
model=model,
|
|
473
|
+
input_tokens=usage.input_tokens,
|
|
474
|
+
output_tokens=usage.output_tokens,
|
|
475
|
+
tags=tags,
|
|
476
|
+
success=success,
|
|
477
|
+
error_type=error_type,
|
|
478
|
+
latency_ms=latency_ms,
|
|
479
|
+
cached_tokens=usage.cached_tokens,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def _handle_tracking_error(self, error: Exception) -> None:
|
|
483
|
+
"""Handle errors during tracking based on configuration."""
|
|
484
|
+
if self._on_tracking_failure == "block":
|
|
485
|
+
raise TrackingUnavailableError(str(error), self._backend_url)
|
|
486
|
+
elif self._on_tracking_failure == "fallback":
|
|
487
|
+
logger.warning(f"Tracking error, using fallback: {error}")
|
|
488
|
+
if self._fallback_backend is None:
|
|
489
|
+
self._fallback_backend = MemoryBackend()
|
|
490
|
+
else:
|
|
491
|
+
logger.warning(f"Tracking error (allowing): {error}")
|
|
492
|
+
|
|
493
|
+
def _check_tag_cardinality(self, tags: Dict[str, str]) -> None:
|
|
494
|
+
"""Check and track tag cardinality."""
|
|
495
|
+
with self._tag_lock:
|
|
496
|
+
for key, value in tags.items():
|
|
497
|
+
if key not in self._tag_values:
|
|
498
|
+
self._tag_values[key] = set()
|
|
499
|
+
|
|
500
|
+
self._tag_values[key].add(value)
|
|
501
|
+
|
|
502
|
+
if len(self._tag_values[key]) > self._max_unique_tag_values:
|
|
503
|
+
logger.warning(
|
|
504
|
+
f"Tag '{key}' has exceeded cardinality limit "
|
|
505
|
+
f"({len(self._tag_values[key])} > {self._max_unique_tag_values})"
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
def last_call(self) -> Optional[CostRecord]:
|
|
509
|
+
"""Get the last recorded call."""
|
|
510
|
+
with self._lock:
|
|
511
|
+
return self._last_record
|
|
512
|
+
|
|
513
|
+
def get_costs(
|
|
514
|
+
self,
|
|
515
|
+
start_date: Optional[str] = None,
|
|
516
|
+
end_date: Optional[str] = None,
|
|
517
|
+
tags: Optional[Dict[str, str]] = None,
|
|
518
|
+
group_by: Optional[List[str]] = None,
|
|
519
|
+
) -> CostReport:
|
|
520
|
+
"""
|
|
521
|
+
Query tracked costs.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
start_date: Start date (ISO format)
|
|
525
|
+
end_date: End date (ISO format)
|
|
526
|
+
tags: Filter by tags
|
|
527
|
+
group_by: Group results by fields (provider, model, or tag keys)
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
CostReport with aggregated data
|
|
531
|
+
"""
|
|
532
|
+
start = datetime.fromisoformat(start_date) if start_date else None
|
|
533
|
+
end = datetime.fromisoformat(end_date) if end_date else None
|
|
534
|
+
|
|
535
|
+
return self._backend.get_report(
|
|
536
|
+
start_date=start,
|
|
537
|
+
end_date=end,
|
|
538
|
+
tags=tags,
|
|
539
|
+
group_by=group_by,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def daily_report(self) -> CostReport:
|
|
543
|
+
"""Get a report for today."""
|
|
544
|
+
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
545
|
+
return self._backend.get_report(start_date=today)
|
|
546
|
+
|
|
547
|
+
def report_by_model(self, period: str = "day") -> CostReport:
|
|
548
|
+
"""Get a report grouped by model."""
|
|
549
|
+
start = self._get_period_start(period)
|
|
550
|
+
return self._backend.get_report(start_date=start, group_by=["model"])
|
|
551
|
+
|
|
552
|
+
def trend_analysis(
|
|
553
|
+
self,
|
|
554
|
+
metric: str = "cost",
|
|
555
|
+
granularity: str = "hour",
|
|
556
|
+
last_n_days: int = 7,
|
|
557
|
+
) -> Dict[str, Any]:
|
|
558
|
+
"""Get trend analysis for a metric."""
|
|
559
|
+
# This is a simplified implementation
|
|
560
|
+
from datetime import timedelta
|
|
561
|
+
|
|
562
|
+
end = datetime.now()
|
|
563
|
+
start = end - timedelta(days=last_n_days)
|
|
564
|
+
|
|
565
|
+
records = self._backend.get_records(start_date=start, end_date=end)
|
|
566
|
+
|
|
567
|
+
# Group by time bucket
|
|
568
|
+
buckets: Dict[str, float] = {}
|
|
569
|
+
for record in records:
|
|
570
|
+
if granularity == "hour":
|
|
571
|
+
bucket_key = record.timestamp.strftime("%Y-%m-%d %H:00")
|
|
572
|
+
elif granularity == "day":
|
|
573
|
+
bucket_key = record.timestamp.strftime("%Y-%m-%d")
|
|
574
|
+
else:
|
|
575
|
+
bucket_key = record.timestamp.strftime("%Y-%m-%d %H:00")
|
|
576
|
+
|
|
577
|
+
if bucket_key not in buckets:
|
|
578
|
+
buckets[bucket_key] = 0.0
|
|
579
|
+
|
|
580
|
+
if metric == "cost":
|
|
581
|
+
buckets[bucket_key] += record.total_cost
|
|
582
|
+
elif metric == "tokens":
|
|
583
|
+
buckets[bucket_key] += record.input_tokens + record.output_tokens
|
|
584
|
+
elif metric == "calls":
|
|
585
|
+
buckets[bucket_key] += 1
|
|
586
|
+
|
|
587
|
+
return {
|
|
588
|
+
"metric": metric,
|
|
589
|
+
"granularity": granularity,
|
|
590
|
+
"start_date": start.isoformat(),
|
|
591
|
+
"end_date": end.isoformat(),
|
|
592
|
+
"data": buckets,
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
def _get_period_start(self, period: str) -> datetime:
|
|
596
|
+
"""Get the start datetime for a period."""
|
|
597
|
+
from datetime import timedelta
|
|
598
|
+
|
|
599
|
+
now = datetime.now()
|
|
600
|
+
|
|
601
|
+
if period == "day":
|
|
602
|
+
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
603
|
+
elif period == "week":
|
|
604
|
+
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
605
|
+
return start - timedelta(days=now.weekday())
|
|
606
|
+
elif period == "month":
|
|
607
|
+
return now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
|
608
|
+
else:
|
|
609
|
+
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
610
|
+
|
|
611
|
+
def to_dataframe(self):
|
|
612
|
+
"""Export records to a pandas DataFrame."""
|
|
613
|
+
try:
|
|
614
|
+
import pandas as pd
|
|
615
|
+
except ImportError:
|
|
616
|
+
raise ImportError("pandas is required for DataFrame export. Install with: pip install pandas")
|
|
617
|
+
|
|
618
|
+
records = self._backend.get_records()
|
|
619
|
+
|
|
620
|
+
data = []
|
|
621
|
+
for r in records:
|
|
622
|
+
row = {
|
|
623
|
+
"timestamp": r.timestamp,
|
|
624
|
+
"provider": r.provider,
|
|
625
|
+
"model": r.model,
|
|
626
|
+
"input_tokens": r.input_tokens,
|
|
627
|
+
"output_tokens": r.output_tokens,
|
|
628
|
+
"total_cost": r.total_cost,
|
|
629
|
+
"latency_ms": r.latency_ms,
|
|
630
|
+
"success": r.success,
|
|
631
|
+
"cached": r.cached,
|
|
632
|
+
}
|
|
633
|
+
# Add tags as columns
|
|
634
|
+
for key, value in r.tags.items():
|
|
635
|
+
row[f"tag_{key}"] = value
|
|
636
|
+
data.append(row)
|
|
637
|
+
|
|
638
|
+
return pd.DataFrame(data)
|
|
639
|
+
|
|
640
|
+
def health_check(self) -> HealthStatus:
|
|
641
|
+
"""Check tracker and backend health."""
|
|
642
|
+
errors = []
|
|
643
|
+
|
|
644
|
+
# Check backend
|
|
645
|
+
backend_connected = False
|
|
646
|
+
try:
|
|
647
|
+
backend_connected = self._backend.health_check()
|
|
648
|
+
except Exception as e:
|
|
649
|
+
errors.append(f"Backend health check failed: {e}")
|
|
650
|
+
|
|
651
|
+
# Check pricing freshness
|
|
652
|
+
pricing_fresh = not self._pricing.is_stale
|
|
653
|
+
if self._pricing.is_stale:
|
|
654
|
+
errors.append("Pricing data is stale")
|
|
655
|
+
|
|
656
|
+
# Get last record time
|
|
657
|
+
last_record_time = None
|
|
658
|
+
with self._lock:
|
|
659
|
+
if self._last_record:
|
|
660
|
+
last_record_time = self._last_record.timestamp
|
|
661
|
+
|
|
662
|
+
return HealthStatus(
|
|
663
|
+
healthy=backend_connected and pricing_fresh and len(errors) == 0,
|
|
664
|
+
backend_connected=backend_connected,
|
|
665
|
+
pricing_fresh=pricing_fresh,
|
|
666
|
+
last_record_time=last_record_time,
|
|
667
|
+
pending_records=0,
|
|
668
|
+
errors=errors,
|
|
669
|
+
pricing_version=str(self._pricing.pricing_version),
|
|
670
|
+
pricing_last_updated=self._pricing.last_updated,
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
def on_budget_warning(self, callback: Callable[[Budget, float], None]) -> None:
|
|
674
|
+
"""Register a callback for budget warnings."""
|
|
675
|
+
self._budget_tracker.on_warning(callback)
|
|
676
|
+
|
|
677
|
+
def on_budget_exceeded(self, callback: Callable[[Budget], None]) -> None:
|
|
678
|
+
"""Register a callback for budget exceeded events."""
|
|
679
|
+
self._budget_tracker.on_exceeded(callback)
|
|
680
|
+
|
|
681
|
+
def get_budget(self, name: str) -> Optional[Budget]:
|
|
682
|
+
"""Get a budget by name."""
|
|
683
|
+
return self._budget_tracker.get_budget(name)
|
|
684
|
+
|
|
685
|
+
def get_budget_utilization(self, name: str) -> float:
|
|
686
|
+
"""Get budget utilization percentage."""
|
|
687
|
+
return self._budget_tracker.get_utilization(name)
|
|
688
|
+
|
|
689
|
+
def reset_budget(self, name: Optional[str] = None) -> None:
|
|
690
|
+
"""Reset a budget or all budgets."""
|
|
691
|
+
self._budget_tracker.reset(name)
|
|
692
|
+
|
|
693
|
+
@property
|
|
694
|
+
def pricing_last_updated(self) -> Optional[datetime]:
|
|
695
|
+
"""Get when pricing was last updated."""
|
|
696
|
+
return self._pricing.last_updated
|
|
697
|
+
|
|
698
|
+
@property
|
|
699
|
+
def pricing_version(self) -> Dict[str, str]:
|
|
700
|
+
"""Get pricing versions for all providers."""
|
|
701
|
+
return self._pricing.pricing_version
|
|
702
|
+
|
|
703
|
+
@property
|
|
704
|
+
def pricing_is_stale(self) -> bool:
|
|
705
|
+
"""Check if pricing is stale."""
|
|
706
|
+
return self._pricing.is_stale
|
|
707
|
+
|
|
708
|
+
def close(self) -> None:
|
|
709
|
+
"""Close the tracker and backend."""
|
|
710
|
+
self._backend.close()
|
|
711
|
+
if self._fallback_backend:
|
|
712
|
+
self._fallback_backend.close()
|
|
713
|
+
|
|
714
|
+
def __enter__(self) -> "CostTracker":
|
|
715
|
+
return self
|
|
716
|
+
|
|
717
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
718
|
+
self.close()
|