mteb 2.7.2__py3-none-any.whl → 2.7.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/_create_dataloaders.py +63 -14
- mteb/_evaluators/any_sts_evaluator.py +12 -5
- mteb/_evaluators/clustering_evaluator.py +12 -4
- mteb/_evaluators/evaluator.py +11 -5
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +14 -5
- mteb/_evaluators/pair_classification_evaluator.py +13 -5
- mteb/_evaluators/retrieval_evaluator.py +22 -13
- mteb/_evaluators/retrieval_metrics.py +9 -3
- mteb/_evaluators/sklearn_evaluator.py +20 -11
- mteb/_evaluators/text/bitext_mining_evaluator.py +10 -3
- mteb/_evaluators/text/summarization_evaluator.py +10 -4
- mteb/_evaluators/zeroshot_classification_evaluator.py +12 -3
- mteb/_helpful_enum.py +5 -1
- mteb/abstasks/_data_filter/filters.py +8 -2
- mteb/abstasks/_data_filter/task_pipelines.py +7 -2
- mteb/abstasks/_statistics_calculation.py +6 -4
- mteb/abstasks/abstask.py +48 -21
- mteb/abstasks/aggregate_task_metadata.py +20 -9
- mteb/abstasks/aggregated_task.py +15 -8
- mteb/abstasks/classification.py +25 -9
- mteb/abstasks/clustering.py +23 -10
- mteb/abstasks/clustering_legacy.py +22 -8
- mteb/abstasks/image/image_text_pair_classification.py +23 -9
- mteb/abstasks/multilabel_classification.py +13 -5
- mteb/abstasks/pair_classification.py +27 -11
- mteb/abstasks/regression.py +14 -6
- mteb/abstasks/retrieval.py +56 -30
- mteb/abstasks/retrieval_dataset_loaders.py +48 -37
- mteb/abstasks/sts.py +29 -13
- mteb/abstasks/task_metadata.py +17 -8
- mteb/abstasks/text/bitext_mining.py +23 -12
- mteb/abstasks/text/reranking.py +2 -2
- mteb/abstasks/text/summarization.py +19 -8
- mteb/abstasks/zeroshot_classification.py +23 -9
- mteb/benchmarks/_create_table.py +13 -7
- mteb/benchmarks/benchmark.py +11 -1
- mteb/benchmarks/benchmarks/__init__.py +2 -0
- mteb/benchmarks/benchmarks/benchmarks.py +41 -2
- mteb/benchmarks/benchmarks/rteb_benchmarks.py +20 -9
- mteb/cache.py +10 -5
- mteb/cli/_display_tasks.py +9 -3
- mteb/cli/build_cli.py +5 -2
- mteb/cli/generate_model_card.py +9 -2
- mteb/deprecated_evaluator.py +16 -12
- mteb/descriptive_stats/Retrieval/BrightAopsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightBiologyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightBiologyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEarthScienceLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEarthScienceRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEconomicsLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightEconomicsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightLeetcodeRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPonyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPonyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPsychologyLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightPsychologyRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightRoboticsLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightRoboticsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightStackoverflowLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightStackoverflowRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightSustainableLivingLongRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightSustainableLivingRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightTheoremQAQuestionsRetrieval.json +35 -0
- mteb/descriptive_stats/Retrieval/BrightTheoremQATheoremsRetrieval.json +35 -0
- mteb/evaluate.py +33 -20
- mteb/filter_tasks.py +12 -7
- mteb/get_tasks.py +9 -4
- mteb/languages/language_scripts.py +8 -3
- mteb/leaderboard/app.py +11 -4
- mteb/leaderboard/table.py +7 -2
- mteb/load_results.py +9 -3
- mteb/models/abs_encoder.py +22 -12
- mteb/models/cache_wrappers/cache_backend_protocol.py +5 -3
- mteb/models/cache_wrappers/cache_backends/_hash_utils.py +8 -4
- mteb/models/cache_wrappers/cache_backends/faiss_cache.py +8 -3
- mteb/models/cache_wrappers/cache_wrapper.py +14 -9
- mteb/models/get_model_meta.py +32 -6
- mteb/models/instruct_wrapper.py +13 -5
- mteb/models/model_implementations/align_models.py +10 -4
- mteb/models/model_implementations/amazon_models.py +1 -0
- mteb/models/model_implementations/andersborges.py +2 -0
- mteb/models/model_implementations/ara_models.py +1 -0
- mteb/models/model_implementations/arctic_models.py +8 -0
- mteb/models/model_implementations/b1ade_models.py +1 -0
- mteb/models/model_implementations/bedrock_models.py +20 -6
- mteb/models/model_implementations/bge_models.py +40 -1
- mteb/models/model_implementations/bica_model.py +1 -0
- mteb/models/model_implementations/blip2_models.py +11 -4
- mteb/models/model_implementations/blip_models.py +17 -4
- mteb/models/model_implementations/bm25.py +24 -14
- mteb/models/model_implementations/bmretriever_models.py +10 -2
- mteb/models/model_implementations/cadet_models.py +1 -0
- mteb/models/model_implementations/cde_models.py +11 -5
- mteb/models/model_implementations/clip_models.py +12 -4
- mteb/models/model_implementations/clips_models.py +3 -0
- mteb/models/model_implementations/codefuse_models.py +5 -0
- mteb/models/model_implementations/codesage_models.py +3 -0
- mteb/models/model_implementations/cohere_models.py +14 -4
- mteb/models/model_implementations/cohere_v.py +14 -4
- mteb/models/model_implementations/colpali_models.py +7 -3
- mteb/models/model_implementations/colqwen_models.py +17 -31
- mteb/models/model_implementations/colsmol_models.py +3 -1
- mteb/models/model_implementations/conan_models.py +11 -4
- mteb/models/model_implementations/dino_models.py +28 -4
- mteb/models/model_implementations/e5_instruct.py +4 -0
- mteb/models/model_implementations/e5_models.py +9 -0
- mteb/models/model_implementations/e5_v.py +10 -4
- mteb/models/model_implementations/eagerworks_models.py +11 -4
- mteb/models/model_implementations/emillykkejensen_models.py +3 -0
- mteb/models/model_implementations/en_code_retriever.py +1 -0
- mteb/models/model_implementations/euler_models.py +1 -0
- mteb/models/model_implementations/evaclip_models.py +13 -4
- mteb/models/model_implementations/fa_models.py +9 -0
- mteb/models/model_implementations/facebookai.py +2 -0
- mteb/models/model_implementations/geogpt_models.py +1 -0
- mteb/models/model_implementations/gme_v_models.py +7 -3
- mteb/models/model_implementations/google_models.py +15 -4
- mteb/models/model_implementations/granite_vision_embedding_models.py +7 -5
- mteb/models/model_implementations/gritlm_models.py +3 -0
- mteb/models/model_implementations/gte_models.py +9 -0
- mteb/models/model_implementations/hinvec_models.py +6 -1
- mteb/models/model_implementations/human.py +1 -0
- mteb/models/model_implementations/ibm_granite_models.py +6 -0
- mteb/models/model_implementations/inf_models.py +2 -0
- mteb/models/model_implementations/jasper_models.py +14 -5
- mteb/models/model_implementations/jina_clip.py +10 -4
- mteb/models/model_implementations/jina_models.py +17 -5
- mteb/models/model_implementations/kalm_models.py +24 -12
- mteb/models/model_implementations/kblab.py +1 -0
- mteb/models/model_implementations/kennethenevoldsen_models.py +2 -0
- mteb/models/model_implementations/kfst.py +1 -0
- mteb/models/model_implementations/kowshik24_models.py +1 -0
- mteb/models/model_implementations/lens_models.py +2 -0
- mteb/models/model_implementations/lgai_embedding_models.py +1 -0
- mteb/models/model_implementations/linq_models.py +7 -1
- mteb/models/model_implementations/listconranker.py +10 -4
- mteb/models/model_implementations/llm2clip_models.py +12 -4
- mteb/models/model_implementations/llm2vec_models.py +20 -6
- mteb/models/model_implementations/mcinext_models.py +8 -2
- mteb/models/model_implementations/mdbr_models.py +2 -0
- mteb/models/model_implementations/misc_models.py +63 -0
- mteb/models/model_implementations/mixedbread_ai_models.py +3 -0
- mteb/models/model_implementations/mme5_models.py +2 -1
- mteb/models/model_implementations/moco_models.py +11 -4
- mteb/models/model_implementations/mod_models.py +2 -1
- mteb/models/model_implementations/model2vec_models.py +23 -4
- mteb/models/model_implementations/moka_models.py +3 -0
- mteb/models/model_implementations/nbailab.py +3 -0
- mteb/models/model_implementations/no_instruct_sentence_models.py +13 -5
- mteb/models/model_implementations/nomic_models.py +17 -4
- mteb/models/model_implementations/nomic_models_vision.py +5 -3
- mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +9 -3
- mteb/models/model_implementations/nvidia_models.py +15 -4
- mteb/models/model_implementations/octen_models.py +3 -1
- mteb/models/model_implementations/openai_models.py +14 -4
- mteb/models/model_implementations/openclip_models.py +17 -4
- mteb/models/model_implementations/opensearch_neural_sparse_models.py +15 -4
- mteb/models/model_implementations/ops_moa_models.py +9 -2
- mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -0
- mteb/models/model_implementations/pawan_models.py +1 -0
- mteb/models/model_implementations/piccolo_models.py +2 -0
- mteb/models/model_implementations/promptriever_models.py +16 -6
- mteb/models/model_implementations/pylate_models.py +32 -13
- mteb/models/model_implementations/qodo_models.py +2 -0
- mteb/models/model_implementations/qtack_models.py +1 -0
- mteb/models/model_implementations/qwen3_models.py +11 -1
- mteb/models/model_implementations/qzhou_models.py +2 -0
- mteb/models/model_implementations/random_baseline.py +4 -3
- mteb/models/model_implementations/rasgaard_models.py +1 -0
- mteb/models/model_implementations/reasonir_model.py +65 -0
- mteb/models/model_implementations/repllama_models.py +15 -6
- mteb/models/model_implementations/rerankers_custom.py +13 -4
- mteb/models/model_implementations/rerankers_monot5_based.py +24 -4
- mteb/models/model_implementations/richinfoai_models.py +1 -0
- mteb/models/model_implementations/ru_sentence_models.py +20 -0
- mteb/models/model_implementations/ruri_models.py +10 -0
- mteb/models/model_implementations/salesforce_models.py +10 -1
- mteb/models/model_implementations/samilpwc_models.py +1 -0
- mteb/models/model_implementations/sarashina_embedding_models.py +2 -0
- mteb/models/model_implementations/searchmap_models.py +1 -0
- mteb/models/model_implementations/seed_1_6_embedding_models.py +5 -2
- mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +6 -2
- mteb/models/model_implementations/seed_models.py +2 -1
- mteb/models/model_implementations/sentence_transformers_models.py +18 -0
- mteb/models/model_implementations/shuu_model.py +1 -0
- mteb/models/model_implementations/siglip_models.py +19 -4
- mteb/models/model_implementations/slm_models.py +7 -4
- mteb/models/model_implementations/sonar_models.py +2 -1
- mteb/models/model_implementations/spartan8806_atles_champion.py +1 -0
- mteb/models/model_implementations/stella_models.py +6 -0
- mteb/models/model_implementations/tarka_models.py +2 -0
- mteb/models/model_implementations/text2vec_models.py +3 -0
- mteb/models/model_implementations/ua_sentence_models.py +1 -0
- mteb/models/model_implementations/uae_models.py +10 -4
- mteb/models/model_implementations/vdr_models.py +8 -1
- mteb/models/model_implementations/vi_vn_models.py +6 -0
- mteb/models/model_implementations/vista_models.py +11 -4
- mteb/models/model_implementations/vlm2vec_models.py +11 -4
- mteb/models/model_implementations/voyage_models.py +52 -4
- mteb/models/model_implementations/voyage_v.py +11 -6
- mteb/models/model_implementations/xyz_models.py +1 -0
- mteb/models/model_implementations/youtu_models.py +1 -0
- mteb/models/model_implementations/yuan_models.py +1 -0
- mteb/models/model_implementations/yuan_models_en.py +2 -1
- mteb/models/model_meta.py +47 -9
- mteb/models/models_protocols.py +23 -18
- mteb/models/search_encoder_index/search_backend_protocol.py +7 -3
- mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +12 -4
- mteb/models/search_wrappers.py +31 -12
- mteb/models/sentence_transformer_wrapper.py +4 -3
- mteb/models/vllm_wrapper.py +8 -6
- mteb/results/benchmark_results.py +22 -17
- mteb/results/model_result.py +21 -15
- mteb/results/task_result.py +32 -16
- mteb/similarity_functions.py +8 -2
- mteb/tasks/aggregated_tasks/eng/cqadupstack_retrieval.py +3 -3
- mteb/tasks/aggregated_tasks/eng/sts17_multilingual_visual_sts_eng.py +3 -3
- mteb/tasks/aggregated_tasks/eng/sts_benchmark_multilingual_visual_sts_eng.py +3 -3
- mteb/tasks/aggregated_tasks/fas/cqadupstack_retrieval_fa.py +3 -3
- mteb/tasks/aggregated_tasks/fas/syn_per_chatbot_conv_sa_classification.py +3 -3
- mteb/tasks/aggregated_tasks/multilingual/sts17_multilingual_vision_sts.py +3 -3
- mteb/tasks/aggregated_tasks/multilingual/sts_benchmark_multilingual_visual_sts.py +3 -3
- mteb/tasks/aggregated_tasks/nld/cqadupstack_nl_retrieval.py +3 -3
- mteb/tasks/aggregated_tasks/pol/cqadupstack_retrieval_pl.py +3 -3
- mteb/tasks/bitext_mining/eng/pub_chem_smiles_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/fas/fa_mteb_summary_retrieval.py +3 -3
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/norwegian_courts_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +2 -2
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -1
- mteb/tasks/classification/bul/bulgarian_store_review_sentiment_classfication.py +1 -1
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -1
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -1
- mteb/tasks/classification/ell/greek_legal_code_classification.py +1 -1
- mteb/tasks/classification/eng/dbpedia_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_conversations_classification.py +2 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +1 -1
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -1
- mteb/tasks/classification/eng/yelp_review_full_classification.py +2 -2
- mteb/tasks/classification/est/estonian_valence.py +1 -1
- mteb/tasks/classification/fas/fa_mteb_classification.py +6 -6
- mteb/tasks/classification/fas/persian_food_sentiment_classification.py +1 -1
- mteb/tasks/classification/fil/filipino_shopee_reviews_classification.py +1 -1
- mteb/tasks/classification/fin/fin_toxicity_classification.py +1 -1
- mteb/tasks/classification/fra/french_book_reviews.py +2 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -1
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -1
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -1
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +2 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -1
- mteb/tasks/classification/ita/dado_eval_coarse_classification.py +1 -1
- mteb/tasks/classification/ita/ita_casehold_classification.py +1 -1
- mteb/tasks/classification/ita/sardi_stance_classification.py +1 -1
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -1
- mteb/tasks/classification/jpn/wrime_classification.py +1 -1
- mteb/tasks/classification/kan/kannada_news_classification.py +2 -2
- mteb/tasks/classification/kor/klue_tc.py +2 -2
- mteb/tasks/classification/kor/kor_fin.py +1 -1
- mteb/tasks/classification/kor/kor_hate_classification.py +1 -1
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +1 -1
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -1
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/afri_senti_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -1
- mteb/tasks/classification/multilingual/cyrillic_turkic_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_nlp_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/masakha_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -1
- mteb/tasks/classification/multilingual/multilingual_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/scala_classification.py +1 -1
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -1
- mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
- mteb/tasks/classification/multilingual/tweet_sentiment_classification.py +1 -1
- mteb/tasks/classification/nep/nepali_news_classification.py +2 -2
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +1 -1
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +1 -1
- mteb/tasks/classification/ory/odia_news_classification.py +2 -2
- mteb/tasks/classification/pan/punjabi_news_classification.py +1 -1
- mteb/tasks/classification/ron/moroco.py +1 -1
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -1
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -1
- mteb/tasks/classification/rus/georeview_classification.py +1 -1
- mteb/tasks/classification/rus/headline_classification.py +2 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +2 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +2 -2
- mteb/tasks/classification/rus/ru_sci_bench_grnti_classification.py +1 -1
- mteb/tasks/classification/rus/ru_sci_bench_oecd_classification.py +1 -1
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -1
- mteb/tasks/classification/san/sanskrit_shlokas_classification.py +1 -1
- mteb/tasks/classification/sin/sinhala_news_classification.py +2 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +2 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -1
- mteb/tasks/classification/spa/spanish_news_classification.py +2 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -1
- mteb/tasks/classification/tam/tamil_news_classification.py +2 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +2 -2
- mteb/tasks/classification/tha/wongnai_reviews_classification.py +1 -1
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +2 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -2
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -1
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -1
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +2 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_p2p.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/arxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/arxiv_hierarchical_clustering.py +2 -2
- mteb/tasks/clustering/eng/big_patent_clustering.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/twenty_newsgroups_clustering.py +1 -1
- mteb/tasks/clustering/fas/fa_mteb_clustering.py +4 -4
- mteb/tasks/clustering/fra/hal_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -1
- mteb/tasks/clustering/multilingual/wiki_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nob/snl_clustering.py +8 -3
- mteb/tasks/clustering/nob/vg_clustering.py +8 -3
- mteb/tasks/clustering/pol/polish_clustering.py +3 -3
- mteb/tasks/clustering/rus/ru_sci_bench_grnti_clustering_p2p.py +1 -1
- mteb/tasks/clustering/rus/ru_sci_bench_oecd_clustering_p2p.py +1 -1
- mteb/tasks/clustering/zho/cmteb_clustering.py +4 -4
- mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
- mteb/tasks/multichoice/eng/cv_bench.py +4 -4
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -1
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -1
- mteb/tasks/multilabel_classification/rus/ru_toixic_multilabelclassification_okmlcup.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -1
- mteb/tasks/pair_classification/ara/ar_entail.py +1 -1
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -1
- mteb/tasks/pair_classification/deu/false_friends_de_en_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_ai_sentence_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +4 -3
- mteb/tasks/pair_classification/eng/pub_chem_synonym_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_wiki_paragraphs_pc.py +1 -1
- mteb/tasks/pair_classification/eng/sprint_duplicate_questions_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_sem_eval2015_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_url_corpus_pc.py +1 -1
- mteb/tasks/pair_classification/fas/fa_mteb_pair_classification.py +5 -5
- mteb/tasks/pair_classification/fas/fars_tail.py +2 -2
- mteb/tasks/pair_classification/hye/armenian_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/ita/dis_co_tex_pair_classification.py +1 -1
- mteb/tasks/pair_classification/kor/klue_nli.py +1 -1
- mteb/tasks/pair_classification/multilingual/rte3.py +2 -2
- mteb/tasks/pair_classification/multilingual/xnli.py +1 -1
- mteb/tasks/pair_classification/pol/polish_pc.py +4 -4
- mteb/tasks/pair_classification/por/assin2_rte.py +1 -1
- mteb/tasks/pair_classification/por/sick_br_pc.py +1 -1
- mteb/tasks/pair_classification/rus/terra.py +2 -2
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -1
- mteb/tasks/pair_classification/zho/cmteb_pair_classification.py +2 -2
- mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +4 -4
- mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
- mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
- mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
- mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
- mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
- mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +1 -1
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
- mteb/tasks/retrieval/eng/__init__.py +42 -0
- mteb/tasks/retrieval/eng/bright_retrieval.py +10 -2
- mteb/tasks/retrieval/eng/bright_v1_1_retrieval.py +968 -0
- mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/limit_retrieval.py +6 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/ml_questions.py +1 -1
- mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
- mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
- mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
- mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
- mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +5 -5
- mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +14 -4
- mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nob/norquad.py +1 -1
- mteb/tasks/retrieval/nob/snl_retrieval.py +1 -1
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
- mteb/tasks/sts/fao/faroese_sts.py +1 -1
- mteb/tasks/sts/fra/sick_fr_sts.py +1 -1
- mteb/tasks/sts/kor/klue_sts.py +1 -1
- mteb/tasks/sts/por/sick_br_sts.py +1 -1
- mteb/tasks/sts/rus/ru_para_phraser_sts.py +1 -1
- mteb/tasks/zeroshot_classification/eng/sci_mmir.py +1 -1
- mteb/types/_encoder_io.py +1 -1
- mteb/types/statistics.py +9 -2
- {mteb-2.7.2.dist-info → mteb-2.7.9.dist-info}/METADATA +1 -1
- {mteb-2.7.2.dist-info → mteb-2.7.9.dist-info}/RECORD +486 -465
- {mteb-2.7.2.dist-info → mteb-2.7.9.dist-info}/WHEEL +1 -1
- {mteb-2.7.2.dist-info → mteb-2.7.9.dist-info}/entry_points.txt +0 -0
- {mteb-2.7.2.dist-info → mteb-2.7.9.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.7.2.dist-info → mteb-2.7.9.dist-info}/top_level.txt +0 -0
mteb/abstasks/retrieval.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import json
|
|
2
4
|
import logging
|
|
3
5
|
from collections import defaultdict
|
|
4
|
-
from collections.abc import Callable, Mapping, Sequence
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from time import time
|
|
7
|
-
from typing import Any, Literal
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
8
9
|
|
|
9
10
|
from datasets import Dataset, DatasetDict, concatenate_datasets
|
|
10
|
-
from typing_extensions import Self
|
|
11
11
|
|
|
12
12
|
from mteb._create_dataloaders import (
|
|
13
13
|
_combine_queries_with_instruction_text,
|
|
@@ -19,25 +19,12 @@ from mteb._evaluators.retrieval_metrics import make_score_dict
|
|
|
19
19
|
from mteb.models import (
|
|
20
20
|
CrossEncoderProtocol,
|
|
21
21
|
EncoderProtocol,
|
|
22
|
-
MTEBModels,
|
|
23
22
|
SearchCrossEncoderWrapper,
|
|
24
23
|
SearchEncoderWrapper,
|
|
25
24
|
SearchProtocol,
|
|
26
25
|
)
|
|
27
|
-
from mteb.types import (
|
|
28
|
-
EncodeKwargs,
|
|
29
|
-
HFSubset,
|
|
30
|
-
QueryDatasetType,
|
|
31
|
-
RelevantDocumentsType,
|
|
32
|
-
RetrievalOutputType,
|
|
33
|
-
ScoresDict,
|
|
34
|
-
)
|
|
35
26
|
from mteb.types.statistics import (
|
|
36
|
-
ImageStatistics,
|
|
37
|
-
RelevantDocsStatistics,
|
|
38
27
|
SplitDescriptiveStatistics,
|
|
39
|
-
TextStatistics,
|
|
40
|
-
TopRankedStatistics,
|
|
41
28
|
)
|
|
42
29
|
|
|
43
30
|
from ._statistics_calculation import (
|
|
@@ -53,6 +40,30 @@ from .retrieval_dataset_loaders import (
|
|
|
53
40
|
_combine_queries_with_instructions_datasets,
|
|
54
41
|
)
|
|
55
42
|
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
45
|
+
|
|
46
|
+
from typing_extensions import Self
|
|
47
|
+
|
|
48
|
+
from mteb.models import (
|
|
49
|
+
MTEBModels,
|
|
50
|
+
)
|
|
51
|
+
from mteb.types import (
|
|
52
|
+
EncodeKwargs,
|
|
53
|
+
HFSubset,
|
|
54
|
+
QueryDatasetType,
|
|
55
|
+
RelevantDocumentsType,
|
|
56
|
+
RetrievalOutputType,
|
|
57
|
+
ScoresDict,
|
|
58
|
+
)
|
|
59
|
+
from mteb.types.statistics import (
|
|
60
|
+
ImageStatistics,
|
|
61
|
+
RelevantDocsStatistics,
|
|
62
|
+
TextStatistics,
|
|
63
|
+
TopRankedStatistics,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
56
67
|
logger = logging.getLogger(__name__)
|
|
57
68
|
|
|
58
69
|
|
|
@@ -137,7 +148,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
137
148
|
)
|
|
138
149
|
)
|
|
139
150
|
|
|
140
|
-
def convert_v1_dataset_format_to_v2(self):
|
|
151
|
+
def convert_v1_dataset_format_to_v2(self, num_proc: int) -> None:
|
|
141
152
|
"""Convert dataset from v1 (from `self.queries`, `self.document`) format to v2 format (`self.dotaset`)."""
|
|
142
153
|
# check if dataset is `v1` version
|
|
143
154
|
if not hasattr(self, "queries"):
|
|
@@ -204,6 +215,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
204
215
|
_combine_queries_with_instructions_datasets(
|
|
205
216
|
self.dataset[subset][split]["queries"],
|
|
206
217
|
instructions,
|
|
218
|
+
num_proc,
|
|
207
219
|
)
|
|
208
220
|
)
|
|
209
221
|
if hasattr(self, "top_ranked"):
|
|
@@ -229,9 +241,10 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
229
241
|
_combine_queries_with_instructions_datasets(
|
|
230
242
|
self.dataset[subset][split]["queries"],
|
|
231
243
|
instructions,
|
|
244
|
+
num_proc,
|
|
232
245
|
)
|
|
233
246
|
)
|
|
234
|
-
if hasattr(self, "top_ranked"):
|
|
247
|
+
if hasattr(self, "top_ranked") and self.top_ranked:
|
|
235
248
|
self.dataset[subset][split]["top_ranked"] = self.top_ranked[
|
|
236
249
|
split
|
|
237
250
|
].copy()
|
|
@@ -244,13 +257,13 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
244
257
|
if hasattr(self, "top_ranked"):
|
|
245
258
|
del self.top_ranked
|
|
246
259
|
|
|
247
|
-
def load_data(self) -> None:
|
|
260
|
+
def load_data(self, num_proc: int = 1, **kwargs) -> None:
|
|
248
261
|
"""Load the dataset for the retrieval task."""
|
|
249
262
|
if self.data_loaded:
|
|
250
263
|
return
|
|
251
264
|
|
|
252
265
|
dataset_path = self.metadata.dataset["path"]
|
|
253
|
-
eval_splits = self.
|
|
266
|
+
eval_splits = self.eval_splits
|
|
254
267
|
trust_remote_code = self.metadata.dataset.get("trust_remote_code", False)
|
|
255
268
|
revision = self.metadata.dataset["revision"]
|
|
256
269
|
|
|
@@ -266,16 +279,18 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
266
279
|
trust_remote_code=trust_remote_code,
|
|
267
280
|
split=split,
|
|
268
281
|
config=hf_subset,
|
|
269
|
-
).load(
|
|
282
|
+
).load(
|
|
283
|
+
num_proc=num_proc,
|
|
284
|
+
)
|
|
270
285
|
|
|
271
286
|
if self.metadata.is_multilingual:
|
|
272
|
-
for lang in self.
|
|
287
|
+
for lang in self.hf_subsets:
|
|
273
288
|
for split in eval_splits:
|
|
274
289
|
_process_data(split, lang)
|
|
275
290
|
else:
|
|
276
291
|
for split in eval_splits:
|
|
277
292
|
_process_data(split)
|
|
278
|
-
self.dataset_transform()
|
|
293
|
+
self.dataset_transform(num_proc=num_proc)
|
|
279
294
|
self.data_loaded = True
|
|
280
295
|
|
|
281
296
|
def evaluate(
|
|
@@ -286,6 +301,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
286
301
|
*,
|
|
287
302
|
encode_kwargs: EncodeKwargs,
|
|
288
303
|
prediction_folder: Path | None = None,
|
|
304
|
+
num_proc: int = 1,
|
|
289
305
|
**kwargs: Any,
|
|
290
306
|
) -> Mapping[HFSubset, ScoresDict]:
|
|
291
307
|
"""Evaluate the model on the retrieval task.
|
|
@@ -297,16 +313,16 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
297
313
|
subsets_to_run: Optional list of subsets to evaluate on
|
|
298
314
|
encode_kwargs: Keyword arguments passed to the encoder
|
|
299
315
|
prediction_folder: Folder to save model predictions
|
|
316
|
+
num_proc: Number of processes to use
|
|
300
317
|
**kwargs: Additional keyword arguments passed to the evaluator
|
|
301
318
|
|
|
302
|
-
|
|
303
319
|
Returns:
|
|
304
320
|
Dictionary mapping subsets to their evaluation scores
|
|
305
321
|
"""
|
|
306
322
|
if not self.data_loaded:
|
|
307
|
-
self.load_data()
|
|
323
|
+
self.load_data(num_proc=num_proc)
|
|
308
324
|
# TODO: convert all tasks directly https://github.com/embeddings-benchmark/mteb/issues/2030
|
|
309
|
-
self.convert_v1_dataset_format_to_v2()
|
|
325
|
+
self.convert_v1_dataset_format_to_v2(num_proc=num_proc)
|
|
310
326
|
|
|
311
327
|
return super().evaluate(
|
|
312
328
|
model,
|
|
@@ -314,6 +330,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
314
330
|
subsets_to_run,
|
|
315
331
|
encode_kwargs=encode_kwargs,
|
|
316
332
|
prediction_folder=prediction_folder,
|
|
333
|
+
num_proc=num_proc,
|
|
317
334
|
**kwargs,
|
|
318
335
|
)
|
|
319
336
|
|
|
@@ -325,6 +342,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
325
342
|
hf_split: str,
|
|
326
343
|
hf_subset: str,
|
|
327
344
|
prediction_folder: Path | None = None,
|
|
345
|
+
num_proc: int = 1,
|
|
328
346
|
**kwargs,
|
|
329
347
|
) -> ScoresDict:
|
|
330
348
|
"""Evaluate a model on a specific subset of the data.
|
|
@@ -336,6 +354,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
336
354
|
hf_split: Split to evaluate on
|
|
337
355
|
hf_subset: Subset to evaluate on
|
|
338
356
|
prediction_folder: Folder with results prediction
|
|
357
|
+
num_proc: Number of processes to use
|
|
339
358
|
**kwargs: Additional keyword arguments passed to the evaluator
|
|
340
359
|
|
|
341
360
|
Returns:
|
|
@@ -375,6 +394,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
375
394
|
results = retriever(
|
|
376
395
|
search_model,
|
|
377
396
|
encode_kwargs=encode_kwargs,
|
|
397
|
+
num_proc=num_proc,
|
|
378
398
|
)
|
|
379
399
|
end_time = time()
|
|
380
400
|
logger.debug(
|
|
@@ -449,9 +469,13 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
449
469
|
return {}
|
|
450
470
|
|
|
451
471
|
def _calculate_descriptive_statistics_from_split(
|
|
452
|
-
self,
|
|
472
|
+
self,
|
|
473
|
+
split: str,
|
|
474
|
+
hf_subset: str | None = None,
|
|
475
|
+
compute_overall: bool = False,
|
|
476
|
+
num_proc: int = 1,
|
|
453
477
|
) -> RetrievalDescriptiveStatistics:
|
|
454
|
-
self.convert_v1_dataset_format_to_v2()
|
|
478
|
+
self.convert_v1_dataset_format_to_v2(num_proc)
|
|
455
479
|
if hf_subset and hf_subset in self.dataset:
|
|
456
480
|
split_data = self.dataset[hf_subset][split]
|
|
457
481
|
queries = split_data["queries"]
|
|
@@ -556,8 +580,8 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
556
580
|
top_ranked_statistics=top_ranked_statistics,
|
|
557
581
|
)
|
|
558
582
|
|
|
559
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
560
|
-
self.convert_v1_dataset_format_to_v2()
|
|
583
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
584
|
+
self.convert_v1_dataset_format_to_v2(num_proc)
|
|
561
585
|
|
|
562
586
|
def _push_section(
|
|
563
587
|
data: dict[str, RetrievalSplitData],
|
|
@@ -597,6 +621,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
597
621
|
repo_name,
|
|
598
622
|
hf_subset_name,
|
|
599
623
|
commit_message=f"Add {hf_subset_name}-{subset_item}",
|
|
624
|
+
num_proc=num_proc,
|
|
600
625
|
)
|
|
601
626
|
|
|
602
627
|
for subset in self.dataset:
|
|
@@ -630,6 +655,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
630
655
|
repo_name,
|
|
631
656
|
f"{subset}-qrels" if subset != "default" else "qrels",
|
|
632
657
|
commit_message=f"Add {subset}-qrels",
|
|
658
|
+
num_proc=num_proc,
|
|
633
659
|
)
|
|
634
660
|
|
|
635
661
|
_push_section(
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
|
-
from typing import TypedDict
|
|
4
|
+
from typing import TYPE_CHECKING, TypedDict
|
|
3
5
|
|
|
4
6
|
from datasets import (
|
|
5
7
|
Dataset,
|
|
@@ -11,13 +13,14 @@ from datasets import (
|
|
|
11
13
|
load_dataset,
|
|
12
14
|
)
|
|
13
15
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from mteb.types import (
|
|
18
|
+
CorpusDatasetType,
|
|
19
|
+
InstructionDatasetType,
|
|
20
|
+
QueryDatasetType,
|
|
21
|
+
RelevantDocumentsType,
|
|
22
|
+
TopRankedDocumentsType,
|
|
23
|
+
)
|
|
21
24
|
|
|
22
25
|
logger = logging.getLogger(__name__)
|
|
23
26
|
|
|
@@ -73,28 +76,36 @@ class RetrievalDatasetLoader:
|
|
|
73
76
|
self.config = config if config != "default" else None
|
|
74
77
|
self.dataset_configs = get_dataset_config_names(self.hf_repo, self.revision)
|
|
75
78
|
|
|
76
|
-
def load(
|
|
79
|
+
def load(
|
|
80
|
+
self,
|
|
81
|
+
num_proc: int = 1,
|
|
82
|
+
) -> RetrievalSplitData:
|
|
77
83
|
"""Loads the dataset split for the specified configuration.
|
|
78
84
|
|
|
85
|
+
Args:
|
|
86
|
+
num_proc: The number of processes to use.
|
|
87
|
+
|
|
79
88
|
Returns:
|
|
80
89
|
A dictionary containing the corpus, queries, relevant documents, instructions (if applicable), and top-ranked documents (if applicable).
|
|
81
90
|
"""
|
|
82
91
|
top_ranked = None
|
|
83
92
|
|
|
84
|
-
qrels = self._load_qrels()
|
|
85
|
-
corpus = self._load_corpus()
|
|
86
|
-
queries = self._load_queries()
|
|
93
|
+
qrels = self._load_qrels(num_proc)
|
|
94
|
+
corpus = self._load_corpus(num_proc)
|
|
95
|
+
queries = self._load_queries(num_proc)
|
|
87
96
|
|
|
88
97
|
queries = queries.filter(
|
|
89
98
|
lambda x: x["id"] in qrels.keys(), desc="Filtering queries by qrels"
|
|
90
99
|
)
|
|
91
100
|
|
|
92
101
|
if any(c.endswith("top_ranked") for c in self.dataset_configs):
|
|
93
|
-
top_ranked = self._load_top_ranked()
|
|
102
|
+
top_ranked = self._load_top_ranked(num_proc)
|
|
94
103
|
|
|
95
104
|
if any(c.endswith("instruction") for c in self.dataset_configs):
|
|
96
|
-
instructions = self._load_instructions()
|
|
97
|
-
queries = _combine_queries_with_instructions_datasets(
|
|
105
|
+
instructions = self._load_instructions(num_proc)
|
|
106
|
+
queries = _combine_queries_with_instructions_datasets(
|
|
107
|
+
queries, instructions, num_proc
|
|
108
|
+
)
|
|
98
109
|
|
|
99
110
|
return RetrievalSplitData(
|
|
100
111
|
corpus=corpus,
|
|
@@ -117,20 +128,21 @@ class RetrievalDatasetLoader:
|
|
|
117
128
|
f"Split {self.split} not found in {splits}. Please specify a valid split."
|
|
118
129
|
)
|
|
119
130
|
|
|
120
|
-
def _load_dataset_split(self, config: str) -> Dataset:
|
|
131
|
+
def _load_dataset_split(self, config: str, num_proc: int) -> Dataset:
|
|
121
132
|
return load_dataset(
|
|
122
133
|
self.hf_repo,
|
|
123
134
|
config,
|
|
124
135
|
split=self._get_split(config),
|
|
125
136
|
trust_remote_code=self.trust_remote_code,
|
|
126
137
|
revision=self.revision,
|
|
138
|
+
num_proc=num_proc,
|
|
127
139
|
)
|
|
128
140
|
|
|
129
|
-
def _load_corpus(self) -> CorpusDatasetType:
|
|
130
|
-
logger.info("Loading Corpus...")
|
|
131
|
-
|
|
141
|
+
def _load_corpus(self, num_proc: int) -> CorpusDatasetType:
|
|
132
142
|
config = f"{self.config}-corpus" if self.config is not None else "corpus"
|
|
133
|
-
|
|
143
|
+
logger.info("Loading corpus subset: %s", config)
|
|
144
|
+
|
|
145
|
+
corpus_ds = self._load_dataset_split(config, num_proc)
|
|
134
146
|
if "_id" in corpus_ds.column_names:
|
|
135
147
|
corpus_ds = corpus_ds.cast_column("_id", Value("string")).rename_column(
|
|
136
148
|
"_id", "id"
|
|
@@ -139,13 +151,13 @@ class RetrievalDatasetLoader:
|
|
|
139
151
|
logger.debug("Doc Example: %s", corpus_ds[0])
|
|
140
152
|
return corpus_ds
|
|
141
153
|
|
|
142
|
-
def _load_queries(self) -> QueryDatasetType:
|
|
143
|
-
logger.info("Loading Queries...")
|
|
144
|
-
|
|
154
|
+
def _load_queries(self, num_proc: int) -> QueryDatasetType:
|
|
145
155
|
config = f"{self.config}-queries" if self.config is not None else "queries"
|
|
156
|
+
logger.info("Loading queries subset: %s", config)
|
|
157
|
+
|
|
146
158
|
if "query" in self.dataset_configs:
|
|
147
159
|
config = "query"
|
|
148
|
-
queries_ds = self._load_dataset_split(config)
|
|
160
|
+
queries_ds = self._load_dataset_split(config, num_proc)
|
|
149
161
|
if "_id" in queries_ds.column_names:
|
|
150
162
|
queries_ds = queries_ds.cast_column("_id", Value("string")).rename_column(
|
|
151
163
|
"_id", "id"
|
|
@@ -156,10 +168,10 @@ class RetrievalDatasetLoader:
|
|
|
156
168
|
|
|
157
169
|
return queries_ds
|
|
158
170
|
|
|
159
|
-
def _load_qrels(self) -> RelevantDocumentsType:
|
|
160
|
-
logger.info("Loading qrels...")
|
|
161
|
-
|
|
171
|
+
def _load_qrels(self, num_proc: int) -> RelevantDocumentsType:
|
|
162
172
|
config = f"{self.config}-qrels" if self.config is not None else "default"
|
|
173
|
+
|
|
174
|
+
logger.info("Loading qrels subset: %s", config)
|
|
163
175
|
if config == "default" and config not in self.dataset_configs:
|
|
164
176
|
if "qrels" in self.dataset_configs:
|
|
165
177
|
config = "qrels"
|
|
@@ -168,7 +180,7 @@ class RetrievalDatasetLoader:
|
|
|
168
180
|
"No qrels or default config found. Please specify a valid config or ensure the dataset has qrels."
|
|
169
181
|
)
|
|
170
182
|
|
|
171
|
-
qrels_ds = self._load_dataset_split(config)
|
|
183
|
+
qrels_ds = self._load_dataset_split(config, num_proc)
|
|
172
184
|
qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"])
|
|
173
185
|
|
|
174
186
|
qrels_ds = qrels_ds.cast(
|
|
@@ -191,13 +203,12 @@ class RetrievalDatasetLoader:
|
|
|
191
203
|
logger.info("Loaded %d %s qrels.", len(qrels_dict), self.split.upper())
|
|
192
204
|
return qrels_dict
|
|
193
205
|
|
|
194
|
-
def _load_top_ranked(self) -> TopRankedDocumentsType:
|
|
195
|
-
logger.info("Loading Top Ranked")
|
|
196
|
-
|
|
206
|
+
def _load_top_ranked(self, num_proc: int) -> TopRankedDocumentsType:
|
|
197
207
|
config = (
|
|
198
208
|
f"{self.config}-top_ranked" if self.config is not None else "top_ranked"
|
|
199
209
|
)
|
|
200
|
-
|
|
210
|
+
logger.info("Loading top ranked subset: %s", config)
|
|
211
|
+
top_ranked_ds = self._load_dataset_split(config, num_proc)
|
|
201
212
|
top_ranked_ds = top_ranked_ds.cast(
|
|
202
213
|
Features(
|
|
203
214
|
{
|
|
@@ -215,13 +226,12 @@ class RetrievalDatasetLoader:
|
|
|
215
226
|
logger.info(f"Top ranked loaded: {len(top_ranked_ds)}")
|
|
216
227
|
return top_ranked_dict
|
|
217
228
|
|
|
218
|
-
def _load_instructions(self) -> InstructionDatasetType:
|
|
219
|
-
logger.info("Loading Instructions")
|
|
220
|
-
|
|
229
|
+
def _load_instructions(self, num_proc: int) -> InstructionDatasetType:
|
|
221
230
|
config = (
|
|
222
231
|
f"{self.config}-instruction" if self.config is not None else "instruction"
|
|
223
232
|
)
|
|
224
|
-
|
|
233
|
+
logger.info("Loading instruction subset: %s", config)
|
|
234
|
+
instructions_ds = self._load_dataset_split(config, num_proc)
|
|
225
235
|
instructions_ds = instructions_ds.cast(
|
|
226
236
|
Features(
|
|
227
237
|
{
|
|
@@ -236,6 +246,7 @@ class RetrievalDatasetLoader:
|
|
|
236
246
|
def _combine_queries_with_instructions_datasets(
|
|
237
247
|
queries_dataset: QueryDatasetType,
|
|
238
248
|
instruction_dataset: InstructionDatasetType | dict[str, str],
|
|
249
|
+
num_proc: int,
|
|
239
250
|
) -> Dataset:
|
|
240
251
|
if isinstance(instruction_dataset, Dataset):
|
|
241
252
|
instruction_to_query_idx = {
|
|
@@ -248,4 +259,4 @@ def _combine_queries_with_instructions_datasets(
|
|
|
248
259
|
row["instruction"] = instruction_to_query_idx[row["id"]]
|
|
249
260
|
return row
|
|
250
261
|
|
|
251
|
-
return queries_dataset.map(_add_instruction_to_query)
|
|
262
|
+
return queries_dataset.map(_add_instruction_to_query, num_proc=num_proc)
|
mteb/abstasks/sts.py
CHANGED
|
@@ -1,19 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
|
-
from
|
|
3
|
-
from typing import Any, TypedDict, cast
|
|
4
|
+
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
|
4
5
|
|
|
5
|
-
from datasets import Dataset
|
|
6
6
|
from scipy.stats import pearsonr, spearmanr
|
|
7
7
|
|
|
8
8
|
from mteb._evaluators import AnySTSEvaluator
|
|
9
|
-
from mteb.
|
|
10
|
-
from mteb.models import EncoderProtocol, MTEBModels
|
|
11
|
-
from mteb.types import EncodeKwargs, PromptType
|
|
9
|
+
from mteb.models import EncoderProtocol
|
|
12
10
|
from mteb.types.statistics import (
|
|
13
|
-
ImageStatistics,
|
|
14
|
-
ScoreStatistics,
|
|
15
11
|
SplitDescriptiveStatistics,
|
|
16
|
-
TextStatistics,
|
|
17
12
|
)
|
|
18
13
|
|
|
19
14
|
from ._statistics_calculation import (
|
|
@@ -23,6 +18,20 @@ from ._statistics_calculation import (
|
|
|
23
18
|
)
|
|
24
19
|
from .abstask import AbsTask
|
|
25
20
|
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
from datasets import Dataset
|
|
25
|
+
|
|
26
|
+
from mteb._evaluators.any_sts_evaluator import STSEvaluatorScores
|
|
27
|
+
from mteb.models import MTEBModels
|
|
28
|
+
from mteb.types import EncodeKwargs, PromptType
|
|
29
|
+
from mteb.types.statistics import (
|
|
30
|
+
ImageStatistics,
|
|
31
|
+
ScoreStatistics,
|
|
32
|
+
TextStatistics,
|
|
33
|
+
)
|
|
34
|
+
|
|
26
35
|
logger = logging.getLogger(__name__)
|
|
27
36
|
|
|
28
37
|
|
|
@@ -109,6 +118,7 @@ class AbsTaskSTS(AbsTask):
|
|
|
109
118
|
hf_split: str,
|
|
110
119
|
hf_subset: str,
|
|
111
120
|
prediction_folder: Path | None = None,
|
|
121
|
+
num_proc: int = 1,
|
|
112
122
|
**kwargs: Any,
|
|
113
123
|
) -> STSMetrics:
|
|
114
124
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -127,7 +137,11 @@ class AbsTaskSTS(AbsTask):
|
|
|
127
137
|
input2_prompt_type=self.input2_prompt_type,
|
|
128
138
|
**kwargs,
|
|
129
139
|
)
|
|
130
|
-
scores = evaluator(
|
|
140
|
+
scores = evaluator(
|
|
141
|
+
model,
|
|
142
|
+
encode_kwargs=encode_kwargs,
|
|
143
|
+
num_proc=num_proc,
|
|
144
|
+
)
|
|
131
145
|
|
|
132
146
|
if prediction_folder:
|
|
133
147
|
self._save_task_predictions(
|
|
@@ -182,7 +196,7 @@ class AbsTaskSTS(AbsTask):
|
|
|
182
196
|
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
|
|
183
197
|
) -> AnySTSDescriptiveStatistics:
|
|
184
198
|
first_column, second_column = self.column_names
|
|
185
|
-
self.dataset = cast(dict[str, dict[str, Dataset]], self.dataset)
|
|
199
|
+
self.dataset = cast("dict[str, dict[str, Dataset]]", self.dataset)
|
|
186
200
|
|
|
187
201
|
if hf_subset:
|
|
188
202
|
sentence1 = self.dataset[hf_subset][split][first_column]
|
|
@@ -236,9 +250,11 @@ class AbsTaskSTS(AbsTask):
|
|
|
236
250
|
label_statistics=labels_statistics,
|
|
237
251
|
)
|
|
238
252
|
|
|
239
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
253
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
240
254
|
self._upload_dataset_to_hub(
|
|
241
|
-
repo_name,
|
|
255
|
+
repo_name,
|
|
256
|
+
[self.column_names[0], self.column_names[1], "score"],
|
|
257
|
+
num_proc=num_proc,
|
|
242
258
|
)
|
|
243
259
|
|
|
244
260
|
def _normalize(self, x: float) -> float:
|
mteb/abstasks/task_metadata.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import json
|
|
2
4
|
import logging
|
|
3
5
|
from collections.abc import Sequence
|
|
4
6
|
from pathlib import Path
|
|
5
|
-
from typing import Any, Literal, cast
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
6
8
|
|
|
7
9
|
from huggingface_hub import (
|
|
8
|
-
CardData,
|
|
9
10
|
DatasetCard,
|
|
10
11
|
DatasetCardData,
|
|
11
12
|
constants,
|
|
@@ -17,13 +18,11 @@ from pydantic import (
|
|
|
17
18
|
ConfigDict,
|
|
18
19
|
field_validator,
|
|
19
20
|
)
|
|
20
|
-
from typing_extensions import Required, TypedDict
|
|
21
|
+
from typing_extensions import Required, TypedDict # noqa: TC002
|
|
21
22
|
|
|
22
23
|
import mteb
|
|
23
24
|
from mteb.languages import check_language_code
|
|
24
25
|
from mteb.types import (
|
|
25
|
-
HFSubset,
|
|
26
|
-
ISOLanguageScript,
|
|
27
26
|
Languages,
|
|
28
27
|
Licenses,
|
|
29
28
|
Modalities,
|
|
@@ -31,7 +30,17 @@ from mteb.types import (
|
|
|
31
30
|
StrDate,
|
|
32
31
|
StrURL,
|
|
33
32
|
)
|
|
34
|
-
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from huggingface_hub import (
|
|
36
|
+
CardData,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
from mteb.types import (
|
|
40
|
+
HFSubset,
|
|
41
|
+
ISOLanguageScript,
|
|
42
|
+
)
|
|
43
|
+
from mteb.types.statistics import DescriptiveStatistics
|
|
35
44
|
|
|
36
45
|
logger = logging.getLogger(__name__)
|
|
37
46
|
|
|
@@ -368,7 +377,7 @@ class TaskMetadata(BaseModel):
|
|
|
368
377
|
"""Return a dictionary mapping huggingface subsets to languages."""
|
|
369
378
|
if isinstance(self.eval_langs, dict):
|
|
370
379
|
return self.eval_langs
|
|
371
|
-
return {"default": cast(list[str], self.eval_langs)}
|
|
380
|
+
return {"default": cast("list[str]", self.eval_langs)}
|
|
372
381
|
|
|
373
382
|
@property
|
|
374
383
|
def intext_citation(self, include_cite: bool = True) -> str:
|
|
@@ -697,7 +706,7 @@ class TaskMetadata(BaseModel):
|
|
|
697
706
|
for val in self.eval_langs.values():
|
|
698
707
|
languages.extend(val)
|
|
699
708
|
else:
|
|
700
|
-
languages = cast(list[str], self.eval_langs)
|
|
709
|
+
languages = cast("list[str]", self.eval_langs)
|
|
701
710
|
# value "python" is not valid. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters),
|
|
702
711
|
# or a special value like "code", "multilingual".
|
|
703
712
|
readme_langs = []
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from collections import defaultdict
|
|
3
|
-
from
|
|
4
|
-
from typing import Any, ClassVar, TypedDict, cast
|
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
|
|
5
6
|
|
|
6
7
|
from datasets import Dataset, DatasetDict
|
|
7
8
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
|
@@ -9,9 +10,15 @@ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_sc
|
|
|
9
10
|
from mteb._evaluators import BitextMiningEvaluator
|
|
10
11
|
from mteb.abstasks._statistics_calculation import calculate_text_statistics
|
|
11
12
|
from mteb.abstasks.abstask import AbsTask
|
|
12
|
-
from mteb.models import EncoderProtocol
|
|
13
|
-
from mteb.types import
|
|
14
|
-
|
|
13
|
+
from mteb.models import EncoderProtocol
|
|
14
|
+
from mteb.types.statistics import SplitDescriptiveStatistics
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
from mteb.models import MTEBModels
|
|
20
|
+
from mteb.types import EncodeKwargs, HFSubset, ScoresDict
|
|
21
|
+
from mteb.types.statistics import TextStatistics
|
|
15
22
|
|
|
16
23
|
logger = logging.getLogger(__name__)
|
|
17
24
|
|
|
@@ -75,6 +82,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
75
82
|
*,
|
|
76
83
|
encode_kwargs: EncodeKwargs,
|
|
77
84
|
prediction_folder: Path | None = None,
|
|
85
|
+
num_proc: int = 1,
|
|
78
86
|
**kwargs: Any,
|
|
79
87
|
) -> dict[HFSubset, ScoresDict]:
|
|
80
88
|
"""Added load for "parallel" datasets"""
|
|
@@ -82,7 +90,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
82
90
|
raise TypeError("Expected model to be an instance of EncoderProtocol")
|
|
83
91
|
|
|
84
92
|
if not self.data_loaded:
|
|
85
|
-
self.load_data()
|
|
93
|
+
self.load_data(num_proc=num_proc)
|
|
86
94
|
|
|
87
95
|
hf_subsets = self.hf_subsets
|
|
88
96
|
|
|
@@ -90,7 +98,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
90
98
|
if subsets_to_run is not None:
|
|
91
99
|
hf_subsets = [s for s in hf_subsets if s in subsets_to_run]
|
|
92
100
|
|
|
93
|
-
encoder_model = cast(EncoderProtocol, model)
|
|
101
|
+
encoder_model = cast("EncoderProtocol", model)
|
|
94
102
|
|
|
95
103
|
if self.dataset is None:
|
|
96
104
|
raise ValueError("Dataset is not loaded.")
|
|
@@ -105,6 +113,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
105
113
|
hf_subset="parallel",
|
|
106
114
|
encode_kwargs=encode_kwargs,
|
|
107
115
|
prediction_folder=prediction_folder,
|
|
116
|
+
num_proc=num_proc,
|
|
108
117
|
**kwargs,
|
|
109
118
|
)
|
|
110
119
|
else:
|
|
@@ -124,10 +133,11 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
124
133
|
hf_subset=hf_subset,
|
|
125
134
|
encode_kwargs=encode_kwargs,
|
|
126
135
|
prediction_folder=prediction_folder,
|
|
136
|
+
num_proc=num_proc,
|
|
127
137
|
**kwargs,
|
|
128
138
|
)
|
|
129
139
|
|
|
130
|
-
return cast(dict[HFSubset, ScoresDict], scores)
|
|
140
|
+
return cast("dict[HFSubset, ScoresDict]", scores)
|
|
131
141
|
|
|
132
142
|
def _get_pairs(self, parallel: bool) -> list[tuple[str, str]]:
|
|
133
143
|
pairs = self._DEFAULT_PAIR
|
|
@@ -145,6 +155,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
145
155
|
encode_kwargs: EncodeKwargs,
|
|
146
156
|
prediction_folder: Path | None = None,
|
|
147
157
|
parallel: bool = False,
|
|
158
|
+
num_proc: int = 1,
|
|
148
159
|
**kwargs,
|
|
149
160
|
) -> BitextMiningMetrics | dict[str, BitextMiningMetrics]:
|
|
150
161
|
pairs = self._get_pairs(parallel)
|
|
@@ -164,7 +175,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
164
175
|
else data_split["gold"]
|
|
165
176
|
)
|
|
166
177
|
|
|
167
|
-
neighbours = evaluator(model, encode_kwargs=encode_kwargs)
|
|
178
|
+
neighbours = evaluator(model, encode_kwargs=encode_kwargs, num_proc=num_proc)
|
|
168
179
|
|
|
169
180
|
if prediction_folder:
|
|
170
181
|
self._save_task_predictions(
|
|
@@ -257,7 +268,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
257
268
|
sentence2_statistics=text2_statistics,
|
|
258
269
|
)
|
|
259
270
|
|
|
260
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
271
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
261
272
|
if self.dataset is None:
|
|
262
273
|
raise ValueError("Dataset is not loaded.")
|
|
263
274
|
|
|
@@ -280,7 +291,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
280
291
|
dataset_dict = DatasetDict(
|
|
281
292
|
{split: Dataset.from_dict(dataset[split]) for split in dataset}
|
|
282
293
|
)
|
|
283
|
-
dataset_dict.push_to_hub(repo_name)
|
|
294
|
+
dataset_dict.push_to_hub(repo_name, num_proc=num_proc)
|
|
284
295
|
else:
|
|
285
296
|
sentences = {}
|
|
286
297
|
for split in self.dataset:
|
|
@@ -292,4 +303,4 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
292
303
|
}
|
|
293
304
|
)
|
|
294
305
|
sentences = DatasetDict(sentences)
|
|
295
|
-
sentences.push_to_hub(repo_name)
|
|
306
|
+
sentences.push_to_hub(repo_name, num_proc=num_proc)
|