adaptive-memory-multi-model-router 1.2.2 → 1.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +146 -66
  3. package/dist/index.d.ts +1 -1
  4. package/dist/index.js +1 -1
  5. package/dist/integrations/airtable.js +20 -0
  6. package/dist/integrations/discord.js +18 -0
  7. package/dist/integrations/github.js +23 -0
  8. package/dist/integrations/gmail.js +19 -0
  9. package/dist/integrations/google-calendar.js +18 -0
  10. package/dist/integrations/index.js +61 -0
  11. package/dist/integrations/jira.js +21 -0
  12. package/dist/integrations/linear.js +19 -0
  13. package/dist/integrations/notion.js +19 -0
  14. package/dist/integrations/slack.js +18 -0
  15. package/dist/integrations/telegram.js +19 -0
  16. package/dist/providers/registry.js +7 -3
  17. package/docs/ARCHITECTURAL-IMPROVEMENTS-2025.md +1391 -0
  18. package/docs/ARCHITECTURAL-IMPROVEMENTS-REVISED-2025.md +1051 -0
  19. package/docs/CONFIGURATION.md +476 -0
  20. package/docs/COUNCIL_DECISION.json +308 -0
  21. package/docs/COUNCIL_SUMMARY.md +265 -0
  22. package/docs/COUNCIL_V2.2_DECISION.md +416 -0
  23. package/docs/IMPROVEMENT_ROADMAP.md +515 -0
  24. package/docs/LLM_COUNCIL_DECISION.md +508 -0
  25. package/docs/QUICK_START_VISIBILITY.md +782 -0
  26. package/docs/REDDIT_GAP_ANALYSIS.md +299 -0
  27. package/docs/RESEARCH_BACKED_IMPROVEMENTS.md +1180 -0
  28. package/docs/TMLPD_QNA.md +751 -0
  29. package/docs/TMLPD_V2.1_COMPLETE.md +763 -0
  30. package/docs/TMLPD_V2.2_RESEARCH_ROADMAP.md +754 -0
  31. package/docs/V2.2_IMPLEMENTATION_COMPLETE.md +446 -0
  32. package/docs/V2_IMPLEMENTATION_GUIDE.md +388 -0
  33. package/docs/VISIBILITY_ADOPTION_PLAN.md +1005 -0
  34. package/docs/launch-content/LAUNCH_EXECUTION_CHECKLIST.md +421 -0
  35. package/docs/launch-content/README.md +457 -0
  36. package/docs/launch-content/assets/cost_comparison_100_tasks.png +0 -0
  37. package/docs/launch-content/assets/cumulative_savings.png +0 -0
  38. package/docs/launch-content/assets/parallel_speedup.png +0 -0
  39. package/docs/launch-content/assets/provider_pricing_comparison.png +0 -0
  40. package/docs/launch-content/assets/task_breakdown_comparison.png +0 -0
  41. package/docs/launch-content/generate_charts.py +313 -0
  42. package/docs/launch-content/hn_show_post.md +139 -0
  43. package/docs/launch-content/partner_outreach_templates.md +745 -0
  44. package/docs/launch-content/reddit_posts.md +467 -0
  45. package/docs/launch-content/twitter_thread.txt +460 -0
  46. package/examples/QUICKSTART.md +1 -1
  47. package/openclaw-alexa-bridge/ALL_REMAINING_FIXES_PLAN.md +313 -0
  48. package/openclaw-alexa-bridge/REMAINING_FIXES_SUMMARY.md +277 -0
  49. package/openclaw-alexa-bridge/src/alexa_handler_no_tmlpd.js +1234 -0
  50. package/openclaw-alexa-bridge/test_fixes.js +77 -0
  51. package/package.json +120 -29
  52. package/package.json.tmp +0 -0
  53. package/qna/TMLPD_QNA.md +3 -3
  54. package/skill/SKILL.md +2 -2
  55. package/src/__tests__/integration/tmpld_integration.test.py +540 -0
  56. package/src/agents/skill_enhanced_agent.py +318 -0
  57. package/src/memory/__init__.py +15 -0
  58. package/src/memory/agentic_memory.py +353 -0
  59. package/src/memory/semantic_memory.py +444 -0
  60. package/src/memory/simple_memory.py +466 -0
  61. package/src/memory/working_memory.py +447 -0
  62. package/src/orchestration/__init__.py +52 -0
  63. package/src/orchestration/execution_engine.py +353 -0
  64. package/src/orchestration/halo_orchestrator.py +367 -0
  65. package/src/orchestration/mcts_workflow.py +498 -0
  66. package/src/orchestration/role_assigner.py +473 -0
  67. package/src/orchestration/task_planner.py +522 -0
  68. package/src/providers/__init__.py +67 -0
  69. package/src/providers/anthropic.py +304 -0
  70. package/src/providers/base.py +241 -0
  71. package/src/providers/cerebras.py +373 -0
  72. package/src/providers/registry.py +476 -0
  73. package/src/routing/__init__.py +30 -0
  74. package/src/routing/universal_router.py +621 -0
  75. package/src/skills/TMLPD-QUICKREF.md +210 -0
  76. package/src/skills/TMLPD-SETUP-SUMMARY.md +157 -0
  77. package/src/skills/TMLPD.md +540 -0
  78. package/src/skills/__tests__/skill_manager.test.ts +328 -0
  79. package/src/skills/skill_manager.py +385 -0
  80. package/src/skills/test-tmlpd.sh +108 -0
  81. package/src/skills/tmlpd-category.yaml +67 -0
  82. package/src/skills/tmlpd-monitoring.yaml +188 -0
  83. package/src/skills/tmlpd-phase.yaml +132 -0
  84. package/src/state/__init__.py +17 -0
  85. package/src/state/simple_checkpoint.py +508 -0
  86. package/src/tmlpd_agent.py +464 -0
  87. package/src/tmpld_v2.py +427 -0
  88. package/src/workflows/__init__.py +18 -0
  89. package/src/workflows/advanced_difficulty_classifier.py +377 -0
  90. package/src/workflows/chaining_executor.py +417 -0
  91. package/src/workflows/difficulty_integration.py +209 -0
  92. package/src/workflows/orchestrator.py +469 -0
  93. package/src/workflows/orchestrator_executor.py +456 -0
  94. package/src/workflows/parallelization_executor.py +382 -0
  95. package/src/workflows/router.py +311 -0
  96. package/test_integration_simple.py +86 -0
  97. package/test_mcts_workflow.py +150 -0
  98. package/test_templd_integration.py +262 -0
  99. package/test_universal_router.py +275 -0
  100. package/tmlpd-pi-extension/README.md +36 -0
  101. package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts +114 -0
  102. package/tmlpd-pi-extension/dist/cache/prefixCache.d.ts.map +1 -0
  103. package/tmlpd-pi-extension/dist/cache/prefixCache.js +285 -0
  104. package/tmlpd-pi-extension/dist/cache/prefixCache.js.map +1 -0
  105. package/tmlpd-pi-extension/dist/cache/responseCache.d.ts +58 -0
  106. package/tmlpd-pi-extension/dist/cache/responseCache.d.ts.map +1 -0
  107. package/tmlpd-pi-extension/dist/cache/responseCache.js +153 -0
  108. package/tmlpd-pi-extension/dist/cache/responseCache.js.map +1 -0
  109. package/tmlpd-pi-extension/dist/cli.js +59 -0
  110. package/tmlpd-pi-extension/dist/cost/costTracker.d.ts +95 -0
  111. package/tmlpd-pi-extension/dist/cost/costTracker.d.ts.map +1 -0
  112. package/tmlpd-pi-extension/dist/cost/costTracker.js +240 -0
  113. package/tmlpd-pi-extension/dist/cost/costTracker.js.map +1 -0
  114. package/tmlpd-pi-extension/dist/index.d.ts +723 -0
  115. package/tmlpd-pi-extension/dist/index.d.ts.map +1 -0
  116. package/tmlpd-pi-extension/dist/index.js +239 -0
  117. package/tmlpd-pi-extension/dist/index.js.map +1 -0
  118. package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts +82 -0
  119. package/tmlpd-pi-extension/dist/memory/episodicMemory.d.ts.map +1 -0
  120. package/tmlpd-pi-extension/dist/memory/episodicMemory.js +145 -0
  121. package/tmlpd-pi-extension/dist/memory/episodicMemory.js.map +1 -0
  122. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts +102 -0
  123. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.d.ts.map +1 -0
  124. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js +207 -0
  125. package/tmlpd-pi-extension/dist/orchestration/haloOrchestrator.js.map +1 -0
  126. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts +85 -0
  127. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.d.ts.map +1 -0
  128. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js +210 -0
  129. package/tmlpd-pi-extension/dist/orchestration/mctsWorkflow.js.map +1 -0
  130. package/tmlpd-pi-extension/dist/providers/localProvider.d.ts +102 -0
  131. package/tmlpd-pi-extension/dist/providers/localProvider.d.ts.map +1 -0
  132. package/tmlpd-pi-extension/dist/providers/localProvider.js +338 -0
  133. package/tmlpd-pi-extension/dist/providers/localProvider.js.map +1 -0
  134. package/tmlpd-pi-extension/dist/providers/registry.d.ts +55 -0
  135. package/tmlpd-pi-extension/dist/providers/registry.d.ts.map +1 -0
  136. package/tmlpd-pi-extension/dist/providers/registry.js +138 -0
  137. package/tmlpd-pi-extension/dist/providers/registry.js.map +1 -0
  138. package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts +68 -0
  139. package/tmlpd-pi-extension/dist/routing/advancedRouter.d.ts.map +1 -0
  140. package/tmlpd-pi-extension/dist/routing/advancedRouter.js +332 -0
  141. package/tmlpd-pi-extension/dist/routing/advancedRouter.js.map +1 -0
  142. package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts +101 -0
  143. package/tmlpd-pi-extension/dist/tools/tmlpdTools.d.ts.map +1 -0
  144. package/tmlpd-pi-extension/dist/tools/tmlpdTools.js +368 -0
  145. package/tmlpd-pi-extension/dist/tools/tmlpdTools.js.map +1 -0
  146. package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts +96 -0
  147. package/tmlpd-pi-extension/dist/utils/batchProcessor.d.ts.map +1 -0
  148. package/tmlpd-pi-extension/dist/utils/batchProcessor.js +170 -0
  149. package/tmlpd-pi-extension/dist/utils/batchProcessor.js.map +1 -0
  150. package/tmlpd-pi-extension/dist/utils/compression.d.ts +61 -0
  151. package/tmlpd-pi-extension/dist/utils/compression.d.ts.map +1 -0
  152. package/tmlpd-pi-extension/dist/utils/compression.js +281 -0
  153. package/tmlpd-pi-extension/dist/utils/compression.js.map +1 -0
  154. package/tmlpd-pi-extension/dist/utils/reliability.d.ts +74 -0
  155. package/tmlpd-pi-extension/dist/utils/reliability.d.ts.map +1 -0
  156. package/tmlpd-pi-extension/dist/utils/reliability.js +177 -0
  157. package/tmlpd-pi-extension/dist/utils/reliability.js.map +1 -0
  158. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts +117 -0
  159. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.d.ts.map +1 -0
  160. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js +246 -0
  161. package/tmlpd-pi-extension/dist/utils/speculativeDecoding.js.map +1 -0
  162. package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts +50 -0
  163. package/tmlpd-pi-extension/dist/utils/tokenUtils.d.ts.map +1 -0
  164. package/tmlpd-pi-extension/dist/utils/tokenUtils.js +124 -0
  165. package/tmlpd-pi-extension/dist/utils/tokenUtils.js.map +1 -0
  166. package/tmlpd-pi-extension/examples/QUICKSTART.md +183 -0
  167. package/tmlpd-pi-extension/package-lock.json +75 -0
  168. package/tmlpd-pi-extension/package.json +172 -0
  169. package/tmlpd-pi-extension/python/examples.py +53 -0
  170. package/tmlpd-pi-extension/python/integrations.py +330 -0
  171. package/tmlpd-pi-extension/python/setup.py +28 -0
  172. package/tmlpd-pi-extension/python/tmlpd.py +369 -0
  173. package/tmlpd-pi-extension/qna/REDDIT_GAP_ANALYSIS.md +299 -0
  174. package/tmlpd-pi-extension/qna/TMLPD_QNA.md +751 -0
  175. package/tmlpd-pi-extension/skill/SKILL.md +238 -0
  176. package/{src → tmlpd-pi-extension/src}/index.ts +1 -1
  177. package/tmlpd-pi-extension/tsconfig.json +18 -0
  178. package/demo/research-demo.js +0 -266
  179. package/notebooks/quickstart.ipynb +0 -157
  180. package/rust/tmlpd.h +0 -268
  181. package/src/cache/prefixCache.ts +0 -365
  182. package/src/routing/advancedRouter.ts +0 -406
  183. package/src/utils/speculativeDecoding.ts +0 -344
  184. /package/{src → tmlpd-pi-extension/src}/cache/responseCache.ts +0 -0
  185. /package/{src → tmlpd-pi-extension/src}/cost/costTracker.ts +0 -0
  186. /package/{src → tmlpd-pi-extension/src}/memory/episodicMemory.ts +0 -0
  187. /package/{src → tmlpd-pi-extension/src}/orchestration/haloOrchestrator.ts +0 -0
  188. /package/{src → tmlpd-pi-extension/src}/orchestration/mctsWorkflow.ts +0 -0
  189. /package/{src → tmlpd-pi-extension/src}/providers/localProvider.ts +0 -0
  190. /package/{src → tmlpd-pi-extension/src}/providers/registry.ts +0 -0
  191. /package/{src → tmlpd-pi-extension/src}/tools/tmlpdTools.ts +0 -0
  192. /package/{src → tmlpd-pi-extension/src}/utils/batchProcessor.ts +0 -0
  193. /package/{src → tmlpd-pi-extension/src}/utils/compression.ts +0 -0
  194. /package/{src → tmlpd-pi-extension/src}/utils/reliability.ts +0 -0
  195. /package/{src → tmlpd-pi-extension/src}/utils/tokenUtils.ts +0 -0
