isage-benchmark-agent 0.1.0.1__cp311-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
- isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
- isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
- isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
- isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
- isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
- sage/__init__.py +0 -0
- sage/benchmark/__init__.py +0 -0
- sage/benchmark/benchmark_agent/__init__.py +108 -0
- sage/benchmark/benchmark_agent/__main__.py +177 -0
- sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
- sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
- sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
- sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
- sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
- sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
- sage/benchmark/benchmark_agent/data_paths.py +332 -0
- sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
- sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
- sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
- sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
- sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
- sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
- sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
- sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
- sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
- sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
- sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
- sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
- sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
- sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
- sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
- sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
- sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
- sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
- sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
- sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
- sage/benchmark/benchmark_agent/tools_loader.py +212 -0
|
@@ -0,0 +1,602 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified Tool Selection Evaluation Framework
|
|
3
|
+
|
|
4
|
+
This module provides a unified interface for evaluating tool selection methods
|
|
5
|
+
across different benchmarks (SAGE, ACEBench, APIBench, etc.).
|
|
6
|
+
|
|
7
|
+
Design Principles:
|
|
8
|
+
- All methods implement the same ToolSelectorProtocol
|
|
9
|
+
- All benchmarks are converted to unified ToolSelectionSample format
|
|
10
|
+
- Same metrics (Top-K Accuracy, MRR, Recall@K) for all comparisons
|
|
11
|
+
|
|
12
|
+
SOTA Practice (from Gorilla, ToolACE papers):
|
|
13
|
+
- Input: Query + Candidate Tools (tool corpus)
|
|
14
|
+
- Output: Ranked list of selected tools
|
|
15
|
+
- Metrics: Accuracy, Recall@K, MRR (Mean Reciprocal Rank)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
from abc import ABC, abstractmethod
|
|
22
|
+
from dataclasses import dataclass, field
|
|
23
|
+
from typing import Any, Optional, Protocol, runtime_checkable
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# =============================================================================
|
|
29
|
+
# 1. Unified Data Format
|
|
30
|
+
# =============================================================================
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class Tool:
|
|
35
|
+
"""Unified tool representation."""
|
|
36
|
+
|
|
37
|
+
id: str # Unique identifier
|
|
38
|
+
name: str # Display name
|
|
39
|
+
description: str # Tool description
|
|
40
|
+
parameters: dict[str, Any] = field(default_factory=dict) # Optional params schema
|
|
41
|
+
|
|
42
|
+
def __hash__(self):
|
|
43
|
+
return hash(self.id)
|
|
44
|
+
|
|
45
|
+
def __eq__(self, other):
|
|
46
|
+
if isinstance(other, Tool):
|
|
47
|
+
return self.id == other.id
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class ToolSelectionSample:
|
|
53
|
+
"""
|
|
54
|
+
Unified sample format for tool selection evaluation.
|
|
55
|
+
|
|
56
|
+
Supports conversion from SAGE, ACEBench, and other formats.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
sample_id: str
|
|
60
|
+
instruction: str # User query/instruction
|
|
61
|
+
candidate_tools: list[Tool] # Tool corpus to select from
|
|
62
|
+
ground_truth: list[str] # List of correct tool IDs
|
|
63
|
+
context: dict[str, Any] = field(default_factory=dict) # Optional context
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def from_sage(cls, sample: dict[str, Any]) -> ToolSelectionSample:
|
|
67
|
+
"""Convert SAGE benchmark format to unified format."""
|
|
68
|
+
ground_truth_raw = sample.get("ground_truth", [])
|
|
69
|
+
if isinstance(ground_truth_raw, dict):
|
|
70
|
+
ground_truth = ground_truth_raw.get("top_k", [])
|
|
71
|
+
else:
|
|
72
|
+
ground_truth = ground_truth_raw
|
|
73
|
+
|
|
74
|
+
# Convert candidate_tools to Tool objects
|
|
75
|
+
candidate_tools = []
|
|
76
|
+
for t in sample.get("candidate_tools", []):
|
|
77
|
+
if isinstance(t, dict):
|
|
78
|
+
candidate_tools.append(
|
|
79
|
+
Tool(
|
|
80
|
+
id=str(t.get("id", t.get("name", ""))),
|
|
81
|
+
name=str(t.get("name", t.get("id", ""))),
|
|
82
|
+
description=t.get("description", ""),
|
|
83
|
+
parameters=t.get("parameters", {}),
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
elif isinstance(t, str):
|
|
87
|
+
candidate_tools.append(Tool(id=t, name=t, description=""))
|
|
88
|
+
|
|
89
|
+
return cls(
|
|
90
|
+
sample_id=sample.get("sample_id", ""),
|
|
91
|
+
instruction=sample.get("instruction", ""),
|
|
92
|
+
candidate_tools=candidate_tools,
|
|
93
|
+
ground_truth=ground_truth if ground_truth is not None else [],
|
|
94
|
+
context=sample.get("context", {}),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def from_acebench(cls, sample: dict[str, Any]) -> ToolSelectionSample:
|
|
99
|
+
"""Convert ACEBench/ToolACE format to unified format."""
|
|
100
|
+
# ACEBench format: instruction, tools (list of tool dicts), ground_truth
|
|
101
|
+
candidate_tools = []
|
|
102
|
+
for t in sample.get("tools", sample.get("candidate_tools", [])):
|
|
103
|
+
if isinstance(t, dict):
|
|
104
|
+
candidate_tools.append(
|
|
105
|
+
Tool(
|
|
106
|
+
id=str(t.get("name", t.get("id", ""))),
|
|
107
|
+
name=str(t.get("name", t.get("id", ""))),
|
|
108
|
+
description=t.get("description", ""),
|
|
109
|
+
parameters=t.get("parameters", {}),
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
elif isinstance(t, str):
|
|
113
|
+
candidate_tools.append(Tool(id=t, name=t, description=""))
|
|
114
|
+
|
|
115
|
+
ground_truth = sample.get("ground_truth", {})
|
|
116
|
+
if isinstance(ground_truth, dict):
|
|
117
|
+
gt_tools = ground_truth.get("top_k", ground_truth.get("tools", []))
|
|
118
|
+
elif isinstance(ground_truth, list):
|
|
119
|
+
gt_tools = ground_truth
|
|
120
|
+
else:
|
|
121
|
+
gt_tools = []
|
|
122
|
+
|
|
123
|
+
return cls(
|
|
124
|
+
sample_id=sample.get("sample_id", ""),
|
|
125
|
+
instruction=sample.get("instruction", ""),
|
|
126
|
+
candidate_tools=candidate_tools,
|
|
127
|
+
ground_truth=gt_tools if gt_tools is not None else [],
|
|
128
|
+
context=sample.get("context", {}),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass
|
|
133
|
+
class SelectionResult:
|
|
134
|
+
"""Result from a tool selection method."""
|
|
135
|
+
|
|
136
|
+
tool_ids: list[str] # Ranked list of selected tool IDs
|
|
137
|
+
scores: list[float] = field(default_factory=list) # Optional confidence scores
|
|
138
|
+
metadata: dict[str, Any] = field(default_factory=dict) # Method-specific metadata
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# =============================================================================
|
|
142
|
+
# 2. Unified Selector Protocol
|
|
143
|
+
# =============================================================================
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@runtime_checkable
|
|
147
|
+
class ToolSelectorProtocol(Protocol):
|
|
148
|
+
"""
|
|
149
|
+
Protocol that all tool selection methods must implement.
|
|
150
|
+
|
|
151
|
+
This ensures consistency across:
|
|
152
|
+
- Retrieval methods (keyword, embedding)
|
|
153
|
+
- LLM-based methods (direct generation, Gorilla)
|
|
154
|
+
- Hybrid methods (retrieval + reranking)
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
def select(
|
|
158
|
+
self,
|
|
159
|
+
query: str,
|
|
160
|
+
candidate_tools: list[Tool],
|
|
161
|
+
top_k: int = 5,
|
|
162
|
+
) -> SelectionResult:
|
|
163
|
+
"""
|
|
164
|
+
Select top_k tools from candidates for the given query.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
query: User query/instruction
|
|
168
|
+
candidate_tools: List of candidate tools to choose from
|
|
169
|
+
top_k: Number of tools to select
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
SelectionResult with ranked tool IDs
|
|
173
|
+
"""
|
|
174
|
+
...
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# =============================================================================
|
|
178
|
+
# 3. Base Selector Implementations
|
|
179
|
+
# =============================================================================
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class BaseSelectorAdapter(ABC):
|
|
183
|
+
"""Base class for selector adapters."""
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
@abstractmethod
|
|
187
|
+
def name(self) -> str:
|
|
188
|
+
"""Return selector name for logging/reporting."""
|
|
189
|
+
...
|
|
190
|
+
|
|
191
|
+
@abstractmethod
|
|
192
|
+
def select(
|
|
193
|
+
self,
|
|
194
|
+
query: str,
|
|
195
|
+
candidate_tools: list[Tool],
|
|
196
|
+
top_k: int = 5,
|
|
197
|
+
) -> SelectionResult:
|
|
198
|
+
"""Select tools."""
|
|
199
|
+
...
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class KeywordSelectorAdapter(BaseSelectorAdapter):
|
|
203
|
+
"""Adapter for keyword/BM25-based selector."""
|
|
204
|
+
|
|
205
|
+
def __init__(self, selector: Any = None):
|
|
206
|
+
self._selector = selector
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def name(self) -> str:
|
|
210
|
+
return "keyword"
|
|
211
|
+
|
|
212
|
+
def select(
|
|
213
|
+
self,
|
|
214
|
+
query: str,
|
|
215
|
+
candidate_tools: list[Tool],
|
|
216
|
+
top_k: int = 5,
|
|
217
|
+
) -> SelectionResult:
|
|
218
|
+
if self._selector is None:
|
|
219
|
+
self._init_selector()
|
|
220
|
+
|
|
221
|
+
# Pass tool IDs only (schema expects list[str])
|
|
222
|
+
tool_ids_input = [t.id for t in candidate_tools]
|
|
223
|
+
results = self._selector.select(query, tool_ids_input, top_k=top_k)
|
|
224
|
+
|
|
225
|
+
tool_ids = [r.tool_id if hasattr(r, "tool_id") else str(r) for r in results]
|
|
226
|
+
scores = [r.score if hasattr(r, "score") else 1.0 for r in results]
|
|
227
|
+
|
|
228
|
+
return SelectionResult(tool_ids=tool_ids, scores=scores)
|
|
229
|
+
|
|
230
|
+
def _init_selector(self):
|
|
231
|
+
# Use adapter_registry which handles resources correctly
|
|
232
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
233
|
+
|
|
234
|
+
registry = get_adapter_registry()
|
|
235
|
+
self._selector = registry.get("selector.keyword")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class EmbeddingSelectorAdapter(BaseSelectorAdapter):
|
|
239
|
+
"""Adapter for embedding-based selector."""
|
|
240
|
+
|
|
241
|
+
def __init__(self, selector: Any = None):
|
|
242
|
+
self._selector = selector
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def name(self) -> str:
|
|
246
|
+
return "embedding"
|
|
247
|
+
|
|
248
|
+
def select(
|
|
249
|
+
self,
|
|
250
|
+
query: str,
|
|
251
|
+
candidate_tools: list[Tool],
|
|
252
|
+
top_k: int = 5,
|
|
253
|
+
) -> SelectionResult:
|
|
254
|
+
if self._selector is None:
|
|
255
|
+
self._init_selector()
|
|
256
|
+
|
|
257
|
+
# Pass tool IDs only (schema expects list[str])
|
|
258
|
+
tool_ids_input = [t.id for t in candidate_tools]
|
|
259
|
+
results = self._selector.select(query, tool_ids_input, top_k=top_k)
|
|
260
|
+
|
|
261
|
+
tool_ids = [r.tool_id if hasattr(r, "tool_id") else str(r) for r in results]
|
|
262
|
+
scores = [r.score if hasattr(r, "score") else 1.0 for r in results]
|
|
263
|
+
|
|
264
|
+
return SelectionResult(tool_ids=tool_ids, scores=scores)
|
|
265
|
+
|
|
266
|
+
def _init_selector(self):
|
|
267
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
268
|
+
|
|
269
|
+
registry = get_adapter_registry()
|
|
270
|
+
self._selector = registry.get("selector.embedding")
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class HybridSelectorAdapter(BaseSelectorAdapter):
|
|
274
|
+
"""Adapter for hybrid (keyword + embedding) selector."""
|
|
275
|
+
|
|
276
|
+
def __init__(self, selector: Any = None):
|
|
277
|
+
self._selector = selector
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def name(self) -> str:
|
|
281
|
+
return "hybrid"
|
|
282
|
+
|
|
283
|
+
def select(
|
|
284
|
+
self,
|
|
285
|
+
query: str,
|
|
286
|
+
candidate_tools: list[Tool],
|
|
287
|
+
top_k: int = 5,
|
|
288
|
+
) -> SelectionResult:
|
|
289
|
+
if self._selector is None:
|
|
290
|
+
self._init_selector()
|
|
291
|
+
|
|
292
|
+
# Pass tool IDs only (schema expects list[str])
|
|
293
|
+
tool_ids_input = [t.id for t in candidate_tools]
|
|
294
|
+
results = self._selector.select(query, tool_ids_input, top_k=top_k)
|
|
295
|
+
|
|
296
|
+
tool_ids = [r.tool_id if hasattr(r, "tool_id") else str(r) for r in results]
|
|
297
|
+
scores = [r.score if hasattr(r, "score") else 1.0 for r in results]
|
|
298
|
+
|
|
299
|
+
return SelectionResult(tool_ids=tool_ids, scores=scores)
|
|
300
|
+
|
|
301
|
+
def _init_selector(self):
|
|
302
|
+
from sage.benchmark.benchmark_agent import get_adapter_registry
|
|
303
|
+
|
|
304
|
+
registry = get_adapter_registry()
|
|
305
|
+
self._selector = registry.get("selector.hybrid")
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class LLMDirectSelectorAdapter(BaseSelectorAdapter):
|
|
309
|
+
"""
|
|
310
|
+
LLM-based tool selection (direct generation).
|
|
311
|
+
|
|
312
|
+
This is the approach used in ACEBench evaluation - let LLM directly
|
|
313
|
+
choose the best tool based on the query and tool descriptions.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
llm_client: Any = None,
|
|
319
|
+
model_id: Optional[str] = None,
|
|
320
|
+
use_embedded: bool = False,
|
|
321
|
+
):
|
|
322
|
+
self._llm_client = llm_client
|
|
323
|
+
self._model_id = model_id
|
|
324
|
+
self._use_embedded = use_embedded
|
|
325
|
+
|
|
326
|
+
@property
|
|
327
|
+
def name(self) -> str:
|
|
328
|
+
return "llm_direct"
|
|
329
|
+
|
|
330
|
+
def select(
|
|
331
|
+
self,
|
|
332
|
+
query: str,
|
|
333
|
+
candidate_tools: list[Tool],
|
|
334
|
+
top_k: int = 5,
|
|
335
|
+
) -> SelectionResult:
|
|
336
|
+
if self._llm_client is None:
|
|
337
|
+
self._init_client()
|
|
338
|
+
|
|
339
|
+
# Build prompt
|
|
340
|
+
prompt = self._build_prompt(query, candidate_tools, top_k)
|
|
341
|
+
|
|
342
|
+
# Generate response
|
|
343
|
+
try:
|
|
344
|
+
response = self._llm_client.chat([{"role": "user", "content": prompt}])
|
|
345
|
+
except Exception as e:
|
|
346
|
+
logger.warning(f"LLM generation failed: {e}")
|
|
347
|
+
return SelectionResult(tool_ids=[])
|
|
348
|
+
|
|
349
|
+
# Parse response
|
|
350
|
+
tool_ids = self._parse_response(response, candidate_tools)
|
|
351
|
+
|
|
352
|
+
return SelectionResult(tool_ids=tool_ids[:top_k])
|
|
353
|
+
|
|
354
|
+
def _init_client(self):
|
|
355
|
+
from sage.llm import UnifiedInferenceClient
|
|
356
|
+
|
|
357
|
+
# UnifiedInferenceClient.create() handles local/cloud detection
|
|
358
|
+
self._llm_client = UnifiedInferenceClient.create()
|
|
359
|
+
|
|
360
|
+
def _build_prompt(self, query: str, candidate_tools: list[Tool], top_k: int) -> str:
|
|
361
|
+
"""Build prompt for LLM tool selection."""
|
|
362
|
+
tools_desc = "\n".join(f"- {t.name}: {t.description}" for t in candidate_tools)
|
|
363
|
+
|
|
364
|
+
return f"""You are a tool selection assistant. Given a user query and a list of available tools,
|
|
365
|
+
select the {top_k} most relevant tool(s) for the query.
|
|
366
|
+
|
|
367
|
+
Available Tools:
|
|
368
|
+
{tools_desc}
|
|
369
|
+
|
|
370
|
+
User Query: {query}
|
|
371
|
+
|
|
372
|
+
Respond with ONLY the tool name(s), one per line, in order of relevance.
|
|
373
|
+
Do not include any explanation or additional text."""
|
|
374
|
+
|
|
375
|
+
def _parse_response(self, response: str, candidate_tools: list[Tool]) -> list[str]:
|
|
376
|
+
"""Parse LLM response to extract tool names."""
|
|
377
|
+
tool_names = {t.name.lower(): t.id for t in candidate_tools}
|
|
378
|
+
tool_ids_map = {t.id.lower(): t.id for t in candidate_tools}
|
|
379
|
+
|
|
380
|
+
selected = []
|
|
381
|
+
for line in response.strip().split("\n"):
|
|
382
|
+
line = line.strip().strip("-").strip("*").strip()
|
|
383
|
+
line_lower = line.lower()
|
|
384
|
+
|
|
385
|
+
# Try exact match
|
|
386
|
+
if line_lower in tool_names:
|
|
387
|
+
selected.append(tool_names[line_lower])
|
|
388
|
+
elif line_lower in tool_ids_map:
|
|
389
|
+
selected.append(tool_ids_map[line_lower])
|
|
390
|
+
else:
|
|
391
|
+
# Try partial match
|
|
392
|
+
for name, tool_id in tool_names.items():
|
|
393
|
+
if name in line_lower or line_lower in name:
|
|
394
|
+
if tool_id not in selected:
|
|
395
|
+
selected.append(tool_id)
|
|
396
|
+
break
|
|
397
|
+
|
|
398
|
+
return selected
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
# =============================================================================
|
|
402
|
+
# 4. Unified Metrics
|
|
403
|
+
# =============================================================================
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@dataclass
|
|
407
|
+
class EvaluationMetrics:
|
|
408
|
+
"""Unified evaluation metrics."""
|
|
409
|
+
|
|
410
|
+
# Core metrics
|
|
411
|
+
top_1_accuracy: float = 0.0
|
|
412
|
+
top_3_accuracy: float = 0.0
|
|
413
|
+
top_5_accuracy: float = 0.0
|
|
414
|
+
mrr: float = 0.0 # Mean Reciprocal Rank
|
|
415
|
+
|
|
416
|
+
# Additional metrics
|
|
417
|
+
recall_at_k: dict[int, float] = field(default_factory=dict)
|
|
418
|
+
precision_at_k: dict[int, float] = field(default_factory=dict)
|
|
419
|
+
|
|
420
|
+
# Metadata
|
|
421
|
+
total_samples: int = 0
|
|
422
|
+
method_name: str = ""
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def compute_metrics(
|
|
426
|
+
predictions: list[SelectionResult],
|
|
427
|
+
references: list[list[str]], # Ground truth tool IDs for each sample
|
|
428
|
+
k_values: list[int] | None = None,
|
|
429
|
+
) -> EvaluationMetrics:
|
|
430
|
+
"""
|
|
431
|
+
Compute unified evaluation metrics.
|
|
432
|
+
|
|
433
|
+
Args:
|
|
434
|
+
predictions: List of SelectionResult from selector
|
|
435
|
+
references: List of ground truth tool ID lists
|
|
436
|
+
k_values: K values for Top-K and Recall@K metrics
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
EvaluationMetrics with all computed metrics
|
|
440
|
+
"""
|
|
441
|
+
if k_values is None:
|
|
442
|
+
k_values = [1, 3, 5]
|
|
443
|
+
|
|
444
|
+
if len(predictions) != len(references):
|
|
445
|
+
raise ValueError(
|
|
446
|
+
f"Predictions ({len(predictions)}) and references ({len(references)}) "
|
|
447
|
+
"must have same length"
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
n = len(predictions)
|
|
451
|
+
if n == 0:
|
|
452
|
+
return EvaluationMetrics()
|
|
453
|
+
|
|
454
|
+
# Initialize counters
|
|
455
|
+
top_k_correct = dict.fromkeys(k_values, 0)
|
|
456
|
+
reciprocal_ranks: list[float] = []
|
|
457
|
+
recall_at_k: dict[int, list[float]] = {k: [] for k in k_values}
|
|
458
|
+
precision_at_k: dict[int, list[float]] = {k: [] for k in k_values}
|
|
459
|
+
|
|
460
|
+
for pred, ref in zip(predictions, references):
|
|
461
|
+
pred_ids = pred.tool_ids
|
|
462
|
+
ref_set = set(ref)
|
|
463
|
+
|
|
464
|
+
# Top-K accuracy: any correct tool in top-k predictions
|
|
465
|
+
for k in k_values:
|
|
466
|
+
top_k_preds = set(pred_ids[:k])
|
|
467
|
+
if top_k_preds & ref_set:
|
|
468
|
+
top_k_correct[k] += 1
|
|
469
|
+
|
|
470
|
+
# MRR: position of first correct prediction
|
|
471
|
+
rr = 0.0
|
|
472
|
+
for i, tool_id in enumerate(pred_ids):
|
|
473
|
+
if tool_id in ref_set:
|
|
474
|
+
rr = 1.0 / (i + 1)
|
|
475
|
+
break
|
|
476
|
+
reciprocal_ranks.append(rr)
|
|
477
|
+
|
|
478
|
+
# Recall@K and Precision@K
|
|
479
|
+
for k in k_values:
|
|
480
|
+
top_k_preds = set(pred_ids[:k])
|
|
481
|
+
hits = len(top_k_preds & ref_set)
|
|
482
|
+
|
|
483
|
+
recall = hits / len(ref_set) if ref_set else 0.0
|
|
484
|
+
precision = hits / k
|
|
485
|
+
|
|
486
|
+
recall_at_k[k].append(recall)
|
|
487
|
+
precision_at_k[k].append(precision)
|
|
488
|
+
|
|
489
|
+
# Compute averages
|
|
490
|
+
metrics = EvaluationMetrics(
|
|
491
|
+
top_1_accuracy=top_k_correct.get(1, 0) / n,
|
|
492
|
+
top_3_accuracy=top_k_correct.get(3, 0) / n,
|
|
493
|
+
top_5_accuracy=top_k_correct.get(5, 0) / n,
|
|
494
|
+
mrr=sum(reciprocal_ranks) / n,
|
|
495
|
+
recall_at_k={k: sum(v) / n for k, v in recall_at_k.items()},
|
|
496
|
+
precision_at_k={k: sum(v) / n for k, v in precision_at_k.items()},
|
|
497
|
+
total_samples=n,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
return metrics
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
# =============================================================================
|
|
504
|
+
# 5. Unified Evaluator
|
|
505
|
+
# =============================================================================
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
class UnifiedToolSelectionEvaluator:
|
|
509
|
+
"""
|
|
510
|
+
Unified evaluator for tool selection methods.
|
|
511
|
+
|
|
512
|
+
Supports multiple benchmarks and methods with consistent evaluation.
|
|
513
|
+
"""
|
|
514
|
+
|
|
515
|
+
def __init__(self):
|
|
516
|
+
self.selectors: dict[str, BaseSelectorAdapter] = {}
|
|
517
|
+
|
|
518
|
+
def register_selector(self, name: str, selector: BaseSelectorAdapter):
|
|
519
|
+
"""Register a selector for evaluation."""
|
|
520
|
+
self.selectors[name] = selector
|
|
521
|
+
|
|
522
|
+
def register_default_selectors(self, use_embedded_llm: bool = False):
|
|
523
|
+
"""Register default set of selectors."""
|
|
524
|
+
self.selectors["keyword"] = KeywordSelectorAdapter()
|
|
525
|
+
self.selectors["embedding"] = EmbeddingSelectorAdapter()
|
|
526
|
+
self.selectors["hybrid"] = HybridSelectorAdapter()
|
|
527
|
+
self.selectors["llm_direct"] = LLMDirectSelectorAdapter(use_embedded=use_embedded_llm)
|
|
528
|
+
|
|
529
|
+
def evaluate(
|
|
530
|
+
self,
|
|
531
|
+
samples: list[ToolSelectionSample],
|
|
532
|
+
selector_names: Optional[list[str]] = None,
|
|
533
|
+
top_k: int = 5,
|
|
534
|
+
) -> dict[str, EvaluationMetrics]:
|
|
535
|
+
"""
|
|
536
|
+
Evaluate selectors on samples.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
samples: List of unified ToolSelectionSample
|
|
540
|
+
selector_names: Names of selectors to evaluate (None = all)
|
|
541
|
+
top_k: Top-K value for selection
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
Dict mapping selector name to EvaluationMetrics
|
|
545
|
+
"""
|
|
546
|
+
if selector_names is None:
|
|
547
|
+
selector_names = list(self.selectors.keys())
|
|
548
|
+
|
|
549
|
+
results = {}
|
|
550
|
+
|
|
551
|
+
for name in selector_names:
|
|
552
|
+
if name not in self.selectors:
|
|
553
|
+
logger.warning(f"Selector '{name}' not registered, skipping")
|
|
554
|
+
continue
|
|
555
|
+
|
|
556
|
+
selector = self.selectors[name]
|
|
557
|
+
logger.info(f"Evaluating selector: {name}")
|
|
558
|
+
|
|
559
|
+
predictions = []
|
|
560
|
+
references = []
|
|
561
|
+
|
|
562
|
+
for sample in samples:
|
|
563
|
+
try:
|
|
564
|
+
result = selector.select(
|
|
565
|
+
query=sample.instruction,
|
|
566
|
+
candidate_tools=sample.candidate_tools,
|
|
567
|
+
top_k=top_k,
|
|
568
|
+
)
|
|
569
|
+
predictions.append(result)
|
|
570
|
+
references.append(sample.ground_truth)
|
|
571
|
+
except Exception as e:
|
|
572
|
+
logger.warning(f"Selector {name} failed on sample {sample.sample_id}: {e}")
|
|
573
|
+
predictions.append(SelectionResult(tool_ids=[]))
|
|
574
|
+
references.append(sample.ground_truth)
|
|
575
|
+
|
|
576
|
+
metrics = compute_metrics(predictions, references)
|
|
577
|
+
metrics.method_name = name
|
|
578
|
+
results[name] = metrics
|
|
579
|
+
|
|
580
|
+
return results
|
|
581
|
+
|
|
582
|
+
def print_results(self, results: dict[str, EvaluationMetrics]):
|
|
583
|
+
"""Print evaluation results in a table format."""
|
|
584
|
+
print("\n" + "=" * 80)
|
|
585
|
+
print("Tool Selection Evaluation Results")
|
|
586
|
+
print("=" * 80)
|
|
587
|
+
print(
|
|
588
|
+
f"{'Method':<15} {'Top-1':>10} {'Top-3':>10} {'Top-5':>10} {'MRR':>10} {'Samples':>10}"
|
|
589
|
+
)
|
|
590
|
+
print("-" * 80)
|
|
591
|
+
|
|
592
|
+
for name, metrics in results.items():
|
|
593
|
+
print(
|
|
594
|
+
f"{name:<15} "
|
|
595
|
+
f"{metrics.top_1_accuracy * 100:>9.1f}% "
|
|
596
|
+
f"{metrics.top_3_accuracy * 100:>9.1f}% "
|
|
597
|
+
f"{metrics.top_5_accuracy * 100:>9.1f}% "
|
|
598
|
+
f"{metrics.mrr * 100:>9.1f}% "
|
|
599
|
+
f"{metrics.total_samples:>10}"
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
print("=" * 80)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Experiment implementations for agent benchmark evaluation.
|
|
3
|
+
|
|
4
|
+
Available experiments:
|
|
5
|
+
- ToolSelectionExperiment: Tool retrieval and ranking
|
|
6
|
+
- PlanningExperiment: Multi-step planning with tool composition
|
|
7
|
+
- TimingDetectionExperiment: Timing judgment for tool invocation
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from sage.benchmark.benchmark_agent.experiments.base_experiment import (
|
|
11
|
+
CONFIG_TYPES,
|
|
12
|
+
BaseExperiment,
|
|
13
|
+
ExperimentConfig,
|
|
14
|
+
ExperimentResult,
|
|
15
|
+
PlanningConfig,
|
|
16
|
+
PlanningPrediction,
|
|
17
|
+
PlanStep,
|
|
18
|
+
ReportConfig,
|
|
19
|
+
TimingDecision,
|
|
20
|
+
TimingDetectionConfig,
|
|
21
|
+
ToolPrediction,
|
|
22
|
+
ToolSelectionConfig,
|
|
23
|
+
create_config,
|
|
24
|
+
)
|
|
25
|
+
from sage.benchmark.benchmark_agent.experiments.planning_exp import (
|
|
26
|
+
PlanningExperiment,
|
|
27
|
+
PlanningTask,
|
|
28
|
+
)
|
|
29
|
+
from sage.benchmark.benchmark_agent.experiments.timing_detection_exp import (
|
|
30
|
+
TimingDetectionExperiment,
|
|
31
|
+
TimingMessage,
|
|
32
|
+
)
|
|
33
|
+
from sage.benchmark.benchmark_agent.experiments.tool_selection_exp import (
|
|
34
|
+
ToolSelectionExperiment,
|
|
35
|
+
ToolSelectionQuery,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
# Base classes
|
|
40
|
+
"BaseExperiment",
|
|
41
|
+
"ExperimentConfig",
|
|
42
|
+
"ExperimentResult",
|
|
43
|
+
# Config models
|
|
44
|
+
"ToolSelectionConfig",
|
|
45
|
+
"PlanningConfig",
|
|
46
|
+
"TimingDetectionConfig",
|
|
47
|
+
"ReportConfig",
|
|
48
|
+
# Result/task models
|
|
49
|
+
"ToolPrediction",
|
|
50
|
+
"PlanStep",
|
|
51
|
+
"PlanningPrediction",
|
|
52
|
+
"TimingDecision",
|
|
53
|
+
"ToolSelectionQuery",
|
|
54
|
+
"PlanningTask",
|
|
55
|
+
"TimingMessage",
|
|
56
|
+
# Utilities
|
|
57
|
+
"CONFIG_TYPES",
|
|
58
|
+
"create_config",
|
|
59
|
+
# Experiment implementations
|
|
60
|
+
"ToolSelectionExperiment",
|
|
61
|
+
"PlanningExperiment",
|
|
62
|
+
"TimingDetectionExperiment",
|
|
63
|
+
]
|