isage-refiner-benchmark 0.1.0.1__cp311-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,62 @@
1
+ Metadata-Version: 2.4
2
+ Name: isage-refiner-benchmark
3
+ Version: 0.1.0.1
4
+ Summary: Context compression and refiner algorithm benchmark framework for the SAGE ecosystem
5
+ Author-email: IntelliStream Team <shuhao_zhang@hust.edu.cn>
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/intellistream/sage-refiner-benchmark
8
+ Project-URL: Documentation, https://github.com/intellistream/sage-refiner-benchmark#readme
9
+ Project-URL: Repository, https://github.com/intellistream/sage-refiner-benchmark
10
+ Project-URL: Issues, https://github.com/intellistream/sage-refiner-benchmark/issues
11
+ Keywords: sage,benchmark,refiner,context-compression,evaluation,intellistream
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3 :: Only
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.11
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: isage-common
24
+ Requires-Dist: isage-kernel
25
+ Requires-Dist: isage-libs
26
+ Requires-Dist: datasets>=2.14.0
27
+ Provides-Extra: refiner
28
+ Requires-Dist: jieba>=0.42; extra == "refiner"
29
+ Requires-Dist: fuzzywuzzy>=0.18; extra == "refiner"
30
+ Requires-Dist: python-Levenshtein>=0.12; extra == "refiner"
31
+ Requires-Dist: rouge>=1.0.0; extra == "refiner"
32
+ Provides-Extra: dev
33
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
34
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
35
+ Requires-Dist: black>=23.0.0; extra == "dev"
36
+ Requires-Dist: ruff==0.14.6; extra == "dev"
37
+ Requires-Dist: pre-commit>=3.0.0; extra == "dev"
38
+ Requires-Dist: mypy>=1.0.0; extra == "dev"
39
+ Dynamic: license-file
40
+
41
+ # SAGE Refiner Benchmark
42
+
43
+ **上下文压缩与 Refiner 算法性能评估框架** - 评测 LongRefiner、REFORM、Provence 等多种上下文压缩算法。
44
+
45
+ **Context Compression & Refiner Benchmark Framework** - Evaluates multiple context compression algorithms including LongRefiner, REFORM, Provence, and more.
46
+
47
+ ## Install
48
+
49
+ ```bash
50
+ pip install isage-refiner-benchmark
51
+ ```
52
+
53
+ ## Features
54
+
55
+ - 📊 多种上下文压缩算法支持
56
+ - ⚡ 标准化评估指标
57
+ - 🔧 灵活的配置系统
58
+ - 📈 详细的性能分析报告
59
+
60
+ ## Repository
61
+
62
+ - GitHub: https://github.com/intellistream/sage-refiner-benchmark
@@ -0,0 +1,14 @@
1
+ isage_refiner_benchmark-0.1.0.1.dist-info/licenses/LICENSE,sha256=vBNVIGkYYZY0B8f0Ui1ITYwRu7WNtSwyxvIAVGYS6jU,1075
2
+ sage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sage/benchmark_refiner/__init__.py,sha256=ps5uQv6Es7E0b53hUvpyOWGQE8b22z76vVPGBzIWanU,1812
4
+ sage/benchmark_refiner/_version.py,sha256=jsH11gEb0e-Ciggsva1C1QDkTp82TcG33ZVQp2QhXB0,153
5
+ sage/benchmark_refiner/batch.py,sha256=V_-b65827BauiGmYK97WVtYdVCdLEE7MMwlbEGXNIbc,5126
6
+ sage/benchmark_refiner/constants.py,sha256=Rfs7JO34iqcqVqwADH1IMfnb_kbxixBQnLzwJknswp4,1467
7
+ sage/benchmark_refiner/evaluator.py,sha256=cJExCQQDnc7zilJyBuLdZSIrNz1e-mldHD0UH3DOA_M,12846
8
+ sage/benchmark_refiner/metrics.py,sha256=_HtI5ERi4P4pGVogQOFhLW_BBYPg171ThnMhGn7GdVE,7274
9
+ sage/benchmark_refiner/promptor.py,sha256=rz6hhYTT0HaxtHQuFo-tOb0sZBL9B1CTPpoM__wMgWA,7690
10
+ sage/benchmark_refiner/utils.py,sha256=lo7NBY_1keB6Mczel2RY1vgeVM3MXZetP_hj2nvgHS4,3044
11
+ isage_refiner_benchmark-0.1.0.1.dist-info/METADATA,sha256=1arb06fh6yaHhP_HpEF7bh1iaYpKe55ZslVS-6jOKUA,2428
12
+ isage_refiner_benchmark-0.1.0.1.dist-info/WHEEL,sha256=yk-B4c9kYsinhQ_MzhPAVcDm9mhkAVmdo0rg0jgFCmo,94
13
+ isage_refiner_benchmark-0.1.0.1.dist-info/top_level.txt,sha256=hibFyzQHiLOMK68qL1OWsNKaXOmSXqZjeLTBem6Yy7I,5
14
+ isage_refiner_benchmark-0.1.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: cp311-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 IntelliStream Team
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
sage/__init__.py ADDED
File without changes
@@ -0,0 +1,65 @@
1
+ """
2
+ Context Compression & Refiner Benchmark Module
3
+ ===============================================
4
+
5
+ 上下文压缩与 Refiner 算法性能评估组件。
6
+
7
+ 支持的算法/数据集:
8
+ - LongBench (THUDM/LongBench) - 长文本理解基准
9
+ - LongRefiner - 上下文压缩算法
10
+ - REFORM - 检索优化
11
+ - Provence - 上下文剪枝
12
+ - 其他自定义 Refiner 实现
13
+
14
+ 组件:
15
+ - LongBenchBatch: 从 THUDM/LongBench 加载数据
16
+ - LongBenchPromptor: 生成 LongBench 官方 prompt
17
+ - LongBenchEvaluator: LongBench 官方评估指标
18
+
19
+ 使用示例:
20
+ from sage.benchmark_refiner import (
21
+ LongBenchBatch,
22
+ LongBenchPromptor,
23
+ LongBenchEvaluator,
24
+ )
25
+
26
+ # 在 pipeline 中使用
27
+ env.from_batch(LongBenchBatch, config["source"])
28
+ .map(LongBenchPromptor, config["promptor"])
29
+ .map(OpenAIGenerator, config["generator"])
30
+ .map(LongBenchEvaluator, config["evaluate"])
31
+
32
+ 配置示例:
33
+ source:
34
+ hf_dataset_name: "THUDM/LongBench"
35
+ hf_dataset_config: "hotpotqa"
36
+ max_samples: 100
37
+
38
+ promptor:
39
+ max_input_tokens: 120000
40
+ is_chat_model: true
41
+ model_name_or_path: "Qwen/Qwen2.5-7B-Instruct"
42
+
43
+ evaluate:
44
+ longbench_e_buckets: false
45
+ output_path: "results/longbench_results.jsonl"
46
+ model_name: "Qwen/Qwen2.5-7B-Instruct"
47
+
48
+ 依赖安装:
49
+ pip install isage-refiner-benchmark[refiner]
50
+
51
+ # 可选依赖说明:
52
+ # - jieba: 中文分词(中文数据集评估需要)
53
+ # - fuzzywuzzy + python-Levenshtein: 代码相似度(lcc, repobench-p 需要)
54
+ # - rouge: ROUGE-L 分数(摘要任务需要)
55
+ """
56
+
57
+ from .batch import LongBenchBatch
58
+ from .evaluator import LongBenchEvaluator
59
+ from .promptor import LongBenchPromptor
60
+
61
+ __all__ = [
62
+ "LongBenchBatch",
63
+ "LongBenchPromptor",
64
+ "LongBenchEvaluator",
65
+ ]
@@ -0,0 +1,4 @@
1
+ """Version information for sage-longbench-benchmark."""
2
+ __version__ = "0.1.0.0"
3
+ __author__ = "IntelliStream Team"
4
+ __email__ = "shuhao_zhang@hust.edu.cn"
@@ -0,0 +1,123 @@
1
+ """
2
+ LongBench Batch Processing
3
+ ==========================
4
+
5
+ LongBench 数据集批处理函数,从 THUDM/LongBench 加载数据并转换为 SAGE 标准格式。
6
+
7
+ 迁移自 sage-libs/foundation/io/batch.py,遵循 SAGE 架构:
8
+ - benchmark 相关组件统一放在 sage-benchmark (L5)
9
+ """
10
+
11
+ from typing import Any
12
+
13
+ from sage.libs.foundation.io import HFDatasetBatch
14
+
15
+
16
+ class LongBenchBatch(HFDatasetBatch):
17
+ """
18
+ LongBench 数据集批处理函数
19
+
20
+ 专门用于 LongBench 长文本理解基准测试,字段映射:
21
+ - input → query(用户问题)
22
+ - context → context(长文本上下文,LongBench 自带,无需检索)
23
+ - answers → references(标准答案列表)
24
+ - all_classes → all_classes(分类任务的类别列表)
25
+ - length → length(原始文本长度,用于 LongBench-E 分桶评估)
26
+
27
+ **与 SAGE RAG Pipeline 的对齐说明**:
28
+
29
+ SAGE RAG 标准数据流:
30
+ - query: 用户问题
31
+ - references: 标准答案(评估用)
32
+ - retrieval_results: 检索到的文档 List[Dict](Retriever 输出)
33
+ - refining_results: 压缩后的文档 List[str](Refiner 输出)
34
+ - context: 上下文字符串或列表(Promptor 读取)
35
+ - generated: 生成的答案(Generator 输出)
36
+
37
+ LongBench 特殊处理:
38
+ - LongBench 自带 context,跳过 Retriever 阶段
39
+ - context 直接作为 `context` 字段供 Promptor 使用
40
+ - 同时设置 `retrieval_results` 为空列表(表示无检索)
41
+
42
+ Input: None (直接从HF数据集读取)
43
+ Output: SAGE RAGResponse 兼容格式 + LongBench 专用字段
44
+
45
+ Config Keys (继承自 HFDatasetBatch):
46
+ hf_dataset_name: str - 固定为 "THUDM/LongBench"
47
+ hf_dataset_config: str - 如 "multi_news", "hotpotqa", "multi_news_e" 等
48
+ hf_split: str - 默认 "test"
49
+ max_samples: int - 最大样本数限制
50
+
51
+ Output Fields (SAGE RAG 标准字段):
52
+ query: str - 用户问题(来自 LongBench input)
53
+ references: List[str] - 标准答案列表(来自 LongBench answers,评估用)
54
+ context: str - 长文本上下文(来自 LongBench context,供 Promptor 使用)
55
+ retrieval_results: List[Dict] - 空列表(LongBench 不走检索)
56
+
57
+ Output Fields (LongBench 专用字段):
58
+ all_classes: List[str] | None - 分类任务类别(trec, lsht 等)
59
+ length: int - 原始文本长度(LongBench-E 分桶评估用)
60
+ _dataset: str - 数据集名称(用于选择评估指标)
61
+ _is_longbench_e: bool - 是否是 LongBench-E 版本
62
+ """
63
+
64
+ def __init__(self, config: dict | None = None, **kwargs):
65
+ super().__init__(config, **kwargs)
66
+ # 解析数据集名称和是否为 LongBench-E
67
+ self._dataset_name = self._parse_dataset_name()
68
+ self._is_longbench_e = self._check_longbench_e()
69
+
70
+ def _parse_dataset_name(self) -> str:
71
+ """从 hf_dataset_config 解析数据集名称(去除 _e 后缀)"""
72
+ config_name = self.hf_config or ""
73
+ if config_name.endswith("_e"):
74
+ return config_name[:-2] # 去除 _e 后缀
75
+ return config_name
76
+
77
+ def _check_longbench_e(self) -> bool:
78
+ """检查是否是 LongBench-E 版本"""
79
+ config_name = self.hf_config or ""
80
+ return config_name.endswith("_e")
81
+
82
+ def _build_iter(self):
83
+ """构建 LongBench 数据集迭代器,重写父类方法"""
84
+ try:
85
+ from datasets import load_dataset
86
+ except ImportError:
87
+ raise ImportError(
88
+ "datasets library is required for LongBenchBatch. "
89
+ "Install with: pip install datasets"
90
+ )
91
+
92
+ ds = load_dataset(self.hf_name, self.hf_config, split=self.hf_split, streaming=True)
93
+ for ex in ds:
94
+ if isinstance(ex, dict):
95
+ yield self._transform_example(ex)
96
+
97
+ def _transform_example(self, ex: dict[str, Any]) -> dict[str, Any]:
98
+ """
99
+ LongBench 字段映射到 SAGE RAG Pipeline 标准格式
100
+
101
+ LongBench 原始字段 → SAGE RAG 标准字段:
102
+ - input → query(用户问题)
103
+ - context → context(上下文,供 Promptor 使用)
104
+ - answers → references(标准答案,供 Evaluate 使用)
105
+
106
+ LongBench 专用字段保留:
107
+ - all_classes(分类任务类别)
108
+ - length(原始长度,LongBench-E 分桶用)
109
+ """
110
+ return {
111
+ # ========== SAGE RAG Pipeline 标准字段 ==========
112
+ "query": ex.get("input", ""),
113
+ "references": ex.get("answers") or [],
114
+ "context": ex.get("context", ""),
115
+ # 空列表表示跳过检索阶段(LongBench 自带 context)
116
+ "retrieval_results": [],
117
+ # ========== LongBench 专用字段 ==========
118
+ "all_classes": ex.get("all_classes"),
119
+ "length": ex.get("length", 0),
120
+ # ========== 内部元数据(下划线前缀,pipeline 流转用)==========
121
+ "_dataset": self._dataset_name,
122
+ "_is_longbench_e": self._is_longbench_e,
123
+ }
@@ -0,0 +1,47 @@
1
+ """
2
+ LongBench 常量定义
3
+
4
+ 数据集到评估指标的映射(照搬官方 eval.py)
5
+ """
6
+
7
+ # 数据集到评估指标的映射(来自官方 eval.py)
8
+ DATASET_TO_METRIC: dict[str, str] = {
9
+ # QA 任务 - F1 score
10
+ "narrativeqa": "qa_f1",
11
+ "qasper": "qa_f1",
12
+ "multifieldqa_en": "qa_f1",
13
+ "hotpotqa": "qa_f1",
14
+ "2wikimqa": "qa_f1",
15
+ "musique": "qa_f1",
16
+ "triviaqa": "qa_f1",
17
+ # 中文 QA(需要 jieba 分词)
18
+ "multifieldqa_zh": "qa_f1_zh",
19
+ # 摘要任务 - ROUGE score
20
+ "gov_report": "rouge",
21
+ "qmsum": "rouge",
22
+ "multi_news": "rouge",
23
+ "samsum": "rouge",
24
+ # 中文摘要(需要 jieba 分词)
25
+ "dureader": "rouge_zh",
26
+ "vcsum": "rouge_zh",
27
+ # 分类任务
28
+ "trec": "classification",
29
+ "lsht": "classification",
30
+ # 检索任务
31
+ "passage_retrieval_en": "retrieval",
32
+ "passage_retrieval_zh": "retrieval_zh",
33
+ "passage_count": "count",
34
+ # 代码任务(需要 fuzzywuzzy)
35
+ "lcc": "code_sim",
36
+ "repobench-p": "code_sim",
37
+ }
38
+
39
+ # 需要取第一行的数据集(照搬官方 eval.py)
40
+ FIRST_LINE_DATASETS: set[str] = {"trec", "triviaqa", "samsum", "lsht"}
41
+
42
+ # 不使用 chat template 的数据集(照搬官方 pred.py 注释)
43
+ # chat models are better off without build prompts on these tasks
44
+ NO_CHAT_DATASETS: set[str] = {"trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"}
45
+
46
+ # 支持的数据集列表
47
+ SUPPORTED_DATASETS: set[str] = set(DATASET_TO_METRIC.keys())
@@ -0,0 +1,339 @@
1
+ """
2
+ LongBench Evaluator - 优化版
3
+
4
+ 改进点:
5
+ 1. 模型专用后处理(照搬官方 post_process)
6
+ 2. 结果持久化(JSONL 逐条写入)
7
+ 3. 压缩文本追踪
8
+ 4. 时间统计
9
+ """
10
+
11
+ import json
12
+ from pathlib import Path
13
+ from typing import Any, Optional, TextIO
14
+
15
+ from sage.common.core import StopSignal
16
+ from sage.common.core.functions import MapFunction as MapOperator
17
+
18
+ from .constants import DATASET_TO_METRIC, FIRST_LINE_DATASETS
19
+ from .metrics import METRIC_FUNCTIONS
20
+ from .utils import post_process
21
+
22
+
23
+ class LongBenchEvaluator(MapOperator):
24
+ """
25
+ LongBench 专用评估器。
26
+
27
+ 功能:
28
+ 1. 根据数据集自动选择评估指标
29
+ 2. 支持标准版单一分数和 LongBench-E 长度分桶
30
+ 3. 集成所有 LongBench 指标函数
31
+ 4. 预测结果后处理(特定数据集只取第一行)
32
+ 5. 模型专用后处理(照搬官方 post_process)
33
+ 6. 结果持久化(JSONL 格式)
34
+
35
+ 输入数据格式(来自 Generator):
36
+ {
37
+ "query": str,
38
+ "generated": str, # 模型生成的答案
39
+ "references": List[str], # 标准答案列表
40
+ "_dataset": str, # 数据集名称
41
+ "all_classes": List[str], # 分类任务类别(可选)
42
+ "length": int, # 原始长度(LongBench-E 分桶用)
43
+ }
44
+
45
+ 配置参数:
46
+ - longbench_e_buckets: bool - 是否输出 LongBench-E 分桶分数
47
+ - output_path: str | None - 结果保存路径(JSONL 格式)
48
+ - model_name: str | None - 模型名称(用于后处理)
49
+ """
50
+
51
+ def __init__(self, config: Optional[dict[str, Any]] = None, **kwargs: Any) -> None:
52
+ super().__init__(**kwargs)
53
+ self.config = config or {}
54
+ self.longbench_e_buckets: bool = self.config.get("longbench_e_buckets", False)
55
+ self.model_name: str = self.config.get("model_name", "")
56
+
57
+ # 结果持久化
58
+ self._output_path: Optional[str] = self.config.get("output_path")
59
+ self._output_file: Optional[TextIO] = None
60
+ if self._output_path:
61
+ Path(self._output_path).parent.mkdir(parents=True, exist_ok=True)
62
+ self._output_file = open(self._output_path, "a", encoding="utf-8")
63
+
64
+ # 分数收集器(用于计算平均分)
65
+ self._scores: list[float] = []
66
+ self._dataset_scores: dict[str, list[float]] = {}
67
+
68
+ # LongBench-E 分桶分数
69
+ self._bucket_scores: dict[str, list[float]] = {
70
+ "0-4k": [],
71
+ "4-8k": [],
72
+ "8k+": [],
73
+ }
74
+
75
+ # 时间收集器
76
+ self._refine_times: list[float] = []
77
+ self._generate_times: list[float] = []
78
+ self._retrieve_times: list[float] = []
79
+
80
+ def _post_process_prediction(self, pred: str, dataset: str) -> str:
81
+ """预测结果后处理"""
82
+ # 1. 模型专用后处理(照搬官方)
83
+ if self.model_name:
84
+ pred = post_process(pred, self.model_name)
85
+
86
+ # 2. 数据集专用后处理:取第一行(照搬官方 eval.py)
87
+ if dataset in FIRST_LINE_DATASETS:
88
+ pred = pred.lstrip("\n").split("\n")[0]
89
+
90
+ return pred
91
+
92
+ def _get_length_bucket(self, length: int) -> str:
93
+ """根据长度获取分桶名称"""
94
+ if length < 4000:
95
+ return "0-4k"
96
+ elif length < 8000:
97
+ return "4-8k"
98
+ return "8k+"
99
+
100
+ def _compute_score(
101
+ self,
102
+ pred: str,
103
+ ground_truths: list[str],
104
+ dataset: str,
105
+ all_classes: Optional[list[str]] = None,
106
+ ) -> float:
107
+ """计算单个样本的分数"""
108
+ # 获取指标类型
109
+ metric_type = DATASET_TO_METRIC.get(dataset, "qa_f1")
110
+ metric_fn = METRIC_FUNCTIONS.get(metric_type)
111
+
112
+ if metric_fn is None:
113
+ self.logger.warning(f"Unknown metric type: {metric_type}")
114
+ return 0.0
115
+
116
+ # 后处理预测结果
117
+ pred = self._post_process_prediction(pred, dataset)
118
+
119
+ # 对所有参考答案计算分数,取最高(照搬官方 eval.py scorer)
120
+ best_score = 0.0
121
+ for ground_truth in ground_truths:
122
+ try:
123
+ score = metric_fn(pred, ground_truth, all_classes=all_classes or [])
124
+ best_score = max(best_score, score)
125
+ except Exception as e:
126
+ self.logger.warning(f"Error computing score for {dataset}: {e}")
127
+
128
+ return best_score
129
+
130
+ def _save_result(self, data: dict[str, Any], score: float, dataset: str) -> None:
131
+ """保存单条结果到 JSONL 文件"""
132
+ if not self._output_file:
133
+ return
134
+
135
+ result: dict[str, Any] = {
136
+ "pred": data.get("generated", ""),
137
+ "answers": data.get("references", []),
138
+ "score": score,
139
+ "dataset": dataset,
140
+ "all_classes": data.get("all_classes"),
141
+ "length": data.get("length", 0),
142
+ }
143
+
144
+ # 保存压缩后的上下文(如果存在)
145
+ refining_results = data.get("refining_results", [])
146
+ if refining_results:
147
+ if isinstance(refining_results, list):
148
+ result["compressed_text"] = "\n\n".join(refining_results)
149
+ else:
150
+ result["compressed_text"] = str(refining_results)
151
+
152
+ # 保存时间数据
153
+ times: dict[str, float] = {}
154
+ if "retrieve_time" in data:
155
+ times["retrieve"] = data["retrieve_time"]
156
+ if "refine_time" in data:
157
+ times["refine"] = data["refine_time"]
158
+ if "generate_time" in data:
159
+ times["generate"] = data["generate_time"]
160
+ if times:
161
+ result["times"] = times
162
+
163
+ self._output_file.write(json.dumps(result, ensure_ascii=False) + "\n")
164
+ self._output_file.flush()
165
+
166
+ def execute(self, data: Any) -> Any:
167
+ """执行评估"""
168
+ # Handle StopSignal - 输出汇总统计
169
+ if isinstance(data, StopSignal):
170
+ self._print_summary()
171
+ if self._output_file:
172
+ self._output_file.close()
173
+ return data
174
+
175
+ # 获取必要字段
176
+ dataset: str = data.get("_dataset", "unknown")
177
+ pred: str = data.get("generated", "")
178
+ references: list[str] = data.get("references", [])
179
+ all_classes: Optional[list[str]] = data.get("all_classes")
180
+ length: int = data.get("length", 0)
181
+
182
+ # 计算分数
183
+ score = self._compute_score(pred, references, dataset, all_classes)
184
+
185
+ # 分数 * 100(与原始 LongBench 一致)
186
+ score_percent = round(score * 100, 2)
187
+
188
+ # 收集分数
189
+ self._scores.append(score)
190
+ if dataset not in self._dataset_scores:
191
+ self._dataset_scores[dataset] = []
192
+ self._dataset_scores[dataset].append(score)
193
+
194
+ # LongBench-E 分桶
195
+ if self.longbench_e_buckets and length > 0:
196
+ bucket = self._get_length_bucket(length)
197
+ self._bucket_scores[bucket].append(score)
198
+
199
+ # 收集时间数据(由 MapOperator 自动添加)
200
+ if "refine_time" in data:
201
+ self._refine_times.append(data["refine_time"])
202
+ if "generate_time" in data:
203
+ self._generate_times.append(data["generate_time"])
204
+ if "retrieve_time" in data:
205
+ self._retrieve_times.append(data["retrieve_time"])
206
+
207
+ # 保存结果到 JSONL
208
+ self._save_result(data, score, dataset)
209
+
210
+ # 打印单个样本分数和时间
211
+ metric_type = DATASET_TO_METRIC.get(dataset, "qa_f1")
212
+ total_time = (
213
+ data.get("retrieve_time", 0) + data.get("refine_time", 0) + data.get("generate_time", 0)
214
+ )
215
+ time_str = f" (time={total_time:.3f}s)" if total_time > 0 else ""
216
+ print(f"\033[92m[LongBench {dataset}] {metric_type}: {score_percent}{time_str}\033[0m")
217
+
218
+ # 将分数添加到数据中
219
+ data["longbench_score"] = score
220
+ data["longbench_score_percent"] = score_percent
221
+ data["longbench_metric"] = metric_type
222
+
223
+ return data
224
+
225
+ def _print_summary(self) -> None:
226
+ """打印汇总统计"""
227
+ if not self._scores:
228
+ print("\n" + "=" * 80)
229
+ print("No LongBench samples processed")
230
+ print("=" * 80)
231
+ return
232
+
233
+ print("\n" + "=" * 80)
234
+ print(f"LONGBENCH EVALUATION SUMMARY ({len(self._scores)} samples)")
235
+ print("=" * 80)
236
+
237
+ # 总体平均分
238
+ avg_score = sum(self._scores) / len(self._scores) * 100
239
+ print(f"\033[92m[Overall Average Score]: {avg_score:.2f}\033[0m")
240
+
241
+ # 按数据集分组的平均分
242
+ if self._dataset_scores:
243
+ print("\n--- Per-Dataset Scores ---")
244
+ for dataset, scores in sorted(self._dataset_scores.items()):
245
+ avg = sum(scores) / len(scores) * 100
246
+ metric_type = DATASET_TO_METRIC.get(dataset, "qa_f1")
247
+ print(f" {dataset} ({metric_type}): {avg:.2f} ({len(scores)} samples)")
248
+
249
+ # LongBench-E 分桶分数
250
+ if self.longbench_e_buckets:
251
+ print("\n--- LongBench-E Length Buckets ---")
252
+ for bucket, scores in self._bucket_scores.items():
253
+ if scores:
254
+ avg = sum(scores) / len(scores) * 100
255
+ print(f" {bucket}: {avg:.2f} ({len(scores)} samples)")
256
+
257
+ # 时间统计
258
+ has_time_data = self._refine_times or self._generate_times or self._retrieve_times
259
+ if has_time_data:
260
+ print("\n--- Timing Statistics (seconds) ---")
261
+ if self._retrieve_times:
262
+ avg_retrieve = sum(self._retrieve_times) / len(self._retrieve_times)
263
+ total_retrieve = sum(self._retrieve_times)
264
+ print(
265
+ f" Retrieve: avg={avg_retrieve:.3f}s, total={total_retrieve:.2f}s ({len(self._retrieve_times)} samples)"
266
+ )
267
+ if self._refine_times:
268
+ avg_refine = sum(self._refine_times) / len(self._refine_times)
269
+ total_refine = sum(self._refine_times)
270
+ print(
271
+ f" Refine: avg={avg_refine:.3f}s, total={total_refine:.2f}s ({len(self._refine_times)} samples)"
272
+ )
273
+ if self._generate_times:
274
+ avg_generate = sum(self._generate_times) / len(self._generate_times)
275
+ total_generate = sum(self._generate_times)
276
+ print(
277
+ f" Generate: avg={avg_generate:.3f}s, total={total_generate:.2f}s ({len(self._generate_times)} samples)"
278
+ )
279
+ # 总时间
280
+ total_time = (
281
+ sum(self._retrieve_times) + sum(self._refine_times) + sum(self._generate_times)
282
+ )
283
+ print(f" \033[92mTotal Pipeline Time: {total_time:.2f}s\033[0m")
284
+
285
+ print("=" * 80 + "\n")
286
+
287
+ def get_results(self) -> dict[str, Any]:
288
+ """获取评估结果(用于程序化访问)"""
289
+ results: dict[str, Any] = {
290
+ "overall_score": (sum(self._scores) / len(self._scores) * 100 if self._scores else 0),
291
+ "sample_count": len(self._scores),
292
+ "per_dataset": {},
293
+ "timing": {},
294
+ }
295
+
296
+ for dataset, scores in self._dataset_scores.items():
297
+ results["per_dataset"][dataset] = {
298
+ "score": sum(scores) / len(scores) * 100 if scores else 0,
299
+ "count": len(scores),
300
+ "metric": DATASET_TO_METRIC.get(dataset, "qa_f1"),
301
+ }
302
+
303
+ if self.longbench_e_buckets:
304
+ results["buckets"] = {}
305
+ for bucket, scores in self._bucket_scores.items():
306
+ results["buckets"][bucket] = {
307
+ "score": sum(scores) / len(scores) * 100 if scores else 0,
308
+ "count": len(scores),
309
+ }
310
+
311
+ # 时间统计
312
+ if self._retrieve_times:
313
+ results["timing"]["retrieve"] = {
314
+ "avg": sum(self._retrieve_times) / len(self._retrieve_times),
315
+ "total": sum(self._retrieve_times),
316
+ "count": len(self._retrieve_times),
317
+ }
318
+ if self._refine_times:
319
+ results["timing"]["refine"] = {
320
+ "avg": sum(self._refine_times) / len(self._refine_times),
321
+ "total": sum(self._refine_times),
322
+ "count": len(self._refine_times),
323
+ }
324
+ if self._generate_times:
325
+ results["timing"]["generate"] = {
326
+ "avg": sum(self._generate_times) / len(self._generate_times),
327
+ "total": sum(self._generate_times),
328
+ "count": len(self._generate_times),
329
+ }
330
+
331
+ return results
332
+
333
+ def __del__(self) -> None:
334
+ """对象销毁时关闭文件"""
335
+ try:
336
+ if self._output_file:
337
+ self._output_file.close()
338
+ except Exception:
339
+ pass
@@ -0,0 +1,216 @@
1
+ """
2
+ LongBench 官方评估指标函数
3
+
4
+ 直接照搬自: https://github.com/THUDM/LongBench/blob/main/metrics.py
5
+ 添加了类型注解和可选依赖处理。
6
+ """
7
+
8
+ import re
9
+ import string
10
+ from collections import Counter
11
+ from typing import List
12
+
13
+ # 可选依赖
14
+ try:
15
+ import jieba
16
+
17
+ HAS_JIEBA = True
18
+ except ImportError:
19
+ HAS_JIEBA = False
20
+
21
+ try:
22
+ from fuzzywuzzy import fuzz
23
+
24
+ HAS_FUZZYWUZZY = True
25
+ except ImportError:
26
+ HAS_FUZZYWUZZY = False
27
+
28
+ try:
29
+ from rouge import Rouge
30
+
31
+ HAS_ROUGE = True
32
+ except ImportError:
33
+ HAS_ROUGE = False
34
+
35
+
36
+ def normalize_answer(s: str) -> str:
37
+ """Lower text and remove punctuation, articles and extra whitespace."""
38
+
39
+ def remove_articles(text: str) -> str:
40
+ return re.sub(r"\b(a|an|the)\b", " ", text)
41
+
42
+ def white_space_fix(text: str) -> str:
43
+ return " ".join(text.split())
44
+
45
+ def remove_punc(text: str) -> str:
46
+ exclude = set(string.punctuation)
47
+ return "".join(ch for ch in text if ch not in exclude)
48
+
49
+ def lower(text: str) -> str:
50
+ return text.lower()
51
+
52
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
53
+
54
+
55
+ def normalize_zh_answer(s: str) -> str:
56
+ """Lower text and remove punctuation, extra whitespace (Chinese)."""
57
+
58
+ def white_space_fix(text: str) -> str:
59
+ return "".join(text.split())
60
+
61
+ def remove_punc(text: str) -> str:
62
+ cn_punctuation = (
63
+ "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—''‛"
64
+ "„‟…‧﹏."
65
+ )
66
+ all_punctuation = set(string.punctuation + cn_punctuation)
67
+ return "".join(ch for ch in text if ch not in all_punctuation)
68
+
69
+ def lower(text: str) -> str:
70
+ return text.lower()
71
+
72
+ return white_space_fix(remove_punc(lower(s)))
73
+
74
+
75
+ def f1_score(prediction: List[str], ground_truth: List[str], **kwargs) -> float:
76
+ """Token-level F1 score."""
77
+ common = Counter(prediction) & Counter(ground_truth)
78
+ num_same = sum(common.values())
79
+ if num_same == 0:
80
+ return 0.0
81
+ precision = 1.0 * num_same / len(prediction)
82
+ recall = 1.0 * num_same / len(ground_truth)
83
+ f1 = (2 * precision * recall) / (precision + recall)
84
+ return f1
85
+
86
+
87
+ def qa_f1_score(prediction: str, ground_truth: str, **kwargs) -> float:
88
+ """QA F1 score (English)."""
89
+ normalized_prediction = normalize_answer(prediction)
90
+ normalized_ground_truth = normalize_answer(ground_truth)
91
+ prediction_tokens = normalized_prediction.split()
92
+ ground_truth_tokens = normalized_ground_truth.split()
93
+ return f1_score(prediction_tokens, ground_truth_tokens)
94
+
95
+
96
+ def qa_f1_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:
97
+ """QA F1 score (Chinese, requires jieba)."""
98
+ if not HAS_JIEBA:
99
+ raise ImportError(
100
+ "jieba is required for Chinese evaluation. Install with: pip install jieba"
101
+ )
102
+
103
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
104
+ ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
105
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
106
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
107
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
108
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
109
+ return f1_score(prediction_tokens, ground_truth_tokens)
110
+
111
+
112
+ def rouge_score(prediction: str, ground_truth: str, **kwargs) -> float:
113
+ """ROUGE-L F1 score."""
114
+ if not HAS_ROUGE:
115
+ raise ImportError("rouge is required. Install with: pip install rouge")
116
+
117
+ rouge = Rouge()
118
+ try:
119
+ scores = rouge.get_scores([prediction], [ground_truth], avg=True)
120
+ return scores["rouge-l"]["f"]
121
+ except Exception:
122
+ return 0.0
123
+
124
+
125
+ def rouge_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:
126
+ """ROUGE-L F1 score (Chinese, requires jieba)."""
127
+ if not HAS_JIEBA:
128
+ raise ImportError(
129
+ "jieba is required for Chinese evaluation. Install with: pip install jieba"
130
+ )
131
+
132
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
133
+ ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
134
+ return rouge_score(prediction, ground_truth)
135
+
136
+
137
+ def classification_score(prediction: str, ground_truth: str, **kwargs) -> float:
138
+ """Classification score with all_classes matching."""
139
+ all_classes = kwargs.get("all_classes", [])
140
+ if not all_classes:
141
+ return 0.0
142
+
143
+ em_match_list = []
144
+ for class_name in all_classes:
145
+ if class_name in prediction:
146
+ em_match_list.append(class_name)
147
+
148
+ for match_term in em_match_list.copy():
149
+ if match_term in ground_truth and match_term != ground_truth:
150
+ em_match_list.remove(match_term)
151
+
152
+ if ground_truth in em_match_list:
153
+ return 1.0 / len(em_match_list)
154
+ return 0.0
155
+
156
+
157
+ def retrieval_score(prediction: str, ground_truth: str, **kwargs) -> float:
158
+ """Retrieval score (English)."""
159
+ pattern = r"Paragraph (\d+)"
160
+ matches = re.findall(pattern, ground_truth)
161
+ if not matches:
162
+ return 0.0
163
+ ground_truth_id = matches[0]
164
+ numbers = re.findall(r"\d+", prediction)
165
+ right_num = sum(1 for number in numbers if str(number) == str(ground_truth_id))
166
+ return 0.0 if len(numbers) == 0 else float(right_num / len(numbers))
167
+
168
+
169
+ def retrieval_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:
170
+ """Retrieval score (Chinese)."""
171
+ pattern = r"段落(\d+)"
172
+ matches = re.findall(pattern, ground_truth)
173
+ if not matches:
174
+ return 0.0
175
+ ground_truth_id = matches[0]
176
+ numbers = re.findall(r"\d+", prediction)
177
+ right_num = sum(1 for number in numbers if str(number) == str(ground_truth_id))
178
+ return 0.0 if len(numbers) == 0 else float(right_num / len(numbers))
179
+
180
+
181
+ def count_score(prediction: str, ground_truth: str, **kwargs) -> float:
182
+ """Count score for passage_count task."""
183
+ numbers = re.findall(r"\d+", prediction)
184
+ right_num = sum(1 for number in numbers if str(number) == str(ground_truth))
185
+ return 0.0 if len(numbers) == 0 else float(right_num / len(numbers))
186
+
187
+
188
+ def code_sim_score(prediction: str, ground_truth: str, **kwargs) -> float:
189
+ """Code similarity score (requires fuzzywuzzy)."""
190
+ if not HAS_FUZZYWUZZY:
191
+ raise ImportError(
192
+ "fuzzywuzzy is required for code evaluation. "
193
+ "Install with: pip install fuzzywuzzy python-Levenshtein"
194
+ )
195
+
196
+ all_lines = prediction.lstrip("\n").split("\n")
197
+ processed_prediction = ""
198
+ for line in all_lines:
199
+ if ("`" not in line) and ("#" not in line) and ("//" not in line):
200
+ processed_prediction = line
201
+ break
202
+ return fuzz.ratio(processed_prediction, ground_truth) / 100.0
203
+
204
+
205
+ # 指标函数注册表(照搬官方 dataset2metric 映射)
206
+ METRIC_FUNCTIONS = {
207
+ "qa_f1": qa_f1_score,
208
+ "qa_f1_zh": qa_f1_zh_score,
209
+ "rouge": rouge_score,
210
+ "rouge_zh": rouge_zh_score,
211
+ "classification": classification_score,
212
+ "retrieval": retrieval_score,
213
+ "retrieval_zh": retrieval_zh_score,
214
+ "count": count_score,
215
+ "code_sim": code_sim_score,
216
+ }
@@ -0,0 +1,210 @@
1
+ """
2
+ LongBench Promptor - 优化版
3
+
4
+ 改进点:
5
+ 1. 从 JSON 文件加载配置(而非硬编码)
6
+ 2. 支持模型专用 chat template(照搬官方 build_chat)
7
+ 3. 中间截断策略(照搬官方实现)
8
+ 4. 数据集验证
9
+ """
10
+
11
+ import json
12
+ from pathlib import Path
13
+ from typing import Any, Optional
14
+
15
+ from sage.common.core.functions import MapFunction as MapOperator
16
+
17
+ from .constants import NO_CHAT_DATASETS, SUPPORTED_DATASETS
18
+ from .utils import build_chat, truncate_middle
19
+
20
+
21
+ class LongBenchPromptor(MapOperator):
22
+ """
23
+ LongBench 专用 Promptor。
24
+
25
+ 功能:
26
+ 1. 从 JSON 文件加载任务专用 prompt 模板
27
+ 2. 使用 context 和 input (query) 填充模板
28
+ 3. Token 级中间截断(当超过 max_input_tokens 时,保留首尾)
29
+ 4. 按数据集决定是否使用 chat template(few-shot 和代码任务不使用)
30
+
31
+ 配置参数:
32
+ - max_input_tokens: int | None - 最大输入 token 数,超过则中间截断
33
+ - is_chat_model: bool - 是否使用 chat template(默认 False)
34
+ - model_name_or_path: str | None - 模型路径,用于加载 tokenizer
35
+ """
36
+
37
+ # 配置文件缓存(类级别,避免重复加载)
38
+ _prompt_templates: Optional[dict[str, str]] = None
39
+ _max_gen_lengths: Optional[dict[str, int]] = None
40
+ _model_max_lengths: Optional[dict[str, int]] = None
41
+
42
+ @classmethod
43
+ def _load_configs(cls) -> None:
44
+ """延迟加载 JSON 配置文件"""
45
+ if cls._prompt_templates is None:
46
+ config_dir = Path(__file__).parent / "config"
47
+ with open(config_dir / "dataset2prompt.json", encoding="utf-8") as f:
48
+ cls._prompt_templates = json.load(f)
49
+ with open(config_dir / "dataset2maxlen.json", encoding="utf-8") as f:
50
+ cls._max_gen_lengths = json.load(f)
51
+ with open(config_dir / "model2maxlen.json", encoding="utf-8") as f:
52
+ cls._model_max_lengths = json.load(f)
53
+
54
+ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
55
+ super().__init__(**kwargs)
56
+ self._load_configs()
57
+
58
+ self.config = config
59
+ self.max_input_tokens: Optional[int] = config.get("max_input_tokens")
60
+ self.is_chat_model: bool = config.get("is_chat_model", False)
61
+ self.model_name: str = config.get("model_name_or_path", "")
62
+
63
+ # 延迟加载 tokenizer
64
+ self._tokenizer: Optional[Any] = None
65
+
66
+ @property
67
+ def tokenizer(self) -> Optional[Any]:
68
+ """延迟加载 tokenizer"""
69
+ if self._tokenizer is None and self.model_name:
70
+ try:
71
+ from transformers import AutoTokenizer
72
+
73
+ self._tokenizer = AutoTokenizer.from_pretrained(
74
+ self.model_name, trust_remote_code=True
75
+ )
76
+ self.logger.info(f"Loaded tokenizer from {self.model_name}")
77
+ except Exception as e:
78
+ self.logger.warning(f"Failed to load tokenizer: {e}")
79
+ return self._tokenizer
80
+
81
+ def execute(self, data: dict[str, Any]) -> list[Any]:
82
+ """
83
+ 生成 LongBench 风格的 prompt。
84
+
85
+ 输入格式(来自 LongBenchBatch 或经过 Refiner):
86
+ {
87
+ "query": str, # 用户问题(原 input 字段)
88
+ "context": str, # 长文本上下文(原始)
89
+ "refining_results": List[str], # 压缩后的上下文(可选,Refiner 输出)
90
+ "references": list, # 标准答案
91
+ "_dataset": str, # 数据集名称
92
+ ...
93
+ }
94
+
95
+ **上下文选择优先级**:
96
+ 1. refining_results(如果存在且非空,来自 Refiner 压缩后的结果)
97
+ 2. context(原始上下文,LongBench 自带)
98
+
99
+ 输出格式:
100
+ [original_data, prompt_string]
101
+
102
+ Args:
103
+ data: 包含 query, context, _dataset 等字段的字典
104
+
105
+ Returns:
106
+ [原始数据, prompt 字符串] 列表
107
+ """
108
+ dataset = data.get("_dataset", "")
109
+ query = data.get("query", "")
110
+
111
+ # 数据集验证
112
+ if dataset and dataset not in SUPPORTED_DATASETS:
113
+ self.logger.warning(
114
+ f"Unknown dataset '{dataset}', using default template. "
115
+ f"Supported: {SUPPORTED_DATASETS}"
116
+ )
117
+
118
+ # 上下文选择优先级:refining_results > context
119
+ refining_results = data.get("refining_results", [])
120
+ if refining_results:
121
+ if isinstance(refining_results, list):
122
+ context = "\n\n".join(refining_results)
123
+ else:
124
+ context = str(refining_results)
125
+ self.logger.info("Using refining_results (compressed context)")
126
+ else:
127
+ context = data.get("context", "")
128
+
129
+ # 1. 获取数据集专用模板(从 JSON 加载)
130
+ assert self._prompt_templates is not None
131
+ template = self._prompt_templates.get(dataset, "{context}\n\nQuestion: {input}\nAnswer:")
132
+
133
+ # 2. 填充模板(LongBench 使用 {context} 和 {input} 占位符)
134
+ prompt = template.format(context=context, input=query)
135
+
136
+ # 3. 中间截断(如果配置了 max_input_tokens)
137
+ if self.max_input_tokens and self.tokenizer:
138
+ prompt = truncate_middle(prompt, self.tokenizer, self.max_input_tokens)
139
+
140
+ # 4. Chat template(按数据集决定)
141
+ # 原始 pred.py: if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
142
+ # prompt = build_chat(tokenizer, prompt, model_name)
143
+ if self.is_chat_model and dataset not in NO_CHAT_DATASETS:
144
+ prompt = build_chat(prompt, self.model_name, self.tokenizer)
145
+
146
+ # 5. 设置 max_gen_tokens 供 Generator 使用
147
+ assert self._max_gen_lengths is not None
148
+ data["_max_gen_tokens"] = self._max_gen_lengths.get(dataset, 128)
149
+
150
+ self.logger.info(f"dataset={dataset}, prompt_length={len(prompt)}")
151
+ return [data, prompt]
152
+
153
+ @classmethod
154
+ def get_max_gen_length(cls, dataset: str) -> int:
155
+ """
156
+ 获取数据集的最大生成长度。
157
+
158
+ Args:
159
+ dataset: 数据集名称
160
+
161
+ Returns:
162
+ 最大生成 token 数
163
+ """
164
+ cls._load_configs()
165
+ assert cls._max_gen_lengths is not None
166
+ return cls._max_gen_lengths.get(dataset, 128)
167
+
168
+ @classmethod
169
+ def get_prompt_template(cls, dataset: str) -> str:
170
+ """
171
+ 获取数据集的 prompt 模板。
172
+
173
+ Args:
174
+ dataset: 数据集名称
175
+
176
+ Returns:
177
+ prompt 模板字符串
178
+ """
179
+ cls._load_configs()
180
+ assert cls._prompt_templates is not None
181
+ return cls._prompt_templates.get(dataset, "{context}\n\nQuestion: {input}\nAnswer:")
182
+
183
+ @classmethod
184
+ def get_model_max_length(cls, model_name: str, default: int = 8192) -> int:
185
+ """
186
+ 获取模型的最大上下文长度。
187
+
188
+ 支持精确匹配和模糊匹配(模型名称包含关系)。
189
+
190
+ Args:
191
+ model_name: 模型名称或路径
192
+ default: 默认最大长度(如果模型未在映射中)
193
+
194
+ Returns:
195
+ 模型最大上下文 token 数
196
+ """
197
+ cls._load_configs()
198
+ assert cls._model_max_lengths is not None
199
+
200
+ # 精确匹配
201
+ if model_name in cls._model_max_lengths:
202
+ return cls._model_max_lengths[model_name]
203
+
204
+ # 模糊匹配
205
+ model_lower = model_name.lower()
206
+ for known_model, max_len in cls._model_max_lengths.items():
207
+ if known_model.lower() in model_lower:
208
+ return max_len
209
+
210
+ return default
@@ -0,0 +1,106 @@
1
+ """
2
+ LongBench 工具函数
3
+
4
+ 包含模型专用的 chat template 和后处理逻辑(照搬官方 pred.py)
5
+ """
6
+
7
+ from typing import Any, Optional
8
+
9
+
10
+ def build_chat(prompt: str, model_name: str, tokenizer: Optional[Any] = None) -> str:
11
+ """
12
+ 构建模型专用的 chat prompt。
13
+
14
+ 照搬自官方 pred.py 的 build_chat 函数。
15
+
16
+ Args:
17
+ prompt: 原始 prompt
18
+ model_name: 模型名称
19
+ tokenizer: tokenizer 实例(用于 chatglm3)
20
+
21
+ Returns:
22
+ 包装后的 prompt
23
+ """
24
+ model_name_lower = model_name.lower()
25
+
26
+ if "chatglm3" in model_name_lower and tokenizer:
27
+ # chatglm3 使用 tokenizer 的专用方法
28
+ return tokenizer.build_chat_input(prompt)
29
+ elif "chatglm" in model_name_lower and tokenizer:
30
+ return tokenizer.build_prompt(prompt)
31
+ elif "llama2" in model_name_lower:
32
+ return f"[INST]{prompt}[/INST]"
33
+ elif "xgen" in model_name_lower:
34
+ header = (
35
+ "A chat between a curious human and an artificial intelligence assistant. "
36
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
37
+ )
38
+ return header + f" ### Human: {prompt}\n###"
39
+ elif "internlm" in model_name_lower:
40
+ return f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
41
+
42
+ # 默认:尝试使用 tokenizer 的 apply_chat_template
43
+ if tokenizer and hasattr(tokenizer, "apply_chat_template"):
44
+ try:
45
+ messages = [{"role": "user", "content": prompt}]
46
+ return tokenizer.apply_chat_template(
47
+ messages, tokenize=False, add_generation_prompt=True
48
+ )
49
+ except Exception:
50
+ pass
51
+
52
+ return prompt
53
+
54
+
55
+ def post_process(response: str, model_name: str) -> str:
56
+ """
57
+ 模型专用输出后处理。
58
+
59
+ 照搬自官方 pred.py 的 post_process 函数。
60
+
61
+ Args:
62
+ response: 模型原始输出
63
+ model_name: 模型名称
64
+
65
+ Returns:
66
+ 处理后的输出
67
+ """
68
+ model_name_lower = model_name.lower()
69
+
70
+ if "xgen" in model_name_lower:
71
+ response = response.strip().replace("Assistant:", "")
72
+ elif "internlm" in model_name_lower:
73
+ response = response.split("<eoa>")[0]
74
+
75
+ return response
76
+
77
+
78
+ def truncate_middle(
79
+ prompt: str,
80
+ tokenizer: Any,
81
+ max_length: int,
82
+ ) -> str:
83
+ """
84
+ 中间截断策略(保留首尾)。
85
+
86
+ 照搬自官方 pred.py:
87
+ # truncate to fit max_length (we suggest truncate in the middle,
88
+ # since the left and right side may contain crucial instructions)
89
+
90
+ Args:
91
+ prompt: 原始 prompt
92
+ tokenizer: tokenizer 实例
93
+ max_length: 最大 token 数
94
+
95
+ Returns:
96
+ 截断后的 prompt
97
+ """
98
+ tokenized = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
99
+
100
+ if len(tokenized) > max_length:
101
+ half = max_length // 2
102
+ prompt = tokenizer.decode(tokenized[:half], skip_special_tokens=True) + tokenizer.decode(
103
+ tokenized[-half:], skip_special_tokens=True
104
+ )
105
+
106
+ return prompt