mteb 2.1.4__py3-none-any.whl → 2.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +4 -0
- mteb/_create_dataloaders.py +6 -3
- mteb/_evaluators/any_sts_evaluator.py +21 -12
- mteb/_evaluators/classification_metrics.py +54 -0
- mteb/_evaluators/clustering_evaluator.py +1 -1
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +9 -4
- mteb/_evaluators/pair_classification_evaluator.py +30 -38
- mteb/_evaluators/sklearn_evaluator.py +15 -28
- mteb/_evaluators/text/bitext_mining_evaluator.py +4 -1
- mteb/_evaluators/text/summarization_evaluator.py +4 -2
- mteb/_evaluators/zeroshot_classification_evaluator.py +2 -2
- mteb/abstasks/_data_filter/__init__.py +0 -0
- mteb/abstasks/_data_filter/filters.py +125 -0
- mteb/abstasks/_data_filter/task_pipelines.py +102 -0
- mteb/abstasks/_statistics_calculation.py +6 -2
- mteb/abstasks/classification.py +0 -2
- mteb/abstasks/clustering.py +1 -1
- mteb/abstasks/clustering_legacy.py +3 -0
- mteb/abstasks/multilabel_classification.py +10 -3
- mteb/abstasks/pair_classification.py +8 -1
- mteb/abstasks/sts.py +7 -0
- mteb/abstasks/task_metadata.py +1 -0
- mteb/benchmarks/_create_table.py +84 -37
- mteb/benchmarks/benchmark.py +74 -15
- mteb/benchmarks/benchmarks/__init__.py +8 -0
- mteb/benchmarks/benchmarks/benchmarks.py +259 -15
- mteb/benchmarks/get_benchmark.py +2 -0
- mteb/cache.py +47 -10
- mteb/deprecated_evaluator.py +8 -13
- mteb/descriptive_stats/BitextMining/RuSciBenchBitextMining.v2.json +61 -0
- mteb/descriptive_stats/Classification/HebrewSentimentAnalysis.v3.json +60 -0
- mteb/descriptive_stats/Classification/TurkishConstitutionalCourtViolation.json +54 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3ComputerScienceRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3EnergyRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceEnRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceFrRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3HrRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3IndustrialRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3NuclearRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PharmaceuticalsRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PhysicsRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3TelecomRetrieval.json +214 -0
- mteb/descriptive_stats/PairClassification/TERRa.V2.json +35 -0
- mteb/descriptive_stats/Reranking/JQaRARerankingLite.json +35 -0
- mteb/descriptive_stats/Reranking/JaCWIRRerankingLite.json +35 -0
- mteb/descriptive_stats/Reranking/MultiLongDocReranking.json +466 -0
- mteb/descriptive_stats/Retrieval/ArguAna-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/JaCWIRRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/JaqketRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/MIRACLJaRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/MrTyDiJaRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/NFCorpus-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/SCIDOCS-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/SQuADKorV1Retrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/SciFact-NL.v2.json +30 -0
- mteb/evaluate.py +65 -45
- mteb/leaderboard/app.py +268 -133
- mteb/leaderboard/benchmark_selector.py +14 -5
- mteb/leaderboard/figures.py +13 -15
- mteb/leaderboard/table.py +82 -17
- mteb/models/__init__.py +4 -1
- mteb/models/abs_encoder.py +21 -17
- mteb/models/cache_wrappers/__init__.py +2 -1
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +2 -2
- mteb/models/cache_wrappers/cache_wrapper.py +1 -1
- mteb/models/get_model_meta.py +3 -114
- mteb/models/instruct_wrapper.py +5 -1
- mteb/models/model_implementations/align_models.py +7 -0
- mteb/models/model_implementations/amazon_models.py +1 -0
- mteb/models/model_implementations/andersborges.py +65 -0
- mteb/models/model_implementations/ara_models.py +8 -0
- mteb/models/model_implementations/arctic_models.py +8 -0
- mteb/models/model_implementations/b1ade_models.py +1 -0
- mteb/models/model_implementations/bedrock_models.py +4 -0
- mteb/models/model_implementations/bge_models.py +60 -0
- mteb/models/model_implementations/bica_model.py +35 -0
- mteb/models/model_implementations/blip2_models.py +11 -0
- mteb/models/model_implementations/blip_models.py +27 -0
- mteb/models/model_implementations/bm25.py +1 -0
- mteb/models/model_implementations/bmretriever_models.py +4 -0
- mteb/models/model_implementations/cadet_models.py +9 -0
- mteb/models/model_implementations/cde_models.py +14 -0
- mteb/models/model_implementations/clip_models.py +3 -0
- mteb/models/model_implementations/clips_models.py +100 -0
- mteb/models/model_implementations/codefuse_models.py +162 -0
- mteb/models/model_implementations/codesage_models.py +15 -0
- mteb/models/model_implementations/cohere_models.py +8 -1
- mteb/models/model_implementations/cohere_v.py +5 -0
- mteb/models/model_implementations/colpali_models.py +14 -6
- mteb/models/model_implementations/colqwen_models.py +271 -1
- mteb/models/model_implementations/colsmol_models.py +2 -0
- mteb/models/model_implementations/conan_models.py +1 -0
- mteb/models/model_implementations/dino_models.py +171 -0
- mteb/models/model_implementations/e5_instruct.py +4 -0
- mteb/models/model_implementations/e5_models.py +12 -101
- mteb/models/model_implementations/e5_v.py +1 -0
- mteb/models/model_implementations/eagerworks_models.py +164 -0
- mteb/models/model_implementations/emillykkejensen_models.py +91 -0
- mteb/models/model_implementations/en_code_retriever.py +1 -0
- mteb/models/model_implementations/euler_models.py +32 -0
- mteb/models/model_implementations/evaclip_models.py +4 -0
- mteb/models/model_implementations/fa_models.py +58 -0
- mteb/models/model_implementations/facebookai.py +193 -0
- mteb/models/model_implementations/geogpt_models.py +1 -0
- mteb/models/model_implementations/gme_v_models.py +11 -5
- mteb/models/model_implementations/google_models.py +16 -5
- mteb/models/model_implementations/granite_vision_embedding_models.py +7 -2
- mteb/models/model_implementations/gritlm_models.py +2 -0
- mteb/models/model_implementations/gte_models.py +78 -0
- mteb/models/model_implementations/hinvec_models.py +1 -0
- mteb/models/model_implementations/human.py +1 -0
- mteb/models/model_implementations/ibm_granite_models.py +6 -0
- mteb/models/model_implementations/inf_models.py +2 -0
- mteb/models/model_implementations/jasper_models.py +255 -2
- mteb/models/model_implementations/jina_clip.py +1 -0
- mteb/models/model_implementations/jina_models.py +209 -5
- mteb/models/model_implementations/kalm_models.py +203 -25
- mteb/models/model_implementations/kblab.py +31 -0
- mteb/models/model_implementations/kennethenevoldsen_models.py +74 -0
- mteb/models/model_implementations/kfst.py +25 -0
- mteb/models/model_implementations/kowshik24_models.py +32 -0
- mteb/models/model_implementations/lens_models.py +2 -0
- mteb/models/model_implementations/lgai_embedding_models.py +1 -0
- mteb/models/model_implementations/linq_models.py +3 -2
- mteb/models/model_implementations/listconranker.py +1 -1
- mteb/models/model_implementations/llm2clip_models.py +3 -0
- mteb/models/model_implementations/llm2vec_models.py +8 -0
- mteb/models/model_implementations/mcinext_models.py +3 -0
- mteb/models/model_implementations/mdbr_models.py +2 -0
- mteb/models/model_implementations/misc_models.py +362 -0
- mteb/models/model_implementations/mme5_models.py +1 -0
- mteb/models/model_implementations/moco_models.py +11 -0
- mteb/models/model_implementations/mod_models.py +191 -0
- mteb/models/model_implementations/model2vec_models.py +13 -0
- mteb/models/model_implementations/moka_models.py +3 -0
- mteb/models/model_implementations/mxbai_models.py +9 -0
- mteb/models/model_implementations/nbailab.py +70 -0
- mteb/models/model_implementations/no_instruct_sentence_models.py +1 -0
- mteb/models/model_implementations/nomic_models.py +156 -4
- mteb/models/model_implementations/nomic_models_vision.py +7 -2
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +23 -16
- mteb/models/model_implementations/nvidia_models.py +4 -1
- mteb/models/model_implementations/octen_models.py +195 -0
- mteb/models/model_implementations/openai_models.py +20 -16
- mteb/models/model_implementations/openclip_models.py +24 -0
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +5 -0
- mteb/models/model_implementations/ops_moa_models.py +4 -2
- mteb/models/model_implementations/pawan_models.py +39 -0
- mteb/models/model_implementations/piccolo_models.py +8 -0
- mteb/models/model_implementations/promptriever_models.py +8 -4
- mteb/models/model_implementations/pylate_models.py +37 -4
- mteb/models/model_implementations/qodo_models.py +2 -0
- mteb/models/model_implementations/qtack_models.py +1 -0
- mteb/models/model_implementations/qwen3_models.py +6 -3
- mteb/models/model_implementations/qzhou_models.py +3 -1
- mteb/models/model_implementations/random_baseline.py +16 -21
- mteb/models/model_implementations/rasgaard_models.py +34 -0
- mteb/models/model_implementations/reasonir_model.py +1 -0
- mteb/models/model_implementations/repllama_models.py +2 -0
- mteb/models/model_implementations/rerankers_custom.py +3 -3
- mteb/models/model_implementations/rerankers_monot5_based.py +14 -14
- mteb/models/model_implementations/richinfoai_models.py +1 -0
- mteb/models/model_implementations/ru_sentence_models.py +51 -0
- mteb/models/model_implementations/ruri_models.py +322 -0
- mteb/models/model_implementations/salesforce_models.py +3 -0
- mteb/models/model_implementations/samilpwc_models.py +1 -0
- mteb/models/model_implementations/sarashina_embedding_models.py +168 -0
- mteb/models/model_implementations/searchmap_models.py +1 -0
- mteb/models/model_implementations/seed_1_6_embedding_models.py +8 -2
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +658 -0
- mteb/models/model_implementations/seed_models.py +1 -0
- mteb/models/model_implementations/sentence_transformers_models.py +57 -0
- mteb/models/model_implementations/shuu_model.py +32 -31
- mteb/models/model_implementations/siglip_models.py +10 -0
- mteb/models/model_implementations/sonar_models.py +1 -0
- mteb/models/model_implementations/spartan8806_atles_champion.py +34 -0
- mteb/models/model_implementations/stella_models.py +6 -0
- mteb/models/model_implementations/tarka_models.py +376 -0
- mteb/models/model_implementations/ua_sentence_models.py +10 -0
- mteb/models/model_implementations/uae_models.py +1 -0
- mteb/models/model_implementations/vdr_models.py +2 -0
- mteb/models/model_implementations/vi_vn_models.py +39 -0
- mteb/models/model_implementations/vista_models.py +2 -0
- mteb/models/model_implementations/vlm2vec_models.py +2 -0
- mteb/models/model_implementations/voyage_models.py +15 -0
- mteb/models/model_implementations/voyage_v.py +8 -2
- mteb/models/model_implementations/xyz_models.py +1 -0
- mteb/models/model_implementations/youtu_models.py +1 -0
- mteb/models/model_implementations/yuan_models.py +34 -0
- mteb/models/model_implementations/yuan_models_en.py +58 -0
- mteb/models/model_meta.py +442 -22
- mteb/models/search_encoder_index/__init__.py +7 -0
- mteb/models/search_encoder_index/search_backend_protocol.py +50 -0
- mteb/models/search_encoder_index/search_indexes/__init__.py +5 -0
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +157 -0
- mteb/models/search_wrappers.py +165 -48
- mteb/models/sentence_transformer_wrapper.py +2 -7
- mteb/results/benchmark_results.py +88 -47
- mteb/results/model_result.py +11 -4
- mteb/results/task_result.py +37 -19
- mteb/similarity_functions.py +49 -0
- mteb/tasks/bitext_mining/multilingual/__init__.py +2 -1
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +4 -2
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining_fast.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ru_sci_bench_bitext_mining.py +47 -5
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -6
- mteb/tasks/classification/ara/ajgt.py +1 -2
- mteb/tasks/classification/ara/hotel_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/tweet_emotion_classification.py +1 -2
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_document_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -2
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +1 -2
- mteb/tasks/classification/dan/angry_tweets_classification.py +1 -2
- mteb/tasks/classification/dan/danish_political_comments_classification.py +1 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -2
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -2
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -2
- mteb/tasks/classification/deu/ten_k_gnad_classification.py +1 -2
- mteb/tasks/classification/eng/amazon_polarity_classification.py +1 -2
- mteb/tasks/classification/eng/arxiv_classification.py +1 -2
- mteb/tasks/classification/eng/banking77_classification.py +1 -2
- mteb/tasks/classification/eng/dbpedia_classification.py +1 -2
- mteb/tasks/classification/eng/emotion_classification.py +1 -2
- mteb/tasks/classification/eng/financial_phrasebank_classification.py +1 -2
- mteb/tasks/classification/eng/frenk_en_classification.py +1 -2
- mteb/tasks/classification/eng/gtsrb_classification.py +1 -1
- mteb/tasks/classification/eng/imdb_classification.py +1 -2
- mteb/tasks/classification/eng/legal_bench_classification.py +14 -120
- mteb/tasks/classification/eng/news_classification.py +1 -2
- mteb/tasks/classification/eng/patch_camelyon_classification.py +1 -1
- mteb/tasks/classification/eng/patent_classification.py +1 -2
- mteb/tasks/classification/eng/poem_sentiment_classification.py +1 -2
- mteb/tasks/classification/eng/sds_eye_protection_classification.py +1 -2
- mteb/tasks/classification/eng/sds_gloves_classification.py +1 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -19
- mteb/tasks/classification/eng/toxic_conversations_classification.py +1 -2
- mteb/tasks/classification/eng/tweet_sentiment_extraction_classification.py +1 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +2 -13
- mteb/tasks/classification/eng/ucf101_classification.py +1 -5
- mteb/tasks/classification/eng/wikipedia_bio_met_chem_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_chem_fields_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_comp_chem_spectroscopy_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_crystallography_analytical_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_theoretical_applied_classification.py +1 -2
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -2
- mteb/tasks/classification/eng/yelp_review_full_classification.py +1 -2
- mteb/tasks/classification/est/estonian_valence.py +1 -2
- mteb/tasks/classification/fas/fa_mteb_classification.py +7 -14
- mteb/tasks/classification/fil/filipino_hate_speech_classification.py +1 -2
- mteb/tasks/classification/fin/fin_toxicity_classification.py +2 -11
- mteb/tasks/classification/fra/french_book_reviews.py +1 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -2
- mteb/tasks/classification/heb/__init__.py +6 -1
- mteb/tasks/classification/heb/hebrew_sentiment_analysis.py +62 -4
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -2
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -2
- mteb/tasks/classification/hrv/frenk_hr_classification.py +1 -2
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +1 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -2
- mteb/tasks/classification/ita/italian_linguist_acceptability_classification.py +1 -2
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -2
- mteb/tasks/classification/jpn/wrime_classification.py +1 -2
- mteb/tasks/classification/kan/kannada_news_classification.py +1 -2
- mteb/tasks/classification/kor/klue_tc.py +1 -2
- mteb/tasks/classification/kor/kor_hate_classification.py +2 -17
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +2 -19
- mteb/tasks/classification/kur/kurdish_sentiment_classification.py +1 -2
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -2
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -2
- mteb/tasks/classification/mkd/macedonian_tweet_sentiment_classification.py +1 -2
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -6
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -4
- mteb/tasks/classification/multilingual/ru_sci_bench_classification.py +4 -23
- mteb/tasks/classification/multilingual/scala_classification.py +1 -2
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -6
- mteb/tasks/classification/mya/myanmar_news.py +1 -2
- mteb/tasks/classification/nep/nepali_news_classification.py +1 -2
- mteb/tasks/classification/nld/dutch_book_review_sentiment_classification.py +4 -2
- mteb/tasks/classification/nld/dutch_cola_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_government_bias_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_news_articles_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +3 -0
- mteb/tasks/classification/nld/iconclass_classification.py +3 -0
- mteb/tasks/classification/nld/open_tender_classification.py +3 -0
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +3 -0
- mteb/tasks/classification/nob/no_rec_classification.py +1 -2
- mteb/tasks/classification/nob/norwegian_parliament_classification.py +1 -2
- mteb/tasks/classification/ory/odia_news_classification.py +1 -2
- mteb/tasks/classification/pol/polish_classification.py +3 -6
- mteb/tasks/classification/ron/moroco.py +1 -2
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -2
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -2
- mteb/tasks/classification/rus/georeview_classification.py +1 -2
- mteb/tasks/classification/rus/headline_classification.py +1 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +1 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +1 -2
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -2
- mteb/tasks/classification/rus/senti_ru_eval.py +1 -2
- mteb/tasks/classification/sin/sinhala_news_classification.py +1 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +1 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/slk/slovak_hate_speech_classification.py +1 -2
- mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -2
- mteb/tasks/classification/spa/spanish_news_classification.py +1 -2
- mteb/tasks/classification/spa/spanish_sentiment_classification.py +1 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -2
- mteb/tasks/classification/swa/swahili_news_classification.py +1 -2
- mteb/tasks/classification/swe/dalaj_classification.py +1 -2
- mteb/tasks/classification/swe/swe_rec_classification.py +1 -2
- mteb/tasks/classification/swe/swedish_sentiment_classification.py +1 -2
- mteb/tasks/classification/tam/tamil_news_classification.py +1 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +1 -2
- mteb/tasks/classification/tha/wisesight_sentiment_classification.py +1 -2
- mteb/tasks/classification/tsn/tswana_news_classification.py +1 -2
- mteb/tasks/classification/tur/__init__.py +4 -0
- mteb/tasks/classification/tur/turkish_constitutional_court.py +41 -0
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +1 -2
- mteb/tasks/classification/tur/turkish_product_sentiment_classification.py +1 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -15
- mteb/tasks/classification/urd/urdu_roman_sentiment_classification.py +1 -2
- mteb/tasks/classification/vie/amazon_counterfactual_vn_classification.py +1 -6
- mteb/tasks/classification/vie/amazon_polarity_vn_classification.py +1 -6
- mteb/tasks/classification/vie/amazon_reviews_vn_classification.py +1 -5
- mteb/tasks/classification/vie/banking77_vn_classification.py +1 -5
- mteb/tasks/classification/vie/emotion_vn_classification.py +1 -5
- mteb/tasks/classification/vie/imdb_vn_classification.py +1 -5
- mteb/tasks/classification/vie/massive_intent_vn_classification.py +1 -5
- mteb/tasks/classification/vie/massive_scenario_vn_classification.py +1 -5
- mteb/tasks/classification/vie/mtop_domain_vn_classification.py +1 -5
- mteb/tasks/classification/vie/mtop_intent_vn_classification.py +1 -5
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -5
- mteb/tasks/classification/vie/tweet_sentiment_extraction_vn_classification.py +1 -5
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -2
- mteb/tasks/classification/zho/cmteb_classification.py +5 -10
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +1 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -2
- mteb/tasks/clustering/jpn/mews_c16_ja_clustering.py +1 -3
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -6
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/open_tender_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +3 -0
- mteb/tasks/clustering/vie/reddit_clustering_p2p_vn.py +1 -5
- mteb/tasks/clustering/vie/reddit_clustering_vn.py +1 -5
- mteb/tasks/clustering/vie/stack_exchange_clustering_p2p_vn.py +1 -5
- mteb/tasks/clustering/vie/stack_exchange_clustering_vn.py +1 -5
- mteb/tasks/clustering/vie/twenty_newsgroups_clustering_vn.py +1 -5
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -5
- mteb/tasks/multilabel_classification/kor/kor_hate_speech_ml_classification.py +1 -9
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -6
- mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +3 -0
- mteb/tasks/multilabel_classification/nld/vabb_multi_label_classification.py +3 -0
- mteb/tasks/multilabel_classification/por/brazilian_toxic_tweets_classification.py +1 -6
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -2
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -6
- mteb/tasks/pair_classification/eng/legal_bench_pc.py +1 -9
- mteb/tasks/pair_classification/nld/sick_nl_pair_classification.py +3 -0
- mteb/tasks/pair_classification/nld/xlwic_nl_pair_classification.py +3 -0
- mteb/tasks/pair_classification/rus/__init__.py +2 -2
- mteb/tasks/pair_classification/rus/terra.py +51 -25
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -5
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -5
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -5
- mteb/tasks/regression/multilingual/ru_sci_bench_regression.py +2 -6
- mteb/tasks/reranking/jpn/__init__.py +9 -1
- mteb/tasks/reranking/jpn/j_qa_ra_reranking_lite.py +49 -0
- mteb/tasks/reranking/jpn/ja_cwir_reranking_lite.py +47 -0
- mteb/tasks/reranking/multilingual/__init__.py +2 -0
- mteb/tasks/reranking/multilingual/multi_long_doc_reranking.py +70 -0
- mteb/tasks/reranking/multilingual/x_glue_wpr_reranking.py +1 -2
- mteb/tasks/reranking/vie/ask_ubuntu_dup_questions_vn.py +1 -5
- mteb/tasks/reranking/vie/sci_docs_reranking_vn.py +1 -5
- mteb/tasks/reranking/vie/stack_overflow_dup_questions_vn.py +1 -5
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +8 -5
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -8
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +4 -0
- mteb/tasks/retrieval/jpn/__init__.py +8 -0
- mteb/tasks/retrieval/jpn/ja_cwir_retrieval.py +1 -4
- mteb/tasks/retrieval/jpn/ja_cwir_retrieval_lite.py +47 -0
- mteb/tasks/retrieval/jpn/jaqket_retrieval_lite.py +50 -0
- mteb/tasks/retrieval/jpn/miracl_ja_retrieval_lite.py +52 -0
- mteb/tasks/retrieval/jpn/mr_tydi_ja_retrieval_lite.py +48 -0
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +11 -4
- mteb/tasks/retrieval/kor/__init__.py +2 -1
- mteb/tasks/retrieval/kor/squad_kor_v1_retrieval.py +47 -0
- mteb/tasks/retrieval/multilingual/__init__.py +22 -0
- mteb/tasks/retrieval/multilingual/belebele_retrieval.py +5 -4
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +56 -42
- mteb/tasks/retrieval/multilingual/mkqa_retrieval.py +1 -2
- mteb/tasks/retrieval/multilingual/mlqa_retrieval.py +1 -4
- mteb/tasks/retrieval/multilingual/multi_long_doc_retrieval.py +1 -2
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +9 -4
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -12
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +4 -2
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +399 -0
- mteb/tasks/retrieval/nld/__init__.py +8 -4
- mteb/tasks/retrieval/nld/argu_ana_nl_retrieval.py +46 -27
- mteb/tasks/retrieval/nld/bbsard_nl_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/dutch_news_articles_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/legal_qa_nl_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/nf_corpus_nl_retrieval.py +42 -25
- mteb/tasks/retrieval/nld/open_tender_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/sci_fact_nl_retrieval.py +42 -24
- mteb/tasks/retrieval/nld/scidocsnl_retrieval.py +44 -27
- mteb/tasks/retrieval/nld/vabb_retrieval.py +3 -0
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -7
- mteb/tasks/retrieval/vie/argu_ana_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_android_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_gis_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_mathematica_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_physics_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_programmers_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_stats_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_tex_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_unix_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_webmasters_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_wordpress_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/fevervn_retrieval.py +1 -7
- mteb/tasks/retrieval/vie/fi_qa2018_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/green_node_table_markdown_retrieval.py +16 -1
- mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +1 -6
- mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/nf_corpus_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/nqvn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/quora_vn_retrieval.py +1 -6
- mteb/tasks/retrieval/vie/sci_fact_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/scidocsvn_retrieval.py +1 -6
- mteb/tasks/retrieval/vie/touche2020_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/treccovidvn_retrieval.py +1 -5
- mteb/tasks/sts/nld/sick_nl_sts.py +1 -0
- mteb/tasks/sts/vie/biosses_stsvn.py +1 -5
- mteb/tasks/sts/vie/sickr_stsvn.py +1 -5
- mteb/tasks/sts/vie/sts_benchmark_stsvn.py +1 -5
- mteb/tasks/zeroshot_classification/eng/gtsrb.py +1 -1
- mteb/tasks/zeroshot_classification/eng/patch_camelyon.py +1 -1
- mteb/tasks/zeroshot_classification/eng/ucf101.py +1 -5
- mteb/types/_encoder_io.py +7 -2
- {mteb-2.1.4.dist-info → mteb-2.5.2.dist-info}/METADATA +11 -5
- {mteb-2.1.4.dist-info → mteb-2.5.2.dist-info}/RECORD +457 -391
- mteb/models/model_implementations/nb_sbert.py +0 -25
- {mteb-2.1.4.dist-info → mteb-2.5.2.dist-info}/WHEEL +0 -0
- {mteb-2.1.4.dist-info → mteb-2.5.2.dist-info}/entry_points.txt +0 -0
- {mteb-2.1.4.dist-info → mteb-2.5.2.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.1.4.dist-info → mteb-2.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from mteb._requires_package import requires_package
|
|
8
|
+
from mteb.models.model_meta import ScoringFunction
|
|
9
|
+
from mteb.models.models_protocols import EncoderProtocol
|
|
10
|
+
from mteb.types import Array, TopRankedDocumentsType
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FaissSearchIndex:
|
|
16
|
+
"""FAISS-based backend for encoder-based search.
|
|
17
|
+
|
|
18
|
+
Supports both full-corpus retrieval and reranking (via `top_ranked`).
|
|
19
|
+
|
|
20
|
+
Notes:
|
|
21
|
+
- Stores *all* embeddings in memory (IndexFlatIP or IndexFlatL2).
|
|
22
|
+
- Expects embeddings to be normalized if cosine similarity is desired.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
_normalize: bool = False
|
|
26
|
+
|
|
27
|
+
def __init__(self, model: EncoderProtocol) -> None:
|
|
28
|
+
requires_package(
|
|
29
|
+
self,
|
|
30
|
+
"faiss",
|
|
31
|
+
"FAISS-based search",
|
|
32
|
+
install_instruction="pip install mteb[faiss-cpu]",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
import faiss
|
|
36
|
+
from faiss import IndexFlatIP, IndexFlatL2
|
|
37
|
+
|
|
38
|
+
# https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
|
|
39
|
+
if model.mteb_model_meta.similarity_fn_name is ScoringFunction.DOT_PRODUCT:
|
|
40
|
+
self.index_type = IndexFlatIP
|
|
41
|
+
elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.COSINE:
|
|
42
|
+
self.index_type = IndexFlatIP
|
|
43
|
+
self._normalize = True
|
|
44
|
+
elif model.mteb_model_meta.similarity_fn_name is ScoringFunction.EUCLIDEAN:
|
|
45
|
+
self.index_type = IndexFlatL2
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"FAISS backend does not support similarity function {model.mteb_model_meta.similarity_fn_name}. "
|
|
49
|
+
f"Available: {ScoringFunction.DOT_PRODUCT}, {ScoringFunction.COSINE}."
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
self.idxs: list[str] = []
|
|
53
|
+
self.index: faiss.Index | None = None
|
|
54
|
+
|
|
55
|
+
def add_documents(self, embeddings: Array, idxs: list[str]) -> None:
|
|
56
|
+
"""Add all document embeddings and their IDs to FAISS index."""
|
|
57
|
+
import faiss
|
|
58
|
+
|
|
59
|
+
if isinstance(embeddings, torch.Tensor):
|
|
60
|
+
embeddings = embeddings.detach().cpu().numpy()
|
|
61
|
+
|
|
62
|
+
embeddings = embeddings.astype(np.float32)
|
|
63
|
+
self.idxs.extend(idxs)
|
|
64
|
+
|
|
65
|
+
if self._normalize:
|
|
66
|
+
faiss.normalize_L2(embeddings)
|
|
67
|
+
|
|
68
|
+
dim = embeddings.shape[1]
|
|
69
|
+
if self.index is None:
|
|
70
|
+
self.index = self.index_type(dim)
|
|
71
|
+
|
|
72
|
+
self.index.add(embeddings)
|
|
73
|
+
logger.info(f"FAISS index built with {len(idxs)} vectors of dim {dim}.")
|
|
74
|
+
|
|
75
|
+
def search(
|
|
76
|
+
self,
|
|
77
|
+
embeddings: Array,
|
|
78
|
+
top_k: int,
|
|
79
|
+
similarity_fn: Callable[[Array, Array], Array],
|
|
80
|
+
top_ranked: TopRankedDocumentsType | None = None,
|
|
81
|
+
query_idx_to_id: dict[int, str] | None = None,
|
|
82
|
+
) -> tuple[list[list[float]], list[list[int]]]:
|
|
83
|
+
"""Search using FAISS."""
|
|
84
|
+
import faiss
|
|
85
|
+
|
|
86
|
+
if self.index is None:
|
|
87
|
+
raise ValueError("No index built. Call add_document() first.")
|
|
88
|
+
|
|
89
|
+
if isinstance(embeddings, torch.Tensor):
|
|
90
|
+
embeddings = embeddings.detach().cpu().numpy()
|
|
91
|
+
|
|
92
|
+
if self._normalize:
|
|
93
|
+
faiss.normalize_L2(embeddings)
|
|
94
|
+
|
|
95
|
+
if top_ranked is not None:
|
|
96
|
+
if query_idx_to_id is None:
|
|
97
|
+
raise ValueError("query_idx_to_id must be provided when reranking.")
|
|
98
|
+
|
|
99
|
+
similarities, ids = self._reranking(
|
|
100
|
+
embeddings,
|
|
101
|
+
top_k,
|
|
102
|
+
top_ranked=top_ranked,
|
|
103
|
+
query_idx_to_id=query_idx_to_id,
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
similarities, ids = self.index.search(embeddings.astype(np.float32), top_k)
|
|
107
|
+
similarities = similarities.tolist()
|
|
108
|
+
ids = ids.tolist()
|
|
109
|
+
|
|
110
|
+
if issubclass(self.index_type, faiss.IndexFlatL2):
|
|
111
|
+
similarities = -np.sqrt(np.maximum(similarities, 0))
|
|
112
|
+
|
|
113
|
+
return similarities, ids
|
|
114
|
+
|
|
115
|
+
def _reranking(
|
|
116
|
+
self,
|
|
117
|
+
embeddings: Array,
|
|
118
|
+
top_k: int,
|
|
119
|
+
top_ranked: TopRankedDocumentsType | None = None,
|
|
120
|
+
query_idx_to_id: dict[int, str] | None = None,
|
|
121
|
+
) -> tuple[list[list[float]], list[list[int]]]:
|
|
122
|
+
doc_id_to_idx = {doc_id: i for i, doc_id in enumerate(self.idxs)}
|
|
123
|
+
scores_all: list[list[float]] = []
|
|
124
|
+
idxs_all: list[list[int]] = []
|
|
125
|
+
|
|
126
|
+
for query_idx, query_emb in enumerate(embeddings):
|
|
127
|
+
query_id = query_idx_to_id[query_idx]
|
|
128
|
+
ranked_ids = top_ranked.get(query_id)
|
|
129
|
+
if not ranked_ids:
|
|
130
|
+
logger.warning(f"No top-ranked documents for query {query_id}")
|
|
131
|
+
scores_all.append([])
|
|
132
|
+
idxs_all.append([])
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
candidate_indices = [doc_id_to_idx[doc_id] for doc_id in ranked_ids]
|
|
136
|
+
d = self.index.d
|
|
137
|
+
candidate_embs = np.vstack(
|
|
138
|
+
[self.index.reconstruct(idx) for idx in candidate_indices]
|
|
139
|
+
)
|
|
140
|
+
sub_reranking_index = self.index_type(d)
|
|
141
|
+
sub_reranking_index.add(candidate_embs)
|
|
142
|
+
|
|
143
|
+
# Search returns scores and indices in one call
|
|
144
|
+
scores, local_indices = sub_reranking_index.search(
|
|
145
|
+
query_emb.reshape(1, -1).astype(np.float32),
|
|
146
|
+
min(top_k, len(candidate_indices)),
|
|
147
|
+
)
|
|
148
|
+
# faiss will output 2d arrays even for single query
|
|
149
|
+
scores_all.append(scores[0].tolist())
|
|
150
|
+
idxs_all.append(local_indices[0].tolist())
|
|
151
|
+
|
|
152
|
+
return scores_all, idxs_all
|
|
153
|
+
|
|
154
|
+
def clear(self) -> None:
|
|
155
|
+
"""Clear all stored documents and embeddings from the backend."""
|
|
156
|
+
self.index = None
|
|
157
|
+
self.idxs = []
|
mteb/models/search_wrappers.py
CHANGED
|
@@ -21,6 +21,7 @@ from mteb.types import (
|
|
|
21
21
|
)
|
|
22
22
|
|
|
23
23
|
from .models_protocols import CrossEncoderProtocol, EncoderProtocol
|
|
24
|
+
from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol
|
|
24
25
|
|
|
25
26
|
logger = logging.getLogger(__name__)
|
|
26
27
|
|
|
@@ -28,13 +29,19 @@ logger = logging.getLogger(__name__)
|
|
|
28
29
|
class SearchEncoderWrapper:
|
|
29
30
|
"""Wrapper for Encoder models to be used in search tasks."""
|
|
30
31
|
|
|
31
|
-
corpus_chunk_size = 50_000
|
|
32
32
|
task_corpus: CorpusDatasetType | None
|
|
33
33
|
|
|
34
|
-
def __init__(
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model: EncoderProtocol,
|
|
37
|
+
corpus_chunk_size: int = 50_000,
|
|
38
|
+
index_backend: IndexEncoderSearchProtocol | None = None,
|
|
39
|
+
) -> None:
|
|
35
40
|
self.model = model
|
|
36
41
|
self.task_corpus = None
|
|
37
42
|
self.mteb_model_meta = model.mteb_model_meta
|
|
43
|
+
self.corpus_chunk_size = corpus_chunk_size
|
|
44
|
+
self.index_backend = index_backend
|
|
38
45
|
|
|
39
46
|
def index(
|
|
40
47
|
self,
|
|
@@ -56,6 +63,22 @@ class SearchEncoderWrapper:
|
|
|
56
63
|
"""
|
|
57
64
|
# Always retain corpus for potential reranking or fallback flows
|
|
58
65
|
self.task_corpus = corpus
|
|
66
|
+
if self.index_backend is not None:
|
|
67
|
+
all_doc_embeddings = self.model.encode(
|
|
68
|
+
create_dataloader(
|
|
69
|
+
corpus,
|
|
70
|
+
task_metadata,
|
|
71
|
+
prompt_type=PromptType.document,
|
|
72
|
+
**encode_kwargs,
|
|
73
|
+
),
|
|
74
|
+
task_metadata=task_metadata,
|
|
75
|
+
hf_split=hf_split,
|
|
76
|
+
hf_subset=hf_subset,
|
|
77
|
+
prompt_type=PromptType.document,
|
|
78
|
+
**encode_kwargs,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.index_backend.add_documents(all_doc_embeddings, corpus["id"])
|
|
59
82
|
|
|
60
83
|
def search(
|
|
61
84
|
self,
|
|
@@ -90,7 +113,7 @@ class SearchEncoderWrapper:
|
|
|
90
113
|
queries,
|
|
91
114
|
task_metadata,
|
|
92
115
|
prompt_type=PromptType.query,
|
|
93
|
-
|
|
116
|
+
**encode_kwargs,
|
|
94
117
|
)
|
|
95
118
|
|
|
96
119
|
query_embeddings = self.model.encode(
|
|
@@ -105,27 +128,74 @@ class SearchEncoderWrapper:
|
|
|
105
128
|
|
|
106
129
|
if top_ranked is not None:
|
|
107
130
|
logger.info("Reranking pre-ranked documents...")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
131
|
+
if self.index_backend is None:
|
|
132
|
+
result_heaps = self._rerank_documents(
|
|
133
|
+
query_idx_to_id=query_idx_to_id,
|
|
134
|
+
query_embeddings=query_embeddings,
|
|
135
|
+
top_ranked=top_ranked,
|
|
136
|
+
top_k=top_k,
|
|
137
|
+
task_metadata=task_metadata,
|
|
138
|
+
hf_subset=hf_subset,
|
|
139
|
+
hf_split=hf_split,
|
|
140
|
+
encode_kwargs=encode_kwargs,
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
cos_scores_top_k_values, cos_scores_top_k_idx = (
|
|
144
|
+
self.index_backend.search(
|
|
145
|
+
query_embeddings,
|
|
146
|
+
top_k,
|
|
147
|
+
similarity_fn=self.model.similarity,
|
|
148
|
+
top_ranked=top_ranked,
|
|
149
|
+
query_idx_to_id=query_idx_to_id,
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
result_heaps = {qid: [] for qid in query_idx_to_id.values()}
|
|
153
|
+
for query_itr in range(len(query_embeddings)):
|
|
154
|
+
result_heaps = self._rerank_sort_results(
|
|
155
|
+
result_heaps=result_heaps,
|
|
156
|
+
query_id=query_idx_to_id[query_itr],
|
|
157
|
+
ranked_ids=top_ranked[query_idx_to_id[query_itr]],
|
|
158
|
+
scores_top_k_idx=torch.tensor(
|
|
159
|
+
[cos_scores_top_k_idx[query_itr]]
|
|
160
|
+
),
|
|
161
|
+
scores_top_k_values=torch.tensor(
|
|
162
|
+
[cos_scores_top_k_values[query_itr]]
|
|
163
|
+
),
|
|
164
|
+
)
|
|
165
|
+
self.index_backend.clear()
|
|
118
166
|
else:
|
|
119
167
|
logger.info("Performing full corpus search...")
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
168
|
+
if self.index_backend is None:
|
|
169
|
+
result_heaps = self._full_corpus_search(
|
|
170
|
+
query_idx_to_id=query_idx_to_id,
|
|
171
|
+
query_embeddings=query_embeddings,
|
|
172
|
+
task_metadata=task_metadata,
|
|
173
|
+
hf_subset=hf_subset,
|
|
174
|
+
hf_split=hf_split,
|
|
175
|
+
top_k=top_k,
|
|
176
|
+
encode_kwargs=encode_kwargs,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
cos_scores_top_k_values, cos_scores_top_k_idx = (
|
|
180
|
+
self.index_backend.search(
|
|
181
|
+
query_embeddings,
|
|
182
|
+
top_k,
|
|
183
|
+
similarity_fn=self.model.similarity,
|
|
184
|
+
top_ranked=None,
|
|
185
|
+
query_idx_to_id=None,
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
result_heaps = {qid: [] for qid in query_idx_to_id.values()}
|
|
189
|
+
result_heaps = self._sort_full_corpus_results(
|
|
190
|
+
result_heaps=result_heaps,
|
|
191
|
+
query_idx_to_id=query_idx_to_id,
|
|
192
|
+
query_embeddings=query_embeddings,
|
|
193
|
+
cos_scores_top_k_idx=cos_scores_top_k_idx,
|
|
194
|
+
cos_scores_top_k_values=cos_scores_top_k_values,
|
|
195
|
+
sub_corpus_ids=self.task_corpus["id"],
|
|
196
|
+
top_k=top_k,
|
|
197
|
+
)
|
|
198
|
+
self.index_backend.clear()
|
|
129
199
|
|
|
130
200
|
# Reset the task corpus dataloader to None to free up memory
|
|
131
201
|
self.task_corpus = None
|
|
@@ -147,7 +217,7 @@ class SearchEncoderWrapper:
|
|
|
147
217
|
top_k: int,
|
|
148
218
|
encode_kwargs: dict[str, Any],
|
|
149
219
|
) -> dict[str, list[tuple[float, str]]]:
|
|
150
|
-
logger.info("Encoding Corpus in batches
|
|
220
|
+
logger.info("Encoding Corpus in batches (this might take a while)...")
|
|
151
221
|
itr = range(0, len(self.task_corpus), self.corpus_chunk_size)
|
|
152
222
|
|
|
153
223
|
result_heaps = {qid: [] for qid in query_idx_to_id.values()}
|
|
@@ -165,7 +235,7 @@ class SearchEncoderWrapper:
|
|
|
165
235
|
sub_corpus,
|
|
166
236
|
task_metadata,
|
|
167
237
|
prompt_type=PromptType.document,
|
|
168
|
-
|
|
238
|
+
**encode_kwargs,
|
|
169
239
|
),
|
|
170
240
|
task_metadata=task_metadata,
|
|
171
241
|
hf_split=hf_split,
|
|
@@ -180,7 +250,7 @@ class SearchEncoderWrapper:
|
|
|
180
250
|
|
|
181
251
|
# get top-k values
|
|
182
252
|
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
|
|
183
|
-
torch.
|
|
253
|
+
torch.as_tensor(scores),
|
|
184
254
|
min(
|
|
185
255
|
top_k + 1,
|
|
186
256
|
len(scores[1]) if len(scores) > 1 else len(scores[-1]),
|
|
@@ -191,19 +261,46 @@ class SearchEncoderWrapper:
|
|
|
191
261
|
cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
|
|
192
262
|
cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
|
|
193
263
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
264
|
+
sub_corpus_ids = list(sub_corpus_ids)
|
|
265
|
+
result_heaps = self._sort_full_corpus_results(
|
|
266
|
+
result_heaps=result_heaps,
|
|
267
|
+
query_idx_to_id=query_idx_to_id,
|
|
268
|
+
query_embeddings=query_embeddings,
|
|
269
|
+
cos_scores_top_k_idx=cos_scores_top_k_idx,
|
|
270
|
+
cos_scores_top_k_values=cos_scores_top_k_values,
|
|
271
|
+
sub_corpus_ids=sub_corpus_ids,
|
|
272
|
+
top_k=top_k,
|
|
273
|
+
)
|
|
274
|
+
return result_heaps
|
|
275
|
+
|
|
276
|
+
def _sort_full_corpus_results(
|
|
277
|
+
self,
|
|
278
|
+
result_heaps: dict[str, list[tuple[float, str]]],
|
|
279
|
+
query_idx_to_id: dict[int, str],
|
|
280
|
+
query_embeddings: Array,
|
|
281
|
+
cos_scores_top_k_idx: list[list[int]],
|
|
282
|
+
cos_scores_top_k_values: list[list[float]],
|
|
283
|
+
sub_corpus_ids: list[str],
|
|
284
|
+
top_k: int,
|
|
285
|
+
) -> dict[str, list[tuple[float, str]]]:
|
|
286
|
+
"""Sort the heaps into descending order lists.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
A dictionary mapping query IDs to a sorted list of tuples, each containing a relevance score and a document ID.
|
|
290
|
+
"""
|
|
291
|
+
for query_itr in range(len(query_embeddings)):
|
|
292
|
+
query_id = query_idx_to_id[query_itr]
|
|
293
|
+
for sub_corpus_id, score in zip(
|
|
294
|
+
cos_scores_top_k_idx[query_itr],
|
|
295
|
+
cos_scores_top_k_values[query_itr],
|
|
296
|
+
):
|
|
297
|
+
corpus_id = sub_corpus_ids[sub_corpus_id]
|
|
298
|
+
if len(result_heaps[query_id]) < top_k:
|
|
299
|
+
# push item on the heap
|
|
300
|
+
heapq.heappush(result_heaps[query_id], (score, corpus_id))
|
|
301
|
+
else:
|
|
302
|
+
# If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
|
|
303
|
+
heapq.heappushpop(result_heaps[query_id], (score, corpus_id))
|
|
207
304
|
return result_heaps
|
|
208
305
|
|
|
209
306
|
def _rerank_documents(
|
|
@@ -230,7 +327,7 @@ class SearchEncoderWrapper:
|
|
|
230
327
|
self.task_corpus,
|
|
231
328
|
task_metadata,
|
|
232
329
|
prompt_type=PromptType.document,
|
|
233
|
-
|
|
330
|
+
**encode_kwargs,
|
|
234
331
|
),
|
|
235
332
|
task_metadata=task_metadata,
|
|
236
333
|
hf_split=hf_split,
|
|
@@ -278,14 +375,34 @@ class SearchEncoderWrapper:
|
|
|
278
375
|
scores_top_k_values = scores_top_k_values.cpu()
|
|
279
376
|
scores_top_k_idx = scores_top_k_idx.cpu()
|
|
280
377
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
378
|
+
result_heaps = self._rerank_sort_results(
|
|
379
|
+
result_heaps=result_heaps,
|
|
380
|
+
query_id=query_id,
|
|
381
|
+
ranked_ids=ranked_ids,
|
|
382
|
+
scores_top_k_idx=scores_top_k_idx,
|
|
383
|
+
scores_top_k_values=scores_top_k_values,
|
|
384
|
+
)
|
|
385
|
+
return result_heaps
|
|
386
|
+
|
|
387
|
+
def _rerank_sort_results(
|
|
388
|
+
self,
|
|
389
|
+
result_heaps: list[tuple[float, str]],
|
|
390
|
+
query_id: str,
|
|
391
|
+
ranked_ids: list[str],
|
|
392
|
+
scores_top_k_idx: torch.Tensor,
|
|
393
|
+
scores_top_k_values: torch.Tensor,
|
|
394
|
+
) -> list[tuple[float, str]]:
|
|
395
|
+
"""Sort the heap into descending order list.
|
|
288
396
|
|
|
397
|
+
Returns:
|
|
398
|
+
A sorted list of tuples, each containing a relevance score and a document ID.
|
|
399
|
+
"""
|
|
400
|
+
for doc_idx, score in zip(
|
|
401
|
+
scores_top_k_idx[0].tolist(),
|
|
402
|
+
scores_top_k_values[0].tolist(),
|
|
403
|
+
):
|
|
404
|
+
corpus_id = ranked_ids[doc_idx]
|
|
405
|
+
heapq.heappush(result_heaps[query_id], (score, corpus_id))
|
|
289
406
|
return result_heaps
|
|
290
407
|
|
|
291
408
|
def encode(
|
|
@@ -407,13 +524,13 @@ class SearchCrossEncoderWrapper:
|
|
|
407
524
|
Dataset.from_list(total_queries),
|
|
408
525
|
task_metadata,
|
|
409
526
|
prompt_type=PromptType.document,
|
|
410
|
-
|
|
527
|
+
**encode_kwargs,
|
|
411
528
|
)
|
|
412
529
|
corpus_loader = create_dataloader(
|
|
413
530
|
Dataset.from_list(total_docs),
|
|
414
531
|
task_metadata,
|
|
415
532
|
prompt_type=PromptType.document,
|
|
416
|
-
|
|
533
|
+
**encode_kwargs,
|
|
417
534
|
)
|
|
418
535
|
predictions = self.model.predict(
|
|
419
536
|
inputs1=queries_loader,
|
|
@@ -68,11 +68,8 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
68
68
|
self.model = SentenceTransformer(model, revision=revision, **kwargs)
|
|
69
69
|
else:
|
|
70
70
|
self.model = model
|
|
71
|
-
from mteb.models.get_model_meta import (
|
|
72
|
-
_model_meta_from_sentence_transformers,
|
|
73
|
-
)
|
|
74
71
|
|
|
75
|
-
self.mteb_model_meta =
|
|
72
|
+
self.mteb_model_meta = ModelMeta.from_sentence_transformer_model(self.model)
|
|
76
73
|
|
|
77
74
|
built_in_prompts = getattr(self.model, "prompts", None)
|
|
78
75
|
if built_in_prompts and not model_prompts:
|
|
@@ -268,14 +265,12 @@ class CrossEncoderWrapper:
|
|
|
268
265
|
) -> None:
|
|
269
266
|
from sentence_transformers import CrossEncoder
|
|
270
267
|
|
|
271
|
-
from mteb.models.get_model_meta import _model_meta_from_cross_encoder
|
|
272
|
-
|
|
273
268
|
if isinstance(model, CrossEncoder):
|
|
274
269
|
self.model = model
|
|
275
270
|
elif isinstance(model, str):
|
|
276
271
|
self.model = CrossEncoder(model, revision=revision, **kwargs)
|
|
277
272
|
|
|
278
|
-
self.mteb_model_meta =
|
|
273
|
+
self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)
|
|
279
274
|
|
|
280
275
|
def predict(
|
|
281
276
|
self,
|