stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
"""Provider model validation utility.
|
|
2
|
+
|
|
3
|
+
Validates model availability for all providers using their respective APIs.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
from typing import Dict, List, Any, Optional
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def validate_provider_models(
|
|
12
|
+
provider: str,
|
|
13
|
+
model_ids: List[str],
|
|
14
|
+
api_key: Optional[str] = None,
|
|
15
|
+
) -> Dict[str, Any]:
|
|
16
|
+
"""
|
|
17
|
+
Validate which models are available for a given provider.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
provider: Provider name (openai, anthropic, google, etc.)
|
|
21
|
+
model_ids: List of model IDs to validate
|
|
22
|
+
api_key: Optional API key (will use env var if not provided)
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Dict containing:
|
|
26
|
+
- valid_models: List of model IDs that are available
|
|
27
|
+
- invalid_models: List of model IDs that are NOT available
|
|
28
|
+
- validation_time_ms: Time taken to validate in milliseconds
|
|
29
|
+
- error: Error message if validation failed (None if successful)
|
|
30
|
+
"""
|
|
31
|
+
validators = {
|
|
32
|
+
"openai": _validate_openai,
|
|
33
|
+
"anthropic": _validate_anthropic,
|
|
34
|
+
"google": _validate_google,
|
|
35
|
+
"deepseek": _validate_deepseek,
|
|
36
|
+
"groq": _validate_groq,
|
|
37
|
+
"grok": _validate_grok,
|
|
38
|
+
"openrouter": _validate_openrouter,
|
|
39
|
+
"ollama": _validate_ollama,
|
|
40
|
+
"bedrock": _validate_bedrock,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
validator = validators.get(provider)
|
|
44
|
+
if not validator:
|
|
45
|
+
return {
|
|
46
|
+
"valid_models": model_ids,
|
|
47
|
+
"invalid_models": [],
|
|
48
|
+
"validation_time_ms": 0,
|
|
49
|
+
"error": f"No validator for provider: {provider}",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
return validator(model_ids, api_key)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _validate_openai(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
56
|
+
"""Validate OpenAI models using models.list() API."""
|
|
57
|
+
result = _init_result()
|
|
58
|
+
start_time = time.time()
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
import openai
|
|
62
|
+
|
|
63
|
+
key = api_key or os.getenv("OPENAI_API_KEY")
|
|
64
|
+
if not key:
|
|
65
|
+
result["error"] = "OPENAI_API_KEY not configured"
|
|
66
|
+
result["valid_models"] = model_ids
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
client = openai.OpenAI(api_key=key)
|
|
70
|
+
response = client.models.list()
|
|
71
|
+
available_ids = {model.id for model in response.data}
|
|
72
|
+
|
|
73
|
+
for model_id in model_ids:
|
|
74
|
+
if model_id in available_ids:
|
|
75
|
+
result["valid_models"].append(model_id)
|
|
76
|
+
else:
|
|
77
|
+
result["invalid_models"].append(model_id)
|
|
78
|
+
|
|
79
|
+
except ImportError:
|
|
80
|
+
result["error"] = "openai package not installed"
|
|
81
|
+
result["valid_models"] = model_ids
|
|
82
|
+
except Exception as e:
|
|
83
|
+
result["error"] = f"Validation failed: {str(e)}"
|
|
84
|
+
result["valid_models"] = model_ids
|
|
85
|
+
finally:
|
|
86
|
+
result["validation_time_ms"] = int((time.time() - start_time) * 1000)
|
|
87
|
+
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _validate_anthropic(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
92
|
+
"""Validate Anthropic credentials (no models list API available)."""
|
|
93
|
+
result = _init_result()
|
|
94
|
+
start_time = time.time()
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
import anthropic
|
|
98
|
+
|
|
99
|
+
key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
100
|
+
if not key:
|
|
101
|
+
result["error"] = "ANTHROPIC_API_KEY not configured"
|
|
102
|
+
result["valid_models"] = model_ids
|
|
103
|
+
return result
|
|
104
|
+
|
|
105
|
+
# Anthropic doesn't have a models list API, so we just verify auth
|
|
106
|
+
# by checking if the key format looks valid
|
|
107
|
+
if key.startswith("sk-ant-"):
|
|
108
|
+
result["valid_models"] = model_ids
|
|
109
|
+
else:
|
|
110
|
+
result["error"] = "Invalid API key format"
|
|
111
|
+
result["valid_models"] = model_ids
|
|
112
|
+
|
|
113
|
+
except ImportError:
|
|
114
|
+
result["error"] = "anthropic package not installed"
|
|
115
|
+
result["valid_models"] = model_ids
|
|
116
|
+
except Exception as e:
|
|
117
|
+
result["error"] = f"Validation failed: {str(e)}"
|
|
118
|
+
result["valid_models"] = model_ids
|
|
119
|
+
finally:
|
|
120
|
+
result["validation_time_ms"] = int((time.time() - start_time) * 1000)
|
|
121
|
+
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _validate_google(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
126
|
+
"""Validate Google models using client.models.list()."""
|
|
127
|
+
result = _init_result()
|
|
128
|
+
start_time = time.time()
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
from google import genai
|
|
132
|
+
|
|
133
|
+
key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
134
|
+
if not key:
|
|
135
|
+
result["error"] = "GOOGLE_API_KEY not configured"
|
|
136
|
+
result["valid_models"] = model_ids
|
|
137
|
+
return result
|
|
138
|
+
|
|
139
|
+
client = genai.Client(api_key=key)
|
|
140
|
+
models = client.models.list()
|
|
141
|
+
|
|
142
|
+
# Extract model names (they come as "models/gemini-2.5-pro" format)
|
|
143
|
+
available_ids = set()
|
|
144
|
+
for model in models:
|
|
145
|
+
name = model.name.replace("models/", "")
|
|
146
|
+
available_ids.add(name)
|
|
147
|
+
# Also add without version suffix for matching
|
|
148
|
+
if "-" in name:
|
|
149
|
+
available_ids.add(name.rsplit("-", 1)[0])
|
|
150
|
+
|
|
151
|
+
for model_id in model_ids:
|
|
152
|
+
# Check exact match or prefix match
|
|
153
|
+
if model_id in available_ids or any(model_id in aid for aid in available_ids):
|
|
154
|
+
result["valid_models"].append(model_id)
|
|
155
|
+
else:
|
|
156
|
+
result["invalid_models"].append(model_id)
|
|
157
|
+
|
|
158
|
+
except ImportError:
|
|
159
|
+
result["error"] = "google-genai package not installed"
|
|
160
|
+
result["valid_models"] = model_ids
|
|
161
|
+
except Exception as e:
|
|
162
|
+
result["error"] = f"Validation failed: {str(e)}"
|
|
163
|
+
result["valid_models"] = model_ids
|
|
164
|
+
finally:
|
|
165
|
+
result["validation_time_ms"] = int((time.time() - start_time) * 1000)
|
|
166
|
+
|
|
167
|
+
return result
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _validate_openai_compatible(
|
|
171
|
+
model_ids: List[str],
|
|
172
|
+
base_url: str,
|
|
173
|
+
api_key: Optional[str],
|
|
174
|
+
env_var: str,
|
|
175
|
+
) -> Dict[str, Any]:
|
|
176
|
+
"""Generic validator for OpenAI-compatible APIs."""
|
|
177
|
+
result = _init_result()
|
|
178
|
+
start_time = time.time()
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
import httpx
|
|
182
|
+
|
|
183
|
+
key = api_key or os.getenv(env_var)
|
|
184
|
+
if not key:
|
|
185
|
+
result["error"] = f"{env_var} not configured"
|
|
186
|
+
result["valid_models"] = model_ids
|
|
187
|
+
return result
|
|
188
|
+
|
|
189
|
+
headers = {"Authorization": f"Bearer {key}"}
|
|
190
|
+
|
|
191
|
+
with httpx.Client(timeout=10.0) as client:
|
|
192
|
+
response = client.get(f"{base_url}/models", headers=headers)
|
|
193
|
+
response.raise_for_status()
|
|
194
|
+
data = response.json()
|
|
195
|
+
|
|
196
|
+
# Extract model IDs from response
|
|
197
|
+
available_ids = set()
|
|
198
|
+
for model in data.get("data", []):
|
|
199
|
+
available_ids.add(model.get("id", ""))
|
|
200
|
+
|
|
201
|
+
for model_id in model_ids:
|
|
202
|
+
if model_id in available_ids:
|
|
203
|
+
result["valid_models"].append(model_id)
|
|
204
|
+
else:
|
|
205
|
+
result["invalid_models"].append(model_id)
|
|
206
|
+
|
|
207
|
+
except ImportError:
|
|
208
|
+
result["error"] = "httpx package not installed"
|
|
209
|
+
result["valid_models"] = model_ids
|
|
210
|
+
except Exception as e:
|
|
211
|
+
result["error"] = f"Validation failed: {str(e)}"
|
|
212
|
+
result["valid_models"] = model_ids
|
|
213
|
+
finally:
|
|
214
|
+
result["validation_time_ms"] = int((time.time() - start_time) * 1000)
|
|
215
|
+
|
|
216
|
+
return result
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _validate_deepseek(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
220
|
+
"""Validate DeepSeek models using OpenAI-compatible API."""
|
|
221
|
+
return _validate_openai_compatible(
|
|
222
|
+
model_ids,
|
|
223
|
+
"https://api.deepseek.com/v1",
|
|
224
|
+
api_key,
|
|
225
|
+
"DEEPSEEK_API_KEY",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _validate_groq(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
230
|
+
"""Validate Groq models using OpenAI-compatible API."""
|
|
231
|
+
return _validate_openai_compatible(
|
|
232
|
+
model_ids,
|
|
233
|
+
"https://api.groq.com/openai/v1",
|
|
234
|
+
api_key,
|
|
235
|
+
"GROQ_API_KEY",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _validate_grok(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
240
|
+
"""Validate Grok (X.AI) models using OpenAI-compatible API."""
|
|
241
|
+
return _validate_openai_compatible(
|
|
242
|
+
model_ids,
|
|
243
|
+
"https://api.x.ai/v1",
|
|
244
|
+
api_key,
|
|
245
|
+
"GROK_API_KEY",
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _validate_openrouter(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
250
|
+
"""Validate OpenRouter models using their models API."""
|
|
251
|
+
result = _init_result()
|
|
252
|
+
start_time = time.time()
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
import httpx
|
|
256
|
+
|
|
257
|
+
key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
258
|
+
if not key:
|
|
259
|
+
result["error"] = "OPENROUTER_API_KEY not configured"
|
|
260
|
+
result["valid_models"] = model_ids
|
|
261
|
+
return result
|
|
262
|
+
|
|
263
|
+
headers = {"Authorization": f"Bearer {key}"}
|
|
264
|
+
|
|
265
|
+
with httpx.Client(timeout=10.0) as client:
|
|
266
|
+
response = client.get(
|
|
267
|
+
"https://openrouter.ai/api/v1/models",
|
|
268
|
+
headers=headers,
|
|
269
|
+
)
|
|
270
|
+
response.raise_for_status()
|
|
271
|
+
data = response.json()
|
|
272
|
+
|
|
273
|
+
# Extract model IDs from response
|
|
274
|
+
available_ids = set()
|
|
275
|
+
for model in data.get("data", []):
|
|
276
|
+
available_ids.add(model.get("id", ""))
|
|
277
|
+
|
|
278
|
+
for model_id in model_ids:
|
|
279
|
+
if model_id in available_ids:
|
|
280
|
+
result["valid_models"].append(model_id)
|
|
281
|
+
else:
|
|
282
|
+
result["invalid_models"].append(model_id)
|
|
283
|
+
|
|
284
|
+
except ImportError:
|
|
285
|
+
result["error"] = "httpx package not installed"
|
|
286
|
+
result["valid_models"] = model_ids
|
|
287
|
+
except Exception as e:
|
|
288
|
+
result["error"] = f"Validation failed: {str(e)}"
|
|
289
|
+
result["valid_models"] = model_ids
|
|
290
|
+
finally:
|
|
291
|
+
result["validation_time_ms"] = int((time.time() - start_time) * 1000)
|
|
292
|
+
|
|
293
|
+
return result
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _validate_ollama(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
297
|
+
"""Validate Ollama models using local API."""
|
|
298
|
+
result = _init_result()
|
|
299
|
+
start_time = time.time()
|
|
300
|
+
|
|
301
|
+
try:
|
|
302
|
+
import httpx
|
|
303
|
+
|
|
304
|
+
base_url = os.getenv("OLLAMA_HOST", "http://localhost:11434")
|
|
305
|
+
|
|
306
|
+
with httpx.Client(timeout=5.0) as client:
|
|
307
|
+
response = client.get(f"{base_url}/api/tags")
|
|
308
|
+
response.raise_for_status()
|
|
309
|
+
data = response.json()
|
|
310
|
+
|
|
311
|
+
# Extract model names from response
|
|
312
|
+
available_ids = set()
|
|
313
|
+
for model in data.get("models", []):
|
|
314
|
+
name = model.get("name", "")
|
|
315
|
+
available_ids.add(name)
|
|
316
|
+
# Also add without tag suffix (e.g., "llama3.2" from "llama3.2:latest")
|
|
317
|
+
if ":" in name:
|
|
318
|
+
available_ids.add(name.split(":")[0])
|
|
319
|
+
|
|
320
|
+
for model_id in model_ids:
|
|
321
|
+
if model_id in available_ids:
|
|
322
|
+
result["valid_models"].append(model_id)
|
|
323
|
+
else:
|
|
324
|
+
result["invalid_models"].append(model_id)
|
|
325
|
+
|
|
326
|
+
except ImportError:
|
|
327
|
+
result["error"] = "httpx package not installed"
|
|
328
|
+
result["valid_models"] = model_ids
|
|
329
|
+
except Exception as e:
|
|
330
|
+
error_msg = str(e)
|
|
331
|
+
if "Connection refused" in error_msg or "ConnectError" in error_msg:
|
|
332
|
+
result["error"] = "Ollama not running (start with: ollama serve)"
|
|
333
|
+
else:
|
|
334
|
+
result["error"] = f"Validation failed: {error_msg}"
|
|
335
|
+
result["valid_models"] = model_ids
|
|
336
|
+
finally:
|
|
337
|
+
result["validation_time_ms"] = int((time.time() - start_time) * 1000)
|
|
338
|
+
|
|
339
|
+
return result
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _validate_bedrock(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
|
|
343
|
+
"""Validate Bedrock models using boto3."""
|
|
344
|
+
# Import the existing bedrock validator
|
|
345
|
+
from .bedrock_validator import validate_bedrock_models
|
|
346
|
+
return validate_bedrock_models(model_ids)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _init_result() -> Dict[str, Any]:
|
|
350
|
+
"""Initialize empty result dict."""
|
|
351
|
+
return {
|
|
352
|
+
"valid_models": [],
|
|
353
|
+
"invalid_models": [],
|
|
354
|
+
"validation_time_ms": 0,
|
|
355
|
+
"error": None,
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def get_validated_interactive_models(provider: str, all_models: bool = False) -> Dict[str, Any]:
|
|
360
|
+
"""
|
|
361
|
+
Get validated models for a provider with metadata.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
provider: Provider name
|
|
365
|
+
all_models: If True, validate ALL models for the provider (not just curated)
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
Dict containing:
|
|
369
|
+
- models: Dict mapping model_id to metadata
|
|
370
|
+
- validation_result: Full validation result dict
|
|
371
|
+
"""
|
|
372
|
+
from ..config import (
|
|
373
|
+
INTERACTIVE_OPENAI_MODELS,
|
|
374
|
+
INTERACTIVE_ANTHROPIC_MODELS,
|
|
375
|
+
INTERACTIVE_GOOGLE_MODELS,
|
|
376
|
+
INTERACTIVE_DEEPSEEK_MODELS,
|
|
377
|
+
INTERACTIVE_GROQ_MODELS,
|
|
378
|
+
INTERACTIVE_GROK_MODELS,
|
|
379
|
+
INTERACTIVE_OPENROUTER_MODELS,
|
|
380
|
+
INTERACTIVE_OLLAMA_MODELS,
|
|
381
|
+
INTERACTIVE_BEDROCK_MODELS,
|
|
382
|
+
OPENAI_MODELS,
|
|
383
|
+
ANTHROPIC_MODELS,
|
|
384
|
+
GOOGLE_MODELS,
|
|
385
|
+
DEEPSEEK_MODELS,
|
|
386
|
+
GROQ_MODELS,
|
|
387
|
+
GROK_MODELS,
|
|
388
|
+
OPENROUTER_MODELS,
|
|
389
|
+
OLLAMA_MODELS,
|
|
390
|
+
BEDROCK_MODELS,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Map provider to interactive and full model configs
|
|
394
|
+
provider_configs = {
|
|
395
|
+
"openai": (INTERACTIVE_OPENAI_MODELS, OPENAI_MODELS),
|
|
396
|
+
"anthropic": (INTERACTIVE_ANTHROPIC_MODELS, ANTHROPIC_MODELS),
|
|
397
|
+
"google": (INTERACTIVE_GOOGLE_MODELS, GOOGLE_MODELS),
|
|
398
|
+
"deepseek": (INTERACTIVE_DEEPSEEK_MODELS, DEEPSEEK_MODELS),
|
|
399
|
+
"groq": (INTERACTIVE_GROQ_MODELS, GROQ_MODELS),
|
|
400
|
+
"grok": (INTERACTIVE_GROK_MODELS, GROK_MODELS),
|
|
401
|
+
"openrouter": (INTERACTIVE_OPENROUTER_MODELS, OPENROUTER_MODELS),
|
|
402
|
+
"ollama": (INTERACTIVE_OLLAMA_MODELS, OLLAMA_MODELS),
|
|
403
|
+
"bedrock": (INTERACTIVE_BEDROCK_MODELS, BEDROCK_MODELS),
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
if provider not in provider_configs:
|
|
407
|
+
return {
|
|
408
|
+
"models": {},
|
|
409
|
+
"validation_result": {
|
|
410
|
+
"valid_models": [],
|
|
411
|
+
"invalid_models": [],
|
|
412
|
+
"validation_time_ms": 0,
|
|
413
|
+
"error": f"Unknown provider: {provider}",
|
|
414
|
+
},
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
interactive_models, full_models = provider_configs[provider]
|
|
418
|
+
|
|
419
|
+
# Use all models or just curated interactive models
|
|
420
|
+
if all_models:
|
|
421
|
+
model_ids = list(full_models.keys())
|
|
422
|
+
else:
|
|
423
|
+
model_ids = list(interactive_models.keys())
|
|
424
|
+
|
|
425
|
+
# Validate
|
|
426
|
+
validation_result = validate_provider_models(provider, model_ids)
|
|
427
|
+
|
|
428
|
+
# Build validated models dict with full metadata
|
|
429
|
+
models = {}
|
|
430
|
+
for model_id in validation_result["valid_models"]:
|
|
431
|
+
interactive_meta = interactive_models.get(model_id, {})
|
|
432
|
+
full_config = full_models.get(model_id, {})
|
|
433
|
+
|
|
434
|
+
models[model_id] = {
|
|
435
|
+
**full_config,
|
|
436
|
+
**interactive_meta,
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
return {
|
|
440
|
+
"models": models,
|
|
441
|
+
"validation_result": validation_result,
|
|
442
|
+
}
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Token counting utilities for estimating LLM token usage."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
import tiktoken
|
|
5
|
+
|
|
6
|
+
from ..models import Message
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def estimate_tokens(text: str, provider: str = "openai", model: Optional[str] = None) -> int:
|
|
10
|
+
"""
|
|
11
|
+
Estimate the number of tokens in a text string.
|
|
12
|
+
|
|
13
|
+
Uses provider-specific tokenizers when available, falls back to
|
|
14
|
+
character-based estimation (1 token ≈ 4 characters).
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
text: The text to estimate tokens for
|
|
18
|
+
provider: The LLM provider (openai, anthropic, google, etc.)
|
|
19
|
+
model: Optional specific model name for more accurate counting
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Estimated number of tokens
|
|
23
|
+
|
|
24
|
+
Examples:
|
|
25
|
+
>>> estimate_tokens("Hello, world!", provider="openai")
|
|
26
|
+
4
|
|
27
|
+
>>> estimate_tokens("Hello, world!", provider="anthropic")
|
|
28
|
+
3
|
|
29
|
+
"""
|
|
30
|
+
if not text:
|
|
31
|
+
return 0
|
|
32
|
+
|
|
33
|
+
# OpenAI models - use tiktoken
|
|
34
|
+
if provider == "openai":
|
|
35
|
+
try:
|
|
36
|
+
# Use model-specific encoding if provided
|
|
37
|
+
if model:
|
|
38
|
+
# Map common model names to encodings
|
|
39
|
+
if model.startswith(("gpt-4", "gpt-3.5")):
|
|
40
|
+
encoding = tiktoken.encoding_for_model(model)
|
|
41
|
+
elif model.startswith(("o1", "o3")):
|
|
42
|
+
# o1/o3 models use same encoding as gpt-4
|
|
43
|
+
encoding = tiktoken.encoding_for_model("gpt-4")
|
|
44
|
+
else:
|
|
45
|
+
# Default to cl100k_base (gpt-4, gpt-3.5-turbo)
|
|
46
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
47
|
+
else:
|
|
48
|
+
# Default to cl100k_base
|
|
49
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
50
|
+
|
|
51
|
+
return len(encoding.encode(text))
|
|
52
|
+
except Exception:
|
|
53
|
+
# Fall back to character-based if tiktoken fails
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
# Anthropic models - approximate with character count
|
|
57
|
+
# Claude tokenizer is more aggressive than OpenAI
|
|
58
|
+
# Roughly 1 token ≈ 3.5 characters for English text
|
|
59
|
+
if provider == "anthropic":
|
|
60
|
+
return int(len(text) / 3.5)
|
|
61
|
+
|
|
62
|
+
# Google Gemini - similar to OpenAI
|
|
63
|
+
if provider == "google":
|
|
64
|
+
return int(len(text) / 4)
|
|
65
|
+
|
|
66
|
+
# DeepSeek, Groq, Grok, OpenRouter - approximate with OpenAI encoding
|
|
67
|
+
if provider in ["deepseek", "groq", "grok", "openrouter"]:
|
|
68
|
+
try:
|
|
69
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
70
|
+
return len(encoding.encode(text))
|
|
71
|
+
except Exception:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
# Ollama - local models, use conservative estimate
|
|
75
|
+
if provider == "ollama":
|
|
76
|
+
return int(len(text) / 4)
|
|
77
|
+
|
|
78
|
+
# Default fallback: 1 token ≈ 4 characters
|
|
79
|
+
return int(len(text) / 4)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def count_tokens_for_messages(messages: List[Message], provider: str = "openai", model: Optional[str] = None) -> int:
|
|
83
|
+
"""
|
|
84
|
+
Count tokens for a list of messages, including formatting overhead.
|
|
85
|
+
|
|
86
|
+
Different models have different formatting requirements that add tokens.
|
|
87
|
+
This function accounts for those overheads.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
messages: List of Message objects
|
|
91
|
+
provider: The LLM provider
|
|
92
|
+
model: Optional specific model name
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Total estimated token count including formatting
|
|
96
|
+
|
|
97
|
+
Examples:
|
|
98
|
+
>>> from stratifyai.models import Message
|
|
99
|
+
>>> messages = [Message(role="user", content="Hello")]
|
|
100
|
+
>>> count_tokens_for_messages(messages, provider="openai")
|
|
101
|
+
7 # Content tokens + formatting tokens
|
|
102
|
+
"""
|
|
103
|
+
if not messages:
|
|
104
|
+
return 0
|
|
105
|
+
|
|
106
|
+
# Count content tokens
|
|
107
|
+
total_tokens = 0
|
|
108
|
+
for message in messages:
|
|
109
|
+
# Count message content
|
|
110
|
+
total_tokens += estimate_tokens(message.content, provider, model)
|
|
111
|
+
|
|
112
|
+
# Add role tokens
|
|
113
|
+
total_tokens += estimate_tokens(message.role, provider, model)
|
|
114
|
+
|
|
115
|
+
# Add formatting overhead per message
|
|
116
|
+
# OpenAI format: <|start|>role\ncontent<|end|>\n
|
|
117
|
+
# Roughly 4-7 tokens per message for formatting
|
|
118
|
+
if provider == "openai":
|
|
119
|
+
tokens_per_message = 4
|
|
120
|
+
if model and model.startswith("gpt-3.5"):
|
|
121
|
+
tokens_per_message = 4
|
|
122
|
+
elif model and model.startswith("gpt-4"):
|
|
123
|
+
tokens_per_message = 3
|
|
124
|
+
total_tokens += tokens_per_message * len(messages)
|
|
125
|
+
total_tokens += 3 # Every reply is primed with <|start|>assistant<|message|>
|
|
126
|
+
|
|
127
|
+
# Anthropic Messages API has minimal overhead
|
|
128
|
+
elif provider == "anthropic":
|
|
129
|
+
total_tokens += 2 * len(messages) # Minimal formatting overhead
|
|
130
|
+
|
|
131
|
+
# Other providers - approximate 3 tokens per message
|
|
132
|
+
else:
|
|
133
|
+
total_tokens += 3 * len(messages)
|
|
134
|
+
|
|
135
|
+
return total_tokens
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def get_context_window(provider: str, model: str) -> int:
|
|
139
|
+
"""
|
|
140
|
+
Get the context window size for a specific model.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
provider: The LLM provider
|
|
144
|
+
model: The model name
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Context window size in tokens
|
|
148
|
+
"""
|
|
149
|
+
from ..config import MODEL_CATALOG
|
|
150
|
+
|
|
151
|
+
model_info = MODEL_CATALOG.get(provider, {}).get(model, {})
|
|
152
|
+
return model_info.get("context", 128000) # Default to 128k
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def check_token_limit(token_count: int, provider: str, model: str, threshold: float = 0.8) -> tuple[bool, int, float]:
|
|
156
|
+
"""
|
|
157
|
+
Check if token count is approaching the model's context limit.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
token_count: Number of tokens
|
|
161
|
+
provider: The LLM provider
|
|
162
|
+
model: The model name
|
|
163
|
+
threshold: Warning threshold (default 0.8 = 80%)
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Tuple of (exceeds_threshold, context_window, percentage_used)
|
|
167
|
+
|
|
168
|
+
Examples:
|
|
169
|
+
>>> check_token_limit(100000, "openai", "gpt-4o", threshold=0.8)
|
|
170
|
+
(False, 128000, 0.78125)
|
|
171
|
+
>>> check_token_limit(110000, "openai", "gpt-4o", threshold=0.8)
|
|
172
|
+
(True, 128000, 0.859375)
|
|
173
|
+
"""
|
|
174
|
+
context_window = get_context_window(provider, model)
|
|
175
|
+
|
|
176
|
+
# Check for API-imposed input limits (e.g., Claude Opus 4.5)
|
|
177
|
+
from ..config import MODEL_CATALOG
|
|
178
|
+
model_info = MODEL_CATALOG.get(provider, {}).get(model, {})
|
|
179
|
+
api_max_input = model_info.get("api_max_input")
|
|
180
|
+
if api_max_input and api_max_input < context_window:
|
|
181
|
+
context_window = api_max_input
|
|
182
|
+
|
|
183
|
+
percentage_used = token_count / context_window if context_window > 0 else 1.0
|
|
184
|
+
exceeds_threshold = percentage_used > threshold
|
|
185
|
+
|
|
186
|
+
return exceeds_threshold, context_window, percentage_used
|