@@ -0,0 +1,498 @@
1
+ """
2
+ MCTS Workflow Search - Optimize HALO orchestration strategies
3
+
4
+ Based on:
5
+ - arXiv:2505.13516 (HALO - MCTS-based workflow search)
6
+ - NeurIPS 2023 (MCTS for LLM agent planning)
7
+ - ICLR 2024 (Tree of Thoughts + MCTS)
8
+
9
+ Key Innovation: Uses Monte Carlo Tree Search to explore different
10
+ execution strategies and learn the best workflow for each task type.
11
+
12
+ Features:
13
+ - Explores different agent assignment strategies
14
+ - Learns from execution outcomes
15
+ - Balances exploration vs exploitation
16
+ - Adapts workflow to task complexity
17
+ """
18
+
19
+ import asyncio
20
+ import random
21
+ import math
22
+ from typing import Dict, List, Any, Optional, Tuple
23
+ from dataclasses import dataclass, field
24
+ from datetime import datetime
25
+ from collections import defaultdict
26
+ import logging
27
+
28
+ from .task_planner import SubTask, TaskDecomposition
29
+ from .role_assigner import RoleAssigner, AgentAssignment, AgentRole
30
+ from .execution_engine import ExecutionEngine, ExecutionSummary
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class WorkflowNode:
38
+ """Node in MCTS search tree representing a workflow state"""
39
+ state_id: str # Unique identifier for this state
40
+ subtasks_remaining: List[SubTask]
41
+ completed_subtasks: List[str]
42
+ current_strategy: Dict[str, str] # subtask_id -> model_id assignments
43
+ results_so_far: Dict[str, Any] = field(default_factory=dict)
44
+
45
+ # MCTS statistics
46
+ visits: int = 0
47
+ total_reward: float = 0.0
48
+ parent: Optional["WorkflowNode"] = None
49
+ children: List["WorkflowNode"] = field(default_factory=list)
50
+
51
+ @property
52
+ def average_reward(self) -> float:
53
+ """Average reward per visit"""
54
+ return self.total_reward / self.visits if self.visits > 0 else 0.0
55
+
56
+ @property
57
+ def is_fully_expanded(self) -> bool:
58
+ """Whether all possible actions have been tried"""
59
+ return len(self.children) > 0
60
+
61
+ @property
62
+ def is_terminal(self) -> bool:
63
+ """Whether this is a terminal state (all subtasks complete)"""
64
+ return len(self.subtasks_remaining) == 0
65
+
66
+
67
+ @dataclass
68
+ class WorkflowStrategy:
69
+ """Complete workflow strategy (agent assignments for all subtasks)"""
70
+ subtask_assignments: Dict[str, str] # subtask_id -> model_id
71
+ expected_quality: float
72
+ expected_cost: float
73
+ execution_order: List[str]
74
+
75
+ # Execution statistics (if executed)
76
+ actual_quality: Optional[float] = None
77
+ actual_cost: Optional[float] = None
78
+ execution_time_seconds: Optional[float] = None
79
+
80
+ def to_dict(self) -> Dict:
81
+ return {
82
+ "subtask_assignments": self.subtask_assignments,
83
+ "expected_quality": self.expected_quality,
84
+ "expected_cost": self.expected_cost,
85
+ "execution_order": self.execution_order,
86
+ "actual_quality": self.actual_quality,
87
+ "actual_cost": self.actual_cost,
88
+ "execution_time_seconds": self.execution_time_seconds
89
+ }
90
+
91
+
92
+ class MCTSWorkflowSearch:
93
+ """
94
+ Monte Carlo Tree Search for workflow optimization
95
+
96
+ Explores different agent assignment strategies to find optimal workflow.
97
+
98
+ Key Innovation:
99
+ - Treats workflow optimization as a search problem
100
+ - Uses MCTS to balance exploration/exploitation
101
+ - Learns from execution outcomes
102
+ - Adapts to task complexity
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ role_assigner: RoleAssigner,
108
+ execution_engine: ExecutionEngine,
109
+ exploration_weight: float = 1.414, # sqrt(2) for UCB1
110
+ max_simulations: int = 100,
111
+ max_depth: int = 10
112
+ ):
113
+ """
114
+ Initialize MCTS workflow search
115
+
116
+ Args:
117
+ role_assigner: RoleAssigner for getting candidate agents
118
+ execution_engine: ExecutionEngine for simulating workflows
119
+ exploration_weight: UCB1 exploration parameter (higher = more exploration)
120
+ max_simulations: Maximum MCTS simulations per search
121
+ max_depth: Maximum search depth
122
+ """
123
+ self.role_assigner = role_assigner
124
+ self.execution_engine = execution_engine
125
+ self.exploration_weight = exploration_weight
126
+ self.max_simulations = max_simulations
127
+ self.max_depth = max_depth
128
+
129
+ # Strategy history for learning
130
+ self.strategy_history: List[WorkflowStrategy] = []
131
+ self.performance_cache: Dict[str, float] = {} # strategy_hash -> reward
132
+
133
+ async def search(
134
+ self,
135
+ decomposition: TaskDecomposition,
136
+ optimization_target: str = "balanced"
137
+ ) -> WorkflowStrategy:
138
+ """
139
+ Search for optimal workflow strategy using MCTS
140
+
141
+ Args:
142
+ decomposition: Task decomposition with subtasks
143
+ optimization_target: "quality", "cost", or "balanced"
144
+
145
+ Returns:
146
+ Best WorkflowStrategy found
147
+ """
148
+ logger.info(f"Starting MCTS workflow search for {len(decomposition.subtasks)} subtasks")
149
+ start_time = datetime.now()
150
+
151
+ # Initialize root node
152
+ root = WorkflowNode(
153
+ state_id="root",
154
+ subtasks_remaining=decomposition.subtasks.copy(),
155
+ completed_subtasks=[],
156
+ current_strategy={}
157
+ )
158
+
159
+ # Run MCTS simulations
160
+ for simulation in range(self.max_simulations):
161
+ if simulation % 20 == 0:
162
+ logger.debug(f"MCTS simulation {simulation}/{self.max_simulations}")
163
+
164
+ # Selection: Traverse tree to find promising node
165
+ node = self._select(root)
166
+
167
+ # Expansion: Add child nodes if not terminal
168
+ if not node.is_terminal and node.visits > 0:
169
+ node = self._expand(node, decomposition, optimization_target) # Changed from await
170
+
171
+ # Simulation: Execute workflow and get reward
172
+ reward = await self._simulate(node, decomposition, optimization_target)
173
+
174
+ # Backpropagation: Update statistics
175
+ self._backpropagate(node, reward)
176
+
177
+ # Select best strategy from root
178
+ best_child = max(root.children, key=lambda c: c.average_reward)
179
+ best_strategy = self._node_to_strategy(best_child, decomposition)
180
+
181
+ elapsed = (datetime.now() - start_time).total_seconds()
182
+ logger.info(f"MCTS search complete in {elapsed:.2f}s")
183
+ logger.info(f" Simulations: {self.max_simulations}")
184
+ logger.info(f" Tree depth: {self._tree_depth(root)}")
185
+ logger.info(f" Best strategy reward: {best_child.average_reward:.3f}")
186
+
187
+ # Cache and store
188
+ self.strategy_history.append(best_strategy)
189
+
190
+ return best_strategy
191
+
192
+ def _select(self, node: WorkflowNode) -> WorkflowNode:
193
+ """
194
+ Selection phase: UCB1 policy to balance exploration/exploitation
195
+
196
+ UCB1 = average_reward + exploration_weight * sqrt(ln(parent_visits) / visits)
197
+ """
198
+ current = node
199
+
200
+ while current.is_fully_expanded and not current.is_terminal:
201
+ # Select child with highest UCB1 score
202
+ current = max(
203
+ current.children,
204
+ key=lambda child: self._ucb1(child, current)
205
+ )
206
+
207
+ return current
208
+
209
+ def _ucb1(self, child: WorkflowNode, parent: WorkflowNode) -> float:
210
+ """UCB1 formula for node selection"""
211
+ if child.visits == 0:
212
+ return float('inf') # Prioritize unvisited nodes
213
+
214
+ exploitation = child.average_reward
215
+ exploration = self.exploration_weight * math.sqrt(
216
+ math.log(parent.visits) / child.visits
217
+ )
218
+
219
+ return exploitation + exploration
220
+
221
+ def _expand(
222
+ self,
223
+ node: WorkflowNode,
224
+ decomposition: TaskDecomposition,
225
+ optimization_target: str
226
+ ) -> WorkflowNode:
227
+ """
228
+ Expansion phase: Add new child node with different action
229
+
230
+ Action = assign a different model to next subtask
231
+ """
232
+ # Get next subtask to execute
233
+ if not node.subtasks_remaining:
234
+ return node # Terminal
235
+
236
+ next_subtask = node.subtasks_remaining[0]
237
+
238
+ # Try different model assignments for this subtask
239
+ # Get candidates from role assigner
240
+ candidates = self._get_model_candidates(next_subtask, optimization_target)
241
+
242
+ # Try unexplored assignments
243
+ for candidate_model in candidates:
244
+ # Check if this assignment already tried
245
+ assignment_key = f"{next_subtask.id}->{candidate_model}"
246
+ if any(assignment_key in child.current_strategy.values() for child in node.children):
247
+ continue # Already explored
248
+
249
+ # Create new child node
250
+ new_strategy = node.current_strategy.copy()
251
+ new_strategy[next_subtask.id] = candidate_model
252
+
253
+ new_remaining = node.subtasks_remaining[1:] # Remove this subtask
254
+ new_completed = node.completed_subtasks + [next_subtask.id]
255
+
256
+ child = WorkflowNode(
257
+ state_id=f"{node.state_id}_{len(node.children)}",
258
+ subtasks_remaining=new_remaining,
259
+ completed_subtasks=new_completed,
260
+ current_strategy=new_strategy,
261
+ parent=node
262
+ )
263
+
264
+ node.children.append(child)
265
+ return child # Return first newly expanded child
266
+
267
+ # If all candidates explored, return existing child
268
+ return node.children[0] if node.children else node
269
+
270
+ def _get_model_candidates(
271
+ self,
272
+ subtask: SubTask,
273
+ optimization_target: str
274
+ ) -> List[str]:
275
+ """Get candidate models for a subtask"""
276
+ # Use role assigner's model registry
277
+ role = self.role_assigner._map_task_type_to_role(subtask.task_type)
278
+ candidates = self.role_assigner.model_registry.get(role, [])
279
+
280
+ # Return top 3 models by optimization target
281
+ if optimization_target == "quality":
282
+ scored = sorted(candidates, key=lambda m: m["quality_score"], reverse=True)
283
+ elif optimization_target == "cost":
284
+ scored = sorted(candidates, key=lambda m: m["cost_per_1k_tokens"])
285
+ else: # balanced
286
+ scored = sorted(
287
+ candidates,
288
+ key=lambda m: m["quality_score"] / (m["cost_per_1k_tokens"] + 0.0001),
289
+ reverse=True
290
+ )
291
+
292
+ return [f"{c['provider']}/{c['model']}" for c in scored[:3]]
293
+
294
+ async def _simulate(
295
+ self,
296
+ node: WorkflowNode,
297
+ decomposition: TaskDecomposition,
298
+ optimization_target: str
299
+ ) -> float:
300
+ """
301
+ Simulation phase: Execute workflow and estimate reward
302
+
303
+ Reward = weighted combination of quality, cost, speed
304
+ """
305
+ if node.is_terminal:
306
+ # Fully executed workflow - use actual results
307
+ return self._calculate_reward(
308
+ node.results_so_far.get("quality", 0.5),
309
+ node.results_so_far.get("cost", 1.0),
310
+ node.results_so_far.get("time", 1.0),
311
+ optimization_target
312
+ )
313
+
314
+ # Partial workflow - estimate reward using heuristics
315
+ total_quality = 0.0
316
+ total_cost = 0.0
317
+ total_time = 0.0
318
+
319
+ for subtask_id, model_id in node.current_strategy.items():
320
+ # Estimate quality based on model and subtask
321
+ subtask = next(st for st in decomposition.subtasks if st.id == subtask_id)
322
+
323
+ # Simple heuristic: match quality score to subtask difficulty
324
+ base_quality = 0.8 # Default
325
+
326
+ # Get model profile
327
+ if "/" in model_id:
328
+ provider, model = model_id.split("/", 1)
329
+ role = self.role_assigner._map_task_type_to_role(subtask.task_type)
330
+ candidates = self.role_assigner.model_registry.get(role, [])
331
+ for c in candidates:
332
+ if c["provider"] == provider and c["model"] in model:
333
+ base_quality = c["quality_score"]
334
+ total_cost += c["cost_per_1k_tokens"]
335
+ break
336
+
337
+ # Adjust for subtask difficulty
338
+ difficulty_factor = 1.0
339
+ if subtask.difficulty > 70:
340
+ difficulty_factor = 1.2 # Harder tasks reduce quality
341
+ elif subtask.difficulty < 30:
342
+ difficulty_factor = 0.9 # Easier tasks boost quality
343
+
344
+ adjusted_quality = base_quality / difficulty_factor
345
+ total_quality += adjusted_quality
346
+
347
+ # Estimate time (faster models = less time)
348
+ total_time += subtask.estimated_duration_seconds
349
+
350
+ # Average quality across subtasks
351
+ avg_quality = total_quality / len(node.current_strategy) if node.current_strategy else 0.5
352
+
353
+ return self._calculate_reward(
354
+ avg_quality,
355
+ total_cost,
356
+ total_time,
357
+ optimization_target
358
+ )
359
+
360
+ def _calculate_reward(
361
+ self,
362
+ quality: float,
363
+ cost: float,
364
+ time: float,
365
+ optimization_target: str
366
+ ) -> float:
367
+ """Calculate reward based on optimization target"""
368
+ if optimization_target == "quality":
369
+ # Maximize quality, ignore cost/time
370
+ return quality
371
+
372
+ elif optimization_target == "cost":
373
+ # Minimize cost (invert for reward)
374
+ return 1.0 / (cost + 0.001)
375
+
376
+ else: # balanced
377
+ # Weighted combination
378
+ quality_weight = 0.6
379
+ cost_weight = 0.3
380
+ time_weight = 0.1
381
+
382
+ normalized_quality = quality
383
+ normalized_cost = 1.0 / (cost + 0.001)
384
+ normalized_time = 1.0 / (time + 0.1)
385
+
386
+ return (
387
+ quality_weight * normalized_quality +
388
+ cost_weight * normalized_cost +
389
+ time_weight * normalized_time
390
+ )
391
+
392
+ def _backpropagate(self, node: WorkflowNode, reward: float):
393
+ """Backpropagation phase: Update statistics up the tree"""
394
+ current = node
395
+ while current is not None:
396
+ current.visits += 1
397
+ current.total_reward += reward
398
+ current = current.parent
399
+
400
+ def _node_to_strategy(
401
+ self,
402
+ node: WorkflowNode,
403
+ decomposition: TaskDecomposition
404
+ ) -> WorkflowStrategy:
405
+ """Convert node to WorkflowStrategy"""
406
+ # Extract execution order
407
+ execution_order = node.completed_subtasks + [st.id for st in node.subtasks_remaining]
408
+
409
+ # Estimate quality and cost
410
+ avg_quality = sum(
411
+ self.role_assigner.model_registry
412
+ .get(self.role_assigner._map_task_type_to_role(next(st for st in decomposition.subtasks if st.id == st_id).task_type), [])
413
+ [0]["quality_score"] if self.role_assigner.model_registry.get(
414
+ self.role_assigner._map_task_type_to_role(next(st for st in decomposition.subtasks if st.id == st_id).task_type)
415
+ ) else 0.8
416
+ for st_id in execution_order
417
+ ) / len(execution_order) if execution_order else 0.5
418
+
419
+ total_cost = sum(
420
+ self.role_assigner.model_registry
421
+ .get(self.role_assigner._map_task_type_to_role(next(st for st in decomposition.subtasks if st.id == st_id).task_type), [])
422
+ [0]["cost_per_1k_tokens"]
423
+ for st_id in node.current_strategy.keys()
424
+ )
425
+
426
+ return WorkflowStrategy(
427
+ subtask_assignments=node.current_strategy,
428
+ expected_quality=avg_quality,
429
+ expected_cost=total_cost,
430
+ execution_order=execution_order
431
+ )
432
+
433
+ def _tree_depth(self, node: WorkflowNode) -> int:
434
+ """Calculate maximum depth of tree from node"""
435
+ if not node.children:
436
+ return 0
437
+ return 1 + max(self._tree_depth(child) for child in node.children)
438
+
439
+ def get_search_stats(self) -> Dict[str, Any]:
440
+ """Get statistics about searches performed"""
441
+ if not self.strategy_history:
442
+ return {"total_searches": 0}
443
+
444
+ return {
445
+ "total_searches": len(self.strategy_history),
446
+ "cache_size": len(self.performance_cache),
447
+ "avg_strategy_quality": sum(s.expected_quality for s in self.strategy_history) / len(self.strategy_history),
448
+ "avg_strategy_cost": sum(s.expected_cost for s in self.strategy_history) / len(self.strategy_history)
449
+ }
450
+
451
+
452
+ # Example usage
453
+ async def main():
454
+ """Example of MCTS workflow search"""
455
+ from .task_planner import TaskPlanner
456
+ from .role_assigner import RoleAssigner
457
+ from .execution_engine import ExecutionEngine
458
+
459
+ # Create components
460
+ planner = TaskPlanner()
461
+ assigner = RoleAssigner()
462
+ engine = ExecutionEngine(max_concurrent=3)
463
+ mcts = MCTSWorkflowSearch(assigner, engine, max_simulations=50)
464
+
465
+ # Create task
466
+ task = {
467
+ "description": "Build a REST API with authentication, database integration, and automated testing",
468
+ "context": {"requirements": ["JWT", "PostgreSQL", "Jest"]}
469
+ }
470
+
471
+ # Decompose
472
+ print("Decomposing task...")
473
+ decomposition = await planner.decompose(task)
474
+
475
+ print(f"\nDecomposed into {len(decomposition.subtasks)} subtasks:")
476
+ for st in decomposition.subtasks:
477
+ print(f" {st.id}: {st.description} (difficulty: {st.difficulty})")
478
+
479
+ # Search for optimal workflow
480
+ print("\nRunning MCTS workflow search...")
481
+ best_strategy = await mcts.search(decomposition, optimization_target="balanced")
482
+
483
+ print(f"\nBest Workflow Strategy:")
484
+ print(f" Assignments: {best_strategy.subtask_assignments}")
485
+ print(f" Expected Quality: {best_strategy.expected_quality:.2f}")
486
+ print(f" Expected Cost: ${best_strategy.expected_cost:.4f}")
487
+ print(f" Execution Order: {best_strategy.execution_order}")
488
+
489
+ # Stats
490
+ stats = mcts.get_search_stats()
491
+ print(f"\nMCTS Stats:")
492
+ print(f" Total searches: {stats['total_searches']}")
493
+ print(f" Cache size: {stats['cache_size']}")
494
+
495
+
496
+ if __name__ == "__main__":
497
+ logging.basicConfig(level=logging.INFO)
498
+ asyncio.run(main())