mteb 2.5.2__py3-none-any.whl → 2.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. mteb/__init__.py +2 -0
  2. mteb/_create_dataloaders.py +17 -18
  3. mteb/_evaluators/any_sts_evaluator.py +3 -3
  4. mteb/_evaluators/clustering_evaluator.py +2 -2
  5. mteb/_evaluators/evaluator.py +4 -2
  6. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +10 -8
  7. mteb/_evaluators/pair_classification_evaluator.py +5 -3
  8. mteb/_evaluators/retrieval_evaluator.py +2 -2
  9. mteb/_evaluators/retrieval_metrics.py +18 -17
  10. mteb/_evaluators/sklearn_evaluator.py +11 -10
  11. mteb/_evaluators/text/bitext_mining_evaluator.py +27 -18
  12. mteb/_evaluators/text/summarization_evaluator.py +23 -18
  13. mteb/_evaluators/zeroshot_classification_evaluator.py +5 -3
  14. mteb/abstasks/_data_filter/filters.py +1 -1
  15. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  16. mteb/abstasks/_statistics_calculation.py +18 -10
  17. mteb/abstasks/_stratification.py +18 -18
  18. mteb/abstasks/abstask.py +35 -28
  19. mteb/abstasks/aggregate_task_metadata.py +1 -9
  20. mteb/abstasks/aggregated_task.py +10 -29
  21. mteb/abstasks/classification.py +15 -10
  22. mteb/abstasks/clustering.py +19 -15
  23. mteb/abstasks/clustering_legacy.py +10 -10
  24. mteb/abstasks/image/image_text_pair_classification.py +7 -4
  25. mteb/abstasks/multilabel_classification.py +23 -19
  26. mteb/abstasks/pair_classification.py +20 -11
  27. mteb/abstasks/regression.py +4 -4
  28. mteb/abstasks/retrieval.py +28 -24
  29. mteb/abstasks/retrieval_dataset_loaders.py +2 -2
  30. mteb/abstasks/sts.py +8 -5
  31. mteb/abstasks/task_metadata.py +31 -33
  32. mteb/abstasks/text/bitext_mining.py +39 -28
  33. mteb/abstasks/text/reranking.py +8 -6
  34. mteb/abstasks/text/summarization.py +10 -5
  35. mteb/abstasks/zeroshot_classification.py +8 -4
  36. mteb/benchmarks/benchmark.py +4 -2
  37. mteb/benchmarks/benchmarks/__init__.py +4 -0
  38. mteb/benchmarks/benchmarks/benchmarks.py +112 -11
  39. mteb/benchmarks/get_benchmark.py +14 -55
  40. mteb/cache.py +182 -29
  41. mteb/cli/_display_tasks.py +2 -2
  42. mteb/cli/build_cli.py +110 -14
  43. mteb/cli/generate_model_card.py +43 -23
  44. mteb/deprecated_evaluator.py +63 -49
  45. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
  46. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
  47. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
  48. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
  49. mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
  50. mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
  51. mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
  52. mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
  53. mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
  54. mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
  55. mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
  56. mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
  57. mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
  58. mteb/evaluate.py +44 -33
  59. mteb/filter_tasks.py +25 -26
  60. mteb/get_tasks.py +29 -30
  61. mteb/languages/language_scripts.py +5 -3
  62. mteb/leaderboard/app.py +162 -34
  63. mteb/load_results.py +12 -12
  64. mteb/models/abs_encoder.py +10 -6
  65. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  66. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  67. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
  68. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
  69. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  70. mteb/models/get_model_meta.py +21 -3
  71. mteb/models/instruct_wrapper.py +28 -8
  72. mteb/models/model_implementations/align_models.py +1 -1
  73. mteb/models/model_implementations/andersborges.py +4 -4
  74. mteb/models/model_implementations/ara_models.py +1 -1
  75. mteb/models/model_implementations/arctic_models.py +8 -8
  76. mteb/models/model_implementations/b1ade_models.py +1 -1
  77. mteb/models/model_implementations/bge_models.py +45 -21
  78. mteb/models/model_implementations/bica_model.py +3 -3
  79. mteb/models/model_implementations/blip2_models.py +2 -2
  80. mteb/models/model_implementations/blip_models.py +16 -16
  81. mteb/models/model_implementations/bm25.py +4 -4
  82. mteb/models/model_implementations/bmretriever_models.py +6 -4
  83. mteb/models/model_implementations/cadet_models.py +1 -1
  84. mteb/models/model_implementations/cde_models.py +11 -4
  85. mteb/models/model_implementations/clip_models.py +6 -6
  86. mteb/models/model_implementations/clips_models.py +3 -3
  87. mteb/models/model_implementations/codefuse_models.py +5 -5
  88. mteb/models/model_implementations/codesage_models.py +3 -3
  89. mteb/models/model_implementations/cohere_models.py +5 -5
  90. mteb/models/model_implementations/cohere_v.py +2 -2
  91. mteb/models/model_implementations/colpali_models.py +3 -3
  92. mteb/models/model_implementations/colqwen_models.py +8 -8
  93. mteb/models/model_implementations/colsmol_models.py +2 -2
  94. mteb/models/model_implementations/conan_models.py +1 -1
  95. mteb/models/model_implementations/dino_models.py +42 -42
  96. mteb/models/model_implementations/e5_instruct.py +23 -4
  97. mteb/models/model_implementations/e5_models.py +9 -9
  98. mteb/models/model_implementations/e5_v.py +6 -6
  99. mteb/models/model_implementations/eagerworks_models.py +1 -1
  100. mteb/models/model_implementations/emillykkejensen_models.py +6 -6
  101. mteb/models/model_implementations/en_code_retriever.py +1 -1
  102. mteb/models/model_implementations/euler_models.py +2 -2
  103. mteb/models/model_implementations/fa_models.py +9 -9
  104. mteb/models/model_implementations/facebookai.py +14 -2
  105. mteb/models/model_implementations/geogpt_models.py +1 -1
  106. mteb/models/model_implementations/gme_v_models.py +6 -5
  107. mteb/models/model_implementations/google_models.py +1 -1
  108. mteb/models/model_implementations/granite_vision_embedding_models.py +1 -1
  109. mteb/models/model_implementations/gritlm_models.py +2 -2
  110. mteb/models/model_implementations/gte_models.py +25 -13
  111. mteb/models/model_implementations/hinvec_models.py +1 -1
  112. mteb/models/model_implementations/ibm_granite_models.py +30 -6
  113. mteb/models/model_implementations/inf_models.py +2 -2
  114. mteb/models/model_implementations/jasper_models.py +2 -2
  115. mteb/models/model_implementations/jina_clip.py +48 -10
  116. mteb/models/model_implementations/jina_models.py +18 -11
  117. mteb/models/model_implementations/kblab.py +12 -6
  118. mteb/models/model_implementations/kennethenevoldsen_models.py +4 -4
  119. mteb/models/model_implementations/kfst.py +1 -1
  120. mteb/models/model_implementations/kowshik24_models.py +1 -1
  121. mteb/models/model_implementations/lgai_embedding_models.py +1 -1
  122. mteb/models/model_implementations/linq_models.py +1 -1
  123. mteb/models/model_implementations/listconranker.py +1 -1
  124. mteb/models/model_implementations/llm2clip_models.py +6 -6
  125. mteb/models/model_implementations/llm2vec_models.py +8 -8
  126. mteb/models/model_implementations/mcinext_models.py +4 -1
  127. mteb/models/model_implementations/mdbr_models.py +17 -3
  128. mteb/models/model_implementations/misc_models.py +68 -68
  129. mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
  130. mteb/models/model_implementations/mme5_models.py +1 -1
  131. mteb/models/model_implementations/moco_models.py +4 -4
  132. mteb/models/model_implementations/mod_models.py +1 -1
  133. mteb/models/model_implementations/model2vec_models.py +14 -14
  134. mteb/models/model_implementations/moka_models.py +1 -1
  135. mteb/models/model_implementations/nbailab.py +3 -3
  136. mteb/models/model_implementations/no_instruct_sentence_models.py +2 -2
  137. mteb/models/model_implementations/nomic_models.py +30 -15
  138. mteb/models/model_implementations/nomic_models_vision.py +1 -1
  139. mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +15 -9
  140. mteb/models/model_implementations/nvidia_models.py +151 -19
  141. mteb/models/model_implementations/octen_models.py +61 -2
  142. mteb/models/model_implementations/openclip_models.py +13 -13
  143. mteb/models/model_implementations/opensearch_neural_sparse_models.py +5 -5
  144. mteb/models/model_implementations/ops_moa_models.py +1 -1
  145. mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
  146. mteb/models/model_implementations/pawan_models.py +1 -1
  147. mteb/models/model_implementations/piccolo_models.py +1 -1
  148. mteb/models/model_implementations/pixie_models.py +56 -0
  149. mteb/models/model_implementations/promptriever_models.py +4 -4
  150. mteb/models/model_implementations/pylate_models.py +10 -9
  151. mteb/models/model_implementations/qodo_models.py +2 -2
  152. mteb/models/model_implementations/qtack_models.py +1 -1
  153. mteb/models/model_implementations/qwen3_models.py +3 -3
  154. mteb/models/model_implementations/qzhou_models.py +2 -2
  155. mteb/models/model_implementations/random_baseline.py +3 -3
  156. mteb/models/model_implementations/rasgaard_models.py +2 -2
  157. mteb/models/model_implementations/reasonir_model.py +1 -1
  158. mteb/models/model_implementations/repllama_models.py +3 -3
  159. mteb/models/model_implementations/rerankers_custom.py +12 -6
  160. mteb/models/model_implementations/rerankers_monot5_based.py +17 -17
  161. mteb/models/model_implementations/richinfoai_models.py +1 -1
  162. mteb/models/model_implementations/ru_sentence_models.py +20 -20
  163. mteb/models/model_implementations/ruri_models.py +10 -10
  164. mteb/models/model_implementations/salesforce_models.py +3 -3
  165. mteb/models/model_implementations/samilpwc_models.py +1 -1
  166. mteb/models/model_implementations/sarashina_embedding_models.py +2 -2
  167. mteb/models/model_implementations/searchmap_models.py +1 -1
  168. mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +113 -146
  169. mteb/models/model_implementations/sentence_transformers_models.py +124 -22
  170. mteb/models/model_implementations/shuu_model.py +1 -1
  171. mteb/models/model_implementations/siglip_models.py +20 -20
  172. mteb/models/model_implementations/slm_models.py +416 -0
  173. mteb/models/model_implementations/spartan8806_atles_champion.py +1 -1
  174. mteb/models/model_implementations/stella_models.py +17 -4
  175. mteb/models/model_implementations/tarka_models.py +2 -2
  176. mteb/models/model_implementations/text2vec_models.py +9 -3
  177. mteb/models/model_implementations/ua_sentence_models.py +1 -1
  178. mteb/models/model_implementations/uae_models.py +7 -1
  179. mteb/models/model_implementations/vdr_models.py +1 -1
  180. mteb/models/model_implementations/vi_vn_models.py +6 -6
  181. mteb/models/model_implementations/vlm2vec_models.py +3 -3
  182. mteb/models/model_implementations/voyage_models.py +84 -0
  183. mteb/models/model_implementations/voyage_v.py +9 -7
  184. mteb/models/model_implementations/youtu_models.py +1 -1
  185. mteb/models/model_implementations/yuan_models.py +1 -1
  186. mteb/models/model_implementations/yuan_models_en.py +1 -1
  187. mteb/models/model_meta.py +80 -31
  188. mteb/models/models_protocols.py +22 -6
  189. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
  190. mteb/models/search_wrappers.py +33 -18
  191. mteb/models/sentence_transformer_wrapper.py +50 -25
  192. mteb/models/vllm_wrapper.py +327 -0
  193. mteb/py.typed +0 -0
  194. mteb/results/benchmark_results.py +29 -21
  195. mteb/results/model_result.py +52 -22
  196. mteb/results/task_result.py +80 -58
  197. mteb/similarity_functions.py +11 -7
  198. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  199. mteb/tasks/classification/est/estonian_valence.py +1 -1
  200. mteb/tasks/classification/kur/kurdish_sentiment_classification.py +2 -2
  201. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  202. mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
  203. mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
  204. mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
  205. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  206. mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
  207. mteb/tasks/retrieval/code/code_rag.py +12 -12
  208. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  209. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  210. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  211. mteb/tasks/retrieval/eng/__init__.py +2 -0
  212. mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
  213. mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
  214. mteb/tasks/retrieval/kor/__init__.py +15 -1
  215. mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
  216. mteb/tasks/retrieval/multilingual/__init__.py +2 -0
  217. mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
  218. mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +90 -100
  219. mteb/tasks/retrieval/nob/norquad.py +2 -2
  220. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  221. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  222. mteb/tasks/retrieval/vie/__init__.py +14 -6
  223. mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +39 -0
  224. mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +39 -0
  225. mteb/tasks/retrieval/vie/fevervn_retrieval.py +39 -0
  226. mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +39 -0
  227. mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +48 -0
  228. mteb/tasks/retrieval/vie/nqvn_retrieval.py +39 -0
  229. mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
  230. mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
  231. mteb/types/__init__.py +2 -0
  232. mteb/types/_encoder_io.py +12 -0
  233. mteb/types/_result.py +2 -1
  234. mteb/types/statistics.py +9 -3
  235. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/METADATA +15 -4
  236. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/RECORD +240 -219
  237. mteb/models/model_implementations/mxbai_models.py +0 -111
  238. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
  239. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
  240. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
  241. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@ from mteb.types import (
14
14
  Array,
15
15
  BatchedInput,
16
16
  CorpusDatasetType,
17
+ EncodeKwargs,
17
18
  PromptType,
18
19
  QueryDatasetType,
19
20
  RetrievalOutputType,
@@ -50,7 +51,7 @@ class SearchEncoderWrapper:
50
51
  task_metadata: TaskMetadata,
51
52
  hf_split: str,
52
53
  hf_subset: str,
53
- encode_kwargs: dict[str, Any],
54
+ encode_kwargs: EncodeKwargs,
54
55
  ) -> None:
