llm-cost-guard 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. llm_cost_guard/__init__.py +39 -0
  2. llm_cost_guard/backends/__init__.py +52 -0
  3. llm_cost_guard/backends/base.py +121 -0
  4. llm_cost_guard/backends/memory.py +265 -0
  5. llm_cost_guard/backends/sqlite.py +425 -0
  6. llm_cost_guard/budget.py +306 -0
  7. llm_cost_guard/cli.py +464 -0
  8. llm_cost_guard/clients/__init__.py +11 -0
  9. llm_cost_guard/clients/anthropic.py +231 -0
  10. llm_cost_guard/clients/openai.py +262 -0
  11. llm_cost_guard/exceptions.py +71 -0
  12. llm_cost_guard/integrations/__init__.py +12 -0
  13. llm_cost_guard/integrations/cache.py +189 -0
  14. llm_cost_guard/integrations/langchain.py +257 -0
  15. llm_cost_guard/models.py +123 -0
  16. llm_cost_guard/pricing/__init__.py +7 -0
  17. llm_cost_guard/pricing/anthropic.yaml +88 -0
  18. llm_cost_guard/pricing/bedrock.yaml +215 -0
  19. llm_cost_guard/pricing/loader.py +221 -0
  20. llm_cost_guard/pricing/openai.yaml +148 -0
  21. llm_cost_guard/pricing/vertex.yaml +133 -0
  22. llm_cost_guard/providers/__init__.py +69 -0
  23. llm_cost_guard/providers/anthropic.py +115 -0
  24. llm_cost_guard/providers/base.py +72 -0
  25. llm_cost_guard/providers/bedrock.py +135 -0
  26. llm_cost_guard/providers/openai.py +110 -0
  27. llm_cost_guard/rate_limit.py +233 -0
  28. llm_cost_guard/span.py +143 -0
  29. llm_cost_guard/tokenizers/__init__.py +7 -0
  30. llm_cost_guard/tokenizers/base.py +207 -0
  31. llm_cost_guard/tracker.py +718 -0
  32. llm_cost_guard-0.1.0.dist-info/METADATA +357 -0
  33. llm_cost_guard-0.1.0.dist-info/RECORD +36 -0
  34. llm_cost_guard-0.1.0.dist-info/WHEEL +4 -0
  35. llm_cost_guard-0.1.0.dist-info/entry_points.txt +2 -0
  36. llm_cost_guard-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,262 @@
