adaptive-memory-multi-model-router 1.2.2 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +146 -66
- package/dist/index.d.ts +1 -1
- package/dist/index.js +1 -1
- package/dist/integrations/airtable.js +20 -0
- package/dist/integrations/discord.js +18 -0
- package/dist/integrations/github.js +23 -0
- package/dist/integrations/gmail.js +19 -0
- package/dist/integrations/google-calendar.js +18 -0
- package/dist/integrations/index.js +61 -0
- package/dist/integrations/jira.js +21 -0
- package/dist/integrations/linear.js +19 -0
- package/dist/integrations/notion.js +19 -0
- package/dist/integrations/slack.js +18 -0
- package/dist/integrations/telegram.js +19 -0
- package/dist/providers/registry.js +7 -3
- package/docs/ARCHITECTURAL-IMPROVEMENTS-2025.md +1391 -0
- package/docs/ARCHITECTURAL-IMPROVEMENTS-REVISED-2025.md +1051 -0
- package/docs/CONFIGURATION.md +476 -0
- package/docs/COUNCIL_DECISION.json +308 -0
- package/docs/COUNCIL_SUMMARY.md +265 -0
- package/docs/COUNCIL_V2.2_DECISION.md +416 -0
- package/docs/IMPROVEMENT_ROADMAP.md +515 -0
- package/docs/LLM_COUNCIL_DECISION.md +508 -0
- package/docs/QUICK_START_VISIBILITY.md +782 -0
- package/docs/REDDIT_GAP_ANALYSIS.md +299 -0
- package/docs/RESEARCH_BACKED_IMPROVEMENTS.md +1180 -0
- package/docs/TMLPD_QNA.md +751 -0
- package/docs/TMLPD_V2.1_COMPLETE.md +763 -0
- package/docs/TMLPD_V2.2_RESEARCH_ROADMAP.md +754 -0
- package/docs/V2.2_IMPLEMENTATION_COMPLETE.md +446 -0
- package/docs/V2_IMPLEMENTATION_GUIDE.md +388 -0
- package/docs/VISIBILITY_ADOPTION_PLAN.md +1005 -0
- package/docs/launch-content/LAUNCH_EXECUTION_CHECKLIST.md +421 -0
- package/docs/launch-content/README.md +457 -0
- package/docs/launch-content/assets/cost_comparison_100_tasks.png +0 -0
- package/docs/launch-content/assets/cumulative_savings.png +0 -0
- package/docs/launch-content/assets/parallel_speedup.png +0 -0
- package/docs/launch-content/assets/provider_pricing_comparison.png +0 -0
- package/docs/launch-content/assets/task_breakdown_comparison.png +0 -0
- package/docs/launch-content/generate_charts.py +313 -0
- package/docs/launch-content/hn_show_post.md +139 -0
- package/docs/launch-content/partner_outreach_templates.md +745 -0
- package/docs/launch-content/reddit_posts.md +467 -0
- package/docs/launch-content/twitter_thread.txt +460 -0
- package/examples/QUICKSTART.md +1 -1
- package/openclaw-alexa-bridge/ALL_REMAINING_FIXES_PLAN.md +313 -0
- package/openclaw-alexa-bridge/REMAINING_FIXES_SUMMARY.md +277 -0
- package/openclaw-alexa-bridge/src/alexa_handler_no_tmlpd.js +1234 -0
- package/openclaw-alexa-bridge/test_fixes.js +77 -0
- package/package.json +120 -29
- package/package.json.tmp +0 -0
- package/qna/TMLPD_QNA.md +3 -3
- package/skill/SKILL.md +2 -2
- package/src/__tests__/integration/tmpld_integration.test.py +540 -0
- package/src/agents/skill_enhanced_agent.py +318 -0
- package/src/memory/__init__.py +15 -0
- package/src/memory/agentic_memory.py +353 -0
- package/src/memory/semantic_memory.py +444 -0
- package/src/memory/simple_memory.py +466 -0
- package/src/memory/working_memory.py +447 -0
- package/src/orchestration/__init__.py +52 -0
- package/src/orchestration/execution_engine.py +353 -0
- package/src/orchestration/halo_orchestrator.py +367 -0
- package/src/orchestration/mcts_workflow.py +498 -0
- package/src/orchestration/role_assigner.py +473 -0
- package/src/orchestration/task_planner.py +522 -0
- package/src/providers/__init__.py +67 -0
- package/src/providers/anthropic.py +304 -0
- package/src/providers/base.py +241 -0
- package/src/providers/cerebras.py +373 -0
- package/src/providers/registry.py +476 -0
- package/src/routing/__init__.py +30 -0
- package/src/routing/universal_router.py +621 -0
- package/src/skills/TMLPD-QUICKREF.md +210 -0
- package/src/skills/TMLPD-SETUP-SUMMARY.md +157 -0
- package/src/skills/TMLPD.md +540 -0
- package/src/skills/__tests__/skill_manager.test.ts +328 -0
- package/src/skills/skill_manager.py +385 -0
- package/src/skills/test-tmlpd.sh +108 -0
- package/src/skills/tmlpd-category.yaml +67 -0
- package/src/skills/tmlpd-monitoring.yaml +188 -0
- package/src/skills/tmlpd-phase.yaml +132 -0
- package/src/state/__init__.py +17 -0
- package/src/state/simple_checkpoint.py +508 -0
- package/src/tmlpd_agent.py +464 -0
- package/src/tmpld_v2.py +427 -0
- package/src/workflows/__init__.py +18 -0
- package/src/workflows/advanced_difficulty_classifier.py +377 -0
- package/src/workflows/chaining_executor.py +417 -0
- package/src/workflows/difficulty_integration.py +209 -0
- package/src/workflows/orchestrator.py +469 -0
- package/src/workflows/orchestrator_executor.py +456 -0
- package/src/workflows/parallelization_executor.py +382 -0
- package/src/workflows/router.py +311 -0
- package/test_integration_simple.py +86 -0
- package/test_mcts_workflow.py +150 -0
- package/test_templd_integration.py +262 -0
- package/test_universal_router.py +275 -0
- package/tmlpd-pi-extension/README.md +36 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts +114 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.js +285 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.js.map +1 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.d.ts +58 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.js +153 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.js.map +1 -0
- package/tmlpd-pi-extension/dist/cli.js +59 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.d.ts +95 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.js +240 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.js.map +1 -0
- package/tmlpd-pi-extension/dist/index.d.ts +723 -0
- package/tmlpd-pi-extension/dist/index.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/index.js +239 -0
- package/tmlpd-pi-extension/dist/index.js.map +1 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts +82 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.js +145 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.js.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts +102 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js +207 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts +85 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js +210 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js.map +1 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.d.ts +102 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.js +338 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.js.map +1 -0
- package/tmlpd-pi-extension/dist/providers/registry.d.ts +55 -0
- package/tmlpd-pi-extension/dist/providers/registry.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/providers/registry.js +138 -0
- package/tmlpd-pi-extension/dist/providers/registry.js.map +1 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts +68 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.js +332 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.js.map +1 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts +101 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.js +368 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts +96 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.js +170 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/compression.d.ts +61 -0
- package/tmlpd-pi-extension/dist/utils/compression.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/compression.js +281 -0
- package/tmlpd-pi-extension/dist/utils/compression.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/reliability.d.ts +74 -0
- package/tmlpd-pi-extension/dist/utils/reliability.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/reliability.js +177 -0
- package/tmlpd-pi-extension/dist/utils/reliability.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts +117 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js +246 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts +50 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.js +124 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.js.map +1 -0
- package/tmlpd-pi-extension/examples/QUICKSTART.md +183 -0
- package/tmlpd-pi-extension/package-lock.json +75 -0
- package/tmlpd-pi-extension/package.json +172 -0
- package/tmlpd-pi-extension/python/examples.py +53 -0
- package/tmlpd-pi-extension/python/integrations.py +330 -0
- package/tmlpd-pi-extension/python/setup.py +28 -0
- package/tmlpd-pi-extension/python/tmlpd.py +369 -0
- package/tmlpd-pi-extension/qna/REDDIT_GAP_ANALYSIS.md +299 -0
- package/tmlpd-pi-extension/qna/TMLPD_QNA.md +751 -0
- package/tmlpd-pi-extension/skill/SKILL.md +238 -0
- package/{src → tmlpd-pi-extension/src}/index.ts +1 -1
- package/tmlpd-pi-extension/tsconfig.json +18 -0
- package/demo/research-demo.js +0 -266
- package/notebooks/quickstart.ipynb +0 -157
- package/rust/tmlpd.h +0 -268
- package/src/cache/prefixCache.ts +0 -365
- package/src/routing/advancedRouter.ts +0 -406
- package/src/utils/speculativeDecoding.ts +0 -344
- /package/{src → tmlpd-pi-extension/src}/cache/responseCache.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/cost/costTracker.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/memory/episodicMemory.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/orchestration/haloOrchestrator.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/orchestration/mctsWorkflow.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/providers/localProvider.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/providers/registry.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/tools/tmlpdTools.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/batchProcessor.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/compression.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/reliability.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/tokenUtils.ts +0 -0
|
@@ -0,0 +1,621 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UniversalModelRouter - Learned Routing with Online Adaptation
|
|
3
|
+
|
|
4
|
+
Based on:
|
|
5
|
+
- arXiv:2502.08773 (UniRoute - Universal Routing)
|
|
6
|
+
- ICLR 2024 (Hybrid LLM - 40% fewer calls to expensive models)
|
|
7
|
+
- ICML 2025 (BEST-Route - 60% cost reduction with <1% quality drop)
|
|
8
|
+
|
|
9
|
+
Key Innovation: Learns model profiles from execution data and adapts
|
|
10
|
+
to new unseen models automatically.
|
|
11
|
+
|
|
12
|
+
Features:
|
|
13
|
+
- Learns feature vectors for each model from execution history
|
|
14
|
+
- Routes based on learned quality profiles, not static rules
|
|
15
|
+
- Online learning: updates profiles from actual outcomes
|
|
16
|
+
- Adapts to new unseen models via clustering + similarity
|
|
17
|
+
- Dynamic quality-cost tradeoff at runtime
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import asyncio
|
|
21
|
+
import numpy as np
|
|
22
|
+
from typing import Dict, List, Any, Optional, Tuple
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from datetime import datetime
|
|
25
|
+
from collections import defaultdict
|
|
26
|
+
import json
|
|
27
|
+
import logging
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ModelProfile:
|
|
35
|
+
"""Learned profile of an LLM model"""
|
|
36
|
+
model_id: str
|
|
37
|
+
provider: str
|
|
38
|
+
cost_per_1k_tokens: float
|
|
39
|
+
latency_ms: float
|
|
40
|
+
|
|
41
|
+
# Learned metrics
|
|
42
|
+
avg_quality_score: float = 0.5 # 0-1
|
|
43
|
+
quality_variance: float = 0.1
|
|
44
|
+
total_executions: int = 0
|
|
45
|
+
successful_executions: int = 0
|
|
46
|
+
|
|
47
|
+
# Quality by difficulty
|
|
48
|
+
quality_by_difficulty: Dict[str, float] = field(default_factory=dict)
|
|
49
|
+
# "trivial": 0.95, "simple": 0.90, "medium": 0.85, etc.
|
|
50
|
+
|
|
51
|
+
# Feature vector (for unseen models)
|
|
52
|
+
feature_vector: Optional[List[float]] = None
|
|
53
|
+
|
|
54
|
+
# Execution history for online learning
|
|
55
|
+
recent_outcomes: List[float] = field(default_factory=list)
|
|
56
|
+
|
|
57
|
+
def to_dict(self) -> Dict:
|
|
58
|
+
return {
|
|
59
|
+
"model_id": self.model_id,
|
|
60
|
+
"provider": self.provider,
|
|
61
|
+
"cost_per_1k_tokens": self.cost_per_1k_tokens,
|
|
62
|
+
"latency_ms": self.latency_ms,
|
|
63
|
+
"avg_quality_score": self.avg_quality_score,
|
|
64
|
+
"quality_variance": self.quality_variance,
|
|
65
|
+
"total_executions": self.total_executions,
|
|
66
|
+
"successful_executions": self.successful_executions,
|
|
67
|
+
"quality_by_difficulty": self.quality_by_difficulty,
|
|
68
|
+
"feature_vector": self.feature_vector
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_dict(cls, data: Dict) -> "ModelProfile":
|
|
73
|
+
"""Create ModelProfile from dict (for loading from storage)"""
|
|
74
|
+
return cls(**data)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class RoutingDecision:
|
|
79
|
+
"""Result of routing decision"""
|
|
80
|
+
selected_model: str
|
|
81
|
+
reasoning: str
|
|
82
|
+
predicted_quality: float
|
|
83
|
+
estimated_cost: float
|
|
84
|
+
alternative_models: List[str] = field(default_factory=list)
|
|
85
|
+
confidence: float = 0.0
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class UniversalModelRouter:
|
|
89
|
+
"""
|
|
90
|
+
Universal Learned Router that adapts to new models
|
|
91
|
+
|
|
92
|
+
Key innovations:
|
|
93
|
+
1. Learns model quality profiles from execution data
|
|
94
|
+
2. Routes based on learned profiles (not static rules)
|
|
95
|
+
3. Adapts to new unseen models via clustering
|
|
96
|
+
4. Online learning from feedback
|
|
97
|
+
5. Dynamic quality-cost tradeoff
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
quality_target: float = 0.95,
|
|
103
|
+
cost_weight: float = 0.5,
|
|
104
|
+
learning_rate: float = 0.1
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
Initialize universal router
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
quality_target: Minimum quality threshold (0-1)
|
|
111
|
+
cost_weight: Balance between quality and cost (0-1)
|
|
112
|
+
learning_rate: Rate at which to update profiles (0-1)
|
|
113
|
+
"""
|
|
114
|
+
self.quality_target = quality_target
|
|
115
|
+
self.cost_weight = cost_weight
|
|
116
|
+
self.learning_rate = learning_rate
|
|
117
|
+
|
|
118
|
+
# Model registry
|
|
119
|
+
self.model_profiles: Dict[str, ModelProfile] = {}
|
|
120
|
+
self._initialize_default_profiles()
|
|
121
|
+
|
|
122
|
+
# Routing history for learning
|
|
123
|
+
self.routing_history: List[Dict] = []
|
|
124
|
+
|
|
125
|
+
def _initialize_default_profiles(self):
|
|
126
|
+
"""Initialize profiles for known models"""
|
|
127
|
+
default_models = [
|
|
128
|
+
{
|
|
129
|
+
"model_id": "anthropic/claude-3-5-sonnet-20241022",
|
|
130
|
+
"provider": "anthropic",
|
|
131
|
+
"cost_per_1k_tokens": 0.003,
|
|
132
|
+
"latency_ms": 800,
|
|
133
|
+
"avg_quality_score": 0.98,
|
|
134
|
+
"quality_variance": 0.02
|
|
135
|
+
},
|
|
136
|
+
{
|
|
137
|
+
"model_id": "openai/gpt-4o",
|
|
138
|
+
"provider": "openai",
|
|
139
|
+
"cost_per_1k_tokens": 0.0025,
|
|
140
|
+
"latency_ms": 600,
|
|
141
|
+
"avg_quality_score": 0.95,
|
|
142
|
+
"quality_variance": 0.03
|
|
143
|
+
},
|
|
144
|
+
{
|
|
145
|
+
"model_id": "cerebras/llama-3.3-70b",
|
|
146
|
+
"provider": "cerebras",
|
|
147
|
+
"cost_per_1k_tokens": 0.0001,
|
|
148
|
+
"latency_ms": 200,
|
|
149
|
+
"avg_quality_score": 0.75,
|
|
150
|
+
"quality_variance": 0.10
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"model_id": "groq/llama-3.3-70b",
|
|
154
|
+
"provider": "groq",
|
|
155
|
+
"cost_per_1k_tokens": 0.0003,
|
|
156
|
+
"latency_ms": 250,
|
|
157
|
+
"avg_quality_score": 0.78,
|
|
158
|
+
"quality_variance": 0.08
|
|
159
|
+
},
|
|
160
|
+
{
|
|
161
|
+
"model_id": "together/mixtral-8x7b",
|
|
162
|
+
"provider": "together",
|
|
163
|
+
"cost_per_1k_tokens": 0.0009,
|
|
164
|
+
"latency_ms": 400,
|
|
165
|
+
"avg_quality_score": 0.82,
|
|
166
|
+
"quality_variance": 0.07
|
|
167
|
+
}
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
for model_data in default_models:
|
|
171
|
+
profile = ModelProfile(**model_data)
|
|
172
|
+
self.model_profiles[profile.model_id] = profile
|
|
173
|
+
|
|
174
|
+
async def route(
|
|
175
|
+
self,
|
|
176
|
+
task: Dict[str, Any],
|
|
177
|
+
available_models: Optional[List[str]] = None,
|
|
178
|
+
quality_threshold: Optional[float] = None,
|
|
179
|
+
budget_cap_cents: Optional[float] = None
|
|
180
|
+
) -> RoutingDecision:
|
|
181
|
+
"""
|
|
182
|
+
Route task to optimal model using learned profiles
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
task: Task to route with 'description' and context
|
|
186
|
+
available_models: List of available models (if None, use all)
|
|
187
|
+
quality_threshold: Minimum quality (overrides default)
|
|
188
|
+
budget_cap_cents: Maximum cost in cents
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
RoutingDecision with selected model and reasoning
|
|
192
|
+
"""
|
|
193
|
+
# Use default if not specified
|
|
194
|
+
if available_models is None:
|
|
195
|
+
available_models = list(self.model_profiles.keys())
|
|
196
|
+
if quality_threshold is None:
|
|
197
|
+
quality_threshold = self.quality_target
|
|
198
|
+
|
|
199
|
+
# Extract task features
|
|
200
|
+
task_features = self._extract_task_features(task)
|
|
201
|
+
|
|
202
|
+
# Score each available model
|
|
203
|
+
model_scores = {}
|
|
204
|
+
for model_id in available_models:
|
|
205
|
+
# Get or create profile
|
|
206
|
+
if model_id not in self.model_profiles:
|
|
207
|
+
profile = await self._infer_profile(model_id)
|
|
208
|
+
self.model_profiles[model_id] = profile
|
|
209
|
+
else:
|
|
210
|
+
profile = self.model_profiles[model_id]
|
|
211
|
+
|
|
212
|
+
# Predict quality for this task
|
|
213
|
+
quality_score = self._predict_quality(task_features, profile)
|
|
214
|
+
|
|
215
|
+
# Calculate combined score
|
|
216
|
+
combined_score = self._calculate_combined_score(
|
|
217
|
+
quality_score,
|
|
218
|
+
profile.cost_per_1k_tokens,
|
|
219
|
+
quality_threshold,
|
|
220
|
+
self.cost_weight
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Check budget cap
|
|
224
|
+
if budget_cap_cents is not None:
|
|
225
|
+
estimated_cost = self._estimate_cost(task, profile)
|
|
226
|
+
if estimated_cost * 100 > budget_cap_cents:
|
|
227
|
+
combined_score = -float('inf') # Over budget
|
|
228
|
+
|
|
229
|
+
model_scores[model_id] = {
|
|
230
|
+
"quality": quality_score,
|
|
231
|
+
"cost": profile.cost_per_1k_tokens,
|
|
232
|
+
"combined": combined_score,
|
|
233
|
+
"profile": profile
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
# Select best model
|
|
237
|
+
if not model_scores:
|
|
238
|
+
raise ValueError("No models available or all over budget")
|
|
239
|
+
|
|
240
|
+
best_model_id = max(model_scores.items(), key=lambda x: x[1]["combined"])[0]
|
|
241
|
+
best_score = model_scores[best_model_id]
|
|
242
|
+
|
|
243
|
+
# Create decision
|
|
244
|
+
decision = RoutingDecision(
|
|
245
|
+
selected_model=best_model_id,
|
|
246
|
+
reasoning=self._generate_reasoning(best_score, task_features),
|
|
247
|
+
predicted_quality=best_score["quality"],
|
|
248
|
+
estimated_cost=self._estimate_cost(task, best_score["profile"]),
|
|
249
|
+
alternative_models=sorted(
|
|
250
|
+
model_scores.keys(),
|
|
251
|
+
key=lambda m: model_scores[m]["combined"],
|
|
252
|
+
reverse=True
|
|
253
|
+
)[1:4], # Top 3 alternatives
|
|
254
|
+
confidence=best_score["quality"]
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Log routing decision for learning
|
|
258
|
+
self._log_routing_decision(task, best_model_id, model_scores)
|
|
259
|
+
|
|
260
|
+
return decision
|
|
261
|
+
|
|
262
|
+
def _extract_task_features(self, task: Dict[str, Any]) -> Dict[str, float]:
|
|
263
|
+
"""
|
|
264
|
+
Extract feature vector from task for learned routing
|
|
265
|
+
|
|
266
|
+
Uses learned heuristics to extract features that correlate with model performance
|
|
267
|
+
"""
|
|
268
|
+
description = task.get("description", "")
|
|
269
|
+
context = task.get("context", {})
|
|
270
|
+
|
|
271
|
+
features = {}
|
|
272
|
+
|
|
273
|
+
# Feature 1: Length (token count estimate)
|
|
274
|
+
features["length"] = len(description.split()) / 100.0 # Normalize
|
|
275
|
+
|
|
276
|
+
# Feature 2: Technical complexity
|
|
277
|
+
technical_keywords = ["api", "database", "algorithm", "code", "implement"]
|
|
278
|
+
features["technical"] = sum(
|
|
279
|
+
1 for kw in technical_keywords if kw in description.lower()
|
|
280
|
+
) / len(technical_keywords)
|
|
281
|
+
|
|
282
|
+
# Feature 3: Domain specificity
|
|
283
|
+
domain_keywords = {
|
|
284
|
+
"web": ["web", "frontend", "html", "css", "javascript"],
|
|
285
|
+
"data": ["data", "sql", "query", "analytics", "database"],
|
|
286
|
+
"ml": ["machine learning", "model", "training", "inference"],
|
|
287
|
+
"general": ["explain", "describe", "what", "how"]
|
|
288
|
+
}
|
|
289
|
+
features["domain"] = {}
|
|
290
|
+
for domain, keywords in domain_keywords.items():
|
|
291
|
+
features["domain"][domain] = sum(
|
|
292
|
+
1 for kw in keywords if kw in description.lower()
|
|
293
|
+
) / len(keywords)
|
|
294
|
+
|
|
295
|
+
# Feature 4: Requirement constraints
|
|
296
|
+
features["constraints"] = len(context.get("requirements", [])) / 10.0
|
|
297
|
+
|
|
298
|
+
# Feature 5: Complexity estimate
|
|
299
|
+
complexity_indicators = ["then", "after", "before", "integrate", "using"]
|
|
300
|
+
features["complexity"] = sum(
|
|
301
|
+
1 for word in complexity_indicators if word in description.lower()
|
|
302
|
+
) / len(complexity_indicators)
|
|
303
|
+
|
|
304
|
+
return features
|
|
305
|
+
|
|
306
|
+
def _predict_quality(
|
|
307
|
+
self,
|
|
308
|
+
task_features: Dict[str, float],
|
|
309
|
+
profile: ModelProfile
|
|
310
|
+
) -> float:
|
|
311
|
+
"""
|
|
312
|
+
Predict quality score for model on this task
|
|
313
|
+
|
|
314
|
+
Uses learned quality-by-difficulty if available,
|
|
315
|
+
otherwise falls back to average quality
|
|
316
|
+
"""
|
|
317
|
+
# Estimate task difficulty from features
|
|
318
|
+
difficulty = "unknown"
|
|
319
|
+
if task_features["complexity"] < 0.3:
|
|
320
|
+
difficulty = "trivial"
|
|
321
|
+
elif task_features["complexity"] < 0.6:
|
|
322
|
+
difficulty = "simple"
|
|
323
|
+
elif task_features["complexity"] < 0.8:
|
|
324
|
+
difficulty = "medium"
|
|
325
|
+
else:
|
|
326
|
+
difficulty = "complex"
|
|
327
|
+
|
|
328
|
+
# Use learned quality by difficulty if available
|
|
329
|
+
if profile.quality_by_difficulty and difficulty in profile.quality_by_difficulty:
|
|
330
|
+
return profile.quality_by_difficulty[difficulty]
|
|
331
|
+
else:
|
|
332
|
+
# Fallback to average quality with adjustment
|
|
333
|
+
base_quality = profile.avg_quality_score
|
|
334
|
+
|
|
335
|
+
# Adjust for difficulty
|
|
336
|
+
if difficulty == "trivial" and base_quality > 0.9:
|
|
337
|
+
# High-quality model might be overkill for trivial task
|
|
338
|
+
return base_quality * 0.95
|
|
339
|
+
elif difficulty == "complex" and base_quality < 0.8:
|
|
340
|
+
# Lower-quality model might struggle
|
|
341
|
+
return base_quality * 0.85
|
|
342
|
+
|
|
343
|
+
return base_quality
|
|
344
|
+
|
|
345
|
+
def _calculate_combined_score(
|
|
346
|
+
self,
|
|
347
|
+
quality_score: float,
|
|
348
|
+
cost_per_1k: float,
|
|
349
|
+
quality_threshold: float,
|
|
350
|
+
cost_weight: float
|
|
351
|
+
) -> float:
|
|
352
|
+
"""
|
|
353
|
+
Calculate combined quality-cost score
|
|
354
|
+
|
|
355
|
+
Formula: (quality^quality_threshold) / (cost * cost_weight)
|
|
356
|
+
Higher is better
|
|
357
|
+
"""
|
|
358
|
+
# Quality penalty if below threshold
|
|
359
|
+
quality_adjusted = quality_score if quality_score >= quality_threshold else quality_score * 0.5
|
|
360
|
+
|
|
361
|
+
# Cost penalty (lower cost is better)
|
|
362
|
+
cost_adjusted = cost_per_1k + 0.0001 # Avoid division by zero
|
|
363
|
+
|
|
364
|
+
# Combined score: maximize quality, minimize cost
|
|
365
|
+
combined = (quality_adjusted ** quality_threshold) / cost_adjusted
|
|
366
|
+
|
|
367
|
+
# Apply cost weight
|
|
368
|
+
if cost_weight < 0.5:
|
|
369
|
+
# Prioritize quality
|
|
370
|
+
combined = combined * (1 + (1 - cost_weight))
|
|
371
|
+
else:
|
|
372
|
+
# Prioritize cost
|
|
373
|
+
combined = combined * (1 + cost_weight)
|
|
374
|
+
|
|
375
|
+
return combined
|
|
376
|
+
|
|
377
|
+
def _estimate_cost(self, task: Dict[str, Any], profile: ModelProfile) -> float:
|
|
378
|
+
"""Estimate cost in USD for this task"""
|
|
379
|
+
# Rough token estimate
|
|
380
|
+
description_length = len(task.get("description", ""))
|
|
381
|
+
estimated_tokens = max(500, description_length * 2)
|
|
382
|
+
|
|
383
|
+
# Estimate cost
|
|
384
|
+
estimated_cost = (estimated_tokens / 1000.0) * profile.cost_per_1k_tokens
|
|
385
|
+
|
|
386
|
+
return estimated_cost
|
|
387
|
+
|
|
388
|
+
async def _infer_profile(self, model_id: str) -> ModelProfile:
|
|
389
|
+
"""
|
|
390
|
+
Infer profile for unseen model
|
|
391
|
+
|
|
392
|
+
Strategy: Cluster with similar models based on known patterns
|
|
393
|
+
"""
|
|
394
|
+
# Parse model_id to get provider and base model
|
|
395
|
+
parts = model_id.split("/")
|
|
396
|
+
provider = parts[0] if len(parts) > 1 else "unknown"
|
|
397
|
+
model_name = parts[1] if len(parts) > 1 else model_id
|
|
398
|
+
|
|
399
|
+
# Use heuristic defaults based on provider
|
|
400
|
+
if provider == "anthropic":
|
|
401
|
+
return ModelProfile(
|
|
402
|
+
model_id=model_id,
|
|
403
|
+
provider=provider,
|
|
404
|
+
cost_per_1k_tokens=0.003,
|
|
405
|
+
latency_ms=800,
|
|
406
|
+
avg_quality_score=0.95,
|
|
407
|
+
quality_variance=0.05
|
|
408
|
+
)
|
|
409
|
+
elif provider == "openai":
|
|
410
|
+
return ModelProfile(
|
|
411
|
+
model_id=model_id,
|
|
412
|
+
provider=provider,
|
|
413
|
+
cost_per_1k_tokens=0.002,
|
|
414
|
+
latency_ms=600,
|
|
415
|
+
avg_quality_score=0.92,
|
|
416
|
+
quality_variance=0.06
|
|
417
|
+
)
|
|
418
|
+
elif provider in ["cerebras", "groq", "together"]:
|
|
419
|
+
return ModelProfile(
|
|
420
|
+
model_id=model_id,
|
|
421
|
+
provider=provider,
|
|
422
|
+
cost_per_1k_tokens=0.0005,
|
|
423
|
+
latency_ms=300,
|
|
424
|
+
avg_quality_score=0.75,
|
|
425
|
+
quality_variance=0.10
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
# Unknown provider - use conservative defaults
|
|
429
|
+
return ModelProfile(
|
|
430
|
+
model_id=model_id,
|
|
431
|
+
provider=provider,
|
|
432
|
+
cost_per_1k_tokens=0.001,
|
|
433
|
+
latency_ms=500,
|
|
434
|
+
avg_quality_score=0.80,
|
|
435
|
+
quality_variance=0.10
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
def _generate_reasoning(
|
|
439
|
+
self,
|
|
440
|
+
score_data: Dict,
|
|
441
|
+
task_features: Dict[str, float]
|
|
442
|
+
) -> str:
|
|
443
|
+
"""Generate human-readable reasoning for routing decision"""
|
|
444
|
+
profile = score_data["profile"]
|
|
445
|
+
quality = score_data["quality"]
|
|
446
|
+
|
|
447
|
+
return (
|
|
448
|
+
f"Selected {profile.model_id} (quality: {quality:.2f}, "
|
|
449
|
+
f"cost: ${profile.cost_per_1k_tokens:.4f}/1K tokens). "
|
|
450
|
+
f"Task features: complexity={task_features.get('complexity', 0):.2f}, "
|
|
451
|
+
f"technical={task_features.get('technical', 0):.2f}"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
def _log_routing_decision(
|
|
455
|
+
self,
|
|
456
|
+
task: Dict[str, Any],
|
|
457
|
+
selected_model: str,
|
|
458
|
+
all_scores: Dict[str, Dict]
|
|
459
|
+
):
|
|
460
|
+
"""Log routing decision for online learning"""
|
|
461
|
+
self.routing_history.append({
|
|
462
|
+
"timestamp": datetime.now().isoformat(),
|
|
463
|
+
"task_description": task.get("description", "")[:100],
|
|
464
|
+
"selected_model": selected_model,
|
|
465
|
+
"all_scores": {
|
|
466
|
+
model_id: scores["combined"]
|
|
467
|
+
for model_id, scores in all_scores.items()
|
|
468
|
+
}
|
|
469
|
+
})
|
|
470
|
+
|
|
471
|
+
async def learn_from_feedback(
|
|
472
|
+
self,
|
|
473
|
+
outcomes: List[Dict[str, Any]]
|
|
474
|
+
):
|
|
475
|
+
"""
|
|
476
|
+
Online learning: Update model profiles based on actual outcomes
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
outcomes: List of dicts with:
|
|
480
|
+
- model: str (model_id used)
|
|
481
|
+
- task: dict (original task)
|
|
482
|
+
- actual_quality: float (0-1, from user feedback or auto-eval)
|
|
483
|
+
- success: bool (whether execution succeeded)
|
|
484
|
+
- cost_usd: float (actual cost)
|
|
485
|
+
"""
|
|
486
|
+
for outcome in outcomes:
|
|
487
|
+
model_id = outcome["model"]
|
|
488
|
+
task = outcome["task"]
|
|
489
|
+
actual_quality = outcome["actual_quality"]
|
|
490
|
+
|
|
491
|
+
# Get profile
|
|
492
|
+
if model_id not in self.model_profiles:
|
|
493
|
+
profile = await self._infer_profile(model_id)
|
|
494
|
+
self.model_profiles[model_id] = profile
|
|
495
|
+
else:
|
|
496
|
+
profile = self.model_profiles[model_id]
|
|
497
|
+
|
|
498
|
+
# Update profile
|
|
499
|
+
profile.total_executions += 1
|
|
500
|
+
if outcome.get("success", True):
|
|
501
|
+
profile.successful_executions += 1
|
|
502
|
+
|
|
503
|
+
# Add to recent outcomes (sliding window of 100)
|
|
504
|
+
profile.recent_outcomes.append(actual_quality)
|
|
505
|
+
if len(profile.recent_outcomes) > 100:
|
|
506
|
+
profile.recent_outcomes.pop(0)
|
|
507
|
+
|
|
508
|
+
# Update average quality (exponential moving average)
|
|
509
|
+
old_quality = profile.avg_quality_score
|
|
510
|
+
new_quality = (1 - self.learning_rate) * old_quality + self.learning_rate * actual_quality
|
|
511
|
+
|
|
512
|
+
profile.avg_quality_score = new_quality
|
|
513
|
+
|
|
514
|
+
# Update quality variance
|
|
515
|
+
if len(profile.recent_outcomes) > 10:
|
|
516
|
+
variance = np.var(profile.recent_outcomes)
|
|
517
|
+
profile.quality_variance = variance
|
|
518
|
+
|
|
519
|
+
logger.info(
|
|
520
|
+
f"Updated profile for {model_id}: "
|
|
521
|
+
f"quality {old_quality:.3f} → {new_quality:.3f} "
|
|
522
|
+
f"(total executions: {profile.total_executions})"
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
def get_routing_stats(self) -> Dict[str, Any]:
|
|
526
|
+
"""Get routing statistics"""
|
|
527
|
+
if not self.routing_history:
|
|
528
|
+
return {"total_routes": 0}
|
|
529
|
+
|
|
530
|
+
model_usage = defaultdict(int)
|
|
531
|
+
for route in self.routing_history:
|
|
532
|
+
model_usage[route["selected_model"]] += 1
|
|
533
|
+
|
|
534
|
+
return {
|
|
535
|
+
"total_routes": len(self.routing_history),
|
|
536
|
+
"model_usage": dict(model_usage),
|
|
537
|
+
"num_models": len(self.model_profiles),
|
|
538
|
+
"models_with_profiles": sum(
|
|
539
|
+
1 for p in self.model_profiles.values()
|
|
540
|
+
if p.total_executions > 0
|
|
541
|
+
)
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
def save_profiles(self, filepath: str):
|
|
545
|
+
"""Save model profiles to file"""
|
|
546
|
+
data = {
|
|
547
|
+
model_id: profile.to_dict()
|
|
548
|
+
for model_id, profile in self.model_profiles.items()
|
|
549
|
+
}
|
|
550
|
+
with open(filepath, 'w') as f:
|
|
551
|
+
json.dump(data, f, indent=2)
|
|
552
|
+
|
|
553
|
+
def load_profiles(self, filepath: str):
|
|
554
|
+
"""Load model profiles from file"""
|
|
555
|
+
with open(filepath, 'r') as f:
|
|
556
|
+
data = json.load(f)
|
|
557
|
+
|
|
558
|
+
self.model_profiles = {
|
|
559
|
+
model_id: ModelProfile.from_dict(profile_data)
|
|
560
|
+
for model_id, profile_data in data.items()
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
# Example usage
|
|
565
|
+
async def main():
|
|
566
|
+
"""Example of UniversalModelRouter usage"""
|
|
567
|
+
router = UniversalModelRouter(
|
|
568
|
+
quality_target=0.90,
|
|
569
|
+
cost_weight=0.5
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Simple task
|
|
573
|
+
simple_task = {
|
|
574
|
+
"description": "What is 2+2?",
|
|
575
|
+
"context": {}
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
decision = await router.route(simple_task)
|
|
579
|
+
|
|
580
|
+
print(f"Routing Decision:")
|
|
581
|
+
print(f" Selected: {decision.selected_model}")
|
|
582
|
+
print(f" Reasoning: {decision.reasoning}")
|
|
583
|
+
print(f" Predicted Quality: {decision.predicted_quality:.2f}")
|
|
584
|
+
print(f" Estimated Cost: ${decision.estimated_cost:.6f}")
|
|
585
|
+
print(f" Confidence: {decision.confidence:.2f}")
|
|
586
|
+
|
|
587
|
+
# Complex task
|
|
588
|
+
complex_task = {
|
|
589
|
+
"description": "Build a REST API with user authentication and database integration",
|
|
590
|
+
"context": {"requirements": ["JWT", "PostgreSQL"]}
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
decision = await router.route(complex_task)
|
|
594
|
+
|
|
595
|
+
print(f"\nRouting Decision (Complex):")
|
|
596
|
+
print(f" Selected: {decision.selected_model}")
|
|
597
|
+
print(f" Reasoning: {decision.reasoning}")
|
|
598
|
+
print(f" Predicted Quality: {decision.predicted_quality:.2f}")
|
|
599
|
+
print(f" Estimated Cost: ${decision.estimated_cost:.6f}")
|
|
600
|
+
|
|
601
|
+
# Simulate learning from feedback
|
|
602
|
+
await router.learn_from_feedback([
|
|
603
|
+
{
|
|
604
|
+
"model": "anthropic/claude-3-5-sonnet-20241022",
|
|
605
|
+
"task": simple_task,
|
|
606
|
+
"actual_quality": 0.95,
|
|
607
|
+
"success": True,
|
|
608
|
+
"cost_usd": 0.002
|
|
609
|
+
}
|
|
610
|
+
])
|
|
611
|
+
|
|
612
|
+
# Stats
|
|
613
|
+
stats = router.get_routing_stats()
|
|
614
|
+
print(f"\nRouter Stats:")
|
|
615
|
+
print(f" Total routes: {stats['total_routes']}")
|
|
616
|
+
print(f" Models available: {stats['num_models']}")
|
|
617
|
+
print(f" Model usage: {stats['model_usage']}")
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
if __name__ == "__main__":
|
|
621
|
+
asyncio.run(main())
|