adaptive-memory-multi-model-router 1.2.2 → 1.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +146 -66
- package/dist/index.d.ts +1 -1
- package/dist/index.js +1 -1
- package/dist/integrations/airtable.js +20 -0
- package/dist/integrations/discord.js +18 -0
- package/dist/integrations/github.js +23 -0
- package/dist/integrations/gmail.js +19 -0
- package/dist/integrations/google-calendar.js +18 -0
- package/dist/integrations/index.js +61 -0
- package/dist/integrations/jira.js +21 -0
- package/dist/integrations/linear.js +19 -0
- package/dist/integrations/notion.js +19 -0
- package/dist/integrations/slack.js +18 -0
- package/dist/integrations/telegram.js +19 -0
- package/dist/providers/registry.js +7 -3
- package/docs/ARCHITECTURAL-IMPROVEMENTS-2025.md +1391 -0
- package/docs/ARCHITECTURAL-IMPROVEMENTS-REVISED-2025.md +1051 -0
- package/docs/CONFIGURATION.md +476 -0
- package/docs/COUNCIL_DECISION.json +308 -0
- package/docs/COUNCIL_SUMMARY.md +265 -0
- package/docs/COUNCIL_V2.2_DECISION.md +416 -0
- package/docs/IMPROVEMENT_ROADMAP.md +515 -0
- package/docs/LLM_COUNCIL_DECISION.md +508 -0
- package/docs/QUICK_START_VISIBILITY.md +782 -0
- package/docs/REDDIT_GAP_ANALYSIS.md +299 -0
- package/docs/RESEARCH_BACKED_IMPROVEMENTS.md +1180 -0
- package/docs/TMLPD_QNA.md +751 -0
- package/docs/TMLPD_V2.1_COMPLETE.md +763 -0
- package/docs/TMLPD_V2.2_RESEARCH_ROADMAP.md +754 -0
- package/docs/V2.2_IMPLEMENTATION_COMPLETE.md +446 -0
- package/docs/V2_IMPLEMENTATION_GUIDE.md +388 -0
- package/docs/VISIBILITY_ADOPTION_PLAN.md +1005 -0
- package/docs/launch-content/LAUNCH_EXECUTION_CHECKLIST.md +421 -0
- package/docs/launch-content/README.md +457 -0
- package/docs/launch-content/assets/cost_comparison_100_tasks.png +0 -0
- package/docs/launch-content/assets/cumulative_savings.png +0 -0
- package/docs/launch-content/assets/parallel_speedup.png +0 -0
- package/docs/launch-content/assets/provider_pricing_comparison.png +0 -0
- package/docs/launch-content/assets/task_breakdown_comparison.png +0 -0
- package/docs/launch-content/generate_charts.py +313 -0
- package/docs/launch-content/hn_show_post.md +139 -0
- package/docs/launch-content/partner_outreach_templates.md +745 -0
- package/docs/launch-content/reddit_posts.md +467 -0
- package/docs/launch-content/twitter_thread.txt +460 -0
- package/examples/QUICKSTART.md +1 -1
- package/openclaw-alexa-bridge/ALL_REMAINING_FIXES_PLAN.md +313 -0
- package/openclaw-alexa-bridge/REMAINING_FIXES_SUMMARY.md +277 -0
- package/openclaw-alexa-bridge/src/alexa_handler_no_tmlpd.js +1234 -0
- package/openclaw-alexa-bridge/test_fixes.js +77 -0
- package/package.json +120 -29
- package/package.json.tmp +0 -0
- package/qna/TMLPD_QNA.md +3 -3
- package/skill/SKILL.md +2 -2
- package/src/__tests__/integration/tmpld_integration.test.py +540 -0
- package/src/agents/skill_enhanced_agent.py +318 -0
- package/src/memory/__init__.py +15 -0
- package/src/memory/agentic_memory.py +353 -0
- package/src/memory/semantic_memory.py +444 -0
- package/src/memory/simple_memory.py +466 -0
- package/src/memory/working_memory.py +447 -0
- package/src/orchestration/__init__.py +52 -0
- package/src/orchestration/execution_engine.py +353 -0
- package/src/orchestration/halo_orchestrator.py +367 -0
- package/src/orchestration/mcts_workflow.py +498 -0
- package/src/orchestration/role_assigner.py +473 -0
- package/src/orchestration/task_planner.py +522 -0
- package/src/providers/__init__.py +67 -0
- package/src/providers/anthropic.py +304 -0
- package/src/providers/base.py +241 -0
- package/src/providers/cerebras.py +373 -0
- package/src/providers/registry.py +476 -0
- package/src/routing/__init__.py +30 -0
- package/src/routing/universal_router.py +621 -0
- package/src/skills/TMLPD-QUICKREF.md +210 -0
- package/src/skills/TMLPD-SETUP-SUMMARY.md +157 -0
- package/src/skills/TMLPD.md +540 -0
- package/src/skills/__tests__/skill_manager.test.ts +328 -0
- package/src/skills/skill_manager.py +385 -0
- package/src/skills/test-tmlpd.sh +108 -0
- package/src/skills/tmlpd-category.yaml +67 -0
- package/src/skills/tmlpd-monitoring.yaml +188 -0
- package/src/skills/tmlpd-phase.yaml +132 -0
- package/src/state/__init__.py +17 -0
- package/src/state/simple_checkpoint.py +508 -0
- package/src/tmlpd_agent.py +464 -0
- package/src/tmpld_v2.py +427 -0
- package/src/workflows/__init__.py +18 -0
- package/src/workflows/advanced_difficulty_classifier.py +377 -0
- package/src/workflows/chaining_executor.py +417 -0
- package/src/workflows/difficulty_integration.py +209 -0
- package/src/workflows/orchestrator.py +469 -0
- package/src/workflows/orchestrator_executor.py +456 -0
- package/src/workflows/parallelization_executor.py +382 -0
- package/src/workflows/router.py +311 -0
- package/test_integration_simple.py +86 -0
- package/test_mcts_workflow.py +150 -0
- package/test_templd_integration.py +262 -0
- package/test_universal_router.py +275 -0
- package/tmlpd-pi-extension/README.md +36 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts +114 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.js +285 -0
- package/tmlpd-pi-extension/dist/cache/prefixCache.js.map +1 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.d.ts +58 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.js +153 -0
- package/tmlpd-pi-extension/dist/cache/responseCache.js.map +1 -0
- package/tmlpd-pi-extension/dist/cli.js +59 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.d.ts +95 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.js +240 -0
- package/tmlpd-pi-extension/dist/cost/costTracker.js.map +1 -0
- package/tmlpd-pi-extension/dist/index.d.ts +723 -0
- package/tmlpd-pi-extension/dist/index.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/index.js +239 -0
- package/tmlpd-pi-extension/dist/index.js.map +1 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts +82 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.js +145 -0
- package/tmlpd-pi-extension/dist/memory/episodicMemory.js.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts +102 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js +207 -0
- package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts +85 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js +210 -0
- package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js.map +1 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.d.ts +102 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.js +338 -0
- package/tmlpd-pi-extension/dist/providers/localProvider.js.map +1 -0
- package/tmlpd-pi-extension/dist/providers/registry.d.ts +55 -0
- package/tmlpd-pi-extension/dist/providers/registry.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/providers/registry.js +138 -0
- package/tmlpd-pi-extension/dist/providers/registry.js.map +1 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts +68 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.js +332 -0
- package/tmlpd-pi-extension/dist/routing/advancedRouter.js.map +1 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts +101 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.js +368 -0
- package/tmlpd-pi-extension/dist/tools/tmlpdTools.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts +96 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.js +170 -0
- package/tmlpd-pi-extension/dist/utils/batchProcessor.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/compression.d.ts +61 -0
- package/tmlpd-pi-extension/dist/utils/compression.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/compression.js +281 -0
- package/tmlpd-pi-extension/dist/utils/compression.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/reliability.d.ts +74 -0
- package/tmlpd-pi-extension/dist/utils/reliability.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/reliability.js +177 -0
- package/tmlpd-pi-extension/dist/utils/reliability.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts +117 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js +246 -0
- package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js.map +1 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts +50 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts.map +1 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.js +124 -0
- package/tmlpd-pi-extension/dist/utils/tokenUtils.js.map +1 -0
- package/tmlpd-pi-extension/examples/QUICKSTART.md +183 -0
- package/tmlpd-pi-extension/package-lock.json +75 -0
- package/tmlpd-pi-extension/package.json +172 -0
- package/tmlpd-pi-extension/python/examples.py +53 -0
- package/tmlpd-pi-extension/python/integrations.py +330 -0
- package/tmlpd-pi-extension/python/setup.py +28 -0
- package/tmlpd-pi-extension/python/tmlpd.py +369 -0
- package/tmlpd-pi-extension/qna/REDDIT_GAP_ANALYSIS.md +299 -0
- package/tmlpd-pi-extension/qna/TMLPD_QNA.md +751 -0
- package/tmlpd-pi-extension/skill/SKILL.md +238 -0
- package/{src → tmlpd-pi-extension/src}/index.ts +1 -1
- package/tmlpd-pi-extension/tsconfig.json +18 -0
- package/demo/research-demo.js +0 -266
- package/notebooks/quickstart.ipynb +0 -157
- package/rust/tmlpd.h +0 -268
- package/src/cache/prefixCache.ts +0 -365
- package/src/routing/advancedRouter.ts +0 -406
- package/src/utils/speculativeDecoding.ts +0 -344
- /package/{src → tmlpd-pi-extension/src}/cache/responseCache.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/cost/costTracker.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/memory/episodicMemory.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/orchestration/haloOrchestrator.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/orchestration/mctsWorkflow.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/providers/localProvider.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/providers/registry.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/tools/tmlpdTools.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/batchProcessor.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/compression.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/reliability.ts +0 -0
- /package/{src → tmlpd-pi-extension/src}/utils/tokenUtils.ts +0 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCTS Workflow Search - Optimize HALO orchestration strategies
|
|
3
|
+
|
|
4
|
+
Based on:
|
|
5
|
+
- arXiv:2505.13516 (HALO - MCTS-based workflow search)
|
|
6
|
+
- NeurIPS 2023 (MCTS for LLM agent planning)
|
|
7
|
+
- ICLR 2024 (Tree of Thoughts + MCTS)
|
|
8
|
+
|
|
9
|
+
Key Innovation: Uses Monte Carlo Tree Search to explore different
|
|
10
|
+
execution strategies and learn the best workflow for each task type.
|
|
11
|
+
|
|
12
|
+
Features:
|
|
13
|
+
- Explores different agent assignment strategies
|
|
14
|
+
- Learns from execution outcomes
|
|
15
|
+
- Balances exploration vs exploitation
|
|
16
|
+
- Adapts workflow to task complexity
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import asyncio
|
|
20
|
+
import random
|
|
21
|
+
import math
|
|
22
|
+
from typing import Dict, List, Any, Optional, Tuple
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from datetime import datetime
|
|
25
|
+
from collections import defaultdict
|
|
26
|
+
import logging
|
|
27
|
+
|
|
28
|
+
from .task_planner import SubTask, TaskDecomposition
|
|
29
|
+
from .role_assigner import RoleAssigner, AgentAssignment, AgentRole
|
|
30
|
+
from .execution_engine import ExecutionEngine, ExecutionSummary
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class WorkflowNode:
|
|
38
|
+
"""Node in MCTS search tree representing a workflow state"""
|
|
39
|
+
state_id: str # Unique identifier for this state
|
|
40
|
+
subtasks_remaining: List[SubTask]
|
|
41
|
+
completed_subtasks: List[str]
|
|
42
|
+
current_strategy: Dict[str, str] # subtask_id -> model_id assignments
|
|
43
|
+
results_so_far: Dict[str, Any] = field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
# MCTS statistics
|
|
46
|
+
visits: int = 0
|
|
47
|
+
total_reward: float = 0.0
|
|
48
|
+
parent: Optional["WorkflowNode"] = None
|
|
49
|
+
children: List["WorkflowNode"] = field(default_factory=list)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def average_reward(self) -> float:
|
|
53
|
+
"""Average reward per visit"""
|
|
54
|
+
return self.total_reward / self.visits if self.visits > 0 else 0.0
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def is_fully_expanded(self) -> bool:
|
|
58
|
+
"""Whether all possible actions have been tried"""
|
|
59
|
+
return len(self.children) > 0
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def is_terminal(self) -> bool:
|
|
63
|
+
"""Whether this is a terminal state (all subtasks complete)"""
|
|
64
|
+
return len(self.subtasks_remaining) == 0
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class WorkflowStrategy:
|
|
69
|
+
"""Complete workflow strategy (agent assignments for all subtasks)"""
|
|
70
|
+
subtask_assignments: Dict[str, str] # subtask_id -> model_id
|
|
71
|
+
expected_quality: float
|
|
72
|
+
expected_cost: float
|
|
73
|
+
execution_order: List[str]
|
|
74
|
+
|
|
75
|
+
# Execution statistics (if executed)
|
|
76
|
+
actual_quality: Optional[float] = None
|
|
77
|
+
actual_cost: Optional[float] = None
|
|
78
|
+
execution_time_seconds: Optional[float] = None
|
|
79
|
+
|
|
80
|
+
def to_dict(self) -> Dict:
|
|
81
|
+
return {
|
|
82
|
+
"subtask_assignments": self.subtask_assignments,
|
|
83
|
+
"expected_quality": self.expected_quality,
|
|
84
|
+
"expected_cost": self.expected_cost,
|
|
85
|
+
"execution_order": self.execution_order,
|
|
86
|
+
"actual_quality": self.actual_quality,
|
|
87
|
+
"actual_cost": self.actual_cost,
|
|
88
|
+
"execution_time_seconds": self.execution_time_seconds
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class MCTSWorkflowSearch:
|
|
93
|
+
"""
|
|
94
|
+
Monte Carlo Tree Search for workflow optimization
|
|
95
|
+
|
|
96
|
+
Explores different agent assignment strategies to find optimal workflow.
|
|
97
|
+
|
|
98
|
+
Key Innovation:
|
|
99
|
+
- Treats workflow optimization as a search problem
|
|
100
|
+
- Uses MCTS to balance exploration/exploitation
|
|
101
|
+
- Learns from execution outcomes
|
|
102
|
+
- Adapts to task complexity
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
role_assigner: RoleAssigner,
|
|
108
|
+
execution_engine: ExecutionEngine,
|
|
109
|
+
exploration_weight: float = 1.414, # sqrt(2) for UCB1
|
|
110
|
+
max_simulations: int = 100,
|
|
111
|
+
max_depth: int = 10
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Initialize MCTS workflow search
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
role_assigner: RoleAssigner for getting candidate agents
|
|
118
|
+
execution_engine: ExecutionEngine for simulating workflows
|
|
119
|
+
exploration_weight: UCB1 exploration parameter (higher = more exploration)
|
|
120
|
+
max_simulations: Maximum MCTS simulations per search
|
|
121
|
+
max_depth: Maximum search depth
|
|
122
|
+
"""
|
|
123
|
+
self.role_assigner = role_assigner
|
|
124
|
+
self.execution_engine = execution_engine
|
|
125
|
+
self.exploration_weight = exploration_weight
|
|
126
|
+
self.max_simulations = max_simulations
|
|
127
|
+
self.max_depth = max_depth
|
|
128
|
+
|
|
129
|
+
# Strategy history for learning
|
|
130
|
+
self.strategy_history: List[WorkflowStrategy] = []
|
|
131
|
+
self.performance_cache: Dict[str, float] = {} # strategy_hash -> reward
|
|
132
|
+
|
|
133
|
+
async def search(
|
|
134
|
+
self,
|
|
135
|
+
decomposition: TaskDecomposition,
|
|
136
|
+
optimization_target: str = "balanced"
|
|
137
|
+
) -> WorkflowStrategy:
|
|
138
|
+
"""
|
|
139
|
+
Search for optimal workflow strategy using MCTS
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
decomposition: Task decomposition with subtasks
|
|
143
|
+
optimization_target: "quality", "cost", or "balanced"
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Best WorkflowStrategy found
|
|
147
|
+
"""
|
|
148
|
+
logger.info(f"Starting MCTS workflow search for {len(decomposition.subtasks)} subtasks")
|
|
149
|
+
start_time = datetime.now()
|
|
150
|
+
|
|
151
|
+
# Initialize root node
|
|
152
|
+
root = WorkflowNode(
|
|
153
|
+
state_id="root",
|
|
154
|
+
subtasks_remaining=decomposition.subtasks.copy(),
|
|
155
|
+
completed_subtasks=[],
|
|
156
|
+
current_strategy={}
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Run MCTS simulations
|
|
160
|
+
for simulation in range(self.max_simulations):
|
|
161
|
+
if simulation % 20 == 0:
|
|
162
|
+
logger.debug(f"MCTS simulation {simulation}/{self.max_simulations}")
|
|
163
|
+
|
|
164
|
+
# Selection: Traverse tree to find promising node
|
|
165
|
+
node = self._select(root)
|
|
166
|
+
|
|
167
|
+
# Expansion: Add child nodes if not terminal
|
|
168
|
+
if not node.is_terminal and node.visits > 0:
|
|
169
|
+
node = self._expand(node, decomposition, optimization_target) # Changed from await
|
|
170
|
+
|
|
171
|
+
# Simulation: Execute workflow and get reward
|
|
172
|
+
reward = await self._simulate(node, decomposition, optimization_target)
|
|
173
|
+
|
|
174
|
+
# Backpropagation: Update statistics
|
|
175
|
+
self._backpropagate(node, reward)
|
|
176
|
+
|
|
177
|
+
# Select best strategy from root
|
|
178
|
+
best_child = max(root.children, key=lambda c: c.average_reward)
|
|
179
|
+
best_strategy = self._node_to_strategy(best_child, decomposition)
|
|
180
|
+
|
|
181
|
+
elapsed = (datetime.now() - start_time).total_seconds()
|
|
182
|
+
logger.info(f"MCTS search complete in {elapsed:.2f}s")
|
|
183
|
+
logger.info(f" Simulations: {self.max_simulations}")
|
|
184
|
+
logger.info(f" Tree depth: {self._tree_depth(root)}")
|
|
185
|
+
logger.info(f" Best strategy reward: {best_child.average_reward:.3f}")
|
|
186
|
+
|
|
187
|
+
# Cache and store
|
|
188
|
+
self.strategy_history.append(best_strategy)
|
|
189
|
+
|
|
190
|
+
return best_strategy
|
|
191
|
+
|
|
192
|
+
def _select(self, node: WorkflowNode) -> WorkflowNode:
|
|
193
|
+
"""
|
|
194
|
+
Selection phase: UCB1 policy to balance exploration/exploitation
|
|
195
|
+
|
|
196
|
+
UCB1 = average_reward + exploration_weight * sqrt(ln(parent_visits) / visits)
|
|
197
|
+
"""
|
|
198
|
+
current = node
|
|
199
|
+
|
|
200
|
+
while current.is_fully_expanded and not current.is_terminal:
|
|
201
|
+
# Select child with highest UCB1 score
|
|
202
|
+
current = max(
|
|
203
|
+
current.children,
|
|
204
|
+
key=lambda child: self._ucb1(child, current)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return current
|
|
208
|
+
|
|
209
|
+
def _ucb1(self, child: WorkflowNode, parent: WorkflowNode) -> float:
|
|
210
|
+
"""UCB1 formula for node selection"""
|
|
211
|
+
if child.visits == 0:
|
|
212
|
+
return float('inf') # Prioritize unvisited nodes
|
|
213
|
+
|
|
214
|
+
exploitation = child.average_reward
|
|
215
|
+
exploration = self.exploration_weight * math.sqrt(
|
|
216
|
+
math.log(parent.visits) / child.visits
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return exploitation + exploration
|
|
220
|
+
|
|
221
|
+
def _expand(
|
|
222
|
+
self,
|
|
223
|
+
node: WorkflowNode,
|
|
224
|
+
decomposition: TaskDecomposition,
|
|
225
|
+
optimization_target: str
|
|
226
|
+
) -> WorkflowNode:
|
|
227
|
+
"""
|
|
228
|
+
Expansion phase: Add new child node with different action
|
|
229
|
+
|
|
230
|
+
Action = assign a different model to next subtask
|
|
231
|
+
"""
|
|
232
|
+
# Get next subtask to execute
|
|
233
|
+
if not node.subtasks_remaining:
|
|
234
|
+
return node # Terminal
|
|
235
|
+
|
|
236
|
+
next_subtask = node.subtasks_remaining[0]
|
|
237
|
+
|
|
238
|
+
# Try different model assignments for this subtask
|
|
239
|
+
# Get candidates from role assigner
|
|
240
|
+
candidates = self._get_model_candidates(next_subtask, optimization_target)
|
|
241
|
+
|
|
242
|
+
# Try unexplored assignments
|
|
243
|
+
for candidate_model in candidates:
|
|
244
|
+
# Check if this assignment already tried
|
|
245
|
+
assignment_key = f"{next_subtask.id}->{candidate_model}"
|
|
246
|
+
if any(assignment_key in child.current_strategy.values() for child in node.children):
|
|
247
|
+
continue # Already explored
|
|
248
|
+
|
|
249
|
+
# Create new child node
|
|
250
|
+
new_strategy = node.current_strategy.copy()
|
|
251
|
+
new_strategy[next_subtask.id] = candidate_model
|
|
252
|
+
|
|
253
|
+
new_remaining = node.subtasks_remaining[1:] # Remove this subtask
|
|
254
|
+
new_completed = node.completed_subtasks + [next_subtask.id]
|
|
255
|
+
|
|
256
|
+
child = WorkflowNode(
|
|
257
|
+
state_id=f"{node.state_id}_{len(node.children)}",
|
|
258
|
+
subtasks_remaining=new_remaining,
|
|
259
|
+
completed_subtasks=new_completed,
|
|
260
|
+
current_strategy=new_strategy,
|
|
261
|
+
parent=node
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
node.children.append(child)
|
|
265
|
+
return child # Return first newly expanded child
|
|
266
|
+
|
|
267
|
+
# If all candidates explored, return existing child
|
|
268
|
+
return node.children[0] if node.children else node
|
|
269
|
+
|
|
270
|
+
def _get_model_candidates(
|
|
271
|
+
self,
|
|
272
|
+
subtask: SubTask,
|
|
273
|
+
optimization_target: str
|
|
274
|
+
) -> List[str]:
|
|
275
|
+
"""Get candidate models for a subtask"""
|
|
276
|
+
# Use role assigner's model registry
|
|
277
|
+
role = self.role_assigner._map_task_type_to_role(subtask.task_type)
|
|
278
|
+
candidates = self.role_assigner.model_registry.get(role, [])
|
|
279
|
+
|
|
280
|
+
# Return top 3 models by optimization target
|
|
281
|
+
if optimization_target == "quality":
|
|
282
|
+
scored = sorted(candidates, key=lambda m: m["quality_score"], reverse=True)
|
|
283
|
+
elif optimization_target == "cost":
|
|
284
|
+
scored = sorted(candidates, key=lambda m: m["cost_per_1k_tokens"])
|
|
285
|
+
else: # balanced
|
|
286
|
+
scored = sorted(
|
|
287
|
+
candidates,
|
|
288
|
+
key=lambda m: m["quality_score"] / (m["cost_per_1k_tokens"] + 0.0001),
|
|
289
|
+
reverse=True
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
return [f"{c['provider']}/{c['model']}" for c in scored[:3]]
|
|
293
|
+
|
|
294
|
+
async def _simulate(
|
|
295
|
+
self,
|
|
296
|
+
node: WorkflowNode,
|
|
297
|
+
decomposition: TaskDecomposition,
|
|
298
|
+
optimization_target: str
|
|
299
|
+
) -> float:
|
|
300
|
+
"""
|
|
301
|
+
Simulation phase: Execute workflow and estimate reward
|
|
302
|
+
|
|
303
|
+
Reward = weighted combination of quality, cost, speed
|
|
304
|
+
"""
|
|
305
|
+
if node.is_terminal:
|
|
306
|
+
# Fully executed workflow - use actual results
|
|
307
|
+
return self._calculate_reward(
|
|
308
|
+
node.results_so_far.get("quality", 0.5),
|
|
309
|
+
node.results_so_far.get("cost", 1.0),
|
|
310
|
+
node.results_so_far.get("time", 1.0),
|
|
311
|
+
optimization_target
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Partial workflow - estimate reward using heuristics
|
|
315
|
+
total_quality = 0.0
|
|
316
|
+
total_cost = 0.0
|
|
317
|
+
total_time = 0.0
|
|
318
|
+
|
|
319
|
+
for subtask_id, model_id in node.current_strategy.items():
|
|
320
|
+
# Estimate quality based on model and subtask
|
|
321
|
+
subtask = next(st for st in decomposition.subtasks if st.id == subtask_id)
|
|
322
|
+
|
|
323
|
+
# Simple heuristic: match quality score to subtask difficulty
|
|
324
|
+
base_quality = 0.8 # Default
|
|
325
|
+
|
|
326
|
+
# Get model profile
|
|
327
|
+
if "/" in model_id:
|
|
328
|
+
provider, model = model_id.split("/", 1)
|
|
329
|
+
role = self.role_assigner._map_task_type_to_role(subtask.task_type)
|
|
330
|
+
candidates = self.role_assigner.model_registry.get(role, [])
|
|
331
|
+
for c in candidates:
|
|
332
|
+
if c["provider"] == provider and c["model"] in model:
|
|
333
|
+
base_quality = c["quality_score"]
|
|
334
|
+
total_cost += c["cost_per_1k_tokens"]
|
|
335
|
+
break
|
|
336
|
+
|
|
337
|
+
# Adjust for subtask difficulty
|
|
338
|
+
difficulty_factor = 1.0
|
|
339
|
+
if subtask.difficulty > 70:
|
|
340
|
+
difficulty_factor = 1.2 # Harder tasks reduce quality
|
|
341
|
+
elif subtask.difficulty < 30:
|
|
342
|
+
difficulty_factor = 0.9 # Easier tasks boost quality
|
|
343
|
+
|
|
344
|
+
adjusted_quality = base_quality / difficulty_factor
|
|
345
|
+
total_quality += adjusted_quality
|
|
346
|
+
|
|
347
|
+
# Estimate time (faster models = less time)
|
|
348
|
+
total_time += subtask.estimated_duration_seconds
|
|
349
|
+
|
|
350
|
+
# Average quality across subtasks
|
|
351
|
+
avg_quality = total_quality / len(node.current_strategy) if node.current_strategy else 0.5
|
|
352
|
+
|
|
353
|
+
return self._calculate_reward(
|
|
354
|
+
avg_quality,
|
|
355
|
+
total_cost,
|
|
356
|
+
total_time,
|
|
357
|
+
optimization_target
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def _calculate_reward(
|
|
361
|
+
self,
|
|
362
|
+
quality: float,
|
|
363
|
+
cost: float,
|
|
364
|
+
time: float,
|
|
365
|
+
optimization_target: str
|
|
366
|
+
) -> float:
|
|
367
|
+
"""Calculate reward based on optimization target"""
|
|
368
|
+
if optimization_target == "quality":
|
|
369
|
+
# Maximize quality, ignore cost/time
|
|
370
|
+
return quality
|
|
371
|
+
|
|
372
|
+
elif optimization_target == "cost":
|
|
373
|
+
# Minimize cost (invert for reward)
|
|
374
|
+
return 1.0 / (cost + 0.001)
|
|
375
|
+
|
|
376
|
+
else: # balanced
|
|
377
|
+
# Weighted combination
|
|
378
|
+
quality_weight = 0.6
|
|
379
|
+
cost_weight = 0.3
|
|
380
|
+
time_weight = 0.1
|
|
381
|
+
|
|
382
|
+
normalized_quality = quality
|
|
383
|
+
normalized_cost = 1.0 / (cost + 0.001)
|
|
384
|
+
normalized_time = 1.0 / (time + 0.1)
|
|
385
|
+
|
|
386
|
+
return (
|
|
387
|
+
quality_weight * normalized_quality +
|
|
388
|
+
cost_weight * normalized_cost +
|
|
389
|
+
time_weight * normalized_time
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
def _backpropagate(self, node: WorkflowNode, reward: float):
|
|
393
|
+
"""Backpropagation phase: Update statistics up the tree"""
|
|
394
|
+
current = node
|
|
395
|
+
while current is not None:
|
|
396
|
+
current.visits += 1
|
|
397
|
+
current.total_reward += reward
|
|
398
|
+
current = current.parent
|
|
399
|
+
|
|
400
|
+
def _node_to_strategy(
|
|
401
|
+
self,
|
|
402
|
+
node: WorkflowNode,
|
|
403
|
+
decomposition: TaskDecomposition
|
|
404
|
+
) -> WorkflowStrategy:
|
|
405
|
+
"""Convert node to WorkflowStrategy"""
|
|
406
|
+
# Extract execution order
|
|
407
|
+
execution_order = node.completed_subtasks + [st.id for st in node.subtasks_remaining]
|
|
408
|
+
|
|
409
|
+
# Estimate quality and cost
|
|
410
|
+
avg_quality = sum(
|
|
411
|
+
self.role_assigner.model_registry
|
|
412
|
+
.get(self.role_assigner._map_task_type_to_role(next(st for st in decomposition.subtasks if st.id == st_id).task_type), [])
|
|
413
|
+
[0]["quality_score"] if self.role_assigner.model_registry.get(
|
|
414
|
+
self.role_assigner._map_task_type_to_role(next(st for st in decomposition.subtasks if st.id == st_id).task_type)
|
|
415
|
+
) else 0.8
|
|
416
|
+
for st_id in execution_order
|
|
417
|
+
) / len(execution_order) if execution_order else 0.5
|
|
418
|
+
|
|
419
|
+
total_cost = sum(
|
|
420
|
+
self.role_assigner.model_registry
|
|
421
|
+
.get(self.role_assigner._map_task_type_to_role(next(st for st in decomposition.subtasks if st.id == st_id).task_type), [])
|
|
422
|
+
[0]["cost_per_1k_tokens"]
|
|
423
|
+
for st_id in node.current_strategy.keys()
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
return WorkflowStrategy(
|
|
427
|
+
subtask_assignments=node.current_strategy,
|
|
428
|
+
expected_quality=avg_quality,
|
|
429
|
+
expected_cost=total_cost,
|
|
430
|
+
execution_order=execution_order
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
def _tree_depth(self, node: WorkflowNode) -> int:
|
|
434
|
+
"""Calculate maximum depth of tree from node"""
|
|
435
|
+
if not node.children:
|
|
436
|
+
return 0
|
|
437
|
+
return 1 + max(self._tree_depth(child) for child in node.children)
|
|
438
|
+
|
|
439
|
+
def get_search_stats(self) -> Dict[str, Any]:
|
|
440
|
+
"""Get statistics about searches performed"""
|
|
441
|
+
if not self.strategy_history:
|
|
442
|
+
return {"total_searches": 0}
|
|
443
|
+
|
|
444
|
+
return {
|
|
445
|
+
"total_searches": len(self.strategy_history),
|
|
446
|
+
"cache_size": len(self.performance_cache),
|
|
447
|
+
"avg_strategy_quality": sum(s.expected_quality for s in self.strategy_history) / len(self.strategy_history),
|
|
448
|
+
"avg_strategy_cost": sum(s.expected_cost for s in self.strategy_history) / len(self.strategy_history)
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# Example usage
|
|
453
|
+
async def main():
|
|
454
|
+
"""Example of MCTS workflow search"""
|
|
455
|
+
from .task_planner import TaskPlanner
|
|
456
|
+
from .role_assigner import RoleAssigner
|
|
457
|
+
from .execution_engine import ExecutionEngine
|
|
458
|
+
|
|
459
|
+
# Create components
|
|
460
|
+
planner = TaskPlanner()
|
|
461
|
+
assigner = RoleAssigner()
|
|
462
|
+
engine = ExecutionEngine(max_concurrent=3)
|
|
463
|
+
mcts = MCTSWorkflowSearch(assigner, engine, max_simulations=50)
|
|
464
|
+
|
|
465
|
+
# Create task
|
|
466
|
+
task = {
|
|
467
|
+
"description": "Build a REST API with authentication, database integration, and automated testing",
|
|
468
|
+
"context": {"requirements": ["JWT", "PostgreSQL", "Jest"]}
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
# Decompose
|
|
472
|
+
print("Decomposing task...")
|
|
473
|
+
decomposition = await planner.decompose(task)
|
|
474
|
+
|
|
475
|
+
print(f"\nDecomposed into {len(decomposition.subtasks)} subtasks:")
|
|
476
|
+
for st in decomposition.subtasks:
|
|
477
|
+
print(f" {st.id}: {st.description} (difficulty: {st.difficulty})")
|
|
478
|
+
|
|
479
|
+
# Search for optimal workflow
|
|
480
|
+
print("\nRunning MCTS workflow search...")
|
|
481
|
+
best_strategy = await mcts.search(decomposition, optimization_target="balanced")
|
|
482
|
+
|
|
483
|
+
print(f"\nBest Workflow Strategy:")
|
|
484
|
+
print(f" Assignments: {best_strategy.subtask_assignments}")
|
|
485
|
+
print(f" Expected Quality: {best_strategy.expected_quality:.2f}")
|
|
486
|
+
print(f" Expected Cost: ${best_strategy.expected_cost:.4f}")
|
|
487
|
+
print(f" Execution Order: {best_strategy.execution_order}")
|
|
488
|
+
|
|
489
|
+
# Stats
|
|
490
|
+
stats = mcts.get_search_stats()
|
|
491
|
+
print(f"\nMCTS Stats:")
|
|
492
|
+
print(f" Total searches: {stats['total_searches']}")
|
|
493
|
+
print(f" Cache size: {stats['cache_size']}")
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
if __name__ == "__main__":
|
|
497
|
+
logging.basicConfig(level=logging.INFO)
|
|
498
|
+
asyncio.run(main())
|