1
+ """
2
+ Wrapped OpenAI client with automatic cost tracking.
3
+ """
4
+
5
+ import time
6
+ from typing import Any, Dict, Optional, TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from llm_cost_guard import CostTracker
10
+
11
+
12
+ class TrackedOpenAI:
13
+ """
14
+ OpenAI client wrapper with automatic cost tracking.
15
+
16
+ Usage:
17
+ from llm_cost_guard import CostTracker
18
+ from llm_cost_guard.clients import TrackedOpenAI
19
+
20
+ tracker = CostTracker()
21
+ client = TrackedOpenAI(tracker=tracker)
22
+
23
+ response = client.chat.completions.create(
24
+ model="gpt-4o",
25
+ messages=[{"role": "user", "content": "Hello!"}]
26
+ )
27
+ # Cost is automatically tracked
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ tracker: "CostTracker",
33
+ client: Optional[Any] = None,
34
+ tags: Optional[Dict[str, str]] = None,
35
+ **openai_kwargs: Any,
36
+ ):
37
+ """
38
+ Initialize the tracked OpenAI client.
39
+
40
+ Args:
41
+ tracker: CostTracker instance
42
+ client: Optional existing OpenAI client to wrap
43
+ tags: Default tags for all calls
44
+ **openai_kwargs: Arguments to pass to OpenAI client
45
+ """
46
+ try:
47
+ from openai import OpenAI
48
+ except ImportError:
49
+ raise ImportError(
50
+ "OpenAI is required for this client. Install with: pip install openai"
51
+ )
52
+
53
+ self._tracker = tracker
54
+ self._default_tags = tags or {}
55
+ self._client = client or OpenAI(**openai_kwargs)
56
+
57
+ # Create wrapped interface
58
+ self.chat = _TrackedChat(self._client.chat, self._tracker, self._default_tags)
59
+ self.completions = _TrackedCompletions(
60
+ self._client.completions, self._tracker, self._default_tags
61
+ )
62
+ self.embeddings = _TrackedEmbeddings(
63
+ self._client.embeddings, self._tracker, self._default_tags
64
+ )
65
+
66
+ @property
67
+ def models(self):
68
+ """Access the models API."""
69
+ return self._client.models
70
+
71
+ def close(self) -> None:
72
+ """Close the client."""
73
+ self._client.close()
74
+
75
+ def __enter__(self):
76
+ return self
77
+
78
+ def __exit__(self, *args):
79
+ self.close()
80
+
81
+
82
+ class _TrackedChat:
83
+ """Wrapped chat completions API."""
84
+
85
+ def __init__(self, chat, tracker: "CostTracker", default_tags: Dict[str, str]):
86
+ self._chat = chat
87
+ self._tracker = tracker
88
+ self._default_tags = default_tags
89
+ self.completions = _TrackedChatCompletions(
90
+ chat.completions, tracker, default_tags
91
+ )
92
+
93
+
94
+ class _TrackedChatCompletions:
95
+ """Wrapped chat.completions API."""
96
+
97
+ def __init__(self, completions, tracker: "CostTracker", default_tags: Dict[str, str]):
98
+ self._completions = completions
99
+ self._tracker = tracker
100
+ self._default_tags = default_tags
101
+
102
+ def create(
103
+ self,
104
+ *,
105
+ tags: Optional[Dict[str, str]] = None,
106
+ **kwargs: Any,
107
+ ) -> Any:
108
+ """Create a chat completion with tracking."""
109
+ start_time = time.time()
110
+ success = True
111
+ error_type = None
112
+ response = None
113
+
114
+ try:
115
+ response = self._completions.create(**kwargs)
116
+ return response
117
+ except Exception as e:
118
+ success = False
119
+ error_type = type(e).__name__
120
+ raise
121
+ finally:
122
+ latency_ms = int((time.time() - start_time) * 1000)
123
+
124
+ if response is not None:
125
+ self._record_response(response, tags, success, error_type, latency_ms)
126
+
127
+ def _record_response(
128
+ self,
129
+ response: Any,
130
+ tags: Optional[Dict[str, str]],
131
+ success: bool,
132
+ error_type: Optional[str],
133
+ latency_ms: int,
134
+ ) -> None:
135
+ """Record the response with the tracker."""
136
+ from llm_cost_guard.providers.openai import OpenAIProvider
137
+
138
+ provider = OpenAIProvider()
139
+ usage = provider.extract_usage(response)
140
+ model = provider.extract_model(response)
141
+
142
+ all_tags = dict(self._default_tags)
143
+ if tags:
144
+ all_tags.update(tags)
145
+
146
+ self._tracker.record(
147
+ provider="openai",
148
+ model=model,
149
+ input_tokens=usage.input_tokens,
150
+ output_tokens=usage.output_tokens,
151
+ tags=all_tags,
152
+ success=success,
153
+ error_type=error_type,
154
+ latency_ms=latency_ms,
155
+ cached_tokens=usage.cached_tokens,
156
+ )
157
+
158
+
159
+ class _TrackedCompletions:
160
+ """Wrapped completions API (legacy)."""
161
+
162
+ def __init__(self, completions, tracker: "CostTracker", default_tags: Dict[str, str]):
163
+ self._completions = completions
164
+ self._tracker = tracker
165
+ self._default_tags = default_tags
166
+
167
+ def create(
168
+ self,
169
+ *,
170
+ tags: Optional[Dict[str, str]] = None,
171
+ **kwargs: Any,
172
+ ) -> Any:
173
+ """Create a completion with tracking."""
174
+ start_time = time.time()
175
+ success = True
176
+ error_type = None
177
+ response = None
178
+
179
+ try:
180
+ response = self._completions.create(**kwargs)
181
+ return response
182
+ except Exception as e:
183
+ success = False
184
+ error_type = type(e).__name__
185
+ raise
186
+ finally:
187
+ latency_ms = int((time.time() - start_time) * 1000)
188
+
189
+ if response is not None:
190
+ from llm_cost_guard.providers.openai import OpenAIProvider
191
+
192
+ provider = OpenAIProvider()
193
+ usage = provider.extract_usage(response)
194
+ model = provider.extract_model(response)
195
+
196
+ all_tags = dict(self._default_tags)
197
+ if tags:
198
+ all_tags.update(tags)
199
+
200
+ self._tracker.record(
201
+ provider="openai",
202
+ model=model,
203
+ input_tokens=usage.input_tokens,
204
+ output_tokens=usage.output_tokens,
205
+ tags=all_tags,
206
+ success=success,
207
+ error_type=error_type,
208
+ latency_ms=latency_ms,
209
+ )
210
+
211
+
212
+ class _TrackedEmbeddings:
213
+ """Wrapped embeddings API."""
214
+
215
+ def __init__(self, embeddings, tracker: "CostTracker", default_tags: Dict[str, str]):
216
+ self._embeddings = embeddings
217
+ self._tracker = tracker
218
+ self._default_tags = default_tags
219
+
220
+ def create(
221
+ self,
222
+ *,
223
+ tags: Optional[Dict[str, str]] = None,
224
+ **kwargs: Any,
225
+ ) -> Any:
226
+ """Create embeddings with tracking."""
227
+ start_time = time.time()
228
+ success = True
229
+ error_type = None
230
+ response = None
231
+
232
+ try:
233
+ response = self._embeddings.create(**kwargs)
234
+ return response
235
+ except Exception as e:
236
+ success = False
237
+ error_type = type(e).__name__
238
+ raise
239
+ finally:
240
+ latency_ms = int((time.time() - start_time) * 1000)
241
+
242
+ if response is not None:
243
+ from llm_cost_guard.providers.openai import OpenAIProvider
244
+
245
+ provider = OpenAIProvider()
246
+ usage = provider.extract_usage(response)
247
+ model = kwargs.get("model", "unknown")
248
+
249
+ all_tags = dict(self._default_tags)
250
+ if tags:
251
+ all_tags.update(tags)
252
+
253
+ self._tracker.record(
254
+ provider="openai",
255
+ model=model,
256
+ input_tokens=usage.input_tokens,
257
+ output_tokens=0, # Embeddings don't have output tokens
258
+ tags=all_tags,
259
+ success=success,
260
+ error_type=error_type,
261
+ latency_ms=latency_ms,
262
+ )
@@ -0,0 +1,71 @@
1
+ """
2
+ Custom exceptions for LLM Cost Guard.
3
+ """
4
+
5
+ from typing import TYPE_CHECKING, Optional
6
+
7
+ if TYPE_CHECKING:
8
+ from llm_cost_guard.budget import Budget
9
+
10
+
11
+ class LLMCostGuardError(Exception):
12
+ """Base exception for LLM Cost Guard."""
13
+
14
+ pass
15
+
16
+
17
+ class BudgetExceededError(LLMCostGuardError):
18
+ """Raised when a budget limit is exceeded."""
19
+
20
+ def __init__(
21
+ self,
22
+ message: str,
23
+ budget: Optional["Budget"] = None,
24
+ current: float = 0.0,
25
+ limit: float = 0.0,
26
+ ):
27
+ super().__init__(message)
28
+ self.budget = budget
29
+ self.current = current
30
+ self.limit = limit
31
+
32
+
33
+ class PricingNotFoundError(LLMCostGuardError):
34
+ """Raised when pricing information for a model is not found."""
35
+
36
+ def __init__(self, message: str, provider: str = "", model: str = ""):
37
+ super().__init__(message)
38
+ self.provider = provider
39
+ self.model = model
40
+
41
+
42
+ class TokenCountError(LLMCostGuardError):
43
+ """Raised when token counting fails."""
44
+
45
+ pass
46
+
47
+
48
+ class TrackingUnavailableError(LLMCostGuardError):
49
+ """Raised when the tracking backend is unavailable."""
50
+
51
+ def __init__(self, message: str, backend: str = ""):
52
+ super().__init__(message)
53
+ self.backend = backend
54
+
55
+
56
+ class RateLimitExceededError(LLMCostGuardError):
57
+ """Raised when a rate limit is exceeded."""
58
+
59
+ def __init__(
60
+ self,
61
+ message: str,
62
+ limit_name: str = "",
63
+ current: int = 0,
64
+ limit: int = 0,
65
+ retry_after_seconds: Optional[float] = None,
66
+ ):
67
+ super().__init__(message)
68
+ self.limit_name = limit_name
69
+ self.current = current
70
+ self.limit = limit
71
+ self.retry_after_seconds = retry_after_seconds
@@ -0,0 +1,12 @@
1
+ """
2
+ Integrations with external tools and frameworks.
3
+ """
4
+
5
+ from llm_cost_guard.integrations.langchain import CostTrackingCallback, track_chain
6
+ from llm_cost_guard.integrations.cache import CacheTracker
7
+
8
+ __all__ = [
9
+ "CostTrackingCallback",
10
+ "track_chain",
11
+ "CacheTracker",
12
+ ]
@@ -0,0 +1,189 @@
1
+ """
2
+ Cache integration for LLM Cost Guard.
3
+ """
4
+
5
+ import functools
6
+ import logging
7
+ from typing import Any, Callable, Dict, Optional, TypeVar
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ F = TypeVar("F", bound=Callable[..., Any])
12
+
13
+
14
+ class CacheTracker:
15
+ """
16
+ Tracks cache hits and savings for cached LLM calls.
17
+
18
+ Usage:
19
+ from llm_cost_guard import CostTracker
20
+ from llm_cost_guard.integrations.cache import CacheTracker
21
+
22
+ tracker = CostTracker()
23
+ cache_tracker = CacheTracker(tracker)
24
+
25
+ @cache_tracker.track
26
+ @your_cache_decorator
27
+ def cached_llm_call(prompt):
28
+ return llm.invoke(prompt)
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ tracker: Any, # CostTracker
34
+ default_tags: Optional[Dict[str, str]] = None,
35
+ ):
36
+ """
37
+ Initialize the cache tracker.
38
+
39
+ Args:
40
+ tracker: CostTracker instance
41
+ default_tags: Default tags to apply to all tracked calls
42
+ """
43
+ self._tracker = tracker
44
+ self._default_tags = default_tags or {}
45
+ self._cache_hits = 0
46
+ self._cache_misses = 0
47
+ self._estimated_savings = 0.0
48
+
49
+ def track(
50
+ self,
51
+ func: Optional[F] = None,
52
+ *,
53
+ tags: Optional[Dict[str, str]] = None,
54
+ cache_indicator: str = "_from_cache",
55
+ ) -> F:
56
+ """
57
+ Decorator to track cache hits and savings.
58
+
59
+ The decorated function should set a `_from_cache` attribute on
60
+ the result if it came from cache, or return a tuple (result, from_cache).
61
+
62
+ Args:
63
+ func: Function to decorate
64
+ tags: Additional tags
65
+ cache_indicator: Attribute name to check for cache hit
66
+
67
+ Returns:
68
+ Decorated function
69
+ """
70
+
71
+ def decorator(f: F) -> F:
72
+ @functools.wraps(f)
73
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
74
+ result = f(*args, **kwargs)
75
+
76
+ # Check if result came from cache
77
+ from_cache = False
78
+
79
+ # Check for tuple return
80
+ if isinstance(result, tuple) and len(result) == 2:
81
+ actual_result, from_cache = result
82
+ result = actual_result
83
+
84
+ # Check for attribute on result
85
+ elif hasattr(result, cache_indicator):
86
+ from_cache = getattr(result, cache_indicator, False)
87
+
88
+ # Update cache stats
89
+ if from_cache:
90
+ self._cache_hits += 1
91
+ else:
92
+ self._cache_misses += 1
93
+
94
+ return result
95
+
96
+ return wrapper # type: ignore
97
+
98
+ if func is not None:
99
+ return decorator(func)
100
+ return decorator # type: ignore
101
+
102
+ def record_cache_hit(
103
+ self,
104
+ estimated_cost: float,
105
+ provider: str = "unknown",
106
+ model: str = "unknown",
107
+ tags: Optional[Dict[str, str]] = None,
108
+ ) -> None:
109
+ """
110
+ Manually record a cache hit with estimated savings.
111
+
112
+ Args:
113
+ estimated_cost: Estimated cost that was saved
114
+ provider: Provider name
115
+ model: Model name
116
+ tags: Attribution tags
117
+ """
118
+ self._cache_hits += 1
119
+ self._estimated_savings += estimated_cost
120
+
121
+ # Record in the main tracker as a zero-cost call
122
+ all_tags = dict(self._default_tags)
123
+ if tags:
124
+ all_tags.update(tags)
125
+ all_tags["cache_hit"] = "true"
126
+
127
+ self._tracker.record(
128
+ provider=provider,
129
+ model=model,
130
+ input_tokens=0,
131
+ output_tokens=0,
132
+ tags=all_tags,
133
+ success=True,
134
+ metadata={"cache_savings": estimated_cost},
135
+ )
136
+
137
+ def record_cache_miss(
138
+ self,
139
+ provider: str = "unknown",
140
+ model: str = "unknown",
141
+ tags: Optional[Dict[str, str]] = None,
142
+ ) -> None:
143
+ """
144
+ Manually record a cache miss.
145
+
146
+ Args:
147
+ provider: Provider name
148
+ model: Model name
149
+ tags: Attribution tags
150
+ """
151
+ self._cache_misses += 1
152
+
153
+ @property
154
+ def cache_hits(self) -> int:
155
+ """Get total cache hits."""
156
+ return self._cache_hits
157
+
158
+ @property
159
+ def cache_misses(self) -> int:
160
+ """Get total cache misses."""
161
+ return self._cache_misses
162
+
163
+ @property
164
+ def cache_hit_rate(self) -> float:
165
+ """Get cache hit rate (0.0 to 1.0)."""
166
+ total = self._cache_hits + self._cache_misses
167
+ if total == 0:
168
+ return 0.0
169
+ return self._cache_hits / total
170
+
171
+ @property
172
+ def estimated_savings(self) -> float:
173
+ """Get estimated cost savings from cache hits."""
174
+ return self._estimated_savings
175
+
176
+ def reset(self) -> None:
177
+ """Reset cache statistics."""
178
+ self._cache_hits = 0
179
+ self._cache_misses = 0
180
+ self._estimated_savings = 0.0
181
+
182
+ def get_stats(self) -> Dict[str, Any]:
183
+ """Get cache statistics."""
184
+ return {
185
+ "cache_hits": self._cache_hits,
186
+ "cache_misses": self._cache_misses,
187
+ "cache_hit_rate": self.cache_hit_rate,
188
+ "estimated_savings": self._estimated_savings,
189
+ }