55
56
  """Index the corpus for retrieval.
56
57
 
@@ -88,7 +89,7 @@ class SearchEncoderWrapper:
88
89
  hf_split: str,
89
90
  hf_subset: str,
90
91
  top_k: int,
91
- encode_kwargs: dict[str, Any],
92
+ encode_kwargs: EncodeKwargs,
92
93
  top_ranked: TopRankedDocumentsType | None = None,
93
94
  ) -> RetrievalOutputType:
94
95
  """Search the corpus for the given queries.
@@ -200,7 +201,7 @@ class SearchEncoderWrapper:
200
201
  # Reset the task corpus dataloader to None to free up memory
201
202
  self.task_corpus = None
202
203
 
203
- results = {qid: {} for qid in query_idx_to_id.values()}
204
+ results: RetrievalOutputType = {qid: {} for qid in query_idx_to_id.values()}
204
205
  for qid in result_heaps:
205
206
  for score, corpus_id in result_heaps[qid]:
206
207
  results[qid][corpus_id] = score
@@ -215,16 +216,22 @@ class SearchEncoderWrapper:
215
216
  hf_subset: str,
216
217
  hf_split: str,
217
218
  top_k: int,
218
- encode_kwargs: dict[str, Any],
219
+ encode_kwargs: EncodeKwargs,
219
220
  ) -> dict[str, list[tuple[float, str]]]:
