multimetriceval 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.4
2
+ Name: multimetriceval
3
+ Version: 0.1.0
4
+ Summary: 多指标翻译评测工具
5
+ Requires-Python: >=3.8
6
+ Requires-Dist: torch>=1.9.0
7
+ Requires-Dist: numpy
8
+ Requires-Dist: sacrebleu>=2.0.0
9
+ Provides-Extra: comet
10
+ Requires-Dist: unbabel-comet>=2.0.0; extra == "comet"
11
+ Provides-Extra: whisper
12
+ Requires-Dist: openai-whisper; extra == "whisper"
13
+ Provides-Extra: all
14
+ Requires-Dist: unbabel-comet>=2.0.0; extra == "all"
15
+ Requires-Dist: openai-whisper; extra == "all"
@@ -0,0 +1,432 @@
1
+ # 📊 MultiMetric-Eval
2
+
3
+ 多指标翻译评测工具,一行代码计算 BLEU、chrF++、COMET、BLEURT,支持文本和语音输入。
4
+
5
+ [![PyPI version](https://badge.fury.io/py/multimetric-eval.svg)](https://badge.fury.io/py/multimetric-eval)
6
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
7
+ <!-- [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) -->
8
+
9
+ ---
10
+
11
+ ## 🚀 安装
12
+
13
+ ```bash
14
+ # 基础安装(BLEU + chrF++ + COMET)
15
+ pip install multimetric-eval
16
+
17
+ # 可选依赖
18
+ pip install unbabel-comet # COMET 指标
19
+ pip install openai-whisper # 语音转文字
20
+ pip install bleurt # BLEURT 指标
21
+ # 需要额外安装内容请见以下“支持的指标”部分表格
22
+ ```
23
+
24
+ ---
25
+
26
+ ## 📖 快速开始
27
+
28
+ ```python
29
+ from multimetric_eval import ModelEvaluator
30
+
31
+ # 初始化(首次会自动下载 COMET 模型)
32
+ evaluator = ModelEvaluator()
33
+
34
+ # 评测
35
+ results = evaluator.evaluate(
36
+ hypothesis=["The cat sits on the mat."],
37
+ reference=["The cat is sitting on the mat."],
38
+ source=["猫坐在垫子上。"]
39
+ )
40
+
41
+ print(results)
42
+ # {'sacreBLEU': 45.23, 'chrF++': 62.15, 'COMET': 0.8523}
43
+ ```
44
+
45
+ ---
46
+
47
+ ## 📁 使用内置数据集
48
+
49
+ ```python
50
+ from multimetric_eval import ModelEvaluator, load_dataset
51
+
52
+ # 加载内置数据集(自动下载到 ./datasets/,若有网络问题,也可以手动下载https://github.com/sjtuayj/MultiMetric-Eval/releases/download/v0.1.0/zh-en-littleprince.zip并将解压文件zh-en-littleprince保存至./datasets/)
53
+ dataset = load_dataset("zh-en-littleprince")
54
+
55
+ # 初始化评测器
56
+ evaluator = ModelEvaluator(use_comet=True)
57
+ ```
58
+
59
+ ### 方式1:传入列表
60
+
61
+ ```python
62
+ results = evaluator.evaluate_dataset(
63
+ dataset=dataset,
64
+ hypothesis=["Translation 1", "Translation 2", "Translation 3"],
65
+ )
66
+ ```
67
+
68
+ ### 方式2:传入 JSON 文件
69
+
70
+ ```python
71
+ results = evaluator.evaluate_dataset(
72
+ dataset=dataset,
73
+ hypothesis="translations.json",
74
+ )
75
+ ```
76
+
77
+ ### 方式3:传入 TXT 文件
78
+
79
+ ```python
80
+ results = evaluator.evaluate_dataset(
81
+ dataset=dataset,
82
+ hypothesis="translations.txt",
83
+ )
84
+ ```
85
+
86
+ ### 方式4:传入音频文件夹
87
+
88
+ ```python
89
+ # 需要启用 Whisper
90
+ evaluator = ModelEvaluator(use_comet=True, use_whisper=True)
91
+
92
+ results = evaluator.evaluate_dataset(
93
+ dataset=dataset,
94
+ audio_folder="./my_audio/",
95
+ )
96
+ ```
97
+
98
+ ---
99
+
100
+ ## 📂 使用自定义数据集
101
+
102
+ ```python
103
+ from multimetric_eval import ModelEvaluator
104
+
105
+ evaluator = ModelEvaluator(use_comet=True)
106
+
107
+ # 准备参考数据
108
+ reference = ["Reference 1", "Reference 2"]
109
+ source = ["源文本1", "源文本2"] # COMET 需要
110
+ ```
111
+
112
+ ### 方式1:传入列表
113
+
114
+ ```python
115
+ results = evaluator.evaluate(
116
+ hypothesis=["Translation 1", "Translation 2"],
117
+ reference=reference,
118
+ source=source,
119
+ )
120
+ ```
121
+
122
+ ### 方式2:传入 JSON 文件
123
+
124
+ ```python
125
+ results = evaluator.evaluate_file(
126
+ hypothesis_file="translations.json",
127
+ reference=reference,
128
+ source=source,
129
+ )
130
+ ```
131
+
132
+ ### 方式3:传入 TXT 文件
133
+
134
+ ```python
135
+ results = evaluator.evaluate_file(
136
+ hypothesis_file="translations.txt",
137
+ reference=reference,
138
+ source=source,
139
+ )
140
+ ```
141
+
142
+ ### 方式4:传入音频文件夹
143
+
144
+ ```python
145
+ evaluator = ModelEvaluator(use_comet=True, use_whisper=True)
146
+
147
+ results = evaluator.evaluate_audio_folder(
148
+ audio_folder="./my_audio/",
149
+ reference=reference,
150
+ source=source,
151
+ )
152
+ ```
153
+
154
+ ---
155
+
156
+ ## 📄 输入文件格式
157
+
158
+ ### JSON 文件(三种格式均支持)
159
+
160
+ **格式1:字典格式**
161
+ ```json
162
+ {
163
+ "hypothesis": [
164
+ "Translation sentence 1.",
165
+ "Translation sentence 2.",
166
+ "Translation sentence 3."
167
+ ]
168
+ }
169
+ ```
170
+
171
+ **格式2:对象数组格式**
172
+ ```json
173
+ [
174
+ {"id": "001", "hypothesis": "Translation sentence 1."},
175
+ {"id": "002", "hypothesis": "Translation sentence 2."},
176
+ {"id": "003", "hypothesis": "Translation sentence 3."}
177
+ ]
178
+ ```
179
+
180
+ **格式3:纯字符串数组**
181
+ ```json
182
+ [
183
+ "Translation sentence 1.",
184
+ "Translation sentence 2.",
185
+ "Translation sentence 3."
186
+ ]
187
+ ```
188
+
189
+ ### TXT 文件
190
+
191
+ 每行一句,空行自动忽略:
192
+
193
+ ```text
194
+ Translation sentence 1.
195
+ Translation sentence 2.
196
+ Translation sentence 3.
197
+ ```
198
+
199
+ ### 音频文件夹
200
+
201
+ ```
202
+ my_audio/
203
+ ├── 001.wav
204
+ ├── 002.wav
205
+ ├── 003.mp3
206
+ └── 004.flac
207
+ ```
208
+
209
+ - **支持格式**:`.wav`、`.mp3`、`.flac`
210
+ - **排序规则**:按文件名自动排序(确保与参考译文顺序一致)
211
+ - **命名建议**:使用数字前缀如 `001.wav`、`002.wav`
212
+
213
+ ---
214
+
215
+ ## ⚙️ 参数配置
216
+
217
+ ### 评测器参数
218
+
219
+ ```python
220
+ evaluator = ModelEvaluator(
221
+ use_comet=True, # 启用 COMET(需要 source)
222
+ use_bleurt=False, # 启用 BLEURT
223
+ use_whisper=False, # 启用语音转文字
224
+ comet_model="Unbabel/wmt22-comet-da", # COMET 模型
225
+ whisper_model="medium", # tiny/base/small/medium/large
226
+ bleurt_path=None, # BLEURT 模型路径
227
+ device=None, # cuda/cpu,默认自动检测
228
+ )
229
+ ```
230
+
231
+ | 参数 | 类型 | 默认值 | 说明 |
232
+ |------|------|--------|------|
233
+ | `use_comet` | bool | `True` | 启用 COMET 指标 |
234
+ | `use_bleurt` | bool | `False` | 启用 BLEURT 指标 |
235
+ | `use_whisper` | bool | `False` | 启用语音转文字 |
236
+ | `comet_model` | str | `"Unbabel/wmt22-comet-da"` | COMET 模型名称 |
237
+ | `whisper_model` | str | `"medium"` | Whisper 模型大小 |
238
+ | `bleurt_path` | str | `None` | BLEURT 模型本地路径 |
239
+ | `device` | str | `None` | 计算设备,自动检测 GPU |
240
+
241
+ ### 数据集参数
242
+
243
+ ```python
244
+ dataset = load_dataset(
245
+ name="zh-en-littleprince", # 数据集名称
246
+ cache_dir="./datasets", # 缓存目录
247
+ force_download=False, # 强制重新下载
248
+ )
249
+ ```
250
+
251
+ ---
252
+
253
+ ## 🎯 常用场景
254
+
255
+ ### 场景1:快速评测(只用 BLEU + chrF++)
256
+
257
+ ```python
258
+ evaluator = ModelEvaluator(use_comet=False)
259
+
260
+ results = evaluator.evaluate(
261
+ hypothesis=["My translation"],
262
+ reference=["Reference translation"],
263
+ )
264
+ # {'sacreBLEU': 45.23, 'chrF++': 62.15}
265
+ ```
266
+
267
+ ### 场景2:完整评测(全部指标)
268
+
269
+ ```python
270
+ evaluator = ModelEvaluator(
271
+ use_comet=True,
272
+ use_bleurt=True,
273
+ bleurt_path="./model/BLEURT-20",
274
+ )
275
+
276
+ results = evaluator.evaluate(
277
+ hypothesis=["My translation"],
278
+ reference=["Reference translation"],
279
+ source=["源文本"],
280
+ )
281
+ # {'sacreBLEU': 45.23, 'chrF++': 62.15, 'COMET': 0.85, 'BLEURT': 0.72}
282
+ ```
283
+
284
+ ### 场景3:语音评测
285
+
286
+ ```python
287
+ evaluator = ModelEvaluator(
288
+ use_comet=True,
289
+ use_whisper=True,
290
+ whisper_model="large", # 更高精度
291
+ )
292
+
293
+ results = evaluator.evaluate_audio_folder(
294
+ audio_folder="./speech_outputs/",
295
+ reference=["Reference 1", "Reference 2"],
296
+ source=["源文本1", "源文本2"],
297
+ )
298
+
299
+ # 查看 ASR 转写结果
300
+ print(results["hypothesis"])
301
+ ```
302
+
303
+ ### 场景4:强制使用 CPU
304
+
305
+ ```python
306
+ evaluator = ModelEvaluator(
307
+ use_comet=True,
308
+ device="cpu",
309
+ )
310
+ ```
311
+
312
+ ---
313
+
314
+ ## 📊 支持的指标
315
+
316
+ | 指标 | 说明 | 需要 source | 需要额外安装 |
317
+ |------|------|-------------|--------------|
318
+ | sacreBLEU | 标准 BLEU 分数 | ❌ | ❌ |
319
+ | chrF++ | 字符级 F 分数 | ❌ | ❌ |
320
+ | COMET | 神经网络评估 | ✅ | `unbabel-comet` |
321
+ | BLEURT | Google BLEURT | ❌ | `bleurt` + 模型文件 |
322
+
323
+ ---
324
+
325
+ ## 📤 输出结果
326
+
327
+ ```python
328
+ results = evaluator.evaluate(...)
329
+
330
+ print(results)
331
+ # {
332
+ # "sacreBLEU": 45.23, # 始终返回
333
+ # "chrF++": 62.15, # 始终返回
334
+ # "COMET": 0.8523, # use_comet=True 时返回
335
+ # "BLEURT": 0.7234, # use_bleurt=True 时返回
336
+ # "hypothesis": [...], # 音频输入时返回转写结果
337
+ # }
338
+ ```
339
+
340
+ ---
341
+
342
+ ## 📋 输入格式支持总结
343
+
344
+ | 格式 | 方法 | 示例 |
345
+ |------|------|------|
346
+ | Python 列表 | `evaluate()` | `["sent1", "sent2"]` |
347
+ | JSON 文件 | `evaluate_file()` | `"translations.json"` |
348
+ | TXT 文件 | `evaluate_file()` | `"translations.txt"` |
349
+ | 音频文件夹 | `evaluate_audio_folder()` | `"./audio/"` |
350
+
351
+ ---
352
+
353
+ ## 🔧 高级用法
354
+
355
+ ### 使用上下文管理器(自动释放显存)
356
+
357
+ ```python
358
+ with ModelEvaluator(use_comet=True) as evaluator:
359
+ results = evaluator.evaluate(
360
+ hypothesis=["Translation"],
361
+ reference=["Reference"],
362
+ source=["源文本"],
363
+ )
364
+ # 自动释放显存
365
+ ```
366
+
367
+ ### 从本地 JSON 创建自定义数据集
368
+
369
+ ```python
370
+ from multimetric_eval import create_dataset_from_json
371
+
372
+ # my_data.json 格式:
373
+ # [
374
+ # {"id": "001", "source_text": "源文本1", "reference_text": "Ref 1"},
375
+ # {"id": "002", "source_text": "源文本2", "reference_text": "Ref 2"}
376
+ # ]
377
+
378
+ dataset = create_dataset_from_json("./my_data.json")
379
+
380
+ results = evaluator.evaluate_dataset(
381
+ dataset=dataset,
382
+ hypothesis=["Translation 1", "Translation 2"],
383
+ )
384
+ ```
385
+
386
+ ### 查看可用数据集
387
+
388
+ ```python
389
+ from multimetric_eval import list_datasets, get_dataset_info
390
+
391
+ # 列出所有可用数据集
392
+ print(list_datasets())
393
+ # ['zh-en-littleprince']
394
+
395
+ # 查看数据集详情
396
+ info = get_dataset_info("zh-en-littleprince")
397
+ print(info)
398
+ # {
399
+ # 'name': 'zh-en-littleprince',
400
+ # 'is_downloaded': True,
401
+ # 'num_samples': 100,
402
+ # 'audio_complete': True
403
+ # }
404
+ ```
405
+
406
+ ---
407
+
408
+ ## ❓ 常见问题
409
+
410
+ ### Q: COMET 分数显示 -1.0?
411
+ A: 请确保传入了 `source` 参数,COMET 需要源文本。
412
+
413
+ ### Q: CUDA out of memory?
414
+ A: 使用上下文管理器或手动调用 `evaluator.cleanup()` 释放显存。
415
+
416
+ ### Q: 如何只使用基础指标?
417
+ A: 设置 `use_comet=False`,只计算 BLEU 和 chrF++。
418
+
419
+ ### Q: 音频文件顺序不对?
420
+ A: 使用数字前缀命名,如 `001.wav`、`002.wav`,确保排序正确。
421
+
422
+ ---
423
+
424
+ ## 📜 License
425
+
426
+ MIT License
427
+
428
+ ---
429
+
430
+ ## 🤝 Contributing
431
+
432
+ 欢迎提交 Issue 和 Pull Request!
@@ -0,0 +1,22 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "multimetriceval"
7
+ version = "0.1.0"
8
+ description = "多指标翻译评测工具"
9
+ requires-python = ">=3.8"
10
+ dependencies = [
11
+ "torch>=1.9.0",
12
+ "numpy",
13
+ "sacrebleu>=2.0.0",
14
+ ]
15
+
16
+ [project.optional-dependencies]
17
+ comet = ["unbabel-comet>=2.0.0"]
18
+ whisper = ["openai-whisper"]
19
+ all = ["unbabel-comet>=2.0.0", "openai-whisper"]
20
+
21
+ [tool.setuptools.packages.find]
22
+ where = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,24 @@
1
+ from .evaluator import (
2
+ ModelEvaluator,
3
+ load_hypothesis_from_file,
4
+ load_audio_from_folder,
5
+ )
6
+ from .dataset import (
7
+ Dataset,
8
+ load_dataset,
9
+ list_datasets,
10
+ get_dataset_info,
11
+ create_dataset_from_json,
12
+ )
13
+
14
+ __version__ = "0.1.0"
15
+ __all__ = [
16
+ "ModelEvaluator",
17
+ "load_hypothesis_from_file",
18
+ "load_audio_from_folder",
19
+ "Dataset",
20
+ "load_dataset",
21
+ "list_datasets",
22
+ "get_dataset_info",
23
+ "create_dataset_from_json",
24
+ ]
@@ -0,0 +1,199 @@
1
+ """
2
+ 内置数据集管理 - 支持音频文件下载
3
+ 优先使用本地缓存,无则联网下载
4
+ """
5
+ import os
6
+ import json
7
+ import urllib.request
8
+ import zipfile
9
+ import shutil
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ # 数据集下载地址
13
+ DATASET_URLS = {
14
+ "zh-en-littleprince": "https://github.com/sjtuayj/MultiMetric-Eval/releases/download/v0.1.0/zh-en-littleprince.zip",
15
+ }
16
+
17
+ # 默认缓存目录
18
+ # DEFAULT_CACHE_DIR = os.path.expanduser("~/.datasets")
19
+ DEFAULT_CACHE_DIR = "./datasets"
20
+
21
+ class Dataset:
22
+ """内置数据集类"""
23
+
24
+ def __init__(self, data: List[Dict], base_dir: str):
25
+ self._data = data
26
+ self._base_dir = base_dir
27
+
28
+ def __len__(self) -> int:
29
+ return len(self._data)
30
+
31
+ def __getitem__(self, idx: int) -> Dict:
32
+ item = self._data[idx].copy()
33
+ if "source_speech_path" in item:
34
+ filename = os.path.basename(item["source_speech_path"])
35
+ item["source_speech_path"] = os.path.join(self._base_dir, "audio", filename)
36
+ return item
37
+
38
+ @property
39
+ def ids(self) -> List[str]:
40
+ return [item.get("id", f"sample_{i}") for i, item in enumerate(self._data)]
41
+
42
+ @property
43
+ def source_texts(self) -> List[str]:
44
+ return [item["source_text"] for item in self._data]
45
+
46
+ @property
47
+ def reference_texts(self) -> List[str]:
48
+ return [item["reference_text"] for item in self._data]
49
+
50
+ @property
51
+ def audio_paths(self) -> List[str]:
52
+ return [self[i].get("source_speech_path", "") for i in range(len(self))]
53
+
54
+ def verify_audio_files(self) -> Dict[str, Union[int, List[str]]]:
55
+ """验证音频文件完整性"""
56
+ missing = [p for p in self.audio_paths if not os.path.exists(p)]
57
+ return {
58
+ "total": len(self),
59
+ "found": len(self) - len(missing),
60
+ "missing": len(missing),
61
+ "missing_files": missing,
62
+ }
63
+
64
+
65
+ def list_datasets() -> List[str]:
66
+ """列出所有可用数据集"""
67
+ return list(DATASET_URLS.keys())
68
+
69
+
70
+ def _is_dataset_cached(name: str, cache_dir: str) -> bool:
71
+ """检查数据集是否已下载"""
72
+ dataset_dir = os.path.join(cache_dir, name)
73
+ data_file = os.path.join(dataset_dir, "dataset_paired.json")
74
+ audio_dir = os.path.join(dataset_dir, "audio")
75
+ return os.path.exists(data_file) and os.path.exists(audio_dir)
76
+
77
+
78
+ def get_dataset_info(name: str, cache_dir: Optional[str] = None) -> Dict:
79
+ """获取数据集信息"""
80
+ if name not in DATASET_URLS:
81
+ raise ValueError(f"未知数据集: {name}")
82
+
83
+ cache_dir = cache_dir or DEFAULT_CACHE_DIR
84
+ is_cached = _is_dataset_cached(name, cache_dir)
85
+
86
+ info = {
87
+ "name": name,
88
+ "url": DATASET_URLS[name],
89
+ "cache_dir": os.path.join(cache_dir, name),
90
+ "is_downloaded": is_cached,
91
+ }
92
+
93
+ if is_cached:
94
+ dataset = load_dataset(name, cache_dir=cache_dir)
95
+ info["num_samples"] = len(dataset)
96
+ verify = dataset.verify_audio_files()
97
+ info["audio_complete"] = verify["missing"] == 0
98
+
99
+ return info
100
+
101
+
102
+ def load_dataset(
103
+ name: str,
104
+ cache_dir: Optional[str] = None,
105
+ force_download: bool = False,
106
+ ) -> Dataset:
107
+ """
108
+ 加载数据集(优先本地,无则下载)
109
+
110
+ Args:
111
+ name: 数据集名称
112
+ cache_dir: 缓存目录
113
+ force_download: 强制重新下载
114
+
115
+ Returns:
116
+ Dataset 对象
117
+ """
118
+ if name not in DATASET_URLS:
119
+ available = ", ".join(DATASET_URLS.keys())
120
+ raise ValueError(f"未知数据集: {name}。可用: {available}")
121
+
122
+ cache_dir = cache_dir or DEFAULT_CACHE_DIR
123
+ dataset_dir = os.path.join(cache_dir, name)
124
+ data_file = os.path.join(dataset_dir, "dataset_paired.json")
125
+
126
+ # 检查本地缓存
127
+ if _is_dataset_cached(name, cache_dir) and not force_download:
128
+ print(f"✅ [Local] 使用本地数据集: {name}")
129
+ print(f" 路径: {dataset_dir}")
130
+ else:
131
+ print(f"⏳ [Online] 下载数据集: {name}")
132
+ _download_dataset(name, cache_dir)
133
+
134
+ # 加载数据
135
+ with open(data_file, "r", encoding="utf-8") as f:
136
+ data = json.load(f)
137
+
138
+ dataset = Dataset(data=data, base_dir=dataset_dir)
139
+
140
+ # 验证完整性
141
+ verify = dataset.verify_audio_files()
142
+ if verify["missing"] > 0:
143
+ print(f" ⚠️ 缺少 {verify['missing']} 个音频文件")
144
+ else:
145
+ print(f" ✅ 数据完整 ({verify['total']} 条样本, 音频齐全)")
146
+
147
+ return dataset
148
+
149
+
150
+ def _download_dataset(name: str, cache_dir: str):
151
+ """下载数据集"""
152
+ url = DATASET_URLS[name]
153
+ dataset_dir = os.path.join(cache_dir, name)
154
+ zip_path = os.path.join(cache_dir, f"{name}.zip")
155
+
156
+ os.makedirs(cache_dir, exist_ok=True)
157
+
158
+ if os.path.exists(dataset_dir):
159
+ shutil.rmtree(dataset_dir)
160
+
161
+ print(f" URL: {url}")
162
+
163
+ try:
164
+ urllib.request.urlretrieve(url, zip_path, _download_progress)
165
+ print()
166
+
167
+ print(f" 📦 解压中...")
168
+ os.makedirs(dataset_dir, exist_ok=True)
169
+ with zipfile.ZipFile(zip_path, 'r') as zf:
170
+ zf.extractall(dataset_dir)
171
+
172
+ os.remove(zip_path)
173
+ print(f" ✅ 下载完成: {dataset_dir}")
174
+
175
+ except Exception as e:
176
+ if os.path.exists(zip_path):
177
+ os.remove(zip_path)
178
+ if os.path.exists(dataset_dir):
179
+ shutil.rmtree(dataset_dir)
180
+ raise RuntimeError(f"下载失败: {e}")
181
+
182
+
183
+ def _download_progress(block_num, block_size, total_size):
184
+ """下载进度条"""
185
+ downloaded = block_num * block_size
186
+ if total_size > 0:
187
+ percent = min(100, downloaded * 100 // total_size)
188
+ bar_len = 40
189
+ filled = int(bar_len * percent // 100)
190
+ bar = "█" * filled + "░" * (bar_len - filled)
191
+ print(f"\r [{bar}] {percent}%", end="", flush=True)
192
+
193
+
194
+ def create_dataset_from_json(json_path: str) -> Dataset:
195
+ """从本地 JSON 创建数据集"""
196
+ with open(json_path, "r", encoding="utf-8") as f:
197
+ data = json.load(f)
198
+ base_dir = os.path.dirname(os.path.abspath(json_path))
199
+ return Dataset(data=data, base_dir=base_dir)
@@ -0,0 +1,320 @@
1
+ """
2
+ MultiMetric Eval - 多指标翻译评测工具
3
+ """
4
+ import os
5
+ import gc
6
+ import json
7
+ import numpy as np
8
+ import sacrebleu
9
+ import torch
10
+ from typing import Dict, List, Optional, Union
11
+ from pathlib import Path
12
+
13
+ # ==================== 配置 ====================
14
+
15
+ CACHE_PATHS = {
16
+ "huggingface": os.path.expanduser("~/.cache/huggingface/hub"),
17
+ "whisper": os.path.expanduser("~/.cache/whisper"),
18
+ }
19
+
20
+ for var in ["HF_DATASETS_OFFLINE", "TRANSFORMERS_OFFLINE"]:
21
+ os.environ.pop(var, None)
22
+
23
+ # ==================== 可选依赖 ====================
24
+
25
+ try:
26
+ import whisper
27
+ except ImportError:
28
+ whisper = None
29
+
30
+ try:
31
+ from bleurt import score as bleurt_score
32
+ except ImportError:
33
+ bleurt_score = None
34
+
35
+ try:
36
+ from comet import download_model, load_from_checkpoint
37
+ except ImportError:
38
+ download_model = None
39
+ load_from_checkpoint = None
40
+
41
+
42
+ # ==================== 输入加载工具 ====================
43
+
44
+ def load_hypothesis_from_file(file_path: str) -> List[str]:
45
+ """
46
+ 从文件加载用户翻译结果
47
+
48
+ 支持格式:
49
+ - .json: {"hypothesis": [...]} 或 [{"id": "x", "hypothesis": "..."}, ...]
50
+ - .txt: 每行一句
51
+ """
52
+ path = Path(file_path)
53
+
54
+ if not path.exists():
55
+ raise FileNotFoundError(f"文件不存在: {file_path}")
56
+
57
+ suffix = path.suffix.lower()
58
+
59
+ if suffix == ".json":
60
+ with open(path, "r", encoding="utf-8") as f:
61
+ data = json.load(f)
62
+
63
+ if isinstance(data, dict) and "hypothesis" in data:
64
+ return data["hypothesis"]
65
+
66
+ if isinstance(data, list) and len(data) > 0:
67
+ if isinstance(data[0], dict) and "hypothesis" in data[0]:
68
+ return [item["hypothesis"] for item in data]
69
+ if isinstance(data[0], str):
70
+ return data
71
+
72
+ raise ValueError("JSON 格式不正确")
73
+
74
+ elif suffix == ".txt":
75
+ with open(path, "r", encoding="utf-8") as f:
76
+ return [line.strip() for line in f if line.strip()]
77
+
78
+ else:
79
+ raise ValueError(f"不支持的文件格式: {suffix}")
80
+
81
+
82
+ def load_audio_from_folder(folder_path: str, extensions: tuple = (".wav", ".mp3", ".flac")) -> List[str]:
83
+ """从文件夹加载音频文件路径"""
84
+ folder = Path(folder_path)
85
+
86
+ if not folder.exists():
87
+ raise FileNotFoundError(f"文件夹不存在: {folder_path}")
88
+
89
+ audio_files = []
90
+ for ext in extensions:
91
+ audio_files.extend(folder.glob(f"*{ext}"))
92
+
93
+ audio_files = sorted(audio_files, key=lambda x: x.stem)
94
+
95
+ if not audio_files:
96
+ raise ValueError(f"文件夹中没有音频文件: {folder_path}")
97
+
98
+ return [str(f) for f in audio_files]
99
+
100
+
101
+ # ==================== 评测器 ====================
102
+
103
+ class ModelEvaluator:
104
+ """多指标翻译评测器"""
105
+
106
+ def __init__(
107
+ self,
108
+ use_comet: bool = True,
109
+ use_bleurt: bool = False,
110
+ use_whisper: bool = False,
111
+ comet_model: str = "Unbabel/wmt22-comet-da",
112
+ whisper_model: str = "medium",
113
+ bleurt_path: Optional[str] = None,
114
+ device: Optional[str] = None,
115
+ ):
116
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
117
+ print(f"🚀 初始化评测器 (设备: {self.device})")
118
+
119
+ self.comet = self._load_comet(comet_model) if use_comet else None
120
+ self.whisper = self._load_whisper(whisper_model) if use_whisper else None
121
+ self.bleurt = self._load_bleurt(bleurt_path) if use_bleurt else None
122
+
123
+ print("✅ 系统就绪!")
124
+
125
+ # -------------------- 上下文管理器 --------------------
126
+
127
+ def __enter__(self):
128
+ return self
129
+
130
+ def __exit__(self, exc_type, exc_val, exc_tb):
131
+ self.cleanup()
132
+ return False
133
+
134
+ def cleanup(self):
135
+ """释放模型显存"""
136
+ if hasattr(self, 'comet') and self.comet is not None:
137
+ del self.comet
138
+ self.comet = None
139
+ if hasattr(self, 'whisper') and self.whisper is not None:
140
+ del self.whisper
141
+ self.whisper = None
142
+ if hasattr(self, 'bleurt') and self.bleurt is not None:
143
+ del self.bleurt
144
+ self.bleurt = None
145
+
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+
149
+ gc.collect()
150
+ print("🧹 已释放模型显存")
151
+
152
+ # -------------------- 模型加载 --------------------
153
+
154
+ def _load_comet(self, model_name: str):
155
+ """加载 COMET 模型"""
156
+ if not download_model:
157
+ print("⚠️ COMET 未安装: pip install unbabel-comet")
158
+ return None
159
+
160
+ cache = os.path.join(CACHE_PATHS["huggingface"], f"models--{model_name.replace('/', '--')}")
161
+ status = "[Local]" if os.path.exists(cache) else "[Online]"
162
+ print(f"⏳ {status} 加载 COMET: {model_name}")
163
+
164
+ model = load_from_checkpoint(download_model(model_name))
165
+ model = model.to(self.device) if self.device == "cuda" else model
166
+ print("✅ COMET 加载成功!")
167
+ return model
168
+
169
+ def _load_whisper(self, model_name: str):
170
+ """加载 Whisper 模型"""
171
+ if not whisper:
172
+ print("⚠️ Whisper 未安装: pip install openai-whisper")
173
+ return None
174
+
175
+ cache = os.path.join(CACHE_PATHS["whisper"], f"{model_name}.pt")
176
+ status = "[Local]" if os.path.exists(cache) else "[Online]"
177
+ print(f"⏳ {status} 加载 Whisper: {model_name}")
178
+
179
+ model = whisper.load_model(model_name, device=self.device)
180
+ print("✅ Whisper 加载成功!")
181
+ return model
182
+
183
+ def _load_bleurt(self, path: Optional[str]):
184
+ """加载 BLEURT 模型"""
185
+ if not bleurt_score:
186
+ print("⚠️ BLEURT 未安装: pip install bleurt")
187
+ return None
188
+ if not path or not os.path.exists(path):
189
+ print(f"⚠️ BLEURT 路径无效: {path}")
190
+ return None
191
+
192
+ print(f"⏳ [Local] 加载 BLEURT: {path}")
193
+ try:
194
+ # 强制 BLEURT 使用 CPU(避免与 PyTorch 抢显存)
195
+ import tensorflow as tf
196
+ tf.config.set_visible_devices([], 'GPU')
197
+
198
+ scorer = bleurt_score.BleurtScorer(path)
199
+ print("✅ BLEURT 加载成功!")
200
+ return scorer
201
+ except Exception as e:
202
+ print(f"⚠️ BLEURT 加载失败: {e}")
203
+ return None
204
+
205
+ # -------------------- 核心功能 --------------------
206
+
207
+ def transcribe(self, audio_paths: List[str]) -> List[str]:
208
+ """语音转文字"""
209
+ if not self.whisper:
210
+ raise RuntimeError("请设置 use_whisper=True")
211
+
212
+ print(f"🎤 ASR 转写 ({len(audio_paths)} 个文件)...")
213
+ results = []
214
+ for i, path in enumerate(audio_paths, 1):
215
+ if not os.path.exists(path):
216
+ print(f" ⚠️ [{i}] 文件不存在")
217
+ results.append("")
218
+ else:
219
+ try:
220
+ text = self.whisper.transcribe(path, fp16=(self.device == "cuda"))["text"]
221
+ results.append(text.strip())
222
+ print(f" ✓ [{i}/{len(audio_paths)}] {os.path.basename(path)}")
223
+ except:
224
+ results.append("")
225
+ return results
226
+
227
+ def evaluate(
228
+ self,
229
+ hypothesis: List[str],
230
+ reference: List[str],
231
+ source: Optional[List[str]] = None,
232
+ ) -> Dict[str, float]:
233
+ """计算评测指标"""
234
+ print("📊 计算指标...")
235
+
236
+ results = {
237
+ "sacreBLEU": self._safe_calc(lambda: sacrebleu.corpus_bleu(hypothesis, [reference]).score),
238
+ "chrF++": self._safe_calc(lambda: sacrebleu.corpus_chrf(hypothesis, [reference], word_order=2).score),
239
+ }
240
+
241
+ if self.bleurt:
242
+ results["BLEURT"] = self._safe_calc(
243
+ lambda: float(np.mean(self.bleurt.score(references=reference, candidates=hypothesis)))
244
+ )
245
+
246
+ if self.comet:
247
+ if source:
248
+ data = [{"src": s, "mt": h, "ref": r} for s, h, r in zip(source, hypothesis, reference)]
249
+ gpus = 1 if self.device == "cuda" else 0
250
+ results["COMET"] = self._safe_calc(lambda: self.comet.predict(data, batch_size=8, gpus=gpus).system_score)
251
+ else:
252
+ print("⚠️ COMET 需要 source 参数")
253
+ results["COMET"] = -1.0
254
+
255
+ return {k: round(v, 4) if v >= 0 else v for k, v in results.items()}
256
+
257
+ def evaluate_file(
258
+ self,
259
+ hypothesis_file: str,
260
+ reference: List[str],
261
+ source: Optional[List[str]] = None,
262
+ ) -> Dict[str, float]:
263
+ """从文件加载翻译结果并评测"""
264
+ print(f"📂 加载翻译结果: {hypothesis_file}")
265
+ hypothesis = load_hypothesis_from_file(hypothesis_file)
266
+ print(f" 加载了 {len(hypothesis)} 条翻译")
267
+ return self.evaluate(hypothesis, reference, source)
268
+
269
+ def evaluate_audio_folder(
270
+ self,
271
+ audio_folder: str,
272
+ reference: List[str],
273
+ source: Optional[List[str]] = None,
274
+ ) -> Dict[str, Union[float, List[str]]]:
275
+ """从文件夹加载音频并评测"""
276
+ print(f"📂 加载音频文件夹: {audio_folder}")
277
+ audio_paths = load_audio_from_folder(audio_folder)
278
+ print(f" 找到 {len(audio_paths)} 个音频文件")
279
+
280
+ hypothesis = self.transcribe(audio_paths)
281
+ results = self.evaluate(hypothesis, reference, source)
282
+ results["hypothesis"] = hypothesis
283
+ return results
284
+
285
+ def evaluate_dataset(
286
+ self,
287
+ dataset,
288
+ hypothesis: Optional[Union[List[str], str]] = None,
289
+ audio_folder: Optional[str] = None,
290
+ ) -> Dict[str, Union[float, List[str]]]:
291
+ """使用数据集评测"""
292
+ if hypothesis:
293
+ if isinstance(hypothesis, str):
294
+ hyp_list = load_hypothesis_from_file(hypothesis)
295
+ else:
296
+ hyp_list = hypothesis
297
+
298
+ results = self.evaluate(hyp_list, dataset.reference_texts, dataset.source_texts)
299
+ results["hypothesis"] = hyp_list
300
+
301
+ elif audio_folder:
302
+ results = self.evaluate_audio_folder(
303
+ audio_folder,
304
+ dataset.reference_texts,
305
+ dataset.source_texts
306
+ )
307
+
308
+ else:
309
+ raise ValueError("请提供 hypothesis(列表或文件路径)或 audio_folder")
310
+
311
+ return results
312
+
313
+ # -------------------- 工具方法 --------------------
314
+
315
+ @staticmethod
316
+ def _safe_calc(fn, default=-1.0) -> float:
317
+ try:
318
+ return fn()
319
+ except:
320
+ return default
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.4
2
+ Name: multimetriceval
3
+ Version: 0.1.0
4
+ Summary: 多指标翻译评测工具
5
+ Requires-Python: >=3.8
6
+ Requires-Dist: torch>=1.9.0
7
+ Requires-Dist: numpy
8
+ Requires-Dist: sacrebleu>=2.0.0
9
+ Provides-Extra: comet
10
+ Requires-Dist: unbabel-comet>=2.0.0; extra == "comet"
11
+ Provides-Extra: whisper
12
+ Requires-Dist: openai-whisper; extra == "whisper"
13
+ Provides-Extra: all
14
+ Requires-Dist: unbabel-comet>=2.0.0; extra == "all"
15
+ Requires-Dist: openai-whisper; extra == "all"
@@ -0,0 +1,10 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/multimetric_eval/__init__.py
4
+ src/multimetric_eval/dataset.py
5
+ src/multimetric_eval/evaluator.py
6
+ src/multimetriceval.egg-info/PKG-INFO
7
+ src/multimetriceval.egg-info/SOURCES.txt
8
+ src/multimetriceval.egg-info/dependency_links.txt
9
+ src/multimetriceval.egg-info/requires.txt
10
+ src/multimetriceval.egg-info/top_level.txt
@@ -0,0 +1,13 @@
1
+ torch>=1.9.0
2
+ numpy
3
+ sacrebleu>=2.0.0
4
+
5
+ [all]
6
+ unbabel-comet>=2.0.0
7
+ openai-whisper
8
+
9
+ [comet]
10
+ unbabel-comet>=2.0.0
11
+
12
+ [whisper]
13
+ openai-whisper
@@ -0,0 +1 @@
1
+ multimetric_eval