adaptive-memory-multi-model-router 1.2.2 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +146 -66
  3. package/dist/index.d.ts +1 -1
  4. package/dist/index.js +1 -1
  5. package/dist/integrations/airtable.js +20 -0
  6. package/dist/integrations/discord.js +18 -0
  7. package/dist/integrations/github.js +23 -0
  8. package/dist/integrations/gmail.js +19 -0
  9. package/dist/integrations/google-calendar.js +18 -0
  10. package/dist/integrations/index.js +61 -0
  11. package/dist/integrations/jira.js +21 -0
  12. package/dist/integrations/linear.js +19 -0
  13. package/dist/integrations/notion.js +19 -0
  14. package/dist/integrations/slack.js +18 -0
  15. package/dist/integrations/telegram.js +19 -0
  16. package/dist/providers/registry.js +7 -3
  17. package/docs/ARCHITECTURAL-IMPROVEMENTS-2025.md +1391 -0
  18. package/docs/ARCHITECTURAL-IMPROVEMENTS-REVISED-2025.md +1051 -0
  19. package/docs/CONFIGURATION.md +476 -0
  20. package/docs/COUNCIL_DECISION.json +308 -0
  21. package/docs/COUNCIL_SUMMARY.md +265 -0
  22. package/docs/COUNCIL_V2.2_DECISION.md +416 -0
  23. package/docs/IMPROVEMENT_ROADMAP.md +515 -0
  24. package/docs/LLM_COUNCIL_DECISION.md +508 -0
  25. package/docs/QUICK_START_VISIBILITY.md +782 -0
  26. package/docs/REDDIT_GAP_ANALYSIS.md +299 -0
  27. package/docs/RESEARCH_BACKED_IMPROVEMENTS.md +1180 -0
  28. package/docs/TMLPD_QNA.md +751 -0
  29. package/docs/TMLPD_V2.1_COMPLETE.md +763 -0
  30. package/docs/TMLPD_V2.2_RESEARCH_ROADMAP.md +754 -0
  31. package/docs/V2.2_IMPLEMENTATION_COMPLETE.md +446 -0
  32. package/docs/V2_IMPLEMENTATION_GUIDE.md +388 -0
  33. package/docs/VISIBILITY_ADOPTION_PLAN.md +1005 -0
  34. package/docs/launch-content/LAUNCH_EXECUTION_CHECKLIST.md +421 -0
  35. package/docs/launch-content/README.md +457 -0
  36. package/docs/launch-content/assets/cost_comparison_100_tasks.png +0 -0
  37. package/docs/launch-content/assets/cumulative_savings.png +0 -0
  38. package/docs/launch-content/assets/parallel_speedup.png +0 -0
  39. package/docs/launch-content/assets/provider_pricing_comparison.png +0 -0
  40. package/docs/launch-content/assets/task_breakdown_comparison.png +0 -0
  41. package/docs/launch-content/generate_charts.py +313 -0
  42. package/docs/launch-content/hn_show_post.md +139 -0
  43. package/docs/launch-content/partner_outreach_templates.md +745 -0
  44. package/docs/launch-content/reddit_posts.md +467 -0
  45. package/docs/launch-content/twitter_thread.txt +460 -0
  46. package/examples/QUICKSTART.md +1 -1
  47. package/openclaw-alexa-bridge/ALL_REMAINING_FIXES_PLAN.md +313 -0
  48. package/openclaw-alexa-bridge/REMAINING_FIXES_SUMMARY.md +277 -0
  49. package/openclaw-alexa-bridge/src/alexa_handler_no_tmlpd.js +1234 -0
  50. package/openclaw-alexa-bridge/test_fixes.js +77 -0
  51. package/package.json +120 -29
  52. package/package.json.tmp +0 -0
  53. package/qna/TMLPD_QNA.md +3 -3
  54. package/skill/SKILL.md +2 -2
  55. package/src/__tests__/integration/tmpld_integration.test.py +540 -0
  56. package/src/agents/skill_enhanced_agent.py +318 -0
  57. package/src/memory/__init__.py +15 -0
  58. package/src/memory/agentic_memory.py +353 -0
  59. package/src/memory/semantic_memory.py +444 -0
  60. package/src/memory/simple_memory.py +466 -0
  61. package/src/memory/working_memory.py +447 -0
  62. package/src/orchestration/__init__.py +52 -0
  63. package/src/orchestration/execution_engine.py +353 -0
  64. package/src/orchestration/halo_orchestrator.py +367 -0
  65. package/src/orchestration/mcts_workflow.py +498 -0
  66. package/src/orchestration/role_assigner.py +473 -0
  67. package/src/orchestration/task_planner.py +522 -0
  68. package/src/providers/__init__.py +67 -0
  69. package/src/providers/anthropic.py +304 -0
  70. package/src/providers/base.py +241 -0
  71. package/src/providers/cerebras.py +373 -0
  72. package/src/providers/registry.py +476 -0
  73. package/src/routing/__init__.py +30 -0
  74. package/src/routing/universal_router.py +621 -0
  75. package/src/skills/TMLPD-QUICKREF.md +210 -0
  76. package/src/skills/TMLPD-SETUP-SUMMARY.md +157 -0
  77. package/src/skills/TMLPD.md +540 -0
  78. package/src/skills/__tests__/skill_manager.test.ts +328 -0
  79. package/src/skills/skill_manager.py +385 -0
  80. package/src/skills/test-tmlpd.sh +108 -0
  81. package/src/skills/tmlpd-category.yaml +67 -0
  82. package/src/skills/tmlpd-monitoring.yaml +188 -0
  83. package/src/skills/tmlpd-phase.yaml +132 -0
  84. package/src/state/__init__.py +17 -0
  85. package/src/state/simple_checkpoint.py +508 -0
  86. package/src/tmlpd_agent.py +464 -0
  87. package/src/tmpld_v2.py +427 -0
  88. package/src/workflows/__init__.py +18 -0
  89. package/src/workflows/advanced_difficulty_classifier.py +377 -0
  90. package/src/workflows/chaining_executor.py +417 -0
  91. package/src/workflows/difficulty_integration.py +209 -0
  92. package/src/workflows/orchestrator.py +469 -0
  93. package/src/workflows/orchestrator_executor.py +456 -0
  94. package/src/workflows/parallelization_executor.py +382 -0
  95. package/src/workflows/router.py +311 -0
  96. package/test_integration_simple.py +86 -0
  97. package/test_mcts_workflow.py +150 -0
  98. package/test_templd_integration.py +262 -0
  99. package/test_universal_router.py +275 -0
  100. package/tmlpd-pi-extension/README.md +36 -0
  101. package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts +114 -0
  102. package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts.map +1 -0
  103. package/tmlpd-pi-extension/dist/cache/prefixCache.js +285 -0
  104. package/tmlpd-pi-extension/dist/cache/prefixCache.js.map +1 -0
  105. package/tmlpd-pi-extension/dist/cache/responseCache.d.ts +58 -0
  106. package/tmlpd-pi-extension/dist/cache/responseCache.d.ts.map +1 -0
  107. package/tmlpd-pi-extension/dist/cache/responseCache.js +153 -0
  108. package/tmlpd-pi-extension/dist/cache/responseCache.js.map +1 -0
  109. package/tmlpd-pi-extension/dist/cli.js +59 -0
  110. package/tmlpd-pi-extension/dist/cost/costTracker.d.ts +95 -0
  111. package/tmlpd-pi-extension/dist/cost/costTracker.d.ts.map +1 -0
  112. package/tmlpd-pi-extension/dist/cost/costTracker.js +240 -0
  113. package/tmlpd-pi-extension/dist/cost/costTracker.js.map +1 -0
  114. package/tmlpd-pi-extension/dist/index.d.ts +723 -0
  115. package/tmlpd-pi-extension/dist/index.d.ts.map +1 -0
  116. package/tmlpd-pi-extension/dist/index.js +239 -0
  117. package/tmlpd-pi-extension/dist/index.js.map +1 -0
  118. package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts +82 -0
  119. package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts.map +1 -0
  120. package/tmlpd-pi-extension/dist/memory/episodicMemory.js +145 -0
  121. package/tmlpd-pi-extension/dist/memory/episodicMemory.js.map +1 -0
  122. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts +102 -0
  123. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts.map +1 -0
  124. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js +207 -0
  125. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js.map +1 -0
  126. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts +85 -0
  127. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts.map +1 -0
  128. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js +210 -0
  129. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js.map +1 -0
  130. package/tmlpd-pi-extension/dist/providers/localProvider.d.ts +102 -0
  131. package/tmlpd-pi-extension/dist/providers/localProvider.d.ts.map +1 -0
  132. package/tmlpd-pi-extension/dist/providers/localProvider.js +338 -0
  133. package/tmlpd-pi-extension/dist/providers/localProvider.js.map +1 -0
  134. package/tmlpd-pi-extension/dist/providers/registry.d.ts +55 -0
  135. package/tmlpd-pi-extension/dist/providers/registry.d.ts.map +1 -0
  136. package/tmlpd-pi-extension/dist/providers/registry.js +138 -0
  137. package/tmlpd-pi-extension/dist/providers/registry.js.map +1 -0
  138. package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts +68 -0
  139. package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts.map +1 -0
  140. package/tmlpd-pi-extension/dist/routing/advancedRouter.js +332 -0
  141. package/tmlpd-pi-extension/dist/routing/advancedRouter.js.map +1 -0
  142. package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts +101 -0
  143. package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts.map +1 -0
  144. package/tmlpd-pi-extension/dist/tools/tmlpdTools.js +368 -0
  145. package/tmlpd-pi-extension/dist/tools/tmlpdTools.js.map +1 -0
  146. package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts +96 -0
  147. package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts.map +1 -0
  148. package/tmlpd-pi-extension/dist/utils/batchProcessor.js +170 -0
  149. package/tmlpd-pi-extension/dist/utils/batchProcessor.js.map +1 -0
  150. package/tmlpd-pi-extension/dist/utils/compression.d.ts +61 -0
  151. package/tmlpd-pi-extension/dist/utils/compression.d.ts.map +1 -0
  152. package/tmlpd-pi-extension/dist/utils/compression.js +281 -0
  153. package/tmlpd-pi-extension/dist/utils/compression.js.map +1 -0
  154. package/tmlpd-pi-extension/dist/utils/reliability.d.ts +74 -0
  155. package/tmlpd-pi-extension/dist/utils/reliability.d.ts.map +1 -0
  156. package/tmlpd-pi-extension/dist/utils/reliability.js +177 -0
  157. package/tmlpd-pi-extension/dist/utils/reliability.js.map +1 -0
  158. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts +117 -0
  159. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts.map +1 -0
  160. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js +246 -0
  161. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js.map +1 -0
  162. package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts +50 -0
  163. package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts.map +1 -0
  164. package/tmlpd-pi-extension/dist/utils/tokenUtils.js +124 -0
  165. package/tmlpd-pi-extension/dist/utils/tokenUtils.js.map +1 -0
  166. package/tmlpd-pi-extension/examples/QUICKSTART.md +183 -0
  167. package/tmlpd-pi-extension/package-lock.json +75 -0
  168. package/tmlpd-pi-extension/package.json +172 -0
  169. package/tmlpd-pi-extension/python/examples.py +53 -0
  170. package/tmlpd-pi-extension/python/integrations.py +330 -0
  171. package/tmlpd-pi-extension/python/setup.py +28 -0
  172. package/tmlpd-pi-extension/python/tmlpd.py +369 -0
  173. package/tmlpd-pi-extension/qna/REDDIT_GAP_ANALYSIS.md +299 -0
  174. package/tmlpd-pi-extension/qna/TMLPD_QNA.md +751 -0
  175. package/tmlpd-pi-extension/skill/SKILL.md +238 -0
  176. package/{src → tmlpd-pi-extension/src}/index.ts +1 -1
  177. package/tmlpd-pi-extension/tsconfig.json +18 -0
  178. package/demo/research-demo.js +0 -266
  179. package/notebooks/quickstart.ipynb +0 -157
  180. package/rust/tmlpd.h +0 -268
  181. package/src/cache/prefixCache.ts +0 -365
  182. package/src/routing/advancedRouter.ts +0 -406
  183. package/src/utils/speculativeDecoding.ts +0 -344
  184. /package/{src → tmlpd-pi-extension/src}/cache/responseCache.ts +0 -0
  185. /package/{src → tmlpd-pi-extension/src}/cost/costTracker.ts +0 -0
  186. /package/{src → tmlpd-pi-extension/src}/memory/episodicMemory.ts +0 -0
  187. /package/{src → tmlpd-pi-extension/src}/orchestration/haloOrchestrator.ts +0 -0
  188. /package/{src → tmlpd-pi-extension/src}/orchestration/mctsWorkflow.ts +0 -0
  189. /package/{src → tmlpd-pi-extension/src}/providers/localProvider.ts +0 -0
  190. /package/{src → tmlpd-pi-extension/src}/providers/registry.ts +0 -0
  191. /package/{src → tmlpd-pi-extension/src}/tools/tmlpdTools.ts +0 -0
  192. /package/{src → tmlpd-pi-extension/src}/utils/batchProcessor.ts +0 -0
  193. /package/{src → tmlpd-pi-extension/src}/utils/compression.ts +0 -0
  194. /package/{src → tmlpd-pi-extension/src}/utils/reliability.ts +0 -0
  195. /package/{src → tmlpd-pi-extension/src}/utils/tokenUtils.ts +0 -0
