isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,430 @@
1
+ """
2
+ Table Generator - LaTeX 表格生成模块
3
+
4
+ 为 Paper 1 所有实验生成 LaTeX 格式的表格:
5
+ - Main Results 表
6
+ - Benchmark Details 表
7
+ - Ablation Study 表
8
+ - Training Comparison 表
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from pathlib import Path
14
+ from typing import Any, Optional
15
+
16
+ from .exp_utils import get_tables_dir
17
+
18
+ # =============================================================================
19
+ # LaTeX 表格生成
20
+ # =============================================================================
21
+
22
+
23
+ def generate_main_results_table(
24
+ timing_results: list[dict],
25
+ planning_results: list[dict],
26
+ selection_results: list[dict],
27
+ output_path: Optional[Path] = None,
28
+ ) -> str:
29
+ """
30
+ 生成主结果表 (Table 1)。
31
+
32
+ Args:
33
+ timing_results: Timing 实验结果
34
+ planning_results: Planning 实验结果
35
+ selection_results: Selection 实验结果
36
+ output_path: 输出路径
37
+
38
+ Returns:
39
+ LaTeX 表格字符串
40
+ """
41
+
42
+ # 找各 challenge 最佳结果
43
+ def get_best(results, metric):
44
+ if not results:
45
+ return None
46
+ return max(results, key=lambda r: r.get("metrics", {}).get(metric, 0))
47
+
48
+ timing_best = get_best(timing_results, "accuracy")
49
+ planning_best = get_best(planning_results, "plan_success_rate")
50
+ selection_best = get_best(selection_results, "top_k_accuracy")
51
+
52
+ lines = [
53
+ r"\begin{table}[t]",
54
+ r"\centering",
55
+ r"\caption{Main results on SAGE-AgentBench. Best performing strategy for each challenge.}",
56
+ r"\label{tab:main}",
57
+ r"\begin{tabular}{lccc}",
58
+ r"\toprule",
59
+ r"\textbf{Challenge} & \textbf{Strategy} & \textbf{Primary Metric} & \textbf{Target Met} \\",
60
+ r"\midrule",
61
+ ]
62
+
63
+ if timing_best:
64
+ acc = timing_best.get("metrics", {}).get("accuracy", 0) * 100
65
+ strategy = timing_best.get("strategy", "").replace("timing.", "").replace("_", " ").title()
66
+ passed = acc >= 95
67
+ status = r"\cmark" if passed else r"\xmark"
68
+ lines.append(f"Timing Detection & {strategy} & {acc:.1f}\\% & {status} \\\\")
69
+
70
+ if planning_best:
71
+ rate = planning_best.get("metrics", {}).get("plan_success_rate", 0) * 100
72
+ strategy = (
73
+ planning_best.get("strategy", "").replace("planner.", "").replace("_", " ").title()
74
+ )
75
+ passed = rate >= 90
76
+ status = r"\cmark" if passed else r"\xmark"
77
+ lines.append(f"Task Planning & {strategy} & {rate:.1f}\\% & {status} \\\\")
78
+
79
+ if selection_best:
80
+ acc = selection_best.get("metrics", {}).get("top_k_accuracy", 0) * 100
81
+ strategy = (
82
+ selection_best.get("strategy", "").replace("selector.", "").replace("_", " ").title()
83
+ )
84
+ passed = acc >= 95
85
+ status = r"\cmark" if passed else r"\xmark"
86
+ lines.append(f"Tool Selection & {strategy} & {acc:.1f}\\% & {status} \\\\")
87
+
88
+ lines.extend(
89
+ [
90
+ r"\bottomrule",
91
+ r"\end{tabular}",
92
+ r"\end{table}",
93
+ ]
94
+ )
95
+
96
+ content = "\n".join(lines)
97
+
98
+ if output_path:
99
+ output_path.parent.mkdir(parents=True, exist_ok=True)
100
+ output_path.write_text(content)
101
+
102
+ return content
103
+
104
+
105
+ def generate_benchmark_details_table(
106
+ timing_results: list[dict],
107
+ planning_results: list[dict],
108
+ selection_results: list[dict],
109
+ output_path: Optional[Path] = None,
110
+ ) -> str:
111
+ """
112
+ 生成详细 Benchmark 结果表 (Table 2)。
113
+
114
+ 显示所有策略在所有指标上的表现。
115
+ """
116
+ lines = [
117
+ r"\begin{table}[t]",
118
+ r"\centering",
119
+ r"\caption{Detailed benchmark results across all strategies.}",
120
+ r"\label{tab:benchmark}",
121
+ r"\small",
122
+ r"\begin{tabular}{llcccc}",
123
+ r"\toprule",
124
+ r"\textbf{Challenge} & \textbf{Strategy} & \textbf{Metric 1} & \textbf{Metric 2} & \textbf{Metric 3} & \textbf{Pass} \\",
125
+ r"\midrule",
126
+ ]
127
+
128
+ # Timing results
129
+ for r in timing_results or []:
130
+ name = r.get("strategy", "").replace("timing.", "").replace("_", " ").title()
131
+ metrics = r.get("metrics", {})
132
+ acc = metrics.get("accuracy", 0) * 100
133
+ prec = metrics.get("precision", 0) * 100
134
+ rec = metrics.get("recall", 0) * 100
135
+ passed = acc >= 95
136
+ status = r"\cmark" if passed else r"\xmark"
137
+ lines.append(
138
+ f"Timing & {name} & Acc: {acc:.1f}\\% & Prec: {prec:.1f}\\% & Rec: {rec:.1f}\\% & {status} \\\\"
139
+ )
140
+
141
+ if timing_results:
142
+ lines.append(r"\midrule")
143
+
144
+ # Planning results
145
+ for r in planning_results or []:
146
+ name = r.get("strategy", "").replace("planner.", "").replace("_", " ").title()
147
+ metrics = r.get("metrics", {})
148
+ success = metrics.get("plan_success_rate", 0) * 100
149
+ step = metrics.get("step_accuracy", 0) * 100
150
+ seq = metrics.get("sequence_match", 0) * 100
151
+ passed = success >= 90
152
+ status = r"\cmark" if passed else r"\xmark"
153
+ lines.append(
154
+ f"Planning & {name} & Succ: {success:.1f}\\% & Step: {step:.1f}\\% & Seq: {seq:.1f}\\% & {status} \\\\"
155
+ )
156
+
157
+ if planning_results:
158
+ lines.append(r"\midrule")
159
+
160
+ # Tool selection results
161
+ for r in selection_results or []:
162
+ name = r.get("strategy", "").replace("selector.", "").replace("_", " ").title()
163
+ metrics = r.get("metrics", {})
164
+ top_k = metrics.get("top_k_accuracy", 0) * 100
165
+ mrr = metrics.get("mrr", 0) * 100
166
+ recall = metrics.get("recall_at_k", 0) * 100
167
+ passed = top_k >= 95
168
+ status = r"\cmark" if passed else r"\xmark"
169
+ lines.append(
170
+ f"Selection & {name} & Top-K: {top_k:.1f}\\% & MRR: {mrr:.1f}\\% & R@K: {recall:.1f}\\% & {status} \\\\"
171
+ )
172
+
173
+ lines.extend(
174
+ [
175
+ r"\bottomrule",
176
+ r"\end{tabular}",
177
+ r"\end{table}",
178
+ ]
179
+ )
180
+
181
+ content = "\n".join(lines)
182
+
183
+ if output_path:
184
+ output_path.parent.mkdir(parents=True, exist_ok=True)
185
+ output_path.write_text(content)
186
+
187
+ return content
188
+
189
+
190
+ def generate_training_comparison_table(
191
+ training_results: list[dict],
192
+ output_path: Optional[Path] = None,
193
+ ) -> str:
194
+ """
195
+ 生成训练方法对比表 (Table 3)。
196
+ """
197
+ lines = [
198
+ r"\begin{table}[t]",
199
+ r"\centering",
200
+ r"\caption{Training method comparison. Performance after fine-tuning on SAGE-AgentBench.}",
201
+ r"\label{tab:training}",
202
+ r"\small",
203
+ r"\begin{tabular}{lcccc}",
204
+ r"\toprule",
205
+ r"\textbf{Method} & \textbf{Timing} & \textbf{Planning} & \textbf{Selection} & \textbf{Overall} \\",
206
+ r"\midrule",
207
+ ]
208
+
209
+ for r in training_results or []:
210
+ name = r.get("method_name", "").replace("_", " ")
211
+ timing = r.get("timing_accuracy", 0) * 100
212
+ planning = r.get("planning_success_rate", 0) * 100
213
+ selection = r.get("selection_top_k_accuracy", 0) * 100
214
+ overall = (timing + planning + selection) / 3
215
+ lines.append(
216
+ f"{name} & {timing:.1f}\\% & {planning:.1f}\\% & {selection:.1f}\\% & {overall:.1f}\\% \\\\"
217
+ )
218
+
219
+ lines.extend(
220
+ [
221
+ r"\bottomrule",
222
+ r"\end{tabular}",
223
+ r"\end{table}",
224
+ ]
225
+ )
226
+
227
+ content = "\n".join(lines)
228
+
229
+ if output_path:
230
+ output_path.parent.mkdir(parents=True, exist_ok=True)
231
+ output_path.write_text(content)
232
+
233
+ return content
234
+
235
+
236
+ def generate_ablation_table(
237
+ ablation_results: dict[str, dict[str, float]],
238
+ title: str = "Ablation study results",
239
+ label: str = "tab:ablation",
240
+ output_path: Optional[Path] = None,
241
+ ) -> str:
242
+ """
243
+ 生成消融实验表。
244
+
245
+ Args:
246
+ ablation_results: {"config_name": {"metric1": 0.9, ...}, ...}
247
+ title: 表格标题
248
+ label: LaTeX 标签
249
+ output_path: 输出路径
250
+ """
251
+ if not ablation_results:
252
+ return ""
253
+
254
+ configs = list(ablation_results.keys())
255
+ metrics = list(next(iter(ablation_results.values())).keys())
256
+
257
+ # 构建表格
258
+ col_spec = "l" + "c" * len(metrics)
259
+ header = (
260
+ r"\textbf{Config} & "
261
+ + " & ".join([f"\\textbf{{{m.replace('_', ' ').title()}}}" for m in metrics])
262
+ + r" \\"
263
+ )
264
+
265
+ lines = [
266
+ r"\begin{table}[t]",
267
+ r"\centering",
268
+ f"\\caption{{{title}}}",
269
+ f"\\label{{{label}}}",
270
+ r"\small",
271
+ f"\\begin{{tabular}}{{{col_spec}}}",
272
+ r"\toprule",
273
+ header,
274
+ r"\midrule",
275
+ ]
276
+
277
+ for config in configs:
278
+ values = ablation_results[config]
279
+ row = config.replace("_", " ")
280
+ for metric in metrics:
281
+ val = values.get(metric, 0)
282
+ if isinstance(val, float):
283
+ row += f" & {val * 100:.1f}\\%"
284
+ else:
285
+ row += f" & {val}"
286
+ row += r" \\"
287
+ lines.append(row)
288
+
289
+ lines.extend(
290
+ [
291
+ r"\bottomrule",
292
+ r"\end{tabular}",
293
+ r"\end{table}",
294
+ ]
295
+ )
296
+
297
+ content = "\n".join(lines)
298
+
299
+ if output_path:
300
+ output_path.parent.mkdir(parents=True, exist_ok=True)
301
+ output_path.write_text(content)
302
+
303
+ return content
304
+
305
+
306
+ def generate_cross_dataset_table(
307
+ cross_dataset_results: dict[str, dict[str, float]],
308
+ output_path: Optional[Path] = None,
309
+ ) -> str:
310
+ """
311
+ 生成跨数据集对比表。
312
+
313
+ Args:
314
+ cross_dataset_results: {"method": {"sage": 0.9, "acebench": 0.85, ...}, ...}
315
+ """
316
+ if not cross_dataset_results:
317
+ return ""
318
+
319
+ methods = list(cross_dataset_results.keys())
320
+ datasets = list(next(iter(cross_dataset_results.values())).keys())
321
+
322
+ col_spec = "l" + "c" * len(datasets)
323
+ header = (
324
+ r"\textbf{Method} & " + " & ".join([f"\\textbf{{{d.upper()}}}" for d in datasets]) + r" \\"
325
+ )
326
+
327
+ lines = [
328
+ r"\begin{table}[t]",
329
+ r"\centering",
330
+ r"\caption{Cross-dataset generalization. Top-5 accuracy on different benchmarks.}",
331
+ r"\label{tab:crossdataset}",
332
+ r"\small",
333
+ f"\\begin{{tabular}}{{{col_spec}}}",
334
+ r"\toprule",
335
+ header,
336
+ r"\midrule",
337
+ ]
338
+
339
+ for method in methods:
340
+ row = method.replace("_", " ").title()
341
+ for dataset in datasets:
342
+ val = cross_dataset_results[method].get(dataset, 0) * 100
343
+ row += f" & {val:.1f}\\%"
344
+ row += r" \\"
345
+ lines.append(row)
346
+
347
+ lines.extend(
348
+ [
349
+ r"\bottomrule",
350
+ r"\end{tabular}",
351
+ r"\end{table}",
352
+ ]
353
+ )
354
+
355
+ content = "\n".join(lines)
356
+
357
+ if output_path:
358
+ output_path.parent.mkdir(parents=True, exist_ok=True)
359
+ output_path.write_text(content)
360
+
361
+ return content
362
+
363
+
364
+ # =============================================================================
365
+ # 批量生成
366
+ # =============================================================================
367
+
368
+
369
+ def generate_all_tables(
370
+ all_results: dict[str, Any],
371
+ output_dir: Optional[Path] = None,
372
+ ) -> dict[str, Path]:
373
+ """
374
+ 生成所有论文表格。
375
+
376
+ Args:
377
+ all_results: 包含所有实验结果的字典
378
+ output_dir: 输出目录
379
+
380
+ Returns:
381
+ 生成的表格文件路径字典
382
+ """
383
+ tables_dir = output_dir or get_tables_dir()
384
+ tables_dir.mkdir(parents=True, exist_ok=True)
385
+
386
+ generated = {}
387
+
388
+ # Table 1: Main Results
389
+ timing = all_results.get("timing", [])
390
+ planning = all_results.get("planning", [])
391
+ selection = all_results.get("selection", [])
392
+
393
+ if timing or planning or selection:
394
+ path = tables_dir / "table1_main_results.tex"
395
+ generate_main_results_table(timing, planning, selection, path)
396
+ generated["main_results"] = path
397
+ print(f" Generated: {path}")
398
+
399
+ # Table 2: Benchmark Details
400
+ if timing or planning or selection:
401
+ path = tables_dir / "table2_benchmark_details.tex"
402
+ generate_benchmark_details_table(timing, planning, selection, path)
403
+ generated["benchmark_details"] = path
404
+ print(f" Generated: {path}")
405
+
406
+ # Table 3: Training Comparison
407
+ training = all_results.get("training", [])
408
+ if training:
409
+ path = tables_dir / "table3_training_comparison.tex"
410
+ generate_training_comparison_table(training, path)
411
+ generated["training_comparison"] = path
412
+ print(f" Generated: {path}")
413
+
414
+ # Table 4: Ablation
415
+ ablation = all_results.get("ablation", {})
416
+ if ablation:
417
+ path = tables_dir / "table4_ablation.tex"
418
+ generate_ablation_table(ablation, output_path=path)
419
+ generated["ablation"] = path
420
+ print(f" Generated: {path}")
421
+
422
+ # Table 5: Cross-Dataset
423
+ cross_dataset = all_results.get("cross_dataset", {})
424
+ if cross_dataset:
425
+ path = tables_dir / "table5_cross_dataset.tex"
426
+ generate_cross_dataset_table(cross_dataset, path)
427
+ generated["cross_dataset"] = path
428
+ print(f" Generated: {path}")
429
+
430
+ return generated
@@ -0,0 +1,212 @@
1
+ """
2
+ SAGE Tools Loader - Loads 1,200 real tools from tool_catalog.jsonl.
3
+
4
+ This module provides the SageToolsLoader class that loads tool definitions
5
+ from the SAGE benchmark's tool catalog and exposes them via a unified interface
6
+ compatible with the selector resources.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import logging
13
+ from dataclasses import dataclass, field
14
+ from pathlib import Path
15
+ from typing import Any, Iterator
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Default path to tool catalog
20
+ TOOL_CATALOG_PATH = (
21
+ Path(__file__).parent.parent.parent
22
+ / "data"
23
+ / "sources"
24
+ / "agent_tools"
25
+ / "data"
26
+ / "tool_catalog.jsonl"
27
+ )
28
+
29
+
30
+ @dataclass
31
+ class SageTool:
32
+ """
33
+ Represents a tool from the SAGE tool catalog.
34
+
35
+ Attributes:
36
+ tool_id: Unique identifier (e.g., 'environment_weather_001')
37
+ name: Human-readable name (e.g., 'Weather Fetch 1')
38
+ description: Tool description (generated from name/category/capabilities)
39
+ category: Tool category (e.g., 'environment/weather')
40
+ capabilities: List of capabilities (e.g., ['forecast', 'radar'])
41
+ inputs: List of input parameter definitions
42
+ outputs: List of output field definitions
43
+ metadata: Additional metadata (owner, version, etc.)
44
+ """
45
+
46
+ tool_id: str
47
+ name: str
48
+ description: str
49
+ category: str
50
+ capabilities: list[str] = field(default_factory=list)
51
+ inputs: list[dict[str, Any]] = field(default_factory=list)
52
+ outputs: list[dict[str, Any]] = field(default_factory=list)
53
+ metadata: dict[str, Any] = field(default_factory=dict)
54
+
55
+ @classmethod
56
+ def from_json(cls, data: dict[str, Any]) -> SageTool:
57
+ """Create SageTool from JSON data."""
58
+ # Generate description from available fields
59
+ name = data.get("name", "")
60
+ category = data.get("category", "")
61
+ capabilities = data.get("capabilities", [])
62
+
63
+ # Build description
64
+ desc_parts = [name]
65
+ if category:
66
+ desc_parts.append(f"Category: {category}")
67
+ if capabilities:
68
+ desc_parts.append(f"Capabilities: {', '.join(capabilities)}")
69
+
70
+ # Add input info
71
+ inputs = data.get("inputs", [])
72
+ if inputs:
73
+ input_names = [inp.get("name", "") for inp in inputs if inp.get("name")]
74
+ if input_names:
75
+ desc_parts.append(f"Inputs: {', '.join(input_names[:5])}")
76
+
77
+ description = ". ".join(desc_parts)
78
+
79
+ return cls(
80
+ tool_id=data.get("tool_id", ""),
81
+ name=name,
82
+ description=description,
83
+ category=category,
84
+ capabilities=capabilities,
85
+ inputs=inputs,
86
+ outputs=data.get("outputs", []),
87
+ metadata=data.get("metadata", {}),
88
+ )
89
+
90
+
91
+ class SageToolsLoader:
92
+ """
93
+ Loads tools from the SAGE benchmark tool catalog.
94
+
95
+ This loader reads the tool_catalog.jsonl file and provides methods
96
+ to iterate over tools and retrieve tools by ID.
97
+
98
+ Usage:
99
+ loader = SageToolsLoader()
100
+ for tool in loader.iter_all():
101
+ print(tool.tool_id, tool.name)
102
+
103
+ tool = loader.get_tool('environment_weather_001')
104
+ """
105
+
106
+ def __init__(self, catalog_path: str | Path | None = None):
107
+ """
108
+ Initialize the loader.
109
+
110
+ Args:
111
+ catalog_path: Path to tool_catalog.jsonl. If None, uses default path.
112
+ """
113
+ self.catalog_path = Path(catalog_path) if catalog_path else TOOL_CATALOG_PATH
114
+ self._tools: dict[str, SageTool] = {}
115
+ self._loaded = False
116
+
117
+ def _ensure_loaded(self) -> None:
118
+ """Load tools if not already loaded."""
119
+ if self._loaded:
120
+ return
121
+
122
+ if not self.catalog_path.exists():
123
+ logger.warning(f"Tool catalog not found: {self.catalog_path}")
124
+ self._loaded = True
125
+ return
126
+
127
+ try:
128
+ with open(self.catalog_path, encoding="utf-8") as f:
129
+ for line in f:
130
+ line = line.strip()
131
+ if not line:
132
+ continue
133
+ try:
134
+ data = json.loads(line)
135
+ tool = SageTool.from_json(data)
136
+ self._tools[tool.tool_id] = tool
137
+ except json.JSONDecodeError as e:
138
+ logger.warning(f"Failed to parse tool JSON: {e}")
139
+ continue
140
+
141
+ logger.info(f"Loaded {len(self._tools)} tools from {self.catalog_path}")
142
+ except Exception as e:
143
+ logger.error(f"Failed to load tool catalog: {e}")
144
+
145
+ self._loaded = True
146
+
147
+ def iter_all(self) -> Iterator[SageTool]:
148
+ """
149
+ Iterate over all tools.
150
+
151
+ Yields:
152
+ SageTool instances
153
+ """
154
+ self._ensure_loaded()
155
+ yield from self._tools.values()
156
+
157
+ def get_tool(self, tool_id: str) -> SageTool:
158
+ """
159
+ Get a tool by ID.
160
+
161
+ Args:
162
+ tool_id: Tool identifier
163
+
164
+ Returns:
165
+ SageTool instance
166
+
167
+ Raises:
168
+ KeyError: If tool not found
169
+ """
170
+ self._ensure_loaded()
171
+ if tool_id not in self._tools:
172
+ raise KeyError(f"Tool not found: {tool_id}")
173
+ return self._tools[tool_id]
174
+
175
+ def get_tools(self, tool_ids: list[str]) -> list[SageTool]:
176
+ """
177
+ Get multiple tools by IDs.
178
+
179
+ Args:
180
+ tool_ids: List of tool identifiers
181
+
182
+ Returns:
183
+ List of SageTool instances (skips missing tools)
184
+ """
185
+ self._ensure_loaded()
186
+ tools = []
187
+ for tid in tool_ids:
188
+ if tid in self._tools:
189
+ tools.append(self._tools[tid])
190
+ return tools
191
+
192
+ def __len__(self) -> int:
193
+ """Return number of tools."""
194
+ self._ensure_loaded()
195
+ return len(self._tools)
196
+
197
+ def __contains__(self, tool_id: str) -> bool:
198
+ """Check if tool exists."""
199
+ self._ensure_loaded()
200
+ return tool_id in self._tools
201
+
202
+
203
+ # Singleton instance for convenience
204
+ _default_loader: SageToolsLoader | None = None
205
+
206
+
207
+ def get_sage_tools_loader() -> SageToolsLoader:
208
+ """Get the default SageToolsLoader instance."""
209
+ global _default_loader
210
+ if _default_loader is None:
211
+ _default_loader = SageToolsLoader()
212
+ return _default_loader