mteb 2.6.9__py3-none-any.whl → 2.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,327 @@
1
+ from __future__ import annotations
2
+
3
+ import atexit
4
+ import gc
5
+ import logging
6
+ import os
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, Any, Literal
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data import DataLoader
13
+
14
+ from mteb._requires_package import requires_package
15
+ from mteb.abstasks.task_metadata import TaskMetadata
16
+ from mteb.models import ModelMeta
17
+ from mteb.models.abs_encoder import AbsEncoder
18
+ from mteb.types import Array, BatchedInput, PromptType
19
+
20
+ if TYPE_CHECKING:
21
+ from vllm.config import PoolerConfig # type: ignore[import-not-found]
22
+ else:
23
+ PoolerConfig = dict[str, Any]
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ Dtype = Literal["half", "float16", "float", "float32", "bfloat16", "auto"]
28
+
29
+
30
+ class VllmWrapperBase:
31
+ """Wrapper for vllm serving engine."""
32
+
33
+ convert = "auto"
34
+ mteb_model_meta: ModelMeta | None = None
35
+
36
+ def __init__(
37
+ self,
38
+ model: str | ModelMeta,
39
+ revision: str | None = None,
40
+ *,
41
+ trust_remote_code: bool = True,
42
+ dtype: Dtype = "auto",
43
+ head_dtype: Literal["model"] | Dtype | None = None,
44
+ max_model_len: int | None = None,
45
+ max_num_batched_tokens: int | None = None,
46
+ max_num_seqs: int = 128,
47
+ tensor_parallel_size: int = 1,
48
+ enable_prefix_caching: bool | None = None,
49
+ gpu_memory_utilization: float = 0.9,
50
+ hf_overrides: dict[str, Any] | None = None,
51
+ pooler_config: PoolerConfig | None = None,
52
+ enforce_eager: bool = False,
53
+ **kwargs: Any,
54
+ ):
55
+ """Wrapper for vllm serving engine.
56
+
57
+ Args:
58
+ model: model name string.
59
+ revision: The revision of the model to use.
60
+ trust_remote_code: Whether to trust remote code execution when loading the model.
61
+ Should be True for models with custom code.
62
+ dtype: Data type for model weights. "auto" will automatically select appropriate
63
+ dtype based on hardware and model capabilities. vllm uses flash attention by
64
+ default, which does not support fp32. Therefore, it defaults to using fp16 for
65
+ inference on fp32 models. Testing has shown a relatively small drop in accuracy.
66
+ You can manually opt for fp32, but inference speed will be very slow.
67
+ head_dtype: "head" refers to the last Linear layer(s) of an LLMs, such as the score
68
+ or classifier in a classification model. Uses fp32 for the head by default to
69
+ gain extra precision.
70
+ max_model_len: Maximum sequence length (context window) supported by the model.
71
+ If None, uses the model's default maximum length.
72
+ max_num_batched_tokens: Maximum number of tokens to process in a single batch.
73
+ If None, automatically determined.
74
+ max_num_seqs: Maximum number of sequences to process concurrently.
75
+ tensor_parallel_size: Number of GPUs for tensor parallelism.
76
+ enable_prefix_caching: Whether to enable KV cache sharing for common prompt prefixes.
77
+ If None, uses the model's default setting.
78
+ gpu_memory_utilization: Target GPU memory utilization ratio (0.0 to 1.0).
79
+ hf_overrides: Dictionary mapping Hugging Face configuration keys to override values.
80
+ pooler_config: Controls the behavior of output pooling in pooling models.
81
+ enforce_eager: Whether to disable CUDA graph optimization and use eager execution.
82
+ **kwargs: Additional arguments to pass to the vllm serving engine model.
83
+ """
84
+ requires_package(
85
+ self,
86
+ "vllm",
87
+ "Wrapper for vllm serving engine",
88
+ install_instruction="pip install mteb[vllm]",
89
+ )
90
+
91
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
92
+
93
+ from vllm import LLM, EngineArgs
94
+
95
+ hf_overrides = {} if hf_overrides is None else hf_overrides
96
+
97
+ if head_dtype is not None:
98
+ hf_overrides["head_dtype"] = head_dtype
99
+
100
+ model_name = model if isinstance(model, str) else model.name
101
+
102
+ if isinstance(model, ModelMeta):
103
+ logger.info(
104
+ "Using revision from model meta. Passed revision will be ignored"
105
+ )
106
+ revision = model.revision
107
+
108
+ args = EngineArgs(
109
+ model=model_name,
110
+ revision=revision,
111
+ runner="pooling",
112
+ convert=self.convert, # type: ignore[arg-type]
113
+ max_model_len=max_model_len,
114
+ max_num_batched_tokens=max_num_batched_tokens,
115
+ max_num_seqs=max_num_seqs,
116
+ tensor_parallel_size=tensor_parallel_size,
117
+ enable_prefix_caching=enable_prefix_caching,
118
+ gpu_memory_utilization=gpu_memory_utilization,
119
+ hf_overrides=hf_overrides,
120
+ pooler_config=pooler_config,
121
+ enforce_eager=enforce_eager,
122
+ trust_remote_code=trust_remote_code,
123
+ dtype=dtype,
124
+ **kwargs,
125
+ )
126
+ self.llm = LLM(**vars(args))
127
+
128
+ if isinstance(model, str):
129
+ self.mteb_model_meta = ModelMeta.from_hub(model=model, revision=revision)
130
+ else:
131
+ self.mteb_model_meta = model
132
+
133
+ atexit.register(self.cleanup)
134
+
135
+ def cleanup(self):
136
+ """Clean up the VLLM distributed runtime environment and release GPU resources."""
137
+ if self.llm is None:
138
+ return
139
+
140
+ from vllm.distributed import ( # type: ignore[import-not-found]
141
+ cleanup_dist_env_and_memory,
142
+ )
143
+
144
+ self.llm = None
145
+ gc.collect()
146
+ cleanup_dist_env_and_memory()
147
+
148
+ def __del__(self):
149
+ try:
150
+ self.cleanup()
151
+ except Exception:
152
+ pass
153
+
154
+
155
+ class VllmEncoderWrapper(AbsEncoder, VllmWrapperBase):
156
+ """vLLM wrapper for Encoder models.
157
+
158
+ Args:
159
+ model: model name string or ModelMeta.
160
+ revision: The revision of the model to use.
161
+ prompt_dict: A dictionary mapping task names to prompt strings.
162
+ use_instructions: Whether to use instructions from the prompt_dict.
163
+ When False, values from prompt_dict are used as static prompts (prefixes).
164
+ When True, values from prompt_dict are used as instructions to be formatted
165
+ using the instruction_template.
166
+ instruction_template: A template or callable to format instructions.
167
+ Can be a string with '{instruction}' placeholder or a callable that takes
168
+ the instruction and prompt type and returns a formatted string.
169
+ apply_instruction_to_documents: Whether to apply instructions to documents prompts.
170
+ **kwargs: Additional arguments to pass to the vllm serving engine model.
171
+ """
172
+
173
+ convert = "embed"
174
+
175
+ def __init__(
176
+ self,
177
+ model: str | ModelMeta,
178
+ revision: str | None = None,
179
+ prompt_dict: dict[str, str] | None = None,
180
+ use_instructions: bool = False,
181
+ instruction_template: (
182
+ str | Callable[[str, PromptType | None], str] | None
183
+ ) = None,
184
+ apply_instruction_to_documents: bool = True,
185
+ **kwargs: Any,
186
+ ):
187
+ if use_instructions and instruction_template is None:
188
+ raise ValueError(
189
+ "To use instructions, an instruction_template must be provided. "
190
+ "For example, `Instruction: {instruction}`"
191
+ )
192
+
193
+ if (
194
+ isinstance(instruction_template, str)
195
+ and "{instruction}" not in instruction_template
196
+ ):
197
+ raise ValueError(
198
+ "Instruction template must contain the string '{instruction}'."
199
+ )
200
+
201
+ self.prompts_dict = prompt_dict
202
+ self.use_instructions = use_instructions
203
+ self.instruction_template = instruction_template
204
+ self.apply_instruction_to_passages = apply_instruction_to_documents
205
+ super().__init__(
206
+ model,
207
+ revision,
208
+ **kwargs,
209
+ )
210
+
211
+ def encode(
212
+ self,
213
+ inputs: DataLoader[BatchedInput],
214
+ *,
215
+ task_metadata: TaskMetadata,
216
+ hf_split: str,
217
+ hf_subset: str,
218
+ prompt_type: PromptType | None = None,
219
+ **kwargs: Any,
220
+ ) -> Array:
221
+ """Encodes the given sentences using the encoder.
222
+
223
+ Args:
224
+ inputs: The sentences to encode.
225
+ task_metadata: The metadata of the task. Sentence-transformers uses this to
226
+ determine which prompt to use from a specified dictionary.
227
+ prompt_type: The name type of prompt. (query or passage)
228
+ hf_split: Split of current task
229
+ hf_subset: Subset of current task
230
+ **kwargs: Additional arguments to pass to the encoder.
231
+
232
+ Returns:
233
+ The encoded sentences.
234
+ """
235
+ prompt = ""
236
+ if self.use_instructions and self.prompts_dict is not None:
237
+ prompt = self.get_task_instruction(task_metadata, prompt_type)
238
+ elif self.prompts_dict is not None:
239
+ prompt_name = self.get_prompt_name(task_metadata, prompt_type)
240
+ if prompt_name is not None:
241
+ prompt = self.prompts_dict.get(prompt_name, "")
242
+
243
+ if (
244
+ self.use_instructions
245
+ and self.apply_instruction_to_passages is False
246
+ and prompt_type == PromptType.document
247
+ ):
248
+ logger.info(
249
+ f"No instruction used, because prompt type = {prompt_type.document}"
250
+ )
251
+ prompt = ""
252
+ else:
253
+ logger.info(
254
+ f"Using instruction: '{prompt}' for task: '{task_metadata.name}' prompt type: '{prompt_type}'"
255
+ )
256
+
257
+ prompts = [prompt + text for batch in inputs for text in batch["text"]]
258
+ outputs = self.llm.encode(
259
+ prompts, pooling_task="embed", truncate_prompt_tokens=-1
260
+ )
261
+ embeddings = torch.stack([output.outputs.data for output in outputs])
262
+ return embeddings
263
+
264
+
265
+ class VllmCrossEncoderWrapper(VllmWrapperBase):
266
+ """vLLM wrapper for CrossEncoder models."""
267
+
268
+ convert = "classify"
269
+
270
+ def __init__(
271
+ self,
272
+ model: str | ModelMeta,
273
+ revision: str | None = None,
274
+ query_prefix: str = "",
275
+ document_prefix: str = "",
276
+ **kwargs: Any,
277
+ ):
278
+ super().__init__(
279
+ model,
280
+ revision,
281
+ **kwargs,
282
+ )
283
+ self.query_prefix = query_prefix
284
+ self.document_prefix = document_prefix
285
+
286
+ def predict(
287
+ self,
288
+ inputs1: DataLoader[BatchedInput],
289
+ inputs2: DataLoader[BatchedInput],
290
+ *,
291
+ task_metadata: TaskMetadata,
292
+ hf_split: str,
293
+ hf_subset: str,
294
+ prompt_type: PromptType | None = None,
295
+ **kwargs: Any,
296
+ ) -> Array:
297
+ """Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.
298
+
299
+ Args:
300
+ inputs1: First Dataloader of inputs to encode. For reranking tasks, these are queries (for text only tasks `QueryDatasetType`).
301
+ inputs2: Second Dataloader of inputs to encode. For reranking, these are documents (for text only tasks `RetrievalOutputType`).
302
+ task_metadata: Metadata of the current task.
303
+ hf_split: Split of current task, allows to know some additional information about current split.
304
+ E.g. Current language
305
+ hf_subset: Subset of current task. Similar to `hf_split` to get more information
306
+ prompt_type: The name type of prompt. (query or passage)
307
+ **kwargs: Additional arguments to pass to the cross-encoder.
308
+
309
+ Returns:
310
+ The predicted relevance scores for each inputs pair.
311
+ """
312
+ queries = [
313
+ self.query_prefix + text for batch in inputs1 for text in batch["text"]
314
+ ]
315
+ corpus = [
316
+ self.document_prefix + text for batch in inputs2 for text in batch["text"]
317
+ ]
318
+ # TODO: support score prompt
319
+
320
+ outputs = self.llm.score(
321
+ queries,
322
+ corpus,
323
+ truncate_prompt_tokens=-1,
324
+ use_tqdm=False,
325
+ )
326
+ scores = np.array([output.outputs.score for output in outputs])
327
+ return scores
@@ -18,6 +18,7 @@ from .built_bench_retrieval import BuiltBenchRetrieval
18
18
  from .chat_doctor_retrieval import ChatDoctorRetrieval
