isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,677 @@
1
+ """
2
+ Figure Generator - 统一图表生成模块
3
+
4
+ 为 Paper 1 所有实验生成一致风格的图表:
5
+ - 使用学术论文标准样式
6
+ - 支持 PDF + PNG 双格式输出
7
+ - 颜色方案对色盲友好
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from pathlib import Path
13
+ from typing import Any, Optional
14
+
15
+ try: # Support both `python -m experiments...` and standalone usage
16
+ from . import exp_utils as _exp_utils
17
+ except ImportError: # pragma: no cover - fallback for direct script execution
18
+ import exp_utils as _exp_utils # type: ignore
19
+
20
+ import numpy as np
21
+
22
+ # =============================================================================
23
+ # 图表样式配置
24
+ # =============================================================================
25
+
26
+ # Matplotlib 样式设置
27
+ FIGURE_STYLE = {
28
+ "font.family": "serif",
29
+ "font.serif": ["Times New Roman", "DejaVu Serif"],
30
+ "font.size": 10,
31
+ "axes.titlesize": 12,
32
+ "axes.labelsize": 10,
33
+ "xtick.labelsize": 9,
34
+ "ytick.labelsize": 9,
35
+ "legend.fontsize": 9,
36
+ "figure.figsize": (8, 6),
37
+ "figure.dpi": 100,
38
+ "savefig.dpi": 300,
39
+ "savefig.format": "pdf",
40
+ "savefig.bbox": "tight",
41
+ "axes.grid": True,
42
+ "grid.alpha": 0.3,
43
+ }
44
+
45
+ # 颜色方案 (colorblind-friendly, based on ColorBrewer)
46
+ COLORS = {
47
+ # 主色调
48
+ "primary": "#1f77b4", # 蓝色
49
+ "secondary": "#ff7f0e", # 橙色
50
+ "tertiary": "#2ca02c", # 绿色
51
+ "quaternary": "#d62728", # 红色
52
+ "quinary": "#9467bd", # 紫色
53
+ # 语义色
54
+ "success": "#2ca02c",
55
+ "warning": "#ff7f0e",
56
+ "danger": "#d62728",
57
+ "info": "#1f77b4",
58
+ # 特殊用途
59
+ "target_line": "#7f7f7f", # 灰色目标线
60
+ "baseline": "#bcbd22", # 黄绿色 baseline
61
+ "best": "#17becf", # 青色最佳
62
+ }
63
+
64
+ # 策略颜色映射
65
+ STRATEGY_COLORS = {
66
+ # Timing
67
+ "rule_based": COLORS["primary"],
68
+ "llm_based": COLORS["secondary"],
69
+ "hybrid": COLORS["tertiary"],
70
+ "embedding": COLORS["quaternary"],
71
+ # Planning
72
+ "simple": COLORS["primary"],
73
+ "hierarchical": COLORS["secondary"],
74
+ "react": COLORS["quaternary"],
75
+ # Selection
76
+ "keyword": COLORS["primary"],
77
+ "gorilla": COLORS["quaternary"],
78
+ "dfsdt": COLORS["quinary"],
79
+ }
80
+
81
+ # 图表尺寸预设
82
+ FIGURE_SIZES = {
83
+ "single": (6, 4), # 单列
84
+ "double": (10, 4), # 双列
85
+ "wide": (12, 4), # 宽幅
86
+ "square": (6, 6), # 方形
87
+ "tall": (6, 8), # 高图
88
+ }
89
+
90
+
91
+ # =============================================================================
92
+ # 输出目录辅助函数
93
+ # =============================================================================
94
+
95
+
96
+ def get_figures_dir(output_dir: Optional[Path] = None) -> Path:
97
+ """Delegate to exp_utils so legacy imports keep working."""
98
+
99
+ return _exp_utils.get_figures_dir(output_dir)
100
+
101
+
102
+ def get_tables_dir(output_dir: Optional[Path] = None) -> Path:
103
+ """Delegate to exp_utils so legacy imports keep working."""
104
+
105
+ return _exp_utils.get_tables_dir(output_dir)
106
+
107
+
108
+ # =============================================================================
109
+ # 图表生成函数
110
+ # =============================================================================
111
+
112
+
113
+ def setup_matplotlib():
114
+ """设置 matplotlib 样式。"""
115
+ try:
116
+ import matplotlib.pyplot as plt
117
+
118
+ plt.rcParams.update(FIGURE_STYLE)
119
+ return plt
120
+ except ImportError:
121
+ print(" Warning: matplotlib not available")
122
+ return None
123
+
124
+
125
+ def plot_challenge_comparison(
126
+ results: list[dict],
127
+ challenge: str,
128
+ metrics: list[str],
129
+ target: float,
130
+ output_path: Optional[Path] = None,
131
+ title: Optional[str] = None,
132
+ ) -> Any:
133
+ """
134
+ 绘制单个 Challenge 的策略对比图。
135
+
136
+ Args:
137
+ results: [{"strategy": str, "metrics": {"accuracy": 0.9, ...}}, ...]
138
+ challenge: 挑战名称
139
+ metrics: 要展示的指标列表
140
+ target: 目标线
141
+ output_path: 输出路径
142
+ title: 图表标题
143
+
144
+ Returns:
145
+ matplotlib Figure 对象
146
+ """
147
+ plt = setup_matplotlib()
148
+ if plt is None:
149
+ return None
150
+
151
+ fig, ax = plt.subplots(figsize=FIGURE_SIZES["single"])
152
+
153
+ strategies = [r["strategy"] for r in results]
154
+ x = np.arange(len(strategies))
155
+ width = 0.8 / len(metrics)
156
+
157
+ # 绘制每个指标的柱状图
158
+ for i, metric in enumerate(metrics):
159
+ values = [r["metrics"].get(metric, 0) * 100 for r in results]
160
+ offset = (i - len(metrics) / 2 + 0.5) * width
161
+ bars = ax.bar(x + offset, values, width, label=metric.replace("_", " ").title())
162
+
163
+ # 在柱子上标注数值
164
+ for bar, val in zip(bars, values):
165
+ height = bar.get_height()
166
+ ax.annotate(
167
+ f"{val:.1f}",
168
+ xy=(bar.get_x() + bar.get_width() / 2, height),
169
+ xytext=(0, 3),
170
+ textcoords="offset points",
171
+ ha="center",
172
+ va="bottom",
173
+ fontsize=8,
174
+ )
175
+
176
+ # 目标线
177
+ ax.axhline(
178
+ y=target * 100,
179
+ color=COLORS["target_line"],
180
+ linestyle="--",
181
+ linewidth=2,
182
+ label=f"Target ({target * 100:.0f}%)",
183
+ )
184
+
185
+ # 设置
186
+ ax.set_xlabel("Strategy")
187
+ ax.set_ylabel("Score (%)")
188
+ ax.set_title(title or f"{challenge.title()} Challenge: Strategy Comparison")
189
+ ax.set_xticks(x)
190
+ ax.set_xticklabels([s.replace("_", " ").title() for s in strategies], rotation=15, ha="right")
191
+ ax.set_ylim(0, 105)
192
+ ax.legend(loc="upper right")
193
+
194
+ plt.tight_layout()
195
+
196
+ # 保存
197
+ if output_path:
198
+ fig.savefig(output_path, format="pdf", bbox_inches="tight")
199
+ # 同时保存 PNG
200
+ png_path = output_path.with_suffix(".png")
201
+ fig.savefig(png_path, format="png", dpi=300, bbox_inches="tight")
202
+
203
+ return fig
204
+
205
+
206
+ def plot_scaling_curve(
207
+ results: dict[str, list[tuple[float, float]]],
208
+ x_label: str,
209
+ y_label: str,
210
+ title: str,
211
+ output_path: Optional[Path] = None,
212
+ log_x: bool = False,
213
+ ) -> Any:
214
+ """
215
+ 绘制 Scaling 曲线图。
216
+
217
+ Args:
218
+ results: {"strategy_name": [(x1, y1), (x2, y2), ...], ...}
219
+ x_label: X轴标签
220
+ y_label: Y轴标签
221
+ title: 标题
222
+ output_path: 输出路径
223
+ log_x: X轴是否使用对数刻度
224
+
225
+ Returns:
226
+ matplotlib Figure 对象
227
+ """
228
+ plt = setup_matplotlib()
229
+ if plt is None:
230
+ return None
231
+
232
+ fig, ax = plt.subplots(figsize=FIGURE_SIZES["single"])
233
+
234
+ for strategy, points in results.items():
235
+ x_vals = [p[0] for p in points]
236
+ y_vals = [p[1] * 100 for p in points] # 转换为百分比
237
+ color = STRATEGY_COLORS.get(strategy, COLORS["primary"])
238
+ ax.plot(
239
+ x_vals,
240
+ y_vals,
241
+ "o-",
242
+ label=strategy.replace("_", " ").title(),
243
+ color=color,
244
+ linewidth=2,
245
+ markersize=6,
246
+ )
247
+
248
+ if log_x:
249
+ ax.set_xscale("log")
250
+
251
+ ax.set_xlabel(x_label)
252
+ ax.set_ylabel(y_label)
253
+ ax.set_title(title)
254
+ ax.legend(loc="best")
255
+ ax.grid(True, alpha=0.3)
256
+
257
+ plt.tight_layout()
258
+
259
+ if output_path:
260
+ fig.savefig(output_path, format="pdf", bbox_inches="tight")
261
+ png_path = output_path.with_suffix(".png")
262
+ fig.savefig(png_path, format="png", dpi=300, bbox_inches="tight")
263
+
264
+ return fig
265
+
266
+
267
+ def plot_error_breakdown(
268
+ errors: dict[str, dict[str, int]],
269
+ challenge: str,
270
+ output_path: Optional[Path] = None,
271
+ ) -> Any:
272
+ """
273
+ 绘制错误类型分解图。
274
+
275
+ Args:
276
+ errors: {"strategy": {"error_type": count, ...}, ...}
277
+ challenge: 挑战名称
278
+ output_path: 输出路径
279
+
280
+ Returns:
281
+ matplotlib Figure 对象
282
+ """
283
+ plt = setup_matplotlib()
284
+ if plt is None:
285
+ return None
286
+
287
+ fig, axes = plt.subplots(1, len(errors), figsize=(4 * len(errors), 4))
288
+ if len(errors) == 1:
289
+ axes = [axes]
290
+
291
+ colors = list(COLORS.values())[:6]
292
+
293
+ for ax, (strategy, error_counts) in zip(axes, errors.items()):
294
+ labels = list(error_counts.keys())
295
+ sizes = list(error_counts.values())
296
+
297
+ if sum(sizes) > 0:
298
+ ax.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors[: len(labels)])
299
+ else:
300
+ ax.text(0.5, 0.5, "No Errors", ha="center", va="center")
301
+
302
+ ax.set_title(strategy.replace("_", " ").title())
303
+
304
+ fig.suptitle(f"{challenge.title()} Challenge: Error Breakdown")
305
+ plt.tight_layout()
306
+
307
+ if output_path:
308
+ fig.savefig(output_path, format="pdf", bbox_inches="tight")
309
+ png_path = output_path.with_suffix(".png")
310
+ fig.savefig(png_path, format="png", dpi=300, bbox_inches="tight")
311
+
312
+ return fig
313
+
314
+
315
+ def plot_ablation_heatmap(
316
+ ablation_results: dict[str, dict[str, float]],
317
+ title: str,
318
+ output_path: Optional[Path] = None,
319
+ ) -> Any:
320
+ """
321
+ 绘制消融实验热力图。
322
+
323
+ Args:
324
+ ablation_results: {"config_name": {"metric1": 0.9, "metric2": 0.8}, ...}
325
+ title: 标题
326
+ output_path: 输出路径
327
+
328
+ Returns:
329
+ matplotlib Figure 对象
330
+ """
331
+ plt = setup_matplotlib()
332
+ if plt is None:
333
+ return None
334
+
335
+ configs = list(ablation_results.keys())
336
+ metrics = list(next(iter(ablation_results.values())).keys())
337
+
338
+ data = np.array([[ablation_results[c][m] * 100 for m in metrics] for c in configs])
339
+
340
+ fig, ax = plt.subplots(figsize=FIGURE_SIZES["square"])
341
+
342
+ im = ax.imshow(data, cmap="RdYlGn", aspect="auto", vmin=0, vmax=100)
343
+
344
+ # 设置标签
345
+ ax.set_xticks(np.arange(len(metrics)))
346
+ ax.set_yticks(np.arange(len(configs)))
347
+ ax.set_xticklabels([m.replace("_", " ").title() for m in metrics])
348
+ ax.set_yticklabels([c.replace("_", " ").title() for c in configs])
349
+
350
+ # 旋转 x 标签
351
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
352
+
353
+ # 添加数值标注
354
+ for i in range(len(configs)):
355
+ for j in range(len(metrics)):
356
+ ax.text(j, i, f"{data[i, j]:.1f}", ha="center", va="center", color="black", fontsize=9)
357
+
358
+ ax.set_title(title)
359
+ fig.colorbar(im, ax=ax, label="Score (%)")
360
+
361
+ plt.tight_layout()
362
+
363
+ if output_path:
364
+ fig.savefig(output_path, format="pdf", bbox_inches="tight")
365
+ png_path = output_path.with_suffix(".png")
366
+ fig.savefig(png_path, format="png", dpi=300, bbox_inches="tight")
367
+
368
+ return fig
369
+
370
+
371
+ def plot_cross_dataset_comparison(
372
+ results: dict[str, dict[str, float]],
373
+ metric: str,
374
+ output_path: Optional[Path] = None,
375
+ ) -> Any:
376
+ """
377
+ 绘制跨数据集对比图。
378
+
379
+ Args:
380
+ results: {"strategy": {"dataset1": 0.9, "dataset2": 0.8}, ...}
381
+ metric: 指标名称
382
+ output_path: 输出路径
383
+
384
+ Returns:
385
+ matplotlib Figure 对象
386
+ """
387
+ plt = setup_matplotlib()
388
+ if plt is None:
389
+ return None
390
+
391
+ strategies = list(results.keys())
392
+ datasets = list(next(iter(results.values())).keys())
393
+
394
+ x = np.arange(len(datasets))
395
+ width = 0.8 / len(strategies)
396
+
397
+ fig, ax = plt.subplots(figsize=FIGURE_SIZES["wide"])
398
+
399
+ for i, strategy in enumerate(strategies):
400
+ values = [results[strategy].get(d, 0) * 100 for d in datasets]
401
+ offset = (i - len(strategies) / 2 + 0.5) * width
402
+ color = STRATEGY_COLORS.get(strategy.split(".")[-1], COLORS["primary"])
403
+ ax.bar(x + offset, values, width, label=strategy.replace("_", " ").title(), color=color)
404
+
405
+ ax.set_xlabel("Dataset")
406
+ ax.set_ylabel(f"{metric.replace('_', ' ').title()} (%)")
407
+ ax.set_title(f"Cross-Dataset Comparison: {metric.replace('_', ' ').title()}")
408
+ ax.set_xticks(x)
409
+ ax.set_xticklabels(datasets)
410
+ ax.legend(loc="upper right", ncol=2)
411
+ ax.set_ylim(0, 105)
412
+
413
+ plt.tight_layout()
414
+
415
+ if output_path:
416
+ fig.savefig(output_path, format="pdf", bbox_inches="tight")
417
+ png_path = output_path.with_suffix(".png")
418
+ fig.savefig(png_path, format="png", dpi=300, bbox_inches="tight")
419
+
420
+ return fig
421
+
422
+
423
+ # =============================================================================
424
+ # LaTeX 表格生成
425
+ # =============================================================================
426
+
427
+
428
+ def generate_main_results_table(
429
+ timing_results: list[dict],
430
+ planning_results: list[dict],
431
+ selection_results: list[dict],
432
+ output_path: Optional[Path] = None,
433
+ ) -> str:
434
+ """
435
+ 生成主结果 LaTeX 表格。
436
+
437
+ Returns:
438
+ LaTeX 表格字符串
439
+ """
440
+ latex = r"""
441
+ \begin{table}[t]
442
+ \centering
443
+ \caption{Main Results across Three Challenges}
444
+ \label{tab:main_results}
445
+ \begin{tabular}{llccc}
446
+ \toprule
447
+ \textbf{Challenge} & \textbf{Strategy} & \textbf{Primary} & \textbf{Secondary} & \textbf{Target Met} \\
448
+ \midrule
449
+ """
450
+
451
+ # Timing 结果
452
+ for r in timing_results:
453
+ acc = r["metrics"].get("accuracy", 0) * 100
454
+ prec = r["metrics"].get("precision", 0) * 100
455
+ passed = "\\cmark" if r.get("passed", False) else "\\xmark"
456
+ latex += f"Timing & {r['strategy']} & {acc:.1f}\\% & {prec:.1f}\\% & {passed} \\\\\n"
457
+
458
+ latex += r"\midrule" + "\n"
459
+
460
+ # Planning 结果
461
+ for r in planning_results:
462
+ success = r["metrics"].get("plan_success_rate", 0) * 100
463
+ step_acc = r["metrics"].get("step_accuracy", 0) * 100
464
+ passed = "\\cmark" if r.get("passed", False) else "\\xmark"
465
+ latex += (
466
+ f"Planning & {r['strategy']} & {success:.1f}\\% & {step_acc:.1f}\\% & {passed} \\\\\n"
467
+ )
468
+
469
+ latex += r"\midrule" + "\n"
470
+
471
+ # Selection 结果
472
+ for r in selection_results:
473
+ topk = r["metrics"].get("top_k_accuracy", 0) * 100
474
+ mrr = r["metrics"].get("mrr", 0) * 100
475
+ passed = "\\cmark" if r.get("passed", False) else "\\xmark"
476
+ latex += f"Selection & {r['strategy']} & {topk:.1f}\\% & {mrr:.1f}\\% & {passed} \\\\\n"
477
+
478
+ latex += r"""
479
+ \bottomrule
480
+ \end{tabular}
481
+ \end{table}
482
+ """
483
+
484
+ if output_path:
485
+ with open(output_path, "w", encoding="utf-8") as f:
486
+ f.write(latex)
487
+
488
+ return latex
489
+
490
+
491
+ def generate_detailed_table(
492
+ results: list[dict],
493
+ challenge: str,
494
+ metrics: list[str],
495
+ output_path: Optional[Path] = None,
496
+ ) -> str:
497
+ """
498
+ 生成详细结果 LaTeX 表格。
499
+
500
+ Returns:
501
+ LaTeX 表格字符串
502
+ """
503
+ metric_headers = " & ".join([f"\\textbf{{{m.replace('_', ' ').title()}}}" for m in metrics])
504
+
505
+ latex = f"""
506
+ \\begin{{table}}[t]
507
+ \\centering
508
+ \\caption{{{challenge.title()} Challenge: Detailed Results}}
509
+ \\label{{tab:{challenge}_detailed}}
510
+ \\begin{{tabular}}{{l{"c" * len(metrics)}}}
511
+ \\toprule
512
+ \\textbf{{Strategy}} & {metric_headers} \\\\
513
+ \\midrule
514
+ """
515
+
516
+ for r in results:
517
+ values = " & ".join([f"{r['metrics'].get(m, 0) * 100:.1f}\\%" for m in metrics])
518
+ latex += f"{r['strategy']} & {values} \\\\\n"
519
+
520
+ latex += r"""
521
+ \bottomrule
522
+ \end{tabular}
523
+ \end{table}
524
+ """
525
+
526
+ if output_path:
527
+ with open(output_path, "w", encoding="utf-8") as f:
528
+ f.write(latex)
529
+
530
+ return latex
531
+
532
+
533
+ # =============================================================================
534
+ # 批量生成
535
+ # =============================================================================
536
+
537
+
538
+ def generate_all_figures(
539
+ results_dir: Optional[Path] = None,
540
+ output_dir: Optional[Path] = None,
541
+ ) -> dict[str, Path]:
542
+ """
543
+ 从结果目录生成所有图表。
544
+
545
+ Args:
546
+ results_dir: 结果目录 (默认 .sage/benchmark/paper1/)
547
+ output_dir: 输出目录 (默认 results_dir/figures/)
548
+
549
+ Returns:
550
+ {figure_name: output_path}
551
+ """
552
+ import json
553
+
554
+ from .exp_utils import DEFAULT_OUTPUT_DIR, get_figures_dir
555
+
556
+ results_dir = results_dir or DEFAULT_OUTPUT_DIR
557
+ output_dir = output_dir or get_figures_dir()
558
+
559
+ print(f"\n Generating figures from: {results_dir}")
560
+ print(f" Output directory: {output_dir}")
561
+
562
+ output_dir.mkdir(parents=True, exist_ok=True)
563
+ generated = {}
564
+
565
+ # 加载结果文件
566
+ def load_json(path):
567
+ if path.exists():
568
+ with open(path, encoding="utf-8") as f:
569
+ return json.load(f)
570
+ return None
571
+
572
+ # Figure 1: Main Results Comparison
573
+ section_5_2 = results_dir / "section_5_2_main"
574
+ timing = load_json(section_5_2 / "timing_results.json")
575
+ planning = load_json(section_5_2 / "planning_results.json")
576
+ selection = load_json(section_5_2 / "selection_results.json")
577
+
578
+ if timing:
579
+ results = timing.get("results", [])
580
+ if results:
581
+ path = output_dir / "fig1_timing_comparison.pdf"
582
+ # Prepare data for plot_challenge_comparison
583
+ formatted_results = [
584
+ {"strategy": r.get("strategy", ""), "metrics": r.get("metrics", {})}
585
+ for r in results
586
+ ]
587
+ plot_challenge_comparison(
588
+ results=formatted_results,
589
+ challenge="timing",
590
+ metrics=["accuracy"],
591
+ target=0.95,
592
+ title="RQ1: Timing Detection",
593
+ output_path=path,
594
+ )
595
+ generated["timing_comparison"] = path
596
+ print(f" Generated: {path.name}")
597
+
598
+ if planning:
599
+ results = planning.get("results", [])
600
+ if results:
601
+ path = output_dir / "fig2_planning_comparison.pdf"
602
+ formatted_results = [
603
+ {"strategy": r.get("strategy", ""), "metrics": r.get("metrics", {})}
604
+ for r in results
605
+ ]
606
+ plot_challenge_comparison(
607
+ results=formatted_results,
608
+ challenge="planning",
609
+ metrics=["plan_success_rate"],
610
+ target=0.90,
611
+ title="RQ2: Task Planning",
612
+ output_path=path,
613
+ )
614
+ generated["planning_comparison"] = path
615
+ print(f" Generated: {path.name}")
616
+
617
+ if selection:
618
+ results = selection.get("results", [])
619
+ if results:
620
+ path = output_dir / "fig3_selection_comparison.pdf"
621
+ formatted_results = [
622
+ {"strategy": r.get("strategy", ""), "metrics": r.get("metrics", {})}
623
+ for r in results
624
+ ]
625
+ plot_challenge_comparison(
626
+ results=formatted_results,
627
+ challenge="selection",
628
+ metrics=["top_k_accuracy"],
629
+ target=0.95,
630
+ title="RQ3: Tool Selection",
631
+ output_path=path,
632
+ )
633
+ generated["selection_comparison"] = path
634
+ print(f" Generated: {path.name}")
635
+
636
+ # Figure 4: Error Breakdown
637
+ section_5_3 = results_dir / "section_5_3_analysis"
638
+ error_data = load_json(section_5_3 / "error_analysis.json")
639
+ if error_data:
640
+ path = output_dir / "fig4_error_breakdown.pdf"
641
+ plot_error_breakdown(error_data, challenge="all", output_path=path)
642
+ generated["error_breakdown"] = path
643
+ print(f" Generated: {path.name}")
644
+
645
+ # Figure 5: Scaling Analysis
646
+ scaling_data = load_json(section_5_3 / "scaling_analysis.json")
647
+ if scaling_data and "tool_scaling" in scaling_data:
648
+ path = output_dir / "fig5_tool_scaling.pdf"
649
+ plot_scaling_curve(
650
+ results=scaling_data["tool_scaling"],
651
+ x_label="Number of Tools",
652
+ y_label="Top-K Accuracy",
653
+ title="Tool Set Size Scaling",
654
+ output_path=path,
655
+ )
656
+ generated["tool_scaling"] = path
657
+ print(f" Generated: {path.name}")
658
+
659
+ # Figure 6: Ablation Heatmap
660
+ ablation_data = load_json(section_5_3 / "ablation_results.json")
661
+ if ablation_data:
662
+ path = output_dir / "fig6_ablation_heatmap.pdf"
663
+ plot_ablation_heatmap(ablation_data, title="Ablation Study Results", output_path=path)
664
+ generated["ablation_heatmap"] = path
665
+ print(f" Generated: {path.name}")
666
+
667
+ # Figure 7: Cross-Dataset
668
+ section_5_4 = results_dir / "section_5_4_generalization"
669
+ cross_data = load_json(section_5_4 / "cross_dataset_results.json")
670
+ if cross_data:
671
+ path = output_dir / "fig7_cross_dataset.pdf"
672
+ plot_cross_dataset_comparison(cross_data, metric="top5_accuracy", output_path=path)
673
+ generated["cross_dataset"] = path
674
+ print(f" Generated: {path.name}")
675
+
676
+ print(f"\n Total figures generated: {len(generated)}")
677
+ return generated