isage-benchmark-agent 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. isage_benchmark_agent-0.1.0.1.dist-info/METADATA +91 -0
  2. isage_benchmark_agent-0.1.0.1.dist-info/RECORD +51 -0
  3. isage_benchmark_agent-0.1.0.1.dist-info/WHEEL +5 -0
  4. isage_benchmark_agent-0.1.0.1.dist-info/entry_points.txt +2 -0
  5. isage_benchmark_agent-0.1.0.1.dist-info/licenses/LICENSE +21 -0
  6. isage_benchmark_agent-0.1.0.1.dist-info/top_level.txt +1 -0
  7. sage/__init__.py +0 -0
  8. sage/benchmark/__init__.py +0 -0
  9. sage/benchmark/benchmark_agent/__init__.py +108 -0
  10. sage/benchmark/benchmark_agent/__main__.py +177 -0
  11. sage/benchmark/benchmark_agent/acebench_loader.py +369 -0
  12. sage/benchmark/benchmark_agent/adapter_registry.py +3036 -0
  13. sage/benchmark/benchmark_agent/config/config_loader.py +176 -0
  14. sage/benchmark/benchmark_agent/config/default_config.yaml +24 -0
  15. sage/benchmark/benchmark_agent/config/planning_exp.yaml +34 -0
  16. sage/benchmark/benchmark_agent/config/timing_detection_exp.yaml +34 -0
  17. sage/benchmark/benchmark_agent/config/tool_selection_exp.yaml +32 -0
  18. sage/benchmark/benchmark_agent/data_paths.py +332 -0
  19. sage/benchmark/benchmark_agent/evaluation/__init__.py +217 -0
  20. sage/benchmark/benchmark_agent/evaluation/analyzers/__init__.py +11 -0
  21. sage/benchmark/benchmark_agent/evaluation/analyzers/planning_analyzer.py +111 -0
  22. sage/benchmark/benchmark_agent/evaluation/analyzers/timing_analyzer.py +135 -0
  23. sage/benchmark/benchmark_agent/evaluation/analyzers/tool_selection_analyzer.py +124 -0
  24. sage/benchmark/benchmark_agent/evaluation/evaluator.py +228 -0
  25. sage/benchmark/benchmark_agent/evaluation/metrics.py +650 -0
  26. sage/benchmark/benchmark_agent/evaluation/report_builder.py +217 -0
  27. sage/benchmark/benchmark_agent/evaluation/unified_tool_selection.py +602 -0
  28. sage/benchmark/benchmark_agent/experiments/__init__.py +63 -0
  29. sage/benchmark/benchmark_agent/experiments/base_experiment.py +263 -0
  30. sage/benchmark/benchmark_agent/experiments/method_comparison.py +742 -0
  31. sage/benchmark/benchmark_agent/experiments/planning_exp.py +262 -0
  32. sage/benchmark/benchmark_agent/experiments/timing_detection_exp.py +198 -0
  33. sage/benchmark/benchmark_agent/experiments/tool_selection_exp.py +250 -0
  34. sage/benchmark/benchmark_agent/scripts/__init__.py +26 -0
  35. sage/benchmark/benchmark_agent/scripts/experiments/__init__.py +40 -0
  36. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_ablation.py +425 -0
  37. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_error.py +400 -0
  38. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_robustness.py +439 -0
  39. sage/benchmark/benchmark_agent/scripts/experiments/exp_analysis_scaling.py +565 -0
  40. sage/benchmark/benchmark_agent/scripts/experiments/exp_cross_dataset.py +406 -0
  41. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_planning.py +315 -0
  42. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_selection.py +344 -0
  43. sage/benchmark/benchmark_agent/scripts/experiments/exp_main_timing.py +270 -0
  44. sage/benchmark/benchmark_agent/scripts/experiments/exp_training_comparison.py +620 -0
  45. sage/benchmark/benchmark_agent/scripts/experiments/exp_utils.py +427 -0
  46. sage/benchmark/benchmark_agent/scripts/experiments/figure_generator.py +677 -0
  47. sage/benchmark/benchmark_agent/scripts/experiments/llm_service.py +332 -0
  48. sage/benchmark/benchmark_agent/scripts/experiments/run_paper1_experiments.py +627 -0
  49. sage/benchmark/benchmark_agent/scripts/experiments/sage_bench_cli.py +422 -0
  50. sage/benchmark/benchmark_agent/scripts/experiments/table_generator.py +430 -0
  51. sage/benchmark/benchmark_agent/tools_loader.py +212 -0
