stratifyai 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. cli/__init__.py +5 -0
  2. cli/stratifyai_cli.py +1753 -0
  3. stratifyai/__init__.py +113 -0
  4. stratifyai/api_key_helper.py +372 -0
  5. stratifyai/caching.py +279 -0
  6. stratifyai/chat/__init__.py +54 -0
  7. stratifyai/chat/builder.py +366 -0
  8. stratifyai/chat/stratifyai_anthropic.py +194 -0
  9. stratifyai/chat/stratifyai_bedrock.py +200 -0
  10. stratifyai/chat/stratifyai_deepseek.py +194 -0
  11. stratifyai/chat/stratifyai_google.py +194 -0
  12. stratifyai/chat/stratifyai_grok.py +194 -0
  13. stratifyai/chat/stratifyai_groq.py +195 -0
  14. stratifyai/chat/stratifyai_ollama.py +201 -0
  15. stratifyai/chat/stratifyai_openai.py +209 -0
  16. stratifyai/chat/stratifyai_openrouter.py +201 -0
  17. stratifyai/chunking.py +158 -0
  18. stratifyai/client.py +292 -0
  19. stratifyai/config.py +1273 -0
  20. stratifyai/cost_tracker.py +257 -0
  21. stratifyai/embeddings.py +245 -0
  22. stratifyai/exceptions.py +91 -0
  23. stratifyai/models.py +59 -0
  24. stratifyai/providers/__init__.py +5 -0
  25. stratifyai/providers/anthropic.py +330 -0
  26. stratifyai/providers/base.py +183 -0
  27. stratifyai/providers/bedrock.py +634 -0
  28. stratifyai/providers/deepseek.py +39 -0
  29. stratifyai/providers/google.py +39 -0
  30. stratifyai/providers/grok.py +39 -0
  31. stratifyai/providers/groq.py +39 -0
  32. stratifyai/providers/ollama.py +43 -0
  33. stratifyai/providers/openai.py +344 -0
  34. stratifyai/providers/openai_compatible.py +372 -0
  35. stratifyai/providers/openrouter.py +39 -0
  36. stratifyai/py.typed +2 -0
  37. stratifyai/rag.py +381 -0
  38. stratifyai/retry.py +185 -0
  39. stratifyai/router.py +643 -0
  40. stratifyai/summarization.py +179 -0
  41. stratifyai/utils/__init__.py +11 -0
  42. stratifyai/utils/bedrock_validator.py +136 -0
  43. stratifyai/utils/code_extractor.py +327 -0
  44. stratifyai/utils/csv_extractor.py +197 -0
  45. stratifyai/utils/file_analyzer.py +192 -0
  46. stratifyai/utils/json_extractor.py +219 -0
  47. stratifyai/utils/log_extractor.py +267 -0
  48. stratifyai/utils/model_selector.py +324 -0
  49. stratifyai/utils/provider_validator.py +442 -0
  50. stratifyai/utils/token_counter.py +186 -0
  51. stratifyai/vectordb.py +344 -0
  52. stratifyai-0.1.0.dist-info/METADATA +263 -0
  53. stratifyai-0.1.0.dist-info/RECORD +57 -0
  54. stratifyai-0.1.0.dist-info/WHEEL +5 -0
  55. stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
  56. stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
  57. stratifyai-0.1.0.dist-info/top_level.txt +2 -0