@@ -0,0 +1,621 @@
1
+ """
2
+ UniversalModelRouter - Learned Routing with Online Adaptation
3
+
4
+ Based on:
5
+ - arXiv:2502.08773 (UniRoute - Universal Routing)
6
+ - ICLR 2024 (Hybrid LLM - 40% fewer calls to expensive models)
7
+ - ICML 2025 (BEST-Route - 60% cost reduction with <1% quality drop)
8
+
9
+ Key Innovation: Learns model profiles from execution data and adapts
10
+ to new unseen models automatically.
11
+
12
+ Features:
13
+ - Learns feature vectors for each model from execution history
14
+ - Routes based on learned quality profiles, not static rules
15
+ - Online learning: updates profiles from actual outcomes
16
+ - Adapts to new unseen models via clustering + similarity
17
+ - Dynamic quality-cost tradeoff at runtime
18
+ """
19
+
20
+ import asyncio
21
+ import numpy as np
22
+ from typing import Dict, List, Any, Optional, Tuple
23
+ from dataclasses import dataclass, field
24
+ from datetime import datetime
25
+ from collections import defaultdict
26
+ import json
27
+ import logging
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @dataclass
34
+ class ModelProfile:
35
+ """Learned profile of an LLM model"""
36
+ model_id: str
37
+ provider: str
38
+ cost_per_1k_tokens: float
39
+ latency_ms: float
40
+
41
+ # Learned metrics
42
+ avg_quality_score: float = 0.5 # 0-1
43
+ quality_variance: float = 0.1
44
+ total_executions: int = 0
45
+ successful_executions: int = 0
46
+
47
+ # Quality by difficulty
48
+ quality_by_difficulty: Dict[str, float] = field(default_factory=dict)
49
+ # "trivial": 0.95, "simple": 0.90, "medium": 0.85, etc.
50
+
51
+ # Feature vector (for unseen models)
52
+ feature_vector: Optional[List[float]] = None
53
+
54
+ # Execution history for online learning
55
+ recent_outcomes: List[float] = field(default_factory=list)
56
+
57
+ def to_dict(self) -> Dict:
58
+ return {
59
+ "model_id": self.model_id,
60
+ "provider": self.provider,
61
+ "cost_per_1k_tokens": self.cost_per_1k_tokens,
62
+ "latency_ms": self.latency_ms,
63
+ "avg_quality_score": self.avg_quality_score,
64
+ "quality_variance": self.quality_variance,
65
+ "total_executions": self.total_executions,
66
+ "successful_executions": self.successful_executions,
67
+ "quality_by_difficulty": self.quality_by_difficulty,
68
+ "feature_vector": self.feature_vector
69
+ }
70
+
71
+ @classmethod
72
+ def from_dict(cls, data: Dict) -> "ModelProfile":
73
+ """Create ModelProfile from dict (for loading from storage)"""
74
+ return cls(**data)
75
+
76
+
77
+ @dataclass
78
+ class RoutingDecision:
79
+ """Result of routing decision"""
80
+ selected_model: str
81
+ reasoning: str
82
+ predicted_quality: float
83
+ estimated_cost: float
84
+ alternative_models: List[str] = field(default_factory=list)
85
+ confidence: float = 0.0
86
+
87
+
88
+ class UniversalModelRouter:
89
+ """
90
+ Universal Learned Router that adapts to new models
91
+
92
+ Key innovations:
93
+ 1. Learns model quality profiles from execution data
94
+ 2. Routes based on learned profiles (not static rules)
95
+ 3. Adapts to new unseen models via clustering
96
+ 4. Online learning from feedback
97
+ 5. Dynamic quality-cost tradeoff
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ quality_target: float = 0.95,
103
+ cost_weight: float = 0.5,
104
+ learning_rate: float = 0.1
105
+ ):
106
+ """
107
+ Initialize universal router
108
+
109
+ Args:
110
+ quality_target: Minimum quality threshold (0-1)
111
+ cost_weight: Balance between quality and cost (0-1)
112
+ learning_rate: Rate at which to update profiles (0-1)
113
+ """
114
+ self.quality_target = quality_target
115
+ self.cost_weight = cost_weight
116
+ self.learning_rate = learning_rate
117
+
118
+ # Model registry
119
+ self.model_profiles: Dict[str, ModelProfile] = {}
120
+ self._initialize_default_profiles()
121
+
122
+ # Routing history for learning
123
+ self.routing_history: List[Dict] = []
124
+
125
+ def _initialize_default_profiles(self):
126
+ """Initialize profiles for known models"""
127
+ default_models = [
128
+ {
129
+ "model_id": "anthropic/claude-3-5-sonnet-20241022",
130
+ "provider": "anthropic",
131
+ "cost_per_1k_tokens": 0.003,
132
+ "latency_ms": 800,
133
+ "avg_quality_score": 0.98,
134
+ "quality_variance": 0.02
135
+ },
136
+ {
137
+ "model_id": "openai/gpt-4o",
138
+ "provider": "openai",
139
+ "cost_per_1k_tokens": 0.0025,
140
+ "latency_ms": 600,
141
+ "avg_quality_score": 0.95,
142
+ "quality_variance": 0.03
143
+ },
144
+ {
145
+ "model_id": "cerebras/llama-3.3-70b",
146
+ "provider": "cerebras",
147
+ "cost_per_1k_tokens": 0.0001,
148
+ "latency_ms": 200,
149
+ "avg_quality_score": 0.75,
150
+ "quality_variance": 0.10
151
+ },
152
+ {
153
+ "model_id": "groq/llama-3.3-70b",
154
+ "provider": "groq",
155
+ "cost_per_1k_tokens": 0.0003,
156
+ "latency_ms": 250,
157
+ "avg_quality_score": 0.78,
158
+ "quality_variance": 0.08
159
+ },
160
+ {
161
+ "model_id": "together/mixtral-8x7b",
162
+ "provider": "together",
163
+ "cost_per_1k_tokens": 0.0009,
164
+ "latency_ms": 400,
165
+ "avg_quality_score": 0.82,
166
+ "quality_variance": 0.07
167
+ }
168
+ ]
169
+
170
+ for model_data in default_models:
171
+ profile = ModelProfile(**model_data)
172
+ self.model_profiles[profile.model_id] = profile
173
+
174
+ async def route(
175
+ self,
176
+ task: Dict[str, Any],
177
+ available_models: Optional[List[str]] = None,
178
+ quality_threshold: Optional[float] = None,
179
+ budget_cap_cents: Optional[float] = None
180
+ ) -> RoutingDecision:
181
+ """
182
+ Route task to optimal model using learned profiles
183
+
184
+ Args:
185
+ task: Task to route with 'description' and context
186
+ available_models: List of available models (if None, use all)
187
+ quality_threshold: Minimum quality (overrides default)
188
+ budget_cap_cents: Maximum cost in cents
189
+
190
+ Returns:
191
+ RoutingDecision with selected model and reasoning
192
+ """
193
+ # Use default if not specified
194
+ if available_models is None:
195
+ available_models = list(self.model_profiles.keys())
196
+ if quality_threshold is None:
197
+ quality_threshold = self.quality_target
198
+
199
+ # Extract task features
200
+ task_features = self._extract_task_features(task)
201
+
202
+ # Score each available model
203
+ model_scores = {}
204
+ for model_id in available_models:
205
+ # Get or create profile
206
+ if model_id not in self.model_profiles:
207
+ profile = await self._infer_profile(model_id)
208
+ self.model_profiles[model_id] = profile
209
+ else:
210
+ profile = self.model_profiles[model_id]
211
+
212
+ # Predict quality for this task
213
+ quality_score = self._predict_quality(task_features, profile)
214
+
215
+ # Calculate combined score
216
+ combined_score = self._calculate_combined_score(
217
+ quality_score,
218
+ profile.cost_per_1k_tokens,
219
+ quality_threshold,
220
+ self.cost_weight
221
+ )
222
+
223
+ # Check budget cap
224
+ if budget_cap_cents is not None:
225
+ estimated_cost = self._estimate_cost(task, profile)
226
+ if estimated_cost * 100 > budget_cap_cents:
227
+ combined_score = -float('inf') # Over budget
228
+
229
+ model_scores[model_id] = {
230
+ "quality": quality_score,
231
+ "cost": profile.cost_per_1k_tokens,
232
+ "combined": combined_score,
233
+ "profile": profile
234
+ }
235
+
236
+ # Select best model
237
+ if not model_scores:
238
+ raise ValueError("No models available or all over budget")
239
+
240
+ best_model_id = max(model_scores.items(), key=lambda x: x[1]["combined"])[0]
241
+ best_score = model_scores[best_model_id]
242
+
243
+ # Create decision
244
+ decision = RoutingDecision(
245
+ selected_model=best_model_id,
246
+ reasoning=self._generate_reasoning(best_score, task_features),
247
+ predicted_quality=best_score["quality"],
248
+ estimated_cost=self._estimate_cost(task, best_score["profile"]),
249
+ alternative_models=sorted(
250
+ model_scores.keys(),
251
+ key=lambda m: model_scores[m]["combined"],
252
+ reverse=True
253
+ )[1:4], # Top 3 alternatives
254
+ confidence=best_score["quality"]
255
+ )
256
+
257
+ # Log routing decision for learning
258
+ self._log_routing_decision(task, best_model_id, model_scores)
259
+
260
+ return decision
261
+
262
+ def _extract_task_features(self, task: Dict[str, Any]) -> Dict[str, float]:
263
+ """
264
+ Extract feature vector from task for learned routing
265
+
266
+ Uses learned heuristics to extract features that correlate with model performance
267
+ """
268
+ description = task.get("description", "")
269
+ context = task.get("context", {})
270
+
271
+ features = {}
272
+
273
+ # Feature 1: Length (token count estimate)
274
+ features["length"] = len(description.split()) / 100.0 # Normalize
275
+
276
+ # Feature 2: Technical complexity
277
+ technical_keywords = ["api", "database", "algorithm", "code", "implement"]
278
+ features["technical"] = sum(
279
+ 1 for kw in technical_keywords if kw in description.lower()
280
+ ) / len(technical_keywords)
281
+
282
+ # Feature 3: Domain specificity
283
+ domain_keywords = {
284
+ "web": ["web", "frontend", "html", "css", "javascript"],
285
+ "data": ["data", "sql", "query", "analytics", "database"],
286
+ "ml": ["machine learning", "model", "training", "inference"],
287
+ "general": ["explain", "describe", "what", "how"]
288
+ }
289
+ features["domain"] = {}
290
+ for domain, keywords in domain_keywords.items():
291
+ features["domain"][domain] = sum(
292
+ 1 for kw in keywords if kw in description.lower()
293
+ ) / len(keywords)
294
+
295
+ # Feature 4: Requirement constraints
296
+ features["constraints"] = len(context.get("requirements", [])) / 10.0
297
+
298
+ # Feature 5: Complexity estimate
299
+ complexity_indicators = ["then", "after", "before", "integrate", "using"]
300
+ features["complexity"] = sum(
301
+ 1 for word in complexity_indicators if word in description.lower()
302
+ ) / len(complexity_indicators)
303
+
304
+ return features
305
+
306
+ def _predict_quality(
307
+ self,
308
+ task_features: Dict[str, float],
309
+ profile: ModelProfile
310
+ ) -> float:
311
+ """
312
+ Predict quality score for model on this task
313
+
314
+ Uses learned quality-by-difficulty if available,
315
+ otherwise falls back to average quality
316
+ """
317
+ # Estimate task difficulty from features
318
+ difficulty = "unknown"
319
+ if task_features["complexity"] < 0.3:
320
+ difficulty = "trivial"
321
+ elif task_features["complexity"] < 0.6:
322
+ difficulty = "simple"
323
+ elif task_features["complexity"] < 0.8:
324
+ difficulty = "medium"
325
+ else:
326
+ difficulty = "complex"
327
+
328
+ # Use learned quality by difficulty if available
329
+ if profile.quality_by_difficulty and difficulty in profile.quality_by_difficulty:
330
+ return profile.quality_by_difficulty[difficulty]
331
+ else:
332
+ # Fallback to average quality with adjustment
333
+ base_quality = profile.avg_quality_score
334
+
335
+ # Adjust for difficulty
336
+ if difficulty == "trivial" and base_quality > 0.9:
337
+ # High-quality model might be overkill for trivial task
338
+ return base_quality * 0.95
339
+ elif difficulty == "complex" and base_quality < 0.8:
340
+ # Lower-quality model might struggle
341
+ return base_quality * 0.85
342
+
343
+ return base_quality
344
+
345
+ def _calculate_combined_score(
346
+ self,
347
+ quality_score: float,
348
+ cost_per_1k: float,
349
+ quality_threshold: float,
350
+ cost_weight: float
351
+ ) -> float:
352
+ """
353
+ Calculate combined quality-cost score
354
+
355
+ Formula: (quality^quality_threshold) / (cost * cost_weight)
356
+ Higher is better
357
+ """
358
+ # Quality penalty if below threshold
359
+ quality_adjusted = quality_score if quality_score >= quality_threshold else quality_score * 0.5
360
+
361
+ # Cost penalty (lower cost is better)
362
+ cost_adjusted = cost_per_1k + 0.0001 # Avoid division by zero
363
+
364
+ # Combined score: maximize quality, minimize cost
365
+ combined = (quality_adjusted ** quality_threshold) / cost_adjusted
366
+
367
+ # Apply cost weight
368
+ if cost_weight < 0.5:
369
+ # Prioritize quality
370
+ combined = combined * (1 + (1 - cost_weight))
371
+ else:
372
+ # Prioritize cost
373
+ combined = combined * (1 + cost_weight)
374
+
375
+ return combined
376
+
377
+ def _estimate_cost(self, task: Dict[str, Any], profile: ModelProfile) -> float:
378
+ """Estimate cost in USD for this task"""
379
+ # Rough token estimate
380
+ description_length = len(task.get("description", ""))
381
+ estimated_tokens = max(500, description_length * 2)
382
+
383
+ # Estimate cost
384
+ estimated_cost = (estimated_tokens / 1000.0) * profile.cost_per_1k_tokens
385
+
386
+ return estimated_cost
387
+
388
+ async def _infer_profile(self, model_id: str) -> ModelProfile:
389
+ """
390
+ Infer profile for unseen model
391
+
392
+ Strategy: Cluster with similar models based on known patterns
393
+ """
394
+ # Parse model_id to get provider and base model
395
+ parts = model_id.split("/")
396
+ provider = parts[0] if len(parts) > 1 else "unknown"
397
+ model_name = parts[1] if len(parts) > 1 else model_id
398
+
399
+ # Use heuristic defaults based on provider
400
+ if provider == "anthropic":
401
+ return ModelProfile(
402
+ model_id=model_id,
403
+ provider=provider,
404
+ cost_per_1k_tokens=0.003,
405
+ latency_ms=800,
406
+ avg_quality_score=0.95,
407
+ quality_variance=0.05
408
+ )
409
+ elif provider == "openai":
410
+ return ModelProfile(
411
+ model_id=model_id,
412
+ provider=provider,
413
+ cost_per_1k_tokens=0.002,
414
+ latency_ms=600,
415
+ avg_quality_score=0.92,
416
+ quality_variance=0.06
417
+ )
418
+ elif provider in ["cerebras", "groq", "together"]:
419
+ return ModelProfile(
420
+ model_id=model_id,
421
+ provider=provider,
422
+ cost_per_1k_tokens=0.0005,
423
+ latency_ms=300,
424
+ avg_quality_score=0.75,
425
+ quality_variance=0.10
426
+ )
427
+ else:
428
+ # Unknown provider - use conservative defaults
429
+ return ModelProfile(
430
+ model_id=model_id,
431
+ provider=provider,
432
+ cost_per_1k_tokens=0.001,
433
+ latency_ms=500,
434
+ avg_quality_score=0.80,
435
+ quality_variance=0.10
436
+ )
437
+
438
+ def _generate_reasoning(
439
+ self,
440
+ score_data: Dict,
441
+ task_features: Dict[str, float]
442
+ ) -> str:
443
+ """Generate human-readable reasoning for routing decision"""
444
+ profile = score_data["profile"]
445
+ quality = score_data["quality"]
446
+
447
+ return (
448
+ f"Selected {profile.model_id} (quality: {quality:.2f}, "
449
+ f"cost: ${profile.cost_per_1k_tokens:.4f}/1K tokens). "
450
+ f"Task features: complexity={task_features.get('complexity', 0):.2f}, "
451
+ f"technical={task_features.get('technical', 0):.2f}"
452
+ )
453
+
454
+ def _log_routing_decision(
455
+ self,
456
+ task: Dict[str, Any],
457
+ selected_model: str,
458
+ all_scores: Dict[str, Dict]
459
+ ):
460
+ """Log routing decision for online learning"""
461
+ self.routing_history.append({
462
+ "timestamp": datetime.now().isoformat(),
463
+ "task_description": task.get("description", "")[:100],
464
+ "selected_model": selected_model,
465
+ "all_scores": {
466
+ model_id: scores["combined"]
467
+ for model_id, scores in all_scores.items()
468
+ }
469
+ })
470
+
471
+ async def learn_from_feedback(
472
+ self,
473
+ outcomes: List[Dict[str, Any]]
474
+ ):
475
+ """
476
+ Online learning: Update model profiles based on actual outcomes
477
+
478
+ Args:
479
+ outcomes: List of dicts with:
480
+ - model: str (model_id used)
481
+ - task: dict (original task)
482
+ - actual_quality: float (0-1, from user feedback or auto-eval)
483
+ - success: bool (whether execution succeeded)
484
+ - cost_usd: float (actual cost)
485
+ """
486
+ for outcome in outcomes:
487
+ model_id = outcome["model"]
488
+ task = outcome["task"]
489
+ actual_quality = outcome["actual_quality"]
490
+
491
+ # Get profile
492
+ if model_id not in self.model_profiles:
493
+ profile = await self._infer_profile(model_id)
494
+ self.model_profiles[model_id] = profile
495
+ else:
496
+ profile = self.model_profiles[model_id]
497
+
498
+ # Update profile
499
+ profile.total_executions += 1
500
+ if outcome.get("success", True):
501
+ profile.successful_executions += 1
502
+
503
+ # Add to recent outcomes (sliding window of 100)
504
+ profile.recent_outcomes.append(actual_quality)
505
+ if len(profile.recent_outcomes) > 100:
506
+ profile.recent_outcomes.pop(0)
507
+
508
+ # Update average quality (exponential moving average)
509
+ old_quality = profile.avg_quality_score
510
+ new_quality = (1 - self.learning_rate) * old_quality + self.learning_rate * actual_quality
511
+
512
+ profile.avg_quality_score = new_quality
513
+
514
+ # Update quality variance
515
+ if len(profile.recent_outcomes) > 10:
516
+ variance = np.var(profile.recent_outcomes)
517
+ profile.quality_variance = variance
518
+
519
+ logger.info(
520
+ f"Updated profile for {model_id}: "
521
+ f"quality {old_quality:.3f} → {new_quality:.3f} "
522
+ f"(total executions: {profile.total_executions})"
523
+ )
524
+
525
+ def get_routing_stats(self) -> Dict[str, Any]:
526
+ """Get routing statistics"""
527
+ if not self.routing_history:
528
+ return {"total_routes": 0}
529
+
530
+ model_usage = defaultdict(int)
531
+ for route in self.routing_history:
532
+ model_usage[route["selected_model"]] += 1
533
+
534
+ return {
535
+ "total_routes": len(self.routing_history),
536
+ "model_usage": dict(model_usage),
537
+ "num_models": len(self.model_profiles),
538
+ "models_with_profiles": sum(
539
+ 1 for p in self.model_profiles.values()
540
+ if p.total_executions > 0
541
+ )
542
+ }
543
+
544
+ def save_profiles(self, filepath: str):
545
+ """Save model profiles to file"""
546
+ data = {
547
+ model_id: profile.to_dict()
548
+ for model_id, profile in self.model_profiles.items()
549
+ }
550
+ with open(filepath, 'w') as f:
551
+ json.dump(data, f, indent=2)
552
+
553
+ def load_profiles(self, filepath: str):
554
+ """Load model profiles from file"""
555
+ with open(filepath, 'r') as f:
556
+ data = json.load(f)
557
+
558
+ self.model_profiles = {
559
+ model_id: ModelProfile.from_dict(profile_data)
560
+ for model_id, profile_data in data.items()
561
+ }
562
+
563
+
564
+ # Example usage
565
+ async def main():
566
+ """Example of UniversalModelRouter usage"""
567
+ router = UniversalModelRouter(
568
+ quality_target=0.90,
569
+ cost_weight=0.5
570
+ )
571
+
572
+ # Simple task
573
+ simple_task = {
574
+ "description": "What is 2+2?",
575
+ "context": {}
576
+ }
577
+
578
+ decision = await router.route(simple_task)
579
+
580
+ print(f"Routing Decision:")
581
+ print(f" Selected: {decision.selected_model}")
582
+ print(f" Reasoning: {decision.reasoning}")
583
+ print(f" Predicted Quality: {decision.predicted_quality:.2f}")
584
+ print(f" Estimated Cost: ${decision.estimated_cost:.6f}")
585
+ print(f" Confidence: {decision.confidence:.2f}")
586
+
587
+ # Complex task
588
+ complex_task = {
589
+ "description": "Build a REST API with user authentication and database integration",
590
+ "context": {"requirements": ["JWT", "PostgreSQL"]}
591
+ }
592
+
593
+ decision = await router.route(complex_task)
594
+
595
+ print(f"\nRouting Decision (Complex):")
596
+ print(f" Selected: {decision.selected_model}")
597
+ print(f" Reasoning: {decision.reasoning}")
598
+ print(f" Predicted Quality: {decision.predicted_quality:.2f}")
599
+ print(f" Estimated Cost: ${decision.estimated_cost:.6f}")
600
+
601
+ # Simulate learning from feedback
602
+ await router.learn_from_feedback([
603
+ {
604
+ "model": "anthropic/claude-3-5-sonnet-20241022",
605
+ "task": simple_task,
606
+ "actual_quality": 0.95,
607
+ "success": True,
608
+ "cost_usd": 0.002
609
+ }
610
+ ])
611
+
612
+ # Stats
613
+ stats = router.get_routing_stats()
614
+ print(f"\nRouter Stats:")
615
+ print(f" Total routes: {stats['total_routes']}")
616
+ print(f" Models available: {stats['num_models']}")
617
+ print(f" Model usage: {stats['model_usage']}")
618
+
619
+
620
+ if __name__ == "__main__":
621
+ asyncio.run(main())