headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
"""Google Gemini provider for Headroom SDK.
|
|
2
|
+
|
|
3
|
+
Supports Google's Gemini models through two interfaces:
|
|
4
|
+
1. OpenAI-compatible endpoint (recommended for Headroom)
|
|
5
|
+
2. Native Google AI SDK (for advanced features)
|
|
6
|
+
|
|
7
|
+
Token counting uses Google's official countTokens API when a client
|
|
8
|
+
is provided. This gives accurate counts for all content types.
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
import google.generativeai as genai
|
|
12
|
+
from headroom import GoogleProvider
|
|
13
|
+
|
|
14
|
+
genai.configure(api_key="your-api-key")
|
|
15
|
+
provider = GoogleProvider(client=genai) # Accurate counting via API
|
|
16
|
+
|
|
17
|
+
# Or without client (uses estimation - less accurate)
|
|
18
|
+
provider = GoogleProvider() # Warning: approximate counting
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import logging
|
|
24
|
+
import warnings
|
|
25
|
+
from datetime import date
|
|
26
|
+
from typing import Any
|
|
27
|
+
|
|
28
|
+
from headroom.tokenizers import EstimatingTokenCounter
|
|
29
|
+
|
|
30
|
+
from .base import Provider, TokenCounter
|
|
31
|
+
|
|
32
|
+
# Check if litellm is available for pricing/context limit lookups
|
|
33
|
+
try:
|
|
34
|
+
import litellm
|
|
35
|
+
from litellm import get_model_info as litellm_get_model_info
|
|
36
|
+
|
|
37
|
+
LITELLM_AVAILABLE = True
|
|
38
|
+
except ImportError:
|
|
39
|
+
LITELLM_AVAILABLE = False
|
|
40
|
+
litellm = None # type: ignore[assignment]
|
|
41
|
+
litellm_get_model_info = None # type: ignore[assignment]
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
# Warning flags
|
|
46
|
+
_FALLBACK_WARNING_SHOWN = False
|
|
47
|
+
|
|
48
|
+
# Pricing metadata
|
|
49
|
+
_PRICING_LAST_UPDATED = date(2025, 1, 6)
|
|
50
|
+
|
|
51
|
+
# Google model context limits
|
|
52
|
+
_CONTEXT_LIMITS: dict[str, int] = {
|
|
53
|
+
# Gemini 2.0
|
|
54
|
+
"gemini-2.0-flash": 1000000,
|
|
55
|
+
"gemini-2.0-flash-exp": 1000000,
|
|
56
|
+
"gemini-2.0-flash-thinking": 1000000,
|
|
57
|
+
# Gemini 1.5
|
|
58
|
+
"gemini-1.5-pro": 2000000,
|
|
59
|
+
"gemini-1.5-pro-latest": 2000000,
|
|
60
|
+
"gemini-1.5-flash": 1000000,
|
|
61
|
+
"gemini-1.5-flash-latest": 1000000,
|
|
62
|
+
"gemini-1.5-flash-8b": 1000000,
|
|
63
|
+
# Gemini 1.0
|
|
64
|
+
"gemini-1.0-pro": 32768,
|
|
65
|
+
"gemini-pro": 32768,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
# Fallback pricing - LiteLLM is preferred source
|
|
69
|
+
# Pricing per 1M tokens (input, output)
|
|
70
|
+
# Note: Google has different pricing tiers based on context length
|
|
71
|
+
_PRICING: dict[str, tuple[float, float]] = {
|
|
72
|
+
"gemini-2.0-flash": (0.10, 0.40),
|
|
73
|
+
"gemini-2.0-flash-exp": (0.10, 0.40), # Experimental, may change
|
|
74
|
+
"gemini-1.5-pro": (1.25, 5.00), # Up to 128K context
|
|
75
|
+
"gemini-1.5-flash": (0.075, 0.30), # Up to 128K context
|
|
76
|
+
"gemini-1.5-flash-8b": (0.0375, 0.15),
|
|
77
|
+
"gemini-1.0-pro": (0.50, 1.50),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class GeminiTokenCounter:
|
|
82
|
+
"""Token counter for Gemini models.
|
|
83
|
+
|
|
84
|
+
When a google.generativeai client is provided, uses the official
|
|
85
|
+
countTokens API for accurate counting. Falls back to estimation
|
|
86
|
+
when no client is available.
|
|
87
|
+
|
|
88
|
+
Usage:
|
|
89
|
+
import google.generativeai as genai
|
|
90
|
+
genai.configure(api_key="...")
|
|
91
|
+
|
|
92
|
+
# With API (accurate)
|
|
93
|
+
counter = GeminiTokenCounter("gemini-2.0-flash", client=genai)
|
|
94
|
+
|
|
95
|
+
# Without API (estimation)
|
|
96
|
+
counter = GeminiTokenCounter("gemini-2.0-flash")
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(self, model: str, client: Any = None):
|
|
100
|
+
"""Initialize Gemini token counter.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model: Gemini model name.
|
|
104
|
+
client: Optional google.generativeai module for API-based counting.
|
|
105
|
+
"""
|
|
106
|
+
global _FALLBACK_WARNING_SHOWN
|
|
107
|
+
|
|
108
|
+
self.model = model
|
|
109
|
+
self._client = client
|
|
110
|
+
self._use_api = client is not None
|
|
111
|
+
self._genai_model = None
|
|
112
|
+
|
|
113
|
+
# Gemini uses ~4 chars per token (similar to GPT models)
|
|
114
|
+
self._estimator = EstimatingTokenCounter(chars_per_token=4.0)
|
|
115
|
+
|
|
116
|
+
if not self._use_api and not _FALLBACK_WARNING_SHOWN:
|
|
117
|
+
warnings.warn(
|
|
118
|
+
"GoogleProvider: No client provided, using estimation. "
|
|
119
|
+
"For accurate counting, pass google.generativeai: "
|
|
120
|
+
"GoogleProvider(client=genai)",
|
|
121
|
+
UserWarning,
|
|
122
|
+
stacklevel=4,
|
|
123
|
+
)
|
|
124
|
+
_FALLBACK_WARNING_SHOWN = True
|
|
125
|
+
|
|
126
|
+
def _get_model(self):
|
|
127
|
+
"""Lazy-load the GenerativeModel for API calls."""
|
|
128
|
+
if self._genai_model is None and self._client is not None:
|
|
129
|
+
self._genai_model = self._client.GenerativeModel(self.model)
|
|
130
|
+
return self._genai_model
|
|
131
|
+
|
|
132
|
+
def count_text(self, text: str) -> int:
|
|
133
|
+
"""Count tokens in text.
|
|
134
|
+
|
|
135
|
+
Uses countTokens API if client available, otherwise estimates.
|
|
136
|
+
"""
|
|
137
|
+
if not text:
|
|
138
|
+
return 0
|
|
139
|
+
|
|
140
|
+
if self._use_api:
|
|
141
|
+
try:
|
|
142
|
+
model = self._get_model()
|
|
143
|
+
response = model.count_tokens(text)
|
|
144
|
+
return response.total_tokens
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.debug(f"Google countTokens API failed: {e}, using estimation")
|
|
147
|
+
|
|
148
|
+
return self._estimator.count_text(text)
|
|
149
|
+
|
|
150
|
+
def count_message(self, message: dict[str, Any]) -> int:
|
|
151
|
+
"""Count tokens in a message."""
|
|
152
|
+
# For API-based counting, convert message to content and count
|
|
153
|
+
if self._use_api:
|
|
154
|
+
try:
|
|
155
|
+
content = self._message_to_content(message)
|
|
156
|
+
model = self._get_model()
|
|
157
|
+
response = model.count_tokens(content)
|
|
158
|
+
return response.total_tokens
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.debug(f"Google countTokens API failed: {e}, using estimation")
|
|
161
|
+
|
|
162
|
+
# Fallback to estimation
|
|
163
|
+
return self._estimate_message(message)
|
|
164
|
+
|
|
165
|
+
def count_messages(self, messages: list[dict[str, Any]]) -> int:
|
|
166
|
+
"""Count tokens in messages.
|
|
167
|
+
|
|
168
|
+
Uses countTokens API with full conversation if available.
|
|
169
|
+
"""
|
|
170
|
+
if not messages:
|
|
171
|
+
return 0
|
|
172
|
+
|
|
173
|
+
if self._use_api:
|
|
174
|
+
try:
|
|
175
|
+
# Convert to Gemini content format
|
|
176
|
+
contents = [self._message_to_content(msg) for msg in messages]
|
|
177
|
+
model = self._get_model()
|
|
178
|
+
response = model.count_tokens(contents)
|
|
179
|
+
return response.total_tokens
|
|
180
|
+
except Exception as e:
|
|
181
|
+
logger.debug(f"Google countTokens API failed: {e}, using estimation")
|
|
182
|
+
|
|
183
|
+
# Fallback to estimation
|
|
184
|
+
total = sum(self._estimate_message(msg) for msg in messages)
|
|
185
|
+
total += 3 # Priming tokens
|
|
186
|
+
return total
|
|
187
|
+
|
|
188
|
+
def _message_to_content(self, message: dict[str, Any]) -> str:
|
|
189
|
+
"""Convert OpenAI-format message to text content for counting."""
|
|
190
|
+
content = message.get("content", "")
|
|
191
|
+
if isinstance(content, str):
|
|
192
|
+
return content
|
|
193
|
+
elif isinstance(content, list):
|
|
194
|
+
parts = []
|
|
195
|
+
for part in content:
|
|
196
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
197
|
+
parts.append(part.get("text", ""))
|
|
198
|
+
elif isinstance(part, str):
|
|
199
|
+
parts.append(part)
|
|
200
|
+
return "\n".join(parts)
|
|
201
|
+
return str(content)
|
|
202
|
+
|
|
203
|
+
def _estimate_message(self, message: dict[str, Any]) -> int:
|
|
204
|
+
"""Estimate tokens in a message without API."""
|
|
205
|
+
tokens = 4 # Message overhead
|
|
206
|
+
|
|
207
|
+
role = message.get("role", "")
|
|
208
|
+
tokens += self._estimator.count_text(role)
|
|
209
|
+
|
|
210
|
+
content = message.get("content")
|
|
211
|
+
if content:
|
|
212
|
+
if isinstance(content, str):
|
|
213
|
+
tokens += self._estimator.count_text(content)
|
|
214
|
+
elif isinstance(content, list):
|
|
215
|
+
for part in content:
|
|
216
|
+
if isinstance(part, dict):
|
|
217
|
+
if part.get("type") == "text":
|
|
218
|
+
tokens += self._estimator.count_text(part.get("text", ""))
|
|
219
|
+
elif isinstance(part, str):
|
|
220
|
+
tokens += self._estimator.count_text(part)
|
|
221
|
+
|
|
222
|
+
return tokens
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class GoogleProvider(Provider):
|
|
226
|
+
"""Provider for Google Gemini models.
|
|
227
|
+
|
|
228
|
+
Supports Gemini 1.5 and 2.0 model families through:
|
|
229
|
+
- OpenAI-compatible endpoint (generativelanguage.googleapis.com)
|
|
230
|
+
- Native Google AI SDK (for accurate token counting)
|
|
231
|
+
|
|
232
|
+
Example:
|
|
233
|
+
import google.generativeai as genai
|
|
234
|
+
genai.configure(api_key="...")
|
|
235
|
+
|
|
236
|
+
# With client (accurate token counting via API)
|
|
237
|
+
provider = GoogleProvider(client=genai)
|
|
238
|
+
|
|
239
|
+
# Without client (estimation-based counting)
|
|
240
|
+
provider = GoogleProvider()
|
|
241
|
+
|
|
242
|
+
# Token counting
|
|
243
|
+
counter = provider.get_token_counter("gemini-2.0-flash")
|
|
244
|
+
tokens = counter.count_text("Hello, world!")
|
|
245
|
+
|
|
246
|
+
# Context limits
|
|
247
|
+
limit = provider.get_context_limit("gemini-1.5-pro") # 2M tokens!
|
|
248
|
+
|
|
249
|
+
# Cost estimation
|
|
250
|
+
cost = provider.estimate_cost(
|
|
251
|
+
input_tokens=100000,
|
|
252
|
+
output_tokens=10000,
|
|
253
|
+
model="gemini-1.5-pro",
|
|
254
|
+
)
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
# OpenAI-compatible endpoint for Gemini
|
|
258
|
+
OPENAI_COMPATIBLE_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
|
259
|
+
|
|
260
|
+
def __init__(self, client: Any = None):
|
|
261
|
+
"""Initialize Google provider.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
client: Optional google.generativeai module for API-based token counting.
|
|
265
|
+
If provided, uses countTokens API for accurate counts.
|
|
266
|
+
"""
|
|
267
|
+
self._client = client
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def name(self) -> str:
|
|
271
|
+
return "google"
|
|
272
|
+
|
|
273
|
+
def supports_model(self, model: str) -> bool:
|
|
274
|
+
"""Check if model is a known Gemini model."""
|
|
275
|
+
model_lower = model.lower()
|
|
276
|
+
if model_lower in _CONTEXT_LIMITS:
|
|
277
|
+
return True
|
|
278
|
+
# Check prefix match
|
|
279
|
+
for prefix in ["gemini-2", "gemini-1.5", "gemini-1.0", "gemini-pro"]:
|
|
280
|
+
if model_lower.startswith(prefix):
|
|
281
|
+
return True
|
|
282
|
+
return False
|
|
283
|
+
|
|
284
|
+
def get_token_counter(self, model: str) -> TokenCounter:
|
|
285
|
+
"""Get token counter for a Gemini model.
|
|
286
|
+
|
|
287
|
+
Uses countTokens API if client was provided, otherwise estimates.
|
|
288
|
+
"""
|
|
289
|
+
if not self.supports_model(model):
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"Model '{model}' is not recognized as a Google model. "
|
|
292
|
+
f"Supported models: {list(_CONTEXT_LIMITS.keys())}"
|
|
293
|
+
)
|
|
294
|
+
return GeminiTokenCounter(model, client=self._client)
|
|
295
|
+
|
|
296
|
+
def get_context_limit(self, model: str) -> int:
|
|
297
|
+
"""Get context limit for a Gemini model.
|
|
298
|
+
|
|
299
|
+
Tries LiteLLM first for up-to-date limits, falls back to hardcoded values.
|
|
300
|
+
Note: Gemini 1.5 Pro has 2M token context!
|
|
301
|
+
"""
|
|
302
|
+
model_lower = model.lower()
|
|
303
|
+
|
|
304
|
+
# Try LiteLLM first for up-to-date context limits
|
|
305
|
+
if LITELLM_AVAILABLE and litellm_get_model_info is not None:
|
|
306
|
+
# Try different model name formats that LiteLLM might recognize
|
|
307
|
+
model_variants = [
|
|
308
|
+
f"gemini/{model_lower}", # gemini/gemini-1.5-pro
|
|
309
|
+
model_lower, # gemini-1.5-pro
|
|
310
|
+
]
|
|
311
|
+
for variant in model_variants:
|
|
312
|
+
try:
|
|
313
|
+
info = litellm_get_model_info(variant)
|
|
314
|
+
if info:
|
|
315
|
+
if "max_input_tokens" in info and info["max_input_tokens"]:
|
|
316
|
+
return info["max_input_tokens"]
|
|
317
|
+
if "max_tokens" in info and info["max_tokens"]:
|
|
318
|
+
return info["max_tokens"]
|
|
319
|
+
except Exception:
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
# Fallback to hardcoded limits
|
|
323
|
+
# Direct match
|
|
324
|
+
if model_lower in _CONTEXT_LIMITS:
|
|
325
|
+
return _CONTEXT_LIMITS[model_lower]
|
|
326
|
+
|
|
327
|
+
# Prefix match
|
|
328
|
+
for prefix, limit in [
|
|
329
|
+
("gemini-2.0", 1000000),
|
|
330
|
+
("gemini-1.5-pro", 2000000),
|
|
331
|
+
("gemini-1.5-flash", 1000000),
|
|
332
|
+
("gemini-1.0", 32768),
|
|
333
|
+
("gemini-pro", 32768),
|
|
334
|
+
]:
|
|
335
|
+
if model_lower.startswith(prefix):
|
|
336
|
+
return limit
|
|
337
|
+
|
|
338
|
+
raise ValueError(
|
|
339
|
+
f"Unknown context limit for model '{model}'. "
|
|
340
|
+
f"Known models: {list(_CONTEXT_LIMITS.keys())}"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def estimate_cost(
|
|
344
|
+
self,
|
|
345
|
+
input_tokens: int,
|
|
346
|
+
output_tokens: int,
|
|
347
|
+
model: str,
|
|
348
|
+
cached_tokens: int = 0,
|
|
349
|
+
) -> float | None:
|
|
350
|
+
"""Estimate cost for Gemini API call.
|
|
351
|
+
|
|
352
|
+
Tries LiteLLM first for up-to-date pricing, falls back to hardcoded values.
|
|
353
|
+
|
|
354
|
+
Note: Google has tiered pricing based on context length.
|
|
355
|
+
This uses the standard pricing (up to 128K context).
|
|
356
|
+
For >128K context, actual costs may be higher.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
input_tokens: Number of input tokens.
|
|
360
|
+
output_tokens: Number of output tokens.
|
|
361
|
+
model: Model name.
|
|
362
|
+
cached_tokens: Number of cached tokens (not used by Google).
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Estimated cost in USD, or None if pricing unknown.
|
|
366
|
+
"""
|
|
367
|
+
model_lower = model.lower()
|
|
368
|
+
|
|
369
|
+
# Try LiteLLM first for up-to-date pricing
|
|
370
|
+
if LITELLM_AVAILABLE and litellm is not None:
|
|
371
|
+
# Try different model name formats that LiteLLM might recognize
|
|
372
|
+
model_variants = [
|
|
373
|
+
f"gemini/{model_lower}", # gemini/gemini-1.5-pro
|
|
374
|
+
model_lower, # gemini-1.5-pro
|
|
375
|
+
]
|
|
376
|
+
for variant in model_variants:
|
|
377
|
+
try:
|
|
378
|
+
cost = litellm.completion_cost(
|
|
379
|
+
model=variant,
|
|
380
|
+
prompt="",
|
|
381
|
+
completion="",
|
|
382
|
+
prompt_tokens=input_tokens,
|
|
383
|
+
completion_tokens=output_tokens,
|
|
384
|
+
)
|
|
385
|
+
if cost is not None:
|
|
386
|
+
return cost
|
|
387
|
+
except Exception:
|
|
388
|
+
continue
|
|
389
|
+
|
|
390
|
+
# Fallback to hardcoded pricing
|
|
391
|
+
input_price, output_price = None, None
|
|
392
|
+
for model_prefix, (inp, outp) in _PRICING.items():
|
|
393
|
+
if model_lower.startswith(model_prefix):
|
|
394
|
+
input_price, output_price = inp, outp
|
|
395
|
+
break
|
|
396
|
+
|
|
397
|
+
if input_price is None:
|
|
398
|
+
return None
|
|
399
|
+
|
|
400
|
+
input_cost = (input_tokens / 1_000_000) * input_price
|
|
401
|
+
output_cost = (output_tokens / 1_000_000) * (output_price or 0)
|
|
402
|
+
|
|
403
|
+
return input_cost + output_cost
|
|
404
|
+
|
|
405
|
+
def get_output_buffer(self, model: str, default: int = 4000) -> int:
|
|
406
|
+
"""Get recommended output buffer."""
|
|
407
|
+
# Gemini models can output up to 8K tokens
|
|
408
|
+
return min(8192, default)
|
|
409
|
+
|
|
410
|
+
@classmethod
|
|
411
|
+
def get_openai_compatible_url(cls, api_key: str) -> str:
|
|
412
|
+
"""Get OpenAI-compatible endpoint URL.
|
|
413
|
+
|
|
414
|
+
Use this with the OpenAI client:
|
|
415
|
+
from openai import OpenAI
|
|
416
|
+
client = OpenAI(
|
|
417
|
+
api_key=api_key,
|
|
418
|
+
base_url=GoogleProvider.get_openai_compatible_url(api_key),
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
api_key: Google AI API key.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
Base URL for OpenAI-compatible requests.
|
|
426
|
+
"""
|
|
427
|
+
return cls.OPENAI_COMPATIBLE_BASE_URL
|