mteb 2.5.2__py3-none-any.whl → 2.7.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +2 -0
- mteb/_create_dataloaders.py +78 -30
- mteb/_evaluators/any_sts_evaluator.py +13 -6
- mteb/_evaluators/clustering_evaluator.py +13 -5
- mteb/_evaluators/evaluator.py +12 -4
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +22 -11
- mteb/_evaluators/pair_classification_evaluator.py +17 -7
- mteb/_evaluators/retrieval_evaluator.py +23 -14
- mteb/_evaluators/retrieval_metrics.py +26 -19
- mteb/_evaluators/sklearn_evaluator.py +27 -17
- mteb/_evaluators/text/bitext_mining_evaluator.py +36 -20
- mteb/_evaluators/text/summarization_evaluator.py +31 -20
- mteb/_evaluators/zeroshot_classification_evaluator.py +16 -5
- mteb/_helpful_enum.py +5 -1
- mteb/abstasks/_data_filter/filters.py +9 -3
- mteb/abstasks/_data_filter/task_pipelines.py +10 -2
- mteb/abstasks/_statistics_calculation.py +21 -11
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +78 -44
- mteb/abstasks/aggregate_task_metadata.py +21 -18
- mteb/abstasks/aggregated_task.py +23 -35
- mteb/abstasks/classification.py +39 -18
- mteb/abstasks/clustering.py +37 -20
- mteb/abstasks/clustering_legacy.py +30 -16
- mteb/abstasks/image/image_text_pair_classification.py +26 -9
- mteb/abstasks/multilabel_classification.py +33 -21
- mteb/abstasks/pair_classification.py +44 -19
- mteb/abstasks/regression.py +18 -10
- mteb/abstasks/retrieval.py +82 -52
- mteb/abstasks/retrieval_dataset_loaders.py +50 -39
- mteb/abstasks/sts.py +34 -15
- mteb/abstasks/task_metadata.py +44 -37
- mteb/abstasks/text/bitext_mining.py +57 -35
- mteb/abstasks/text/reranking.py +10 -8
- mteb/abstasks/text/summarization.py +26 -10
- mteb/abstasks/zeroshot_classification.py +27 -9
- mteb/benchmarks/_create_table.py +13 -7
- mteb/benchmarks/benchmark.py +15 -3
- mteb/benchmarks/benchmarks/__init__.py +6 -0
- mteb/benchmarks/benchmarks/benchmarks.py +153 -13
- mteb/benchmarks/benchmarks/rteb_benchmarks.py +20 -9
- mteb/benchmarks/get_benchmark.py +14 -55
- mteb/cache.py +189 -31
- mteb/cli/_display_tasks.py +10 -4
- mteb/cli/build_cli.py +112 -13
- mteb/cli/generate_model_card.py +50 -23
- mteb/deprecated_evaluator.py +72 -54
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
- mteb/descriptive_stats/Retrieval/BrightAopsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightBiologyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightBiologyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEarthScienceLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEarthScienceRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEconomicsLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEconomicsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightLeetcodeRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPonyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPonyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPsychologyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPsychologyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightRoboticsLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightRoboticsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightStackoverflowLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightStackoverflowRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightSustainableLivingLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightSustainableLivingRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightTheoremQAQuestionsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightTheoremQATheoremsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
- mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
- mteb/evaluate.py +71 -47
- mteb/filter_tasks.py +36 -32
- mteb/get_tasks.py +37 -33
- mteb/languages/language_scripts.py +11 -4
- mteb/leaderboard/app.py +172 -37
- mteb/leaderboard/table.py +7 -2
- mteb/load_results.py +20 -14
- mteb/models/abs_encoder.py +30 -16
- mteb/models/cache_wrappers/cache_backend_protocol.py +7 -7
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +10 -5
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +13 -4
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
- mteb/models/cache_wrappers/cache_wrapper.py +16 -11
- mteb/models/get_model_meta.py +53 -9
- mteb/models/instruct_wrapper.py +41 -13
- mteb/models/model_implementations/align_models.py +11 -5
- mteb/models/model_implementations/amazon_models.py +1 -0
- mteb/models/model_implementations/andersborges.py +6 -4
- mteb/models/model_implementations/ara_models.py +2 -1
- mteb/models/model_implementations/arctic_models.py +16 -8
- mteb/models/model_implementations/b1ade_models.py +2 -1
- mteb/models/model_implementations/bedrock_models.py +20 -6
- mteb/models/model_implementations/bge_models.py +85 -22
- mteb/models/model_implementations/bica_model.py +4 -3
- mteb/models/model_implementations/blip2_models.py +13 -6
- mteb/models/model_implementations/blip_models.py +33 -20
- mteb/models/model_implementations/bm25.py +27 -17
- mteb/models/model_implementations/bmretriever_models.py +16 -6
- mteb/models/model_implementations/cadet_models.py +2 -1
- mteb/models/model_implementations/cde_models.py +22 -9
- mteb/models/model_implementations/clip_models.py +18 -10
- mteb/models/model_implementations/clips_models.py +6 -3
- mteb/models/model_implementations/codefuse_models.py +10 -5
- mteb/models/model_implementations/codesage_models.py +6 -3
- mteb/models/model_implementations/cohere_models.py +19 -9
- mteb/models/model_implementations/cohere_v.py +16 -6
- mteb/models/model_implementations/colpali_models.py +10 -6
- mteb/models/model_implementations/colqwen_models.py +24 -38
- mteb/models/model_implementations/colsmol_models.py +5 -3
- mteb/models/model_implementations/conan_models.py +12 -5
- mteb/models/model_implementations/dino_models.py +70 -46
- mteb/models/model_implementations/e5_instruct.py +27 -4
- mteb/models/model_implementations/e5_models.py +18 -9
- mteb/models/model_implementations/e5_v.py +16 -10
- mteb/models/model_implementations/eagerworks_models.py +12 -5
- mteb/models/model_implementations/emillykkejensen_models.py +9 -6
- mteb/models/model_implementations/en_code_retriever.py +2 -1
- mteb/models/model_implementations/euler_models.py +3 -2
- mteb/models/model_implementations/evaclip_models.py +13 -4
- mteb/models/model_implementations/fa_models.py +18 -9
- mteb/models/model_implementations/facebookai.py +16 -2
- mteb/models/model_implementations/geogpt_models.py +2 -1
- mteb/models/model_implementations/gme_v_models.py +13 -8
- mteb/models/model_implementations/google_models.py +16 -5
- mteb/models/model_implementations/granite_vision_embedding_models.py +8 -6
- mteb/models/model_implementations/gritlm_models.py +5 -2
- mteb/models/model_implementations/gte_models.py +34 -13
- mteb/models/model_implementations/hinvec_models.py +7 -2
- mteb/models/model_implementations/human.py +1 -0
- mteb/models/model_implementations/ibm_granite_models.py +36 -6
- mteb/models/model_implementations/inf_models.py +4 -2
- mteb/models/model_implementations/jasper_models.py +16 -7
- mteb/models/model_implementations/jina_clip.py +58 -14
- mteb/models/model_implementations/jina_models.py +35 -16
- mteb/models/model_implementations/kalm_models.py +24 -12
- mteb/models/model_implementations/kblab.py +13 -6
- mteb/models/model_implementations/kennethenevoldsen_models.py +6 -4
- mteb/models/model_implementations/kfst.py +2 -1
- mteb/models/model_implementations/kowshik24_models.py +2 -1
- mteb/models/model_implementations/lens_models.py +2 -0
- mteb/models/model_implementations/lgai_embedding_models.py +2 -1
- mteb/models/model_implementations/linq_models.py +8 -2
- mteb/models/model_implementations/listconranker.py +11 -5
- mteb/models/model_implementations/llm2clip_models.py +18 -10
- mteb/models/model_implementations/llm2vec_models.py +28 -14
- mteb/models/model_implementations/mcinext_models.py +12 -3
- mteb/models/model_implementations/mdbr_models.py +19 -3
- mteb/models/model_implementations/misc_models.py +131 -68
- mteb/models/model_implementations/mixedbread_ai_models.py +335 -0
- mteb/models/model_implementations/mme5_models.py +3 -2
- mteb/models/model_implementations/moco_models.py +15 -8
- mteb/models/model_implementations/mod_models.py +3 -2
- mteb/models/model_implementations/model2vec_models.py +37 -18
- mteb/models/model_implementations/moka_models.py +4 -1
- mteb/models/model_implementations/nbailab.py +6 -3
- mteb/models/model_implementations/no_instruct_sentence_models.py +15 -7
- mteb/models/model_implementations/nomic_models.py +47 -19
- mteb/models/model_implementations/nomic_models_vision.py +6 -4
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +20 -8
- mteb/models/model_implementations/nvidia_models.py +165 -22
- mteb/models/model_implementations/octen_models.py +64 -3
- mteb/models/model_implementations/openai_models.py +14 -4
- mteb/models/model_implementations/openclip_models.py +30 -17
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +20 -9
- mteb/models/model_implementations/ops_moa_models.py +10 -3
- mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +2 -1
- mteb/models/model_implementations/pawan_models.py +2 -1
- mteb/models/model_implementations/piccolo_models.py +3 -1
- mteb/models/model_implementations/pixie_models.py +56 -0
- mteb/models/model_implementations/promptriever_models.py +20 -10
- mteb/models/model_implementations/pylate_models.py +41 -21
- mteb/models/model_implementations/qodo_models.py +4 -2
- mteb/models/model_implementations/qtack_models.py +2 -1
- mteb/models/model_implementations/qwen3_models.py +14 -4
- mteb/models/model_implementations/qzhou_models.py +4 -2
- mteb/models/model_implementations/random_baseline.py +7 -6
- mteb/models/model_implementations/rasgaard_models.py +3 -2
- mteb/models/model_implementations/reasonir_model.py +66 -1
- mteb/models/model_implementations/repllama_models.py +18 -9
- mteb/models/model_implementations/rerankers_custom.py +25 -10
- mteb/models/model_implementations/rerankers_monot5_based.py +41 -21
- mteb/models/model_implementations/richinfoai_models.py +2 -1
- mteb/models/model_implementations/ru_sentence_models.py +40 -20
- mteb/models/model_implementations/ruri_models.py +20 -10
- mteb/models/model_implementations/salesforce_models.py +13 -4
- mteb/models/model_implementations/samilpwc_models.py +2 -1
- mteb/models/model_implementations/sarashina_embedding_models.py +4 -2
- mteb/models/model_implementations/searchmap_models.py +2 -1
- mteb/models/model_implementations/seed_1_6_embedding_models.py +5 -2
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +119 -148
- mteb/models/model_implementations/seed_models.py +2 -1
- mteb/models/model_implementations/sentence_transformers_models.py +142 -22
- mteb/models/model_implementations/shuu_model.py +2 -1
- mteb/models/model_implementations/siglip_models.py +39 -24
- mteb/models/model_implementations/slm_models.py +419 -0
- mteb/models/model_implementations/sonar_models.py +2 -1
- mteb/models/model_implementations/spartan8806_atles_champion.py +2 -1
- mteb/models/model_implementations/stella_models.py +23 -4
- mteb/models/model_implementations/tarka_models.py +4 -2
- mteb/models/model_implementations/text2vec_models.py +12 -3
- mteb/models/model_implementations/ua_sentence_models.py +2 -1
- mteb/models/model_implementations/uae_models.py +17 -5
- mteb/models/model_implementations/vdr_models.py +9 -2
- mteb/models/model_implementations/vi_vn_models.py +12 -6
- mteb/models/model_implementations/vista_models.py +11 -4
- mteb/models/model_implementations/vlm2vec_models.py +14 -7
- mteb/models/model_implementations/voyage_models.py +136 -4
- mteb/models/model_implementations/voyage_v.py +17 -10
- mteb/models/model_implementations/xyz_models.py +1 -0
- mteb/models/model_implementations/youtu_models.py +2 -1
- mteb/models/model_implementations/yuan_models.py +2 -1
- mteb/models/model_implementations/yuan_models_en.py +3 -2
- mteb/models/model_meta.py +127 -40
- mteb/models/models_protocols.py +43 -22
- mteb/models/search_encoder_index/search_backend_protocol.py +7 -3
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +21 -10
- mteb/models/search_wrappers.py +63 -29
- mteb/models/sentence_transformer_wrapper.py +52 -26
- mteb/models/vllm_wrapper.py +329 -0
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +48 -35
- mteb/results/model_result.py +68 -32
- mteb/results/task_result.py +110 -72
- mteb/similarity_functions.py +19 -9
- mteb/tasks/aggregated_tasks/eng/cqadupstack_retrieval.py +3 -3
- mteb/tasks/aggregated_tasks/eng/sts17_multilingual_visual_sts_eng.py +3 -3
- mteb/tasks/aggregated_tasks/eng/sts_benchmark_multilingual_visual_sts_eng.py +3 -3
- mteb/tasks/aggregated_tasks/fas/cqadupstack_retrieval_fa.py +3 -3
- mteb/tasks/aggregated_tasks/fas/syn_per_chatbot_conv_sa_classification.py +3 -3
- mteb/tasks/aggregated_tasks/multilingual/sts17_multilingual_vision_sts.py +3 -3
- mteb/tasks/aggregated_tasks/multilingual/sts_benchmark_multilingual_visual_sts.py +3 -3
- mteb/tasks/aggregated_tasks/nld/cqadupstack_nl_retrieval.py +3 -3
- mteb/tasks/aggregated_tasks/pol/cqadupstack_retrieval_pl.py +3 -3
- mteb/tasks/bitext_mining/eng/pub_chem_smiles_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/fas/fa_mteb_summary_retrieval.py +3 -3
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/norwegian_courts_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +2 -2
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -1
- mteb/tasks/classification/bul/bulgarian_store_review_sentiment_classfication.py +1 -1
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -1
- mteb/tasks/classification/dan/dk_hate_classification.py +2 -2
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -1
- mteb/tasks/classification/ell/greek_legal_code_classification.py +1 -1
- mteb/tasks/classification/eng/dbpedia_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_conversations_classification.py +2 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +1 -1
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -1
- mteb/tasks/classification/eng/yelp_review_full_classification.py +2 -2
- mteb/tasks/classification/est/estonian_valence.py +2 -2
- mteb/tasks/classification/fas/fa_mteb_classification.py +6 -6
- mteb/tasks/classification/fas/persian_food_sentiment_classification.py +1 -1
- mteb/tasks/classification/fil/filipino_shopee_reviews_classification.py +1 -1
- mteb/tasks/classification/fin/fin_toxicity_classification.py +1 -1
- mteb/tasks/classification/fra/french_book_reviews.py +2 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -1
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -1
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -1
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +2 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -1
- mteb/tasks/classification/ita/dado_eval_coarse_classification.py +1 -1
- mteb/tasks/classification/ita/ita_casehold_classification.py +1 -1
- mteb/tasks/classification/ita/sardi_stance_classification.py +1 -1
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -1
- mteb/tasks/classification/jpn/wrime_classification.py +1 -1
- mteb/tasks/classification/kan/kannada_news_classification.py +2 -2
- mteb/tasks/classification/kor/klue_tc.py +2 -2
- mteb/tasks/classification/kor/kor_fin.py +1 -1
- mteb/tasks/classification/kor/kor_hate_classification.py +1 -1
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +1 -1
- mteb/tasks/classification/kur/kurdish_sentiment_classification.py +2 -2
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -1
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/afri_senti_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -1
- mteb/tasks/classification/multilingual/cyrillic_turkic_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_nlp_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/masakha_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -1
- mteb/tasks/classification/multilingual/multilingual_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/scala_classification.py +2 -2
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -1
- mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
- mteb/tasks/classification/multilingual/tweet_sentiment_classification.py +1 -1
- mteb/tasks/classification/nep/nepali_news_classification.py +2 -2
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +1 -1
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +1 -1
- mteb/tasks/classification/ory/odia_news_classification.py +2 -2
- mteb/tasks/classification/pan/punjabi_news_classification.py +1 -1
- mteb/tasks/classification/ron/moroco.py +1 -1
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -1
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -1
- mteb/tasks/classification/rus/georeview_classification.py +1 -1
- mteb/tasks/classification/rus/headline_classification.py +2 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +2 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +2 -2
- mteb/tasks/classification/rus/ru_sci_bench_grnti_classification.py +1 -1
- mteb/tasks/classification/rus/ru_sci_bench_oecd_classification.py +1 -1
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -1
- mteb/tasks/classification/san/sanskrit_shlokas_classification.py +1 -1
- mteb/tasks/classification/sin/sinhala_news_classification.py +2 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +2 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -1
- mteb/tasks/classification/spa/spanish_news_classification.py +2 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -1
- mteb/tasks/classification/tam/tamil_news_classification.py +2 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +2 -2
- mteb/tasks/classification/tha/wongnai_reviews_classification.py +1 -1
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +2 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -2
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -1
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -1
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +2 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_p2p.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/arxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/arxiv_hierarchical_clustering.py +2 -2
- mteb/tasks/clustering/eng/big_patent_clustering.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/twenty_newsgroups_clustering.py +1 -1
- mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/fas/fa_mteb_clustering.py +4 -4
- mteb/tasks/clustering/fra/hal_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -1
- mteb/tasks/clustering/multilingual/wiki_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nob/snl_clustering.py +8 -3
- mteb/tasks/clustering/nob/vg_clustering.py +8 -3
- mteb/tasks/clustering/pol/polish_clustering.py +3 -3
- mteb/tasks/clustering/rus/ru_sci_bench_grnti_clustering_p2p.py +1 -1
- mteb/tasks/clustering/rus/ru_sci_bench_oecd_clustering_p2p.py +1 -1
- mteb/tasks/clustering/zho/cmteb_clustering.py +6 -6
- mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +2 -2
- mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
- mteb/tasks/multichoice/eng/cv_bench.py +4 -4
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -1
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -1
- mteb/tasks/multilabel_classification/rus/ru_toixic_multilabelclassification_okmlcup.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -1
- mteb/tasks/pair_classification/ara/ar_entail.py +1 -1
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -1
- mteb/tasks/pair_classification/deu/false_friends_de_en_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_ai_sentence_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +4 -3
- mteb/tasks/pair_classification/eng/pub_chem_synonym_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_wiki_paragraphs_pc.py +1 -1
- mteb/tasks/pair_classification/eng/sprint_duplicate_questions_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_sem_eval2015_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_url_corpus_pc.py +1 -1
- mteb/tasks/pair_classification/fas/fa_mteb_pair_classification.py +5 -5
- mteb/tasks/pair_classification/fas/fars_tail.py +2 -2
- mteb/tasks/pair_classification/hye/armenian_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/ita/dis_co_tex_pair_classification.py +1 -1
- mteb/tasks/pair_classification/kor/klue_nli.py +1 -1
- mteb/tasks/pair_classification/multilingual/rte3.py +2 -2
- mteb/tasks/pair_classification/multilingual/xnli.py +1 -1
- mteb/tasks/pair_classification/pol/polish_pc.py +4 -4
- mteb/tasks/pair_classification/por/assin2_rte.py +1 -1
- mteb/tasks/pair_classification/por/sick_br_pc.py +1 -1
- mteb/tasks/pair_classification/rus/terra.py +2 -2
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -1
- mteb/tasks/pair_classification/zho/cmteb_pair_classification.py +2 -2
- mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
- mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +16 -16
- mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
- mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
- mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
- mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
- mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
- mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +2 -2
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +3 -3
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +3 -3
- mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
- mteb/tasks/retrieval/eng/__init__.py +44 -0
- mteb/tasks/retrieval/eng/bright_retrieval.py +10 -2
- mteb/tasks/retrieval/eng/bright_v1_1_retrieval.py +968 -0
- mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
- mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/limit_retrieval.py +6 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/ml_questions.py +1 -1
- mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
- mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
- mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
- mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
- mteb/tasks/retrieval/kor/__init__.py +15 -1
- mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
- mteb/tasks/retrieval/multilingual/__init__.py +2 -0
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
- mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +5 -5
- mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +14 -4
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +90 -100
- mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nob/norquad.py +3 -3
- mteb/tasks/retrieval/nob/snl_retrieval.py +3 -3
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/tasks/retrieval/vie/__init__.py +14 -6
- mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/fevervn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +48 -0
- mteb/tasks/retrieval/vie/nqvn_retrieval.py +39 -0
- mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
- mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
- mteb/tasks/sts/fao/faroese_sts.py +1 -1
- mteb/tasks/sts/fra/sick_fr_sts.py +1 -1
- mteb/tasks/sts/kor/klue_sts.py +1 -1
- mteb/tasks/sts/por/sick_br_sts.py +1 -1
- mteb/tasks/sts/rus/ru_para_phraser_sts.py +1 -1
- mteb/tasks/zeroshot_classification/eng/sci_mmir.py +1 -1
- mteb/types/__init__.py +2 -0
- mteb/types/_encoder_io.py +13 -1
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +18 -5
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/METADATA +15 -4
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/RECORD +528 -486
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/WHEEL +1 -1
- mteb/models/model_implementations/mxbai_models.py +0 -111
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/entry_points.txt +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.5.2.dist-info → mteb-2.7.9.dist-info}/top_level.txt +0 -0
mteb/models/search_wrappers.py
CHANGED
|
@@ -1,27 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import heapq
|
|
2
4
|
import logging
|
|
3
|
-
from typing import Any
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
4
6
|
|
|
5
7
|
import torch
|
|
6
8
|
from datasets import Dataset
|
|
7
|
-
from torch.utils.data import DataLoader
|
|
8
9
|
|
|
9
10
|
from mteb._create_dataloaders import (
|
|
10
11
|
create_dataloader,
|
|
11
12
|
)
|
|
12
|
-
from mteb.abstasks.task_metadata import TaskMetadata
|
|
13
13
|
from mteb.types import (
|
|
14
|
-
Array,
|
|
15
|
-
BatchedInput,
|
|
16
|
-
CorpusDatasetType,
|
|
17
14
|
PromptType,
|
|
18
|
-
QueryDatasetType,
|
|
19
|
-
RetrievalOutputType,
|
|
20
|
-
TopRankedDocumentsType,
|
|
21
15
|
)
|
|
22
16
|
|
|
23
|
-
|
|
24
|
-
from .
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
|
|
20
|
+
from mteb.abstasks.task_metadata import TaskMetadata
|
|
21
|
+
from mteb.types import (
|
|
22
|
+
Array,
|
|
23
|
+
BatchedInput,
|
|
24
|
+
CorpusDatasetType,
|
|
25
|
+
EncodeKwargs,
|
|
26
|
+
QueryDatasetType,
|
|
27
|
+
RetrievalOutputType,
|
|
28
|
+
TopRankedDocumentsType,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from .models_protocols import CrossEncoderProtocol, EncoderProtocol
|
|
32
|
+
from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol
|
|
25
33
|
|
|
26
34
|
logger = logging.getLogger(__name__)
|
|
27
35
|
|
|
@@ -50,7 +58,8 @@ class SearchEncoderWrapper:
|
|
|
50
58
|
task_metadata: TaskMetadata,
|
|
51
59
|
hf_split: str,
|
|
52
60
|
hf_subset: str,
|
|
53
|
-
encode_kwargs:
|
|
61
|
+
encode_kwargs: EncodeKwargs,
|
|
62
|
+
num_proc: int = 1,
|
|
54
63
|
) -> None:
|
|
55
64
|
"""Index the corpus for retrieval.
|
|
56
65
|
|
|
@@ -60,6 +69,7 @@ class SearchEncoderWrapper:
|
|
|
60
69
|
hf_split: Split of current task, allows to know some additional information about current split.
|
|
61
70
|
hf_subset: Subset of current task. Similar to `hf_split` to get more information
|
|
62
71
|
encode_kwargs: Additional arguments to pass to the encoder during indexing.
|
|
72
|
+
num_proc: Number of processes to use for dataloading.
|
|
63
73
|
"""
|
|
64
74
|
# Always retain corpus for potential reranking or fallback flows
|
|
65
75
|
self.task_corpus = corpus
|
|
@@ -69,6 +79,7 @@ class SearchEncoderWrapper:
|
|
|
69
79
|
corpus,
|
|
70
80
|
task_metadata,
|
|
71
81
|
prompt_type=PromptType.document,
|
|
82
|
+
num_proc=num_proc,
|
|
72
83
|
**encode_kwargs,
|
|
73
84
|
),
|
|
74
85
|
task_metadata=task_metadata,
|
|
@@ -88,8 +99,9 @@ class SearchEncoderWrapper:
|
|
|
88
99
|
hf_split: str,
|
|
89
100
|
hf_subset: str,
|
|
90
101
|
top_k: int,
|
|
91
|
-
encode_kwargs:
|
|
102
|
+
encode_kwargs: EncodeKwargs,
|
|
92
103
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
104
|
+
num_proc: int = 1,
|
|
93
105
|
) -> RetrievalOutputType:
|
|
94
106
|
"""Search the corpus for the given queries.
|
|
95
107
|
|
|
@@ -102,6 +114,7 @@ class SearchEncoderWrapper:
|
|
|
102
114
|
Passed only from Reranking tasks.
|
|
103
115
|
top_k: Number of top documents to return for each query.
|
|
104
116
|
encode_kwargs: Additional arguments to pass to the encoder during indexing.
|
|
117
|
+
num_proc: Number of processes to use for dataloading.
|
|
105
118
|
|
|
106
119
|
Returns:
|
|
107
120
|
Dictionary with query IDs as keys with dict as values, where each value is a mapping of document IDs to their relevance scores.
|
|
@@ -113,6 +126,7 @@ class SearchEncoderWrapper:
|
|
|
113
126
|
queries,
|
|
114
127
|
task_metadata,
|
|
115
128
|
prompt_type=PromptType.query,
|
|
129
|
+
num_proc=num_proc,
|
|
116
130
|
**encode_kwargs,
|
|
117
131
|
)
|
|
118
132
|
|
|
@@ -200,7 +214,7 @@ class SearchEncoderWrapper:
|
|
|
200
214
|
# Reset the task corpus dataloader to None to free up memory
|
|
201
215
|
self.task_corpus = None
|
|
202
216
|
|
|
203
|
-
results = {qid: {} for qid in query_idx_to_id.values()}
|
|
217
|
+
results: RetrievalOutputType = {qid: {} for qid in query_idx_to_id.values()}
|
|
204
218
|
for qid in result_heaps:
|
|
205
219
|
for score, corpus_id in result_heaps[qid]:
|
|
206
220
|
results[qid][corpus_id] = score
|
|
@@ -215,16 +229,22 @@ class SearchEncoderWrapper:
|
|
|
215
229
|
hf_subset: str,
|
|
216
230
|
hf_split: str,
|
|
217
231
|
top_k: int,
|
|
218
|
-
encode_kwargs:
|
|
232
|
+
encode_kwargs: EncodeKwargs,
|
|
219
233
|
) -> dict[str, list[tuple[float, str]]]:
|
|
220
234
|
logger.info("Encoding Corpus in batches (this might take a while)...")
|
|
235
|
+
if self.task_corpus is None:
|
|
236
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
237
|
+
|
|
221
238
|
itr = range(0, len(self.task_corpus), self.corpus_chunk_size)
|
|
222
239
|
|
|
223
|
-
result_heaps
|
|
240
|
+
result_heaps: dict[str, list[tuple[float, str]]] = {
|
|
241
|
+
qid: [] for qid in query_idx_to_id.values()
|
|
242
|
+
}
|
|
224
243
|
for batch_num, corpus_start_idx in enumerate(itr):
|
|
225
244
|
logger.info(f"Encoding Batch {batch_num + 1}/{len(itr)}...")
|
|
226
245
|
corpus_end_idx = min(
|
|
227
|
-
corpus_start_idx + self.corpus_chunk_size,
|
|
246
|
+
corpus_start_idx + self.corpus_chunk_size,
|
|
247
|
+
len(self.task_corpus),
|
|
228
248
|
)
|
|
229
249
|
sub_corpus = self.task_corpus.select(
|
|
230
250
|
range(corpus_start_idx, corpus_end_idx)
|
|
@@ -249,7 +269,7 @@ class SearchEncoderWrapper:
|
|
|
249
269
|
scores = self.model.similarity(query_embeddings, sub_corpus_embeddings)
|
|
250
270
|
|
|
251
271
|
# get top-k values
|
|
252
|
-
|
|
272
|
+
cos_scores_top_k_values_tensor, cos_scores_top_k_idx_tensor = torch.topk(
|
|
253
273
|
torch.as_tensor(scores),
|
|
254
274
|
min(
|
|
255
275
|
top_k + 1,
|
|
@@ -258,8 +278,8 @@ class SearchEncoderWrapper:
|
|
|
258
278
|
dim=1,
|
|
259
279
|
largest=True,
|
|
260
280
|
)
|
|
261
|
-
cos_scores_top_k_idx =
|
|
262
|
-
cos_scores_top_k_values =
|
|
281
|
+
cos_scores_top_k_idx = cos_scores_top_k_idx_tensor.cpu().tolist()
|
|
282
|
+
cos_scores_top_k_values = cos_scores_top_k_values_tensor.cpu().tolist()
|
|
263
283
|
|
|
264
284
|
sub_corpus_ids = list(sub_corpus_ids)
|
|
265
285
|
result_heaps = self._sort_full_corpus_results(
|
|
@@ -312,14 +332,18 @@ class SearchEncoderWrapper:
|
|
|
312
332
|
task_metadata: TaskMetadata,
|
|
313
333
|
hf_subset: str,
|
|
314
334
|
hf_split: str,
|
|
315
|
-
encode_kwargs:
|
|
335
|
+
encode_kwargs: EncodeKwargs,
|
|
316
336
|
) -> dict[str, list[tuple[float, str]]]:
|
|
317
337
|
"""Rerank documents based on pre-ranked documents.
|
|
318
338
|
|
|
319
339
|
Returns:
|
|
320
340
|
A dictionary mapping query IDs to a list of tuples, each containing a relevance score and a document ID.
|
|
321
341
|
"""
|
|
322
|
-
|
|
342
|
+
if self.task_corpus is None:
|
|
343
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
344
|
+
result_heaps: dict[str, list[tuple[float, str]]] = {
|
|
345
|
+
qid: [] for qid in query_idx_to_id.values()
|
|
346
|
+
}
|
|
323
347
|
doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
|
|
324
348
|
|
|
325
349
|
all_doc_embeddings = self.model.encode(
|
|
@@ -340,7 +364,8 @@ class SearchEncoderWrapper:
|
|
|
340
364
|
for query_idx, query_embedding in enumerate(query_embeddings):
|
|
341
365
|
query_id = query_idx_to_id[query_idx]
|
|
342
366
|
if query_id not in top_ranked:
|
|
343
|
-
|
|
367
|
+
msg = f"No pre-ranked documents found for query {query_id}"
|
|
368
|
+
logger.warning(msg)
|
|
344
369
|
continue
|
|
345
370
|
|
|
346
371
|
ranked_ids = top_ranked[query_id]
|
|
@@ -386,12 +411,12 @@ class SearchEncoderWrapper:
|
|
|
386
411
|
|
|
387
412
|
def _rerank_sort_results(
|
|
388
413
|
self,
|
|
389
|
-
result_heaps: list[tuple[float, str]],
|
|
414
|
+
result_heaps: dict[str, list[tuple[float, str]]],
|
|
390
415
|
query_id: str,
|
|
391
416
|
ranked_ids: list[str],
|
|
392
417
|
scores_top_k_idx: torch.Tensor,
|
|
393
418
|
scores_top_k_values: torch.Tensor,
|
|
394
|
-
) -> list[tuple[float, str]]:
|
|
419
|
+
) -> dict[str, list[tuple[float, str]]]:
|
|
395
420
|
"""Sort the heap into descending order list.
|
|
396
421
|
|
|
397
422
|
Returns:
|
|
@@ -459,7 +484,8 @@ class SearchCrossEncoderWrapper:
|
|
|
459
484
|
task_metadata: TaskMetadata,
|
|
460
485
|
hf_split: str,
|
|
461
486
|
hf_subset: str,
|
|
462
|
-
encode_kwargs:
|
|
487
|
+
encode_kwargs: EncodeKwargs,
|
|
488
|
+
num_proc: int = 1,
|
|
463
489
|
) -> None:
|
|
464
490
|
"""Index the corpus for retrieval.
|
|
465
491
|
|
|
@@ -469,6 +495,7 @@ class SearchCrossEncoderWrapper:
|
|
|
469
495
|
hf_split: Split of current task, allows to know some additional information about current split.
|
|
470
496
|
hf_subset: Subset of current task. Similar to `hf_split` to get more information
|
|
471
497
|
encode_kwargs: Additional arguments to pass to the encoder during indexing.
|
|
498
|
+
num_proc: Number of processes to use.
|
|
472
499
|
"""
|
|
473
500
|
self.task_corpus = corpus
|
|
474
501
|
|
|
@@ -480,8 +507,9 @@ class SearchCrossEncoderWrapper:
|
|
|
480
507
|
hf_split: str,
|
|
481
508
|
hf_subset: str,
|
|
482
509
|
top_k: int,
|
|
483
|
-
encode_kwargs:
|
|
510
|
+
encode_kwargs: EncodeKwargs,
|
|
484
511
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
512
|
+
num_proc: int = 1,
|
|
485
513
|
) -> RetrievalOutputType:
|
|
486
514
|
"""Search the corpus using the given queries.
|
|
487
515
|
|
|
@@ -494,6 +522,7 @@ class SearchCrossEncoderWrapper:
|
|
|
494
522
|
Passed only from Reranking tasks.
|
|
495
523
|
top_k: Number of top documents to return for each query.
|
|
496
524
|
encode_kwargs: Additional arguments to pass to the encoder during indexing.
|
|
525
|
+
num_proc: Number of processes to use.
|
|
497
526
|
|
|
498
527
|
Returns:
|
|
499
528
|
Dictionary with query IDs as keys with dict as values, where each value is a mapping of document IDs to their relevance scores.
|
|
@@ -502,6 +531,8 @@ class SearchCrossEncoderWrapper:
|
|
|
502
531
|
raise ValueError(
|
|
503
532
|
"CrossEncoder search requires top_ranked documents for reranking."
|
|
504
533
|
)
|
|
534
|
+
if self.task_corpus is None:
|
|
535
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
505
536
|
|
|
506
537
|
query_id_to_idx = {row["id"]: i for i, row in enumerate(queries)}
|
|
507
538
|
doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
|
|
@@ -511,7 +542,8 @@ class SearchCrossEncoderWrapper:
|
|
|
511
542
|
doc_pairs_ids: list[tuple[str, str]] = []
|
|
512
543
|
for query_id, corpus_ids in top_ranked.items():
|
|
513
544
|
if query_id not in top_ranked:
|
|
514
|
-
|
|
545
|
+
msg = f"No pre-ranked documents found for query {query_id}"
|
|
546
|
+
logger.warning(msg)
|
|
515
547
|
continue
|
|
516
548
|
|
|
517
549
|
query_idx = query_id_to_idx[query_id]
|
|
@@ -524,12 +556,14 @@ class SearchCrossEncoderWrapper:
|
|
|
524
556
|
Dataset.from_list(total_queries),
|
|
525
557
|
task_metadata,
|
|
526
558
|
prompt_type=PromptType.document,
|
|
559
|
+
num_proc=num_proc,
|
|
527
560
|
**encode_kwargs,
|
|
528
561
|
)
|
|
529
562
|
corpus_loader = create_dataloader(
|
|
530
563
|
Dataset.from_list(total_docs),
|
|
531
564
|
task_metadata,
|
|
532
565
|
prompt_type=PromptType.document,
|
|
566
|
+
num_proc=num_proc,
|
|
533
567
|
**encode_kwargs,
|
|
534
568
|
)
|
|
535
569
|
predictions = self.model.predict(
|
|
@@ -540,7 +574,7 @@ class SearchCrossEncoderWrapper:
|
|
|
540
574
|
hf_subset=hf_subset,
|
|
541
575
|
)
|
|
542
576
|
|
|
543
|
-
results = {qid: {} for qid in queries["id"]}
|
|
577
|
+
results: RetrievalOutputType = {qid: {} for qid in queries["id"]}
|
|
544
578
|
for (query_id, corpus_id), score in zip(doc_pairs_ids, predictions):
|
|
545
579
|
results[query_id][corpus_id] = float(score)
|
|
546
580
|
|
|
@@ -1,23 +1,26 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import warnings
|
|
4
5
|
from typing import TYPE_CHECKING, Any
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import torch
|
|
8
9
|
from packaging.version import Version
|
|
9
|
-
from torch.utils.data import DataLoader
|
|
10
10
|
|
|
11
11
|
from mteb._log_once import LogOnce
|
|
12
12
|
from mteb.models import ModelMeta
|
|
13
|
-
from mteb.types import
|
|
13
|
+
from mteb.types import PromptType
|
|
14
14
|
|
|
15
15
|
from .abs_encoder import AbsEncoder
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
from typing_extensions import Unpack
|
|
19
21
|
|
|
20
22
|
from mteb.abstasks.task_metadata import TaskMetadata
|
|
23
|
+
from mteb.types import Array, BatchedInput, EncodeKwargs
|
|
21
24
|
|
|
22
25
|
logger = logging.getLogger(__name__)
|
|
23
26
|
|
|
@@ -25,17 +28,18 @@ SENTENCE_TRANSFORMERS_QUERY_ENCODE_VERSION = "5.0.0"
|
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
def sentence_transformers_loader(
|
|
28
|
-
model_name: str, revision: str | None = None, **kwargs
|
|
31
|
+
model_name: str, revision: str | None = None, device: str | None = None, **kwargs
|
|
29
32
|
) -> SentenceTransformerEncoderWrapper:
|
|
30
33
|
"""Loads a SentenceTransformer model and wraps it in a SentenceTransformerEncoderWrapper.
|
|
31
34
|
|
|
32
35
|
Args:
|
|
33
36
|
model_name: The name of the SentenceTransformer model to load.
|
|
34
37
|
revision: The revision of the model to load.
|
|
38
|
+
device: The device used to load the model.
|
|
35
39
|
kwargs: Additional arguments to pass to the SentenceTransformer model.
|
|
36
40
|
"""
|
|
37
41
|
return SentenceTransformerEncoderWrapper(
|
|
38
|
-
model=model_name, revision=revision, **kwargs
|
|
42
|
+
model=model_name, revision=revision, device=device, **kwargs
|
|
39
43
|
)
|
|
40
44
|
|
|
41
45
|
|
|
@@ -48,6 +52,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
48
52
|
self,
|
|
49
53
|
model: str | SentenceTransformer,
|
|
50
54
|
revision: str | None = None,
|
|
55
|
+
device: str | None = None,
|
|
51
56
|
model_prompts: dict[str, str] | None = None,
|
|
52
57
|
**kwargs,
|
|
53
58
|
) -> None:
|
|
@@ -56,6 +61,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
56
61
|
Args:
|
|
57
62
|
model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model.
|
|
58
63
|
revision: The revision of the model to use.
|
|
64
|
+
device: The device used to load the model.
|
|
59
65
|
model_prompts: A dictionary mapping task names to prompt names.
|
|
60
66
|
First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt,
|
|
61
67
|
then to the composed prompt of task type + prompt type, then to the specific task type prompt,
|
|
@@ -65,7 +71,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
65
71
|
from sentence_transformers import SentenceTransformer
|
|
66
72
|
|
|
67
73
|
if isinstance(model, str):
|
|
68
|
-
self.model = SentenceTransformer(
|
|
74
|
+
self.model = SentenceTransformer(
|
|
75
|
+
model, revision=revision, device=device, **kwargs
|
|
76
|
+
)
|
|
69
77
|
else:
|
|
70
78
|
self.model = model
|
|
71
79
|
|
|
@@ -75,9 +83,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
75
83
|
if built_in_prompts and not model_prompts:
|
|
76
84
|
model_prompts = built_in_prompts
|
|
77
85
|
elif model_prompts and built_in_prompts:
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
86
|
+
msg = f"Model prompts specified, these will overwrite the default model prompts. Current prompts will be:\n {model_prompts}"
|
|
87
|
+
logger.warning(msg)
|
|
88
|
+
warnings.warn(msg)
|
|
81
89
|
self.model.prompts = model_prompts
|
|
82
90
|
|
|
83
91
|
self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name(
|
|
@@ -86,9 +94,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
86
94
|
|
|
87
95
|
if invalid_prompts:
|
|
88
96
|
invalid_prompts = "\n".join(invalid_prompts)
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
)
|
|
97
|
+
msg = f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
|
|
98
|
+
logger.warning(msg)
|
|
99
|
+
warnings.warn(msg)
|
|
92
100
|
|
|
93
101
|
if (
|
|
94
102
|
self.model_prompts
|
|
@@ -98,13 +106,15 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
98
106
|
or PromptType.document.value not in self.model_prompts
|
|
99
107
|
)
|
|
100
108
|
):
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
)
|
|
109
|
+
msg = f"SentenceTransformers that use prompts most often need to be configured with at least 'query' and 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
|
|
110
|
+
logger.warning(msg)
|
|
111
|
+
warnings.warn(msg)
|
|
105
112
|
|
|
113
|
+
def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:
|
|
114
|
+
"""Compute the similarity between two collections of embeddings."""
|
|
106
115
|
if hasattr(self.model, "similarity") and callable(self.model.similarity):
|
|
107
|
-
|
|
116
|
+
return self.model.similarity(embeddings1, embeddings2)
|
|
117
|
+
return super().similarity(embeddings1, embeddings2)
|
|
108
118
|
|
|
109
119
|
def encode(
|
|
110
120
|
self,
|
|
@@ -114,7 +124,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
114
124
|
hf_split: str,
|
|
115
125
|
hf_subset: str,
|
|
116
126
|
prompt_type: PromptType | None = None,
|
|
117
|
-
**kwargs:
|
|
127
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
118
128
|
) -> Array:
|
|
119
129
|
"""Encodes the given sentences using the encoder.
|
|
120
130
|
|
|
@@ -150,7 +160,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
150
160
|
prompt_name = None
|
|
151
161
|
if self.model_prompts is not None:
|
|
152
162
|
prompt_name = self.get_prompt_name(task_metadata, prompt_type)
|
|
153
|
-
prompt = self.model_prompts.get(prompt_name, None)
|
|
163
|
+
prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
|
|
154
164
|
if prompt_name:
|
|
155
165
|
prompt_log = f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
|
|
156
166
|
else:
|
|
@@ -193,7 +203,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
193
203
|
hf_split: str,
|
|
194
204
|
hf_subset: str,
|
|
195
205
|
prompt_type: PromptType | None = None,
|
|
196
|
-
**kwargs:
|
|
206
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
197
207
|
) -> Array:
|
|
198
208
|
"""Encodes the given sentences using the encoder.
|
|
199
209
|
|
|
@@ -221,7 +231,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
221
231
|
prompt_name = None
|
|
222
232
|
if self.model_prompts is not None:
|
|
223
233
|
prompt_name = self.get_prompt_name(task_metadata, prompt_type)
|
|
224
|
-
prompt = self.model_prompts.get(prompt_name, None)
|
|
234
|
+
prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
|
|
225
235
|
if prompt_name:
|
|
226
236
|
logger.info(
|
|
227
237
|
f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
|
|
@@ -234,7 +244,9 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
234
244
|
all_embeddings = []
|
|
235
245
|
for batch in inputs:
|
|
236
246
|
batch_column = next(iter(batch.keys()))
|
|
237
|
-
batched_input
|
|
247
|
+
batched_input: list[dict[str, Any]] = [
|
|
248
|
+
dict() for _ in range(len(batch[batch_column]))
|
|
249
|
+
]
|
|
238
250
|
|
|
239
251
|
# transform from {"text": [text1, text2], "image": [image1, image2]} to
|
|
240
252
|
# [{"text": text1, "image": image1}, {"text": text2, "image": image2}]
|
|
@@ -255,12 +267,24 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
255
267
|
|
|
256
268
|
|
|
257
269
|
class CrossEncoderWrapper:
|
|
258
|
-
"""Wrapper for CrossEncoder models.
|
|
270
|
+
"""Wrapper for CrossEncoder models.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
model: The CrossEncoder model to use. Can be a string (model name) or a CrossEncoder model.
|
|
274
|
+
revision: The revision of the model to use.
|
|
275
|
+
device: The device used to load the model.
|
|
276
|
+
query_prefix: A prefix to add to all queries.
|
|
277
|
+
passage_prefix: A prefix to add to all passages.
|
|
278
|
+
**kwargs: Additional arguments to pass to the CrossEncoder model.
|
|
279
|
+
"""
|
|
259
280
|
|
|
260
281
|
def __init__(
|
|
261
282
|
self,
|
|
262
283
|
model: CrossEncoder | str,
|
|
263
284
|
revision: str | None = None,
|
|
285
|
+
device: str | None = None,
|
|
286
|
+
query_prefix: str = "",
|
|
287
|
+
passage_prefix: str = "",
|
|
264
288
|
**kwargs,
|
|
265
289
|
) -> None:
|
|
266
290
|
from sentence_transformers import CrossEncoder
|
|
@@ -268,9 +292,11 @@ class CrossEncoderWrapper:
|
|
|
268
292
|
if isinstance(model, CrossEncoder):
|
|
269
293
|
self.model = model
|
|
270
294
|
elif isinstance(model, str):
|
|
271
|
-
self.model = CrossEncoder(model, revision=revision, **kwargs)
|
|
295
|
+
self.model = CrossEncoder(model, revision=revision, device=device, **kwargs)
|
|
272
296
|
|
|
273
297
|
self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)
|
|
298
|
+
self.query_prefix = query_prefix
|
|
299
|
+
self.passage_prefix = passage_prefix
|
|
274
300
|
|
|
275
301
|
def predict(
|
|
276
302
|
self,
|
|
@@ -281,7 +307,7 @@ class CrossEncoderWrapper:
|
|
|
281
307
|
hf_split: str,
|
|
282
308
|
hf_subset: str,
|
|
283
309
|
prompt_type: PromptType | None = None,
|
|
284
|
-
**kwargs:
|
|
310
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
285
311
|
) -> Array:
|
|
286
312
|
"""Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.
|
|
287
313
|
|
|
@@ -299,10 +325,10 @@ class CrossEncoderWrapper:
|
|
|
299
325
|
The predicted relevance scores for each inputs pair.
|
|
300
326
|
"""
|
|
301
327
|
all_queries_with_instructions = [
|
|
302
|
-
text for batch in inputs1 for text in batch["text"]
|
|
328
|
+
self.query_prefix + text for batch in inputs1 for text in batch["text"]
|
|
303
329
|
]
|
|
304
330
|
all_corpus_with_instructions = [
|
|
305
|
-
text for batch in inputs2 for text in batch["text"]
|
|
331
|
+
self.passage_prefix + text for batch in inputs2 for text in batch["text"]
|
|
306
332
|
]
|
|
307
333
|
|
|
308
334
|
return self.model.predict(
|