isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,602 @@
1
+ """
2
+ Unified Tool Selection Evaluation Framework
3
+
4
+ This module provides a unified interface for evaluating tool selection methods
5
+ across different benchmarks (SAGE, ACEBench, APIBench, etc.).
6
+
7
+ Design Principles:
8
+ - All methods implement the same ToolSelectorProtocol
9
+ - All benchmarks are converted to unified ToolSelectionSample format
10
+ - Same metrics (Top-K Accuracy, MRR, Recall@K) for all comparisons
11
+
12
+ SOTA Practice (from Gorilla, ToolACE papers):
13
+ - Input: Query + Candidate Tools (tool corpus)
14
+ - Output: Ranked list of selected tools
15
+ - Metrics: Accuracy, Recall@K, MRR (Mean Reciprocal Rank)
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import logging
21
+ from abc import ABC, abstractmethod
22
+ from dataclasses import dataclass, field
23
+ from typing import Any, Optional, Protocol, runtime_checkable
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # =============================================================================
29
+ # 1. Unified Data Format
30
+ # =============================================================================
31
+
32
+
33
+ @dataclass
34
+ class Tool:
35
+ """Unified tool representation."""
36
+
37
+ id: str # Unique identifier
38
+ name: str # Display name
39
+ description: str # Tool description
40
+ parameters: dict[str, Any] = field(default_factory=dict) # Optional params schema
41
+
42
+ def __hash__(self):
43
+ return hash(self.id)
44
+
45
+ def __eq__(self, other):
46
+ if isinstance(other, Tool):
47
+ return self.id == other.id
48
+ return False
49
+
50
+
51
+ @dataclass
52
+ class ToolSelectionSample:
53
+ """
54
+ Unified sample format for tool selection evaluation.
55
+
56
+ Supports conversion from SAGE, ACEBench, and other formats.
57
+ """
58
+
59
+ sample_id: str
60
+ instruction: str # User query/instruction
61
+ candidate_tools: list[Tool] # Tool corpus to select from
62
+ ground_truth: list[str] # List of correct tool IDs
63
+ context: dict[str, Any] = field(default_factory=dict) # Optional context
64
+
65
+ @classmethod
66
+ def from_sage(cls, sample: dict[str, Any]) -> ToolSelectionSample:
67
+ """Convert SAGE benchmark format to unified format."""
68
+ ground_truth_raw = sample.get("ground_truth", [])
69
+ if isinstance(ground_truth_raw, dict):
70
+ ground_truth = ground_truth_raw.get("top_k", [])
71
+ else:
72
+ ground_truth = ground_truth_raw
73
+
74
+ # Convert candidate_tools to Tool objects
75
+ candidate_tools = []
76
+ for t in sample.get("candidate_tools", []):
77
+ if isinstance(t, dict):
78
+ candidate_tools.append(
79
+ Tool(
80
+ id=str(t.get("id", t.get("name", ""))),
81
+ name=str(t.get("name", t.get("id", ""))),
82
+ description=t.get("description", ""),
83
+ parameters=t.get("parameters", {}),
84
+ )
85
+ )
86
+ elif isinstance(t, str):
87
+ candidate_tools.append(Tool(id=t, name=t, description=""))
88
+
89
+ return cls(
90
+ sample_id=sample.get("sample_id", ""),
91
+ instruction=sample.get("instruction", ""),
92
+ candidate_tools=candidate_tools,
93
+ ground_truth=ground_truth if ground_truth is not None else [],
94
+ context=sample.get("context", {}),
95
+ )
96
+
97
+ @classmethod
98
+ def from_acebench(cls, sample: dict[str, Any]) -> ToolSelectionSample:
99
+ """Convert ACEBench/ToolACE format to unified format."""
100
+ # ACEBench format: instruction, tools (list of tool dicts), ground_truth
101
+ candidate_tools = []
102
+ for t in sample.get("tools", sample.get("candidate_tools", [])):
103
+ if isinstance(t, dict):
104
+ candidate_tools.append(
105
+ Tool(
106
+ id=str(t.get("name", t.get("id", ""))),
107
+ name=str(t.get("name", t.get("id", ""))),
108
+ description=t.get("description", ""),
109
+ parameters=t.get("parameters", {}),
110
+ )
111
+ )
112
+ elif isinstance(t, str):
113
+ candidate_tools.append(Tool(id=t, name=t, description=""))
114
+
115
+ ground_truth = sample.get("ground_truth", {})
116
+ if isinstance(ground_truth, dict):
117
+ gt_tools = ground_truth.get("top_k", ground_truth.get("tools", []))
118
+ elif isinstance(ground_truth, list):
119
+ gt_tools = ground_truth
120
+ else:
121
+ gt_tools = []
122
+
123
+ return cls(
124
+ sample_id=sample.get("sample_id", ""),
125
+ instruction=sample.get("instruction", ""),
126
+ candidate_tools=candidate_tools,
127
+ ground_truth=gt_tools if gt_tools is not None else [],
128
+ context=sample.get("context", {}),
129
+ )
130
+
131
+
132
+ @dataclass
133
+ class SelectionResult:
134
+ """Result from a tool selection method."""
135
+
136
+ tool_ids: list[str] # Ranked list of selected tool IDs
137
+ scores: list[float] = field(default_factory=list) # Optional confidence scores
138
+ metadata: dict[str, Any] = field(default_factory=dict) # Method-specific metadata
139
+
140
+
141
+ # =============================================================================
142
+ # 2. Unified Selector Protocol
143
+ # =============================================================================
144
+
145
+
146
+ @runtime_checkable
147
+ class ToolSelectorProtocol(Protocol):
148
+ """
149
+ Protocol that all tool selection methods must implement.
150
+
151
+ This ensures consistency across:
152
+ - Retrieval methods (keyword, embedding)
153
+ - LLM-based methods (direct generation, Gorilla)
154
+ - Hybrid methods (retrieval + reranking)
155
+ """
156
+
157
+ def select(
158
+ self,
159
+ query: str,
160
+ candidate_tools: list[Tool],
161
+ top_k: int = 5,
162
+ ) -> SelectionResult:
163
+ """
164
+ Select top_k tools from candidates for the given query.
165
+
166
+ Args:
167
+ query: User query/instruction
168
+ candidate_tools: List of candidate tools to choose from
169
+ top_k: Number of tools to select
170
+
171
+ Returns:
172
+ SelectionResult with ranked tool IDs
173
+ """
174
+ ...
175
+
176
+
177
+ # =============================================================================
178
+ # 3. Base Selector Implementations
179
+ # =============================================================================
180
+
181
+
182
+ class BaseSelectorAdapter(ABC):
183
+ """Base class for selector adapters."""
184
+
185
+ @property
186
+ @abstractmethod
187
+ def name(self) -> str:
188
+ """Return selector name for logging/reporting."""
189
+ ...
190
+
191
+ @abstractmethod
192
+ def select(
193
+ self,
194
+ query: str,
195
+ candidate_tools: list[Tool],
196
+ top_k: int = 5,
197
+ ) -> SelectionResult:
198
+ """Select tools."""
199
+ ...
200
+
201
+
202
+ class KeywordSelectorAdapter(BaseSelectorAdapter):
203
+ """Adapter for keyword/BM25-based selector."""
204
+
205
+ def __init__(self, selector: Any = None):
206
+ self._selector = selector
207
+
208
+ @property
209
+ def name(self) -> str:
210
+ return "keyword"
211
+
212
+ def select(
213
+ self,
214
+ query: str,
215
+ candidate_tools: list[Tool],
216
+ top_k: int = 5,
217
+ ) -> SelectionResult:
218
+ if self._selector is None:
219
+ self._init_selector()
220
+
221
+ # Pass tool IDs only (schema expects list[str])
222
+ tool_ids_input = [t.id for t in candidate_tools]
223
+ results = self._selector.select(query, tool_ids_input, top_k=top_k)
224
+
225
+ tool_ids = [r.tool_id if hasattr(r, "tool_id") else str(r) for r in results]
226
+ scores = [r.score if hasattr(r, "score") else 1.0 for r in results]
227
+
228
+ return SelectionResult(tool_ids=tool_ids, scores=scores)
229
+
230
+ def _init_selector(self):
231
+ # Use adapter_registry which handles resources correctly
232
+ from sage.benchmark.benchmark_agent import get_adapter_registry
233
+
234
+ registry = get_adapter_registry()
235
+ self._selector = registry.get("selector.keyword")
236
+
237
+
238
+ class EmbeddingSelectorAdapter(BaseSelectorAdapter):
239
+ """Adapter for embedding-based selector."""
240
+
241
+ def __init__(self, selector: Any = None):
242
+ self._selector = selector
243
+
244
+ @property
245
+ def name(self) -> str:
246
+ return "embedding"
247
+
248
+ def select(
249
+ self,
250
+ query: str,
251
+ candidate_tools: list[Tool],
252
+ top_k: int = 5,
253
+ ) -> SelectionResult:
254
+ if self._selector is None:
255
+ self._init_selector()
256
+
257
+ # Pass tool IDs only (schema expects list[str])
258
+ tool_ids_input = [t.id for t in candidate_tools]
259
+ results = self._selector.select(query, tool_ids_input, top_k=top_k)
260
+
261
+ tool_ids = [r.tool_id if hasattr(r, "tool_id") else str(r) for r in results]
262
+ scores = [r.score if hasattr(r, "score") else 1.0 for r in results]
263
+
264
+ return SelectionResult(tool_ids=tool_ids, scores=scores)
265
+
266
+ def _init_selector(self):
267
+ from sage.benchmark.benchmark_agent import get_adapter_registry
268
+
269
+ registry = get_adapter_registry()
270
+ self._selector = registry.get("selector.embedding")
271
+
272
+
273
+ class HybridSelectorAdapter(BaseSelectorAdapter):
274
+ """Adapter for hybrid (keyword + embedding) selector."""
275
+
276
+ def __init__(self, selector: Any = None):
277
+ self._selector = selector
278
+
279
+ @property
280
+ def name(self) -> str:
281
+ return "hybrid"
282
+
283
+ def select(
284
+ self,
285
+ query: str,
286
+ candidate_tools: list[Tool],
287
+ top_k: int = 5,
288
+ ) -> SelectionResult:
289
+ if self._selector is None:
290
+ self._init_selector()
291
+
292
+ # Pass tool IDs only (schema expects list[str])
293
+ tool_ids_input = [t.id for t in candidate_tools]
294
+ results = self._selector.select(query, tool_ids_input, top_k=top_k)
295
+
296
+ tool_ids = [r.tool_id if hasattr(r, "tool_id") else str(r) for r in results]
297
+ scores = [r.score if hasattr(r, "score") else 1.0 for r in results]
298
+
299
+ return SelectionResult(tool_ids=tool_ids, scores=scores)
300
+
301
+ def _init_selector(self):
302
+ from sage.benchmark.benchmark_agent import get_adapter_registry
303
+
304
+ registry = get_adapter_registry()
305
+ self._selector = registry.get("selector.hybrid")
306
+
307
+
308
+ class LLMDirectSelectorAdapter(BaseSelectorAdapter):
309
+ """
310
+ LLM-based tool selection (direct generation).
311
+
312
+ This is the approach used in ACEBench evaluation - let LLM directly
313
+ choose the best tool based on the query and tool descriptions.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ llm_client: Any = None,
319
+ model_id: Optional[str] = None,
320
+ use_embedded: bool = False,
321
+ ):
322
+ self._llm_client = llm_client
323
+ self._model_id = model_id
324
+ self._use_embedded = use_embedded
325
+
326
+ @property
327
+ def name(self) -> str:
328
+ return "llm_direct"
329
+
330
+ def select(
331
+ self,
332
+ query: str,
333
+ candidate_tools: list[Tool],
334
+ top_k: int = 5,
335
+ ) -> SelectionResult:
336
+ if self._llm_client is None:
337
+ self._init_client()
338
+
339
+ # Build prompt
340
+ prompt = self._build_prompt(query, candidate_tools, top_k)
341
+
342
+ # Generate response
343
+ try:
344
+ response = self._llm_client.chat([{"role": "user", "content": prompt}])
345
+ except Exception as e:
346
+ logger.warning(f"LLM generation failed: {e}")
347
+ return SelectionResult(tool_ids=[])
348
+
349
+ # Parse response
350
+ tool_ids = self._parse_response(response, candidate_tools)
351
+
352
+ return SelectionResult(tool_ids=tool_ids[:top_k])
353
+
354
+ def _init_client(self):
355
+ from sage.llm import UnifiedInferenceClient
356
+
357
+ # UnifiedInferenceClient.create() handles local/cloud detection
358
+ self._llm_client = UnifiedInferenceClient.create()
359
+
360
+ def _build_prompt(self, query: str, candidate_tools: list[Tool], top_k: int) -> str:
361
+ """Build prompt for LLM tool selection."""
362
+ tools_desc = "\n".join(f"- {t.name}: {t.description}" for t in candidate_tools)
363
+
364
+ return f"""You are a tool selection assistant. Given a user query and a list of available tools,
365
+ select the {top_k} most relevant tool(s) for the query.
366
+
367
+ Available Tools:
368
+ {tools_desc}
369
+
370
+ User Query: {query}
371
+
372
+ Respond with ONLY the tool name(s), one per line, in order of relevance.
373
+ Do not include any explanation or additional text."""
374
+
375
+ def _parse_response(self, response: str, candidate_tools: list[Tool]) -> list[str]:
376
+ """Parse LLM response to extract tool names."""
377
+ tool_names = {t.name.lower(): t.id for t in candidate_tools}
378
+ tool_ids_map = {t.id.lower(): t.id for t in candidate_tools}
379
+
380
+ selected = []
381
+ for line in response.strip().split("\n"):
382
+ line = line.strip().strip("-").strip("*").strip()
383
+ line_lower = line.lower()
384
+
385
+ # Try exact match
386
+ if line_lower in tool_names:
387
+ selected.append(tool_names[line_lower])
388
+ elif line_lower in tool_ids_map:
389
+ selected.append(tool_ids_map[line_lower])
390
+ else:
391
+ # Try partial match
392
+ for name, tool_id in tool_names.items():
393
+ if name in line_lower or line_lower in name:
394
+ if tool_id not in selected:
395
+ selected.append(tool_id)
396
+ break
397
+
398
+ return selected
399
+
400
+
401
+ # =============================================================================
402
+ # 4. Unified Metrics
403
+ # =============================================================================
404
+
405
+
406
+ @dataclass
407
+ class EvaluationMetrics:
408
+ """Unified evaluation metrics."""
409
+
410
+ # Core metrics
411
+ top_1_accuracy: float = 0.0
412
+ top_3_accuracy: float = 0.0
413
+ top_5_accuracy: float = 0.0
414
+ mrr: float = 0.0 # Mean Reciprocal Rank
415
+
416
+ # Additional metrics
417
+ recall_at_k: dict[int, float] = field(default_factory=dict)
418
+ precision_at_k: dict[int, float] = field(default_factory=dict)
419
+
420
+ # Metadata
421
+ total_samples: int = 0
422
+ method_name: str = ""
423
+
424
+
425
+ def compute_metrics(
426
+ predictions: list[SelectionResult],
427
+ references: list[list[str]], # Ground truth tool IDs for each sample
428
+ k_values: list[int] | None = None,
429
+ ) -> EvaluationMetrics:
430
+ """
431
+ Compute unified evaluation metrics.
432
+
433
+ Args:
434
+ predictions: List of SelectionResult from selector
435
+ references: List of ground truth tool ID lists
436
+ k_values: K values for Top-K and Recall@K metrics
437
+
438
+ Returns:
439
+ EvaluationMetrics with all computed metrics
440
+ """
441
+ if k_values is None:
442
+ k_values = [1, 3, 5]
443
+
444
+ if len(predictions) != len(references):
445
+ raise ValueError(
446
+ f"Predictions ({len(predictions)}) and references ({len(references)}) "
447
+ "must have same length"
448
+ )
449
+
450
+ n = len(predictions)
451
+ if n == 0:
452
+ return EvaluationMetrics()
453
+
454
+ # Initialize counters
455
+ top_k_correct = dict.fromkeys(k_values, 0)
456
+ reciprocal_ranks: list[float] = []
457
+ recall_at_k: dict[int, list[float]] = {k: [] for k in k_values}
458
+ precision_at_k: dict[int, list[float]] = {k: [] for k in k_values}
459
+
460
+ for pred, ref in zip(predictions, references):
461
+ pred_ids = pred.tool_ids
462
+ ref_set = set(ref)
463
+
464
+ # Top-K accuracy: any correct tool in top-k predictions
465
+ for k in k_values:
466
+ top_k_preds = set(pred_ids[:k])
467
+ if top_k_preds & ref_set:
468
+ top_k_correct[k] += 1
469
+
470
+ # MRR: position of first correct prediction
471
+ rr = 0.0
472
+ for i, tool_id in enumerate(pred_ids):
473
+ if tool_id in ref_set:
474
+ rr = 1.0 / (i + 1)
475
+ break
476
+ reciprocal_ranks.append(rr)
477
+
478
+ # Recall@K and Precision@K
479
+ for k in k_values:
480
+ top_k_preds = set(pred_ids[:k])
481
+ hits = len(top_k_preds & ref_set)
482
+
483
+ recall = hits / len(ref_set) if ref_set else 0.0
484
+ precision = hits / k
485
+
486
+ recall_at_k[k].append(recall)
487
+ precision_at_k[k].append(precision)
488
+
489
+ # Compute averages
490
+ metrics = EvaluationMetrics(
491
+ top_1_accuracy=top_k_correct.get(1, 0) / n,
492
+ top_3_accuracy=top_k_correct.get(3, 0) / n,
493
+ top_5_accuracy=top_k_correct.get(5, 0) / n,
494
+ mrr=sum(reciprocal_ranks) / n,
495
+ recall_at_k={k: sum(v) / n for k, v in recall_at_k.items()},
496
+ precision_at_k={k: sum(v) / n for k, v in precision_at_k.items()},
497
+ total_samples=n,
498
+ )
499
+
500
+ return metrics
501
+
502
+
503
+ # =============================================================================
504
+ # 5. Unified Evaluator
505
+ # =============================================================================
506
+
507
+
508
+ class UnifiedToolSelectionEvaluator:
509
+ """
510
+ Unified evaluator for tool selection methods.
511
+
512
+ Supports multiple benchmarks and methods with consistent evaluation.
513
+ """
514
+
515
+ def __init__(self):
516
+ self.selectors: dict[str, BaseSelectorAdapter] = {}
517
+
518
+ def register_selector(self, name: str, selector: BaseSelectorAdapter):
519
+ """Register a selector for evaluation."""
520
+ self.selectors[name] = selector
521
+
522
+ def register_default_selectors(self, use_embedded_llm: bool = False):
523
+ """Register default set of selectors."""
524
+ self.selectors["keyword"] = KeywordSelectorAdapter()
525
+ self.selectors["embedding"] = EmbeddingSelectorAdapter()
526
+ self.selectors["hybrid"] = HybridSelectorAdapter()
527
+ self.selectors["llm_direct"] = LLMDirectSelectorAdapter(use_embedded=use_embedded_llm)
528
+
529
+ def evaluate(
530
+ self,
531
+ samples: list[ToolSelectionSample],
532
+ selector_names: Optional[list[str]] = None,
533
+ top_k: int = 5,
534
+ ) -> dict[str, EvaluationMetrics]:
535
+ """
536
+ Evaluate selectors on samples.
537
+
538
+ Args:
539
+ samples: List of unified ToolSelectionSample
540
+ selector_names: Names of selectors to evaluate (None = all)
541
+ top_k: Top-K value for selection
542
+
543
+ Returns:
544
+ Dict mapping selector name to EvaluationMetrics
545
+ """
546
+ if selector_names is None:
547
+ selector_names = list(self.selectors.keys())
548
+
549
+ results = {}
550
+
551
+ for name in selector_names:
552
+ if name not in self.selectors:
553
+ logger.warning(f"Selector '{name}' not registered, skipping")
554
+ continue
555
+
556
+ selector = self.selectors[name]
557
+ logger.info(f"Evaluating selector: {name}")
558
+
559
+ predictions = []
560
+ references = []
561
+
562
+ for sample in samples:
563
+ try:
564
+ result = selector.select(
565
+ query=sample.instruction,
566
+ candidate_tools=sample.candidate_tools,
567
+ top_k=top_k,
568
+ )
569
+ predictions.append(result)
570
+ references.append(sample.ground_truth)
571
+ except Exception as e:
572
+ logger.warning(f"Selector {name} failed on sample {sample.sample_id}: {e}")
573
+ predictions.append(SelectionResult(tool_ids=[]))
574
+ references.append(sample.ground_truth)
575
+
576
+ metrics = compute_metrics(predictions, references)
577
+ metrics.method_name = name
578
+ results[name] = metrics
579
+
580
+ return results
581
+
582
+ def print_results(self, results: dict[str, EvaluationMetrics]):
583
+ """Print evaluation results in a table format."""
584
+ print("\n" + "=" * 80)
585
+ print("Tool Selection Evaluation Results")
586
+ print("=" * 80)
587
+ print(
588
+ f"{'Method':<15} {'Top-1':>10} {'Top-3':>10} {'Top-5':>10} {'MRR':>10} {'Samples':>10}"
589
+ )
590
+ print("-" * 80)
591
+
592
+ for name, metrics in results.items():
593
+ print(
594
+ f"{name:<15} "
595
+ f"{metrics.top_1_accuracy * 100:>9.1f}% "
596
+ f"{metrics.top_3_accuracy * 100:>9.1f}% "
597
+ f"{metrics.top_5_accuracy * 100:>9.1f}% "
598
+ f"{metrics.mrr * 100:>9.1f}% "
599
+ f"{metrics.total_samples:>10}"
600
+ )
601
+
602
+ print("=" * 80)
@@ -0,0 +1,63 @@
1
+ """
2
+ Experiment implementations for agent benchmark evaluation.
3
+
4
+ Available experiments:
5
+ - ToolSelectionExperiment: Tool retrieval and ranking
6
+ - PlanningExperiment: Multi-step planning with tool composition
7
+ - TimingDetectionExperiment: Timing judgment for tool invocation
8
+ """
9
+
10
+ from sage.benchmark.benchmark_agent.experiments.base_experiment import (
11
+ CONFIG_TYPES,
12
+ BaseExperiment,
13
+ ExperimentConfig,
14
+ ExperimentResult,
15
+ PlanningConfig,
16
+ PlanningPrediction,
17
+ PlanStep,
18
+ ReportConfig,
19
+ TimingDecision,
20
+ TimingDetectionConfig,
21
+ ToolPrediction,
22
+ ToolSelectionConfig,
23
+ create_config,
24
+ )
25
+ from sage.benchmark.benchmark_agent.experiments.planning_exp import (
26
+ PlanningExperiment,
27
+ PlanningTask,
28
+ )
29
+ from sage.benchmark.benchmark_agent.experiments.timing_detection_exp import (
30
+ TimingDetectionExperiment,
31
+ TimingMessage,
32
+ )
33
+ from sage.benchmark.benchmark_agent.experiments.tool_selection_exp import (
34
+ ToolSelectionExperiment,
35
+ ToolSelectionQuery,
36
+ )
37
+
38
+ __all__ = [
39
+ # Base classes
40
+ "BaseExperiment",
41
+ "ExperimentConfig",
42
+ "ExperimentResult",
43
+ # Config models
44
+ "ToolSelectionConfig",
45
+ "PlanningConfig",
46
+ "TimingDetectionConfig",
47
+ "ReportConfig",
48
+ # Result/task models
49
+ "ToolPrediction",
50
+ "PlanStep",
51
+ "PlanningPrediction",
52
+ "TimingDecision",
53
+ "ToolSelectionQuery",
54
+ "PlanningTask",
55
+ "TimingMessage",
56
+ # Utilities
57
+ "CONFIG_TYPES",
58
+ "create_config",
59
+ # Experiment implementations
60
+ "ToolSelectionExperiment",
61
+ "PlanningExperiment",
62
+ "TimingDetectionExperiment",
63
+ ]