stratifyai/router.py ADDED
@@ -0,0 +1,643 @@
1
+ """Intelligent model router for selecting optimal LLM providers."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Dict, List, Optional, Tuple
6
+ import re
7
+
8
+ from .config import MODEL_CATALOG
9
+ from .models import Message
10
+ from .utils.file_analyzer import FileType
11
+
12
+
13
+ class RoutingStrategy(str, Enum):
14
+ """Routing strategy types."""
15
+ COST = "cost"
16
+ QUALITY = "quality"
17
+ LATENCY = "latency"
18
+ HYBRID = "hybrid"
19
+
20
+
21
+ @dataclass
22
+ class ModelMetadata:
23
+ """Metadata for a specific model."""
24
+ provider: str
25
+ model: str
26
+ quality_score: float # 0.0 - 1.0
27
+ cost_per_1m_input: float # USD per 1M tokens
28
+ cost_per_1m_output: float # USD per 1M tokens
29
+ avg_latency_ms: float # Average response time in milliseconds
30
+ context_window: int # Maximum context tokens
31
+ capabilities: List[str] # e.g., ["reasoning", "vision", "tools"]
32
+ reasoning_model: bool = False
33
+ supports_streaming: bool = True
34
+
35
+
36
+ class Router:
37
+ """Intelligent model router for selecting optimal providers."""
38
+
39
+ def __init__(
40
+ self,
41
+ strategy: RoutingStrategy = RoutingStrategy.HYBRID,
42
+ preferred_providers: Optional[List[str]] = None,
43
+ excluded_providers: Optional[List[str]] = None,
44
+ ):
45
+ """
46
+ Initialize router.
47
+
48
+ Args:
49
+ strategy: Routing strategy to use
50
+ preferred_providers: List of preferred providers (prioritized)
51
+ excluded_providers: List of providers to exclude
52
+ """
53
+ self.strategy = strategy
54
+ self.preferred_providers = preferred_providers or []
55
+ self.excluded_providers = excluded_providers or []
56
+ self._load_model_metadata()
57
+
58
+ def _load_model_metadata(self) -> None:
59
+ """Load model metadata from MODEL_CATALOG."""
60
+ self.model_metadata: Dict[str, ModelMetadata] = {}
61
+
62
+ # Quality scores (estimated based on benchmarks)
63
+ # These are rough estimates - should be updated with real benchmark data
64
+ quality_scores = {
65
+ # OpenAI
66
+ "gpt-5": 0.98,
67
+ "o3-mini": 0.95,
68
+ "gpt-4.5-turbo-20250205": 0.93,
69
+ "gpt-4.1-turbo": 0.90,
70
+ "gpt-4.1-mini": 0.82,
71
+ "o1-mini": 0.88,
72
+ "o1": 0.96,
73
+
74
+ # Anthropic
75
+ "claude-sonnet-4-5-20250929": 0.94,
76
+ "claude-3-5-sonnet-20241022": 0.92,
77
+ "claude-3-5-haiku-20241022": 0.80,
78
+
79
+ # Google
80
+ "gemini-2.5-pro": 0.91,
81
+ "gemini-2.5-flash": 0.85,
82
+ "gemini-2.5-flash-lite": 0.78,
83
+
84
+ # DeepSeek
85
+ "deepseek-chat": 0.85,
86
+ "deepseek-reasoner": 0.90,
87
+
88
+ # Groq
89
+ "llama-3.1-70b-versatile": 0.83,
90
+ "llama-3.1-8b-instant": 0.75,
91
+ "mixtral-8x7b-32768": 0.80,
92
+
93
+ # Grok
94
+ "grok-beta": 0.87,
95
+
96
+ # OpenRouter (using same as original providers)
97
+ "anthropic/claude-3-5-sonnet": 0.92,
98
+ "openai/gpt-4-turbo": 0.90,
99
+
100
+ # Ollama (local models)
101
+ "llama3.2": 0.70,
102
+ "mistral": 0.68,
103
+ "codellama": 0.72,
104
+ }
105
+
106
+ # Average latency (ms) - rough estimates
107
+ latency_estimates = {
108
+ # OpenAI
109
+ "gpt-5": 3500,
110
+ "o3-mini": 8000,
111
+ "gpt-4.5-turbo-20250205": 2500,
112
+ "gpt-4.1-turbo": 2000,
113
+ "gpt-4.1-mini": 800,
114
+ "o1-mini": 5000,
115
+ "o1": 10000,
116
+
117
+ # Anthropic
118
+ "claude-sonnet-4-5-20250929": 2800,
119
+ "claude-3-5-sonnet-20241022": 2200,
120
+ "claude-3-5-haiku-20241022": 1200,
121
+
122
+ # Google
123
+ "gemini-2.5-pro": 2000,
124
+ "gemini-2.5-flash": 1000,
125
+ "gemini-2.5-flash-lite": 600,
126
+
127
+ # DeepSeek
128
+ "deepseek-chat": 1500,
129
+ "deepseek-reasoner": 6000,
130
+
131
+ # Groq (known for speed)
132
+ "llama-3.1-70b-versatile": 400,
133
+ "llama-3.1-8b-instant": 200,
134
+ "mixtral-8x7b-32768": 350,
135
+
136
+ # Grok
137
+ "grok-beta": 2500,
138
+
139
+ # OpenRouter
140
+ "anthropic/claude-3-5-sonnet": 2500,
141
+ "openai/gpt-4-turbo": 2200,
142
+
143
+ # Ollama (local, very fast)
144
+ "llama3.2": 100,
145
+ "mistral": 120,
146
+ "codellama": 110,
147
+ }
148
+
149
+ # Build metadata from MODEL_CATALOG
150
+ for provider, models in MODEL_CATALOG.items():
151
+ # Skip if provider is excluded
152
+ if provider in self.excluded_providers:
153
+ continue
154
+
155
+ for model_name, model_info in models.items():
156
+ key = f"{provider}/{model_name}"
157
+
158
+ # Extract capabilities
159
+ capabilities = []
160
+ if model_info.get("supports_vision"):
161
+ capabilities.append("vision")
162
+ if model_info.get("supports_tools"):
163
+ capabilities.append("tools")
164
+ if model_info.get("reasoning_model"):
165
+ capabilities.append("reasoning")
166
+
167
+ self.model_metadata[key] = ModelMetadata(
168
+ provider=provider,
169
+ model=model_name,
170
+ quality_score=quality_scores.get(model_name, 0.75),
171
+ cost_per_1m_input=model_info.get("cost_input", 0.0),
172
+ cost_per_1m_output=model_info.get("cost_output", 0.0),
173
+ avg_latency_ms=latency_estimates.get(model_name, 2000),
174
+ context_window=model_info.get("context", 8192),
175
+ capabilities=capabilities,
176
+ reasoning_model=model_info.get("reasoning_model", False),
177
+ supports_streaming=True,
178
+ )
179
+
180
+ def route(
181
+ self,
182
+ messages: List[Message],
183
+ required_capabilities: Optional[List[str]] = None,
184
+ max_cost_per_1k_tokens: Optional[float] = None,
185
+ max_latency_ms: Optional[float] = None,
186
+ min_context_window: Optional[int] = None,
187
+ ) -> Tuple[str, str]:
188
+ """
189
+ Select the best model for the given request.
190
+
191
+ Args:
192
+ messages: Conversation messages
193
+ required_capabilities: Required model capabilities (e.g., ["vision", "tools"])
194
+ max_cost_per_1k_tokens: Maximum acceptable cost per 1k tokens
195
+ max_latency_ms: Maximum acceptable latency in milliseconds
196
+ min_context_window: Minimum required context window
197
+
198
+ Returns:
199
+ Tuple of (provider, model) for the selected model
200
+ """
201
+ # Analyze prompt complexity
202
+ complexity = self._analyze_complexity(messages)
203
+
204
+ # Filter models by requirements
205
+ candidates = self._filter_candidates(
206
+ required_capabilities,
207
+ max_cost_per_1k_tokens,
208
+ max_latency_ms,
209
+ min_context_window,
210
+ )
211
+
212
+ if not candidates:
213
+ raise ValueError(
214
+ "No models meet the specified requirements. "
215
+ "Consider relaxing constraints."
216
+ )
217
+
218
+ # Select model based on strategy
219
+ if self.strategy == RoutingStrategy.COST:
220
+ selected_key = self._select_by_cost(candidates)
221
+ elif self.strategy == RoutingStrategy.QUALITY:
222
+ selected_key = self._select_by_quality(candidates, complexity)
223
+ elif self.strategy == RoutingStrategy.LATENCY:
224
+ selected_key = self._select_by_latency(candidates)
225
+ elif self.strategy == RoutingStrategy.HYBRID:
226
+ selected_key = self._select_hybrid(candidates, complexity)
227
+ else:
228
+ selected_key = list(candidates.keys())[0]
229
+
230
+ metadata = self.model_metadata[selected_key]
231
+ return metadata.provider, metadata.model
232
+
233
+ def _analyze_complexity(self, messages: List[Message]) -> float:
234
+ """
235
+ Analyze prompt complexity (0.0 - 1.0).
236
+
237
+ Higher complexity scores indicate more difficult tasks requiring
238
+ better models.
239
+
240
+ Args:
241
+ messages: Conversation messages
242
+
243
+ Returns:
244
+ Complexity score from 0.0 (simple) to 1.0 (complex)
245
+ """
246
+ # Combine all message content
247
+ text = " ".join(msg.content for msg in messages)
248
+
249
+ # Initialize score
250
+ complexity_score = 0.0
251
+
252
+ # 1. Check for reasoning keywords (40% weight)
253
+ reasoning_keywords = [
254
+ r'\banalyze\b', r'\bexplain\b', r'\breasoning\b', r'\bproof\b',
255
+ r'\bstep by step\b', r'\bcomplex\b', r'\bcalculate\b',
256
+ r'\bderive\b', r'\bsolve\b', r'\bprove\b', r'\bthink\b',
257
+ r'\bcompare\b', r'\bevaluate\b', r'\bdetailed\b',
258
+ ]
259
+
260
+ reasoning_matches = sum(
261
+ 1 for pattern in reasoning_keywords
262
+ if re.search(pattern, text.lower())
263
+ )
264
+ complexity_score += min(reasoning_matches / len(reasoning_keywords), 1.0) * 0.4
265
+
266
+ # 2. Length-based complexity (20% weight)
267
+ # Longer prompts often indicate more complex tasks
268
+ text_length = len(text)
269
+ length_score = min(text_length / 2000, 1.0) # Normalize to 2000 chars
270
+ complexity_score += length_score * 0.2
271
+
272
+ # 3. Check for code/technical content (20% weight)
273
+ code_indicators = [
274
+ r'```', r'function\s+\w+', r'class\s+\w+', r'def\s+\w+',
275
+ r'import\s+\w+', r'\/\/.*', r'\/\*.*\*\/', r'\bcode\b',
276
+ ]
277
+
278
+ code_matches = sum(
279
+ 1 for pattern in code_indicators
280
+ if re.search(pattern, text, re.IGNORECASE)
281
+ )
282
+ complexity_score += min(code_matches / len(code_indicators), 1.0) * 0.2
283
+
284
+ # 4. Multi-turn conversation (10% weight)
285
+ # More messages can indicate more complex context
286
+ message_count = len(messages)
287
+ conversation_score = min(message_count / 10, 1.0)
288
+ complexity_score += conversation_score * 0.1
289
+
290
+ # 5. Mathematical content (10% weight)
291
+ math_indicators = [
292
+ r'\d+\s*[+\-*/]\s*\d+', r'\bequation\b', r'\bformula\b',
293
+ r'\bcalculus\b', r'\balgebra\b', r'\bintegral\b',
294
+ ]
295
+
296
+ math_matches = sum(
297
+ 1 for pattern in math_indicators
298
+ if re.search(pattern, text, re.IGNORECASE)
299
+ )
300
+ complexity_score += min(math_matches / len(math_indicators), 1.0) * 0.1
301
+
302
+ return min(complexity_score, 1.0)
303
+
304
+ def _filter_candidates(
305
+ self,
306
+ required_capabilities: Optional[List[str]],
307
+ max_cost_per_1k: Optional[float],
308
+ max_latency_ms: Optional[float],
309
+ min_context_window: Optional[int],
310
+ ) -> Dict[str, ModelMetadata]:
311
+ """Filter models based on requirements."""
312
+ candidates = {}
313
+
314
+ for key, meta in self.model_metadata.items():
315
+ # Check capabilities
316
+ if required_capabilities:
317
+ if not all(cap in meta.capabilities for cap in required_capabilities):
318
+ continue
319
+
320
+ # Check cost constraint
321
+ if max_cost_per_1k:
322
+ avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
323
+ if avg_cost_per_1k > max_cost_per_1k:
324
+ continue
325
+
326
+ # Check latency constraint
327
+ if max_latency_ms and meta.avg_latency_ms > max_latency_ms:
328
+ continue
329
+
330
+ # Check context window constraint
331
+ if min_context_window and meta.context_window < min_context_window:
332
+ continue
333
+
334
+ candidates[key] = meta
335
+
336
+ # Prioritize preferred providers
337
+ if self.preferred_providers:
338
+ prioritized = {}
339
+ for key, meta in candidates.items():
340
+ if meta.provider in self.preferred_providers:
341
+ prioritized[key] = meta
342
+ if prioritized:
343
+ return prioritized
344
+
345
+ return candidates
346
+
347
+ def _select_by_cost(self, candidates: Dict[str, ModelMetadata]) -> str:
348
+ """Select cheapest model."""
349
+ def cost_score(meta: ModelMetadata) -> float:
350
+ # Average cost per 1M tokens (input + output weighted equally)
351
+ return (meta.cost_per_1m_input + meta.cost_per_1m_output) / 2
352
+
353
+ return min(candidates.keys(), key=lambda k: cost_score(candidates[k]))
354
+
355
+ def _select_by_quality(
356
+ self,
357
+ candidates: Dict[str, ModelMetadata],
358
+ complexity: float
359
+ ) -> str:
360
+ """
361
+ Select highest quality model.
362
+
363
+ For high complexity tasks, prioritize models with reasoning capabilities.
364
+ """
365
+ def quality_score(meta: ModelMetadata) -> float:
366
+ base_score = meta.quality_score
367
+
368
+ # Boost reasoning models for complex tasks
369
+ if complexity > 0.6 and meta.reasoning_model:
370
+ base_score += 0.05
371
+
372
+ return base_score
373
+
374
+ return max(candidates.keys(), key=lambda k: quality_score(candidates[k]))
375
+
376
+ def _select_by_latency(self, candidates: Dict[str, ModelMetadata]) -> str:
377
+ """Select fastest model."""
378
+ return min(candidates.keys(), key=lambda k: candidates[k].avg_latency_ms)
379
+
380
+ def _select_hybrid(
381
+ self,
382
+ candidates: Dict[str, ModelMetadata],
383
+ complexity: float
384
+ ) -> str:
385
+ """
386
+ Hybrid selection balancing cost, quality, and latency.
387
+
388
+ Adjusts weights based on task complexity:
389
+ - Low complexity: Prioritize cost (60%) and speed (30%), quality (10%)
390
+ - High complexity: Prioritize quality (60%) and cost (30%), speed (10%)
391
+ """
392
+ scores = {}
393
+
394
+ # Adjust weights based on complexity
395
+ quality_weight = 0.1 + (complexity * 0.5) # 0.1 to 0.6
396
+ cost_weight = 0.6 - (complexity * 0.3) # 0.6 to 0.3
397
+ latency_weight = 0.3 - (complexity * 0.2) # 0.3 to 0.1
398
+
399
+ for key, meta in candidates.items():
400
+ # Quality score (0-1, higher is better)
401
+ quality_score = meta.quality_score
402
+
403
+ # Cost score (0-1, lower cost is better)
404
+ # Normalize to typical range: $0.001 to $0.050 per 1k tokens
405
+ avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
406
+ cost_score = max(0, 1 - (avg_cost_per_1k / 0.050))
407
+
408
+ # Latency score (0-1, lower latency is better)
409
+ # Normalize to typical range: 100ms to 10000ms
410
+ latency_score = max(0, 1 - (meta.avg_latency_ms / 10000))
411
+
412
+ # Calculate weighted score
413
+ scores[key] = (
414
+ quality_weight * quality_score +
415
+ cost_weight * cost_score +
416
+ latency_weight * latency_score
417
+ )
418
+
419
+ return max(scores, key=scores.get)
420
+
421
+ def get_model_info(self, provider: str, model: str) -> Optional[ModelMetadata]:
422
+ """Get metadata for a specific model."""
423
+ key = f"{provider}/{model}"
424
+ return self.model_metadata.get(key)
425
+
426
+ def route_for_extraction(
427
+ self,
428
+ file_type: FileType,
429
+ extraction_mode: str = "schema",
430
+ max_cost_per_1k_tokens: Optional[float] = None,
431
+ ) -> Tuple[str, str]:
432
+ """
433
+ Select the best model for file extraction tasks.
434
+
435
+ Prioritizes quality over cost/latency for extraction tasks,
436
+ with task-specific optimizations.
437
+
438
+ Args:
439
+ file_type: Type of file being extracted
440
+ extraction_mode: Mode of extraction ("schema", "errors", "structure", "summary")
441
+ max_cost_per_1k_tokens: Optional cost constraint
442
+
443
+ Returns:
444
+ Tuple of (provider, model) optimized for extraction
445
+ """
446
+ # Define extraction-specific quality weights
447
+ if extraction_mode == "schema":
448
+ quality_weight = 0.90
449
+ cost_weight = 0.10
450
+ elif extraction_mode == "errors":
451
+ quality_weight = 0.80
452
+ cost_weight = 0.20
453
+ elif extraction_mode == "structure":
454
+ quality_weight = 0.85
455
+ cost_weight = 0.15
456
+ else: # summary
457
+ quality_weight = 0.70
458
+ cost_weight = 0.30
459
+
460
+ # Filter candidates
461
+ candidates = {}
462
+ for key, meta in self.model_metadata.items():
463
+ # Apply cost constraint if specified
464
+ if max_cost_per_1k_tokens:
465
+ avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
466
+ if avg_cost_per_1k > max_cost_per_1k_tokens:
467
+ continue
468
+
469
+ candidates[key] = meta
470
+
471
+ if not candidates:
472
+ raise ValueError(
473
+ "No models meet cost constraints. "
474
+ "Consider increasing max_cost_per_1k_tokens."
475
+ )
476
+
477
+ # Score candidates with extraction-focused weights
478
+ scores = {}
479
+ for key, meta in candidates.items():
480
+ # Quality score (0-1, higher is better)
481
+ quality_score = meta.quality_score
482
+
483
+ # Boost for reasoning models on error extraction
484
+ if extraction_mode == "errors" and meta.reasoning_model:
485
+ quality_score += 0.05
486
+
487
+ # Cost score (0-1, lower cost is better)
488
+ avg_cost_per_1k = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
489
+ cost_score = max(0, 1 - (avg_cost_per_1k / 0.050))
490
+
491
+ # Calculate weighted score (no latency for extraction)
492
+ scores[key] = (
493
+ quality_weight * quality_score +
494
+ cost_weight * cost_score
495
+ )
496
+
497
+ # Select highest scoring model
498
+ best_key = max(scores, key=scores.get)
499
+ metadata = self.model_metadata[best_key]
500
+ return metadata.provider, metadata.model
501
+
502
+ def get_fallback_chain(
503
+ self,
504
+ messages: List[Message],
505
+ count: int = 3,
506
+ required_capabilities: Optional[List[str]] = None,
507
+ max_cost_per_1k_tokens: Optional[float] = None,
508
+ max_latency_ms: Optional[float] = None,
509
+ min_context_window: Optional[int] = None,
510
+ exclude_models: Optional[List[str]] = None,
511
+ ) -> List[Tuple[str, str]]:
512
+ """
513
+ Get a ranked list of fallback models for resilient routing.
514
+
515
+ Returns models ranked by the current strategy, allowing automatic
516
+ fallback if the primary model fails.
517
+
518
+ Args:
519
+ messages: Conversation messages (used for complexity analysis)
520
+ count: Number of fallback models to return (default: 3)
521
+ required_capabilities: Required model capabilities
522
+ max_cost_per_1k_tokens: Maximum acceptable cost
523
+ max_latency_ms: Maximum acceptable latency
524
+ min_context_window: Minimum required context window
525
+ exclude_models: Models to exclude from results
526
+
527
+ Returns:
528
+ List of (provider, model) tuples ranked by strategy
529
+
530
+ Example:
531
+ >>> router = Router(strategy=RoutingStrategy.HYBRID)
532
+ >>> fallbacks = router.get_fallback_chain(messages, count=3)
533
+ >>> # Returns: [("openai", "gpt-4o"), ("anthropic", "claude-3-5-sonnet"), ...]
534
+ """
535
+ # Analyze complexity for scoring
536
+ complexity = self._analyze_complexity(messages)
537
+
538
+ # Filter candidates
539
+ candidates = self._filter_candidates(
540
+ required_capabilities,
541
+ max_cost_per_1k_tokens,
542
+ max_latency_ms,
543
+ min_context_window,
544
+ )
545
+
546
+ # Exclude specified models
547
+ if exclude_models:
548
+ candidates = {
549
+ k: v for k, v in candidates.items()
550
+ if v.model not in exclude_models
551
+ }
552
+
553
+ if not candidates:
554
+ return []
555
+
556
+ # Score all candidates based on strategy
557
+ scores = self._score_candidates(candidates, complexity)
558
+
559
+ # Sort by score (descending) and return top N
560
+ ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
561
+
562
+ result = []
563
+ for key, _ in ranked[:count]:
564
+ meta = self.model_metadata[key]
565
+ result.append((meta.provider, meta.model))
566
+
567
+ return result
568
+
569
+ def _score_candidates(
570
+ self,
571
+ candidates: Dict[str, ModelMetadata],
572
+ complexity: float,
573
+ ) -> Dict[str, float]:
574
+ """
575
+ Score all candidates based on current routing strategy.
576
+
577
+ Args:
578
+ candidates: Filtered model candidates
579
+ complexity: Task complexity score (0.0 - 1.0)
580
+
581
+ Returns:
582
+ Dictionary mapping model keys to scores
583
+ """
584
+ scores = {}
585
+
586
+ for key, meta in candidates.items():
587
+ if self.strategy == RoutingStrategy.COST:
588
+ # Lower cost = higher score
589
+ avg_cost = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 2
590
+ scores[key] = max(0, 1 - (avg_cost / 100)) # Normalize to $100/1M
591
+
592
+ elif self.strategy == RoutingStrategy.QUALITY:
593
+ scores[key] = meta.quality_score
594
+ if complexity > 0.6 and meta.reasoning_model:
595
+ scores[key] += 0.05
596
+
597
+ elif self.strategy == RoutingStrategy.LATENCY:
598
+ # Lower latency = higher score
599
+ scores[key] = max(0, 1 - (meta.avg_latency_ms / 10000))
600
+
601
+ else: # HYBRID
602
+ # Dynamic weights based on complexity
603
+ quality_weight = 0.1 + (complexity * 0.5)
604
+ cost_weight = 0.6 - (complexity * 0.3)
605
+ latency_weight = 0.3 - (complexity * 0.2)
606
+
607
+ quality_score = meta.quality_score
608
+ avg_cost = (meta.cost_per_1m_input + meta.cost_per_1m_output) / 1000
609
+ cost_score = max(0, 1 - (avg_cost / 0.050))
610
+ latency_score = max(0, 1 - (meta.avg_latency_ms / 10000))
611
+
612
+ scores[key] = (
613
+ quality_weight * quality_score +
614
+ cost_weight * cost_score +
615
+ latency_weight * latency_score
616
+ )
617
+
618
+ return scores
619
+
620
+ def list_models(
621
+ self,
622
+ required_capabilities: Optional[List[str]] = None,
623
+ ) -> List[Tuple[str, str, ModelMetadata]]:
624
+ """
625
+ List all available models with their metadata.
626
+
627
+ Args:
628
+ required_capabilities: Filter by required capabilities
629
+
630
+ Returns:
631
+ List of (provider, model, metadata) tuples
632
+ """
633
+ results = []
634
+
635
+ for key, meta in self.model_metadata.items():
636
+ # Check capabilities
637
+ if required_capabilities:
638
+ if not all(cap in meta.capabilities for cap in required_capabilities):
639
+ continue
640
+
641
+ results.append((meta.provider, meta.model, meta))
642
+
643
+ return results