headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""HuggingFace tokenizer wrapper for open models.
|
|
2
|
+
|
|
3
|
+
Supports Llama, Mistral, Falcon, and other models with HuggingFace
|
|
4
|
+
tokenizers. Requires the `transformers` library.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from functools import lru_cache
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .base import BaseTokenizer
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Model name to HuggingFace tokenizer mapping
|
|
19
|
+
# Maps common model names to their HuggingFace tokenizer identifiers
|
|
20
|
+
MODEL_TO_TOKENIZER: dict[str, str] = {
|
|
21
|
+
# Llama 3 family
|
|
22
|
+
"llama-3": "meta-llama/Meta-Llama-3-8B",
|
|
23
|
+
"llama-3-8b": "meta-llama/Meta-Llama-3-8B",
|
|
24
|
+
"llama-3-70b": "meta-llama/Meta-Llama-3-70B",
|
|
25
|
+
"llama-3.1-8b": "meta-llama/Llama-3.1-8B",
|
|
26
|
+
"llama-3.1-70b": "meta-llama/Llama-3.1-70B",
|
|
27
|
+
"llama-3.1-405b": "meta-llama/Llama-3.1-405B",
|
|
28
|
+
"llama-3.2-1b": "meta-llama/Llama-3.2-1B",
|
|
29
|
+
"llama-3.2-3b": "meta-llama/Llama-3.2-3B",
|
|
30
|
+
"llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
|
|
31
|
+
# Llama 2 family
|
|
32
|
+
"llama-2": "meta-llama/Llama-2-7b-hf",
|
|
33
|
+
"llama-2-7b": "meta-llama/Llama-2-7b-hf",
|
|
34
|
+
"llama-2-13b": "meta-llama/Llama-2-13b-hf",
|
|
35
|
+
"llama-2-70b": "meta-llama/Llama-2-70b-hf",
|
|
36
|
+
# CodeLlama
|
|
37
|
+
"codellama": "codellama/CodeLlama-7b-hf",
|
|
38
|
+
"codellama-7b": "codellama/CodeLlama-7b-hf",
|
|
39
|
+
"codellama-13b": "codellama/CodeLlama-13b-hf",
|
|
40
|
+
"codellama-34b": "codellama/CodeLlama-34b-hf",
|
|
41
|
+
# Mistral family
|
|
42
|
+
"mistral": "mistralai/Mistral-7B-v0.1",
|
|
43
|
+
"mistral-7b": "mistralai/Mistral-7B-v0.1",
|
|
44
|
+
"mistral-7b-v0.2": "mistralai/Mistral-7B-Instruct-v0.2",
|
|
45
|
+
"mistral-7b-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
|
|
46
|
+
"mistral-nemo": "mistralai/Mistral-Nemo-Base-2407",
|
|
47
|
+
"mistral-small": "mistralai/Mistral-Small-Instruct-2409",
|
|
48
|
+
"mistral-large": "mistralai/Mistral-Large-Instruct-2407",
|
|
49
|
+
# Mixtral
|
|
50
|
+
"mixtral": "mistralai/Mixtral-8x7B-v0.1",
|
|
51
|
+
"mixtral-8x7b": "mistralai/Mixtral-8x7B-v0.1",
|
|
52
|
+
"mixtral-8x22b": "mistralai/Mixtral-8x22B-v0.1",
|
|
53
|
+
# Qwen family
|
|
54
|
+
"qwen": "Qwen/Qwen-7B",
|
|
55
|
+
"qwen-7b": "Qwen/Qwen-7B",
|
|
56
|
+
"qwen-14b": "Qwen/Qwen-14B",
|
|
57
|
+
"qwen-72b": "Qwen/Qwen-72B",
|
|
58
|
+
"qwen2": "Qwen/Qwen2-7B",
|
|
59
|
+
"qwen2-7b": "Qwen/Qwen2-7B",
|
|
60
|
+
"qwen2-72b": "Qwen/Qwen2-72B",
|
|
61
|
+
"qwen2.5": "Qwen/Qwen2.5-7B",
|
|
62
|
+
"qwen2.5-7b": "Qwen/Qwen2.5-7B",
|
|
63
|
+
"qwen2.5-72b": "Qwen/Qwen2.5-72B",
|
|
64
|
+
# DeepSeek
|
|
65
|
+
"deepseek": "deepseek-ai/deepseek-llm-7b-base",
|
|
66
|
+
"deepseek-7b": "deepseek-ai/deepseek-llm-7b-base",
|
|
67
|
+
"deepseek-67b": "deepseek-ai/deepseek-llm-67b-base",
|
|
68
|
+
"deepseek-coder": "deepseek-ai/deepseek-coder-6.7b-base",
|
|
69
|
+
"deepseek-v2": "deepseek-ai/DeepSeek-V2",
|
|
70
|
+
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
|
|
71
|
+
# Yi family
|
|
72
|
+
"yi": "01-ai/Yi-6B",
|
|
73
|
+
"yi-6b": "01-ai/Yi-6B",
|
|
74
|
+
"yi-34b": "01-ai/Yi-34B",
|
|
75
|
+
"yi-1.5": "01-ai/Yi-1.5-6B",
|
|
76
|
+
# Phi family
|
|
77
|
+
"phi-2": "microsoft/phi-2",
|
|
78
|
+
"phi-3": "microsoft/Phi-3-mini-4k-instruct",
|
|
79
|
+
"phi-3-mini": "microsoft/Phi-3-mini-4k-instruct",
|
|
80
|
+
"phi-3-small": "microsoft/Phi-3-small-8k-instruct",
|
|
81
|
+
"phi-3-medium": "microsoft/Phi-3-medium-4k-instruct",
|
|
82
|
+
# Falcon
|
|
83
|
+
"falcon": "tiiuae/falcon-7b",
|
|
84
|
+
"falcon-7b": "tiiuae/falcon-7b",
|
|
85
|
+
"falcon-40b": "tiiuae/falcon-40b",
|
|
86
|
+
"falcon-180b": "tiiuae/falcon-180B",
|
|
87
|
+
# StarCoder
|
|
88
|
+
"starcoder": "bigcode/starcoder",
|
|
89
|
+
"starcoder2": "bigcode/starcoder2-15b",
|
|
90
|
+
"starcoder2-3b": "bigcode/starcoder2-3b",
|
|
91
|
+
"starcoder2-7b": "bigcode/starcoder2-7b",
|
|
92
|
+
"starcoder2-15b": "bigcode/starcoder2-15b",
|
|
93
|
+
# MPT
|
|
94
|
+
"mpt-7b": "mosaicml/mpt-7b",
|
|
95
|
+
"mpt-30b": "mosaicml/mpt-30b",
|
|
96
|
+
# Gemma
|
|
97
|
+
"gemma": "google/gemma-7b",
|
|
98
|
+
"gemma-2b": "google/gemma-2b",
|
|
99
|
+
"gemma-7b": "google/gemma-7b",
|
|
100
|
+
"gemma-2": "google/gemma-2-9b",
|
|
101
|
+
"gemma-2-9b": "google/gemma-2-9b",
|
|
102
|
+
"gemma-2-27b": "google/gemma-2-27b",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@lru_cache(maxsize=16)
|
|
107
|
+
def _load_tokenizer(tokenizer_name: str):
|
|
108
|
+
"""Load and cache HuggingFace tokenizer.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
tokenizer_name: HuggingFace model/tokenizer name.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Loaded tokenizer, or None if unavailable.
|
|
115
|
+
"""
|
|
116
|
+
from transformers import AutoTokenizer
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
return AutoTokenizer.from_pretrained(
|
|
120
|
+
tokenizer_name,
|
|
121
|
+
trust_remote_code=True,
|
|
122
|
+
)
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.warning(f"Failed to load tokenizer {tokenizer_name}: {e}")
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_tokenizer_name(model: str) -> str:
|
|
129
|
+
"""Get HuggingFace tokenizer name for a model.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model: Model name.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
HuggingFace tokenizer identifier.
|
|
136
|
+
"""
|
|
137
|
+
model_lower = model.lower()
|
|
138
|
+
|
|
139
|
+
# Direct lookup
|
|
140
|
+
if model_lower in MODEL_TO_TOKENIZER:
|
|
141
|
+
return MODEL_TO_TOKENIZER[model_lower]
|
|
142
|
+
|
|
143
|
+
# Try prefix matching
|
|
144
|
+
for key, value in MODEL_TO_TOKENIZER.items():
|
|
145
|
+
if model_lower.startswith(key):
|
|
146
|
+
return value
|
|
147
|
+
|
|
148
|
+
# Assume model name is the tokenizer name
|
|
149
|
+
return model
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class HuggingFaceTokenizer(BaseTokenizer):
|
|
153
|
+
"""Token counter using HuggingFace tokenizers.
|
|
154
|
+
|
|
155
|
+
Supports any model with a HuggingFace tokenizer, including:
|
|
156
|
+
- Llama family (Llama 2, Llama 3, CodeLlama)
|
|
157
|
+
- Mistral family (Mistral, Mixtral)
|
|
158
|
+
- Qwen family
|
|
159
|
+
- DeepSeek family
|
|
160
|
+
- Phi family
|
|
161
|
+
- Falcon, StarCoder, MPT, Gemma, etc.
|
|
162
|
+
|
|
163
|
+
Requires the `transformers` library:
|
|
164
|
+
pip install transformers
|
|
165
|
+
|
|
166
|
+
Some models may require authentication:
|
|
167
|
+
huggingface-cli login
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
counter = HuggingFaceTokenizer("llama-3-8b")
|
|
171
|
+
tokens = counter.count_text("Hello, world!")
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
# Overhead per message (varies by model, this is a reasonable default)
|
|
175
|
+
MESSAGE_OVERHEAD = 4
|
|
176
|
+
REPLY_OVERHEAD = 3
|
|
177
|
+
|
|
178
|
+
def __init__(self, model: str):
|
|
179
|
+
"""Initialize HuggingFace tokenizer.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
model: Model name (e.g., 'llama-3-8b', 'mistral-7b').
|
|
183
|
+
"""
|
|
184
|
+
self.model = model
|
|
185
|
+
self.tokenizer_name = get_tokenizer_name(model)
|
|
186
|
+
self._tokenizer = None # Lazy load
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def tokenizer(self):
|
|
190
|
+
"""Lazy-load the tokenizer."""
|
|
191
|
+
if self._tokenizer is None:
|
|
192
|
+
loaded = _load_tokenizer(self.tokenizer_name)
|
|
193
|
+
if loaded is not None:
|
|
194
|
+
self._tokenizer = loaded
|
|
195
|
+
else:
|
|
196
|
+
# Mark as unavailable
|
|
197
|
+
self._tokenizer = False
|
|
198
|
+
return self._tokenizer if self._tokenizer is not False else None
|
|
199
|
+
|
|
200
|
+
def _use_fallback(self) -> bool:
|
|
201
|
+
"""Check if we need to use fallback estimation."""
|
|
202
|
+
return self.tokenizer is None
|
|
203
|
+
|
|
204
|
+
def count_text(self, text: str) -> int:
|
|
205
|
+
"""Count tokens in text.
|
|
206
|
+
|
|
207
|
+
Falls back to estimation if tokenizer unavailable.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
text: Text to tokenize.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Number of tokens.
|
|
214
|
+
"""
|
|
215
|
+
if not text:
|
|
216
|
+
return 0
|
|
217
|
+
if self._use_fallback():
|
|
218
|
+
# Fall back to ~4 chars per token estimation
|
|
219
|
+
return max(1, int(len(text) / 4 + 0.5))
|
|
220
|
+
tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
|
221
|
+
return len(tokens)
|
|
222
|
+
|
|
223
|
+
def count_messages(self, messages: list[dict[str, Any]]) -> int:
|
|
224
|
+
"""Count tokens in chat messages.
|
|
225
|
+
|
|
226
|
+
Uses the model's chat template if available, otherwise
|
|
227
|
+
falls back to base class implementation.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
messages: List of chat messages.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Total token count.
|
|
234
|
+
"""
|
|
235
|
+
if self._use_fallback():
|
|
236
|
+
# Use base class implementation with estimation
|
|
237
|
+
return super().count_messages(messages)
|
|
238
|
+
|
|
239
|
+
# Try to use chat template for accurate counting
|
|
240
|
+
if hasattr(self.tokenizer, "apply_chat_template"):
|
|
241
|
+
try:
|
|
242
|
+
# Apply chat template and count
|
|
243
|
+
formatted = self.tokenizer.apply_chat_template(
|
|
244
|
+
messages,
|
|
245
|
+
tokenize=True,
|
|
246
|
+
add_generation_prompt=True,
|
|
247
|
+
)
|
|
248
|
+
return len(formatted)
|
|
249
|
+
except Exception:
|
|
250
|
+
# Fall back to base implementation
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
return super().count_messages(messages)
|
|
254
|
+
|
|
255
|
+
def encode(self, text: str) -> list[int]:
|
|
256
|
+
"""Encode text to token IDs.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
text: Text to encode.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
List of token IDs.
|
|
263
|
+
|
|
264
|
+
Raises:
|
|
265
|
+
NotImplementedError: If tokenizer not available.
|
|
266
|
+
"""
|
|
267
|
+
if self._use_fallback():
|
|
268
|
+
raise NotImplementedError(
|
|
269
|
+
f"Encoding not available for {self.model} - "
|
|
270
|
+
f"tokenizer {self.tokenizer_name} could not be loaded"
|
|
271
|
+
)
|
|
272
|
+
return self.tokenizer.encode(text, add_special_tokens=False)
|
|
273
|
+
|
|
274
|
+
def decode(self, tokens: list[int]) -> str:
|
|
275
|
+
"""Decode token IDs to text.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
tokens: List of token IDs.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Decoded text.
|
|
282
|
+
|
|
283
|
+
Raises:
|
|
284
|
+
NotImplementedError: If tokenizer not available.
|
|
285
|
+
"""
|
|
286
|
+
if self._use_fallback():
|
|
287
|
+
raise NotImplementedError(
|
|
288
|
+
f"Decoding not available for {self.model} - "
|
|
289
|
+
f"tokenizer {self.tokenizer_name} could not be loaded"
|
|
290
|
+
)
|
|
291
|
+
return self.tokenizer.decode(tokens)
|
|
292
|
+
|
|
293
|
+
@classmethod
|
|
294
|
+
def is_available(cls) -> bool:
|
|
295
|
+
"""Check if HuggingFace tokenizers are available.
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
True if transformers is installed.
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
import transformers # noqa: F401
|
|
302
|
+
|
|
303
|
+
return True
|
|
304
|
+
except ImportError:
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
@classmethod
|
|
308
|
+
def list_supported_models(cls) -> list[str]:
|
|
309
|
+
"""List models with known tokenizer mappings.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
List of supported model names.
|
|
313
|
+
"""
|
|
314
|
+
return list(MODEL_TO_TOKENIZER.keys())
|
|
315
|
+
|
|
316
|
+
def __repr__(self) -> str:
|
|
317
|
+
return f"HuggingFaceTokenizer(model={self.model!r}, tokenizer={self.tokenizer_name!r})"
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""Mistral tokenizer using the official mistral-common package.
|
|
2
|
+
|
|
3
|
+
Mistral AI released their tokenizer publicly, making accurate
|
|
4
|
+
token counting possible without API calls.
|
|
5
|
+
|
|
6
|
+
Requires: pip install mistral-common
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from functools import lru_cache
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from .base import BaseTokenizer
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Check if mistral-common is available
|
|
20
|
+
try:
|
|
21
|
+
from mistral_common.protocol.instruct.messages import (
|
|
22
|
+
AssistantMessage,
|
|
23
|
+
SystemMessage,
|
|
24
|
+
UserMessage,
|
|
25
|
+
)
|
|
26
|
+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
27
|
+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer as _MistralTokenizer
|
|
28
|
+
|
|
29
|
+
MISTRAL_AVAILABLE = True
|
|
30
|
+
except ImportError:
|
|
31
|
+
MISTRAL_AVAILABLE = False
|
|
32
|
+
_MistralTokenizer = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def is_mistral_available() -> bool:
|
|
36
|
+
"""Check if mistral-common is installed."""
|
|
37
|
+
return MISTRAL_AVAILABLE
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Model to tokenizer version mapping
|
|
41
|
+
MODEL_TO_VERSION = {
|
|
42
|
+
# Mistral models use v3 tokenizer (tekken)
|
|
43
|
+
"mistral-large": "v3",
|
|
44
|
+
"mistral-large-latest": "v3",
|
|
45
|
+
"mistral-small": "v3",
|
|
46
|
+
"mistral-small-latest": "v3",
|
|
47
|
+
"ministral-8b": "v3",
|
|
48
|
+
"ministral-3b": "v3",
|
|
49
|
+
"mistral-nemo": "v3",
|
|
50
|
+
"pixtral-12b": "v3",
|
|
51
|
+
"codestral": "v3",
|
|
52
|
+
"codestral-latest": "v3",
|
|
53
|
+
# Mixtral uses v1
|
|
54
|
+
"mixtral-8x7b": "v1",
|
|
55
|
+
"mixtral-8x22b": "v1",
|
|
56
|
+
"open-mixtral-8x7b": "v1",
|
|
57
|
+
"open-mixtral-8x22b": "v1",
|
|
58
|
+
# Mistral 7B uses v1
|
|
59
|
+
"mistral-7b": "v1",
|
|
60
|
+
"open-mistral-7b": "v1",
|
|
61
|
+
"mistral-7b-instruct": "v1",
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@lru_cache(maxsize=4)
|
|
66
|
+
def _get_tokenizer(version: str):
|
|
67
|
+
"""Get and cache Mistral tokenizer by version."""
|
|
68
|
+
if not MISTRAL_AVAILABLE:
|
|
69
|
+
raise RuntimeError(
|
|
70
|
+
"mistral-common is required for MistralTokenizer. "
|
|
71
|
+
"Install with: pip install mistral-common"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if version == "v3":
|
|
75
|
+
return _MistralTokenizer.v3(is_tekken=True)
|
|
76
|
+
elif version == "v2":
|
|
77
|
+
return _MistralTokenizer.v2()
|
|
78
|
+
else: # v1
|
|
79
|
+
return _MistralTokenizer.v1()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_tokenizer_version(model: str) -> str:
|
|
83
|
+
"""Get tokenizer version for a model."""
|
|
84
|
+
model_lower = model.lower()
|
|
85
|
+
|
|
86
|
+
# Direct lookup
|
|
87
|
+
if model_lower in MODEL_TO_VERSION:
|
|
88
|
+
return MODEL_TO_VERSION[model_lower]
|
|
89
|
+
|
|
90
|
+
# Prefix matching
|
|
91
|
+
for prefix, version in [
|
|
92
|
+
("mistral-large", "v3"),
|
|
93
|
+
("mistral-small", "v3"),
|
|
94
|
+
("ministral", "v3"),
|
|
95
|
+
("codestral", "v3"),
|
|
96
|
+
("pixtral", "v3"),
|
|
97
|
+
("mistral-nemo", "v3"),
|
|
98
|
+
("mixtral", "v1"),
|
|
99
|
+
("mistral-7b", "v1"),
|
|
100
|
+
("open-mistral", "v1"),
|
|
101
|
+
]:
|
|
102
|
+
if model_lower.startswith(prefix):
|
|
103
|
+
return version
|
|
104
|
+
|
|
105
|
+
# Default to v3 for newer models
|
|
106
|
+
return "v3"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class MistralTokenizer(BaseTokenizer):
|
|
110
|
+
"""Token counter using Mistral's official tokenizer.
|
|
111
|
+
|
|
112
|
+
Uses mistral-common package for accurate token counting.
|
|
113
|
+
|
|
114
|
+
Requires: pip install mistral-common
|
|
115
|
+
|
|
116
|
+
Example:
|
|
117
|
+
counter = MistralTokenizer("mistral-large")
|
|
118
|
+
tokens = counter.count_text("Hello, world!")
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
MESSAGE_OVERHEAD = 4
|
|
122
|
+
REPLY_OVERHEAD = 3
|
|
123
|
+
|
|
124
|
+
def __init__(self, model: str = "mistral-large"):
|
|
125
|
+
"""Initialize Mistral tokenizer.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
model: Mistral model name.
|
|
129
|
+
"""
|
|
130
|
+
if not MISTRAL_AVAILABLE:
|
|
131
|
+
raise RuntimeError(
|
|
132
|
+
"mistral-common is required for MistralTokenizer. "
|
|
133
|
+
"Install with: pip install mistral-common"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
self.model = model
|
|
137
|
+
self.version = get_tokenizer_version(model)
|
|
138
|
+
self._tokenizer = None # Lazy load
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def tokenizer(self):
|
|
142
|
+
"""Lazy-load the tokenizer (MistralTokenizer object)."""
|
|
143
|
+
if self._tokenizer is None:
|
|
144
|
+
self._tokenizer = _get_tokenizer(self.version)
|
|
145
|
+
return self._tokenizer
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def _text_tokenizer(self):
|
|
149
|
+
"""Get the underlying text tokenizer for encode/decode."""
|
|
150
|
+
return self.tokenizer.instruct_tokenizer.tokenizer
|
|
151
|
+
|
|
152
|
+
def count_text(self, text: str) -> int:
|
|
153
|
+
"""Count tokens in text.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
text: Text to tokenize.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Number of tokens.
|
|
160
|
+
"""
|
|
161
|
+
if not text:
|
|
162
|
+
return 0
|
|
163
|
+
tokens = self._text_tokenizer.encode(text, bos=False, eos=False)
|
|
164
|
+
return len(tokens)
|
|
165
|
+
|
|
166
|
+
def count_messages(self, messages: list[dict[str, Any]]) -> int:
|
|
167
|
+
"""Count tokens in chat messages.
|
|
168
|
+
|
|
169
|
+
Uses Mistral's chat template for accurate counting.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
messages: List of chat messages.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Total token count.
|
|
176
|
+
"""
|
|
177
|
+
if not messages:
|
|
178
|
+
return 0
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
# Convert to Mistral message format
|
|
182
|
+
mistral_messages = []
|
|
183
|
+
for msg in messages:
|
|
184
|
+
role = msg.get("role", "user")
|
|
185
|
+
content = msg.get("content", "")
|
|
186
|
+
|
|
187
|
+
if isinstance(content, list):
|
|
188
|
+
# Multi-part content - extract text
|
|
189
|
+
text_parts = []
|
|
190
|
+
for part in content:
|
|
191
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
192
|
+
text_parts.append(part.get("text", ""))
|
|
193
|
+
elif isinstance(part, str):
|
|
194
|
+
text_parts.append(part)
|
|
195
|
+
content = "\n".join(text_parts)
|
|
196
|
+
|
|
197
|
+
if role == "user":
|
|
198
|
+
mistral_messages.append(UserMessage(content=content))
|
|
199
|
+
elif role == "assistant":
|
|
200
|
+
mistral_messages.append(AssistantMessage(content=content))
|
|
201
|
+
elif role == "system":
|
|
202
|
+
mistral_messages.append(SystemMessage(content=content))
|
|
203
|
+
else:
|
|
204
|
+
# Tool messages etc - treat as user
|
|
205
|
+
mistral_messages.append(UserMessage(content=content))
|
|
206
|
+
|
|
207
|
+
# Encode with chat template
|
|
208
|
+
request = ChatCompletionRequest(messages=mistral_messages)
|
|
209
|
+
tokenized = self.tokenizer.encode_chat_completion(request)
|
|
210
|
+
return len(tokenized.tokens)
|
|
211
|
+
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.debug(f"Mistral chat encoding failed: {e}, falling back to text counting")
|
|
214
|
+
# Fallback to base implementation
|
|
215
|
+
return super().count_messages(messages)
|
|
216
|
+
|
|
217
|
+
def encode(self, text: str) -> list[int]:
|
|
218
|
+
"""Encode text to token IDs.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
text: Text to encode.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
List of token IDs.
|
|
225
|
+
"""
|
|
226
|
+
return self._text_tokenizer.encode(text, bos=False, eos=False)
|
|
227
|
+
|
|
228
|
+
def decode(self, tokens: list[int]) -> str:
|
|
229
|
+
"""Decode token IDs to text.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
tokens: List of token IDs.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Decoded text.
|
|
236
|
+
"""
|
|
237
|
+
return self._text_tokenizer.decode(tokens)
|
|
238
|
+
|
|
239
|
+
@classmethod
|
|
240
|
+
def is_available(cls) -> bool:
|
|
241
|
+
"""Check if Mistral tokenizer is available."""
|
|
242
|
+
return MISTRAL_AVAILABLE
|
|
243
|
+
|
|
244
|
+
def __repr__(self) -> str:
|
|
245
|
+
return f"MistralTokenizer(model={self.model!r}, version={self.version!r})"
|