isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
SAGE-Bench Paper 1 Experiment Runner
|
|
4
|
+
|
|
5
|
+
统一入口,按论文 Experiment Section 顺序运行所有实验。
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# 运行所有实验
|
|
9
|
+
python run_paper1_experiments.py
|
|
10
|
+
|
|
11
|
+
# 快速模式 (少量样本)
|
|
12
|
+
python run_paper1_experiments.py --quick
|
|
13
|
+
|
|
14
|
+
# 仅主实验 (Section 5.2)
|
|
15
|
+
python run_paper1_experiments.py --section 5.2
|
|
16
|
+
|
|
17
|
+
# 仅分析实验 (Section 5.3)
|
|
18
|
+
python run_paper1_experiments.py --section 5.3
|
|
19
|
+
|
|
20
|
+
# 单独运行某个实验
|
|
21
|
+
python run_paper1_experiments.py --exp timing
|
|
22
|
+
python run_paper1_experiments.py --exp scaling
|
|
23
|
+
python run_paper1_experiments.py --exp cross-dataset
|
|
24
|
+
|
|
25
|
+
# 跳过 LLM 方法 (快速测试)
|
|
26
|
+
python run_paper1_experiments.py --skip-llm
|
|
27
|
+
|
|
28
|
+
输出目录:
|
|
29
|
+
.sage/benchmark/paper1/
|
|
30
|
+
├── section_5_2_main/
|
|
31
|
+
├── section_5_3_analysis/
|
|
32
|
+
├── section_5_4_generalization/
|
|
33
|
+
├── figures/
|
|
34
|
+
└── tables/
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
import argparse
|
|
40
|
+
import sys
|
|
41
|
+
import time
|
|
42
|
+
from datetime import datetime
|
|
43
|
+
from pathlib import Path
|
|
44
|
+
from typing import Any
|
|
45
|
+
|
|
46
|
+
# 添加实验模块路径
|
|
47
|
+
SCRIPT_DIR = Path(__file__).resolve().parent
|
|
48
|
+
sys.path.insert(0, str(SCRIPT_DIR))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def print_banner():
|
|
52
|
+
"""打印启动 banner。"""
|
|
53
|
+
banner = """
|
|
54
|
+
╔══════════════════════════════════════════════════════════════════════╗
|
|
55
|
+
║ ║
|
|
56
|
+
║ ███████╗ █████╗ ██████╗ ███████╗ ██████╗ ███████╗███╗ ██╗ ║
|
|
57
|
+
║ ██╔════╝██╔══██╗██╔════╝ ██╔════╝ ██╔══██╗██╔════╝████╗ ██║ ║
|
|
58
|
+
║ ███████╗███████║██║ ███╗█████╗ █████╗██████╔╝█████╗ ██╔██╗ ██║ ║
|
|
59
|
+
║ ╚════██║██╔══██║██║ ██║██╔══╝ ╚════╝██╔══██╗██╔══╝ ██║╚██╗██║ ║
|
|
60
|
+
║ ███████║██║ ██║╚██████╔╝███████╗ ██████╔╝███████╗██║ ╚████║ ║
|
|
61
|
+
║ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ║
|
|
62
|
+
║ ║
|
|
63
|
+
║ Paper 1: Benchmark Experiments Runner ║
|
|
64
|
+
║ ║
|
|
65
|
+
╚══════════════════════════════════════════════════════════════════════╝
|
|
66
|
+
"""
|
|
67
|
+
print(banner)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def run_section_5_2(args) -> dict[str, Any]:
|
|
71
|
+
"""
|
|
72
|
+
Section 5.2: Main Results (RQ1-RQ3)
|
|
73
|
+
"""
|
|
74
|
+
print("\n" + "=" * 70)
|
|
75
|
+
print("📊 SECTION 5.2: MAIN RESULTS")
|
|
76
|
+
print("=" * 70)
|
|
77
|
+
|
|
78
|
+
results = {}
|
|
79
|
+
max_samples = 50 if args.quick else 150
|
|
80
|
+
|
|
81
|
+
# RQ1: Timing Detection
|
|
82
|
+
if not args.exp or args.exp in ["timing", "all"]:
|
|
83
|
+
print("\n▶ Running RQ1: Timing Detection...")
|
|
84
|
+
try:
|
|
85
|
+
from experiments.exp_main_timing import run_timing_experiment
|
|
86
|
+
|
|
87
|
+
results["timing"] = run_timing_experiment(
|
|
88
|
+
max_samples=max_samples,
|
|
89
|
+
skip_llm=args.skip_llm,
|
|
90
|
+
verbose=args.verbose,
|
|
91
|
+
)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
print(f" ❌ Timing experiment failed: {e}")
|
|
94
|
+
results["timing"] = None
|
|
95
|
+
|
|
96
|
+
# RQ2: Task Planning
|
|
97
|
+
if not args.exp or args.exp in ["planning", "all"]:
|
|
98
|
+
print("\n▶ Running RQ2: Task Planning...")
|
|
99
|
+
try:
|
|
100
|
+
from experiments.exp_main_planning import run_planning_experiment
|
|
101
|
+
|
|
102
|
+
results["planning"] = run_planning_experiment(
|
|
103
|
+
max_samples=max_samples,
|
|
104
|
+
skip_llm=args.skip_llm,
|
|
105
|
+
verbose=args.verbose,
|
|
106
|
+
)
|
|
107
|
+
except Exception as e:
|
|
108
|
+
print(f" ❌ Planning experiment failed: {e}")
|
|
109
|
+
results["planning"] = None
|
|
110
|
+
|
|
111
|
+
# RQ3: Tool Selection
|
|
112
|
+
if not args.exp or args.exp in ["selection", "all"]:
|
|
113
|
+
print("\n▶ Running RQ3: Tool Selection...")
|
|
114
|
+
try:
|
|
115
|
+
from experiments.exp_main_selection import run_selection_experiment
|
|
116
|
+
|
|
117
|
+
results["selection"] = run_selection_experiment(
|
|
118
|
+
max_samples=max_samples,
|
|
119
|
+
top_k=5,
|
|
120
|
+
skip_llm=args.skip_llm,
|
|
121
|
+
verbose=args.verbose,
|
|
122
|
+
)
|
|
123
|
+
except Exception as e:
|
|
124
|
+
print(f" ❌ Selection experiment failed: {e}")
|
|
125
|
+
results["selection"] = None
|
|
126
|
+
|
|
127
|
+
return results
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def run_section_5_3(args) -> dict[str, Any]:
|
|
131
|
+
"""
|
|
132
|
+
Section 5.3: Analysis & Discussion
|
|
133
|
+
"""
|
|
134
|
+
print("\n" + "=" * 70)
|
|
135
|
+
print("🔬 SECTION 5.3: ANALYSIS & DISCUSSION")
|
|
136
|
+
print("=" * 70)
|
|
137
|
+
|
|
138
|
+
results = {}
|
|
139
|
+
max_samples = 30 if args.quick else 100
|
|
140
|
+
|
|
141
|
+
# 5.3.1 Error Analysis
|
|
142
|
+
if not args.exp or args.exp in ["error", "all"]:
|
|
143
|
+
print("\n▶ Running 5.3.1: Error Analysis...")
|
|
144
|
+
try:
|
|
145
|
+
from experiments.exp_analysis_error import run_error_analysis
|
|
146
|
+
|
|
147
|
+
results["error_analysis"] = run_error_analysis(
|
|
148
|
+
challenge="all",
|
|
149
|
+
verbose=args.verbose,
|
|
150
|
+
)
|
|
151
|
+
except Exception as e:
|
|
152
|
+
print(f" ❌ Error analysis failed: {e}")
|
|
153
|
+
results["error_analysis"] = None
|
|
154
|
+
|
|
155
|
+
# 5.3.2 Scaling Analysis
|
|
156
|
+
if not args.exp or args.exp in ["scaling", "all"]:
|
|
157
|
+
print("\n▶ Running 5.3.2: Scaling Analysis...")
|
|
158
|
+
try:
|
|
159
|
+
from experiments.exp_analysis_scaling import run_scaling_analysis
|
|
160
|
+
|
|
161
|
+
# 判断是否运行真实 LLM Scaling
|
|
162
|
+
run_llm_scaling = not args.skip_llm and not getattr(args, "skip_llm_scaling", False)
|
|
163
|
+
|
|
164
|
+
results["scaling_analysis"] = run_scaling_analysis(
|
|
165
|
+
tool_scaling=True,
|
|
166
|
+
llm_scaling=run_llm_scaling,
|
|
167
|
+
max_samples=max_samples,
|
|
168
|
+
verbose=args.verbose,
|
|
169
|
+
)
|
|
170
|
+
except Exception as e:
|
|
171
|
+
print(f" ❌ Scaling analysis failed: {e}")
|
|
172
|
+
results["scaling_analysis"] = None
|
|
173
|
+
|
|
174
|
+
# 5.3.3 Robustness Analysis
|
|
175
|
+
if not args.exp or args.exp in ["robustness", "all"]:
|
|
176
|
+
print("\n▶ Running 5.3.3: Robustness Analysis...")
|
|
177
|
+
try:
|
|
178
|
+
from experiments.exp_analysis_robustness import run_robustness_analysis
|
|
179
|
+
|
|
180
|
+
results["robustness_analysis"] = run_robustness_analysis(
|
|
181
|
+
semantic_variation=True,
|
|
182
|
+
instruction_quality=True,
|
|
183
|
+
reliability=True,
|
|
184
|
+
max_samples=max_samples,
|
|
185
|
+
verbose=args.verbose,
|
|
186
|
+
)
|
|
187
|
+
except Exception as e:
|
|
188
|
+
print(f" ❌ Robustness analysis failed: {e}")
|
|
189
|
+
results["robustness_analysis"] = None
|
|
190
|
+
|
|
191
|
+
# 5.3.4 Ablation Studies
|
|
192
|
+
if not args.exp or args.exp in ["ablation", "all"]:
|
|
193
|
+
print("\n▶ Running 5.3.4: Ablation Studies...")
|
|
194
|
+
try:
|
|
195
|
+
from experiments.exp_analysis_ablation import run_ablation_study
|
|
196
|
+
|
|
197
|
+
results["ablation_study"] = run_ablation_study(
|
|
198
|
+
prompt_ablation=True,
|
|
199
|
+
hybrid_ablation=True,
|
|
200
|
+
max_samples=max_samples,
|
|
201
|
+
verbose=args.verbose,
|
|
202
|
+
)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
print(f" ❌ Ablation study failed: {e}")
|
|
205
|
+
results["ablation_study"] = None
|
|
206
|
+
|
|
207
|
+
return results
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def run_section_5_4(args) -> dict[str, Any]:
|
|
211
|
+
"""
|
|
212
|
+
Section 5.4: Cross-Dataset Generalization
|
|
213
|
+
"""
|
|
214
|
+
print("\n" + "=" * 70)
|
|
215
|
+
print("🌐 SECTION 5.4: CROSS-DATASET GENERALIZATION")
|
|
216
|
+
print("=" * 70)
|
|
217
|
+
|
|
218
|
+
results = {}
|
|
219
|
+
max_samples = 50 if args.quick else 100
|
|
220
|
+
|
|
221
|
+
if not args.exp or args.exp in ["cross-dataset", "all"]:
|
|
222
|
+
print("\n▶ Running Cross-Dataset Evaluation...")
|
|
223
|
+
try:
|
|
224
|
+
from experiments.exp_cross_dataset import run_cross_dataset_evaluation
|
|
225
|
+
|
|
226
|
+
results["cross_dataset"] = run_cross_dataset_evaluation(
|
|
227
|
+
datasets=["sage", "acebench"],
|
|
228
|
+
max_samples=max_samples,
|
|
229
|
+
verbose=args.verbose,
|
|
230
|
+
)
|
|
231
|
+
except Exception as e:
|
|
232
|
+
print(f" ❌ Cross-dataset evaluation failed: {e}")
|
|
233
|
+
results["cross_dataset"] = None
|
|
234
|
+
|
|
235
|
+
return results
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def print_final_summary(all_results: dict[str, Any], elapsed_time: float):
|
|
239
|
+
"""打印最终汇总。"""
|
|
240
|
+
print("\n")
|
|
241
|
+
print("╔" + "═" * 68 + "╗")
|
|
242
|
+
print("║" + " " * 20 + "EXPERIMENT SUMMARY" + " " * 30 + "║")
|
|
243
|
+
print("╚" + "═" * 68 + "╝")
|
|
244
|
+
|
|
245
|
+
# Section 5.2 Summary
|
|
246
|
+
if "section_5_2" in all_results:
|
|
247
|
+
print("\n📊 Section 5.2: Main Results")
|
|
248
|
+
print("-" * 50)
|
|
249
|
+
for challenge, result in all_results["section_5_2"].items():
|
|
250
|
+
if result and hasattr(result, "target_met"):
|
|
251
|
+
status = "✅ PASS" if result.target_met else "❌ FAIL"
|
|
252
|
+
best = f"{result.best_metric * 100:.1f}%" if result.best_metric else "N/A"
|
|
253
|
+
print(f" {challenge:15s}: {best:>8s} {status}")
|
|
254
|
+
|
|
255
|
+
# Section 5.3 Summary
|
|
256
|
+
if "section_5_3" in all_results:
|
|
257
|
+
print("\n🔬 Section 5.3: Analysis")
|
|
258
|
+
print("-" * 50)
|
|
259
|
+
for analysis, result in all_results["section_5_3"].items():
|
|
260
|
+
status = "✅ Done" if result else "❌ Failed"
|
|
261
|
+
print(f" {analysis:20s}: {status}")
|
|
262
|
+
|
|
263
|
+
# Section 5.4 Summary
|
|
264
|
+
if "section_5_4" in all_results:
|
|
265
|
+
print("\n🌐 Section 5.4: Generalization")
|
|
266
|
+
print("-" * 50)
|
|
267
|
+
for eval_name, result in all_results["section_5_4"].items():
|
|
268
|
+
status = "✅ Done" if result else "❌ Failed"
|
|
269
|
+
print(f" {eval_name:20s}: {status}")
|
|
270
|
+
|
|
271
|
+
# Section 5.5 Summary
|
|
272
|
+
if "section_5_5" in all_results:
|
|
273
|
+
print("\n🎓 Section 5.5: Training Comparison")
|
|
274
|
+
print("-" * 50)
|
|
275
|
+
tc = all_results["section_5_5"].get("training_comparison")
|
|
276
|
+
if tc and hasattr(tc, "best_method"):
|
|
277
|
+
print(f" Best method: {tc.best_method}")
|
|
278
|
+
print(f" Best score: {tc.best_score * 100:.1f}%")
|
|
279
|
+
elif tc:
|
|
280
|
+
status = "✅ Done"
|
|
281
|
+
print(f" training_comparison: {status}")
|
|
282
|
+
else:
|
|
283
|
+
print(" training_comparison: ❌ Failed")
|
|
284
|
+
|
|
285
|
+
# Timing
|
|
286
|
+
print("\n" + "=" * 50)
|
|
287
|
+
print(f" Total time: {elapsed_time / 60:.1f} minutes")
|
|
288
|
+
print(f" Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
289
|
+
print("=" * 50)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def run_section_5_5(args) -> dict[str, Any]:
|
|
293
|
+
"""
|
|
294
|
+
Section 5.5: Training Method Comparison
|
|
295
|
+
"""
|
|
296
|
+
print("\n" + "=" * 70)
|
|
297
|
+
print("🎓 SECTION 5.5: TRAINING METHOD COMPARISON")
|
|
298
|
+
print("=" * 70)
|
|
299
|
+
|
|
300
|
+
results = {}
|
|
301
|
+
|
|
302
|
+
if not args.exp or args.exp in ["training", "all"]:
|
|
303
|
+
print("\n▶ Running Training Method Comparison...")
|
|
304
|
+
try:
|
|
305
|
+
from experiments.exp_training_comparison import run_training_comparison
|
|
306
|
+
|
|
307
|
+
# Paper 1 compares published SOTA training methods
|
|
308
|
+
# SIAS methods (B_coreset, C_continual, D_combined) are for Paper 2
|
|
309
|
+
if args.train_methods:
|
|
310
|
+
methods = args.train_methods.split(",")
|
|
311
|
+
else:
|
|
312
|
+
# Paper 1 default: compare published SOTA methods
|
|
313
|
+
methods = [
|
|
314
|
+
"A_baseline", # Standard SFT (full params)
|
|
315
|
+
"A_lora", # LoRA (Hu et al., 2021)
|
|
316
|
+
"A_qlora", # QLoRA (Dettmers et al., 2023)
|
|
317
|
+
"A_fireact", # FireAct trajectory tuning
|
|
318
|
+
"A_agenttuning", # AgentTuning multi-task
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
results["training_comparison"] = run_training_comparison(
|
|
322
|
+
methods=methods,
|
|
323
|
+
base_model=args.train_model,
|
|
324
|
+
quick=args.quick,
|
|
325
|
+
dry_run=args.dry_run,
|
|
326
|
+
verbose=args.verbose,
|
|
327
|
+
)
|
|
328
|
+
except Exception as e:
|
|
329
|
+
print(f" ❌ Training comparison failed: {e}")
|
|
330
|
+
results["training_comparison"] = None
|
|
331
|
+
|
|
332
|
+
return results
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def generate_paper_materials(all_results: dict[str, Any], output_dir: Path):
|
|
336
|
+
"""生成论文所需的 figures 和 tables。"""
|
|
337
|
+
print("\n" + "=" * 70)
|
|
338
|
+
print("📝 GENERATING PAPER MATERIALS")
|
|
339
|
+
print("=" * 70)
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
from experiments.table_generator import generate_all_tables
|
|
343
|
+
|
|
344
|
+
# 收集所有结果
|
|
345
|
+
collected: dict[str, Any] = {
|
|
346
|
+
"timing": [],
|
|
347
|
+
"planning": [],
|
|
348
|
+
"selection": [],
|
|
349
|
+
"training": [],
|
|
350
|
+
"ablation": {},
|
|
351
|
+
"cross_dataset": {},
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
# 从 section_5_2 收集
|
|
355
|
+
if "section_5_2" in all_results:
|
|
356
|
+
s52 = all_results["section_5_2"]
|
|
357
|
+
if s52.get("timing"):
|
|
358
|
+
collected["timing"] = [
|
|
359
|
+
r.to_dict() if hasattr(r, "to_dict") else r
|
|
360
|
+
for r in (s52["timing"].results if hasattr(s52["timing"], "results") else [])
|
|
361
|
+
]
|
|
362
|
+
if s52.get("planning"):
|
|
363
|
+
collected["planning"] = [
|
|
364
|
+
r.to_dict() if hasattr(r, "to_dict") else r
|
|
365
|
+
for r in (
|
|
366
|
+
s52["planning"].results if hasattr(s52["planning"], "results") else []
|
|
367
|
+
)
|
|
368
|
+
]
|
|
369
|
+
if s52.get("selection"):
|
|
370
|
+
collected["selection"] = [
|
|
371
|
+
r.to_dict() if hasattr(r, "to_dict") else r
|
|
372
|
+
for r in (
|
|
373
|
+
s52["selection"].results if hasattr(s52["selection"], "results") else []
|
|
374
|
+
)
|
|
375
|
+
]
|
|
376
|
+
|
|
377
|
+
# 从 section_5_5 收集训练结果
|
|
378
|
+
if "section_5_5" in all_results:
|
|
379
|
+
s55 = all_results["section_5_5"]
|
|
380
|
+
if s55.get("training_comparison"):
|
|
381
|
+
tc = s55["training_comparison"]
|
|
382
|
+
if hasattr(tc, "results"):
|
|
383
|
+
collected["training"] = [r.to_dict() for r in tc.results]
|
|
384
|
+
|
|
385
|
+
# 生成表格
|
|
386
|
+
generate_all_tables(collected, output_dir / "tables")
|
|
387
|
+
print(" ✓ LaTeX tables generated")
|
|
388
|
+
|
|
389
|
+
except Exception as e:
|
|
390
|
+
print(f" ⚠️ Failed to generate tables: {e}")
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def main():
|
|
394
|
+
parser = argparse.ArgumentParser(
|
|
395
|
+
description="SAGE-Bench Paper 1 Experiment Runner",
|
|
396
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
397
|
+
epilog="""
|
|
398
|
+
Examples:
|
|
399
|
+
# Run all experiments
|
|
400
|
+
python run_paper1_experiments.py
|
|
401
|
+
|
|
402
|
+
# Quick mode (fewer samples)
|
|
403
|
+
python run_paper1_experiments.py --quick
|
|
404
|
+
|
|
405
|
+
# Run specific section
|
|
406
|
+
python run_paper1_experiments.py --section 5.2 # Main results only
|
|
407
|
+
python run_paper1_experiments.py --section 5.5 # Training comparison
|
|
408
|
+
|
|
409
|
+
# Run specific experiment
|
|
410
|
+
python run_paper1_experiments.py --exp timing
|
|
411
|
+
python run_paper1_experiments.py --exp training
|
|
412
|
+
|
|
413
|
+
# Skip LLM-based methods (faster)
|
|
414
|
+
python run_paper1_experiments.py --skip-llm
|
|
415
|
+
|
|
416
|
+
# Training method comparison
|
|
417
|
+
python run_paper1_experiments.py --exp training --train-methods A_baseline,D_combined
|
|
418
|
+
python run_paper1_experiments.py --exp training --dry-run
|
|
419
|
+
|
|
420
|
+
# LLM service management
|
|
421
|
+
python run_paper1_experiments.py --llm-status
|
|
422
|
+
python run_paper1_experiments.py --llm-start
|
|
423
|
+
python run_paper1_experiments.py --llm-stop
|
|
424
|
+
""",
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
parser.add_argument(
|
|
428
|
+
"--section",
|
|
429
|
+
type=str,
|
|
430
|
+
choices=["5.2", "5.3", "5.4", "5.5", "all"],
|
|
431
|
+
default="all",
|
|
432
|
+
help="Run specific section (default: all)",
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
parser.add_argument(
|
|
436
|
+
"--exp",
|
|
437
|
+
type=str,
|
|
438
|
+
choices=[
|
|
439
|
+
"timing",
|
|
440
|
+
"planning",
|
|
441
|
+
"selection", # 5.2
|
|
442
|
+
"error",
|
|
443
|
+
"scaling",
|
|
444
|
+
"robustness",
|
|
445
|
+
"ablation", # 5.3
|
|
446
|
+
"cross-dataset", # 5.4
|
|
447
|
+
"training", # 5.5
|
|
448
|
+
"all",
|
|
449
|
+
],
|
|
450
|
+
default=None,
|
|
451
|
+
help="Run specific experiment only",
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
parser.add_argument(
|
|
455
|
+
"--quick",
|
|
456
|
+
action="store_true",
|
|
457
|
+
help="Quick mode with fewer samples",
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
parser.add_argument(
|
|
461
|
+
"--skip-llm",
|
|
462
|
+
action="store_true",
|
|
463
|
+
help="Skip LLM-based methods (faster)",
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
parser.add_argument(
|
|
467
|
+
"--skip-llm-scaling",
|
|
468
|
+
action="store_true",
|
|
469
|
+
help="Skip real LLM Scaling test (uses estimates instead)",
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
parser.add_argument(
|
|
473
|
+
"--verbose",
|
|
474
|
+
action="store_true",
|
|
475
|
+
default=True,
|
|
476
|
+
help="Verbose output",
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
# Training comparison arguments
|
|
480
|
+
parser.add_argument(
|
|
481
|
+
"--train-methods",
|
|
482
|
+
type=str,
|
|
483
|
+
default=None,
|
|
484
|
+
help="Training methods to compare, comma-separated (e.g., A_baseline,D_combined)",
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
parser.add_argument(
|
|
488
|
+
"--train-model",
|
|
489
|
+
type=str,
|
|
490
|
+
default="Qwen/Qwen2.5-1.5B-Instruct",
|
|
491
|
+
help="Base model for training comparison",
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
parser.add_argument(
|
|
495
|
+
"--dry-run",
|
|
496
|
+
action="store_true",
|
|
497
|
+
help="Simulate training without actual model training",
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
# LLM service management
|
|
501
|
+
parser.add_argument(
|
|
502
|
+
"--llm-status",
|
|
503
|
+
action="store_true",
|
|
504
|
+
help="Check LLM service status",
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
parser.add_argument(
|
|
508
|
+
"--llm-start",
|
|
509
|
+
action="store_true",
|
|
510
|
+
help="Start LLM service",
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
parser.add_argument(
|
|
514
|
+
"--llm-stop",
|
|
515
|
+
action="store_true",
|
|
516
|
+
help="Stop LLM service",
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
parser.add_argument(
|
|
520
|
+
"--llm-model",
|
|
521
|
+
type=str,
|
|
522
|
+
default="Qwen/Qwen2.5-0.5B-Instruct",
|
|
523
|
+
help="Model for LLM service",
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
parser.add_argument(
|
|
527
|
+
"--llm-port",
|
|
528
|
+
type=int,
|
|
529
|
+
default=8901,
|
|
530
|
+
help="Port for LLM service",
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Output
|
|
534
|
+
parser.add_argument(
|
|
535
|
+
"--output-dir",
|
|
536
|
+
type=str,
|
|
537
|
+
default=None,
|
|
538
|
+
help="Output directory for results",
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
parser.add_argument(
|
|
542
|
+
"--generate-paper",
|
|
543
|
+
action="store_true",
|
|
544
|
+
help="Generate paper materials (figures and tables)",
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
args = parser.parse_args()
|
|
548
|
+
|
|
549
|
+
# Handle LLM service commands
|
|
550
|
+
if args.llm_status or args.llm_start or args.llm_stop:
|
|
551
|
+
from experiments.llm_service import (
|
|
552
|
+
print_llm_status,
|
|
553
|
+
start_llm_service,
|
|
554
|
+
stop_llm_service,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
if args.llm_status:
|
|
558
|
+
print_llm_status()
|
|
559
|
+
return
|
|
560
|
+
|
|
561
|
+
if args.llm_start:
|
|
562
|
+
start_llm_service(model=args.llm_model, port=args.llm_port)
|
|
563
|
+
return
|
|
564
|
+
|
|
565
|
+
if args.llm_stop:
|
|
566
|
+
stop_llm_service()
|
|
567
|
+
return
|
|
568
|
+
|
|
569
|
+
# 打印 banner
|
|
570
|
+
print_banner()
|
|
571
|
+
|
|
572
|
+
print("Configuration:")
|
|
573
|
+
print(f" Section: {args.section}")
|
|
574
|
+
print(f" Experiment: {args.exp or 'all'}")
|
|
575
|
+
print(f" Quick mode: {args.quick}")
|
|
576
|
+
print(f" Skip LLM: {args.skip_llm}")
|
|
577
|
+
if args.train_methods:
|
|
578
|
+
print(f" Train methods: {args.train_methods}")
|
|
579
|
+
|
|
580
|
+
# Check LLM availability (unless skipping)
|
|
581
|
+
if not args.skip_llm:
|
|
582
|
+
print("\n📡 Checking LLM service...")
|
|
583
|
+
from experiments.llm_service import ensure_llm_available
|
|
584
|
+
|
|
585
|
+
# 尝试自动启动本地服务,不允许使用云端 API
|
|
586
|
+
# 优先检查命令行指定的端口和模型
|
|
587
|
+
if not ensure_llm_available(
|
|
588
|
+
port=args.llm_port,
|
|
589
|
+
model=args.llm_model,
|
|
590
|
+
auto_start=True,
|
|
591
|
+
allow_cloud=False,
|
|
592
|
+
):
|
|
593
|
+
print(" ❌ Failed to connect to or start local LLM service. Aborting.")
|
|
594
|
+
print(" 💡 Please check logs or start service manually: sage llm run")
|
|
595
|
+
sys.exit(1)
|
|
596
|
+
|
|
597
|
+
start_time = time.time()
|
|
598
|
+
all_results = {}
|
|
599
|
+
|
|
600
|
+
# 运行各 section
|
|
601
|
+
if args.section in ["5.2", "all"]:
|
|
602
|
+
all_results["section_5_2"] = run_section_5_2(args)
|
|
603
|
+
|
|
604
|
+
if args.section in ["5.3", "all"]:
|
|
605
|
+
all_results["section_5_3"] = run_section_5_3(args)
|
|
606
|
+
|
|
607
|
+
if args.section in ["5.4", "all"]:
|
|
608
|
+
all_results["section_5_4"] = run_section_5_4(args)
|
|
609
|
+
|
|
610
|
+
if args.section in ["5.5", "all"] or args.exp == "training":
|
|
611
|
+
all_results["section_5_5"] = run_section_5_5(args)
|
|
612
|
+
|
|
613
|
+
elapsed_time = time.time() - start_time
|
|
614
|
+
|
|
615
|
+
# 生成论文材料
|
|
616
|
+
if args.generate_paper:
|
|
617
|
+
from experiments.exp_utils import DEFAULT_OUTPUT_DIR
|
|
618
|
+
|
|
619
|
+
output_dir = Path(args.output_dir) if args.output_dir else DEFAULT_OUTPUT_DIR
|
|
620
|
+
generate_paper_materials(all_results, output_dir)
|
|
621
|
+
|
|
622
|
+
# 打印最终汇总
|
|
623
|
+
print_final_summary(all_results, elapsed_time)
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
if __name__ == "__main__":
|
|
627
|
+
main()
|