mteb 2.7.4__py3-none-any.whl → 2.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/_create_dataloaders.py +47 -5
- mteb/_evaluators/any_sts_evaluator.py +2 -0
- mteb/_evaluators/clustering_evaluator.py +2 -0
- mteb/_evaluators/evaluator.py +2 -1
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +8 -1
- mteb/_evaluators/pair_classification_evaluator.py +3 -0
- mteb/_evaluators/retrieval_evaluator.py +3 -0
- mteb/_evaluators/sklearn_evaluator.py +6 -1
- mteb/_evaluators/text/bitext_mining_evaluator.py +2 -0
- mteb/_evaluators/text/summarization_evaluator.py +2 -0
- mteb/_evaluators/zeroshot_classification_evaluator.py +2 -0
- mteb/abstasks/abstask.py +31 -12
- mteb/abstasks/classification.py +10 -3
- mteb/abstasks/clustering.py +6 -2
- mteb/abstasks/clustering_legacy.py +8 -2
- mteb/abstasks/image/image_text_pair_classification.py +6 -2
- mteb/abstasks/multilabel_classification.py +2 -0
- mteb/abstasks/pair_classification.py +8 -2
- mteb/abstasks/retrieval.py +26 -11
- mteb/abstasks/retrieval_dataset_loaders.py +29 -19
- mteb/abstasks/sts.py +10 -3
- mteb/abstasks/text/bitext_mining.py +9 -5
- mteb/abstasks/text/reranking.py +2 -2
- mteb/abstasks/text/summarization.py +2 -1
- mteb/abstasks/zeroshot_classification.py +8 -2
- mteb/evaluate.py +13 -2
- mteb/models/model_implementations/bm25.py +2 -0
- mteb/models/model_implementations/pylate_models.py +10 -0
- mteb/models/models_protocols.py +4 -0
- mteb/models/search_wrappers.py +12 -0
- mteb/tasks/bitext_mining/eng/pub_chem_smiles_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/fas/fa_mteb_summary_retrieval.py +3 -3
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/norwegian_courts_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +2 -2
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -1
- mteb/tasks/classification/bul/bulgarian_store_review_sentiment_classfication.py +1 -1
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -1
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -1
- mteb/tasks/classification/ell/greek_legal_code_classification.py +1 -1
- mteb/tasks/classification/eng/dbpedia_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_conversations_classification.py +2 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +1 -1
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -1
- mteb/tasks/classification/eng/yelp_review_full_classification.py +2 -2
- mteb/tasks/classification/est/estonian_valence.py +1 -1
- mteb/tasks/classification/fas/fa_mteb_classification.py +6 -6
- mteb/tasks/classification/fas/persian_food_sentiment_classification.py +1 -1
- mteb/tasks/classification/fil/filipino_shopee_reviews_classification.py +1 -1
- mteb/tasks/classification/fin/fin_toxicity_classification.py +1 -1
- mteb/tasks/classification/fra/french_book_reviews.py +2 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -1
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -1
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -1
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +2 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -1
- mteb/tasks/classification/ita/dado_eval_coarse_classification.py +1 -1
- mteb/tasks/classification/ita/ita_casehold_classification.py +1 -1
- mteb/tasks/classification/ita/sardi_stance_classification.py +1 -1
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -1
- mteb/tasks/classification/jpn/wrime_classification.py +1 -1
- mteb/tasks/classification/kan/kannada_news_classification.py +2 -2
- mteb/tasks/classification/kor/klue_tc.py +2 -2
- mteb/tasks/classification/kor/kor_fin.py +1 -1
- mteb/tasks/classification/kor/kor_hate_classification.py +1 -1
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +1 -1
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -1
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/afri_senti_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -1
- mteb/tasks/classification/multilingual/cyrillic_turkic_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_nlp_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/masakha_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -1
- mteb/tasks/classification/multilingual/multilingual_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/scala_classification.py +1 -1
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -1
- mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
- mteb/tasks/classification/multilingual/tweet_sentiment_classification.py +1 -1
- mteb/tasks/classification/nep/nepali_news_classification.py +2 -2
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +1 -1
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +1 -1
- mteb/tasks/classification/ory/odia_news_classification.py +2 -2
- mteb/tasks/classification/pan/punjabi_news_classification.py +1 -1
- mteb/tasks/classification/ron/moroco.py +1 -1
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -1
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -1
- mteb/tasks/classification/rus/georeview_classification.py +1 -1
- mteb/tasks/classification/rus/headline_classification.py +2 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +2 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +2 -2
- mteb/tasks/classification/rus/ru_sci_bench_grnti_classification.py +1 -1
- mteb/tasks/classification/rus/ru_sci_bench_oecd_classification.py +1 -1
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -1
- mteb/tasks/classification/san/sanskrit_shlokas_classification.py +1 -1
- mteb/tasks/classification/sin/sinhala_news_classification.py +2 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +2 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -1
- mteb/tasks/classification/spa/spanish_news_classification.py +2 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -1
- mteb/tasks/classification/tam/tamil_news_classification.py +2 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +2 -2
- mteb/tasks/classification/tha/wongnai_reviews_classification.py +1 -1
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +2 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -2
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -1
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -1
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +2 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_p2p.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/arxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/arxiv_hierarchical_clustering.py +2 -2
- mteb/tasks/clustering/eng/big_patent_clustering.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/twenty_newsgroups_clustering.py +1 -1
- mteb/tasks/clustering/fas/fa_mteb_clustering.py +4 -4
- mteb/tasks/clustering/fra/hal_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -1
- mteb/tasks/clustering/multilingual/wiki_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nob/snl_clustering.py +1 -1
- mteb/tasks/clustering/nob/vg_clustering.py +1 -1
- mteb/tasks/clustering/pol/polish_clustering.py +3 -3
- mteb/tasks/clustering/rus/ru_sci_bench_grnti_clustering_p2p.py +1 -1
- mteb/tasks/clustering/rus/ru_sci_bench_oecd_clustering_p2p.py +1 -1
- mteb/tasks/clustering/zho/cmteb_clustering.py +4 -4
- mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
- mteb/tasks/multichoice/eng/cv_bench.py +4 -4
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -1
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -1
- mteb/tasks/multilabel_classification/rus/ru_toixic_multilabelclassification_okmlcup.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -1
- mteb/tasks/pair_classification/ara/ar_entail.py +1 -1
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -1
- mteb/tasks/pair_classification/deu/false_friends_de_en_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_ai_sentence_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_synonym_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_wiki_paragraphs_pc.py +1 -1
- mteb/tasks/pair_classification/eng/sprint_duplicate_questions_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_sem_eval2015_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_url_corpus_pc.py +1 -1
- mteb/tasks/pair_classification/fas/fa_mteb_pair_classification.py +5 -5
- mteb/tasks/pair_classification/fas/fars_tail.py +2 -2
- mteb/tasks/pair_classification/hye/armenian_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/ita/dis_co_tex_pair_classification.py +1 -1
- mteb/tasks/pair_classification/kor/klue_nli.py +1 -1
- mteb/tasks/pair_classification/multilingual/rte3.py +2 -2
- mteb/tasks/pair_classification/multilingual/xnli.py +1 -1
- mteb/tasks/pair_classification/pol/polish_pc.py +4 -4
- mteb/tasks/pair_classification/por/assin2_rte.py +1 -1
- mteb/tasks/pair_classification/por/sick_br_pc.py +1 -1
- mteb/tasks/pair_classification/rus/terra.py +2 -2
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -1
- mteb/tasks/pair_classification/zho/cmteb_pair_classification.py +2 -2
- mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +4 -4
- mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
- mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
- mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
- mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
- mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
- mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +1 -1
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
- mteb/tasks/retrieval/eng/bright_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/ml_questions.py +1 -1
- mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
- mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
- mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
- mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
- mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -2
- mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +14 -4
- mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nob/norquad.py +1 -1
- mteb/tasks/retrieval/nob/snl_retrieval.py +1 -1
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
- mteb/tasks/sts/fao/faroese_sts.py +1 -1
- mteb/tasks/sts/fra/sick_fr_sts.py +1 -1
- mteb/tasks/sts/kor/klue_sts.py +1 -1
- mteb/tasks/sts/por/sick_br_sts.py +1 -1
- mteb/tasks/sts/rus/ru_para_phraser_sts.py +1 -1
- mteb/tasks/zeroshot_classification/eng/sci_mmir.py +1 -1
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/METADATA +1 -1
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/RECORD +287 -287
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/WHEEL +0 -0
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/entry_points.txt +0 -0
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/top_level.txt +0 -0
mteb/_create_dataloaders.py
CHANGED
|
@@ -30,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
def _create_dataloader_from_texts(
|
|
31
31
|
text: list[str],
|
|
32
32
|
batch_size: int = 32,
|
|
33
|
+
num_proc: int = 1,
|
|
33
34
|
**kwargs: Any,
|
|
34
35
|
) -> DataLoader[TextInput]:
|
|
35
36
|
"""Create a dataloader from a list of text.
|
|
@@ -37,15 +38,17 @@ def _create_dataloader_from_texts(
|
|
|
37
38
|
Args:
|
|
38
39
|
text: A list of text to create a dataloader from.
|
|
39
40
|
batch_size: Batch size for the dataloader.
|
|
41
|
+
num_proc: Number of processes to use.
|
|
40
42
|
kwargs: Not used, present catching extra arguments.
|
|
41
43
|
|
|
42
44
|
Returns:
|
|
43
45
|
A dataloader with the text.
|
|
44
46
|
"""
|
|
45
47
|
dataset = Dataset.from_dict({"text": text})
|
|
46
|
-
return
|
|
48
|
+
return DataLoader(
|
|
47
49
|
dataset,
|
|
48
50
|
batch_size=batch_size,
|
|
51
|
+
num_workers=num_proc if num_proc > 1 else 0,
|
|
49
52
|
)
|
|
50
53
|
|
|
51
54
|
|
|
@@ -71,20 +74,27 @@ def _corpus_to_dict(
|
|
|
71
74
|
def _create_dataloader_for_retrieval_corpus(
|
|
72
75
|
dataset: Dataset,
|
|
73
76
|
batch_size: int = 32,
|
|
77
|
+
num_proc: int = 1,
|
|
74
78
|
) -> DataLoader[CorpusInput]:
|
|
75
79
|
"""Create a dataloader from a corpus.
|
|
76
80
|
|
|
77
81
|
Args:
|
|
78
82
|
dataset: Corpus
|
|
79
83
|
batch_size: Batch size for the dataloader.
|
|
84
|
+
num_proc: Number of processes to use.
|
|
80
85
|
|
|
81
86
|
Returns:
|
|
82
87
|
A dataloader with the corpus.
|
|
83
88
|
"""
|
|
84
|
-
new_ds = dataset.map(
|
|
85
|
-
|
|
89
|
+
new_ds = dataset.map(
|
|
90
|
+
_corpus_to_dict,
|
|
91
|
+
desc="Converting corpus dict",
|
|
92
|
+
num_proc=num_proc,
|
|
93
|
+
)
|
|
94
|
+
return DataLoader(
|
|
86
95
|
new_ds,
|
|
87
96
|
batch_size=batch_size,
|
|
97
|
+
num_workers=num_proc if num_proc > 1 else 0,
|
|
88
98
|
)
|
|
89
99
|
|
|
90
100
|
|
|
@@ -101,12 +111,14 @@ def _combine_queries_with_instruction_text(row: dict[str, str]) -> dict[str, str
|
|
|
101
111
|
def _create_text_dataloader_for_queries(
|
|
102
112
|
queries: QueryDatasetType,
|
|
103
113
|
batch_size: int = 32,
|
|
114
|
+
num_proc: int = 1,
|
|
104
115
|
) -> DataLoader[QueryInput]:
|
|
105
116
|
"""Create a dataloader from a list of queries.
|
|
106
117
|
|
|
107
118
|
Args:
|
|
108
119
|
queries: A list of queries.
|
|
109
120
|
batch_size: Batch size for the dataloader.
|
|
121
|
+
num_proc: Number of processes to use.
|
|
110
122
|
|
|
111
123
|
Returns:
|
|
112
124
|
A dataloader with the queries.
|
|
@@ -114,10 +126,12 @@ def _create_text_dataloader_for_queries(
|
|
|
114
126
|
queries = queries.map(
|
|
115
127
|
_combine_queries_with_instruction_text,
|
|
116
128
|
desc="Processing queries for dataloading",
|
|
129
|
+
num_proc=num_proc,
|
|
117
130
|
)
|
|
118
|
-
return
|
|
131
|
+
return DataLoader(
|
|
119
132
|
queries,
|
|
120
133
|
batch_size=batch_size,
|
|
134
|
+
num_workers=num_proc if num_proc > 1 else 0,
|
|
121
135
|
)
|
|
122
136
|
|
|
123
137
|
|
|
@@ -186,12 +200,14 @@ def _convert_conv_history_to_query(
|
|
|
186
200
|
def _create_dataloader_for_queries_conversation(
|
|
187
201
|
queries: QueryDatasetType,
|
|
188
202
|
batch_size: int = 32,
|
|
203
|
+
num_proc: int = 1,
|
|
189
204
|
) -> DataLoader[QueryInput]:
|
|
190
205
|
"""Create a dataloader from a list of queries.
|
|
191
206
|
|
|
192
207
|
Args:
|
|
193
208
|
queries: A list of queries.
|
|
194
209
|
batch_size: Batch size for the dataloader.
|
|
210
|
+
num_proc: Number of processes to use.
|
|
195
211
|
|
|
196
212
|
Returns:
|
|
197
213
|
A dataloader with the queries.
|
|
@@ -200,9 +216,11 @@ def _create_dataloader_for_queries_conversation(
|
|
|
200
216
|
queries.map(
|
|
201
217
|
_convert_conv_history_to_query,
|
|
202
218
|
desc="Converting conversations to queries",
|
|
219
|
+
num_proc=num_proc,
|
|
203
220
|
),
|
|
204
221
|
collate_fn=_custom_collate_fn,
|
|
205
222
|
batch_size=batch_size,
|
|
223
|
+
num_workers=num_proc if num_proc > 1 else 0,
|
|
206
224
|
)
|
|
207
225
|
|
|
208
226
|
|
|
@@ -247,6 +265,7 @@ def _prepare_image_dataset(
|
|
|
247
265
|
dataset: Dataset,
|
|
248
266
|
image_column_name: str | None = None,
|
|
249
267
|
transform: Callable[[Any], Any] | None = None,
|
|
268
|
+
num_proc: int = 1,
|
|
250
269
|
) -> Dataset:
|
|
251
270
|
"""Prepare the image dataset by converting images to RGB and applying transformations."""
|
|
252
271
|
if (
|
|
@@ -262,6 +281,7 @@ def _prepare_image_dataset(
|
|
|
262
281
|
_convert_images_to_rgb,
|
|
263
282
|
fn_kwargs={"image_col_name": "image", "transform": transform},
|
|
264
283
|
desc="Converting images to RGB",
|
|
284
|
+
num_proc=num_proc,
|
|
265
285
|
)
|
|
266
286
|
|
|
267
287
|
|
|
@@ -295,6 +315,7 @@ def _create_image_dataloader(
|
|
|
295
315
|
batch_size: int = 32,
|
|
296
316
|
transform: Callable[[Any], Any] | None = None,
|
|
297
317
|
collate_fn: Callable[[list[dict[str, Any]]], dict[str, Any]] = _custom_collate_fn,
|
|
318
|
+
num_proc: int = 1,
|
|
298
319
|
) -> DataLoader[ImageInput]:
|
|
299
320
|
"""Creates a DataLoader with the image dataset prepared using the explicit transformation.
|
|
300
321
|
|
|
@@ -304,33 +325,41 @@ def _create_image_dataloader(
|
|
|
304
325
|
batch_size: Batch size for the dataloader.
|
|
305
326
|
transform: A transformation function to apply to each image (e.g., converting to tensor).
|
|
306
327
|
collate_fn: A custom collate function to handle batching.
|
|
328
|
+
num_proc: Number of processes to use.
|
|
307
329
|
|
|
308
330
|
Returns:
|
|
309
331
|
A DataLoader with the image dataset.
|
|
310
332
|
"""
|
|
311
333
|
dataset = _prepare_image_dataset(
|
|
312
|
-
dataset,
|
|
334
|
+
dataset,
|
|
335
|
+
image_column_name,
|
|
336
|
+
transform,
|
|
337
|
+
num_proc=num_proc,
|
|
313
338
|
).select_columns(["image"])
|
|
314
339
|
return DataLoader(
|
|
315
340
|
dataset,
|
|
316
341
|
batch_size=batch_size,
|
|
317
342
|
collate_fn=collate_fn,
|
|
318
343
|
shuffle=False,
|
|
344
|
+
num_workers=num_proc if num_proc > 1 else 0,
|
|
319
345
|
)
|
|
320
346
|
|
|
321
347
|
|
|
322
348
|
def _create_text_queries_dataloader(
|
|
323
349
|
dataset: Dataset,
|
|
324
350
|
batch_size: int = 32,
|
|
351
|
+
num_proc: int = 1,
|
|
325
352
|
) -> DataLoader[QueryInput]:
|
|
326
353
|
if not isinstance(dataset["text"][0], list):
|
|
327
354
|
return _create_text_dataloader_for_queries(
|
|
328
355
|
dataset,
|
|
329
356
|
batch_size=batch_size,
|
|
357
|
+
num_proc=num_proc,
|
|
330
358
|
)
|
|
331
359
|
return _create_dataloader_for_queries_conversation(
|
|
332
360
|
dataset,
|
|
333
361
|
batch_size=batch_size,
|
|
362
|
+
num_proc=num_proc,
|
|
334
363
|
)
|
|
335
364
|
|
|
336
365
|
|
|
@@ -339,6 +368,7 @@ def _create_queries_dataloader(
|
|
|
339
368
|
task_metadata: TaskMetadata,
|
|
340
369
|
input_column: str | None = None,
|
|
341
370
|
batch_size: int = 32,
|
|
371
|
+
num_proc: int = 1,
|
|
342
372
|
) -> DataLoader[QueryInput | ImageInput]:
|
|
343
373
|
"""Create a dataloader for queries."""
|
|
344
374
|
queries_type = task_metadata.get_modalities(PromptType.query)
|
|
@@ -346,12 +376,14 @@ def _create_queries_dataloader(
|
|
|
346
376
|
return _create_text_queries_dataloader(
|
|
347
377
|
dataset,
|
|
348
378
|
batch_size=batch_size,
|
|
379
|
+
num_proc=num_proc,
|
|
349
380
|
)
|
|
350
381
|
if "image" in queries_type: # contains image
|
|
351
382
|
return _create_image_dataloader(
|
|
352
383
|
dataset,
|
|
353
384
|
image_column_name="image",
|
|
354
385
|
batch_size=batch_size,
|
|
386
|
+
num_proc=num_proc,
|
|
355
387
|
)
|
|
356
388
|
raise ValueError(f"Can't handle queries type {queries_type}")
|
|
357
389
|
|
|
@@ -361,6 +393,7 @@ def _create_document_dataloader(
|
|
|
361
393
|
task_metadata: TaskMetadata,
|
|
362
394
|
input_column: str | None = None,
|
|
363
395
|
batch_size: int = 32,
|
|
396
|
+
num_proc: int = 1,
|
|
364
397
|
) -> DataLoader[CorpusInput | ImageInput]:
|
|
365
398
|
"""Create a dataloader for documents.
|
|
366
399
|
|
|
@@ -369,6 +402,7 @@ def _create_document_dataloader(
|
|
|
369
402
|
task_metadata: Metadata of the task to determine the document type.
|
|
370
403
|
input_column: The column to use as input. If None, it will use the first column that matches the modality.
|
|
371
404
|
batch_size: Batch size for the dataloader.
|
|
405
|
+
num_proc: Number of processes to use.
|
|
372
406
|
|
|
373
407
|
Returns:
|
|
374
408
|
A dataloader for the documents.
|
|
@@ -378,12 +412,14 @@ def _create_document_dataloader(
|
|
|
378
412
|
return _create_dataloader_for_retrieval_corpus(
|
|
379
413
|
dataset,
|
|
380
414
|
batch_size=batch_size,
|
|
415
|
+
num_proc=num_proc,
|
|
381
416
|
)
|
|
382
417
|
if "image" in document_type: # contains image
|
|
383
418
|
return _create_image_dataloader(
|
|
384
419
|
dataset,
|
|
385
420
|
image_column_name="image",
|
|
386
421
|
batch_size=batch_size,
|
|
422
|
+
num_proc=num_proc,
|
|
387
423
|
)
|
|
388
424
|
raise ValueError(f"Can't handle queries type {document_type}")
|
|
389
425
|
|
|
@@ -394,6 +430,7 @@ def create_dataloader(
|
|
|
394
430
|
prompt_type: PromptType | None = None,
|
|
395
431
|
input_column: str | None = None,
|
|
396
432
|
batch_size: int = 32,
|
|
433
|
+
num_proc: int = 1,
|
|
397
434
|
**kwargs: Any,
|
|
398
435
|
) -> DataLoader[BatchedInput]:
|
|
399
436
|
"""Create a dataloader from a dataset.
|
|
@@ -407,6 +444,7 @@ def create_dataloader(
|
|
|
407
444
|
prompt_type: The type of prompt to create a dataloader for. If None, it will be inferred from the task metadata.
|
|
408
445
|
input_column: The column to use as input. If None, it will use the first column that matches the modality.
|
|
409
446
|
batch_size: The batch size for the dataloader.
|
|
447
|
+
num_proc: The number of processes to use for dataset processing.
|
|
410
448
|
**kwargs: Additional arguments to pass to the dataloader creation functions.
|
|
411
449
|
|
|
412
450
|
Returns:
|
|
@@ -418,6 +456,7 @@ def create_dataloader(
|
|
|
418
456
|
task_metadata,
|
|
419
457
|
batch_size=batch_size,
|
|
420
458
|
input_column=input_column,
|
|
459
|
+
num_proc=num_proc,
|
|
421
460
|
)
|
|
422
461
|
if prompt_type == PromptType.document:
|
|
423
462
|
return _create_document_dataloader(
|
|
@@ -425,6 +464,7 @@ def create_dataloader(
|
|
|
425
464
|
task_metadata,
|
|
426
465
|
input_column=input_column,
|
|
427
466
|
batch_size=batch_size,
|
|
467
|
+
num_proc=num_proc,
|
|
428
468
|
)
|
|
429
469
|
|
|
430
470
|
if "image" in task_metadata.modalities:
|
|
@@ -432,6 +472,7 @@ def create_dataloader(
|
|
|
432
472
|
dataset,
|
|
433
473
|
image_column_name=input_column,
|
|
434
474
|
batch_size=batch_size,
|
|
475
|
+
num_proc=num_proc,
|
|
435
476
|
)
|
|
436
477
|
if "text" in task_metadata.modalities and input_column is not None:
|
|
437
478
|
return _create_dataloader_from_texts(
|
|
@@ -441,4 +482,5 @@ def create_dataloader(
|
|
|
441
482
|
return DataLoader(
|
|
442
483
|
dataset,
|
|
443
484
|
batch_size=batch_size,
|
|
485
|
+
num_workers=num_proc if num_proc > 1 else 0,
|
|
444
486
|
)
|
|
@@ -66,6 +66,7 @@ class AnySTSEvaluator(Evaluator):
|
|
|
66
66
|
model: EncoderProtocol,
|
|
67
67
|
*,
|
|
68
68
|
encode_kwargs: EncodeKwargs,
|
|
69
|
+
num_proc: int = 1,
|
|
69
70
|
) -> STSEvaluatorScores:
|
|
70
71
|
logger.info("Running semantic similarity - Encoding samples (1/2)")
|
|
71
72
|
embeddings1 = model.encode(
|
|
@@ -73,6 +74,7 @@ class AnySTSEvaluator(Evaluator):
|
|
|
73
74
|
self.dataset,
|
|
74
75
|
self.task_metadata,
|
|
75
76
|
input_column=self.input_columns[0],
|
|
77
|
+
num_proc=num_proc,
|
|
76
78
|
**encode_kwargs,
|
|
77
79
|
),
|
|
78
80
|
task_metadata=self.task_metadata,
|
|
@@ -45,11 +45,13 @@ class ClusteringEvaluator(Evaluator):
|
|
|
45
45
|
model: EncoderProtocol,
|
|
46
46
|
*,
|
|
47
47
|
encode_kwargs: EncodeKwargs,
|
|
48
|
+
num_proc: int = 1,
|
|
48
49
|
) -> list[int]:
|
|
49
50
|
data_loader = create_dataloader(
|
|
50
51
|
self.dataset,
|
|
51
52
|
self.task_metadata,
|
|
52
53
|
input_column=self.input_column_name,
|
|
54
|
+
num_proc=num_proc,
|
|
53
55
|
**encode_kwargs,
|
|
54
56
|
)
|
|
55
57
|
|
mteb/_evaluators/evaluator.py
CHANGED
|
@@ -24,7 +24,7 @@ class Evaluator(ABC):
|
|
|
24
24
|
|
|
25
25
|
@abstractmethod
|
|
26
26
|
def __call__(
|
|
27
|
-
self, model: EncoderProtocol, *, encode_kwargs: EncodeKwargs
|
|
27
|
+
self, model: EncoderProtocol, *, encode_kwargs: EncodeKwargs, num_proc: int = 1
|
|
28
28
|
) -> Mapping[str, float] | Iterable[Any]:
|
|
29
29
|
"""This is called during training to evaluate the model.
|
|
30
30
|
|
|
@@ -33,5 +33,6 @@ class Evaluator(ABC):
|
|
|
33
33
|
Args:
|
|
34
34
|
model: the model to evaluate
|
|
35
35
|
encode_kwargs: kwargs to pass to the model's encode method
|
|
36
|
+
num_proc: number of processes to use for data loading
|
|
36
37
|
"""
|
|
37
38
|
pass
|
|
@@ -91,6 +91,7 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
91
91
|
model: EncoderProtocol,
|
|
92
92
|
*,
|
|
93
93
|
encode_kwargs: EncodeKwargs,
|
|
94
|
+
num_proc: int = 1,
|
|
94
95
|
) -> list[torch.Tensor]:
|
|
95
96
|
images = []
|
|
96
97
|
if isinstance(self.images_column_names, str):
|
|
@@ -113,6 +114,7 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
113
114
|
text_embeddings = model.encode(
|
|
114
115
|
_create_dataloader_from_texts(
|
|
115
116
|
texts,
|
|
117
|
+
num_proc=num_proc,
|
|
116
118
|
**encode_kwargs,
|
|
117
119
|
),
|
|
118
120
|
task_metadata=self.task_metadata,
|
|
@@ -129,10 +131,15 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
129
131
|
dim=-1,
|
|
130
132
|
).view(len(self.dataset), self.num_texts_per_sample, -1)
|
|
131
133
|
|
|
134
|
+
def _image_collate_fn(batch):
|
|
135
|
+
"""Collate function for image batches."""
|
|
136
|
+
return {"image": [item["image"] for item in batch]}
|
|
137
|
+
|
|
132
138
|
image_embeddings = model.encode(
|
|
133
139
|
DataLoader(
|
|
134
140
|
CustomImageDataset(images),
|
|
135
|
-
collate_fn=
|
|
141
|
+
collate_fn=_image_collate_fn,
|
|
142
|
+
num_workers=num_proc if num_proc > 1 else 0,
|
|
136
143
|
),
|
|
137
144
|
task_metadata=self.task_metadata,
|
|
138
145
|
hf_subset=self.hf_subset,
|
|
@@ -91,6 +91,7 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
91
91
|
self,
|
|
92
92
|
model: EncoderProtocol,
|
|
93
93
|
encode_kwargs: EncodeKwargs,
|
|
94
|
+
num_proc: int = 1,
|
|
94
95
|
) -> PairClassificationDistances:
|
|
95
96
|
logger.info("Running pair classification - Encoding samples (1/2)")
|
|
96
97
|
embeddings1 = model.encode(
|
|
@@ -98,6 +99,7 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
98
99
|
self.dataset,
|
|
99
100
|
task_metadata=self.task_metadata,
|
|
100
101
|
input_column=self.input1_column_name,
|
|
102
|
+
num_proc=num_proc,
|
|
101
103
|
**encode_kwargs,
|
|
102
104
|
),
|
|
103
105
|
task_metadata=self.task_metadata,
|
|
@@ -112,6 +114,7 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
112
114
|
self.dataset,
|
|
113
115
|
task_metadata=self.task_metadata,
|
|
114
116
|
input_column=self.input2_column_name,
|
|
117
|
+
num_proc=num_proc,
|
|
115
118
|
**encode_kwargs,
|
|
116
119
|
),
|
|
117
120
|
task_metadata=self.task_metadata,
|
|
@@ -55,6 +55,7 @@ class RetrievalEvaluator(Evaluator):
|
|
|
55
55
|
self,
|
|
56
56
|
search_model: SearchProtocol,
|
|
57
57
|
encode_kwargs: EncodeKwargs,
|
|
58
|
+
num_proc: int = 1,
|
|
58
59
|
) -> RetrievalOutputType:
|
|
59
60
|
logger.info("Running retrieval task - Indexing corpus...")
|
|
60
61
|
search_model.index(
|
|
@@ -63,6 +64,7 @@ class RetrievalEvaluator(Evaluator):
|
|
|
63
64
|
hf_split=self.hf_split,
|
|
64
65
|
hf_subset=self.hf_subset,
|
|
65
66
|
encode_kwargs=encode_kwargs,
|
|
67
|
+
num_proc=num_proc,
|
|
66
68
|
)
|
|
67
69
|
logger.info("Running retrieval task - Searching queries...")
|
|
68
70
|
return search_model.search(
|
|
@@ -73,6 +75,7 @@ class RetrievalEvaluator(Evaluator):
|
|
|
73
75
|
hf_subset=self.hf_subset,
|
|
74
76
|
encode_kwargs=encode_kwargs,
|
|
75
77
|
top_ranked=self.top_ranked,
|
|
78
|
+
num_proc=num_proc,
|
|
76
79
|
)
|
|
77
80
|
|
|
78
81
|
def evaluate(
|
|
@@ -54,18 +54,20 @@ class SklearnEvaluator(Evaluator):
|
|
|
54
54
|
self.evaluator_model = evaluator_model
|
|
55
55
|
|
|
56
56
|
def create_dataloaders(
|
|
57
|
-
self, encode_kwargs: EncodeKwargs
|
|
57
|
+
self, encode_kwargs: EncodeKwargs, num_proc: int
|
|
58
58
|
) -> tuple[DataLoader[BatchedInput], DataLoader[BatchedInput]]:
|
|
59
59
|
dataloader_train = create_dataloader(
|
|
60
60
|
self.train_dataset,
|
|
61
61
|
self.task_metadata,
|
|
62
62
|
input_column=self.values_column_name,
|
|
63
|
+
num_proc=num_proc,
|
|
63
64
|
**encode_kwargs,
|
|
64
65
|
)
|
|
65
66
|
dataloader_test = create_dataloader(
|
|
66
67
|
self.eval_dataset,
|
|
67
68
|
self.task_metadata,
|
|
68
69
|
input_column=self.values_column_name,
|
|
70
|
+
num_proc=num_proc,
|
|
69
71
|
**encode_kwargs,
|
|
70
72
|
)
|
|
71
73
|
return dataloader_train, dataloader_test
|
|
@@ -76,6 +78,7 @@ class SklearnEvaluator(Evaluator):
|
|
|
76
78
|
*,
|
|
77
79
|
encode_kwargs: EncodeKwargs,
|
|
78
80
|
test_cache: Array | None = None,
|
|
81
|
+
num_proc: int = 1,
|
|
79
82
|
) -> tuple[np.ndarray, Array]:
|
|
80
83
|
"""Classification evaluation by training a sklearn classifier on the embeddings of the training set and evaluating on the embeddings of the test set.
|
|
81
84
|
|
|
@@ -83,6 +86,7 @@ class SklearnEvaluator(Evaluator):
|
|
|
83
86
|
model: Encoder
|
|
84
87
|
encode_kwargs: encode kwargs
|
|
85
88
|
test_cache: embeddings of the test set, if already computed
|
|
89
|
+
num_proc: number of processes to use
|
|
86
90
|
|
|
87
91
|
Returns:
|
|
88
92
|
Tuple of test predictions and embeddings
|
|
@@ -90,6 +94,7 @@ class SklearnEvaluator(Evaluator):
|
|
|
90
94
|
"""
|
|
91
95
|
dataloader_train, dataloader_test = self.create_dataloaders(
|
|
92
96
|
encode_kwargs=encode_kwargs,
|
|
97
|
+
num_proc=num_proc,
|
|
93
98
|
)
|
|
94
99
|
|
|
95
100
|
logger.info("Running - Encoding samples...")
|
|
@@ -41,6 +41,7 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
41
41
|
model: EncoderProtocol,
|
|
42
42
|
*,
|
|
43
43
|
encode_kwargs: EncodeKwargs,
|
|
44
|
+
num_proc: int = 1,
|
|
44
45
|
) -> dict[str, list[dict[str, float]]]:
|
|
45
46
|
pair_elements = {p for pair in self.pairs for p in pair}
|
|
46
47
|
if isinstance(self.sentences, Dataset):
|
|
@@ -55,6 +56,7 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
55
56
|
for sub in tqdm(subsets):
|
|
56
57
|
dataloader = _create_dataloader_from_texts(
|
|
57
58
|
self.sentences[sub],
|
|
59
|
+
num_proc=num_proc,
|
|
58
60
|
**encode_kwargs,
|
|
59
61
|
)
|
|
60
62
|
embeddings[sub] = model.encode(
|
|
@@ -100,6 +100,7 @@ class SummarizationEvaluator(Evaluator):
|
|
|
100
100
|
model: EncoderProtocol,
|
|
101
101
|
*,
|
|
102
102
|
encode_kwargs: EncodeKwargs,
|
|
103
|
+
num_proc: int = 1,
|
|
103
104
|
) -> SummarizationDistances:
|
|
104
105
|
# Get the human & machine summaries for the text in one go for all
|
|
105
106
|
human_lens = [len(human_summaries) for human_summaries in self.human_summaries]
|
|
@@ -115,6 +116,7 @@ class SummarizationEvaluator(Evaluator):
|
|
|
115
116
|
for human_summaries in self.human_summaries
|
|
116
117
|
for summary in human_summaries
|
|
117
118
|
],
|
|
119
|
+
num_proc=num_proc,
|
|
118
120
|
**encode_kwargs,
|
|
119
121
|
),
|
|
120
122
|
task_metadata=self.task_metadata,
|
|
@@ -48,11 +48,13 @@ class ZeroShotClassificationEvaluator(Evaluator):
|
|
|
48
48
|
model: EncoderProtocol,
|
|
49
49
|
*,
|
|
50
50
|
encode_kwargs: EncodeKwargs,
|
|
51
|
+
num_proc: int = 1,
|
|
51
52
|
) -> Array:
|
|
52
53
|
dataloader = create_dataloader(
|
|
53
54
|
self.dataset,
|
|
54
55
|
input_column=self.input_column_name,
|
|
55
56
|
task_metadata=self.task_metadata,
|
|
57
|
+
num_proc=num_proc,
|
|
56
58
|
**encode_kwargs,
|
|
57
59
|
)
|
|
58
60
|
|
mteb/abstasks/abstask.py
CHANGED
|
@@ -116,11 +116,14 @@ class AbsTask(ABC):
|
|
|
116
116
|
logger.warning(msg)
|
|
117
117
|
warnings.warn(msg)
|
|
118
118
|
|
|
119
|
-
def dataset_transform(self):
|
|
119
|
+
def dataset_transform(self, num_proc: int = 1):
|
|
120
120
|
"""A transform operations applied to the dataset after loading.
|
|
121
121
|
|
|
122
122
|
This method is useful when the dataset from Huggingface is not in an `mteb` compatible format.
|
|
123
123
|
Override this method if your dataset requires additional transformation.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
num_proc: Number of processes to use for the transformation.
|
|
124
127
|
"""
|
|
125
128
|
pass
|
|
126
129
|
|
|
@@ -132,6 +135,7 @@ class AbsTask(ABC):
|
|
|
132
135
|
*,
|
|
133
136
|
encode_kwargs: EncodeKwargs,
|
|
134
137
|
prediction_folder: Path | None = None,
|
|
138
|
+
num_proc: int = 1,
|
|
135
139
|
**kwargs: Any,
|
|
136
140
|
) -> Mapping[HFSubset, ScoresDict]:
|
|
137
141
|
"""Evaluates an MTEB compatible model on the task.
|
|
@@ -142,6 +146,7 @@ class AbsTask(ABC):
|
|
|
142
146
|
subsets_to_run: List of huggingface subsets (HFSubsets) to evaluate. If None, all subsets are evaluated.
|
|
143
147
|
encode_kwargs: Additional keyword arguments that are passed to the model's `encode` method.
|
|
144
148
|
prediction_folder: Folder to save model predictions
|
|
149
|
+
num_proc: Number of processes to use for loading the dataset or processing.
|
|
145
150
|
kwargs: Additional keyword arguments that are passed to the _evaluate_subset method.
|
|
146
151
|
|
|
147
152
|
Returns:
|
|
@@ -197,6 +202,7 @@ class AbsTask(ABC):
|
|
|
197
202
|
hf_subset=hf_subset,
|
|
198
203
|
encode_kwargs=encode_kwargs,
|
|
199
204
|
prediction_folder=prediction_folder,
|
|
205
|
+
num_proc=num_proc,
|
|
200
206
|
**kwargs,
|
|
201
207
|
)
|
|
202
208
|
self._add_main_score(scores[hf_subset])
|
|
@@ -212,6 +218,7 @@ class AbsTask(ABC):
|
|
|
212
218
|
hf_subset: str,
|
|
213
219
|
encode_kwargs: EncodeKwargs,
|
|
214
220
|
prediction_folder: Path | None = None,
|
|
221
|
+
num_proc: int = 1,
|
|
215
222
|
**kwargs: Any,
|
|
216
223
|
) -> ScoresDict:
|
|
217
224
|
raise NotImplementedError(
|
|
@@ -316,11 +323,15 @@ class AbsTask(ABC):
|
|
|
316
323
|
) # only take the specified test split.
|
|
317
324
|
return dataset_dict
|
|
318
325
|
|
|
319
|
-
def load_data(self) -> None:
|
|
326
|
+
def load_data(self, num_proc: int = 1, **kwargs: Any) -> None:
|
|
320
327
|
"""Loads dataset from HuggingFace hub
|
|
321
328
|
|
|
322
329
|
This is the main loading function for Task. Do not overwrite this, instead we recommend using `dataset_transform`, which is called after the
|
|
323
330
|
dataset is loaded using `datasets.load_dataset`.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
num_proc: Number of processes to use for loading the dataset.
|
|
334
|
+
kwargs: Additional keyword arguments passed to the load_dataset function. Keep for forward compatibility.
|
|
324
335
|
"""
|
|
325
336
|
if self.data_loaded:
|
|
326
337
|
return
|
|
@@ -333,11 +344,12 @@ class AbsTask(ABC):
|
|
|
333
344
|
self.dataset[hf_subset] = load_dataset(
|
|
334
345
|
name=hf_subset,
|
|
335
346
|
**self.metadata.dataset,
|
|
347
|
+
num_proc=num_proc,
|
|
336
348
|
)
|
|
337
349
|
else:
|
|
338
350
|
# some of monolingual datasets explicitly adding the split name to the dataset name
|
|
339
|
-
self.dataset = load_dataset(**self.metadata.dataset)
|
|
340
|
-
self.dataset_transform()
|
|
351
|
+
self.dataset = load_dataset(**self.metadata.dataset, num_proc=num_proc)
|
|
352
|
+
self.dataset_transform(num_proc=num_proc)
|
|
341
353
|
self.data_loaded = True
|
|
342
354
|
|
|
343
355
|
def fast_load(self) -> None:
|
|
@@ -360,12 +372,13 @@ class AbsTask(ABC):
|
|
|
360
372
|
self.dataset[lang] = DatasetDict(subset)
|
|
361
373
|
|
|
362
374
|
def calculate_descriptive_statistics(
|
|
363
|
-
self, overwrite_results: bool = False
|
|
375
|
+
self, overwrite_results: bool = False, num_proc: int = 1
|
|
364
376
|
) -> dict[str, DescriptiveStatistics]:
|
|
365
377
|
"""Calculates descriptive statistics from the dataset.
|
|
366
378
|
|
|
367
379
|
Args:
|
|
368
380
|
overwrite_results: Whether to overwrite existing results. If False and results already exist, the existing results will be loaded from cache.
|
|
381
|
+
num_proc: Number of processes to use for loading the dataset.
|
|
369
382
|
|
|
370
383
|
Returns:
|
|
371
384
|
A dictionary containing descriptive statistics for each split.
|
|
@@ -379,7 +392,7 @@ class AbsTask(ABC):
|
|
|
379
392
|
return existing_stats
|
|
380
393
|
|
|
381
394
|
if not self.data_loaded:
|
|
382
|
-
self.load_data()
|
|
395
|
+
self.load_data(num_proc=num_proc)
|
|
383
396
|
|
|
384
397
|
descriptive_stats: dict[str, DescriptiveStatistics] = {}
|
|
385
398
|
hf_subset_stat: Literal["hf_subset_descriptive_stats"] = (
|
|
@@ -517,7 +530,7 @@ class AbsTask(ABC):
|
|
|
517
530
|
scores["main_score"] = scores[self.metadata.main_score]
|
|
518
531
|
|
|
519
532
|
def _upload_dataset_to_hub(
|
|
520
|
-
self, repo_name: str, fields: list[str] | dict[str, str]
|
|
533
|
+
self, repo_name: str, fields: list[str] | dict[str, str], num_proc: int = 1
|
|
521
534
|
) -> None:
|
|
522
535
|
if self.dataset is None:
|
|
523
536
|
raise ValueError("Dataset not loaded")
|
|
@@ -542,7 +555,10 @@ class AbsTask(ABC):
|
|
|
542
555
|
)
|
|
543
556
|
sentences = DatasetDict(sentences)
|
|
544
557
|
sentences.push_to_hub(
|
|
545
|
-
repo_name,
|
|
558
|
+
repo_name,
|
|
559
|
+
config,
|
|
560
|
+
commit_message=f"Add {config} dataset",
|
|
561
|
+
num_proc=num_proc,
|
|
546
562
|
)
|
|
547
563
|
else:
|
|
548
564
|
sentences = {}
|
|
@@ -559,16 +575,19 @@ class AbsTask(ABC):
|
|
|
559
575
|
{field: self.dataset[split][field] for field in fields}
|
|
560
576
|
)
|
|
561
577
|
sentences = DatasetDict(sentences)
|
|
562
|
-
sentences.push_to_hub(
|
|
578
|
+
sentences.push_to_hub(
|
|
579
|
+
repo_name, commit_message="Add dataset", num_proc=num_proc
|
|
580
|
+
)
|
|
563
581
|
|
|
564
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
582
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
565
583
|
raise NotImplementedError
|
|
566
584
|
|
|
567
|
-
def push_dataset_to_hub(self, repo_name: str) -> None:
|
|
585
|
+
def push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
568
586
|
"""Push the dataset to the HuggingFace Hub.
|
|
569
587
|
|
|
570
588
|
Args:
|
|
571
589
|
repo_name: The name of the repository to push the dataset to.
|
|
590
|
+
num_proc: Number of processes to use for loading the dataset.
|
|
572
591
|
|
|
573
592
|
Examples:
|
|
574
593
|
>>> import mteb
|
|
@@ -580,7 +599,7 @@ class AbsTask(ABC):
|
|
|
580
599
|
if not self.data_loaded:
|
|
581
600
|
self.load_data()
|
|
582
601
|
|
|
583
|
-
self._push_dataset_to_hub(repo_name)
|
|
602
|
+
self._push_dataset_to_hub(repo_name, num_proc)
|
|
584
603
|
# dataset repo not creating when pushing card
|
|
585
604
|
self.metadata.push_dataset_card_to_hub(repo_name)
|
|
586
605
|
|