220
221
  logger.info("Encoding Corpus in batches (this might take a while)...")
222
+ if self.task_corpus is None:
223
+ raise ValueError("Corpus must be indexed before searching.")
224
+
221
225
  itr = range(0, len(self.task_corpus), self.corpus_chunk_size)
222
226
 
223
- result_heaps = {qid: [] for qid in query_idx_to_id.values()}
227
+ result_heaps: dict[str, list[tuple[float, str]]] = {
228
+ qid: [] for qid in query_idx_to_id.values()
229
+ }
224
230
  for batch_num, corpus_start_idx in enumerate(itr):
225
231
  logger.info(f"Encoding Batch {batch_num + 1}/{len(itr)}...")
226
232
  corpus_end_idx = min(
227
- corpus_start_idx + self.corpus_chunk_size, len(self.task_corpus)
233
+ corpus_start_idx + self.corpus_chunk_size,
234
+ len(self.task_corpus),
228
235
  )
229
236
  sub_corpus = self.task_corpus.select(
230
237
  range(corpus_start_idx, corpus_end_idx)
@@ -249,7 +256,7 @@ class SearchEncoderWrapper:
249
256
  scores = self.model.similarity(query_embeddings, sub_corpus_embeddings)
250
257
 
251
258
  # get top-k values
252
- cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
259
+ cos_scores_top_k_values_tensor, cos_scores_top_k_idx_tensor = torch.topk(
253
260
  torch.as_tensor(scores),
254
261
  min(
255
262
  top_k + 1,
@@ -258,8 +265,8 @@ class SearchEncoderWrapper:
258
265
  dim=1,
259
266
  largest=True,
260
267
  )
261
- cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
262
- cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
268
+ cos_scores_top_k_idx = cos_scores_top_k_idx_tensor.cpu().tolist()
269
+ cos_scores_top_k_values = cos_scores_top_k_values_tensor.cpu().tolist()
263
270
 
264
271
  sub_corpus_ids = list(sub_corpus_ids)
265
272
  result_heaps = self._sort_full_corpus_results(
@@ -312,14 +319,18 @@ class SearchEncoderWrapper:
312
319
  task_metadata: TaskMetadata,
313
320
  hf_subset: str,
314
321
  hf_split: str,
315
- encode_kwargs: dict[str, Any],
322
+ encode_kwargs: EncodeKwargs,
316
323
  ) -> dict[str, list[tuple[float, str]]]:
317
324
  """Rerank documents based on pre-ranked documents.
318
325
 
319
326
  Returns:
320
327
  A dictionary mapping query IDs to a list of tuples, each containing a relevance score and a document ID.
321
328
  """
322
- result_heaps = {qid: [] for qid in query_idx_to_id.values()}
329
+ if self.task_corpus is None:
330
+ raise ValueError("Corpus must be indexed before searching.")
331
+ result_heaps: dict[str, list[tuple[float, str]]] = {
332
+ qid: [] for qid in query_idx_to_id.values()
333
+ }
323
334
  doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
324
335
 
325
336
  all_doc_embeddings = self.model.encode(
@@ -340,7 +351,8 @@ class SearchEncoderWrapper:
340
351
  for query_idx, query_embedding in enumerate(query_embeddings):
341
352
  query_id = query_idx_to_id[query_idx]
342
353
  if query_id not in top_ranked:
343
- logger.warning(f"No pre-ranked documents found for query {query_id}")
354
+ msg = f"No pre-ranked documents found for query {query_id}"
355
+ logger.warning(msg)
344
356
  continue
345
357
 
346
358
  ranked_ids = top_ranked[query_id]
@@ -386,12 +398,12 @@ class SearchEncoderWrapper:
386
398
 
387
399
  def _rerank_sort_results(
388
400
  self,
389
- result_heaps: list[tuple[float, str]],
401
+ result_heaps: dict[str, list[tuple[float, str]]],
390
402
  query_id: str,
391
403
  ranked_ids: list[str],
392
404
  scores_top_k_idx: torch.Tensor,
393
405
  scores_top_k_values: torch.Tensor,
394
- ) -> list[tuple[float, str]]:
406
+ ) -> dict[str, list[tuple[float, str]]]:
395
407
  """Sort the heap into descending order list.
396
408
 
397
409
  Returns:
@@ -459,7 +471,7 @@ class SearchCrossEncoderWrapper:
459
471
  task_metadata: TaskMetadata,
460
472
  hf_split: str,
461
473
  hf_subset: str,
462
- encode_kwargs: dict[str, Any],
474
+ encode_kwargs: EncodeKwargs,
463
475
  ) -> None:
464
476
  """Index the corpus for retrieval.
465
477
 
@@ -480,7 +492,7 @@ class SearchCrossEncoderWrapper:
480
492
  hf_split: str,
481
493
  hf_subset: str,
482
494
  top_k: int,
483
- encode_kwargs: dict[str, Any],
495
+ encode_kwargs: EncodeKwargs,
484
496
  top_ranked: TopRankedDocumentsType | None = None,
485
497
  ) -> RetrievalOutputType:
486
498
  """Search the corpus using the given queries.
@@ -502,6 +514,8 @@ class SearchCrossEncoderWrapper:
502
514
  raise ValueError(
503
515
  "CrossEncoder search requires top_ranked documents for reranking."
504
516
  )
517
+ if self.task_corpus is None:
518
+ raise ValueError("Corpus must be indexed before searching.")
505
519
 
506
520
  query_id_to_idx = {row["id"]: i for i, row in enumerate(queries)}
507
521
  doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
@@ -511,7 +525,8 @@ class SearchCrossEncoderWrapper:
511
525
  doc_pairs_ids: list[tuple[str, str]] = []
512
526
  for query_id, corpus_ids in top_ranked.items():
513
527
  if query_id not in top_ranked:
514
- logger.warning(f"No pre-ranked documents found for query {query_id}")
528
+ msg = f"No pre-ranked documents found for query {query_id}"
529
+ logger.warning(msg)
515
530
  continue
516
531
 
517
532
  query_idx = query_id_to_idx[query_id]
@@ -540,7 +555,7 @@ class SearchCrossEncoderWrapper:
540
555
  hf_subset=hf_subset,
541
556
  )
542
557
 
543
- results = {qid: {} for qid in queries["id"]}
558
+ results: RetrievalOutputType = {qid: {} for qid in queries["id"]}
544
559
  for (query_id, corpus_id), score in zip(doc_pairs_ids, predictions):
545
560
  results[query_id][corpus_id] = float(score)
546
561
 
@@ -1,16 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ import warnings
4
5
  from typing import TYPE_CHECKING, Any
5
6
 
6
7
  import numpy as np
7
8
  import torch
8
9
  from packaging.version import Version
9
10
  from torch.utils.data import DataLoader
11
+ from typing_extensions import Unpack
10
12
 
11
13
  from mteb._log_once import LogOnce
12
14
  from mteb.models import ModelMeta
13
- from mteb.types import Array, BatchedInput, PromptType
15
+ from mteb.types import Array, BatchedInput, EncodeKwargs, PromptType
14
16
 
15
17
  from .abs_encoder import AbsEncoder
16
18
 
@@ -25,17 +27,18 @@ SENTENCE_TRANSFORMERS_QUERY_ENCODE_VERSION = "5.0.0"
25
27
 
26
28
 
27
29
  def sentence_transformers_loader(
28
- model_name: str, revision: str | None = None, **kwargs
30
+ model_name: str, revision: str | None = None, device: str | None = None, **kwargs
29
31
  ) -> SentenceTransformerEncoderWrapper:
30
32
  """Loads a SentenceTransformer model and wraps it in a SentenceTransformerEncoderWrapper.
31
33
 
32
34
  Args:
33
35
  model_name: The name of the SentenceTransformer model to load.
34
36
  revision: The revision of the model to load.
37
+ device: The device used to load the model.
35
38
  kwargs: Additional arguments to pass to the SentenceTransformer model.
36
39
  """
37
40
  return SentenceTransformerEncoderWrapper(
38
- model=model_name, revision=revision, **kwargs
41
+ model=model_name, revision=revision, device=device, **kwargs
39
42
  )
40
43
 
41
44
 
@@ -48,6 +51,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
48
51
  self,
49
52
  model: str | SentenceTransformer,
50
53
  revision: str | None = None,
54
+ device: str | None = None,
51
55
  model_prompts: dict[str, str] | None = None,
52
56
  **kwargs,
53
57
  ) -> None:
@@ -56,6 +60,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
56
60
  Args:
57
61
  model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model.
58
62
  revision: The revision of the model to use.
63
+ device: The device used to load the model.
59
64
  model_prompts: A dictionary mapping task names to prompt names.
60
65
  First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt,
61
66
  then to the composed prompt of task type + prompt type, then to the specific task type prompt,
@@ -65,7 +70,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
65
70
  from sentence_transformers import SentenceTransformer
66
71
 
67
72
  if isinstance(model, str):
68
- self.model = SentenceTransformer(model, revision=revision, **kwargs)
73
+ self.model = SentenceTransformer(
74
+ model, revision=revision, device=device, **kwargs
75
+ )
69
76
  else:
70
77
  self.model = model
71
78
 
@@ -75,9 +82,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
75
82
  if built_in_prompts and not model_prompts:
76
83
  model_prompts = built_in_prompts
77
84
  elif model_prompts and built_in_prompts:
78
- logger.warning(
79
- f"Model prompts specified, these will overwrite the default model prompts. Current prompts will be:\n {model_prompts}"
80
- )
85
+ msg = f"Model prompts specified, these will overwrite the default model prompts. Current prompts will be:\n {model_prompts}"
86
+ logger.warning(msg)
87
+ warnings.warn(msg)
81
88
  self.model.prompts = model_prompts
82
89
 
83
90
  self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name(
@@ -86,9 +93,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
86
93
 
87
94
  if invalid_prompts:
88
95
  invalid_prompts = "\n".join(invalid_prompts)
89
- logger.warning(
90
- f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
91
- )
96
+ msg = f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
97
+ logger.warning(msg)
98
+ warnings.warn(msg)
92
99
 
93
100
  if (
94
101
  self.model_prompts
@@ -98,13 +105,15 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
98
105
  or PromptType.document.value not in self.model_prompts
99
106
  )
100
107
  ):
101
- logger.warning(
102
- "SentenceTransformers that use prompts most often need to be configured with at least 'query' and"
103
- f" 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
104
- )
108
+ msg = f"SentenceTransformers that use prompts most often need to be configured with at least 'query' and 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
109
+ logger.warning(msg)
110
+ warnings.warn(msg)
105
111
 
112
+ def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:
113
+ """Compute the similarity between two collections of embeddings."""
106
114
  if hasattr(self.model, "similarity") and callable(self.model.similarity):
107
- self.similarity = self.model.similarity
115
+ return self.model.similarity(embeddings1, embeddings2)
116
+ return super().similarity(embeddings1, embeddings2)
108
117
 
109
118
  def encode(
110
119
  self,
@@ -114,7 +123,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
114
123
  hf_split: str,
115
124
  hf_subset: str,
116
125
  prompt_type: PromptType | None = None,
117
- **kwargs: Any,
126
+ **kwargs: Unpack[EncodeKwargs],
118
127
  ) -> Array:
119
128
  """Encodes the given sentences using the encoder.
120
129
 
@@ -150,7 +159,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
150
159
  prompt_name = None
151
160
  if self.model_prompts is not None:
152
161
  prompt_name = self.get_prompt_name(task_metadata, prompt_type)
153
- prompt = self.model_prompts.get(prompt_name, None)
162
+ prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
154
163
  if prompt_name:
155
164
  prompt_log = f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
156
165
  else:
@@ -193,7 +202,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
193
202
  hf_split: str,
194
203
  hf_subset: str,
195
204
  prompt_type: PromptType | None = None,
196
- **kwargs: Any,
205
+ **kwargs: Unpack[EncodeKwargs],
197
206
  ) -> Array:
