isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,315 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Section 5.2.2: Task Planning Experiment (RQ2)
4
+
5
+ 研究问题: 现有方法将复杂任务分解为执行步骤的能力如何?
6
+
7
+ 测试方法:
8
+ - planner.simple : 简单贪心匹配 (Baseline)
9
+ - planner.hierarchical : 层次化分解 (HuggingGPT, ICML'23)
10
+ - planner.llm_based : LLM 生成计划 (CoT Prompting)
11
+ - planner.react : 交错执行 (ReAct, Yao et al., ICLR'23)
12
+ - planner.tot : 树搜索 (Tree-of-Thoughts, NeurIPS'23)
13
+
14
+ 目标指标:
15
+ - Primary: Plan Success Rate ≥ 90%
16
+ - Secondary: Step Accuracy, Tool Coverage
17
+ - Tertiary: Average Plan Length
18
+
19
+ Usage:
20
+ python exp_main_planning.py
21
+ python exp_main_planning.py --max-samples 100
22
+ python exp_main_planning.py --skip-llm
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import time
29
+
30
+ from .exp_utils import (
31
+ ExperimentResult,
32
+ ExperimentSummary,
33
+ create_progress_bar,
34
+ load_benchmark_data,
35
+ print_metrics_detail,
36
+ print_result_row,
37
+ print_section_header,
38
+ print_subsection_header,
39
+ save_results,
40
+ setup_experiment_env,
41
+ )
42
+
43
+
44
+ def compute_step_accuracy(predicted_steps: list[str], reference_steps: list[str]) -> float:
45
+ """计算步骤级别的准确率。"""
46
+ if not reference_steps:
47
+ return 1.0 if not predicted_steps else 0.0
48
+
49
+ correct = sum(1 for p, r in zip(predicted_steps, reference_steps) if p == r)
50
+ return correct / len(reference_steps)
51
+
52
+
53
+ def compute_tool_coverage(predicted_steps: list[str], reference_steps: list[str]) -> float:
54
+ """计算工具覆盖率。"""
55
+ if not reference_steps:
56
+ return 1.0
57
+
58
+ ref_set = set(reference_steps)
59
+ pred_set = set(predicted_steps)
60
+ return len(pred_set & ref_set) / len(ref_set)
61
+
62
+
63
+ def is_plan_success(
64
+ predicted_steps: list[str], reference_steps: list[str], threshold: float = 0.8
65
+ ) -> bool:
66
+ """判断计划是否成功 (步骤准确率 >= threshold)。"""
67
+ if not reference_steps:
68
+ return not predicted_steps
69
+
70
+ step_acc = compute_step_accuracy(predicted_steps, reference_steps)
71
+ return step_acc >= threshold
72
+
73
+
74
+ def run_planning_experiment(
75
+ max_samples: int = 100,
76
+ skip_llm: bool = False,
77
+ verbose: bool = True,
78
+ ) -> ExperimentSummary:
79
+ """
80
+ 运行 Task Planning 实验。
81
+
82
+ Args:
83
+ max_samples: 最大测试样本数
84
+ skip_llm: 是否跳过 LLM-based 方法
85
+ verbose: 是否打印详细信息
86
+
87
+ Returns:
88
+ ExperimentSummary 对象
89
+ """
90
+ setup_experiment_env(verbose=verbose)
91
+
92
+ print_section_header("Section 5.2.2: Task Planning (RQ2)")
93
+ print(" Target: Plan Success Rate ≥ 90%")
94
+ print(f" Max samples: {max_samples}")
95
+
96
+ # 加载数据
97
+ samples = load_benchmark_data("planning", split="test", max_samples=max_samples)
98
+ if not samples:
99
+ print(" ❌ No planning data available")
100
+ return ExperimentSummary(section="5_2_main", challenge="planning")
101
+
102
+ print(f" Loaded {len(samples)} samples")
103
+
104
+ # 获取策略注册表
105
+ try:
106
+ from sage.benchmark.benchmark_agent import get_adapter_registry
107
+
108
+ registry = get_adapter_registry()
109
+ except ImportError as e:
110
+ print(f" ❌ Failed to import adapter registry: {e}")
111
+ return ExperimentSummary(section="5_2_main", challenge="planning")
112
+
113
+ # 定义测试策略 - 包含主流 SOTA 方法
114
+ strategies = [
115
+ ("planner.simple", "Simple (Greedy)"),
116
+ ("planner.hierarchical", "Hierarchical (HuggingGPT)"),
117
+ ("planner.llm_based", "LLM-based (CoT)"),
118
+ ("planner.react", "ReAct (Yao et al.)"),
119
+ ("planner.tot", "Tree-of-Thoughts"),
120
+ ]
121
+
122
+ # 跳过 LLM 策略
123
+ LLM_STRATEGIES = {"planner.llm_based", "planner.react", "planner.tot"}
124
+ if skip_llm:
125
+ strategies = [(name, display) for name, display in strategies if name not in LLM_STRATEGIES]
126
+ print(" ⚠️ Skipping LLM-based strategies")
127
+
128
+ results = []
129
+ target = 0.90
130
+
131
+ for strategy_name, display_name in strategies:
132
+ print_subsection_header(f"Testing: {display_name}")
133
+
134
+ try:
135
+ planner = registry.get(strategy_name)
136
+ except Exception as e:
137
+ print(f" ⚠️ Failed to create planner: {e}")
138
+ continue
139
+
140
+ # 运行评测
141
+ start_time = time.time()
142
+ success_count = 0
143
+ step_accuracies = []
144
+ tool_coverages = []
145
+ plan_lengths = []
146
+
147
+ with create_progress_bar(len(samples), desc=f" {display_name}") as pbar:
148
+ for sample in samples:
149
+ try:
150
+ task_description = sample.get("task", sample.get("instruction", ""))
151
+ available_tools = sample.get("available_tools", sample.get("tools", []))
152
+ reference_plan = sample.get("ground_truth", sample.get("expected_plan", []))
153
+
154
+ # 标准化 reference
155
+ if isinstance(reference_plan, list) and reference_plan:
156
+ if isinstance(reference_plan[0], dict):
157
+ ref_steps = [
158
+ s.get("tool_id", s.get("tool", "")) for s in reference_plan
159
+ ]
160
+ else:
161
+ ref_steps = reference_plan
162
+ else:
163
+ ref_steps = []
164
+
165
+ # 调用规划器
166
+ plan_result = planner.plan(task_description, available_tools=available_tools)
167
+
168
+ # 提取预测步骤
169
+ if hasattr(plan_result, "tool_sequence"):
170
+ pred_steps = plan_result.tool_sequence
171
+ elif hasattr(plan_result, "steps"):
172
+ pred_steps = [
173
+ s.tool_id if hasattr(s, "tool_id") else str(s)
174
+ for s in plan_result.steps
175
+ ]
176
+ elif isinstance(plan_result, list):
177
+ pred_steps = [
178
+ s.get("tool_id", str(s)) if isinstance(s, dict) else str(s)
179
+ for s in plan_result
180
+ ]
181
+ else:
182
+ pred_steps = []
183
+
184
+ # 计算指标
185
+ if is_plan_success(pred_steps, ref_steps):
186
+ success_count += 1
187
+
188
+ step_acc = compute_step_accuracy(pred_steps, ref_steps)
189
+ tool_cov = compute_tool_coverage(pred_steps, ref_steps)
190
+
191
+ step_accuracies.append(step_acc)
192
+ tool_coverages.append(tool_cov)
193
+ plan_lengths.append(len(pred_steps))
194
+
195
+ except Exception as e:
196
+ if verbose:
197
+ print(f" Error: {e}")
198
+ step_accuracies.append(0.0)
199
+ tool_coverages.append(0.0)
200
+ plan_lengths.append(0)
201
+
202
+ pbar.update(1)
203
+
204
+ elapsed = time.time() - start_time
205
+
206
+ # 计算汇总指标
207
+ n = len(samples)
208
+ success_rate = success_count / n if n > 0 else 0
209
+ avg_step_acc = sum(step_accuracies) / n if n > 0 else 0
210
+ avg_tool_cov = sum(tool_coverages) / n if n > 0 else 0
211
+ avg_plan_len = sum(plan_lengths) / n if n > 0 else 0
212
+
213
+ exp_result = ExperimentResult(
214
+ challenge="planning",
215
+ strategy=strategy_name,
216
+ metrics={
217
+ "plan_success_rate": success_rate,
218
+ "step_accuracy": avg_step_acc,
219
+ "tool_coverage": avg_tool_cov,
220
+ },
221
+ metadata={
222
+ "total_samples": n,
223
+ "success_count": success_count,
224
+ "avg_plan_length": avg_plan_len,
225
+ "latency_ms": elapsed * 1000 / n if n > 0 else 0,
226
+ },
227
+ passed=success_rate >= target,
228
+ target=target,
229
+ )
230
+ results.append(exp_result)
231
+
232
+ # 打印结果
233
+ print_result_row(
234
+ display_name, {"plan_success_rate": success_rate}, exp_result.passed, target
235
+ )
236
+ if verbose:
237
+ print_metrics_detail(exp_result.metrics)
238
+
239
+ # 找出最佳策略
240
+ best_result = max(results, key=lambda r: r.metrics["plan_success_rate"]) if results else None
241
+
242
+ summary = ExperimentSummary(
243
+ section="5_2_main",
244
+ challenge="planning",
245
+ results=results,
246
+ best_strategy=best_result.strategy if best_result else None,
247
+ best_metric=best_result.metrics["plan_success_rate"] if best_result else None,
248
+ target_met=any(r.passed for r in results),
249
+ )
250
+
251
+ # 保存结果
252
+ output_file = save_results(summary.to_dict(), "5_2_main", "planning")
253
+ print(f"\n Results saved to: {output_file}")
254
+
255
+ # 生成图表
256
+ if results:
257
+ from figure_generator import (
258
+ generate_detailed_table,
259
+ get_figures_dir,
260
+ get_tables_dir,
261
+ plot_challenge_comparison,
262
+ )
263
+
264
+ figures_dir = get_figures_dir()
265
+ tables_dir = get_tables_dir()
266
+
267
+ # 生成对比图
268
+ plot_challenge_comparison(
269
+ [{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
270
+ challenge="planning",
271
+ metrics=["plan_success_rate", "step_accuracy", "tool_coverage"],
272
+ target=target,
273
+ output_path=figures_dir / "fig2_main_planning_comparison.pdf",
274
+ title="Task Planning: Strategy Comparison",
275
+ )
276
+
277
+ # 生成表格
278
+ generate_detailed_table(
279
+ [{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
280
+ challenge="planning",
281
+ metrics=["plan_success_rate", "step_accuracy", "tool_coverage"],
282
+ output_path=tables_dir / "table_planning_detailed.tex",
283
+ )
284
+
285
+ print(f" Figure saved to: {figures_dir / 'fig2_main_planning_comparison.pdf'}")
286
+
287
+ return summary
288
+
289
+
290
+ def main():
291
+ parser = argparse.ArgumentParser(description="Section 5.2.2: Task Planning Experiment")
292
+ parser.add_argument("--max-samples", type=int, default=100, help="Maximum samples to test")
293
+ parser.add_argument("--skip-llm", action="store_true", help="Skip LLM-based methods")
294
+ parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
295
+ args = parser.parse_args()
296
+
297
+ summary = run_planning_experiment(
298
+ max_samples=args.max_samples,
299
+ skip_llm=args.skip_llm,
300
+ verbose=args.verbose,
301
+ )
302
+
303
+ # 打印总结
304
+ print("\n" + "=" * 70)
305
+ print("📊 Summary")
306
+ print("=" * 70)
307
+ print(f" Best strategy: {summary.best_strategy}")
308
+ print(
309
+ f" Best success rate: {summary.best_metric * 100:.1f}%" if summary.best_metric else " N/A"
310
+ )
311
+ print(f" Target met: {'✅ YES' if summary.target_met else '❌ NO'}")
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
@@ -0,0 +1,344 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Section 5.2.3: Tool Selection Experiment (RQ3)
4
+
5
+ 研究问题: 现有方法从大规模工具库中选择正确工具的能力如何?
6
+
7
+ 测试方法:
8
+ - selector.keyword : BM25 关键词匹配
9
+ - selector.embedding : Dense Retrieval 语义匹配
10
+ - selector.hybrid : 40% BM25 + 60% Dense 融合
11
+ - selector.gorilla : Embedding 检索 + LLM 重排序
12
+ - selector.dfsdt : LLM 逐个评分 (ToolLLM 方法)
13
+
14
+ 目标指标:
15
+ - Primary: Top-K Accuracy ≥ 95% (K=5)
16
+ - Secondary: MRR, Recall@K, Precision@K
17
+ - Tertiary: Latency (ms)
18
+
19
+ Usage:
20
+ python exp_main_selection.py
21
+ python exp_main_selection.py --max-samples 100 --top-k 5
22
+ python exp_main_selection.py --skip-llm
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import time
29
+
30
+ from .exp_utils import (
31
+ ExperimentResult,
32
+ ExperimentSummary,
33
+ create_progress_bar,
34
+ load_benchmark_data,
35
+ print_metrics_detail,
36
+ print_result_row,
37
+ print_section_header,
38
+ print_subsection_header,
39
+ save_results,
40
+ setup_experiment_env,
41
+ )
42
+
43
+
44
+ def compute_mrr(predictions: list[list[str]], references: list[list[str]]) -> float:
45
+ """计算 Mean Reciprocal Rank。"""
46
+ rr_sum = 0.0
47
+ for pred, ref in zip(predictions, references):
48
+ ref_set = set(ref)
49
+ for i, p in enumerate(pred):
50
+ if p in ref_set:
51
+ rr_sum += 1.0 / (i + 1)
52
+ break
53
+ return rr_sum / len(predictions) if predictions else 0.0
54
+
55
+
56
+ def compute_top_k_accuracy(
57
+ predictions: list[list[str]], references: list[list[str]], k: int
58
+ ) -> float:
59
+ """计算 Top-K Accuracy。"""
60
+ hits = 0
61
+ for pred, ref in zip(predictions, references):
62
+ pred_top_k = set(pred[:k])
63
+ ref_set = set(ref)
64
+ if pred_top_k & ref_set:
65
+ hits += 1
66
+ return hits / len(predictions) if predictions else 0.0
67
+
68
+
69
+ def compute_recall_at_k(predictions: list[list[str]], references: list[list[str]], k: int) -> float:
70
+ """计算 Recall@K。"""
71
+ recalls = []
72
+ for pred, ref in zip(predictions, references):
73
+ pred_top_k = set(pred[:k])
74
+ ref_set = set(ref)
75
+ if ref_set:
76
+ recalls.append(len(pred_top_k & ref_set) / len(ref_set))
77
+ else:
78
+ recalls.append(0.0)
79
+ return sum(recalls) / len(recalls) if recalls else 0.0
80
+
81
+
82
+ def compute_precision_at_k(
83
+ predictions: list[list[str]], references: list[list[str]], k: int
84
+ ) -> float:
85
+ """计算 Precision@K。"""
86
+ precisions = []
87
+ for pred, ref in zip(predictions, references):
88
+ pred_top_k = set(pred[:k])
89
+ ref_set = set(ref)
90
+ if pred_top_k:
91
+ precisions.append(len(pred_top_k & ref_set) / len(pred_top_k))
92
+ else:
93
+ precisions.append(0.0)
94
+ return sum(precisions) / len(precisions) if precisions else 0.0
95
+
96
+
97
+ def normalize_ground_truth(ground_truth: object) -> list[str]:
98
+ """将地面真实标签统一为字符串列表。"""
99
+
100
+ if ground_truth is None:
101
+ return []
102
+
103
+ if isinstance(ground_truth, str):
104
+ return [ground_truth]
105
+
106
+ if isinstance(ground_truth, list):
107
+ return [str(item) for item in ground_truth]
108
+
109
+ if isinstance(ground_truth, dict):
110
+ for key in ("top_k", "tool_ids", "tools", "ids"):
111
+ value = ground_truth.get(key)
112
+ if isinstance(value, list):
113
+ return [str(item) for item in value]
114
+ if isinstance(value, str):
115
+ return [value]
116
+ # fall back to any string-like values
117
+ values = [str(value) for value in ground_truth.values() if value]
118
+ if values:
119
+ return values
120
+
121
+ return [str(ground_truth)]
122
+
123
+
124
+ def run_selection_experiment(
125
+ max_samples: int = 100,
126
+ top_k: int = 5,
127
+ skip_llm: bool = False,
128
+ verbose: bool = True,
129
+ ) -> ExperimentSummary:
130
+ """
131
+ 运行 Tool Selection 实验。
132
+
133
+ Args:
134
+ max_samples: 最大测试样本数
135
+ top_k: Top-K 参数
136
+ skip_llm: 是否跳过 LLM-based 方法
137
+ verbose: 是否打印详细信息
138
+
139
+ Returns:
140
+ ExperimentSummary 对象
141
+ """
142
+ setup_experiment_env(verbose=verbose)
143
+
144
+ print_section_header("Section 5.2.3: Tool Selection (RQ3)")
145
+ print(f" Target: Top-{top_k} Accuracy ≥ 95%")
146
+ print(f" Max samples: {max_samples}")
147
+
148
+ # 加载数据
149
+ samples = load_benchmark_data("selection", split="test", max_samples=max_samples)
150
+ if not samples:
151
+ print(" ❌ No selection data available")
152
+ return ExperimentSummary(section="5_2_main", challenge="selection")
153
+
154
+ print(f" Loaded {len(samples)} samples")
155
+
156
+ # 获取策略注册表
157
+ try:
158
+ from sage.benchmark.benchmark_agent import get_adapter_registry
159
+
160
+ registry = get_adapter_registry()
161
+ except ImportError as e:
162
+ print(f" ❌ Failed to import adapter registry: {e}")
163
+ return ExperimentSummary(section="5_2_main", challenge="selection")
164
+
165
+ # 定义测试策略
166
+ strategies = [
167
+ ("selector.keyword", "Keyword (BM25)"),
168
+ ("selector.embedding", "Embedding"),
169
+ ("selector.hybrid", "Hybrid"),
170
+ ("selector.gorilla", "Gorilla"),
171
+ ("selector.dfsdt", "DFSDT"),
172
+ ]
173
+
174
+ # 跳过 LLM 策略
175
+ LLM_STRATEGIES = {"selector.gorilla", "selector.dfsdt"}
176
+ if skip_llm:
177
+ strategies = [(name, display) for name, display in strategies if name not in LLM_STRATEGIES]
178
+ print(" ⚠️ Skipping LLM-based strategies")
179
+
180
+ results = []
181
+ target = 0.95
182
+
183
+ for strategy_name, display_name in strategies:
184
+ print_subsection_header(f"Testing: {display_name}")
185
+
186
+ try:
187
+ selector = registry.get(strategy_name)
188
+ except Exception as e:
189
+ print(f" ⚠️ Failed to create selector: {e}")
190
+ continue
191
+
192
+ # 运行评测
193
+ start_time = time.time()
194
+ all_predictions = []
195
+ all_references = []
196
+
197
+ with create_progress_bar(len(samples), desc=f" {display_name}") as pbar:
198
+ for sample in samples:
199
+ try:
200
+ query = sample.get("instruction", sample.get("query", ""))
201
+ candidate_tools = sample.get("candidate_tools", [])
202
+ ground_truth = sample.get("ground_truth", sample.get("expected_tools", []))
203
+
204
+ # 调用选择器
205
+ predictions = selector.select(
206
+ query, candidate_tools=candidate_tools, top_k=top_k
207
+ )
208
+
209
+ # 提取工具 ID
210
+ if predictions and hasattr(predictions[0], "tool_id"):
211
+ pred_ids = [p.tool_id for p in predictions]
212
+ elif predictions and isinstance(predictions[0], dict):
213
+ pred_ids = [p.get("tool_id", p.get("id", str(p))) for p in predictions]
214
+ else:
215
+ pred_ids = [str(p) for p in predictions] if predictions else []
216
+
217
+ # 标准化 ground truth
218
+ ref_ids = normalize_ground_truth(ground_truth)
219
+
220
+ all_predictions.append(pred_ids)
221
+ all_references.append(ref_ids)
222
+
223
+ except Exception as e:
224
+ if verbose:
225
+ print(f" Error: {e}")
226
+ all_predictions.append([])
227
+ all_references.append(sample.get("ground_truth", []))
228
+
229
+ pbar.update(1)
230
+
231
+ elapsed = time.time() - start_time
232
+
233
+ # 计算指标
234
+ n = len(samples)
235
+ top_k_acc = compute_top_k_accuracy(all_predictions, all_references, top_k)
236
+ mrr = compute_mrr(all_predictions, all_references)
237
+ recall_k = compute_recall_at_k(all_predictions, all_references, top_k)
238
+ precision_k = compute_precision_at_k(all_predictions, all_references, top_k)
239
+
240
+ exp_result = ExperimentResult(
241
+ challenge="selection",
242
+ strategy=strategy_name,
243
+ metrics={
244
+ "top_k_accuracy": top_k_acc,
245
+ "mrr": mrr,
246
+ f"recall@{top_k}": recall_k,
247
+ f"precision@{top_k}": precision_k,
248
+ },
249
+ metadata={
250
+ "total_samples": n,
251
+ "top_k": top_k,
252
+ "latency_ms": elapsed * 1000 / n if n > 0 else 0,
253
+ },
254
+ passed=top_k_acc >= target,
255
+ target=target,
256
+ )
257
+ results.append(exp_result)
258
+
259
+ # 打印结果
260
+ print_result_row(display_name, {"top_k_accuracy": top_k_acc}, exp_result.passed, target)
261
+ if verbose:
262
+ print_metrics_detail(exp_result.metrics)
263
+
264
+ # 找出最佳策略
265
+ best_result = max(results, key=lambda r: r.metrics["top_k_accuracy"]) if results else None
266
+
267
+ summary = ExperimentSummary(
268
+ section="5_2_main",
269
+ challenge="selection",
270
+ results=results,
271
+ best_strategy=best_result.strategy if best_result else None,
272
+ best_metric=best_result.metrics["top_k_accuracy"] if best_result else None,
273
+ target_met=any(r.passed for r in results),
274
+ )
275
+
276
+ # 保存结果
277
+ output_file = save_results(summary.to_dict(), "5_2_main", "selection")
278
+ print(f"\n Results saved to: {output_file}")
279
+
280
+ # 生成图表
281
+ if results:
282
+ from figure_generator import (
283
+ generate_detailed_table,
284
+ get_figures_dir,
285
+ get_tables_dir,
286
+ plot_challenge_comparison,
287
+ )
288
+
289
+ figures_dir = get_figures_dir()
290
+ tables_dir = get_tables_dir()
291
+
292
+ # 生成对比图
293
+ plot_challenge_comparison(
294
+ [{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
295
+ challenge="selection",
296
+ metrics=["top_k_accuracy", "mrr"],
297
+ target=target,
298
+ output_path=figures_dir / "fig3_main_selection_comparison.pdf",
299
+ title=f"Tool Selection: Strategy Comparison (Top-{top_k})",
300
+ )
301
+
302
+ # 生成表格
303
+ generate_detailed_table(
304
+ [{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
305
+ challenge="selection",
306
+ metrics=["top_k_accuracy", "mrr", f"recall@{top_k}", f"precision@{top_k}"],
307
+ output_path=tables_dir / "table_selection_detailed.tex",
308
+ )
309
+
310
+ print(f" Figure saved to: {figures_dir / 'fig3_main_selection_comparison.pdf'}")
311
+
312
+ return summary
313
+
314
+
315
+ def main():
316
+ parser = argparse.ArgumentParser(description="Section 5.2.3: Tool Selection Experiment")
317
+ parser.add_argument("--max-samples", type=int, default=100, help="Maximum samples to test")
318
+ parser.add_argument("--top-k", type=int, default=5, help="Top-K parameter")
319
+ parser.add_argument("--skip-llm", action="store_true", help="Skip LLM-based methods")
320
+ parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
321
+ args = parser.parse_args()
322
+
323
+ summary = run_selection_experiment(
324
+ max_samples=args.max_samples,
325
+ top_k=args.top_k,
326
+ skip_llm=args.skip_llm,
327
+ verbose=args.verbose,
328
+ )
329
+
330
+ # 打印总结
331
+ print("\n" + "=" * 70)
332
+ print("📊 Summary")
333
+ print("=" * 70)
334
+ print(f" Best strategy: {summary.best_strategy}")
335
+ print(
336
+ f" Best Top-K accuracy: {summary.best_metric * 100:.1f}%"
337
+ if summary.best_metric
338
+ else " N/A"
339
+ )
340
+ print(f" Target met: {'✅ YES' if summary.target_met else '❌ NO'}")
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()