stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,442 @@
1
+ """Provider model validation utility.
2
+
3
+ Validates model availability for all providers using their respective APIs.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ from typing import Dict, List, Any, Optional
9
+
10
+
11
+ def validate_provider_models(
12
+ provider: str,
13
+ model_ids: List[str],
14
+ api_key: Optional[str] = None,
15
+ ) -> Dict[str, Any]:
16
+ """
17
+ Validate which models are available for a given provider.
18
+
19
+ Args:
20
+ provider: Provider name (openai, anthropic, google, etc.)
21
+ model_ids: List of model IDs to validate
22
+ api_key: Optional API key (will use env var if not provided)
23
+
24
+ Returns:
25
+ Dict containing:
26
+ - valid_models: List of model IDs that are available
27
+ - invalid_models: List of model IDs that are NOT available
28
+ - validation_time_ms: Time taken to validate in milliseconds
29
+ - error: Error message if validation failed (None if successful)
30
+ """
31
+ validators = {
32
+ "openai": _validate_openai,
33
+ "anthropic": _validate_anthropic,
34
+ "google": _validate_google,
35
+ "deepseek": _validate_deepseek,
36
+ "groq": _validate_groq,
37
+ "grok": _validate_grok,
38
+ "openrouter": _validate_openrouter,
39
+ "ollama": _validate_ollama,
40
+ "bedrock": _validate_bedrock,
41
+ }
42
+
43
+ validator = validators.get(provider)
44
+ if not validator:
45
+ return {
46
+ "valid_models": model_ids,
47
+ "invalid_models": [],
48
+ "validation_time_ms": 0,
49
+ "error": f"No validator for provider: {provider}",
50
+ }
51
+
52
+ return validator(model_ids, api_key)
53
+
54
+
55
+ def _validate_openai(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
56
+ """Validate OpenAI models using models.list() API."""
57
+ result = _init_result()
58
+ start_time = time.time()
59
+
60
+ try:
61
+ import openai
62
+
63
+ key = api_key or os.getenv("OPENAI_API_KEY")
64
+ if not key:
65
+ result["error"] = "OPENAI_API_KEY not configured"
66
+ result["valid_models"] = model_ids
67
+ return result
68
+
69
+ client = openai.OpenAI(api_key=key)
70
+ response = client.models.list()
71
+ available_ids = {model.id for model in response.data}
72
+
73
+ for model_id in model_ids:
74
+ if model_id in available_ids:
75
+ result["valid_models"].append(model_id)
76
+ else:
77
+ result["invalid_models"].append(model_id)
78
+
79
+ except ImportError:
80
+ result["error"] = "openai package not installed"
81
+ result["valid_models"] = model_ids
82
+ except Exception as e:
83
+ result["error"] = f"Validation failed: {str(e)}"
84
+ result["valid_models"] = model_ids
85
+ finally:
86
+ result["validation_time_ms"] = int((time.time() - start_time) * 1000)
87
+
88
+ return result
89
+
90
+
91
+ def _validate_anthropic(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
92
+ """Validate Anthropic credentials (no models list API available)."""
93
+ result = _init_result()
94
+ start_time = time.time()
95
+
96
+ try:
97
+ import anthropic
98
+
99
+ key = api_key or os.getenv("ANTHROPIC_API_KEY")
100
+ if not key:
101
+ result["error"] = "ANTHROPIC_API_KEY not configured"
102
+ result["valid_models"] = model_ids
103
+ return result
104
+
105
+ # Anthropic doesn't have a models list API, so we just verify auth
106
+ # by checking if the key format looks valid
107
+ if key.startswith("sk-ant-"):
108
+ result["valid_models"] = model_ids
109
+ else:
110
+ result["error"] = "Invalid API key format"
111
+ result["valid_models"] = model_ids
112
+
113
+ except ImportError:
114
+ result["error"] = "anthropic package not installed"
115
+ result["valid_models"] = model_ids
116
+ except Exception as e:
117
+ result["error"] = f"Validation failed: {str(e)}"
118
+ result["valid_models"] = model_ids
119
+ finally:
120
+ result["validation_time_ms"] = int((time.time() - start_time) * 1000)
121
+
122
+ return result
123
+
124
+
125
+ def _validate_google(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
126
+ """Validate Google models using client.models.list()."""
127
+ result = _init_result()
128
+ start_time = time.time()
129
+
130
+ try:
131
+ from google import genai
132
+
133
+ key = api_key or os.getenv("GOOGLE_API_KEY")
134
+ if not key:
135
+ result["error"] = "GOOGLE_API_KEY not configured"
136
+ result["valid_models"] = model_ids
137
+ return result
138
+
139
+ client = genai.Client(api_key=key)
140
+ models = client.models.list()
141
+
142
+ # Extract model names (they come as "models/gemini-2.5-pro" format)
143
+ available_ids = set()
144
+ for model in models:
145
+ name = model.name.replace("models/", "")
146
+ available_ids.add(name)
147
+ # Also add without version suffix for matching
148
+ if "-" in name:
149
+ available_ids.add(name.rsplit("-", 1)[0])
150
+
151
+ for model_id in model_ids:
152
+ # Check exact match or prefix match
153
+ if model_id in available_ids or any(model_id in aid for aid in available_ids):
154
+ result["valid_models"].append(model_id)
155
+ else:
156
+ result["invalid_models"].append(model_id)
157
+
158
+ except ImportError:
159
+ result["error"] = "google-genai package not installed"
160
+ result["valid_models"] = model_ids
161
+ except Exception as e:
162
+ result["error"] = f"Validation failed: {str(e)}"
163
+ result["valid_models"] = model_ids
164
+ finally:
165
+ result["validation_time_ms"] = int((time.time() - start_time) * 1000)
166
+
167
+ return result
168
+
169
+
170
+ def _validate_openai_compatible(
171
+ model_ids: List[str],
172
+ base_url: str,
173
+ api_key: Optional[str],
174
+ env_var: str,
175
+ ) -> Dict[str, Any]:
176
+ """Generic validator for OpenAI-compatible APIs."""
177
+ result = _init_result()
178
+ start_time = time.time()
179
+
180
+ try:
181
+ import httpx
182
+
183
+ key = api_key or os.getenv(env_var)
184
+ if not key:
185
+ result["error"] = f"{env_var} not configured"
186
+ result["valid_models"] = model_ids
187
+ return result
188
+
189
+ headers = {"Authorization": f"Bearer {key}"}
190
+
191
+ with httpx.Client(timeout=10.0) as client:
192
+ response = client.get(f"{base_url}/models", headers=headers)
193
+ response.raise_for_status()
194
+ data = response.json()
195
+
196
+ # Extract model IDs from response
197
+ available_ids = set()
198
+ for model in data.get("data", []):
199
+ available_ids.add(model.get("id", ""))
200
+
201
+ for model_id in model_ids:
202
+ if model_id in available_ids:
203
+ result["valid_models"].append(model_id)
204
+ else:
205
+ result["invalid_models"].append(model_id)
206
+
207
+ except ImportError:
208
+ result["error"] = "httpx package not installed"
209
+ result["valid_models"] = model_ids
210
+ except Exception as e:
211
+ result["error"] = f"Validation failed: {str(e)}"
212
+ result["valid_models"] = model_ids
213
+ finally:
214
+ result["validation_time_ms"] = int((time.time() - start_time) * 1000)
215
+
216
+ return result
217
+
218
+
219
+ def _validate_deepseek(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
220
+ """Validate DeepSeek models using OpenAI-compatible API."""
221
+ return _validate_openai_compatible(
222
+ model_ids,
223
+ "https://api.deepseek.com/v1",
224
+ api_key,
225
+ "DEEPSEEK_API_KEY",
226
+ )
227
+
228
+
229
+ def _validate_groq(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
230
+ """Validate Groq models using OpenAI-compatible API."""
231
+ return _validate_openai_compatible(
232
+ model_ids,
233
+ "https://api.groq.com/openai/v1",
234
+ api_key,
235
+ "GROQ_API_KEY",
236
+ )
237
+
238
+
239
+ def _validate_grok(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
240
+ """Validate Grok (X.AI) models using OpenAI-compatible API."""
241
+ return _validate_openai_compatible(
242
+ model_ids,
243
+ "https://api.x.ai/v1",
244
+ api_key,
245
+ "GROK_API_KEY",
246
+ )
247
+
248
+
249
+ def _validate_openrouter(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
250
+ """Validate OpenRouter models using their models API."""
251
+ result = _init_result()
252
+ start_time = time.time()
253
+
254
+ try:
255
+ import httpx
256
+
257
+ key = api_key or os.getenv("OPENROUTER_API_KEY")
258
+ if not key:
259
+ result["error"] = "OPENROUTER_API_KEY not configured"
260
+ result["valid_models"] = model_ids
261
+ return result
262
+
263
+ headers = {"Authorization": f"Bearer {key}"}
264
+
265
+ with httpx.Client(timeout=10.0) as client:
266
+ response = client.get(
267
+ "https://openrouter.ai/api/v1/models",
268
+ headers=headers,
269
+ )
270
+ response.raise_for_status()
271
+ data = response.json()
272
+
273
+ # Extract model IDs from response
274
+ available_ids = set()
275
+ for model in data.get("data", []):
276
+ available_ids.add(model.get("id", ""))
277
+
278
+ for model_id in model_ids:
279
+ if model_id in available_ids:
280
+ result["valid_models"].append(model_id)
281
+ else:
282
+ result["invalid_models"].append(model_id)
283
+
284
+ except ImportError:
285
+ result["error"] = "httpx package not installed"
286
+ result["valid_models"] = model_ids
287
+ except Exception as e:
288
+ result["error"] = f"Validation failed: {str(e)}"
289
+ result["valid_models"] = model_ids
290
+ finally:
291
+ result["validation_time_ms"] = int((time.time() - start_time) * 1000)
292
+
293
+ return result
294
+
295
+
296
+ def _validate_ollama(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
297
+ """Validate Ollama models using local API."""
298
+ result = _init_result()
299
+ start_time = time.time()
300
+
301
+ try:
302
+ import httpx
303
+
304
+ base_url = os.getenv("OLLAMA_HOST", "http://localhost:11434")
305
+
306
+ with httpx.Client(timeout=5.0) as client:
307
+ response = client.get(f"{base_url}/api/tags")
308
+ response.raise_for_status()
309
+ data = response.json()
310
+
311
+ # Extract model names from response
312
+ available_ids = set()
313
+ for model in data.get("models", []):
314
+ name = model.get("name", "")
315
+ available_ids.add(name)
316
+ # Also add without tag suffix (e.g., "llama3.2" from "llama3.2:latest")
317
+ if ":" in name:
318
+ available_ids.add(name.split(":")[0])
319
+
320
+ for model_id in model_ids:
321
+ if model_id in available_ids:
322
+ result["valid_models"].append(model_id)
323
+ else:
324
+ result["invalid_models"].append(model_id)
325
+
326
+ except ImportError:
327
+ result["error"] = "httpx package not installed"
328
+ result["valid_models"] = model_ids
329
+ except Exception as e:
330
+ error_msg = str(e)
331
+ if "Connection refused" in error_msg or "ConnectError" in error_msg:
332
+ result["error"] = "Ollama not running (start with: ollama serve)"
333
+ else:
334
+ result["error"] = f"Validation failed: {error_msg}"
335
+ result["valid_models"] = model_ids
336
+ finally:
337
+ result["validation_time_ms"] = int((time.time() - start_time) * 1000)
338
+
339
+ return result
340
+
341
+
342
+ def _validate_bedrock(model_ids: List[str], api_key: Optional[str] = None) -> Dict[str, Any]:
343
+ """Validate Bedrock models using boto3."""
344
+ # Import the existing bedrock validator
345
+ from .bedrock_validator import validate_bedrock_models
346
+ return validate_bedrock_models(model_ids)
347
+
348
+
349
+ def _init_result() -> Dict[str, Any]:
350
+ """Initialize empty result dict."""
351
+ return {
352
+ "valid_models": [],
353
+ "invalid_models": [],
354
+ "validation_time_ms": 0,
355
+ "error": None,
356
+ }
357
+
358
+
359
+ def get_validated_interactive_models(provider: str, all_models: bool = False) -> Dict[str, Any]:
360
+ """
361
+ Get validated models for a provider with metadata.
362
+
363
+ Args:
364
+ provider: Provider name
365
+ all_models: If True, validate ALL models for the provider (not just curated)
366
+
367
+ Returns:
368
+ Dict containing:
369
+ - models: Dict mapping model_id to metadata
370
+ - validation_result: Full validation result dict
371
+ """
372
+ from ..config import (
373
+ INTERACTIVE_OPENAI_MODELS,
374
+ INTERACTIVE_ANTHROPIC_MODELS,
375
+ INTERACTIVE_GOOGLE_MODELS,
376
+ INTERACTIVE_DEEPSEEK_MODELS,
377
+ INTERACTIVE_GROQ_MODELS,
378
+ INTERACTIVE_GROK_MODELS,
379
+ INTERACTIVE_OPENROUTER_MODELS,
380
+ INTERACTIVE_OLLAMA_MODELS,
381
+ INTERACTIVE_BEDROCK_MODELS,
382
+ OPENAI_MODELS,
383
+ ANTHROPIC_MODELS,
384
+ GOOGLE_MODELS,
385
+ DEEPSEEK_MODELS,
386
+ GROQ_MODELS,
387
+ GROK_MODELS,
388
+ OPENROUTER_MODELS,
389
+ OLLAMA_MODELS,
390
+ BEDROCK_MODELS,
391
+ )
392
+
393
+ # Map provider to interactive and full model configs
394
+ provider_configs = {
395
+ "openai": (INTERACTIVE_OPENAI_MODELS, OPENAI_MODELS),
396
+ "anthropic": (INTERACTIVE_ANTHROPIC_MODELS, ANTHROPIC_MODELS),
397
+ "google": (INTERACTIVE_GOOGLE_MODELS, GOOGLE_MODELS),
398
+ "deepseek": (INTERACTIVE_DEEPSEEK_MODELS, DEEPSEEK_MODELS),
399
+ "groq": (INTERACTIVE_GROQ_MODELS, GROQ_MODELS),
400
+ "grok": (INTERACTIVE_GROK_MODELS, GROK_MODELS),
401
+ "openrouter": (INTERACTIVE_OPENROUTER_MODELS, OPENROUTER_MODELS),
402
+ "ollama": (INTERACTIVE_OLLAMA_MODELS, OLLAMA_MODELS),
403
+ "bedrock": (INTERACTIVE_BEDROCK_MODELS, BEDROCK_MODELS),
404
+ }
405
+
406
+ if provider not in provider_configs:
407
+ return {
408
+ "models": {},
409
+ "validation_result": {
410
+ "valid_models": [],
411
+ "invalid_models": [],
412
+ "validation_time_ms": 0,
413
+ "error": f"Unknown provider: {provider}",
414
+ },
415
+ }
416
+
417
+ interactive_models, full_models = provider_configs[provider]
418
+
419
+ # Use all models or just curated interactive models
420
+ if all_models:
421
+ model_ids = list(full_models.keys())
422
+ else:
423
+ model_ids = list(interactive_models.keys())
424
+
425
+ # Validate
426
+ validation_result = validate_provider_models(provider, model_ids)
427
+
428
+ # Build validated models dict with full metadata
429
+ models = {}
430
+ for model_id in validation_result["valid_models"]:
431
+ interactive_meta = interactive_models.get(model_id, {})
432
+ full_config = full_models.get(model_id, {})
433
+
434
+ models[model_id] = {
435
+ **full_config,
436
+ **interactive_meta,
437
+ }
438
+
439
+ return {
440
+ "models": models,
441
+ "validation_result": validation_result,
442
+ }
@@ -0,0 +1,186 @@
1
+ """Token counting utilities for estimating LLM token usage."""
2
+
3
+ from typing import List, Optional
4
+ import tiktoken
5
+
6
+ from ..models import Message
7
+
8
+
9
+ def estimate_tokens(text: str, provider: str = "openai", model: Optional[str] = None) -> int:
10
+ """
11
+ Estimate the number of tokens in a text string.
12
+
13
+ Uses provider-specific tokenizers when available, falls back to
14
+ character-based estimation (1 token ≈ 4 characters).
15
+
16
+ Args:
17
+ text: The text to estimate tokens for
18
+ provider: The LLM provider (openai, anthropic, google, etc.)
19
+ model: Optional specific model name for more accurate counting
20
+
21
+ Returns:
22
+ Estimated number of tokens
23
+
24
+ Examples:
25
+ >>> estimate_tokens("Hello, world!", provider="openai")
26
+ 4
27
+ >>> estimate_tokens("Hello, world!", provider="anthropic")
28
+ 3
29
+ """
30
+ if not text:
31
+ return 0
32
+
33
+ # OpenAI models - use tiktoken
34
+ if provider == "openai":
35
+ try:
36
+ # Use model-specific encoding if provided
37
+ if model:
38
+ # Map common model names to encodings
39
+ if model.startswith(("gpt-4", "gpt-3.5")):
40
+ encoding = tiktoken.encoding_for_model(model)
41
+ elif model.startswith(("o1", "o3")):
42
+ # o1/o3 models use same encoding as gpt-4
43
+ encoding = tiktoken.encoding_for_model("gpt-4")
44
+ else:
45
+ # Default to cl100k_base (gpt-4, gpt-3.5-turbo)
46
+ encoding = tiktoken.get_encoding("cl100k_base")
47
+ else:
48
+ # Default to cl100k_base
49
+ encoding = tiktoken.get_encoding("cl100k_base")
50
+
51
+ return len(encoding.encode(text))
52
+ except Exception:
53
+ # Fall back to character-based if tiktoken fails
54
+ pass
55
+
56
+ # Anthropic models - approximate with character count
57
+ # Claude tokenizer is more aggressive than OpenAI
58
+ # Roughly 1 token ≈ 3.5 characters for English text
59
+ if provider == "anthropic":
60
+ return int(len(text) / 3.5)
61
+
62
+ # Google Gemini - similar to OpenAI
63
+ if provider == "google":
64
+ return int(len(text) / 4)
65
+
66
+ # DeepSeek, Groq, Grok, OpenRouter - approximate with OpenAI encoding
67
+ if provider in ["deepseek", "groq", "grok", "openrouter"]:
68
+ try:
69
+ encoding = tiktoken.get_encoding("cl100k_base")
70
+ return len(encoding.encode(text))
71
+ except Exception:
72
+ pass
73
+
74
+ # Ollama - local models, use conservative estimate
75
+ if provider == "ollama":
76
+ return int(len(text) / 4)
77
+
78
+ # Default fallback: 1 token ≈ 4 characters
79
+ return int(len(text) / 4)
80
+
81
+
82
+ def count_tokens_for_messages(messages: List[Message], provider: str = "openai", model: Optional[str] = None) -> int:
83
+ """
84
+ Count tokens for a list of messages, including formatting overhead.
85
+
86
+ Different models have different formatting requirements that add tokens.
87
+ This function accounts for those overheads.
88
+
89
+ Args:
90
+ messages: List of Message objects
91
+ provider: The LLM provider
92
+ model: Optional specific model name
93
+
94
+ Returns:
95
+ Total estimated token count including formatting
96
+
97
+ Examples:
98
+ >>> from stratifyai.models import Message
99
+ >>> messages = [Message(role="user", content="Hello")]
100
+ >>> count_tokens_for_messages(messages, provider="openai")
101
+ 7 # Content tokens + formatting tokens
102
+ """
103
+ if not messages:
104
+ return 0
105
+
106
+ # Count content tokens
107
+ total_tokens = 0
108
+ for message in messages:
109
+ # Count message content
110
+ total_tokens += estimate_tokens(message.content, provider, model)
111
+
112
+ # Add role tokens
113
+ total_tokens += estimate_tokens(message.role, provider, model)
114
+
115
+ # Add formatting overhead per message
116
+ # OpenAI format: <|start|>role\ncontent<|end|>\n
117
+ # Roughly 4-7 tokens per message for formatting
118
+ if provider == "openai":
119
+ tokens_per_message = 4
120
+ if model and model.startswith("gpt-3.5"):
121
+ tokens_per_message = 4
122
+ elif model and model.startswith("gpt-4"):
123
+ tokens_per_message = 3
124
+ total_tokens += tokens_per_message * len(messages)
125
+ total_tokens += 3 # Every reply is primed with <|start|>assistant<|message|>
126
+
127
+ # Anthropic Messages API has minimal overhead
128
+ elif provider == "anthropic":
129
+ total_tokens += 2 * len(messages) # Minimal formatting overhead
130
+
131
+ # Other providers - approximate 3 tokens per message
132
+ else:
133
+ total_tokens += 3 * len(messages)
134
+
135
+ return total_tokens
136
+
137
+
138
+ def get_context_window(provider: str, model: str) -> int:
139
+ """
140
+ Get the context window size for a specific model.
141
+
142
+ Args:
143
+ provider: The LLM provider
144
+ model: The model name
145
+
146
+ Returns:
147
+ Context window size in tokens
148
+ """
149
+ from ..config import MODEL_CATALOG
150
+
151
+ model_info = MODEL_CATALOG.get(provider, {}).get(model, {})
152
+ return model_info.get("context", 128000) # Default to 128k
153
+
154
+
155
+ def check_token_limit(token_count: int, provider: str, model: str, threshold: float = 0.8) -> tuple[bool, int, float]:
156
+ """
157
+ Check if token count is approaching the model's context limit.
158
+
159
+ Args:
160
+ token_count: Number of tokens
161
+ provider: The LLM provider
162
+ model: The model name
163
+ threshold: Warning threshold (default 0.8 = 80%)
164
+
165
+ Returns:
166
+ Tuple of (exceeds_threshold, context_window, percentage_used)
167
+
168
+ Examples:
169
+ >>> check_token_limit(100000, "openai", "gpt-4o", threshold=0.8)
170
+ (False, 128000, 0.78125)
171
+ >>> check_token_limit(110000, "openai", "gpt-4o", threshold=0.8)
172
+ (True, 128000, 0.859375)
173
+ """
174
+ context_window = get_context_window(provider, model)
175
+
176
+ # Check for API-imposed input limits (e.g., Claude Opus 4.5)
177
+ from ..config import MODEL_CATALOG
178
+ model_info = MODEL_CATALOG.get(provider, {}).get(model, {})
179
+ api_max_input = model_info.get("api_max_input")
180
+ if api_max_input and api_max_input < context_window:
181
+ context_window = api_max_input
182
+
183
+ percentage_used = token_count / context_window if context_window > 0 else 1.0
184
+ exceeds_threshold = percentage_used > threshold
185
+
186
+ return exceeds_threshold, context_window, percentage_used