stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
stratifyai/router.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
1
|
+
"""Intelligent model router for selecting optimal LLM providers."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Dict, List, Optional, Tuple
|
|
6
|
+
import re
|
|
7
|
+
|
|
8
|
+
from .config import MODEL_CATALOG
|
|
9
|
+
from .models import Message
|
|
10
|
+
from .utils.file_analyzer import FileType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RoutingStrategy(str, Enum):
|
|
14
|
+
"""Routing strategy types."""
|
|
15
|
+
COST = "cost"
|
|
16
|
+
QUALITY = "quality"
|
|
17
|
+
LATENCY = "latency"
|
|
18
|
+
HYBRID = "hybrid"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ModelMetadata:
|
|
23
|
+
"""Metadata for a specific model."""
|
|
24
|
+
provider: str
|
|
25
|
+
model: str
|
|
26
|
+
quality_score: float # 0.0 - 1.0
|
|
27
|
+
cost_per_1m_input: float # USD per 1M tokens
|
|
28
|
+
cost_per_1m_output: float # USD per 1M tokens
|
|
29
|
+
avg_latency_ms: float # Average response time in milliseconds
|
|
30
|
+
context_window: int # Maximum context tokens
|
|
31
|
+
capabilities: List[str] # e.g., ["reasoning", "vision", "tools"]
|
|
32
|
+
reasoning_model: bool = False
|
|
33
|
+
supports_streaming: bool = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Router:
|
|
37
|
+
"""Intelligent model router for selecting optimal providers."""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
strategy: RoutingStrategy = RoutingStrategy.HYBRID,
|
|
42
|
+
preferred_providers: Optional[List[str]] = None,
|
|
43
|
+
excluded_providers: Optional[List[str]] = None,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Initialize router.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
strategy: Routing strategy to use
|
|
50
|
+
preferred_providers: List of preferred providers (prioritized)
|
|
51
|
+
excluded_providers: List of providers to exclude
|
|
52
|
+
"""
|
|
53
|
+
self.strategy = strategy
|
|
54
|
+
self.preferred_providers = preferred_providers or []
|
|
55
|
+
self.excluded_providers = excluded_providers or []
|
|
56
|
+
self._load_model_metadata()
|
|
57
|
+
|
|
58
|
+
def _load_model_metadata(self) -> None:
|
|
59
|
+
"""Load model metadata from MODEL_CATALOG."""
|
|
60
|
+
self.model_metadata: Dict[str, ModelMetadata] = {}
|
|
61
|
+
|
|
62
|
+
# Quality scores (estimated based on benchmarks)
|
|
63
|
+
# These are rough estimates - should be updated with real benchmark data
|
|
64
|
+
quality_scores = {
|
|
65
|
+
# OpenAI
|
|
66
|
+
"gpt-5": 0.98,
|
|
67
|
+
"o3-mini": 0.95,
|
|
68
|
+
"gpt-4.5-turbo-20250205": 0.93,
|
|
69
|
+
"gpt-4.1-turbo": 0.90,
|
|
70
|
+
"gpt-4.1-mini": 0.82,
|
|
71
|
+
"o1-mini": 0.88,
|
|
72
|
+
"o1": 0.96,
|
|
73
|
+
|
|
74
|
+
# Anthropic
|
|
75
|
+
"claude-sonnet-4-5-20250929": 0.94,
|
|
76
|
+
"claude-3-5-sonnet-20241022": 0.92,
|
|
77
|
+
"claude-3-5-haiku-20241022": 0.80,
|
|
78
|
+
|
|
79
|
+
# Google
|
|
80
|
+
"gemini-2.5-pro": 0.91,
|
|
81
|
+
"gemini-2.5-flash": 0.85,
|
|
82
|
+
"gemini-2.5-flash-lite": 0.78,
|
|
83
|
+
|
|
84
|
+
# DeepSeek
|
|
85
|
+
"deepseek-chat": 0.85,
|
|
86
|
+
"deepseek-reasoner": 0.90,
|
|
87
|
+
|
|
88
|
+
# Groq
|
|
89
|
+
"llama-3.1-70b-versatile": 0.83,
|
|
90
|
+
"llama-3.1-8b-instant": 0.75,
|
|
91
|
+
"mixtral-8x7b-32768": 0.80,
|
|
92
|
+
|
|
93
|
+
# Grok
|
|
94
|
+
"grok-beta": 0.87,
|
|
95
|
+
|
|
96
|
+
# OpenRouter (using same as original providers)
|
|
97
|
+
"anthropic/claude-3-5-sonnet": 0.92,
|
|
98
|
+
"openai/gpt-4-turbo": 0.90,
|
|
99
|
+
|
|
100
|
+
# Ollama (local models)
|
|
101
|
+
"llama3.2": 0.70,
|
|
102
|
+
"mistral": 0.68,
|
|
103
|
+
"codellama": 0.72,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
# Average latency (ms) - rough estimates
|
|
107
|
+
latency_estimates = {
|
|
108
|
+
# OpenAI
|
|
109
|
+
"gpt-5": 3500,
|
|
110
|
+
"o3-mini": 8000,
|
|
111
|
+
"gpt-4.5-turbo-20250205": 2500,
|
|
112
|
+
"gpt-4.1-turbo": 2000,
|
|
113
|
+
"gpt-4.1-mini": 800,
|
|
114
|
+
"o1-mini": 5000,
|
|
115
|
+
"o1": 10000,
|
|
116
|
+
|
|
117
|
+
# Anthropic
|
|
118
|
+
"claude-sonnet-4-5-20250929": 2800,
|
|
119
|
+
"claude-3-5-sonnet-20241022": 2200,
|
|
120
|
+
"claude-3-5-haiku-20241022": 1200,
|
|
121
|
+
|
|
122
|
+
# Google
|
|
123
|
+
"gemini-2.5-pro": 2000,
|
|
124
|
+
"gemini-2.5-flash": 1000,
|
|
125
|
+
"gemini-2.5-flash-lite": 600,
|
|
126
|
+
|
|
127
|
+
# DeepSeek
|
|
128
|
+
"deepseek-chat": 1500,
|
|
129
|
+
"deepseek-reasoner": 6000,
|
|
130
|
+
|
|
131
|
+
# Groq (known for speed)
|
|
132
|
+
"llama-3.1-70b-versatile": 400,
|
|
133
|
+
"llama-3.1-8b-instant": 200,
|
|
134
|
+
"mixtral-8x7b-32768": 350,
|
|
135
|
+
|
|
136
|
+
# Grok
|
|
137
|
+
"grok-beta": 2500,
|
|
138
|
+
|
|
139
|
+
# OpenRouter
|
|
140
|
+
"anthropic/claude-3-5-sonnet": 2500,
|
|
141
|
+
"openai/gpt-4-turbo": 2200,
|
|
142
|
+
|
|
143
|
+
# Ollama (local, very fast)
|
|
144
|
+
"llama3.2": 100,
|
|
145
|
+
"mistral": 120,
|
|
146
|
+
"codellama": 110,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# Build metadata from MODEL_CATALOG
|
|
150
|
+
for provider, models in MODEL_CATALOG.items():
|
|
151
|
+
# Skip if provider is excluded
|
|
152
|
+
if provider in self.excluded_providers:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
for model_name, model_info in models.items():
|
|
156
|
+
key = f"{provider}/{model_name}"
|
|
157
|
+
|
|
158
|
+
# Extract capabilities
|
|
159
|
+
capabilities = []
|
|
160
|
+
if model_info.get("supports_vision"):
|
|
161
|
+
capabilities.append("vision")
|
|
162
|
+
if model_info.get("supports_tools"):
|
|
163
|
+
capabilities.append("tools")
|
|
164
|
+
if model_info.get("reasoning_model"):
|
|
165
|
+
capabilities.append("reasoning")
|
|
166
|
+
|
|
167
|
+
self.model_metadata[key] = ModelMetadata(
|
|
168
|
+
provider=provider,
|
|
169
|
+
model=model_name,
|
|
170
|
+
quality_score=quality_scores.get(model_name, 0.75),
|
|
171
|
+
cost_per_1m_input=model_info.get("cost_input", 0.0),
|
|
172
|
+
cost_per_1m_output=model_info.get("cost_output", 0.0),
|
|
173
|
+
avg_latency_ms=latency_estimates.get(model_name, 2000),
|
|
174
|
+
context_window=model_info.get("context", 8192),
|
|
175
|
+
capabilities=capabilities,
|
|
176
|
+
reasoning_model=model_info.get("reasoning_model", False),
|
|
177
|
+
supports_streaming=True,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def route(
|
|
181
|
+
self,
|
|
182
|
+
messages: List[Message],
|
|
183
|
+
required_capabilities: Optional[List[str]] = None,
|
|
184
|
+
max_cost_per_1k_tokens: Optional[float] = None,
|
|
185
|
+
max_latency_ms: Optional[float] = None,
|
|
186
|
+
min_context_window: Optional[int] = None,
|
|
187
|
+
) -> Tuple[str, str]:
|
|
188
|
+
"""
|
|
189
|
+
Select the best model for the given request.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
messages: Conversation messages
|
|
193
|
+
required_capabilities: Required model capabilities (e.g., ["vision", "tools"])
|
|
194
|
+
max_cost_per_1k_tokens: Maximum acceptable cost per 1k tokens
|
|
195
|
+
max_latency_ms: Maximum acceptable latency in milliseconds
|
|
196
|
+
min_context_window: Minimum required context window
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Tuple of (provider, model) for the selected model
|
|
200
|
+
"""
|
|
201
|
+
# Analyze prompt complexity
|
|
202
|
+
complexity = self._analyze_complexity(messages)
|
|
203
|
+
|
|
204
|
+
# Filter models by requirements
|
|
205
|
+
candidates = self._filter_candidates(
|
|
206
|
+
required_capabilities,
|
|
207
|
+
max_cost_per_1k_tokens,
|
|
208
|
+
max_latency_ms,
|
|
209
|
+
min_context_window,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if not candidates:
|
|
213
|
+
raise ValueError(
|
|
214
|
+
"No models meet the specified requirements. "
|
|
215
|
+
"Consider relaxing constraints."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Select model based on strategy
|
|
219
|
+
if self.strategy == RoutingStrategy.COST:
|
|
220
|
+
selected_key = self._select_by_cost(candidates)
|
|
221
|
+
elif self.strategy == RoutingStrategy.QUALITY:
|
|
222
|
+
selected_key = self._select_by_quality(candidates, complexity)
|
|
223
|
+
elif self.strategy == RoutingStrategy.LATENCY:
|
|
224
|
+
selected_key = self._select_by_latency(candidates)
|
|
225
|
+
elif self.strategy == RoutingStrategy.HYBRID:
|
|
226
|
+
selected_key = self._select_hybrid(candidates, complexity)
|
|
227
|
+
else:
|
|
228
|
+
selected_key = list(candidates.keys())[0]
|
|
229
|
+
|
|
230
|
+
metadata = self.model_metadata[selected_key]
|
|
231
|
+
return metadata.provider, metadata.model
|
|
232
|
+
|
|
233
|
+
def _analyze_complexity(self, messages: List[Message]) -> float:
|
|
234
|
+
"""
|
|
235
|
+
Analyze prompt complexity (0.0 - 1.0).
|
|
236
|
+
|
|
237
|
+
Higher complexity scores indicate more difficult tasks requiring
|
|
238
|
+
better models.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
messages: Conversation messages
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Complexity score from 0.0 (simple) to 1.0 (complex)
|
|
245
|
+
"""
|
|
246
|
+
# Combine all message content
|
|
247
|
+
text = " ".join(msg.content for msg in messages)
|
|
248
|
+
|
|
249
|
+
# Initialize score
|
|
250
|
+
complexity_score = 0.0
|
|
251
|
+
|
|
252
|
+
# 1. Check for reasoning keywords (40% weight)
|
|
253
|
+
reasoning_keywords = [
|
|
254
|
+
r'\banalyze\b', r'\bexplain\b', r'\breasoning\b', r'\bproof\b',
|
|
255
|
+
r'\bstep by step\b', r'\bcomplex\b', r'\bcalculate\b',
|
|
256
|
+
r'\bderive\b', r'\bsolve\b', r'\bprove\b', r'\bthink\b',
|
|
257
|
+
r'\bcompare\b', r'\bevaluate\b', r'\bdetailed\b',
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
reasoning_matches = sum(
|
|
261
|
+
1 for pattern in reasoning_keywords
|
|
262
|
+
if re.search(pattern, text.lower())
|
|
263
|
+
)
|
|
264
|
+
complexity_score += min(reasoning_matches / len(reasoning_keywords), 1.0) * 0.4
|
|
265
|
+
|
|
266
|
+
# 2. Length-based complexity (20% weight)
|
|
267
|
+
# Longer prompts often indicate more complex tasks
|
|
268
|
+
text_length = len(text)
|
|
269
|
+
length_score = min(text_length / 2000, 1.0) # Normalize to 2000 chars
|
|
270
|
+
complexity_score += length_score * 0.2
|
|
271
|
+
|
|
272
|
+
# 3. Check for code/technical content (20% weight)
|
|
273
|
+
code_indicators = [
|
|
274
|
+
r'```', r'function\s+\w+', r'class\s+\w+', r'def\s+\w+',
|
|
275
|
+
r'import\s+\w+', r'\/\/.*', r'\/\*.*\*\/', r'\bcode\b',
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
code_matches = sum(
|
|
279
|
+
1 for pattern in code_indicators
|
|
280
|
+
if re.search(pattern, text, re.IGNORECASE)
|
|
281
|
+
)
|
|
282
|
+
complexity_score += min(code_matches / len(code_indicators), 1.0) * 0.2
|
|
283
|
+
|
|
284
|
+
# 4. Multi-turn conversation (10% weight)
|
|
285
|
+
# More messages can indicate more complex context
|
|
286
|
+
message_count = len(messages)
|
|
287
|
+
conversation_score = min(message_count / 10, 1.0)
|
|
288
|
+
complexity_score += conversation_score * 0.1
|
|
289
|
+
|
|
290
|
+
# 5. Mathematical content (10% weight)
|
|
291
|
+
math_indicators = [
|
|
292
|
+
r'\d+\s*[+\-*/]\s*\d+', r'\bequation\b', r'\bformula\b',
|
|
293
|
+
r'\bcalculus\b', r'\balgebra\b', r'\bintegral\b',
|
|
294
|
+
]
|
|
295
|
+
|
|
296
|
+
math_matches = sum(
|
|
297
|
+
1 for pattern in math_indicators
|
|
298
|
+
if re.search(pattern, text, re.IGNORECASE)
|
|
299
|
+
)
|
|
300
|
+
complexity_score += min(math_matches / len(math_indicators), 1.0) * 0.1
|
|
301
|
+
|
|
302
|
+
return min(complexity_score, 1.0)
|
|
303
|
+
|
|
304
|
+
def _filter_candidates(
|
|
305
|
+
self,
|
|
306
|
+
required_capabilities: Optional[List[str]],
|
|
307
|
+
max_cost_per_1k: Optional[float],
|
|
308
|
+
max_latency_ms: Optional[float],
|
|
309
|
+
min_context_window: Optional[int],
|
|
310
|
+
) -> Dict[str, ModelMetadata]:
|
|
311
|
+
"""Filter models based on requirements."""
|
|
312
|
+
candidates = {}
|
|
313
|
+
|
|
314
|
+
for key, meta in self.model_metadata.items():
|
|
315
|
+
# Check capabilities
|
|
316
|
+
if required_capabilities:
|
|
317
|
+
if not all(cap in meta.capabilities for cap in required_capabilities):
|
|
318
|
+
continue
|
|
319
|
+
|
|
320
|
+
# Check cost constraint
|
|
321
|
+
if max_cost_per_1k:
|
|
322
|
+
avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
|
|
323
|
+
if avg_cost_per_1k > max_cost_per_1k:
|
|
324
|
+
continue
|
|
325
|
+
|
|
326
|
+
# Check latency constraint
|
|
327
|
+
if max_latency_ms and meta.avg_latency_ms > max_latency_ms:
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
# Check context window constraint
|
|
331
|
+
if min_context_window and meta.context_window < min_context_window:
|
|
332
|
+
continue
|
|
333
|
+
|
|
334
|
+
candidates[key] = meta
|
|
335
|
+
|
|
336
|
+
# Prioritize preferred providers
|
|
337
|
+
if self.preferred_providers:
|
|
338
|
+
prioritized = {}
|
|
339
|
+
for key, meta in candidates.items():
|
|
340
|
+
if meta.provider in self.preferred_providers:
|
|
341
|
+
prioritized[key] = meta
|
|
342
|
+
if prioritized:
|
|
343
|
+
return prioritized
|
|
344
|
+
|
|
345
|
+
return candidates
|
|
346
|
+
|
|
347
|
+
def _select_by_cost(self, candidates: Dict[str, ModelMetadata]) -> str:
|
|
348
|
+
"""Select cheapest model."""
|
|
349
|
+
def cost_score(meta: ModelMetadata) -> float:
|
|
350
|
+
# Average cost per 1M tokens (input + output weighted equally)
|
|
351
|
+
return (meta.cost_per_1m_input + meta.cost_per_1m_output) / 2
|
|
352
|
+
|
|
353
|
+
return min(candidates.keys(), key=lambda k: cost_score(candidates[k]))
|
|
354
|
+
|
|
355
|
+
def _select_by_quality(
|
|
356
|
+
self,
|
|
357
|
+
candidates: Dict[str, ModelMetadata],
|
|
358
|
+
complexity: float
|
|
359
|
+
) -> str:
|
|
360
|
+
"""
|
|
361
|
+
Select highest quality model.
|
|
362
|
+
|
|
363
|
+
For high complexity tasks, prioritize models with reasoning capabilities.
|
|
364
|
+
"""
|
|
365
|
+
def quality_score(meta: ModelMetadata) -> float:
|
|
366
|
+
base_score = meta.quality_score
|
|
367
|
+
|
|
368
|
+
# Boost reasoning models for complex tasks
|
|
369
|
+
if complexity > 0.6 and meta.reasoning_model:
|
|
370
|
+
base_score += 0.05
|
|
371
|
+
|
|
372
|
+
return base_score
|
|
373
|
+
|
|
374
|
+
return max(candidates.keys(), key=lambda k: quality_score(candidates[k]))
|
|
375
|
+
|
|
376
|
+
def _select_by_latency(self, candidates: Dict[str, ModelMetadata]) -> str:
|
|
377
|
+
"""Select fastest model."""
|
|
378
|
+
return min(candidates.keys(), key=lambda k: candidates[k].avg_latency_ms)
|
|
379
|
+
|
|
380
|
+
def _select_hybrid(
|
|
381
|
+
self,
|
|
382
|
+
candidates: Dict[str, ModelMetadata],
|
|
383
|
+
complexity: float
|
|
384
|
+
) -> str:
|
|
385
|
+
"""
|
|
386
|
+
Hybrid selection balancing cost, quality, and latency.
|
|
387
|
+
|
|
388
|
+
Adjusts weights based on task complexity:
|
|
389
|
+
- Low complexity: Prioritize cost (60%) and speed (30%), quality (10%)
|
|
390
|
+
- High complexity: Prioritize quality (60%) and cost (30%), speed (10%)
|
|
391
|
+
"""
|
|
392
|
+
scores = {}
|
|
393
|
+
|
|
394
|
+
# Adjust weights based on complexity
|
|
395
|
+
quality_weight = 0.1 + (complexity * 0.5) # 0.1 to 0.6
|
|
396
|
+
cost_weight = 0.6 - (complexity * 0.3) # 0.6 to 0.3
|
|
397
|
+
latency_weight = 0.3 - (complexity * 0.2) # 0.3 to 0.1
|
|
398
|
+
|
|
399
|
+
for key, meta in candidates.items():
|
|
400
|
+
# Quality score (0-1, higher is better)
|
|
401
|
+
quality_score = meta.quality_score
|
|
402
|
+
|
|
403
|
+
# Cost score (0-1, lower cost is better)
|
|
404
|
+
# Normalize to typical range: $0.001 to $0.050 per 1k tokens
|
|
405
|
+
avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
|
|
406
|
+
cost_score = max(0, 1 - (avg_cost_per_1k / 0.050))
|
|
407
|
+
|
|
408
|
+
# Latency score (0-1, lower latency is better)
|
|
409
|
+
# Normalize to typical range: 100ms to 10000ms
|
|
410
|
+
latency_score = max(0, 1 - (meta.avg_latency_ms / 10000))
|
|
411
|
+
|
|
412
|
+
# Calculate weighted score
|
|
413
|
+
scores[key] = (
|
|
414
|
+
quality_weight * quality_score +
|
|
415
|
+
cost_weight * cost_score +
|
|
416
|
+
latency_weight * latency_score
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
return max(scores, key=scores.get)
|
|
420
|
+
|
|
421
|
+
def get_model_info(self, provider: str, model: str) -> Optional[ModelMetadata]:
|
|
422
|
+
"""Get metadata for a specific model."""
|
|
423
|
+
key = f"{provider}/{model}"
|
|
424
|
+
return self.model_metadata.get(key)
|
|
425
|
+
|
|
426
|
+
def route_for_extraction(
|
|
427
|
+
self,
|
|
428
|
+
file_type: FileType,
|
|
429
|
+
extraction_mode: str = "schema",
|
|
430
|
+
max_cost_per_1k_tokens: Optional[float] = None,
|
|
431
|
+
) -> Tuple[str, str]:
|
|
432
|
+
"""
|
|
433
|
+
Select the best model for file extraction tasks.
|
|
434
|
+
|
|
435
|
+
Prioritizes quality over cost/latency for extraction tasks,
|
|
436
|
+
with task-specific optimizations.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
file_type: Type of file being extracted
|
|
440
|
+
extraction_mode: Mode of extraction ("schema", "errors", "structure", "summary")
|
|
441
|
+
max_cost_per_1k_tokens: Optional cost constraint
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Tuple of (provider, model) optimized for extraction
|
|
445
|
+
"""
|
|
446
|
+
# Define extraction-specific quality weights
|
|
447
|
+
if extraction_mode == "schema":
|
|
448
|
+
quality_weight = 0.90
|
|
449
|
+
cost_weight = 0.10
|
|
450
|
+
elif extraction_mode == "errors":
|
|
451
|
+
quality_weight = 0.80
|
|
452
|
+
cost_weight = 0.20
|
|
453
|
+
elif extraction_mode == "structure":
|
|
454
|
+
quality_weight = 0.85
|
|
455
|
+
cost_weight = 0.15
|
|
456
|
+
else: # summary
|
|
457
|
+
quality_weight = 0.70
|
|
458
|
+
cost_weight = 0.30
|
|
459
|
+
|
|
460
|
+
# Filter candidates
|
|
461
|
+
candidates = {}
|
|
462
|
+
for key, meta in self.model_metadata.items():
|
|
463
|
+
# Apply cost constraint if specified
|
|
464
|
+
if max_cost_per_1k_tokens:
|
|
465
|
+
avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
|
|
466
|
+
if avg_cost_per_1k > max_cost_per_1k_tokens:
|
|
467
|
+
continue
|
|
468
|
+
|
|
469
|
+
candidates[key] = meta
|
|
470
|
+
|
|
471
|
+
if not candidates:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
"No models meet cost constraints. "
|
|
474
|
+
"Consider increasing max_cost_per_1k_tokens."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Score candidates with extraction-focused weights
|
|
478
|
+
scores = {}
|
|
479
|
+
for key, meta in candidates.items():
|
|
480
|
+
# Quality score (0-1, higher is better)
|
|
481
|
+
quality_score = meta.quality_score
|
|
482
|
+
|
|
483
|
+
# Boost for reasoning models on error extraction
|
|
484
|
+
if extraction_mode == "errors" and meta.reasoning_model:
|
|
485
|
+
quality_score += 0.05
|
|
486
|
+
|
|
487
|
+
# Cost score (0-1, lower cost is better)
|
|
488
|
+
avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
|
|
489
|
+
cost_score = max(0, 1 - (avg_cost_per_1k / 0.050))
|
|
490
|
+
|
|
491
|
+
# Calculate weighted score (no latency for extraction)
|
|
492
|
+
scores[key] = (
|
|
493
|
+
quality_weight * quality_score +
|
|
494
|
+
cost_weight * cost_score
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Select highest scoring model
|
|
498
|
+
best_key = max(scores, key=scores.get)
|
|
499
|
+
metadata = self.model_metadata[best_key]
|
|
500
|
+
return metadata.provider, metadata.model
|
|
501
|
+
|
|
502
|
+
def get_fallback_chain(
|
|
503
|
+
self,
|
|
504
|
+
messages: List[Message],
|
|
505
|
+
count: int = 3,
|
|
506
|
+
required_capabilities: Optional[List[str]] = None,
|
|
507
|
+
max_cost_per_1k_tokens: Optional[float] = None,
|
|
508
|
+
max_latency_ms: Optional[float] = None,
|
|
509
|
+
min_context_window: Optional[int] = None,
|
|
510
|
+
exclude_models: Optional[List[str]] = None,
|
|
511
|
+
) -> List[Tuple[str, str]]:
|
|
512
|
+
"""
|
|
513
|
+
Get a ranked list of fallback models for resilient routing.
|
|
514
|
+
|
|
515
|
+
Returns models ranked by the current strategy, allowing automatic
|
|
516
|
+
fallback if the primary model fails.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
messages: Conversation messages (used for complexity analysis)
|
|
520
|
+
count: Number of fallback models to return (default: 3)
|
|
521
|
+
required_capabilities: Required model capabilities
|
|
522
|
+
max_cost_per_1k_tokens: Maximum acceptable cost
|
|
523
|
+
max_latency_ms: Maximum acceptable latency
|
|
524
|
+
min_context_window: Minimum required context window
|
|
525
|
+
exclude_models: Models to exclude from results
|
|
526
|
+
|
|
527
|
+
Returns:
|
|
528
|
+
List of (provider, model) tuples ranked by strategy
|
|
529
|
+
|
|
530
|
+
Example:
|
|
531
|
+
>>> router = Router(strategy=RoutingStrategy.HYBRID)
|
|
532
|
+
>>> fallbacks = router.get_fallback_chain(messages, count=3)
|
|
533
|
+
>>> # Returns: [("openai", "gpt-4o"), ("anthropic", "claude-3-5-sonnet"), ...]
|
|
534
|
+
"""
|
|
535
|
+
# Analyze complexity for scoring
|
|
536
|
+
complexity = self._analyze_complexity(messages)
|
|
537
|
+
|
|
538
|
+
# Filter candidates
|
|
539
|
+
candidates = self._filter_candidates(
|
|
540
|
+
required_capabilities,
|
|
541
|
+
max_cost_per_1k_tokens,
|
|
542
|
+
max_latency_ms,
|
|
543
|
+
min_context_window,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Exclude specified models
|
|
547
|
+
if exclude_models:
|
|
548
|
+
candidates = {
|
|
549
|
+
k: v for k, v in candidates.items()
|
|
550
|
+
if v.model not in exclude_models
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
if not candidates:
|
|
554
|
+
return []
|
|
555
|
+
|
|
556
|
+
# Score all candidates based on strategy
|
|
557
|
+
scores = self._score_candidates(candidates, complexity)
|
|
558
|
+
|
|
559
|
+
# Sort by score (descending) and return top N
|
|
560
|
+
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
|
561
|
+
|
|
562
|
+
result = []
|
|
563
|
+
for key, _ in ranked[:count]:
|
|
564
|
+
meta = self.model_metadata[key]
|
|
565
|
+
result.append((meta.provider, meta.model))
|
|
566
|
+
|
|
567
|
+
return result
|
|
568
|
+
|
|
569
|
+
def _score_candidates(
|
|
570
|
+
self,
|
|
571
|
+
candidates: Dict[str, ModelMetadata],
|
|
572
|
+
complexity: float,
|
|
573
|
+
) -> Dict[str, float]:
|
|
574
|
+
"""
|
|
575
|
+
Score all candidates based on current routing strategy.
|
|
576
|
+
|
|
577
|
+
Args:
|
|
578
|
+
candidates: Filtered model candidates
|
|
579
|
+
complexity: Task complexity score (0.0 - 1.0)
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
Dictionary mapping model keys to scores
|
|
583
|
+
"""
|
|
584
|
+
scores = {}
|
|
585
|
+
|
|
586
|
+
for key, meta in candidates.items():
|
|
587
|
+
if self.strategy == RoutingStrategy.COST:
|
|
588
|
+
# Lower cost = higher score
|
|
589
|
+
avg_cost = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 2
|
|
590
|
+
scores[key] = max(0, 1 - (avg_cost / 100)) # Normalize to $100/1M
|
|
591
|
+
|
|
592
|
+
elif self.strategy == RoutingStrategy.QUALITY:
|
|
593
|
+
scores[key] = meta.quality_score
|
|
594
|
+
if complexity > 0.6 and meta.reasoning_model:
|
|
595
|
+
scores[key] += 0.05
|
|
596
|
+
|
|
597
|
+
elif self.strategy == RoutingStrategy.LATENCY:
|
|
598
|
+
# Lower latency = higher score
|
|
599
|
+
scores[key] = max(0, 1 - (meta.avg_latency_ms / 10000))
|
|
600
|
+
|
|
601
|
+
else: # HYBRID
|
|
602
|
+
# Dynamic weights based on complexity
|
|
603
|
+
quality_weight = 0.1 + (complexity * 0.5)
|
|
604
|
+
cost_weight = 0.6 - (complexity * 0.3)
|
|
605
|
+
latency_weight = 0.3 - (complexity * 0.2)
|
|
606
|
+
|
|
607
|
+
quality_score = meta.quality_score
|
|
608
|
+
avg_cost = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
|
|
609
|
+
cost_score = max(0, 1 - (avg_cost / 0.050))
|
|
610
|
+
latency_score = max(0, 1 - (meta.avg_latency_ms / 10000))
|
|
611
|
+
|
|
612
|
+
scores[key] = (
|
|
613
|
+
quality_weight * quality_score +
|
|
614
|
+
cost_weight * cost_score +
|
|
615
|
+
latency_weight * latency_score
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
return scores
|
|
619
|
+
|
|
620
|
+
def list_models(
|
|
621
|
+
self,
|
|
622
|
+
required_capabilities: Optional[List[str]] = None,
|
|
623
|
+
) -> List[Tuple[str, str, ModelMetadata]]:
|
|
624
|
+
"""
|
|
625
|
+
List all available models with their metadata.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
required_capabilities: Filter by required capabilities
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
List of (provider, model, metadata) tuples
|
|
632
|
+
"""
|
|
633
|
+
results = []
|
|
634
|
+
|
|
635
|
+
for key, meta in self.model_metadata.items():
|
|
636
|
+
# Check capabilities
|
|
637
|
+
if required_capabilities:
|
|
638
|
+
if not all(cap in meta.capabilities for cap in required_capabilities):
|
|
639
|
+
continue
|
|
640
|
+
|
|
641
|
+
results.append((meta.provider, meta.model, meta))
|
|
642
|
+
|
|
643
|
+
return results
|