@@ -0,0 +1,427 @@
1
+ """
2
+ Experiment Utilities - 实验共享工具函数
3
+
4
+ 为所有 Paper 1 实验提供统一的:
5
+ - 环境设置
6
+ - 数据加载
7
+ - 结果保存
8
+ - 客户端获取
9
+ - 进度显示
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import os
16
+ import random
17
+ import urllib.request
18
+ from datetime import datetime
19
+ from pathlib import Path
20
+ from typing import Any, Optional
21
+
22
+ import numpy as np
23
+
24
+ # =============================================================================
25
+ # 控制变量配置 (从 adapter_registry 同步)
26
+ # =============================================================================
27
+ RANDOM_SEED = 42
28
+ BENCHMARK_EMBEDDING_MODEL = "BAAI/bge-small-zh-v1.5"
29
+ BENCHMARK_LLM_TEMPERATURE = 0.1
30
+
31
+ # =============================================================================
32
+ # 路径配置
33
+ # =============================================================================
34
+ SCRIPT_DIR = Path(__file__).resolve().parent
35
+ BENCHMARK_AGENT_DIR = SCRIPT_DIR.parent.parent
36
+ BENCHMARK_ROOT = BENCHMARK_AGENT_DIR.parent.parent.parent.parent
37
+
38
+ # 尝试导入数据路径模块
39
+ try:
40
+ from sage.benchmark.benchmark_agent.data_paths import get_runtime_paths
41
+
42
+ _runtime_paths = get_runtime_paths()
43
+ DEFAULT_OUTPUT_DIR = _runtime_paths.results_root.parent / "paper1"
44
+ DEFAULT_DATA_DIR = _runtime_paths.data_root
45
+ except ImportError:
46
+ SAGE_ROOT = BENCHMARK_ROOT.parent.parent
47
+ DEFAULT_OUTPUT_DIR = SAGE_ROOT / ".sage" / "benchmark" / "paper1"
48
+ DEFAULT_DATA_DIR = SAGE_ROOT / ".sage" / "benchmark" / "data"
49
+
50
+
51
+ # =============================================================================
52
+ # 环境设置
53
+ # =============================================================================
54
+ def ensure_hf_endpoint_configured(verbose: bool = False) -> tuple[bool, bool]:
55
+ """确保 HuggingFace 端点可用(必要时自动切换镜像)。"""
56
+
57
+ configured_endpoint = False
58
+ synced_hub = False
59
+
60
+ endpoint = os.environ.get("HF_ENDPOINT", "").strip()
61
+ if not endpoint:
62
+ try:
63
+ urllib.request.urlopen("https://huggingface.co", timeout=3)
64
+ except Exception:
65
+ endpoint = "https://hf-mirror.com"
66
+ os.environ["HF_ENDPOINT"] = endpoint
67
+ configured_endpoint = True
68
+ if verbose:
69
+ print(f" Auto-configured HF mirror: {endpoint}")
70
+ else:
71
+ endpoint = endpoint.rstrip("/")
72
+ os.environ["HF_ENDPOINT"] = endpoint
73
+
74
+ if os.environ.get("HF_ENDPOINT") and not os.environ.get("HF_HUB_ENDPOINT"):
75
+ os.environ["HF_HUB_ENDPOINT"] = os.environ["HF_ENDPOINT"]
76
+ synced_hub = True
77
+ if verbose:
78
+ print(f" HF_HUB_ENDPOINT synchronized to {os.environ['HF_HUB_ENDPOINT']}")
79
+
80
+ return configured_endpoint, synced_hub
81
+
82
+
83
+ # 在模块导入时尽早配置镜像,避免后续导入 transformers 时命中默认域名
84
+ ensure_hf_endpoint_configured(verbose=False)
85
+
86
+
87
+ def setup_experiment_env(seed: int = RANDOM_SEED, verbose: bool = True) -> None:
88
+ """
89
+ 设置实验环境,确保可复现性。
90
+
91
+ Args:
92
+ seed: 随机种子
93
+ verbose: 是否打印设置信息
94
+ """
95
+ # 设置 Python 随机数
96
+ random.seed(seed)
97
+
98
+ # 设置 NumPy 随机数
99
+ np.random.seed(seed)
100
+
101
+ # 设置 PyTorch 随机数 (如果可用)
102
+ try:
103
+ import torch
104
+
105
+ torch.manual_seed(seed)
106
+ if torch.cuda.is_available():
107
+ torch.cuda.manual_seed_all(seed)
108
+ # 确定性算法 (可能降低性能)
109
+ torch.backends.cudnn.deterministic = True
110
+ torch.backends.cudnn.benchmark = False
111
+ except ImportError:
112
+ pass
113
+
114
+ # 设置环境变量
115
+ os.environ.setdefault("SAGE_TEST_MODE", "true")
116
+ os.environ.setdefault("PYTHONHASHSEED", str(seed))
117
+
118
+ # vLLM 配置
119
+ os.environ.setdefault("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
120
+
121
+ # PyTorch 分布式警告抑制
122
+ os.environ.setdefault("GLOO_SOCKET_IFNAME", "lo")
123
+ os.environ.setdefault("NCCL_SOCKET_IFNAME", "lo")
124
+ os.environ.setdefault("TORCH_DISTRIBUTED_DEBUG", "OFF")
125
+
126
+ ensure_hf_endpoint_configured(verbose=verbose)
127
+
128
+ if verbose:
129
+ print(f" Random seed: {seed}")
130
+ print(f" Embedding model: {BENCHMARK_EMBEDDING_MODEL}")
131
+ print(f" LLM temperature: {BENCHMARK_LLM_TEMPERATURE}")
132
+
133
+
134
+ # =============================================================================
135
+ # 数据加载
136
+ # =============================================================================
137
+ def load_benchmark_data(
138
+ challenge: str,
139
+ split: str = "test",
140
+ max_samples: Optional[int] = None,
141
+ data_dir: Optional[Path] = None,
142
+ ) -> list[dict]:
143
+ """
144
+ 加载 benchmark 数据。
145
+
146
+ Args:
147
+ challenge: 挑战类型 ("timing", "planning", "selection")
148
+ split: 数据集划分 ("train", "dev", "test")
149
+ max_samples: 最大样本数 (None 表示全部)
150
+ data_dir: 数据目录 (None 使用默认)
151
+
152
+ Returns:
153
+ 样本列表
154
+ """
155
+ data_dir = data_dir or DEFAULT_DATA_DIR
156
+
157
+ # 映射 challenge 到数据目录
158
+ challenge_dirs = {
159
+ "timing": "timing_judgment",
160
+ "planning": "task_planning",
161
+ "selection": "tool_selection",
162
+ }
163
+
164
+ if challenge not in challenge_dirs:
165
+ raise ValueError(f"Unknown challenge: {challenge}. Use: {list(challenge_dirs.keys())}")
166
+
167
+ data_file = data_dir / challenge_dirs[challenge] / f"{split}.jsonl"
168
+
169
+ if not data_file.exists():
170
+ print(f" Warning: Data file not found: {data_file}")
171
+ return []
172
+
173
+ samples = []
174
+ with open(data_file, encoding="utf-8") as f:
175
+ for line in f:
176
+ if line.strip():
177
+ samples.append(json.loads(line))
178
+
179
+ if max_samples is not None and max_samples > 0:
180
+ samples = samples[:max_samples]
181
+
182
+ return samples
183
+
184
+
185
+ def load_jsonl(file_path: Path) -> list[dict]:
186
+ """加载 JSONL 文件。"""
187
+ if not file_path.exists():
188
+ return []
189
+
190
+ samples = []
191
+ with open(file_path, encoding="utf-8") as f:
192
+ for line in f:
193
+ if line.strip():
194
+ samples.append(json.loads(line))
195
+ return samples
196
+
197
+
198
+ # =============================================================================
199
+ # 结果保存
200
+ # =============================================================================
201
+ def save_results(
202
+ results: dict[str, Any],
203
+ section: str,
204
+ name: str,
205
+ output_dir: Optional[Path] = None,
206
+ ) -> Path:
207
+ """
208
+ 保存实验结果。
209
+
210
+ Args:
211
+ results: 结果字典
212
+ section: 论文章节 ("5_2_main", "5_3_analysis", "5_4_generalization")
213
+ name: 结果名称 (如 "timing", "error_analysis")
214
+ output_dir: 输出目录 (None 使用默认)
215
+
216
+ Returns:
217
+ 保存的文件路径
218
+ """
219
+ output_dir = output_dir or DEFAULT_OUTPUT_DIR
220
+
221
+ # 创建章节目录
222
+ section_dir = output_dir / f"section_{section}"
223
+ section_dir.mkdir(parents=True, exist_ok=True)
224
+
225
+ # 添加元数据
226
+ results["_metadata"] = {
227
+ "timestamp": datetime.now().isoformat(),
228
+ "seed": RANDOM_SEED,
229
+ "embedding_model": BENCHMARK_EMBEDDING_MODEL,
230
+ "llm_temperature": BENCHMARK_LLM_TEMPERATURE,
231
+ }
232
+
233
+ # 保存 JSON
234
+ output_file = section_dir / f"{name}_results.json"
235
+ with open(output_file, "w", encoding="utf-8") as f:
236
+ json.dump(results, f, indent=2, ensure_ascii=False)
237
+
238
+ return output_file
239
+
240
+
241
+ def get_output_dir(section: str, output_dir: Optional[Path] = None) -> Path:
242
+ """获取指定章节的输出目录。"""
243
+ output_dir = output_dir or DEFAULT_OUTPUT_DIR
244
+ section_dir = output_dir / f"section_{section}"
245
+ section_dir.mkdir(parents=True, exist_ok=True)
246
+ return section_dir
247
+
248
+
249
+ def get_figures_dir(output_dir: Optional[Path] = None) -> Path:
250
+ """获取 figures 目录。"""
251
+ output_dir = output_dir or DEFAULT_OUTPUT_DIR
252
+ figures_dir = output_dir / "figures"
253
+ figures_dir.mkdir(parents=True, exist_ok=True)
254
+ return figures_dir
255
+
256
+
257
+ def get_tables_dir(output_dir: Optional[Path] = None) -> Path:
258
+ """获取 tables 目录。"""
259
+ output_dir = output_dir or DEFAULT_OUTPUT_DIR
260
+ tables_dir = output_dir / "tables"
261
+ tables_dir.mkdir(parents=True, exist_ok=True)
262
+ return tables_dir
263
+
264
+
265
+ # =============================================================================
266
+ # 客户端获取
267
+ # =============================================================================
268
+ def get_llm_client():
269
+ """
270
+ 获取统一 LLM 客户端。
271
+
272
+ Returns:
273
+ UnifiedInferenceClient 实例
274
+ """
275
+ try:
276
+ from sage.llm import UnifiedInferenceClient
277
+
278
+ return UnifiedInferenceClient.create()
279
+ except ImportError as e:
280
+ print(f" Warning: Could not create LLM client: {e}")
281
+ return None
282
+
283
+
284
+ def get_embedding_client():
285
+ """
286
+ 获取 Embedding 客户端。
287
+
288
+ Returns:
289
+ EmbeddingClientAdapter 实例
290
+ """
291
+ try:
292
+ from sage.common.components.sage_embedding import (
293
+ EmbeddingClientAdapter,
294
+ EmbeddingFactory,
295
+ )
296
+
297
+ raw_embedder = EmbeddingFactory.create("hf", model=BENCHMARK_EMBEDDING_MODEL)
298
+ return EmbeddingClientAdapter(raw_embedder)
299
+ except ImportError as e:
300
+ print(f" Warning: Could not create embedding client: {e}")
301
+ return None
302
+
303
+
304
+ # =============================================================================
305
+ # 进度显示
306
+ # =============================================================================
307
+ def create_progress_bar(total: int, desc: str = "Processing"):
308
+ """
309
+ 创建进度条。
310
+
311
+ Args:
312
+ total: 总数
313
+ desc: 描述
314
+
315
+ Returns:
316
+ tqdm 进度条或简单迭代器
317
+ """
318
+ try:
319
+ from tqdm import tqdm
320
+
321
+ return tqdm(total=total, desc=desc, ncols=80)
322
+ except ImportError:
323
+ # 简单的进度显示
324
+ class SimpleProgress:
325
+ def __init__(self, total, desc):
326
+ self.total = total
327
+ self.current = 0
328
+ self.desc = desc
329
+
330
+ def update(self, n=1):
331
+ self.current += n
332
+ if self.current % max(1, self.total // 10) == 0:
333
+ print(f" {self.desc}: {self.current}/{self.total}")
334
+
335
+ def close(self):
336
+ print(f" {self.desc}: Complete ({self.total})")
337
+
338
+ def __enter__(self):
339
+ return self
340
+
341
+ def __exit__(self, *args):
342
+ self.close()
343
+
344
+ return SimpleProgress(total, desc)
345
+
346
+
347
+ # =============================================================================
348
+ # 实验结果数据类
349
+ # =============================================================================
350
+ from dataclasses import dataclass, field
351
+
352
+
353
+ @dataclass
354
+ class ExperimentResult:
355
+ """单个策略的实验结果。"""
356
+
357
+ challenge: str
358
+ strategy: str
359
+ metrics: dict[str, float]
360
+ metadata: dict[str, Any] = field(default_factory=dict)
361
+ passed: bool = False
362
+ target: float = 0.0
363
+
364
+
365
+ @dataclass
366
+ class ExperimentSummary:
367
+ """实验汇总。"""
368
+
369
+ section: str
370
+ challenge: str
371
+ results: list[ExperimentResult] = field(default_factory=list)
372
+ best_strategy: Optional[str] = None
373
+ best_metric: Optional[float] = None
374
+ target_met: bool = False
375
+
376
+ def to_dict(self) -> dict:
377
+ """转换为字典。"""
378
+ return {
379
+ "section": self.section,
380
+ "challenge": self.challenge,
381
+ "results": [
382
+ {
383
+ "strategy": r.strategy,
384
+ "metrics": r.metrics,
385
+ "passed": r.passed,
386
+ "target": r.target,
387
+ }
388
+ for r in self.results
389
+ ],
390
+ "best_strategy": self.best_strategy,
391
+ "best_metric": self.best_metric,
392
+ "target_met": self.target_met,
393
+ }
394
+
395
+
396
+ # =============================================================================
397
+ # 打印工具
398
+ # =============================================================================
399
+ def print_section_header(title: str, width: int = 70) -> None:
400
+ """打印章节标题。"""
401
+ print("\n" + "=" * width)
402
+ print(f"📊 {title}")
403
+ print("=" * width)
404
+
405
+
406
+ def print_subsection_header(title: str) -> None:
407
+ """打印子章节标题。"""
408
+ print(f"\n ▸ {title}")
409
+ print(" " + "-" * 50)
410
+
411
+
412
+ def print_result_row(strategy: str, metrics: dict, passed: bool, target: float) -> None:
413
+ """打印结果行。"""
414
+ primary_metric = list(metrics.values())[0] if metrics else 0.0
415
+ status = "✅ PASS" if passed else "❌ FAIL"
416
+ print(
417
+ f" {strategy:20s} | {primary_metric * 100:6.1f}% (target: {target * 100:.0f}%) {status}"
418
+ )
419
+
420
+
421
+ def print_metrics_detail(metrics: dict) -> None:
422
+ """打印详细指标。"""
423
+ for name, value in metrics.items():
424
+ if isinstance(value, float):
425
+ print(f" - {name}: {value * 100:.1f}%")
426
+ else:
427
+ print(f" - {name}: {value}")