modern-bert-score 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,13 @@
1
+ Copyright 2026 Philipp Koch
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
@@ -0,0 +1,124 @@
1
+ Metadata-Version: 2.4
2
+ Name: modern-bert-score
3
+ Version: 0.0.1
4
+ Summary: A reimplementation of the BERTScore metric optimized for modern inference workflows.
5
+ Author-email: Philipp Koch <PhillKoch@protonmail.com>
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/pypa/sampleproject
8
+ Project-URL: Issues, https://github.com/pypa/sampleproject/issues
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Development Status :: 2 - Pre-Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Environment :: GPU :: NVIDIA CUDA
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.12
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE.md
19
+ Requires-Dist: sentence-transformers>=5.2.3
20
+ Requires-Dist: torch>=2.9.1
21
+ Requires-Dist: transformers>=4.57.6
22
+ Provides-Extra: vllm
23
+ Requires-Dist: vllm>=0.16.0; extra == "vllm"
24
+ Provides-Extra: dev
25
+ Requires-Dist: sphinx; extra == "dev"
26
+ Requires-Dist: sphinx_rtd_theme; extra == "dev"
27
+ Requires-Dist: myst-parser; extra == "dev"
28
+ Requires-Dist: pytest; extra == "dev"
29
+ Requires-Dist: pytest-cov; extra == "dev"
30
+ Requires-Dist: ruff; extra == "dev"
31
+ Requires-Dist: flake8; extra == "dev"
32
+ Requires-Dist: mypy; extra == "dev"
33
+ Requires-Dist: flake8; extra == "dev"
34
+ Requires-Dist: pytest-cov; extra == "dev"
35
+ Dynamic: license-file
36
+
37
+ # Modern BERTScore for Fast Inference
38
+
39
+ [![CI](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml/badge.svg)](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml)
40
+ [![Python 3.12](https://img.shields.io/badge/python-3.12-blue.svg)](https://www.python.org/downloads/release/python-3120/)
41
+ [![Python 3.13](https://img.shields.io/badge/python-3.13-blue.svg)](https://www.python.org/downloads/release/python-3130/)
42
+ [![Python 3.14](https://img.shields.io/badge/python-3.14-blue.svg)](https://www.python.org/downloads/)
43
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
44
+
45
+ ![Performance](runtime_comp_nvidia_rtx_4060.png)
46
+
47
+ **Modern-BERT-Score** is a reimplementation of the BERTScore metric introduced by [Zhang et al., 2019](https://arxiv.org/abs/1904.09675), optimized for modern inference workflows using [SentenceTransformers](https://www.sbert.net/) and [vLLM](https://vllm.ai/).
48
+
49
+ This library provides fast, GPU-accelerated scoring for text generation evaluation, making BERTScore practical for large-scale inference tasks.
50
+
51
+ ---
52
+
53
+ ## ⚡ Features
54
+ - Fast, efficient computation with optional vLLM support
55
+ - Compatible with all Hugging Face transformer models
56
+ - Supports truncated and optimized model versions for faster inference
57
+ - Works seamlessly with both CPU and GPU setups
58
+
59
+ ---
60
+
61
+ ## 📦 Installation
62
+
63
+ Modern-BERT-Score comes in **two variants**: a base version and a vLLM-enhanced version. For vLLM, an NVIDIA GPU is strongly recommended.
64
+
65
+ ### Base Version
66
+ ```bash
67
+ pip install modern-bert-score
68
+ ```
69
+
70
+ ### vLLM Version
71
+ ```{bash}
72
+ pip install modern-bert-score[vllm]
73
+ ```
74
+ This implementation is significantly faster than the original BERTScore, especially with GPU acceleration.
75
+
76
+ ## 📝 BERTScore
77
+ BERTScore ([Zhang et al., 2019](https://arxiv.org/abs/1904.09675)) evaluates the similarity between candidate and reference texts by comparing their contextual embeddings from a pre-trained transformer model. For each token in the candidate, it finds the most similar token in the reference (using cosine similarity) and aggregates these scores to compute **precision**, **recall**, and **F1**. Optionally, **IDF-weighting** can be applied to give more importance to rare and informative words, improving the metric’s sensitivity to meaningful content over common words. Additionally, optional **Baseline Rescaling** shifts the scores such that the score is in the range of [0,1]. This approach captures semantic similarity beyond exact word matches, making it robust for tasks such as machine translation and text generation evaluation.
78
+
79
+ The following figure, taken from the original paper, illustrates how BERTScore works:
80
+
81
+ ![BERTScore](zhang_19_figure_1.png)
82
+
83
+ ## 🛠 Usage
84
+ ### Example
85
+ ```python
86
+ from modern_bert_score import BertScore
87
+
88
+ candidates = ["Hello World!", "A robin is a bird."]
89
+ references = ["Hi World!", "A robin is not a bird."]
90
+
91
+ metric = BertScore(model_id="roberta-base")
92
+ scores = metric(candidates, references)
93
+
94
+ # scores is a list of (Precision, Recall, F1) tuples
95
+ # To get separate lists of P, R, F1:
96
+ P, R, F1 = zip(*scores)
97
+
98
+ print("Precision scores:", P)
99
+ print("Recall scores:", R)
100
+ print("F1 scores:", F1)
101
+ ```
102
+
103
+ ## ⚠️ NOTICE
104
+
105
+ - For best performance, an optimal layer should be used for each model.
106
+ - To find the optimal layer, [please use this script from the original BERTScore implementation](https://github.com/Tiiiger/bert_score/tree/master/tune_layers).
107
+
108
+ Some pre-truncated models optimized for vLLM are available on [Hugging Face](https://huggingface.co/collections/LazerLambda/modern-bertscore) and directly available in this library:
109
+
110
+ - `LazerLambda/ModernBERT-base-ModBERTScore-12` -> `ModernBERTBaseScore`
111
+ - `LazerLambda/ModernBERT-large-ModBERTScore-19` -> `ModernBERTLargeScore`
112
+ - `LazerLambda/roberta-base-ModBERTScore-10` -> `RobertaBaseScore`
113
+ - `LazerLambda/roberta-large-ModBERTScore-17` -> `RobertaLargeScore`
114
+ - `LazerLambda/roberta-large-mnli-ModBERTScore-19` -> `RobertaLargeMNLIScore`
115
+
116
+
117
+ ## 🗺 Roadmap
118
+
119
+ - [x] Implement base version and vLLM addon
120
+ - [x] Add IDF-weighted scoring
121
+ - [ ] Add baseline-rescaling and scripts for identifying optimal baselines
122
+ - [ ] Add model (vLLM-)adaptation script for slicing the model
123
+ - [ ] Add multilingual support
124
+ - [ ] Add CLI tool
@@ -0,0 +1,88 @@
1
+ # Modern BERTScore for Fast Inference
2
+
3
+ [![CI](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml/badge.svg)](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml)
4
+ [![Python 3.12](https://img.shields.io/badge/python-3.12-blue.svg)](https://www.python.org/downloads/release/python-3120/)
5
+ [![Python 3.13](https://img.shields.io/badge/python-3.13-blue.svg)](https://www.python.org/downloads/release/python-3130/)
6
+ [![Python 3.14](https://img.shields.io/badge/python-3.14-blue.svg)](https://www.python.org/downloads/)
7
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
8
+
9
+ ![Performance](runtime_comp_nvidia_rtx_4060.png)
10
+
11
+ **Modern-BERT-Score** is a reimplementation of the BERTScore metric introduced by [Zhang et al., 2019](https://arxiv.org/abs/1904.09675), optimized for modern inference workflows using [SentenceTransformers](https://www.sbert.net/) and [vLLM](https://vllm.ai/).
12
+
13
+ This library provides fast, GPU-accelerated scoring for text generation evaluation, making BERTScore practical for large-scale inference tasks.
14
+
15
+ ---
16
+
17
+ ## ⚡ Features
18
+ - Fast, efficient computation with optional vLLM support
19
+ - Compatible with all Hugging Face transformer models
20
+ - Supports truncated and optimized model versions for faster inference
21
+ - Works seamlessly with both CPU and GPU setups
22
+
23
+ ---
24
+
25
+ ## 📦 Installation
26
+
27
+ Modern-BERT-Score comes in **two variants**: a base version and a vLLM-enhanced version. For vLLM, an NVIDIA GPU is strongly recommended.
28
+
29
+ ### Base Version
30
+ ```bash
31
+ pip install modern-bert-score
32
+ ```
33
+
34
+ ### vLLM Version
35
+ ```{bash}
36
+ pip install modern-bert-score[vllm]
37
+ ```
38
+ This implementation is significantly faster than the original BERTScore, especially with GPU acceleration.
39
+
40
+ ## 📝 BERTScore
41
+ BERTScore ([Zhang et al., 2019](https://arxiv.org/abs/1904.09675)) evaluates the similarity between candidate and reference texts by comparing their contextual embeddings from a pre-trained transformer model. For each token in the candidate, it finds the most similar token in the reference (using cosine similarity) and aggregates these scores to compute **precision**, **recall**, and **F1**. Optionally, **IDF-weighting** can be applied to give more importance to rare and informative words, improving the metric’s sensitivity to meaningful content over common words. Additionally, optional **Baseline Rescaling** shifts the scores such that the score is in the range of [0,1]. This approach captures semantic similarity beyond exact word matches, making it robust for tasks such as machine translation and text generation evaluation.
42
+
43
+ The following figure, taken from the original paper, illustrates how BERTScore works:
44
+
45
+ ![BERTScore](zhang_19_figure_1.png)
46
+
47
+ ## 🛠 Usage
48
+ ### Example
49
+ ```python
50
+ from modern_bert_score import BertScore
51
+
52
+ candidates = ["Hello World!", "A robin is a bird."]
53
+ references = ["Hi World!", "A robin is not a bird."]
54
+
55
+ metric = BertScore(model_id="roberta-base")
56
+ scores = metric(candidates, references)
57
+
58
+ # scores is a list of (Precision, Recall, F1) tuples
59
+ # To get separate lists of P, R, F1:
60
+ P, R, F1 = zip(*scores)
61
+
62
+ print("Precision scores:", P)
63
+ print("Recall scores:", R)
64
+ print("F1 scores:", F1)
65
+ ```
66
+
67
+ ## ⚠️ NOTICE
68
+
69
+ - For best performance, an optimal layer should be used for each model.
70
+ - To find the optimal layer, [please use this script from the original BERTScore implementation](https://github.com/Tiiiger/bert_score/tree/master/tune_layers).
71
+
72
+ Some pre-truncated models optimized for vLLM are available on [Hugging Face](https://huggingface.co/collections/LazerLambda/modern-bertscore) and directly available in this library:
73
+
74
+ - `LazerLambda/ModernBERT-base-ModBERTScore-12` -> `ModernBERTBaseScore`
75
+ - `LazerLambda/ModernBERT-large-ModBERTScore-19` -> `ModernBERTLargeScore`
76
+ - `LazerLambda/roberta-base-ModBERTScore-10` -> `RobertaBaseScore`
77
+ - `LazerLambda/roberta-large-ModBERTScore-17` -> `RobertaLargeScore`
78
+ - `LazerLambda/roberta-large-mnli-ModBERTScore-19` -> `RobertaLargeMNLIScore`
79
+
80
+
81
+ ## 🗺 Roadmap
82
+
83
+ - [x] Implement base version and vLLM addon
84
+ - [x] Add IDF-weighted scoring
85
+ - [ ] Add baseline-rescaling and scripts for identifying optimal baselines
86
+ - [ ] Add model (vLLM-)adaptation script for slicing the model
87
+ - [ ] Add multilingual support
88
+ - [ ] Add CLI tool
@@ -0,0 +1,17 @@
1
+ from .bert_score import (
2
+ BertScore,
3
+ ModernBERTBaseScore,
4
+ ModernBERTLargeScore,
5
+ RobertaBaseScore,
6
+ RobertaLargeMNLIScore,
7
+ RobertaLargeScore,
8
+ )
9
+
10
+ __all__ = [
11
+ "BertScore",
12
+ "ModernBERTBaseScore",
13
+ "ModernBERTLargeScore",
14
+ "RobertaBaseScore",
15
+ "RobertaLargeScore",
16
+ "RobertaLargeMNLIScore",
17
+ ]
@@ -0,0 +1,310 @@
1
+ from collections import Counter, defaultdict
2
+ from functools import partial, reduce
3
+ from itertools import chain, islice
4
+ from math import log
5
+ from multiprocessing import Pool
6
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
11
+
12
+ from modern_bert_score.consts import BASELINES
13
+ from modern_bert_score.inference import STInference, VLLMInference
14
+
15
+
16
+ class BertScore:
17
+ inference_engine: Optional[Union[STInference, VLLMInference]]
18
+
19
+ def __init__(
20
+ self,
21
+ model_id: str = "answerdotai/ModernBERT-base",
22
+ idf_weighting: bool = False,
23
+ baseline_rescaling: bool = False,
24
+ custom_baseline: Optional[Tuple[float, float, float]] = None,
25
+ device: str = "cpu", # TODO Enum cuda, mlx, cpu?
26
+ backend: str = "default", # TODO Enum default, vllm, onnx, etc
27
+ sentence_transformers_args: Optional[Dict[str, Any]] = None,
28
+ vllm_args: Optional[Dict[str, Any]] = None,
29
+ ):
30
+ self.tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_id)
31
+ self.idf_weighting: bool = idf_weighting
32
+ if baseline_rescaling:
33
+ if model_id in BASELINES.keys():
34
+ self.baseline: Tuple[float, float, float] = BASELINES[model_id]
35
+ if isinstance(self.baseline, tuple):
36
+ self.baseline = self.baseline
37
+ elif custom_baseline is not None:
38
+ self.baseline = custom_baseline
39
+ else:
40
+ raise ValueError(
41
+ (
42
+ "Baseline rescaling enabled but no"
43
+ f"baseline found for model {model_id}"
44
+ " and no custom baseline provided."
45
+ )
46
+ )
47
+ self.baseline_rescaling: bool = baseline_rescaling
48
+ if backend == "vllm":
49
+ self.inference_engine = VLLMInference(
50
+ model=model_id,
51
+ runner="pooling",
52
+ convert="embed",
53
+ **(vllm_args or {}),
54
+ )
55
+ elif backend == "default":
56
+ self.inference_engine = STInference(
57
+ model_id, device=device, **(sentence_transformers_args or {})
58
+ )
59
+ else:
60
+ raise ValueError(f"Unsupported backend {backend}")
61
+
62
+ def __call__(
63
+ self,
64
+ candidates: Union[str, List[str]],
65
+ references: Union[str, List[str]],
66
+ **kwargs: Any
67
+ ) -> List[Tuple[float, float, float]]:
68
+ assert reduce(
69
+ lambda acc, x: acc and isinstance(x, str), candidates, True
70
+ ), "Candidates must be a list of strings or a string."
71
+ assert reduce(
72
+ lambda acc, x: acc and isinstance(x, str), references, True
73
+ ), "References must be a list of strings or a string."
74
+ if isinstance(candidates, str):
75
+ candidates = [candidates]
76
+ if isinstance(references, str):
77
+ references = [references]
78
+ if len(candidates) != len(references):
79
+ raise ValueError(
80
+ "Number of candidates and references must be the same."
81
+ )
82
+ if len(candidates) == 0:
83
+ return []
84
+ candidates = [c.strip() for c in candidates]
85
+ references = [r.strip() for r in references]
86
+ if self.idf_weighting:
87
+ idf_dict_ref, input_ids_ref = self.get_idf_dict(references)
88
+ input_ids_cand = self._tokenize_data(candidates)
89
+ else:
90
+ idf_dict_ref = None
91
+
92
+ if self.inference_engine is None:
93
+ raise ValueError(
94
+ "Inference engine not initialized. Check backend "
95
+ "configuration."
96
+ )
97
+ cand_embs, ref_embs = self.inference_engine.inference(
98
+ candidates, references, **kwargs
99
+ )
100
+ if self.idf_weighting:
101
+ scores = [
102
+ self.bert_score(
103
+ candidates=c,
104
+ references=r,
105
+ idf_dict_ref=idf_dict_ref,
106
+ input_ids_cand=ids_cand,
107
+ input_ids_ref=ids_ref,
108
+ )
109
+ for c, r, ids_cand, ids_ref in zip(
110
+ cand_embs, ref_embs, input_ids_cand, input_ids_ref
111
+ )
112
+ ]
113
+ else:
114
+ scores = [
115
+ self.bert_score(candidates=c, references=r)
116
+ for c, r in zip(cand_embs, ref_embs)
117
+ ]
118
+ if self.baseline_rescaling:
119
+ rescaled_scores = []
120
+ for p, r, f1 in scores:
121
+ rescaled_p = (p - self.baseline[0]) / (1 - self.baseline[0])
122
+ rescaled_r = (r - self.baseline[1]) / (1 - self.baseline[1])
123
+ rescaled_f1 = (f1 - self.baseline[2]) / (1 - self.baseline[2])
124
+ rescaled_scores.append((rescaled_p, rescaled_r, rescaled_f1))
125
+ return rescaled_scores
126
+ return scores
127
+
128
+ @staticmethod
129
+ def _check_nan(f1: torch.Tensor) -> torch.Tensor:
130
+ if torch.isnan(f1):
131
+ f1 = torch.Tensor([0.0])
132
+ return f1
133
+
134
+
135
+ def bert_score(
136
+ self,
137
+ candidates: torch.Tensor,
138
+ references: torch.Tensor,
139
+ idf_dict_ref: Optional[Dict[int, float]] = None,
140
+ input_ids_cand: Optional[List[int]] = None,
141
+ input_ids_ref: Optional[List[int]] = None,
142
+ ) -> Tuple[float, float, float]:
143
+ has_idf_dict = idf_dict_ref is not None
144
+ has_input_ids = (
145
+ input_ids_cand is not None and input_ids_ref is not None
146
+ )
147
+ if has_idf_dict != has_input_ids:
148
+ raise ValueError(
149
+ "`idf_dict` and `input_ids` must either both be provided or "
150
+ "both be None."
151
+ )
152
+ # TODO: w cuda?
153
+ assert len(candidates.shape) == 2 and len(references.shape) == 2
154
+ candidates = candidates[1:-1] # remove CLS and SEP
155
+ references = references[1:-1]
156
+ similarities: torch.Tensor = candidates @ references.T
157
+ r_bert = similarities.max(dim=0).values.cpu()
158
+ p_bert = similarities.max(dim=1).values.cpu()
159
+ if idf_dict_ref and input_ids_cand and input_ids_ref:
160
+ idf_weights_cand = torch.tensor(
161
+ [idf_dict_ref[tok_id] for tok_id in input_ids_cand]
162
+ )
163
+ idf_weights_ref = torch.tensor(
164
+ [idf_dict_ref[tok_id] for tok_id in input_ids_ref]
165
+ )
166
+ idf_weights_cand /= idf_weights_cand.sum()
167
+ idf_weights_ref /= idf_weights_ref.sum()
168
+ p_bert = (p_bert * idf_weights_cand).sum()
169
+ r_bert = (r_bert * idf_weights_ref).sum()
170
+ else:
171
+ r_bert = r_bert.mean()
172
+ p_bert = p_bert.mean()
173
+ f1 = 2 * p_bert * r_bert / (p_bert + r_bert)
174
+ # handle p_bert + r_bert == 0
175
+ f1 = self._check_nan(f1)
176
+ return p_bert.item(), r_bert.item(), f1.item()
177
+
178
+ @staticmethod
179
+ def _batchify(
180
+ iterable: List[str], batch_size: int
181
+ ) -> Generator[List[str], None, None]:
182
+ iterator = iter(iterable)
183
+ while True:
184
+ batch = list(islice(iterator, batch_size))
185
+ if not batch:
186
+ break
187
+ yield batch
188
+
189
+ @staticmethod
190
+ def _process_batch(
191
+ batch: List[str],
192
+ tokenizer: PreTrainedTokenizerFast,
193
+ ignore_counter: bool = False,
194
+ ) -> Tuple[Counter[int], List[List[int]]]:
195
+ stripped_batch = [sample.strip() for sample in batch]
196
+
197
+ encoded_batch = tokenizer(
198
+ stripped_batch,
199
+ add_special_tokens=True,
200
+ truncation=True,
201
+ return_attention_mask=False,
202
+ return_token_type_ids=False,
203
+ )["input_ids"]
204
+ encoded_batch = [e[1:-1] for e in encoded_batch] # remove CLS and SEP
205
+ if ignore_counter:
206
+ return Counter(), encoded_batch
207
+ else:
208
+ batch_count: Counter[int] = Counter(
209
+ chain.from_iterable(map(set, encoded_batch))
210
+ )
211
+ return batch_count, encoded_batch
212
+
213
+ def _tokenize_data(
214
+ self,
215
+ corpus: List[str],
216
+ batch_size: int = 100_000,
217
+ nthreads: int = 4) -> List[List[int]]:
218
+ collected_input_ids: List[List[int]] = []
219
+
220
+ process_partial = partial(
221
+ self._process_batch, tokenizer=self.tokenizer, ignore_counter=True
222
+ )
223
+ batches = self._batchify(corpus, batch_size)
224
+
225
+ if nthreads > 0:
226
+ with Pool(nthreads) as pool:
227
+ for batch_result in pool.imap(
228
+ process_partial, batches, chunksize=1
229
+ ):
230
+ _, batch_input_ids = batch_result
231
+ collected_input_ids.extend(batch_input_ids)
232
+ else:
233
+ for batch_result in map(process_partial, batches):
234
+ _, batch_input_ids = batch_result
235
+ collected_input_ids.extend(batch_input_ids)
236
+ return collected_input_ids
237
+
238
+ def get_idf_dict(
239
+ self,
240
+ corpus: List[str],
241
+ nthreads: int = 4,
242
+ batch_size: int = 100_000,
243
+ ) -> Tuple[Dict[int, float], List[List[int]]]: # TODO: Return dict
244
+ """Build an IDF (Inverse-Document-Frequency) dictionary for a corpus.
245
+
246
+ When ``return_input_ids`` is true, this also returns the tokenized
247
+ ``input_ids`` for each corpus entry in the same order as ``corpus``.
248
+ """
249
+ idf_count: Counter[int] = Counter()
250
+ collected_input_ids: List[List[int]] = []
251
+ num_docs = len(corpus)
252
+
253
+ process_partial = partial(
254
+ self._process_batch, tokenizer=self.tokenizer
255
+ )
256
+ batches = self._batchify(corpus, batch_size)
257
+
258
+ if nthreads > 0:
259
+ with Pool(nthreads) as pool:
260
+ for batch_result in pool.imap(
261
+ process_partial, batches, chunksize=1
262
+ ):
263
+ batch_count, batch_input_ids = batch_result
264
+ collected_input_ids.extend(batch_input_ids)
265
+ idf_count.update(batch_count)
266
+ else:
267
+ for batch_result in map(process_partial, batches):
268
+ batch_count, batch_input_ids = batch_result
269
+ collected_input_ids.extend(batch_input_ids)
270
+ idf_count.update(batch_count)
271
+
272
+ idf_dict: Dict[int, float] = defaultdict(
273
+ lambda: log((num_docs + 1) / (1))
274
+ )
275
+ idf_dict.update(
276
+ {
277
+ idx: log((num_docs + 1) / (count + 1))
278
+ for idx, count in idf_count.items()
279
+ }
280
+ )
281
+ return idf_dict, collected_input_ids
282
+
283
+ def ModernBERTBaseScore(**kwargs: Any) -> "BertScore":
284
+ """BertScore with ModernBERT-base-ModBERTScore-12"""
285
+ kwargs.pop("model_id", None)
286
+ return BertScore(model_id="LazerLambda/ModernBERT-base-ModBERTScore-12", **kwargs)
287
+
288
+
289
+ def ModernBERTLargeScore(**kwargs: Any) -> "BertScore":
290
+ """BertScore with ModernBERT-large-ModBERTScore-19"""
291
+ kwargs.pop("model_id", None)
292
+ return BertScore(model_id="LazerLambda/ModernBERT-large-ModBERTScore-19", **kwargs)
293
+
294
+
295
+ def RobertaBaseScore(**kwargs: Any) -> "BertScore":
296
+ """BertScore with roberta-base-ModBERTScore-10"""
297
+ kwargs.pop("model_id", None)
298
+ return BertScore(model_id="LazerLambda/roberta-base-ModBERTScore-10", **kwargs)
299
+
300
+
301
+ def RobertaLargeScore(**kwargs: Any) -> "BertScore":
302
+ """BertScore with roberta-large-ModBERTScore-17"""
303
+ kwargs.pop("model_id", None)
304
+ return BertScore(model_id="LazerLambda/roberta-large-ModBERTScore-17", **kwargs)
305
+
306
+
307
+ def RobertaLargeMNLIScore(**kwargs: Any) -> "BertScore":
308
+ """BertScore with roberta-large-mnli-ModBERTScore-19"""
309
+ kwargs.pop("model_id", None)
310
+ return BertScore(model_id="LazerLambda/roberta-large-mnli-ModBERTScore-19", **kwargs)
@@ -0,0 +1,8 @@
1
+
2
+ BASELINES = {
3
+ "LazerLambda/ModernBERT-base-ModBERTScore-12": (0.7899982,0.790043,0.7898656),
4
+ "LazerLambda/ModernBERT-large-ModBERTScore-19": (0.64403415,0.64410657,0.64355606),
5
+ "LazerLambda/roberta-base-ModBERTScore-10": (0.80465585,0.80464727,0.8043641),
6
+ "LazerLambda/roberta-large-ModBERTScore-17": (0.8243552,0.8243522,0.8240494),
7
+ "LazerLambda/roberta-large-mnli-ModBERTScore-19": (0.6901594,0.69018,0.6896288)
8
+ }
@@ -0,0 +1,139 @@
1
+ import gc
2
+ from typing import Any, List
3
+
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer
6
+ from torch.nn import functional as F
7
+ from transformers import AutoTokenizer
8
+
9
+ try:
10
+ from vllm import LLM
11
+
12
+ VLLM_AVAILABLE = True
13
+ except ImportError:
14
+ LLM = object # To prevent NameError if vllm is not installed
15
+ VLLM_AVAILABLE = False
16
+
17
+
18
+ # TODO: Cache reference embeddings
19
+ class Inference:
20
+
21
+ model: Any = None
22
+
23
+ def inference(
24
+ self, candidates: List[str], references: List[str], **kwargs: Any
25
+ ) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
26
+ raise NotImplementedError("Method must be implemented in Subclass.")
27
+
28
+
29
+ class STInference(Inference):
30
+ def __init__(
31
+ self,
32
+ model_id: str,
33
+ device: str = "cpu",
34
+ batch_size: int = 64,
35
+ **kwargs: Any,
36
+ ):
37
+ self.model = SentenceTransformer(
38
+ model_name_or_path=model_id, device=device, **kwargs
39
+ )
40
+ self.tokenizer = AutoTokenizer.from_pretrained(
41
+ model_id
42
+ ) # TODO: Maybe switch to PreTrainedTokenizerFast for clarity?
43
+ self.batch_size = batch_size
44
+ self.eps: float = 1e-12
45
+
46
+ def inference(
47
+ self, candidates: List[str], references: List[str], **kwargs: Any
48
+ ) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
49
+
50
+ if self.model is None:
51
+ raise RuntimeError("Model not loaded.")
52
+ embds_refs = self.model.encode(
53
+ references,
54
+ output_value="token_embeddings",
55
+ convert_to_tensor=True,
56
+ **kwargs,
57
+ )
58
+ embds_refs = [
59
+ F.normalize(e, p=2, dim=-1, eps=self.eps) for e in embds_refs
60
+ ]
61
+ embds_cnds = self.model.encode(
62
+ candidates,
63
+ output_value="token_embeddings",
64
+ convert_to_tensor=True,
65
+ **kwargs,
66
+ )
67
+ embds_cnds = [
68
+ F.normalize(e, p=2, dim=-1, eps=self.eps) for e in embds_cnds
69
+ ]
70
+
71
+ return embds_cnds, embds_refs
72
+
73
+
74
+ class VLLMInference(Inference):
75
+ def __init__(self, **kwargs: Any):
76
+ if not VLLM_AVAILABLE:
77
+ raise ImportError(
78
+ "vLLM is not installed. To use the vLLM backend, please "
79
+ "install it with `pip install vllm` or "
80
+ "`pip install 'modern-bert-score[vllm]'`."
81
+ )
82
+ # Backward compatibility for old callsites that pass task="embed".
83
+ kwargs = self._prepare_args(kwargs)
84
+ try:
85
+ self.model = LLM(**kwargs)
86
+ except Exception as exc:
87
+ message = str(exc)
88
+ if (
89
+ "Model architectures" in message
90
+ and "ModernBertForMaskedLM" in message
91
+ ):
92
+ raise RuntimeError(
93
+ "vLLM does not accept the masked-LM ModernBERT checkpoint "
94
+ "directly. Export an encoder-only checkpoint first with "
95
+ "prepare_model.py, which rewrites the saved config to "
96
+ "advertise ModernBertModel, then load that local path in "
97
+ "VLLMInference. If you do not need vLLM specifically, use "
98
+ "STInference for the original HF checkpoint."
99
+ ) from exc
100
+ raise
101
+ self.eps: float = 1e-12
102
+
103
+ @staticmethod
104
+ def _prepare_args(kwargs: Any) -> Any:
105
+ task = kwargs.pop("task", None)
106
+ if task == "embed":
107
+ kwargs.setdefault("runner", "pooling")
108
+ kwargs.setdefault("convert", "embed")
109
+ return kwargs
110
+
111
+ def inference(
112
+ self, candidates: List[str], references: List[str], **kwargs: Any
113
+ ) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
114
+ if self.model is None:
115
+ raise RuntimeError("Model not loaded.")
116
+ outputs_cands = self.model.encode(
117
+ candidates, pooling_task="token_embed", **kwargs
118
+ )
119
+ outputs_refs = self.model.encode(
120
+ references, pooling_task="token_embed", **kwargs
121
+ )
122
+ collector: List[torch.Tensor] = []
123
+ for output in outputs_cands:
124
+ embeds = output.outputs.data
125
+ collector.append(embeds)
126
+ for output in outputs_refs:
127
+ embeds = output.outputs.data
128
+ collector.append(embeds)
129
+
130
+ collector = [
131
+ F.normalize(e, p=2, dim=-1, eps=self.eps) for e in collector
132
+ ] # TODO: Check superflous?
133
+ return collector[0 : len(candidates)], collector[len(candidates) :]
134
+
135
+ def cleanup(self) -> None:
136
+ if hasattr(self, "model") and self.model:
137
+ del self.model
138
+ gc.collect()
139
+ torch.cuda.empty_cache()
@@ -0,0 +1,124 @@
1
+ Metadata-Version: 2.4
2
+ Name: modern-bert-score
3
+ Version: 0.0.1
4
+ Summary: A reimplementation of the BERTScore metric optimized for modern inference workflows.
5
+ Author-email: Philipp Koch <PhillKoch@protonmail.com>
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/pypa/sampleproject
8
+ Project-URL: Issues, https://github.com/pypa/sampleproject/issues
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Development Status :: 2 - Pre-Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Environment :: GPU :: NVIDIA CUDA
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.12
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE.md
19
+ Requires-Dist: sentence-transformers>=5.2.3
20
+ Requires-Dist: torch>=2.9.1
21
+ Requires-Dist: transformers>=4.57.6
22
+ Provides-Extra: vllm
23
+ Requires-Dist: vllm>=0.16.0; extra == "vllm"
24
+ Provides-Extra: dev
25
+ Requires-Dist: sphinx; extra == "dev"
26
+ Requires-Dist: sphinx_rtd_theme; extra == "dev"
27
+ Requires-Dist: myst-parser; extra == "dev"
28
+ Requires-Dist: pytest; extra == "dev"
29
+ Requires-Dist: pytest-cov; extra == "dev"
30
+ Requires-Dist: ruff; extra == "dev"
31
+ Requires-Dist: flake8; extra == "dev"
32
+ Requires-Dist: mypy; extra == "dev"
33
+ Requires-Dist: flake8; extra == "dev"
34
+ Requires-Dist: pytest-cov; extra == "dev"
35
+ Dynamic: license-file
36
+
37
+ # Modern BERTScore for Fast Inference
38
+
39
+ [![CI](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml/badge.svg)](https://github.com/LazerLambda/modern-bert-score/actions/workflows/ci.yml)
40
+ [![Python 3.12](https://img.shields.io/badge/python-3.12-blue.svg)](https://www.python.org/downloads/release/python-3120/)
41
+ [![Python 3.13](https://img.shields.io/badge/python-3.13-blue.svg)](https://www.python.org/downloads/release/python-3130/)
42
+ [![Python 3.14](https://img.shields.io/badge/python-3.14-blue.svg)](https://www.python.org/downloads/)
43
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
44
+
45
+ ![Performance](runtime_comp_nvidia_rtx_4060.png)
46
+
47
+ **Modern-BERT-Score** is a reimplementation of the BERTScore metric introduced by [Zhang et al., 2019](https://arxiv.org/abs/1904.09675), optimized for modern inference workflows using [SentenceTransformers](https://www.sbert.net/) and [vLLM](https://vllm.ai/).
48
+
49
+ This library provides fast, GPU-accelerated scoring for text generation evaluation, making BERTScore practical for large-scale inference tasks.
50
+
51
+ ---
52
+
53
+ ## ⚡ Features
54
+ - Fast, efficient computation with optional vLLM support
55
+ - Compatible with all Hugging Face transformer models
56
+ - Supports truncated and optimized model versions for faster inference
57
+ - Works seamlessly with both CPU and GPU setups
58
+
59
+ ---
60
+
61
+ ## 📦 Installation
62
+
63
+ Modern-BERT-Score comes in **two variants**: a base version and a vLLM-enhanced version. For vLLM, an NVIDIA GPU is strongly recommended.
64
+
65
+ ### Base Version
66
+ ```bash
67
+ pip install modern-bert-score
68
+ ```
69
+
70
+ ### vLLM Version
71
+ ```{bash}
72
+ pip install modern-bert-score[vllm]
73
+ ```
74
+ This implementation is significantly faster than the original BERTScore, especially with GPU acceleration.
75
+
76
+ ## 📝 BERTScore
77
+ BERTScore ([Zhang et al., 2019](https://arxiv.org/abs/1904.09675)) evaluates the similarity between candidate and reference texts by comparing their contextual embeddings from a pre-trained transformer model. For each token in the candidate, it finds the most similar token in the reference (using cosine similarity) and aggregates these scores to compute **precision**, **recall**, and **F1**. Optionally, **IDF-weighting** can be applied to give more importance to rare and informative words, improving the metric’s sensitivity to meaningful content over common words. Additionally, optional **Baseline Rescaling** shifts the scores such that the score is in the range of [0,1]. This approach captures semantic similarity beyond exact word matches, making it robust for tasks such as machine translation and text generation evaluation.
78
+
79
+ The following figure, taken from the original paper, illustrates how BERTScore works:
80
+
81
+ ![BERTScore](zhang_19_figure_1.png)
82
+
83
+ ## 🛠 Usage
84
+ ### Example
85
+ ```python
86
+ from modern_bert_score import BertScore
87
+
88
+ candidates = ["Hello World!", "A robin is a bird."]
89
+ references = ["Hi World!", "A robin is not a bird."]
90
+
91
+ metric = BertScore(model_id="roberta-base")
92
+ scores = metric(candidates, references)
93
+
94
+ # scores is a list of (Precision, Recall, F1) tuples
95
+ # To get separate lists of P, R, F1:
96
+ P, R, F1 = zip(*scores)
97
+
98
+ print("Precision scores:", P)
99
+ print("Recall scores:", R)
100
+ print("F1 scores:", F1)
101
+ ```
102
+
103
+ ## ⚠️ NOTICE
104
+
105
+ - For best performance, an optimal layer should be used for each model.
106
+ - To find the optimal layer, [please use this script from the original BERTScore implementation](https://github.com/Tiiiger/bert_score/tree/master/tune_layers).
107
+
108
+ Some pre-truncated models optimized for vLLM are available on [Hugging Face](https://huggingface.co/collections/LazerLambda/modern-bertscore) and directly available in this library:
109
+
110
+ - `LazerLambda/ModernBERT-base-ModBERTScore-12` -> `ModernBERTBaseScore`
111
+ - `LazerLambda/ModernBERT-large-ModBERTScore-19` -> `ModernBERTLargeScore`
112
+ - `LazerLambda/roberta-base-ModBERTScore-10` -> `RobertaBaseScore`
113
+ - `LazerLambda/roberta-large-ModBERTScore-17` -> `RobertaLargeScore`
114
+ - `LazerLambda/roberta-large-mnli-ModBERTScore-19` -> `RobertaLargeMNLIScore`
115
+
116
+
117
+ ## 🗺 Roadmap
118
+
119
+ - [x] Implement base version and vLLM addon
120
+ - [x] Add IDF-weighted scoring
121
+ - [ ] Add baseline-rescaling and scripts for identifying optimal baselines
122
+ - [ ] Add model (vLLM-)adaptation script for slicing the model
123
+ - [ ] Add multilingual support
124
+ - [ ] Add CLI tool
@@ -0,0 +1,13 @@
1
+ LICENSE.md
2
+ README.md
3
+ pyproject.toml
4
+ modern_bert_score/__init__.py
5
+ modern_bert_score/bert_score.py
6
+ modern_bert_score/consts.py
7
+ modern_bert_score/inference.py
8
+ modern_bert_score.egg-info/PKG-INFO
9
+ modern_bert_score.egg-info/SOURCES.txt
10
+ modern_bert_score.egg-info/dependency_links.txt
11
+ modern_bert_score.egg-info/requires.txt
12
+ modern_bert_score.egg-info/top_level.txt
13
+ tests/test_bert_score.py
@@ -0,0 +1,16 @@
1
+ sentence-transformers>=5.2.3
2
+ torch>=2.9.1
3
+ transformers>=4.57.6
4
+
5
+ [dev]
6
+ sphinx
7
+ sphinx_rtd_theme
8
+ myst-parser
9
+ pytest
10
+ pytest-cov
11
+ ruff
12
+ flake8
13
+ mypy
14
+
15
+ [vllm]
16
+ vllm>=0.16.0
@@ -0,0 +1 @@
1
+ modern_bert_score
@@ -0,0 +1,66 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "modern-bert-score"
7
+ version = "0.0.1"
8
+ description = "A reimplementation of the BERTScore metric optimized for modern inference workflows."
9
+ readme = "README.md"
10
+ requires-python = ">=3.12"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "Operating System :: OS Independent",
14
+ "Development Status :: 2 - Pre-Alpha",
15
+ "Intended Audience :: Developers",
16
+ "Intended Audience :: Science/Research",
17
+ "Environment :: GPU :: NVIDIA CUDA",
18
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
19
+ ]
20
+ license = "Apache-2.0"
21
+ license-files = ["LICEN[CS]E*"]
22
+ authors = [
23
+ {name = "Philipp Koch", email="PhillKoch@protonmail.com"},
24
+ ]
25
+ dependencies = [
26
+ "sentence-transformers>=5.2.3",
27
+ "torch>=2.9.1",
28
+ "transformers>=4.57.6",
29
+ ]
30
+
31
+ [project.urls]
32
+ Homepage = "https://github.com/pypa/sampleproject"
33
+ Issues = "https://github.com/pypa/sampleproject/issues"
34
+
35
+ [project.optional-dependencies]
36
+ vllm = ["vllm>=0.16.0"]
37
+ dev = [
38
+ "sphinx",
39
+ "sphinx_rtd_theme",
40
+ "myst-parser",
41
+ "pytest",
42
+ "pytest-cov",
43
+ "ruff",
44
+ "flake8",
45
+ "mypy",
46
+ "flake8",
47
+ "pytest-cov"
48
+ ]
49
+
50
+ [tool.setuptools.packages.find]
51
+ include = ["modern_bert_score*"]
52
+
53
+ [tool.ruff]
54
+ line-length = 90
55
+ extend-exclude = ["tests"]
56
+
57
+ [tool.ruff.lint]
58
+ select = ["E", "F", "W", "I"]
59
+
60
+ [tool.mypy]
61
+ strict = true
62
+ ignore_missing_imports = true
63
+
64
+ [tool.pytest.ini_options]
65
+ filterwarnings = ["ignore::DeprecationWarning"]
66
+ exclude = ["^build/"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,326 @@
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+
4
+ import pytest
5
+ import torch
6
+ from transformers import logging as transformers_logging
7
+
8
+ from modern_bert_score.bert_score import BertScore
9
+ from modern_bert_score.inference import VLLM_AVAILABLE
10
+
11
+ # Suppress transformers warnings about missing weights during testing
12
+ transformers_logging.set_verbosity_error()
13
+
14
+ TEST_MODEL= "LazerLambda/BERT-Tiny-L-2-H-128-A-2-ModBERTScore-TEST"
15
+
16
+ class TestBertScore(unittest.TestCase):
17
+
18
+ test_model: str = TEST_MODEL
19
+
20
+ def setUp(self):
21
+ self.candidates = [
22
+ "Hello, my name is",
23
+ "The president of the United States is",
24
+ "The capital of France is",
25
+ "The future of AI is",
26
+ "The cat is on the table",
27
+ ]
28
+ self.references = [
29
+ "Hello, my name is",
30
+ "The head of the United States is",
31
+ "The capital of Japan is",
32
+ "The future of Work is",
33
+ "The dog is on the table",
34
+ ]
35
+
36
+ def test_against_original(self):
37
+ """Test basic functionality."""
38
+ original_p_r_f1 = [
39
+ torch.tensor([1.0000, 0.9423, 0.8705, 0.9105, 0.9637]),
40
+ torch.tensor([1.0000, 0.9423, 0.8705, 0.9105, 0.9637]),
41
+ torch.tensor([1.0000, 0.9423, 0.8705, 0.9105, 0.9637])
42
+ ]
43
+ original_p_r_f1 = torch.stack(original_p_r_f1).T
44
+ original_p_r_f1_idf = [
45
+ torch.tensor([1.0000, 0.9233, 0.8167, 0.8609, 0.9369]),
46
+ torch.tensor([1.0000, 0.9358, 0.8304, 0.8861, 0.9505]),
47
+ torch.tensor([1.0000, 0.9295, 0.8235, 0.8733, 0.9436])
48
+ ]
49
+ original_p_r_f1_idf = torch.stack(original_p_r_f1_idf).T
50
+ bs = BertScore(model_id=self.test_model, backend="default")
51
+ p_r_f1 = bs(self.candidates, self.references)
52
+ bs.idf_weighting = True
53
+ p_r_f1_idf = bs(self.candidates, self.references)
54
+ for (p, r, f1), (p_exp, r_exp, f1_exp) in zip(p_r_f1, original_p_r_f1):
55
+ self.assertTrue(torch.allclose(torch.tensor(p), p_exp, atol=1e-4))
56
+ self.assertTrue(torch.allclose(torch.tensor(r), r_exp, atol=1e-4))
57
+ self.assertTrue(torch.allclose(torch.tensor(f1), f1_exp, atol=1e-4))
58
+ for (p, r, f1), (p_exp, r_exp, f1_exp) in zip(p_r_f1_idf, original_p_r_f1_idf):
59
+ self.assertTrue(torch.allclose(torch.tensor(p), p_exp, atol=1e-4))
60
+ self.assertTrue(torch.allclose(torch.tensor(r), r_exp, atol=1e-4))
61
+ self.assertTrue(torch.allclose(torch.tensor(f1), f1_exp, atol=1e-4))
62
+
63
+ def test_id(self):
64
+ cand1 = ["Hello World!"]
65
+ ref1 = ["Hello World!"]
66
+ bs = BertScore(model_id=self.test_model, backend="default")
67
+ p_r_f1 = bs(cand1, ref1)
68
+ self.assertAlmostEqual(p_r_f1[0][0], 1.0, places=5)
69
+ self.assertAlmostEqual(p_r_f1[0][1], 1.0, places=5)
70
+ self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
71
+
72
+ cand2 = ["Hello World!"]
73
+ ref2 = [" Hello World! "]
74
+ p_r_f1 = bs(cand2, ref2)
75
+ self.assertAlmostEqual(p_r_f1[0][0], 1.0, places=5)
76
+ self.assertAlmostEqual(p_r_f1[0][1], 1.0, places=5)
77
+ self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
78
+
79
+ def test_unequal_length(self):
80
+ cand1 = ["Hello World!"]
81
+ ref1 = ["Hello World!", "Hello World!"]
82
+ bs = BertScore(model_id=self.test_model, backend="default")
83
+ with self.assertRaises(ValueError):
84
+ bs(cand1, ref1)
85
+
86
+ cand2 = ["Hello World!", "Hello World!"]
87
+ ref2 = ["Hello World!"]
88
+ bs = BertScore(model_id=self.test_model, backend="default")
89
+ with self.assertRaises(ValueError):
90
+ bs(cand2, ref2)
91
+
92
+ def test_unequal_input(self):
93
+ cand1 = ["Hello World!"]
94
+ ref1 = ["Bye World!"]
95
+ bs = BertScore(model_id=self.test_model, backend="default")
96
+ p_r_f1 = bs(cand1, ref1)
97
+ self.assertTrue(p_r_f1[0][0] < 1.0)
98
+ self.assertTrue(p_r_f1[0][1] < 1.0)
99
+ self.assertTrue(p_r_f1[0][2] < 1.0)
100
+
101
+ def test_empty_input(self):
102
+ cand1 = []
103
+ ref1 = []
104
+ bs = BertScore(model_id=self.test_model, backend="default")
105
+ p_r_f1 = bs(cand1, ref1)
106
+ self.assertEqual(p_r_f1, [])
107
+
108
+ def test_empty_inference_engine(self):
109
+ bs = BertScore(model_id=self.test_model, backend="default")
110
+ bs.inference_engine = None
111
+ with self.assertRaises(ValueError):
112
+ bs(self.candidates, self.references)
113
+
114
+ def test_check_nan(self):
115
+ f1_nan = torch.tensor(torch.nan)
116
+ f1_checked = BertScore._check_nan(f1_nan)
117
+ self.assertEqual(f1_checked, torch.Tensor([0.0]))
118
+
119
+ def test_exception_idf_bertscore(self):
120
+ bs = BertScore(model_id=self.test_model, backend="default")
121
+ with self.assertRaises(ValueError):
122
+ bs.bert_score(
123
+ candidates=torch.rand(3, 4),
124
+ references=torch.rand(3, 4),
125
+ input_ids_cand=[0, 1, 2],
126
+ input_ids_ref=[0, 1, 2],
127
+ )
128
+
129
+ def test_single(self):
130
+ cand1 = "Hello World!"
131
+ ref1 = "Hello World!"
132
+ bs = BertScore(model_id=self.test_model, backend="default")
133
+ bs(cand1, ref1)
134
+
135
+ def test_tokenize_batch(self):
136
+ bs = BertScore(model_id=self.test_model, backend="default")
137
+ counter, input_ids = bs._process_batch(
138
+ ["Hello World!", "Hello World!"],
139
+ bs.tokenizer,
140
+ ignore_counter=True
141
+ )
142
+ assert len(counter) == 0
143
+ counter, input_ids = bs._process_batch(
144
+ ["Hello World!", "Hello World!"],
145
+ bs.tokenizer,
146
+ ignore_counter=False
147
+ )
148
+ assert counter is not None
149
+
150
+ def test_tokenize_data(self):
151
+ bs = BertScore(model_id=self.test_model, backend="default")
152
+ bs._tokenize_data(["Hello World!", "Hello World!"], nthreads=4)
153
+ bs._tokenize_data(["Hello World!", "Hello World!"], nthreads=0)
154
+
155
+ def test_get_idf_dict(self):
156
+ bs = BertScore(model_id=self.test_model, backend="default")
157
+ bs.get_idf_dict(["Hello World!", "Hello World!"], nthreads=4)
158
+ bs.get_idf_dict(["Hello World!", "Hello World!"], nthreads=0)
159
+
160
+ def test_base_line(self):
161
+ cand1 = ["Hello World!"]
162
+ ref1 = ["Hello World!"]
163
+ bs = BertScore(model_id=self.test_model, backend="default", baseline_rescaling=True, custom_baseline=(0.5, 0.5, 0.5))
164
+ p_r_f1 = bs(cand1, ref1)
165
+ self.assertAlmostEqual(p_r_f1[0][0], 1.0, places=5)
166
+ self.assertAlmostEqual(p_r_f1[0][1], 1.0, places=5)
167
+ self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
168
+
169
+ def test_baseline_rescaling_exception(self):
170
+ """Test that ValueError is raised when baseline is missing."""
171
+ with self.assertRaises(ValueError):
172
+ BertScore(model_id=self.test_model, backend="default", baseline_rescaling=True)
173
+
174
+ def test_baseline_rescaling_injection(self):
175
+ """Test baseline rescaling with injected baseline for the test model."""
176
+ from modern_bert_score.consts import BASELINES
177
+ # Inject test model baseline (P, R, F1) tuple. F1 is at index 2.
178
+ BASELINES[self.test_model] = (0.5, 0.5, 0.5)
179
+ try:
180
+ bs = BertScore(model_id=self.test_model, backend="default", baseline_rescaling=True)
181
+ self.assertEqual(bs.baseline, (0.5, 0.5, 0.5))
182
+ # Test computation
183
+ cand1 = ["Hello World!"]
184
+ ref1 = ["Hello World!"]
185
+ p_r_f1 = bs(cand1, ref1)
186
+ self.assertAlmostEqual(p_r_f1[0][2], 1.0, places=5)
187
+ finally:
188
+ del BASELINES[self.test_model]
189
+
190
+ def test_inference_base_not_implemented_error(self):
191
+ from modern_bert_score.inference import Inference
192
+ with self.assertRaises(NotImplementedError):
193
+ dummy = Inference()
194
+ dummy.inference(["Hello"], ["Hello"])
195
+
196
+ @patch("modern_bert_score.bert_score.AutoTokenizer")
197
+ @patch("modern_bert_score.bert_score.STInference")
198
+ def test_initialization_default(self, mock_st_inference, mock_tokenizer):
199
+ """Test initialization with default backend."""
200
+ # Setup mocks
201
+ mock_tokenizer.from_pretrained.return_value = MagicMock()
202
+
203
+ bs = BertScore(model_id=self.test_model, backend="default")
204
+
205
+ mock_tokenizer.from_pretrained.assert_called_with(self.test_model)
206
+ mock_st_inference.assert_called_once()
207
+ self.assertEqual(bs.inference_engine, mock_st_inference.return_value)
208
+
209
+ @patch("modern_bert_score.bert_score.AutoTokenizer")
210
+ @patch("modern_bert_score.bert_score.VLLMInference")
211
+ def test_initialization_vllm(self, mock_vllm_inference, mock_tokenizer):
212
+ """Test initialization with vllm backend."""
213
+ # Setup mocks
214
+ mock_tokenizer.from_pretrained.return_value = MagicMock()
215
+
216
+ bs = BertScore(model_id=self.test_model, backend="vllm")
217
+
218
+ mock_vllm_inference.assert_called_once()
219
+ self.assertEqual(bs.inference_engine, mock_vllm_inference.return_value)
220
+
221
+ @patch("modern_bert_score.bert_score.AutoTokenizer")
222
+ @patch("modern_bert_score.inference.VLLM_AVAILABLE", False)
223
+ def test_initialization_vllm_without_vllm_installed(self, mock_tokenizer):
224
+ """Test that using vllm backend without vllm installed raises ImportError."""
225
+ mock_tokenizer.from_pretrained.return_value = MagicMock()
226
+ with self.assertRaises(ImportError) as cm:
227
+ BertScore(model_id=self.test_model, backend="vllm")
228
+
229
+ self.assertIn("vLLM is not installed", str(cm.exception))
230
+ self.assertIn("pip install 'modern-bert-score[vllm]'", str(cm.exception))
231
+
232
+ @patch("modern_bert_score.inference.LLM")
233
+ @patch("modern_bert_score.inference.VLLM_AVAILABLE", True)
234
+ @patch("modern_bert_score.bert_score.AutoTokenizer")
235
+ def test_initialization_vllm_masked_lm_error(self, mock_tokenizer, mock_llm):
236
+ """Test that appropriate error is raised when vLLM rejects MaskedLM architecture.
237
+
238
+ This test uses mocks for vLLM components, so it will run (and pass) regardless
239
+ of whether the `vllm` package is installed in the test environment.
240
+ """
241
+ mock_tokenizer.from_pretrained.return_value = MagicMock()
242
+
243
+ # Simulate vLLM raising exception about architecture
244
+ mock_llm.side_effect = Exception(
245
+ "ValueError: Model architectures ['ModernBertForMaskedLM'] are not supported for now. "
246
+ "Supported architectures: ['ModernBertModel', ...]"
247
+ )
248
+
249
+ with self.assertRaises(RuntimeError) as cm:
250
+ BertScore(model_id=self.test_model, backend="vllm")
251
+
252
+ self.assertIn(
253
+ "vLLM does not accept the masked-LM ModernBERT checkpoint directly",
254
+ str(cm.exception),
255
+ )
256
+
257
+ @patch("modern_bert_score.bert_score.AutoTokenizer")
258
+ def test_initialization_invalid_backend(self, mock_tokenizer):
259
+ """Test initialization with invalid backend raises ValueError."""
260
+ with self.assertRaises(ValueError):
261
+ BertScore(model_id=self.test_model, backend="invalid")
262
+
263
+ @patch("modern_bert_score.bert_score.AutoTokenizer")
264
+ @patch("modern_bert_score.bert_score.STInference")
265
+ def test_call_simple(self, mock_st_inference, mock_tokenizer):
266
+ """Test scoring call with default backend."""
267
+ mock_engine = MagicMock()
268
+ mock_st_inference.return_value = mock_engine
269
+
270
+ # Mock embeddings: List[Tensor]
271
+ # Need shape [seq_len, hidden_dim].
272
+ # BertScore slices [1:-1] (removes CLS and SEP), so we need at least 3 tokens.
273
+ c_emb = torch.rand(3, 4)
274
+ r_emb = torch.rand(3, 4)
275
+
276
+ mock_engine.inference.return_value = ([c_emb], [r_emb])
277
+
278
+ bs = BertScore(model_id=self.test_model, backend="default")
279
+ results = bs(self.candidates, self.references)
280
+
281
+ self.assertEqual(len(results), 1)
282
+ p, r, f1 = results[0]
283
+ self.assertIsInstance(p, float)
284
+ self.assertIsInstance(r, float)
285
+ self.assertIsInstance(f1, float)
286
+
287
+ # Verify inference called correctly
288
+ mock_engine.inference.assert_called_with(self.candidates, self.references)
289
+
290
+ @patch("modern_bert_score.bert_score.AutoTokenizer")
291
+ @patch("modern_bert_score.bert_score.STInference")
292
+ def test_input_validation(self, mock_st_inference, mock_tokenizer):
293
+ """Test input validation for candidates/references."""
294
+ bs = BertScore(model_id=self.test_model, backend="default")
295
+
296
+ # Mismatched lengths
297
+ with self.assertRaises(ValueError):
298
+ bs(["a"], ["b", "c"])
299
+
300
+ @pytest.fixture(scope="module")
301
+ def vllm_bert_score():
302
+ # This setup runs once for the module
303
+ bs = BertScore(
304
+ model_id=TEST_MODEL,
305
+ backend="vllm",
306
+ vllm_args={"gpu_memory_utilization": 0.3, "enforce_eager": True, "distributed_executor_backend": "mp", "task": "embed"}
307
+ )
308
+ return bs
309
+
310
+ # 2. The actual test function
311
+ @pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
312
+ def test_vllm_only_feature(vllm_bert_score):
313
+ """Test that requests the fixture above."""
314
+ cand1 = ["Hello World!"]
315
+ ref1 = ["Hello World!"]
316
+
317
+ # Use the fixture passed as an argument
318
+ p_r_f1 = vllm_bert_score(cand1, ref1)
319
+
320
+ # Use standard asserts instead of self.assertEqual
321
+ assert p_r_f1[0][0] == pytest.approx(1.0)
322
+ assert p_r_f1[0][1] == pytest.approx(1.0)
323
+ assert p_r_f1[0][2] == pytest.approx(1.0)
324
+
325
+ # Test kwargs passed to vLLMInference
326
+ # kwargs = {"task": "embed", "gpu_memory_utilization": 0.3, "enforce_eager": True, "distributed_executor_backend": "mp"}