modern-bert-score 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modern_bert_score-0.0.1/LICENSE.md +13 -0
- modern_bert_score-0.0.1/PKG-INFO +124 -0
- modern_bert_score-0.0.1/README.md +88 -0
- modern_bert_score-0.0.1/modern_bert_score/__init__.py +17 -0
- modern_bert_score-0.0.1/modern_bert_score/bert_score.py +310 -0
- modern_bert_score-0.0.1/modern_bert_score/consts.py +8 -0
- modern_bert_score-0.0.1/modern_bert_score/inference.py +139 -0
- modern_bert_score-0.0.1/modern_bert_score.egg-info/PKG-INFO +124 -0
- modern_bert_score-0.0.1/modern_bert_score.egg-info/SOURCES.txt +13 -0
- modern_bert_score-0.0.1/modern_bert_score.egg-info/dependency_links.txt +1 -0
- modern_bert_score-0.0.1/modern_bert_score.egg-info/requires.txt +16 -0
- modern_bert_score-0.0.1/modern_bert_score.egg-info/top_level.txt +1 -0
- modern_bert_score-0.0.1/pyproject.toml +66 -0
- modern_bert_score-0.0.1/setup.cfg +4 -0
- modern_bert_score-0.0.1/tests/test_bert_score.py +326 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Copyright 2026 Philipp Koch
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: modern-bert-score
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A reimplementation of the BERTScore metric optimized for modern inference workflows.
|
|
5
|
+
Author-email: Philipp Koch <PhillKoch@protonmail.com>
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/pypa/sampleproject
|
|
8
|
+
Project-URL: Issues, https://github.com/pypa/sampleproject/issues
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.12
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE.md
|
|
19
|
+
Requires-Dist: sentence-transformers>=5.2.3
|
|
20
|
+
Requires-Dist: torch>=2.9.1
|
|
21
|
+
Requires-Dist: transformers>=4.57.6
|
|
22
|
+
Provides-Extra: vllm
|
|
23
|
+
Requires-Dist: vllm>=0.16.0; extra == "vllm"
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: sphinx; extra == "dev"
|
|
26
|
+
Requires-Dist: sphinx_rtd_theme; extra == "dev"
|
|
27
|
+
Requires-Dist: myst-parser; extra == "dev"
|
|
28
|
+
Requires-Dist: pytest; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff; extra == "dev"
|
|
31
|
+
Requires-Dist: flake8; extra == "dev"
|
|
32
|
+
Requires-Dist: mypy; extra == "dev"
|
|
33
|
+
Requires-Dist: flake8; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
35
|
+
Dynamic: license-file
|
|
36
|
+
|
|
37
|
+
# Modern BERTScore for Fast Inference
|
|
38
|
+
|
|
39
|
+
[](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml)
|
|
40
|
+
[](https://www.python.org/downloads/release/python-3120/)
|
|
41
|
+
[](https://www.python.org/downloads/release/python-3130/)
|
|
42
|
+
[](https://www.python.org/downloads/)
|
|
43
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
44
|
+
|
|
45
|
+

|
|
46
|
+
|
|
47
|
+
**Modern-BERT-Score** is a reimplementation of the BERTScore metric introduced by [Zhang et al., 2019](https://arxiv.org/abs/1904.09675), optimized for modern inference workflows using [SentenceTransformers](https://www.sbert.net/) and [vLLM](https://vllm.ai/).
|
|
48
|
+
|
|
49
|
+
This library provides fast, GPU-accelerated scoring for text generation evaluation, making BERTScore practical for large-scale inference tasks.
|
|
50
|
+
|
|
51
|
+
---
|
|
52
|
+
|
|
53
|
+
## ⚡ Features
|
|
54
|
+
- Fast, efficient computation with optional vLLM support
|
|
55
|
+
- Compatible with all Hugging Face transformer models
|
|
56
|
+
- Supports truncated and optimized model versions for faster inference
|
|
57
|
+
- Works seamlessly with both CPU and GPU setups
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
## 📦 Installation
|
|
62
|
+
|
|
63
|
+
Modern-BERT-Score comes in **two variants**: a base version and a vLLM-enhanced version. For vLLM, an NVIDIA GPU is strongly recommended.
|
|
64
|
+
|
|
65
|
+
### Base Version
|
|
66
|
+
```bash
|
|
67
|
+
pip install modern-bert-score
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### vLLM Version
|
|
71
|
+
```{bash}
|
|
72
|
+
pip install modern-bert-score[vllm]
|
|
73
|
+
```
|
|
74
|
+
This implementation is significantly faster than the original BERTScore, especially with GPU acceleration.
|
|
75
|
+
|
|
76
|
+
## 📝 BERTScore
|
|
77
|
+
BERTScore ([Zhang et al., 2019](https://arxiv.org/abs/1904.09675)) evaluates the similarity between candidate and reference texts by comparing their contextual embeddings from a pre-trained transformer model. For each token in the candidate, it finds the most similar token in the reference (using cosine similarity) and aggregates these scores to compute **precision**, **recall**, and **F1**. Optionally, **IDF-weighting** can be applied to give more importance to rare and informative words, improving the metric’s sensitivity to meaningful content over common words. Additionally, optional **Baseline Rescaling** shifts the scores such that the score is in the range of [0,1]. This approach captures semantic similarity beyond exact word matches, making it robust for tasks such as machine translation and text generation evaluation.
|
|
78
|
+
|
|
79
|
+
The following figure, taken from the original paper, illustrates how BERTScore works:
|
|
80
|
+
|
|
81
|
+

|
|
82
|
+
|
|
83
|
+
## 🛠 Usage
|
|
84
|
+
### Example
|
|
85
|
+
```python
|
|
86
|
+
from modern_bert_score import BertScore
|
|
87
|
+
|
|
88
|
+
candidates = ["Hello World!", "A robin is a bird."]
|
|
89
|
+
references = ["Hi World!", "A robin is not a bird."]
|
|
90
|
+
|
|
91
|
+
metric = BertScore(model_id="roberta-base")
|
|
92
|
+
scores = metric(candidates, references)
|
|
93
|
+
|
|
94
|
+
# scores is a list of (Precision, Recall, F1) tuples
|
|
95
|
+
# To get separate lists of P, R, F1:
|
|
96
|
+
P, R, F1 = zip(*scores)
|
|
97
|
+
|
|
98
|
+
print("Precision scores:", P)
|
|
99
|
+
print("Recall scores:", R)
|
|
100
|
+
print("F1 scores:", F1)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## ⚠️ NOTICE
|
|
104
|
+
|
|
105
|
+
- For best performance, an optimal layer should be used for each model.
|
|
106
|
+
- To find the optimal layer, [please use this script from the original BERTScore implementation](https://github.com/Tiiiger/bert_score/tree/master/tune_layers).
|
|
107
|
+
|
|
108
|
+
Some pre-truncated models optimized for vLLM are available on [Hugging Face](https://huggingface.co/collections/LazerLambda/modern-bertscore) and directly available in this library:
|
|
109
|
+
|
|
110
|
+
- `LazerLambda/ModernBERT-base-ModBERTScore-12` -> `ModernBERTBaseScore`
|
|
111
|
+
- `LazerLambda/ModernBERT-large-ModBERTScore-19` -> `ModernBERTLargeScore`
|
|
112
|
+
- `LazerLambda/roberta-base-ModBERTScore-10` -> `RobertaBaseScore`
|
|
113
|
+
- `LazerLambda/roberta-large-ModBERTScore-17` -> `RobertaLargeScore`
|
|
114
|
+
- `LazerLambda/roberta-large-mnli-ModBERTScore-19` -> `RobertaLargeMNLIScore`
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
## 🗺 Roadmap
|
|
118
|
+
|
|
119
|
+
- [x] Implement base version and vLLM addon
|
|
120
|
+
- [x] Add IDF-weighted scoring
|
|
121
|
+
- [ ] Add baseline-rescaling and scripts for identifying optimal baselines
|
|
122
|
+
- [ ] Add model (vLLM-)adaptation script for slicing the model
|
|
123
|
+
- [ ] Add multilingual support
|
|
124
|
+
- [ ] Add CLI tool
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Modern BERTScore for Fast Inference
|
|
2
|
+
|
|
3
|
+
[](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml)
|
|
4
|
+
[](https://www.python.org/downloads/release/python-3120/)
|
|
5
|
+
[](https://www.python.org/downloads/release/python-3130/)
|
|
6
|
+
[](https://www.python.org/downloads/)
|
|
7
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
8
|
+
|
|
9
|
+

|
|
10
|
+
|
|
11
|
+
**Modern-BERT-Score** is a reimplementation of the BERTScore metric introduced by [Zhang et al., 2019](https://arxiv.org/abs/1904.09675), optimized for modern inference workflows using [SentenceTransformers](https://www.sbert.net/) and [vLLM](https://vllm.ai/).
|
|
12
|
+
|
|
13
|
+
This library provides fast, GPU-accelerated scoring for text generation evaluation, making BERTScore practical for large-scale inference tasks.
|
|
14
|
+
|
|
15
|
+
---
|
|
16
|
+
|
|
17
|
+
## ⚡ Features
|
|
18
|
+
- Fast, efficient computation with optional vLLM support
|
|
19
|
+
- Compatible with all Hugging Face transformer models
|
|
20
|
+
- Supports truncated and optimized model versions for faster inference
|
|
21
|
+
- Works seamlessly with both CPU and GPU setups
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
## 📦 Installation
|
|
26
|
+
|
|
27
|
+
Modern-BERT-Score comes in **two variants**: a base version and a vLLM-enhanced version. For vLLM, an NVIDIA GPU is strongly recommended.
|
|
28
|
+
|
|
29
|
+
### Base Version
|
|
30
|
+
```bash
|
|
31
|
+
pip install modern-bert-score
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
### vLLM Version
|
|
35
|
+
```{bash}
|
|
36
|
+
pip install modern-bert-score[vllm]
|
|
37
|
+
```
|
|
38
|
+
This implementation is significantly faster than the original BERTScore, especially with GPU acceleration.
|
|
39
|
+
|
|
40
|
+
## 📝 BERTScore
|
|
41
|
+
BERTScore ([Zhang et al., 2019](https://arxiv.org/abs/1904.09675)) evaluates the similarity between candidate and reference texts by comparing their contextual embeddings from a pre-trained transformer model. For each token in the candidate, it finds the most similar token in the reference (using cosine similarity) and aggregates these scores to compute **precision**, **recall**, and **F1**. Optionally, **IDF-weighting** can be applied to give more importance to rare and informative words, improving the metric’s sensitivity to meaningful content over common words. Additionally, optional **Baseline Rescaling** shifts the scores such that the score is in the range of [0,1]. This approach captures semantic similarity beyond exact word matches, making it robust for tasks such as machine translation and text generation evaluation.
|
|
42
|
+
|
|
43
|
+
The following figure, taken from the original paper, illustrates how BERTScore works:
|
|
44
|
+
|
|
45
|
+

|
|
46
|
+
|
|
47
|
+
## 🛠 Usage
|
|
48
|
+
### Example
|
|
49
|
+
```python
|
|
50
|
+
from modern_bert_score import BertScore
|
|
51
|
+
|
|
52
|
+
candidates = ["Hello World!", "A robin is a bird."]
|
|
53
|
+
references = ["Hi World!", "A robin is not a bird."]
|
|
54
|
+
|
|
55
|
+
metric = BertScore(model_id="roberta-base")
|
|
56
|
+
scores = metric(candidates, references)
|
|
57
|
+
|
|
58
|
+
# scores is a list of (Precision, Recall, F1) tuples
|
|
59
|
+
# To get separate lists of P, R, F1:
|
|
60
|
+
P, R, F1 = zip(*scores)
|
|
61
|
+
|
|
62
|
+
print("Precision scores:", P)
|
|
63
|
+
print("Recall scores:", R)
|
|
64
|
+
print("F1 scores:", F1)
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## ⚠️ NOTICE
|
|
68
|
+
|
|
69
|
+
- For best performance, an optimal layer should be used for each model.
|
|
70
|
+
- To find the optimal layer, [please use this script from the original BERTScore implementation](https://github.com/Tiiiger/bert_score/tree/master/tune_layers).
|
|
71
|
+
|
|
72
|
+
Some pre-truncated models optimized for vLLM are available on [Hugging Face](https://huggingface.co/collections/LazerLambda/modern-bertscore) and directly available in this library:
|
|
73
|
+
|
|
74
|
+
- `LazerLambda/ModernBERT-base-ModBERTScore-12` -> `ModernBERTBaseScore`
|
|
75
|
+
- `LazerLambda/ModernBERT-large-ModBERTScore-19` -> `ModernBERTLargeScore`
|
|
76
|
+
- `LazerLambda/roberta-base-ModBERTScore-10` -> `RobertaBaseScore`
|
|
77
|
+
- `LazerLambda/roberta-large-ModBERTScore-17` -> `RobertaLargeScore`
|
|
78
|
+
- `LazerLambda/roberta-large-mnli-ModBERTScore-19` -> `RobertaLargeMNLIScore`
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
## 🗺 Roadmap
|
|
82
|
+
|
|
83
|
+
- [x] Implement base version and vLLM addon
|
|
84
|
+
- [x] Add IDF-weighted scoring
|
|
85
|
+
- [ ] Add baseline-rescaling and scripts for identifying optimal baselines
|
|
86
|
+
- [ ] Add model (vLLM-)adaptation script for slicing the model
|
|
87
|
+
- [ ] Add multilingual support
|
|
88
|
+
- [ ] Add CLI tool
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .bert_score import (
|
|
2
|
+
BertScore,
|
|
3
|
+
ModernBERTBaseScore,
|
|
4
|
+
ModernBERTLargeScore,
|
|
5
|
+
RobertaBaseScore,
|
|
6
|
+
RobertaLargeMNLIScore,
|
|
7
|
+
RobertaLargeScore,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BertScore",
|
|
12
|
+
"ModernBERTBaseScore",
|
|
13
|
+
"ModernBERTLargeScore",
|
|
14
|
+
"RobertaBaseScore",
|
|
15
|
+
"RobertaLargeScore",
|
|
16
|
+
"RobertaLargeMNLIScore",
|
|
17
|
+
]
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
from collections import Counter, defaultdict
|
|
2
|
+
from functools import partial, reduce
|
|
3
|
+
from itertools import chain, islice
|
|
4
|
+
from math import log
|
|
5
|
+
from multiprocessing import Pool
|
|
6
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from transformers import AutoTokenizer
|
|
10
|
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
11
|
+
|
|
12
|
+
from modern_bert_score.consts import BASELINES
|
|
13
|
+
from modern_bert_score.inference import STInference, VLLMInference
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BertScore:
|
|
17
|
+
inference_engine: Optional[Union[STInference, VLLMInference]]
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model_id: str = "answerdotai/ModernBERT-base",
|
|
22
|
+
idf_weighting: bool = False,
|
|
23
|
+
baseline_rescaling: bool = False,
|
|
24
|
+
custom_baseline: Optional[Tuple[float, float, float]] = None,
|
|
25
|
+
device: str = "cpu", # TODO Enum cuda, mlx, cpu?
|
|
26
|
+
backend: str = "default", # TODO Enum default, vllm, onnx, etc
|
|
27
|
+
sentence_transformers_args: Optional[Dict[str, Any]] = None,
|
|
28
|
+
vllm_args: Optional[Dict[str, Any]] = None,
|
|
29
|
+
):
|
|
30
|
+
self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_id)
|
|
31
|
+
self.idf_weighting: bool = idf_weighting
|
|
32
|
+
if baseline_rescaling:
|
|
33
|
+
if model_id in BASELINES.keys():
|
|
34
|
+
self.baseline: Tuple[float, float, float] = BASELINES[model_id]
|
|
35
|
+
if isinstance(self.baseline, tuple):
|
|
36
|
+
self.baseline = self.baseline
|
|
37
|
+
elif custom_baseline is not None:
|
|
38
|
+
self.baseline = custom_baseline
|
|
39
|
+
else:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
(
|
|
42
|
+
"Baseline rescaling enabled but no"
|
|
43
|
+
f"baseline found for model {model_id}"
|
|
44
|
+
" and no custom baseline provided."
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
self.baseline_rescaling: bool = baseline_rescaling
|
|
48
|
+
if backend == "vllm":
|
|
49
|
+
self.inference_engine = VLLMInference(
|
|
50
|
+
model=model_id,
|
|
51
|
+
runner="pooling",
|
|
52
|
+
convert="embed",
|
|
53
|
+
**(vllm_args or {}),
|
|
54
|
+
)
|
|
55
|
+
elif backend == "default":
|
|
56
|
+
self.inference_engine = STInference(
|
|
57
|
+
model_id, device=device, **(sentence_transformers_args or {})
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Unsupported backend {backend}")
|
|
61
|
+
|
|
62
|
+
def __call__(
|
|
63
|
+
self,
|
|
64
|
+
candidates: Union[str, List[str]],
|
|
65
|
+
references: Union[str, List[str]],
|
|
66
|
+
**kwargs: Any
|
|
67
|
+
) -> List[Tuple[float, float, float]]:
|
|
68
|
+
assert reduce(
|
|
69
|
+
lambda acc, x: acc and isinstance(x, str), candidates, True
|
|
70
|
+
), "Candidates must be a list of strings or a string."
|
|
71
|
+
assert reduce(
|
|
72
|
+
lambda acc, x: acc and isinstance(x, str), references, True
|
|
73
|
+
), "References must be a list of strings or a string."
|
|
74
|
+
if isinstance(candidates, str):
|
|
75
|
+
candidates = [candidates]
|
|
76
|
+
if isinstance(references, str):
|
|
77
|
+
references = [references]
|
|
78
|
+
if len(candidates) != len(references):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"Number of candidates and references must be the same."
|
|
81
|
+
)
|
|
82
|
+
if len(candidates) == 0:
|
|
83
|
+
return []
|
|
84
|
+
candidates = [c.strip() for c in candidates]
|
|
85
|
+
references = [r.strip() for r in references]
|
|
86
|
+
if self.idf_weighting:
|
|
87
|
+
idf_dict_ref, input_ids_ref = self.get_idf_dict(references)
|
|
88
|
+
input_ids_cand = self._tokenize_data(candidates)
|
|
89
|
+
else:
|
|
90
|
+
idf_dict_ref = None
|
|
91
|
+
|
|
92
|
+
if self.inference_engine is None:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
"Inference engine not initialized. Check backend "
|
|
95
|
+
"configuration."
|
|
96
|
+
)
|
|
97
|
+
cand_embs, ref_embs = self.inference_engine.inference(
|
|
98
|
+
candidates, references, **kwargs
|
|
99
|
+
)
|
|
100
|
+
if self.idf_weighting:
|
|
101
|
+
scores = [
|
|
102
|
+
self.bert_score(
|
|
103
|
+
candidates=c,
|
|
104
|
+
references=r,
|
|
105
|
+
idf_dict_ref=idf_dict_ref,
|
|
106
|
+
input_ids_cand=ids_cand,
|
|
107
|
+
input_ids_ref=ids_ref,
|
|
108
|
+
)
|
|
109
|
+
for c, r, ids_cand, ids_ref in zip(
|
|
110
|
+
cand_embs, ref_embs, input_ids_cand, input_ids_ref
|
|
111
|
+
)
|
|
112
|
+
]
|
|
113
|
+
else:
|
|
114
|
+
scores = [
|
|
115
|
+
self.bert_score(candidates=c, references=r)
|
|
116
|
+
for c, r in zip(cand_embs, ref_embs)
|
|
117
|
+
]
|
|
118
|
+
if self.baseline_rescaling:
|
|
119
|
+
rescaled_scores = []
|
|
120
|
+
for p, r, f1 in scores:
|
|
121
|
+
rescaled_p = (p - self.baseline[0]) / (1 - self.baseline[0])
|
|
122
|
+
rescaled_r = (r - self.baseline[1]) / (1 - self.baseline[1])
|
|
123
|
+
rescaled_f1 = (f1 - self.baseline[2]) / (1 - self.baseline[2])
|
|
124
|
+
rescaled_scores.append((rescaled_p, rescaled_r, rescaled_f1))
|
|
125
|
+
return rescaled_scores
|
|
126
|
+
return scores
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _check_nan(f1: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
if torch.isnan(f1):
|
|
131
|
+
f1 = torch.Tensor([0.0])
|
|
132
|
+
return f1
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def bert_score(
|
|
136
|
+
self,
|
|
137
|
+
candidates: torch.Tensor,
|
|
138
|
+
references: torch.Tensor,
|
|
139
|
+
idf_dict_ref: Optional[Dict[int, float]] = None,
|
|
140
|
+
input_ids_cand: Optional[List[int]] = None,
|
|
141
|
+
input_ids_ref: Optional[List[int]] = None,
|
|
142
|
+
) -> Tuple[float, float, float]:
|
|
143
|
+
has_idf_dict = idf_dict_ref is not None
|
|
144
|
+
has_input_ids = (
|
|
145
|
+
input_ids_cand is not None and input_ids_ref is not None
|
|
146
|
+
)
|
|
147
|
+
if has_idf_dict != has_input_ids:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
"`idf_dict` and `input_ids` must either both be provided or "
|
|
150
|
+
"both be None."
|
|
151
|
+
)
|
|
152
|
+
# TODO: w cuda?
|
|
153
|
+
assert len(candidates.shape) == 2 and len(references.shape) == 2
|
|
154
|
+
candidates = candidates[1:-1] # remove CLS and SEP
|
|
155
|
+
references = references[1:-1]
|
|
156
|
+
similarities: torch.Tensor = candidates @ references.T
|
|
157
|
+
r_bert = similarities.max(dim=0).values.cpu()
|
|
158
|
+
p_bert = similarities.max(dim=1).values.cpu()
|
|
159
|
+
if idf_dict_ref and input_ids_cand and input_ids_ref:
|
|
160
|
+
idf_weights_cand = torch.tensor(
|
|
161
|
+
[idf_dict_ref[tok_id] for tok_id in input_ids_cand]
|
|
162
|
+
)
|
|
163
|
+
idf_weights_ref = torch.tensor(
|
|
164
|
+
[idf_dict_ref[tok_id] for tok_id in input_ids_ref]
|
|
165
|
+
)
|
|
166
|
+
idf_weights_cand /= idf_weights_cand.sum()
|
|
167
|
+
idf_weights_ref /= idf_weights_ref.sum()
|
|
168
|
+
p_bert = (p_bert * idf_weights_cand).sum()
|
|
169
|
+
r_bert = (r_bert * idf_weights_ref).sum()
|
|
170
|
+
else:
|
|
171
|
+
r_bert = r_bert.mean()
|
|
172
|
+
p_bert = p_bert.mean()
|
|
173
|
+
f1 = 2 * p_bert * r_bert / (p_bert + r_bert)
|
|
174
|
+
# handle p_bert + r_bert == 0
|
|
175
|
+
f1 = self._check_nan(f1)
|
|
176
|
+
return p_bert.item(), r_bert.item(), f1.item()
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def _batchify(
|
|
180
|
+
iterable: List[str], batch_size: int
|
|
181
|
+
) -> Generator[List[str], None, None]:
|
|
182
|
+
iterator = iter(iterable)
|
|
183
|
+
while True:
|
|
184
|
+
batch = list(islice(iterator, batch_size))
|
|
185
|
+
if not batch:
|
|
186
|
+
break
|
|
187
|
+
yield batch
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def _process_batch(
|
|
191
|
+
batch: List[str],
|
|
192
|
+
tokenizer: PreTrainedTokenizerFast,
|
|
193
|
+
ignore_counter: bool = False,
|
|
194
|
+
) -> Tuple[Counter[int], List[List[int]]]:
|
|
195
|
+
stripped_batch = [sample.strip() for sample in batch]
|
|
196
|
+
|
|
197
|
+
encoded_batch = tokenizer(
|
|
198
|
+
stripped_batch,
|
|
199
|
+
add_special_tokens=True,
|
|
200
|
+
truncation=True,
|
|
201
|
+
return_attention_mask=False,
|
|
202
|
+
return_token_type_ids=False,
|
|
203
|
+
)["input_ids"]
|
|
204
|
+
encoded_batch = [e[1:-1] for e in encoded_batch] # remove CLS and SEP
|
|
205
|
+
if ignore_counter:
|
|
206
|
+
return Counter(), encoded_batch
|
|
207
|
+
else:
|
|
208
|
+
batch_count: Counter[int] = Counter(
|
|
209
|
+
chain.from_iterable(map(set, encoded_batch))
|
|
210
|
+
)
|
|
211
|
+
return batch_count, encoded_batch
|
|
212
|
+
|
|
213
|
+
def _tokenize_data(
|
|
214
|
+
self,
|
|
215
|
+
corpus: List[str],
|
|
216
|
+
batch_size: int = 100_000,
|
|
217
|
+
nthreads: int = 4) -> List[List[int]]:
|
|
218
|
+
collected_input_ids: List[List[int]] = []
|
|
219
|
+
|
|
220
|
+
process_partial = partial(
|
|
221
|
+
self._process_batch, tokenizer=self.tokenizer, ignore_counter=True
|
|
222
|
+
)
|
|
223
|
+
batches = self._batchify(corpus, batch_size)
|
|
224
|
+
|
|
225
|
+
if nthreads > 0:
|
|
226
|
+
with Pool(nthreads) as pool:
|
|
227
|
+
for batch_result in pool.imap(
|
|
228
|
+
process_partial, batches, chunksize=1
|
|
229
|
+
):
|
|
230
|
+
_, batch_input_ids = batch_result
|
|
231
|
+
collected_input_ids.extend(batch_input_ids)
|
|
232
|
+
else:
|
|
233
|
+
for batch_result in map(process_partial, batches):
|
|
234
|
+
_, batch_input_ids = batch_result
|
|
235
|
+
collected_input_ids.extend(batch_input_ids)
|
|
236
|
+
return collected_input_ids
|
|
237
|
+
|
|
238
|
+
def get_idf_dict(
|
|
239
|
+
self,
|
|
240
|
+
corpus: List[str],
|
|
241
|
+
nthreads: int = 4,
|
|
242
|
+
batch_size: int = 100_000,
|
|
243
|
+
) -> Tuple[Dict[int, float], List[List[int]]]: # TODO: Return dict
|
|
244
|
+
"""Build an IDF (Inverse-Document-Frequency) dictionary for a corpus.
|
|
245
|
+
|
|
246
|
+
When ``return_input_ids`` is true, this also returns the tokenized
|
|
247
|
+
``input_ids`` for each corpus entry in the same order as ``corpus``.
|
|
248
|
+
"""
|
|
249
|
+
idf_count: Counter[int] = Counter()
|
|
250
|
+
collected_input_ids: List[List[int]] = []
|
|
251
|
+
num_docs = len(corpus)
|
|
252
|
+
|
|
253
|
+
process_partial = partial(
|
|
254
|
+
self._process_batch, tokenizer=self.tokenizer
|
|
255
|
+
)
|
|
256
|
+
batches = self._batchify(corpus, batch_size)
|
|
257
|
+
|
|
258
|
+
if nthreads > 0:
|
|
259
|
+
with Pool(nthreads) as pool:
|
|
260
|
+
for batch_result in pool.imap(
|
|
261
|
+
process_partial, batches, chunksize=1
|
|
262
|
+
):
|
|
263
|
+
batch_count, batch_input_ids = batch_result
|
|
264
|
+
collected_input_ids.extend(batch_input_ids)
|
|
265
|
+
idf_count.update(batch_count)
|
|
266
|
+
else:
|
|
267
|
+
for batch_result in map(process_partial, batches):
|
|
268
|
+
batch_count, batch_input_ids = batch_result
|
|
269
|
+
collected_input_ids.extend(batch_input_ids)
|
|
270
|
+
idf_count.update(batch_count)
|
|
271
|
+
|
|
272
|
+
idf_dict: Dict[int, float] = defaultdict(
|
|
273
|
+
lambda: log((num_docs + 1) / (1))
|
|
274
|
+
)
|
|
275
|
+
idf_dict.update(
|
|
276
|
+
{
|
|
277
|
+
idx: log((num_docs + 1) / (count + 1))
|
|
278
|
+
for idx, count in idf_count.items()
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
return idf_dict, collected_input_ids
|
|
282
|
+
|
|
283
|
+
def ModernBERTBaseScore(**kwargs: Any) -> "BertScore":
|
|
284
|
+
"""BertScore with ModernBERT-base-ModBERTScore-12"""
|
|
285
|
+
kwargs.pop("model_id", None)
|
|
286
|
+
return BertScore(model_id="LazerLambda/ModernBERT-base-ModBERTScore-12", **kwargs)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def ModernBERTLargeScore(**kwargs: Any) -> "BertScore":
|
|
290
|
+
"""BertScore with ModernBERT-large-ModBERTScore-19"""
|
|
291
|
+
kwargs.pop("model_id", None)
|
|
292
|
+
return BertScore(model_id="LazerLambda/ModernBERT-large-ModBERTScore-19", **kwargs)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def RobertaBaseScore(**kwargs: Any) -> "BertScore":
|
|
296
|
+
"""BertScore with roberta-base-ModBERTScore-10"""
|
|
297
|
+
kwargs.pop("model_id", None)
|
|
298
|
+
return BertScore(model_id="LazerLambda/roberta-base-ModBERTScore-10", **kwargs)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def RobertaLargeScore(**kwargs: Any) -> "BertScore":
|
|
302
|
+
"""BertScore with roberta-large-ModBERTScore-17"""
|
|
303
|
+
kwargs.pop("model_id", None)
|
|
304
|
+
return BertScore(model_id="LazerLambda/roberta-large-ModBERTScore-17", **kwargs)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def RobertaLargeMNLIScore(**kwargs: Any) -> "BertScore":
|
|
308
|
+
"""BertScore with roberta-large-mnli-ModBERTScore-19"""
|
|
309
|
+
kwargs.pop("model_id", None)
|
|
310
|
+
return BertScore(model_id="LazerLambda/roberta-large-mnli-ModBERTScore-19", **kwargs)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
|
|
2
|
+
BASELINES = {
|
|
3
|
+
"LazerLambda/ModernBERT-base-ModBERTScore-12": (0.7899982,0.790043,0.7898656),
|
|
4
|
+
"LazerLambda/ModernBERT-large-ModBERTScore-19": (0.64403415,0.64410657,0.64355606),
|
|
5
|
+
"LazerLambda/roberta-base-ModBERTScore-10": (0.80465585,0.80464727,0.8043641),
|
|
6
|
+
"LazerLambda/roberta-large-ModBERTScore-17": (0.8243552,0.8243522,0.8240494),
|
|
7
|
+
"LazerLambda/roberta-large-mnli-ModBERTScore-19": (0.6901594,0.69018,0.6896288)
|
|
8
|
+
}
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
from typing import Any, List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from sentence_transformers import SentenceTransformer
|
|
6
|
+
from torch.nn import functional as F
|
|
7
|
+
from transformers import AutoTokenizer
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from vllm import LLM
|
|
11
|
+
|
|
12
|
+
VLLM_AVAILABLE = True
|
|
13
|
+
except ImportError:
|
|
14
|
+
LLM = object # To prevent NameError if vllm is not installed
|
|
15
|
+
VLLM_AVAILABLE = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# TODO: Cache reference embeddings
|
|
19
|
+
class Inference:
|
|
20
|
+
|
|
21
|
+
model: Any = None
|
|
22
|
+
|
|
23
|
+
def inference(
|
|
24
|
+
self, candidates: List[str], references: List[str], **kwargs: Any
|
|
25
|
+
) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
26
|
+
raise NotImplementedError("Method must be implemented in Subclass.")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class STInference(Inference):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model_id: str,
|
|
33
|
+
device: str = "cpu",
|
|
34
|
+
batch_size: int = 64,
|
|
35
|
+
**kwargs: Any,
|
|
36
|
+
):
|
|
37
|
+
self.model = SentenceTransformer(
|
|
38
|
+
model_name_or_path=model_id, device=device, **kwargs
|
|
39
|
+
)
|
|
40
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
41
|
+
model_id
|
|
42
|
+
) # TODO: Maybe switch to PreTrainedTokenizerFast for clarity?
|
|
43
|
+
self.batch_size = batch_size
|
|
44
|
+
self.eps: float = 1e-12
|
|
45
|
+
|
|
46
|
+
def inference(
|
|
47
|
+
self, candidates: List[str], references: List[str], **kwargs: Any
|
|
48
|
+
) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
49
|
+
|
|
50
|
+
if self.model is None:
|
|
51
|
+
raise RuntimeError("Model not loaded.")
|
|
52
|
+
embds_refs = self.model.encode(
|
|
53
|
+
references,
|
|
54
|
+
output_value="token_embeddings",
|
|
55
|
+
convert_to_tensor=True,
|
|
56
|
+
**kwargs,
|
|
57
|
+
)
|
|
58
|
+
embds_refs = [
|
|
59
|
+
F.normalize(e, p=2, dim=-1, eps=self.eps) for e in embds_refs
|
|
60
|
+
]
|
|
61
|
+
embds_cnds = self.model.encode(
|
|
62
|
+
candidates,
|
|
63
|
+
output_value="token_embeddings",
|
|
64
|
+
convert_to_tensor=True,
|
|
65
|
+
**kwargs,
|
|
66
|
+
)
|
|
67
|
+
embds_cnds = [
|
|
68
|
+
F.normalize(e, p=2, dim=-1, eps=self.eps) for e in embds_cnds
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
return embds_cnds, embds_refs
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class VLLMInference(Inference):
|
|
75
|
+
def __init__(self, **kwargs: Any):
|
|
76
|
+
if not VLLM_AVAILABLE:
|
|
77
|
+
raise ImportError(
|
|
78
|
+
"vLLM is not installed. To use the vLLM backend, please "
|
|
79
|
+
"install it with `pip install vllm` or "
|
|
80
|
+
"`pip install 'modern-bert-score[vllm]'`."
|
|
81
|
+
)
|
|
82
|
+
# Backward compatibility for old callsites that pass task="embed".
|
|
83
|
+
kwargs = self._prepare_args(kwargs)
|
|
84
|
+
try:
|
|
85
|
+
self.model = LLM(**kwargs)
|
|
86
|
+
except Exception as exc:
|
|
87
|
+
message = str(exc)
|
|
88
|
+
if (
|
|
89
|
+
"Model architectures" in message
|
|
90
|
+
and "ModernBertForMaskedLM" in message
|
|
91
|
+
):
|
|
92
|
+
raise RuntimeError(
|
|
93
|
+
"vLLM does not accept the masked-LM ModernBERT checkpoint "
|
|
94
|
+
"directly. Export an encoder-only checkpoint first with "
|
|
95
|
+
"prepare_model.py, which rewrites the saved config to "
|
|
96
|
+
"advertise ModernBertModel, then load that local path in "
|
|
97
|
+
"VLLMInference. If you do not need vLLM specifically, use "
|
|
98
|
+
"STInference for the original HF checkpoint."
|
|
99
|
+
) from exc
|
|
100
|
+
raise
|
|
101
|
+
self.eps: float = 1e-12
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _prepare_args(kwargs: Any) -> Any:
|
|
105
|
+
task = kwargs.pop("task", None)
|
|
106
|
+
if task == "embed":
|
|
107
|
+
kwargs.setdefault("runner", "pooling")
|
|
108
|
+
kwargs.setdefault("convert", "embed")
|
|
109
|
+
return kwargs
|
|
110
|
+
|
|
111
|
+
def inference(
|
|
112
|
+
self, candidates: List[str], references: List[str], **kwargs: Any
|
|
113
|
+
) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
114
|
+
if self.model is None:
|
|
115
|
+
raise RuntimeError("Model not loaded.")
|
|
116
|
+
outputs_cands = self.model.encode(
|
|
117
|
+
candidates, pooling_task="token_embed", **kwargs
|
|
118
|
+
)
|
|
119
|
+
outputs_refs = self.model.encode(
|
|
120
|
+
references, pooling_task="token_embed", **kwargs
|
|
121
|
+
)
|
|
122
|
+
collector: List[torch.Tensor] = []
|
|
123
|
+
for output in outputs_cands:
|
|
124
|
+
embeds = output.outputs.data
|
|
125
|
+
collector.append(embeds)
|
|
126
|
+
for output in outputs_refs:
|
|
127
|
+
embeds = output.outputs.data
|
|
128
|
+
collector.append(embeds)
|
|
129
|
+
|
|
130
|
+
collector = [
|
|
131
|
+
F.normalize(e, p=2, dim=-1, eps=self.eps) for e in collector
|
|
132
|
+
] # TODO: Check superflous?
|
|
133
|
+
return collector[0 : len(candidates)], collector[len(candidates) :]
|
|
134
|
+
|
|
135
|
+
def cleanup(self) -> None:
|
|
136
|
+
if hasattr(self, "model") and self.model:
|
|
137
|
+
del self.model
|
|
138
|
+
gc.collect()
|
|
139
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: modern-bert-score
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: A reimplementation of the BERTScore metric optimized for modern inference workflows.
|
|
5
|
+
Author-email: Philipp Koch <PhillKoch@protonmail.com>
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/pypa/sampleproject
|
|
8
|
+
Project-URL: Issues, https://github.com/pypa/sampleproject/issues
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Environment :: GPU :: NVIDIA CUDA
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Requires-Python: >=3.12
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE.md
|
|
19
|
+
Requires-Dist: sentence-transformers>=5.2.3
|
|
20
|
+
Requires-Dist: torch>=2.9.1
|
|
21
|
+
Requires-Dist: transformers>=4.57.6
|
|
22
|
+
Provides-Extra: vllm
|
|
23
|
+
Requires-Dist: vllm>=0.16.0; extra == "vllm"
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: sphinx; extra == "dev"
|
|
26
|
+
Requires-Dist: sphinx_rtd_theme; extra == "dev"
|
|
27
|
+
Requires-Dist: myst-parser; extra == "dev"
|
|
28
|
+
Requires-Dist: pytest; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
30
|
+
Requires-Dist: ruff; extra == "dev"
|
|
31
|
+
Requires-Dist: flake8; extra == "dev"
|
|
32
|
+
Requires-Dist: mypy; extra == "dev"
|
|
33
|
+
Requires-Dist: flake8; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
35
|
+
Dynamic: license-file
|
|
36
|
+
|
|
37
|
+
# Modern BERTScore for Fast Inference
|
|
38
|
+
|
|
39
|
+
[](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml)
|
|
40
|
+
[](https://www.python.org/downloads/release/python-3120/)
|
|
41
|
+
[](https://www.python.org/downloads/release/python-3130/)
|
|
42
|
+
[](https://www.python.org/downloads/)
|
|
43
|
+
[](https://opensource.org/licenses/Apache-2.0)
|
|
44
|
+
|
|
45
|
+

|
|
46
|
+
|
|
47
|
+
**Modern-BERT-Score** is a reimplementation of the BERTScore metric introduced by [Zhang et al., 2019](https://arxiv.org/abs/1904.09675), optimized for modern inference workflows using [SentenceTransformers](https://www.sbert.net/) and [vLLM](https://vllm.ai/).
|
|
48
|
+
|
|
49
|
+
This library provides fast, GPU-accelerated scoring for text generation evaluation, making BERTScore practical for large-scale inference tasks.
|
|
50
|
+
|
|
51
|
+
---
|
|
52
|
+
|
|
53
|
+
## ⚡ Features
|
|
54
|
+
- Fast, efficient computation with optional vLLM support
|
|
55
|
+
- Compatible with all Hugging Face transformer models
|
|
56
|
+
- Supports truncated and optimized model versions for faster inference
|
|
57
|
+
- Works seamlessly with both CPU and GPU setups
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
## 📦 Installation
|
|
62
|
+
|
|
63
|
+
Modern-BERT-Score comes in **two variants**: a base version and a vLLM-enhanced version. For vLLM, an NVIDIA GPU is strongly recommended.
|
|
64
|
+
|
|
65
|
+
### Base Version
|
|
66
|
+
```bash
|
|
67
|
+
pip install modern-bert-score
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### vLLM Version
|
|
71
|
+
```{bash}
|
|
72
|
+
pip install modern-bert-score[vllm]
|
|
73
|
+
```
|
|
74
|
+
This implementation is significantly faster than the original BERTScore, especially with GPU acceleration.
|
|
75
|
+
|
|
76
|
+
## 📝 BERTScore
|
|
77
|
+
BERTScore ([Zhang et al., 2019](https://arxiv.org/abs/1904.09675)) evaluates the similarity between candidate and reference texts by comparing their contextual embeddings from a pre-trained transformer model. For each token in the candidate, it finds the most similar token in the reference (using cosine similarity) and aggregates these scores to compute **precision**, **recall**, and **F1**. Optionally, **IDF-weighting** can be applied to give more importance to rare and informative words, improving the metric’s sensitivity to meaningful content over common words. Additionally, optional **Baseline Rescaling** shifts the scores such that the score is in the range of [0,1]. This approach captures semantic similarity beyond exact word matches, making it robust for tasks such as machine translation and text generation evaluation.
|
|
78
|
+
|
|
79
|
+
The following figure, taken from the original paper, illustrates how BERTScore works:
|
|
80
|
+
|
|
81
|
+

|
|
82
|
+
|
|
83
|
+
## 🛠 Usage
|
|
84
|
+
### Example
|
|
85
|
+
```python
|
|
86
|
+
from modern_bert_score import BertScore
|
|
87
|
+
|
|
88
|
+
candidates = ["Hello World!", "A robin is a bird."]
|
|
89
|
+
references = ["Hi World!", "A robin is not a bird."]
|
|
90
|
+
|
|
91
|
+
metric = BertScore(model_id="roberta-base")
|
|
92
|
+
scores = metric(candidates, references)
|
|
93
|
+
|
|
94
|
+
# scores is a list of (Precision, Recall, F1) tuples
|
|
95
|
+
# To get separate lists of P, R, F1:
|
|
96
|
+
P, R, F1 = zip(*scores)
|
|
97
|
+
|
|
98
|
+
print("Precision scores:", P)
|
|
99
|
+
print("Recall scores:", R)
|
|
100
|
+
print("F1 scores:", F1)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## ⚠️ NOTICE
|
|
104
|
+
|
|
105
|
+
- For best performance, an optimal layer should be used for each model.
|
|
106
|
+
- To find the optimal layer, [please use this script from the original BERTScore implementation](https://github.com/Tiiiger/bert_score/tree/master/tune_layers).
|
|
107
|
+
|
|
108
|
+
Some pre-truncated models optimized for vLLM are available on [Hugging Face](https://huggingface.co/collections/LazerLambda/modern-bertscore) and directly available in this library:
|
|
109
|
+
|
|
110
|
+
- `LazerLambda/ModernBERT-base-ModBERTScore-12` -> `ModernBERTBaseScore`
|
|
111
|
+
- `LazerLambda/ModernBERT-large-ModBERTScore-19` -> `ModernBERTLargeScore`
|
|
112
|
+
- `LazerLambda/roberta-base-ModBERTScore-10` -> `RobertaBaseScore`
|
|
113
|
+
- `LazerLambda/roberta-large-ModBERTScore-17` -> `RobertaLargeScore`
|
|
114
|
+
- `LazerLambda/roberta-large-mnli-ModBERTScore-19` -> `RobertaLargeMNLIScore`
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
## 🗺 Roadmap
|
|
118
|
+
|
|
119
|
+
- [x] Implement base version and vLLM addon
|
|
120
|
+
- [x] Add IDF-weighted scoring
|
|
121
|
+
- [ ] Add baseline-rescaling and scripts for identifying optimal baselines
|
|
122
|
+
- [ ] Add model (vLLM-)adaptation script for slicing the model
|
|
123
|
+
- [ ] Add multilingual support
|
|
124
|
+
- [ ] Add CLI tool
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
LICENSE.md
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
modern_bert_score/__init__.py
|
|
5
|
+
modern_bert_score/bert_score.py
|
|
6
|
+
modern_bert_score/consts.py
|
|
7
|
+
modern_bert_score/inference.py
|
|
8
|
+
modern_bert_score.egg-info/PKG-INFO
|
|
9
|
+
modern_bert_score.egg-info/SOURCES.txt
|
|
10
|
+
modern_bert_score.egg-info/dependency_links.txt
|
|
11
|
+
modern_bert_score.egg-info/requires.txt
|
|
12
|
+
modern_bert_score.egg-info/top_level.txt
|
|
13
|
+
tests/test_bert_score.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
modern_bert_score
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "modern-bert-score"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "A reimplementation of the BERTScore metric optimized for modern inference workflows."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
classifiers = [
|
|
12
|
+
"Programming Language :: Python :: 3",
|
|
13
|
+
"Operating System :: OS Independent",
|
|
14
|
+
"Development Status :: 2 - Pre-Alpha",
|
|
15
|
+
"Intended Audience :: Developers",
|
|
16
|
+
"Intended Audience :: Science/Research",
|
|
17
|
+
"Environment :: GPU :: NVIDIA CUDA",
|
|
18
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
|
19
|
+
]
|
|
20
|
+
license = "Apache-2.0"
|
|
21
|
+
license-files = ["LICEN[CS]E*"]
|
|
22
|
+
authors = [
|
|
23
|
+
{name = "Philipp Koch", email="PhillKoch@protonmail.com"},
|
|
24
|
+
]
|
|
25
|
+
dependencies = [
|
|
26
|
+
"sentence-transformers>=5.2.3",
|
|
27
|
+
"torch>=2.9.1",
|
|
28
|
+
"transformers>=4.57.6",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[project.urls]
|
|
32
|
+
Homepage = "https://github.com/pypa/sampleproject"
|
|
33
|
+
Issues = "https://github.com/pypa/sampleproject/issues"
|
|
34
|
+
|
|
35
|
+
[project.optional-dependencies]
|
|
36
|
+
vllm = ["vllm>=0.16.0"]
|
|
37
|
+
dev = [
|
|
38
|
+
"sphinx",
|
|
39
|
+
"sphinx_rtd_theme",
|
|
40
|
+
"myst-parser",
|
|
41
|
+
"pytest",
|
|
42
|
+
"pytest-cov",
|
|
43
|
+
"ruff",
|
|
44
|
+
"flake8",
|
|
45
|
+
"mypy",
|
|
46
|
+
"flake8",
|
|
47
|
+
"pytest-cov"
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
[tool.setuptools.packages.find]
|
|
51
|
+
include = ["modern_bert_score*"]
|
|
52
|
+
|
|
53
|
+
[tool.ruff]
|
|
54
|
+
line-length = 90
|
|
55
|
+
extend-exclude = ["tests"]
|
|
56
|
+
|
|
57
|
+
[tool.ruff.lint]
|
|
58
|
+
select = ["E", "F", "W", "I"]
|
|
59
|
+
|
|
60
|
+
[tool.mypy]
|
|
61
|
+
strict = true
|
|
62
|
+
ignore_missing_imports = true
|
|
63
|
+
|
|
64
|
+
[tool.pytest.ini_options]
|
|
65
|
+
filterwarnings = ["ignore::DeprecationWarning"]
|
|
66
|
+
exclude = ["^build/"]
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from unittest.mock import MagicMock, patch
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import logging as transformers_logging
|
|
7
|
+
|
|
8
|
+
from modern_bert_score.bert_score import BertScore
|
|
9
|
+
from modern_bert_score.inference import VLLM_AVAILABLE
|
|
10
|
+
|
|
11
|
+
# Suppress transformers warnings about missing weights during testing
|
|
12
|
+
transformers_logging.set_verbosity_error()
|
|
13
|
+
|
|
14
|
+
TEST_MODEL= "LazerLambda/BERT-Tiny-L-2-H-128-A-2-ModBERTScore-TEST"
|
|
15
|
+
|
|
16
|
+
class TestBertScore(unittest.TestCase):
|
|
17
|
+
|
|
18
|
+
test_model: str = TEST_MODEL
|
|
19
|
+
|
|
20
|
+
def setUp(self):
|
|
21
|
+
self.candidates = [
|
|
22
|
+
"Hello, my name is",
|
|
23
|
+
"The president of the United States is",
|
|
24
|
+
"The capital of France is",
|
|
25
|
+
"The future of AI is",
|
|
26
|
+
"The cat is on the table",
|
|
27
|
+
]
|
|
28
|
+
self.references = [
|
|
29
|
+
"Hello, my name is",
|
|
30
|
+
"The head of the United States is",
|
|
31
|
+
"The capital of Japan is",
|
|
32
|
+
"The future of Work is",
|
|
33
|
+
"The dog is on the table",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
def test_against_original(self):
|
|
37
|
+
"""Test basic functionality."""
|
|
38
|
+
original_p_r_f1 = [
|
|
39
|
+
torch.tensor([1.0000, 0.9423, 0.8705, 0.9105, 0.9637]),
|
|
40
|
+
torch.tensor([1.0000, 0.9423, 0.8705, 0.9105, 0.9637]),
|
|
41
|
+
torch.tensor([1.0000, 0.9423, 0.8705, 0.9105, 0.9637])
|
|
42
|
+
]
|
|
43
|
+
original_p_r_f1 = torch.stack(original_p_r_f1).T
|
|
44
|
+
original_p_r_f1_idf = [
|
|
45
|
+
torch.tensor([1.0000, 0.9233, 0.8167, 0.8609, 0.9369]),
|
|
46
|
+
torch.tensor([1.0000, 0.9358, 0.8304, 0.8861, 0.9505]),
|
|
47
|
+
torch.tensor([1.0000, 0.9295, 0.8235, 0.8733, 0.9436])
|
|
48
|
+
]
|
|
49
|
+
original_p_r_f1_idf = torch.stack(original_p_r_f1_idf).T
|
|
50
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
51
|
+
p_r_f1 = bs(self.candidates, self.references)
|
|
52
|
+
bs.idf_weighting = True
|
|
53
|
+
p_r_f1_idf = bs(self.candidates, self.references)
|
|
54
|
+
for (p, r, f1), (p_exp, r_exp, f1_exp) in zip(p_r_f1, original_p_r_f1):
|
|
55
|
+
self.assertTrue(torch.allclose(torch.tensor(p), p_exp, atol=1e-4))
|
|
56
|
+
self.assertTrue(torch.allclose(torch.tensor(r), r_exp, atol=1e-4))
|
|
57
|
+
self.assertTrue(torch.allclose(torch.tensor(f1), f1_exp, atol=1e-4))
|
|
58
|
+
for (p, r, f1), (p_exp, r_exp, f1_exp) in zip(p_r_f1_idf, original_p_r_f1_idf):
|
|
59
|
+
self.assertTrue(torch.allclose(torch.tensor(p), p_exp, atol=1e-4))
|
|
60
|
+
self.assertTrue(torch.allclose(torch.tensor(r), r_exp, atol=1e-4))
|
|
61
|
+
self.assertTrue(torch.allclose(torch.tensor(f1), f1_exp, atol=1e-4))
|
|
62
|
+
|
|
63
|
+
def test_id(self):
|
|
64
|
+
cand1 = ["Hello World!"]
|
|
65
|
+
ref1 = ["Hello World!"]
|
|
66
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
67
|
+
p_r_f1 = bs(cand1, ref1)
|
|
68
|
+
self.assertAlmostEqual(p_r_f1[0][0], 1.0, places=5)
|
|
69
|
+
self.assertAlmostEqual(p_r_f1[0][1], 1.0, places=5)
|
|
70
|
+
self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
|
|
71
|
+
|
|
72
|
+
cand2 = ["Hello World!"]
|
|
73
|
+
ref2 = [" Hello World! "]
|
|
74
|
+
p_r_f1 = bs(cand2, ref2)
|
|
75
|
+
self.assertAlmostEqual(p_r_f1[0][0], 1.0, places=5)
|
|
76
|
+
self.assertAlmostEqual(p_r_f1[0][1], 1.0, places=5)
|
|
77
|
+
self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
|
|
78
|
+
|
|
79
|
+
def test_unequal_length(self):
|
|
80
|
+
cand1 = ["Hello World!"]
|
|
81
|
+
ref1 = ["Hello World!", "Hello World!"]
|
|
82
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
83
|
+
with self.assertRaises(ValueError):
|
|
84
|
+
bs(cand1, ref1)
|
|
85
|
+
|
|
86
|
+
cand2 = ["Hello World!", "Hello World!"]
|
|
87
|
+
ref2 = ["Hello World!"]
|
|
88
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
89
|
+
with self.assertRaises(ValueError):
|
|
90
|
+
bs(cand2, ref2)
|
|
91
|
+
|
|
92
|
+
def test_unequal_input(self):
|
|
93
|
+
cand1 = ["Hello World!"]
|
|
94
|
+
ref1 = ["Bye World!"]
|
|
95
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
96
|
+
p_r_f1 = bs(cand1, ref1)
|
|
97
|
+
self.assertTrue(p_r_f1[0][0] < 1.0)
|
|
98
|
+
self.assertTrue(p_r_f1[0][1] < 1.0)
|
|
99
|
+
self.assertTrue(p_r_f1[0][2] < 1.0)
|
|
100
|
+
|
|
101
|
+
def test_empty_input(self):
|
|
102
|
+
cand1 = []
|
|
103
|
+
ref1 = []
|
|
104
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
105
|
+
p_r_f1 = bs(cand1, ref1)
|
|
106
|
+
self.assertEqual(p_r_f1, [])
|
|
107
|
+
|
|
108
|
+
def test_empty_inference_engine(self):
|
|
109
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
110
|
+
bs.inference_engine = None
|
|
111
|
+
with self.assertRaises(ValueError):
|
|
112
|
+
bs(self.candidates, self.references)
|
|
113
|
+
|
|
114
|
+
def test_check_nan(self):
|
|
115
|
+
f1_nan = torch.tensor(torch.nan)
|
|
116
|
+
f1_checked = BertScore._check_nan(f1_nan)
|
|
117
|
+
self.assertEqual(f1_checked, torch.Tensor([0.0]))
|
|
118
|
+
|
|
119
|
+
def test_exception_idf_bertscore(self):
|
|
120
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
121
|
+
with self.assertRaises(ValueError):
|
|
122
|
+
bs.bert_score(
|
|
123
|
+
candidates=torch.rand(3, 4),
|
|
124
|
+
references=torch.rand(3, 4),
|
|
125
|
+
input_ids_cand=[0, 1, 2],
|
|
126
|
+
input_ids_ref=[0, 1, 2],
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def test_single(self):
|
|
130
|
+
cand1 = "Hello World!"
|
|
131
|
+
ref1 = "Hello World!"
|
|
132
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
133
|
+
bs(cand1, ref1)
|
|
134
|
+
|
|
135
|
+
def test_tokenize_batch(self):
|
|
136
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
137
|
+
counter, input_ids = bs._process_batch(
|
|
138
|
+
["Hello World!", "Hello World!"],
|
|
139
|
+
bs.tokenizer,
|
|
140
|
+
ignore_counter=True
|
|
141
|
+
)
|
|
142
|
+
assert len(counter) == 0
|
|
143
|
+
counter, input_ids = bs._process_batch(
|
|
144
|
+
["Hello World!", "Hello World!"],
|
|
145
|
+
bs.tokenizer,
|
|
146
|
+
ignore_counter=False
|
|
147
|
+
)
|
|
148
|
+
assert counter is not None
|
|
149
|
+
|
|
150
|
+
def test_tokenize_data(self):
|
|
151
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
152
|
+
bs._tokenize_data(["Hello World!", "Hello World!"], nthreads=4)
|
|
153
|
+
bs._tokenize_data(["Hello World!", "Hello World!"], nthreads=0)
|
|
154
|
+
|
|
155
|
+
def test_get_idf_dict(self):
|
|
156
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
157
|
+
bs.get_idf_dict(["Hello World!", "Hello World!"], nthreads=4)
|
|
158
|
+
bs.get_idf_dict(["Hello World!", "Hello World!"], nthreads=0)
|
|
159
|
+
|
|
160
|
+
def test_base_line(self):
|
|
161
|
+
cand1 = ["Hello World!"]
|
|
162
|
+
ref1 = ["Hello World!"]
|
|
163
|
+
bs = BertScore(model_id=self.test_model, backend="default", baseline_rescaling=True, custom_baseline=(0.5, 0.5, 0.5))
|
|
164
|
+
p_r_f1 = bs(cand1, ref1)
|
|
165
|
+
self.assertAlmostEqual(p_r_f1[0][0], 1.0, places=5)
|
|
166
|
+
self.assertAlmostEqual(p_r_f1[0][1], 1.0, places=5)
|
|
167
|
+
self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
|
|
168
|
+
|
|
169
|
+
def test_baseline_rescaling_exception(self):
|
|
170
|
+
"""Test that ValueError is raised when baseline is missing."""
|
|
171
|
+
with self.assertRaises(ValueError):
|
|
172
|
+
BertScore(model_id=self.test_model, backend="default", baseline_rescaling=True)
|
|
173
|
+
|
|
174
|
+
def test_baseline_rescaling_injection(self):
|
|
175
|
+
"""Test baseline rescaling with injected baseline for the test model."""
|
|
176
|
+
from modern_bert_score.consts import BASELINES
|
|
177
|
+
# Inject test model baseline (P, R, F1) tuple. F1 is at index 2.
|
|
178
|
+
BASELINES[self.test_model] = (0.5, 0.5, 0.5)
|
|
179
|
+
try:
|
|
180
|
+
bs = BertScore(model_id=self.test_model, backend="default", baseline_rescaling=True)
|
|
181
|
+
self.assertEqual(bs.baseline, (0.5, 0.5, 0.5))
|
|
182
|
+
# Test computation
|
|
183
|
+
cand1 = ["Hello World!"]
|
|
184
|
+
ref1 = ["Hello World!"]
|
|
185
|
+
p_r_f1 = bs(cand1, ref1)
|
|
186
|
+
self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
|
|
187
|
+
finally:
|
|
188
|
+
del BASELINES[self.test_model]
|
|
189
|
+
|
|
190
|
+
def test_inference_base_not_implemented_error(self):
|
|
191
|
+
from modern_bert_score.inference import Inference
|
|
192
|
+
with self.assertRaises(NotImplementedError):
|
|
193
|
+
dummy = Inference()
|
|
194
|
+
dummy.inference(["Hello"], ["Hello"])
|
|
195
|
+
|
|
196
|
+
@patch("modern_bert_score.bert_score.AutoTokenizer")
|
|
197
|
+
@patch("modern_bert_score.bert_score.STInference")
|
|
198
|
+
def test_initialization_default(self, mock_st_inference, mock_tokenizer):
|
|
199
|
+
"""Test initialization with default backend."""
|
|
200
|
+
# Setup mocks
|
|
201
|
+
mock_tokenizer.from_pretrained.return_value = MagicMock()
|
|
202
|
+
|
|
203
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
204
|
+
|
|
205
|
+
mock_tokenizer.from_pretrained.assert_called_with(self.test_model)
|
|
206
|
+
mock_st_inference.assert_called_once()
|
|
207
|
+
self.assertEqual(bs.inference_engine, mock_st_inference.return_value)
|
|
208
|
+
|
|
209
|
+
@patch("modern_bert_score.bert_score.AutoTokenizer")
|
|
210
|
+
@patch("modern_bert_score.bert_score.VLLMInference")
|
|
211
|
+
def test_initialization_vllm(self, mock_vllm_inference, mock_tokenizer):
|
|
212
|
+
"""Test initialization with vllm backend."""
|
|
213
|
+
# Setup mocks
|
|
214
|
+
mock_tokenizer.from_pretrained.return_value = MagicMock()
|
|
215
|
+
|
|
216
|
+
bs = BertScore(model_id=self.test_model, backend="vllm")
|
|
217
|
+
|
|
218
|
+
mock_vllm_inference.assert_called_once()
|
|
219
|
+
self.assertEqual(bs.inference_engine, mock_vllm_inference.return_value)
|
|
220
|
+
|
|
221
|
+
@patch("modern_bert_score.bert_score.AutoTokenizer")
|
|
222
|
+
@patch("modern_bert_score.inference.VLLM_AVAILABLE", False)
|
|
223
|
+
def test_initialization_vllm_without_vllm_installed(self, mock_tokenizer):
|
|
224
|
+
"""Test that using vllm backend without vllm installed raises ImportError."""
|
|
225
|
+
mock_tokenizer.from_pretrained.return_value = MagicMock()
|
|
226
|
+
with self.assertRaises(ImportError) as cm:
|
|
227
|
+
BertScore(model_id=self.test_model, backend="vllm")
|
|
228
|
+
|
|
229
|
+
self.assertIn("vLLM is not installed", str(cm.exception))
|
|
230
|
+
self.assertIn("pip install 'modern-bert-score[vllm]'", str(cm.exception))
|
|
231
|
+
|
|
232
|
+
@patch("modern_bert_score.inference.LLM")
|
|
233
|
+
@patch("modern_bert_score.inference.VLLM_AVAILABLE", True)
|
|
234
|
+
@patch("modern_bert_score.bert_score.AutoTokenizer")
|
|
235
|
+
def test_initialization_vllm_masked_lm_error(self, mock_tokenizer, mock_llm):
|
|
236
|
+
"""Test that appropriate error is raised when vLLM rejects MaskedLM architecture.
|
|
237
|
+
|
|
238
|
+
This test uses mocks for vLLM components, so it will run (and pass) regardless
|
|
239
|
+
of whether the `vllm` package is installed in the test environment.
|
|
240
|
+
"""
|
|
241
|
+
mock_tokenizer.from_pretrained.return_value = MagicMock()
|
|
242
|
+
|
|
243
|
+
# Simulate vLLM raising exception about architecture
|
|
244
|
+
mock_llm.side_effect = Exception(
|
|
245
|
+
"ValueError: Model architectures ['ModernBertForMaskedLM'] are not supported for now. "
|
|
246
|
+
"Supported architectures: ['ModernBertModel', ...]"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
with self.assertRaises(RuntimeError) as cm:
|
|
250
|
+
BertScore(model_id=self.test_model, backend="vllm")
|
|
251
|
+
|
|
252
|
+
self.assertIn(
|
|
253
|
+
"vLLM does not accept the masked-LM ModernBERT checkpoint directly",
|
|
254
|
+
str(cm.exception),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
@patch("modern_bert_score.bert_score.AutoTokenizer")
|
|
258
|
+
def test_initialization_invalid_backend(self, mock_tokenizer):
|
|
259
|
+
"""Test initialization with invalid backend raises ValueError."""
|
|
260
|
+
with self.assertRaises(ValueError):
|
|
261
|
+
BertScore(model_id=self.test_model, backend="invalid")
|
|
262
|
+
|
|
263
|
+
@patch("modern_bert_score.bert_score.AutoTokenizer")
|
|
264
|
+
@patch("modern_bert_score.bert_score.STInference")
|
|
265
|
+
def test_call_simple(self, mock_st_inference, mock_tokenizer):
|
|
266
|
+
"""Test scoring call with default backend."""
|
|
267
|
+
mock_engine = MagicMock()
|
|
268
|
+
mock_st_inference.return_value = mock_engine
|
|
269
|
+
|
|
270
|
+
# Mock embeddings: List[Tensor]
|
|
271
|
+
# Need shape [seq_len, hidden_dim].
|
|
272
|
+
# BertScore slices [1:-1] (removes CLS and SEP), so we need at least 3 tokens.
|
|
273
|
+
c_emb = torch.rand(3, 4)
|
|
274
|
+
r_emb = torch.rand(3, 4)
|
|
275
|
+
|
|
276
|
+
mock_engine.inference.return_value = ([c_emb], [r_emb])
|
|
277
|
+
|
|
278
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
279
|
+
results = bs(self.candidates, self.references)
|
|
280
|
+
|
|
281
|
+
self.assertEqual(len(results), 1)
|
|
282
|
+
p, r, f1 = results[0]
|
|
283
|
+
self.assertIsInstance(p, float)
|
|
284
|
+
self.assertIsInstance(r, float)
|
|
285
|
+
self.assertIsInstance(f1, float)
|
|
286
|
+
|
|
287
|
+
# Verify inference called correctly
|
|
288
|
+
mock_engine.inference.assert_called_with(self.candidates, self.references)
|
|
289
|
+
|
|
290
|
+
@patch("modern_bert_score.bert_score.AutoTokenizer")
|
|
291
|
+
@patch("modern_bert_score.bert_score.STInference")
|
|
292
|
+
def test_input_validation(self, mock_st_inference, mock_tokenizer):
|
|
293
|
+
"""Test input validation for candidates/references."""
|
|
294
|
+
bs = BertScore(model_id=self.test_model, backend="default")
|
|
295
|
+
|
|
296
|
+
# Mismatched lengths
|
|
297
|
+
with self.assertRaises(ValueError):
|
|
298
|
+
bs(["a"], ["b", "c"])
|
|
299
|
+
|
|
300
|
+
@pytest.fixture(scope="module")
|
|
301
|
+
def vllm_bert_score():
|
|
302
|
+
# This setup runs once for the module
|
|
303
|
+
bs = BertScore(
|
|
304
|
+
model_id=TEST_MODEL,
|
|
305
|
+
backend="vllm",
|
|
306
|
+
vllm_args={"gpu_memory_utilization": 0.3, "enforce_eager": True, "distributed_executor_backend": "mp", "task": "embed"}
|
|
307
|
+
)
|
|
308
|
+
return bs
|
|
309
|
+
|
|
310
|
+
# 2. The actual test function
|
|
311
|
+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
|
|
312
|
+
def test_vllm_only_feature(vllm_bert_score):
|
|
313
|
+
"""Test that requests the fixture above."""
|
|
314
|
+
cand1 = ["Hello World!"]
|
|
315
|
+
ref1 = ["Hello World!"]
|
|
316
|
+
|
|
317
|
+
# Use the fixture passed as an argument
|
|
318
|
+
p_r_f1 = vllm_bert_score(cand1, ref1)
|
|
319
|
+
|
|
320
|
+
# Use standard asserts instead of self.assertEqual
|
|
321
|
+
assert p_r_f1[0][0] == pytest.approx(1.0)
|
|
322
|
+
assert p_r_f1[0][1] == pytest.approx(1.0)
|
|
323
|
+
assert p_r_f1[0][2] == pytest.approx(1.0)
|
|
324
|
+
|
|
325
|
+
# Test kwargs passed to vLLMInference
|
|
326
|
+
# kwargs = {"task": "embed", "gpu_memory_utilization": 0.3, "enforce_eager": True, "distributed_executor_backend": "mp"}
|