isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Table Generator - LaTeX 表格生成模块
|
|
3
|
+
|
|
4
|
+
为 Paper 1 所有实验生成 LaTeX 格式的表格:
|
|
5
|
+
- Main Results 表
|
|
6
|
+
- Benchmark Details 表
|
|
7
|
+
- Ablation Study 表
|
|
8
|
+
- Training Comparison 表
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Optional
|
|
15
|
+
|
|
16
|
+
from .exp_utils import get_tables_dir
|
|
17
|
+
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# LaTeX 表格生成
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def generate_main_results_table(
|
|
24
|
+
timing_results: list[dict],
|
|
25
|
+
planning_results: list[dict],
|
|
26
|
+
selection_results: list[dict],
|
|
27
|
+
output_path: Optional[Path] = None,
|
|
28
|
+
) -> str:
|
|
29
|
+
"""
|
|
30
|
+
生成主结果表 (Table 1)。
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
timing_results: Timing 实验结果
|
|
34
|
+
planning_results: Planning 实验结果
|
|
35
|
+
selection_results: Selection 实验结果
|
|
36
|
+
output_path: 输出路径
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
LaTeX 表格字符串
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
# 找各 challenge 最佳结果
|
|
43
|
+
def get_best(results, metric):
|
|
44
|
+
if not results:
|
|
45
|
+
return None
|
|
46
|
+
return max(results, key=lambda r: r.get("metrics", {}).get(metric, 0))
|
|
47
|
+
|
|
48
|
+
timing_best = get_best(timing_results, "accuracy")
|
|
49
|
+
planning_best = get_best(planning_results, "plan_success_rate")
|
|
50
|
+
selection_best = get_best(selection_results, "top_k_accuracy")
|
|
51
|
+
|
|
52
|
+
lines = [
|
|
53
|
+
r"\begin{table}[t]",
|
|
54
|
+
r"\centering",
|
|
55
|
+
r"\caption{Main results on SAGE-AgentBench. Best performing strategy for each challenge.}",
|
|
56
|
+
r"\label{tab:main}",
|
|
57
|
+
r"\begin{tabular}{lccc}",
|
|
58
|
+
r"\toprule",
|
|
59
|
+
r"\textbf{Challenge} & \textbf{Strategy} & \textbf{Primary Metric} & \textbf{Target Met} \\",
|
|
60
|
+
r"\midrule",
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
if timing_best:
|
|
64
|
+
acc = timing_best.get("metrics", {}).get("accuracy", 0) * 100
|
|
65
|
+
strategy = timing_best.get("strategy", "").replace("timing.", "").replace("_", " ").title()
|
|
66
|
+
passed = acc >= 95
|
|
67
|
+
status = r"\cmark" if passed else r"\xmark"
|
|
68
|
+
lines.append(f"Timing Detection & {strategy} & {acc:.1f}\\% & {status} \\\\")
|
|
69
|
+
|
|
70
|
+
if planning_best:
|
|
71
|
+
rate = planning_best.get("metrics", {}).get("plan_success_rate", 0) * 100
|
|
72
|
+
strategy = (
|
|
73
|
+
planning_best.get("strategy", "").replace("planner.", "").replace("_", " ").title()
|
|
74
|
+
)
|
|
75
|
+
passed = rate >= 90
|
|
76
|
+
status = r"\cmark" if passed else r"\xmark"
|
|
77
|
+
lines.append(f"Task Planning & {strategy} & {rate:.1f}\\% & {status} \\\\")
|
|
78
|
+
|
|
79
|
+
if selection_best:
|
|
80
|
+
acc = selection_best.get("metrics", {}).get("top_k_accuracy", 0) * 100
|
|
81
|
+
strategy = (
|
|
82
|
+
selection_best.get("strategy", "").replace("selector.", "").replace("_", " ").title()
|
|
83
|
+
)
|
|
84
|
+
passed = acc >= 95
|
|
85
|
+
status = r"\cmark" if passed else r"\xmark"
|
|
86
|
+
lines.append(f"Tool Selection & {strategy} & {acc:.1f}\\% & {status} \\\\")
|
|
87
|
+
|
|
88
|
+
lines.extend(
|
|
89
|
+
[
|
|
90
|
+
r"\bottomrule",
|
|
91
|
+
r"\end{tabular}",
|
|
92
|
+
r"\end{table}",
|
|
93
|
+
]
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
content = "\n".join(lines)
|
|
97
|
+
|
|
98
|
+
if output_path:
|
|
99
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
100
|
+
output_path.write_text(content)
|
|
101
|
+
|
|
102
|
+
return content
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def generate_benchmark_details_table(
|
|
106
|
+
timing_results: list[dict],
|
|
107
|
+
planning_results: list[dict],
|
|
108
|
+
selection_results: list[dict],
|
|
109
|
+
output_path: Optional[Path] = None,
|
|
110
|
+
) -> str:
|
|
111
|
+
"""
|
|
112
|
+
生成详细 Benchmark 结果表 (Table 2)。
|
|
113
|
+
|
|
114
|
+
显示所有策略在所有指标上的表现。
|
|
115
|
+
"""
|
|
116
|
+
lines = [
|
|
117
|
+
r"\begin{table}[t]",
|
|
118
|
+
r"\centering",
|
|
119
|
+
r"\caption{Detailed benchmark results across all strategies.}",
|
|
120
|
+
r"\label{tab:benchmark}",
|
|
121
|
+
r"\small",
|
|
122
|
+
r"\begin{tabular}{llcccc}",
|
|
123
|
+
r"\toprule",
|
|
124
|
+
r"\textbf{Challenge} & \textbf{Strategy} & \textbf{Metric 1} & \textbf{Metric 2} & \textbf{Metric 3} & \textbf{Pass} \\",
|
|
125
|
+
r"\midrule",
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
# Timing results
|
|
129
|
+
for r in timing_results or []:
|
|
130
|
+
name = r.get("strategy", "").replace("timing.", "").replace("_", " ").title()
|
|
131
|
+
metrics = r.get("metrics", {})
|
|
132
|
+
acc = metrics.get("accuracy", 0) * 100
|
|
133
|
+
prec = metrics.get("precision", 0) * 100
|
|
134
|
+
rec = metrics.get("recall", 0) * 100
|
|
135
|
+
passed = acc >= 95
|
|
136
|
+
status = r"\cmark" if passed else r"\xmark"
|
|
137
|
+
lines.append(
|
|
138
|
+
f"Timing & {name} & Acc: {acc:.1f}\\% & Prec: {prec:.1f}\\% & Rec: {rec:.1f}\\% & {status} \\\\"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if timing_results:
|
|
142
|
+
lines.append(r"\midrule")
|
|
143
|
+
|
|
144
|
+
# Planning results
|
|
145
|
+
for r in planning_results or []:
|
|
146
|
+
name = r.get("strategy", "").replace("planner.", "").replace("_", " ").title()
|
|
147
|
+
metrics = r.get("metrics", {})
|
|
148
|
+
success = metrics.get("plan_success_rate", 0) * 100
|
|
149
|
+
step = metrics.get("step_accuracy", 0) * 100
|
|
150
|
+
seq = metrics.get("sequence_match", 0) * 100
|
|
151
|
+
passed = success >= 90
|
|
152
|
+
status = r"\cmark" if passed else r"\xmark"
|
|
153
|
+
lines.append(
|
|
154
|
+
f"Planning & {name} & Succ: {success:.1f}\\% & Step: {step:.1f}\\% & Seq: {seq:.1f}\\% & {status} \\\\"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if planning_results:
|
|
158
|
+
lines.append(r"\midrule")
|
|
159
|
+
|
|
160
|
+
# Tool selection results
|
|
161
|
+
for r in selection_results or []:
|
|
162
|
+
name = r.get("strategy", "").replace("selector.", "").replace("_", " ").title()
|
|
163
|
+
metrics = r.get("metrics", {})
|
|
164
|
+
top_k = metrics.get("top_k_accuracy", 0) * 100
|
|
165
|
+
mrr = metrics.get("mrr", 0) * 100
|
|
166
|
+
recall = metrics.get("recall_at_k", 0) * 100
|
|
167
|
+
passed = top_k >= 95
|
|
168
|
+
status = r"\cmark" if passed else r"\xmark"
|
|
169
|
+
lines.append(
|
|
170
|
+
f"Selection & {name} & Top-K: {top_k:.1f}\\% & MRR: {mrr:.1f}\\% & R@K: {recall:.1f}\\% & {status} \\\\"
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
lines.extend(
|
|
174
|
+
[
|
|
175
|
+
r"\bottomrule",
|
|
176
|
+
r"\end{tabular}",
|
|
177
|
+
r"\end{table}",
|
|
178
|
+
]
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
content = "\n".join(lines)
|
|
182
|
+
|
|
183
|
+
if output_path:
|
|
184
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
185
|
+
output_path.write_text(content)
|
|
186
|
+
|
|
187
|
+
return content
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def generate_training_comparison_table(
|
|
191
|
+
training_results: list[dict],
|
|
192
|
+
output_path: Optional[Path] = None,
|
|
193
|
+
) -> str:
|
|
194
|
+
"""
|
|
195
|
+
生成训练方法对比表 (Table 3)。
|
|
196
|
+
"""
|
|
197
|
+
lines = [
|
|
198
|
+
r"\begin{table}[t]",
|
|
199
|
+
r"\centering",
|
|
200
|
+
r"\caption{Training method comparison. Performance after fine-tuning on SAGE-AgentBench.}",
|
|
201
|
+
r"\label{tab:training}",
|
|
202
|
+
r"\small",
|
|
203
|
+
r"\begin{tabular}{lcccc}",
|
|
204
|
+
r"\toprule",
|
|
205
|
+
r"\textbf{Method} & \textbf{Timing} & \textbf{Planning} & \textbf{Selection} & \textbf{Overall} \\",
|
|
206
|
+
r"\midrule",
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
for r in training_results or []:
|
|
210
|
+
name = r.get("method_name", "").replace("_", " ")
|
|
211
|
+
timing = r.get("timing_accuracy", 0) * 100
|
|
212
|
+
planning = r.get("planning_success_rate", 0) * 100
|
|
213
|
+
selection = r.get("selection_top_k_accuracy", 0) * 100
|
|
214
|
+
overall = (timing + planning + selection) / 3
|
|
215
|
+
lines.append(
|
|
216
|
+
f"{name} & {timing:.1f}\\% & {planning:.1f}\\% & {selection:.1f}\\% & {overall:.1f}\\% \\\\"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
lines.extend(
|
|
220
|
+
[
|
|
221
|
+
r"\bottomrule",
|
|
222
|
+
r"\end{tabular}",
|
|
223
|
+
r"\end{table}",
|
|
224
|
+
]
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
content = "\n".join(lines)
|
|
228
|
+
|
|
229
|
+
if output_path:
|
|
230
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
231
|
+
output_path.write_text(content)
|
|
232
|
+
|
|
233
|
+
return content
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def generate_ablation_table(
|
|
237
|
+
ablation_results: dict[str, dict[str, float]],
|
|
238
|
+
title: str = "Ablation study results",
|
|
239
|
+
label: str = "tab:ablation",
|
|
240
|
+
output_path: Optional[Path] = None,
|
|
241
|
+
) -> str:
|
|
242
|
+
"""
|
|
243
|
+
生成消融实验表。
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
ablation_results: {"config_name": {"metric1": 0.9, ...}, ...}
|
|
247
|
+
title: 表格标题
|
|
248
|
+
label: LaTeX 标签
|
|
249
|
+
output_path: 输出路径
|
|
250
|
+
"""
|
|
251
|
+
if not ablation_results:
|
|
252
|
+
return ""
|
|
253
|
+
|
|
254
|
+
configs = list(ablation_results.keys())
|
|
255
|
+
metrics = list(next(iter(ablation_results.values())).keys())
|
|
256
|
+
|
|
257
|
+
# 构建表格
|
|
258
|
+
col_spec = "l" + "c" * len(metrics)
|
|
259
|
+
header = (
|
|
260
|
+
r"\textbf{Config} & "
|
|
261
|
+
+ " & ".join([f"\\textbf{{{m.replace('_', ' ').title()}}}" for m in metrics])
|
|
262
|
+
+ r" \\"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
lines = [
|
|
266
|
+
r"\begin{table}[t]",
|
|
267
|
+
r"\centering",
|
|
268
|
+
f"\\caption{{{title}}}",
|
|
269
|
+
f"\\label{{{label}}}",
|
|
270
|
+
r"\small",
|
|
271
|
+
f"\\begin{{tabular}}{{{col_spec}}}",
|
|
272
|
+
r"\toprule",
|
|
273
|
+
header,
|
|
274
|
+
r"\midrule",
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
for config in configs:
|
|
278
|
+
values = ablation_results[config]
|
|
279
|
+
row = config.replace("_", " ")
|
|
280
|
+
for metric in metrics:
|
|
281
|
+
val = values.get(metric, 0)
|
|
282
|
+
if isinstance(val, float):
|
|
283
|
+
row += f" & {val * 100:.1f}\\%"
|
|
284
|
+
else:
|
|
285
|
+
row += f" & {val}"
|
|
286
|
+
row += r" \\"
|
|
287
|
+
lines.append(row)
|
|
288
|
+
|
|
289
|
+
lines.extend(
|
|
290
|
+
[
|
|
291
|
+
r"\bottomrule",
|
|
292
|
+
r"\end{tabular}",
|
|
293
|
+
r"\end{table}",
|
|
294
|
+
]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
content = "\n".join(lines)
|
|
298
|
+
|
|
299
|
+
if output_path:
|
|
300
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
301
|
+
output_path.write_text(content)
|
|
302
|
+
|
|
303
|
+
return content
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def generate_cross_dataset_table(
|
|
307
|
+
cross_dataset_results: dict[str, dict[str, float]],
|
|
308
|
+
output_path: Optional[Path] = None,
|
|
309
|
+
) -> str:
|
|
310
|
+
"""
|
|
311
|
+
生成跨数据集对比表。
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
cross_dataset_results: {"method": {"sage": 0.9, "acebench": 0.85, ...}, ...}
|
|
315
|
+
"""
|
|
316
|
+
if not cross_dataset_results:
|
|
317
|
+
return ""
|
|
318
|
+
|
|
319
|
+
methods = list(cross_dataset_results.keys())
|
|
320
|
+
datasets = list(next(iter(cross_dataset_results.values())).keys())
|
|
321
|
+
|
|
322
|
+
col_spec = "l" + "c" * len(datasets)
|
|
323
|
+
header = (
|
|
324
|
+
r"\textbf{Method} & " + " & ".join([f"\\textbf{{{d.upper()}}}" for d in datasets]) + r" \\"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
lines = [
|
|
328
|
+
r"\begin{table}[t]",
|
|
329
|
+
r"\centering",
|
|
330
|
+
r"\caption{Cross-dataset generalization. Top-5 accuracy on different benchmarks.}",
|
|
331
|
+
r"\label{tab:crossdataset}",
|
|
332
|
+
r"\small",
|
|
333
|
+
f"\\begin{{tabular}}{{{col_spec}}}",
|
|
334
|
+
r"\toprule",
|
|
335
|
+
header,
|
|
336
|
+
r"\midrule",
|
|
337
|
+
]
|
|
338
|
+
|
|
339
|
+
for method in methods:
|
|
340
|
+
row = method.replace("_", " ").title()
|
|
341
|
+
for dataset in datasets:
|
|
342
|
+
val = cross_dataset_results[method].get(dataset, 0) * 100
|
|
343
|
+
row += f" & {val:.1f}\\%"
|
|
344
|
+
row += r" \\"
|
|
345
|
+
lines.append(row)
|
|
346
|
+
|
|
347
|
+
lines.extend(
|
|
348
|
+
[
|
|
349
|
+
r"\bottomrule",
|
|
350
|
+
r"\end{tabular}",
|
|
351
|
+
r"\end{table}",
|
|
352
|
+
]
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
content = "\n".join(lines)
|
|
356
|
+
|
|
357
|
+
if output_path:
|
|
358
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
359
|
+
output_path.write_text(content)
|
|
360
|
+
|
|
361
|
+
return content
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
# =============================================================================
|
|
365
|
+
# 批量生成
|
|
366
|
+
# =============================================================================
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def generate_all_tables(
|
|
370
|
+
all_results: dict[str, Any],
|
|
371
|
+
output_dir: Optional[Path] = None,
|
|
372
|
+
) -> dict[str, Path]:
|
|
373
|
+
"""
|
|
374
|
+
生成所有论文表格。
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
all_results: 包含所有实验结果的字典
|
|
378
|
+
output_dir: 输出目录
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
生成的表格文件路径字典
|
|
382
|
+
"""
|
|
383
|
+
tables_dir = output_dir or get_tables_dir()
|
|
384
|
+
tables_dir.mkdir(parents=True, exist_ok=True)
|
|
385
|
+
|
|
386
|
+
generated = {}
|
|
387
|
+
|
|
388
|
+
# Table 1: Main Results
|
|
389
|
+
timing = all_results.get("timing", [])
|
|
390
|
+
planning = all_results.get("planning", [])
|
|
391
|
+
selection = all_results.get("selection", [])
|
|
392
|
+
|
|
393
|
+
if timing or planning or selection:
|
|
394
|
+
path = tables_dir / "table1_main_results.tex"
|
|
395
|
+
generate_main_results_table(timing, planning, selection, path)
|
|
396
|
+
generated["main_results"] = path
|
|
397
|
+
print(f" Generated: {path}")
|
|
398
|
+
|
|
399
|
+
# Table 2: Benchmark Details
|
|
400
|
+
if timing or planning or selection:
|
|
401
|
+
path = tables_dir / "table2_benchmark_details.tex"
|
|
402
|
+
generate_benchmark_details_table(timing, planning, selection, path)
|
|
403
|
+
generated["benchmark_details"] = path
|
|
404
|
+
print(f" Generated: {path}")
|
|
405
|
+
|
|
406
|
+
# Table 3: Training Comparison
|
|
407
|
+
training = all_results.get("training", [])
|
|
408
|
+
if training:
|
|
409
|
+
path = tables_dir / "table3_training_comparison.tex"
|
|
410
|
+
generate_training_comparison_table(training, path)
|
|
411
|
+
generated["training_comparison"] = path
|
|
412
|
+
print(f" Generated: {path}")
|
|
413
|
+
|
|
414
|
+
# Table 4: Ablation
|
|
415
|
+
ablation = all_results.get("ablation", {})
|
|
416
|
+
if ablation:
|
|
417
|
+
path = tables_dir / "table4_ablation.tex"
|
|
418
|
+
generate_ablation_table(ablation, output_path=path)
|
|
419
|
+
generated["ablation"] = path
|
|
420
|
+
print(f" Generated: {path}")
|
|
421
|
+
|
|
422
|
+
# Table 5: Cross-Dataset
|
|
423
|
+
cross_dataset = all_results.get("cross_dataset", {})
|
|
424
|
+
if cross_dataset:
|
|
425
|
+
path = tables_dir / "table5_cross_dataset.tex"
|
|
426
|
+
generate_cross_dataset_table(cross_dataset, path)
|
|
427
|
+
generated["cross_dataset"] = path
|
|
428
|
+
print(f" Generated: {path}")
|
|
429
|
+
|
|
430
|
+
return generated
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAGE Tools Loader - Loads 1,200 real tools from tool_catalog.jsonl.
|
|
3
|
+
|
|
4
|
+
This module provides the SageToolsLoader class that loads tool definitions
|
|
5
|
+
from the SAGE benchmark's tool catalog and exposes them via a unified interface
|
|
6
|
+
compatible with the selector resources.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Iterator
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Default path to tool catalog
|
|
20
|
+
TOOL_CATALOG_PATH = (
|
|
21
|
+
Path(__file__).parent.parent.parent
|
|
22
|
+
/ "data"
|
|
23
|
+
/ "sources"
|
|
24
|
+
/ "agent_tools"
|
|
25
|
+
/ "data"
|
|
26
|
+
/ "tool_catalog.jsonl"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class SageTool:
|
|
32
|
+
"""
|
|
33
|
+
Represents a tool from the SAGE tool catalog.
|
|
34
|
+
|
|
35
|
+
Attributes:
|
|
36
|
+
tool_id: Unique identifier (e.g., 'environment_weather_001')
|
|
37
|
+
name: Human-readable name (e.g., 'Weather Fetch 1')
|
|
38
|
+
description: Tool description (generated from name/category/capabilities)
|
|
39
|
+
category: Tool category (e.g., 'environment/weather')
|
|
40
|
+
capabilities: List of capabilities (e.g., ['forecast', 'radar'])
|
|
41
|
+
inputs: List of input parameter definitions
|
|
42
|
+
outputs: List of output field definitions
|
|
43
|
+
metadata: Additional metadata (owner, version, etc.)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
tool_id: str
|
|
47
|
+
name: str
|
|
48
|
+
description: str
|
|
49
|
+
category: str
|
|
50
|
+
capabilities: list[str] = field(default_factory=list)
|
|
51
|
+
inputs: list[dict[str, Any]] = field(default_factory=list)
|
|
52
|
+
outputs: list[dict[str, Any]] = field(default_factory=list)
|
|
53
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_json(cls, data: dict[str, Any]) -> SageTool:
|
|
57
|
+
"""Create SageTool from JSON data."""
|
|
58
|
+
# Generate description from available fields
|
|
59
|
+
name = data.get("name", "")
|
|
60
|
+
category = data.get("category", "")
|
|
61
|
+
capabilities = data.get("capabilities", [])
|
|
62
|
+
|
|
63
|
+
# Build description
|
|
64
|
+
desc_parts = [name]
|
|
65
|
+
if category:
|
|
66
|
+
desc_parts.append(f"Category: {category}")
|
|
67
|
+
if capabilities:
|
|
68
|
+
desc_parts.append(f"Capabilities: {', '.join(capabilities)}")
|
|
69
|
+
|
|
70
|
+
# Add input info
|
|
71
|
+
inputs = data.get("inputs", [])
|
|
72
|
+
if inputs:
|
|
73
|
+
input_names = [inp.get("name", "") for inp in inputs if inp.get("name")]
|
|
74
|
+
if input_names:
|
|
75
|
+
desc_parts.append(f"Inputs: {', '.join(input_names[:5])}")
|
|
76
|
+
|
|
77
|
+
description = ". ".join(desc_parts)
|
|
78
|
+
|
|
79
|
+
return cls(
|
|
80
|
+
tool_id=data.get("tool_id", ""),
|
|
81
|
+
name=name,
|
|
82
|
+
description=description,
|
|
83
|
+
category=category,
|
|
84
|
+
capabilities=capabilities,
|
|
85
|
+
inputs=inputs,
|
|
86
|
+
outputs=data.get("outputs", []),
|
|
87
|
+
metadata=data.get("metadata", {}),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class SageToolsLoader:
|
|
92
|
+
"""
|
|
93
|
+
Loads tools from the SAGE benchmark tool catalog.
|
|
94
|
+
|
|
95
|
+
This loader reads the tool_catalog.jsonl file and provides methods
|
|
96
|
+
to iterate over tools and retrieve tools by ID.
|
|
97
|
+
|
|
98
|
+
Usage:
|
|
99
|
+
loader = SageToolsLoader()
|
|
100
|
+
for tool in loader.iter_all():
|
|
101
|
+
print(tool.tool_id, tool.name)
|
|
102
|
+
|
|
103
|
+
tool = loader.get_tool('environment_weather_001')
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, catalog_path: str | Path | None = None):
|
|
107
|
+
"""
|
|
108
|
+
Initialize the loader.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
catalog_path: Path to tool_catalog.jsonl. If None, uses default path.
|
|
112
|
+
"""
|
|
113
|
+
self.catalog_path = Path(catalog_path) if catalog_path else TOOL_CATALOG_PATH
|
|
114
|
+
self._tools: dict[str, SageTool] = {}
|
|
115
|
+
self._loaded = False
|
|
116
|
+
|
|
117
|
+
def _ensure_loaded(self) -> None:
|
|
118
|
+
"""Load tools if not already loaded."""
|
|
119
|
+
if self._loaded:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
if not self.catalog_path.exists():
|
|
123
|
+
logger.warning(f"Tool catalog not found: {self.catalog_path}")
|
|
124
|
+
self._loaded = True
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
with open(self.catalog_path, encoding="utf-8") as f:
|
|
129
|
+
for line in f:
|
|
130
|
+
line = line.strip()
|
|
131
|
+
if not line:
|
|
132
|
+
continue
|
|
133
|
+
try:
|
|
134
|
+
data = json.loads(line)
|
|
135
|
+
tool = SageTool.from_json(data)
|
|
136
|
+
self._tools[tool.tool_id] = tool
|
|
137
|
+
except json.JSONDecodeError as e:
|
|
138
|
+
logger.warning(f"Failed to parse tool JSON: {e}")
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
logger.info(f"Loaded {len(self._tools)} tools from {self.catalog_path}")
|
|
142
|
+
except Exception as e:
|
|
143
|
+
logger.error(f"Failed to load tool catalog: {e}")
|
|
144
|
+
|
|
145
|
+
self._loaded = True
|
|
146
|
+
|
|
147
|
+
def iter_all(self) -> Iterator[SageTool]:
|
|
148
|
+
"""
|
|
149
|
+
Iterate over all tools.
|
|
150
|
+
|
|
151
|
+
Yields:
|
|
152
|
+
SageTool instances
|
|
153
|
+
"""
|
|
154
|
+
self._ensure_loaded()
|
|
155
|
+
yield from self._tools.values()
|
|
156
|
+
|
|
157
|
+
def get_tool(self, tool_id: str) -> SageTool:
|
|
158
|
+
"""
|
|
159
|
+
Get a tool by ID.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
tool_id: Tool identifier
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
SageTool instance
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
KeyError: If tool not found
|
|
169
|
+
"""
|
|
170
|
+
self._ensure_loaded()
|
|
171
|
+
if tool_id not in self._tools:
|
|
172
|
+
raise KeyError(f"Tool not found: {tool_id}")
|
|
173
|
+
return self._tools[tool_id]
|
|
174
|
+
|
|
175
|
+
def get_tools(self, tool_ids: list[str]) -> list[SageTool]:
|
|
176
|
+
"""
|
|
177
|
+
Get multiple tools by IDs.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
tool_ids: List of tool identifiers
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
List of SageTool instances (skips missing tools)
|
|
184
|
+
"""
|
|
185
|
+
self._ensure_loaded()
|
|
186
|
+
tools = []
|
|
187
|
+
for tid in tool_ids:
|
|
188
|
+
if tid in self._tools:
|
|
189
|
+
tools.append(self._tools[tid])
|
|
190
|
+
return tools
|
|
191
|
+
|
|
192
|
+
def __len__(self) -> int:
|
|
193
|
+
"""Return number of tools."""
|
|
194
|
+
self._ensure_loaded()
|
|
195
|
+
return len(self._tools)
|
|
196
|
+
|
|
197
|
+
def __contains__(self, tool_id: str) -> bool:
|
|
198
|
+
"""Check if tool exists."""
|
|
199
|
+
self._ensure_loaded()
|
|
200
|
+
return tool_id in self._tools
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# Singleton instance for convenience
|
|
204
|
+
_default_loader: SageToolsLoader | None = None
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def get_sage_tools_loader() -> SageToolsLoader:
|
|
208
|
+
"""Get the default SageToolsLoader instance."""
|
|
209
|
+
global _default_loader
|
|
210
|
+
if _default_loader is None:
|
|
211
|
+
_default_loader = SageToolsLoader()
|
|
212
|
+
return _default_loader
|