isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Section 5.2.1: Timing Detection Experiment (RQ1)
|
|
4
|
+
|
|
5
|
+
研究问题: 现有方法在判断"是否需要调用工具"上的表现如何?
|
|
6
|
+
|
|
7
|
+
测试方法:
|
|
8
|
+
- timing.rule_based : 关键词 + 正则规则
|
|
9
|
+
- timing.embedding : 语义相似度判断
|
|
10
|
+
- timing.llm_based : 直接 LLM 推理
|
|
11
|
+
- timing.hybrid : Rule 初筛 + LLM 精判
|
|
12
|
+
|
|
13
|
+
目标指标:
|
|
14
|
+
- Primary: Accuracy ≥ 95%
|
|
15
|
+
- Secondary: Precision, Recall, F1
|
|
16
|
+
- Tertiary: Latency (ms)
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
python exp_main_timing.py
|
|
20
|
+
python exp_main_timing.py --max-samples 100
|
|
21
|
+
python exp_main_timing.py --skip-llm
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import argparse
|
|
27
|
+
import time
|
|
28
|
+
|
|
29
|
+
from .exp_utils import (
|
|
30
|
+
ExperimentResult,
|
|
31
|
+
ExperimentSummary,
|
|
32
|
+
create_progress_bar,
|
|
33
|
+
load_benchmark_data,
|
|
34
|
+
print_metrics_detail,
|
|
35
|
+
print_result_row,
|
|
36
|
+
print_section_header,
|
|
37
|
+
print_subsection_header,
|
|
38
|
+
save_results,
|
|
39
|
+
setup_experiment_env,
|
|
40
|
+
)
|
|
41
|
+
from .figure_generator import generate_detailed_table, plot_challenge_comparison
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def run_timing_experiment(
|
|
45
|
+
max_samples: int = 150,
|
|
46
|
+
skip_llm: bool = False,
|
|
47
|
+
verbose: bool = True,
|
|
48
|
+
) -> ExperimentSummary:
|
|
49
|
+
"""
|
|
50
|
+
运行 Timing Detection 实验。
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
max_samples: 最大测试样本数
|
|
54
|
+
skip_llm: 是否跳过 LLM-based 方法
|
|
55
|
+
verbose: 是否打印详细信息
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
ExperimentSummary 对象
|
|
59
|
+
"""
|
|
60
|
+
setup_experiment_env(verbose=verbose)
|
|
61
|
+
|
|
62
|
+
print_section_header("Section 5.2.1: Timing Detection (RQ1)")
|
|
63
|
+
print(" Target: Accuracy ≥ 95%")
|
|
64
|
+
print(f" Max samples: {max_samples}")
|
|
65
|
+
|
|
66
|
+
# 加载数据
|
|
67
|
+
samples = load_benchmark_data("timing", split="test", max_samples=max_samples)
|
|
68
|
+
if not samples:
|
|
69
|
+
print(" ❌ No timing data available")
|
|
70
|
+
return ExperimentSummary(section="5_2_main", challenge="timing")
|
|
71
|
+
|
|
72
|
+
print(f" Loaded {len(samples)} samples")
|
|
73
|
+
|
|
74
|
+
# 获取策略注册表
|
|
75
|
+
try:
|
|
76
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
77
|
+
|
|
78
|
+
registry = get_adapter_registry()
|
|
79
|
+
except ImportError as e:
|
|
80
|
+
print(f" ❌ Failed to import adapter registry: {e}")
|
|
81
|
+
return ExperimentSummary(section="5_2_main", challenge="timing")
|
|
82
|
+
|
|
83
|
+
# 定义测试策略
|
|
84
|
+
strategies = [
|
|
85
|
+
("timing.rule_based", "Rule-based"),
|
|
86
|
+
("timing.embedding", "Embedding"),
|
|
87
|
+
("timing.llm_based", "LLM-based"),
|
|
88
|
+
("timing.hybrid", "Hybrid"),
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
# 跳过 LLM 策略
|
|
92
|
+
LLM_STRATEGIES = {"timing.llm_based", "timing.hybrid"}
|
|
93
|
+
if skip_llm:
|
|
94
|
+
strategies = [(name, display) for name, display in strategies if name not in LLM_STRATEGIES]
|
|
95
|
+
print(" ⚠️ Skipping LLM-based strategies")
|
|
96
|
+
|
|
97
|
+
results = []
|
|
98
|
+
target = 0.95
|
|
99
|
+
|
|
100
|
+
for strategy_name, display_name in strategies:
|
|
101
|
+
print_subsection_header(f"Testing: {display_name}")
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
detector = registry.get(strategy_name)
|
|
105
|
+
except Exception as e:
|
|
106
|
+
print(f" ⚠️ Failed to create detector: {e}")
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
# 运行评测
|
|
110
|
+
start_time = time.time()
|
|
111
|
+
correct = 0
|
|
112
|
+
true_positives = 0
|
|
113
|
+
false_positives = 0
|
|
114
|
+
false_negatives = 0
|
|
115
|
+
true_negatives = 0
|
|
116
|
+
|
|
117
|
+
with create_progress_bar(len(samples), desc=f" {display_name}") as pbar:
|
|
118
|
+
for sample in samples:
|
|
119
|
+
try:
|
|
120
|
+
# 创建消息对象
|
|
121
|
+
from sage.benchmark.benchmark_agent.experiments.timing_detection_exp import (
|
|
122
|
+
TimingMessage,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
message = TimingMessage(
|
|
126
|
+
sample_id=sample.get("sample_id", ""),
|
|
127
|
+
message=sample.get("message", ""),
|
|
128
|
+
context=sample.get("context", {}),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
result = detector.decide(message)
|
|
132
|
+
predicted = result.should_call_tool
|
|
133
|
+
expected = sample.get("should_call_tool", False)
|
|
134
|
+
|
|
135
|
+
if predicted == expected:
|
|
136
|
+
correct += 1
|
|
137
|
+
|
|
138
|
+
# 混淆矩阵
|
|
139
|
+
if predicted and expected:
|
|
140
|
+
true_positives += 1
|
|
141
|
+
elif predicted and not expected:
|
|
142
|
+
false_positives += 1
|
|
143
|
+
elif not predicted and expected:
|
|
144
|
+
false_negatives += 1
|
|
145
|
+
else:
|
|
146
|
+
true_negatives += 1
|
|
147
|
+
|
|
148
|
+
except Exception as e:
|
|
149
|
+
if verbose:
|
|
150
|
+
print(f" Error: {e}")
|
|
151
|
+
|
|
152
|
+
pbar.update(1)
|
|
153
|
+
|
|
154
|
+
elapsed = time.time() - start_time
|
|
155
|
+
|
|
156
|
+
# 计算指标
|
|
157
|
+
n = len(samples)
|
|
158
|
+
accuracy = correct / n if n > 0 else 0
|
|
159
|
+
precision = (
|
|
160
|
+
true_positives / (true_positives + false_positives)
|
|
161
|
+
if (true_positives + false_positives) > 0
|
|
162
|
+
else 0
|
|
163
|
+
)
|
|
164
|
+
recall = (
|
|
165
|
+
true_positives / (true_positives + false_negatives)
|
|
166
|
+
if (true_positives + false_negatives) > 0
|
|
167
|
+
else 0
|
|
168
|
+
)
|
|
169
|
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
|
|
170
|
+
|
|
171
|
+
exp_result = ExperimentResult(
|
|
172
|
+
challenge="timing",
|
|
173
|
+
strategy=strategy_name,
|
|
174
|
+
metrics={
|
|
175
|
+
"accuracy": accuracy,
|
|
176
|
+
"precision": precision,
|
|
177
|
+
"recall": recall,
|
|
178
|
+
"f1": f1,
|
|
179
|
+
},
|
|
180
|
+
metadata={
|
|
181
|
+
"total_samples": n,
|
|
182
|
+
"correct": correct,
|
|
183
|
+
"latency_ms": elapsed * 1000 / n if n > 0 else 0,
|
|
184
|
+
"confusion_matrix": {
|
|
185
|
+
"tp": true_positives,
|
|
186
|
+
"fp": false_positives,
|
|
187
|
+
"fn": false_negatives,
|
|
188
|
+
"tn": true_negatives,
|
|
189
|
+
},
|
|
190
|
+
},
|
|
191
|
+
passed=accuracy >= target,
|
|
192
|
+
target=target,
|
|
193
|
+
)
|
|
194
|
+
results.append(exp_result)
|
|
195
|
+
|
|
196
|
+
# 打印结果
|
|
197
|
+
print_result_row(display_name, exp_result.metrics, exp_result.passed, target)
|
|
198
|
+
if verbose:
|
|
199
|
+
print_metrics_detail(exp_result.metrics)
|
|
200
|
+
|
|
201
|
+
# 找出最佳策略
|
|
202
|
+
best_result = max(results, key=lambda r: r.metrics["accuracy"]) if results else None
|
|
203
|
+
|
|
204
|
+
summary = ExperimentSummary(
|
|
205
|
+
section="5_2_main",
|
|
206
|
+
challenge="timing",
|
|
207
|
+
results=results,
|
|
208
|
+
best_strategy=best_result.strategy if best_result else None,
|
|
209
|
+
best_metric=best_result.metrics["accuracy"] if best_result else None,
|
|
210
|
+
target_met=any(r.passed for r in results),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# 保存结果
|
|
214
|
+
output_file = save_results(summary.to_dict(), "5_2_main", "timing")
|
|
215
|
+
print(f"\n Results saved to: {output_file}")
|
|
216
|
+
|
|
217
|
+
# 生成图表
|
|
218
|
+
if results:
|
|
219
|
+
from figure_generator import get_figures_dir, get_tables_dir
|
|
220
|
+
|
|
221
|
+
figures_dir = get_figures_dir()
|
|
222
|
+
tables_dir = get_tables_dir()
|
|
223
|
+
|
|
224
|
+
# 生成对比图
|
|
225
|
+
plot_challenge_comparison(
|
|
226
|
+
[{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
|
|
227
|
+
challenge="timing",
|
|
228
|
+
metrics=["accuracy", "precision", "recall", "f1"],
|
|
229
|
+
target=target,
|
|
230
|
+
output_path=figures_dir / "fig1_main_timing_comparison.pdf",
|
|
231
|
+
title="Timing Detection: Strategy Comparison",
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# 生成表格
|
|
235
|
+
generate_detailed_table(
|
|
236
|
+
[{"strategy": r.strategy.split(".")[-1], "metrics": r.metrics} for r in results],
|
|
237
|
+
challenge="timing",
|
|
238
|
+
metrics=["accuracy", "precision", "recall", "f1"],
|
|
239
|
+
output_path=tables_dir / "table_timing_detailed.tex",
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
print(f" Figure saved to: {figures_dir / 'fig1_main_timing_comparison.pdf'}")
|
|
243
|
+
|
|
244
|
+
return summary
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def main():
|
|
248
|
+
parser = argparse.ArgumentParser(description="Section 5.2.1: Timing Detection Experiment")
|
|
249
|
+
parser.add_argument("--max-samples", type=int, default=150, help="Maximum samples to test")
|
|
250
|
+
parser.add_argument("--skip-llm", action="store_true", help="Skip LLM-based methods")
|
|
251
|
+
parser.add_argument("--verbose", action="store_true", default=True, help="Verbose output")
|
|
252
|
+
args = parser.parse_args()
|
|
253
|
+
|
|
254
|
+
summary = run_timing_experiment(
|
|
255
|
+
max_samples=args.max_samples,
|
|
256
|
+
skip_llm=args.skip_llm,
|
|
257
|
+
verbose=args.verbose,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# 打印总结
|
|
261
|
+
print("\n" + "=" * 70)
|
|
262
|
+
print("📊 Summary")
|
|
263
|
+
print("=" * 70)
|
|
264
|
+
print(f" Best strategy: {summary.best_strategy}")
|
|
265
|
+
print(f" Best accuracy: {summary.best_metric * 100:.1f}%" if summary.best_metric else " N/A")
|
|
266
|
+
print(f" Target met: {'✅ YES' if summary.target_met else '❌ NO'}")
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
if __name__ == "__main__":
|
|
270
|
+
main()
|