isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,406 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Section 5.4: Cross-Dataset Generalization
4
+
5
+ 验证方法在不同数据集上的泛化能力。
6
+
7
+ 数据集:
8
+ - SAGE-Bench (ours): 内部基准
9
+ - ACE-Bench: 外部工具选择数据集
10
+ - ToolBench: Qin et al. 工具选择
11
+ - API-Bank: API 调用数据集
12
+ - BFCL: Gorilla Function Calling
13
+
14
+ 输出:
15
+ - figures/fig9_generalization_cross_dataset.pdf
16
+ - tables/table_cross_dataset_results.tex
17
+
18
+ Usage:
19
+ python exp_cross_dataset.py
20
+ python exp_cross_dataset.py --datasets sage,acebench
21
+ python exp_cross_dataset.py --strategies keyword,embedding,hybrid
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ from typing import Any, Optional
28
+
29
+ from .exp_utils import (
30
+ get_figures_dir,
31
+ load_benchmark_data,
32
+ print_section_header,
33
+ print_subsection_header,
34
+ save_results,
35
+ setup_experiment_env,
36
+ )
37
+
38
+ # =============================================================================
39
+ # Dataset Configuration
40
+ # =============================================================================
41
+
42
+ DATASETS = {
43
+ "sage": {
44
+ "name": "SAGE-Bench",
45
+ "source": "internal",
46
+ "challenge": "selection",
47
+ "loader": "load_benchmark_data",
48
+ },
49
+ "acebench": {
50
+ "name": "ACE-Bench",
51
+ "source": "external",
52
+ "path": "acebench",
53
+ "loader": "load_acebench_data",
54
+ },
55
+ "toolbench": {
56
+ "name": "ToolBench",
57
+ "source": "external",
58
+ "path": "toolbench",
59
+ "loader": "load_toolbench_data",
60
+ },
61
+ "apibank": {
62
+ "name": "API-Bank",
63
+ "source": "external",
64
+ "path": "apibank",
65
+ "loader": "load_apibank_data",
66
+ },
67
+ "bfcl": {
68
+ "name": "BFCL",
69
+ "source": "external",
70
+ "path": "bfcl",
71
+ "loader": "load_bfcl_data",
72
+ },
73
+ }
74
+
75
+ DEFAULT_STRATEGIES = [
76
+ "selector.keyword",
77
+ "selector.embedding",
78
+ "selector.hybrid",
79
+ ]
80
+
81
+
82
+ # =============================================================================
83
+ # Data Loaders
84
+ # =============================================================================
85
+
86
+
87
+ def load_dataset(dataset_id: str, max_samples: int = 100) -> list[dict]:
88
+ """
89
+ 加载指定数据集。
90
+
91
+ Args:
92
+ dataset_id: 数据集 ID
93
+ max_samples: 最大样本数
94
+
95
+ Returns:
96
+ 标准化的样本列表
97
+ """
98
+ if dataset_id not in DATASETS:
99
+ print(f" ⚠️ Unknown dataset: {dataset_id}")
100
+ return []
101
+
102
+ DATASETS[dataset_id]
103
+
104
+ if dataset_id == "sage":
105
+ return load_benchmark_data("selection", split="test", max_samples=max_samples)
106
+
107
+ elif dataset_id == "acebench":
108
+ return _load_acebench_data(max_samples)
109
+
110
+ elif dataset_id == "toolbench":
111
+ return _load_toolbench_data(max_samples)
112
+
113
+ elif dataset_id == "apibank":
114
+ return _load_apibank_data(max_samples)
115
+
116
+ elif dataset_id == "bfcl":
117
+ return _load_bfcl_data(max_samples)
118
+
119
+ return []
120
+
121
+
122
+ def _load_acebench_data(max_samples: int) -> list[dict]:
123
+ """加载 ACE-Bench 数据。"""
124
+ # TODO: 实现实际的 ACE-Bench 加载
125
+ # 这里返回模拟数据结构
126
+ try:
127
+ from sage.benchmark.benchmark_agent.acebench_loader import load_acebench_samples
128
+
129
+ samples = load_acebench_samples(max_samples=max_samples)
130
+ # 标准化字段
131
+ return [
132
+ {
133
+ "instruction": s.get("query", s.get("instruction", "")),
134
+ "candidate_tools": s.get("tools", s.get("candidate_tools", [])),
135
+ "ground_truth": s.get("expected", s.get("ground_truth", [])),
136
+ }
137
+ for s in samples
138
+ ]
139
+ except ImportError:
140
+ print(" ⚠️ ACE-Bench loader not available")
141
+ return []
142
+
143
+
144
+ def _load_toolbench_data(max_samples: int) -> list[dict]:
145
+ """加载 ToolBench 数据。"""
146
+ # TODO: 实现实际的 ToolBench 加载
147
+ print(" ⚠️ ToolBench loader not implemented")
148
+ return []
149
+
150
+
151
+ def _load_apibank_data(max_samples: int) -> list[dict]:
152
+ """加载 API-Bank 数据。"""
153
+ # TODO: 实现实际的 API-Bank 加载
154
+ print(" ⚠️ API-Bank loader not implemented")
155
+ return []
156
+
157
+
158
+ def _load_bfcl_data(max_samples: int) -> list[dict]:
159
+ """加载 BFCL 数据。"""
160
+ # TODO: 实现实际的 BFCL 加载
161
+ print(" ⚠️ BFCL loader not implemented")
162
+ return []
163
+
164
+
165
+ # =============================================================================
166
+ # Evaluation
167
+ # =============================================================================
168
+
169
+
170
+ def evaluate_on_dataset(
171
+ strategy_name: str,
172
+ samples: list[dict],
173
+ top_k: int = 5,
174
+ verbose: bool = True,
175
+ ) -> dict[str, float]:
176
+ """
177
+ 在单个数据集上评估策略。
178
+
179
+ Returns:
180
+ {metric: value}
181
+ """
182
+ if not samples:
183
+ return {"top_k_accuracy": 0.0, "mrr": 0.0}
184
+
185
+ try:
186
+ from sage.benchmark.benchmark_agent import get_adapter_registry
187
+
188
+ registry = get_adapter_registry()
189
+ selector = registry.get(strategy_name)
190
+ except Exception as e:
191
+ if verbose:
192
+ print(f" ⚠️ Failed to create selector: {e}")
193
+ return {"top_k_accuracy": 0.0, "mrr": 0.0}
194
+
195
+ hits = 0
196
+ rr_sum = 0.0
197
+
198
+ for sample in samples:
199
+ query = sample.get("instruction", "")
200
+ candidate_tools = sample.get("candidate_tools", [])
201
+ ground_truth = sample.get("ground_truth", [])
202
+
203
+ try:
204
+ preds = selector.select(query, candidate_tools=candidate_tools, top_k=top_k)
205
+ pred_ids = (
206
+ [p.tool_id if hasattr(p, "tool_id") else str(p) for p in preds] if preds else []
207
+ )
208
+
209
+ ref_set = set(ground_truth) if isinstance(ground_truth, list) else {ground_truth}
210
+
211
+ # Top-K accuracy
212
+ if set(pred_ids[:top_k]) & ref_set:
213
+ hits += 1
214
+
215
+ # MRR
216
+ for i, p in enumerate(pred_ids):
217
+ if p in ref_set:
218
+ rr_sum += 1.0 / (i + 1)
219
+ break
220
+
221
+ except Exception:
222
+ pass
223
+
224
+ n = len(samples)
225
+ return {
226
+ "top_k_accuracy": hits / n if n > 0 else 0.0,
227
+ "mrr": rr_sum / n if n > 0 else 0.0,
228
+ }
229
+
230
+
231
+ # =============================================================================
232
+ # Main Experiment
233
+ # =============================================================================
234
+
235
+
236
+ def run_cross_dataset_evaluation(
237
+ datasets: Optional[list[str]] = None,
238
+ strategies: Optional[list[str]] = None,
239
+ max_samples: int = 100,
240
+ top_k: int = 5,
241
+ verbose: bool = True,
242
+ ) -> dict[str, dict[str, dict[str, float]]]:
243
+ """
244
+ 运行跨数据集评估。
245
+
246
+ Args:
247
+ datasets: 要测试的数据集列表
248
+ strategies: 要测试的策略列表
249
+ max_samples: 每个数据集的最大样本数
250
+ top_k: Top-K 参数
251
+ verbose: 是否打印详细信息
252
+
253
+ Returns:
254
+ {strategy: {dataset: {metric: value}}}
255
+ """
256
+ setup_experiment_env(verbose=verbose)
257
+
258
+ print_section_header("Section 5.4: Cross-Dataset Generalization")
259
+
260
+ if datasets is None:
261
+ datasets = ["sage", "acebench"] # 默认只测试可用的
262
+
263
+ if strategies is None:
264
+ strategies = DEFAULT_STRATEGIES
265
+
266
+ print(f" Datasets: {datasets}")
267
+ print(f" Strategies: {[s.split('.')[-1] for s in strategies]}")
268
+ print(f" Max samples per dataset: {max_samples}")
269
+
270
+ all_results: dict[str, dict[str, dict[str, float]]] = {}
271
+
272
+ for strategy_name in strategies:
273
+ strategy_short = strategy_name.split(".")[-1]
274
+ print_subsection_header(f"Strategy: {strategy_short}")
275
+
276
+ all_results[strategy_name] = {}
277
+
278
+ for dataset_id in datasets:
279
+ dataset_config = DATASETS.get(dataset_id, {})
280
+ dataset_name = dataset_config.get("name", dataset_id)
281
+
282
+ print(f"\n Dataset: {dataset_name}")
283
+
284
+ # 加载数据
285
+ samples = load_dataset(dataset_id, max_samples=max_samples)
286
+
287
+ if not samples:
288
+ print(" No data available")
289
+ all_results[strategy_name][dataset_id] = {"top_k_accuracy": 0.0, "mrr": 0.0}
290
+ continue
291
+
292
+ print(f" Samples: {len(samples)}")
293
+
294
+ # 评估
295
+ metrics = evaluate_on_dataset(strategy_name, samples, top_k=top_k, verbose=verbose)
296
+ all_results[strategy_name][dataset_id] = metrics
297
+
298
+ if verbose:
299
+ print(f" Top-{top_k} Accuracy: {metrics['top_k_accuracy'] * 100:.1f}%")
300
+ print(f" MRR: {metrics['mrr'] * 100:.1f}%")
301
+
302
+ # 保存结果
303
+ output_file = save_results(all_results, "5_4_generalization", "cross_dataset")
304
+ print(f"\n Results saved to: {output_file}")
305
+
306
+ # 生成图表
307
+ _generate_cross_dataset_figures(all_results, datasets, top_k)
308
+
309
+ # 打印汇总表
310
+ _print_summary_table(all_results, datasets)
311
+
312
+ return all_results
313
+
314
+
315
+ def _generate_cross_dataset_figures(results: dict, datasets: list[str], top_k: int) -> None:
316
+ """生成跨数据集对比图表。"""
317
+ try:
318
+ from figure_generator import plot_cross_dataset_comparison
319
+
320
+ figures_dir = get_figures_dir()
321
+
322
+ # 转换数据格式
323
+ plot_data = {}
324
+ for strategy, dataset_results in results.items():
325
+ strategy_short = strategy.split(".")[-1]
326
+ plot_data[strategy_short] = {
327
+ d: dataset_results.get(d, {}).get("top_k_accuracy", 0) for d in datasets
328
+ }
329
+
330
+ plot_cross_dataset_comparison(
331
+ plot_data,
332
+ metric=f"top_{top_k}_accuracy",
333
+ output_path=figures_dir / "fig9_generalization_cross_dataset.pdf",
334
+ )
335
+ print(" Figure saved: fig9_generalization_cross_dataset.pdf")
336
+
337
+ except Exception as e:
338
+ print(f" Warning: Could not generate figures: {e}")
339
+
340
+
341
+ def _print_summary_table(results: dict, datasets: list[str]) -> None:
342
+ """打印汇总表格。"""
343
+ print("\n" + "=" * 70)
344
+ print(" Cross-Dataset Generalization Summary")
345
+ print("=" * 70)
346
+
347
+ # 表头
348
+ header = f"{'Strategy':15s}"
349
+ for d in datasets:
350
+ header += f" | {DATASETS.get(d, {}).get('name', d):12s}"
351
+ print(header)
352
+ print("-" * 70)
353
+
354
+ # 每个策略一行
355
+ for strategy, dataset_results in results.items():
356
+ strategy_short = strategy.split(".")[-1]
357
+ row = f"{strategy_short:15s}"
358
+ for d in datasets:
359
+ acc = dataset_results.get(d, {}).get("top_k_accuracy", 0)
360
+ row += f" | {acc * 100:10.1f}%"
361
+ print(row)
362
+
363
+ print("-" * 70)
364
+
365
+ # 计算泛化得分 (跨数据集方差的倒数)
366
+ print("\n Generalization Scores (lower variance = better):")
367
+ for strategy, dataset_results in results.items():
368
+ strategy_short = strategy.split(".")[-1]
369
+ accs = [dataset_results.get(d, {}).get("top_k_accuracy", 0) for d in datasets]
370
+ if accs:
371
+ mean_acc = sum(accs) / len(accs)
372
+ variance = sum((a - mean_acc) ** 2 for a in accs) / len(accs)
373
+ print(f" {strategy_short:12s}: mean={mean_acc * 100:.1f}%, var={variance * 100:.2f}")
374
+
375
+
376
+ def main():
377
+ parser = argparse.ArgumentParser(description="Section 5.4: Cross-Dataset Generalization")
378
+ parser.add_argument(
379
+ "--datasets", type=str, default="sage,acebench", help="Comma-separated dataset IDs"
380
+ )
381
+ parser.add_argument(
382
+ "--strategies", type=str, default=None, help="Comma-separated strategy names"
383
+ )
384
+ parser.add_argument("--max-samples", type=int, default=100, help="Maximum samples per dataset")
385
+ parser.add_argument("--top-k", type=int, default=5, help="Top-K parameter")
386
+ parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
387
+ args = parser.parse_args()
388
+
389
+ datasets = args.datasets.split(",") if args.datasets else None
390
+ strategies = args.strategies.split(",") if args.strategies else None
391
+
392
+ run_cross_dataset_evaluation(
393
+ datasets=datasets,
394
+ strategies=strategies,
395
+ max_samples=args.max_samples,
396
+ top_k=args.top_k,
397
+ verbose=args.verbose,
398
+ )
399
+
400
+ print("\n" + "=" * 70)
401
+ print("📊 Cross-Dataset Evaluation Complete")
402
+ print("=" * 70)
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()