mteb 2.7.17__py3-none-any.whl → 2.7.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mteb/_create_dataloaders.py +16 -16
- mteb/_evaluators/any_sts_evaluator.py +1 -1
- mteb/_evaluators/clustering_evaluator.py +1 -1
- mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +2 -2
- mteb/_evaluators/pair_classification_evaluator.py +1 -1
- mteb/_evaluators/retrieval_evaluator.py +1 -1
- mteb/_evaluators/sklearn_evaluator.py +4 -2
- mteb/_evaluators/text/bitext_mining_evaluator.py +1 -1
- mteb/_evaluators/text/summarization_evaluator.py +1 -1
- mteb/_evaluators/zeroshot_classification_evaluator.py +1 -1
- mteb/abstasks/abstask.py +4 -4
- mteb/abstasks/classification.py +2 -2
- mteb/abstasks/clustering.py +1 -1
- mteb/abstasks/clustering_legacy.py +1 -1
- mteb/abstasks/image/image_text_pair_classification.py +1 -1
- mteb/abstasks/multilabel_classification.py +1 -1
- mteb/abstasks/pair_classification.py +1 -1
- mteb/abstasks/retrieval.py +8 -5
- mteb/abstasks/retrieval_dataset_loaders.py +27 -8
- mteb/abstasks/sts.py +1 -1
- mteb/abstasks/text/bitext_mining.py +2 -2
- mteb/abstasks/text/reranking.py +1 -1
- mteb/abstasks/text/summarization.py +1 -1
- mteb/abstasks/zeroshot_classification.py +1 -1
- mteb/evaluate.py +2 -2
- mteb/models/model_implementations/bm25.py +2 -2
- mteb/models/model_implementations/ict_time_and_querit_models.py +115 -0
- mteb/models/model_implementations/pylate_models.py +4 -4
- mteb/models/models_protocols.py +2 -2
- mteb/models/search_wrappers.py +4 -4
- mteb/tasks/bitext_mining/multilingual/bible_nlp_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
- mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +1 -1
- mteb/tasks/classification/ben/bengali_document_classification.py +2 -2
- mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/hin_dialect_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/indic_sentiment_classification.py +1 -1
- mteb/tasks/classification/multilingual/language_classification.py +1 -1
- mteb/tasks/classification/multilingual/south_african_lang_classification.py +1 -1
- mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
- mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +2 -2
- mteb/tasks/classification/swa/swahili_news_classification.py +2 -2
- mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py +1 -1
- mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py +1 -1
- mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
- mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
- mteb/tasks/clustering/nob/vg_hierarchical_clustering.py +2 -2
- mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
- mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
- mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
- mteb/tasks/multichoice/eng/cv_bench.py +4 -4
- mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +1 -1
- mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +1 -1
- mteb/tasks/pair_classification/multilingual/pub_chem_wiki_pair_classification.py +1 -1
- mteb/tasks/pair_classification/multilingual/rte3.py +1 -1
- mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/code/code_rag.py +8 -8
- mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
- mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
- mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
- mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
- mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
- mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
- mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
- mteb/tasks/retrieval/dan/dan_fever_retrieval.py +2 -2
- mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
- mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
- mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
- mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
- mteb/tasks/retrieval/eng/bright_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/ml_questions.py +1 -1
- mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
- mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
- mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
- mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
- mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
- mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
- mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -2
- mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +5 -5
- mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
- mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
- mteb/tasks/retrieval/nob/norquad.py +2 -2
- mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
- mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
- mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
- mteb/tasks/sts/multilingual/sem_rel24_sts.py +1 -1
- mteb/tasks/sts/multilingual/sts_benchmark_multilingual_sts.py +1 -1
- mteb/tasks/sts/por/assin2_sts.py +1 -1
- mteb/types/_encoder_io.py +1 -1
- {mteb-2.7.17.dist-info → mteb-2.7.19.dist-info}/METADATA +1 -1
- {mteb-2.7.17.dist-info → mteb-2.7.19.dist-info}/RECORD +156 -155
- {mteb-2.7.17.dist-info → mteb-2.7.19.dist-info}/WHEEL +0 -0
- {mteb-2.7.17.dist-info → mteb-2.7.19.dist-info}/entry_points.txt +0 -0
- {mteb-2.7.17.dist-info → mteb-2.7.19.dist-info}/licenses/LICENSE +0 -0
- {mteb-2.7.17.dist-info → mteb-2.7.19.dist-info}/top_level.txt +0 -0
mteb/_create_dataloaders.py
CHANGED
|
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|
|
30
30
|
def _create_dataloader_from_texts(
|
|
31
31
|
text: list[str],
|
|
32
32
|
batch_size: int = 32,
|
|
33
|
-
num_proc: int =
|
|
33
|
+
num_proc: int | None = None,
|
|
34
34
|
**kwargs: Any,
|
|
35
35
|
) -> DataLoader[TextInput]:
|
|
36
36
|
"""Create a dataloader from a list of text.
|
|
@@ -48,7 +48,7 @@ def _create_dataloader_from_texts(
|
|
|
48
48
|
return DataLoader(
|
|
49
49
|
dataset,
|
|
50
50
|
batch_size=batch_size,
|
|
51
|
-
num_workers=num_proc if num_proc > 1 else 0,
|
|
51
|
+
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
|
|
52
52
|
)
|
|
53
53
|
|
|
54
54
|
|
|
@@ -74,7 +74,7 @@ def _corpus_to_dict(
|
|
|
74
74
|
def _create_dataloader_for_retrieval_corpus(
|
|
75
75
|
dataset: Dataset,
|
|
76
76
|
batch_size: int = 32,
|
|
77
|
-
num_proc: int =
|
|
77
|
+
num_proc: int | None = None,
|
|
78
78
|
) -> DataLoader[CorpusInput]:
|
|
79
79
|
"""Create a dataloader from a corpus.
|
|
80
80
|
|
|
@@ -94,7 +94,7 @@ def _create_dataloader_for_retrieval_corpus(
|
|
|
94
94
|
return DataLoader(
|
|
95
95
|
new_ds,
|
|
96
96
|
batch_size=batch_size,
|
|
97
|
-
num_workers=num_proc if num_proc > 1 else 0,
|
|
97
|
+
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
|
|
98
98
|
)
|
|
99
99
|
|
|
100
100
|
|
|
@@ -111,7 +111,7 @@ def _combine_queries_with_instruction_text(row: dict[str, str]) -> dict[str, str
|
|
|
111
111
|
def _create_text_dataloader_for_queries(
|
|
112
112
|
queries: QueryDatasetType,
|
|
113
113
|
batch_size: int = 32,
|
|
114
|
-
num_proc: int =
|
|
114
|
+
num_proc: int | None = None,
|
|
115
115
|
) -> DataLoader[QueryInput]:
|
|
116
116
|
"""Create a dataloader from a list of queries.
|
|
117
117
|
|
|
@@ -131,7 +131,7 @@ def _create_text_dataloader_for_queries(
|
|
|
131
131
|
return DataLoader(
|
|
132
132
|
queries,
|
|
133
133
|
batch_size=batch_size,
|
|
134
|
-
num_workers=num_proc if num_proc > 1 else 0,
|
|
134
|
+
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
|
|
135
135
|
)
|
|
136
136
|
|
|
137
137
|
|
|
@@ -200,7 +200,7 @@ def _convert_conv_history_to_query(
|
|
|
200
200
|
def _create_dataloader_for_queries_conversation(
|
|
201
201
|
queries: QueryDatasetType,
|
|
202
202
|
batch_size: int = 32,
|
|
203
|
-
num_proc: int =
|
|
203
|
+
num_proc: int | None = None,
|
|
204
204
|
) -> DataLoader[QueryInput]:
|
|
205
205
|
"""Create a dataloader from a list of queries.
|
|
206
206
|
|
|
@@ -220,7 +220,7 @@ def _create_dataloader_for_queries_conversation(
|
|
|
220
220
|
),
|
|
221
221
|
collate_fn=_custom_collate_fn,
|
|
222
222
|
batch_size=batch_size,
|
|
223
|
-
num_workers=num_proc if num_proc > 1 else 0,
|
|
223
|
+
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
|
|
224
224
|
)
|
|
225
225
|
|
|
226
226
|
|
|
@@ -265,7 +265,7 @@ def _prepare_image_dataset(
|
|
|
265
265
|
dataset: Dataset,
|
|
266
266
|
image_column_name: str | None = None,
|
|
267
267
|
transform: Callable[[Any], Any] | None = None,
|
|
268
|
-
num_proc: int =
|
|
268
|
+
num_proc: int | None = None,
|
|
269
269
|
) -> Dataset:
|
|
270
270
|
"""Prepare the image dataset by converting images to RGB and applying transformations."""
|
|
271
271
|
if (
|
|
@@ -315,7 +315,7 @@ def _create_image_dataloader(
|
|
|
315
315
|
batch_size: int = 32,
|
|
316
316
|
transform: Callable[[Any], Any] | None = None,
|
|
317
317
|
collate_fn: Callable[[list[dict[str, Any]]], dict[str, Any]] = _custom_collate_fn,
|
|
318
|
-
num_proc: int =
|
|
318
|
+
num_proc: int | None = None,
|
|
319
319
|
) -> DataLoader[ImageInput]:
|
|
320
320
|
"""Creates a DataLoader with the image dataset prepared using the explicit transformation.
|
|
321
321
|
|
|
@@ -341,14 +341,14 @@ def _create_image_dataloader(
|
|
|
341
341
|
batch_size=batch_size,
|
|
342
342
|
collate_fn=collate_fn,
|
|
343
343
|
shuffle=False,
|
|
344
|
-
num_workers=num_proc if num_proc > 1 else 0,
|
|
344
|
+
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
|
|
345
345
|
)
|
|
346
346
|
|
|
347
347
|
|
|
348
348
|
def _create_text_queries_dataloader(
|
|
349
349
|
dataset: Dataset,
|
|
350
350
|
batch_size: int = 32,
|
|
351
|
-
num_proc: int =
|
|
351
|
+
num_proc: int | None = None,
|
|
352
352
|
) -> DataLoader[QueryInput]:
|
|
353
353
|
if not isinstance(dataset["text"][0], list):
|
|
354
354
|
return _create_text_dataloader_for_queries(
|
|
@@ -368,7 +368,7 @@ def _create_queries_dataloader(
|
|
|
368
368
|
task_metadata: TaskMetadata,
|
|
369
369
|
input_column: str | None = None,
|
|
370
370
|
batch_size: int = 32,
|
|
371
|
-
num_proc: int =
|
|
371
|
+
num_proc: int | None = None,
|
|
372
372
|
) -> DataLoader[QueryInput | ImageInput]:
|
|
373
373
|
"""Create a dataloader for queries."""
|
|
374
374
|
queries_type = task_metadata.get_modalities(PromptType.query)
|
|
@@ -393,7 +393,7 @@ def _create_document_dataloader(
|
|
|
393
393
|
task_metadata: TaskMetadata,
|
|
394
394
|
input_column: str | None = None,
|
|
395
395
|
batch_size: int = 32,
|
|
396
|
-
num_proc: int =
|
|
396
|
+
num_proc: int | None = None,
|
|
397
397
|
) -> DataLoader[CorpusInput | ImageInput]:
|
|
398
398
|
"""Create a dataloader for documents.
|
|
399
399
|
|
|
@@ -430,7 +430,7 @@ def create_dataloader(
|
|
|
430
430
|
prompt_type: PromptType | None = None,
|
|
431
431
|
input_column: str | None = None,
|
|
432
432
|
batch_size: int = 32,
|
|
433
|
-
num_proc: int =
|
|
433
|
+
num_proc: int | None = None,
|
|
434
434
|
**kwargs: Any,
|
|
435
435
|
) -> DataLoader[BatchedInput]:
|
|
436
436
|
"""Create a dataloader from a dataset.
|
|
@@ -482,5 +482,5 @@ def create_dataloader(
|
|
|
482
482
|
return DataLoader(
|
|
483
483
|
dataset,
|
|
484
484
|
batch_size=batch_size,
|
|
485
|
-
num_workers=num_proc if num_proc > 1 else 0,
|
|
485
|
+
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
|
|
486
486
|
)
|
|
@@ -66,7 +66,7 @@ class AnySTSEvaluator(Evaluator):
|
|
|
66
66
|
model: EncoderProtocol,
|
|
67
67
|
*,
|
|
68
68
|
encode_kwargs: EncodeKwargs,
|
|
69
|
-
num_proc: int =
|
|
69
|
+
num_proc: int | None = None,
|
|
70
70
|
) -> STSEvaluatorScores:
|
|
71
71
|
logger.info("Running semantic similarity - Encoding samples (1/2)")
|
|
72
72
|
embeddings1 = model.encode(
|
|
@@ -91,7 +91,7 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
91
91
|
model: EncoderProtocol,
|
|
92
92
|
*,
|
|
93
93
|
encode_kwargs: EncodeKwargs,
|
|
94
|
-
num_proc: int =
|
|
94
|
+
num_proc: int | None = None,
|
|
95
95
|
) -> list[torch.Tensor]:
|
|
96
96
|
images = []
|
|
97
97
|
if isinstance(self.images_column_names, str):
|
|
@@ -139,7 +139,7 @@ class ImageTextPairClassificationEvaluator(Evaluator):
|
|
|
139
139
|
DataLoader(
|
|
140
140
|
CustomImageDataset(images),
|
|
141
141
|
collate_fn=_image_collate_fn,
|
|
142
|
-
num_workers=num_proc if num_proc > 1 else 0,
|
|
142
|
+
num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
|
|
143
143
|
),
|
|
144
144
|
task_metadata=self.task_metadata,
|
|
145
145
|
hf_subset=self.hf_subset,
|
|
@@ -92,7 +92,7 @@ class PairClassificationEvaluator(Evaluator):
|
|
|
92
92
|
self,
|
|
93
93
|
model: EncoderProtocol,
|
|
94
94
|
encode_kwargs: EncodeKwargs,
|
|
95
|
-
num_proc: int =
|
|
95
|
+
num_proc: int | None = None,
|
|
96
96
|
) -> PairClassificationDistances:
|
|
97
97
|
logger.info("Running pair classification - Encoding samples (1/2)")
|
|
98
98
|
embeddings1 = model.encode(
|
|
@@ -55,7 +55,7 @@ class RetrievalEvaluator(Evaluator):
|
|
|
55
55
|
self,
|
|
56
56
|
search_model: SearchProtocol,
|
|
57
57
|
encode_kwargs: EncodeKwargs,
|
|
58
|
-
num_proc: int =
|
|
58
|
+
num_proc: int | None = None,
|
|
59
59
|
) -> RetrievalOutputType:
|
|
60
60
|
logger.info("Running retrieval task - Indexing corpus...")
|
|
61
61
|
search_model.index(
|
|
@@ -59,7 +59,9 @@ class SklearnEvaluator(Evaluator):
|
|
|
59
59
|
self.evaluator_model = evaluator_model
|
|
60
60
|
|
|
61
61
|
def create_dataloaders(
|
|
62
|
-
self,
|
|
62
|
+
self,
|
|
63
|
+
encode_kwargs: EncodeKwargs,
|
|
64
|
+
num_proc: int | None,
|
|
63
65
|
) -> tuple[DataLoader[BatchedInput], DataLoader[BatchedInput]]:
|
|
64
66
|
dataloader_train = create_dataloader(
|
|
65
67
|
self.train_dataset,
|
|
@@ -83,7 +85,7 @@ class SklearnEvaluator(Evaluator):
|
|
|
83
85
|
*,
|
|
84
86
|
encode_kwargs: EncodeKwargs,
|
|
85
87
|
test_cache: Array | None = None,
|
|
86
|
-
num_proc: int =
|
|
88
|
+
num_proc: int | None = None,
|
|
87
89
|
) -> tuple[NDArray[np.integer | np.floating], Array]:
|
|
88
90
|
"""Classification evaluation by training a sklearn classifier on the embeddings of the training set and evaluating on the embeddings of the test set.
|
|
89
91
|
|
|
@@ -41,7 +41,7 @@ class BitextMiningEvaluator(Evaluator):
|
|
|
41
41
|
model: EncoderProtocol,
|
|
42
42
|
*,
|
|
43
43
|
encode_kwargs: EncodeKwargs,
|
|
44
|
-
num_proc: int =
|
|
44
|
+
num_proc: int | None = None,
|
|
45
45
|
) -> dict[str, list[dict[str, float]]]:
|
|
46
46
|
pair_elements = {p for pair in self.pairs for p in pair}
|
|
47
47
|
if isinstance(self.sentences, Dataset):
|
|
@@ -100,7 +100,7 @@ class SummarizationEvaluator(Evaluator):
|
|
|
100
100
|
model: EncoderProtocol,
|
|
101
101
|
*,
|
|
102
102
|
encode_kwargs: EncodeKwargs,
|
|
103
|
-
num_proc: int =
|
|
103
|
+
num_proc: int | None = None,
|
|
104
104
|
) -> SummarizationDistances:
|
|
105
105
|
# Get the human & machine summaries for the text in one go for all
|
|
106
106
|
human_lens = [len(human_summaries) for human_summaries in self.human_summaries]
|
mteb/abstasks/abstask.py
CHANGED
|
@@ -116,7 +116,7 @@ class AbsTask(ABC):
|
|
|
116
116
|
logger.warning(msg)
|
|
117
117
|
warnings.warn(msg)
|
|
118
118
|
|
|
119
|
-
def dataset_transform(self, num_proc: int =
|
|
119
|
+
def dataset_transform(self, num_proc: int | None = None, **kwargs: Any) -> None:
|
|
120
120
|
"""A transform operations applied to the dataset after loading.
|
|
121
121
|
|
|
122
122
|
This method is useful when the dataset from Huggingface is not in an `mteb` compatible format.
|
|
@@ -136,7 +136,7 @@ class AbsTask(ABC):
|
|
|
136
136
|
*,
|
|
137
137
|
encode_kwargs: EncodeKwargs,
|
|
138
138
|
prediction_folder: Path | None = None,
|
|
139
|
-
num_proc: int =
|
|
139
|
+
num_proc: int | None = None,
|
|
140
140
|
**kwargs: Any,
|
|
141
141
|
) -> Mapping[HFSubset, ScoresDict]:
|
|
142
142
|
"""Evaluates an MTEB compatible model on the task.
|
|
@@ -219,7 +219,7 @@ class AbsTask(ABC):
|
|
|
219
219
|
hf_subset: str,
|
|
220
220
|
encode_kwargs: EncodeKwargs,
|
|
221
221
|
prediction_folder: Path | None = None,
|
|
222
|
-
num_proc: int =
|
|
222
|
+
num_proc: int | None = None,
|
|
223
223
|
**kwargs: Any,
|
|
224
224
|
) -> ScoresDict:
|
|
225
225
|
raise NotImplementedError(
|
|
@@ -324,7 +324,7 @@ class AbsTask(ABC):
|
|
|
324
324
|
) # only take the specified test split.
|
|
325
325
|
return dataset_dict
|
|
326
326
|
|
|
327
|
-
def load_data(self, num_proc: int =
|
|
327
|
+
def load_data(self, num_proc: int | None = None, **kwargs: Any) -> None:
|
|
328
328
|
"""Loads dataset from HuggingFace hub
|
|
329
329
|
|
|
330
330
|
This is the main loading function for Task. Do not overwrite this, instead we recommend using `dataset_transform`, which is called after the
|
mteb/abstasks/classification.py
CHANGED
|
@@ -138,7 +138,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
138
138
|
*,
|
|
139
139
|
encode_kwargs: EncodeKwargs,
|
|
140
140
|
prediction_folder: Path | None = None,
|
|
141
|
-
num_proc: int =
|
|
141
|
+
num_proc: int | None = None,
|
|
142
142
|
**kwargs: Any,
|
|
143
143
|
) -> dict[HFSubset, ScoresDict]:
|
|
144
144
|
"""Evaluate a model on the classification task.
|
|
@@ -201,7 +201,7 @@ class AbsTaskClassification(AbsTask):
|
|
|
201
201
|
hf_split: str,
|
|
202
202
|
hf_subset: str,
|
|
203
203
|
prediction_folder: Path | None = None,
|
|
204
|
-
num_proc: int =
|
|
204
|
+
num_proc: int | None = None,
|
|
205
205
|
**kwargs: Any,
|
|
206
206
|
) -> FullClassificationMetrics:
|
|
207
207
|
if not isinstance(model, EncoderProtocol):
|
mteb/abstasks/clustering.py
CHANGED
|
@@ -169,7 +169,7 @@ class AbsTaskClustering(AbsTask):
|
|
|
169
169
|
hf_split: str,
|
|
170
170
|
hf_subset: str,
|
|
171
171
|
prediction_folder: Path | None = None,
|
|
172
|
-
num_proc: int =
|
|
172
|
+
num_proc: int | None = None,
|
|
173
173
|
**kwargs: Any,
|
|
174
174
|
) -> ScoresDict:
|
|
175
175
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -95,7 +95,7 @@ class AbsTaskClusteringLegacy(AbsTask):
|
|
|
95
95
|
hf_split: str,
|
|
96
96
|
hf_subset: str,
|
|
97
97
|
prediction_folder: Path | None = None,
|
|
98
|
-
num_proc: int =
|
|
98
|
+
num_proc: int | None = None,
|
|
99
99
|
**kwargs: Any,
|
|
100
100
|
) -> ScoresDict:
|
|
101
101
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -134,7 +134,7 @@ class AbsTaskImageTextPairClassification(AbsTask):
|
|
|
134
134
|
hf_split: str,
|
|
135
135
|
hf_subset: str,
|
|
136
136
|
prediction_folder: Path | None = None,
|
|
137
|
-
num_proc: int =
|
|
137
|
+
num_proc: int | None = None,
|
|
138
138
|
**kwargs: Any,
|
|
139
139
|
) -> ImageTextPairClassificationMetrics:
|
|
140
140
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -95,7 +95,7 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
|
|
|
95
95
|
hf_split: str,
|
|
96
96
|
hf_subset: str,
|
|
97
97
|
prediction_folder: Path | None = None,
|
|
98
|
-
num_proc: int =
|
|
98
|
+
num_proc: int | None = None,
|
|
99
99
|
**kwargs: Any,
|
|
100
100
|
) -> FullMultilabelClassificationMetrics:
|
|
101
101
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -97,7 +97,7 @@ class AbsTaskPairClassification(AbsTask):
|
|
|
97
97
|
hf_subset: str,
|
|
98
98
|
encode_kwargs: EncodeKwargs,
|
|
99
99
|
prediction_folder: Path | None = None,
|
|
100
|
-
num_proc: int =
|
|
100
|
+
num_proc: int | None = None,
|
|
101
101
|
**kwargs,
|
|
102
102
|
) -> dict[str, float]:
|
|
103
103
|
if not isinstance(model, EncoderProtocol):
|
mteb/abstasks/retrieval.py
CHANGED
|
@@ -148,7 +148,10 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
148
148
|
)
|
|
149
149
|
)
|
|
150
150
|
|
|
151
|
-
def convert_v1_dataset_format_to_v2(
|
|
151
|
+
def convert_v1_dataset_format_to_v2(
|
|
152
|
+
self,
|
|
153
|
+
num_proc: int | None,
|
|
154
|
+
) -> None:
|
|
152
155
|
"""Convert dataset from v1 (from `self.queries`, `self.document`) format to v2 format (`self.dotaset`)."""
|
|
153
156
|
# check if dataset is `v1` version
|
|
154
157
|
if not hasattr(self, "queries"):
|
|
@@ -257,7 +260,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
257
260
|
if hasattr(self, "top_ranked"):
|
|
258
261
|
del self.top_ranked
|
|
259
262
|
|
|
260
|
-
def load_data(self, num_proc: int =
|
|
263
|
+
def load_data(self, num_proc: int | None = None, **kwargs) -> None:
|
|
261
264
|
"""Load the dataset for the retrieval task."""
|
|
262
265
|
if self.data_loaded:
|
|
263
266
|
return
|
|
@@ -301,7 +304,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
301
304
|
*,
|
|
302
305
|
encode_kwargs: EncodeKwargs,
|
|
303
306
|
prediction_folder: Path | None = None,
|
|
304
|
-
num_proc: int =
|
|
307
|
+
num_proc: int | None = None,
|
|
305
308
|
**kwargs: Any,
|
|
306
309
|
) -> Mapping[HFSubset, ScoresDict]:
|
|
307
310
|
"""Evaluate the model on the retrieval task.
|
|
@@ -342,7 +345,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
342
345
|
hf_split: str,
|
|
343
346
|
hf_subset: str,
|
|
344
347
|
prediction_folder: Path | None = None,
|
|
345
|
-
num_proc: int =
|
|
348
|
+
num_proc: int | None = None,
|
|
346
349
|
**kwargs,
|
|
347
350
|
) -> ScoresDict:
|
|
348
351
|
"""Evaluate a model on a specific subset of the data.
|
|
@@ -473,7 +476,7 @@ class AbsTaskRetrieval(AbsTask):
|
|
|
473
476
|
split: str,
|
|
474
477
|
hf_subset: str | None = None,
|
|
475
478
|
compute_overall: bool = False,
|
|
476
|
-
num_proc: int =
|
|
479
|
+
num_proc: int | None = None,
|
|
477
480
|
) -> RetrievalDescriptiveStatistics:
|
|
478
481
|
self.convert_v1_dataset_format_to_v2(num_proc)
|
|
479
482
|
if hf_subset and hf_subset in self.dataset:
|
|
@@ -78,7 +78,7 @@ class RetrievalDatasetLoader:
|
|
|
78
78
|
|
|
79
79
|
def load(
|
|
80
80
|
self,
|
|
81
|
-
num_proc: int =
|
|
81
|
+
num_proc: int | None = None,
|
|
82
82
|
) -> RetrievalSplitData:
|
|
83
83
|
"""Loads the dataset split for the specified configuration.
|
|
84
84
|
|
|
@@ -128,7 +128,11 @@ class RetrievalDatasetLoader:
|
|
|
128
128
|
f"Split {self.split} not found in {splits}. Please specify a valid split."
|
|
129
129
|
)
|
|
130
130
|
|
|
131
|
-
def _load_dataset_split(
|
|
131
|
+
def _load_dataset_split(
|
|
132
|
+
self,
|
|
133
|
+
config: str,
|
|
134
|
+
num_proc: int | None,
|
|
135
|
+
) -> Dataset:
|
|
132
136
|
return load_dataset(
|
|
133
137
|
self.hf_repo,
|
|
134
138
|
config,
|
|
@@ -138,7 +142,10 @@ class RetrievalDatasetLoader:
|
|
|
138
142
|
num_proc=num_proc,
|
|
139
143
|
)
|
|
140
144
|
|
|
141
|
-
def _load_corpus(
|
|
145
|
+
def _load_corpus(
|
|
146
|
+
self,
|
|
147
|
+
num_proc: int | None,
|
|
148
|
+
) -> CorpusDatasetType:
|
|
142
149
|
config = f"{self.config}-corpus" if self.config is not None else "corpus"
|
|
143
150
|
logger.info("Loading corpus subset: %s", config)
|
|
144
151
|
|
|
@@ -151,7 +158,10 @@ class RetrievalDatasetLoader:
|
|
|
151
158
|
logger.debug("Doc Example: %s", corpus_ds[0])
|
|
152
159
|
return corpus_ds
|
|
153
160
|
|
|
154
|
-
def _load_queries(
|
|
161
|
+
def _load_queries(
|
|
162
|
+
self,
|
|
163
|
+
num_proc: int | None,
|
|
164
|
+
) -> QueryDatasetType:
|
|
155
165
|
config = f"{self.config}-queries" if self.config is not None else "queries"
|
|
156
166
|
logger.info("Loading queries subset: %s", config)
|
|
157
167
|
|
|
@@ -168,7 +178,10 @@ class RetrievalDatasetLoader:
|
|
|
168
178
|
|
|
169
179
|
return queries_ds
|
|
170
180
|
|
|
171
|
-
def _load_qrels(
|
|
181
|
+
def _load_qrels(
|
|
182
|
+
self,
|
|
183
|
+
num_proc: int | None,
|
|
184
|
+
) -> RelevantDocumentsType:
|
|
172
185
|
config = f"{self.config}-qrels" if self.config is not None else "default"
|
|
173
186
|
|
|
174
187
|
logger.info("Loading qrels subset: %s", config)
|
|
@@ -203,7 +216,10 @@ class RetrievalDatasetLoader:
|
|
|
203
216
|
logger.info("Loaded %d %s qrels.", len(qrels_dict), self.split.upper())
|
|
204
217
|
return qrels_dict
|
|
205
218
|
|
|
206
|
-
def _load_top_ranked(
|
|
219
|
+
def _load_top_ranked(
|
|
220
|
+
self,
|
|
221
|
+
num_proc: int | None,
|
|
222
|
+
) -> TopRankedDocumentsType:
|
|
207
223
|
config = (
|
|
208
224
|
f"{self.config}-top_ranked" if self.config is not None else "top_ranked"
|
|
209
225
|
)
|
|
@@ -226,7 +242,10 @@ class RetrievalDatasetLoader:
|
|
|
226
242
|
logger.info(f"Top ranked loaded: {len(top_ranked_ds)}")
|
|
227
243
|
return top_ranked_dict
|
|
228
244
|
|
|
229
|
-
def _load_instructions(
|
|
245
|
+
def _load_instructions(
|
|
246
|
+
self,
|
|
247
|
+
num_proc: int | None,
|
|
248
|
+
) -> InstructionDatasetType:
|
|
230
249
|
config = (
|
|
231
250
|
f"{self.config}-instruction" if self.config is not None else "instruction"
|
|
232
251
|
)
|
|
@@ -246,7 +265,7 @@ class RetrievalDatasetLoader:
|
|
|
246
265
|
def _combine_queries_with_instructions_datasets(
|
|
247
266
|
queries_dataset: QueryDatasetType,
|
|
248
267
|
instruction_dataset: InstructionDatasetType | dict[str, str],
|
|
249
|
-
num_proc: int,
|
|
268
|
+
num_proc: int | None,
|
|
250
269
|
) -> Dataset:
|
|
251
270
|
if isinstance(instruction_dataset, Dataset):
|
|
252
271
|
instruction_to_query_idx = {
|
mteb/abstasks/sts.py
CHANGED
|
@@ -82,7 +82,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
82
82
|
*,
|
|
83
83
|
encode_kwargs: EncodeKwargs,
|
|
84
84
|
prediction_folder: Path | None = None,
|
|
85
|
-
num_proc: int =
|
|
85
|
+
num_proc: int | None = None,
|
|
86
86
|
**kwargs: Any,
|
|
87
87
|
) -> dict[HFSubset, ScoresDict]:
|
|
88
88
|
"""Added load for "parallel" datasets"""
|
|
@@ -155,7 +155,7 @@ class AbsTaskBitextMining(AbsTask):
|
|
|
155
155
|
encode_kwargs: EncodeKwargs,
|
|
156
156
|
prediction_folder: Path | None = None,
|
|
157
157
|
parallel: bool = False,
|
|
158
|
-
num_proc: int =
|
|
158
|
+
num_proc: int | None = None,
|
|
159
159
|
**kwargs,
|
|
160
160
|
) -> BitextMiningMetrics | dict[str, BitextMiningMetrics]:
|
|
161
161
|
pairs = self._get_pairs(parallel)
|
mteb/abstasks/text/reranking.py
CHANGED
|
@@ -34,7 +34,7 @@ class AbsTaskReranking(AbsTaskRetrieval):
|
|
|
34
34
|
For dataformat and other information, see [AbsTaskRetrieval][mteb.abstasks.retrieval.AbsTaskRetrieval].
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
def load_data(self, num_proc: int =
|
|
37
|
+
def load_data(self, num_proc: int | None = None, **kwargs) -> None:
|
|
38
38
|
"""Load the dataset."""
|
|
39
39
|
if self.data_loaded:
|
|
40
40
|
return
|
|
@@ -94,7 +94,7 @@ class AbsTaskSummarization(AbsTask):
|
|
|
94
94
|
hf_subset: str,
|
|
95
95
|
encode_kwargs: EncodeKwargs,
|
|
96
96
|
prediction_folder: Path | None = None,
|
|
97
|
-
num_proc: int =
|
|
97
|
+
num_proc: int | None = None,
|
|
98
98
|
**kwargs,
|
|
99
99
|
) -> SummarizationMetrics:
|
|
100
100
|
if not isinstance(model, EncoderProtocol):
|
|
@@ -127,7 +127,7 @@ class AbsTaskZeroShotClassification(AbsTask):
|
|
|
127
127
|
hf_subset: str,
|
|
128
128
|
encode_kwargs: EncodeKwargs,
|
|
129
129
|
prediction_folder: Path | None = None,
|
|
130
|
-
num_proc: int =
|
|
130
|
+
num_proc: int | None = None,
|
|
131
131
|
**kwargs,
|
|
132
132
|
) -> ZeroShotClassificationMetrics:
|
|
133
133
|
if not isinstance(model, EncoderProtocol):
|
mteb/evaluate.py
CHANGED
|
@@ -91,7 +91,7 @@ def _evaluate_task(
|
|
|
91
91
|
encode_kwargs: EncodeKwargs,
|
|
92
92
|
prediction_folder: Path | None,
|
|
93
93
|
public_only: bool | None,
|
|
94
|
-
num_proc: int =
|
|
94
|
+
num_proc: int | None = None,
|
|
95
95
|
) -> TaskResult | TaskError:
|
|
96
96
|
"""The core logic to run a model on a given task. See `evaluate` for more details.
|
|
97
97
|
|
|
@@ -282,7 +282,7 @@ def evaluate(
|
|
|
282
282
|
prediction_folder: Path | str | None = None,
|
|
283
283
|
show_progress_bar: bool = True,
|
|
284
284
|
public_only: bool | None = None,
|
|
285
|
-
num_proc: int =
|
|
285
|
+
num_proc: int | None = None,
|
|
286
286
|
) -> ModelResult:
|
|
287
287
|
"""This function runs a model on a given task and returns the results.
|
|
288
288
|
|
|
@@ -54,7 +54,7 @@ def bm25_loader(model_name, **kwargs) -> SearchProtocol:
|
|
|
54
54
|
hf_split: str,
|
|
55
55
|
hf_subset: str,
|
|
56
56
|
encode_kwargs: EncodeKwargs,
|
|
57
|
-
num_proc: int =
|
|
57
|
+
num_proc: int | None = None,
|
|
58
58
|
) -> None:
|
|
59
59
|
logger.info("Encoding Corpus...")
|
|
60
60
|
corpus_texts = [
|
|
@@ -81,7 +81,7 @@ def bm25_loader(model_name, **kwargs) -> SearchProtocol:
|
|
|
81
81
|
top_k: int,
|
|
82
82
|
encode_kwargs: EncodeKwargs,
|
|
83
83
|
top_ranked: TopRankedDocumentsType | None = None,
|
|
84
|
-
num_proc: int =
|
|
84
|
+
num_proc: int | None = None,
|
|
85
85
|
) -> RetrievalOutputType:
|
|
86
86
|
logger.info("Encoding Queries...")
|
|
87
87
|
query_ids = list(queries["id"])
|