mteb 2.1.4__py3-none-any.whl → 2.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (527) hide show
  1. mteb/__init__.py +6 -0
  2. mteb/_create_dataloaders.py +22 -20
  3. mteb/_evaluators/any_sts_evaluator.py +23 -14
  4. mteb/_evaluators/classification_metrics.py +54 -0
  5. mteb/_evaluators/clustering_evaluator.py +3 -3
  6. mteb/_evaluators/evaluator.py +4 -2
  7. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +18 -11
  8. mteb/_evaluators/pair_classification_evaluator.py +34 -40
  9. mteb/_evaluators/retrieval_evaluator.py +2 -2
  10. mteb/_evaluators/retrieval_metrics.py +18 -17
  11. mteb/_evaluators/sklearn_evaluator.py +25 -37
  12. mteb/_evaluators/text/bitext_mining_evaluator.py +31 -19
  13. mteb/_evaluators/text/summarization_evaluator.py +27 -20
  14. mteb/_evaluators/zeroshot_classification_evaluator.py +7 -5
  15. mteb/abstasks/_data_filter/__init__.py +0 -0
  16. mteb/abstasks/_data_filter/filters.py +125 -0
  17. mteb/abstasks/_data_filter/task_pipelines.py +105 -0
  18. mteb/abstasks/_statistics_calculation.py +23 -11
  19. mteb/abstasks/_stratification.py +18 -18
  20. mteb/abstasks/abstask.py +35 -28
  21. mteb/abstasks/aggregate_task_metadata.py +1 -9
  22. mteb/abstasks/aggregated_task.py +10 -29
  23. mteb/abstasks/classification.py +15 -12
  24. mteb/abstasks/clustering.py +20 -16
  25. mteb/abstasks/clustering_legacy.py +13 -10
  26. mteb/abstasks/image/image_text_pair_classification.py +7 -4
  27. mteb/abstasks/multilabel_classification.py +33 -22
  28. mteb/abstasks/pair_classification.py +27 -11
  29. mteb/abstasks/regression.py +4 -4
  30. mteb/abstasks/retrieval.py +28 -24
  31. mteb/abstasks/retrieval_dataset_loaders.py +2 -2
  32. mteb/abstasks/sts.py +14 -4
  33. mteb/abstasks/task_metadata.py +32 -33
  34. mteb/abstasks/text/bitext_mining.py +39 -28
  35. mteb/abstasks/text/reranking.py +8 -6
  36. mteb/abstasks/text/summarization.py +10 -5
  37. mteb/abstasks/zeroshot_classification.py +8 -4
  38. mteb/benchmarks/_create_table.py +84 -37
  39. mteb/benchmarks/benchmark.py +77 -16
  40. mteb/benchmarks/benchmarks/__init__.py +12 -0
  41. mteb/benchmarks/benchmarks/benchmarks.py +361 -16
  42. mteb/benchmarks/get_benchmark.py +14 -53
  43. mteb/cache.py +227 -37
  44. mteb/cli/_display_tasks.py +2 -2
  45. mteb/cli/build_cli.py +110 -14
  46. mteb/cli/generate_model_card.py +43 -23
  47. mteb/deprecated_evaluator.py +71 -62
  48. mteb/descriptive_stats/BitextMining/RuSciBenchBitextMining.v2.json +61 -0
  49. mteb/descriptive_stats/Classification/HebrewSentimentAnalysis.v3.json +60 -0
  50. mteb/descriptive_stats/Classification/TurkishConstitutionalCourtViolation.json +54 -0
  51. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
  52. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
  53. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
  54. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
  55. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3ComputerScienceRetrieval.json +214 -0
  56. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3EnergyRetrieval.json +214 -0
  57. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceEnRetrieval.json +214 -0
  58. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceFrRetrieval.json +214 -0
  59. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3HrRetrieval.json +214 -0
  60. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3IndustrialRetrieval.json +214 -0
  61. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3NuclearRetrieval.json +214 -0
  62. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PharmaceuticalsRetrieval.json +214 -0
  63. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PhysicsRetrieval.json +214 -0
  64. mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3TelecomRetrieval.json +214 -0
  65. mteb/descriptive_stats/PairClassification/TERRa.V2.json +35 -0
  66. mteb/descriptive_stats/Reranking/JQaRARerankingLite.json +35 -0
  67. mteb/descriptive_stats/Reranking/JaCWIRRerankingLite.json +35 -0
  68. mteb/descriptive_stats/Reranking/MultiLongDocReranking.json +466 -0
  69. mteb/descriptive_stats/Retrieval/ArguAna-NL.v2.json +30 -0
  70. mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
  71. mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
  72. mteb/descriptive_stats/Retrieval/JaCWIRRetrievalLite.json +30 -0
  73. mteb/descriptive_stats/Retrieval/JaqketRetrievalLite.json +30 -0
  74. mteb/descriptive_stats/Retrieval/MIRACLJaRetrievalLite.json +30 -0
  75. mteb/descriptive_stats/Retrieval/MrTyDiJaRetrievalLite.json +30 -0
  76. mteb/descriptive_stats/Retrieval/NFCorpus-NL.v2.json +30 -0
  77. mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
  78. mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
  79. mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
  80. mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
  81. mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
  82. mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
  83. mteb/descriptive_stats/Retrieval/SCIDOCS-NL.v2.json +30 -0
  84. mteb/descriptive_stats/Retrieval/SQuADKorV1Retrieval.json +30 -0
  85. mteb/descriptive_stats/Retrieval/SciFact-NL.v2.json +30 -0
  86. mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
  87. mteb/evaluate.py +106 -75
  88. mteb/filter_tasks.py +25 -26
  89. mteb/get_tasks.py +29 -30
  90. mteb/languages/language_scripts.py +5 -3
  91. mteb/leaderboard/app.py +414 -151
  92. mteb/leaderboard/benchmark_selector.py +14 -5
  93. mteb/leaderboard/figures.py +13 -15
  94. mteb/leaderboard/table.py +82 -17
  95. mteb/load_results.py +12 -12
  96. mteb/models/__init__.py +4 -1
  97. mteb/models/abs_encoder.py +31 -23
  98. mteb/models/cache_wrappers/__init__.py +2 -1
  99. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  100. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +7 -6
  101. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
  102. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
  103. mteb/models/cache_wrappers/cache_wrapper.py +3 -3
  104. mteb/models/get_model_meta.py +25 -118
  105. mteb/models/instruct_wrapper.py +33 -9
  106. mteb/models/model_implementations/align_models.py +8 -1
  107. mteb/models/model_implementations/amazon_models.py +1 -0
  108. mteb/models/model_implementations/andersborges.py +65 -0
  109. mteb/models/model_implementations/ara_models.py +9 -1
  110. mteb/models/model_implementations/arctic_models.py +16 -8
  111. mteb/models/model_implementations/b1ade_models.py +2 -1
  112. mteb/models/model_implementations/bedrock_models.py +4 -0
  113. mteb/models/model_implementations/bge_models.py +101 -17
  114. mteb/models/model_implementations/bica_model.py +35 -0
  115. mteb/models/model_implementations/blip2_models.py +13 -2
  116. mteb/models/model_implementations/blip_models.py +43 -16
  117. mteb/models/model_implementations/bm25.py +5 -4
  118. mteb/models/model_implementations/bmretriever_models.py +10 -4
  119. mteb/models/model_implementations/cadet_models.py +10 -1
  120. mteb/models/model_implementations/cde_models.py +25 -4
  121. mteb/models/model_implementations/clip_models.py +9 -6
  122. mteb/models/model_implementations/clips_models.py +100 -0
  123. mteb/models/model_implementations/codefuse_models.py +165 -3
  124. mteb/models/model_implementations/codesage_models.py +18 -3
  125. mteb/models/model_implementations/cohere_models.py +13 -6
  126. mteb/models/model_implementations/cohere_v.py +7 -2
  127. mteb/models/model_implementations/colpali_models.py +17 -9
  128. mteb/models/model_implementations/colqwen_models.py +275 -5
  129. mteb/models/model_implementations/colsmol_models.py +4 -2
  130. mteb/models/model_implementations/conan_models.py +2 -1
  131. mteb/models/model_implementations/dino_models.py +194 -23
  132. mteb/models/model_implementations/e5_instruct.py +27 -4
  133. mteb/models/model_implementations/e5_models.py +21 -110
  134. mteb/models/model_implementations/e5_v.py +7 -6
  135. mteb/models/model_implementations/eagerworks_models.py +164 -0
  136. mteb/models/model_implementations/emillykkejensen_models.py +91 -0
  137. mteb/models/model_implementations/en_code_retriever.py +2 -1
  138. mteb/models/model_implementations/euler_models.py +32 -0
  139. mteb/models/model_implementations/evaclip_models.py +4 -0
  140. mteb/models/model_implementations/fa_models.py +67 -9
  141. mteb/models/model_implementations/facebookai.py +205 -0
  142. mteb/models/model_implementations/geogpt_models.py +2 -1
  143. mteb/models/model_implementations/gme_v_models.py +17 -10
  144. mteb/models/model_implementations/google_models.py +17 -6
  145. mteb/models/model_implementations/granite_vision_embedding_models.py +8 -3
  146. mteb/models/model_implementations/gritlm_models.py +4 -2
  147. mteb/models/model_implementations/gte_models.py +99 -9
  148. mteb/models/model_implementations/hinvec_models.py +2 -1
  149. mteb/models/model_implementations/human.py +1 -0
  150. mteb/models/model_implementations/ibm_granite_models.py +36 -6
  151. mteb/models/model_implementations/inf_models.py +4 -2
  152. mteb/models/model_implementations/jasper_models.py +256 -3
  153. mteb/models/model_implementations/jina_clip.py +49 -10
  154. mteb/models/model_implementations/jina_models.py +222 -11
  155. mteb/models/model_implementations/kalm_models.py +203 -25
  156. mteb/models/model_implementations/kblab.py +37 -0
  157. mteb/models/model_implementations/kennethenevoldsen_models.py +74 -0
  158. mteb/models/model_implementations/kfst.py +25 -0
  159. mteb/models/model_implementations/kowshik24_models.py +32 -0
  160. mteb/models/model_implementations/lens_models.py +2 -0
  161. mteb/models/model_implementations/lgai_embedding_models.py +2 -1
  162. mteb/models/model_implementations/linq_models.py +4 -3
  163. mteb/models/model_implementations/listconranker.py +2 -2
  164. mteb/models/model_implementations/llm2clip_models.py +9 -6
  165. mteb/models/model_implementations/llm2vec_models.py +16 -8
  166. mteb/models/model_implementations/mcinext_models.py +7 -1
  167. mteb/models/model_implementations/mdbr_models.py +19 -3
  168. mteb/models/model_implementations/misc_models.py +422 -60
  169. mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
  170. mteb/models/model_implementations/mme5_models.py +2 -1
  171. mteb/models/model_implementations/moco_models.py +15 -4
  172. mteb/models/model_implementations/mod_models.py +191 -0
  173. mteb/models/model_implementations/model2vec_models.py +27 -14
  174. mteb/models/model_implementations/moka_models.py +4 -1
  175. mteb/models/model_implementations/nbailab.py +70 -0
  176. mteb/models/model_implementations/no_instruct_sentence_models.py +3 -2
  177. mteb/models/model_implementations/nomic_models.py +173 -6
  178. mteb/models/model_implementations/nomic_models_vision.py +8 -3
  179. mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +32 -19
  180. mteb/models/model_implementations/nvidia_models.py +155 -20
  181. mteb/models/model_implementations/octen_models.py +254 -0
  182. mteb/models/model_implementations/openai_models.py +20 -16
  183. mteb/models/model_implementations/openclip_models.py +37 -13
  184. mteb/models/model_implementations/opensearch_neural_sparse_models.py +10 -5
  185. mteb/models/model_implementations/ops_moa_models.py +5 -3
  186. mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
  187. mteb/models/model_implementations/pawan_models.py +39 -0
  188. mteb/models/model_implementations/piccolo_models.py +9 -1
  189. mteb/models/model_implementations/pixie_models.py +56 -0
  190. mteb/models/model_implementations/promptriever_models.py +12 -8
  191. mteb/models/model_implementations/pylate_models.py +46 -12
  192. mteb/models/model_implementations/qodo_models.py +4 -2
  193. mteb/models/model_implementations/qtack_models.py +2 -1
  194. mteb/models/model_implementations/qwen3_models.py +9 -6
  195. mteb/models/model_implementations/qzhou_models.py +5 -3
  196. mteb/models/model_implementations/random_baseline.py +19 -24
  197. mteb/models/model_implementations/rasgaard_models.py +34 -0
  198. mteb/models/model_implementations/reasonir_model.py +2 -1
  199. mteb/models/model_implementations/repllama_models.py +5 -3
  200. mteb/models/model_implementations/rerankers_custom.py +15 -9
  201. mteb/models/model_implementations/rerankers_monot5_based.py +31 -31
  202. mteb/models/model_implementations/richinfoai_models.py +2 -1
  203. mteb/models/model_implementations/ru_sentence_models.py +71 -20
  204. mteb/models/model_implementations/ruri_models.py +322 -0
  205. mteb/models/model_implementations/salesforce_models.py +6 -3
  206. mteb/models/model_implementations/samilpwc_models.py +2 -1
  207. mteb/models/model_implementations/sarashina_embedding_models.py +168 -0
  208. mteb/models/model_implementations/searchmap_models.py +2 -1
  209. mteb/models/model_implementations/seed_1_6_embedding_models.py +8 -2
  210. mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +625 -0
  211. mteb/models/model_implementations/seed_models.py +1 -0
  212. mteb/models/model_implementations/sentence_transformers_models.py +177 -18
  213. mteb/models/model_implementations/shuu_model.py +32 -31
  214. mteb/models/model_implementations/siglip_models.py +30 -20
  215. mteb/models/model_implementations/slm_models.py +416 -0
  216. mteb/models/model_implementations/sonar_models.py +1 -0
  217. mteb/models/model_implementations/spartan8806_atles_champion.py +34 -0
  218. mteb/models/model_implementations/stella_models.py +23 -4
  219. mteb/models/model_implementations/tarka_models.py +376 -0
  220. mteb/models/model_implementations/text2vec_models.py +9 -3
  221. mteb/models/model_implementations/ua_sentence_models.py +11 -1
  222. mteb/models/model_implementations/uae_models.py +8 -1
  223. mteb/models/model_implementations/vdr_models.py +3 -1
  224. mteb/models/model_implementations/vi_vn_models.py +45 -6
  225. mteb/models/model_implementations/vista_models.py +2 -0
  226. mteb/models/model_implementations/vlm2vec_models.py +5 -3
  227. mteb/models/model_implementations/voyage_models.py +99 -0
  228. mteb/models/model_implementations/voyage_v.py +17 -9
  229. mteb/models/model_implementations/xyz_models.py +1 -0
  230. mteb/models/model_implementations/youtu_models.py +2 -1
  231. mteb/models/model_implementations/yuan_models.py +34 -0
  232. mteb/models/model_implementations/yuan_models_en.py +58 -0
  233. mteb/models/model_meta.py +498 -29
  234. mteb/models/models_protocols.py +22 -6
  235. mteb/models/search_encoder_index/__init__.py +7 -0
  236. mteb/models/search_encoder_index/search_backend_protocol.py +50 -0
  237. mteb/models/search_encoder_index/search_indexes/__init__.py +5 -0
  238. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +160 -0
  239. mteb/models/search_wrappers.py +197 -65
  240. mteb/models/sentence_transformer_wrapper.py +52 -32
  241. mteb/models/vllm_wrapper.py +327 -0
  242. mteb/py.typed +0 -0
  243. mteb/results/benchmark_results.py +114 -65
  244. mteb/results/model_result.py +63 -26
  245. mteb/results/task_result.py +117 -77
  246. mteb/similarity_functions.py +60 -7
  247. mteb/tasks/bitext_mining/multilingual/__init__.py +2 -1
  248. mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +4 -2
  249. mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining_fast.py +1 -1
  250. mteb/tasks/bitext_mining/multilingual/ru_sci_bench_bitext_mining.py +47 -5
  251. mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -6
  252. mteb/tasks/classification/ara/ajgt.py +1 -2
  253. mteb/tasks/classification/ara/hotel_review_sentiment_classification.py +1 -2
  254. mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -2
  255. mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -2
  256. mteb/tasks/classification/ara/tweet_emotion_classification.py +1 -2
  257. mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -2
  258. mteb/tasks/classification/ben/bengali_document_classification.py +1 -2
  259. mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -2
  260. mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -2
  261. mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +1 -2
  262. mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +1 -2
  263. mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +1 -2
  264. mteb/tasks/classification/dan/angry_tweets_classification.py +1 -2
  265. mteb/tasks/classification/dan/danish_political_comments_classification.py +1 -2
  266. mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -2
  267. mteb/tasks/classification/dan/dk_hate_classification.py +2 -3
  268. mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -2
  269. mteb/tasks/classification/deu/ten_k_gnad_classification.py +1 -2
  270. mteb/tasks/classification/eng/amazon_polarity_classification.py +1 -2
  271. mteb/tasks/classification/eng/arxiv_classification.py +1 -2
  272. mteb/tasks/classification/eng/banking77_classification.py +1 -2
  273. mteb/tasks/classification/eng/dbpedia_classification.py +1 -2
  274. mteb/tasks/classification/eng/emotion_classification.py +1 -2
  275. mteb/tasks/classification/eng/financial_phrasebank_classification.py +1 -2
  276. mteb/tasks/classification/eng/frenk_en_classification.py +1 -2
  277. mteb/tasks/classification/eng/gtsrb_classification.py +1 -1
  278. mteb/tasks/classification/eng/imdb_classification.py +1 -2
  279. mteb/tasks/classification/eng/legal_bench_classification.py +14 -120
  280. mteb/tasks/classification/eng/news_classification.py +1 -2
  281. mteb/tasks/classification/eng/patch_camelyon_classification.py +1 -1
  282. mteb/tasks/classification/eng/patent_classification.py +1 -2
  283. mteb/tasks/classification/eng/poem_sentiment_classification.py +1 -2
  284. mteb/tasks/classification/eng/sds_eye_protection_classification.py +1 -2
  285. mteb/tasks/classification/eng/sds_gloves_classification.py +1 -2
  286. mteb/tasks/classification/eng/toxic_chat_classification.py +2 -19
  287. mteb/tasks/classification/eng/toxic_conversations_classification.py +1 -2
  288. mteb/tasks/classification/eng/tweet_sentiment_extraction_classification.py +1 -2
  289. mteb/tasks/classification/eng/tweet_topic_single_classification.py +2 -13
  290. mteb/tasks/classification/eng/ucf101_classification.py +1 -5
  291. mteb/tasks/classification/eng/wikipedia_bio_met_chem_classification.py +1 -2
  292. mteb/tasks/classification/eng/wikipedia_chem_fields_classification.py +1 -2
  293. mteb/tasks/classification/eng/wikipedia_comp_chem_spectroscopy_classification.py +1 -2
  294. mteb/tasks/classification/eng/wikipedia_crystallography_analytical_classification.py +1 -2
  295. mteb/tasks/classification/eng/wikipedia_theoretical_applied_classification.py +1 -2
  296. mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -2
  297. mteb/tasks/classification/eng/yelp_review_full_classification.py +1 -2
  298. mteb/tasks/classification/est/estonian_valence.py +2 -3
  299. mteb/tasks/classification/fas/fa_mteb_classification.py +7 -14
  300. mteb/tasks/classification/fil/filipino_hate_speech_classification.py +1 -2
  301. mteb/tasks/classification/fin/fin_toxicity_classification.py +2 -11
  302. mteb/tasks/classification/fra/french_book_reviews.py +1 -2
  303. mteb/tasks/classification/fra/movie_review_sentiment_classification.py +1 -2
  304. mteb/tasks/classification/guj/gujarati_news_classification.py +1 -2
  305. mteb/tasks/classification/heb/__init__.py +6 -1
  306. mteb/tasks/classification/heb/hebrew_sentiment_analysis.py +62 -4
  307. mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -2
  308. mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -2
  309. mteb/tasks/classification/hrv/frenk_hr_classification.py +1 -2
  310. mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +1 -2
  311. mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -2
  312. mteb/tasks/classification/ita/italian_linguist_acceptability_classification.py +1 -2
  313. mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -2
  314. mteb/tasks/classification/jpn/wrime_classification.py +1 -2
  315. mteb/tasks/classification/kan/kannada_news_classification.py +1 -2
  316. mteb/tasks/classification/kor/klue_tc.py +1 -2
  317. mteb/tasks/classification/kor/kor_hate_classification.py +2 -17
  318. mteb/tasks/classification/kor/kor_sarcasm_classification.py +2 -19
  319. mteb/tasks/classification/kur/kurdish_sentiment_classification.py +3 -4
  320. mteb/tasks/classification/mal/malayalam_news_classification.py +1 -2
  321. mteb/tasks/classification/mar/marathi_news_classification.py +1 -2
  322. mteb/tasks/classification/mkd/macedonian_tweet_sentiment_classification.py +1 -2
  323. mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -6
  324. mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -4
  325. mteb/tasks/classification/multilingual/ru_sci_bench_classification.py +4 -23
  326. mteb/tasks/classification/multilingual/scala_classification.py +2 -3
  327. mteb/tasks/classification/multilingual/sib200_classification.py +1 -6
  328. mteb/tasks/classification/mya/myanmar_news.py +1 -2
  329. mteb/tasks/classification/nep/nepali_news_classification.py +1 -2
  330. mteb/tasks/classification/nld/dutch_book_review_sentiment_classification.py +4 -2
  331. mteb/tasks/classification/nld/dutch_cola_classification.py +3 -0
  332. mteb/tasks/classification/nld/dutch_government_bias_classification.py +3 -0
  333. mteb/tasks/classification/nld/dutch_news_articles_classification.py +3 -0
  334. mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +3 -0
  335. mteb/tasks/classification/nld/iconclass_classification.py +3 -0
  336. mteb/tasks/classification/nld/open_tender_classification.py +3 -0
  337. mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +3 -0
  338. mteb/tasks/classification/nob/no_rec_classification.py +1 -2
  339. mteb/tasks/classification/nob/norwegian_parliament_classification.py +1 -2
  340. mteb/tasks/classification/ory/odia_news_classification.py +1 -2
  341. mteb/tasks/classification/pol/polish_classification.py +3 -6
  342. mteb/tasks/classification/ron/moroco.py +1 -2
  343. mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -2
  344. mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -2
  345. mteb/tasks/classification/rus/georeview_classification.py +1 -2
  346. mteb/tasks/classification/rus/headline_classification.py +1 -2
  347. mteb/tasks/classification/rus/inappropriateness_classification.py +1 -2
  348. mteb/tasks/classification/rus/ru_reviews_classification.py +1 -2
  349. mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -2
  350. mteb/tasks/classification/rus/senti_ru_eval.py +1 -2
  351. mteb/tasks/classification/sin/sinhala_news_classification.py +1 -2
  352. mteb/tasks/classification/sin/sinhala_news_source_classification.py +1 -2
  353. mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +1 -2
  354. mteb/tasks/classification/slk/slovak_hate_speech_classification.py +1 -2
  355. mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +1 -2
  356. mteb/tasks/classification/slv/frenk_sl_classification.py +1 -2
  357. mteb/tasks/classification/spa/spanish_news_classification.py +1 -2
  358. mteb/tasks/classification/spa/spanish_sentiment_classification.py +1 -2
  359. mteb/tasks/classification/ssw/siswati_news_classification.py +1 -2
  360. mteb/tasks/classification/swa/swahili_news_classification.py +1 -2
  361. mteb/tasks/classification/swe/dalaj_classification.py +1 -2
  362. mteb/tasks/classification/swe/swe_rec_classification.py +1 -2
  363. mteb/tasks/classification/swe/swedish_sentiment_classification.py +1 -2
  364. mteb/tasks/classification/tam/tamil_news_classification.py +1 -2
  365. mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +1 -2
  366. mteb/tasks/classification/tha/wisesight_sentiment_classification.py +1 -2
  367. mteb/tasks/classification/tsn/tswana_news_classification.py +1 -2
  368. mteb/tasks/classification/tur/__init__.py +4 -0
  369. mteb/tasks/classification/tur/turkish_constitutional_court.py +41 -0
  370. mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +1 -2
  371. mteb/tasks/classification/tur/turkish_product_sentiment_classification.py +1 -2
  372. mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -15
  373. mteb/tasks/classification/urd/urdu_roman_sentiment_classification.py +1 -2
  374. mteb/tasks/classification/vie/amazon_counterfactual_vn_classification.py +1 -6
  375. mteb/tasks/classification/vie/amazon_polarity_vn_classification.py +1 -6
  376. mteb/tasks/classification/vie/amazon_reviews_vn_classification.py +1 -5
  377. mteb/tasks/classification/vie/banking77_vn_classification.py +1 -5
  378. mteb/tasks/classification/vie/emotion_vn_classification.py +1 -5
  379. mteb/tasks/classification/vie/imdb_vn_classification.py +1 -5
  380. mteb/tasks/classification/vie/massive_intent_vn_classification.py +1 -5
  381. mteb/tasks/classification/vie/massive_scenario_vn_classification.py +1 -5
  382. mteb/tasks/classification/vie/mtop_domain_vn_classification.py +1 -5
  383. mteb/tasks/classification/vie/mtop_intent_vn_classification.py +1 -5
  384. mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -5
  385. mteb/tasks/classification/vie/tweet_sentiment_extraction_vn_classification.py +1 -5
  386. mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -2
  387. mteb/tasks/classification/zho/cmteb_classification.py +5 -10
  388. mteb/tasks/classification/zho/yue_openrice_review_classification.py +1 -2
  389. mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -2
  390. mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
  391. mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
  392. mteb/tasks/clustering/jpn/mews_c16_ja_clustering.py +1 -3
  393. mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -6
  394. mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +3 -0
  395. mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +3 -0
  396. mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +3 -0
  397. mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +3 -0
  398. mteb/tasks/clustering/nld/open_tender_clustering_s2s.py +3 -0
  399. mteb/tasks/clustering/nld/vabb_clustering_p2p.py +3 -0
  400. mteb/tasks/clustering/nld/vabb_clustering_s2s.py +3 -0
  401. mteb/tasks/clustering/vie/reddit_clustering_p2p_vn.py +1 -5
  402. mteb/tasks/clustering/vie/reddit_clustering_vn.py +1 -5
  403. mteb/tasks/clustering/vie/stack_exchange_clustering_p2p_vn.py +1 -5
  404. mteb/tasks/clustering/vie/stack_exchange_clustering_vn.py +1 -5
  405. mteb/tasks/clustering/vie/twenty_newsgroups_clustering_vn.py +1 -5
  406. mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
  407. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  408. mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -5
  409. mteb/tasks/multilabel_classification/kor/kor_hate_speech_ml_classification.py +1 -9
  410. mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -6
  411. mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +3 -0
  412. mteb/tasks/multilabel_classification/nld/vabb_multi_label_classification.py +3 -0
  413. mteb/tasks/multilabel_classification/por/brazilian_toxic_tweets_classification.py +1 -6
  414. mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
  415. mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -2
  416. mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -6
  417. mteb/tasks/pair_classification/eng/legal_bench_pc.py +1 -9
  418. mteb/tasks/pair_classification/nld/sick_nl_pair_classification.py +3 -0
  419. mteb/tasks/pair_classification/nld/xlwic_nl_pair_classification.py +3 -0
  420. mteb/tasks/pair_classification/rus/__init__.py +2 -2
  421. mteb/tasks/pair_classification/rus/terra.py +51 -25
  422. mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -5
  423. mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -5
  424. mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -5
  425. mteb/tasks/regression/multilingual/ru_sci_bench_regression.py +2 -6
  426. mteb/tasks/reranking/jpn/__init__.py +9 -1
  427. mteb/tasks/reranking/jpn/j_qa_ra_reranking_lite.py +49 -0
  428. mteb/tasks/reranking/jpn/ja_cwir_reranking_lite.py +47 -0
  429. mteb/tasks/reranking/multilingual/__init__.py +2 -0
  430. mteb/tasks/reranking/multilingual/multi_long_doc_reranking.py +70 -0
  431. mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
  432. mteb/tasks/reranking/multilingual/x_glue_wpr_reranking.py +1 -2
  433. mteb/tasks/reranking/vie/ask_ubuntu_dup_questions_vn.py +1 -5
  434. mteb/tasks/reranking/vie/sci_docs_reranking_vn.py +1 -5
  435. mteb/tasks/reranking/vie/stack_overflow_dup_questions_vn.py +1 -5
  436. mteb/tasks/retrieval/code/code_rag.py +12 -12
  437. mteb/tasks/retrieval/code/fresh_stack_retrieval.py +8 -5
  438. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  439. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  440. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  441. mteb/tasks/retrieval/eng/__init__.py +2 -0
  442. mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
  443. mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
  444. mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -8
  445. mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +4 -0
  446. mteb/tasks/retrieval/jpn/__init__.py +8 -0
  447. mteb/tasks/retrieval/jpn/ja_cwir_retrieval.py +1 -4
  448. mteb/tasks/retrieval/jpn/ja_cwir_retrieval_lite.py +47 -0
  449. mteb/tasks/retrieval/jpn/jaqket_retrieval_lite.py +50 -0
  450. mteb/tasks/retrieval/jpn/miracl_ja_retrieval_lite.py +52 -0
  451. mteb/tasks/retrieval/jpn/mr_tydi_ja_retrieval_lite.py +48 -0
  452. mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +11 -4
  453. mteb/tasks/retrieval/kor/__init__.py +16 -1
  454. mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
  455. mteb/tasks/retrieval/kor/squad_kor_v1_retrieval.py +47 -0
  456. mteb/tasks/retrieval/multilingual/__init__.py +24 -0
  457. mteb/tasks/retrieval/multilingual/belebele_retrieval.py +5 -4
  458. mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
  459. mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +56 -42
  460. mteb/tasks/retrieval/multilingual/mkqa_retrieval.py +1 -2
  461. mteb/tasks/retrieval/multilingual/mlqa_retrieval.py +1 -4
  462. mteb/tasks/retrieval/multilingual/multi_long_doc_retrieval.py +1 -2
  463. mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +9 -4
  464. mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -12
  465. mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +4 -2
  466. mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +389 -0
  467. mteb/tasks/retrieval/nld/__init__.py +8 -4
  468. mteb/tasks/retrieval/nld/argu_ana_nl_retrieval.py +46 -27
  469. mteb/tasks/retrieval/nld/bbsard_nl_retrieval.py +3 -0
  470. mteb/tasks/retrieval/nld/dutch_news_articles_retrieval.py +3 -0
  471. mteb/tasks/retrieval/nld/legal_qa_nl_retrieval.py +3 -0
  472. mteb/tasks/retrieval/nld/nf_corpus_nl_retrieval.py +42 -25
  473. mteb/tasks/retrieval/nld/open_tender_retrieval.py +3 -0
  474. mteb/tasks/retrieval/nld/sci_fact_nl_retrieval.py +42 -24
  475. mteb/tasks/retrieval/nld/scidocsnl_retrieval.py +44 -27
  476. mteb/tasks/retrieval/nld/vabb_retrieval.py +3 -0
  477. mteb/tasks/retrieval/nob/norquad.py +2 -2
  478. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  479. mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -7
  480. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  481. mteb/tasks/retrieval/vie/__init__.py +14 -6
  482. mteb/tasks/retrieval/vie/argu_ana_vn_retrieval.py +1 -5
  483. mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +40 -5
  484. mteb/tasks/retrieval/vie/cqa_dupstack_android_vn_retrieval.py +1 -5
  485. mteb/tasks/retrieval/vie/cqa_dupstack_gis_vn_retrieval.py +1 -5
  486. mteb/tasks/retrieval/vie/cqa_dupstack_mathematica_vn_retrieval.py +1 -5
  487. mteb/tasks/retrieval/vie/cqa_dupstack_physics_vn_retrieval.py +1 -5
  488. mteb/tasks/retrieval/vie/cqa_dupstack_programmers_vn_retrieval.py +1 -5
  489. mteb/tasks/retrieval/vie/cqa_dupstack_stats_vn_retrieval.py +1 -5
  490. mteb/tasks/retrieval/vie/cqa_dupstack_tex_vn_retrieval.py +1 -5
  491. mteb/tasks/retrieval/vie/cqa_dupstack_unix_vn_retrieval.py +1 -5
  492. mteb/tasks/retrieval/vie/cqa_dupstack_webmasters_vn_retrieval.py +1 -5
  493. mteb/tasks/retrieval/vie/cqa_dupstack_wordpress_vn_retrieval.py +1 -5
  494. mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +40 -5
  495. mteb/tasks/retrieval/vie/fevervn_retrieval.py +40 -7
  496. mteb/tasks/retrieval/vie/fi_qa2018_vn_retrieval.py +1 -5
  497. mteb/tasks/retrieval/vie/green_node_table_markdown_retrieval.py +16 -1
  498. mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +40 -6
  499. mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +49 -5
  500. mteb/tasks/retrieval/vie/nf_corpus_vn_retrieval.py +1 -5
  501. mteb/tasks/retrieval/vie/nqvn_retrieval.py +40 -5
  502. mteb/tasks/retrieval/vie/quora_vn_retrieval.py +1 -6
  503. mteb/tasks/retrieval/vie/sci_fact_vn_retrieval.py +1 -5
  504. mteb/tasks/retrieval/vie/scidocsvn_retrieval.py +1 -6
  505. mteb/tasks/retrieval/vie/touche2020_vn_retrieval.py +1 -5
  506. mteb/tasks/retrieval/vie/treccovidvn_retrieval.py +1 -5
  507. mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
  508. mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
  509. mteb/tasks/sts/nld/sick_nl_sts.py +1 -0
  510. mteb/tasks/sts/vie/biosses_stsvn.py +1 -5
  511. mteb/tasks/sts/vie/sickr_stsvn.py +1 -5
  512. mteb/tasks/sts/vie/sts_benchmark_stsvn.py +1 -5
  513. mteb/tasks/zeroshot_classification/eng/gtsrb.py +1 -1
  514. mteb/tasks/zeroshot_classification/eng/patch_camelyon.py +1 -1
  515. mteb/tasks/zeroshot_classification/eng/ucf101.py +1 -5
  516. mteb/types/__init__.py +2 -0
  517. mteb/types/_encoder_io.py +19 -2
  518. mteb/types/_result.py +2 -1
  519. mteb/types/statistics.py +9 -3
  520. {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/METADATA +25 -8
  521. {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/RECORD +525 -438
  522. mteb/models/model_implementations/mxbai_models.py +0 -102
  523. mteb/models/model_implementations/nb_sbert.py +0 -25
  524. {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
  525. {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
  526. {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
  527. {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@ from mteb.types import (
14
14
  Array,
15
15
  BatchedInput,
16
16
  CorpusDatasetType,
17
+ EncodeKwargs,
17
18
  PromptType,
18
19
  QueryDatasetType,
19
20
  RetrievalOutputType,
@@ -21,6 +22,7 @@ from mteb.types import (
21
22
  )
22
23
 
23
24
  from .models_protocols import CrossEncoderProtocol, EncoderProtocol
25
+ from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol
24
26
 
25
27
  logger = logging.getLogger(__name__)
26
28
 
@@ -28,13 +30,19 @@ logger = logging.getLogger(__name__)
28
30
  class SearchEncoderWrapper:
29
31
  """Wrapper for Encoder models to be used in search tasks."""
30
32
 
31
- corpus_chunk_size = 50_000
32
33
  task_corpus: CorpusDatasetType | None
33
34
 
34
- def __init__(self, model: EncoderProtocol):
35
+ def __init__(
36
+ self,
37
+ model: EncoderProtocol,
38
+ corpus_chunk_size: int = 50_000,
39
+ index_backend: IndexEncoderSearchProtocol | None = None,
40
+ ) -> None:
35
41
  self.model = model
36
42
  self.task_corpus = None
37
43
  self.mteb_model_meta = model.mteb_model_meta
44
+ self.corpus_chunk_size = corpus_chunk_size
45
+ self.index_backend = index_backend
38
46
 
39
47
  def index(
40
48
  self,
@@ -43,7 +51,7 @@ class SearchEncoderWrapper:
43
51
  task_metadata: TaskMetadata,
44
52
  hf_split: str,
45
53
  hf_subset: str,
46
- encode_kwargs: dict[str, Any],
54
+ encode_kwargs: EncodeKwargs,
47
55
  ) -> None:
48
56
  """Index the corpus for retrieval.
49
57
 
@@ -56,6 +64,22 @@ class SearchEncoderWrapper:
56
64
  """
57
65
  # Always retain corpus for potential reranking or fallback flows
58
66
  self.task_corpus = corpus
67
+ if self.index_backend is not None:
68
+ all_doc_embeddings = self.model.encode(
69
+ create_dataloader(
70
+ corpus,
71
+ task_metadata,
72
+ prompt_type=PromptType.document,
73
+ **encode_kwargs,
74
+ ),
75
+ task_metadata=task_metadata,
76
+ hf_split=hf_split,
77
+ hf_subset=hf_subset,
78
+ prompt_type=PromptType.document,
79
+ **encode_kwargs,
80
+ )
81
+
82
+ self.index_backend.add_documents(all_doc_embeddings, corpus["id"])
59
83
 
60
84
  def search(
61
85
  self,
@@ -65,7 +89,7 @@ class SearchEncoderWrapper:
65
89
  hf_split: str,
66
90
  hf_subset: str,
67
91
  top_k: int,
68
- encode_kwargs: dict[str, Any],
92
+ encode_kwargs: EncodeKwargs,
69
93
  top_ranked: TopRankedDocumentsType | None = None,
70
94
  ) -> RetrievalOutputType:
71
95
  """Search the corpus for the given queries.
@@ -90,7 +114,7 @@ class SearchEncoderWrapper:
90
114
  queries,
91
115
  task_metadata,
92
116
  prompt_type=PromptType.query,
93
- batch_size=encode_kwargs.get("batch_size", 32),
117
+ **encode_kwargs,
94
118
  )
95
119
 
96
120
  query_embeddings = self.model.encode(
@@ -105,32 +129,79 @@ class SearchEncoderWrapper:
105
129
 
106
130
  if top_ranked is not None:
107
131
  logger.info("Reranking pre-ranked documents...")
108
- result_heaps = self._rerank_documents(
109
- query_idx_to_id=query_idx_to_id,
110
- query_embeddings=query_embeddings,
111
- top_ranked=top_ranked,
112
- top_k=top_k,
113
- task_metadata=task_metadata,
114
- hf_subset=hf_subset,
115
- hf_split=hf_split,
116
- encode_kwargs=encode_kwargs,
117
- )
132
+ if self.index_backend is None:
133
+ result_heaps = self._rerank_documents(
134
+ query_idx_to_id=query_idx_to_id,
135
+ query_embeddings=query_embeddings,
136
+ top_ranked=top_ranked,
137
+ top_k=top_k,
138
+ task_metadata=task_metadata,
139
+ hf_subset=hf_subset,
140
+ hf_split=hf_split,
141
+ encode_kwargs=encode_kwargs,
142
+ )
143
+ else:
144
+ cos_scores_top_k_values, cos_scores_top_k_idx = (
145
+ self.index_backend.search(
146
+ query_embeddings,
147
+ top_k,
148
+ similarity_fn=self.model.similarity,
149
+ top_ranked=top_ranked,
150
+ query_idx_to_id=query_idx_to_id,
151
+ )
152
+ )
153
+ result_heaps = {qid: [] for qid in query_idx_to_id.values()}
154
+ for query_itr in range(len(query_embeddings)):
155
+ result_heaps = self._rerank_sort_results(
156
+ result_heaps=result_heaps,
157
+ query_id=query_idx_to_id[query_itr],
158
+ ranked_ids=top_ranked[query_idx_to_id[query_itr]],
159
+ scores_top_k_idx=torch.tensor(
160
+ [cos_scores_top_k_idx[query_itr]]
161
+ ),
162
+ scores_top_k_values=torch.tensor(
163
+ [cos_scores_top_k_values[query_itr]]
164
+ ),
165
+ )
166
+ self.index_backend.clear()
118
167
  else:
119
168
  logger.info("Performing full corpus search...")
120
- result_heaps = self._full_corpus_search(
121
- query_idx_to_id=query_idx_to_id,
122
- query_embeddings=query_embeddings,
123
- task_metadata=task_metadata,
124
- hf_subset=hf_subset,
125
- hf_split=hf_split,
126
- top_k=top_k,
127
- encode_kwargs=encode_kwargs,
128
- )
169
+ if self.index_backend is None:
170
+ result_heaps = self._full_corpus_search(
171
+ query_idx_to_id=query_idx_to_id,
172
+ query_embeddings=query_embeddings,
173
+ task_metadata=task_metadata,
174
+ hf_subset=hf_subset,
175
+ hf_split=hf_split,
176
+ top_k=top_k,
177
+ encode_kwargs=encode_kwargs,
178
+ )
179
+ else:
180
+ cos_scores_top_k_values, cos_scores_top_k_idx = (
181
+ self.index_backend.search(
182
+ query_embeddings,
183
+ top_k,
184
+ similarity_fn=self.model.similarity,
185
+ top_ranked=None,
186
+ query_idx_to_id=None,
187
+ )
188
+ )
189
+ result_heaps = {qid: [] for qid in query_idx_to_id.values()}
190
+ result_heaps = self._sort_full_corpus_results(
191
+ result_heaps=result_heaps,
192
+ query_idx_to_id=query_idx_to_id,
193
+ query_embeddings=query_embeddings,
194
+ cos_scores_top_k_idx=cos_scores_top_k_idx,
195
+ cos_scores_top_k_values=cos_scores_top_k_values,
196
+ sub_corpus_ids=self.task_corpus["id"],
197
+ top_k=top_k,
198
+ )
199
+ self.index_backend.clear()
129
200
 
130
201
  # Reset the task corpus dataloader to None to free up memory
131
202
  self.task_corpus = None
132
203
 
133
- results = {qid: {} for qid in query_idx_to_id.values()}
204
+ results: RetrievalOutputType = {qid: {} for qid in query_idx_to_id.values()}
134
205
  for qid in result_heaps:
135
206
  for score, corpus_id in result_heaps[qid]:
136
207
  results[qid][corpus_id] = score
@@ -145,16 +216,22 @@ class SearchEncoderWrapper:
145
216
  hf_subset: str,
146
217
  hf_split: str,
147
218
  top_k: int,
148
- encode_kwargs: dict[str, Any],
219
+ encode_kwargs: EncodeKwargs,
149
220
  ) -> dict[str, list[tuple[float, str]]]:
150
- logger.info("Encoding Corpus in batches... Warning: This might take a while!")
221
+ logger.info("Encoding Corpus in batches (this might take a while)...")
222
+ if self.task_corpus is None:
223
+ raise ValueError("Corpus must be indexed before searching.")
224
+
151
225
  itr = range(0, len(self.task_corpus), self.corpus_chunk_size)
152
226
 
153
- result_heaps = {qid: [] for qid in query_idx_to_id.values()}
227
+ result_heaps: dict[str, list[tuple[float, str]]] = {
228
+ qid: [] for qid in query_idx_to_id.values()
229
+ }
154
230
  for batch_num, corpus_start_idx in enumerate(itr):
155
231
  logger.info(f"Encoding Batch {batch_num + 1}/{len(itr)}...")
156
232
  corpus_end_idx = min(
157
- corpus_start_idx + self.corpus_chunk_size, len(self.task_corpus)
233
+ corpus_start_idx + self.corpus_chunk_size,
234
+ len(self.task_corpus),
158
235
  )
159
236
  sub_corpus = self.task_corpus.select(
160
237
  range(corpus_start_idx, corpus_end_idx)
@@ -165,7 +242,7 @@ class SearchEncoderWrapper:
165
242
  sub_corpus,
166
243
  task_metadata,
167
244
  prompt_type=PromptType.document,
168
- batch_size=encode_kwargs.get("batch_size", 32),
245
+ **encode_kwargs,
169
246
  ),
170
247
  task_metadata=task_metadata,
171
248
  hf_split=hf_split,
@@ -179,8 +256,8 @@ class SearchEncoderWrapper:
179
256
  scores = self.model.similarity(query_embeddings, sub_corpus_embeddings)
180
257
 
181
258
  # get top-k values
182
- cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
183
- torch.tensor(scores),
259
+ cos_scores_top_k_values_tensor, cos_scores_top_k_idx_tensor = torch.topk(
260
+ torch.as_tensor(scores),
184
261
  min(
185
262
  top_k + 1,
186
263
  len(scores[1]) if len(scores) > 1 else len(scores[-1]),
@@ -188,22 +265,49 @@ class SearchEncoderWrapper:
188
265
  dim=1,
189
266
  largest=True,
190
267
  )
191
- cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
192
- cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
193
-
194
- for query_itr in range(len(query_embeddings)):
195
- query_id = query_idx_to_id[query_itr]
196
- for sub_corpus_id, score in zip(
197
- cos_scores_top_k_idx[query_itr],
198
- cos_scores_top_k_values[query_itr],
199
- ):
200
- corpus_id = sub_corpus_ids[sub_corpus_id]
201
- if len(result_heaps[query_id]) < top_k:
202
- # push item on the heap
203
- heapq.heappush(result_heaps[query_id], (score, corpus_id))
204
- else:
205
- # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
206
- heapq.heappushpop(result_heaps[query_id], (score, corpus_id))
268
+ cos_scores_top_k_idx = cos_scores_top_k_idx_tensor.cpu().tolist()
269
+ cos_scores_top_k_values = cos_scores_top_k_values_tensor.cpu().tolist()
270
+
271
+ sub_corpus_ids = list(sub_corpus_ids)
272
+ result_heaps = self._sort_full_corpus_results(
273
+ result_heaps=result_heaps,
274
+ query_idx_to_id=query_idx_to_id,
275
+ query_embeddings=query_embeddings,
276
+ cos_scores_top_k_idx=cos_scores_top_k_idx,
277
+ cos_scores_top_k_values=cos_scores_top_k_values,
278
+ sub_corpus_ids=sub_corpus_ids,
279
+ top_k=top_k,
280
+ )
281
+ return result_heaps
282
+
283
+ def _sort_full_corpus_results(
284
+ self,
285
+ result_heaps: dict[str, list[tuple[float, str]]],
286
+ query_idx_to_id: dict[int, str],
287
+ query_embeddings: Array,
288
+ cos_scores_top_k_idx: list[list[int]],
289
+ cos_scores_top_k_values: list[list[float]],
290
+ sub_corpus_ids: list[str],
291
+ top_k: int,
292
+ ) -> dict[str, list[tuple[float, str]]]:
293
+ """Sort the heaps into descending order lists.
294
+
295
+ Returns:
296
+ A dictionary mapping query IDs to a sorted list of tuples, each containing a relevance score and a document ID.
297
+ """
298
+ for query_itr in range(len(query_embeddings)):
299
+ query_id = query_idx_to_id[query_itr]
300
+ for sub_corpus_id, score in zip(
301
+ cos_scores_top_k_idx[query_itr],
302
+ cos_scores_top_k_values[query_itr],
303
+ ):
304
+ corpus_id = sub_corpus_ids[sub_corpus_id]
305
+ if len(result_heaps[query_id]) < top_k:
306
+ # push item on the heap
307
+ heapq.heappush(result_heaps[query_id], (score, corpus_id))
308
+ else:
309
+ # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
310
+ heapq.heappushpop(result_heaps[query_id], (score, corpus_id))
207
311
  return result_heaps
208
312
 
209
313
  def _rerank_documents(
@@ -215,14 +319,18 @@ class SearchEncoderWrapper:
215
319
  task_metadata: TaskMetadata,
216
320
  hf_subset: str,
217
321
  hf_split: str,
218
- encode_kwargs: dict[str, Any],
322
+ encode_kwargs: EncodeKwargs,
219
323
  ) -> dict[str, list[tuple[float, str]]]:
220
324
  """Rerank documents based on pre-ranked documents.
221
325
 
222
326
  Returns:
223
327
  A dictionary mapping query IDs to a list of tuples, each containing a relevance score and a document ID.
224
328
  """
225
- result_heaps = {qid: [] for qid in query_idx_to_id.values()}
329
+ if self.task_corpus is None:
330
+ raise ValueError("Corpus must be indexed before searching.")
331
+ result_heaps: dict[str, list[tuple[float, str]]] = {
332
+ qid: [] for qid in query_idx_to_id.values()
333
+ }
226
334
  doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
227
335
 
228
336
  all_doc_embeddings = self.model.encode(
@@ -230,7 +338,7 @@ class SearchEncoderWrapper:
230
338
  self.task_corpus,
231
339
  task_metadata,
232
340
  prompt_type=PromptType.document,
233
- batch_size=encode_kwargs.get("batch_size", 32),
341
+ **encode_kwargs,
234
342
  ),
235
343
  task_metadata=task_metadata,
236
344
  hf_split=hf_split,
@@ -243,7 +351,8 @@ class SearchEncoderWrapper:
243
351
  for query_idx, query_embedding in enumerate(query_embeddings):
244
352
  query_id = query_idx_to_id[query_idx]
245
353
  if query_id not in top_ranked:
246
- logger.warning(f"No pre-ranked documents found for query {query_id}")
354
+ msg = f"No pre-ranked documents found for query {query_id}"
355
+ logger.warning(msg)
247
356
  continue
248
357
 
249
358
  ranked_ids = top_ranked[query_id]
@@ -278,14 +387,34 @@ class SearchEncoderWrapper:
278
387
  scores_top_k_values = scores_top_k_values.cpu()
279
388
  scores_top_k_idx = scores_top_k_idx.cpu()
280
389
 
281
- # Build result heap
282
- for doc_idx, score in zip(
283
- scores_top_k_idx[0].tolist(),
284
- scores_top_k_values[0].tolist(),
285
- ):
286
- corpus_id = ranked_ids[doc_idx]
287
- heapq.heappush(result_heaps[query_id], (score, corpus_id))
390
+ result_heaps = self._rerank_sort_results(
391
+ result_heaps=result_heaps,
392
+ query_id=query_id,
393
+ ranked_ids=ranked_ids,
394
+ scores_top_k_idx=scores_top_k_idx,
395
+ scores_top_k_values=scores_top_k_values,
396
+ )
397
+ return result_heaps
398
+
399
+ def _rerank_sort_results(
400
+ self,
401
+ result_heaps: dict[str, list[tuple[float, str]]],
402
+ query_id: str,
403
+ ranked_ids: list[str],
404
+ scores_top_k_idx: torch.Tensor,
405
+ scores_top_k_values: torch.Tensor,
406
+ ) -> dict[str, list[tuple[float, str]]]:
407
+ """Sort the heap into descending order list.
288
408
 
409
+ Returns:
410
+ A sorted list of tuples, each containing a relevance score and a document ID.
411
+ """
412
+ for doc_idx, score in zip(
413
+ scores_top_k_idx[0].tolist(),
414
+ scores_top_k_values[0].tolist(),
415
+ ):
416
+ corpus_id = ranked_ids[doc_idx]
417
+ heapq.heappush(result_heaps[query_id], (score, corpus_id))
289
418
  return result_heaps
290
419
 
291
420
  def encode(
@@ -342,7 +471,7 @@ class SearchCrossEncoderWrapper:
342
471
  task_metadata: TaskMetadata,
343
472
  hf_split: str,
344
473
  hf_subset: str,
345
- encode_kwargs: dict[str, Any],
474
+ encode_kwargs: EncodeKwargs,
346
475
  ) -> None:
347
476
  """Index the corpus for retrieval.
348
477
 
@@ -363,7 +492,7 @@ class SearchCrossEncoderWrapper:
363
492
  hf_split: str,
364
493
  hf_subset: str,
365
494
  top_k: int,
366
- encode_kwargs: dict[str, Any],
495
+ encode_kwargs: EncodeKwargs,
367
496
  top_ranked: TopRankedDocumentsType | None = None,
368
497
  ) -> RetrievalOutputType:
369
498
  """Search the corpus using the given queries.
@@ -385,6 +514,8 @@ class SearchCrossEncoderWrapper:
385
514
  raise ValueError(
386
515
  "CrossEncoder search requires top_ranked documents for reranking."
387
516
  )
517
+ if self.task_corpus is None:
518
+ raise ValueError("Corpus must be indexed before searching.")
388
519
 
389
520
  query_id_to_idx = {row["id"]: i for i, row in enumerate(queries)}
390
521
  doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
@@ -394,7 +525,8 @@ class SearchCrossEncoderWrapper:
394
525
  doc_pairs_ids: list[tuple[str, str]] = []
395
526
  for query_id, corpus_ids in top_ranked.items():
396
527
  if query_id not in top_ranked:
397
- logger.warning(f"No pre-ranked documents found for query {query_id}")
528
+ msg = f"No pre-ranked documents found for query {query_id}"
529
+ logger.warning(msg)
398
530
  continue
399
531
 
400
532
  query_idx = query_id_to_idx[query_id]
@@ -407,13 +539,13 @@ class SearchCrossEncoderWrapper:
407
539
  Dataset.from_list(total_queries),
408
540
  task_metadata,
409
541
  prompt_type=PromptType.document,
410
- batch_size=encode_kwargs.get("batch_size", 32),
542
+ **encode_kwargs,
411
543
  )
412
544
  corpus_loader = create_dataloader(
413
545
  Dataset.from_list(total_docs),
414
546
  task_metadata,
415
547
  prompt_type=PromptType.document,
416
- batch_size=encode_kwargs.get("batch_size", 32),
548
+ **encode_kwargs,
417
549
  )
418
550
  predictions = self.model.predict(
419
551
  inputs1=queries_loader,
@@ -423,7 +555,7 @@ class SearchCrossEncoderWrapper:
423
555
  hf_subset=hf_subset,
424
556
  )
425
557
 
426
- results = {qid: {} for qid in queries["id"]}
558
+ results: RetrievalOutputType = {qid: {} for qid in queries["id"]}
427
559
  for (query_id, corpus_id), score in zip(doc_pairs_ids, predictions):
428
560
  results[query_id][corpus_id] = float(score)
429
561
 
@@ -1,16 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ import warnings
4
5
  from typing import TYPE_CHECKING, Any
5
6
 
6
7
  import numpy as np
7
8
  import torch
8
9
  from packaging.version import Version
9
10
  from torch.utils.data import DataLoader
11
+ from typing_extensions import Unpack
10
12
 
11
13
  from mteb._log_once import LogOnce
12
14
  from mteb.models import ModelMeta
13
- from mteb.types import Array, BatchedInput, PromptType
15
+ from mteb.types import Array, BatchedInput, EncodeKwargs, PromptType
14
16
 
15
17
  from .abs_encoder import AbsEncoder
16
18
 
@@ -25,17 +27,18 @@ SENTENCE_TRANSFORMERS_QUERY_ENCODE_VERSION = "5.0.0"
25
27
 
26
28
 
27
29
  def sentence_transformers_loader(
28
- model_name: str, revision: str | None = None, **kwargs
30
+ model_name: str, revision: str | None = None, device: str | None = None, **kwargs
29
31
  ) -> SentenceTransformerEncoderWrapper:
30
32
  """Loads a SentenceTransformer model and wraps it in a SentenceTransformerEncoderWrapper.
31
33
 
32
34
  Args:
33
35
  model_name: The name of the SentenceTransformer model to load.
34
36
  revision: The revision of the model to load.
37
+ device: The device used to load the model.
35
38
  kwargs: Additional arguments to pass to the SentenceTransformer model.
36
39
  """
37
40
  return SentenceTransformerEncoderWrapper(
38
- model=model_name, revision=revision, **kwargs
41
+ model=model_name, revision=revision, device=device, **kwargs
39
42
  )
40
43
 
41
44
 
@@ -48,6 +51,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
48
51
  self,
49
52
  model: str | SentenceTransformer,
50
53
  revision: str | None = None,
54
+ device: str | None = None,
51
55
  model_prompts: dict[str, str] | None = None,
52
56
  **kwargs,
53
57
  ) -> None:
@@ -56,6 +60,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
56
60
  Args:
57
61
  model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model.
58
62
  revision: The revision of the model to use.
63
+ device: The device used to load the model.
59
64
  model_prompts: A dictionary mapping task names to prompt names.
60
65
  First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt,
61
66
  then to the composed prompt of task type + prompt type, then to the specific task type prompt,
@@ -65,22 +70,21 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
65
70
  from sentence_transformers import SentenceTransformer
66
71
 
67
72
  if isinstance(model, str):
68
- self.model = SentenceTransformer(model, revision=revision, **kwargs)
73
+ self.model = SentenceTransformer(
74
+ model, revision=revision, device=device, **kwargs
75
+ )
69
76
  else:
70
77
  self.model = model
71
- from mteb.models.get_model_meta import (
72
- _model_meta_from_sentence_transformers,
73
- )
74
78
 
75
- self.mteb_model_meta = _model_meta_from_sentence_transformers(self.model)
79
+ self.mteb_model_meta = ModelMeta.from_sentence_transformer_model(self.model)
76
80
 
77
81
  built_in_prompts = getattr(self.model, "prompts", None)
78
82
  if built_in_prompts and not model_prompts:
79
83
  model_prompts = built_in_prompts
80
84
  elif model_prompts and built_in_prompts:
81
- logger.warning(
82
- f"Model prompts specified, these will overwrite the default model prompts. Current prompts will be:\n {model_prompts}"
83
- )
85
+ msg = f"Model prompts specified, these will overwrite the default model prompts. Current prompts will be:\n {model_prompts}"
86
+ logger.warning(msg)
87
+ warnings.warn(msg)
84
88
  self.model.prompts = model_prompts
85
89
 
86
90
  self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name(
@@ -89,9 +93,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
89
93
 
90
94
  if invalid_prompts:
91
95
  invalid_prompts = "\n".join(invalid_prompts)
92
- logger.warning(
93
- f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
94
- )
96
+ msg = f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
97
+ logger.warning(msg)
98
+ warnings.warn(msg)
95
99
 
96
100
  if (
97
101
  self.model_prompts
@@ -101,13 +105,15 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
101
105
  or PromptType.document.value not in self.model_prompts
102
106
  )
103
107
  ):
104
- logger.warning(
105
- "SentenceTransformers that use prompts most often need to be configured with at least 'query' and"
106
- f" 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
107
- )
108
+ msg = f"SentenceTransformers that use prompts most often need to be configured with at least 'query' and 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
109
+ logger.warning(msg)
110
+ warnings.warn(msg)
108
111
 
112
+ def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:
113
+ """Compute the similarity between two collections of embeddings."""
109
114
  if hasattr(self.model, "similarity") and callable(self.model.similarity):
110
- self.similarity = self.model.similarity
115
+ return self.model.similarity(embeddings1, embeddings2)
116
+ return super().similarity(embeddings1, embeddings2)
111
117
 
112
118
  def encode(
113
119
  self,
@@ -117,7 +123,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
117
123
  hf_split: str,
118
124
  hf_subset: str,
119
125
  prompt_type: PromptType | None = None,
120
- **kwargs: Any,
126
+ **kwargs: Unpack[EncodeKwargs],
121
127
  ) -> Array:
122
128
  """Encodes the given sentences using the encoder.
123
129
 
@@ -153,7 +159,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
153
159
  prompt_name = None
154
160
  if self.model_prompts is not None:
155
161
  prompt_name = self.get_prompt_name(task_metadata, prompt_type)
156
- prompt = self.model_prompts.get(prompt_name, None)
162
+ prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
157
163
  if prompt_name:
158
164
  prompt_log = f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
159
165
  else:
@@ -196,7 +202,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
196
202
  hf_split: str,
197
203
  hf_subset: str,
198
204
  prompt_type: PromptType | None = None,
199
- **kwargs: Any,
205
+ **kwargs: Unpack[EncodeKwargs],
200
206
  ) -> Array:
201
207
  """Encodes the given sentences using the encoder.
202
208
 
@@ -224,7 +230,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
224
230
  prompt_name = None
225
231
  if self.model_prompts is not None:
226
232
  prompt_name = self.get_prompt_name(task_metadata, prompt_type)
227
- prompt = self.model_prompts.get(prompt_name, None)
233
+ prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
228
234
  if prompt_name:
229
235
  logger.info(
230
236
  f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
@@ -237,7 +243,9 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
237
243
  all_embeddings = []
238
244
  for batch in inputs:
239
245
  batch_column = next(iter(batch.keys()))
240
- batched_input = [dict() for _ in range(len(batch[batch_column]))]
246
+ batched_input: list[dict[str, Any]] = [
247
+ dict() for _ in range(len(batch[batch_column]))
248
+ ]
241
249
 
242
250
  # transform from {"text": [text1, text2], "image": [image1, image2]} to
243
251
  # [{"text": text1, "image": image1}, {"text": text2, "image": image2}]
@@ -258,24 +266,36 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
258
266
 
259
267
 
260
268
  class CrossEncoderWrapper:
261
- """Wrapper for CrossEncoder models."""
269
+ """Wrapper for CrossEncoder models.
270
+
271
+ Args:
272
+ model: The CrossEncoder model to use. Can be a string (model name) or a CrossEncoder model.
273
+ revision: The revision of the model to use.
274
+ device: The device used to load the model.
275
+ query_prefix: A prefix to add to all queries.
276
+ passage_prefix: A prefix to add to all passages.
277
+ **kwargs: Additional arguments to pass to the CrossEncoder model.
278
+ """
262
279
 
263
280
  def __init__(
264
281
  self,
265
282
  model: CrossEncoder | str,
266
283
  revision: str | None = None,
284
+ device: str | None = None,
285
+ query_prefix: str = "",
286
+ passage_prefix: str = "",
267
287
  **kwargs,
268
288
  ) -> None:
269
289
  from sentence_transformers import CrossEncoder
270
290
 
271
- from mteb.models.get_model_meta import _model_meta_from_cross_encoder
272
-
273
291
  if isinstance(model, CrossEncoder):
274
292
  self.model = model
275
293
  elif isinstance(model, str):
276
- self.model = CrossEncoder(model, revision=revision, **kwargs)
294
+ self.model = CrossEncoder(model, revision=revision, device=device, **kwargs)
277
295
 
278
- self.mteb_model_meta = _model_meta_from_cross_encoder(self.model)
296
+ self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)
297
+ self.query_prefix = query_prefix
298
+ self.passage_prefix = passage_prefix
279
299
 
280
300
  def predict(
281
301
  self,
@@ -286,7 +306,7 @@ class CrossEncoderWrapper:
286
306
  hf_split: str,
287
307
  hf_subset: str,
288
308
  prompt_type: PromptType | None = None,
289
- **kwargs: Any,
309
+ **kwargs: Unpack[EncodeKwargs],
290
310
  ) -> Array:
291
311
  """Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.
292
312
 
@@ -304,10 +324,10 @@ class CrossEncoderWrapper:
304
324
  The predicted relevance scores for each inputs pair.
305
325
  """
306
326
  all_queries_with_instructions = [
307
- text for batch in inputs1 for text in batch["text"]
327
+ self.query_prefix + text for batch in inputs1 for text in batch["text"]
308
328
  ]
309
329
  all_corpus_with_instructions = [
310
- text for batch in inputs2 for text in batch["text"]
330
+ self.passage_prefix + text for batch in inputs2 for text in batch["text"]
311
331
  ]
312
332
 
313
333
  return self.model.predict(