isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,3036 @@
1
+ """
2
+ Strategy Adapter Registry
3
+
4
+ Provides a unified registry for mapping strategy names (e.g., "selector.keyword")
5
+ to actual selector/planner/timing implementations from sage-libs.
6
+
7
+ This bridges the benchmark experiments with the runtime components.
8
+
9
+ Method Classification:
10
+ =====================
11
+
12
+ Paper 1 (Benchmark) - Existing SOTA Methods (Runtime Evaluation Only):
13
+ Tool Selection: keyword, embedding, hybrid, gorilla, dfsdt/toolllm
14
+ Planning: simple, hierarchical, llm_based, react, tot
15
+ Timing: rule_based, llm_based, hybrid, embedding
16
+
17
+ Note: Paper 1 is a BENCHMARK paper that evaluates existing methods
18
+ using pre-trained models. It does NOT involve training.
19
+
20
+ Paper 2 (SIAS Method) - Training Strategies (in sage.libs.sias):
21
+ Core Components:
22
+ - CoresetSelector: Intelligent sample selection (loss_topk, diversity, hybrid)
23
+ - OnlineContinualLearner: Experience replay buffer with importance weighting
24
+ - SSIS: Streaming Sample Importance Scorer (TODO)
25
+ - Priority Replay: Importance-Weighted Experience Buffer (TODO)
26
+
27
+ Training Configurations:
28
+ - SIAS_sft_baseline: Standard SFT (ablation baseline)
29
+ - SIAS_coreset: + Coreset selection
30
+ - SIAS_continual: + Continual learning with replay
31
+ - SIAS_full: Complete SIAS framework
32
+
33
+ Import: from sage.libs.sias import CoresetSelector, OnlineContinualLearner
34
+
35
+ Usage:
36
+ >>> registry = get_adapter_registry()
37
+ >>> selector = registry.get("selector.keyword", resources)
38
+ >>> planner = registry.get("planner.react", resources)
39
+ """
40
+
41
+ from typing import Any, Callable, Optional, Protocol
42
+
43
+ # =============================================================================
44
+ # Experiment Configuration Constants (for controlled variable experiments)
45
+ # =============================================================================
46
+ # These constants ensure all methods use the same underlying models/parameters
47
+ # to enable fair comparison (控制变量法)
48
+
49
+ # Embedding model: all embedding-based methods use the same model
50
+ BENCHMARK_EMBEDDING_MODEL = "BAAI/bge-small-zh-v1.5"
51
+
52
+ # LLM temperature: low value for reproducibility in benchmark evaluation
53
+ BENCHMARK_LLM_TEMPERATURE = 0.1
54
+
55
+ # =============================================================================
56
+
57
+ # Lazy imports to avoid circular dependencies
58
+ _SELECTOR_REGISTRY = None
59
+ _PLANNER_CLASS = None
60
+ _TIMING_DECIDER_CLASS = None
61
+
62
+
63
+ class StrategyProtocol(Protocol):
64
+ """Protocol for strategy adapters."""
65
+
66
+ def predict(self, query: Any, **kwargs) -> Any:
67
+ """Make a prediction."""
68
+ ...
69
+
70
+
71
+ class SelectorAdapter:
72
+ """
73
+ Adapter wrapping tool selectors to provide unified predict()/select() interface.
74
+
75
+ Maps the selector's select() method to predict() for benchmark compatibility.
76
+ Also provides select() method for run_all_experiments.py compatibility.
77
+ """
78
+
79
+ def __init__(self, selector: Any):
80
+ """
81
+ Initialize adapter.
82
+
83
+ Args:
84
+ selector: A tool selector instance with select() method
85
+ """
86
+ self.selector = selector
87
+
88
+ def predict(self, query: Any, top_k: Optional[int] = None, **kwargs) -> list:
89
+ """
90
+ Make tool selection prediction.
91
+
92
+ Args:
93
+ query: ToolSelectionQuery from experiment
94
+ top_k: Number of tools to select
95
+
96
+ Returns:
97
+ List of ToolPrediction objects
98
+ """
99
+ # Convert experiment query to selector query format
100
+ from sage.libs.agentic.agents.action.tool_selection.schemas import (
101
+ ToolSelectionQuery as SelectorQuery,
102
+ )
103
+
104
+ selector_query = SelectorQuery(
105
+ sample_id=query.sample_id,
106
+ instruction=query.instruction,
107
+ candidate_tools=getattr(query, "candidate_tools", None),
108
+ context=getattr(query, "context", {}),
109
+ )
110
+
111
+ k = top_k if top_k is not None else 5
112
+ return self.selector.select(selector_query, top_k=k)
113
+
114
+ def select(self, query: Any, candidate_tools: Optional[list] = None, top_k: int = 5) -> list:
115
+ """
116
+ Select tools for a query (alias for predict with simpler interface).
117
+
118
+ This method is provided for compatibility with run_all_experiments.py
119
+ which calls selector.select(query, candidate_tools, top_k=top_k).
120
+
121
+ Args:
122
+ query: Either a string (instruction) or ToolSelectionQuery object
123
+ candidate_tools: Optional list of candidate tools (may be ignored if
124
+ selector has its own tool corpus)
125
+ top_k: Number of tools to select
126
+
127
+ Returns:
128
+ List of ToolPrediction objects or tool IDs
129
+ """
130
+ from sage.libs.agentic.agents.action.tool_selection.schemas import (
131
+ ToolSelectionQuery as SelectorQuery,
132
+ )
133
+
134
+ # Ensure candidate_tools is always a list (never None)
135
+ tools = candidate_tools if candidate_tools is not None else []
136
+
137
+ # Handle string query (from run_all_experiments.py)
138
+ if isinstance(query, str):
139
+ selector_query = SelectorQuery(
140
+ sample_id="runtime",
141
+ instruction=query,
142
+ candidate_tools=tools,
143
+ context={},
144
+ )
145
+ # Handle dict query
146
+ elif isinstance(query, dict):
147
+ tools_from_dict = query.get("candidate_tools", [])
148
+ selector_query = SelectorQuery(
149
+ sample_id=query.get("sample_id", "runtime"),
150
+ instruction=query.get("instruction", str(query)),
151
+ candidate_tools=tools if tools else (tools_from_dict or []),
152
+ context=query.get("context", {}),
153
+ )
154
+ # Handle query object with attributes
155
+ elif hasattr(query, "instruction"):
156
+ tools_from_obj = getattr(query, "candidate_tools", [])
157
+ selector_query = SelectorQuery(
158
+ sample_id=getattr(query, "sample_id", "runtime"),
159
+ instruction=query.instruction,
160
+ candidate_tools=tools if tools else (tools_from_obj or []),
161
+ context=getattr(query, "context", {}),
162
+ )
163
+ else:
164
+ # Fallback: treat query as instruction string
165
+ selector_query = SelectorQuery(
166
+ sample_id="runtime",
167
+ instruction=str(query),
168
+ candidate_tools=tools,
169
+ context={},
170
+ )
171
+
172
+ try:
173
+ result = self.selector.select(selector_query, top_k=top_k)
174
+ return result
175
+ except Exception as e:
176
+ # If selector fails, return empty list with debug info
177
+ import logging
178
+
179
+ logging.getLogger(__name__).warning(
180
+ f"Selector failed for query '{str(query)[:50]}...': {e}"
181
+ )
182
+ return []
183
+
184
+
185
+ class PlannerAdapter:
186
+ """
187
+ Adapter wrapping planners to provide unified plan() interface.
188
+ """
189
+
190
+ def __init__(self, planner: Any):
191
+ """
192
+ Initialize adapter.
193
+
194
+ Args:
195
+ planner: A planner instance with plan() method
196
+ """
197
+ self.planner = planner
198
+
199
+ def plan(self, task: Any, **kwargs) -> Any:
200
+ """
201
+ Generate a plan for the task.
202
+
203
+ Args:
204
+ task: PlanningTask from experiment
205
+
206
+ Returns:
207
+ PlanningPrediction with steps and tool_sequence
208
+ """
209
+ return self.planner.plan(task)
210
+
211
+
212
+ class UnifiedTimingMessage:
213
+ """
214
+ Unified timing message that provides both .message and .user_message attributes.
215
+
216
+ This ensures compatibility with:
217
+ - Local deciders in adapter_registry.py (expect .message)
218
+ - sage-libs timing_decider.py (expects .user_message)
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ user_message: str,
224
+ conversation_history: Optional[list[Any]] = None,
225
+ last_tool_call: Any = None,
226
+ context: Optional[dict[Any, Any]] = None,
227
+ ):
228
+ self.user_message = user_message
229
+ self.message = user_message # Alias for compatibility with local deciders
230
+ self.conversation_history = conversation_history or []
231
+ self.last_tool_call = last_tool_call
232
+ self.context = context or {}
233
+
234
+
235
+ class TimingAdapter:
236
+ """
237
+ Adapter wrapping timing deciders to provide unified decide() interface.
238
+
239
+ Handles conversion between experiment TimingMessage format and
240
+ sage-libs schemas.TimingMessage format.
241
+ """
242
+
243
+ def __init__(self, decider: Any):
244
+ """
245
+ Initialize adapter.
246
+
247
+ Args:
248
+ decider: A timing decider instance with decide() method
249
+ """
250
+ self.decider = decider
251
+
252
+ def decide(self, message: Any, **kwargs) -> Any:
253
+ """
254
+ Make timing decision.
255
+
256
+ Args:
257
+ message: TimingMessage from experiment (has .message attribute)
258
+ or dict with 'instruction'/'message' key
259
+
260
+ Returns:
261
+ TimingDecision with should_call_tool, confidence, reasoning
262
+ """
263
+ # Convert to UnifiedTimingMessage which has both .message and .user_message
264
+ if isinstance(message, dict):
265
+ # Handle dict input (e.g., from tests or direct API calls)
266
+ user_msg = (
267
+ message.get("instruction")
268
+ or message.get("message")
269
+ or message.get("user_message", "")
270
+ )
271
+ unified_message = UnifiedTimingMessage(
272
+ user_message=user_msg,
273
+ conversation_history=message.get("conversation_history", []),
274
+ last_tool_call=message.get("last_tool_call"),
275
+ context=message.get("context", {}),
276
+ )
277
+ elif hasattr(message, "user_message") and hasattr(message, "message"):
278
+ # Already has both attributes (e.g., UnifiedTimingMessage)
279
+ unified_message = message
280
+ elif hasattr(message, "user_message"):
281
+ # schemas.TimingMessage format (sage-libs)
282
+ unified_message = UnifiedTimingMessage(
283
+ user_message=message.user_message,
284
+ conversation_history=getattr(message, "conversation_history", []),
285
+ last_tool_call=getattr(message, "last_tool_call", None),
286
+ context=getattr(message, "context", {}),
287
+ )
288
+ elif hasattr(message, "message"):
289
+ # Experiment's TimingMessage (has .message instead of .user_message)
290
+ unified_message = UnifiedTimingMessage(
291
+ user_message=message.message,
292
+ conversation_history=getattr(message, "conversation_history", []),
293
+ last_tool_call=getattr(message, "last_tool_call", None),
294
+ context=getattr(message, "context", {}),
295
+ )
296
+ else:
297
+ # Fallback: treat as string
298
+ unified_message = UnifiedTimingMessage(
299
+ user_message=str(message),
300
+ conversation_history=[],
301
+ last_tool_call=None,
302
+ context={},
303
+ )
304
+
305
+ return self.decider.decide(unified_message)
306
+
307
+
308
+ class AdapterRegistry:
309
+ """
310
+ Registry for strategy adapters.
311
+
312
+ Maps string names like "baseline.keyword" to actual implementations.
313
+ """
314
+
315
+ def __init__(self):
316
+ """Initialize registry with built-in strategies."""
317
+ import logging
318
+
319
+ self.logger = logging.getLogger(__name__)
320
+
321
+ self._selectors: dict[str, Any] = {}
322
+ self._planners: dict[str, Any] = {}
323
+ self._timing_deciders: dict[str, Any] = {}
324
+ self._factories: dict[str, Callable] = {}
325
+
326
+ # Register built-in strategies
327
+ self._register_builtins()
328
+
329
+ def _register_builtins(self) -> None:
330
+ """Register built-in baseline strategies.
331
+
332
+ Method Classification:
333
+ =====================
334
+
335
+ Paper 1 (Benchmark) - Existing SOTA Methods:
336
+ - Tool Selection: keyword, embedding, hybrid, gorilla, dfsdt/toolllm
337
+ - Planning: simple, hierarchical, llm_based, react, tot
338
+ - Timing: rule_based, llm_based, hybrid, embedding
339
+
340
+ Paper 2 (Method) - SAGE Original Methods:
341
+ - Training: SAGE_baseline_sft, SAGE_coreset_loss, SAGE_coreset_diversity,
342
+ SAGE_coreset_hybrid, SAGE_continual, SAGE_combined
343
+ - (Defined in run_full_training_comparison.py, not runtime adapters)
344
+ """
345
+ # =================================================================
346
+ # Paper 1: Existing SOTA Selector Strategies
347
+ # =================================================================
348
+ # BM25/TF-IDF keyword-based selection
349
+ self._factories["baseline.keyword"] = self._create_keyword_selector
350
+ self._factories["baseline.embedding"] = self._create_embedding_selector
351
+ self._factories["baseline.hybrid"] = self._create_hybrid_selector
352
+ self._factories["keyword"] = self._create_keyword_selector
353
+ self._factories["embedding"] = self._create_embedding_selector
354
+ self._factories["hybrid"] = self._create_hybrid_selector
355
+ # Aliased names for benchmark scripts
356
+ self._factories["selector.keyword"] = self._create_keyword_selector
357
+ self._factories["selector.embedding"] = self._create_embedding_selector
358
+ self._factories["selector.hybrid"] = self._create_hybrid_selector
359
+ # Gorilla: LLM-augmented retrieval (Patil et al., 2023)
360
+ self._factories["selector.gorilla"] = self._create_gorilla_selector
361
+ self._factories["gorilla"] = self._create_gorilla_selector
362
+ # ToolLLM DFSDT: Depth-First Search Decision Tree (Qin et al., 2023)
363
+ self._factories["selector.dfsdt"] = self._create_dfsdt_selector
364
+ self._factories["selector.toolllm"] = self._create_dfsdt_selector # Alias
365
+ self._factories["dfsdt"] = self._create_dfsdt_selector
366
+ self._factories["toolllm"] = self._create_dfsdt_selector # Alias
367
+
368
+ # =================================================================
369
+ # Paper 1: Existing SOTA Planner Strategies
370
+ # =================================================================
371
+ self._factories["baseline.template"] = self._create_template_planner
372
+ self._factories["baseline.hierarchical"] = self._create_hierarchical_planner
373
+ self._factories["cot"] = self._create_hierarchical_planner
374
+ self._factories["baseline.sequence"] = self._create_sequence_planner
375
+ # Challenge 2 planner strategies
376
+ self._factories["planner.simple"] = self._create_simple_planner
377
+ self._factories["planner.hierarchical"] = self._create_hierarchical_planning_strategy
378
+ self._factories["planner.llm_based"] = self._create_llm_planning_strategy
379
+ # ReAct: Reasoning + Acting (Yao et al., 2023)
380
+ self._factories["planner.react"] = self._create_react_planner
381
+ self._factories["react"] = self._create_react_planner # Alias
382
+ # Tree-of-Thoughts: Multi-path reasoning (Yao et al., 2023)
383
+ self._factories["planner.tot"] = self._create_tot_planner
384
+ self._factories["planner.tree_of_thoughts"] = self._create_tot_planner # Alias
385
+
386
+ # =================================================================
387
+ # Paper 1: Existing SOTA Timing Strategies
388
+ # =================================================================
389
+ self._factories["baseline.threshold"] = self._create_threshold_decider
390
+ self._factories["llm_based"] = self._create_llm_timing_decider
391
+ # New timing strategies for benchmark
392
+ self._factories["timing.rule_based"] = self._create_rule_based_decider
393
+ self._factories["timing.llm_based"] = self._create_llm_timing_decider
394
+ self._factories["timing.hybrid"] = self._create_hybrid_timing_decider
395
+ self._factories["timing.embedding"] = self._create_embedding_timing_decider
396
+
397
+ def register(self, name: str, strategy: Any) -> None:
398
+ """
399
+ Register a strategy by name.
400
+
401
+ Args:
402
+ name: Strategy name (e.g., "my_selector")
403
+ strategy: Strategy instance or factory
404
+ """
405
+ if callable(strategy) and not hasattr(strategy, "predict"):
406
+ self._factories[name] = strategy
407
+ else:
408
+ # Store instance directly
409
+ if hasattr(strategy, "select"):
410
+ self._selectors[name] = SelectorAdapter(strategy)
411
+ elif hasattr(strategy, "plan"):
412
+ self._planners[name] = PlannerAdapter(strategy)
413
+ elif hasattr(strategy, "decide"):
414
+ self._timing_deciders[name] = TimingAdapter(strategy)
415
+ else:
416
+ self._factories[name] = lambda: strategy
417
+
418
+ def get(self, name: str, resources: Optional[Any] = None) -> Any:
419
+ """
420
+ Get a strategy by name.
421
+
422
+ Args:
423
+ name: Strategy name
424
+ resources: Optional SelectorResources for initialization
425
+
426
+ Returns:
427
+ Strategy adapter instance
428
+
429
+ Raises:
430
+ ValueError: If strategy not found
431
+ """
432
+ # Check cached instances
433
+ if name in self._selectors:
434
+ return self._selectors[name]
435
+ if name in self._planners:
436
+ return self._planners[name]
437
+ if name in self._timing_deciders:
438
+ return self._timing_deciders[name]
439
+
440
+ # Try factory
441
+ if name in self._factories:
442
+ strategy = self._factories[name](resources)
443
+ return strategy
444
+
445
+ raise ValueError(f"Unknown strategy: {name}. Available: {self.list_strategies()}")
446
+
447
+ def list_strategies(self) -> list:
448
+ """List all registered strategy names."""
449
+ all_names = set(self._selectors.keys())
450
+ all_names.update(self._planners.keys())
451
+ all_names.update(self._timing_deciders.keys())
452
+ all_names.update(self._factories.keys())
453
+ return sorted(all_names)
454
+
455
+ # --- Factory methods for built-in strategies ---
456
+
457
+ def _create_keyword_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
458
+ """
459
+ Create keyword-based selector using BM25/TF-IDF.
460
+
461
+ This implementation uses dynamic indexing - it processes the candidate_tools
462
+ provided in each query, enabling cross-dataset evaluation.
463
+ """
464
+
465
+ class DynamicKeywordSelector:
466
+ """Dynamic keyword selector using BM25 on candidate_tools."""
467
+
468
+ def __init__(self):
469
+ self.name = "keyword"
470
+
471
+ def select(self, query, top_k=5):
472
+ """Select tools using BM25 keyword matching."""
473
+ import math
474
+ import re
475
+
476
+ from sage.libs.agentic.agents.action.tool_selection.schemas import (
477
+ ToolPrediction,
478
+ )
479
+
480
+ candidate_tools = getattr(query, "candidate_tools", []) or []
481
+ if not candidate_tools:
482
+ return []
483
+
484
+ instruction = getattr(query, "instruction", str(query))
485
+
486
+ # Tokenize query
487
+ query_tokens = set(re.findall(r"[a-z0-9]+", instruction.lower()))
488
+ if not query_tokens:
489
+ return []
490
+
491
+ # Build tool texts and tokenize
492
+ tool_data = []
493
+ for tool in candidate_tools:
494
+ if isinstance(tool, str):
495
+ tool_id = tool
496
+ tool_text = tool.replace("_", " ")
497
+ elif hasattr(tool, "name"):
498
+ tool_id = getattr(tool, "tool_id", getattr(tool, "id", tool.name))
499
+ tool_text = f"{tool.name} {getattr(tool, 'description', '')}"
500
+ elif isinstance(tool, dict):
501
+ tool_id = tool.get("tool_id", tool.get("id", tool.get("name", "")))
502
+ tool_text = f"{tool.get('name', '')} {tool.get('description', '')}"
503
+ else:
504
+ continue
505
+ tool_tokens = set(re.findall(r"[a-z0-9]+", tool_text.lower()))
506
+ tool_data.append((tool_id, tool_tokens, len(tool_tokens)))
507
+
508
+ if not tool_data:
509
+ return []
510
+
511
+ # Compute IDF
512
+ num_docs = len(tool_data)
513
+ doc_freq = {}
514
+ for _, tokens, _ in tool_data:
515
+ for token in tokens:
516
+ doc_freq[token] = doc_freq.get(token, 0) + 1
517
+ idf = {t: math.log((num_docs + 1) / (df + 1)) for t, df in doc_freq.items()}
518
+
519
+ # BM25 scoring
520
+ k1, b = 1.5, 0.75
521
+ avg_dl = sum(dl for _, _, dl in tool_data) / num_docs if num_docs else 1
522
+
523
+ scores = []
524
+ for tool_id, tool_tokens, doc_len in tool_data:
525
+ score = 0.0
526
+ for token in query_tokens:
527
+ if token in tool_tokens:
528
+ tf = 1 # Binary TF
529
+ score += (
530
+ idf.get(token, 0)
531
+ * (tf * (k1 + 1))
532
+ / (tf + k1 * (1 - b + b * doc_len / avg_dl))
533
+ )
534
+ scores.append((tool_id, score))
535
+
536
+ # Sort and return top-k
537
+ scores.sort(key=lambda x: x[1], reverse=True)
538
+ return [
539
+ ToolPrediction(tool_id=tid, score=min(s / 10, 1.0)) for tid, s in scores[:top_k]
540
+ ]
541
+
542
+ return SelectorAdapter(DynamicKeywordSelector())
543
+
544
+ def _create_embedding_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
545
+ """
546
+ Create embedding-based selector using cosine similarity.
547
+
548
+ This implementation uses dynamic indexing - it embeds the candidate_tools
549
+ provided in each query, enabling cross-dataset evaluation.
550
+ """
551
+
552
+ class DynamicEmbeddingSelector:
553
+ """Dynamic embedding selector on candidate_tools."""
554
+
555
+ def __init__(self):
556
+ self.name = "embedding"
557
+ self._embedding_client = None
558
+
559
+ def _init_client(self):
560
+ if self._embedding_client is None:
561
+ try:
562
+ from sage.common.components.sage_embedding import (
563
+ EmbeddingClientAdapter,
564
+ EmbeddingFactory,
565
+ )
566
+
567
+ raw_embedder = EmbeddingFactory.create(
568
+ "hf", model=BENCHMARK_EMBEDDING_MODEL
569
+ )
570
+ self._embedding_client = EmbeddingClientAdapter(raw_embedder)
571
+ except Exception:
572
+ pass
573
+
574
+ def select(self, query, top_k=5):
575
+ """Select tools using embedding similarity."""
576
+ import numpy as np
577
+ from sage.libs.agentic.agents.action.tool_selection.schemas import (
578
+ ToolPrediction,
579
+ )
580
+
581
+ self._init_client()
582
+
583
+ candidate_tools = getattr(query, "candidate_tools", []) or []
584
+ if not candidate_tools:
585
+ return []
586
+
587
+ instruction = getattr(query, "instruction", str(query))
588
+
589
+ # Build tool texts
590
+ tool_ids = []
591
+ tool_texts = []
592
+ for tool in candidate_tools:
593
+ if isinstance(tool, str):
594
+ tool_ids.append(tool)
595
+ tool_texts.append(tool.replace("_", " "))
596
+ elif hasattr(tool, "name"):
597
+ tool_ids.append(getattr(tool, "tool_id", getattr(tool, "id", tool.name)))
598
+ tool_texts.append(f"{tool.name}: {getattr(tool, 'description', '')}")
599
+ elif isinstance(tool, dict):
600
+ tool_ids.append(tool.get("tool_id", tool.get("id", tool.get("name", ""))))
601
+ tool_texts.append(f"{tool.get('name', '')}: {tool.get('description', '')}")
602
+
603
+ if not tool_texts or self._embedding_client is None:
604
+ # Fallback to simple matching
605
+ return [ToolPrediction(tool_id=tid, score=0.5) for tid in tool_ids[:top_k]]
606
+
607
+ try:
608
+ # Embed query and tools
609
+ all_texts = [instruction] + tool_texts
610
+ embeddings = self._embedding_client.embed(all_texts)
611
+
612
+ query_emb = np.asarray(embeddings[0])
613
+ tool_embs = np.asarray(embeddings[1:])
614
+
615
+ # Cosine similarity
616
+ query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-8)
617
+ tool_norms = tool_embs / (
618
+ np.linalg.norm(tool_embs, axis=1, keepdims=True) + 1e-8
619
+ )
620
+ scores = np.dot(tool_norms, query_norm)
621
+
622
+ # Sort and return top-k
623
+ top_indices = np.argsort(scores)[::-1][:top_k]
624
+ return [
625
+ ToolPrediction(tool_id=tool_ids[i], score=float(scores[i]))
626
+ for i in top_indices
627
+ ]
628
+ except Exception:
629
+ return [ToolPrediction(tool_id=tid, score=0.5) for tid in tool_ids[:top_k]]
630
+
631
+ return SelectorAdapter(DynamicEmbeddingSelector())
632
+
633
+ def _create_hybrid_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
634
+ """
635
+ Create hybrid selector combining keyword (BM25) and embedding similarity.
636
+
637
+ This implementation uses dynamic indexing - it processes the candidate_tools
638
+ provided in each query, enabling cross-dataset evaluation.
639
+ """
640
+
641
+ class DynamicHybridSelector:
642
+ """Dynamic hybrid selector: 40% keyword + 60% embedding."""
643
+
644
+ def __init__(self):
645
+ self.name = "hybrid"
646
+ self._embedding_client = None
647
+ self._keyword_weight = 0.4
648
+ self._embedding_weight = 0.6
649
+
650
+ def _init_client(self):
651
+ if self._embedding_client is None:
652
+ try:
653
+ from sage.common.components.sage_embedding import (
654
+ EmbeddingClientAdapter,
655
+ EmbeddingFactory,
656
+ )
657
+
658
+ raw_embedder = EmbeddingFactory.create(
659
+ "hf", model=BENCHMARK_EMBEDDING_MODEL
660
+ )
661
+ self._embedding_client = EmbeddingClientAdapter(raw_embedder)
662
+ except Exception:
663
+ pass
664
+
665
+ def select(self, query, top_k=5):
666
+ """Select tools using hybrid scoring."""
667
+ import math
668
+ import re
669
+
670
+ import numpy as np
671
+ from sage.libs.agentic.agents.action.tool_selection.schemas import (
672
+ ToolPrediction,
673
+ )
674
+
675
+ self._init_client()
676
+
677
+ candidate_tools = getattr(query, "candidate_tools", []) or []
678
+ if not candidate_tools:
679
+ return []
680
+
681
+ instruction = getattr(query, "instruction", str(query))
682
+
683
+ # Build tool data
684
+ tool_ids = []
685
+ tool_texts = []
686
+ for tool in candidate_tools:
687
+ if isinstance(tool, str):
688
+ tool_ids.append(tool)
689
+ tool_texts.append(tool.replace("_", " "))
690
+ elif hasattr(tool, "name"):
691
+ tool_ids.append(getattr(tool, "tool_id", getattr(tool, "id", tool.name)))
692
+ tool_texts.append(f"{tool.name}: {getattr(tool, 'description', '')}")
693
+ elif isinstance(tool, dict):
694
+ tool_ids.append(tool.get("tool_id", tool.get("id", tool.get("name", ""))))
695
+ tool_texts.append(f"{tool.get('name', '')}: {tool.get('description', '')}")
696
+
697
+ if not tool_ids:
698
+ return []
699
+
700
+ # === Keyword scores (BM25) ===
701
+ query_tokens = set(re.findall(r"[a-z0-9]+", instruction.lower()))
702
+ tool_tokens_list = [set(re.findall(r"[a-z0-9]+", t.lower())) for t in tool_texts]
703
+
704
+ # IDF
705
+ num_docs = len(tool_ids)
706
+ doc_freq = {}
707
+ for tokens in tool_tokens_list:
708
+ for token in tokens:
709
+ doc_freq[token] = doc_freq.get(token, 0) + 1
710
+ idf = {t: math.log((num_docs + 1) / (df + 1)) for t, df in doc_freq.items()}
711
+
712
+ # BM25
713
+ k1, b = 1.5, 0.75
714
+ avg_dl = sum(len(t) for t in tool_tokens_list) / num_docs if num_docs else 1
715
+
716
+ keyword_scores = []
717
+ for tool_tokens in tool_tokens_list:
718
+ score = 0.0
719
+ doc_len = len(tool_tokens)
720
+ for token in query_tokens:
721
+ if token in tool_tokens:
722
+ tf = 1
723
+ score += (
724
+ idf.get(token, 0)
725
+ * (tf * (k1 + 1))
726
+ / (tf + k1 * (1 - b + b * doc_len / avg_dl))
727
+ )
728
+ keyword_scores.append(score)
729
+
730
+ # Normalize keyword scores
731
+ max_kw = max(keyword_scores) if keyword_scores and max(keyword_scores) > 0 else 1
732
+ keyword_scores = [s / max_kw for s in keyword_scores]
733
+
734
+ # === Embedding scores ===
735
+ embedding_scores = [0.0] * len(tool_ids)
736
+ if self._embedding_client is not None:
737
+ try:
738
+ all_texts = [instruction] + tool_texts
739
+ embeddings = self._embedding_client.embed(all_texts)
740
+
741
+ query_emb = np.asarray(embeddings[0])
742
+ tool_embs = np.asarray(embeddings[1:])
743
+
744
+ query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-8)
745
+ tool_norms = tool_embs / (
746
+ np.linalg.norm(tool_embs, axis=1, keepdims=True) + 1e-8
747
+ )
748
+ embedding_scores = list(np.dot(tool_norms, query_norm))
749
+ except Exception:
750
+ pass
751
+
752
+ # === Combine scores ===
753
+ combined = [
754
+ (
755
+ tool_ids[i],
756
+ self._keyword_weight * keyword_scores[i]
757
+ + self._embedding_weight * embedding_scores[i],
758
+ )
759
+ for i in range(len(tool_ids))
760
+ ]
761
+ combined.sort(key=lambda x: x[1], reverse=True)
762
+
763
+ return [
764
+ ToolPrediction(tool_id=tid, score=min(s, 1.0)) for tid, s in combined[:top_k]
765
+ ]
766
+
767
+ return SelectorAdapter(DynamicHybridSelector())
768
+
769
+ def _create_gorilla_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
770
+ """
771
+ Create Gorilla-style retrieval-augmented selector.
772
+
773
+ Gorilla uses a two-stage approach:
774
+ 1. Embedding retrieval to find candidate tools
775
+ 2. LLM selection from retrieved candidates
776
+
777
+ Reference: Patil et al. (2023) "Gorilla: Large Language Model Connected with Massive APIs"
778
+
779
+ This implementation uses dynamic indexing - it builds embeddings from the
780
+ candidate_tools provided in each query, enabling cross-dataset evaluation.
781
+ """
782
+
783
+ class DynamicGorillaSelector:
784
+ """Gorilla selector with dynamic tool indexing."""
785
+
786
+ def __init__(self):
787
+ self._embedding_client = None
788
+ self._llm_client = None
789
+ self.name = "gorilla"
790
+
791
+ def _init_clients(self):
792
+ """Lazy initialization of embedding and LLM clients."""
793
+ if self._embedding_client is None:
794
+ # Try local HuggingFace embedding first
795
+ try:
796
+ from sage.common.components.sage_embedding import (
797
+ EmbeddingClientAdapter,
798
+ EmbeddingFactory,
799
+ )
800
+
801
+ raw_embedder = EmbeddingFactory.create(
802
+ "hf", model=BENCHMARK_EMBEDDING_MODEL
803
+ )
804
+ self._embedding_client = EmbeddingClientAdapter(raw_embedder)
805
+ except Exception:
806
+ pass
807
+
808
+ if self._llm_client is None:
809
+ try:
810
+ from sage.llm import UnifiedInferenceClient
811
+
812
+ self._llm_client = UnifiedInferenceClient.create()
813
+ except Exception:
814
+ pass
815
+
816
+ def select(self, query, top_k=5):
817
+ """Select tools using embedding retrieval + LLM reranking."""
818
+ import logging
819
+
820
+ import numpy as np
821
+ from sage.libs.agentic.agents.action.tool_selection.schemas import (
822
+ ToolPrediction,
823
+ )
824
+
825
+ logger = logging.getLogger(__name__)
826
+ self._init_clients()
827
+
828
+ candidate_tools = getattr(query, "candidate_tools", []) or []
829
+ if not candidate_tools:
830
+ return []
831
+
832
+ instruction = getattr(query, "instruction", str(query))
833
+
834
+ # Parse candidate tools into (tool_id, tool_text) pairs
835
+ tool_ids, tool_texts = [], []
836
+ for tool in candidate_tools:
837
+ if isinstance(tool, str):
838
+ tool_ids.append(tool)
839
+ tool_texts.append(tool)
840
+ elif hasattr(tool, "name"):
841
+ tid = getattr(tool, "tool_id", getattr(tool, "id", tool.name))
842
+ tool_ids.append(tid)
843
+ tool_texts.append(f"{tool.name}: {getattr(tool, 'description', '')}")
844
+ elif isinstance(tool, dict):
845
+ tid = tool.get("tool_id", tool.get("id", tool.get("name", "")))
846
+ tool_ids.append(tid)
847
+ tool_texts.append(f"{tool.get('name', tid)}: {tool.get('description', '')}")
848
+
849
+ if not tool_ids:
850
+ return []
851
+
852
+ # Fallback if no embedding client
853
+ if self._embedding_client is None:
854
+ return [ToolPrediction(tool_id=tid, score=0.5) for tid in tool_ids[:top_k]]
855
+
856
+ try:
857
+ # Embed query and tools
858
+ embeddings = self._embedding_client.embed([instruction] + tool_texts)
859
+ query_emb = np.asarray(embeddings[0])
860
+ tool_embs = np.asarray(embeddings[1:])
861
+
862
+ # Cosine similarity
863
+ query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-8)
864
+ tool_norms = tool_embs / (
865
+ np.linalg.norm(tool_embs, axis=1, keepdims=True) + 1e-8
866
+ )
867
+ scores = np.dot(tool_norms, query_norm)
868
+
869
+ # Get top candidates for LLM reranking
870
+ retrieve_k = min(15, len(tool_ids))
871
+ top_indices = np.argsort(scores)[::-1][:retrieve_k]
872
+
873
+ # LLM reranking if available
874
+ if self._llm_client is not None:
875
+ reranked = self._llm_rerank(
876
+ instruction,
877
+ [(tool_ids[i], tool_texts[i]) for i in top_indices],
878
+ top_k,
879
+ )
880
+ if reranked:
881
+ return [
882
+ ToolPrediction(tool_id=tid, score=1.0 - i * 0.1)
883
+ for i, tid in enumerate(reranked)
884
+ ]
885
+
886
+ # Fallback to embedding-only
887
+ return [
888
+ ToolPrediction(tool_id=tool_ids[i], score=float(scores[i]))
889
+ for i in top_indices[:top_k]
890
+ ]
891
+
892
+ except Exception as e:
893
+ logger.warning(f"Gorilla selector failed: {e}")
894
+ return []
895
+
896
+ def _llm_rerank(self, query, tools, top_k):
897
+ """Use LLM to rerank retrieved tools."""
898
+ import json
899
+ import re
900
+
901
+ tools_text = "\n".join(
902
+ f"{i + 1}. {tid}: {desc}" for i, (tid, desc) in enumerate(tools)
903
+ )
904
+ prompt = f"""Select the {top_k} most relevant tools for this task. Return ONLY a JSON array of tool IDs.
905
+
906
+ Task: {query}
907
+
908
+ Tools:
909
+ {tools_text}
910
+
911
+ Output (JSON array only):"""
912
+
913
+ try:
914
+ response = self._llm_client.chat(
915
+ [{"role": "user", "content": prompt}], temperature=BENCHMARK_LLM_TEMPERATURE
916
+ )
917
+ response = response.strip()
918
+ if response.startswith("```"):
919
+ response = "\n".join(response.split("\n")[1:-1]).strip()
920
+ match = re.search(r"\[.*?\]", response, re.DOTALL)
921
+ if match:
922
+ selected = json.loads(match.group())
923
+ valid_ids = {tid for tid, _ in tools}
924
+ return [tid for tid in selected if tid in valid_ids][:top_k]
925
+ except Exception:
926
+ pass
927
+ return []
928
+
929
+ return SelectorAdapter(DynamicGorillaSelector())
930
+
931
+ def _create_dfsdt_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
932
+ """
933
+ Create DFSDT (Depth-First Search-based Decision Tree) selector.
934
+
935
+ Based on ToolLLM paper (Qin et al., 2023):
936
+ "ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs"
937
+
938
+ This implementation uses dynamic scoring - it evaluates the candidate_tools
939
+ provided in each query, enabling cross-dataset evaluation.
940
+ """
941
+
942
+ class DynamicDFSDTSelector:
943
+ """DFSDT selector with dynamic tool scoring."""
944
+
945
+ def __init__(self):
946
+ self._llm_client = None
947
+ self._score_threshold = 0.3
948
+ self.name = "dfsdt"
949
+
950
+ def _init_client(self):
951
+ """Lazy initialization of LLM client."""
952
+ if self._llm_client is None:
953
+ try:
954
+ from sage.llm import UnifiedInferenceClient
955
+
956
+ self._llm_client = UnifiedInferenceClient.create()
957
+ except Exception:
958
+ pass
959
+
960
+ def select(self, query, top_k=5):
961
+ """Select tools using LLM-based scoring."""
962
+
963
+ from sage.libs.agentic.agents.action.tool_selection.schemas import (
964
+ ToolPrediction,
965
+ )
966
+
967
+ self._init_client()
968
+
969
+ candidate_tools = getattr(query, "candidate_tools", []) or []
970
+ if not candidate_tools:
971
+ return []
972
+
973
+ instruction = getattr(query, "instruction", str(query))
974
+
975
+ # Parse and score each tool
976
+ scored_tools = []
977
+ for tool in candidate_tools:
978
+ if isinstance(tool, str):
979
+ tool_id, tool_name, tool_desc = tool, tool, ""
980
+ elif hasattr(tool, "name"):
981
+ tool_id = getattr(tool, "tool_id", getattr(tool, "id", tool.name))
982
+ tool_name, tool_desc = tool.name, getattr(tool, "description", "")
983
+ elif isinstance(tool, dict):
984
+ tool_id = tool.get("tool_id", tool.get("id", tool.get("name", "")))
985
+ tool_name = tool.get("name", tool_id)
986
+ tool_desc = tool.get("description", "")
987
+ else:
988
+ continue
989
+
990
+ score = self._score_tool(instruction, tool_name, tool_desc)
991
+ if score >= self._score_threshold:
992
+ scored_tools.append((tool_id, score))
993
+
994
+ scored_tools.sort(key=lambda x: x[1], reverse=True)
995
+ return [
996
+ ToolPrediction(tool_id=tid, score=score) for tid, score in scored_tools[:top_k]
997
+ ]
998
+
999
+ def _score_tool(self, query, tool_name, tool_desc):
1000
+ """Score tool relevance using LLM or keyword fallback."""
1001
+ import re
1002
+
1003
+ if self._llm_client is not None:
1004
+ prompt = f"""Rate relevance (0-10): Query: {query} | Tool: {tool_name} - {tool_desc}
1005
+ Output only a number:"""
1006
+ try:
1007
+ response = self._llm_client.chat(
1008
+ [{"role": "user", "content": prompt}],
1009
+ temperature=BENCHMARK_LLM_TEMPERATURE,
1010
+ )
1011
+ numbers = re.findall(r"(\d+(?:\.\d+)?)", response.strip())
1012
+ if numbers:
1013
+ return min(max(float(numbers[0]), 0.0), 10.0) / 10.0
1014
+ except Exception:
1015
+ pass
1016
+
1017
+ # Keyword fallback
1018
+ query_words = set(query.lower().split())
1019
+ tool_words = set(f"{tool_name} {tool_desc}".lower().split())
1020
+ if not query_words:
1021
+ return 0.0
1022
+ return min(len(query_words & tool_words) / len(query_words), 1.0)
1023
+
1024
+ return SelectorAdapter(DynamicDFSDTSelector())
1025
+
1026
+ def _create_inline_hybrid_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
1027
+ """Create inline hybrid selector when library version unavailable."""
1028
+ from sage.libs.agentic.agents.action.tool_selection import (
1029
+ KeywordSelector,
1030
+ KeywordSelectorConfig,
1031
+ ToolPrediction,
1032
+ )
1033
+
1034
+ class InlineHybridSelector:
1035
+ """Inline hybrid selector for benchmark fallback."""
1036
+
1037
+ def __init__(self, resources):
1038
+ self.resources = resources
1039
+ # Create keyword selector as base
1040
+ config = KeywordSelectorConfig(
1041
+ name="keyword",
1042
+ method="bm25",
1043
+ top_k=10,
1044
+ )
1045
+ self._keyword_selector = KeywordSelector(config, resources)
1046
+ self.name = "hybrid"
1047
+
1048
+ def select(self, query, top_k=5):
1049
+ """Select using keyword with boosting."""
1050
+ # Use keyword selector
1051
+ results = self._keyword_selector.select(query, top_k=top_k * 2)
1052
+
1053
+ # Simple score boosting based on query-tool match
1054
+ boosted = []
1055
+ query_lower = query.instruction.lower()
1056
+
1057
+ for pred in results[:top_k]:
1058
+ # Check if tool matches query keywords better
1059
+ tool_text = self.resources.tools_loader.get_tool(pred.tool_id)
1060
+ tool_desc = getattr(tool_text, "description", "") or ""
1061
+
1062
+ # Simple boost: increase score if query words in description
1063
+ words = query_lower.split()
1064
+ matches = sum(1 for w in words if w in tool_desc.lower())
1065
+ boost = 1.0 + (matches * 0.1)
1066
+
1067
+ boosted.append(
1068
+ ToolPrediction(
1069
+ tool_id=pred.tool_id,
1070
+ score=min(pred.score * boost, 1.0),
1071
+ metadata={"method": "hybrid_inline", "boost": boost},
1072
+ )
1073
+ )
1074
+
1075
+ boosted.sort(key=lambda x: x.score, reverse=True)
1076
+
1077
+ # Optionally use LLM reranker via UnifiedInferenceClient.create()
1078
+ # Uses LOCAL-FIRST strategy: local services (SagePorts) -> cloud API fallback
1079
+ # Set SAGE_HYBRID_ENABLE_LLM_RERANK=1 to enable
1080
+ import os
1081
+
1082
+ enable_llm_rerank = os.environ.get("SAGE_HYBRID_ENABLE_LLM_RERANK", "0") == "1"
1083
+
1084
+ if enable_llm_rerank and not hasattr(self, "_llm_client_checked"):
1085
+ self._llm_client_checked = True
1086
+ try:
1087
+ from sage.llm import UnifiedInferenceClient
1088
+
1089
+ # Use singleton to avoid repeated model loading
1090
+ self._llm_client = UnifiedInferenceClient.get_instance(
1091
+ instance_key="benchmark_hybrid"
1092
+ )
1093
+ except Exception:
1094
+ self._llm_client = None
1095
+
1096
+ if enable_llm_rerank and getattr(self, "_llm_client", None) is not None:
1097
+ try:
1098
+ reranked = self._llm_rerank(self._llm_client, query, boosted, top_k)
1099
+ if reranked:
1100
+ return reranked
1101
+ except Exception:
1102
+ pass # Silently fall back to keyword-boosted results
1103
+
1104
+ return boosted[:top_k]
1105
+
1106
+ def _llm_rerank(self, llm, query, boosted, top_k):
1107
+ """Use LLM to rerank candidates."""
1108
+ import json
1109
+ import re
1110
+
1111
+ # Prepare prompt
1112
+ cand_infos = []
1113
+ for p in boosted[: max(10, top_k)]:
1114
+ try:
1115
+ t = self.resources.tools_loader.get_tool(p.tool_id)
1116
+ desc = getattr(t, "description", "") or ""
1117
+ except Exception:
1118
+ desc = ""
1119
+ cand_infos.append(f"{p.tool_id}: {desc[:100]}")
1120
+
1121
+ messages = [
1122
+ {
1123
+ "role": "system",
1124
+ "content": "You are an assistant that ranks candidate tools by relevance. Return a JSON array of tool ids sorted most->least relevant. Only output the JSON array.",
1125
+ },
1126
+ {
1127
+ "role": "user",
1128
+ "content": f"Instruction: {query.instruction}\n\nCandidates:\n"
1129
+ + "\n".join(cand_infos),
1130
+ },
1131
+ ]
1132
+
1133
+ # UnifiedInferenceClient.chat() returns string directly
1134
+ resp = llm.chat(messages)
1135
+
1136
+ # Parse response
1137
+ txt = resp if isinstance(resp, str) else str(resp)
1138
+ m = re.search(r"\[.*\]", txt, re.S)
1139
+ if not m:
1140
+ return None
1141
+
1142
+ try:
1143
+ ranked = json.loads(m.group(0))
1144
+ except Exception:
1145
+ return None
1146
+
1147
+ if not isinstance(ranked, list):
1148
+ return None
1149
+
1150
+ # Build final predictions
1151
+ id_to_pred = {p.tool_id: p for p in boosted}
1152
+ final_preds = []
1153
+ for tid in ranked:
1154
+ if tid in id_to_pred:
1155
+ final_preds.append(id_to_pred[tid])
1156
+ for p in boosted:
1157
+ if p.tool_id not in {fp.tool_id for fp in final_preds}:
1158
+ final_preds.append(p)
1159
+
1160
+ return final_preds[:top_k]
1161
+
1162
+ if resources is None:
1163
+ resources = self._create_mock_resources()
1164
+
1165
+ selector = InlineHybridSelector(resources)
1166
+ return SelectorAdapter(selector)
1167
+
1168
+ def _create_template_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
1169
+ """Create template-based planner."""
1170
+
1171
+ # Return a simple mock planner for now
1172
+ class MockPlanner:
1173
+ def plan(self, task):
1174
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1175
+ PlanningPrediction,
1176
+ PlanStep,
1177
+ )
1178
+
1179
+ return PlanningPrediction(
1180
+ steps=[
1181
+ PlanStep(
1182
+ step_id=0,
1183
+ description="Execute task",
1184
+ tool_id=task.available_tools[0] if task.available_tools else "unknown",
1185
+ confidence=0.5,
1186
+ )
1187
+ ],
1188
+ tool_sequence=[task.available_tools[0] if task.available_tools else "unknown"],
1189
+ )
1190
+
1191
+ return PlannerAdapter(MockPlanner())
1192
+
1193
+ def _create_hierarchical_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
1194
+ """Create hierarchical planner."""
1195
+ try:
1196
+ from sage.libs.agentic.agents.planning import HierarchicalPlanner
1197
+
1198
+ planner = HierarchicalPlanner()
1199
+ return PlannerAdapter(planner)
1200
+ except ImportError:
1201
+ return self._create_template_planner(resources)
1202
+
1203
+ def _create_react_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
1204
+ """
1205
+ Create ReAct planner implementing Thought-Action-Observation loop.
1206
+
1207
+ ReAct (Reasoning + Acting) generates plans by interleaving:
1208
+ 1. Thought: Reasoning about current state
1209
+ 2. Action: Selecting tool to use
1210
+ 3. Observation: Expected result (predicted in planning mode)
1211
+
1212
+ Reference: "ReAct: Synergizing Reasoning and Acting in Language Models" (Yao et al., 2023)
1213
+ """
1214
+
1215
+ class ReActPlannerWrapper:
1216
+ """Wrapper for ReAct planner with benchmark-compatible interface."""
1217
+
1218
+ def __init__(self):
1219
+ self._planner = None
1220
+ self._llm_client = None
1221
+ self._initialized = False
1222
+
1223
+ def _ensure_initialized(self):
1224
+ """Lazy initialization of ReAct planner."""
1225
+ if self._initialized:
1226
+ return
1227
+
1228
+ self._initialized = True
1229
+
1230
+ try:
1231
+ from sage.libs.agentic.agents.planning import (
1232
+ ReActConfig,
1233
+ ReActPlanner,
1234
+ )
1235
+
1236
+ # Try to get LLM client
1237
+ llm_client = self._get_llm_client()
1238
+
1239
+ config = ReActConfig(
1240
+ min_steps=5,
1241
+ max_steps=10,
1242
+ max_iterations=12,
1243
+ temperature=BENCHMARK_LLM_TEMPERATURE,
1244
+ )
1245
+
1246
+ self._planner = ReActPlanner(
1247
+ config=config,
1248
+ llm_client=llm_client,
1249
+ )
1250
+ except ImportError as e:
1251
+ import logging
1252
+
1253
+ logging.warning(f"ReActPlanner import failed: {e}")
1254
+ self._planner = None
1255
+
1256
+ def _get_llm_client(self):
1257
+ """Get LLM client with local-first strategy."""
1258
+ # Use UnifiedInferenceClient.create() which implements local-first strategy
1259
+ try:
1260
+ from sage.llm import UnifiedInferenceClient
1261
+
1262
+ return UnifiedInferenceClient.create()
1263
+ except Exception:
1264
+ pass
1265
+
1266
+ return None
1267
+
1268
+ def plan(self, task):
1269
+ """Generate plan using ReAct strategy."""
1270
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1271
+ PlanningPrediction,
1272
+ PlanStep,
1273
+ )
1274
+
1275
+ self._ensure_initialized()
1276
+
1277
+ instruction = getattr(task, "instruction", "") or ""
1278
+ available_tools = getattr(task, "available_tools", []) or []
1279
+
1280
+ if not available_tools:
1281
+ return PlanningPrediction(steps=[], tool_sequence=[])
1282
+
1283
+ # Try ReAct planner if available
1284
+ if self._planner is not None:
1285
+ try:
1286
+ from sage.libs.agentic.agents.planning import (
1287
+ PlanRequest,
1288
+ ToolMetadata,
1289
+ )
1290
+
1291
+ # Convert available tools to ToolMetadata
1292
+ tools = [
1293
+ ToolMetadata(
1294
+ tool_id=t,
1295
+ name=t,
1296
+ description=f"Tool: {t}",
1297
+ category="general",
1298
+ )
1299
+ for t in available_tools
1300
+ ]
1301
+
1302
+ request = PlanRequest(
1303
+ goal=instruction,
1304
+ tools=tools,
1305
+ constraints=[],
1306
+ min_steps=5,
1307
+ max_steps=10,
1308
+ )
1309
+
1310
+ result = self._planner.plan(request)
1311
+
1312
+ if result.success and result.steps:
1313
+ steps = []
1314
+ tool_sequence = []
1315
+
1316
+ for i, step in enumerate(result.steps):
1317
+ tool_id = step.tool_id or step.action
1318
+ if tool_id in available_tools:
1319
+ steps.append(
1320
+ PlanStep(
1321
+ step_id=i,
1322
+ description=step.description or step.action,
1323
+ tool_id=tool_id,
1324
+ confidence=0.8,
1325
+ )
1326
+ )
1327
+ tool_sequence.append(tool_id)
1328
+
1329
+ if steps:
1330
+ return PlanningPrediction(
1331
+ steps=steps,
1332
+ tool_sequence=tool_sequence,
1333
+ )
1334
+ except Exception:
1335
+ pass
1336
+
1337
+ # Fallback: use heuristic-based planning
1338
+ return self._fallback_plan(instruction, available_tools)
1339
+
1340
+ def _fallback_plan(self, instruction: str, available_tools: list[str]):
1341
+ """Heuristic-based fallback planning."""
1342
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1343
+ PlanningPrediction,
1344
+ PlanStep,
1345
+ )
1346
+
1347
+ steps = []
1348
+ tool_sequence = []
1349
+ instruction_lower = instruction.lower()
1350
+
1351
+ # Score tools by relevance
1352
+ tool_scores = []
1353
+ for tool in available_tools:
1354
+ score = 0
1355
+ tool_lower = tool.lower()
1356
+ tool_words = set(tool_lower.replace("_", " ").split())
1357
+ instruction_words = set(instruction_lower.replace(",", " ").split())
1358
+
1359
+ # Word overlap
1360
+ overlap = len(tool_words & instruction_words)
1361
+ score += overlap * 2
1362
+
1363
+ # Action keyword matching
1364
+ if any(w in tool_lower for w in ["read", "get", "fetch", "load"]):
1365
+ if any(w in instruction_lower for w in ["read", "get", "load", "fetch"]):
1366
+ score += 2
1367
+ if any(w in tool_lower for w in ["write", "save", "send", "post"]):
1368
+ if any(w in instruction_lower for w in ["write", "save", "send", "post"]):
1369
+ score += 2
1370
+ if any(w in tool_lower for w in ["process", "transform", "convert"]):
1371
+ if any(w in instruction_lower for w in ["process", "convert", "transform"]):
1372
+ score += 2
1373
+
1374
+ tool_scores.append((tool, score))
1375
+
1376
+ tool_scores.sort(key=lambda x: x[1], reverse=True)
1377
+
1378
+ # Select top tools
1379
+ selected = [t for t, s in tool_scores[:8] if s > 0]
1380
+ if len(selected) < 5:
1381
+ for t, _ in tool_scores:
1382
+ if t not in selected:
1383
+ selected.append(t)
1384
+ if len(selected) >= 5:
1385
+ break
1386
+
1387
+ for i, tool in enumerate(selected[:10]):
1388
+ steps.append(
1389
+ PlanStep(
1390
+ step_id=i,
1391
+ description=f"ReAct step {i + 1}: Use {tool}",
1392
+ tool_id=tool,
1393
+ confidence=0.6,
1394
+ )
1395
+ )
1396
+ tool_sequence.append(tool)
1397
+
1398
+ return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
1399
+
1400
+ return PlannerAdapter(ReActPlannerWrapper())
1401
+
1402
+ def _create_sequence_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
1403
+ """Create sequence-based planner using selector for tool ordering."""
1404
+
1405
+ class SequencePlanner:
1406
+ """Plan by selecting tools in sequence based on task steps."""
1407
+
1408
+ def __init__(self, selector_factory):
1409
+ self._selector_factory = selector_factory
1410
+
1411
+ def plan(self, task):
1412
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1413
+ PlanningPrediction,
1414
+ PlanStep,
1415
+ )
1416
+
1417
+ # Use task instruction (not description) to select relevant tools
1418
+ instruction = (
1419
+ getattr(task, "instruction", "") or getattr(task, "description", "") or ""
1420
+ )
1421
+ steps = []
1422
+ tool_sequence = []
1423
+
1424
+ # Parse task for steps (simple heuristic)
1425
+ sub_tasks = [s.strip() for s in instruction.split(".") if s.strip()]
1426
+
1427
+ available = task.available_tools if task.available_tools else []
1428
+
1429
+ # Match each sub-task to a tool
1430
+ for i, sub_task in enumerate(sub_tasks[:5]): # Max 5 steps
1431
+ # Simple matching: pick tool with most keyword overlap
1432
+ best_tool = available[i % len(available)] if available else "unknown"
1433
+
1434
+ steps.append(
1435
+ PlanStep(
1436
+ step_id=i,
1437
+ description=sub_task,
1438
+ tool_id=best_tool,
1439
+ confidence=0.6,
1440
+ )
1441
+ )
1442
+ tool_sequence.append(best_tool)
1443
+
1444
+ if not steps and available:
1445
+ # Fallback: at least one step
1446
+ steps.append(
1447
+ PlanStep(
1448
+ step_id=0,
1449
+ description=instruction[:100] if instruction else "Execute task",
1450
+ tool_id=available[0],
1451
+ confidence=0.5,
1452
+ )
1453
+ )
1454
+ tool_sequence.append(available[0])
1455
+
1456
+ return PlanningPrediction(
1457
+ steps=steps,
1458
+ tool_sequence=tool_sequence,
1459
+ )
1460
+
1461
+ return PlannerAdapter(SequencePlanner(self._create_keyword_selector))
1462
+
1463
+ def _create_threshold_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
1464
+ """Create threshold-based timing decider."""
1465
+
1466
+ class ThresholdDecider:
1467
+ """Simple keyword-based timing decider."""
1468
+
1469
+ # Keywords indicating tool invocation is needed
1470
+ ACTION_KEYWORDS = frozenset(
1471
+ [
1472
+ "search",
1473
+ "find",
1474
+ "calculate",
1475
+ "analyze",
1476
+ "create",
1477
+ "update",
1478
+ "delete",
1479
+ "get",
1480
+ "fetch",
1481
+ "query",
1482
+ "look up",
1483
+ "current",
1484
+ "now",
1485
+ "today",
1486
+ "latest",
1487
+ "real-time",
1488
+ "weather",
1489
+ "time",
1490
+ "stock",
1491
+ "price",
1492
+ "news",
1493
+ "check",
1494
+ "verify",
1495
+ "compare",
1496
+ "list",
1497
+ "show",
1498
+ ]
1499
+ )
1500
+
1501
+ # Keywords indicating direct answer (no tool needed)
1502
+ FACTUAL_KEYWORDS = frozenset(
1503
+ [
1504
+ "what is",
1505
+ "how many",
1506
+ "define",
1507
+ "explain",
1508
+ "meaning of",
1509
+ "who was",
1510
+ "when was",
1511
+ "where is",
1512
+ "+ ",
1513
+ "multiply",
1514
+ "capital of",
1515
+ "population of",
1516
+ "history of",
1517
+ ]
1518
+ )
1519
+
1520
+ def decide(self, message):
1521
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1522
+ TimingDecision,
1523
+ )
1524
+
1525
+ text = message.message.lower()
1526
+
1527
+ # Check for action keywords
1528
+ has_action = any(kw in text for kw in self.ACTION_KEYWORDS)
1529
+
1530
+ # Check for factual (no-tool) keywords
1531
+ has_factual = any(kw in text for kw in self.FACTUAL_KEYWORDS)
1532
+
1533
+ # Heuristic: action keywords > factual keywords
1534
+ # Real-time or current info needs tools
1535
+ should_call = has_action and not (
1536
+ has_factual
1537
+ and not any(
1538
+ kw in text
1539
+ for kw in [
1540
+ "current",
1541
+ "now",
1542
+ "today",
1543
+ "latest",
1544
+ "real-time",
1545
+ "weather",
1546
+ "time",
1547
+ "stock",
1548
+ "news",
1549
+ ]
1550
+ )
1551
+ )
1552
+
1553
+ return TimingDecision(
1554
+ should_call_tool=should_call,
1555
+ confidence=0.8 if (has_action or has_factual) else 0.5,
1556
+ reasoning=(
1557
+ "Detected action keywords"
1558
+ if should_call
1559
+ else "No action keywords or factual query"
1560
+ ),
1561
+ )
1562
+
1563
+ return TimingAdapter(ThresholdDecider())
1564
+
1565
+ def _create_llm_timing_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
1566
+ """Create LLM-based timing decider using UnifiedInferenceClient.
1567
+
1568
+ Uses UnifiedInferenceClient.create() which handles:
1569
+ 1. Environment variables (SAGE_UNIFIED_BASE_URL)
1570
+ 2. Local services (ports from SagePorts)
1571
+ 3. Cloud API fallback (OpenAI-compatible)
1572
+ """
1573
+
1574
+ class LLMTimingDecider:
1575
+ """
1576
+ LLM-based timing decider using UnifiedInferenceClient.
1577
+ """
1578
+
1579
+ TIMING_PROMPT = """You are an AI assistant that determines whether a user's message requires tool invocation or can be answered directly from your knowledge.
1580
+
1581
+ Analyze the following user message and determine:
1582
+ 1. Does this message require real-time information (weather, stock prices, current time, etc.)?
1583
+ 2. Does this message require performing an action (search, calculate, create file, send email, etc.)?
1584
+ 3. Does this message require accessing external data or APIs?
1585
+
1586
+ If ANY of the above is true, the user needs a tool call.
1587
+ If the message is asking for factual knowledge, explanations, opinions, creative writing, or general conversation, it can be answered directly without tools.
1588
+
1589
+ User message: "{message}"
1590
+
1591
+ Respond in JSON format:
1592
+ {{
1593
+ "should_call_tool": true/false,
1594
+ "confidence": 0.0-1.0,
1595
+ "reasoning": "brief explanation"
1596
+ }}
1597
+
1598
+ Only output the JSON, nothing else."""
1599
+
1600
+ def __init__(self):
1601
+ self._client = None
1602
+ self._initialized = False
1603
+
1604
+ def _ensure_client(self):
1605
+ """Lazy initialization using UnifiedInferenceClient."""
1606
+ if self._initialized:
1607
+ return self._client is not None
1608
+
1609
+ self._initialized = True
1610
+
1611
+ try:
1612
+ from sage.llm import UnifiedInferenceClient
1613
+
1614
+ self._client = UnifiedInferenceClient.create()
1615
+ return True
1616
+ except Exception as e:
1617
+ import logging
1618
+
1619
+ logging.getLogger(__name__).warning(
1620
+ f"Failed to initialize LLM client: {e}. Falling back to rule-based."
1621
+ )
1622
+ self._client = None
1623
+ return False
1624
+
1625
+ def decide(self, message):
1626
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1627
+ TimingDecision,
1628
+ )
1629
+
1630
+ # Ensure LLM client is initialized
1631
+ if not self._ensure_client():
1632
+ return self._fallback_decide(message)
1633
+
1634
+ try:
1635
+ import json
1636
+
1637
+ prompt = self.TIMING_PROMPT.format(message=message.message)
1638
+ messages = [{"role": "user", "content": prompt}]
1639
+
1640
+ # UnifiedInferenceClient.chat() returns string directly
1641
+ content = self._client.chat(messages)
1642
+
1643
+ # Parse JSON response
1644
+ # Handle potential markdown code blocks
1645
+ if content.startswith("```"):
1646
+ content = content.split("```")[1]
1647
+ if content.startswith("json"):
1648
+ content = content[4:]
1649
+ content = content.strip()
1650
+
1651
+ result = json.loads(content)
1652
+ return TimingDecision(
1653
+ should_call_tool=result.get("should_call_tool", False),
1654
+ confidence=float(result.get("confidence", 0.7)),
1655
+ reasoning=f"[LLM] {result.get('reasoning', 'LLM decision')}",
1656
+ )
1657
+ except Exception:
1658
+ return self._fallback_decide(message)
1659
+
1660
+ def _fallback_decide(self, message):
1661
+ """Fallback to simple rule-based for failed LLM calls."""
1662
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1663
+ TimingDecision,
1664
+ )
1665
+
1666
+ text = message.message.lower()
1667
+ action_keywords = [
1668
+ "search",
1669
+ "find",
1670
+ "calculate",
1671
+ "weather",
1672
+ "stock",
1673
+ "price",
1674
+ "current",
1675
+ "now",
1676
+ "today",
1677
+ "create",
1678
+ "send",
1679
+ "schedule",
1680
+ ]
1681
+ has_action = any(kw in text for kw in action_keywords)
1682
+
1683
+ return TimingDecision(
1684
+ should_call_tool=has_action,
1685
+ confidence=0.6,
1686
+ reasoning="[Fallback] Simple keyword match",
1687
+ )
1688
+
1689
+ return TimingAdapter(LLMTimingDecider())
1690
+
1691
+ def _create_rule_based_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
1692
+ """Create rule-based timing decider with comprehensive keyword matching."""
1693
+
1694
+ class RuleBasedDecider:
1695
+ """
1696
+ Enhanced rule-based timing decider v2.
1697
+
1698
+ Uses multi-layer pattern matching and keyword analysis to determine
1699
+ if a message requires tool invocation.
1700
+
1701
+ Improvements over v1:
1702
+ - Added patterns for prediction/forecast queries (stock prices, future events)
1703
+ - Added patterns for data lookup queries (calories, nutritional info)
1704
+ - Added NO_TOOL patterns for advice/opinion questions
1705
+ - Improved confidence scoring with weighted categories
1706
+ - Better handling of edge cases (e.g., "should I" questions)
1707
+ """
1708
+
1709
+ # Patterns that strongly indicate tool invocation is needed
1710
+ TOOL_REQUIRED_PATTERNS = [
1711
+ # Real-time information
1712
+ r"\b(current|real-time|live|now|today|right now)\b.*\b(weather|temperature|stock|price|news|time|traffic)\b",
1713
+ r"\bwhat('s| is) the (current|latest)\b",
1714
+ r"\b(weather|temperature|forecast)\s+(in|for|at)\b",
1715
+ r"\b(stock|share) (price|value|chart)\b",
1716
+ # Time queries
1717
+ r"\bwhat('s| is) the\s+(current\s+)?time\s+(in|at)\b",
1718
+ r"\bwhat time is it in\b",
1719
+ # Prediction/Forecast queries (needs tool for data)
1720
+ r"\bwhat will\b.*\b(price|stock|weather|be)\b.*\b(next|tomorrow|week|month)\b",
1721
+ r"\bwill it\b.*\b(rain|snow|be sunny|be cold|be hot)\b",
1722
+ r"\b(forecast|predict|projection)\s+for\b",
1723
+ # Actions and operations
1724
+ r"\b(search|find|look up|fetch|retrieve|get|query)\s+(for|the|about)?\b",
1725
+ r"\b(calculate|compute|convert)\s+",
1726
+ r"\b(create|generate|make|build)\s+(a|an|the)?\s*(file|document|spreadsheet|chart|report)\b",
1727
+ r"\b(send|email|schedule|book|reserve|cancel)\s+",
1728
+ # File operations (enhanced patterns)
1729
+ r"\b(open|close|save|delete|rename|move|copy)\s+(the|a|this)?\s*(file|document|report|spreadsheet)\b",
1730
+ r"\bsave\s+(this|the|a)\s+(document|file|data)\b",
1731
+ r"\bsave\s+.*\s+as\b", # "save X as Y"
1732
+ r"\bopen\s+(the|a)\s+file\b",
1733
+ r"\bdelete\s+(the|a)\s+file\b",
1734
+ # Code execution (enhanced patterns)
1735
+ r"\b(run|execute|compile|debug|test)\s+(this|the|a)?\s*(code|script|program|command|python|javascript)\b",
1736
+ r"\brun\s+this\s+\w+\s+code\b", # "run this Python code"
1737
+ r"\bexecute\s+(this|the)\s+(script|code|program)\b",
1738
+ # Database/API operations
1739
+ r"\b(select|insert|update|delete)\s+.*\b(from|into|where)\b",
1740
+ r"\bapi\s+(call|request|endpoint)\b",
1741
+ # Data lookup queries (needs external database)
1742
+ r"\bhow many calories\b",
1743
+ r"\b(calories|carbs|protein|fat|nutrition)\s+(in|of)\b",
1744
+ r"\b(nutritional|nutrition)\s+(info|information|value|data)\b",
1745
+ r"\bexchange rate\s+(for|of|between)\b",
1746
+ r"\bconvert\s+\d+\s*\w+\s+to\b",
1747
+ ]
1748
+
1749
+ # Patterns that indicate NO tool needed (advice, opinion, philosophical)
1750
+ NO_TOOL_PATTERNS = [
1751
+ # Advice/Opinion questions
1752
+ r"\bshould i\b",
1753
+ r"\bwhat do you think\b",
1754
+ r"\bwhat('s| is) your opinion\b",
1755
+ r"\bdo you recommend\b",
1756
+ r"\bany (tips|advice|suggestions)\b",
1757
+ r"\bis it (a good|worth|better)\b",
1758
+ # Personal/Philosophical questions
1759
+ r"\bwhat('s| is) the meaning of life\b",
1760
+ r"\bwhy (do|should) (we|i|people)\b",
1761
+ r"\bhow (do|can) i (feel|cope|deal)\b",
1762
+ ]
1763
+
1764
+ # Keywords indicating tool invocation (with weights)
1765
+ TOOL_KEYWORDS = frozenset(
1766
+ [
1767
+ # Search/Retrieve
1768
+ "search",
1769
+ "find",
1770
+ "look up",
1771
+ "lookup",
1772
+ "fetch",
1773
+ "retrieve",
1774
+ "query",
1775
+ # Actions
1776
+ "calculate",
1777
+ "compute",
1778
+ "convert",
1779
+ "translate",
1780
+ "analyze",
1781
+ # CRUD operations
1782
+ "create",
1783
+ "update",
1784
+ "delete",
1785
+ "modify",
1786
+ "edit",
1787
+ # Real-time indicators
1788
+ "current",
1789
+ "live",
1790
+ "real-time",
1791
+ "realtime",
1792
+ "latest",
1793
+ "now",
1794
+ "today",
1795
+ "right now",
1796
+ # Specific domains requiring tools
1797
+ "weather",
1798
+ "stock",
1799
+ "price",
1800
+ "exchange rate",
1801
+ "traffic",
1802
+ "flight",
1803
+ "news",
1804
+ "calories",
1805
+ "nutritional",
1806
+ # File/System operations
1807
+ "open",
1808
+ "save",
1809
+ "download",
1810
+ "upload",
1811
+ "export",
1812
+ "import",
1813
+ # Scheduling
1814
+ "schedule",
1815
+ "book",
1816
+ "reserve",
1817
+ "remind",
1818
+ "alarm",
1819
+ # Communication
1820
+ "send",
1821
+ "email",
1822
+ "message",
1823
+ "notify",
1824
+ "call",
1825
+ # Code execution
1826
+ "run",
1827
+ "execute",
1828
+ "compile",
1829
+ "debug",
1830
+ ]
1831
+ )
1832
+
1833
+ # High-weight keywords that strongly suggest tool usage
1834
+ HIGH_WEIGHT_TOOL_KEYWORDS = frozenset(
1835
+ [
1836
+ "search",
1837
+ "calculate",
1838
+ "weather",
1839
+ "stock",
1840
+ "price",
1841
+ "calories",
1842
+ "execute",
1843
+ "run code",
1844
+ "compile",
1845
+ "exchange rate",
1846
+ "schedule",
1847
+ "book",
1848
+ "send email",
1849
+ ]
1850
+ )
1851
+
1852
+ # Keywords indicating NO tool needed (direct answer)
1853
+ NO_TOOL_KEYWORDS = frozenset(
1854
+ [
1855
+ # Definitions
1856
+ "what is",
1857
+ "what are",
1858
+ "define",
1859
+ "definition",
1860
+ "meaning of",
1861
+ "explain",
1862
+ "describe",
1863
+ # Factual (static knowledge)
1864
+ "who was",
1865
+ "who is",
1866
+ "who invented",
1867
+ "who wrote",
1868
+ "when was",
1869
+ "when did",
1870
+ "where is",
1871
+ "where was",
1872
+ "capital of",
1873
+ "population of",
1874
+ "history of",
1875
+ # Knowledge/Educational
1876
+ "how does",
1877
+ "how do",
1878
+ "why do",
1879
+ "why does",
1880
+ "what causes",
1881
+ # Conversational
1882
+ "hello",
1883
+ "hi",
1884
+ "thanks",
1885
+ "thank you",
1886
+ "goodbye",
1887
+ "bye",
1888
+ # Opinion/Advice (important: these do NOT need tools)
1889
+ "what do you think",
1890
+ "your opinion",
1891
+ "any tips",
1892
+ "advice",
1893
+ "suggest",
1894
+ "recommend",
1895
+ "should i",
1896
+ "is it worth",
1897
+ "is it a good idea",
1898
+ "pros and cons",
1899
+ # Creative writing
1900
+ "write a",
1901
+ "compose",
1902
+ "create a story",
1903
+ "create a poem",
1904
+ "tell me a story",
1905
+ # Math/Science knowledge (not calculations)
1906
+ "pythagorean theorem",
1907
+ "quadratic formula",
1908
+ "what is pi",
1909
+ ]
1910
+ )
1911
+
1912
+ # High-weight keywords that strongly suggest NO tool
1913
+ HIGH_WEIGHT_NO_TOOL_KEYWORDS = frozenset(
1914
+ [
1915
+ "should i",
1916
+ "what do you think",
1917
+ "your opinion",
1918
+ "recommend",
1919
+ "advice",
1920
+ "capital of",
1921
+ "who invented",
1922
+ "who wrote",
1923
+ "explain",
1924
+ "define",
1925
+ ]
1926
+ )
1927
+
1928
+ def __init__(self):
1929
+ import re
1930
+
1931
+ self._compiled_tool_patterns = [
1932
+ re.compile(p, re.IGNORECASE) for p in self.TOOL_REQUIRED_PATTERNS
1933
+ ]
1934
+ self._compiled_no_tool_patterns = [
1935
+ re.compile(p, re.IGNORECASE) for p in self.NO_TOOL_PATTERNS
1936
+ ]
1937
+
1938
+ def decide(self, message):
1939
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
1940
+ TimingDecision,
1941
+ )
1942
+
1943
+ text = message.message.lower()
1944
+
1945
+ # Priority 1: Check for strong NO-TOOL patterns first (advice/opinion)
1946
+ # These override tool patterns because the question is fundamentally
1947
+ # asking for advice, not data retrieval
1948
+ no_tool_pattern_match = any(p.search(text) for p in self._compiled_no_tool_patterns)
1949
+ if no_tool_pattern_match:
1950
+ # Even if there are tool-like keywords, the core question is advice
1951
+ return TimingDecision(
1952
+ should_call_tool=False,
1953
+ confidence=0.9,
1954
+ reasoning="Strong pattern match for advice/opinion question (no tool needed)",
1955
+ )
1956
+
1957
+ # Priority 2: Check for strong TOOL patterns
1958
+ tool_pattern_match = any(p.search(text) for p in self._compiled_tool_patterns)
1959
+ if tool_pattern_match:
1960
+ return TimingDecision(
1961
+ should_call_tool=True,
1962
+ confidence=0.95,
1963
+ reasoning="Strong pattern match for tool invocation",
1964
+ )
1965
+
1966
+ # Priority 3: Weighted keyword analysis
1967
+ # Calculate weighted scores
1968
+ tool_score = sum(1 for kw in self.TOOL_KEYWORDS if kw in text)
1969
+ high_tool_score = sum(2 for kw in self.HIGH_WEIGHT_TOOL_KEYWORDS if kw in text)
1970
+ total_tool_score = tool_score + high_tool_score
1971
+
1972
+ no_tool_score = sum(1 for kw in self.NO_TOOL_KEYWORDS if kw in text)
1973
+ high_no_tool_score = sum(
1974
+ 2 for kw in self.HIGH_WEIGHT_NO_TOOL_KEYWORDS if kw in text
1975
+ )
1976
+ total_no_tool_score = no_tool_score + high_no_tool_score
1977
+
1978
+ # Decision based on weighted scores
1979
+ if total_tool_score > 0 and total_no_tool_score == 0:
1980
+ confidence = min(0.7 + total_tool_score * 0.08, 0.95)
1981
+ return TimingDecision(
1982
+ should_call_tool=True,
1983
+ confidence=confidence,
1984
+ reasoning=f"Tool keywords detected (score: {total_tool_score})",
1985
+ )
1986
+ elif total_no_tool_score > 0 and total_tool_score == 0:
1987
+ confidence = min(0.7 + total_no_tool_score * 0.08, 0.95)
1988
+ return TimingDecision(
1989
+ should_call_tool=False,
1990
+ confidence=confidence,
1991
+ reasoning=f"No-tool keywords detected (score: {total_no_tool_score})",
1992
+ )
1993
+ elif total_tool_score > total_no_tool_score:
1994
+ score_diff = total_tool_score - total_no_tool_score
1995
+ confidence = min(0.55 + score_diff * 0.08, 0.85)
1996
+ return TimingDecision(
1997
+ should_call_tool=True,
1998
+ confidence=confidence,
1999
+ reasoning=f"More tool keywords ({total_tool_score} vs {total_no_tool_score})",
2000
+ )
2001
+ elif total_no_tool_score > total_tool_score:
2002
+ score_diff = total_no_tool_score - total_tool_score
2003
+ confidence = min(0.55 + score_diff * 0.08, 0.85)
2004
+ return TimingDecision(
2005
+ should_call_tool=False,
2006
+ confidence=confidence,
2007
+ reasoning=f"More no-tool keywords ({total_no_tool_score} vs {total_tool_score})",
2008
+ )
2009
+ else:
2010
+ # Ambiguous case: use heuristics
2011
+ # Check for question words that might indicate knowledge queries
2012
+ knowledge_indicators = ["what is", "who is", "where is", "when was"]
2013
+ is_knowledge_query = any(ind in text for ind in knowledge_indicators)
2014
+
2015
+ if is_knowledge_query:
2016
+ return TimingDecision(
2017
+ should_call_tool=False,
2018
+ confidence=0.55,
2019
+ reasoning="Ambiguous - appears to be knowledge query, defaulting to no tool",
2020
+ )
2021
+
2022
+ # Default: assume no tool needed for truly ambiguous cases
2023
+ return TimingDecision(
2024
+ should_call_tool=False,
2025
+ confidence=0.5,
2026
+ reasoning="Ambiguous - defaulting to no tool",
2027
+ )
2028
+
2029
+ return TimingAdapter(RuleBasedDecider())
2030
+
2031
+ def _create_hybrid_timing_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
2032
+ """Create hybrid timing decider combining rule-based and LLM-based approaches."""
2033
+
2034
+ class HybridDecider:
2035
+ """
2036
+ Hybrid timing decider.
2037
+
2038
+ Uses rule-based detection first for high-confidence cases,
2039
+ falls back to LLM for ambiguous cases (if available).
2040
+ """
2041
+
2042
+ def __init__(self, rule_decider, llm_decider=None, confidence_threshold=0.7):
2043
+ self._rule_decider = rule_decider
2044
+ self._llm_decider = llm_decider
2045
+ self._threshold = confidence_threshold
2046
+
2047
+ def decide(self, message):
2048
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
2049
+ TimingDecision,
2050
+ )
2051
+
2052
+ # First try rule-based
2053
+ rule_decision = self._rule_decider.decide(message)
2054
+
2055
+ # If high confidence, return rule-based result
2056
+ if rule_decision.confidence >= self._threshold:
2057
+ return TimingDecision(
2058
+ should_call_tool=rule_decision.should_call_tool,
2059
+ confidence=rule_decision.confidence,
2060
+ reasoning=f"[Rule-based] {rule_decision.reasoning}",
2061
+ )
2062
+
2063
+ # For low confidence, try LLM if available
2064
+ if self._llm_decider is not None:
2065
+ try:
2066
+ llm_decision = self._llm_decider.decide(message)
2067
+ # Combine results
2068
+ if rule_decision.should_call_tool == llm_decision.should_call_tool:
2069
+ # Agreement - higher confidence
2070
+ combined_conf = min(
2071
+ (rule_decision.confidence + llm_decision.confidence) / 2 + 0.1, 1.0
2072
+ )
2073
+ return TimingDecision(
2074
+ should_call_tool=rule_decision.should_call_tool,
2075
+ confidence=combined_conf,
2076
+ reasoning=f"[Hybrid-agree] Rule: {rule_decision.reasoning}, LLM: {llm_decision.reasoning}",
2077
+ )
2078
+ else:
2079
+ # Disagreement - prefer higher confidence
2080
+ if llm_decision.confidence > rule_decision.confidence:
2081
+ return TimingDecision(
2082
+ should_call_tool=llm_decision.should_call_tool,
2083
+ confidence=llm_decision.confidence * 0.9,
2084
+ reasoning=f"[Hybrid-LLM] {llm_decision.reasoning}",
2085
+ )
2086
+ else:
2087
+ return TimingDecision(
2088
+ should_call_tool=rule_decision.should_call_tool,
2089
+ confidence=rule_decision.confidence * 0.9,
2090
+ reasoning=f"[Hybrid-Rule] {rule_decision.reasoning}",
2091
+ )
2092
+ except Exception:
2093
+ pass # Fall back to rule-based
2094
+
2095
+ # Return rule-based result
2096
+ return TimingDecision(
2097
+ should_call_tool=rule_decision.should_call_tool,
2098
+ confidence=rule_decision.confidence,
2099
+ reasoning=f"[Rule-fallback] {rule_decision.reasoning}",
2100
+ )
2101
+
2102
+ # Create rule-based decider
2103
+ rule_adapter = self._create_rule_based_decider(resources)
2104
+ rule_decider = rule_adapter.decider
2105
+
2106
+ # Try to create LLM decider
2107
+ llm_decider = None
2108
+ try:
2109
+ llm_adapter = self._create_llm_timing_decider(resources)
2110
+ llm_decider = llm_adapter.decider
2111
+ except Exception:
2112
+ pass
2113
+
2114
+ return TimingAdapter(HybridDecider(rule_decider, llm_decider))
2115
+
2116
+ def _create_embedding_timing_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
2117
+ """Create embedding-based timing decider using SAGE's EmbeddingService."""
2118
+
2119
+ class EmbeddingTimingDecider:
2120
+ """
2121
+ Embedding-based timing decider using semantic similarity.
2122
+
2123
+ Uses pre-computed embeddings of typical "tool-needed" and "no-tool-needed"
2124
+ messages to classify new queries via cosine similarity.
2125
+ """
2126
+
2127
+ # Representative examples for each class
2128
+ TOOL_NEEDED_EXAMPLES = [
2129
+ "What's the weather like in New York right now?",
2130
+ "Search for the latest news about AI",
2131
+ "Calculate the compound interest on $10000",
2132
+ "What's the current stock price of AAPL?",
2133
+ "Send an email to my team",
2134
+ "Create a new spreadsheet with sales data",
2135
+ "What time is it in Tokyo?",
2136
+ "Book a flight from NYC to London",
2137
+ "Find restaurants near me",
2138
+ "Download the latest report",
2139
+ ]
2140
+
2141
+ NO_TOOL_EXAMPLES = [
2142
+ "What is the capital of France?",
2143
+ "Explain quantum computing to me",
2144
+ "Who invented the telephone?",
2145
+ "What does photosynthesis mean?",
2146
+ "Write a poem about nature",
2147
+ "What are the pros and cons of remote work?",
2148
+ "How does machine learning work?",
2149
+ "Tell me a story about a brave knight",
2150
+ "What is 2 + 2?",
2151
+ "Thank you for your help!",
2152
+ ]
2153
+
2154
+ def __init__(self):
2155
+ self._embedder = None
2156
+ self._tool_needed_embeddings = None
2157
+ self._no_tool_embeddings = None
2158
+ self._initialized = False
2159
+
2160
+ def _ensure_initialized(self):
2161
+ """Lazy initialization of embedder and example embeddings."""
2162
+ if self._initialized:
2163
+ return self._embedder is not None
2164
+
2165
+ self._initialized = True
2166
+ try:
2167
+ import os
2168
+
2169
+ from sage.common.components.sage_embedding import (
2170
+ get_embedding_model,
2171
+ )
2172
+
2173
+ # Choose embedding method based on environment
2174
+ method = os.getenv("SAGE_EMBEDDING_METHOD", "hash")
2175
+
2176
+ if method == "hf":
2177
+ try:
2178
+ self._embedder = get_embedding_model(
2179
+ "hf", model=BENCHMARK_EMBEDDING_MODEL
2180
+ )
2181
+ except Exception:
2182
+ self._embedder = get_embedding_model("hash", dim=384)
2183
+ else:
2184
+ self._embedder = get_embedding_model("hash", dim=384)
2185
+
2186
+ # Pre-compute example embeddings
2187
+ self._tool_needed_embeddings = [
2188
+ self._embedder.embed(ex) for ex in self.TOOL_NEEDED_EXAMPLES
2189
+ ]
2190
+ self._no_tool_embeddings = [
2191
+ self._embedder.embed(ex) for ex in self.NO_TOOL_EXAMPLES
2192
+ ]
2193
+ return True
2194
+ except Exception as e:
2195
+ import logging
2196
+
2197
+ logging.getLogger(__name__).warning(
2198
+ f"Failed to initialize embedding decider: {e}"
2199
+ )
2200
+ self._embedder = None
2201
+ return False
2202
+
2203
+ def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
2204
+ """Compute cosine similarity between two vectors."""
2205
+ import math
2206
+
2207
+ dot = sum(a * b for a, b in zip(vec1, vec2))
2208
+ norm1 = math.sqrt(sum(a * a for a in vec1))
2209
+ norm2 = math.sqrt(sum(b * b for b in vec2))
2210
+ if norm1 == 0 or norm2 == 0:
2211
+ return 0.0
2212
+ return dot / (norm1 * norm2)
2213
+
2214
+ def decide(self, message):
2215
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
2216
+ TimingDecision,
2217
+ )
2218
+
2219
+ if not self._ensure_initialized() or self._embedder is None:
2220
+ # Fallback to simple rule-based
2221
+ text = message.message.lower()
2222
+ action_kws = ["search", "find", "weather", "stock", "current", "now"]
2223
+ has_action = any(kw in text for kw in action_kws)
2224
+ return TimingDecision(
2225
+ should_call_tool=has_action,
2226
+ confidence=0.5,
2227
+ reasoning="[Fallback] Simple keyword match",
2228
+ )
2229
+
2230
+ try:
2231
+ # Embed the query
2232
+ query_embedding = self._embedder.embed(message.message)
2233
+
2234
+ # Compute average similarity to each class
2235
+ tool_sims = [
2236
+ self._cosine_similarity(query_embedding, ex_emb)
2237
+ for ex_emb in self._tool_needed_embeddings
2238
+ ]
2239
+ no_tool_sims = [
2240
+ self._cosine_similarity(query_embedding, ex_emb)
2241
+ for ex_emb in self._no_tool_embeddings
2242
+ ]
2243
+
2244
+ avg_tool_sim = sum(tool_sims) / len(tool_sims)
2245
+ avg_no_tool_sim = sum(no_tool_sims) / len(no_tool_sims)
2246
+
2247
+ # Decision based on higher average similarity
2248
+ should_call = avg_tool_sim > avg_no_tool_sim
2249
+ confidence = abs(avg_tool_sim - avg_no_tool_sim) + 0.5
2250
+ confidence = min(max(confidence, 0.5), 0.95)
2251
+
2252
+ return TimingDecision(
2253
+ should_call_tool=should_call,
2254
+ confidence=confidence,
2255
+ reasoning=f"[Embedding] tool_sim={avg_tool_sim:.3f}, no_tool_sim={avg_no_tool_sim:.3f}",
2256
+ )
2257
+ except Exception as e:
2258
+ return TimingDecision(
2259
+ should_call_tool=False,
2260
+ confidence=0.5,
2261
+ reasoning=f"[Embedding-error] {str(e)[:50]}",
2262
+ )
2263
+
2264
+ return TimingAdapter(EmbeddingTimingDecider())
2265
+
2266
+ def _create_mock_resources(self) -> Any:
2267
+ """Create mock resources for testing."""
2268
+ from sage.libs.agentic.agents.action.tool_selection import SelectorResources
2269
+
2270
+ class MockToolsLoader:
2271
+ """Mock tools loader for testing."""
2272
+
2273
+ def iter_all(self):
2274
+ """Yield mock tools."""
2275
+
2276
+ class MockTool:
2277
+ def __init__(self, tool_id, name, description, category):
2278
+ self.tool_id = tool_id
2279
+ self.name = name
2280
+ self.description = description
2281
+ self.category = category
2282
+
2283
+ # Generate some mock tools
2284
+ categories = ["search", "calculate", "data", "communication"]
2285
+ for i in range(50):
2286
+ cat = categories[i % len(categories)]
2287
+ yield MockTool(
2288
+ tool_id=f"tool_{i:03d}",
2289
+ name=f"{cat}_tool_{i}",
2290
+ description=f"A tool for {cat} operations, variant {i}",
2291
+ category=cat,
2292
+ )
2293
+
2294
+ return SelectorResources(
2295
+ tools_loader=MockToolsLoader(),
2296
+ embedding_client=None,
2297
+ )
2298
+
2299
+ def _create_sage_resources(self, embedding_client: Optional[Any] = None) -> Any:
2300
+ """
2301
+ Create SelectorResources with real SAGE-Bench tools.
2302
+
2303
+ This loads the 1,200 tools from tool_catalog.jsonl for use with
2304
+ Gorilla and DFSDT selectors that need to build a proper tool index.
2305
+
2306
+ Args:
2307
+ embedding_client: Optional embedding client for semantic search
2308
+
2309
+ Returns:
2310
+ SelectorResources with SageToolsLoader
2311
+ """
2312
+ from sage.libs.agentic.agents.action.tool_selection import SelectorResources
2313
+
2314
+ try:
2315
+ from sage.benchmark.benchmark_agent.tools_loader import (
2316
+ get_sage_tools_loader,
2317
+ )
2318
+
2319
+ tools_loader = get_sage_tools_loader()
2320
+ self.logger.info(f"Using SAGE tools loader with {len(tools_loader)} tools")
2321
+ except Exception as e:
2322
+ self.logger.warning(f"Failed to load SAGE tools: {e}. Falling back to mock tools.")
2323
+ return self._create_mock_resources()
2324
+
2325
+ return SelectorResources(
2326
+ tools_loader=tools_loader,
2327
+ embedding_client=embedding_client,
2328
+ )
2329
+
2330
+ def _create_simple_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
2331
+ """
2332
+ Create simple planner using embedding-based tool matching.
2333
+
2334
+ Uses EmbeddingFactory for semantic similarity matching.
2335
+ """
2336
+
2337
+ class EmbeddingBasedPlanner:
2338
+ """Planner using embedding similarity for tool selection."""
2339
+
2340
+ def __init__(self):
2341
+ self._embedder = None
2342
+ self._tool_embeddings_cache: dict[str, list[float]] = {}
2343
+
2344
+ def _get_embedder(self):
2345
+ """Lazy initialization of embedder."""
2346
+ if self._embedder is None:
2347
+ try:
2348
+ from sage.common.components.sage_embedding import (
2349
+ get_embedding_model,
2350
+ )
2351
+
2352
+ # Try to use HF model, fallback to hash
2353
+ try:
2354
+ self._embedder = get_embedding_model(
2355
+ "hf", model=BENCHMARK_EMBEDDING_MODEL
2356
+ )
2357
+ except Exception:
2358
+ self._embedder = get_embedding_model("hash", dim=384)
2359
+ except Exception:
2360
+ pass
2361
+ return self._embedder
2362
+
2363
+ def _compute_similarity(self, vec1: list[float], vec2: list[float]) -> float:
2364
+ """Compute cosine similarity."""
2365
+ import math
2366
+
2367
+ dot = sum(a * b for a, b in zip(vec1, vec2))
2368
+ norm1 = math.sqrt(sum(a * a for a in vec1))
2369
+ norm2 = math.sqrt(sum(b * b for b in vec2))
2370
+ if norm1 == 0 or norm2 == 0:
2371
+ return 0.0
2372
+ return dot / (norm1 * norm2)
2373
+
2374
+ def plan(self, task):
2375
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
2376
+ PlanningPrediction,
2377
+ PlanStep,
2378
+ )
2379
+
2380
+ instruction = getattr(task, "instruction", "") or ""
2381
+ available_tools = getattr(task, "available_tools", []) or []
2382
+
2383
+ if not available_tools:
2384
+ return PlanningPrediction(steps=[], tool_sequence=[])
2385
+
2386
+ embedder = self._get_embedder()
2387
+
2388
+ # Decompose instruction into sub-tasks
2389
+ sub_tasks = self._decompose_instruction(instruction)
2390
+
2391
+ steps = []
2392
+ tool_sequence = []
2393
+ used_tools = set()
2394
+
2395
+ for i, sub_task in enumerate(sub_tasks):
2396
+ best_tool = self._match_tool_semantic(
2397
+ sub_task, available_tools, used_tools, embedder
2398
+ )
2399
+ if best_tool:
2400
+ steps.append(
2401
+ PlanStep(
2402
+ step_id=i,
2403
+ description=sub_task,
2404
+ tool_id=best_tool,
2405
+ confidence=0.7,
2406
+ )
2407
+ )
2408
+ tool_sequence.append(best_tool)
2409
+ used_tools.add(best_tool)
2410
+
2411
+ return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
2412
+
2413
+ def _decompose_instruction(self, instruction: str) -> list[str]:
2414
+ """Decompose instruction into sub-tasks."""
2415
+ delimiters = [", and ", " and ", ", then ", ", ", ". ", "; "]
2416
+ sub_tasks = [instruction]
2417
+
2418
+ for delim in delimiters:
2419
+ new_tasks = []
2420
+ for task_item in sub_tasks:
2421
+ parts = task_item.split(delim)
2422
+ new_tasks.extend([p.strip() for p in parts if p.strip()])
2423
+ sub_tasks = new_tasks
2424
+
2425
+ sub_tasks = [t for t in sub_tasks if len(t) > 5]
2426
+ return sub_tasks[:10] if sub_tasks else [instruction]
2427
+
2428
+ def _match_tool_semantic(
2429
+ self,
2430
+ sub_task: str,
2431
+ available_tools: list[str],
2432
+ used_tools: set[str],
2433
+ embedder,
2434
+ ) -> str | None:
2435
+ """Match sub-task to tool using semantic similarity."""
2436
+ if embedder is None:
2437
+ # Fallback to keyword matching
2438
+ return self._match_tool_keyword(sub_task, available_tools, used_tools)
2439
+
2440
+ try:
2441
+ # Get sub-task embedding
2442
+ task_vec = embedder.embed(sub_task)
2443
+
2444
+ best_tool = None
2445
+ best_score = -1.0
2446
+
2447
+ for tool in available_tools:
2448
+ # Get tool embedding (with cache)
2449
+ if tool not in self._tool_embeddings_cache:
2450
+ # Create description from tool name
2451
+ tool_desc = tool.replace("_", " ")
2452
+ self._tool_embeddings_cache[tool] = embedder.embed(tool_desc)
2453
+
2454
+ tool_vec = self._tool_embeddings_cache[tool]
2455
+ similarity = self._compute_similarity(task_vec, tool_vec)
2456
+
2457
+ # Penalty for already used tools
2458
+ if tool in used_tools:
2459
+ similarity *= 0.5
2460
+
2461
+ if similarity > best_score:
2462
+ best_score = similarity
2463
+ best_tool = tool
2464
+
2465
+ return best_tool
2466
+ except Exception:
2467
+ return self._match_tool_keyword(sub_task, available_tools, used_tools)
2468
+
2469
+ def _match_tool_keyword(
2470
+ self, sub_task: str, available_tools: list[str], used_tools: set[str]
2471
+ ) -> str | None:
2472
+ """Fallback keyword matching."""
2473
+ sub_task_lower = sub_task.lower()
2474
+ best_tool = None
2475
+ best_score: float = 0.0
2476
+
2477
+ for tool in available_tools:
2478
+ tool_lower = tool.lower()
2479
+ score: float = 0.0
2480
+
2481
+ for part in tool_lower.split("_"):
2482
+ if part in sub_task_lower and len(part) > 2:
2483
+ score += 2
2484
+
2485
+ if tool in used_tools:
2486
+ score *= 0.5
2487
+
2488
+ if score > best_score:
2489
+ best_score = score
2490
+ best_tool = tool
2491
+
2492
+ if best_tool is None:
2493
+ for tool in available_tools:
2494
+ if tool not in used_tools:
2495
+ return tool
2496
+
2497
+ return best_tool
2498
+
2499
+ return PlannerAdapter(EmbeddingBasedPlanner())
2500
+
2501
+ def _create_hierarchical_planning_strategy(
2502
+ self, resources: Optional[Any] = None
2503
+ ) -> PlannerAdapter:
2504
+ """
2505
+ Create hierarchical planner for Challenge 2.
2506
+
2507
+ Uses task decomposition and dependency analysis.
2508
+ """
2509
+
2510
+ class HierarchicalPlanningStrategy:
2511
+ """Hierarchical planner with dependency management."""
2512
+
2513
+ def plan(self, task):
2514
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
2515
+ PlanningPrediction,
2516
+ PlanStep,
2517
+ )
2518
+
2519
+ instruction = getattr(task, "instruction", "") or ""
2520
+ available_tools = getattr(task, "available_tools", []) or []
2521
+
2522
+ if not available_tools:
2523
+ return PlanningPrediction(steps=[], tool_sequence=[])
2524
+
2525
+ # Decompose instruction into sub-tasks
2526
+ sub_tasks = self._decompose_instruction(instruction)
2527
+
2528
+ # Map sub-tasks to tools
2529
+ steps = []
2530
+ tool_sequence = []
2531
+ used_tools = set()
2532
+
2533
+ for i, sub_task in enumerate(sub_tasks):
2534
+ best_tool = self._match_tool(sub_task, available_tools, used_tools)
2535
+ if best_tool:
2536
+ steps.append(
2537
+ PlanStep(
2538
+ step_id=i,
2539
+ description=sub_task,
2540
+ tool_id=best_tool,
2541
+ confidence=0.75,
2542
+ )
2543
+ )
2544
+ tool_sequence.append(best_tool)
2545
+ used_tools.add(best_tool)
2546
+
2547
+ return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
2548
+
2549
+ def _decompose_instruction(self, instruction: str) -> list[str]:
2550
+ """Decompose instruction into sub-tasks."""
2551
+ # Split by common delimiters
2552
+ delimiters = [", and ", " and ", ", then ", ", ", ". ", "; "]
2553
+ sub_tasks = [instruction]
2554
+
2555
+ for delim in delimiters:
2556
+ new_tasks = []
2557
+ for task in sub_tasks:
2558
+ parts = task.split(delim)
2559
+ new_tasks.extend([p.strip() for p in parts if p.strip()])
2560
+ sub_tasks = new_tasks
2561
+
2562
+ # Filter out very short tasks
2563
+ sub_tasks = [t for t in sub_tasks if len(t) > 5]
2564
+
2565
+ # Limit to reasonable number
2566
+ return sub_tasks[:10] if sub_tasks else [instruction]
2567
+
2568
+ def _match_tool(
2569
+ self, sub_task: str, available_tools: list[str], used_tools: set[str]
2570
+ ) -> str | None:
2571
+ """Match a sub-task to the best available tool."""
2572
+ sub_task_lower = sub_task.lower()
2573
+
2574
+ # Tool type indicators
2575
+ type_keywords = {
2576
+ "file_read": ["read", "load", "open file"],
2577
+ "file_write": ["write", "save", "store file"],
2578
+ "file_list": ["list files", "directory"],
2579
+ "file_copy": ["copy", "backup"],
2580
+ "file_delete": ["delete file", "remove file"],
2581
+ "data_parse_json": ["parse", "json"],
2582
+ "data_transform": ["transform", "convert data"],
2583
+ "data_filter": ["filter", "select records"],
2584
+ "data_aggregate": ["aggregate", "sum", "total"],
2585
+ "data_validate": ["validate", "check schema"],
2586
+ "http_get": ["fetch", "get data", "download", "api get"],
2587
+ "http_post": ["post", "submit", "send data"],
2588
+ "api_authenticate": ["authenticate", "login", "auth"],
2589
+ "web_scrape": ["scrape", "extract from web"],
2590
+ "db_connect": ["connect database", "db connect"],
2591
+ "db_query": ["query", "select from", "database query"],
2592
+ "db_insert": ["insert", "add record"],
2593
+ "db_update": ["update record", "modify"],
2594
+ "db_delete": ["delete record"],
2595
+ "email_send": ["email", "send mail"],
2596
+ "notification_send": ["notify", "notification"],
2597
+ "slack_post": ["slack", "post message"],
2598
+ "text_analyze": ["analyze text", "text analysis"],
2599
+ "sentiment_analyze": ["sentiment"],
2600
+ "stats_compute": ["statistics", "compute stats"],
2601
+ "math_calculate": ["calculate", "math"],
2602
+ "format_json": ["format json", "to json"],
2603
+ "format_csv": ["csv", "format csv"],
2604
+ "format_html": ["html", "format html", "report"],
2605
+ "cache_get": ["get cache", "retrieve cache"],
2606
+ "cache_set": ["cache", "set cache", "store cache"],
2607
+ "log_write": ["log", "write log"],
2608
+ "metrics_record": ["metrics", "record metric"],
2609
+ "image_resize": ["resize", "thumbnail"],
2610
+ "image_convert": ["convert image", "png", "jpg"],
2611
+ "schedule_task": ["schedule"],
2612
+ "get_calendar": ["calendar", "events"],
2613
+ "code_lint": ["lint"],
2614
+ "code_format": ["format code"],
2615
+ "code_execute": ["execute", "run code"],
2616
+ "search_web": ["search web", "web search"],
2617
+ "search_documents": ["search document"],
2618
+ "search_database": ["search database"],
2619
+ "convert_units": ["convert units"],
2620
+ }
2621
+
2622
+ best_tool = None
2623
+ best_score: float = 0.0
2624
+
2625
+ for tool in available_tools:
2626
+ if tool in used_tools:
2627
+ # Prefer unused tools but allow reuse with penalty
2628
+ penalty = 0.5
2629
+ else:
2630
+ penalty = 1.0
2631
+
2632
+ score: float = 0.0
2633
+ tool_lower = tool.lower()
2634
+
2635
+ # Check against type keywords
2636
+ if tool in type_keywords:
2637
+ for kw in type_keywords[tool]:
2638
+ if kw in sub_task_lower:
2639
+ score += 3 * penalty
2640
+
2641
+ # Check tool name parts
2642
+ for part in tool_lower.split("_"):
2643
+ if part in sub_task_lower and len(part) > 2:
2644
+ score += 2 * penalty
2645
+
2646
+ if score > best_score:
2647
+ best_score = score
2648
+ best_tool = tool
2649
+
2650
+ # Fallback: pick first unused tool
2651
+ if best_tool is None:
2652
+ for tool in available_tools:
2653
+ if tool not in used_tools:
2654
+ return tool
2655
+
2656
+ return best_tool
2657
+
2658
+ return PlannerAdapter(HierarchicalPlanningStrategy())
2659
+
2660
+ def _create_llm_planning_strategy(self, resources: Optional[Any] = None) -> PlannerAdapter:
2661
+ """
2662
+ Create LLM-based planner using UnifiedInferenceClient.
2663
+
2664
+ Uses real LLM for plan generation with semantic understanding.
2665
+ """
2666
+
2667
+ class LLMPlanningStrategy:
2668
+ """LLM-based planner using UnifiedInferenceClient."""
2669
+
2670
+ def __init__(self, fallback_planner):
2671
+ self._fallback = fallback_planner
2672
+ self._llm_client = None
2673
+ self._client_initialized = False
2674
+
2675
+ def _get_llm_client(self):
2676
+ """Lazy initialization of LLM client using UnifiedInferenceClient.
2677
+
2678
+ Uses UnifiedInferenceClient.create() which handles:
2679
+ 1. Local vLLM API service detection (via SagePorts)
2680
+ 2. Cloud API fallback (via SAGE_CHAT_* env vars)
2681
+ """
2682
+ if not self._client_initialized:
2683
+ self._client_initialized = True
2684
+
2685
+ from sage.llm import UnifiedInferenceClient
2686
+
2687
+ try:
2688
+ self._llm_client = UnifiedInferenceClient.create()
2689
+ # Log which mode we're using
2690
+ if self._llm_client._llm_base_url:
2691
+ if "localhost" in self._llm_client._llm_base_url:
2692
+ print(f"✅ 使用本地 LLM: {self._llm_client._llm_model}")
2693
+ else:
2694
+ print(f"☁️ 使用云端 API: {self._llm_client._llm_model}")
2695
+ else:
2696
+ print("⚠️ LLM 客户端初始化但无可用端点")
2697
+ except Exception as e:
2698
+ print(f"⚠️ 无可用 LLM 服务: {e}")
2699
+ print(" 启动本地服务: sage studio start")
2700
+ print(" 或配置云端: export SAGE_CHAT_API_KEY=your_key")
2701
+
2702
+ return self._llm_client
2703
+
2704
+ def plan(self, task):
2705
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
2706
+ PlanningPrediction,
2707
+ PlanStep,
2708
+ )
2709
+
2710
+ instruction = getattr(task, "instruction", "") or ""
2711
+ available_tools = getattr(task, "available_tools", []) or []
2712
+
2713
+ if not available_tools:
2714
+ return PlanningPrediction(steps=[], tool_sequence=[])
2715
+
2716
+ # Initialize LLM client
2717
+ client = self._get_llm_client()
2718
+ if client is None:
2719
+ return self._fallback.plan(task)
2720
+
2721
+ # Build prompt for plan generation
2722
+ tools_desc = ", ".join(available_tools[:20]) # Limit for prompt
2723
+ prompt = f"""Generate a step-by-step plan to accomplish this task.
2724
+
2725
+ Task: {instruction}
2726
+
2727
+ Available tools: {tools_desc}
2728
+
2729
+ Return a JSON array of steps, each with:
2730
+ - "tool_id": the tool to use (must be from available tools)
2731
+ - "description": brief description of what this step does
2732
+
2733
+ Return ONLY the JSON array, no explanation. Example:
2734
+ [{{"tool_id": "file_read", "description": "Read config file"}}]"""
2735
+
2736
+ try:
2737
+ messages = [
2738
+ {
2739
+ "role": "system",
2740
+ "content": "You are a task planning assistant. Return only valid JSON.",
2741
+ },
2742
+ {"role": "user", "content": prompt},
2743
+ ]
2744
+ response = client.chat(messages)
2745
+ except Exception:
2746
+ return self._fallback.plan(task)
2747
+
2748
+ # Parse response
2749
+ if response:
2750
+ try:
2751
+ import json
2752
+ import re
2753
+
2754
+ # Extract JSON from response
2755
+ text = response if isinstance(response, str) else str(response)
2756
+
2757
+ # Try to find JSON array
2758
+ json_match = re.search(r"\[.*\]", text, re.DOTALL)
2759
+ if json_match:
2760
+ plan_data = json.loads(json_match.group())
2761
+
2762
+ steps = []
2763
+ tool_sequence = []
2764
+ for i, step in enumerate(plan_data):
2765
+ tool_id = step.get("tool_id", "")
2766
+ # Validate tool is available
2767
+ if tool_id in available_tools:
2768
+ steps.append(
2769
+ PlanStep(
2770
+ step_id=i,
2771
+ description=step.get("description", f"Step {i + 1}"),
2772
+ tool_id=tool_id,
2773
+ confidence=0.85,
2774
+ )
2775
+ )
2776
+ tool_sequence.append(tool_id)
2777
+
2778
+ if steps:
2779
+ return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
2780
+ except Exception:
2781
+ pass
2782
+
2783
+ # Fallback to hierarchical planner
2784
+ return self._fallback.plan(task)
2785
+
2786
+ # Create hierarchical fallback
2787
+ hierarchical = self._create_hierarchical_planning_strategy(resources)
2788
+ return PlannerAdapter(LLMPlanningStrategy(hierarchical.planner))
2789
+
2790
+ def _create_tot_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
2791
+ """
2792
+ Create Tree-of-Thoughts planner for Challenge 2.
2793
+
2794
+ Uses tree search to explore multiple reasoning paths.
2795
+ Based on "Tree of Thoughts: Deliberate Problem Solving with LLMs" (Yao et al., 2023)
2796
+ """
2797
+
2798
+ class ToTPlanningStrategy:
2799
+ """
2800
+ Tree-of-Thoughts planning strategy.
2801
+
2802
+ Explores multiple reasoning paths via BFS/DFS tree search,
2803
+ using LLM to generate and evaluate thought candidates.
2804
+ """
2805
+
2806
+ def __init__(self, fallback_planner):
2807
+ self._fallback = fallback_planner
2808
+ self._llm_client = None
2809
+ self._client_initialized = False
2810
+ # ToT configuration
2811
+ self._max_depth = 3
2812
+ self._branch_factor = 3
2813
+ self._beam_width = 5
2814
+ self._min_score = 0.3
2815
+
2816
+ def _get_llm_client(self):
2817
+ """Lazy initialization of LLM client."""
2818
+ if self._client_initialized:
2819
+ return self._llm_client
2820
+
2821
+ self._client_initialized = True
2822
+ try:
2823
+ from sage.llm import UnifiedInferenceClient
2824
+
2825
+ # Use singleton to avoid repeated model loading
2826
+ self._llm_client = UnifiedInferenceClient.get_instance(
2827
+ instance_key="benchmark_planner"
2828
+ )
2829
+ except Exception:
2830
+ self._llm_client = None
2831
+
2832
+ return self._llm_client
2833
+
2834
+ def plan(self, task):
2835
+ """Generate plan using Tree-of-Thoughts search."""
2836
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
2837
+ PlanningPrediction,
2838
+ PlanStep,
2839
+ )
2840
+
2841
+ instruction = getattr(task, "instruction", "") or ""
2842
+ available_tools = getattr(task, "available_tools", []) or []
2843
+
2844
+ if not available_tools:
2845
+ return PlanningPrediction(steps=[], tool_sequence=[])
2846
+
2847
+ # Get LLM client
2848
+ llm_client = self._get_llm_client()
2849
+
2850
+ # If no LLM, fall back to hierarchical
2851
+ if llm_client is None:
2852
+ return self._fallback.plan(task)
2853
+
2854
+ try:
2855
+ # Run ToT search
2856
+ best_path = self._tot_search(instruction, available_tools, llm_client)
2857
+
2858
+ if best_path:
2859
+ steps = []
2860
+ tool_sequence = []
2861
+ for i, (thought, tool_id) in enumerate(best_path):
2862
+ if tool_id and tool_id in available_tools:
2863
+ steps.append(
2864
+ PlanStep(
2865
+ step_id=i,
2866
+ description=thought,
2867
+ tool_id=tool_id,
2868
+ confidence=0.8,
2869
+ )
2870
+ )
2871
+ tool_sequence.append(tool_id)
2872
+
2873
+ if steps:
2874
+ return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
2875
+ except Exception:
2876
+ pass
2877
+
2878
+ # Fallback to hierarchical planner
2879
+ return self._fallback.plan(task)
2880
+
2881
+ def _tot_search(
2882
+ self, instruction: str, available_tools: list[str], llm_client
2883
+ ) -> list[tuple[str, str]]:
2884
+ """
2885
+ Perform Tree-of-Thoughts BFS search.
2886
+
2887
+ Returns list of (thought, tool_id) tuples.
2888
+ """
2889
+ # Initialize queue with empty path
2890
+ queue: list[list[tuple[str, str, float]]] = [[]] # Paths
2891
+
2892
+ tools_str = ", ".join(available_tools[:15])
2893
+
2894
+ for depth in range(self._max_depth):
2895
+ next_queue: list[list[tuple[str, str, float]]] = []
2896
+
2897
+ for path in queue:
2898
+ # Generate candidate thoughts
2899
+ candidates = self._generate_thoughts(
2900
+ instruction, path, available_tools, llm_client, tools_str
2901
+ )
2902
+
2903
+ for thought, tool_id, score in candidates:
2904
+ if score >= self._min_score:
2905
+ new_path = path + [(thought, tool_id, score)]
2906
+ next_queue.append(new_path)
2907
+
2908
+ if not next_queue:
2909
+ break
2910
+
2911
+ # Keep top-k paths by average score
2912
+ next_queue.sort(
2913
+ key=lambda p: sum(s for _, _, s in p) / len(p) if p else 0, reverse=True
2914
+ )
2915
+ queue = next_queue[: self._beam_width]
2916
+
2917
+ # Return best path
2918
+ if queue:
2919
+ best_path = max(
2920
+ queue, key=lambda p: sum(s for _, _, s in p) / len(p) if p else 0
2921
+ )
2922
+ return [(t, tid) for t, tid, _ in best_path]
2923
+
2924
+ return []
2925
+
2926
+ def _generate_thoughts(
2927
+ self,
2928
+ instruction: str,
2929
+ path: list[tuple[str, str, float]],
2930
+ available_tools: list[str],
2931
+ llm_client,
2932
+ tools_str: str,
2933
+ ) -> list[tuple[str, str, float]]:
2934
+ """Generate and evaluate candidate thoughts."""
2935
+ import json
2936
+ import re
2937
+
2938
+ # Format current progress
2939
+ progress = ""
2940
+ if path:
2941
+ progress = "\n".join(
2942
+ f"Step {i + 1}: {t} (tool: {tid})" for i, (t, tid, _) in enumerate(path)
2943
+ )
2944
+ else:
2945
+ progress = "No steps taken yet."
2946
+
2947
+ # Get used tools
2948
+ used_tools = {tid for _, tid, _ in path if tid}
2949
+
2950
+ # Generate prompt
2951
+ prompt = f"""You are a planning assistant. Generate {self._branch_factor} different possible next steps.
2952
+
2953
+ Task: {instruction}
2954
+ Available tools: {tools_str}
2955
+ Current progress:
2956
+ {progress}
2957
+
2958
+ Generate {self._branch_factor} different next steps. Each should use an available tool.
2959
+ Avoid tools already used: {", ".join(used_tools) if used_tools else "none"}
2960
+
2961
+ Output as JSON array:
2962
+ [{{"thought": "step description", "tool_id": "tool_name", "score": 0-10}}]
2963
+
2964
+ Only output JSON, nothing else."""
2965
+
2966
+ try:
2967
+ response = llm_client.chat(
2968
+ [{"role": "user", "content": prompt}],
2969
+ max_tokens=512,
2970
+ temperature=BENCHMARK_LLM_TEMPERATURE,
2971
+ )
2972
+
2973
+ # Parse response
2974
+ text = response if isinstance(response, str) else str(response)
2975
+ json_match = re.search(r"\[.*\]", text, re.DOTALL)
2976
+ if json_match:
2977
+ candidates = json.loads(json_match.group())
2978
+ result = []
2979
+ for c in candidates:
2980
+ thought = c.get("thought", "")
2981
+ tool_id = c.get("tool_id", "")
2982
+ score = c.get("score", 5) / 10.0
2983
+
2984
+ # Validate tool
2985
+ if tool_id not in available_tools:
2986
+ # Try to find closest match
2987
+ for tool in available_tools:
2988
+ if tool not in used_tools:
2989
+ tool_id = tool
2990
+ break
2991
+ else:
2992
+ continue
2993
+
2994
+ result.append((thought, tool_id, score))
2995
+
2996
+ return result[: self._branch_factor]
2997
+ except Exception:
2998
+ pass
2999
+
3000
+ # Fallback: return one thought per unused tool
3001
+ result = []
3002
+ for tool in available_tools:
3003
+ if tool not in used_tools and len(result) < self._branch_factor:
3004
+ result.append((f"Use {tool} for task", tool, 0.5))
3005
+ return result
3006
+
3007
+ # Create hierarchical fallback
3008
+ hierarchical = self._create_hierarchical_planning_strategy(resources)
3009
+ return PlannerAdapter(ToTPlanningStrategy(hierarchical.planner))
3010
+
3011
+
3012
+ # Global registry instance
3013
+ _global_registry: Optional[AdapterRegistry] = None
3014
+
3015
+
3016
+ def get_adapter_registry() -> AdapterRegistry:
3017
+ """Get the global adapter registry instance."""
3018
+ global _global_registry
3019
+ if _global_registry is None:
3020
+ _global_registry = AdapterRegistry()
3021
+ return _global_registry
3022
+
3023
+
3024
+ def register_strategy(name: str, strategy: Any) -> None:
3025
+ """Register a strategy in the global registry."""
3026
+ get_adapter_registry().register(name, strategy)
3027
+
3028
+
3029
+ __all__ = [
3030
+ "AdapterRegistry",
3031
+ "SelectorAdapter",
3032
+ "PlannerAdapter",
3033
+ "TimingAdapter",
3034
+ "get_adapter_registry",
3035
+ "register_strategy",
3036
+ ]