headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
"""Tokenizer registry for universal model support.
|
|
2
|
+
|
|
3
|
+
Provides automatic tokenizer selection based on model name with
|
|
4
|
+
support for multiple backends and custom tokenizers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import re
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
from .base import TokenCounter
|
|
15
|
+
from .estimator import EstimatingTokenCounter
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Model pattern matching for tokenizer selection
|
|
24
|
+
# Order matters - more specific patterns first
|
|
25
|
+
MODEL_PATTERNS: list[tuple[str, str]] = [
|
|
26
|
+
# OpenAI models -> tiktoken
|
|
27
|
+
(r"^gpt-4o", "tiktoken"),
|
|
28
|
+
(r"^gpt-4", "tiktoken"),
|
|
29
|
+
(r"^gpt-3\.5", "tiktoken"),
|
|
30
|
+
(r"^o1", "tiktoken"),
|
|
31
|
+
(r"^o3", "tiktoken"),
|
|
32
|
+
(r"^text-embedding", "tiktoken"),
|
|
33
|
+
(r"^text-davinci", "tiktoken"),
|
|
34
|
+
(r"^code-", "tiktoken"),
|
|
35
|
+
(r"^davinci", "tiktoken"),
|
|
36
|
+
(r"^curie", "tiktoken"),
|
|
37
|
+
(r"^babbage", "tiktoken"),
|
|
38
|
+
(r"^ada", "tiktoken"),
|
|
39
|
+
# Anthropic models -> estimation (Claude uses custom tokenizer)
|
|
40
|
+
(r"^claude-", "anthropic"),
|
|
41
|
+
# Llama family -> huggingface (when available)
|
|
42
|
+
(r"^llama", "huggingface"),
|
|
43
|
+
(r"^meta-llama", "huggingface"),
|
|
44
|
+
(r"^codellama", "huggingface"),
|
|
45
|
+
# Mistral family -> official mistral tokenizer
|
|
46
|
+
(r"^mistral", "mistral"),
|
|
47
|
+
(r"^mixtral", "mistral"),
|
|
48
|
+
(r"^codestral", "mistral"),
|
|
49
|
+
(r"^ministral", "mistral"),
|
|
50
|
+
(r"^pixtral", "mistral"),
|
|
51
|
+
# Google models -> estimation (Gemini uses SentencePiece)
|
|
52
|
+
(r"^gemini", "google"),
|
|
53
|
+
(r"^palm", "google"),
|
|
54
|
+
# Cohere models -> estimation
|
|
55
|
+
(r"^command", "cohere"),
|
|
56
|
+
# Open models commonly served via OpenAI-compatible APIs
|
|
57
|
+
(r"^phi-", "huggingface"),
|
|
58
|
+
(r"^qwen", "huggingface"),
|
|
59
|
+
(r"^deepseek", "huggingface"),
|
|
60
|
+
(r"^yi-", "huggingface"),
|
|
61
|
+
(r"^falcon", "huggingface"),
|
|
62
|
+
(r"^mpt-", "huggingface"),
|
|
63
|
+
(r"^starcoder", "huggingface"),
|
|
64
|
+
(r"^codegen", "huggingface"),
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class TokenizerRegistry:
|
|
69
|
+
"""Registry for tokenizer instances and factories.
|
|
70
|
+
|
|
71
|
+
Supports:
|
|
72
|
+
- Automatic tokenizer selection based on model name
|
|
73
|
+
- Custom tokenizer registration
|
|
74
|
+
- Multiple backends (tiktoken, huggingface, estimation)
|
|
75
|
+
- Lazy loading of tokenizer dependencies
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
# Auto-detect tokenizer
|
|
79
|
+
tokenizer = TokenizerRegistry.get("gpt-4o")
|
|
80
|
+
|
|
81
|
+
# Register custom tokenizer
|
|
82
|
+
TokenizerRegistry.register("my-model", my_tokenizer)
|
|
83
|
+
|
|
84
|
+
# Use specific backend
|
|
85
|
+
tokenizer = TokenizerRegistry.get("llama-3", backend="huggingface")
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
# Singleton registry instance
|
|
89
|
+
_instance: TokenizerRegistry | None = None
|
|
90
|
+
|
|
91
|
+
# Registered tokenizers (model -> tokenizer instance)
|
|
92
|
+
_tokenizers: dict[str, TokenCounter] = {}
|
|
93
|
+
|
|
94
|
+
# Registered factories (backend -> factory function)
|
|
95
|
+
_factories: dict[str, Callable[[str], TokenCounter]] = {}
|
|
96
|
+
|
|
97
|
+
# Cache for auto-detected tokenizers
|
|
98
|
+
_cache: dict[str, TokenCounter] = {}
|
|
99
|
+
|
|
100
|
+
def __new__(cls) -> TokenizerRegistry:
|
|
101
|
+
"""Singleton pattern."""
|
|
102
|
+
if cls._instance is None:
|
|
103
|
+
cls._instance = super().__new__(cls)
|
|
104
|
+
cls._instance._init_factories()
|
|
105
|
+
return cls._instance
|
|
106
|
+
|
|
107
|
+
def _init_factories(self) -> None:
|
|
108
|
+
"""Initialize default tokenizer factories."""
|
|
109
|
+
self._factories = {
|
|
110
|
+
"tiktoken": self._create_tiktoken,
|
|
111
|
+
"huggingface": self._create_huggingface,
|
|
112
|
+
"anthropic": self._create_anthropic,
|
|
113
|
+
"google": self._create_google,
|
|
114
|
+
"cohere": self._create_cohere,
|
|
115
|
+
"mistral": self._create_mistral,
|
|
116
|
+
"estimation": self._create_estimation,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def get(
|
|
121
|
+
cls,
|
|
122
|
+
model: str,
|
|
123
|
+
backend: str | None = None,
|
|
124
|
+
fallback: bool = True,
|
|
125
|
+
) -> TokenCounter:
|
|
126
|
+
"""Get tokenizer for a model.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
model: Model name (e.g., 'gpt-4o', 'claude-3-sonnet').
|
|
130
|
+
backend: Force specific backend ('tiktoken', 'huggingface', etc.).
|
|
131
|
+
If None, auto-detects based on model name.
|
|
132
|
+
fallback: If True, fall back to estimation on errors.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
TokenCounter instance for the model.
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If backend not found and fallback=False.
|
|
139
|
+
"""
|
|
140
|
+
registry = cls()
|
|
141
|
+
model_lower = model.lower()
|
|
142
|
+
|
|
143
|
+
# Check for explicitly registered tokenizer
|
|
144
|
+
if model_lower in registry._tokenizers:
|
|
145
|
+
return registry._tokenizers[model_lower]
|
|
146
|
+
|
|
147
|
+
# Check cache
|
|
148
|
+
cache_key = f"{model_lower}:{backend or 'auto'}"
|
|
149
|
+
if cache_key in registry._cache:
|
|
150
|
+
return registry._cache[cache_key]
|
|
151
|
+
|
|
152
|
+
# Create tokenizer
|
|
153
|
+
try:
|
|
154
|
+
tokenizer = registry._create_tokenizer(model, backend)
|
|
155
|
+
registry._cache[cache_key] = tokenizer
|
|
156
|
+
return tokenizer
|
|
157
|
+
except Exception as e:
|
|
158
|
+
if fallback:
|
|
159
|
+
logger.warning(
|
|
160
|
+
f"Failed to create tokenizer for {model}: {e}. Falling back to estimation."
|
|
161
|
+
)
|
|
162
|
+
tokenizer = EstimatingTokenCounter()
|
|
163
|
+
registry._cache[cache_key] = tokenizer
|
|
164
|
+
return tokenizer
|
|
165
|
+
raise ValueError(f"No tokenizer available for {model}: {e}") from e
|
|
166
|
+
|
|
167
|
+
@classmethod
|
|
168
|
+
def register(
|
|
169
|
+
cls,
|
|
170
|
+
model: str,
|
|
171
|
+
tokenizer: TokenCounter | None = None,
|
|
172
|
+
factory: Callable[[str], TokenCounter] | None = None,
|
|
173
|
+
) -> None:
|
|
174
|
+
"""Register a tokenizer or factory for a model.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
model: Model name to register.
|
|
178
|
+
tokenizer: Pre-instantiated tokenizer instance.
|
|
179
|
+
factory: Factory function that creates tokenizer for model.
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If neither tokenizer nor factory provided.
|
|
183
|
+
"""
|
|
184
|
+
registry = cls()
|
|
185
|
+
model_lower = model.lower()
|
|
186
|
+
|
|
187
|
+
if tokenizer is not None:
|
|
188
|
+
registry._tokenizers[model_lower] = tokenizer
|
|
189
|
+
elif factory is not None:
|
|
190
|
+
registry._factories[model_lower] = factory
|
|
191
|
+
else:
|
|
192
|
+
raise ValueError("Must provide either tokenizer or factory")
|
|
193
|
+
|
|
194
|
+
# Clear cache for this model
|
|
195
|
+
keys_to_remove = [k for k in registry._cache if k.startswith(model_lower)]
|
|
196
|
+
for key in keys_to_remove:
|
|
197
|
+
del registry._cache[key]
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def register_backend(
|
|
201
|
+
cls,
|
|
202
|
+
backend: str,
|
|
203
|
+
factory: Callable[[str], TokenCounter],
|
|
204
|
+
) -> None:
|
|
205
|
+
"""Register a backend factory.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
backend: Backend name.
|
|
209
|
+
factory: Factory function (model: str) -> TokenCounter.
|
|
210
|
+
"""
|
|
211
|
+
registry = cls()
|
|
212
|
+
registry._factories[backend] = factory
|
|
213
|
+
|
|
214
|
+
@classmethod
|
|
215
|
+
def list_backends(cls) -> list[str]:
|
|
216
|
+
"""List available backends."""
|
|
217
|
+
registry = cls()
|
|
218
|
+
return list(registry._factories.keys())
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def list_registered(cls) -> list[str]:
|
|
222
|
+
"""List explicitly registered models."""
|
|
223
|
+
registry = cls()
|
|
224
|
+
return list(registry._tokenizers.keys())
|
|
225
|
+
|
|
226
|
+
@classmethod
|
|
227
|
+
def clear_cache(cls) -> None:
|
|
228
|
+
"""Clear the tokenizer cache."""
|
|
229
|
+
registry = cls()
|
|
230
|
+
registry._cache.clear()
|
|
231
|
+
|
|
232
|
+
def _create_tokenizer(
|
|
233
|
+
self,
|
|
234
|
+
model: str,
|
|
235
|
+
backend: str | None,
|
|
236
|
+
) -> TokenCounter:
|
|
237
|
+
"""Create tokenizer for model.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
model: Model name.
|
|
241
|
+
backend: Backend to use (or None for auto-detect).
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
TokenCounter instance.
|
|
245
|
+
"""
|
|
246
|
+
if backend is None:
|
|
247
|
+
backend = self._detect_backend(model)
|
|
248
|
+
|
|
249
|
+
factory = self._factories.get(backend)
|
|
250
|
+
if factory is None:
|
|
251
|
+
raise ValueError(f"Unknown backend: {backend}")
|
|
252
|
+
|
|
253
|
+
return factory(model)
|
|
254
|
+
|
|
255
|
+
def _create_mistral(self, model: str) -> TokenCounter:
|
|
256
|
+
"""Create Mistral tokenizer using official mistral-common."""
|
|
257
|
+
try:
|
|
258
|
+
from .mistral import MistralTokenizer, is_mistral_available
|
|
259
|
+
|
|
260
|
+
if is_mistral_available():
|
|
261
|
+
return MistralTokenizer(model)
|
|
262
|
+
except ImportError:
|
|
263
|
+
pass
|
|
264
|
+
|
|
265
|
+
logger.warning(
|
|
266
|
+
"mistral-common not installed for Mistral tokenizer. "
|
|
267
|
+
"Install with: pip install mistral-common"
|
|
268
|
+
)
|
|
269
|
+
return EstimatingTokenCounter()
|
|
270
|
+
|
|
271
|
+
def _detect_backend(self, model: str) -> str:
|
|
272
|
+
"""Detect best backend for model.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
model: Model name.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Backend name.
|
|
279
|
+
"""
|
|
280
|
+
model_lower = model.lower()
|
|
281
|
+
|
|
282
|
+
for pattern, backend in MODEL_PATTERNS:
|
|
283
|
+
if re.match(pattern, model_lower):
|
|
284
|
+
return backend
|
|
285
|
+
|
|
286
|
+
# Default to estimation for unknown models
|
|
287
|
+
return "estimation"
|
|
288
|
+
|
|
289
|
+
def _create_tiktoken(self, model: str) -> TokenCounter:
|
|
290
|
+
"""Create tiktoken-based tokenizer."""
|
|
291
|
+
try:
|
|
292
|
+
from .tiktoken_counter import TiktokenCounter
|
|
293
|
+
|
|
294
|
+
return TiktokenCounter(model)
|
|
295
|
+
except ImportError:
|
|
296
|
+
logger.warning("tiktoken not installed. Install with: pip install tiktoken")
|
|
297
|
+
return EstimatingTokenCounter()
|
|
298
|
+
|
|
299
|
+
def _create_huggingface(self, model: str) -> TokenCounter:
|
|
300
|
+
"""Create HuggingFace-based tokenizer."""
|
|
301
|
+
try:
|
|
302
|
+
from .huggingface import HuggingFaceTokenizer
|
|
303
|
+
|
|
304
|
+
return HuggingFaceTokenizer(model)
|
|
305
|
+
except ImportError:
|
|
306
|
+
logger.warning(
|
|
307
|
+
"transformers not installed for HuggingFace tokenizer. "
|
|
308
|
+
"Install with: pip install transformers"
|
|
309
|
+
)
|
|
310
|
+
return EstimatingTokenCounter()
|
|
311
|
+
except Exception as e:
|
|
312
|
+
logger.warning(f"Failed to load HuggingFace tokenizer for {model}: {e}")
|
|
313
|
+
return EstimatingTokenCounter()
|
|
314
|
+
|
|
315
|
+
def _create_anthropic(self, model: str) -> TokenCounter:
|
|
316
|
+
"""Create Anthropic tokenizer.
|
|
317
|
+
|
|
318
|
+
Anthropic uses a custom tokenizer that's not publicly available.
|
|
319
|
+
We use estimation calibrated for Claude models.
|
|
320
|
+
"""
|
|
321
|
+
# Claude models use ~3.5 chars per token on average
|
|
322
|
+
return EstimatingTokenCounter(chars_per_token=3.5)
|
|
323
|
+
|
|
324
|
+
def _create_google(self, model: str) -> TokenCounter:
|
|
325
|
+
"""Create Google tokenizer.
|
|
326
|
+
|
|
327
|
+
Gemini uses SentencePiece which isn't easily accessible.
|
|
328
|
+
We use estimation calibrated for Gemini models.
|
|
329
|
+
"""
|
|
330
|
+
# Gemini models use ~4 chars per token
|
|
331
|
+
return EstimatingTokenCounter(chars_per_token=4.0)
|
|
332
|
+
|
|
333
|
+
def _create_cohere(self, model: str) -> TokenCounter:
|
|
334
|
+
"""Create Cohere tokenizer.
|
|
335
|
+
|
|
336
|
+
Cohere has its own tokenizer, we use estimation.
|
|
337
|
+
"""
|
|
338
|
+
return EstimatingTokenCounter(chars_per_token=4.0)
|
|
339
|
+
|
|
340
|
+
def _create_estimation(self, model: str) -> TokenCounter:
|
|
341
|
+
"""Create estimation-based tokenizer."""
|
|
342
|
+
return EstimatingTokenCounter()
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
# Convenience functions
|
|
346
|
+
def get_tokenizer(
|
|
347
|
+
model: str,
|
|
348
|
+
backend: str | None = None,
|
|
349
|
+
fallback: bool = True,
|
|
350
|
+
) -> TokenCounter:
|
|
351
|
+
"""Get tokenizer for a model.
|
|
352
|
+
|
|
353
|
+
This is the main entry point for getting tokenizers.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
model: Model name (e.g., 'gpt-4o', 'claude-3-sonnet').
|
|
357
|
+
backend: Force specific backend ('tiktoken', 'huggingface', etc.).
|
|
358
|
+
fallback: If True, fall back to estimation on errors.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
TokenCounter instance.
|
|
362
|
+
|
|
363
|
+
Example:
|
|
364
|
+
tokenizer = get_tokenizer("gpt-4o")
|
|
365
|
+
tokens = tokenizer.count_text("Hello, world!")
|
|
366
|
+
"""
|
|
367
|
+
return TokenizerRegistry.get(model, backend, fallback)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def register_tokenizer(
|
|
371
|
+
model: str,
|
|
372
|
+
tokenizer: TokenCounter | None = None,
|
|
373
|
+
factory: Callable[[str], TokenCounter] | None = None,
|
|
374
|
+
) -> None:
|
|
375
|
+
"""Register a custom tokenizer for a model.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
model: Model name.
|
|
379
|
+
tokenizer: Tokenizer instance.
|
|
380
|
+
factory: Factory function.
|
|
381
|
+
|
|
382
|
+
Example:
|
|
383
|
+
# Register instance
|
|
384
|
+
register_tokenizer("my-model", MyTokenizer())
|
|
385
|
+
|
|
386
|
+
# Register factory
|
|
387
|
+
register_tokenizer("my-model", factory=lambda m: MyTokenizer(m))
|
|
388
|
+
"""
|
|
389
|
+
TokenizerRegistry.register(model, tokenizer, factory)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def list_supported_models() -> dict[str, str]:
|
|
393
|
+
"""List models with known tokenizer mappings.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
Dict mapping model pattern to backend.
|
|
397
|
+
"""
|
|
398
|
+
return dict(MODEL_PATTERNS)
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
"""Tiktoken-based token counter for OpenAI models.
|
|
2
|
+
|
|
3
|
+
Tiktoken is OpenAI's fast BPE tokenizer used by GPT models.
|
|
4
|
+
It supports multiple encodings:
|
|
5
|
+
- cl100k_base: GPT-4, GPT-3.5-turbo, text-embedding-ada-002
|
|
6
|
+
- o200k_base: GPT-4o, GPT-4o-mini
|
|
7
|
+
- p50k_base: Codex models, text-davinci-002/003
|
|
8
|
+
- r50k_base: GPT-3 models (davinci, curie, etc.)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from functools import lru_cache
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from .base import BaseTokenizer
|
|
17
|
+
|
|
18
|
+
# Model to encoding mapping
|
|
19
|
+
MODEL_TO_ENCODING = {
|
|
20
|
+
# GPT-4o family (o200k_base)
|
|
21
|
+
"gpt-4o": "o200k_base",
|
|
22
|
+
"gpt-4o-mini": "o200k_base",
|
|
23
|
+
"gpt-4o-2024-05-13": "o200k_base",
|
|
24
|
+
"gpt-4o-2024-08-06": "o200k_base",
|
|
25
|
+
"gpt-4o-2024-11-20": "o200k_base",
|
|
26
|
+
"gpt-4o-mini-2024-07-18": "o200k_base",
|
|
27
|
+
# o1 reasoning models (o200k_base)
|
|
28
|
+
"o1": "o200k_base",
|
|
29
|
+
"o1-mini": "o200k_base",
|
|
30
|
+
"o1-preview": "o200k_base",
|
|
31
|
+
"o3-mini": "o200k_base",
|
|
32
|
+
# GPT-4 family (cl100k_base)
|
|
33
|
+
"gpt-4": "cl100k_base",
|
|
34
|
+
"gpt-4-turbo": "cl100k_base",
|
|
35
|
+
"gpt-4-turbo-preview": "cl100k_base",
|
|
36
|
+
"gpt-4-0314": "cl100k_base",
|
|
37
|
+
"gpt-4-0613": "cl100k_base",
|
|
38
|
+
"gpt-4-32k": "cl100k_base",
|
|
39
|
+
"gpt-4-32k-0314": "cl100k_base",
|
|
40
|
+
"gpt-4-32k-0613": "cl100k_base",
|
|
41
|
+
"gpt-4-1106-preview": "cl100k_base",
|
|
42
|
+
"gpt-4-0125-preview": "cl100k_base",
|
|
43
|
+
"gpt-4-turbo-2024-04-09": "cl100k_base",
|
|
44
|
+
# GPT-3.5 family (cl100k_base)
|
|
45
|
+
"gpt-3.5-turbo": "cl100k_base",
|
|
46
|
+
"gpt-3.5-turbo-0301": "cl100k_base",
|
|
47
|
+
"gpt-3.5-turbo-0613": "cl100k_base",
|
|
48
|
+
"gpt-3.5-turbo-1106": "cl100k_base",
|
|
49
|
+
"gpt-3.5-turbo-0125": "cl100k_base",
|
|
50
|
+
"gpt-3.5-turbo-16k": "cl100k_base",
|
|
51
|
+
"gpt-3.5-turbo-16k-0613": "cl100k_base",
|
|
52
|
+
"gpt-3.5-turbo-instruct": "cl100k_base",
|
|
53
|
+
# Embeddings (cl100k_base)
|
|
54
|
+
"text-embedding-ada-002": "cl100k_base",
|
|
55
|
+
"text-embedding-3-small": "cl100k_base",
|
|
56
|
+
"text-embedding-3-large": "cl100k_base",
|
|
57
|
+
# Codex (p50k_base)
|
|
58
|
+
"code-davinci-002": "p50k_base",
|
|
59
|
+
"code-davinci-001": "p50k_base",
|
|
60
|
+
"code-cushman-002": "p50k_base",
|
|
61
|
+
"code-cushman-001": "p50k_base",
|
|
62
|
+
# Legacy GPT-3 (r50k_base)
|
|
63
|
+
"text-davinci-003": "p50k_base",
|
|
64
|
+
"text-davinci-002": "p50k_base",
|
|
65
|
+
"text-davinci-001": "r50k_base",
|
|
66
|
+
"text-curie-001": "r50k_base",
|
|
67
|
+
"text-babbage-001": "r50k_base",
|
|
68
|
+
"text-ada-001": "r50k_base",
|
|
69
|
+
"davinci": "r50k_base",
|
|
70
|
+
"curie": "r50k_base",
|
|
71
|
+
"babbage": "r50k_base",
|
|
72
|
+
"ada": "r50k_base",
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Default encoding for unknown models
|
|
76
|
+
DEFAULT_ENCODING = "cl100k_base"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@lru_cache(maxsize=8)
|
|
80
|
+
def _get_encoding(encoding_name: str):
|
|
81
|
+
"""Get tiktoken encoding, cached for performance."""
|
|
82
|
+
import tiktoken
|
|
83
|
+
|
|
84
|
+
return tiktoken.get_encoding(encoding_name)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_encoding_for_model(model: str) -> str:
|
|
88
|
+
"""Get the tiktoken encoding name for a model.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model: Model name (e.g., 'gpt-4o', 'gpt-3.5-turbo').
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Encoding name (e.g., 'o200k_base', 'cl100k_base').
|
|
95
|
+
"""
|
|
96
|
+
# Direct lookup
|
|
97
|
+
if model in MODEL_TO_ENCODING:
|
|
98
|
+
return MODEL_TO_ENCODING[model]
|
|
99
|
+
|
|
100
|
+
# Try prefix matching for versioned models
|
|
101
|
+
for prefix in ["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5", "o1", "o3"]:
|
|
102
|
+
if model.startswith(prefix):
|
|
103
|
+
# Find any model with this prefix
|
|
104
|
+
for known_model, encoding in MODEL_TO_ENCODING.items():
|
|
105
|
+
if known_model.startswith(prefix):
|
|
106
|
+
return encoding
|
|
107
|
+
|
|
108
|
+
return DEFAULT_ENCODING
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class TiktokenCounter(BaseTokenizer):
|
|
112
|
+
"""Token counter using tiktoken (OpenAI's tokenizer).
|
|
113
|
+
|
|
114
|
+
This is the most accurate tokenizer for OpenAI models and provides
|
|
115
|
+
a good approximation for many other models that use similar BPE
|
|
116
|
+
tokenization.
|
|
117
|
+
|
|
118
|
+
Example:
|
|
119
|
+
counter = TiktokenCounter("gpt-4o")
|
|
120
|
+
tokens = counter.count_text("Hello, world!")
|
|
121
|
+
print(f"Token count: {tokens}")
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
# OpenAI-specific message overhead
|
|
125
|
+
MESSAGE_OVERHEAD = 3
|
|
126
|
+
REPLY_OVERHEAD = 3
|
|
127
|
+
|
|
128
|
+
def __init__(self, model: str = "gpt-4o"):
|
|
129
|
+
"""Initialize tiktoken counter.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model: Model name to determine encoding.
|
|
133
|
+
Defaults to 'gpt-4o' (o200k_base encoding).
|
|
134
|
+
"""
|
|
135
|
+
self.model = model
|
|
136
|
+
self.encoding_name = get_encoding_for_model(model)
|
|
137
|
+
self._encoding = None # Lazy load
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def encoding(self):
|
|
141
|
+
"""Lazy-load the encoding."""
|
|
142
|
+
if self._encoding is None:
|
|
143
|
+
self._encoding = _get_encoding(self.encoding_name)
|
|
144
|
+
return self._encoding
|
|
145
|
+
|
|
146
|
+
def count_text(self, text: str) -> int:
|
|
147
|
+
"""Count tokens in text using tiktoken.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
text: Text to tokenize.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Number of tokens.
|
|
154
|
+
"""
|
|
155
|
+
if not text:
|
|
156
|
+
return 0
|
|
157
|
+
return len(self.encoding.encode(text))
|
|
158
|
+
|
|
159
|
+
def count_messages(self, messages: list[dict[str, Any]]) -> int:
|
|
160
|
+
"""Count tokens in messages using OpenAI's exact formula.
|
|
161
|
+
|
|
162
|
+
This matches OpenAI's token counting for chat completions.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
messages: List of chat messages.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Total token count.
|
|
169
|
+
"""
|
|
170
|
+
total = 0
|
|
171
|
+
|
|
172
|
+
for message in messages:
|
|
173
|
+
# Every message has overhead for role and formatting
|
|
174
|
+
total += self.MESSAGE_OVERHEAD
|
|
175
|
+
|
|
176
|
+
for key, value in message.items():
|
|
177
|
+
if value is None:
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
if key == "content":
|
|
181
|
+
if isinstance(value, str):
|
|
182
|
+
total += self.count_text(value)
|
|
183
|
+
elif isinstance(value, list):
|
|
184
|
+
# Multi-part content
|
|
185
|
+
for part in value:
|
|
186
|
+
if isinstance(part, dict):
|
|
187
|
+
if part.get("type") == "text":
|
|
188
|
+
total += self.count_text(part.get("text", ""))
|
|
189
|
+
elif part.get("type") == "image_url":
|
|
190
|
+
# Image tokens vary by detail level
|
|
191
|
+
detail = part.get("image_url", {}).get("detail", "auto")
|
|
192
|
+
if detail == "low":
|
|
193
|
+
total += 85
|
|
194
|
+
else:
|
|
195
|
+
total += 170 # Base for high detail
|
|
196
|
+
else:
|
|
197
|
+
total += self.count_text(str(part))
|
|
198
|
+
elif isinstance(part, str):
|
|
199
|
+
total += self.count_text(part)
|
|
200
|
+
elif key == "role":
|
|
201
|
+
total += self.count_text(value)
|
|
202
|
+
elif key == "name":
|
|
203
|
+
total += self.count_text(value)
|
|
204
|
+
total += 1 # Name adds 1 token
|
|
205
|
+
elif key == "tool_calls":
|
|
206
|
+
for tool_call in value:
|
|
207
|
+
total += 3 # Tool call overhead
|
|
208
|
+
if "function" in tool_call:
|
|
209
|
+
func = tool_call["function"]
|
|
210
|
+
total += self.count_text(func.get("name", ""))
|
|
211
|
+
total += self.count_text(func.get("arguments", ""))
|
|
212
|
+
if "id" in tool_call:
|
|
213
|
+
total += self.count_text(tool_call["id"])
|
|
214
|
+
elif key == "tool_call_id":
|
|
215
|
+
total += self.count_text(value)
|
|
216
|
+
elif key == "function_call":
|
|
217
|
+
total += self.count_text(value.get("name", ""))
|
|
218
|
+
total += self.count_text(value.get("arguments", ""))
|
|
219
|
+
|
|
220
|
+
# Every reply is primed with assistant
|
|
221
|
+
total += self.REPLY_OVERHEAD
|
|
222
|
+
|
|
223
|
+
return total
|
|
224
|
+
|
|
225
|
+
def encode(self, text: str) -> list[int]:
|
|
226
|
+
"""Encode text to token IDs.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
text: Text to encode.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
List of token IDs.
|
|
233
|
+
"""
|
|
234
|
+
return self.encoding.encode(text)
|
|
235
|
+
|
|
236
|
+
def decode(self, tokens: list[int]) -> str:
|
|
237
|
+
"""Decode token IDs to text.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
tokens: List of token IDs.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Decoded text.
|
|
244
|
+
"""
|
|
245
|
+
return self.encoding.decode(tokens)
|
|
246
|
+
|
|
247
|
+
def __repr__(self) -> str:
|
|
248
|
+
return f"TiktokenCounter(model={self.model!r}, encoding={self.encoding_name!r})"
|