mteb 2.7.4__py3-none-any.whl → 2.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/_create_dataloaders.py +47 -5
- mteb/_evaluators/any_sts_evaluator.py +2 -0
- mteb/_evaluators/clustering_evaluator.py +2 -0
- mteb/_evaluators/evaluator.py +2 -1
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +8 -1
- mteb/_evaluators/pair_classification_evaluator.py +3 -0
- mteb/_evaluators/retrieval_evaluator.py +3 -0
- mteb/_evaluators/sklearn_evaluator.py +6 -1
- mteb/_evaluators/text/bitext_mining_evaluator.py +2 -0
- mteb/_evaluators/text/summarization_evaluator.py +2 -0
- mteb/_evaluators/zeroshot_classification_evaluator.py +2 -0
- mteb/abstasks/abstask.py +31 -12
- mteb/abstasks/classification.py +10 -3
- mteb/abstasks/clustering.py +6 -2
- mteb/abstasks/clustering_legacy.py +8 -2
- mteb/abstasks/image/image_text_pair_classification.py +6 -2
- mteb/abstasks/multilabel_classification.py +2 -0
- mteb/abstasks/pair_classification.py +8 -2
- mteb/abstasks/retrieval.py +26 -11
- mteb/abstasks/retrieval_dataset_loaders.py +29 -19
- mteb/abstasks/sts.py +10 -3
- mteb/abstasks/text/bitext_mining.py +9 -5
- mteb/abstasks/text/reranking.py +2 -2
- mteb/abstasks/text/summarization.py +2 -1
- mteb/abstasks/zeroshot_classification.py +8 -2
- mteb/evaluate.py +13 -2
- mteb/models/model_implementations/bm25.py +2 -0
- mteb/models/model_implementations/pylate_models.py +10 -0
- mteb/models/models_protocols.py +4 -0
- mteb/models/search_wrappers.py +12 -0
- mteb/tasks/bitext_mining/eng/pub_chem_smiles_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/fas/fa_mteb_summary_retrieval.py +3 -3
- mteb/tasks/bitext_mining/multilingual/bucc_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/norwegian_courts_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +2 -2
- mteb/tasks/bitext_mining/multilingual/web_faq_bitext_mining.py +2 -2
- mteb/tasks/classification/ara/online_store_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/restaurant_review_sentiment_classification.py +1 -1
- mteb/tasks/classification/ara/tweet_sarcasm_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_hate_speech_classification.py +1 -1
- mteb/tasks/classification/ben/bengali_sentiment_analysis.py +1 -1
- mteb/tasks/classification/bul/bulgarian_store_review_sentiment_classfication.py +1 -1
- mteb/tasks/classification/ces/csfdcz_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/dan/ddisco_cohesion_classification.py +1 -1
- mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
- mteb/tasks/classification/deu/german_politicians_twitter_sentiment_classification.py +1 -1
- mteb/tasks/classification/ell/greek_legal_code_classification.py +1 -1
- mteb/tasks/classification/eng/dbpedia_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_chat_classification.py +2 -2
- mteb/tasks/classification/eng/toxic_conversations_classification.py +2 -2
- mteb/tasks/classification/eng/tweet_topic_single_classification.py +1 -1
- mteb/tasks/classification/eng/yahoo_answers_topics_classification.py +1 -1
- mteb/tasks/classification/eng/yelp_review_full_classification.py +2 -2
- mteb/tasks/classification/est/estonian_valence.py +1 -1
- mteb/tasks/classification/fas/fa_mteb_classification.py +6 -6
- mteb/tasks/classification/fas/persian_food_sentiment_classification.py +1 -1
- mteb/tasks/classification/fil/filipino_shopee_reviews_classification.py +1 -1
- mteb/tasks/classification/fin/fin_toxicity_classification.py +1 -1
- mteb/tasks/classification/fra/french_book_reviews.py +2 -2
- mteb/tasks/classification/fra/movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/guj/gujarati_news_classification.py +1 -1
- mteb/tasks/classification/hin/hindi_discourse_classification.py +1 -1
- mteb/tasks/classification/hin/sentiment_analysis_hindi.py +1 -1
- mteb/tasks/classification/ind/indonesian_id_clickbait_classification.py +2 -2
- mteb/tasks/classification/ind/indonesian_mongabay_conservation_classification.py +1 -1
- mteb/tasks/classification/ita/dado_eval_coarse_classification.py +1 -1
- mteb/tasks/classification/ita/ita_casehold_classification.py +1 -1
- mteb/tasks/classification/ita/sardi_stance_classification.py +1 -1
- mteb/tasks/classification/jav/javanese_imdb_classification.py +1 -1
- mteb/tasks/classification/jpn/wrime_classification.py +1 -1
- mteb/tasks/classification/kan/kannada_news_classification.py +2 -2
- mteb/tasks/classification/kor/klue_tc.py +2 -2
- mteb/tasks/classification/kor/kor_fin.py +1 -1
- mteb/tasks/classification/kor/kor_hate_classification.py +1 -1
- mteb/tasks/classification/kor/kor_sarcasm_classification.py +1 -1
- mteb/tasks/classification/mal/malayalam_news_classification.py +1 -1
- mteb/tasks/classification/mar/marathi_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/afri_senti_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/catalonia_tweet_classification.py +1 -1
- mteb/tasks/classification/multilingual/cyrillic_turkic_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_nlp_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/masakha_news_classification.py +1 -1
- mteb/tasks/classification/multilingual/multi_hate_classification.py +1 -1
- mteb/tasks/classification/multilingual/multilingual_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/scala_classification.py +1 -1
- mteb/tasks/classification/multilingual/sib200_classification.py +1 -1
- mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
- mteb/tasks/classification/multilingual/tweet_sentiment_classification.py +1 -1
- mteb/tasks/classification/nep/nepali_news_classification.py +2 -2
- mteb/tasks/classification/nld/dutch_sarcastic_headlines_classification.py +1 -1
- mteb/tasks/classification/nld/vaccin_chat_nl_classification.py +1 -1
- mteb/tasks/classification/ory/odia_news_classification.py +2 -2
- mteb/tasks/classification/pan/punjabi_news_classification.py +1 -1
- mteb/tasks/classification/ron/moroco.py +1 -1
- mteb/tasks/classification/ron/romanian_reviews_sentiment.py +1 -1
- mteb/tasks/classification/ron/romanian_sentiment_classification.py +1 -1
- mteb/tasks/classification/rus/georeview_classification.py +1 -1
- mteb/tasks/classification/rus/headline_classification.py +2 -2
- mteb/tasks/classification/rus/inappropriateness_classification.py +2 -2
- mteb/tasks/classification/rus/ru_reviews_classification.py +2 -2
- mteb/tasks/classification/rus/ru_sci_bench_grnti_classification.py +1 -1
- mteb/tasks/classification/rus/ru_sci_bench_oecd_classification.py +1 -1
- mteb/tasks/classification/rus/ru_toixic_classification_okmlcup.py +1 -1
- mteb/tasks/classification/san/sanskrit_shlokas_classification.py +1 -1
- mteb/tasks/classification/sin/sinhala_news_classification.py +2 -2
- mteb/tasks/classification/sin/sinhala_news_source_classification.py +2 -2
- mteb/tasks/classification/slk/csfdsk_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/slv/frenk_sl_classification.py +1 -1
- mteb/tasks/classification/spa/spanish_news_classification.py +2 -2
- mteb/tasks/classification/ssw/siswati_news_classification.py +1 -1
- mteb/tasks/classification/tam/tamil_news_classification.py +2 -2
- mteb/tasks/classification/tel/telugu_andhra_jyoti_news_classification.py +2 -2
- mteb/tasks/classification/tha/wongnai_reviews_classification.py +1 -1
- mteb/tasks/classification/tur/turkish_movie_sentiment_classification.py +2 -2
- mteb/tasks/classification/ukr/ukr_formality_classification.py +2 -2
- mteb/tasks/classification/vie/toxic_conversations_vn_classification.py +1 -1
- mteb/tasks/classification/vie/vie_student_feedback_classification.py +1 -1
- mteb/tasks/classification/zho/yue_openrice_review_classification.py +2 -2
- mteb/tasks/classification/zul/isi_zulu_news_classification.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_p2p.py +1 -1
- mteb/tasks/clustering/deu/blurbs_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/arxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/arxiv_hierarchical_clustering.py +2 -2
- mteb/tasks/clustering/eng/big_patent_clustering.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/biorxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/medrxiv_clustering_s2s.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering.py +1 -1
- mteb/tasks/clustering/eng/reddit_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering.py +1 -1
- mteb/tasks/clustering/eng/stack_exchange_clustering_p2p.py +1 -1
- mteb/tasks/clustering/eng/twenty_newsgroups_clustering.py +1 -1
- mteb/tasks/clustering/fas/fa_mteb_clustering.py +4 -4
- mteb/tasks/clustering/fra/hal_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
- mteb/tasks/clustering/multilingual/sib200_clustering_s2s.py +1 -1
- mteb/tasks/clustering/multilingual/wiki_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/dutch_news_articles_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/iconclass_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nld/open_tender_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_p2p.py +1 -1
- mteb/tasks/clustering/nld/vabb_clustering_s2s.py +1 -1
- mteb/tasks/clustering/nob/snl_clustering.py +1 -1
- mteb/tasks/clustering/nob/vg_clustering.py +1 -1
- mteb/tasks/clustering/pol/polish_clustering.py +3 -3
- mteb/tasks/clustering/rus/ru_sci_bench_grnti_clustering_p2p.py +1 -1
- mteb/tasks/clustering/rus/ru_sci_bench_oecd_clustering_p2p.py +1 -1
- mteb/tasks/clustering/zho/cmteb_clustering.py +4 -4
- mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
- mteb/tasks/multichoice/eng/cv_bench.py +4 -4
- mteb/tasks/multilabel_classification/ita/emit_classification.py +1 -1
- mteb/tasks/multilabel_classification/mlt/maltese_news_classification.py +1 -1
- mteb/tasks/multilabel_classification/rus/ru_toixic_multilabelclassification_okmlcup.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_group_classification.py +1 -1
- mteb/tasks/multilabel_classification/swe/swedish_patent_cpc_subclass_classification.py +1 -1
- mteb/tasks/pair_classification/ara/ar_entail.py +1 -1
- mteb/tasks/pair_classification/dan/talemaader_pc.py +1 -1
- mteb/tasks/pair_classification/deu/false_friends_de_en_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_ai_sentence_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_synonym_pc.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_wiki_paragraphs_pc.py +1 -1
- mteb/tasks/pair_classification/eng/sprint_duplicate_questions_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_sem_eval2015_pc.py +1 -1
- mteb/tasks/pair_classification/eng/twitter_url_corpus_pc.py +1 -1
- mteb/tasks/pair_classification/fas/fa_mteb_pair_classification.py +5 -5
- mteb/tasks/pair_classification/fas/fars_tail.py +2 -2
- mteb/tasks/pair_classification/hye/armenian_paraphrase_pc.py +1 -1
- mteb/tasks/pair_classification/ita/dis_co_tex_pair_classification.py +1 -1
- mteb/tasks/pair_classification/kor/klue_nli.py +1 -1
- mteb/tasks/pair_classification/multilingual/rte3.py +2 -2
- mteb/tasks/pair_classification/multilingual/xnli.py +1 -1
- mteb/tasks/pair_classification/pol/polish_pc.py +4 -4
- mteb/tasks/pair_classification/por/assin2_rte.py +1 -1
- mteb/tasks/pair_classification/por/sick_br_pc.py +1 -1
- mteb/tasks/pair_classification/rus/terra.py +2 -2
- mteb/tasks/pair_classification/vie/sprint_duplicate_questions_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_sem_eval2015_pcvn.py +1 -1
- mteb/tasks/pair_classification/vie/twitter_url_corpus_pcvn.py +1 -1
- mteb/tasks/pair_classification/zho/cmteb_pair_classification.py +2 -2
- mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +4 -4
- mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
- mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
- mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
- mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
- mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
- mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +1 -1
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
- mteb/tasks/retrieval/eng/bright_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/ml_questions.py +1 -1
- mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
- mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
- mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
- mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
- mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -2
- mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +14 -4
- mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nob/norquad.py +1 -1
- mteb/tasks/retrieval/nob/snl_retrieval.py +1 -1
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
- mteb/tasks/sts/fao/faroese_sts.py +1 -1
- mteb/tasks/sts/fra/sick_fr_sts.py +1 -1
- mteb/tasks/sts/kor/klue_sts.py +1 -1
- mteb/tasks/sts/por/sick_br_sts.py +1 -1
- mteb/tasks/sts/rus/ru_para_phraser_sts.py +1 -1
- mteb/tasks/zeroshot_classification/eng/sci_mmir.py +1 -1
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/METADATA +1 -1
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/RECORD +287 -287
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/WHEEL +0 -0
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/entry_points.txt +0 -0
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.7.4.dist-info → mteb-2.7.6.dist-info}/top_level.txt +0 -0
mteb/abstasks/classification.py
CHANGED
|
@@ -136,6 +136,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
136
136
|
*,
|
|
137
137
|
encode_kwargs: EncodeKwargs,
|
|
138
138
|
prediction_folder: Path | None = None,
|
|
139
|
+
num_proc: int = 1,
|
|
139
140
|
**kwargs: Any,
|
|
140
141
|
) -> dict[HFSubset, ScoresDict]:
|
|
141
142
|
"""Evaluate a model on the classification task.
|
|
@@ -149,7 +150,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
149
150
|
)
|
|
150
151
|
|
|
151
152
|
if not self.data_loaded:
|
|
152
|
-
self.load_data()
|
|
153
|
+
self.load_data(num_proc=num_proc)
|
|
153
154
|
|
|
154
155
|
if self.dataset is None:
|
|
155
156
|
raise RuntimeError("Dataset not loaded.")
|
|
@@ -182,6 +183,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
182
183
|
hf_subset=hf_subset,
|
|
183
184
|
encode_kwargs=encode_kwargs,
|
|
184
185
|
prediction_folder=prediction_folder,
|
|
186
|
+
num_proc=num_proc,
|
|
185
187
|
**kwargs,
|
|
186
188
|
)
|
|
187
189
|
self._add_main_score(scores[hf_subset])
|
|
@@ -197,6 +199,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
197
199
|
hf_split: str,
|
|
198
200
|
hf_subset: str,
|
|
199
201
|
prediction_folder: Path | None = None,
|
|
202
|
+
num_proc: int = 1,
|
|
200
203
|
**kwargs: Any,
|
|
201
204
|
) -> FullClassificationMetrics:
|
|
202
205
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -230,7 +233,10 @@ class AbsTaskClassification(AbsTask):
|
|
|
230
233
|
evaluator_model=self.evaluator_model,
|
|
231
234
|
)
|
|
232
235
|
y_pred, test_cache = evaluator(
|
|
233
|
-
model,
|
|
236
|
+
model,
|
|
237
|
+
encode_kwargs=encode_kwargs,
|
|
238
|
+
test_cache=test_cache,
|
|
239
|
+
num_proc=num_proc,
|
|
234
240
|
)
|
|
235
241
|
if prediction_folder:
|
|
236
242
|
all_predictions.append(y_pred.tolist())
|
|
@@ -372,11 +378,12 @@ class AbsTaskClassification(AbsTask):
|
|
|
372
378
|
label_statistics=label_statistics,
|
|
373
379
|
)
|
|
374
380
|
|
|
375
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
381
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
376
382
|
self._upload_dataset_to_hub(
|
|
377
383
|
repo_name,
|
|
378
384
|
[
|
|
379
385
|
self.input_column_name,
|
|
380
386
|
self.label_column_name,
|
|
381
387
|
],
|
|
388
|
+
num_proc=num_proc,
|
|
382
389
|
)
|
mteb/abstasks/clustering.py
CHANGED
|
@@ -169,6 +169,7 @@ class AbsTaskClustering(AbsTask):
|
|
|
169
169
|
hf_split: str,
|
|
170
170
|
hf_subset: str,
|
|
171
171
|
prediction_folder: Path | None = None,
|
|
172
|
+
num_proc: int = 1,
|
|
172
173
|
**kwargs: Any,
|
|
173
174
|
) -> ScoresDict:
|
|
174
175
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -213,6 +214,7 @@ class AbsTaskClustering(AbsTask):
|
|
|
213
214
|
downsampled_dataset,
|
|
214
215
|
self.metadata,
|
|
215
216
|
input_column=self.input_column_name,
|
|
217
|
+
num_proc=num_proc,
|
|
216
218
|
**encode_kwargs,
|
|
217
219
|
),
|
|
218
220
|
task_metadata=self.metadata,
|
|
@@ -296,9 +298,11 @@ class AbsTaskClustering(AbsTask):
|
|
|
296
298
|
labels_statistics=label_statistics,
|
|
297
299
|
)
|
|
298
300
|
|
|
299
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
301
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
300
302
|
self._upload_dataset_to_hub(
|
|
301
|
-
repo_name,
|
|
303
|
+
repo_name,
|
|
304
|
+
[self.input_column_name, self.label_column_name],
|
|
305
|
+
num_proc=num_proc,
|
|
302
306
|
)
|
|
303
307
|
|
|
304
308
|
|
|
@@ -95,6 +95,7 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
95
95
|
hf_split: str,
|
|
96
96
|
hf_subset: str,
|
|
97
97
|
prediction_folder: Path | None = None,
|
|
98
|
+
num_proc: int = 1,
|
|
98
99
|
**kwargs: Any,
|
|
99
100
|
) -> ScoresDict:
|
|
100
101
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -159,7 +160,11 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
159
160
|
hf_subset=hf_subset,
|
|
160
161
|
**kwargs,
|
|
161
162
|
)
|
|
162
|
-
evaluate_clusters = evaluator(
|
|
163
|
+
evaluate_clusters = evaluator(
|
|
164
|
+
model,
|
|
165
|
+
encode_kwargs=encode_kwargs,
|
|
166
|
+
num_proc=num_proc,
|
|
167
|
+
)
|
|
163
168
|
if prediction_folder:
|
|
164
169
|
self._save_task_predictions(
|
|
165
170
|
evaluate_clusters,
|
|
@@ -238,11 +243,12 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
238
243
|
label_statistics=label_statistics,
|
|
239
244
|
)
|
|
240
245
|
|
|
241
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
246
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
242
247
|
self._upload_dataset_to_hub(
|
|
243
248
|
repo_name,
|
|
244
249
|
[
|
|
245
250
|
self.input_column_name,
|
|
246
251
|
self.label_column_name,
|
|
247
252
|
],
|
|
253
|
+
num_proc=num_proc,
|
|
248
254
|
)
|
|
@@ -134,6 +134,7 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
134
134
|
hf_split: str,
|
|
135
135
|
hf_subset: str,
|
|
136
136
|
prediction_folder: Path | None = None,
|
|
137
|
+
num_proc: int = 1,
|
|
137
138
|
**kwargs: Any,
|
|
138
139
|
) -> ImageTextPairClassificationMetrics:
|
|
139
140
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -167,7 +168,9 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
167
168
|
hf_subset=hf_subset,
|
|
168
169
|
**kwargs,
|
|
169
170
|
)
|
|
170
|
-
scores: list[torch.Tensor] = evaluator(
|
|
171
|
+
scores: list[torch.Tensor] = evaluator(
|
|
172
|
+
model, encode_kwargs=encode_kwargs, num_proc=num_proc
|
|
173
|
+
) # type: ignore[assignment]
|
|
171
174
|
if prediction_folder:
|
|
172
175
|
self._save_task_predictions(
|
|
173
176
|
[score.tolist() for score in scores],
|
|
@@ -215,7 +218,7 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
215
218
|
accuracy=torch.Tensor(all_correct_scores).float().mean().item(),
|
|
216
219
|
)
|
|
217
220
|
|
|
218
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
221
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
219
222
|
text_columns = (
|
|
220
223
|
[self.texts_column_names]
|
|
221
224
|
if isinstance(self.texts_column_names, str)
|
|
@@ -230,4 +233,5 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
230
233
|
self._upload_dataset_to_hub(
|
|
231
234
|
repo_name,
|
|
232
235
|
[*text_columns, *image_columns],
|
|
236
|
+
num_proc=num_proc,
|
|
233
237
|
)
|
|
@@ -93,6 +93,7 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
93
93
|
hf_split: str,
|
|
94
94
|
hf_subset: str,
|
|
95
95
|
prediction_folder: Path | None = None,
|
|
96
|
+
num_proc: int = 1,
|
|
96
97
|
**kwargs: Any,
|
|
97
98
|
) -> FullMultilabelClassificationMetrics:
|
|
98
99
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -125,6 +126,7 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
125
126
|
unique_train_dataset,
|
|
126
127
|
self.metadata,
|
|
127
128
|
input_column=self.input_column_name,
|
|
129
|
+
num_proc=num_proc,
|
|
128
130
|
**encode_kwargs,
|
|
129
131
|
)
|
|
130
132
|
|
|
@@ -96,6 +96,7 @@ class AbsTaskPairClassification(AbsTask):
|
|
|
96
96
|
hf_subset: str,
|
|
97
97
|
encode_kwargs: EncodeKwargs,
|
|
98
98
|
prediction_folder: Path | None = None,
|
|
99
|
+
num_proc: int = 1,
|
|
99
100
|
**kwargs,
|
|
100
101
|
) -> dict[str, float]:
|
|
101
102
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -115,7 +116,11 @@ class AbsTaskPairClassification(AbsTask):
|
|
|
115
116
|
input2_prompt_type=self.input2_prompt_type,
|
|
116
117
|
**kwargs,
|
|
117
118
|
)
|
|
118
|
-
similarity_scores = evaluator(
|
|
119
|
+
similarity_scores = evaluator(
|
|
120
|
+
model,
|
|
121
|
+
encode_kwargs=encode_kwargs,
|
|
122
|
+
num_proc=num_proc,
|
|
123
|
+
)
|
|
119
124
|
|
|
120
125
|
if prediction_folder:
|
|
121
126
|
self._save_task_predictions(
|
|
@@ -248,7 +253,7 @@ class AbsTaskPairClassification(AbsTask):
|
|
|
248
253
|
labels_statistics=calculate_label_statistics(labels),
|
|
249
254
|
)
|
|
250
255
|
|
|
251
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
256
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
252
257
|
# previously pair classification datasets were stored in a single row
|
|
253
258
|
if self.dataset is None:
|
|
254
259
|
# overall this shouldn't happen as we check for dataset before pushing to hub
|
|
@@ -272,6 +277,7 @@ class AbsTaskPairClassification(AbsTask):
|
|
|
272
277
|
self.input2_column_name,
|
|
273
278
|
self.label_column_name,
|
|
274
279
|
],
|
|
280
|
+
num_proc=num_proc,
|
|
275
281
|
)
|
|
276
282
|
|
|
277
283
|
def _compute_metrics_values(
|
mteb/abstasks/retrieval.py
CHANGED
|
@@ -148,7 +148,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
148
148
|
)
|
|
149
149
|
)
|
|
150
150
|
|
|
151
|
-
def convert_v1_dataset_format_to_v2(self):
|
|
151
|
+
def convert_v1_dataset_format_to_v2(self, num_proc: int) -> None:
|
|
152
152
|
"""Convert dataset from v1 (from `self.queries`, `self.document`) format to v2 format (`self.dotaset`)."""
|
|
153
153
|
# check if dataset is `v1` version
|
|
154
154
|
if not hasattr(self, "queries"):
|
|
@@ -215,6 +215,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
215
215
|
_combine_queries_with_instructions_datasets(
|
|
216
216
|
self.dataset[subset][split]["queries"],
|
|
217
217
|
instructions,
|
|
218
|
+
num_proc,
|
|
218
219
|
)
|
|
219
220
|
)
|
|
220
221
|
if hasattr(self, "top_ranked"):
|
|
@@ -240,6 +241,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
240
241
|
_combine_queries_with_instructions_datasets(
|
|
241
242
|
self.dataset[subset][split]["queries"],
|
|
242
243
|
instructions,
|
|
244
|
+
num_proc,
|
|
243
245
|
)
|
|
244
246
|
)
|
|
245
247
|
if hasattr(self, "top_ranked") and self.top_ranked:
|
|
@@ -255,7 +257,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
255
257
|
if hasattr(self, "top_ranked"):
|
|
256
258
|
del self.top_ranked
|
|
257
259
|
|
|
258
|
-
def load_data(self) -> None:
|
|
260
|
+
def load_data(self, num_proc: int = 1, **kwargs) -> None:
|
|
259
261
|
"""Load the dataset for the retrieval task."""
|
|
260
262
|
if self.data_loaded:
|
|
261
263
|
return
|
|
@@ -277,7 +279,9 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
277
279
|
trust_remote_code=trust_remote_code,
|
|
278
280
|
split=split,
|
|
279
281
|
config=hf_subset,
|
|
280
|
-
).load(
|
|
282
|
+
).load(
|
|
283
|
+
num_proc=num_proc,
|
|
284
|
+
)
|
|
281
285
|
|
|
282
286
|
if self.metadata.is_multilingual:
|
|
283
287
|
for lang in self.metadata.eval_langs:
|
|
@@ -286,7 +290,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
286
290
|
else:
|
|
287
291
|
for split in eval_splits:
|
|
288
292
|
_process_data(split)
|
|
289
|
-
self.dataset_transform()
|
|
293
|
+
self.dataset_transform(num_proc=num_proc)
|
|
290
294
|
self.data_loaded = True
|
|
291
295
|
|
|
292
296
|
def evaluate(
|
|
@@ -297,6 +301,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
297
301
|
*,
|
|
298
302
|
encode_kwargs: EncodeKwargs,
|
|
299
303
|
prediction_folder: Path | None = None,
|
|
304
|
+
num_proc: int = 1,
|
|
300
305
|
**kwargs: Any,
|
|
301
306
|
) -> Mapping[HFSubset, ScoresDict]:
|
|
302
307
|
"""Evaluate the model on the retrieval task.
|
|
@@ -308,16 +313,16 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
308
313
|
subsets_to_run: Optional list of subsets to evaluate on
|
|
309
314
|
encode_kwargs: Keyword arguments passed to the encoder
|
|
310
315
|
prediction_folder: Folder to save model predictions
|
|
316
|
+
num_proc: Number of processes to use
|
|
311
317
|
**kwargs: Additional keyword arguments passed to the evaluator
|
|
312
318
|
|
|
313
|
-
|
|
314
319
|
Returns:
|
|
315
320
|
Dictionary mapping subsets to their evaluation scores
|
|
316
321
|
"""
|
|
317
322
|
if not self.data_loaded:
|
|
318
|
-
self.load_data()
|
|
323
|
+
self.load_data(num_proc=num_proc)
|
|
319
324
|
# TODO: convert all tasks directly https://github.com/embeddings-benchmark/mteb/issues/2030
|
|
320
|
-
self.convert_v1_dataset_format_to_v2()
|
|
325
|
+
self.convert_v1_dataset_format_to_v2(num_proc=num_proc)
|
|
321
326
|
|
|
322
327
|
return super().evaluate(
|
|
323
328
|
model,
|
|
@@ -325,6 +330,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
325
330
|
subsets_to_run,
|
|
326
331
|
encode_kwargs=encode_kwargs,
|
|
327
332
|
prediction_folder=prediction_folder,
|
|
333
|
+
num_proc=num_proc,
|
|
328
334
|
**kwargs,
|
|
329
335
|
)
|
|
330
336
|
|
|
@@ -336,6 +342,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
336
342
|
hf_split: str,
|
|
337
343
|
hf_subset: str,
|
|
338
344
|
prediction_folder: Path | None = None,
|
|
345
|
+
num_proc: int = 1,
|
|
339
346
|
**kwargs,
|
|
340
347
|
) -> ScoresDict:
|
|
341
348
|
"""Evaluate a model on a specific subset of the data.
|
|
@@ -347,6 +354,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
347
354
|
hf_split: Split to evaluate on
|
|
348
355
|
hf_subset: Subset to evaluate on
|
|
349
356
|
prediction_folder: Folder with results prediction
|
|
357
|
+
num_proc: Number of processes to use
|
|
350
358
|
**kwargs: Additional keyword arguments passed to the evaluator
|
|
351
359
|
|
|
352
360
|
Returns:
|
|
@@ -386,6 +394,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
386
394
|
results = retriever(
|
|
387
395
|
search_model,
|
|
388
396
|
encode_kwargs=encode_kwargs,
|
|
397
|
+
num_proc=num_proc,
|
|
389
398
|
)
|
|
390
399
|
end_time = time()
|
|
391
400
|
logger.debug(
|
|
@@ -460,9 +469,13 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
460
469
|
return {}
|
|
461
470
|
|
|
462
471
|
def _calculate_descriptive_statistics_from_split(
|
|
463
|
-
self,
|
|
472
|
+
self,
|
|
473
|
+
split: str,
|
|
474
|
+
hf_subset: str | None = None,
|
|
475
|
+
compute_overall: bool = False,
|
|
476
|
+
num_proc: int = 1,
|
|
464
477
|
) -> RetrievalDescriptiveStatistics:
|
|
465
|
-
self.convert_v1_dataset_format_to_v2()
|
|
478
|
+
self.convert_v1_dataset_format_to_v2(num_proc)
|
|
466
479
|
if hf_subset and hf_subset in self.dataset:
|
|
467
480
|
split_data = self.dataset[hf_subset][split]
|
|
468
481
|
queries = split_data["queries"]
|
|
@@ -567,8 +580,8 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
567
580
|
top_ranked_statistics=top_ranked_statistics,
|
|
568
581
|
)
|
|
569
582
|
|
|
570
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
571
|
-
self.convert_v1_dataset_format_to_v2()
|
|
583
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
584
|
+
self.convert_v1_dataset_format_to_v2(num_proc)
|
|
572
585
|
|
|
573
586
|
def _push_section(
|
|
574
587
|
data: dict[str, RetrievalSplitData],
|
|
@@ -608,6 +621,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
608
621
|
repo_name,
|
|
609
622
|
hf_subset_name,
|
|
610
623
|
commit_message=f"Add {hf_subset_name}-{subset_item}",
|
|
624
|
+
num_proc=num_proc,
|
|
611
625
|
)
|
|
612
626
|
|
|
613
627
|
for subset in self.dataset:
|
|
@@ -641,6 +655,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
641
655
|
repo_name,
|
|
642
656
|
f"{subset}-qrels" if subset != "default" else "qrels",
|
|
643
657
|
commit_message=f"Add {subset}-qrels",
|
|
658
|
+
num_proc=num_proc,
|
|
644
659
|
)
|
|
645
660
|
|
|
646
661
|
_push_section(
|
|
@@ -76,28 +76,36 @@ class RetrievalDatasetLoader:
|
|
|
76
76
|
self.config = config if config != "default" else None
|
|
77
77
|
self.dataset_configs = get_dataset_config_names(self.hf_repo, self.revision)
|
|
78
78
|
|
|
79
|
-
def load(
|
|
79
|
+
def load(
|
|
80
|
+
self,
|
|
81
|
+
num_proc: int = 1,
|
|
82
|
+
) -> RetrievalSplitData:
|
|
80
83
|
"""Loads the dataset split for the specified configuration.
|
|
81
84
|
|
|
85
|
+
Args:
|
|
86
|
+
num_proc: The number of processes to use.
|
|
87
|
+
|
|
82
88
|
Returns:
|
|
83
89
|
A dictionary containing the corpus, queries, relevant documents, instructions (if applicable), and top-ranked documents (if applicable).
|
|
84
90
|
"""
|
|
85
91
|
top_ranked = None
|
|
86
92
|
|
|
87
|
-
qrels = self._load_qrels()
|
|
88
|
-
corpus = self._load_corpus()
|
|
89
|
-
queries = self._load_queries()
|
|
93
|
+
qrels = self._load_qrels(num_proc)
|
|
94
|
+
corpus = self._load_corpus(num_proc)
|
|
95
|
+
queries = self._load_queries(num_proc)
|
|
90
96
|
|
|
91
97
|
queries = queries.filter(
|
|
92
98
|
lambda x: x["id"] in qrels.keys(), desc="Filtering queries by qrels"
|
|
93
99
|
)
|
|
94
100
|
|
|
95
101
|
if any(c.endswith("top_ranked") for c in self.dataset_configs):
|
|
96
|
-
top_ranked = self._load_top_ranked()
|
|
102
|
+
top_ranked = self._load_top_ranked(num_proc)
|
|
97
103
|
|
|
98
104
|
if any(c.endswith("instruction") for c in self.dataset_configs):
|
|
99
|
-
instructions = self._load_instructions()
|
|
100
|
-
queries = _combine_queries_with_instructions_datasets(
|
|
105
|
+
instructions = self._load_instructions(num_proc)
|
|
106
|
+
queries = _combine_queries_with_instructions_datasets(
|
|
107
|
+
queries, instructions, num_proc
|
|
108
|
+
)
|
|
101
109
|
|
|
102
110
|
return RetrievalSplitData(
|
|
103
111
|
corpus=corpus,
|
|
@@ -120,20 +128,21 @@ class RetrievalDatasetLoader:
|
|
|
120
128
|
f"Split {self.split} not found in {splits}. Please specify a valid split."
|
|
121
129
|
)
|
|
122
130
|
|
|
123
|
-
def _load_dataset_split(self, config: str) -> Dataset:
|
|
131
|
+
def _load_dataset_split(self, config: str, num_proc: int) -> Dataset:
|
|
124
132
|
return load_dataset(
|
|
125
133
|
self.hf_repo,
|
|
126
134
|
config,
|
|
127
135
|
split=self._get_split(config),
|
|
128
136
|
trust_remote_code=self.trust_remote_code,
|
|
129
137
|
revision=self.revision,
|
|
138
|
+
num_proc=num_proc,
|
|
130
139
|
)
|
|
131
140
|
|
|
132
|
-
def _load_corpus(self) -> CorpusDatasetType:
|
|
141
|
+
def _load_corpus(self, num_proc: int) -> CorpusDatasetType:
|
|
133
142
|
logger.info("Loading Corpus...")
|
|
134
143
|
|
|
135
144
|
config = f"{self.config}-corpus" if self.config is not None else "corpus"
|
|
136
|
-
corpus_ds = self._load_dataset_split(config)
|
|
145
|
+
corpus_ds = self._load_dataset_split(config, num_proc)
|
|
137
146
|
if "_id" in corpus_ds.column_names:
|
|
138
147
|
corpus_ds = corpus_ds.cast_column("_id", Value("string")).rename_column(
|
|
139
148
|
"_id", "id"
|
|
@@ -142,13 +151,13 @@ class RetrievalDatasetLoader:
|
|
|
142
151
|
logger.debug("Doc Example: %s", corpus_ds[0])
|
|
143
152
|
return corpus_ds
|
|
144
153
|
|
|
145
|
-
def _load_queries(self) -> QueryDatasetType:
|
|
154
|
+
def _load_queries(self, num_proc: int) -> QueryDatasetType:
|
|
146
155
|
logger.info("Loading Queries...")
|
|
147
156
|
|
|
148
157
|
config = f"{self.config}-queries" if self.config is not None else "queries"
|
|
149
158
|
if "query" in self.dataset_configs:
|
|
150
159
|
config = "query"
|
|
151
|
-
queries_ds = self._load_dataset_split(config)
|
|
160
|
+
queries_ds = self._load_dataset_split(config, num_proc)
|
|
152
161
|
if "_id" in queries_ds.column_names:
|
|
153
162
|
queries_ds = queries_ds.cast_column("_id", Value("string")).rename_column(
|
|
154
163
|
"_id", "id"
|
|
@@ -159,7 +168,7 @@ class RetrievalDatasetLoader:
|
|
|
159
168
|
|
|
160
169
|
return queries_ds
|
|
161
170
|
|
|
162
|
-
def _load_qrels(self) -> RelevantDocumentsType:
|
|
171
|
+
def _load_qrels(self, num_proc: int) -> RelevantDocumentsType:
|
|
163
172
|
logger.info("Loading qrels...")
|
|
164
173
|
|
|
165
174
|
config = f"{self.config}-qrels" if self.config is not None else "default"
|
|
@@ -171,7 +180,7 @@ class RetrievalDatasetLoader:
|
|
|
171
180
|
"No qrels or default config found. Please specify a valid config or ensure the dataset has qrels."
|
|
172
181
|
)
|
|
173
182
|
|
|
174
|
-
qrels_ds = self._load_dataset_split(config)
|
|
183
|
+
qrels_ds = self._load_dataset_split(config, num_proc)
|
|
175
184
|
qrels_ds = qrels_ds.select_columns(["query-id", "corpus-id", "score"])
|
|
176
185
|
|
|
177
186
|
qrels_ds = qrels_ds.cast(
|
|
@@ -194,13 +203,13 @@ class RetrievalDatasetLoader:
|
|
|
194
203
|
logger.info("Loaded %d %s qrels.", len(qrels_dict), self.split.upper())
|
|
195
204
|
return qrels_dict
|
|
196
205
|
|
|
197
|
-
def _load_top_ranked(self) -> TopRankedDocumentsType:
|
|
206
|
+
def _load_top_ranked(self, num_proc: int) -> TopRankedDocumentsType:
|
|
198
207
|
logger.info("Loading Top Ranked")
|
|
199
208
|
|
|
200
209
|
config = (
|
|
201
210
|
f"{self.config}-top_ranked" if self.config is not None else "top_ranked"
|
|
202
211
|
)
|
|
203
|
-
top_ranked_ds = self._load_dataset_split(config)
|
|
212
|
+
top_ranked_ds = self._load_dataset_split(config, num_proc)
|
|
204
213
|
top_ranked_ds = top_ranked_ds.cast(
|
|
205
214
|
Features(
|
|
206
215
|
{
|
|
@@ -218,13 +227,13 @@ class RetrievalDatasetLoader:
|
|
|
218
227
|
logger.info(f"Top ranked loaded: {len(top_ranked_ds)}")
|
|
219
228
|
return top_ranked_dict
|
|
220
229
|
|
|
221
|
-
def _load_instructions(self) -> InstructionDatasetType:
|
|
230
|
+
def _load_instructions(self, num_proc: int) -> InstructionDatasetType:
|
|
222
231
|
logger.info("Loading Instructions")
|
|
223
232
|
|
|
224
233
|
config = (
|
|
225
234
|
f"{self.config}-instruction" if self.config is not None else "instruction"
|
|
226
235
|
)
|
|
227
|
-
instructions_ds = self._load_dataset_split(config)
|
|
236
|
+
instructions_ds = self._load_dataset_split(config, num_proc)
|
|
228
237
|
instructions_ds = instructions_ds.cast(
|
|
229
238
|
Features(
|
|
230
239
|
{
|
|
@@ -239,6 +248,7 @@ class RetrievalDatasetLoader:
|
|
|
239
248
|
def _combine_queries_with_instructions_datasets(
|
|
240
249
|
queries_dataset: QueryDatasetType,
|
|
241
250
|
instruction_dataset: InstructionDatasetType | dict[str, str],
|
|
251
|
+
num_proc: int,
|
|
242
252
|
) -> Dataset:
|
|
243
253
|
if isinstance(instruction_dataset, Dataset):
|
|
244
254
|
instruction_to_query_idx = {
|
|
@@ -251,4 +261,4 @@ def _combine_queries_with_instructions_datasets(
|
|
|
251
261
|
row["instruction"] = instruction_to_query_idx[row["id"]]
|
|
252
262
|
return row
|
|
253
263
|
|
|
254
|
-
return queries_dataset.map(_add_instruction_to_query)
|
|
264
|
+
return queries_dataset.map(_add_instruction_to_query, num_proc=num_proc)
|
mteb/abstasks/sts.py
CHANGED
|
@@ -118,6 +118,7 @@ class AbsTaskSTS(AbsTask):
|
|
|
118
118
|
hf_split: str,
|
|
119
119
|
hf_subset: str,
|
|
120
120
|
prediction_folder: Path | None = None,
|
|
121
|
+
num_proc: int = 1,
|
|
121
122
|
**kwargs: Any,
|
|
122
123
|
) -> STSMetrics:
|
|
123
124
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -136,7 +137,11 @@ class AbsTaskSTS(AbsTask):
|
|
|
136
137
|
input2_prompt_type=self.input2_prompt_type,
|
|
137
138
|
**kwargs,
|
|
138
139
|
)
|
|
139
|
-
scores = evaluator(
|
|
140
|
+
scores = evaluator(
|
|
141
|
+
model,
|
|
142
|
+
encode_kwargs=encode_kwargs,
|
|
143
|
+
num_proc=num_proc,
|
|
144
|
+
)
|
|
140
145
|
|
|
141
146
|
if prediction_folder:
|
|
142
147
|
self._save_task_predictions(
|
|
@@ -245,9 +250,11 @@ class AbsTaskSTS(AbsTask):
|
|
|
245
250
|
label_statistics=labels_statistics,
|
|
246
251
|
)
|
|
247
252
|
|
|
248
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
253
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
249
254
|
self._upload_dataset_to_hub(
|
|
250
|
-
repo_name,
|
|
255
|
+
repo_name,
|
|
256
|
+
[self.column_names[0], self.column_names[1], "score"],
|
|
257
|
+
num_proc=num_proc,
|
|
251
258
|
)
|
|
252
259
|
|
|
253
260
|
def _normalize(self, x: float) -> float:
|
|
@@ -82,6 +82,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
82
82
|
*,
|
|
83
83
|
encode_kwargs: EncodeKwargs,
|
|
84
84
|
prediction_folder: Path | None = None,
|
|
85
|
+
num_proc: int = 1,
|
|
85
86
|
**kwargs: Any,
|
|
86
87
|
) -> dict[HFSubset, ScoresDict]:
|
|
87
88
|
"""Added load for "parallel" datasets"""
|
|
@@ -89,7 +90,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
89
90
|
raise TypeError("Expected model to be an instance of EncoderProtocol")
|
|
90
91
|
|
|
91
92
|
if not self.data_loaded:
|
|
92
|
-
self.load_data()
|
|
93
|
+
self.load_data(num_proc=num_proc)
|
|
93
94
|
|
|
94
95
|
hf_subsets = self.hf_subsets
|
|
95
96
|
|
|
@@ -112,6 +113,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
112
113
|
hf_subset="parallel",
|
|
113
114
|
encode_kwargs=encode_kwargs,
|
|
114
115
|
prediction_folder=prediction_folder,
|
|
116
|
+
num_proc=num_proc,
|
|
115
117
|
**kwargs,
|
|
116
118
|
)
|
|
117
119
|
else:
|
|
@@ -131,6 +133,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
131
133
|
hf_subset=hf_subset,
|
|
132
134
|
encode_kwargs=encode_kwargs,
|
|
133
135
|
prediction_folder=prediction_folder,
|
|
136
|
+
num_proc=num_proc,
|
|
134
137
|
**kwargs,
|
|
135
138
|
)
|
|
136
139
|
|
|
@@ -152,6 +155,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
152
155
|
encode_kwargs: EncodeKwargs,
|
|
153
156
|
prediction_folder: Path | None = None,
|
|
154
157
|
parallel: bool = False,
|
|
158
|
+
num_proc: int = 1,
|
|
155
159
|
**kwargs,
|
|
156
160
|
) -> BitextMiningMetrics | dict[str, BitextMiningMetrics]:
|
|
157
161
|
pairs = self._get_pairs(parallel)
|
|
@@ -171,7 +175,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
171
175
|
else data_split["gold"]
|
|
172
176
|
)
|
|
173
177
|
|
|
174
|
-
neighbours = evaluator(model, encode_kwargs=encode_kwargs)
|
|
178
|
+
neighbours = evaluator(model, encode_kwargs=encode_kwargs, num_proc=num_proc)
|
|
175
179
|
|
|
176
180
|
if prediction_folder:
|
|
177
181
|
self._save_task_predictions(
|
|
@@ -264,7 +268,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
264
268
|
sentence2_statistics=text2_statistics,
|
|
265
269
|
)
|
|
266
270
|
|
|
267
|
-
def _push_dataset_to_hub(self, repo_name: str) -> None:
|
|
271
|
+
def _push_dataset_to_hub(self, repo_name: str, num_proc: int = 1) -> None:
|
|
268
272
|
if self.dataset is None:
|
|
269
273
|
raise ValueError("Dataset is not loaded.")
|
|
270
274
|
|
|
@@ -287,7 +291,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
287
291
|
dataset_dict = DatasetDict(
|
|
288
292
|
{split: Dataset.from_dict(dataset[split]) for split in dataset}
|
|
289
293
|
)
|
|
290
|
-
dataset_dict.push_to_hub(repo_name)
|
|
294
|
+
dataset_dict.push_to_hub(repo_name, num_proc=num_proc)
|
|
291
295
|
else:
|
|
292
296
|
sentences = {}
|
|
293
297
|
for split in self.dataset:
|
|
@@ -299,4 +303,4 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
299
303
|
}
|
|
300
304
|
)
|
|
301
305
|
sentences = DatasetDict(sentences)
|
|
302
|
-
sentences.push_to_hub(repo_name)
|
|
306
|
+
sentences.push_to_hub(repo_name, num_proc=num_proc)
|
mteb/abstasks/text/reranking.py
CHANGED
|
@@ -34,7 +34,7 @@ class AbsTaskReranking(AbsTaskRetrieval):
|
|
|
34
34
|
For dataformat and other information, see [AbsTaskRetrieval][mteb.abstasks.retrieval.AbsTaskRetrieval].
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
def load_data(self) -> None:
|
|
37
|
+
def load_data(self, num_proc: int = 1, **kwargs) -> None:
|
|
38
38
|
"""Load the dataset."""
|
|
39
39
|
if self.data_loaded:
|
|
40
40
|
return
|
|
@@ -43,7 +43,7 @@ class AbsTaskReranking(AbsTaskRetrieval):
|
|
|
43
43
|
self.transform_old_dataset_format()
|
|
44
44
|
else:
|
|
45
45
|
# use AbsTaskRetrieval default to load the data
|
|
46
|
-
return super().load_data()
|
|
46
|
+
return super().load_data(num_proc=num_proc)
|
|
47
47
|
|
|
48
48
|
def _process_example(self, example: dict, split: str, query_idx: int) -> dict:
|
|
49
49
|
"""Process a single example from the dataset.
|
|
@@ -94,6 +94,7 @@ class AbsTaskSummarization(AbsTask):
|
|
|
94
94
|
hf_subset: str,
|
|
95
95
|
encode_kwargs: EncodeKwargs,
|
|
96
96
|
prediction_folder: Path | None = None,
|
|
97
|
+
num_proc: int = 1,
|
|
97
98
|
**kwargs,
|
|
98
99
|
) -> SummarizationMetrics:
|
|
99
100
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -115,7 +116,7 @@ class AbsTaskSummarization(AbsTask):
|
|
|
115
116
|
hf_subset=hf_subset,
|
|
116
117
|
**kwargs,
|
|
117
118
|
)
|
|
118
|
-
scores = evaluator(model, encode_kwargs=encode_kwargs)
|
|
119
|
+
scores = evaluator(model, encode_kwargs=encode_kwargs, num_proc=num_proc)
|
|
119
120
|
if prediction_folder:
|
|
120
121
|
self._save_task_predictions(
|
|
121
122
|
scores,
|