19
19
  from .chem_hotpot_qa_retrieval import ChemHotpotQARetrieval
20
20
  from .chem_nq_retrieval import ChemNQRetrieval
21
+ from .chemrxiv import ChemRxivRetrieval
21
22
  from .cirr_it2i_retrieval import CIRRIT2IRetrieval
22
23
  from .climate_fever_retrieval import (
23
24
  ClimateFEVER,
@@ -254,6 +255,7 @@ __all__ = [
254
255
  "ChatDoctorRetrieval",
255
256
  "ChemHotpotQARetrieval",
256
257
  "ChemNQRetrieval",
258
+ "ChemRxivRetrieval",
257
259
  "ClimateFEVER",
258
260
  "ClimateFEVERHardNegatives",
259
261
  "ClimateFEVERHardNegativesV2",
@@ -0,0 +1,33 @@
1
+ from mteb.abstasks.retrieval import AbsTaskRetrieval
2
+ from mteb.abstasks.task_metadata import TaskMetadata
3
+
4
+
5
+ class ChemRxivRetrieval(AbsTaskRetrieval):
6
+ metadata = TaskMetadata(
7
+ name="ChemRxivRetrieval",
8
+ dataset={
9
+ "path": "BASF-AI/ChemRxivRetrieval",
10
+ "revision": "5377aa18f309ec440ff6325a4c2cd3362c2cb8d7",
11
+ },
12
+ description="A retrieval task based on ChemRxiv papers where queries are LLM-synthesized to match specific paragraphs.",
13
+ reference="https://arxiv.org/abs/2508.01643",
14
+ type="Retrieval",
15
+ category="t2t",
16
+ modalities=["text"],
17
+ eval_splits=["test"],
18
+ eval_langs=["eng-Latn"],
19
+ main_score="ndcg_at_10",
20
+ date=("2025-01-01", "2025-05-01"),
21
+ domains=["Chemistry"],
22
+ task_subtypes=["Question answering", "Article retrieval"],
23
+ license="cc-by-nc-sa-4.0",
24
+ annotations_creators="LM-generated and reviewed",
25
+ dialect=[],
26
+ sample_creation="found",
27
+ bibtex_citation="""@article{kasmaee2025chembed,
28
+ author = {Kasmaee, Ali Shiraee and Khodadad, Mohammad and Astaraki, Mahdi and Saloot, Mohammad Arshi and Sherck, Nicholas and Mahyar, Hamidreza and Samiee, Soheila},
29
+ journal = {arXiv preprint arXiv:2508.01643},
30
+ title = {Chembed: Enhancing chemical literature search through domain-specific text embeddings},
31
+ year = {2025},
32
+ }""",
33
+ )
@@ -6,6 +6,7 @@ from .cross_lingual_semantic_discrimination_wmt21 import (
6
6
  CrossLingualSemanticDiscriminationWMT21,
7
7
  )
8
8
  from .cur_ev1_retrieval import CUREv1Retrieval
9
+ from .euro_pirq_retrieval import EuroPIRQRetrieval
9
10
  from .indic_qa_retrieval import IndicQARetrieval
10
11
  from .jina_vdr_bench_retrieval import (
11
12
  JinaVDRAirbnbSyntheticRetrieval,
@@ -107,6 +108,7 @@ __all__ = [
107
108
  "CUREv1Retrieval",
108
109
  "CrossLingualSemanticDiscriminationWMT19",
109
110
  "CrossLingualSemanticDiscriminationWMT21",
111
+ "EuroPIRQRetrieval",
110
112
  "IndicQARetrieval",
111
113
  "JinaVDRAirbnbSyntheticRetrieval",
112
114
  "JinaVDRArabicChartQARetrieval",
@@ -0,0 +1,43 @@
1
+ from mteb.abstasks.retrieval import AbsTaskRetrieval
2
+ from mteb.abstasks.task_metadata import TaskMetadata
3
+
4
+ _LANGUAGES = {
5
+ "en": ["eng-Latn"],
6
+ "fi": ["fin-Latn"],
7
+ "pt": ["por-Latn"],
8
+ }
9
+
10
+
11
+ class EuroPIRQRetrieval(AbsTaskRetrieval):
12
+ metadata = TaskMetadata(
13
+ name="EuroPIRQRetrieval",
14
+ description="The EuroPIRQ retrieval dataset is a multilingual collection designed for evaluating retrieval and cross-lingual retrieval tasks. Dataset contains 10,000 parallel passages & 100 parallel queries (synthetic) in three languages: English, Portuguese, and Finnish, constructed from the European Union's DGT-Acquis corpus.",
15
+ reference="https://huggingface.co/datasets/eherra/EuroPIRQ-retrieval",
16
+ dataset={
17
+ "path": "eherra/EuroPIRQ-retrieval",
18
+ "revision": "59225ed25fbcea2185e1acbc8c3c80f1a8cd8341",
19
+ },
20
+ type="Retrieval",
21
+ category="t2t",
22
+ modalities=["text"],
23
+ eval_splits=["test"],
24
+ eval_langs=_LANGUAGES,
25
+ main_score="ndcg_at_10",
26
+ date=("2025-12-01", "2025-12-31"),
27
+ domains=["Legal"],
28
+ task_subtypes=[],
29
+ license="not specified",
30
+ annotations_creators="LM-generated and reviewed",
31
+ dialect=[],
32
+ sample_creation="found",
33
+ is_public=True,
34
+ bibtex_citation=r"""
35
+ @misc{eherra_2025_europirq,
36
+ author = { {Elias Herranen} },
37
+ publisher = { Hugging Face },
38
+ title = { EuroPIRQ: European Parallel Information Retrieval Queries },
39
+ url = { https://huggingface.co/datasets/eherra/EuroPIRQ-retrieval },
40
+ year = {2025},
41
+ }
42
+ """,
43
+ )