isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool Selection Experiment
|
|
3
|
+
|
|
4
|
+
Experiment runner for evaluating tool selection capabilities.
|
|
5
|
+
Tests ability to select relevant tools from a candidate set given a user instruction.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
11
|
+
BaseExperiment,
|
|
12
|
+
ExperimentResult,
|
|
13
|
+
ToolSelectionConfig,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ToolSelectionQuery:
|
|
18
|
+
"""Query for tool selection."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self, sample_id: str, instruction: str, context: dict[str, Any], candidate_tools: list[str]
|
|
22
|
+
):
|
|
23
|
+
self.sample_id = sample_id
|
|
24
|
+
self.instruction = instruction
|
|
25
|
+
self.context = context
|
|
26
|
+
self.candidate_tools = candidate_tools
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ToolSelectionExperiment(BaseExperiment):
|
|
30
|
+
"""
|
|
31
|
+
Experiment for tool selection evaluation.
|
|
32
|
+
|
|
33
|
+
Workflow:
|
|
34
|
+
1. Load benchmark samples from DataManager
|
|
35
|
+
2. For each sample, call selector strategy to get predictions
|
|
36
|
+
3. Collect predictions and ground truth
|
|
37
|
+
4. Return ExperimentResult for evaluation
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self, config: ToolSelectionConfig, data_manager: Any = None, adapter_registry: Any = None
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize tool selection experiment.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
config: Tool selection configuration
|
|
48
|
+
data_manager: DataManager for data loading
|
|
49
|
+
adapter_registry: Registry containing selector strategies
|
|
50
|
+
"""
|
|
51
|
+
super().__init__(config, data_manager, adapter_registry)
|
|
52
|
+
self.config: ToolSelectionConfig = config
|
|
53
|
+
self._embedding_client = None
|
|
54
|
+
|
|
55
|
+
def _create_embedding_client(self):
|
|
56
|
+
"""Create embedding client for selector."""
|
|
57
|
+
# Try to create an embedding client wrapper
|
|
58
|
+
try:
|
|
59
|
+
import os
|
|
60
|
+
|
|
61
|
+
from sage.common.components.sage_embedding import EmbeddingService
|
|
62
|
+
|
|
63
|
+
# Determine embedding method and optional model from environment
|
|
64
|
+
method = os.environ.get("SAGE_EMBEDDING_METHOD") or "hf"
|
|
65
|
+
model = os.environ.get("SAGE_EMBEDDING_MODEL") or None
|
|
66
|
+
api_key = os.environ.get("SAGE_EMBEDDING_API_KEY") or None
|
|
67
|
+
base_url = os.environ.get("SAGE_EMBEDDING_BASE_URL") or None
|
|
68
|
+
|
|
69
|
+
# Create a simple embedding client wrapper that uses configured service
|
|
70
|
+
class EmbeddingClientWrapper:
|
|
71
|
+
"""Wrapper to adapt EmbeddingService to selector interface."""
|
|
72
|
+
|
|
73
|
+
def __init__(self):
|
|
74
|
+
cfg = {
|
|
75
|
+
"method": method,
|
|
76
|
+
"model": model,
|
|
77
|
+
"api_key": api_key,
|
|
78
|
+
"base_url": base_url,
|
|
79
|
+
"normalize": True,
|
|
80
|
+
}
|
|
81
|
+
try:
|
|
82
|
+
self.service = EmbeddingService(cfg)
|
|
83
|
+
self.service.setup()
|
|
84
|
+
except Exception:
|
|
85
|
+
# Fallback to mockembedder if configured method not available
|
|
86
|
+
cfg["method"] = "mockembedder"
|
|
87
|
+
self.service = EmbeddingService(cfg)
|
|
88
|
+
self.service.setup()
|
|
89
|
+
|
|
90
|
+
def embed(self, texts, model=None, batch_size=32):
|
|
91
|
+
"""Embed texts and return numpy array."""
|
|
92
|
+
import numpy as np
|
|
93
|
+
|
|
94
|
+
result = self.service.embed(texts, batch_size=batch_size)
|
|
95
|
+
return np.array(result["vectors"])
|
|
96
|
+
|
|
97
|
+
def cleanup(self):
|
|
98
|
+
try:
|
|
99
|
+
self.service.cleanup()
|
|
100
|
+
except Exception:
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
return EmbeddingClientWrapper()
|
|
104
|
+
except ImportError:
|
|
105
|
+
return None
|
|
106
|
+
except Exception as e:
|
|
107
|
+
print(f"Warning: Could not create embedding client: {e}")
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
def prepare(self):
|
|
111
|
+
"""Prepare experiment: load data and initialize selector."""
|
|
112
|
+
super().prepare()
|
|
113
|
+
|
|
114
|
+
verbose = getattr(self.config, "verbose", False)
|
|
115
|
+
if verbose:
|
|
116
|
+
print(f"\n{'=' * 60}")
|
|
117
|
+
print(f"Tool Selection Experiment: {self.experiment_id}")
|
|
118
|
+
print(f"{'=' * 60}")
|
|
119
|
+
print(f"Profile: {self.config.profile}")
|
|
120
|
+
print(f"Split: {self.config.split}")
|
|
121
|
+
print(f"Selector: {self.config.selector}")
|
|
122
|
+
print(f"Top-k: {self.config.top_k}")
|
|
123
|
+
|
|
124
|
+
# Load data through DataManager
|
|
125
|
+
try:
|
|
126
|
+
agent_eval = self.dm.get_by_usage("agent_eval")
|
|
127
|
+
profile_data = agent_eval.load_profile(self.config.profile)
|
|
128
|
+
|
|
129
|
+
self.benchmark_loader = profile_data.get("benchmark")
|
|
130
|
+
self.tools_loader = profile_data.get("tools")
|
|
131
|
+
|
|
132
|
+
if verbose:
|
|
133
|
+
print("✓ Loaded benchmark data")
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
print(f"Warning: Could not load data: {e}")
|
|
137
|
+
|
|
138
|
+
# Initialize selector strategy with real tools data
|
|
139
|
+
if self.adapter_registry is not None:
|
|
140
|
+
try:
|
|
141
|
+
# Create resources with real tools loader
|
|
142
|
+
from sage.libs.agentic.agents.action.tool_selection import (
|
|
143
|
+
SelectorResources,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Create embedding client if selector needs it
|
|
147
|
+
embedding_client = None
|
|
148
|
+
selector_name = self.config.selector.lower()
|
|
149
|
+
if "embedding" in selector_name or "hybrid" in selector_name:
|
|
150
|
+
embedding_client = self._create_embedding_client()
|
|
151
|
+
self._embedding_client = embedding_client
|
|
152
|
+
if verbose and embedding_client:
|
|
153
|
+
print("✓ Initialized embedding client")
|
|
154
|
+
|
|
155
|
+
resources = SelectorResources(
|
|
156
|
+
tools_loader=self.tools_loader,
|
|
157
|
+
embedding_client=embedding_client,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
self.strategy = self.adapter_registry.get(self.config.selector, resources=resources)
|
|
161
|
+
if verbose:
|
|
162
|
+
print(f"✓ Initialized selector: {self.config.selector}")
|
|
163
|
+
except Exception as e:
|
|
164
|
+
print(f"Warning: Could not load selector: {e}")
|
|
165
|
+
self.strategy = None
|
|
166
|
+
else:
|
|
167
|
+
self.strategy = None
|
|
168
|
+
|
|
169
|
+
def run(self) -> ExperimentResult:
|
|
170
|
+
"""
|
|
171
|
+
Run tool selection experiment.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
ExperimentResult with predictions and references
|
|
175
|
+
"""
|
|
176
|
+
verbose = getattr(self.config, "verbose", False)
|
|
177
|
+
if verbose:
|
|
178
|
+
print("\nRunning experiment...")
|
|
179
|
+
|
|
180
|
+
predictions = []
|
|
181
|
+
references = []
|
|
182
|
+
metadata = {"total_samples": 0, "failed_samples": 0}
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
samples = self.benchmark_loader.iter_split(
|
|
186
|
+
task_type="tool_selection", split=self.config.split
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
for idx, sample in enumerate(samples):
|
|
190
|
+
if self.config.max_samples and idx >= self.config.max_samples:
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
metadata["total_samples"] += 1
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
# Handle context - may be string or dict
|
|
197
|
+
context = sample.context if hasattr(sample, "context") else {}
|
|
198
|
+
if isinstance(context, str):
|
|
199
|
+
context = {"description": context}
|
|
200
|
+
elif context is None:
|
|
201
|
+
context = {}
|
|
202
|
+
|
|
203
|
+
query = ToolSelectionQuery(
|
|
204
|
+
sample_id=sample.sample_id,
|
|
205
|
+
instruction=sample.instruction,
|
|
206
|
+
context=context,
|
|
207
|
+
candidate_tools=sample.candidate_tools,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if self.strategy is not None:
|
|
211
|
+
pred_tools = self.strategy.predict(query, top_k=self.config.top_k)
|
|
212
|
+
pred_dict = {
|
|
213
|
+
"sample_id": sample.sample_id,
|
|
214
|
+
"predicted_tools": [
|
|
215
|
+
{"tool_id": p.tool_id, "score": p.score} for p in pred_tools
|
|
216
|
+
],
|
|
217
|
+
}
|
|
218
|
+
else:
|
|
219
|
+
pred_dict = {"sample_id": sample.sample_id, "predicted_tools": []}
|
|
220
|
+
|
|
221
|
+
predictions.append(pred_dict)
|
|
222
|
+
|
|
223
|
+
gt = sample.get_typed_ground_truth()
|
|
224
|
+
ref_dict = {
|
|
225
|
+
"sample_id": sample.sample_id,
|
|
226
|
+
"ground_truth_tools": gt.top_k,
|
|
227
|
+
"explanation": gt.explanation if hasattr(gt, "explanation") else None,
|
|
228
|
+
}
|
|
229
|
+
references.append(ref_dict)
|
|
230
|
+
|
|
231
|
+
if verbose and (idx + 1) % 10 == 0:
|
|
232
|
+
print(f" Processed {idx + 1} samples...")
|
|
233
|
+
|
|
234
|
+
except Exception as e:
|
|
235
|
+
metadata["failed_samples"] += 1
|
|
236
|
+
if verbose:
|
|
237
|
+
print(f" Error processing sample {idx}: {e}")
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
except Exception as e:
|
|
241
|
+
print(f"Error iterating samples: {e}")
|
|
242
|
+
|
|
243
|
+
if verbose:
|
|
244
|
+
print("\nCompleted:")
|
|
245
|
+
print(f" Total samples: {metadata['total_samples']}")
|
|
246
|
+
print(f" Failed samples: {metadata['failed_samples']}")
|
|
247
|
+
|
|
248
|
+
return self._create_result(
|
|
249
|
+
predictions=predictions, references=references, metadata=metadata
|
|
250
|
+
)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAGE-Bench Scripts
|
|
3
|
+
|
|
4
|
+
统一的 Benchmark 实验脚本入口。
|
|
5
|
+
|
|
6
|
+
所有实验功能位于 experiments/ 子包:
|
|
7
|
+
- run_paper1_experiments.py: 论文 1 实验统一入口
|
|
8
|
+
- exp_main_*.py: Section 5.2 主实验
|
|
9
|
+
- exp_analysis_*.py: Section 5.3 分析实验
|
|
10
|
+
- exp_cross_dataset.py: Section 5.4 跨数据集泛化
|
|
11
|
+
- exp_training_comparison.py: Section 5.5 训练方法对比
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
# CLI 入口
|
|
15
|
+
sage-bench run --quick
|
|
16
|
+
sage-bench eval --dataset all
|
|
17
|
+
sage-bench train --dry-run
|
|
18
|
+
sage-bench llm status
|
|
19
|
+
|
|
20
|
+
# 直接运行
|
|
21
|
+
python -m sage.benchmark.benchmark_agent.scripts.experiments.run_paper1_experiments --quick
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"experiments",
|
|
26
|
+
]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SAGE-Bench Paper 1 Experiments Package
|
|
3
|
+
|
|
4
|
+
按论文 Experiment Section 组织的实验脚本集合:
|
|
5
|
+
|
|
6
|
+
- Section 5.2 (Main Results):
|
|
7
|
+
- exp_main_timing.py # RQ1: Timing Detection
|
|
8
|
+
- exp_main_planning.py # RQ2: Task Planning
|
|
9
|
+
- exp_main_selection.py # RQ3: Tool Selection
|
|
10
|
+
|
|
11
|
+
- Section 5.3 (Analysis):
|
|
12
|
+
- exp_analysis_error.py # 5.3.1 Error Analysis
|
|
13
|
+
- exp_analysis_scaling.py # 5.3.2 Scaling Analysis
|
|
14
|
+
- exp_analysis_robustness.py # 5.3.3 Robustness Analysis
|
|
15
|
+
- exp_analysis_ablation.py # 5.3.4 Ablation Studies
|
|
16
|
+
|
|
17
|
+
- Section 5.4 (Generalization):
|
|
18
|
+
- exp_cross_dataset.py # Cross-dataset evaluation
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
sage-bench paper1 run # 运行所有实验
|
|
22
|
+
sage-bench paper1 run --section 5.2 # 仅主实验
|
|
23
|
+
sage-bench paper1 timing # 单个实验
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from .exp_utils import (
|
|
27
|
+
get_embedding_client,
|
|
28
|
+
get_llm_client,
|
|
29
|
+
load_benchmark_data,
|
|
30
|
+
save_results,
|
|
31
|
+
setup_experiment_env,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"setup_experiment_env",
|
|
36
|
+
"load_benchmark_data",
|
|
37
|
+
"save_results",
|
|
38
|
+
"get_llm_client",
|
|
39
|
+
"get_embedding_client",
|
|
40
|
+
]
|