isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,3036 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Strategy Adapter Registry
|
|
3
|
+
|
|
4
|
+
Provides a unified registry for mapping strategy names (e.g., "selector.keyword")
|
|
5
|
+
to actual selector/planner/timing implementations from sage-libs.
|
|
6
|
+
|
|
7
|
+
This bridges the benchmark experiments with the runtime components.
|
|
8
|
+
|
|
9
|
+
Method Classification:
|
|
10
|
+
=====================
|
|
11
|
+
|
|
12
|
+
Paper 1 (Benchmark) - Existing SOTA Methods (Runtime Evaluation Only):
|
|
13
|
+
Tool Selection: keyword, embedding, hybrid, gorilla, dfsdt/toolllm
|
|
14
|
+
Planning: simple, hierarchical, llm_based, react, tot
|
|
15
|
+
Timing: rule_based, llm_based, hybrid, embedding
|
|
16
|
+
|
|
17
|
+
Note: Paper 1 is a BENCHMARK paper that evaluates existing methods
|
|
18
|
+
using pre-trained models. It does NOT involve training.
|
|
19
|
+
|
|
20
|
+
Paper 2 (SIAS Method) - Training Strategies (in sage.libs.sias):
|
|
21
|
+
Core Components:
|
|
22
|
+
- CoresetSelector: Intelligent sample selection (loss_topk, diversity, hybrid)
|
|
23
|
+
- OnlineContinualLearner: Experience replay buffer with importance weighting
|
|
24
|
+
- SSIS: Streaming Sample Importance Scorer (TODO)
|
|
25
|
+
- Priority Replay: Importance-Weighted Experience Buffer (TODO)
|
|
26
|
+
|
|
27
|
+
Training Configurations:
|
|
28
|
+
- SIAS_sft_baseline: Standard SFT (ablation baseline)
|
|
29
|
+
- SIAS_coreset: + Coreset selection
|
|
30
|
+
- SIAS_continual: + Continual learning with replay
|
|
31
|
+
- SIAS_full: Complete SIAS framework
|
|
32
|
+
|
|
33
|
+
Import: from sage.libs.sias import CoresetSelector, OnlineContinualLearner
|
|
34
|
+
|
|
35
|
+
Usage:
|
|
36
|
+
>>> registry = get_adapter_registry()
|
|
37
|
+
>>> selector = registry.get("selector.keyword", resources)
|
|
38
|
+
>>> planner = registry.get("planner.react", resources)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from typing import Any, Callable, Optional, Protocol
|
|
42
|
+
|
|
43
|
+
# =============================================================================
|
|
44
|
+
# Experiment Configuration Constants (for controlled variable experiments)
|
|
45
|
+
# =============================================================================
|
|
46
|
+
# These constants ensure all methods use the same underlying models/parameters
|
|
47
|
+
# to enable fair comparison (控制变量法)
|
|
48
|
+
|
|
49
|
+
# Embedding model: all embedding-based methods use the same model
|
|
50
|
+
BENCHMARK_EMBEDDING_MODEL = "BAAI/bge-small-zh-v1.5"
|
|
51
|
+
|
|
52
|
+
# LLM temperature: low value for reproducibility in benchmark evaluation
|
|
53
|
+
BENCHMARK_LLM_TEMPERATURE = 0.1
|
|
54
|
+
|
|
55
|
+
# =============================================================================
|
|
56
|
+
|
|
57
|
+
# Lazy imports to avoid circular dependencies
|
|
58
|
+
_SELECTOR_REGISTRY = None
|
|
59
|
+
_PLANNER_CLASS = None
|
|
60
|
+
_TIMING_DECIDER_CLASS = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class StrategyProtocol(Protocol):
|
|
64
|
+
"""Protocol for strategy adapters."""
|
|
65
|
+
|
|
66
|
+
def predict(self, query: Any, **kwargs) -> Any:
|
|
67
|
+
"""Make a prediction."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class SelectorAdapter:
|
|
72
|
+
"""
|
|
73
|
+
Adapter wrapping tool selectors to provide unified predict()/select() interface.
|
|
74
|
+
|
|
75
|
+
Maps the selector's select() method to predict() for benchmark compatibility.
|
|
76
|
+
Also provides select() method for run_all_experiments.py compatibility.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(self, selector: Any):
|
|
80
|
+
"""
|
|
81
|
+
Initialize adapter.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
selector: A tool selector instance with select() method
|
|
85
|
+
"""
|
|
86
|
+
self.selector = selector
|
|
87
|
+
|
|
88
|
+
def predict(self, query: Any, top_k: Optional[int] = None, **kwargs) -> list:
|
|
89
|
+
"""
|
|
90
|
+
Make tool selection prediction.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
query: ToolSelectionQuery from experiment
|
|
94
|
+
top_k: Number of tools to select
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
List of ToolPrediction objects
|
|
98
|
+
"""
|
|
99
|
+
# Convert experiment query to selector query format
|
|
100
|
+
from sage.libs.agentic.agents.action.tool_selection.schemas import (
|
|
101
|
+
ToolSelectionQuery as SelectorQuery,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
selector_query = SelectorQuery(
|
|
105
|
+
sample_id=query.sample_id,
|
|
106
|
+
instruction=query.instruction,
|
|
107
|
+
candidate_tools=getattr(query, "candidate_tools", None),
|
|
108
|
+
context=getattr(query, "context", {}),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
k = top_k if top_k is not None else 5
|
|
112
|
+
return self.selector.select(selector_query, top_k=k)
|
|
113
|
+
|
|
114
|
+
def select(self, query: Any, candidate_tools: Optional[list] = None, top_k: int = 5) -> list:
|
|
115
|
+
"""
|
|
116
|
+
Select tools for a query (alias for predict with simpler interface).
|
|
117
|
+
|
|
118
|
+
This method is provided for compatibility with run_all_experiments.py
|
|
119
|
+
which calls selector.select(query, candidate_tools, top_k=top_k).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
query: Either a string (instruction) or ToolSelectionQuery object
|
|
123
|
+
candidate_tools: Optional list of candidate tools (may be ignored if
|
|
124
|
+
selector has its own tool corpus)
|
|
125
|
+
top_k: Number of tools to select
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List of ToolPrediction objects or tool IDs
|
|
129
|
+
"""
|
|
130
|
+
from sage.libs.agentic.agents.action.tool_selection.schemas import (
|
|
131
|
+
ToolSelectionQuery as SelectorQuery,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Ensure candidate_tools is always a list (never None)
|
|
135
|
+
tools = candidate_tools if candidate_tools is not None else []
|
|
136
|
+
|
|
137
|
+
# Handle string query (from run_all_experiments.py)
|
|
138
|
+
if isinstance(query, str):
|
|
139
|
+
selector_query = SelectorQuery(
|
|
140
|
+
sample_id="runtime",
|
|
141
|
+
instruction=query,
|
|
142
|
+
candidate_tools=tools,
|
|
143
|
+
context={},
|
|
144
|
+
)
|
|
145
|
+
# Handle dict query
|
|
146
|
+
elif isinstance(query, dict):
|
|
147
|
+
tools_from_dict = query.get("candidate_tools", [])
|
|
148
|
+
selector_query = SelectorQuery(
|
|
149
|
+
sample_id=query.get("sample_id", "runtime"),
|
|
150
|
+
instruction=query.get("instruction", str(query)),
|
|
151
|
+
candidate_tools=tools if tools else (tools_from_dict or []),
|
|
152
|
+
context=query.get("context", {}),
|
|
153
|
+
)
|
|
154
|
+
# Handle query object with attributes
|
|
155
|
+
elif hasattr(query, "instruction"):
|
|
156
|
+
tools_from_obj = getattr(query, "candidate_tools", [])
|
|
157
|
+
selector_query = SelectorQuery(
|
|
158
|
+
sample_id=getattr(query, "sample_id", "runtime"),
|
|
159
|
+
instruction=query.instruction,
|
|
160
|
+
candidate_tools=tools if tools else (tools_from_obj or []),
|
|
161
|
+
context=getattr(query, "context", {}),
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
# Fallback: treat query as instruction string
|
|
165
|
+
selector_query = SelectorQuery(
|
|
166
|
+
sample_id="runtime",
|
|
167
|
+
instruction=str(query),
|
|
168
|
+
candidate_tools=tools,
|
|
169
|
+
context={},
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
try:
|
|
173
|
+
result = self.selector.select(selector_query, top_k=top_k)
|
|
174
|
+
return result
|
|
175
|
+
except Exception as e:
|
|
176
|
+
# If selector fails, return empty list with debug info
|
|
177
|
+
import logging
|
|
178
|
+
|
|
179
|
+
logging.getLogger(__name__).warning(
|
|
180
|
+
f"Selector failed for query '{str(query)[:50]}...': {e}"
|
|
181
|
+
)
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class PlannerAdapter:
|
|
186
|
+
"""
|
|
187
|
+
Adapter wrapping planners to provide unified plan() interface.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self, planner: Any):
|
|
191
|
+
"""
|
|
192
|
+
Initialize adapter.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
planner: A planner instance with plan() method
|
|
196
|
+
"""
|
|
197
|
+
self.planner = planner
|
|
198
|
+
|
|
199
|
+
def plan(self, task: Any, **kwargs) -> Any:
|
|
200
|
+
"""
|
|
201
|
+
Generate a plan for the task.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
task: PlanningTask from experiment
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
PlanningPrediction with steps and tool_sequence
|
|
208
|
+
"""
|
|
209
|
+
return self.planner.plan(task)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class UnifiedTimingMessage:
|
|
213
|
+
"""
|
|
214
|
+
Unified timing message that provides both .message and .user_message attributes.
|
|
215
|
+
|
|
216
|
+
This ensures compatibility with:
|
|
217
|
+
- Local deciders in adapter_registry.py (expect .message)
|
|
218
|
+
- sage-libs timing_decider.py (expects .user_message)
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
user_message: str,
|
|
224
|
+
conversation_history: Optional[list[Any]] = None,
|
|
225
|
+
last_tool_call: Any = None,
|
|
226
|
+
context: Optional[dict[Any, Any]] = None,
|
|
227
|
+
):
|
|
228
|
+
self.user_message = user_message
|
|
229
|
+
self.message = user_message # Alias for compatibility with local deciders
|
|
230
|
+
self.conversation_history = conversation_history or []
|
|
231
|
+
self.last_tool_call = last_tool_call
|
|
232
|
+
self.context = context or {}
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class TimingAdapter:
|
|
236
|
+
"""
|
|
237
|
+
Adapter wrapping timing deciders to provide unified decide() interface.
|
|
238
|
+
|
|
239
|
+
Handles conversion between experiment TimingMessage format and
|
|
240
|
+
sage-libs schemas.TimingMessage format.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
def __init__(self, decider: Any):
|
|
244
|
+
"""
|
|
245
|
+
Initialize adapter.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
decider: A timing decider instance with decide() method
|
|
249
|
+
"""
|
|
250
|
+
self.decider = decider
|
|
251
|
+
|
|
252
|
+
def decide(self, message: Any, **kwargs) -> Any:
|
|
253
|
+
"""
|
|
254
|
+
Make timing decision.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
message: TimingMessage from experiment (has .message attribute)
|
|
258
|
+
or dict with 'instruction'/'message' key
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
TimingDecision with should_call_tool, confidence, reasoning
|
|
262
|
+
"""
|
|
263
|
+
# Convert to UnifiedTimingMessage which has both .message and .user_message
|
|
264
|
+
if isinstance(message, dict):
|
|
265
|
+
# Handle dict input (e.g., from tests or direct API calls)
|
|
266
|
+
user_msg = (
|
|
267
|
+
message.get("instruction")
|
|
268
|
+
or message.get("message")
|
|
269
|
+
or message.get("user_message", "")
|
|
270
|
+
)
|
|
271
|
+
unified_message = UnifiedTimingMessage(
|
|
272
|
+
user_message=user_msg,
|
|
273
|
+
conversation_history=message.get("conversation_history", []),
|
|
274
|
+
last_tool_call=message.get("last_tool_call"),
|
|
275
|
+
context=message.get("context", {}),
|
|
276
|
+
)
|
|
277
|
+
elif hasattr(message, "user_message") and hasattr(message, "message"):
|
|
278
|
+
# Already has both attributes (e.g., UnifiedTimingMessage)
|
|
279
|
+
unified_message = message
|
|
280
|
+
elif hasattr(message, "user_message"):
|
|
281
|
+
# schemas.TimingMessage format (sage-libs)
|
|
282
|
+
unified_message = UnifiedTimingMessage(
|
|
283
|
+
user_message=message.user_message,
|
|
284
|
+
conversation_history=getattr(message, "conversation_history", []),
|
|
285
|
+
last_tool_call=getattr(message, "last_tool_call", None),
|
|
286
|
+
context=getattr(message, "context", {}),
|
|
287
|
+
)
|
|
288
|
+
elif hasattr(message, "message"):
|
|
289
|
+
# Experiment's TimingMessage (has .message instead of .user_message)
|
|
290
|
+
unified_message = UnifiedTimingMessage(
|
|
291
|
+
user_message=message.message,
|
|
292
|
+
conversation_history=getattr(message, "conversation_history", []),
|
|
293
|
+
last_tool_call=getattr(message, "last_tool_call", None),
|
|
294
|
+
context=getattr(message, "context", {}),
|
|
295
|
+
)
|
|
296
|
+
else:
|
|
297
|
+
# Fallback: treat as string
|
|
298
|
+
unified_message = UnifiedTimingMessage(
|
|
299
|
+
user_message=str(message),
|
|
300
|
+
conversation_history=[],
|
|
301
|
+
last_tool_call=None,
|
|
302
|
+
context={},
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
return self.decider.decide(unified_message)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class AdapterRegistry:
|
|
309
|
+
"""
|
|
310
|
+
Registry for strategy adapters.
|
|
311
|
+
|
|
312
|
+
Maps string names like "baseline.keyword" to actual implementations.
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
def __init__(self):
|
|
316
|
+
"""Initialize registry with built-in strategies."""
|
|
317
|
+
import logging
|
|
318
|
+
|
|
319
|
+
self.logger = logging.getLogger(__name__)
|
|
320
|
+
|
|
321
|
+
self._selectors: dict[str, Any] = {}
|
|
322
|
+
self._planners: dict[str, Any] = {}
|
|
323
|
+
self._timing_deciders: dict[str, Any] = {}
|
|
324
|
+
self._factories: dict[str, Callable] = {}
|
|
325
|
+
|
|
326
|
+
# Register built-in strategies
|
|
327
|
+
self._register_builtins()
|
|
328
|
+
|
|
329
|
+
def _register_builtins(self) -> None:
|
|
330
|
+
"""Register built-in baseline strategies.
|
|
331
|
+
|
|
332
|
+
Method Classification:
|
|
333
|
+
=====================
|
|
334
|
+
|
|
335
|
+
Paper 1 (Benchmark) - Existing SOTA Methods:
|
|
336
|
+
- Tool Selection: keyword, embedding, hybrid, gorilla, dfsdt/toolllm
|
|
337
|
+
- Planning: simple, hierarchical, llm_based, react, tot
|
|
338
|
+
- Timing: rule_based, llm_based, hybrid, embedding
|
|
339
|
+
|
|
340
|
+
Paper 2 (Method) - SAGE Original Methods:
|
|
341
|
+
- Training: SAGE_baseline_sft, SAGE_coreset_loss, SAGE_coreset_diversity,
|
|
342
|
+
SAGE_coreset_hybrid, SAGE_continual, SAGE_combined
|
|
343
|
+
- (Defined in run_full_training_comparison.py, not runtime adapters)
|
|
344
|
+
"""
|
|
345
|
+
# =================================================================
|
|
346
|
+
# Paper 1: Existing SOTA Selector Strategies
|
|
347
|
+
# =================================================================
|
|
348
|
+
# BM25/TF-IDF keyword-based selection
|
|
349
|
+
self._factories["baseline.keyword"] = self._create_keyword_selector
|
|
350
|
+
self._factories["baseline.embedding"] = self._create_embedding_selector
|
|
351
|
+
self._factories["baseline.hybrid"] = self._create_hybrid_selector
|
|
352
|
+
self._factories["keyword"] = self._create_keyword_selector
|
|
353
|
+
self._factories["embedding"] = self._create_embedding_selector
|
|
354
|
+
self._factories["hybrid"] = self._create_hybrid_selector
|
|
355
|
+
# Aliased names for benchmark scripts
|
|
356
|
+
self._factories["selector.keyword"] = self._create_keyword_selector
|
|
357
|
+
self._factories["selector.embedding"] = self._create_embedding_selector
|
|
358
|
+
self._factories["selector.hybrid"] = self._create_hybrid_selector
|
|
359
|
+
# Gorilla: LLM-augmented retrieval (Patil et al., 2023)
|
|
360
|
+
self._factories["selector.gorilla"] = self._create_gorilla_selector
|
|
361
|
+
self._factories["gorilla"] = self._create_gorilla_selector
|
|
362
|
+
# ToolLLM DFSDT: Depth-First Search Decision Tree (Qin et al., 2023)
|
|
363
|
+
self._factories["selector.dfsdt"] = self._create_dfsdt_selector
|
|
364
|
+
self._factories["selector.toolllm"] = self._create_dfsdt_selector # Alias
|
|
365
|
+
self._factories["dfsdt"] = self._create_dfsdt_selector
|
|
366
|
+
self._factories["toolllm"] = self._create_dfsdt_selector # Alias
|
|
367
|
+
|
|
368
|
+
# =================================================================
|
|
369
|
+
# Paper 1: Existing SOTA Planner Strategies
|
|
370
|
+
# =================================================================
|
|
371
|
+
self._factories["baseline.template"] = self._create_template_planner
|
|
372
|
+
self._factories["baseline.hierarchical"] = self._create_hierarchical_planner
|
|
373
|
+
self._factories["cot"] = self._create_hierarchical_planner
|
|
374
|
+
self._factories["baseline.sequence"] = self._create_sequence_planner
|
|
375
|
+
# Challenge 2 planner strategies
|
|
376
|
+
self._factories["planner.simple"] = self._create_simple_planner
|
|
377
|
+
self._factories["planner.hierarchical"] = self._create_hierarchical_planning_strategy
|
|
378
|
+
self._factories["planner.llm_based"] = self._create_llm_planning_strategy
|
|
379
|
+
# ReAct: Reasoning + Acting (Yao et al., 2023)
|
|
380
|
+
self._factories["planner.react"] = self._create_react_planner
|
|
381
|
+
self._factories["react"] = self._create_react_planner # Alias
|
|
382
|
+
# Tree-of-Thoughts: Multi-path reasoning (Yao et al., 2023)
|
|
383
|
+
self._factories["planner.tot"] = self._create_tot_planner
|
|
384
|
+
self._factories["planner.tree_of_thoughts"] = self._create_tot_planner # Alias
|
|
385
|
+
|
|
386
|
+
# =================================================================
|
|
387
|
+
# Paper 1: Existing SOTA Timing Strategies
|
|
388
|
+
# =================================================================
|
|
389
|
+
self._factories["baseline.threshold"] = self._create_threshold_decider
|
|
390
|
+
self._factories["llm_based"] = self._create_llm_timing_decider
|
|
391
|
+
# New timing strategies for benchmark
|
|
392
|
+
self._factories["timing.rule_based"] = self._create_rule_based_decider
|
|
393
|
+
self._factories["timing.llm_based"] = self._create_llm_timing_decider
|
|
394
|
+
self._factories["timing.hybrid"] = self._create_hybrid_timing_decider
|
|
395
|
+
self._factories["timing.embedding"] = self._create_embedding_timing_decider
|
|
396
|
+
|
|
397
|
+
def register(self, name: str, strategy: Any) -> None:
|
|
398
|
+
"""
|
|
399
|
+
Register a strategy by name.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
name: Strategy name (e.g., "my_selector")
|
|
403
|
+
strategy: Strategy instance or factory
|
|
404
|
+
"""
|
|
405
|
+
if callable(strategy) and not hasattr(strategy, "predict"):
|
|
406
|
+
self._factories[name] = strategy
|
|
407
|
+
else:
|
|
408
|
+
# Store instance directly
|
|
409
|
+
if hasattr(strategy, "select"):
|
|
410
|
+
self._selectors[name] = SelectorAdapter(strategy)
|
|
411
|
+
elif hasattr(strategy, "plan"):
|
|
412
|
+
self._planners[name] = PlannerAdapter(strategy)
|
|
413
|
+
elif hasattr(strategy, "decide"):
|
|
414
|
+
self._timing_deciders[name] = TimingAdapter(strategy)
|
|
415
|
+
else:
|
|
416
|
+
self._factories[name] = lambda: strategy
|
|
417
|
+
|
|
418
|
+
def get(self, name: str, resources: Optional[Any] = None) -> Any:
|
|
419
|
+
"""
|
|
420
|
+
Get a strategy by name.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
name: Strategy name
|
|
424
|
+
resources: Optional SelectorResources for initialization
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Strategy adapter instance
|
|
428
|
+
|
|
429
|
+
Raises:
|
|
430
|
+
ValueError: If strategy not found
|
|
431
|
+
"""
|
|
432
|
+
# Check cached instances
|
|
433
|
+
if name in self._selectors:
|
|
434
|
+
return self._selectors[name]
|
|
435
|
+
if name in self._planners:
|
|
436
|
+
return self._planners[name]
|
|
437
|
+
if name in self._timing_deciders:
|
|
438
|
+
return self._timing_deciders[name]
|
|
439
|
+
|
|
440
|
+
# Try factory
|
|
441
|
+
if name in self._factories:
|
|
442
|
+
strategy = self._factories[name](resources)
|
|
443
|
+
return strategy
|
|
444
|
+
|
|
445
|
+
raise ValueError(f"Unknown strategy: {name}. Available: {self.list_strategies()}")
|
|
446
|
+
|
|
447
|
+
def list_strategies(self) -> list:
|
|
448
|
+
"""List all registered strategy names."""
|
|
449
|
+
all_names = set(self._selectors.keys())
|
|
450
|
+
all_names.update(self._planners.keys())
|
|
451
|
+
all_names.update(self._timing_deciders.keys())
|
|
452
|
+
all_names.update(self._factories.keys())
|
|
453
|
+
return sorted(all_names)
|
|
454
|
+
|
|
455
|
+
# --- Factory methods for built-in strategies ---
|
|
456
|
+
|
|
457
|
+
def _create_keyword_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
|
|
458
|
+
"""
|
|
459
|
+
Create keyword-based selector using BM25/TF-IDF.
|
|
460
|
+
|
|
461
|
+
This implementation uses dynamic indexing - it processes the candidate_tools
|
|
462
|
+
provided in each query, enabling cross-dataset evaluation.
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
class DynamicKeywordSelector:
|
|
466
|
+
"""Dynamic keyword selector using BM25 on candidate_tools."""
|
|
467
|
+
|
|
468
|
+
def __init__(self):
|
|
469
|
+
self.name = "keyword"
|
|
470
|
+
|
|
471
|
+
def select(self, query, top_k=5):
|
|
472
|
+
"""Select tools using BM25 keyword matching."""
|
|
473
|
+
import math
|
|
474
|
+
import re
|
|
475
|
+
|
|
476
|
+
from sage.libs.agentic.agents.action.tool_selection.schemas import (
|
|
477
|
+
ToolPrediction,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
candidate_tools = getattr(query, "candidate_tools", []) or []
|
|
481
|
+
if not candidate_tools:
|
|
482
|
+
return []
|
|
483
|
+
|
|
484
|
+
instruction = getattr(query, "instruction", str(query))
|
|
485
|
+
|
|
486
|
+
# Tokenize query
|
|
487
|
+
query_tokens = set(re.findall(r"[a-z0-9]+", instruction.lower()))
|
|
488
|
+
if not query_tokens:
|
|
489
|
+
return []
|
|
490
|
+
|
|
491
|
+
# Build tool texts and tokenize
|
|
492
|
+
tool_data = []
|
|
493
|
+
for tool in candidate_tools:
|
|
494
|
+
if isinstance(tool, str):
|
|
495
|
+
tool_id = tool
|
|
496
|
+
tool_text = tool.replace("_", " ")
|
|
497
|
+
elif hasattr(tool, "name"):
|
|
498
|
+
tool_id = getattr(tool, "tool_id", getattr(tool, "id", tool.name))
|
|
499
|
+
tool_text = f"{tool.name} {getattr(tool, 'description', '')}"
|
|
500
|
+
elif isinstance(tool, dict):
|
|
501
|
+
tool_id = tool.get("tool_id", tool.get("id", tool.get("name", "")))
|
|
502
|
+
tool_text = f"{tool.get('name', '')} {tool.get('description', '')}"
|
|
503
|
+
else:
|
|
504
|
+
continue
|
|
505
|
+
tool_tokens = set(re.findall(r"[a-z0-9]+", tool_text.lower()))
|
|
506
|
+
tool_data.append((tool_id, tool_tokens, len(tool_tokens)))
|
|
507
|
+
|
|
508
|
+
if not tool_data:
|
|
509
|
+
return []
|
|
510
|
+
|
|
511
|
+
# Compute IDF
|
|
512
|
+
num_docs = len(tool_data)
|
|
513
|
+
doc_freq = {}
|
|
514
|
+
for _, tokens, _ in tool_data:
|
|
515
|
+
for token in tokens:
|
|
516
|
+
doc_freq[token] = doc_freq.get(token, 0) + 1
|
|
517
|
+
idf = {t: math.log((num_docs + 1) / (df + 1)) for t, df in doc_freq.items()}
|
|
518
|
+
|
|
519
|
+
# BM25 scoring
|
|
520
|
+
k1, b = 1.5, 0.75
|
|
521
|
+
avg_dl = sum(dl for _, _, dl in tool_data) / num_docs if num_docs else 1
|
|
522
|
+
|
|
523
|
+
scores = []
|
|
524
|
+
for tool_id, tool_tokens, doc_len in tool_data:
|
|
525
|
+
score = 0.0
|
|
526
|
+
for token in query_tokens:
|
|
527
|
+
if token in tool_tokens:
|
|
528
|
+
tf = 1 # Binary TF
|
|
529
|
+
score += (
|
|
530
|
+
idf.get(token, 0)
|
|
531
|
+
* (tf * (k1 + 1))
|
|
532
|
+
/ (tf + k1 * (1 - b + b * doc_len / avg_dl))
|
|
533
|
+
)
|
|
534
|
+
scores.append((tool_id, score))
|
|
535
|
+
|
|
536
|
+
# Sort and return top-k
|
|
537
|
+
scores.sort(key=lambda x: x[1], reverse=True)
|
|
538
|
+
return [
|
|
539
|
+
ToolPrediction(tool_id=tid, score=min(s / 10, 1.0)) for tid, s in scores[:top_k]
|
|
540
|
+
]
|
|
541
|
+
|
|
542
|
+
return SelectorAdapter(DynamicKeywordSelector())
|
|
543
|
+
|
|
544
|
+
def _create_embedding_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
|
|
545
|
+
"""
|
|
546
|
+
Create embedding-based selector using cosine similarity.
|
|
547
|
+
|
|
548
|
+
This implementation uses dynamic indexing - it embeds the candidate_tools
|
|
549
|
+
provided in each query, enabling cross-dataset evaluation.
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
class DynamicEmbeddingSelector:
|
|
553
|
+
"""Dynamic embedding selector on candidate_tools."""
|
|
554
|
+
|
|
555
|
+
def __init__(self):
|
|
556
|
+
self.name = "embedding"
|
|
557
|
+
self._embedding_client = None
|
|
558
|
+
|
|
559
|
+
def _init_client(self):
|
|
560
|
+
if self._embedding_client is None:
|
|
561
|
+
try:
|
|
562
|
+
from sage.common.components.sage_embedding import (
|
|
563
|
+
EmbeddingClientAdapter,
|
|
564
|
+
EmbeddingFactory,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
raw_embedder = EmbeddingFactory.create(
|
|
568
|
+
"hf", model=BENCHMARK_EMBEDDING_MODEL
|
|
569
|
+
)
|
|
570
|
+
self._embedding_client = EmbeddingClientAdapter(raw_embedder)
|
|
571
|
+
except Exception:
|
|
572
|
+
pass
|
|
573
|
+
|
|
574
|
+
def select(self, query, top_k=5):
|
|
575
|
+
"""Select tools using embedding similarity."""
|
|
576
|
+
import numpy as np
|
|
577
|
+
from sage.libs.agentic.agents.action.tool_selection.schemas import (
|
|
578
|
+
ToolPrediction,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
self._init_client()
|
|
582
|
+
|
|
583
|
+
candidate_tools = getattr(query, "candidate_tools", []) or []
|
|
584
|
+
if not candidate_tools:
|
|
585
|
+
return []
|
|
586
|
+
|
|
587
|
+
instruction = getattr(query, "instruction", str(query))
|
|
588
|
+
|
|
589
|
+
# Build tool texts
|
|
590
|
+
tool_ids = []
|
|
591
|
+
tool_texts = []
|
|
592
|
+
for tool in candidate_tools:
|
|
593
|
+
if isinstance(tool, str):
|
|
594
|
+
tool_ids.append(tool)
|
|
595
|
+
tool_texts.append(tool.replace("_", " "))
|
|
596
|
+
elif hasattr(tool, "name"):
|
|
597
|
+
tool_ids.append(getattr(tool, "tool_id", getattr(tool, "id", tool.name)))
|
|
598
|
+
tool_texts.append(f"{tool.name}: {getattr(tool, 'description', '')}")
|
|
599
|
+
elif isinstance(tool, dict):
|
|
600
|
+
tool_ids.append(tool.get("tool_id", tool.get("id", tool.get("name", ""))))
|
|
601
|
+
tool_texts.append(f"{tool.get('name', '')}: {tool.get('description', '')}")
|
|
602
|
+
|
|
603
|
+
if not tool_texts or self._embedding_client is None:
|
|
604
|
+
# Fallback to simple matching
|
|
605
|
+
return [ToolPrediction(tool_id=tid, score=0.5) for tid in tool_ids[:top_k]]
|
|
606
|
+
|
|
607
|
+
try:
|
|
608
|
+
# Embed query and tools
|
|
609
|
+
all_texts = [instruction] + tool_texts
|
|
610
|
+
embeddings = self._embedding_client.embed(all_texts)
|
|
611
|
+
|
|
612
|
+
query_emb = np.asarray(embeddings[0])
|
|
613
|
+
tool_embs = np.asarray(embeddings[1:])
|
|
614
|
+
|
|
615
|
+
# Cosine similarity
|
|
616
|
+
query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-8)
|
|
617
|
+
tool_norms = tool_embs / (
|
|
618
|
+
np.linalg.norm(tool_embs, axis=1, keepdims=True) + 1e-8
|
|
619
|
+
)
|
|
620
|
+
scores = np.dot(tool_norms, query_norm)
|
|
621
|
+
|
|
622
|
+
# Sort and return top-k
|
|
623
|
+
top_indices = np.argsort(scores)[::-1][:top_k]
|
|
624
|
+
return [
|
|
625
|
+
ToolPrediction(tool_id=tool_ids[i], score=float(scores[i]))
|
|
626
|
+
for i in top_indices
|
|
627
|
+
]
|
|
628
|
+
except Exception:
|
|
629
|
+
return [ToolPrediction(tool_id=tid, score=0.5) for tid in tool_ids[:top_k]]
|
|
630
|
+
|
|
631
|
+
return SelectorAdapter(DynamicEmbeddingSelector())
|
|
632
|
+
|
|
633
|
+
def _create_hybrid_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
|
|
634
|
+
"""
|
|
635
|
+
Create hybrid selector combining keyword (BM25) and embedding similarity.
|
|
636
|
+
|
|
637
|
+
This implementation uses dynamic indexing - it processes the candidate_tools
|
|
638
|
+
provided in each query, enabling cross-dataset evaluation.
|
|
639
|
+
"""
|
|
640
|
+
|
|
641
|
+
class DynamicHybridSelector:
|
|
642
|
+
"""Dynamic hybrid selector: 40% keyword + 60% embedding."""
|
|
643
|
+
|
|
644
|
+
def __init__(self):
|
|
645
|
+
self.name = "hybrid"
|
|
646
|
+
self._embedding_client = None
|
|
647
|
+
self._keyword_weight = 0.4
|
|
648
|
+
self._embedding_weight = 0.6
|
|
649
|
+
|
|
650
|
+
def _init_client(self):
|
|
651
|
+
if self._embedding_client is None:
|
|
652
|
+
try:
|
|
653
|
+
from sage.common.components.sage_embedding import (
|
|
654
|
+
EmbeddingClientAdapter,
|
|
655
|
+
EmbeddingFactory,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
raw_embedder = EmbeddingFactory.create(
|
|
659
|
+
"hf", model=BENCHMARK_EMBEDDING_MODEL
|
|
660
|
+
)
|
|
661
|
+
self._embedding_client = EmbeddingClientAdapter(raw_embedder)
|
|
662
|
+
except Exception:
|
|
663
|
+
pass
|
|
664
|
+
|
|
665
|
+
def select(self, query, top_k=5):
|
|
666
|
+
"""Select tools using hybrid scoring."""
|
|
667
|
+
import math
|
|
668
|
+
import re
|
|
669
|
+
|
|
670
|
+
import numpy as np
|
|
671
|
+
from sage.libs.agentic.agents.action.tool_selection.schemas import (
|
|
672
|
+
ToolPrediction,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
self._init_client()
|
|
676
|
+
|
|
677
|
+
candidate_tools = getattr(query, "candidate_tools", []) or []
|
|
678
|
+
if not candidate_tools:
|
|
679
|
+
return []
|
|
680
|
+
|
|
681
|
+
instruction = getattr(query, "instruction", str(query))
|
|
682
|
+
|
|
683
|
+
# Build tool data
|
|
684
|
+
tool_ids = []
|
|
685
|
+
tool_texts = []
|
|
686
|
+
for tool in candidate_tools:
|
|
687
|
+
if isinstance(tool, str):
|
|
688
|
+
tool_ids.append(tool)
|
|
689
|
+
tool_texts.append(tool.replace("_", " "))
|
|
690
|
+
elif hasattr(tool, "name"):
|
|
691
|
+
tool_ids.append(getattr(tool, "tool_id", getattr(tool, "id", tool.name)))
|
|
692
|
+
tool_texts.append(f"{tool.name}: {getattr(tool, 'description', '')}")
|
|
693
|
+
elif isinstance(tool, dict):
|
|
694
|
+
tool_ids.append(tool.get("tool_id", tool.get("id", tool.get("name", ""))))
|
|
695
|
+
tool_texts.append(f"{tool.get('name', '')}: {tool.get('description', '')}")
|
|
696
|
+
|
|
697
|
+
if not tool_ids:
|
|
698
|
+
return []
|
|
699
|
+
|
|
700
|
+
# === Keyword scores (BM25) ===
|
|
701
|
+
query_tokens = set(re.findall(r"[a-z0-9]+", instruction.lower()))
|
|
702
|
+
tool_tokens_list = [set(re.findall(r"[a-z0-9]+", t.lower())) for t in tool_texts]
|
|
703
|
+
|
|
704
|
+
# IDF
|
|
705
|
+
num_docs = len(tool_ids)
|
|
706
|
+
doc_freq = {}
|
|
707
|
+
for tokens in tool_tokens_list:
|
|
708
|
+
for token in tokens:
|
|
709
|
+
doc_freq[token] = doc_freq.get(token, 0) + 1
|
|
710
|
+
idf = {t: math.log((num_docs + 1) / (df + 1)) for t, df in doc_freq.items()}
|
|
711
|
+
|
|
712
|
+
# BM25
|
|
713
|
+
k1, b = 1.5, 0.75
|
|
714
|
+
avg_dl = sum(len(t) for t in tool_tokens_list) / num_docs if num_docs else 1
|
|
715
|
+
|
|
716
|
+
keyword_scores = []
|
|
717
|
+
for tool_tokens in tool_tokens_list:
|
|
718
|
+
score = 0.0
|
|
719
|
+
doc_len = len(tool_tokens)
|
|
720
|
+
for token in query_tokens:
|
|
721
|
+
if token in tool_tokens:
|
|
722
|
+
tf = 1
|
|
723
|
+
score += (
|
|
724
|
+
idf.get(token, 0)
|
|
725
|
+
* (tf * (k1 + 1))
|
|
726
|
+
/ (tf + k1 * (1 - b + b * doc_len / avg_dl))
|
|
727
|
+
)
|
|
728
|
+
keyword_scores.append(score)
|
|
729
|
+
|
|
730
|
+
# Normalize keyword scores
|
|
731
|
+
max_kw = max(keyword_scores) if keyword_scores and max(keyword_scores) > 0 else 1
|
|
732
|
+
keyword_scores = [s / max_kw for s in keyword_scores]
|
|
733
|
+
|
|
734
|
+
# === Embedding scores ===
|
|
735
|
+
embedding_scores = [0.0] * len(tool_ids)
|
|
736
|
+
if self._embedding_client is not None:
|
|
737
|
+
try:
|
|
738
|
+
all_texts = [instruction] + tool_texts
|
|
739
|
+
embeddings = self._embedding_client.embed(all_texts)
|
|
740
|
+
|
|
741
|
+
query_emb = np.asarray(embeddings[0])
|
|
742
|
+
tool_embs = np.asarray(embeddings[1:])
|
|
743
|
+
|
|
744
|
+
query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-8)
|
|
745
|
+
tool_norms = tool_embs / (
|
|
746
|
+
np.linalg.norm(tool_embs, axis=1, keepdims=True) + 1e-8
|
|
747
|
+
)
|
|
748
|
+
embedding_scores = list(np.dot(tool_norms, query_norm))
|
|
749
|
+
except Exception:
|
|
750
|
+
pass
|
|
751
|
+
|
|
752
|
+
# === Combine scores ===
|
|
753
|
+
combined = [
|
|
754
|
+
(
|
|
755
|
+
tool_ids[i],
|
|
756
|
+
self._keyword_weight * keyword_scores[i]
|
|
757
|
+
+ self._embedding_weight * embedding_scores[i],
|
|
758
|
+
)
|
|
759
|
+
for i in range(len(tool_ids))
|
|
760
|
+
]
|
|
761
|
+
combined.sort(key=lambda x: x[1], reverse=True)
|
|
762
|
+
|
|
763
|
+
return [
|
|
764
|
+
ToolPrediction(tool_id=tid, score=min(s, 1.0)) for tid, s in combined[:top_k]
|
|
765
|
+
]
|
|
766
|
+
|
|
767
|
+
return SelectorAdapter(DynamicHybridSelector())
|
|
768
|
+
|
|
769
|
+
def _create_gorilla_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
|
|
770
|
+
"""
|
|
771
|
+
Create Gorilla-style retrieval-augmented selector.
|
|
772
|
+
|
|
773
|
+
Gorilla uses a two-stage approach:
|
|
774
|
+
1. Embedding retrieval to find candidate tools
|
|
775
|
+
2. LLM selection from retrieved candidates
|
|
776
|
+
|
|
777
|
+
Reference: Patil et al. (2023) "Gorilla: Large Language Model Connected with Massive APIs"
|
|
778
|
+
|
|
779
|
+
This implementation uses dynamic indexing - it builds embeddings from the
|
|
780
|
+
candidate_tools provided in each query, enabling cross-dataset evaluation.
|
|
781
|
+
"""
|
|
782
|
+
|
|
783
|
+
class DynamicGorillaSelector:
|
|
784
|
+
"""Gorilla selector with dynamic tool indexing."""
|
|
785
|
+
|
|
786
|
+
def __init__(self):
|
|
787
|
+
self._embedding_client = None
|
|
788
|
+
self._llm_client = None
|
|
789
|
+
self.name = "gorilla"
|
|
790
|
+
|
|
791
|
+
def _init_clients(self):
|
|
792
|
+
"""Lazy initialization of embedding and LLM clients."""
|
|
793
|
+
if self._embedding_client is None:
|
|
794
|
+
# Try local HuggingFace embedding first
|
|
795
|
+
try:
|
|
796
|
+
from sage.common.components.sage_embedding import (
|
|
797
|
+
EmbeddingClientAdapter,
|
|
798
|
+
EmbeddingFactory,
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
raw_embedder = EmbeddingFactory.create(
|
|
802
|
+
"hf", model=BENCHMARK_EMBEDDING_MODEL
|
|
803
|
+
)
|
|
804
|
+
self._embedding_client = EmbeddingClientAdapter(raw_embedder)
|
|
805
|
+
except Exception:
|
|
806
|
+
pass
|
|
807
|
+
|
|
808
|
+
if self._llm_client is None:
|
|
809
|
+
try:
|
|
810
|
+
from sage.llm import UnifiedInferenceClient
|
|
811
|
+
|
|
812
|
+
self._llm_client = UnifiedInferenceClient.create()
|
|
813
|
+
except Exception:
|
|
814
|
+
pass
|
|
815
|
+
|
|
816
|
+
def select(self, query, top_k=5):
|
|
817
|
+
"""Select tools using embedding retrieval + LLM reranking."""
|
|
818
|
+
import logging
|
|
819
|
+
|
|
820
|
+
import numpy as np
|
|
821
|
+
from sage.libs.agentic.agents.action.tool_selection.schemas import (
|
|
822
|
+
ToolPrediction,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
logger = logging.getLogger(__name__)
|
|
826
|
+
self._init_clients()
|
|
827
|
+
|
|
828
|
+
candidate_tools = getattr(query, "candidate_tools", []) or []
|
|
829
|
+
if not candidate_tools:
|
|
830
|
+
return []
|
|
831
|
+
|
|
832
|
+
instruction = getattr(query, "instruction", str(query))
|
|
833
|
+
|
|
834
|
+
# Parse candidate tools into (tool_id, tool_text) pairs
|
|
835
|
+
tool_ids, tool_texts = [], []
|
|
836
|
+
for tool in candidate_tools:
|
|
837
|
+
if isinstance(tool, str):
|
|
838
|
+
tool_ids.append(tool)
|
|
839
|
+
tool_texts.append(tool)
|
|
840
|
+
elif hasattr(tool, "name"):
|
|
841
|
+
tid = getattr(tool, "tool_id", getattr(tool, "id", tool.name))
|
|
842
|
+
tool_ids.append(tid)
|
|
843
|
+
tool_texts.append(f"{tool.name}: {getattr(tool, 'description', '')}")
|
|
844
|
+
elif isinstance(tool, dict):
|
|
845
|
+
tid = tool.get("tool_id", tool.get("id", tool.get("name", "")))
|
|
846
|
+
tool_ids.append(tid)
|
|
847
|
+
tool_texts.append(f"{tool.get('name', tid)}: {tool.get('description', '')}")
|
|
848
|
+
|
|
849
|
+
if not tool_ids:
|
|
850
|
+
return []
|
|
851
|
+
|
|
852
|
+
# Fallback if no embedding client
|
|
853
|
+
if self._embedding_client is None:
|
|
854
|
+
return [ToolPrediction(tool_id=tid, score=0.5) for tid in tool_ids[:top_k]]
|
|
855
|
+
|
|
856
|
+
try:
|
|
857
|
+
# Embed query and tools
|
|
858
|
+
embeddings = self._embedding_client.embed([instruction] + tool_texts)
|
|
859
|
+
query_emb = np.asarray(embeddings[0])
|
|
860
|
+
tool_embs = np.asarray(embeddings[1:])
|
|
861
|
+
|
|
862
|
+
# Cosine similarity
|
|
863
|
+
query_norm = query_emb / (np.linalg.norm(query_emb) + 1e-8)
|
|
864
|
+
tool_norms = tool_embs / (
|
|
865
|
+
np.linalg.norm(tool_embs, axis=1, keepdims=True) + 1e-8
|
|
866
|
+
)
|
|
867
|
+
scores = np.dot(tool_norms, query_norm)
|
|
868
|
+
|
|
869
|
+
# Get top candidates for LLM reranking
|
|
870
|
+
retrieve_k = min(15, len(tool_ids))
|
|
871
|
+
top_indices = np.argsort(scores)[::-1][:retrieve_k]
|
|
872
|
+
|
|
873
|
+
# LLM reranking if available
|
|
874
|
+
if self._llm_client is not None:
|
|
875
|
+
reranked = self._llm_rerank(
|
|
876
|
+
instruction,
|
|
877
|
+
[(tool_ids[i], tool_texts[i]) for i in top_indices],
|
|
878
|
+
top_k,
|
|
879
|
+
)
|
|
880
|
+
if reranked:
|
|
881
|
+
return [
|
|
882
|
+
ToolPrediction(tool_id=tid, score=1.0 - i * 0.1)
|
|
883
|
+
for i, tid in enumerate(reranked)
|
|
884
|
+
]
|
|
885
|
+
|
|
886
|
+
# Fallback to embedding-only
|
|
887
|
+
return [
|
|
888
|
+
ToolPrediction(tool_id=tool_ids[i], score=float(scores[i]))
|
|
889
|
+
for i in top_indices[:top_k]
|
|
890
|
+
]
|
|
891
|
+
|
|
892
|
+
except Exception as e:
|
|
893
|
+
logger.warning(f"Gorilla selector failed: {e}")
|
|
894
|
+
return []
|
|
895
|
+
|
|
896
|
+
def _llm_rerank(self, query, tools, top_k):
|
|
897
|
+
"""Use LLM to rerank retrieved tools."""
|
|
898
|
+
import json
|
|
899
|
+
import re
|
|
900
|
+
|
|
901
|
+
tools_text = "\n".join(
|
|
902
|
+
f"{i + 1}. {tid}: {desc}" for i, (tid, desc) in enumerate(tools)
|
|
903
|
+
)
|
|
904
|
+
prompt = f"""Select the {top_k} most relevant tools for this task. Return ONLY a JSON array of tool IDs.
|
|
905
|
+
|
|
906
|
+
Task: {query}
|
|
907
|
+
|
|
908
|
+
Tools:
|
|
909
|
+
{tools_text}
|
|
910
|
+
|
|
911
|
+
Output (JSON array only):"""
|
|
912
|
+
|
|
913
|
+
try:
|
|
914
|
+
response = self._llm_client.chat(
|
|
915
|
+
[{"role": "user", "content": prompt}], temperature=BENCHMARK_LLM_TEMPERATURE
|
|
916
|
+
)
|
|
917
|
+
response = response.strip()
|
|
918
|
+
if response.startswith("```"):
|
|
919
|
+
response = "\n".join(response.split("\n")[1:-1]).strip()
|
|
920
|
+
match = re.search(r"\[.*?\]", response, re.DOTALL)
|
|
921
|
+
if match:
|
|
922
|
+
selected = json.loads(match.group())
|
|
923
|
+
valid_ids = {tid for tid, _ in tools}
|
|
924
|
+
return [tid for tid in selected if tid in valid_ids][:top_k]
|
|
925
|
+
except Exception:
|
|
926
|
+
pass
|
|
927
|
+
return []
|
|
928
|
+
|
|
929
|
+
return SelectorAdapter(DynamicGorillaSelector())
|
|
930
|
+
|
|
931
|
+
def _create_dfsdt_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
|
|
932
|
+
"""
|
|
933
|
+
Create DFSDT (Depth-First Search-based Decision Tree) selector.
|
|
934
|
+
|
|
935
|
+
Based on ToolLLM paper (Qin et al., 2023):
|
|
936
|
+
"ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs"
|
|
937
|
+
|
|
938
|
+
This implementation uses dynamic scoring - it evaluates the candidate_tools
|
|
939
|
+
provided in each query, enabling cross-dataset evaluation.
|
|
940
|
+
"""
|
|
941
|
+
|
|
942
|
+
class DynamicDFSDTSelector:
|
|
943
|
+
"""DFSDT selector with dynamic tool scoring."""
|
|
944
|
+
|
|
945
|
+
def __init__(self):
|
|
946
|
+
self._llm_client = None
|
|
947
|
+
self._score_threshold = 0.3
|
|
948
|
+
self.name = "dfsdt"
|
|
949
|
+
|
|
950
|
+
def _init_client(self):
|
|
951
|
+
"""Lazy initialization of LLM client."""
|
|
952
|
+
if self._llm_client is None:
|
|
953
|
+
try:
|
|
954
|
+
from sage.llm import UnifiedInferenceClient
|
|
955
|
+
|
|
956
|
+
self._llm_client = UnifiedInferenceClient.create()
|
|
957
|
+
except Exception:
|
|
958
|
+
pass
|
|
959
|
+
|
|
960
|
+
def select(self, query, top_k=5):
|
|
961
|
+
"""Select tools using LLM-based scoring."""
|
|
962
|
+
|
|
963
|
+
from sage.libs.agentic.agents.action.tool_selection.schemas import (
|
|
964
|
+
ToolPrediction,
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
self._init_client()
|
|
968
|
+
|
|
969
|
+
candidate_tools = getattr(query, "candidate_tools", []) or []
|
|
970
|
+
if not candidate_tools:
|
|
971
|
+
return []
|
|
972
|
+
|
|
973
|
+
instruction = getattr(query, "instruction", str(query))
|
|
974
|
+
|
|
975
|
+
# Parse and score each tool
|
|
976
|
+
scored_tools = []
|
|
977
|
+
for tool in candidate_tools:
|
|
978
|
+
if isinstance(tool, str):
|
|
979
|
+
tool_id, tool_name, tool_desc = tool, tool, ""
|
|
980
|
+
elif hasattr(tool, "name"):
|
|
981
|
+
tool_id = getattr(tool, "tool_id", getattr(tool, "id", tool.name))
|
|
982
|
+
tool_name, tool_desc = tool.name, getattr(tool, "description", "")
|
|
983
|
+
elif isinstance(tool, dict):
|
|
984
|
+
tool_id = tool.get("tool_id", tool.get("id", tool.get("name", "")))
|
|
985
|
+
tool_name = tool.get("name", tool_id)
|
|
986
|
+
tool_desc = tool.get("description", "")
|
|
987
|
+
else:
|
|
988
|
+
continue
|
|
989
|
+
|
|
990
|
+
score = self._score_tool(instruction, tool_name, tool_desc)
|
|
991
|
+
if score >= self._score_threshold:
|
|
992
|
+
scored_tools.append((tool_id, score))
|
|
993
|
+
|
|
994
|
+
scored_tools.sort(key=lambda x: x[1], reverse=True)
|
|
995
|
+
return [
|
|
996
|
+
ToolPrediction(tool_id=tid, score=score) for tid, score in scored_tools[:top_k]
|
|
997
|
+
]
|
|
998
|
+
|
|
999
|
+
def _score_tool(self, query, tool_name, tool_desc):
|
|
1000
|
+
"""Score tool relevance using LLM or keyword fallback."""
|
|
1001
|
+
import re
|
|
1002
|
+
|
|
1003
|
+
if self._llm_client is not None:
|
|
1004
|
+
prompt = f"""Rate relevance (0-10): Query: {query} | Tool: {tool_name} - {tool_desc}
|
|
1005
|
+
Output only a number:"""
|
|
1006
|
+
try:
|
|
1007
|
+
response = self._llm_client.chat(
|
|
1008
|
+
[{"role": "user", "content": prompt}],
|
|
1009
|
+
temperature=BENCHMARK_LLM_TEMPERATURE,
|
|
1010
|
+
)
|
|
1011
|
+
numbers = re.findall(r"(\d+(?:\.\d+)?)", response.strip())
|
|
1012
|
+
if numbers:
|
|
1013
|
+
return min(max(float(numbers[0]), 0.0), 10.0) / 10.0
|
|
1014
|
+
except Exception:
|
|
1015
|
+
pass
|
|
1016
|
+
|
|
1017
|
+
# Keyword fallback
|
|
1018
|
+
query_words = set(query.lower().split())
|
|
1019
|
+
tool_words = set(f"{tool_name} {tool_desc}".lower().split())
|
|
1020
|
+
if not query_words:
|
|
1021
|
+
return 0.0
|
|
1022
|
+
return min(len(query_words & tool_words) / len(query_words), 1.0)
|
|
1023
|
+
|
|
1024
|
+
return SelectorAdapter(DynamicDFSDTSelector())
|
|
1025
|
+
|
|
1026
|
+
def _create_inline_hybrid_selector(self, resources: Optional[Any] = None) -> SelectorAdapter:
|
|
1027
|
+
"""Create inline hybrid selector when library version unavailable."""
|
|
1028
|
+
from sage.libs.agentic.agents.action.tool_selection import (
|
|
1029
|
+
KeywordSelector,
|
|
1030
|
+
KeywordSelectorConfig,
|
|
1031
|
+
ToolPrediction,
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
class InlineHybridSelector:
|
|
1035
|
+
"""Inline hybrid selector for benchmark fallback."""
|
|
1036
|
+
|
|
1037
|
+
def __init__(self, resources):
|
|
1038
|
+
self.resources = resources
|
|
1039
|
+
# Create keyword selector as base
|
|
1040
|
+
config = KeywordSelectorConfig(
|
|
1041
|
+
name="keyword",
|
|
1042
|
+
method="bm25",
|
|
1043
|
+
top_k=10,
|
|
1044
|
+
)
|
|
1045
|
+
self._keyword_selector = KeywordSelector(config, resources)
|
|
1046
|
+
self.name = "hybrid"
|
|
1047
|
+
|
|
1048
|
+
def select(self, query, top_k=5):
|
|
1049
|
+
"""Select using keyword with boosting."""
|
|
1050
|
+
# Use keyword selector
|
|
1051
|
+
results = self._keyword_selector.select(query, top_k=top_k * 2)
|
|
1052
|
+
|
|
1053
|
+
# Simple score boosting based on query-tool match
|
|
1054
|
+
boosted = []
|
|
1055
|
+
query_lower = query.instruction.lower()
|
|
1056
|
+
|
|
1057
|
+
for pred in results[:top_k]:
|
|
1058
|
+
# Check if tool matches query keywords better
|
|
1059
|
+
tool_text = self.resources.tools_loader.get_tool(pred.tool_id)
|
|
1060
|
+
tool_desc = getattr(tool_text, "description", "") or ""
|
|
1061
|
+
|
|
1062
|
+
# Simple boost: increase score if query words in description
|
|
1063
|
+
words = query_lower.split()
|
|
1064
|
+
matches = sum(1 for w in words if w in tool_desc.lower())
|
|
1065
|
+
boost = 1.0 + (matches * 0.1)
|
|
1066
|
+
|
|
1067
|
+
boosted.append(
|
|
1068
|
+
ToolPrediction(
|
|
1069
|
+
tool_id=pred.tool_id,
|
|
1070
|
+
score=min(pred.score * boost, 1.0),
|
|
1071
|
+
metadata={"method": "hybrid_inline", "boost": boost},
|
|
1072
|
+
)
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
boosted.sort(key=lambda x: x.score, reverse=True)
|
|
1076
|
+
|
|
1077
|
+
# Optionally use LLM reranker via UnifiedInferenceClient.create()
|
|
1078
|
+
# Uses LOCAL-FIRST strategy: local services (SagePorts) -> cloud API fallback
|
|
1079
|
+
# Set SAGE_HYBRID_ENABLE_LLM_RERANK=1 to enable
|
|
1080
|
+
import os
|
|
1081
|
+
|
|
1082
|
+
enable_llm_rerank = os.environ.get("SAGE_HYBRID_ENABLE_LLM_RERANK", "0") == "1"
|
|
1083
|
+
|
|
1084
|
+
if enable_llm_rerank and not hasattr(self, "_llm_client_checked"):
|
|
1085
|
+
self._llm_client_checked = True
|
|
1086
|
+
try:
|
|
1087
|
+
from sage.llm import UnifiedInferenceClient
|
|
1088
|
+
|
|
1089
|
+
# Use singleton to avoid repeated model loading
|
|
1090
|
+
self._llm_client = UnifiedInferenceClient.get_instance(
|
|
1091
|
+
instance_key="benchmark_hybrid"
|
|
1092
|
+
)
|
|
1093
|
+
except Exception:
|
|
1094
|
+
self._llm_client = None
|
|
1095
|
+
|
|
1096
|
+
if enable_llm_rerank and getattr(self, "_llm_client", None) is not None:
|
|
1097
|
+
try:
|
|
1098
|
+
reranked = self._llm_rerank(self._llm_client, query, boosted, top_k)
|
|
1099
|
+
if reranked:
|
|
1100
|
+
return reranked
|
|
1101
|
+
except Exception:
|
|
1102
|
+
pass # Silently fall back to keyword-boosted results
|
|
1103
|
+
|
|
1104
|
+
return boosted[:top_k]
|
|
1105
|
+
|
|
1106
|
+
def _llm_rerank(self, llm, query, boosted, top_k):
|
|
1107
|
+
"""Use LLM to rerank candidates."""
|
|
1108
|
+
import json
|
|
1109
|
+
import re
|
|
1110
|
+
|
|
1111
|
+
# Prepare prompt
|
|
1112
|
+
cand_infos = []
|
|
1113
|
+
for p in boosted[: max(10, top_k)]:
|
|
1114
|
+
try:
|
|
1115
|
+
t = self.resources.tools_loader.get_tool(p.tool_id)
|
|
1116
|
+
desc = getattr(t, "description", "") or ""
|
|
1117
|
+
except Exception:
|
|
1118
|
+
desc = ""
|
|
1119
|
+
cand_infos.append(f"{p.tool_id}: {desc[:100]}")
|
|
1120
|
+
|
|
1121
|
+
messages = [
|
|
1122
|
+
{
|
|
1123
|
+
"role": "system",
|
|
1124
|
+
"content": "You are an assistant that ranks candidate tools by relevance. Return a JSON array of tool ids sorted most->least relevant. Only output the JSON array.",
|
|
1125
|
+
},
|
|
1126
|
+
{
|
|
1127
|
+
"role": "user",
|
|
1128
|
+
"content": f"Instruction: {query.instruction}\n\nCandidates:\n"
|
|
1129
|
+
+ "\n".join(cand_infos),
|
|
1130
|
+
},
|
|
1131
|
+
]
|
|
1132
|
+
|
|
1133
|
+
# UnifiedInferenceClient.chat() returns string directly
|
|
1134
|
+
resp = llm.chat(messages)
|
|
1135
|
+
|
|
1136
|
+
# Parse response
|
|
1137
|
+
txt = resp if isinstance(resp, str) else str(resp)
|
|
1138
|
+
m = re.search(r"\[.*\]", txt, re.S)
|
|
1139
|
+
if not m:
|
|
1140
|
+
return None
|
|
1141
|
+
|
|
1142
|
+
try:
|
|
1143
|
+
ranked = json.loads(m.group(0))
|
|
1144
|
+
except Exception:
|
|
1145
|
+
return None
|
|
1146
|
+
|
|
1147
|
+
if not isinstance(ranked, list):
|
|
1148
|
+
return None
|
|
1149
|
+
|
|
1150
|
+
# Build final predictions
|
|
1151
|
+
id_to_pred = {p.tool_id: p for p in boosted}
|
|
1152
|
+
final_preds = []
|
|
1153
|
+
for tid in ranked:
|
|
1154
|
+
if tid in id_to_pred:
|
|
1155
|
+
final_preds.append(id_to_pred[tid])
|
|
1156
|
+
for p in boosted:
|
|
1157
|
+
if p.tool_id not in {fp.tool_id for fp in final_preds}:
|
|
1158
|
+
final_preds.append(p)
|
|
1159
|
+
|
|
1160
|
+
return final_preds[:top_k]
|
|
1161
|
+
|
|
1162
|
+
if resources is None:
|
|
1163
|
+
resources = self._create_mock_resources()
|
|
1164
|
+
|
|
1165
|
+
selector = InlineHybridSelector(resources)
|
|
1166
|
+
return SelectorAdapter(selector)
|
|
1167
|
+
|
|
1168
|
+
def _create_template_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
|
|
1169
|
+
"""Create template-based planner."""
|
|
1170
|
+
|
|
1171
|
+
# Return a simple mock planner for now
|
|
1172
|
+
class MockPlanner:
|
|
1173
|
+
def plan(self, task):
|
|
1174
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1175
|
+
PlanningPrediction,
|
|
1176
|
+
PlanStep,
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
return PlanningPrediction(
|
|
1180
|
+
steps=[
|
|
1181
|
+
PlanStep(
|
|
1182
|
+
step_id=0,
|
|
1183
|
+
description="Execute task",
|
|
1184
|
+
tool_id=task.available_tools[0] if task.available_tools else "unknown",
|
|
1185
|
+
confidence=0.5,
|
|
1186
|
+
)
|
|
1187
|
+
],
|
|
1188
|
+
tool_sequence=[task.available_tools[0] if task.available_tools else "unknown"],
|
|
1189
|
+
)
|
|
1190
|
+
|
|
1191
|
+
return PlannerAdapter(MockPlanner())
|
|
1192
|
+
|
|
1193
|
+
def _create_hierarchical_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
|
|
1194
|
+
"""Create hierarchical planner."""
|
|
1195
|
+
try:
|
|
1196
|
+
from sage.libs.agentic.agents.planning import HierarchicalPlanner
|
|
1197
|
+
|
|
1198
|
+
planner = HierarchicalPlanner()
|
|
1199
|
+
return PlannerAdapter(planner)
|
|
1200
|
+
except ImportError:
|
|
1201
|
+
return self._create_template_planner(resources)
|
|
1202
|
+
|
|
1203
|
+
def _create_react_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
|
|
1204
|
+
"""
|
|
1205
|
+
Create ReAct planner implementing Thought-Action-Observation loop.
|
|
1206
|
+
|
|
1207
|
+
ReAct (Reasoning + Acting) generates plans by interleaving:
|
|
1208
|
+
1. Thought: Reasoning about current state
|
|
1209
|
+
2. Action: Selecting tool to use
|
|
1210
|
+
3. Observation: Expected result (predicted in planning mode)
|
|
1211
|
+
|
|
1212
|
+
Reference: "ReAct: Synergizing Reasoning and Acting in Language Models" (Yao et al., 2023)
|
|
1213
|
+
"""
|
|
1214
|
+
|
|
1215
|
+
class ReActPlannerWrapper:
|
|
1216
|
+
"""Wrapper for ReAct planner with benchmark-compatible interface."""
|
|
1217
|
+
|
|
1218
|
+
def __init__(self):
|
|
1219
|
+
self._planner = None
|
|
1220
|
+
self._llm_client = None
|
|
1221
|
+
self._initialized = False
|
|
1222
|
+
|
|
1223
|
+
def _ensure_initialized(self):
|
|
1224
|
+
"""Lazy initialization of ReAct planner."""
|
|
1225
|
+
if self._initialized:
|
|
1226
|
+
return
|
|
1227
|
+
|
|
1228
|
+
self._initialized = True
|
|
1229
|
+
|
|
1230
|
+
try:
|
|
1231
|
+
from sage.libs.agentic.agents.planning import (
|
|
1232
|
+
ReActConfig,
|
|
1233
|
+
ReActPlanner,
|
|
1234
|
+
)
|
|
1235
|
+
|
|
1236
|
+
# Try to get LLM client
|
|
1237
|
+
llm_client = self._get_llm_client()
|
|
1238
|
+
|
|
1239
|
+
config = ReActConfig(
|
|
1240
|
+
min_steps=5,
|
|
1241
|
+
max_steps=10,
|
|
1242
|
+
max_iterations=12,
|
|
1243
|
+
temperature=BENCHMARK_LLM_TEMPERATURE,
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
self._planner = ReActPlanner(
|
|
1247
|
+
config=config,
|
|
1248
|
+
llm_client=llm_client,
|
|
1249
|
+
)
|
|
1250
|
+
except ImportError as e:
|
|
1251
|
+
import logging
|
|
1252
|
+
|
|
1253
|
+
logging.warning(f"ReActPlanner import failed: {e}")
|
|
1254
|
+
self._planner = None
|
|
1255
|
+
|
|
1256
|
+
def _get_llm_client(self):
|
|
1257
|
+
"""Get LLM client with local-first strategy."""
|
|
1258
|
+
# Use UnifiedInferenceClient.create() which implements local-first strategy
|
|
1259
|
+
try:
|
|
1260
|
+
from sage.llm import UnifiedInferenceClient
|
|
1261
|
+
|
|
1262
|
+
return UnifiedInferenceClient.create()
|
|
1263
|
+
except Exception:
|
|
1264
|
+
pass
|
|
1265
|
+
|
|
1266
|
+
return None
|
|
1267
|
+
|
|
1268
|
+
def plan(self, task):
|
|
1269
|
+
"""Generate plan using ReAct strategy."""
|
|
1270
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1271
|
+
PlanningPrediction,
|
|
1272
|
+
PlanStep,
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
self._ensure_initialized()
|
|
1276
|
+
|
|
1277
|
+
instruction = getattr(task, "instruction", "") or ""
|
|
1278
|
+
available_tools = getattr(task, "available_tools", []) or []
|
|
1279
|
+
|
|
1280
|
+
if not available_tools:
|
|
1281
|
+
return PlanningPrediction(steps=[], tool_sequence=[])
|
|
1282
|
+
|
|
1283
|
+
# Try ReAct planner if available
|
|
1284
|
+
if self._planner is not None:
|
|
1285
|
+
try:
|
|
1286
|
+
from sage.libs.agentic.agents.planning import (
|
|
1287
|
+
PlanRequest,
|
|
1288
|
+
ToolMetadata,
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
# Convert available tools to ToolMetadata
|
|
1292
|
+
tools = [
|
|
1293
|
+
ToolMetadata(
|
|
1294
|
+
tool_id=t,
|
|
1295
|
+
name=t,
|
|
1296
|
+
description=f"Tool: {t}",
|
|
1297
|
+
category="general",
|
|
1298
|
+
)
|
|
1299
|
+
for t in available_tools
|
|
1300
|
+
]
|
|
1301
|
+
|
|
1302
|
+
request = PlanRequest(
|
|
1303
|
+
goal=instruction,
|
|
1304
|
+
tools=tools,
|
|
1305
|
+
constraints=[],
|
|
1306
|
+
min_steps=5,
|
|
1307
|
+
max_steps=10,
|
|
1308
|
+
)
|
|
1309
|
+
|
|
1310
|
+
result = self._planner.plan(request)
|
|
1311
|
+
|
|
1312
|
+
if result.success and result.steps:
|
|
1313
|
+
steps = []
|
|
1314
|
+
tool_sequence = []
|
|
1315
|
+
|
|
1316
|
+
for i, step in enumerate(result.steps):
|
|
1317
|
+
tool_id = step.tool_id or step.action
|
|
1318
|
+
if tool_id in available_tools:
|
|
1319
|
+
steps.append(
|
|
1320
|
+
PlanStep(
|
|
1321
|
+
step_id=i,
|
|
1322
|
+
description=step.description or step.action,
|
|
1323
|
+
tool_id=tool_id,
|
|
1324
|
+
confidence=0.8,
|
|
1325
|
+
)
|
|
1326
|
+
)
|
|
1327
|
+
tool_sequence.append(tool_id)
|
|
1328
|
+
|
|
1329
|
+
if steps:
|
|
1330
|
+
return PlanningPrediction(
|
|
1331
|
+
steps=steps,
|
|
1332
|
+
tool_sequence=tool_sequence,
|
|
1333
|
+
)
|
|
1334
|
+
except Exception:
|
|
1335
|
+
pass
|
|
1336
|
+
|
|
1337
|
+
# Fallback: use heuristic-based planning
|
|
1338
|
+
return self._fallback_plan(instruction, available_tools)
|
|
1339
|
+
|
|
1340
|
+
def _fallback_plan(self, instruction: str, available_tools: list[str]):
|
|
1341
|
+
"""Heuristic-based fallback planning."""
|
|
1342
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1343
|
+
PlanningPrediction,
|
|
1344
|
+
PlanStep,
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1347
|
+
steps = []
|
|
1348
|
+
tool_sequence = []
|
|
1349
|
+
instruction_lower = instruction.lower()
|
|
1350
|
+
|
|
1351
|
+
# Score tools by relevance
|
|
1352
|
+
tool_scores = []
|
|
1353
|
+
for tool in available_tools:
|
|
1354
|
+
score = 0
|
|
1355
|
+
tool_lower = tool.lower()
|
|
1356
|
+
tool_words = set(tool_lower.replace("_", " ").split())
|
|
1357
|
+
instruction_words = set(instruction_lower.replace(",", " ").split())
|
|
1358
|
+
|
|
1359
|
+
# Word overlap
|
|
1360
|
+
overlap = len(tool_words & instruction_words)
|
|
1361
|
+
score += overlap * 2
|
|
1362
|
+
|
|
1363
|
+
# Action keyword matching
|
|
1364
|
+
if any(w in tool_lower for w in ["read", "get", "fetch", "load"]):
|
|
1365
|
+
if any(w in instruction_lower for w in ["read", "get", "load", "fetch"]):
|
|
1366
|
+
score += 2
|
|
1367
|
+
if any(w in tool_lower for w in ["write", "save", "send", "post"]):
|
|
1368
|
+
if any(w in instruction_lower for w in ["write", "save", "send", "post"]):
|
|
1369
|
+
score += 2
|
|
1370
|
+
if any(w in tool_lower for w in ["process", "transform", "convert"]):
|
|
1371
|
+
if any(w in instruction_lower for w in ["process", "convert", "transform"]):
|
|
1372
|
+
score += 2
|
|
1373
|
+
|
|
1374
|
+
tool_scores.append((tool, score))
|
|
1375
|
+
|
|
1376
|
+
tool_scores.sort(key=lambda x: x[1], reverse=True)
|
|
1377
|
+
|
|
1378
|
+
# Select top tools
|
|
1379
|
+
selected = [t for t, s in tool_scores[:8] if s > 0]
|
|
1380
|
+
if len(selected) < 5:
|
|
1381
|
+
for t, _ in tool_scores:
|
|
1382
|
+
if t not in selected:
|
|
1383
|
+
selected.append(t)
|
|
1384
|
+
if len(selected) >= 5:
|
|
1385
|
+
break
|
|
1386
|
+
|
|
1387
|
+
for i, tool in enumerate(selected[:10]):
|
|
1388
|
+
steps.append(
|
|
1389
|
+
PlanStep(
|
|
1390
|
+
step_id=i,
|
|
1391
|
+
description=f"ReAct step {i + 1}: Use {tool}",
|
|
1392
|
+
tool_id=tool,
|
|
1393
|
+
confidence=0.6,
|
|
1394
|
+
)
|
|
1395
|
+
)
|
|
1396
|
+
tool_sequence.append(tool)
|
|
1397
|
+
|
|
1398
|
+
return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
|
|
1399
|
+
|
|
1400
|
+
return PlannerAdapter(ReActPlannerWrapper())
|
|
1401
|
+
|
|
1402
|
+
def _create_sequence_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
|
|
1403
|
+
"""Create sequence-based planner using selector for tool ordering."""
|
|
1404
|
+
|
|
1405
|
+
class SequencePlanner:
|
|
1406
|
+
"""Plan by selecting tools in sequence based on task steps."""
|
|
1407
|
+
|
|
1408
|
+
def __init__(self, selector_factory):
|
|
1409
|
+
self._selector_factory = selector_factory
|
|
1410
|
+
|
|
1411
|
+
def plan(self, task):
|
|
1412
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1413
|
+
PlanningPrediction,
|
|
1414
|
+
PlanStep,
|
|
1415
|
+
)
|
|
1416
|
+
|
|
1417
|
+
# Use task instruction (not description) to select relevant tools
|
|
1418
|
+
instruction = (
|
|
1419
|
+
getattr(task, "instruction", "") or getattr(task, "description", "") or ""
|
|
1420
|
+
)
|
|
1421
|
+
steps = []
|
|
1422
|
+
tool_sequence = []
|
|
1423
|
+
|
|
1424
|
+
# Parse task for steps (simple heuristic)
|
|
1425
|
+
sub_tasks = [s.strip() for s in instruction.split(".") if s.strip()]
|
|
1426
|
+
|
|
1427
|
+
available = task.available_tools if task.available_tools else []
|
|
1428
|
+
|
|
1429
|
+
# Match each sub-task to a tool
|
|
1430
|
+
for i, sub_task in enumerate(sub_tasks[:5]): # Max 5 steps
|
|
1431
|
+
# Simple matching: pick tool with most keyword overlap
|
|
1432
|
+
best_tool = available[i % len(available)] if available else "unknown"
|
|
1433
|
+
|
|
1434
|
+
steps.append(
|
|
1435
|
+
PlanStep(
|
|
1436
|
+
step_id=i,
|
|
1437
|
+
description=sub_task,
|
|
1438
|
+
tool_id=best_tool,
|
|
1439
|
+
confidence=0.6,
|
|
1440
|
+
)
|
|
1441
|
+
)
|
|
1442
|
+
tool_sequence.append(best_tool)
|
|
1443
|
+
|
|
1444
|
+
if not steps and available:
|
|
1445
|
+
# Fallback: at least one step
|
|
1446
|
+
steps.append(
|
|
1447
|
+
PlanStep(
|
|
1448
|
+
step_id=0,
|
|
1449
|
+
description=instruction[:100] if instruction else "Execute task",
|
|
1450
|
+
tool_id=available[0],
|
|
1451
|
+
confidence=0.5,
|
|
1452
|
+
)
|
|
1453
|
+
)
|
|
1454
|
+
tool_sequence.append(available[0])
|
|
1455
|
+
|
|
1456
|
+
return PlanningPrediction(
|
|
1457
|
+
steps=steps,
|
|
1458
|
+
tool_sequence=tool_sequence,
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
return PlannerAdapter(SequencePlanner(self._create_keyword_selector))
|
|
1462
|
+
|
|
1463
|
+
def _create_threshold_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
|
|
1464
|
+
"""Create threshold-based timing decider."""
|
|
1465
|
+
|
|
1466
|
+
class ThresholdDecider:
|
|
1467
|
+
"""Simple keyword-based timing decider."""
|
|
1468
|
+
|
|
1469
|
+
# Keywords indicating tool invocation is needed
|
|
1470
|
+
ACTION_KEYWORDS = frozenset(
|
|
1471
|
+
[
|
|
1472
|
+
"search",
|
|
1473
|
+
"find",
|
|
1474
|
+
"calculate",
|
|
1475
|
+
"analyze",
|
|
1476
|
+
"create",
|
|
1477
|
+
"update",
|
|
1478
|
+
"delete",
|
|
1479
|
+
"get",
|
|
1480
|
+
"fetch",
|
|
1481
|
+
"query",
|
|
1482
|
+
"look up",
|
|
1483
|
+
"current",
|
|
1484
|
+
"now",
|
|
1485
|
+
"today",
|
|
1486
|
+
"latest",
|
|
1487
|
+
"real-time",
|
|
1488
|
+
"weather",
|
|
1489
|
+
"time",
|
|
1490
|
+
"stock",
|
|
1491
|
+
"price",
|
|
1492
|
+
"news",
|
|
1493
|
+
"check",
|
|
1494
|
+
"verify",
|
|
1495
|
+
"compare",
|
|
1496
|
+
"list",
|
|
1497
|
+
"show",
|
|
1498
|
+
]
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
# Keywords indicating direct answer (no tool needed)
|
|
1502
|
+
FACTUAL_KEYWORDS = frozenset(
|
|
1503
|
+
[
|
|
1504
|
+
"what is",
|
|
1505
|
+
"how many",
|
|
1506
|
+
"define",
|
|
1507
|
+
"explain",
|
|
1508
|
+
"meaning of",
|
|
1509
|
+
"who was",
|
|
1510
|
+
"when was",
|
|
1511
|
+
"where is",
|
|
1512
|
+
"+ ",
|
|
1513
|
+
"multiply",
|
|
1514
|
+
"capital of",
|
|
1515
|
+
"population of",
|
|
1516
|
+
"history of",
|
|
1517
|
+
]
|
|
1518
|
+
)
|
|
1519
|
+
|
|
1520
|
+
def decide(self, message):
|
|
1521
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1522
|
+
TimingDecision,
|
|
1523
|
+
)
|
|
1524
|
+
|
|
1525
|
+
text = message.message.lower()
|
|
1526
|
+
|
|
1527
|
+
# Check for action keywords
|
|
1528
|
+
has_action = any(kw in text for kw in self.ACTION_KEYWORDS)
|
|
1529
|
+
|
|
1530
|
+
# Check for factual (no-tool) keywords
|
|
1531
|
+
has_factual = any(kw in text for kw in self.FACTUAL_KEYWORDS)
|
|
1532
|
+
|
|
1533
|
+
# Heuristic: action keywords > factual keywords
|
|
1534
|
+
# Real-time or current info needs tools
|
|
1535
|
+
should_call = has_action and not (
|
|
1536
|
+
has_factual
|
|
1537
|
+
and not any(
|
|
1538
|
+
kw in text
|
|
1539
|
+
for kw in [
|
|
1540
|
+
"current",
|
|
1541
|
+
"now",
|
|
1542
|
+
"today",
|
|
1543
|
+
"latest",
|
|
1544
|
+
"real-time",
|
|
1545
|
+
"weather",
|
|
1546
|
+
"time",
|
|
1547
|
+
"stock",
|
|
1548
|
+
"news",
|
|
1549
|
+
]
|
|
1550
|
+
)
|
|
1551
|
+
)
|
|
1552
|
+
|
|
1553
|
+
return TimingDecision(
|
|
1554
|
+
should_call_tool=should_call,
|
|
1555
|
+
confidence=0.8 if (has_action or has_factual) else 0.5,
|
|
1556
|
+
reasoning=(
|
|
1557
|
+
"Detected action keywords"
|
|
1558
|
+
if should_call
|
|
1559
|
+
else "No action keywords or factual query"
|
|
1560
|
+
),
|
|
1561
|
+
)
|
|
1562
|
+
|
|
1563
|
+
return TimingAdapter(ThresholdDecider())
|
|
1564
|
+
|
|
1565
|
+
def _create_llm_timing_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
|
|
1566
|
+
"""Create LLM-based timing decider using UnifiedInferenceClient.
|
|
1567
|
+
|
|
1568
|
+
Uses UnifiedInferenceClient.create() which handles:
|
|
1569
|
+
1. Environment variables (SAGE_UNIFIED_BASE_URL)
|
|
1570
|
+
2. Local services (ports from SagePorts)
|
|
1571
|
+
3. Cloud API fallback (OpenAI-compatible)
|
|
1572
|
+
"""
|
|
1573
|
+
|
|
1574
|
+
class LLMTimingDecider:
|
|
1575
|
+
"""
|
|
1576
|
+
LLM-based timing decider using UnifiedInferenceClient.
|
|
1577
|
+
"""
|
|
1578
|
+
|
|
1579
|
+
TIMING_PROMPT = """You are an AI assistant that determines whether a user's message requires tool invocation or can be answered directly from your knowledge.
|
|
1580
|
+
|
|
1581
|
+
Analyze the following user message and determine:
|
|
1582
|
+
1. Does this message require real-time information (weather, stock prices, current time, etc.)?
|
|
1583
|
+
2. Does this message require performing an action (search, calculate, create file, send email, etc.)?
|
|
1584
|
+
3. Does this message require accessing external data or APIs?
|
|
1585
|
+
|
|
1586
|
+
If ANY of the above is true, the user needs a tool call.
|
|
1587
|
+
If the message is asking for factual knowledge, explanations, opinions, creative writing, or general conversation, it can be answered directly without tools.
|
|
1588
|
+
|
|
1589
|
+
User message: "{message}"
|
|
1590
|
+
|
|
1591
|
+
Respond in JSON format:
|
|
1592
|
+
{{
|
|
1593
|
+
"should_call_tool": true/false,
|
|
1594
|
+
"confidence": 0.0-1.0,
|
|
1595
|
+
"reasoning": "brief explanation"
|
|
1596
|
+
}}
|
|
1597
|
+
|
|
1598
|
+
Only output the JSON, nothing else."""
|
|
1599
|
+
|
|
1600
|
+
def __init__(self):
|
|
1601
|
+
self._client = None
|
|
1602
|
+
self._initialized = False
|
|
1603
|
+
|
|
1604
|
+
def _ensure_client(self):
|
|
1605
|
+
"""Lazy initialization using UnifiedInferenceClient."""
|
|
1606
|
+
if self._initialized:
|
|
1607
|
+
return self._client is not None
|
|
1608
|
+
|
|
1609
|
+
self._initialized = True
|
|
1610
|
+
|
|
1611
|
+
try:
|
|
1612
|
+
from sage.llm import UnifiedInferenceClient
|
|
1613
|
+
|
|
1614
|
+
self._client = UnifiedInferenceClient.create()
|
|
1615
|
+
return True
|
|
1616
|
+
except Exception as e:
|
|
1617
|
+
import logging
|
|
1618
|
+
|
|
1619
|
+
logging.getLogger(__name__).warning(
|
|
1620
|
+
f"Failed to initialize LLM client: {e}. Falling back to rule-based."
|
|
1621
|
+
)
|
|
1622
|
+
self._client = None
|
|
1623
|
+
return False
|
|
1624
|
+
|
|
1625
|
+
def decide(self, message):
|
|
1626
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1627
|
+
TimingDecision,
|
|
1628
|
+
)
|
|
1629
|
+
|
|
1630
|
+
# Ensure LLM client is initialized
|
|
1631
|
+
if not self._ensure_client():
|
|
1632
|
+
return self._fallback_decide(message)
|
|
1633
|
+
|
|
1634
|
+
try:
|
|
1635
|
+
import json
|
|
1636
|
+
|
|
1637
|
+
prompt = self.TIMING_PROMPT.format(message=message.message)
|
|
1638
|
+
messages = [{"role": "user", "content": prompt}]
|
|
1639
|
+
|
|
1640
|
+
# UnifiedInferenceClient.chat() returns string directly
|
|
1641
|
+
content = self._client.chat(messages)
|
|
1642
|
+
|
|
1643
|
+
# Parse JSON response
|
|
1644
|
+
# Handle potential markdown code blocks
|
|
1645
|
+
if content.startswith("```"):
|
|
1646
|
+
content = content.split("```")[1]
|
|
1647
|
+
if content.startswith("json"):
|
|
1648
|
+
content = content[4:]
|
|
1649
|
+
content = content.strip()
|
|
1650
|
+
|
|
1651
|
+
result = json.loads(content)
|
|
1652
|
+
return TimingDecision(
|
|
1653
|
+
should_call_tool=result.get("should_call_tool", False),
|
|
1654
|
+
confidence=float(result.get("confidence", 0.7)),
|
|
1655
|
+
reasoning=f"[LLM] {result.get('reasoning', 'LLM decision')}",
|
|
1656
|
+
)
|
|
1657
|
+
except Exception:
|
|
1658
|
+
return self._fallback_decide(message)
|
|
1659
|
+
|
|
1660
|
+
def _fallback_decide(self, message):
|
|
1661
|
+
"""Fallback to simple rule-based for failed LLM calls."""
|
|
1662
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1663
|
+
TimingDecision,
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
text = message.message.lower()
|
|
1667
|
+
action_keywords = [
|
|
1668
|
+
"search",
|
|
1669
|
+
"find",
|
|
1670
|
+
"calculate",
|
|
1671
|
+
"weather",
|
|
1672
|
+
"stock",
|
|
1673
|
+
"price",
|
|
1674
|
+
"current",
|
|
1675
|
+
"now",
|
|
1676
|
+
"today",
|
|
1677
|
+
"create",
|
|
1678
|
+
"send",
|
|
1679
|
+
"schedule",
|
|
1680
|
+
]
|
|
1681
|
+
has_action = any(kw in text for kw in action_keywords)
|
|
1682
|
+
|
|
1683
|
+
return TimingDecision(
|
|
1684
|
+
should_call_tool=has_action,
|
|
1685
|
+
confidence=0.6,
|
|
1686
|
+
reasoning="[Fallback] Simple keyword match",
|
|
1687
|
+
)
|
|
1688
|
+
|
|
1689
|
+
return TimingAdapter(LLMTimingDecider())
|
|
1690
|
+
|
|
1691
|
+
def _create_rule_based_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
|
|
1692
|
+
"""Create rule-based timing decider with comprehensive keyword matching."""
|
|
1693
|
+
|
|
1694
|
+
class RuleBasedDecider:
|
|
1695
|
+
"""
|
|
1696
|
+
Enhanced rule-based timing decider v2.
|
|
1697
|
+
|
|
1698
|
+
Uses multi-layer pattern matching and keyword analysis to determine
|
|
1699
|
+
if a message requires tool invocation.
|
|
1700
|
+
|
|
1701
|
+
Improvements over v1:
|
|
1702
|
+
- Added patterns for prediction/forecast queries (stock prices, future events)
|
|
1703
|
+
- Added patterns for data lookup queries (calories, nutritional info)
|
|
1704
|
+
- Added NO_TOOL patterns for advice/opinion questions
|
|
1705
|
+
- Improved confidence scoring with weighted categories
|
|
1706
|
+
- Better handling of edge cases (e.g., "should I" questions)
|
|
1707
|
+
"""
|
|
1708
|
+
|
|
1709
|
+
# Patterns that strongly indicate tool invocation is needed
|
|
1710
|
+
TOOL_REQUIRED_PATTERNS = [
|
|
1711
|
+
# Real-time information
|
|
1712
|
+
r"\b(current|real-time|live|now|today|right now)\b.*\b(weather|temperature|stock|price|news|time|traffic)\b",
|
|
1713
|
+
r"\bwhat('s| is) the (current|latest)\b",
|
|
1714
|
+
r"\b(weather|temperature|forecast)\s+(in|for|at)\b",
|
|
1715
|
+
r"\b(stock|share) (price|value|chart)\b",
|
|
1716
|
+
# Time queries
|
|
1717
|
+
r"\bwhat('s| is) the\s+(current\s+)?time\s+(in|at)\b",
|
|
1718
|
+
r"\bwhat time is it in\b",
|
|
1719
|
+
# Prediction/Forecast queries (needs tool for data)
|
|
1720
|
+
r"\bwhat will\b.*\b(price|stock|weather|be)\b.*\b(next|tomorrow|week|month)\b",
|
|
1721
|
+
r"\bwill it\b.*\b(rain|snow|be sunny|be cold|be hot)\b",
|
|
1722
|
+
r"\b(forecast|predict|projection)\s+for\b",
|
|
1723
|
+
# Actions and operations
|
|
1724
|
+
r"\b(search|find|look up|fetch|retrieve|get|query)\s+(for|the|about)?\b",
|
|
1725
|
+
r"\b(calculate|compute|convert)\s+",
|
|
1726
|
+
r"\b(create|generate|make|build)\s+(a|an|the)?\s*(file|document|spreadsheet|chart|report)\b",
|
|
1727
|
+
r"\b(send|email|schedule|book|reserve|cancel)\s+",
|
|
1728
|
+
# File operations (enhanced patterns)
|
|
1729
|
+
r"\b(open|close|save|delete|rename|move|copy)\s+(the|a|this)?\s*(file|document|report|spreadsheet)\b",
|
|
1730
|
+
r"\bsave\s+(this|the|a)\s+(document|file|data)\b",
|
|
1731
|
+
r"\bsave\s+.*\s+as\b", # "save X as Y"
|
|
1732
|
+
r"\bopen\s+(the|a)\s+file\b",
|
|
1733
|
+
r"\bdelete\s+(the|a)\s+file\b",
|
|
1734
|
+
# Code execution (enhanced patterns)
|
|
1735
|
+
r"\b(run|execute|compile|debug|test)\s+(this|the|a)?\s*(code|script|program|command|python|javascript)\b",
|
|
1736
|
+
r"\brun\s+this\s+\w+\s+code\b", # "run this Python code"
|
|
1737
|
+
r"\bexecute\s+(this|the)\s+(script|code|program)\b",
|
|
1738
|
+
# Database/API operations
|
|
1739
|
+
r"\b(select|insert|update|delete)\s+.*\b(from|into|where)\b",
|
|
1740
|
+
r"\bapi\s+(call|request|endpoint)\b",
|
|
1741
|
+
# Data lookup queries (needs external database)
|
|
1742
|
+
r"\bhow many calories\b",
|
|
1743
|
+
r"\b(calories|carbs|protein|fat|nutrition)\s+(in|of)\b",
|
|
1744
|
+
r"\b(nutritional|nutrition)\s+(info|information|value|data)\b",
|
|
1745
|
+
r"\bexchange rate\s+(for|of|between)\b",
|
|
1746
|
+
r"\bconvert\s+\d+\s*\w+\s+to\b",
|
|
1747
|
+
]
|
|
1748
|
+
|
|
1749
|
+
# Patterns that indicate NO tool needed (advice, opinion, philosophical)
|
|
1750
|
+
NO_TOOL_PATTERNS = [
|
|
1751
|
+
# Advice/Opinion questions
|
|
1752
|
+
r"\bshould i\b",
|
|
1753
|
+
r"\bwhat do you think\b",
|
|
1754
|
+
r"\bwhat('s| is) your opinion\b",
|
|
1755
|
+
r"\bdo you recommend\b",
|
|
1756
|
+
r"\bany (tips|advice|suggestions)\b",
|
|
1757
|
+
r"\bis it (a good|worth|better)\b",
|
|
1758
|
+
# Personal/Philosophical questions
|
|
1759
|
+
r"\bwhat('s| is) the meaning of life\b",
|
|
1760
|
+
r"\bwhy (do|should) (we|i|people)\b",
|
|
1761
|
+
r"\bhow (do|can) i (feel|cope|deal)\b",
|
|
1762
|
+
]
|
|
1763
|
+
|
|
1764
|
+
# Keywords indicating tool invocation (with weights)
|
|
1765
|
+
TOOL_KEYWORDS = frozenset(
|
|
1766
|
+
[
|
|
1767
|
+
# Search/Retrieve
|
|
1768
|
+
"search",
|
|
1769
|
+
"find",
|
|
1770
|
+
"look up",
|
|
1771
|
+
"lookup",
|
|
1772
|
+
"fetch",
|
|
1773
|
+
"retrieve",
|
|
1774
|
+
"query",
|
|
1775
|
+
# Actions
|
|
1776
|
+
"calculate",
|
|
1777
|
+
"compute",
|
|
1778
|
+
"convert",
|
|
1779
|
+
"translate",
|
|
1780
|
+
"analyze",
|
|
1781
|
+
# CRUD operations
|
|
1782
|
+
"create",
|
|
1783
|
+
"update",
|
|
1784
|
+
"delete",
|
|
1785
|
+
"modify",
|
|
1786
|
+
"edit",
|
|
1787
|
+
# Real-time indicators
|
|
1788
|
+
"current",
|
|
1789
|
+
"live",
|
|
1790
|
+
"real-time",
|
|
1791
|
+
"realtime",
|
|
1792
|
+
"latest",
|
|
1793
|
+
"now",
|
|
1794
|
+
"today",
|
|
1795
|
+
"right now",
|
|
1796
|
+
# Specific domains requiring tools
|
|
1797
|
+
"weather",
|
|
1798
|
+
"stock",
|
|
1799
|
+
"price",
|
|
1800
|
+
"exchange rate",
|
|
1801
|
+
"traffic",
|
|
1802
|
+
"flight",
|
|
1803
|
+
"news",
|
|
1804
|
+
"calories",
|
|
1805
|
+
"nutritional",
|
|
1806
|
+
# File/System operations
|
|
1807
|
+
"open",
|
|
1808
|
+
"save",
|
|
1809
|
+
"download",
|
|
1810
|
+
"upload",
|
|
1811
|
+
"export",
|
|
1812
|
+
"import",
|
|
1813
|
+
# Scheduling
|
|
1814
|
+
"schedule",
|
|
1815
|
+
"book",
|
|
1816
|
+
"reserve",
|
|
1817
|
+
"remind",
|
|
1818
|
+
"alarm",
|
|
1819
|
+
# Communication
|
|
1820
|
+
"send",
|
|
1821
|
+
"email",
|
|
1822
|
+
"message",
|
|
1823
|
+
"notify",
|
|
1824
|
+
"call",
|
|
1825
|
+
# Code execution
|
|
1826
|
+
"run",
|
|
1827
|
+
"execute",
|
|
1828
|
+
"compile",
|
|
1829
|
+
"debug",
|
|
1830
|
+
]
|
|
1831
|
+
)
|
|
1832
|
+
|
|
1833
|
+
# High-weight keywords that strongly suggest tool usage
|
|
1834
|
+
HIGH_WEIGHT_TOOL_KEYWORDS = frozenset(
|
|
1835
|
+
[
|
|
1836
|
+
"search",
|
|
1837
|
+
"calculate",
|
|
1838
|
+
"weather",
|
|
1839
|
+
"stock",
|
|
1840
|
+
"price",
|
|
1841
|
+
"calories",
|
|
1842
|
+
"execute",
|
|
1843
|
+
"run code",
|
|
1844
|
+
"compile",
|
|
1845
|
+
"exchange rate",
|
|
1846
|
+
"schedule",
|
|
1847
|
+
"book",
|
|
1848
|
+
"send email",
|
|
1849
|
+
]
|
|
1850
|
+
)
|
|
1851
|
+
|
|
1852
|
+
# Keywords indicating NO tool needed (direct answer)
|
|
1853
|
+
NO_TOOL_KEYWORDS = frozenset(
|
|
1854
|
+
[
|
|
1855
|
+
# Definitions
|
|
1856
|
+
"what is",
|
|
1857
|
+
"what are",
|
|
1858
|
+
"define",
|
|
1859
|
+
"definition",
|
|
1860
|
+
"meaning of",
|
|
1861
|
+
"explain",
|
|
1862
|
+
"describe",
|
|
1863
|
+
# Factual (static knowledge)
|
|
1864
|
+
"who was",
|
|
1865
|
+
"who is",
|
|
1866
|
+
"who invented",
|
|
1867
|
+
"who wrote",
|
|
1868
|
+
"when was",
|
|
1869
|
+
"when did",
|
|
1870
|
+
"where is",
|
|
1871
|
+
"where was",
|
|
1872
|
+
"capital of",
|
|
1873
|
+
"population of",
|
|
1874
|
+
"history of",
|
|
1875
|
+
# Knowledge/Educational
|
|
1876
|
+
"how does",
|
|
1877
|
+
"how do",
|
|
1878
|
+
"why do",
|
|
1879
|
+
"why does",
|
|
1880
|
+
"what causes",
|
|
1881
|
+
# Conversational
|
|
1882
|
+
"hello",
|
|
1883
|
+
"hi",
|
|
1884
|
+
"thanks",
|
|
1885
|
+
"thank you",
|
|
1886
|
+
"goodbye",
|
|
1887
|
+
"bye",
|
|
1888
|
+
# Opinion/Advice (important: these do NOT need tools)
|
|
1889
|
+
"what do you think",
|
|
1890
|
+
"your opinion",
|
|
1891
|
+
"any tips",
|
|
1892
|
+
"advice",
|
|
1893
|
+
"suggest",
|
|
1894
|
+
"recommend",
|
|
1895
|
+
"should i",
|
|
1896
|
+
"is it worth",
|
|
1897
|
+
"is it a good idea",
|
|
1898
|
+
"pros and cons",
|
|
1899
|
+
# Creative writing
|
|
1900
|
+
"write a",
|
|
1901
|
+
"compose",
|
|
1902
|
+
"create a story",
|
|
1903
|
+
"create a poem",
|
|
1904
|
+
"tell me a story",
|
|
1905
|
+
# Math/Science knowledge (not calculations)
|
|
1906
|
+
"pythagorean theorem",
|
|
1907
|
+
"quadratic formula",
|
|
1908
|
+
"what is pi",
|
|
1909
|
+
]
|
|
1910
|
+
)
|
|
1911
|
+
|
|
1912
|
+
# High-weight keywords that strongly suggest NO tool
|
|
1913
|
+
HIGH_WEIGHT_NO_TOOL_KEYWORDS = frozenset(
|
|
1914
|
+
[
|
|
1915
|
+
"should i",
|
|
1916
|
+
"what do you think",
|
|
1917
|
+
"your opinion",
|
|
1918
|
+
"recommend",
|
|
1919
|
+
"advice",
|
|
1920
|
+
"capital of",
|
|
1921
|
+
"who invented",
|
|
1922
|
+
"who wrote",
|
|
1923
|
+
"explain",
|
|
1924
|
+
"define",
|
|
1925
|
+
]
|
|
1926
|
+
)
|
|
1927
|
+
|
|
1928
|
+
def __init__(self):
|
|
1929
|
+
import re
|
|
1930
|
+
|
|
1931
|
+
self._compiled_tool_patterns = [
|
|
1932
|
+
re.compile(p, re.IGNORECASE) for p in self.TOOL_REQUIRED_PATTERNS
|
|
1933
|
+
]
|
|
1934
|
+
self._compiled_no_tool_patterns = [
|
|
1935
|
+
re.compile(p, re.IGNORECASE) for p in self.NO_TOOL_PATTERNS
|
|
1936
|
+
]
|
|
1937
|
+
|
|
1938
|
+
def decide(self, message):
|
|
1939
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
1940
|
+
TimingDecision,
|
|
1941
|
+
)
|
|
1942
|
+
|
|
1943
|
+
text = message.message.lower()
|
|
1944
|
+
|
|
1945
|
+
# Priority 1: Check for strong NO-TOOL patterns first (advice/opinion)
|
|
1946
|
+
# These override tool patterns because the question is fundamentally
|
|
1947
|
+
# asking for advice, not data retrieval
|
|
1948
|
+
no_tool_pattern_match = any(p.search(text) for p in self._compiled_no_tool_patterns)
|
|
1949
|
+
if no_tool_pattern_match:
|
|
1950
|
+
# Even if there are tool-like keywords, the core question is advice
|
|
1951
|
+
return TimingDecision(
|
|
1952
|
+
should_call_tool=False,
|
|
1953
|
+
confidence=0.9,
|
|
1954
|
+
reasoning="Strong pattern match for advice/opinion question (no tool needed)",
|
|
1955
|
+
)
|
|
1956
|
+
|
|
1957
|
+
# Priority 2: Check for strong TOOL patterns
|
|
1958
|
+
tool_pattern_match = any(p.search(text) for p in self._compiled_tool_patterns)
|
|
1959
|
+
if tool_pattern_match:
|
|
1960
|
+
return TimingDecision(
|
|
1961
|
+
should_call_tool=True,
|
|
1962
|
+
confidence=0.95,
|
|
1963
|
+
reasoning="Strong pattern match for tool invocation",
|
|
1964
|
+
)
|
|
1965
|
+
|
|
1966
|
+
# Priority 3: Weighted keyword analysis
|
|
1967
|
+
# Calculate weighted scores
|
|
1968
|
+
tool_score = sum(1 for kw in self.TOOL_KEYWORDS if kw in text)
|
|
1969
|
+
high_tool_score = sum(2 for kw in self.HIGH_WEIGHT_TOOL_KEYWORDS if kw in text)
|
|
1970
|
+
total_tool_score = tool_score + high_tool_score
|
|
1971
|
+
|
|
1972
|
+
no_tool_score = sum(1 for kw in self.NO_TOOL_KEYWORDS if kw in text)
|
|
1973
|
+
high_no_tool_score = sum(
|
|
1974
|
+
2 for kw in self.HIGH_WEIGHT_NO_TOOL_KEYWORDS if kw in text
|
|
1975
|
+
)
|
|
1976
|
+
total_no_tool_score = no_tool_score + high_no_tool_score
|
|
1977
|
+
|
|
1978
|
+
# Decision based on weighted scores
|
|
1979
|
+
if total_tool_score > 0 and total_no_tool_score == 0:
|
|
1980
|
+
confidence = min(0.7 + total_tool_score * 0.08, 0.95)
|
|
1981
|
+
return TimingDecision(
|
|
1982
|
+
should_call_tool=True,
|
|
1983
|
+
confidence=confidence,
|
|
1984
|
+
reasoning=f"Tool keywords detected (score: {total_tool_score})",
|
|
1985
|
+
)
|
|
1986
|
+
elif total_no_tool_score > 0 and total_tool_score == 0:
|
|
1987
|
+
confidence = min(0.7 + total_no_tool_score * 0.08, 0.95)
|
|
1988
|
+
return TimingDecision(
|
|
1989
|
+
should_call_tool=False,
|
|
1990
|
+
confidence=confidence,
|
|
1991
|
+
reasoning=f"No-tool keywords detected (score: {total_no_tool_score})",
|
|
1992
|
+
)
|
|
1993
|
+
elif total_tool_score > total_no_tool_score:
|
|
1994
|
+
score_diff = total_tool_score - total_no_tool_score
|
|
1995
|
+
confidence = min(0.55 + score_diff * 0.08, 0.85)
|
|
1996
|
+
return TimingDecision(
|
|
1997
|
+
should_call_tool=True,
|
|
1998
|
+
confidence=confidence,
|
|
1999
|
+
reasoning=f"More tool keywords ({total_tool_score} vs {total_no_tool_score})",
|
|
2000
|
+
)
|
|
2001
|
+
elif total_no_tool_score > total_tool_score:
|
|
2002
|
+
score_diff = total_no_tool_score - total_tool_score
|
|
2003
|
+
confidence = min(0.55 + score_diff * 0.08, 0.85)
|
|
2004
|
+
return TimingDecision(
|
|
2005
|
+
should_call_tool=False,
|
|
2006
|
+
confidence=confidence,
|
|
2007
|
+
reasoning=f"More no-tool keywords ({total_no_tool_score} vs {total_tool_score})",
|
|
2008
|
+
)
|
|
2009
|
+
else:
|
|
2010
|
+
# Ambiguous case: use heuristics
|
|
2011
|
+
# Check for question words that might indicate knowledge queries
|
|
2012
|
+
knowledge_indicators = ["what is", "who is", "where is", "when was"]
|
|
2013
|
+
is_knowledge_query = any(ind in text for ind in knowledge_indicators)
|
|
2014
|
+
|
|
2015
|
+
if is_knowledge_query:
|
|
2016
|
+
return TimingDecision(
|
|
2017
|
+
should_call_tool=False,
|
|
2018
|
+
confidence=0.55,
|
|
2019
|
+
reasoning="Ambiguous - appears to be knowledge query, defaulting to no tool",
|
|
2020
|
+
)
|
|
2021
|
+
|
|
2022
|
+
# Default: assume no tool needed for truly ambiguous cases
|
|
2023
|
+
return TimingDecision(
|
|
2024
|
+
should_call_tool=False,
|
|
2025
|
+
confidence=0.5,
|
|
2026
|
+
reasoning="Ambiguous - defaulting to no tool",
|
|
2027
|
+
)
|
|
2028
|
+
|
|
2029
|
+
return TimingAdapter(RuleBasedDecider())
|
|
2030
|
+
|
|
2031
|
+
def _create_hybrid_timing_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
|
|
2032
|
+
"""Create hybrid timing decider combining rule-based and LLM-based approaches."""
|
|
2033
|
+
|
|
2034
|
+
class HybridDecider:
|
|
2035
|
+
"""
|
|
2036
|
+
Hybrid timing decider.
|
|
2037
|
+
|
|
2038
|
+
Uses rule-based detection first for high-confidence cases,
|
|
2039
|
+
falls back to LLM for ambiguous cases (if available).
|
|
2040
|
+
"""
|
|
2041
|
+
|
|
2042
|
+
def __init__(self, rule_decider, llm_decider=None, confidence_threshold=0.7):
|
|
2043
|
+
self._rule_decider = rule_decider
|
|
2044
|
+
self._llm_decider = llm_decider
|
|
2045
|
+
self._threshold = confidence_threshold
|
|
2046
|
+
|
|
2047
|
+
def decide(self, message):
|
|
2048
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
2049
|
+
TimingDecision,
|
|
2050
|
+
)
|
|
2051
|
+
|
|
2052
|
+
# First try rule-based
|
|
2053
|
+
rule_decision = self._rule_decider.decide(message)
|
|
2054
|
+
|
|
2055
|
+
# If high confidence, return rule-based result
|
|
2056
|
+
if rule_decision.confidence >= self._threshold:
|
|
2057
|
+
return TimingDecision(
|
|
2058
|
+
should_call_tool=rule_decision.should_call_tool,
|
|
2059
|
+
confidence=rule_decision.confidence,
|
|
2060
|
+
reasoning=f"[Rule-based] {rule_decision.reasoning}",
|
|
2061
|
+
)
|
|
2062
|
+
|
|
2063
|
+
# For low confidence, try LLM if available
|
|
2064
|
+
if self._llm_decider is not None:
|
|
2065
|
+
try:
|
|
2066
|
+
llm_decision = self._llm_decider.decide(message)
|
|
2067
|
+
# Combine results
|
|
2068
|
+
if rule_decision.should_call_tool == llm_decision.should_call_tool:
|
|
2069
|
+
# Agreement - higher confidence
|
|
2070
|
+
combined_conf = min(
|
|
2071
|
+
(rule_decision.confidence + llm_decision.confidence) / 2 + 0.1, 1.0
|
|
2072
|
+
)
|
|
2073
|
+
return TimingDecision(
|
|
2074
|
+
should_call_tool=rule_decision.should_call_tool,
|
|
2075
|
+
confidence=combined_conf,
|
|
2076
|
+
reasoning=f"[Hybrid-agree] Rule: {rule_decision.reasoning}, LLM: {llm_decision.reasoning}",
|
|
2077
|
+
)
|
|
2078
|
+
else:
|
|
2079
|
+
# Disagreement - prefer higher confidence
|
|
2080
|
+
if llm_decision.confidence > rule_decision.confidence:
|
|
2081
|
+
return TimingDecision(
|
|
2082
|
+
should_call_tool=llm_decision.should_call_tool,
|
|
2083
|
+
confidence=llm_decision.confidence * 0.9,
|
|
2084
|
+
reasoning=f"[Hybrid-LLM] {llm_decision.reasoning}",
|
|
2085
|
+
)
|
|
2086
|
+
else:
|
|
2087
|
+
return TimingDecision(
|
|
2088
|
+
should_call_tool=rule_decision.should_call_tool,
|
|
2089
|
+
confidence=rule_decision.confidence * 0.9,
|
|
2090
|
+
reasoning=f"[Hybrid-Rule] {rule_decision.reasoning}",
|
|
2091
|
+
)
|
|
2092
|
+
except Exception:
|
|
2093
|
+
pass # Fall back to rule-based
|
|
2094
|
+
|
|
2095
|
+
# Return rule-based result
|
|
2096
|
+
return TimingDecision(
|
|
2097
|
+
should_call_tool=rule_decision.should_call_tool,
|
|
2098
|
+
confidence=rule_decision.confidence,
|
|
2099
|
+
reasoning=f"[Rule-fallback] {rule_decision.reasoning}",
|
|
2100
|
+
)
|
|
2101
|
+
|
|
2102
|
+
# Create rule-based decider
|
|
2103
|
+
rule_adapter = self._create_rule_based_decider(resources)
|
|
2104
|
+
rule_decider = rule_adapter.decider
|
|
2105
|
+
|
|
2106
|
+
# Try to create LLM decider
|
|
2107
|
+
llm_decider = None
|
|
2108
|
+
try:
|
|
2109
|
+
llm_adapter = self._create_llm_timing_decider(resources)
|
|
2110
|
+
llm_decider = llm_adapter.decider
|
|
2111
|
+
except Exception:
|
|
2112
|
+
pass
|
|
2113
|
+
|
|
2114
|
+
return TimingAdapter(HybridDecider(rule_decider, llm_decider))
|
|
2115
|
+
|
|
2116
|
+
def _create_embedding_timing_decider(self, resources: Optional[Any] = None) -> TimingAdapter:
|
|
2117
|
+
"""Create embedding-based timing decider using SAGE's EmbeddingService."""
|
|
2118
|
+
|
|
2119
|
+
class EmbeddingTimingDecider:
|
|
2120
|
+
"""
|
|
2121
|
+
Embedding-based timing decider using semantic similarity.
|
|
2122
|
+
|
|
2123
|
+
Uses pre-computed embeddings of typical "tool-needed" and "no-tool-needed"
|
|
2124
|
+
messages to classify new queries via cosine similarity.
|
|
2125
|
+
"""
|
|
2126
|
+
|
|
2127
|
+
# Representative examples for each class
|
|
2128
|
+
TOOL_NEEDED_EXAMPLES = [
|
|
2129
|
+
"What's the weather like in New York right now?",
|
|
2130
|
+
"Search for the latest news about AI",
|
|
2131
|
+
"Calculate the compound interest on $10000",
|
|
2132
|
+
"What's the current stock price of AAPL?",
|
|
2133
|
+
"Send an email to my team",
|
|
2134
|
+
"Create a new spreadsheet with sales data",
|
|
2135
|
+
"What time is it in Tokyo?",
|
|
2136
|
+
"Book a flight from NYC to London",
|
|
2137
|
+
"Find restaurants near me",
|
|
2138
|
+
"Download the latest report",
|
|
2139
|
+
]
|
|
2140
|
+
|
|
2141
|
+
NO_TOOL_EXAMPLES = [
|
|
2142
|
+
"What is the capital of France?",
|
|
2143
|
+
"Explain quantum computing to me",
|
|
2144
|
+
"Who invented the telephone?",
|
|
2145
|
+
"What does photosynthesis mean?",
|
|
2146
|
+
"Write a poem about nature",
|
|
2147
|
+
"What are the pros and cons of remote work?",
|
|
2148
|
+
"How does machine learning work?",
|
|
2149
|
+
"Tell me a story about a brave knight",
|
|
2150
|
+
"What is 2 + 2?",
|
|
2151
|
+
"Thank you for your help!",
|
|
2152
|
+
]
|
|
2153
|
+
|
|
2154
|
+
def __init__(self):
|
|
2155
|
+
self._embedder = None
|
|
2156
|
+
self._tool_needed_embeddings = None
|
|
2157
|
+
self._no_tool_embeddings = None
|
|
2158
|
+
self._initialized = False
|
|
2159
|
+
|
|
2160
|
+
def _ensure_initialized(self):
|
|
2161
|
+
"""Lazy initialization of embedder and example embeddings."""
|
|
2162
|
+
if self._initialized:
|
|
2163
|
+
return self._embedder is not None
|
|
2164
|
+
|
|
2165
|
+
self._initialized = True
|
|
2166
|
+
try:
|
|
2167
|
+
import os
|
|
2168
|
+
|
|
2169
|
+
from sage.common.components.sage_embedding import (
|
|
2170
|
+
get_embedding_model,
|
|
2171
|
+
)
|
|
2172
|
+
|
|
2173
|
+
# Choose embedding method based on environment
|
|
2174
|
+
method = os.getenv("SAGE_EMBEDDING_METHOD", "hash")
|
|
2175
|
+
|
|
2176
|
+
if method == "hf":
|
|
2177
|
+
try:
|
|
2178
|
+
self._embedder = get_embedding_model(
|
|
2179
|
+
"hf", model=BENCHMARK_EMBEDDING_MODEL
|
|
2180
|
+
)
|
|
2181
|
+
except Exception:
|
|
2182
|
+
self._embedder = get_embedding_model("hash", dim=384)
|
|
2183
|
+
else:
|
|
2184
|
+
self._embedder = get_embedding_model("hash", dim=384)
|
|
2185
|
+
|
|
2186
|
+
# Pre-compute example embeddings
|
|
2187
|
+
self._tool_needed_embeddings = [
|
|
2188
|
+
self._embedder.embed(ex) for ex in self.TOOL_NEEDED_EXAMPLES
|
|
2189
|
+
]
|
|
2190
|
+
self._no_tool_embeddings = [
|
|
2191
|
+
self._embedder.embed(ex) for ex in self.NO_TOOL_EXAMPLES
|
|
2192
|
+
]
|
|
2193
|
+
return True
|
|
2194
|
+
except Exception as e:
|
|
2195
|
+
import logging
|
|
2196
|
+
|
|
2197
|
+
logging.getLogger(__name__).warning(
|
|
2198
|
+
f"Failed to initialize embedding decider: {e}"
|
|
2199
|
+
)
|
|
2200
|
+
self._embedder = None
|
|
2201
|
+
return False
|
|
2202
|
+
|
|
2203
|
+
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
|
2204
|
+
"""Compute cosine similarity between two vectors."""
|
|
2205
|
+
import math
|
|
2206
|
+
|
|
2207
|
+
dot = sum(a * b for a, b in zip(vec1, vec2))
|
|
2208
|
+
norm1 = math.sqrt(sum(a * a for a in vec1))
|
|
2209
|
+
norm2 = math.sqrt(sum(b * b for b in vec2))
|
|
2210
|
+
if norm1 == 0 or norm2 == 0:
|
|
2211
|
+
return 0.0
|
|
2212
|
+
return dot / (norm1 * norm2)
|
|
2213
|
+
|
|
2214
|
+
def decide(self, message):
|
|
2215
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
2216
|
+
TimingDecision,
|
|
2217
|
+
)
|
|
2218
|
+
|
|
2219
|
+
if not self._ensure_initialized() or self._embedder is None:
|
|
2220
|
+
# Fallback to simple rule-based
|
|
2221
|
+
text = message.message.lower()
|
|
2222
|
+
action_kws = ["search", "find", "weather", "stock", "current", "now"]
|
|
2223
|
+
has_action = any(kw in text for kw in action_kws)
|
|
2224
|
+
return TimingDecision(
|
|
2225
|
+
should_call_tool=has_action,
|
|
2226
|
+
confidence=0.5,
|
|
2227
|
+
reasoning="[Fallback] Simple keyword match",
|
|
2228
|
+
)
|
|
2229
|
+
|
|
2230
|
+
try:
|
|
2231
|
+
# Embed the query
|
|
2232
|
+
query_embedding = self._embedder.embed(message.message)
|
|
2233
|
+
|
|
2234
|
+
# Compute average similarity to each class
|
|
2235
|
+
tool_sims = [
|
|
2236
|
+
self._cosine_similarity(query_embedding, ex_emb)
|
|
2237
|
+
for ex_emb in self._tool_needed_embeddings
|
|
2238
|
+
]
|
|
2239
|
+
no_tool_sims = [
|
|
2240
|
+
self._cosine_similarity(query_embedding, ex_emb)
|
|
2241
|
+
for ex_emb in self._no_tool_embeddings
|
|
2242
|
+
]
|
|
2243
|
+
|
|
2244
|
+
avg_tool_sim = sum(tool_sims) / len(tool_sims)
|
|
2245
|
+
avg_no_tool_sim = sum(no_tool_sims) / len(no_tool_sims)
|
|
2246
|
+
|
|
2247
|
+
# Decision based on higher average similarity
|
|
2248
|
+
should_call = avg_tool_sim > avg_no_tool_sim
|
|
2249
|
+
confidence = abs(avg_tool_sim - avg_no_tool_sim) + 0.5
|
|
2250
|
+
confidence = min(max(confidence, 0.5), 0.95)
|
|
2251
|
+
|
|
2252
|
+
return TimingDecision(
|
|
2253
|
+
should_call_tool=should_call,
|
|
2254
|
+
confidence=confidence,
|
|
2255
|
+
reasoning=f"[Embedding] tool_sim={avg_tool_sim:.3f}, no_tool_sim={avg_no_tool_sim:.3f}",
|
|
2256
|
+
)
|
|
2257
|
+
except Exception as e:
|
|
2258
|
+
return TimingDecision(
|
|
2259
|
+
should_call_tool=False,
|
|
2260
|
+
confidence=0.5,
|
|
2261
|
+
reasoning=f"[Embedding-error] {str(e)[:50]}",
|
|
2262
|
+
)
|
|
2263
|
+
|
|
2264
|
+
return TimingAdapter(EmbeddingTimingDecider())
|
|
2265
|
+
|
|
2266
|
+
def _create_mock_resources(self) -> Any:
|
|
2267
|
+
"""Create mock resources for testing."""
|
|
2268
|
+
from sage.libs.agentic.agents.action.tool_selection import SelectorResources
|
|
2269
|
+
|
|
2270
|
+
class MockToolsLoader:
|
|
2271
|
+
"""Mock tools loader for testing."""
|
|
2272
|
+
|
|
2273
|
+
def iter_all(self):
|
|
2274
|
+
"""Yield mock tools."""
|
|
2275
|
+
|
|
2276
|
+
class MockTool:
|
|
2277
|
+
def __init__(self, tool_id, name, description, category):
|
|
2278
|
+
self.tool_id = tool_id
|
|
2279
|
+
self.name = name
|
|
2280
|
+
self.description = description
|
|
2281
|
+
self.category = category
|
|
2282
|
+
|
|
2283
|
+
# Generate some mock tools
|
|
2284
|
+
categories = ["search", "calculate", "data", "communication"]
|
|
2285
|
+
for i in range(50):
|
|
2286
|
+
cat = categories[i % len(categories)]
|
|
2287
|
+
yield MockTool(
|
|
2288
|
+
tool_id=f"tool_{i:03d}",
|
|
2289
|
+
name=f"{cat}_tool_{i}",
|
|
2290
|
+
description=f"A tool for {cat} operations, variant {i}",
|
|
2291
|
+
category=cat,
|
|
2292
|
+
)
|
|
2293
|
+
|
|
2294
|
+
return SelectorResources(
|
|
2295
|
+
tools_loader=MockToolsLoader(),
|
|
2296
|
+
embedding_client=None,
|
|
2297
|
+
)
|
|
2298
|
+
|
|
2299
|
+
def _create_sage_resources(self, embedding_client: Optional[Any] = None) -> Any:
|
|
2300
|
+
"""
|
|
2301
|
+
Create SelectorResources with real SAGE-Bench tools.
|
|
2302
|
+
|
|
2303
|
+
This loads the 1,200 tools from tool_catalog.jsonl for use with
|
|
2304
|
+
Gorilla and DFSDT selectors that need to build a proper tool index.
|
|
2305
|
+
|
|
2306
|
+
Args:
|
|
2307
|
+
embedding_client: Optional embedding client for semantic search
|
|
2308
|
+
|
|
2309
|
+
Returns:
|
|
2310
|
+
SelectorResources with SageToolsLoader
|
|
2311
|
+
"""
|
|
2312
|
+
from sage.libs.agentic.agents.action.tool_selection import SelectorResources
|
|
2313
|
+
|
|
2314
|
+
try:
|
|
2315
|
+
from sage.benchmark.benchmark_agent.tools_loader import (
|
|
2316
|
+
get_sage_tools_loader,
|
|
2317
|
+
)
|
|
2318
|
+
|
|
2319
|
+
tools_loader = get_sage_tools_loader()
|
|
2320
|
+
self.logger.info(f"Using SAGE tools loader with {len(tools_loader)} tools")
|
|
2321
|
+
except Exception as e:
|
|
2322
|
+
self.logger.warning(f"Failed to load SAGE tools: {e}. Falling back to mock tools.")
|
|
2323
|
+
return self._create_mock_resources()
|
|
2324
|
+
|
|
2325
|
+
return SelectorResources(
|
|
2326
|
+
tools_loader=tools_loader,
|
|
2327
|
+
embedding_client=embedding_client,
|
|
2328
|
+
)
|
|
2329
|
+
|
|
2330
|
+
def _create_simple_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
|
|
2331
|
+
"""
|
|
2332
|
+
Create simple planner using embedding-based tool matching.
|
|
2333
|
+
|
|
2334
|
+
Uses EmbeddingFactory for semantic similarity matching.
|
|
2335
|
+
"""
|
|
2336
|
+
|
|
2337
|
+
class EmbeddingBasedPlanner:
|
|
2338
|
+
"""Planner using embedding similarity for tool selection."""
|
|
2339
|
+
|
|
2340
|
+
def __init__(self):
|
|
2341
|
+
self._embedder = None
|
|
2342
|
+
self._tool_embeddings_cache: dict[str, list[float]] = {}
|
|
2343
|
+
|
|
2344
|
+
def _get_embedder(self):
|
|
2345
|
+
"""Lazy initialization of embedder."""
|
|
2346
|
+
if self._embedder is None:
|
|
2347
|
+
try:
|
|
2348
|
+
from sage.common.components.sage_embedding import (
|
|
2349
|
+
get_embedding_model,
|
|
2350
|
+
)
|
|
2351
|
+
|
|
2352
|
+
# Try to use HF model, fallback to hash
|
|
2353
|
+
try:
|
|
2354
|
+
self._embedder = get_embedding_model(
|
|
2355
|
+
"hf", model=BENCHMARK_EMBEDDING_MODEL
|
|
2356
|
+
)
|
|
2357
|
+
except Exception:
|
|
2358
|
+
self._embedder = get_embedding_model("hash", dim=384)
|
|
2359
|
+
except Exception:
|
|
2360
|
+
pass
|
|
2361
|
+
return self._embedder
|
|
2362
|
+
|
|
2363
|
+
def _compute_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
|
2364
|
+
"""Compute cosine similarity."""
|
|
2365
|
+
import math
|
|
2366
|
+
|
|
2367
|
+
dot = sum(a * b for a, b in zip(vec1, vec2))
|
|
2368
|
+
norm1 = math.sqrt(sum(a * a for a in vec1))
|
|
2369
|
+
norm2 = math.sqrt(sum(b * b for b in vec2))
|
|
2370
|
+
if norm1 == 0 or norm2 == 0:
|
|
2371
|
+
return 0.0
|
|
2372
|
+
return dot / (norm1 * norm2)
|
|
2373
|
+
|
|
2374
|
+
def plan(self, task):
|
|
2375
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
2376
|
+
PlanningPrediction,
|
|
2377
|
+
PlanStep,
|
|
2378
|
+
)
|
|
2379
|
+
|
|
2380
|
+
instruction = getattr(task, "instruction", "") or ""
|
|
2381
|
+
available_tools = getattr(task, "available_tools", []) or []
|
|
2382
|
+
|
|
2383
|
+
if not available_tools:
|
|
2384
|
+
return PlanningPrediction(steps=[], tool_sequence=[])
|
|
2385
|
+
|
|
2386
|
+
embedder = self._get_embedder()
|
|
2387
|
+
|
|
2388
|
+
# Decompose instruction into sub-tasks
|
|
2389
|
+
sub_tasks = self._decompose_instruction(instruction)
|
|
2390
|
+
|
|
2391
|
+
steps = []
|
|
2392
|
+
tool_sequence = []
|
|
2393
|
+
used_tools = set()
|
|
2394
|
+
|
|
2395
|
+
for i, sub_task in enumerate(sub_tasks):
|
|
2396
|
+
best_tool = self._match_tool_semantic(
|
|
2397
|
+
sub_task, available_tools, used_tools, embedder
|
|
2398
|
+
)
|
|
2399
|
+
if best_tool:
|
|
2400
|
+
steps.append(
|
|
2401
|
+
PlanStep(
|
|
2402
|
+
step_id=i,
|
|
2403
|
+
description=sub_task,
|
|
2404
|
+
tool_id=best_tool,
|
|
2405
|
+
confidence=0.7,
|
|
2406
|
+
)
|
|
2407
|
+
)
|
|
2408
|
+
tool_sequence.append(best_tool)
|
|
2409
|
+
used_tools.add(best_tool)
|
|
2410
|
+
|
|
2411
|
+
return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
|
|
2412
|
+
|
|
2413
|
+
def _decompose_instruction(self, instruction: str) -> list[str]:
|
|
2414
|
+
"""Decompose instruction into sub-tasks."""
|
|
2415
|
+
delimiters = [", and ", " and ", ", then ", ", ", ". ", "; "]
|
|
2416
|
+
sub_tasks = [instruction]
|
|
2417
|
+
|
|
2418
|
+
for delim in delimiters:
|
|
2419
|
+
new_tasks = []
|
|
2420
|
+
for task_item in sub_tasks:
|
|
2421
|
+
parts = task_item.split(delim)
|
|
2422
|
+
new_tasks.extend([p.strip() for p in parts if p.strip()])
|
|
2423
|
+
sub_tasks = new_tasks
|
|
2424
|
+
|
|
2425
|
+
sub_tasks = [t for t in sub_tasks if len(t) > 5]
|
|
2426
|
+
return sub_tasks[:10] if sub_tasks else [instruction]
|
|
2427
|
+
|
|
2428
|
+
def _match_tool_semantic(
|
|
2429
|
+
self,
|
|
2430
|
+
sub_task: str,
|
|
2431
|
+
available_tools: list[str],
|
|
2432
|
+
used_tools: set[str],
|
|
2433
|
+
embedder,
|
|
2434
|
+
) -> str | None:
|
|
2435
|
+
"""Match sub-task to tool using semantic similarity."""
|
|
2436
|
+
if embedder is None:
|
|
2437
|
+
# Fallback to keyword matching
|
|
2438
|
+
return self._match_tool_keyword(sub_task, available_tools, used_tools)
|
|
2439
|
+
|
|
2440
|
+
try:
|
|
2441
|
+
# Get sub-task embedding
|
|
2442
|
+
task_vec = embedder.embed(sub_task)
|
|
2443
|
+
|
|
2444
|
+
best_tool = None
|
|
2445
|
+
best_score = -1.0
|
|
2446
|
+
|
|
2447
|
+
for tool in available_tools:
|
|
2448
|
+
# Get tool embedding (with cache)
|
|
2449
|
+
if tool not in self._tool_embeddings_cache:
|
|
2450
|
+
# Create description from tool name
|
|
2451
|
+
tool_desc = tool.replace("_", " ")
|
|
2452
|
+
self._tool_embeddings_cache[tool] = embedder.embed(tool_desc)
|
|
2453
|
+
|
|
2454
|
+
tool_vec = self._tool_embeddings_cache[tool]
|
|
2455
|
+
similarity = self._compute_similarity(task_vec, tool_vec)
|
|
2456
|
+
|
|
2457
|
+
# Penalty for already used tools
|
|
2458
|
+
if tool in used_tools:
|
|
2459
|
+
similarity *= 0.5
|
|
2460
|
+
|
|
2461
|
+
if similarity > best_score:
|
|
2462
|
+
best_score = similarity
|
|
2463
|
+
best_tool = tool
|
|
2464
|
+
|
|
2465
|
+
return best_tool
|
|
2466
|
+
except Exception:
|
|
2467
|
+
return self._match_tool_keyword(sub_task, available_tools, used_tools)
|
|
2468
|
+
|
|
2469
|
+
def _match_tool_keyword(
|
|
2470
|
+
self, sub_task: str, available_tools: list[str], used_tools: set[str]
|
|
2471
|
+
) -> str | None:
|
|
2472
|
+
"""Fallback keyword matching."""
|
|
2473
|
+
sub_task_lower = sub_task.lower()
|
|
2474
|
+
best_tool = None
|
|
2475
|
+
best_score: float = 0.0
|
|
2476
|
+
|
|
2477
|
+
for tool in available_tools:
|
|
2478
|
+
tool_lower = tool.lower()
|
|
2479
|
+
score: float = 0.0
|
|
2480
|
+
|
|
2481
|
+
for part in tool_lower.split("_"):
|
|
2482
|
+
if part in sub_task_lower and len(part) > 2:
|
|
2483
|
+
score += 2
|
|
2484
|
+
|
|
2485
|
+
if tool in used_tools:
|
|
2486
|
+
score *= 0.5
|
|
2487
|
+
|
|
2488
|
+
if score > best_score:
|
|
2489
|
+
best_score = score
|
|
2490
|
+
best_tool = tool
|
|
2491
|
+
|
|
2492
|
+
if best_tool is None:
|
|
2493
|
+
for tool in available_tools:
|
|
2494
|
+
if tool not in used_tools:
|
|
2495
|
+
return tool
|
|
2496
|
+
|
|
2497
|
+
return best_tool
|
|
2498
|
+
|
|
2499
|
+
return PlannerAdapter(EmbeddingBasedPlanner())
|
|
2500
|
+
|
|
2501
|
+
def _create_hierarchical_planning_strategy(
|
|
2502
|
+
self, resources: Optional[Any] = None
|
|
2503
|
+
) -> PlannerAdapter:
|
|
2504
|
+
"""
|
|
2505
|
+
Create hierarchical planner for Challenge 2.
|
|
2506
|
+
|
|
2507
|
+
Uses task decomposition and dependency analysis.
|
|
2508
|
+
"""
|
|
2509
|
+
|
|
2510
|
+
class HierarchicalPlanningStrategy:
|
|
2511
|
+
"""Hierarchical planner with dependency management."""
|
|
2512
|
+
|
|
2513
|
+
def plan(self, task):
|
|
2514
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
2515
|
+
PlanningPrediction,
|
|
2516
|
+
PlanStep,
|
|
2517
|
+
)
|
|
2518
|
+
|
|
2519
|
+
instruction = getattr(task, "instruction", "") or ""
|
|
2520
|
+
available_tools = getattr(task, "available_tools", []) or []
|
|
2521
|
+
|
|
2522
|
+
if not available_tools:
|
|
2523
|
+
return PlanningPrediction(steps=[], tool_sequence=[])
|
|
2524
|
+
|
|
2525
|
+
# Decompose instruction into sub-tasks
|
|
2526
|
+
sub_tasks = self._decompose_instruction(instruction)
|
|
2527
|
+
|
|
2528
|
+
# Map sub-tasks to tools
|
|
2529
|
+
steps = []
|
|
2530
|
+
tool_sequence = []
|
|
2531
|
+
used_tools = set()
|
|
2532
|
+
|
|
2533
|
+
for i, sub_task in enumerate(sub_tasks):
|
|
2534
|
+
best_tool = self._match_tool(sub_task, available_tools, used_tools)
|
|
2535
|
+
if best_tool:
|
|
2536
|
+
steps.append(
|
|
2537
|
+
PlanStep(
|
|
2538
|
+
step_id=i,
|
|
2539
|
+
description=sub_task,
|
|
2540
|
+
tool_id=best_tool,
|
|
2541
|
+
confidence=0.75,
|
|
2542
|
+
)
|
|
2543
|
+
)
|
|
2544
|
+
tool_sequence.append(best_tool)
|
|
2545
|
+
used_tools.add(best_tool)
|
|
2546
|
+
|
|
2547
|
+
return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
|
|
2548
|
+
|
|
2549
|
+
def _decompose_instruction(self, instruction: str) -> list[str]:
|
|
2550
|
+
"""Decompose instruction into sub-tasks."""
|
|
2551
|
+
# Split by common delimiters
|
|
2552
|
+
delimiters = [", and ", " and ", ", then ", ", ", ". ", "; "]
|
|
2553
|
+
sub_tasks = [instruction]
|
|
2554
|
+
|
|
2555
|
+
for delim in delimiters:
|
|
2556
|
+
new_tasks = []
|
|
2557
|
+
for task in sub_tasks:
|
|
2558
|
+
parts = task.split(delim)
|
|
2559
|
+
new_tasks.extend([p.strip() for p in parts if p.strip()])
|
|
2560
|
+
sub_tasks = new_tasks
|
|
2561
|
+
|
|
2562
|
+
# Filter out very short tasks
|
|
2563
|
+
sub_tasks = [t for t in sub_tasks if len(t) > 5]
|
|
2564
|
+
|
|
2565
|
+
# Limit to reasonable number
|
|
2566
|
+
return sub_tasks[:10] if sub_tasks else [instruction]
|
|
2567
|
+
|
|
2568
|
+
def _match_tool(
|
|
2569
|
+
self, sub_task: str, available_tools: list[str], used_tools: set[str]
|
|
2570
|
+
) -> str | None:
|
|
2571
|
+
"""Match a sub-task to the best available tool."""
|
|
2572
|
+
sub_task_lower = sub_task.lower()
|
|
2573
|
+
|
|
2574
|
+
# Tool type indicators
|
|
2575
|
+
type_keywords = {
|
|
2576
|
+
"file_read": ["read", "load", "open file"],
|
|
2577
|
+
"file_write": ["write", "save", "store file"],
|
|
2578
|
+
"file_list": ["list files", "directory"],
|
|
2579
|
+
"file_copy": ["copy", "backup"],
|
|
2580
|
+
"file_delete": ["delete file", "remove file"],
|
|
2581
|
+
"data_parse_json": ["parse", "json"],
|
|
2582
|
+
"data_transform": ["transform", "convert data"],
|
|
2583
|
+
"data_filter": ["filter", "select records"],
|
|
2584
|
+
"data_aggregate": ["aggregate", "sum", "total"],
|
|
2585
|
+
"data_validate": ["validate", "check schema"],
|
|
2586
|
+
"http_get": ["fetch", "get data", "download", "api get"],
|
|
2587
|
+
"http_post": ["post", "submit", "send data"],
|
|
2588
|
+
"api_authenticate": ["authenticate", "login", "auth"],
|
|
2589
|
+
"web_scrape": ["scrape", "extract from web"],
|
|
2590
|
+
"db_connect": ["connect database", "db connect"],
|
|
2591
|
+
"db_query": ["query", "select from", "database query"],
|
|
2592
|
+
"db_insert": ["insert", "add record"],
|
|
2593
|
+
"db_update": ["update record", "modify"],
|
|
2594
|
+
"db_delete": ["delete record"],
|
|
2595
|
+
"email_send": ["email", "send mail"],
|
|
2596
|
+
"notification_send": ["notify", "notification"],
|
|
2597
|
+
"slack_post": ["slack", "post message"],
|
|
2598
|
+
"text_analyze": ["analyze text", "text analysis"],
|
|
2599
|
+
"sentiment_analyze": ["sentiment"],
|
|
2600
|
+
"stats_compute": ["statistics", "compute stats"],
|
|
2601
|
+
"math_calculate": ["calculate", "math"],
|
|
2602
|
+
"format_json": ["format json", "to json"],
|
|
2603
|
+
"format_csv": ["csv", "format csv"],
|
|
2604
|
+
"format_html": ["html", "format html", "report"],
|
|
2605
|
+
"cache_get": ["get cache", "retrieve cache"],
|
|
2606
|
+
"cache_set": ["cache", "set cache", "store cache"],
|
|
2607
|
+
"log_write": ["log", "write log"],
|
|
2608
|
+
"metrics_record": ["metrics", "record metric"],
|
|
2609
|
+
"image_resize": ["resize", "thumbnail"],
|
|
2610
|
+
"image_convert": ["convert image", "png", "jpg"],
|
|
2611
|
+
"schedule_task": ["schedule"],
|
|
2612
|
+
"get_calendar": ["calendar", "events"],
|
|
2613
|
+
"code_lint": ["lint"],
|
|
2614
|
+
"code_format": ["format code"],
|
|
2615
|
+
"code_execute": ["execute", "run code"],
|
|
2616
|
+
"search_web": ["search web", "web search"],
|
|
2617
|
+
"search_documents": ["search document"],
|
|
2618
|
+
"search_database": ["search database"],
|
|
2619
|
+
"convert_units": ["convert units"],
|
|
2620
|
+
}
|
|
2621
|
+
|
|
2622
|
+
best_tool = None
|
|
2623
|
+
best_score: float = 0.0
|
|
2624
|
+
|
|
2625
|
+
for tool in available_tools:
|
|
2626
|
+
if tool in used_tools:
|
|
2627
|
+
# Prefer unused tools but allow reuse with penalty
|
|
2628
|
+
penalty = 0.5
|
|
2629
|
+
else:
|
|
2630
|
+
penalty = 1.0
|
|
2631
|
+
|
|
2632
|
+
score: float = 0.0
|
|
2633
|
+
tool_lower = tool.lower()
|
|
2634
|
+
|
|
2635
|
+
# Check against type keywords
|
|
2636
|
+
if tool in type_keywords:
|
|
2637
|
+
for kw in type_keywords[tool]:
|
|
2638
|
+
if kw in sub_task_lower:
|
|
2639
|
+
score += 3 * penalty
|
|
2640
|
+
|
|
2641
|
+
# Check tool name parts
|
|
2642
|
+
for part in tool_lower.split("_"):
|
|
2643
|
+
if part in sub_task_lower and len(part) > 2:
|
|
2644
|
+
score += 2 * penalty
|
|
2645
|
+
|
|
2646
|
+
if score > best_score:
|
|
2647
|
+
best_score = score
|
|
2648
|
+
best_tool = tool
|
|
2649
|
+
|
|
2650
|
+
# Fallback: pick first unused tool
|
|
2651
|
+
if best_tool is None:
|
|
2652
|
+
for tool in available_tools:
|
|
2653
|
+
if tool not in used_tools:
|
|
2654
|
+
return tool
|
|
2655
|
+
|
|
2656
|
+
return best_tool
|
|
2657
|
+
|
|
2658
|
+
return PlannerAdapter(HierarchicalPlanningStrategy())
|
|
2659
|
+
|
|
2660
|
+
def _create_llm_planning_strategy(self, resources: Optional[Any] = None) -> PlannerAdapter:
|
|
2661
|
+
"""
|
|
2662
|
+
Create LLM-based planner using UnifiedInferenceClient.
|
|
2663
|
+
|
|
2664
|
+
Uses real LLM for plan generation with semantic understanding.
|
|
2665
|
+
"""
|
|
2666
|
+
|
|
2667
|
+
class LLMPlanningStrategy:
|
|
2668
|
+
"""LLM-based planner using UnifiedInferenceClient."""
|
|
2669
|
+
|
|
2670
|
+
def __init__(self, fallback_planner):
|
|
2671
|
+
self._fallback = fallback_planner
|
|
2672
|
+
self._llm_client = None
|
|
2673
|
+
self._client_initialized = False
|
|
2674
|
+
|
|
2675
|
+
def _get_llm_client(self):
|
|
2676
|
+
"""Lazy initialization of LLM client using UnifiedInferenceClient.
|
|
2677
|
+
|
|
2678
|
+
Uses UnifiedInferenceClient.create() which handles:
|
|
2679
|
+
1. Local vLLM API service detection (via SagePorts)
|
|
2680
|
+
2. Cloud API fallback (via SAGE_CHAT_* env vars)
|
|
2681
|
+
"""
|
|
2682
|
+
if not self._client_initialized:
|
|
2683
|
+
self._client_initialized = True
|
|
2684
|
+
|
|
2685
|
+
from sage.llm import UnifiedInferenceClient
|
|
2686
|
+
|
|
2687
|
+
try:
|
|
2688
|
+
self._llm_client = UnifiedInferenceClient.create()
|
|
2689
|
+
# Log which mode we're using
|
|
2690
|
+
if self._llm_client._llm_base_url:
|
|
2691
|
+
if "localhost" in self._llm_client._llm_base_url:
|
|
2692
|
+
print(f"✅ 使用本地 LLM: {self._llm_client._llm_model}")
|
|
2693
|
+
else:
|
|
2694
|
+
print(f"☁️ 使用云端 API: {self._llm_client._llm_model}")
|
|
2695
|
+
else:
|
|
2696
|
+
print("⚠️ LLM 客户端初始化但无可用端点")
|
|
2697
|
+
except Exception as e:
|
|
2698
|
+
print(f"⚠️ 无可用 LLM 服务: {e}")
|
|
2699
|
+
print(" 启动本地服务: sage studio start")
|
|
2700
|
+
print(" 或配置云端: export SAGE_CHAT_API_KEY=your_key")
|
|
2701
|
+
|
|
2702
|
+
return self._llm_client
|
|
2703
|
+
|
|
2704
|
+
def plan(self, task):
|
|
2705
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
2706
|
+
PlanningPrediction,
|
|
2707
|
+
PlanStep,
|
|
2708
|
+
)
|
|
2709
|
+
|
|
2710
|
+
instruction = getattr(task, "instruction", "") or ""
|
|
2711
|
+
available_tools = getattr(task, "available_tools", []) or []
|
|
2712
|
+
|
|
2713
|
+
if not available_tools:
|
|
2714
|
+
return PlanningPrediction(steps=[], tool_sequence=[])
|
|
2715
|
+
|
|
2716
|
+
# Initialize LLM client
|
|
2717
|
+
client = self._get_llm_client()
|
|
2718
|
+
if client is None:
|
|
2719
|
+
return self._fallback.plan(task)
|
|
2720
|
+
|
|
2721
|
+
# Build prompt for plan generation
|
|
2722
|
+
tools_desc = ", ".join(available_tools[:20]) # Limit for prompt
|
|
2723
|
+
prompt = f"""Generate a step-by-step plan to accomplish this task.
|
|
2724
|
+
|
|
2725
|
+
Task: {instruction}
|
|
2726
|
+
|
|
2727
|
+
Available tools: {tools_desc}
|
|
2728
|
+
|
|
2729
|
+
Return a JSON array of steps, each with:
|
|
2730
|
+
- "tool_id": the tool to use (must be from available tools)
|
|
2731
|
+
- "description": brief description of what this step does
|
|
2732
|
+
|
|
2733
|
+
Return ONLY the JSON array, no explanation. Example:
|
|
2734
|
+
[{{"tool_id": "file_read", "description": "Read config file"}}]"""
|
|
2735
|
+
|
|
2736
|
+
try:
|
|
2737
|
+
messages = [
|
|
2738
|
+
{
|
|
2739
|
+
"role": "system",
|
|
2740
|
+
"content": "You are a task planning assistant. Return only valid JSON.",
|
|
2741
|
+
},
|
|
2742
|
+
{"role": "user", "content": prompt},
|
|
2743
|
+
]
|
|
2744
|
+
response = client.chat(messages)
|
|
2745
|
+
except Exception:
|
|
2746
|
+
return self._fallback.plan(task)
|
|
2747
|
+
|
|
2748
|
+
# Parse response
|
|
2749
|
+
if response:
|
|
2750
|
+
try:
|
|
2751
|
+
import json
|
|
2752
|
+
import re
|
|
2753
|
+
|
|
2754
|
+
# Extract JSON from response
|
|
2755
|
+
text = response if isinstance(response, str) else str(response)
|
|
2756
|
+
|
|
2757
|
+
# Try to find JSON array
|
|
2758
|
+
json_match = re.search(r"\[.*\]", text, re.DOTALL)
|
|
2759
|
+
if json_match:
|
|
2760
|
+
plan_data = json.loads(json_match.group())
|
|
2761
|
+
|
|
2762
|
+
steps = []
|
|
2763
|
+
tool_sequence = []
|
|
2764
|
+
for i, step in enumerate(plan_data):
|
|
2765
|
+
tool_id = step.get("tool_id", "")
|
|
2766
|
+
# Validate tool is available
|
|
2767
|
+
if tool_id in available_tools:
|
|
2768
|
+
steps.append(
|
|
2769
|
+
PlanStep(
|
|
2770
|
+
step_id=i,
|
|
2771
|
+
description=step.get("description", f"Step {i + 1}"),
|
|
2772
|
+
tool_id=tool_id,
|
|
2773
|
+
confidence=0.85,
|
|
2774
|
+
)
|
|
2775
|
+
)
|
|
2776
|
+
tool_sequence.append(tool_id)
|
|
2777
|
+
|
|
2778
|
+
if steps:
|
|
2779
|
+
return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
|
|
2780
|
+
except Exception:
|
|
2781
|
+
pass
|
|
2782
|
+
|
|
2783
|
+
# Fallback to hierarchical planner
|
|
2784
|
+
return self._fallback.plan(task)
|
|
2785
|
+
|
|
2786
|
+
# Create hierarchical fallback
|
|
2787
|
+
hierarchical = self._create_hierarchical_planning_strategy(resources)
|
|
2788
|
+
return PlannerAdapter(LLMPlanningStrategy(hierarchical.planner))
|
|
2789
|
+
|
|
2790
|
+
def _create_tot_planner(self, resources: Optional[Any] = None) -> PlannerAdapter:
|
|
2791
|
+
"""
|
|
2792
|
+
Create Tree-of-Thoughts planner for Challenge 2.
|
|
2793
|
+
|
|
2794
|
+
Uses tree search to explore multiple reasoning paths.
|
|
2795
|
+
Based on "Tree of Thoughts: Deliberate Problem Solving with LLMs" (Yao et al., 2023)
|
|
2796
|
+
"""
|
|
2797
|
+
|
|
2798
|
+
class ToTPlanningStrategy:
|
|
2799
|
+
"""
|
|
2800
|
+
Tree-of-Thoughts planning strategy.
|
|
2801
|
+
|
|
2802
|
+
Explores multiple reasoning paths via BFS/DFS tree search,
|
|
2803
|
+
using LLM to generate and evaluate thought candidates.
|
|
2804
|
+
"""
|
|
2805
|
+
|
|
2806
|
+
def __init__(self, fallback_planner):
|
|
2807
|
+
self._fallback = fallback_planner
|
|
2808
|
+
self._llm_client = None
|
|
2809
|
+
self._client_initialized = False
|
|
2810
|
+
# ToT configuration
|
|
2811
|
+
self._max_depth = 3
|
|
2812
|
+
self._branch_factor = 3
|
|
2813
|
+
self._beam_width = 5
|
|
2814
|
+
self._min_score = 0.3
|
|
2815
|
+
|
|
2816
|
+
def _get_llm_client(self):
|
|
2817
|
+
"""Lazy initialization of LLM client."""
|
|
2818
|
+
if self._client_initialized:
|
|
2819
|
+
return self._llm_client
|
|
2820
|
+
|
|
2821
|
+
self._client_initialized = True
|
|
2822
|
+
try:
|
|
2823
|
+
from sage.llm import UnifiedInferenceClient
|
|
2824
|
+
|
|
2825
|
+
# Use singleton to avoid repeated model loading
|
|
2826
|
+
self._llm_client = UnifiedInferenceClient.get_instance(
|
|
2827
|
+
instance_key="benchmark_planner"
|
|
2828
|
+
)
|
|
2829
|
+
except Exception:
|
|
2830
|
+
self._llm_client = None
|
|
2831
|
+
|
|
2832
|
+
return self._llm_client
|
|
2833
|
+
|
|
2834
|
+
def plan(self, task):
|
|
2835
|
+
"""Generate plan using Tree-of-Thoughts search."""
|
|
2836
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
2837
|
+
PlanningPrediction,
|
|
2838
|
+
PlanStep,
|
|
2839
|
+
)
|
|
2840
|
+
|
|
2841
|
+
instruction = getattr(task, "instruction", "") or ""
|
|
2842
|
+
available_tools = getattr(task, "available_tools", []) or []
|
|
2843
|
+
|
|
2844
|
+
if not available_tools:
|
|
2845
|
+
return PlanningPrediction(steps=[], tool_sequence=[])
|
|
2846
|
+
|
|
2847
|
+
# Get LLM client
|
|
2848
|
+
llm_client = self._get_llm_client()
|
|
2849
|
+
|
|
2850
|
+
# If no LLM, fall back to hierarchical
|
|
2851
|
+
if llm_client is None:
|
|
2852
|
+
return self._fallback.plan(task)
|
|
2853
|
+
|
|
2854
|
+
try:
|
|
2855
|
+
# Run ToT search
|
|
2856
|
+
best_path = self._tot_search(instruction, available_tools, llm_client)
|
|
2857
|
+
|
|
2858
|
+
if best_path:
|
|
2859
|
+
steps = []
|
|
2860
|
+
tool_sequence = []
|
|
2861
|
+
for i, (thought, tool_id) in enumerate(best_path):
|
|
2862
|
+
if tool_id and tool_id in available_tools:
|
|
2863
|
+
steps.append(
|
|
2864
|
+
PlanStep(
|
|
2865
|
+
step_id=i,
|
|
2866
|
+
description=thought,
|
|
2867
|
+
tool_id=tool_id,
|
|
2868
|
+
confidence=0.8,
|
|
2869
|
+
)
|
|
2870
|
+
)
|
|
2871
|
+
tool_sequence.append(tool_id)
|
|
2872
|
+
|
|
2873
|
+
if steps:
|
|
2874
|
+
return PlanningPrediction(steps=steps, tool_sequence=tool_sequence)
|
|
2875
|
+
except Exception:
|
|
2876
|
+
pass
|
|
2877
|
+
|
|
2878
|
+
# Fallback to hierarchical planner
|
|
2879
|
+
return self._fallback.plan(task)
|
|
2880
|
+
|
|
2881
|
+
def _tot_search(
|
|
2882
|
+
self, instruction: str, available_tools: list[str], llm_client
|
|
2883
|
+
) -> list[tuple[str, str]]:
|
|
2884
|
+
"""
|
|
2885
|
+
Perform Tree-of-Thoughts BFS search.
|
|
2886
|
+
|
|
2887
|
+
Returns list of (thought, tool_id) tuples.
|
|
2888
|
+
"""
|
|
2889
|
+
# Initialize queue with empty path
|
|
2890
|
+
queue: list[list[tuple[str, str, float]]] = [[]] # Paths
|
|
2891
|
+
|
|
2892
|
+
tools_str = ", ".join(available_tools[:15])
|
|
2893
|
+
|
|
2894
|
+
for depth in range(self._max_depth):
|
|
2895
|
+
next_queue: list[list[tuple[str, str, float]]] = []
|
|
2896
|
+
|
|
2897
|
+
for path in queue:
|
|
2898
|
+
# Generate candidate thoughts
|
|
2899
|
+
candidates = self._generate_thoughts(
|
|
2900
|
+
instruction, path, available_tools, llm_client, tools_str
|
|
2901
|
+
)
|
|
2902
|
+
|
|
2903
|
+
for thought, tool_id, score in candidates:
|
|
2904
|
+
if score >= self._min_score:
|
|
2905
|
+
new_path = path + [(thought, tool_id, score)]
|
|
2906
|
+
next_queue.append(new_path)
|
|
2907
|
+
|
|
2908
|
+
if not next_queue:
|
|
2909
|
+
break
|
|
2910
|
+
|
|
2911
|
+
# Keep top-k paths by average score
|
|
2912
|
+
next_queue.sort(
|
|
2913
|
+
key=lambda p: sum(s for _, _, s in p) / len(p) if p else 0, reverse=True
|
|
2914
|
+
)
|
|
2915
|
+
queue = next_queue[: self._beam_width]
|
|
2916
|
+
|
|
2917
|
+
# Return best path
|
|
2918
|
+
if queue:
|
|
2919
|
+
best_path = max(
|
|
2920
|
+
queue, key=lambda p: sum(s for _, _, s in p) / len(p) if p else 0
|
|
2921
|
+
)
|
|
2922
|
+
return [(t, tid) for t, tid, _ in best_path]
|
|
2923
|
+
|
|
2924
|
+
return []
|
|
2925
|
+
|
|
2926
|
+
def _generate_thoughts(
|
|
2927
|
+
self,
|
|
2928
|
+
instruction: str,
|
|
2929
|
+
path: list[tuple[str, str, float]],
|
|
2930
|
+
available_tools: list[str],
|
|
2931
|
+
llm_client,
|
|
2932
|
+
tools_str: str,
|
|
2933
|
+
) -> list[tuple[str, str, float]]:
|
|
2934
|
+
"""Generate and evaluate candidate thoughts."""
|
|
2935
|
+
import json
|
|
2936
|
+
import re
|
|
2937
|
+
|
|
2938
|
+
# Format current progress
|
|
2939
|
+
progress = ""
|
|
2940
|
+
if path:
|
|
2941
|
+
progress = "\n".join(
|
|
2942
|
+
f"Step {i + 1}: {t} (tool: {tid})" for i, (t, tid, _) in enumerate(path)
|
|
2943
|
+
)
|
|
2944
|
+
else:
|
|
2945
|
+
progress = "No steps taken yet."
|
|
2946
|
+
|
|
2947
|
+
# Get used tools
|
|
2948
|
+
used_tools = {tid for _, tid, _ in path if tid}
|
|
2949
|
+
|
|
2950
|
+
# Generate prompt
|
|
2951
|
+
prompt = f"""You are a planning assistant. Generate {self._branch_factor} different possible next steps.
|
|
2952
|
+
|
|
2953
|
+
Task: {instruction}
|
|
2954
|
+
Available tools: {tools_str}
|
|
2955
|
+
Current progress:
|
|
2956
|
+
{progress}
|
|
2957
|
+
|
|
2958
|
+
Generate {self._branch_factor} different next steps. Each should use an available tool.
|
|
2959
|
+
Avoid tools already used: {", ".join(used_tools) if used_tools else "none"}
|
|
2960
|
+
|
|
2961
|
+
Output as JSON array:
|
|
2962
|
+
[{{"thought": "step description", "tool_id": "tool_name", "score": 0-10}}]
|
|
2963
|
+
|
|
2964
|
+
Only output JSON, nothing else."""
|
|
2965
|
+
|
|
2966
|
+
try:
|
|
2967
|
+
response = llm_client.chat(
|
|
2968
|
+
[{"role": "user", "content": prompt}],
|
|
2969
|
+
max_tokens=512,
|
|
2970
|
+
temperature=BENCHMARK_LLM_TEMPERATURE,
|
|
2971
|
+
)
|
|
2972
|
+
|
|
2973
|
+
# Parse response
|
|
2974
|
+
text = response if isinstance(response, str) else str(response)
|
|
2975
|
+
json_match = re.search(r"\[.*\]", text, re.DOTALL)
|
|
2976
|
+
if json_match:
|
|
2977
|
+
candidates = json.loads(json_match.group())
|
|
2978
|
+
result = []
|
|
2979
|
+
for c in candidates:
|
|
2980
|
+
thought = c.get("thought", "")
|
|
2981
|
+
tool_id = c.get("tool_id", "")
|
|
2982
|
+
score = c.get("score", 5) / 10.0
|
|
2983
|
+
|
|
2984
|
+
# Validate tool
|
|
2985
|
+
if tool_id not in available_tools:
|
|
2986
|
+
# Try to find closest match
|
|
2987
|
+
for tool in available_tools:
|
|
2988
|
+
if tool not in used_tools:
|
|
2989
|
+
tool_id = tool
|
|
2990
|
+
break
|
|
2991
|
+
else:
|
|
2992
|
+
continue
|
|
2993
|
+
|
|
2994
|
+
result.append((thought, tool_id, score))
|
|
2995
|
+
|
|
2996
|
+
return result[: self._branch_factor]
|
|
2997
|
+
except Exception:
|
|
2998
|
+
pass
|
|
2999
|
+
|
|
3000
|
+
# Fallback: return one thought per unused tool
|
|
3001
|
+
result = []
|
|
3002
|
+
for tool in available_tools:
|
|
3003
|
+
if tool not in used_tools and len(result) < self._branch_factor:
|
|
3004
|
+
result.append((f"Use {tool} for task", tool, 0.5))
|
|
3005
|
+
return result
|
|
3006
|
+
|
|
3007
|
+
# Create hierarchical fallback
|
|
3008
|
+
hierarchical = self._create_hierarchical_planning_strategy(resources)
|
|
3009
|
+
return PlannerAdapter(ToTPlanningStrategy(hierarchical.planner))
|
|
3010
|
+
|
|
3011
|
+
|
|
3012
|
+
# Global registry instance
|
|
3013
|
+
_global_registry: Optional[AdapterRegistry] = None
|
|
3014
|
+
|
|
3015
|
+
|
|
3016
|
+
def get_adapter_registry() -> AdapterRegistry:
|
|
3017
|
+
"""Get the global adapter registry instance."""
|
|
3018
|
+
global _global_registry
|
|
3019
|
+
if _global_registry is None:
|
|
3020
|
+
_global_registry = AdapterRegistry()
|
|
3021
|
+
return _global_registry
|
|
3022
|
+
|
|
3023
|
+
|
|
3024
|
+
def register_strategy(name: str, strategy: Any) -> None:
|
|
3025
|
+
"""Register a strategy in the global registry."""
|
|
3026
|
+
get_adapter_registry().register(name, strategy)
|
|
3027
|
+
|
|
3028
|
+
|
|
3029
|
+
__all__ = [
|
|
3030
|
+
"AdapterRegistry",
|
|
3031
|
+
"SelectorAdapter",
|
|
3032
|
+
"PlannerAdapter",
|
|
3033
|
+
"TimingAdapter",
|
|
3034
|
+
"get_adapter_registry",
|
|
3035
|
+
"register_strategy",
|
|
3036
|
+
]
|