headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Base provider protocol for Headroom SDK.
|
|
2
|
+
|
|
3
|
+
Providers are responsible for:
|
|
4
|
+
- Token counting (model-specific)
|
|
5
|
+
- Model context limits
|
|
6
|
+
- Cost estimation (optional)
|
|
7
|
+
|
|
8
|
+
This module defines the protocols that all providers must implement.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from abc import ABC, abstractmethod
|
|
14
|
+
from typing import Any, Protocol, runtime_checkable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@runtime_checkable
|
|
18
|
+
class TokenCounter(Protocol):
|
|
19
|
+
"""Protocol for token counting implementations."""
|
|
20
|
+
|
|
21
|
+
def count_text(self, text: str) -> int:
|
|
22
|
+
"""Count tokens in a text string."""
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
def count_message(self, message: dict[str, Any]) -> int:
|
|
26
|
+
"""Count tokens in a single message dict."""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def count_messages(self, messages: list[dict[str, Any]]) -> int:
|
|
30
|
+
"""Count tokens in a list of messages."""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Provider(ABC):
|
|
35
|
+
"""
|
|
36
|
+
Abstract base class for LLM providers.
|
|
37
|
+
|
|
38
|
+
Providers encapsulate all model-specific behavior:
|
|
39
|
+
- Token counting
|
|
40
|
+
- Context window limits
|
|
41
|
+
- Cost estimation
|
|
42
|
+
|
|
43
|
+
Implementations must be explicit - no silent fallbacks.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def name(self) -> str:
|
|
49
|
+
"""Provider name (e.g., 'openai', 'anthropic')."""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def get_token_counter(self, model: str) -> TokenCounter:
|
|
54
|
+
"""
|
|
55
|
+
Get a token counter for a specific model.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model: The model name.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
TokenCounter instance for the model.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If model is not supported by this provider.
|
|
65
|
+
"""
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
@abstractmethod
|
|
69
|
+
def get_context_limit(self, model: str) -> int:
|
|
70
|
+
"""
|
|
71
|
+
Get the context window limit for a model.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
model: The model name.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Maximum context tokens for the model.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If model is not recognized.
|
|
81
|
+
"""
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def supports_model(self, model: str) -> bool:
|
|
86
|
+
"""
|
|
87
|
+
Check if this provider supports a given model.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
model: The model name.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
True if the model is supported.
|
|
94
|
+
"""
|
|
95
|
+
...
|
|
96
|
+
|
|
97
|
+
def estimate_cost(
|
|
98
|
+
self,
|
|
99
|
+
input_tokens: int,
|
|
100
|
+
output_tokens: int,
|
|
101
|
+
model: str,
|
|
102
|
+
cached_tokens: int = 0,
|
|
103
|
+
) -> float | None:
|
|
104
|
+
"""
|
|
105
|
+
Estimate API cost in USD.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
input_tokens: Number of input tokens.
|
|
109
|
+
output_tokens: Number of output tokens.
|
|
110
|
+
model: Model name.
|
|
111
|
+
cached_tokens: Number of cached input tokens.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Estimated cost in USD, or None if cost estimation not available.
|
|
115
|
+
"""
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
def get_output_buffer(self, model: str, default: int = 4000) -> int:
|
|
119
|
+
"""
|
|
120
|
+
Get recommended output buffer for a model.
|
|
121
|
+
|
|
122
|
+
Some models (like reasoning models) produce longer outputs.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
model: The model name.
|
|
126
|
+
default: Default buffer if no model-specific recommendation.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Recommended output token buffer.
|
|
130
|
+
"""
|
|
131
|
+
return default
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
"""Cohere provider for Headroom SDK.
|
|
2
|
+
|
|
3
|
+
Token counting uses Cohere's official tokenize API when a client
|
|
4
|
+
is provided. This gives accurate counts for all content types.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
import cohere
|
|
8
|
+
from headroom import CohereProvider
|
|
9
|
+
|
|
10
|
+
client = cohere.ClientV2() # Uses CO_API_KEY env var
|
|
11
|
+
provider = CohereProvider(client=client) # Accurate counting via API
|
|
12
|
+
|
|
13
|
+
# Or without client (uses estimation - less accurate)
|
|
14
|
+
provider = CohereProvider() # Warning: approximate counting
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
import warnings
|
|
21
|
+
from datetime import date
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from headroom.tokenizers import EstimatingTokenCounter
|
|
25
|
+
|
|
26
|
+
from .base import Provider, TokenCounter
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
import litellm
|
|
30
|
+
|
|
31
|
+
LITELLM_AVAILABLE = True
|
|
32
|
+
except ImportError:
|
|
33
|
+
LITELLM_AVAILABLE = False
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
# Warning flags
|
|
38
|
+
_FALLBACK_WARNING_SHOWN = False
|
|
39
|
+
|
|
40
|
+
# Pricing metadata
|
|
41
|
+
_PRICING_LAST_UPDATED = date(2025, 1, 6)
|
|
42
|
+
|
|
43
|
+
# Cohere model context limits
|
|
44
|
+
_CONTEXT_LIMITS: dict[str, int] = {
|
|
45
|
+
# Command A (latest, 2025)
|
|
46
|
+
"command-a-03-2025": 256000,
|
|
47
|
+
"command-a": 256000,
|
|
48
|
+
# Command R+ (2024)
|
|
49
|
+
"command-r-plus-08-2024": 128000,
|
|
50
|
+
"command-r-plus": 128000,
|
|
51
|
+
# Command R (2024)
|
|
52
|
+
"command-r-08-2024": 128000,
|
|
53
|
+
"command-r": 128000,
|
|
54
|
+
# Command (legacy)
|
|
55
|
+
"command": 4096,
|
|
56
|
+
"command-light": 4096,
|
|
57
|
+
"command-nightly": 128000,
|
|
58
|
+
# Embed models
|
|
59
|
+
"embed-english-v3.0": 512,
|
|
60
|
+
"embed-multilingual-v3.0": 512,
|
|
61
|
+
"embed-english-light-v3.0": 512,
|
|
62
|
+
"embed-multilingual-light-v3.0": 512,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
# Fallback pricing - LiteLLM is preferred source
|
|
66
|
+
# Pricing per 1M tokens (input, output)
|
|
67
|
+
_PRICING: dict[str, tuple[float, float]] = {
|
|
68
|
+
"command-a-03-2025": (2.50, 10.00),
|
|
69
|
+
"command-a": (2.50, 10.00),
|
|
70
|
+
"command-r-plus-08-2024": (2.50, 10.00),
|
|
71
|
+
"command-r-plus": (2.50, 10.00),
|
|
72
|
+
"command-r-08-2024": (0.15, 0.60),
|
|
73
|
+
"command-r": (0.15, 0.60),
|
|
74
|
+
"command": (1.00, 2.00),
|
|
75
|
+
"command-light": (0.30, 0.60),
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class CohereTokenCounter:
|
|
80
|
+
"""Token counter for Cohere models.
|
|
81
|
+
|
|
82
|
+
When a Cohere client is provided, uses the official tokenize API
|
|
83
|
+
for accurate counting. Falls back to estimation when no client
|
|
84
|
+
is available.
|
|
85
|
+
|
|
86
|
+
Usage:
|
|
87
|
+
import cohere
|
|
88
|
+
client = cohere.ClientV2()
|
|
89
|
+
|
|
90
|
+
# With API (accurate)
|
|
91
|
+
counter = CohereTokenCounter("command-r-plus", client=client)
|
|
92
|
+
|
|
93
|
+
# Without API (estimation)
|
|
94
|
+
counter = CohereTokenCounter("command-r-plus")
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(self, model: str, client: Any = None):
|
|
98
|
+
"""Initialize Cohere token counter.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
model: Cohere model name.
|
|
102
|
+
client: Optional cohere.ClientV2 for API-based counting.
|
|
103
|
+
"""
|
|
104
|
+
global _FALLBACK_WARNING_SHOWN
|
|
105
|
+
|
|
106
|
+
self.model = model
|
|
107
|
+
self._client = client
|
|
108
|
+
self._use_api = client is not None
|
|
109
|
+
|
|
110
|
+
# Cohere uses ~4 chars per token
|
|
111
|
+
self._estimator = EstimatingTokenCounter(chars_per_token=4.0)
|
|
112
|
+
|
|
113
|
+
if not self._use_api and not _FALLBACK_WARNING_SHOWN:
|
|
114
|
+
warnings.warn(
|
|
115
|
+
"CohereProvider: No client provided, using estimation. "
|
|
116
|
+
"For accurate counting, pass a Cohere client: "
|
|
117
|
+
"CohereProvider(client=cohere.ClientV2())",
|
|
118
|
+
UserWarning,
|
|
119
|
+
stacklevel=4,
|
|
120
|
+
)
|
|
121
|
+
_FALLBACK_WARNING_SHOWN = True
|
|
122
|
+
|
|
123
|
+
def count_text(self, text: str) -> int:
|
|
124
|
+
"""Count tokens in text.
|
|
125
|
+
|
|
126
|
+
Uses tokenize API if client available, otherwise estimates.
|
|
127
|
+
"""
|
|
128
|
+
if not text:
|
|
129
|
+
return 0
|
|
130
|
+
|
|
131
|
+
if self._use_api:
|
|
132
|
+
try:
|
|
133
|
+
response = self._client.tokenize(
|
|
134
|
+
text=text,
|
|
135
|
+
model=self.model,
|
|
136
|
+
)
|
|
137
|
+
return len(response.tokens)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
logger.debug(f"Cohere tokenize API failed: {e}, using estimation")
|
|
140
|
+
|
|
141
|
+
return self._estimator.count_text(text)
|
|
142
|
+
|
|
143
|
+
def count_message(self, message: dict[str, Any]) -> int:
|
|
144
|
+
"""Count tokens in a message."""
|
|
145
|
+
content = self._extract_content(message)
|
|
146
|
+
tokens = self.count_text(content)
|
|
147
|
+
tokens += 4 # Message overhead (role tokens, etc.)
|
|
148
|
+
return tokens
|
|
149
|
+
|
|
150
|
+
def count_messages(self, messages: list[dict[str, Any]]) -> int:
|
|
151
|
+
"""Count tokens in messages."""
|
|
152
|
+
if not messages:
|
|
153
|
+
return 0
|
|
154
|
+
|
|
155
|
+
# For API-based counting, concatenate all content
|
|
156
|
+
if self._use_api:
|
|
157
|
+
try:
|
|
158
|
+
all_content = []
|
|
159
|
+
for msg in messages:
|
|
160
|
+
content = self._extract_content(msg)
|
|
161
|
+
role = msg.get("role", "user")
|
|
162
|
+
all_content.append(f"{role}: {content}")
|
|
163
|
+
|
|
164
|
+
full_text = "\n".join(all_content)
|
|
165
|
+
response = self._client.tokenize(
|
|
166
|
+
text=full_text,
|
|
167
|
+
model=self.model,
|
|
168
|
+
)
|
|
169
|
+
return len(response.tokens)
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.debug(f"Cohere tokenize API failed: {e}, using estimation")
|
|
172
|
+
|
|
173
|
+
# Fallback to estimation
|
|
174
|
+
total = sum(self.count_message(msg) for msg in messages)
|
|
175
|
+
total += 3 # Priming tokens
|
|
176
|
+
return total
|
|
177
|
+
|
|
178
|
+
def _extract_content(self, message: dict[str, Any]) -> str:
|
|
179
|
+
"""Extract text content from message."""
|
|
180
|
+
content = message.get("content", "")
|
|
181
|
+
if isinstance(content, str):
|
|
182
|
+
return content
|
|
183
|
+
elif isinstance(content, list):
|
|
184
|
+
parts = []
|
|
185
|
+
for part in content:
|
|
186
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
187
|
+
parts.append(part.get("text", ""))
|
|
188
|
+
elif isinstance(part, str):
|
|
189
|
+
parts.append(part)
|
|
190
|
+
return "\n".join(parts)
|
|
191
|
+
return str(content)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class CohereProvider(Provider):
|
|
195
|
+
"""Provider for Cohere Command models.
|
|
196
|
+
|
|
197
|
+
Supports Command R, Command R+, and Command A model families.
|
|
198
|
+
|
|
199
|
+
Example:
|
|
200
|
+
import cohere
|
|
201
|
+
client = cohere.ClientV2()
|
|
202
|
+
|
|
203
|
+
# With client (accurate token counting via API)
|
|
204
|
+
provider = CohereProvider(client=client)
|
|
205
|
+
|
|
206
|
+
# Without client (estimation-based counting)
|
|
207
|
+
provider = CohereProvider()
|
|
208
|
+
|
|
209
|
+
# Token counting
|
|
210
|
+
counter = provider.get_token_counter("command-r-plus")
|
|
211
|
+
tokens = counter.count_text("Hello, world!")
|
|
212
|
+
|
|
213
|
+
# Context limits
|
|
214
|
+
limit = provider.get_context_limit("command-a") # 256K tokens
|
|
215
|
+
|
|
216
|
+
# Cost estimation
|
|
217
|
+
cost = provider.estimate_cost(
|
|
218
|
+
input_tokens=100000,
|
|
219
|
+
output_tokens=10000,
|
|
220
|
+
model="command-r-plus",
|
|
221
|
+
)
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def __init__(self, client: Any = None):
|
|
225
|
+
"""Initialize Cohere provider.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
client: Optional cohere.ClientV2 for API-based token counting.
|
|
229
|
+
If provided, uses tokenize API for accurate counts.
|
|
230
|
+
"""
|
|
231
|
+
self._client = client
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def name(self) -> str:
|
|
235
|
+
return "cohere"
|
|
236
|
+
|
|
237
|
+
def supports_model(self, model: str) -> bool:
|
|
238
|
+
"""Check if model is a known Cohere model."""
|
|
239
|
+
model_lower = model.lower()
|
|
240
|
+
if model_lower in _CONTEXT_LIMITS:
|
|
241
|
+
return True
|
|
242
|
+
# Check prefix match
|
|
243
|
+
for prefix in ["command-a", "command-r", "command", "embed-"]:
|
|
244
|
+
if model_lower.startswith(prefix):
|
|
245
|
+
return True
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
def get_token_counter(self, model: str) -> TokenCounter:
|
|
249
|
+
"""Get token counter for a Cohere model.
|
|
250
|
+
|
|
251
|
+
Uses tokenize API if client was provided, otherwise estimates.
|
|
252
|
+
"""
|
|
253
|
+
if not self.supports_model(model):
|
|
254
|
+
raise ValueError(
|
|
255
|
+
f"Model '{model}' is not recognized as a Cohere model. "
|
|
256
|
+
f"Supported models: {list(_CONTEXT_LIMITS.keys())}"
|
|
257
|
+
)
|
|
258
|
+
return CohereTokenCounter(model, client=self._client)
|
|
259
|
+
|
|
260
|
+
def get_context_limit(self, model: str) -> int:
|
|
261
|
+
"""Get context limit for a Cohere model.
|
|
262
|
+
|
|
263
|
+
Tries LiteLLM first (with and without 'cohere/' prefix),
|
|
264
|
+
then falls back to built-in limits.
|
|
265
|
+
"""
|
|
266
|
+
# Try LiteLLM first
|
|
267
|
+
if LITELLM_AVAILABLE:
|
|
268
|
+
for model_variant in [f"cohere/{model}", model]:
|
|
269
|
+
try:
|
|
270
|
+
info = litellm.get_model_info(model_variant)
|
|
271
|
+
if info and "max_input_tokens" in info:
|
|
272
|
+
result = info["max_input_tokens"]
|
|
273
|
+
if result is not None:
|
|
274
|
+
return result
|
|
275
|
+
if info and "max_tokens" in info:
|
|
276
|
+
result = info["max_tokens"]
|
|
277
|
+
if result is not None:
|
|
278
|
+
return result
|
|
279
|
+
except Exception:
|
|
280
|
+
pass
|
|
281
|
+
|
|
282
|
+
# Fallback to built-in limits
|
|
283
|
+
model_lower = model.lower()
|
|
284
|
+
|
|
285
|
+
# Direct match
|
|
286
|
+
if model_lower in _CONTEXT_LIMITS:
|
|
287
|
+
return _CONTEXT_LIMITS[model_lower]
|
|
288
|
+
|
|
289
|
+
# Prefix match
|
|
290
|
+
for prefix, limit in [
|
|
291
|
+
("command-a", 256000),
|
|
292
|
+
("command-r-plus", 128000),
|
|
293
|
+
("command-r", 128000),
|
|
294
|
+
("command", 4096),
|
|
295
|
+
("embed-", 512),
|
|
296
|
+
]:
|
|
297
|
+
if model_lower.startswith(prefix):
|
|
298
|
+
return limit
|
|
299
|
+
|
|
300
|
+
raise ValueError(
|
|
301
|
+
f"Unknown context limit for model '{model}'. "
|
|
302
|
+
f"Known models: {list(_CONTEXT_LIMITS.keys())}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def estimate_cost(
|
|
306
|
+
self,
|
|
307
|
+
input_tokens: int,
|
|
308
|
+
output_tokens: int,
|
|
309
|
+
model: str,
|
|
310
|
+
cached_tokens: int = 0,
|
|
311
|
+
) -> float | None:
|
|
312
|
+
"""Estimate cost for Cohere API call.
|
|
313
|
+
|
|
314
|
+
Tries LiteLLM first (with and without 'cohere/' prefix),
|
|
315
|
+
then falls back to built-in pricing.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
input_tokens: Number of input tokens.
|
|
319
|
+
output_tokens: Number of output tokens.
|
|
320
|
+
model: Model name.
|
|
321
|
+
cached_tokens: Not used by Cohere.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Estimated cost in USD, or None if pricing unknown.
|
|
325
|
+
"""
|
|
326
|
+
# Try LiteLLM first
|
|
327
|
+
if LITELLM_AVAILABLE:
|
|
328
|
+
for model_variant in [f"cohere/{model}", model]:
|
|
329
|
+
try:
|
|
330
|
+
cost = litellm.completion_cost(
|
|
331
|
+
model=model_variant,
|
|
332
|
+
prompt="",
|
|
333
|
+
completion="",
|
|
334
|
+
prompt_tokens=input_tokens,
|
|
335
|
+
completion_tokens=output_tokens,
|
|
336
|
+
)
|
|
337
|
+
if cost is not None:
|
|
338
|
+
return cost
|
|
339
|
+
except Exception:
|
|
340
|
+
pass
|
|
341
|
+
|
|
342
|
+
# Fallback to built-in pricing
|
|
343
|
+
model_lower = model.lower()
|
|
344
|
+
|
|
345
|
+
# Find pricing
|
|
346
|
+
input_price, output_price = None, None
|
|
347
|
+
for model_prefix, (inp, outp) in _PRICING.items():
|
|
348
|
+
if model_lower.startswith(model_prefix):
|
|
349
|
+
input_price, output_price = inp, outp
|
|
350
|
+
break
|
|
351
|
+
|
|
352
|
+
if input_price is None:
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
input_cost = (input_tokens / 1_000_000) * input_price
|
|
356
|
+
output_cost = (output_tokens / 1_000_000) * (output_price or 0)
|
|
357
|
+
|
|
358
|
+
return input_cost + output_cost
|
|
359
|
+
|
|
360
|
+
def get_output_buffer(self, model: str, default: int = 4000) -> int:
|
|
361
|
+
"""Get recommended output buffer."""
|
|
362
|
+
return default
|