mteb 2.1.4__py3-none-any.whl → 2.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/__init__.py +6 -0
- mteb/_create_dataloaders.py +22 -20
- mteb/_evaluators/any_sts_evaluator.py +23 -14
- mteb/_evaluators/classification_metrics.py +54 -0
- mteb/_evaluators/clustering_evaluator.py +3 -3
- mteb/_evaluators/evaluator.py +4 -2
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +18 -11
- mteb/_evaluators/pair_classification_evaluator.py +34 -40
- mteb/_evaluators/retrieval_evaluator.py +2 -2
- mteb/_evaluators/retrieval_metrics.py +18 -17
- mteb/_evaluators/sklearn_evaluator.py +25 -37
- mteb/_evaluators/text/bitext_mining_evaluator.py +31 -19
- mteb/_evaluators/text/summarization_evaluator.py +27 -20
- mteb/_evaluators/zeroshot_classification_evaluator.py +7 -5
- mteb/abstasks/_data_filter/__init__.py +0 -0
- mteb/abstasks/_data_filter/filters.py +125 -0
- mteb/abstasks/_data_filter/task_pipelines.py +105 -0
- mteb/abstasks/_statistics_calculation.py +23 -11
- mteb/abstasks/_stratification.py +18 -18
- mteb/abstasks/abstask.py +35 -28
- mteb/abstasks/aggregate_task_metadata.py +1 -9
- mteb/abstasks/aggregated_task.py +10 -29
- mteb/abstasks/classification.py +15 -12
- mteb/abstasks/clustering.py +20 -16
- mteb/abstasks/clustering_legacy.py +13 -10
- mteb/abstasks/image/image_text_pair_classification.py +7 -4
- mteb/abstasks/multilabel_classification.py +33 -22
- mteb/abstasks/pair_classification.py +27 -11
- mteb/abstasks/regression.py +4 -4
- mteb/abstasks/retrieval.py +28 -24
- mteb/abstasks/retrieval_dataset_loaders.py +2 -2
- mteb/abstasks/sts.py +14 -4
- mteb/abstasks/task_metadata.py +32 -33
- mteb/abstasks/text/bitext_mining.py +39 -28
- mteb/abstasks/text/reranking.py +8 -6
- mteb/abstasks/text/summarization.py +10 -5
- mteb/abstasks/zeroshot_classification.py +8 -4
- mteb/benchmarks/_create_table.py +84 -37
- mteb/benchmarks/benchmark.py +77 -16
- mteb/benchmarks/benchmarks/__init__.py +12 -0
- mteb/benchmarks/benchmarks/benchmarks.py +361 -16
- mteb/benchmarks/get_benchmark.py +14 -53
- mteb/cache.py +227 -37
- mteb/cli/_display_tasks.py +2 -2
- mteb/cli/build_cli.py +110 -14
- mteb/cli/generate_model_card.py +43 -23
- mteb/deprecated_evaluator.py +71 -62
- mteb/descriptive_stats/BitextMining/RuSciBenchBitextMining.v2.json +61 -0
- mteb/descriptive_stats/Classification/HebrewSentimentAnalysis.v3.json +60 -0
- mteb/descriptive_stats/Classification/TurkishConstitutionalCourtViolation.json +54 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3ComputerScienceRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3EnergyRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceEnRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3FinanceFrRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3HrRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3IndustrialRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3NuclearRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PharmaceuticalsRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3PhysicsRetrieval.json +214 -0
- mteb/descriptive_stats/Image/DocumentUnderstanding/Vidore3TelecomRetrieval.json +214 -0
- mteb/descriptive_stats/PairClassification/TERRa.V2.json +35 -0
- mteb/descriptive_stats/Reranking/JQaRARerankingLite.json +35 -0
- mteb/descriptive_stats/Reranking/JaCWIRRerankingLite.json +35 -0
- mteb/descriptive_stats/Reranking/MultiLongDocReranking.json +466 -0
- mteb/descriptive_stats/Retrieval/ArguAna-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
- mteb/descriptive_stats/Retrieval/JaCWIRRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/JaqketRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/MIRACLJaRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/MrTyDiJaRetrievalLite.json +30 -0
- mteb/descriptive_stats/Retrieval/NFCorpus-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
- mteb/descriptive_stats/Retrieval/SCIDOCS-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/SQuADKorV1Retrieval.json +30 -0
- mteb/descriptive_stats/Retrieval/SciFact-NL.v2.json +30 -0
- mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
- mteb/evaluate.py +106 -75
- mteb/filter_tasks.py +25 -26
- mteb/get_tasks.py +29 -30
- mteb/languages/language_scripts.py +5 -3
- mteb/leaderboard/app.py +414 -151
- mteb/leaderboard/benchmark_selector.py +14 -5
- mteb/leaderboard/figures.py +13 -15
- mteb/leaderboard/table.py +82 -17
- mteb/load_results.py +12 -12
- mteb/models/__init__.py +4 -1
- mteb/models/abs_encoder.py +31 -23
- mteb/models/cache_wrappers/__init__.py +2 -1
- mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +7 -6
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
- mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
- mteb/models/cache_wrappers/cache_wrapper.py +3 -3
- mteb/models/get_model_meta.py +25 -118
- mteb/models/instruct_wrapper.py +33 -9
- mteb/models/model_implementations/align_models.py +8 -1
- mteb/models/model_implementations/amazon_models.py +1 -0
- mteb/models/model_implementations/andersborges.py +65 -0
- mteb/models/model_implementations/ara_models.py +9 -1
- mteb/models/model_implementations/arctic_models.py +16 -8
- mteb/models/model_implementations/b1ade_models.py +2 -1
- mteb/models/model_implementations/bedrock_models.py +4 -0
- mteb/models/model_implementations/bge_models.py +101 -17
- mteb/models/model_implementations/bica_model.py +35 -0
- mteb/models/model_implementations/blip2_models.py +13 -2
- mteb/models/model_implementations/blip_models.py +43 -16
- mteb/models/model_implementations/bm25.py +5 -4
- mteb/models/model_implementations/bmretriever_models.py +10 -4
- mteb/models/model_implementations/cadet_models.py +10 -1
- mteb/models/model_implementations/cde_models.py +25 -4
- mteb/models/model_implementations/clip_models.py +9 -6
- mteb/models/model_implementations/clips_models.py +100 -0
- mteb/models/model_implementations/codefuse_models.py +165 -3
- mteb/models/model_implementations/codesage_models.py +18 -3
- mteb/models/model_implementations/cohere_models.py +13 -6
- mteb/models/model_implementations/cohere_v.py +7 -2
- mteb/models/model_implementations/colpali_models.py +17 -9
- mteb/models/model_implementations/colqwen_models.py +275 -5
- mteb/models/model_implementations/colsmol_models.py +4 -2
- mteb/models/model_implementations/conan_models.py +2 -1
- mteb/models/model_implementations/dino_models.py +194 -23
- mteb/models/model_implementations/e5_instruct.py +27 -4
- mteb/models/model_implementations/e5_models.py +21 -110
- mteb/models/model_implementations/e5_v.py +7 -6
- mteb/models/model_implementations/eagerworks_models.py +164 -0
- mteb/models/model_implementations/emillykkejensen_models.py +91 -0
- mteb/models/model_implementations/en_code_retriever.py +2 -1
- mteb/models/model_implementations/euler_models.py +32 -0
- mteb/models/model_implementations/evaclip_models.py +4 -0
- mteb/models/model_implementations/fa_models.py +67 -9
- mteb/models/model_implementations/facebookai.py +205 -0
- mteb/models/model_implementations/geogpt_models.py +2 -1
- mteb/models/model_implementations/gme_v_models.py +17 -10
- mteb/models/model_implementations/google_models.py +17 -6
- mteb/models/model_implementations/granite_vision_embedding_models.py +8 -3
- mteb/models/model_implementations/gritlm_models.py +4 -2
- mteb/models/model_implementations/gte_models.py +99 -9
- mteb/models/model_implementations/hinvec_models.py +2 -1
- mteb/models/model_implementations/human.py +1 -0
- mteb/models/model_implementations/ibm_granite_models.py +36 -6
- mteb/models/model_implementations/inf_models.py +4 -2
- mteb/models/model_implementations/jasper_models.py +256 -3
- mteb/models/model_implementations/jina_clip.py +49 -10
- mteb/models/model_implementations/jina_models.py +222 -11
- mteb/models/model_implementations/kalm_models.py +203 -25
- mteb/models/model_implementations/kblab.py +37 -0
- mteb/models/model_implementations/kennethenevoldsen_models.py +74 -0
- mteb/models/model_implementations/kfst.py +25 -0
- mteb/models/model_implementations/kowshik24_models.py +32 -0
- mteb/models/model_implementations/lens_models.py +2 -0
- mteb/models/model_implementations/lgai_embedding_models.py +2 -1
- mteb/models/model_implementations/linq_models.py +4 -3
- mteb/models/model_implementations/listconranker.py +2 -2
- mteb/models/model_implementations/llm2clip_models.py +9 -6
- mteb/models/model_implementations/llm2vec_models.py +16 -8
- mteb/models/model_implementations/mcinext_models.py +7 -1
- mteb/models/model_implementations/mdbr_models.py +19 -3
- mteb/models/model_implementations/misc_models.py +422 -60
- mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
- mteb/models/model_implementations/mme5_models.py +2 -1
- mteb/models/model_implementations/moco_models.py +15 -4
- mteb/models/model_implementations/mod_models.py +191 -0
- mteb/models/model_implementations/model2vec_models.py +27 -14
- mteb/models/model_implementations/moka_models.py +4 -1
- mteb/models/model_implementations/nbailab.py +70 -0
- mteb/models/model_implementations/no_instruct_sentence_models.py +3 -2
- mteb/models/model_implementations/nomic_models.py +173 -6
- mteb/models/model_implementations/nomic_models_vision.py +8 -3
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +32 -19
- mteb/models/model_implementations/nvidia_models.py +155 -20
- mteb/models/model_implementations/octen_models.py +254 -0
- mteb/models/model_implementations/openai_models.py +20 -16
- mteb/models/model_implementations/openclip_models.py +37 -13
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +10 -5
- mteb/models/model_implementations/ops_moa_models.py +5 -3
- mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
- mteb/models/model_implementations/pawan_models.py +39 -0
- mteb/models/model_implementations/piccolo_models.py +9 -1
- mteb/models/model_implementations/pixie_models.py +56 -0
- mteb/models/model_implementations/promptriever_models.py +12 -8
- mteb/models/model_implementations/pylate_models.py +46 -12
- mteb/models/model_implementations/qodo_models.py +4 -2
- mteb/models/model_implementations/qtack_models.py +2 -1
- mteb/models/model_implementations/qwen3_models.py +9 -6
- mteb/models/model_implementations/qzhou_models.py +5 -3
- mteb/models/model_implementations/random_baseline.py +19 -24
- mteb/models/model_implementations/rasgaard_models.py +34 -0
- mteb/models/model_implementations/reasonir_model.py +2 -1
- mteb/models/model_implementations/repllama_models.py +5 -3
- mteb/models/model_implementations/rerankers_custom.py +15 -9
- mteb/models/model_implementations/rerankers_monot5_based.py +31 -31
- mteb/models/model_implementations/richinfoai_models.py +2 -1
- mteb/models/model_implementations/ru_sentence_models.py +71 -20
- mteb/models/model_implementations/ruri_models.py +322 -0
- mteb/models/model_implementations/salesforce_models.py +6 -3
- mteb/models/model_implementations/samilpwc_models.py +2 -1
- mteb/models/model_implementations/sarashina_embedding_models.py +168 -0
- mteb/models/model_implementations/searchmap_models.py +2 -1
- mteb/models/model_implementations/seed_1_6_embedding_models.py +8 -2
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +625 -0
- mteb/models/model_implementations/seed_models.py +1 -0
- mteb/models/model_implementations/sentence_transformers_models.py +177 -18
- mteb/models/model_implementations/shuu_model.py +32 -31
- mteb/models/model_implementations/siglip_models.py +30 -20
- mteb/models/model_implementations/slm_models.py +416 -0
- mteb/models/model_implementations/sonar_models.py +1 -0
- mteb/models/model_implementations/spartan8806_atles_champion.py +34 -0
- mteb/models/model_implementations/stella_models.py +23 -4
- mteb/models/model_implementations/tarka_models.py +376 -0
- mteb/models/model_implementations/text2vec_models.py +9 -3
- mteb/models/model_implementations/ua_sentence_models.py +11 -1
- mteb/models/model_implementations/uae_models.py +8 -1
- mteb/models/model_implementations/vdr_models.py +3 -1
- mteb/models/model_implementations/vi_vn_models.py +45 -6
- mteb/models/model_implementations/vista_models.py +2 -0
- mteb/models/model_implementations/vlm2vec_models.py +5 -3
- mteb/models/model_implementations/voyage_models.py +99 -0
- mteb/models/model_implementations/voyage_v.py +17 -9
- mteb/models/model_implementations/xyz_models.py +1 -0
- mteb/models/model_implementations/youtu_models.py +2 -1
- mteb/models/model_implementations/yuan_models.py +34 -0
- mteb/models/model_implementations/yuan_models_en.py +58 -0
- mteb/models/model_meta.py +498 -29
- mteb/models/models_protocols.py +22 -6
- mteb/models/search_encoder_index/__init__.py +7 -0
- mteb/models/search_encoder_index/search_backend_protocol.py +50 -0
- mteb/models/search_encoder_index/search_indexes/__init__.py +5 -0
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +160 -0
- mteb/models/search_wrappers.py +197 -65
- mteb/models/sentence_transformer_wrapper.py +52 -32
- mteb/models/vllm_wrapper.py +327 -0
- mteb/py.typed +0 -0
- mteb/results/benchmark_results.py +114 -65
- mteb/results/model_result.py +63 -26
- mteb/results/task_result.py +117 -77
- mteb/similarity_functions.py +60 -7
- mteb/tasks/bitext_mining/multilingual/__init__.py +2 -1
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +4 -2
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining_fast.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ru_sci_bench_bitext_mining.py +47 -5
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -6
- mteb/tasks/classification/ara/ajgt.py +1 -2
- mteb/tasks/classification/ara/hotel_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ara/tweet_emotion_classification.py +1 -2
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_document_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -2
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -2
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +1 -2
- mteb/tasks/classification/dan/angry_tweets_classification.py +1 -2
- mteb/tasks/classification/dan/danish_political_comments_classification.py +1 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -2
- mteb/tasks/classification/dan/dk_hate_classification.py +2 -3
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -2
- mteb/tasks/classification/deu/ten_k_gnad_classification.py +1 -2
- mteb/tasks/classification/eng/amazon_polarity_classification.py +1 -2
- mteb/tasks/classification/eng/arxiv_classification.py +1 -2
- mteb/tasks/classification/eng/banking77_classification.py +1 -2
- mteb/tasks/classification/eng/dbpedia_classification.py +1 -2
- mteb/tasks/classification/eng/emotion_classification.py +1 -2
- mteb/tasks/classification/eng/financial_phrasebank_classification.py +1 -2
- mteb/tasks/classification/eng/frenk_en_classification.py +1 -2
- mteb/tasks/classification/eng/gtsrb_classification.py +1 -1
- mteb/tasks/classification/eng/imdb_classification.py +1 -2
- mteb/tasks/classification/eng/legal_bench_classification.py +14 -120
- mteb/tasks/classification/eng/news_classification.py +1 -2
- mteb/tasks/classification/eng/patch_camelyon_classification.py +1 -1
- mteb/tasks/classification/eng/patent_classification.py +1 -2
- mteb/tasks/classification/eng/poem_sentiment_classification.py +1 -2
- mteb/tasks/classification/eng/sds_eye_protection_classification.py +1 -2
- mteb/tasks/classification/eng/sds_gloves_classification.py +1 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -19
- mteb/tasks/classification/eng/toxic_conversations_classification.py +1 -2
- mteb/tasks/classification/eng/tweet_sentiment_extraction_classification.py +1 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +2 -13
- mteb/tasks/classification/eng/ucf101_classification.py +1 -5
- mteb/tasks/classification/eng/wikipedia_bio_met_chem_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_chem_fields_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_comp_chem_spectroscopy_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_crystallography_analytical_classification.py +1 -2
- mteb/tasks/classification/eng/wikipedia_theoretical_applied_classification.py +1 -2
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -2
- mteb/tasks/classification/eng/yelp_review_full_classification.py +1 -2
- mteb/tasks/classification/est/estonian_valence.py +2 -3
- mteb/tasks/classification/fas/fa_mteb_classification.py +7 -14
- mteb/tasks/classification/fil/filipino_hate_speech_classification.py +1 -2
- mteb/tasks/classification/fin/fin_toxicity_classification.py +2 -11
- mteb/tasks/classification/fra/french_book_reviews.py +1 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -2
- mteb/tasks/classification/heb/__init__.py +6 -1
- mteb/tasks/classification/heb/hebrew_sentiment_analysis.py +62 -4
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -2
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -2
- mteb/tasks/classification/hrv/frenk_hr_classification.py +1 -2
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +1 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -2
- mteb/tasks/classification/ita/italian_linguist_acceptability_classification.py +1 -2
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -2
- mteb/tasks/classification/jpn/wrime_classification.py +1 -2
- mteb/tasks/classification/kan/kannada_news_classification.py +1 -2
- mteb/tasks/classification/kor/klue_tc.py +1 -2
- mteb/tasks/classification/kor/kor_hate_classification.py +2 -17
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +2 -19
- mteb/tasks/classification/kur/kurdish_sentiment_classification.py +3 -4
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -2
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -2
- mteb/tasks/classification/mkd/macedonian_tweet_sentiment_classification.py +1 -2
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -6
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -4
- mteb/tasks/classification/multilingual/ru_sci_bench_classification.py +4 -23
- mteb/tasks/classification/multilingual/scala_classification.py +2 -3
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -6
- mteb/tasks/classification/mya/myanmar_news.py +1 -2
- mteb/tasks/classification/nep/nepali_news_classification.py +1 -2
- mteb/tasks/classification/nld/dutch_book_review_sentiment_classification.py +4 -2
- mteb/tasks/classification/nld/dutch_cola_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_government_bias_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_news_articles_classification.py +3 -0
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +3 -0
- mteb/tasks/classification/nld/iconclass_classification.py +3 -0
- mteb/tasks/classification/nld/open_tender_classification.py +3 -0
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +3 -0
- mteb/tasks/classification/nob/no_rec_classification.py +1 -2
- mteb/tasks/classification/nob/norwegian_parliament_classification.py +1 -2
- mteb/tasks/classification/ory/odia_news_classification.py +1 -2
- mteb/tasks/classification/pol/polish_classification.py +3 -6
- mteb/tasks/classification/ron/moroco.py +1 -2
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -2
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -2
- mteb/tasks/classification/rus/georeview_classification.py +1 -2
- mteb/tasks/classification/rus/headline_classification.py +1 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +1 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +1 -2
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -2
- mteb/tasks/classification/rus/senti_ru_eval.py +1 -2
- mteb/tasks/classification/sin/sinhala_news_classification.py +1 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +1 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/slk/slovak_hate_speech_classification.py +1 -2
- mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +1 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -2
- mteb/tasks/classification/spa/spanish_news_classification.py +1 -2
- mteb/tasks/classification/spa/spanish_sentiment_classification.py +1 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -2
- mteb/tasks/classification/swa/swahili_news_classification.py +1 -2
- mteb/tasks/classification/swe/dalaj_classification.py +1 -2
- mteb/tasks/classification/swe/swe_rec_classification.py +1 -2
- mteb/tasks/classification/swe/swedish_sentiment_classification.py +1 -2
- mteb/tasks/classification/tam/tamil_news_classification.py +1 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +1 -2
- mteb/tasks/classification/tha/wisesight_sentiment_classification.py +1 -2
- mteb/tasks/classification/tsn/tswana_news_classification.py +1 -2
- mteb/tasks/classification/tur/__init__.py +4 -0
- mteb/tasks/classification/tur/turkish_constitutional_court.py +41 -0
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +1 -2
- mteb/tasks/classification/tur/turkish_product_sentiment_classification.py +1 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -15
- mteb/tasks/classification/urd/urdu_roman_sentiment_classification.py +1 -2
- mteb/tasks/classification/vie/amazon_counterfactual_vn_classification.py +1 -6
- mteb/tasks/classification/vie/amazon_polarity_vn_classification.py +1 -6
- mteb/tasks/classification/vie/amazon_reviews_vn_classification.py +1 -5
- mteb/tasks/classification/vie/banking77_vn_classification.py +1 -5
- mteb/tasks/classification/vie/emotion_vn_classification.py +1 -5
- mteb/tasks/classification/vie/imdb_vn_classification.py +1 -5
- mteb/tasks/classification/vie/massive_intent_vn_classification.py +1 -5
- mteb/tasks/classification/vie/massive_scenario_vn_classification.py +1 -5
- mteb/tasks/classification/vie/mtop_domain_vn_classification.py +1 -5
- mteb/tasks/classification/vie/mtop_intent_vn_classification.py +1 -5
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -5
- mteb/tasks/classification/vie/tweet_sentiment_extraction_vn_classification.py +1 -5
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -2
- mteb/tasks/classification/zho/cmteb_classification.py +5 -10
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +1 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -2
- mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
- mteb/tasks/clustering/jpn/mews_c16_ja_clustering.py +1 -3
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -6
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/open_tender_clustering_s2s.py +3 -0
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +3 -0
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +3 -0
- mteb/tasks/clustering/vie/reddit_clustering_p2p_vn.py +1 -5
- mteb/tasks/clustering/vie/reddit_clustering_vn.py +1 -5
- mteb/tasks/clustering/vie/stack_exchange_clustering_p2p_vn.py +1 -5
- mteb/tasks/clustering/vie/stack_exchange_clustering_vn.py +1 -5
- mteb/tasks/clustering/vie/twenty_newsgroups_clustering_vn.py +1 -5
- mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -5
- mteb/tasks/multilabel_classification/kor/kor_hate_speech_ml_classification.py +1 -9
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -6
- mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +3 -0
- mteb/tasks/multilabel_classification/nld/vabb_multi_label_classification.py +3 -0
- mteb/tasks/multilabel_classification/por/brazilian_toxic_tweets_classification.py +1 -6
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -2
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -6
- mteb/tasks/pair_classification/eng/legal_bench_pc.py +1 -9
- mteb/tasks/pair_classification/nld/sick_nl_pair_classification.py +3 -0
- mteb/tasks/pair_classification/nld/xlwic_nl_pair_classification.py +3 -0
- mteb/tasks/pair_classification/rus/__init__.py +2 -2
- mteb/tasks/pair_classification/rus/terra.py +51 -25
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -5
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -5
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -5
- mteb/tasks/regression/multilingual/ru_sci_bench_regression.py +2 -6
- mteb/tasks/reranking/jpn/__init__.py +9 -1
- mteb/tasks/reranking/jpn/j_qa_ra_reranking_lite.py +49 -0
- mteb/tasks/reranking/jpn/ja_cwir_reranking_lite.py +47 -0
- mteb/tasks/reranking/multilingual/__init__.py +2 -0
- mteb/tasks/reranking/multilingual/multi_long_doc_reranking.py +70 -0
- mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
- mteb/tasks/reranking/multilingual/x_glue_wpr_reranking.py +1 -2
- mteb/tasks/reranking/vie/ask_ubuntu_dup_questions_vn.py +1 -5
- mteb/tasks/reranking/vie/sci_docs_reranking_vn.py +1 -5
- mteb/tasks/reranking/vie/stack_overflow_dup_questions_vn.py +1 -5
- mteb/tasks/retrieval/code/code_rag.py +12 -12
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +8 -5
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/eng/__init__.py +2 -0
- mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
- mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -8
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +4 -0
- mteb/tasks/retrieval/jpn/__init__.py +8 -0
- mteb/tasks/retrieval/jpn/ja_cwir_retrieval.py +1 -4
- mteb/tasks/retrieval/jpn/ja_cwir_retrieval_lite.py +47 -0
- mteb/tasks/retrieval/jpn/jaqket_retrieval_lite.py +50 -0
- mteb/tasks/retrieval/jpn/miracl_ja_retrieval_lite.py +52 -0
- mteb/tasks/retrieval/jpn/mr_tydi_ja_retrieval_lite.py +48 -0
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +11 -4
- mteb/tasks/retrieval/kor/__init__.py +16 -1
- mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
- mteb/tasks/retrieval/kor/squad_kor_v1_retrieval.py +47 -0
- mteb/tasks/retrieval/multilingual/__init__.py +24 -0
- mteb/tasks/retrieval/multilingual/belebele_retrieval.py +5 -4
- mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +56 -42
- mteb/tasks/retrieval/multilingual/mkqa_retrieval.py +1 -2
- mteb/tasks/retrieval/multilingual/mlqa_retrieval.py +1 -4
- mteb/tasks/retrieval/multilingual/multi_long_doc_retrieval.py +1 -2
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +9 -4
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -12
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +4 -2
- mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +389 -0
- mteb/tasks/retrieval/nld/__init__.py +8 -4
- mteb/tasks/retrieval/nld/argu_ana_nl_retrieval.py +46 -27
- mteb/tasks/retrieval/nld/bbsard_nl_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/dutch_news_articles_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/legal_qa_nl_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/nf_corpus_nl_retrieval.py +42 -25
- mteb/tasks/retrieval/nld/open_tender_retrieval.py +3 -0
- mteb/tasks/retrieval/nld/sci_fact_nl_retrieval.py +42 -24
- mteb/tasks/retrieval/nld/scidocsnl_retrieval.py +44 -27
- mteb/tasks/retrieval/nld/vabb_retrieval.py +3 -0
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -7
- mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
- mteb/tasks/retrieval/vie/__init__.py +14 -6
- mteb/tasks/retrieval/vie/argu_ana_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +40 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_android_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_gis_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_mathematica_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_physics_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_programmers_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_stats_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_tex_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_unix_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_webmasters_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/cqa_dupstack_wordpress_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +40 -5
- mteb/tasks/retrieval/vie/fevervn_retrieval.py +40 -7
- mteb/tasks/retrieval/vie/fi_qa2018_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/green_node_table_markdown_retrieval.py +16 -1
- mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +40 -6
- mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +49 -5
- mteb/tasks/retrieval/vie/nf_corpus_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/nqvn_retrieval.py +40 -5
- mteb/tasks/retrieval/vie/quora_vn_retrieval.py +1 -6
- mteb/tasks/retrieval/vie/sci_fact_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/scidocsvn_retrieval.py +1 -6
- mteb/tasks/retrieval/vie/touche2020_vn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/treccovidvn_retrieval.py +1 -5
- mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
- mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
- mteb/tasks/sts/nld/sick_nl_sts.py +1 -0
- mteb/tasks/sts/vie/biosses_stsvn.py +1 -5
- mteb/tasks/sts/vie/sickr_stsvn.py +1 -5
- mteb/tasks/sts/vie/sts_benchmark_stsvn.py +1 -5
- mteb/tasks/zeroshot_classification/eng/gtsrb.py +1 -1
- mteb/tasks/zeroshot_classification/eng/patch_camelyon.py +1 -1
- mteb/tasks/zeroshot_classification/eng/ucf101.py +1 -5
- mteb/types/__init__.py +2 -0
- mteb/types/_encoder_io.py +19 -2
- mteb/types/_result.py +2 -1
- mteb/types/statistics.py +9 -3
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/METADATA +25 -8
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/RECORD +525 -438
- mteb/models/model_implementations/mxbai_models.py +0 -102
- mteb/models/model_implementations/nb_sbert.py +0 -25
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.1.4.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
mteb/models/search_wrappers.py
CHANGED
|
@@ -14,6 +14,7 @@ from mteb.types import (
|
|
|
14
14
|
Array,
|
|
15
15
|
BatchedInput,
|
|
16
16
|
CorpusDatasetType,
|
|
17
|
+
EncodeKwargs,
|
|
17
18
|
PromptType,
|
|
18
19
|
QueryDatasetType,
|
|
19
20
|
RetrievalOutputType,
|
|
@@ -21,6 +22,7 @@ from mteb.types import (
|
|
|
21
22
|
)
|
|
22
23
|
|
|
23
24
|
from .models_protocols import CrossEncoderProtocol, EncoderProtocol
|
|
25
|
+
from .search_encoder_index.search_backend_protocol import IndexEncoderSearchProtocol
|
|
24
26
|
|
|
25
27
|
logger = logging.getLogger(__name__)
|
|
26
28
|
|
|
@@ -28,13 +30,19 @@ logger = logging.getLogger(__name__)
|
|
|
28
30
|
class SearchEncoderWrapper:
|
|
29
31
|
"""Wrapper for Encoder models to be used in search tasks."""
|
|
30
32
|
|
|
31
|
-
corpus_chunk_size = 50_000
|
|
32
33
|
task_corpus: CorpusDatasetType | None
|
|
33
34
|
|
|
34
|
-
def __init__(
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model: EncoderProtocol,
|
|
38
|
+
corpus_chunk_size: int = 50_000,
|
|
39
|
+
index_backend: IndexEncoderSearchProtocol | None = None,
|
|
40
|
+
) -> None:
|
|
35
41
|
self.model = model
|
|
36
42
|
self.task_corpus = None
|
|
37
43
|
self.mteb_model_meta = model.mteb_model_meta
|
|
44
|
+
self.corpus_chunk_size = corpus_chunk_size
|
|
45
|
+
self.index_backend = index_backend
|
|
38
46
|
|
|
39
47
|
def index(
|
|
40
48
|
self,
|
|
@@ -43,7 +51,7 @@ class SearchEncoderWrapper:
|
|
|
43
51
|
task_metadata: TaskMetadata,
|
|
44
52
|
hf_split: str,
|
|
45
53
|
hf_subset: str,
|
|
46
|
-
encode_kwargs:
|
|
54
|
+
encode_kwargs: EncodeKwargs,
|
|
47
55
|
) -> None:
|
|
48
56
|
"""Index the corpus for retrieval.
|
|
49
57
|
|
|
@@ -56,6 +64,22 @@ class SearchEncoderWrapper:
|
|
|
56
64
|
"""
|
|
57
65
|
# Always retain corpus for potential reranking or fallback flows
|
|
58
66
|
self.task_corpus = corpus
|
|
67
|
+
if self.index_backend is not None:
|
|
68
|
+
all_doc_embeddings = self.model.encode(
|
|
69
|
+
create_dataloader(
|
|
70
|
+
corpus,
|
|
71
|
+
task_metadata,
|
|
72
|
+
prompt_type=PromptType.document,
|
|
73
|
+
**encode_kwargs,
|
|
74
|
+
),
|
|
75
|
+
task_metadata=task_metadata,
|
|
76
|
+
hf_split=hf_split,
|
|
77
|
+
hf_subset=hf_subset,
|
|
78
|
+
prompt_type=PromptType.document,
|
|
79
|
+
**encode_kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.index_backend.add_documents(all_doc_embeddings, corpus["id"])
|
|
59
83
|
|
|
60
84
|
def search(
|
|
61
85
|
self,
|
|
@@ -65,7 +89,7 @@ class SearchEncoderWrapper:
|
|
|
65
89
|
hf_split: str,
|
|
66
90
|
hf_subset: str,
|
|
67
91
|
top_k: int,
|
|
68
|
-
encode_kwargs:
|
|
92
|
+
encode_kwargs: EncodeKwargs,
|
|
69
93
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
70
94
|
) -> RetrievalOutputType:
|
|
71
95
|
"""Search the corpus for the given queries.
|
|
@@ -90,7 +114,7 @@ class SearchEncoderWrapper:
|
|
|
90
114
|
queries,
|
|
91
115
|
task_metadata,
|
|
92
116
|
prompt_type=PromptType.query,
|
|
93
|
-
|
|
117
|
+
**encode_kwargs,
|
|
94
118
|
)
|
|
95
119
|
|
|
96
120
|
query_embeddings = self.model.encode(
|
|
@@ -105,32 +129,79 @@ class SearchEncoderWrapper:
|
|
|
105
129
|
|
|
106
130
|
if top_ranked is not None:
|
|
107
131
|
logger.info("Reranking pre-ranked documents...")
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
132
|
+
if self.index_backend is None:
|
|
133
|
+
result_heaps = self._rerank_documents(
|
|
134
|
+
query_idx_to_id=query_idx_to_id,
|
|
135
|
+
query_embeddings=query_embeddings,
|
|
136
|
+
top_ranked=top_ranked,
|
|
137
|
+
top_k=top_k,
|
|
138
|
+
task_metadata=task_metadata,
|
|
139
|
+
hf_subset=hf_subset,
|
|
140
|
+
hf_split=hf_split,
|
|
141
|
+
encode_kwargs=encode_kwargs,
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
cos_scores_top_k_values, cos_scores_top_k_idx = (
|
|
145
|
+
self.index_backend.search(
|
|
146
|
+
query_embeddings,
|
|
147
|
+
top_k,
|
|
148
|
+
similarity_fn=self.model.similarity,
|
|
149
|
+
top_ranked=top_ranked,
|
|
150
|
+
query_idx_to_id=query_idx_to_id,
|
|
151
|
+
)
|
|
152
|
+
)
|
|
153
|
+
result_heaps = {qid: [] for qid in query_idx_to_id.values()}
|
|
154
|
+
for query_itr in range(len(query_embeddings)):
|
|
155
|
+
result_heaps = self._rerank_sort_results(
|
|
156
|
+
result_heaps=result_heaps,
|
|
157
|
+
query_id=query_idx_to_id[query_itr],
|
|
158
|
+
ranked_ids=top_ranked[query_idx_to_id[query_itr]],
|
|
159
|
+
scores_top_k_idx=torch.tensor(
|
|
160
|
+
[cos_scores_top_k_idx[query_itr]]
|
|
161
|
+
),
|
|
162
|
+
scores_top_k_values=torch.tensor(
|
|
163
|
+
[cos_scores_top_k_values[query_itr]]
|
|
164
|
+
),
|
|
165
|
+
)
|
|
166
|
+
self.index_backend.clear()
|
|
118
167
|
else:
|
|
119
168
|
logger.info("Performing full corpus search...")
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
169
|
+
if self.index_backend is None:
|
|
170
|
+
result_heaps = self._full_corpus_search(
|
|
171
|
+
query_idx_to_id=query_idx_to_id,
|
|
172
|
+
query_embeddings=query_embeddings,
|
|
173
|
+
task_metadata=task_metadata,
|
|
174
|
+
hf_subset=hf_subset,
|
|
175
|
+
hf_split=hf_split,
|
|
176
|
+
top_k=top_k,
|
|
177
|
+
encode_kwargs=encode_kwargs,
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
cos_scores_top_k_values, cos_scores_top_k_idx = (
|
|
181
|
+
self.index_backend.search(
|
|
182
|
+
query_embeddings,
|
|
183
|
+
top_k,
|
|
184
|
+
similarity_fn=self.model.similarity,
|
|
185
|
+
top_ranked=None,
|
|
186
|
+
query_idx_to_id=None,
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
result_heaps = {qid: [] for qid in query_idx_to_id.values()}
|
|
190
|
+
result_heaps = self._sort_full_corpus_results(
|
|
191
|
+
result_heaps=result_heaps,
|
|
192
|
+
query_idx_to_id=query_idx_to_id,
|
|
193
|
+
query_embeddings=query_embeddings,
|
|
194
|
+
cos_scores_top_k_idx=cos_scores_top_k_idx,
|
|
195
|
+
cos_scores_top_k_values=cos_scores_top_k_values,
|
|
196
|
+
sub_corpus_ids=self.task_corpus["id"],
|
|
197
|
+
top_k=top_k,
|
|
198
|
+
)
|
|
199
|
+
self.index_backend.clear()
|
|
129
200
|
|
|
130
201
|
# Reset the task corpus dataloader to None to free up memory
|
|
131
202
|
self.task_corpus = None
|
|
132
203
|
|
|
133
|
-
results = {qid: {} for qid in query_idx_to_id.values()}
|
|
204
|
+
results: RetrievalOutputType = {qid: {} for qid in query_idx_to_id.values()}
|
|
134
205
|
for qid in result_heaps:
|
|
135
206
|
for score, corpus_id in result_heaps[qid]:
|
|
136
207
|
results[qid][corpus_id] = score
|
|
@@ -145,16 +216,22 @@ class SearchEncoderWrapper:
|
|
|
145
216
|
hf_subset: str,
|
|
146
217
|
hf_split: str,
|
|
147
218
|
top_k: int,
|
|
148
|
-
encode_kwargs:
|
|
219
|
+
encode_kwargs: EncodeKwargs,
|
|
149
220
|
) -> dict[str, list[tuple[float, str]]]:
|
|
150
|
-
logger.info("Encoding Corpus in batches
|
|
221
|
+
logger.info("Encoding Corpus in batches (this might take a while)...")
|
|
222
|
+
if self.task_corpus is None:
|
|
223
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
224
|
+
|
|
151
225
|
itr = range(0, len(self.task_corpus), self.corpus_chunk_size)
|
|
152
226
|
|
|
153
|
-
result_heaps
|
|
227
|
+
result_heaps: dict[str, list[tuple[float, str]]] = {
|
|
228
|
+
qid: [] for qid in query_idx_to_id.values()
|
|
229
|
+
}
|
|
154
230
|
for batch_num, corpus_start_idx in enumerate(itr):
|
|
155
231
|
logger.info(f"Encoding Batch {batch_num + 1}/{len(itr)}...")
|
|
156
232
|
corpus_end_idx = min(
|
|
157
|
-
corpus_start_idx + self.corpus_chunk_size,
|
|
233
|
+
corpus_start_idx + self.corpus_chunk_size,
|
|
234
|
+
len(self.task_corpus),
|
|
158
235
|
)
|
|
159
236
|
sub_corpus = self.task_corpus.select(
|
|
160
237
|
range(corpus_start_idx, corpus_end_idx)
|
|
@@ -165,7 +242,7 @@ class SearchEncoderWrapper:
|
|
|
165
242
|
sub_corpus,
|
|
166
243
|
task_metadata,
|
|
167
244
|
prompt_type=PromptType.document,
|
|
168
|
-
|
|
245
|
+
**encode_kwargs,
|
|
169
246
|
),
|
|
170
247
|
task_metadata=task_metadata,
|
|
171
248
|
hf_split=hf_split,
|
|
@@ -179,8 +256,8 @@ class SearchEncoderWrapper:
|
|
|
179
256
|
scores = self.model.similarity(query_embeddings, sub_corpus_embeddings)
|
|
180
257
|
|
|
181
258
|
# get top-k values
|
|
182
|
-
|
|
183
|
-
torch.
|
|
259
|
+
cos_scores_top_k_values_tensor, cos_scores_top_k_idx_tensor = torch.topk(
|
|
260
|
+
torch.as_tensor(scores),
|
|
184
261
|
min(
|
|
185
262
|
top_k + 1,
|
|
186
263
|
len(scores[1]) if len(scores) > 1 else len(scores[-1]),
|
|
@@ -188,22 +265,49 @@ class SearchEncoderWrapper:
|
|
|
188
265
|
dim=1,
|
|
189
266
|
largest=True,
|
|
190
267
|
)
|
|
191
|
-
cos_scores_top_k_idx =
|
|
192
|
-
cos_scores_top_k_values =
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
268
|
+
cos_scores_top_k_idx = cos_scores_top_k_idx_tensor.cpu().tolist()
|
|
269
|
+
cos_scores_top_k_values = cos_scores_top_k_values_tensor.cpu().tolist()
|
|
270
|
+
|
|
271
|
+
sub_corpus_ids = list(sub_corpus_ids)
|
|
272
|
+
result_heaps = self._sort_full_corpus_results(
|
|
273
|
+
result_heaps=result_heaps,
|
|
274
|
+
query_idx_to_id=query_idx_to_id,
|
|
275
|
+
query_embeddings=query_embeddings,
|
|
276
|
+
cos_scores_top_k_idx=cos_scores_top_k_idx,
|
|
277
|
+
cos_scores_top_k_values=cos_scores_top_k_values,
|
|
278
|
+
sub_corpus_ids=sub_corpus_ids,
|
|
279
|
+
top_k=top_k,
|
|
280
|
+
)
|
|
281
|
+
return result_heaps
|
|
282
|
+
|
|
283
|
+
def _sort_full_corpus_results(
|
|
284
|
+
self,
|
|
285
|
+
result_heaps: dict[str, list[tuple[float, str]]],
|
|
286
|
+
query_idx_to_id: dict[int, str],
|
|
287
|
+
query_embeddings: Array,
|
|
288
|
+
cos_scores_top_k_idx: list[list[int]],
|
|
289
|
+
cos_scores_top_k_values: list[list[float]],
|
|
290
|
+
sub_corpus_ids: list[str],
|
|
291
|
+
top_k: int,
|
|
292
|
+
) -> dict[str, list[tuple[float, str]]]:
|
|
293
|
+
"""Sort the heaps into descending order lists.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
A dictionary mapping query IDs to a sorted list of tuples, each containing a relevance score and a document ID.
|
|
297
|
+
"""
|
|
298
|
+
for query_itr in range(len(query_embeddings)):
|
|
299
|
+
query_id = query_idx_to_id[query_itr]
|
|
300
|
+
for sub_corpus_id, score in zip(
|
|
301
|
+
cos_scores_top_k_idx[query_itr],
|
|
302
|
+
cos_scores_top_k_values[query_itr],
|
|
303
|
+
):
|
|
304
|
+
corpus_id = sub_corpus_ids[sub_corpus_id]
|
|
305
|
+
if len(result_heaps[query_id]) < top_k:
|
|
306
|
+
# push item on the heap
|
|
307
|
+
heapq.heappush(result_heaps[query_id], (score, corpus_id))
|
|
308
|
+
else:
|
|
309
|
+
# If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
|
|
310
|
+
heapq.heappushpop(result_heaps[query_id], (score, corpus_id))
|
|
207
311
|
return result_heaps
|
|
208
312
|
|
|
209
313
|
def _rerank_documents(
|
|
@@ -215,14 +319,18 @@ class SearchEncoderWrapper:
|
|
|
215
319
|
task_metadata: TaskMetadata,
|
|
216
320
|
hf_subset: str,
|
|
217
321
|
hf_split: str,
|
|
218
|
-
encode_kwargs:
|
|
322
|
+
encode_kwargs: EncodeKwargs,
|
|
219
323
|
) -> dict[str, list[tuple[float, str]]]:
|
|
220
324
|
"""Rerank documents based on pre-ranked documents.
|
|
221
325
|
|
|
222
326
|
Returns:
|
|
223
327
|
A dictionary mapping query IDs to a list of tuples, each containing a relevance score and a document ID.
|
|
224
328
|
"""
|
|
225
|
-
|
|
329
|
+
if self.task_corpus is None:
|
|
330
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
331
|
+
result_heaps: dict[str, list[tuple[float, str]]] = {
|
|
332
|
+
qid: [] for qid in query_idx_to_id.values()
|
|
333
|
+
}
|
|
226
334
|
doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
|
|
227
335
|
|
|
228
336
|
all_doc_embeddings = self.model.encode(
|
|
@@ -230,7 +338,7 @@ class SearchEncoderWrapper:
|
|
|
230
338
|
self.task_corpus,
|
|
231
339
|
task_metadata,
|
|
232
340
|
prompt_type=PromptType.document,
|
|
233
|
-
|
|
341
|
+
**encode_kwargs,
|
|
234
342
|
),
|
|
235
343
|
task_metadata=task_metadata,
|
|
236
344
|
hf_split=hf_split,
|
|
@@ -243,7 +351,8 @@ class SearchEncoderWrapper:
|
|
|
243
351
|
for query_idx, query_embedding in enumerate(query_embeddings):
|
|
244
352
|
query_id = query_idx_to_id[query_idx]
|
|
245
353
|
if query_id not in top_ranked:
|
|
246
|
-
|
|
354
|
+
msg = f"No pre-ranked documents found for query {query_id}"
|
|
355
|
+
logger.warning(msg)
|
|
247
356
|
continue
|
|
248
357
|
|
|
249
358
|
ranked_ids = top_ranked[query_id]
|
|
@@ -278,14 +387,34 @@ class SearchEncoderWrapper:
|
|
|
278
387
|
scores_top_k_values = scores_top_k_values.cpu()
|
|
279
388
|
scores_top_k_idx = scores_top_k_idx.cpu()
|
|
280
389
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
390
|
+
result_heaps = self._rerank_sort_results(
|
|
391
|
+
result_heaps=result_heaps,
|
|
392
|
+
query_id=query_id,
|
|
393
|
+
ranked_ids=ranked_ids,
|
|
394
|
+
scores_top_k_idx=scores_top_k_idx,
|
|
395
|
+
scores_top_k_values=scores_top_k_values,
|
|
396
|
+
)
|
|
397
|
+
return result_heaps
|
|
398
|
+
|
|
399
|
+
def _rerank_sort_results(
|
|
400
|
+
self,
|
|
401
|
+
result_heaps: dict[str, list[tuple[float, str]]],
|
|
402
|
+
query_id: str,
|
|
403
|
+
ranked_ids: list[str],
|
|
404
|
+
scores_top_k_idx: torch.Tensor,
|
|
405
|
+
scores_top_k_values: torch.Tensor,
|
|
406
|
+
) -> dict[str, list[tuple[float, str]]]:
|
|
407
|
+
"""Sort the heap into descending order list.
|
|
288
408
|
|
|
409
|
+
Returns:
|
|
410
|
+
A sorted list of tuples, each containing a relevance score and a document ID.
|
|
411
|
+
"""
|
|
412
|
+
for doc_idx, score in zip(
|
|
413
|
+
scores_top_k_idx[0].tolist(),
|
|
414
|
+
scores_top_k_values[0].tolist(),
|
|
415
|
+
):
|
|
416
|
+
corpus_id = ranked_ids[doc_idx]
|
|
417
|
+
heapq.heappush(result_heaps[query_id], (score, corpus_id))
|
|
289
418
|
return result_heaps
|
|
290
419
|
|
|
291
420
|
def encode(
|
|
@@ -342,7 +471,7 @@ class SearchCrossEncoderWrapper:
|
|
|
342
471
|
task_metadata: TaskMetadata,
|
|
343
472
|
hf_split: str,
|
|
344
473
|
hf_subset: str,
|
|
345
|
-
encode_kwargs:
|
|
474
|
+
encode_kwargs: EncodeKwargs,
|
|
346
475
|
) -> None:
|
|
347
476
|
"""Index the corpus for retrieval.
|
|
348
477
|
|
|
@@ -363,7 +492,7 @@ class SearchCrossEncoderWrapper:
|
|
|
363
492
|
hf_split: str,
|
|
364
493
|
hf_subset: str,
|
|
365
494
|
top_k: int,
|
|
366
|
-
encode_kwargs:
|
|
495
|
+
encode_kwargs: EncodeKwargs,
|
|
367
496
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
368
497
|
) -> RetrievalOutputType:
|
|
369
498
|
"""Search the corpus using the given queries.
|
|
@@ -385,6 +514,8 @@ class SearchCrossEncoderWrapper:
|
|
|
385
514
|
raise ValueError(
|
|
386
515
|
"CrossEncoder search requires top_ranked documents for reranking."
|
|
387
516
|
)
|
|
517
|
+
if self.task_corpus is None:
|
|
518
|
+
raise ValueError("Corpus must be indexed before searching.")
|
|
388
519
|
|
|
389
520
|
query_id_to_idx = {row["id"]: i for i, row in enumerate(queries)}
|
|
390
521
|
doc_id_to_idx = {doc["id"]: idx for idx, doc in enumerate(self.task_corpus)}
|
|
@@ -394,7 +525,8 @@ class SearchCrossEncoderWrapper:
|
|
|
394
525
|
doc_pairs_ids: list[tuple[str, str]] = []
|
|
395
526
|
for query_id, corpus_ids in top_ranked.items():
|
|
396
527
|
if query_id not in top_ranked:
|
|
397
|
-
|
|
528
|
+
msg = f"No pre-ranked documents found for query {query_id}"
|
|
529
|
+
logger.warning(msg)
|
|
398
530
|
continue
|
|
399
531
|
|
|
400
532
|
query_idx = query_id_to_idx[query_id]
|
|
@@ -407,13 +539,13 @@ class SearchCrossEncoderWrapper:
|
|
|
407
539
|
Dataset.from_list(total_queries),
|
|
408
540
|
task_metadata,
|
|
409
541
|
prompt_type=PromptType.document,
|
|
410
|
-
|
|
542
|
+
**encode_kwargs,
|
|
411
543
|
)
|
|
412
544
|
corpus_loader = create_dataloader(
|
|
413
545
|
Dataset.from_list(total_docs),
|
|
414
546
|
task_metadata,
|
|
415
547
|
prompt_type=PromptType.document,
|
|
416
|
-
|
|
548
|
+
**encode_kwargs,
|
|
417
549
|
)
|
|
418
550
|
predictions = self.model.predict(
|
|
419
551
|
inputs1=queries_loader,
|
|
@@ -423,7 +555,7 @@ class SearchCrossEncoderWrapper:
|
|
|
423
555
|
hf_subset=hf_subset,
|
|
424
556
|
)
|
|
425
557
|
|
|
426
|
-
results = {qid: {} for qid in queries["id"]}
|
|
558
|
+
results: RetrievalOutputType = {qid: {} for qid in queries["id"]}
|
|
427
559
|
for (query_id, corpus_id), score in zip(doc_pairs_ids, predictions):
|
|
428
560
|
results[query_id][corpus_id] = float(score)
|
|
429
561
|
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
+
import warnings
|
|
4
5
|
from typing import TYPE_CHECKING, Any
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import torch
|
|
8
9
|
from packaging.version import Version
|
|
9
10
|
from torch.utils.data import DataLoader
|
|
11
|
+
from typing_extensions import Unpack
|
|
10
12
|
|
|
11
13
|
from mteb._log_once import LogOnce
|
|
12
14
|
from mteb.models import ModelMeta
|
|
13
|
-
from mteb.types import Array, BatchedInput, PromptType
|
|
15
|
+
from mteb.types import Array, BatchedInput, EncodeKwargs, PromptType
|
|
14
16
|
|
|
15
17
|
from .abs_encoder import AbsEncoder
|
|
16
18
|
|
|
@@ -25,17 +27,18 @@ SENTENCE_TRANSFORMERS_QUERY_ENCODE_VERSION = "5.0.0"
|
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def sentence_transformers_loader(
|
|
28
|
-
model_name: str, revision: str | None = None, **kwargs
|
|
30
|
+
model_name: str, revision: str | None = None, device: str | None = None, **kwargs
|
|
29
31
|
) -> SentenceTransformerEncoderWrapper:
|
|
30
32
|
"""Loads a SentenceTransformer model and wraps it in a SentenceTransformerEncoderWrapper.
|
|
31
33
|
|
|
32
34
|
Args:
|
|
33
35
|
model_name: The name of the SentenceTransformer model to load.
|
|
34
36
|
revision: The revision of the model to load.
|
|
37
|
+
device: The device used to load the model.
|
|
35
38
|
kwargs: Additional arguments to pass to the SentenceTransformer model.
|
|
36
39
|
"""
|
|
37
40
|
return SentenceTransformerEncoderWrapper(
|
|
38
|
-
model=model_name, revision=revision, **kwargs
|
|
41
|
+
model=model_name, revision=revision, device=device, **kwargs
|
|
39
42
|
)
|
|
40
43
|
|
|
41
44
|
|
|
@@ -48,6 +51,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
48
51
|
self,
|
|
49
52
|
model: str | SentenceTransformer,
|
|
50
53
|
revision: str | None = None,
|
|
54
|
+
device: str | None = None,
|
|
51
55
|
model_prompts: dict[str, str] | None = None,
|
|
52
56
|
**kwargs,
|
|
53
57
|
) -> None:
|
|
@@ -56,6 +60,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
56
60
|
Args:
|
|
57
61
|
model: The SentenceTransformer model to use. Can be a string (model name), a SentenceTransformer model, or a CrossEncoder model.
|
|
58
62
|
revision: The revision of the model to use.
|
|
63
|
+
device: The device used to load the model.
|
|
59
64
|
model_prompts: A dictionary mapping task names to prompt names.
|
|
60
65
|
First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt,
|
|
61
66
|
then to the composed prompt of task type + prompt type, then to the specific task type prompt,
|
|
@@ -65,22 +70,21 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
65
70
|
from sentence_transformers import SentenceTransformer
|
|
66
71
|
|
|
67
72
|
if isinstance(model, str):
|
|
68
|
-
self.model = SentenceTransformer(
|
|
73
|
+
self.model = SentenceTransformer(
|
|
74
|
+
model, revision=revision, device=device, **kwargs
|
|
75
|
+
)
|
|
69
76
|
else:
|
|
70
77
|
self.model = model
|
|
71
|
-
from mteb.models.get_model_meta import (
|
|
72
|
-
_model_meta_from_sentence_transformers,
|
|
73
|
-
)
|
|
74
78
|
|
|
75
|
-
self.mteb_model_meta =
|
|
79
|
+
self.mteb_model_meta = ModelMeta.from_sentence_transformer_model(self.model)
|
|
76
80
|
|
|
77
81
|
built_in_prompts = getattr(self.model, "prompts", None)
|
|
78
82
|
if built_in_prompts and not model_prompts:
|
|
79
83
|
model_prompts = built_in_prompts
|
|
80
84
|
elif model_prompts and built_in_prompts:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
)
|
|
85
|
+
msg = f"Model prompts specified, these will overwrite the default model prompts. Current prompts will be:\n {model_prompts}"
|
|
86
|
+
logger.warning(msg)
|
|
87
|
+
warnings.warn(msg)
|
|
84
88
|
self.model.prompts = model_prompts
|
|
85
89
|
|
|
86
90
|
self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name(
|
|
@@ -89,9 +93,9 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
89
93
|
|
|
90
94
|
if invalid_prompts:
|
|
91
95
|
invalid_prompts = "\n".join(invalid_prompts)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
)
|
|
96
|
+
msg = f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}"
|
|
97
|
+
logger.warning(msg)
|
|
98
|
+
warnings.warn(msg)
|
|
95
99
|
|
|
96
100
|
if (
|
|
97
101
|
self.model_prompts
|
|
@@ -101,13 +105,15 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
101
105
|
or PromptType.document.value not in self.model_prompts
|
|
102
106
|
)
|
|
103
107
|
):
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
)
|
|
108
|
+
msg = f"SentenceTransformers that use prompts most often need to be configured with at least 'query' and 'document' prompts to ensure optimal performance. Received {self.model_prompts}"
|
|
109
|
+
logger.warning(msg)
|
|
110
|
+
warnings.warn(msg)
|
|
108
111
|
|
|
112
|
+
def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:
|
|
113
|
+
"""Compute the similarity between two collections of embeddings."""
|
|
109
114
|
if hasattr(self.model, "similarity") and callable(self.model.similarity):
|
|
110
|
-
|
|
115
|
+
return self.model.similarity(embeddings1, embeddings2)
|
|
116
|
+
return super().similarity(embeddings1, embeddings2)
|
|
111
117
|
|
|
112
118
|
def encode(
|
|
113
119
|
self,
|
|
@@ -117,7 +123,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
117
123
|
hf_split: str,
|
|
118
124
|
hf_subset: str,
|
|
119
125
|
prompt_type: PromptType | None = None,
|
|
120
|
-
**kwargs:
|
|
126
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
121
127
|
) -> Array:
|
|
122
128
|
"""Encodes the given sentences using the encoder.
|
|
123
129
|
|
|
@@ -153,7 +159,7 @@ class SentenceTransformerEncoderWrapper(AbsEncoder):
|
|
|
153
159
|
prompt_name = None
|
|
154
160
|
if self.model_prompts is not None:
|
|
155
161
|
prompt_name = self.get_prompt_name(task_metadata, prompt_type)
|
|
156
|
-
prompt = self.model_prompts.get(prompt_name, None)
|
|
162
|
+
prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
|
|
157
163
|
if prompt_name:
|
|
158
164
|
prompt_log = f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
|
|
159
165
|
else:
|
|
@@ -196,7 +202,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
196
202
|
hf_split: str,
|
|
197
203
|
hf_subset: str,
|
|
198
204
|
prompt_type: PromptType | None = None,
|
|
199
|
-
**kwargs:
|
|
205
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
200
206
|
) -> Array:
|
|
201
207
|
"""Encodes the given sentences using the encoder.
|
|
202
208
|
|
|
@@ -224,7 +230,7 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
224
230
|
prompt_name = None
|
|
225
231
|
if self.model_prompts is not None:
|
|
226
232
|
prompt_name = self.get_prompt_name(task_metadata, prompt_type)
|
|
227
|
-
prompt = self.model_prompts.get(prompt_name, None)
|
|
233
|
+
prompt = self.model_prompts.get(prompt_name, None) # type: ignore[arg-type]
|
|
228
234
|
if prompt_name:
|
|
229
235
|
logger.info(
|
|
230
236
|
f"Using {prompt_name=} for task={task_metadata.name} {prompt_type=} with {prompt=}"
|
|
@@ -237,7 +243,9 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
237
243
|
all_embeddings = []
|
|
238
244
|
for batch in inputs:
|
|
239
245
|
batch_column = next(iter(batch.keys()))
|
|
240
|
-
batched_input
|
|
246
|
+
batched_input: list[dict[str, Any]] = [
|
|
247
|
+
dict() for _ in range(len(batch[batch_column]))
|
|
248
|
+
]
|
|
241
249
|
|
|
242
250
|
# transform from {"text": [text1, text2], "image": [image1, image2]} to
|
|
243
251
|
# [{"text": text1, "image": image1}, {"text": text2, "image": image2}]
|
|
@@ -258,24 +266,36 @@ class SentenceTransformerMultimodalEncoderWrapper(SentenceTransformerEncoderWrap
|
|
|
258
266
|
|
|
259
267
|
|
|
260
268
|
class CrossEncoderWrapper:
|
|
261
|
-
"""Wrapper for CrossEncoder models.
|
|
269
|
+
"""Wrapper for CrossEncoder models.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
model: The CrossEncoder model to use. Can be a string (model name) or a CrossEncoder model.
|
|
273
|
+
revision: The revision of the model to use.
|
|
274
|
+
device: The device used to load the model.
|
|
275
|
+
query_prefix: A prefix to add to all queries.
|
|
276
|
+
passage_prefix: A prefix to add to all passages.
|
|
277
|
+
**kwargs: Additional arguments to pass to the CrossEncoder model.
|
|
278
|
+
"""
|
|
262
279
|
|
|
263
280
|
def __init__(
|
|
264
281
|
self,
|
|
265
282
|
model: CrossEncoder | str,
|
|
266
283
|
revision: str | None = None,
|
|
284
|
+
device: str | None = None,
|
|
285
|
+
query_prefix: str = "",
|
|
286
|
+
passage_prefix: str = "",
|
|
267
287
|
**kwargs,
|
|
268
288
|
) -> None:
|
|
269
289
|
from sentence_transformers import CrossEncoder
|
|
270
290
|
|
|
271
|
-
from mteb.models.get_model_meta import _model_meta_from_cross_encoder
|
|
272
|
-
|
|
273
291
|
if isinstance(model, CrossEncoder):
|
|
274
292
|
self.model = model
|
|
275
293
|
elif isinstance(model, str):
|
|
276
|
-
self.model = CrossEncoder(model, revision=revision, **kwargs)
|
|
294
|
+
self.model = CrossEncoder(model, revision=revision, device=device, **kwargs)
|
|
277
295
|
|
|
278
|
-
self.mteb_model_meta =
|
|
296
|
+
self.mteb_model_meta = ModelMeta.from_cross_encoder(self.model)
|
|
297
|
+
self.query_prefix = query_prefix
|
|
298
|
+
self.passage_prefix = passage_prefix
|
|
279
299
|
|
|
280
300
|
def predict(
|
|
281
301
|
self,
|
|
@@ -286,7 +306,7 @@ class CrossEncoderWrapper:
|
|
|
286
306
|
hf_split: str,
|
|
287
307
|
hf_subset: str,
|
|
288
308
|
prompt_type: PromptType | None = None,
|
|
289
|
-
**kwargs:
|
|
309
|
+
**kwargs: Unpack[EncodeKwargs],
|
|
290
310
|
) -> Array:
|
|
291
311
|
"""Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.
|
|
292
312
|
|
|
@@ -304,10 +324,10 @@ class CrossEncoderWrapper:
|
|
|
304
324
|
The predicted relevance scores for each inputs pair.
|
|
305
325
|
"""
|
|
306
326
|
all_queries_with_instructions = [
|
|
307
|
-
text for batch in inputs1 for text in batch["text"]
|
|
327
|
+
self.query_prefix + text for batch in inputs1 for text in batch["text"]
|
|
308
328
|
]
|
|
309
329
|
all_corpus_with_instructions = [
|
|
310
|
-
text for batch in inputs2 for text in batch["text"]
|
|
330
|
+
self.passage_prefix + text for batch in inputs2 for text in batch["text"]
|
|
311
331
|
]
|
|
312
332
|
|
|
313
333
|
return self.model.predict(
|