isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Section 5.2.2: Task Planning Experiment (RQ2)
|
|
4
|
+
|
|
5
|
+
研究问题: 现有方法将复杂任务分解为执行步骤的能力如何?
|
|
6
|
+
|
|
7
|
+
测试方法:
|
|
8
|
+
- planner.simple : 简单贪心匹配 (Baseline)
|
|
9
|
+
- planner.hierarchical : 层次化分解 (HuggingGPT, ICML'23)
|
|
10
|
+
- planner.llm_based : LLM 生成计划 (CoT Prompting)
|
|
11
|
+
- planner.react : 交错执行 (ReAct, Yao et al., ICLR'23)
|
|
12
|
+
- planner.tot : 树搜索 (Tree-of-Thoughts, NeurIPS'23)
|
|
13
|
+
|
|
14
|
+
目标指标:
|
|
15
|
+
- Primary: Plan Success Rate ≥ 90%
|
|
16
|
+
- Secondary: Step Accuracy, Tool Coverage
|
|
17
|
+
- Tertiary: Average Plan Length
|
|
18
|
+
|
|
19
|
+
Usage:
|
|
20
|
+
python exp_main_planning.py
|
|
21
|
+
python exp_main_planning.py --max-samples 100
|
|
22
|
+
python exp_main_planning.py --skip-llm
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import argparse
|
|
28
|
+
import time
|
|
29
|
+
|
|
30
|
+
from .exp_utils import (
|
|
31
|
+
ExperimentResult,
|
|
32
|
+
ExperimentSummary,
|
|
33
|
+
create_progress_bar,
|
|
34
|
+
load_benchmark_data,
|
|
35
|
+
print_metrics_detail,
|
|
36
|
+
print_result_row,
|
|
37
|
+
print_section_header,
|
|
38
|
+
print_subsection_header,
|
|
39
|
+
save_results,
|
|
40
|
+
setup_experiment_env,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compute_step_accuracy(predicted_steps: list[str], reference_steps: list[str]) -> float:
|
|
45
|
+
"""计算步骤级别的准确率。"""
|
|
46
|
+
if not reference_steps:
|
|
47
|
+
return 1.0 if not predicted_steps else 0.0
|
|
48
|
+
|
|
49
|
+
correct = sum(1 for p, r in zip(predicted_steps, reference_steps) if p == r)
|
|
50
|
+
return correct / len(reference_steps)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def compute_tool_coverage(predicted_steps: list[str], reference_steps: list[str]) -> float:
|
|
54
|
+
"""计算工具覆盖率。"""
|
|
55
|
+
if not reference_steps:
|
|
56
|
+
return 1.0
|
|
57
|
+
|
|
58
|
+
ref_set = set(reference_steps)
|
|
59
|
+
pred_set = set(predicted_steps)
|
|
60
|
+
return len(pred_set & ref_set) / len(ref_set)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def is_plan_success(
|
|
64
|
+
predicted_steps: list[str], reference_steps: list[str], threshold: float = 0.8
|
|
65
|
+
) -> bool:
|
|
66
|
+
"""判断计划是否成功 (步骤准确率 >= threshold)。"""
|
|
67
|
+
if not reference_steps:
|
|
68
|
+
return not predicted_steps
|
|
69
|
+
|
|
70
|
+
step_acc = compute_step_accuracy(predicted_steps, reference_steps)
|
|
71
|
+
return step_acc >= threshold
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def run_planning_experiment(
|
|
75
|
+
max_samples: int = 100,
|
|
76
|
+
skip_llm: bool = False,
|
|
77
|
+
verbose: bool = True,
|
|
78
|
+
) -> ExperimentSummary:
|
|
79
|
+
"""
|
|
80
|
+
运行 Task Planning 实验。
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
max_samples: 最大测试样本数
|
|
84
|
+
skip_llm: 是否跳过 LLM-based 方法
|
|
85
|
+
verbose: 是否打印详细信息
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
ExperimentSummary 对象
|
|
89
|
+
"""
|
|
90
|
+
setup_experiment_env(verbose=verbose)
|
|
91
|
+
|
|
92
|
+
print_section_header("Section 5.2.2: Task Planning (RQ2)")
|
|
93
|
+
print(" Target: Plan Success Rate ≥ 90%")
|
|
94
|
+
print(f" Max samples: {max_samples}")
|
|
95
|
+
|
|
96
|
+
# 加载数据
|
|
97
|
+
samples = load_benchmark_data("planning", split="test", max_samples=max_samples)
|
|
98
|
+
if not samples:
|
|
99
|
+
print(" ❌ No planning data available")
|
|
100
|
+
return ExperimentSummary(section="5_2_main", challenge="planning")
|
|
101
|
+
|
|
102
|
+
print(f" Loaded {len(samples)} samples")
|
|
103
|
+
|
|
104
|
+
# 获取策略注册表
|
|
105
|
+
try:
|
|
106
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
107
|
+
|
|
108
|
+
registry = get_adapter_registry()
|
|
109
|
+
except ImportError as e:
|
|
110
|
+
print(f" ❌ Failed to import adapter registry: {e}")
|
|
111
|
+
return ExperimentSummary(section="5_2_main", challenge="planning")
|
|
112
|
+
|
|
113
|
+
# 定义测试策略 - 包含主流 SOTA 方法
|
|
114
|
+
strategies = [
|
|
115
|
+
("planner.simple", "Simple (Greedy)"),
|
|
116
|
+
("planner.hierarchical", "Hierarchical (HuggingGPT)"),
|
|
117
|
+
("planner.llm_based", "LLM-based (CoT)"),
|
|
118
|
+
("planner.react", "ReAct (Yao et al.)"),
|
|
119
|
+
("planner.tot", "Tree-of-Thoughts"),
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
# 跳过 LLM 策略
|
|
123
|
+
LLM_STRATEGIES = {"planner.llm_based", "planner.react", "planner.tot"}
|
|
124
|
+
if skip_llm:
|
|
125
|
+
strategies = [(name, display) for name, display in strategies if name not in LLM_STRATEGIES]
|
|
126
|
+
print(" ⚠️ Skipping LLM-based strategies")
|
|
127
|
+
|
|
128
|
+
results = []
|
|
129
|
+
target = 0.90
|
|
130
|
+
|
|
131
|
+
for strategy_name, display_name in strategies:
|
|
132
|
+
print_subsection_header(f"Testing: {display_name}")
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
planner = registry.get(strategy_name)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
print(f" ⚠️ Failed to create planner: {e}")
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
# 运行评测
|
|
141
|
+
start_time = time.time()
|
|
142
|
+
success_count = 0
|
|
143
|
+
step_accuracies = []
|
|
144
|
+
tool_coverages = []
|
|
145
|
+
plan_lengths = []
|
|
146
|
+
|
|
147
|
+
with create_progress_bar(len(samples), desc=f" {display_name}") as pbar:
|
|
148
|
+
for sample in samples:
|
|
149
|
+
try:
|
|
150
|
+
task_description = sample.get("task", sample.get("instruction", ""))
|
|
151
|
+
available_tools = sample.get("available_tools", sample.get("tools", []))
|
|
152
|
+
reference_plan = sample.get("ground_truth", sample.get("expected_plan", []))
|
|
153
|
+
|
|
154
|
+
# 标准化 reference
|
|
155
|
+
if isinstance(reference_plan, list) and reference_plan:
|
|
156
|
+
if isinstance(reference_plan[0], dict):
|
|
157
|
+
ref_steps = [
|
|
158
|
+
s.get("tool_id", s.get("tool", "")) for s in reference_plan
|
|
159
|
+
]
|
|
160
|
+
else:
|
|
161
|
+
ref_steps = reference_plan
|
|
162
|
+
else:
|
|
163
|
+
ref_steps = []
|
|
164
|
+
|
|
165
|
+
# 调用规划器
|
|
166
|
+
plan_result = planner.plan(task_description, available_tools=available_tools)
|
|
167
|
+
|
|
168
|
+
# 提取预测步骤
|
|
169
|
+
if hasattr(plan_result, "tool_sequence"):
|
|
170
|
+
pred_steps = plan_result.tool_sequence
|
|
171
|
+
elif hasattr(plan_result, "steps"):
|
|
172
|
+
pred_steps = [
|
|
173
|
+
s.tool_id if hasattr(s, "tool_id") else str(s)
|
|
174
|
+
for s in plan_result.steps
|
|
175
|
+
]
|
|
176
|
+
elif isinstance(plan_result, list):
|
|
177
|
+
pred_steps = [
|
|
178
|
+
s.get("tool_id", str(s)) if isinstance(s, dict) else str(s)
|
|
179
|
+
for s in plan_result
|
|
180
|
+
]
|
|
181
|
+
else:
|
|
182
|
+
pred_steps = []
|
|
183
|
+
|
|
184
|
+
# 计算指标
|
|
185
|
+
if is_plan_success(pred_steps, ref_steps):
|
|
186
|
+
success_count += 1
|
|
187
|
+
|
|
188
|
+
step_acc = compute_step_accuracy(pred_steps, ref_steps)
|
|
189
|
+
tool_cov = compute_tool_coverage(pred_steps, ref_steps)
|
|
190
|
+
|
|
191
|
+
step_accuracies.append(step_acc)
|
|
192
|
+
tool_coverages.append(tool_cov)
|
|
193
|
+
plan_lengths.append(len(pred_steps))
|
|
194
|
+
|
|
195
|
+
except Exception as e:
|
|
196
|
+
if verbose:
|
|
197
|
+
print(f" Error: {e}")
|
|
198
|
+
step_accuracies.append(0.0)
|
|
199
|
+
tool_coverages.append(0.0)
|
|
200
|
+
plan_lengths.append(0)
|
|
201
|
+
|
|
202
|
+
pbar.update(1)
|
|
203
|
+
|
|
204
|
+
elapsed = time.time() - start_time
|
|
205
|
+
|
|
206
|
+
# 计算汇总指标
|
|
207
|
+
n = len(samples)
|
|
208
|
+
success_rate = success_count / n if n > 0 else 0
|
|
209
|
+
avg_step_acc = sum(step_accuracies) / n if n > 0 else 0
|
|
210
|
+
avg_tool_cov = sum(tool_coverages) / n if n > 0 else 0
|
|
211
|
+
avg_plan_len = sum(plan_lengths) / n if n > 0 else 0
|
|
212
|
+
|
|
213
|
+
exp_result = ExperimentResult(
|
|
214
|
+
challenge="planning",
|
|
215
|
+
strategy=strategy_name,
|
|
216
|
+
metrics={
|
|
217
|
+
"plan_success_rate": success_rate,
|
|
218
|
+
"step_accuracy": avg_step_acc,
|
|
219
|
+
"tool_coverage": avg_tool_cov,
|
|
220
|
+
},
|
|
221
|
+
metadata={
|
|
222
|
+
"total_samples": n,
|
|
223
|
+
"success_count": success_count,
|
|
224
|
+
"avg_plan_length": avg_plan_len,
|
|
225
|
+
"latency_ms": elapsed * 1000 / n if n > 0 else 0,
|
|
226
|
+
},
|
|
227
|
+
passed=success_rate >= target,
|
|
228
|
+
target=target,
|
|
229
|
+
)
|
|
230
|
+
results.append(exp_result)
|
|
231
|
+
|
|
232
|
+
# 打印结果
|
|
233
|
+
print_result_row(
|
|
234
|
+
display_name, {"plan_success_rate": success_rate}, exp_result.passed, target
|
|
235
|
+
)
|
|
236
|
+
if verbose:
|
|
237
|
+
print_metrics_detail(exp_result.metrics)
|
|
238
|
+
|
|
239
|
+
# 找出最佳策略
|
|
240
|
+
best_result = max(results, key=lambda r: r.metrics["plan_success_rate"]) if results else None
|
|
241
|
+
|
|
242
|
+
summary = ExperimentSummary(
|
|
243
|
+
section="5_2_main",
|
|
244
|
+
challenge="planning",
|
|
245
|
+
results=results,
|
|
246
|
+
best_strategy=best_result.strategy if best_result else None,
|
|
247
|
+
best_metric=best_result.metrics["plan_success_rate"] if best_result else None,
|
|
248
|
+
target_met=any(r.passed for r in results),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# 保存结果
|
|
252
|
+
output_file = save_results(summary.to_dict(), "5_2_main", "planning")
|
|
253
|
+
print(f"\n Results saved to: {output_file}")
|
|
254
|
+
|
|
255
|
+
# 生成图表
|
|
256
|
+
if results:
|
|
257
|
+
from figure_generator import (
|
|
258
|
+
generate_detailed_table,
|
|
259
|
+
get_figures_dir,
|
|
260
|
+
get_tables_dir,
|
|
261
|
+
plot_challenge_comparison,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
figures_dir = get_figures_dir()
|
|
265
|
+
tables_dir = get_tables_dir()
|
|
266
|
+
|
|
267
|
+
# 生成对比图
|
|
268
|
+
plot_challenge_comparison(
|
|
269
|
+
[{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
|
|
270
|
+
challenge="planning",
|
|
271
|
+
metrics=["plan_success_rate", "step_accuracy", "tool_coverage"],
|
|
272
|
+
target=target,
|
|
273
|
+
output_path=figures_dir / "fig2_main_planning_comparison.pdf",
|
|
274
|
+
title="Task Planning: Strategy Comparison",
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# 生成表格
|
|
278
|
+
generate_detailed_table(
|
|
279
|
+
[{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
|
|
280
|
+
challenge="planning",
|
|
281
|
+
metrics=["plan_success_rate", "step_accuracy", "tool_coverage"],
|
|
282
|
+
output_path=tables_dir / "table_planning_detailed.tex",
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
print(f" Figure saved to: {figures_dir / 'fig2_main_planning_comparison.pdf'}")
|
|
286
|
+
|
|
287
|
+
return summary
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def main():
|
|
291
|
+
parser = argparse.ArgumentParser(description="Section 5.2.2: Task Planning Experiment")
|
|
292
|
+
parser.add_argument("--max-samples", type=int, default=100, help="Maximum samples to test")
|
|
293
|
+
parser.add_argument("--skip-llm", action="store_true", help="Skip LLM-based methods")
|
|
294
|
+
parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
|
|
295
|
+
args = parser.parse_args()
|
|
296
|
+
|
|
297
|
+
summary = run_planning_experiment(
|
|
298
|
+
max_samples=args.max_samples,
|
|
299
|
+
skip_llm=args.skip_llm,
|
|
300
|
+
verbose=args.verbose,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# 打印总结
|
|
304
|
+
print("\n" + "=" * 70)
|
|
305
|
+
print("📊 Summary")
|
|
306
|
+
print("=" * 70)
|
|
307
|
+
print(f" Best strategy: {summary.best_strategy}")
|
|
308
|
+
print(
|
|
309
|
+
f" Best success rate: {summary.best_metric * 100:.1f}%" if summary.best_metric else " N/A"
|
|
310
|
+
)
|
|
311
|
+
print(f" Target met: {'✅ YES' if summary.target_met else '❌ NO'}")
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
if __name__ == "__main__":
|
|
315
|
+
main()
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Section 5.2.3: Tool Selection Experiment (RQ3)
|
|
4
|
+
|
|
5
|
+
研究问题: 现有方法从大规模工具库中选择正确工具的能力如何?
|
|
6
|
+
|
|
7
|
+
测试方法:
|
|
8
|
+
- selector.keyword : BM25 关键词匹配
|
|
9
|
+
- selector.embedding : Dense Retrieval 语义匹配
|
|
10
|
+
- selector.hybrid : 40% BM25 + 60% Dense 融合
|
|
11
|
+
- selector.gorilla : Embedding 检索 + LLM 重排序
|
|
12
|
+
- selector.dfsdt : LLM 逐个评分 (ToolLLM 方法)
|
|
13
|
+
|
|
14
|
+
目标指标:
|
|
15
|
+
- Primary: Top-K Accuracy ≥ 95% (K=5)
|
|
16
|
+
- Secondary: MRR, Recall@K, Precision@K
|
|
17
|
+
- Tertiary: Latency (ms)
|
|
18
|
+
|
|
19
|
+
Usage:
|
|
20
|
+
python exp_main_selection.py
|
|
21
|
+
python exp_main_selection.py --max-samples 100 --top-k 5
|
|
22
|
+
python exp_main_selection.py --skip-llm
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
import argparse
|
|
28
|
+
import time
|
|
29
|
+
|
|
30
|
+
from .exp_utils import (
|
|
31
|
+
ExperimentResult,
|
|
32
|
+
ExperimentSummary,
|
|
33
|
+
create_progress_bar,
|
|
34
|
+
load_benchmark_data,
|
|
35
|
+
print_metrics_detail,
|
|
36
|
+
print_result_row,
|
|
37
|
+
print_section_header,
|
|
38
|
+
print_subsection_header,
|
|
39
|
+
save_results,
|
|
40
|
+
setup_experiment_env,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compute_mrr(predictions: list[list[str]], references: list[list[str]]) -> float:
|
|
45
|
+
"""计算 Mean Reciprocal Rank。"""
|
|
46
|
+
rr_sum = 0.0
|
|
47
|
+
for pred, ref in zip(predictions, references):
|
|
48
|
+
ref_set = set(ref)
|
|
49
|
+
for i, p in enumerate(pred):
|
|
50
|
+
if p in ref_set:
|
|
51
|
+
rr_sum += 1.0 / (i + 1)
|
|
52
|
+
break
|
|
53
|
+
return rr_sum / len(predictions) if predictions else 0.0
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_top_k_accuracy(
|
|
57
|
+
predictions: list[list[str]], references: list[list[str]], k: int
|
|
58
|
+
) -> float:
|
|
59
|
+
"""计算 Top-K Accuracy。"""
|
|
60
|
+
hits = 0
|
|
61
|
+
for pred, ref in zip(predictions, references):
|
|
62
|
+
pred_top_k = set(pred[:k])
|
|
63
|
+
ref_set = set(ref)
|
|
64
|
+
if pred_top_k & ref_set:
|
|
65
|
+
hits += 1
|
|
66
|
+
return hits / len(predictions) if predictions else 0.0
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def compute_recall_at_k(predictions: list[list[str]], references: list[list[str]], k: int) -> float:
|
|
70
|
+
"""计算 Recall@K。"""
|
|
71
|
+
recalls = []
|
|
72
|
+
for pred, ref in zip(predictions, references):
|
|
73
|
+
pred_top_k = set(pred[:k])
|
|
74
|
+
ref_set = set(ref)
|
|
75
|
+
if ref_set:
|
|
76
|
+
recalls.append(len(pred_top_k & ref_set) / len(ref_set))
|
|
77
|
+
else:
|
|
78
|
+
recalls.append(0.0)
|
|
79
|
+
return sum(recalls) / len(recalls) if recalls else 0.0
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def compute_precision_at_k(
|
|
83
|
+
predictions: list[list[str]], references: list[list[str]], k: int
|
|
84
|
+
) -> float:
|
|
85
|
+
"""计算 Precision@K。"""
|
|
86
|
+
precisions = []
|
|
87
|
+
for pred, ref in zip(predictions, references):
|
|
88
|
+
pred_top_k = set(pred[:k])
|
|
89
|
+
ref_set = set(ref)
|
|
90
|
+
if pred_top_k:
|
|
91
|
+
precisions.append(len(pred_top_k & ref_set) / len(pred_top_k))
|
|
92
|
+
else:
|
|
93
|
+
precisions.append(0.0)
|
|
94
|
+
return sum(precisions) / len(precisions) if precisions else 0.0
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def normalize_ground_truth(ground_truth: object) -> list[str]:
|
|
98
|
+
"""将地面真实标签统一为字符串列表。"""
|
|
99
|
+
|
|
100
|
+
if ground_truth is None:
|
|
101
|
+
return []
|
|
102
|
+
|
|
103
|
+
if isinstance(ground_truth, str):
|
|
104
|
+
return [ground_truth]
|
|
105
|
+
|
|
106
|
+
if isinstance(ground_truth, list):
|
|
107
|
+
return [str(item) for item in ground_truth]
|
|
108
|
+
|
|
109
|
+
if isinstance(ground_truth, dict):
|
|
110
|
+
for key in ("top_k", "tool_ids", "tools", "ids"):
|
|
111
|
+
value = ground_truth.get(key)
|
|
112
|
+
if isinstance(value, list):
|
|
113
|
+
return [str(item) for item in value]
|
|
114
|
+
if isinstance(value, str):
|
|
115
|
+
return [value]
|
|
116
|
+
# fall back to any string-like values
|
|
117
|
+
values = [str(value) for value in ground_truth.values() if value]
|
|
118
|
+
if values:
|
|
119
|
+
return values
|
|
120
|
+
|
|
121
|
+
return [str(ground_truth)]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def run_selection_experiment(
|
|
125
|
+
max_samples: int = 100,
|
|
126
|
+
top_k: int = 5,
|
|
127
|
+
skip_llm: bool = False,
|
|
128
|
+
verbose: bool = True,
|
|
129
|
+
) -> ExperimentSummary:
|
|
130
|
+
"""
|
|
131
|
+
运行 Tool Selection 实验。
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
max_samples: 最大测试样本数
|
|
135
|
+
top_k: Top-K 参数
|
|
136
|
+
skip_llm: 是否跳过 LLM-based 方法
|
|
137
|
+
verbose: 是否打印详细信息
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
ExperimentSummary 对象
|
|
141
|
+
"""
|
|
142
|
+
setup_experiment_env(verbose=verbose)
|
|
143
|
+
|
|
144
|
+
print_section_header("Section 5.2.3: Tool Selection (RQ3)")
|
|
145
|
+
print(f" Target: Top-{top_k} Accuracy ≥ 95%")
|
|
146
|
+
print(f" Max samples: {max_samples}")
|
|
147
|
+
|
|
148
|
+
# 加载数据
|
|
149
|
+
samples = load_benchmark_data("selection", split="test", max_samples=max_samples)
|
|
150
|
+
if not samples:
|
|
151
|
+
print(" ❌ No selection data available")
|
|
152
|
+
return ExperimentSummary(section="5_2_main", challenge="selection")
|
|
153
|
+
|
|
154
|
+
print(f" Loaded {len(samples)} samples")
|
|
155
|
+
|
|
156
|
+
# 获取策略注册表
|
|
157
|
+
try:
|
|
158
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
159
|
+
|
|
160
|
+
registry = get_adapter_registry()
|
|
161
|
+
except ImportError as e:
|
|
162
|
+
print(f" ❌ Failed to import adapter registry: {e}")
|
|
163
|
+
return ExperimentSummary(section="5_2_main", challenge="selection")
|
|
164
|
+
|
|
165
|
+
# 定义测试策略
|
|
166
|
+
strategies = [
|
|
167
|
+
("selector.keyword", "Keyword (BM25)"),
|
|
168
|
+
("selector.embedding", "Embedding"),
|
|
169
|
+
("selector.hybrid", "Hybrid"),
|
|
170
|
+
("selector.gorilla", "Gorilla"),
|
|
171
|
+
("selector.dfsdt", "DFSDT"),
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
# 跳过 LLM 策略
|
|
175
|
+
LLM_STRATEGIES = {"selector.gorilla", "selector.dfsdt"}
|
|
176
|
+
if skip_llm:
|
|
177
|
+
strategies = [(name, display) for name, display in strategies if name not in LLM_STRATEGIES]
|
|
178
|
+
print(" ⚠️ Skipping LLM-based strategies")
|
|
179
|
+
|
|
180
|
+
results = []
|
|
181
|
+
target = 0.95
|
|
182
|
+
|
|
183
|
+
for strategy_name, display_name in strategies:
|
|
184
|
+
print_subsection_header(f"Testing: {display_name}")
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
selector = registry.get(strategy_name)
|
|
188
|
+
except Exception as e:
|
|
189
|
+
print(f" ⚠️ Failed to create selector: {e}")
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
# 运行评测
|
|
193
|
+
start_time = time.time()
|
|
194
|
+
all_predictions = []
|
|
195
|
+
all_references = []
|
|
196
|
+
|
|
197
|
+
with create_progress_bar(len(samples), desc=f" {display_name}") as pbar:
|
|
198
|
+
for sample in samples:
|
|
199
|
+
try:
|
|
200
|
+
query = sample.get("instruction", sample.get("query", ""))
|
|
201
|
+
candidate_tools = sample.get("candidate_tools", [])
|
|
202
|
+
ground_truth = sample.get("ground_truth", sample.get("expected_tools", []))
|
|
203
|
+
|
|
204
|
+
# 调用选择器
|
|
205
|
+
predictions = selector.select(
|
|
206
|
+
query, candidate_tools=candidate_tools, top_k=top_k
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# 提取工具 ID
|
|
210
|
+
if predictions and hasattr(predictions[0], "tool_id"):
|
|
211
|
+
pred_ids = [p.tool_id for p in predictions]
|
|
212
|
+
elif predictions and isinstance(predictions[0], dict):
|
|
213
|
+
pred_ids = [p.get("tool_id", p.get("id", str(p))) for p in predictions]
|
|
214
|
+
else:
|
|
215
|
+
pred_ids = [str(p) for p in predictions] if predictions else []
|
|
216
|
+
|
|
217
|
+
# 标准化 ground truth
|
|
218
|
+
ref_ids = normalize_ground_truth(ground_truth)
|
|
219
|
+
|
|
220
|
+
all_predictions.append(pred_ids)
|
|
221
|
+
all_references.append(ref_ids)
|
|
222
|
+
|
|
223
|
+
except Exception as e:
|
|
224
|
+
if verbose:
|
|
225
|
+
print(f" Error: {e}")
|
|
226
|
+
all_predictions.append([])
|
|
227
|
+
all_references.append(sample.get("ground_truth", []))
|
|
228
|
+
|
|
229
|
+
pbar.update(1)
|
|
230
|
+
|
|
231
|
+
elapsed = time.time() - start_time
|
|
232
|
+
|
|
233
|
+
# 计算指标
|
|
234
|
+
n = len(samples)
|
|
235
|
+
top_k_acc = compute_top_k_accuracy(all_predictions, all_references, top_k)
|
|
236
|
+
mrr = compute_mrr(all_predictions, all_references)
|
|
237
|
+
recall_k = compute_recall_at_k(all_predictions, all_references, top_k)
|
|
238
|
+
precision_k = compute_precision_at_k(all_predictions, all_references, top_k)
|
|
239
|
+
|
|
240
|
+
exp_result = ExperimentResult(
|
|
241
|
+
challenge="selection",
|
|
242
|
+
strategy=strategy_name,
|
|
243
|
+
metrics={
|
|
244
|
+
"top_k_accuracy": top_k_acc,
|
|
245
|
+
"mrr": mrr,
|
|
246
|
+
f"recall@{top_k}": recall_k,
|
|
247
|
+
f"precision@{top_k}": precision_k,
|
|
248
|
+
},
|
|
249
|
+
metadata={
|
|
250
|
+
"total_samples": n,
|
|
251
|
+
"top_k": top_k,
|
|
252
|
+
"latency_ms": elapsed * 1000 / n if n > 0 else 0,
|
|
253
|
+
},
|
|
254
|
+
passed=top_k_acc >= target,
|
|
255
|
+
target=target,
|
|
256
|
+
)
|
|
257
|
+
results.append(exp_result)
|
|
258
|
+
|
|
259
|
+
# 打印结果
|
|
260
|
+
print_result_row(display_name, {"top_k_accuracy": top_k_acc}, exp_result.passed, target)
|
|
261
|
+
if verbose:
|
|
262
|
+
print_metrics_detail(exp_result.metrics)
|
|
263
|
+
|
|
264
|
+
# 找出最佳策略
|
|
265
|
+
best_result = max(results, key=lambda r: r.metrics["top_k_accuracy"]) if results else None
|
|
266
|
+
|
|
267
|
+
summary = ExperimentSummary(
|
|
268
|
+
section="5_2_main",
|
|
269
|
+
challenge="selection",
|
|
270
|
+
results=results,
|
|
271
|
+
best_strategy=best_result.strategy if best_result else None,
|
|
272
|
+
best_metric=best_result.metrics["top_k_accuracy"] if best_result else None,
|
|
273
|
+
target_met=any(r.passed for r in results),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# 保存结果
|
|
277
|
+
output_file = save_results(summary.to_dict(), "5_2_main", "selection")
|
|
278
|
+
print(f"\n Results saved to: {output_file}")
|
|
279
|
+
|
|
280
|
+
# 生成图表
|
|
281
|
+
if results:
|
|
282
|
+
from figure_generator import (
|
|
283
|
+
generate_detailed_table,
|
|
284
|
+
get_figures_dir,
|
|
285
|
+
get_tables_dir,
|
|
286
|
+
plot_challenge_comparison,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
figures_dir = get_figures_dir()
|
|
290
|
+
tables_dir = get_tables_dir()
|
|
291
|
+
|
|
292
|
+
# 生成对比图
|
|
293
|
+
plot_challenge_comparison(
|
|
294
|
+
[{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
|
|
295
|
+
challenge="selection",
|
|
296
|
+
metrics=["top_k_accuracy", "mrr"],
|
|
297
|
+
target=target,
|
|
298
|
+
output_path=figures_dir / "fig3_main_selection_comparison.pdf",
|
|
299
|
+
title=f"Tool Selection: Strategy Comparison (Top-{top_k})",
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# 生成表格
|
|
303
|
+
generate_detailed_table(
|
|
304
|
+
[{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
|
|
305
|
+
challenge="selection",
|
|
306
|
+
metrics=["top_k_accuracy", "mrr", f"recall@{top_k}", f"precision@{top_k}"],
|
|
307
|
+
output_path=tables_dir / "table_selection_detailed.tex",
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
print(f" Figure saved to: {figures_dir / 'fig3_main_selection_comparison.pdf'}")
|
|
311
|
+
|
|
312
|
+
return summary
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def main():
|
|
316
|
+
parser = argparse.ArgumentParser(description="Section 5.2.3: Tool Selection Experiment")
|
|
317
|
+
parser.add_argument("--max-samples", type=int, default=100, help="Maximum samples to test")
|
|
318
|
+
parser.add_argument("--top-k", type=int, default=5, help="Top-K parameter")
|
|
319
|
+
parser.add_argument("--skip-llm", action="store_true", help="Skip LLM-based methods")
|
|
320
|
+
parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
|
|
321
|
+
args = parser.parse_args()
|
|
322
|
+
|
|
323
|
+
summary = run_selection_experiment(
|
|
324
|
+
max_samples=args.max_samples,
|
|
325
|
+
top_k=args.top_k,
|
|
326
|
+
skip_llm=args.skip_llm,
|
|
327
|
+
verbose=args.verbose,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# 打印总结
|
|
331
|
+
print("\n" + "=" * 70)
|
|
332
|
+
print("📊 Summary")
|
|
333
|
+
print("=" * 70)
|
|
334
|
+
print(f" Best strategy: {summary.best_strategy}")
|
|
335
|
+
print(
|
|
336
|
+
f" Best Top-K accuracy: {summary.best_metric * 100:.1f}%"
|
|
337
|
+
if summary.best_metric
|
|
338
|
+
else " N/A"
|
|
339
|
+
)
|
|
340
|
+
print(f" Target met: {'✅ YES' if summary.target_met else '❌ NO'}")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
if __name__ == "__main__":
|
|
344
|
+
main()
|