198
207
  """Encodes the given sentences using the encoder.
199
208
 
@@ -221,7 +230,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
221
230
  prompt_name = None
222
231
  if self.model_prompts is not None:
223
232
  prompt_name = self.get_prompt_name(task_metadata, prompt_type)
224
- prompt = self.model_prompts.get(prompt_name, None)
233
+ prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
225
234
  if prompt_name:
226
235
  logger.info(
227
236
  f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
@@ -234,7 +243,9 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
234
243
  all_embeddings = []
235
244
  for batch in inputs:
236
245
  batch_column = next(iter(batch.keys()))
237
- batched_input = [dict() for _ in range(len(batch[batch_column]))]
246
+ batched_input: list[dict[str, Any]] = [
247
+ dict() for _ in range(len(batch[batch_column]))
248
+ ]
238
249
 
239
250
  # transform from {"text": [text1, text2], "image": [image1, image2]} to
240
251
  # [{"text": text1, "image": image1}, {"text": text2, "image": image2}]
@@ -255,12 +266,24 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
255
266
 
256
267
 
257
268
  class CrossEncoderWrapper:
258
- """Wrapper for CrossEncoder models."""
269
+ """Wrapper for CrossEncoder models.
270
+
271
+ Args:
272
+ model: The CrossEncoder model to use. Can be a string (model name) or a CrossEncoder model.
273
+ revision: The revision of the model to use.
274
+ device: The device used to load the model.
275
+ query_prefix: A prefix to add to all queries.
276
+ passage_prefix: A prefix to add to all passages.
277
+ **kwargs: Additional arguments to pass to the CrossEncoder model.
278
+ """
259
279
 
260
280
  def __init__(
261
281
  self,
262
282
  model: CrossEncoder | str,
263
283
  revision: str | None = None,
284
+ device: str | None = None,
285
+ query_prefix: str = "",
286
+ passage_prefix: str = "",
264
287
  **kwargs,
265
288
  ) -> None:
266
289
  from sentence_transformers import CrossEncoder
@@ -268,9 +291,11 @@ class CrossEncoderWrapper:
268
291
  if isinstance(model, CrossEncoder):
269
292
  self.model = model
270
293
  elif isinstance(model, str):
271
- self.model = CrossEncoder(model, revision=revision, **kwargs)
294
+ self.model = CrossEncoder(model, revision=revision, device=device, **kwargs)
272
295
 
273
296
  self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)
297
+ self.query_prefix = query_prefix
298
+ self.passage_prefix = passage_prefix
274
299
 
275
300
  def predict(
276
301
  self,
@@ -281,7 +306,7 @@ class CrossEncoderWrapper:
281
306
  hf_split: str,
282
307
  hf_subset: str,
283
308
  prompt_type: PromptType | None = None,
284
- **kwargs: Any,
309
+ **kwargs: Unpack[EncodeKwargs],
285
310
  ) -> Array:
286
311
  """Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.
287
312
 
@@ -299,10 +324,10 @@ class CrossEncoderWrapper:
299
324
  The predicted relevance scores for each inputs pair.
300
325
  """
301
326
  all_queries_with_instructions = [
302
- text for batch in inputs1 for text in batch["text"]
327
+ self.query_prefix + text for batch in inputs1 for text in batch["text"]
303
328
  ]
304
329
  all_corpus_with_instructions = [
305
- text for batch in inputs2 for text in batch["text"]
330
+ self.passage_prefix + text for batch in inputs2 for text in batch["text"]
306
331
  ]
307
332
 
308
333
  return self.model.predict(