mteb 2.7.16__py3-none-any.whl → 2.7.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. mteb/_create_dataloaders.py +16 -16
  2. mteb/_evaluators/any_sts_evaluator.py +1 -1
  3. mteb/_evaluators/classification_metrics.py +10 -1
  4. mteb/_evaluators/clustering_evaluator.py +1 -1
  5. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +2 -2
  6. mteb/_evaluators/pair_classification_evaluator.py +3 -2
  7. mteb/_evaluators/retrieval_evaluator.py +1 -1
  8. mteb/_evaluators/retrieval_metrics.py +9 -7
  9. mteb/_evaluators/sklearn_evaluator.py +13 -6
  10. mteb/_evaluators/text/bitext_mining_evaluator.py +1 -1
  11. mteb/_evaluators/text/summarization_evaluator.py +1 -1
  12. mteb/_evaluators/zeroshot_classification_evaluator.py +1 -1
  13. mteb/abstasks/_stratification.py +13 -8
  14. mteb/abstasks/abstask.py +4 -4
  15. mteb/abstasks/classification.py +6 -4
  16. mteb/abstasks/clustering.py +1 -1
  17. mteb/abstasks/clustering_legacy.py +1 -1
  18. mteb/abstasks/image/image_text_pair_classification.py +1 -1
  19. mteb/abstasks/multilabel_classification.py +7 -5
  20. mteb/abstasks/pair_classification.py +1 -1
  21. mteb/abstasks/regression.py +3 -2
  22. mteb/abstasks/retrieval.py +8 -5
  23. mteb/abstasks/retrieval_dataset_loaders.py +27 -8
  24. mteb/abstasks/sts.py +1 -1
  25. mteb/abstasks/text/bitext_mining.py +2 -2
  26. mteb/abstasks/text/reranking.py +1 -1
  27. mteb/abstasks/text/summarization.py +1 -1
  28. mteb/abstasks/zeroshot_classification.py +1 -1
  29. mteb/benchmarks/benchmark.py +131 -3
  30. mteb/evaluate.py +2 -2
  31. mteb/leaderboard/figures.py +2 -1
  32. mteb/leaderboard/table.py +10 -2
  33. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -3
  34. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +3 -3
  35. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +8 -3
  36. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  37. mteb/models/model_implementations/bedrock_models.py +4 -4
  38. mteb/models/model_implementations/bm25.py +2 -2
  39. mteb/models/model_implementations/mcinext_models.py +2 -2
  40. mteb/models/model_implementations/openai_models.py +2 -1
  41. mteb/models/model_implementations/pylate_models.py +4 -4
  42. mteb/models/model_implementations/random_baseline.py +4 -3
  43. mteb/models/model_implementations/seed_models.py +7 -2
  44. mteb/models/model_implementations/voyage_models.py +1 -1
  45. mteb/models/models_protocols.py +2 -2
  46. mteb/models/search_wrappers.py +4 -4
  47. mteb/tasks/bitext_mining/multilingual/bible_nlp_bitext_mining.py +1 -1
  48. mteb/tasks/bitext_mining/multilingual/flores_bitext_mining.py +1 -1
  49. mteb/tasks/bitext_mining/multilingual/in22_conv_bitext_mining.py +1 -1
  50. mteb/tasks/bitext_mining/multilingual/in22_gen_bitext_mining.py +1 -1
  51. mteb/tasks/bitext_mining/multilingual/ntrex_bitext_mining.py +1 -1
  52. mteb/tasks/bitext_mining/multilingual/roma_tales_bitext_mining.py +1 -1
  53. mteb/tasks/classification/ben/bengali_document_classification.py +2 -2
  54. mteb/tasks/classification/ces/czech_product_review_sentiment_classification.py +2 -2
  55. mteb/tasks/classification/ces/czech_so_me_sentiment_classification.py +1 -1
  56. mteb/tasks/classification/multilingual/hin_dialect_classification.py +1 -1
  57. mteb/tasks/classification/multilingual/indic_lang_classification.py +1 -1
  58. mteb/tasks/classification/multilingual/indic_sentiment_classification.py +1 -1
  59. mteb/tasks/classification/multilingual/language_classification.py +1 -1
  60. mteb/tasks/classification/multilingual/south_african_lang_classification.py +1 -1
  61. mteb/tasks/classification/multilingual/turkic_classification.py +1 -1
  62. mteb/tasks/classification/slk/slovak_movie_review_sentiment_classification.py +2 -2
  63. mteb/tasks/classification/swa/swahili_news_classification.py +2 -2
  64. mteb/tasks/clustering/deu/ten_k_gnad_clustering_p2p.py +1 -1
  65. mteb/tasks/clustering/deu/ten_k_gnad_clustering_s2s.py +1 -1
  66. mteb/tasks/clustering/multilingual/mlsum_clustering_p2p.py +2 -2
  67. mteb/tasks/clustering/multilingual/mlsum_clustering_s2s.py +2 -2
  68. mteb/tasks/clustering/nob/vg_hierarchical_clustering.py +2 -2
  69. mteb/tasks/image_text_pair_classification/eng/image_co_de.py +1 -1
  70. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  71. mteb/tasks/instruction_reranking/multilingual/m_follow_ir.py +2 -2
  72. mteb/tasks/multichoice/eng/cv_bench.py +4 -4
  73. mteb/tasks/multilabel_classification/nld/covid_disinformation_nl_multi_label_classification.py +1 -1
  74. mteb/tasks/pair_classification/eng/pub_chem_smilespc.py +1 -1
  75. mteb/tasks/pair_classification/multilingual/pub_chem_wiki_pair_classification.py +1 -1
  76. mteb/tasks/pair_classification/multilingual/rte3.py +1 -1
  77. mteb/tasks/retrieval/ara/sadeem_question_retrieval.py +1 -1
  78. mteb/tasks/retrieval/code/code_edit_search_retrieval.py +1 -1
  79. mteb/tasks/retrieval/code/code_rag.py +8 -8
  80. mteb/tasks/retrieval/code/code_search_net_cc_retrieval.py +1 -1
  81. mteb/tasks/retrieval/code/coir_code_search_net_retrieval.py +1 -1
  82. mteb/tasks/retrieval/code/ds1000_retrieval.py +1 -1
  83. mteb/tasks/retrieval/code/fresh_stack_retrieval.py +1 -1
  84. mteb/tasks/retrieval/code/human_eval_retrieval.py +1 -1
  85. mteb/tasks/retrieval/code/mbpp_retrieval.py +1 -1
  86. mteb/tasks/retrieval/code/wiki_sql_retrieval.py +1 -1
  87. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +2 -2
  88. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  89. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  90. mteb/tasks/retrieval/deu/german_gov_service_retrieval.py +1 -1
  91. mteb/tasks/retrieval/deu/german_qu_ad_retrieval.py +1 -1
  92. mteb/tasks/retrieval/ell/greek_civics_qa.py +1 -1
  93. mteb/tasks/retrieval/eng/bright_retrieval.py +1 -1
  94. mteb/tasks/retrieval/eng/chat_doctor_retrieval.py +1 -1
  95. mteb/tasks/retrieval/eng/fin_qa_retrieval.py +1 -1
  96. mteb/tasks/retrieval/eng/finance_bench_retrieval.py +1 -1
  97. mteb/tasks/retrieval/eng/hateful_memes_i2t_retrieval.py +1 -1
  98. mteb/tasks/retrieval/eng/hateful_memes_t2i_retrieval.py +1 -1
  99. mteb/tasks/retrieval/eng/hc3_finance_retrieval.py +1 -1
  100. mteb/tasks/retrieval/eng/lemb_narrative_qa_retrieval.py +1 -1
  101. mteb/tasks/retrieval/eng/lemb_needle_retrieval.py +1 -1
  102. mteb/tasks/retrieval/eng/lemb_passkey_retrieval.py +1 -1
  103. mteb/tasks/retrieval/eng/lemb_summ_screen_fd_retrieval.py +1 -1
  104. mteb/tasks/retrieval/eng/lemb_wikim_qa_retrieval.py +1 -1
  105. mteb/tasks/retrieval/eng/lembqm_sum_retrieval.py +1 -1
  106. mteb/tasks/retrieval/eng/lit_search_retrieval.py +1 -1
  107. mteb/tasks/retrieval/eng/memotion_i2t_retrieval.py +1 -1
  108. mteb/tasks/retrieval/eng/memotion_t2i_retrieval.py +1 -1
  109. mteb/tasks/retrieval/eng/ml_questions.py +1 -1
  110. mteb/tasks/retrieval/eng/nano_argu_ana_retrieval.py +1 -1
  111. mteb/tasks/retrieval/eng/nano_climate_fever_retrieval.py +1 -1
  112. mteb/tasks/retrieval/eng/nano_db_pedia_retrieval.py +1 -1
  113. mteb/tasks/retrieval/eng/nano_fever_retrieval.py +1 -1
  114. mteb/tasks/retrieval/eng/nano_fi_qa2018_retrieval.py +1 -1
  115. mteb/tasks/retrieval/eng/nano_hotpot_qa_retrieval.py +1 -1
  116. mteb/tasks/retrieval/eng/nano_msmarco_retrieval.py +1 -1
  117. mteb/tasks/retrieval/eng/nano_nf_corpus_retrieval.py +1 -1
  118. mteb/tasks/retrieval/eng/nano_nq_retrieval.py +1 -1
  119. mteb/tasks/retrieval/eng/nano_quora_retrieval.py +1 -1
  120. mteb/tasks/retrieval/eng/nano_sci_fact_retrieval.py +1 -1
  121. mteb/tasks/retrieval/eng/nano_scidocs_retrieval.py +1 -1
  122. mteb/tasks/retrieval/eng/nano_touche2020_retrieval.py +1 -1
  123. mteb/tasks/retrieval/eng/narrative_qa_retrieval.py +1 -1
  124. mteb/tasks/retrieval/eng/r2_med_retrieval.py +8 -8
  125. mteb/tasks/retrieval/eng/sci_mmir_i2t_retrieval.py +1 -1
  126. mteb/tasks/retrieval/eng/sci_mmir_t2i_retrieval.py +1 -1
  127. mteb/tasks/retrieval/eng/vidore_bench_retrieval.py +10 -10
  128. mteb/tasks/retrieval/fra/f_qu_ad_retrieval.py +1 -1
  129. mteb/tasks/retrieval/fra/syntec_retrieval.py +1 -1
  130. mteb/tasks/retrieval/hun/hun_sum2.py +1 -1
  131. mteb/tasks/retrieval/kat/georgian_faq_retrieval.py +1 -1
  132. mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt19.py +1 -1
  133. mteb/tasks/retrieval/multilingual/cross_lingual_semantic_discrimination_wmt21.py +1 -1
  134. mteb/tasks/retrieval/multilingual/cur_ev1_retrieval.py +1 -1
  135. mteb/tasks/retrieval/multilingual/jina_vdr_bench_retrieval.py +1 -1
  136. mteb/tasks/retrieval/multilingual/miracl_vision_retrieval.py +1 -1
  137. mteb/tasks/retrieval/multilingual/mr_tidy_retrieval.py +1 -1
  138. mteb/tasks/retrieval/multilingual/public_health_qa_retrieval.py +1 -1
  139. mteb/tasks/retrieval/multilingual/ru_sci_bench_retrieval.py +2 -2
  140. mteb/tasks/retrieval/multilingual/statcan_dialogue_dataset_retrieval.py +1 -1
  141. mteb/tasks/retrieval/multilingual/vdr_multilingual_retrieval.py +1 -1
  142. mteb/tasks/retrieval/multilingual/vidore2_bench_retrieval.py +5 -5
  143. mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +1 -0
  144. mteb/tasks/retrieval/multilingual/wit_t2i_retrieval.py +1 -1
  145. mteb/tasks/retrieval/multilingual/x_flickr30k_co_t2i_retrieval.py +1 -1
  146. mteb/tasks/retrieval/multilingual/x_qu_ad_retrieval.py +1 -1
  147. mteb/tasks/retrieval/multilingual/xm3600_t2i_retrieval.py +1 -1
  148. mteb/tasks/retrieval/nld/cqa_dupstack_android_nl_retrieval.py +1 -1
  149. mteb/tasks/retrieval/nld/cqa_dupstack_english_nl_retrieval.py +1 -1
  150. mteb/tasks/retrieval/nld/cqa_dupstack_gaming_nl_retrieval.py +1 -1
  151. mteb/tasks/retrieval/nld/cqa_dupstack_gis_nl_retrieval.py +1 -1
  152. mteb/tasks/retrieval/nld/cqa_dupstack_mathematica_nl_retrieval.py +1 -1
  153. mteb/tasks/retrieval/nld/cqa_dupstack_physics_nl_retrieval.py +1 -1
  154. mteb/tasks/retrieval/nld/cqa_dupstack_programmers_nl_retrieval.py +1 -1
  155. mteb/tasks/retrieval/nld/cqa_dupstack_stats_nl_retrieval.py +1 -1
  156. mteb/tasks/retrieval/nld/cqa_dupstack_tex_nl_retrieval.py +1 -1
  157. mteb/tasks/retrieval/nld/cqa_dupstack_unix_nl_retrieval.py +1 -1
  158. mteb/tasks/retrieval/nld/cqa_dupstack_webmasters_nl_retrieval.py +1 -1
  159. mteb/tasks/retrieval/nld/cqa_dupstack_wordpress_nl_retrieval.py +1 -1
  160. mteb/tasks/retrieval/nob/norquad.py +2 -2
  161. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  162. mteb/tasks/retrieval/slk/slovak_sum_retrieval.py +1 -1
  163. mteb/tasks/retrieval/vie/vie_qu_ad_retrieval.py +1 -1
  164. mteb/tasks/sts/multilingual/sem_rel24_sts.py +1 -1
  165. mteb/tasks/sts/multilingual/sts_benchmark_multilingual_sts.py +1 -1
  166. mteb/tasks/sts/por/assin2_sts.py +1 -1
  167. mteb/types/_encoder_io.py +3 -2
  168. {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/METADATA +1 -1
  169. {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/RECORD +173 -173
  170. {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/WHEEL +0 -0
  171. {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/entry_points.txt +0 -0
  172. {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/licenses/LICENSE +0 -0
  173. {mteb-2.7.16.dist-info → mteb-2.7.18.dist-info}/top_level.txt +0 -0
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
30
30
  def _create_dataloader_from_texts(
31
31
  text: list[str],
32
32
  batch_size: int = 32,
33
- num_proc: int = 1,
33
+ num_proc: int | None = None,
34
34
  **kwargs: Any,
35
35
  ) -> DataLoader[TextInput]:
36
36
  """Create a dataloader from a list of text.
@@ -48,7 +48,7 @@ def _create_dataloader_from_texts(
48
48
  return DataLoader(
49
49
  dataset,
50
50
  batch_size=batch_size,
51
- num_workers=num_proc if num_proc > 1 else 0,
51
+ num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
52
52
  )
53
53
 
54
54
 
@@ -74,7 +74,7 @@ def _corpus_to_dict(
74
74
  def _create_dataloader_for_retrieval_corpus(
75
75
  dataset: Dataset,
76
76
  batch_size: int = 32,
77
- num_proc: int = 1,
77
+ num_proc: int | None = None,
78
78
  ) -> DataLoader[CorpusInput]:
79
79
  """Create a dataloader from a corpus.
80
80
 
@@ -94,7 +94,7 @@ def _create_dataloader_for_retrieval_corpus(
94
94
  return DataLoader(
95
95
  new_ds,
96
96
  batch_size=batch_size,
97
- num_workers=num_proc if num_proc > 1 else 0,
97
+ num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
98
98
  )
99
99
 
100
100
 
@@ -111,7 +111,7 @@ def _combine_queries_with_instruction_text(row: dict[str, str]) -> dict[str, str
111
111
  def _create_text_dataloader_for_queries(
112
112
  queries: QueryDatasetType,
113
113
  batch_size: int = 32,
114
- num_proc: int = 1,
114
+ num_proc: int | None = None,
115
115
  ) -> DataLoader[QueryInput]:
116
116
  """Create a dataloader from a list of queries.
117
117
 
@@ -131,7 +131,7 @@ def _create_text_dataloader_for_queries(
131
131
  return DataLoader(
132
132
  queries,
133
133
  batch_size=batch_size,
134
- num_workers=num_proc if num_proc > 1 else 0,
134
+ num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
135
135
  )
136
136
 
137
137
 
@@ -200,7 +200,7 @@ def _convert_conv_history_to_query(
200
200
  def _create_dataloader_for_queries_conversation(
201
201
  queries: QueryDatasetType,
202
202
  batch_size: int = 32,
203
- num_proc: int = 1,
203
+ num_proc: int | None = None,
204
204
  ) -> DataLoader[QueryInput]:
205
205
  """Create a dataloader from a list of queries.
206
206
 
@@ -220,7 +220,7 @@ def _create_dataloader_for_queries_conversation(
220
220
  ),
221
221
  collate_fn=_custom_collate_fn,
222
222
  batch_size=batch_size,
223
- num_workers=num_proc if num_proc > 1 else 0,
223
+ num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
224
224
  )
225
225
 
226
226
 
@@ -265,7 +265,7 @@ def _prepare_image_dataset(
265
265
  dataset: Dataset,
266
266
  image_column_name: str | None = None,
267
267
  transform: Callable[[Any], Any] | None = None,
268
- num_proc: int = 1,
268
+ num_proc: int | None = None,
269
269
  ) -> Dataset:
270
270
  """Prepare the image dataset by converting images to RGB and applying transformations."""
271
271
  if (
@@ -315,7 +315,7 @@ def _create_image_dataloader(
315
315
  batch_size: int = 32,
316
316
  transform: Callable[[Any], Any] | None = None,
317
317
  collate_fn: Callable[[list[dict[str, Any]]], dict[str, Any]] = _custom_collate_fn,
318
- num_proc: int = 1,
318
+ num_proc: int | None = None,
319
319
  ) -> DataLoader[ImageInput]:
320
320
  """Creates a DataLoader with the image dataset prepared using the explicit transformation.
321
321
 
@@ -341,14 +341,14 @@ def _create_image_dataloader(
341
341
  batch_size=batch_size,
342
342
  collate_fn=collate_fn,
343
343
  shuffle=False,
344
- num_workers=num_proc if num_proc > 1 else 0,
344
+ num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
345
345
  )
346
346
 
347
347
 
348
348
  def _create_text_queries_dataloader(
349
349
  dataset: Dataset,
350
350
  batch_size: int = 32,
351
- num_proc: int = 1,
351
+ num_proc: int | None = None,
352
352
  ) -> DataLoader[QueryInput]:
353
353
  if not isinstance(dataset["text"][0], list):
354
354
  return _create_text_dataloader_for_queries(
@@ -368,7 +368,7 @@ def _create_queries_dataloader(
368
368
  task_metadata: TaskMetadata,
369
369
  input_column: str | None = None,
370
370
  batch_size: int = 32,
371
- num_proc: int = 1,
371
+ num_proc: int | None = None,
372
372
  ) -> DataLoader[QueryInput | ImageInput]:
373
373
  """Create a dataloader for queries."""
374
374
  queries_type = task_metadata.get_modalities(PromptType.query)
@@ -393,7 +393,7 @@ def _create_document_dataloader(
393
393
  task_metadata: TaskMetadata,
394
394
  input_column: str | None = None,
395
395
  batch_size: int = 32,
396
- num_proc: int = 1,
396
+ num_proc: int | None = None,
397
397
  ) -> DataLoader[CorpusInput | ImageInput]:
398
398
  """Create a dataloader for documents.
399
399
 
@@ -430,7 +430,7 @@ def create_dataloader(
430
430
  prompt_type: PromptType | None = None,
431
431
  input_column: str | None = None,
432
432
  batch_size: int = 32,
433
- num_proc: int = 1,
433
+ num_proc: int | None = None,
434
434
  **kwargs: Any,
435
435
  ) -> DataLoader[BatchedInput]:
436
436
  """Create a dataloader from a dataset.
@@ -482,5 +482,5 @@ def create_dataloader(
482
482
  return DataLoader(
483
483
  dataset,
484
484
  batch_size=batch_size,
485
- num_workers=num_proc if num_proc > 1 else 0,
485
+ num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
486
486
  )
@@ -66,7 +66,7 @@ class AnySTSEvaluator(Evaluator):
66
66
  model: EncoderProtocol,
67
67
  *,
68
68
  encode_kwargs: EncodeKwargs,
69
- num_proc: int = 1,
69
+ num_proc: int | None = None,
70
70
  ) -> STSEvaluatorScores:
71
71
  logger.info("Running semantic similarity - Encoding samples (1/2)")
72
72
  embeddings1 = model.encode(
@@ -1,7 +1,16 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
1
5
  import numpy as np
2
6
 
7
+ if TYPE_CHECKING:
8
+ from numpy.typing import NDArray
9
+
3
10
 
4
- def hamming_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
11
+ def hamming_score(
12
+ y_true: NDArray[np.integer], y_pred: NDArray[np.integer | np.floating]
13
+ ) -> float:
5
14
  """Compute the Hamming score (a.k.a. label-based accuracy) for multilabel classification.
6
15
 
7
16
  The Hamming score is the fraction of labels that are correctly predicted for each sample,
@@ -45,7 +45,7 @@ class ClusteringEvaluator(Evaluator):
45
45
  model: EncoderProtocol,
46
46
  *,
47
47
  encode_kwargs: EncodeKwargs,
48
- num_proc: int = 1,
48
+ num_proc: int | None = None,
49
49
  ) -> list[int]:
50
50
  data_loader = create_dataloader(
51
51
  self.dataset,
@@ -91,7 +91,7 @@ class ImageTextPairClassificationEvaluator(Evaluator):
91
91
  model: EncoderProtocol,
92
92
  *,
93
93
  encode_kwargs: EncodeKwargs,
94
- num_proc: int = 1,
94
+ num_proc: int | None = None,
95
95
  ) -> list[torch.Tensor]:
96
96
  images = []
97
97
  if isinstance(self.images_column_names, str):
@@ -139,7 +139,7 @@ class ImageTextPairClassificationEvaluator(Evaluator):
139
139
  DataLoader(
140
140
  CustomImageDataset(images),
141
141
  collate_fn=_image_collate_fn,
142
- num_workers=num_proc if num_proc > 1 else 0,
142
+ num_workers=num_proc if num_proc is not None and num_proc > 1 else 0,
143
143
  ),
144
144
  task_metadata=self.task_metadata,
145
145
  hf_subset=self.hf_subset,
@@ -16,6 +16,7 @@ from mteb.similarity_functions import compute_pairwise_similarity
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from datasets import Dataset
19
+ from numpy.typing import NDArray
19
20
 
20
21
  from mteb.abstasks.task_metadata import TaskMetadata
21
22
  from mteb.models import EncoderProtocol
@@ -91,7 +92,7 @@ class PairClassificationEvaluator(Evaluator):
91
92
  self,
92
93
  model: EncoderProtocol,
93
94
  encode_kwargs: EncodeKwargs,
94
- num_proc: int = 1,
95
+ num_proc: int | None = None,
95
96
  ) -> PairClassificationDistances:
96
97
  logger.info("Running pair classification - Encoding samples (1/2)")
97
98
  embeddings1 = model.encode(
@@ -155,7 +156,7 @@ class PairClassificationEvaluator(Evaluator):
155
156
  hf_split: str,
156
157
  hf_subset: str,
157
158
  **encode_kwargs: Any,
158
- ) -> np.ndarray:
159
+ ) -> NDArray[np.floating]:
159
160
  index_map = {}
160
161
  all_unique_texts: list[str] = []
161
162
  all_texts_indexes = []
@@ -55,7 +55,7 @@ class RetrievalEvaluator(Evaluator):
55
55
  self,
56
56
  search_model: SearchProtocol,
57
57
  encode_kwargs: EncodeKwargs,
58
- num_proc: int = 1,
58
+ num_proc: int | None = None,
59
59
  ) -> RetrievalOutputType:
60
60
  logger.info("Running retrieval task - Indexing corpus...")
61
61
  search_model.index(
@@ -15,6 +15,8 @@ from mteb.types import RetrievalEvaluationResult
15
15
  if TYPE_CHECKING:
16
16
  from collections.abc import Mapping
17
17
 
18
+ from numpy.typing import NDArray
19
+
18
20
  from mteb.types import RelevantDocumentsType
19
21
 
20
22
  logger = logging.getLogger(__name__)
@@ -273,9 +275,9 @@ def confidence_scores(sim_scores: list[float]) -> dict[str, float]:
273
275
 
274
276
 
275
277
  def nauc(
276
- conf_scores: np.ndarray,
277
- metrics: np.ndarray,
278
- abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1],
278
+ conf_scores: NDArray[np.floating],
279
+ metrics: NDArray[np.floating],
280
+ abstention_rates: NDArray[np.floating] = np.linspace(0, 1, 11)[:-1],
279
281
  ) -> float:
280
282
  """Computes normalized Area Under the Curve (nAUC) on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997
281
283
 
@@ -295,10 +297,10 @@ def nauc(
295
297
  """
296
298
 
297
299
  def abstention_curve(
298
- conf_scores: np.ndarray,
299
- metrics: np.ndarray,
300
- abstention_rates: np.ndarray = np.linspace(0, 1, 11)[:-1],
301
- ) -> np.ndarray:
300
+ conf_scores: NDArray[np.floating],
301
+ metrics: NDArray[np.floating],
302
+ abstention_rates: NDArray[np.floating] = np.linspace(0, 1, 11)[:-1],
303
+ ) -> NDArray[np.floating]:
302
304
  """Computes the raw abstention curve for a given set of evaluated instances and corresponding confidence scores
303
305
 
304
306
  Args:
@@ -10,6 +10,7 @@ from .evaluator import Evaluator
10
10
  if TYPE_CHECKING:
11
11
  import numpy as np
12
12
  from datasets import Dataset
13
+ from numpy.typing import NDArray
13
14
  from torch.utils.data import DataLoader
14
15
  from typing_extensions import Self
15
16
 
@@ -21,11 +22,15 @@ logger = logging.getLogger(__name__)
21
22
 
22
23
 
23
24
  class SklearnModelProtocol(Protocol):
24
- def fit(self, X: Array, y: np.ndarray | list[int]) -> None: ... # noqa: N803
25
- def predict(self, X: Array) -> np.ndarray: ... # noqa: N803
25
+ def fit(
26
+ self, X: Array, y: NDArray[np.integer | np.floating] | list[int | float]
27
+ ) -> None: ...
28
+ def predict(self, X: Array) -> NDArray[np.integer | np.floating]: ...
26
29
  def get_params(self) -> dict[str, Any]: ...
27
30
  def set_params(self, random_state: int, **kwargs: dict[str, Any]) -> Self: ...
28
- def score(self, X: Array, y: np.ndarray | list[int]) -> float: ... # noqa: N803
31
+ def score(
32
+ self, X: Array, y: NDArray[np.integer | np.floating] | list[int | float]
33
+ ) -> float: ...
29
34
 
30
35
 
31
36
  class SklearnEvaluator(Evaluator):
@@ -54,7 +59,9 @@ class SklearnEvaluator(Evaluator):
54
59
  self.evaluator_model = evaluator_model
55
60
 
56
61
  def create_dataloaders(
57
- self, encode_kwargs: EncodeKwargs, num_proc: int
62
+ self,
63
+ encode_kwargs: EncodeKwargs,
64
+ num_proc: int | None,
58
65
  ) -> tuple[DataLoader[BatchedInput], DataLoader[BatchedInput]]:
59
66
  dataloader_train = create_dataloader(
60
67
  self.train_dataset,
@@ -78,8 +85,8 @@ class SklearnEvaluator(Evaluator):
78
85
  *,
79
86
  encode_kwargs: EncodeKwargs,
80
87
  test_cache: Array | None = None,
81
- num_proc: int = 1,
82
- ) -> tuple[np.ndarray, Array]:
88
+ num_proc: int | None = None,
89
+ ) -> tuple[NDArray[np.integer | np.floating], Array]:
83
90
  """Classification evaluation by training a sklearn classifier on the embeddings of the training set and evaluating on the embeddings of the test set.
84
91
 
85
92
  Args:
@@ -41,7 +41,7 @@ class BitextMiningEvaluator(Evaluator):
41
41
  model: EncoderProtocol,
42
42
  *,
43
43
  encode_kwargs: EncodeKwargs,
44
- num_proc: int = 1,
44
+ num_proc: int | None = None,
45
45
  ) -> dict[str, list[dict[str, float]]]:
46
46
  pair_elements = {p for pair in self.pairs for p in pair}
47
47
  if isinstance(self.sentences, Dataset):
@@ -100,7 +100,7 @@ class SummarizationEvaluator(Evaluator):
100
100
  model: EncoderProtocol,
101
101
  *,
102
102
  encode_kwargs: EncodeKwargs,
103
- num_proc: int = 1,
103
+ num_proc: int | None = None,
104
104
  ) -> SummarizationDistances:
105
105
  # Get the human & machine summaries for the text in one go for all
106
106
  human_lens = [len(human_summaries) for human_summaries in self.human_summaries]
@@ -48,7 +48,7 @@ class ZeroShotClassificationEvaluator(Evaluator):
48
48
  model: EncoderProtocol,
49
49
  *,
50
50
  encode_kwargs: EncodeKwargs,
51
- num_proc: int = 1,
51
+ num_proc: int | None = None,
52
52
  ) -> Array:
53
53
  dataloader = create_dataloader(
54
54
  self.dataset,
@@ -38,21 +38,26 @@ Bibtex:
38
38
  }
39
39
  """
40
40
 
41
+ from __future__ import annotations
42
+
41
43
  import itertools
42
- from typing import Any
44
+ from typing import TYPE_CHECKING, Any
43
45
 
44
46
  import numpy as np
45
47
  import scipy.sparse as sp
46
48
  from sklearn.model_selection._split import _BaseKFold
47
49
  from sklearn.utils import check_random_state
48
50
 
51
+ if TYPE_CHECKING:
52
+ from numpy.typing import NDArray
53
+
49
54
 
50
55
  def _iterative_train_test_split(
51
- X: np.ndarray, # noqa: N803
52
- y: np.ndarray,
56
+ X: NDArray[np.integer],
57
+ y: NDArray[np.integer],
53
58
  test_size: float,
54
59
  random_state: int | None = None,
55
- ) -> tuple[np.ndarray, np.ndarray]:
60
+ ) -> tuple[NDArray[np.integer], NDArray[np.integer]]:
56
61
  """Iteratively stratified train/test split
57
62
 
58
63
  Slighltly modified from:
@@ -79,8 +84,8 @@ def _iterative_train_test_split(
79
84
 
80
85
 
81
86
  def _fold_tie_break(
82
- desired_samples_per_fold: np.ndarray,
83
- M: np.ndarray, # noqa: N803
87
+ desired_samples_per_fold: NDArray[np.floating],
88
+ M: NDArray[np.integer], # noqa: N803
84
89
  random_state: np.random.RandomState,
85
90
  ):
86
91
  """Helper function to split a tie between folds with same desirability of a given sample
@@ -179,7 +184,7 @@ class IterativeStratification(_BaseKFold):
179
184
  ]
180
185
 
181
186
  def _prepare_stratification(
182
- self, y: np.ndarray
187
+ self, y: NDArray[np.integer]
183
188
  ) -> tuple[
184
189
  list[list[int]],
185
190
  dict[int, bool],
@@ -301,7 +306,7 @@ class IterativeStratification(_BaseKFold):
301
306
  self.desired_samples_per_fold[fold_selected] -= 1
302
307
  folds[fold_selected].append(row)
303
308
 
304
- def _iter_test_indices(self, X, y=None, groups=None): # noqa: N803
309
+ def _iter_test_indices(self, X, y=None, groups=None):
305
310
  """Internal method for providing scikit-learn's split with folds
306
311
 
307
312
  Args:
mteb/abstasks/abstask.py CHANGED
@@ -116,7 +116,7 @@ class AbsTask(ABC):
116
116
  logger.warning(msg)
117
117
  warnings.warn(msg)
118
118
 
119
- def dataset_transform(self, num_proc: int = 1, **kwargs: Any) -> None:
119
+ def dataset_transform(self, num_proc: int | None = None, **kwargs: Any) -> None:
120
120
  """A transform operations applied to the dataset after loading.
121
121
 
122
122
  This method is useful when the dataset from Huggingface is not in an `mteb` compatible format.
@@ -136,7 +136,7 @@ class AbsTask(ABC):
136
136
  *,
137
137
  encode_kwargs: EncodeKwargs,
138
138
  prediction_folder: Path | None = None,
139
- num_proc: int = 1,
139
+ num_proc: int | None = None,
140
140
  **kwargs: Any,
141
141
  ) -> Mapping[HFSubset, ScoresDict]:
142
142
  """Evaluates an MTEB compatible model on the task.
@@ -219,7 +219,7 @@ class AbsTask(ABC):
219
219
  hf_subset: str,
220
220
  encode_kwargs: EncodeKwargs,
221
221
  prediction_folder: Path | None = None,
222
- num_proc: int = 1,
222
+ num_proc: int | None = None,
223
223
  **kwargs: Any,
224
224
  ) -> ScoresDict:
225
225
  raise NotImplementedError(
@@ -324,7 +324,7 @@ class AbsTask(ABC):
324
324
  ) # only take the specified test split.
325
325
  return dataset_dict
326
326
 
327
- def load_data(self, num_proc: int = 1, **kwargs: Any) -> None:
327
+ def load_data(self, num_proc: int | None = None, **kwargs: Any) -> None:
328
328
  """Loads dataset from HuggingFace hub
329
329
 
330
330
  This is the main loading function for Task. Do not overwrite this, instead we recommend using `dataset_transform`, which is called after the
@@ -31,6 +31,8 @@ from .abstask import AbsTask
31
31
  if TYPE_CHECKING:
32
32
  from pathlib import Path
33
33
 
34
+ from numpy.typing import NDArray
35
+
34
36
  from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol
35
37
  from mteb.models import MTEBModels
36
38
  from mteb.types import EncodeKwargs, HFSubset, ScoresDict
@@ -136,7 +138,7 @@ class AbsTaskClassification(AbsTask):
136
138
  *,
137
139
  encode_kwargs: EncodeKwargs,
138
140
  prediction_folder: Path | None = None,
139
- num_proc: int = 1,
141
+ num_proc: int | None = None,
140
142
  **kwargs: Any,
141
143
  ) -> dict[HFSubset, ScoresDict]:
142
144
  """Evaluate a model on the classification task.
@@ -199,7 +201,7 @@ class AbsTaskClassification(AbsTask):
199
201
  hf_split: str,
200
202
  hf_subset: str,
201
203
  prediction_folder: Path | None = None,
202
- num_proc: int = 1,
204
+ num_proc: int | None = None,
203
205
  **kwargs: Any,
204
206
  ) -> FullClassificationMetrics:
205
207
  if not isinstance(model, EncoderProtocol):
@@ -270,8 +272,8 @@ class AbsTaskClassification(AbsTask):
270
272
 
271
273
  def _calculate_scores(
272
274
  self,
273
- y_test: np.ndarray | list[int],
274
- y_pred: np.ndarray,
275
+ y_test: NDArray[np.integer] | list[int],
276
+ y_pred: NDArray[np.integer | np.floating] | list[int],
275
277
  ) -> ClassificationMetrics:
276
278
  scores = ClassificationMetrics(
277
279
  accuracy=accuracy_score(y_test, y_pred),
@@ -169,7 +169,7 @@ class AbsTaskClustering(AbsTask):
169
169
  hf_split: str,
170
170
  hf_subset: str,
171
171
  prediction_folder: Path | None = None,
172
- num_proc: int = 1,
172
+ num_proc: int | None = None,
173
173
  **kwargs: Any,
174
174
  ) -> ScoresDict:
175
175
  if not isinstance(model, EncoderProtocol):
@@ -95,7 +95,7 @@ class AbsTaskClusteringLegacy(AbsTask):
95
95
  hf_split: str,
96
96
  hf_subset: str,
97
97
  prediction_folder: Path | None = None,
98
- num_proc: int = 1,
98
+ num_proc: int | None = None,
99
99
  **kwargs: Any,
100
100
  ) -> ScoresDict:
101
101
  if not isinstance(model, EncoderProtocol):
@@ -134,7 +134,7 @@ class AbsTaskImageTextPairClassification(AbsTask):
134
134
  hf_split: str,
135
135
  hf_subset: str,
136
136
  prediction_folder: Path | None = None,
137
- num_proc: int = 1,
137
+ num_proc: int | None = None,
138
138
  **kwargs: Any,
139
139
  ) -> ImageTextPairClassificationMetrics:
140
140
  if not isinstance(model, EncoderProtocol):
@@ -23,6 +23,8 @@ from .classification import AbsTaskClassification
23
23
  if TYPE_CHECKING:
24
24
  from pathlib import Path
25
25
 
26
+ from numpy.typing import NDArray
27
+
26
28
  from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol
27
29
  from mteb.models import MTEBModels
28
30
  from mteb.types import Array, EncodeKwargs
@@ -32,10 +34,10 @@ logger = logging.getLogger(__name__)
32
34
 
33
35
  def _evaluate_classifier(
34
36
  embeddings_train: Array,
35
- y_train: np.ndarray,
37
+ y_train: NDArray[np.integer],
36
38
  embeddings_test: Array,
37
39
  classifier: SklearnModelProtocol,
38
- ) -> tuple[np.ndarray, SklearnModelProtocol]:
40
+ ) -> tuple[NDArray[np.integer | np.floating], SklearnModelProtocol]:
39
41
  classifier_copy: SklearnModelProtocol = clone(classifier)
40
42
  classifier_copy.fit(embeddings_train, y_train)
41
43
  return classifier_copy.predict(embeddings_test), classifier_copy
@@ -93,7 +95,7 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
93
95
  hf_split: str,
94
96
  hf_subset: str,
95
97
  prediction_folder: Path | None = None,
96
- num_proc: int = 1,
98
+ num_proc: int | None = None,
97
99
  **kwargs: Any,
98
100
  ) -> FullMultilabelClassificationMetrics:
99
101
  if not isinstance(model, EncoderProtocol):
@@ -208,8 +210,8 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
208
210
 
209
211
  def _calculate_scores( # type: ignore[override]
210
212
  self,
211
- y_test: np.ndarray,
212
- y_pred: np.ndarray,
213
+ y_test: NDArray[np.integer],
214
+ y_pred: NDArray[np.integer | np.floating],
213
215
  x_test_embedding: Array,
214
216
  current_classifier: SklearnModelProtocol,
215
217
  ) -> MultilabelClassificationMetrics:
@@ -97,7 +97,7 @@ class AbsTaskPairClassification(AbsTask):
97
97
  hf_subset: str,
98
98
  encode_kwargs: EncodeKwargs,
99
99
  prediction_folder: Path | None = None,
100
- num_proc: int = 1,
100
+ num_proc: int | None = None,
101
101
  **kwargs,
102
102
  ) -> dict[str, float]:
103
103
  if not isinstance(model, EncoderProtocol):
@@ -24,6 +24,7 @@ from .classification import AbsTaskClassification
24
24
 
25
25
  if TYPE_CHECKING:
26
26
  from datasets import Dataset
27
+ from numpy.typing import NDArray
27
28
 
28
29
  from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol
29
30
  from mteb.types.statistics import (
@@ -123,8 +124,8 @@ class AbsTaskRegression(AbsTaskClassification):
123
124
 
124
125
  def _calculate_scores( # type: ignore[override]
125
126
  self,
126
- y_test: np.ndarray | list[int],
127
- y_pred: np.ndarray,
127
+ y_test: NDArray[np.floating] | list[float],
128
+ y_pred: NDArray[np.floating] | list[float],
128
129
  ) -> RegressionMetrics:
129
130
  mse = mean_squared_error(y_test, y_pred)
130
131
  return RegressionMetrics(
@@ -148,7 +148,10 @@ class AbsTaskRetrieval(AbsTask):
148
148
  )
149
149
  )
150
150
 
151
- def convert_v1_dataset_format_to_v2(self, num_proc: int) -> None:
151
+ def convert_v1_dataset_format_to_v2(
152
+ self,
153
+ num_proc: int | None,
154
+ ) -> None:
152
155
  """Convert dataset from v1 (from `self.queries`, `self.document`) format to v2 format (`self.dotaset`)."""
153
156
  # check if dataset is `v1` version
154
157
  if not hasattr(self, "queries"):
@@ -257,7 +260,7 @@ class AbsTaskRetrieval(AbsTask):
257
260
  if hasattr(self, "top_ranked"):
258
261
  del self.top_ranked
259
262
 
260
- def load_data(self, num_proc: int = 1, **kwargs) -> None:
263
+ def load_data(self, num_proc: int | None = None, **kwargs) -> None:
261
264
  """Load the dataset for the retrieval task."""
262
265
  if self.data_loaded:
263
266
  return
@@ -301,7 +304,7 @@ class AbsTaskRetrieval(AbsTask):
301
304
  *,
302
305
  encode_kwargs: EncodeKwargs,
303
306
  prediction_folder: Path | None = None,
304
- num_proc: int = 1,
307
+ num_proc: int | None = None,
305
308
  **kwargs: Any,
306
309
  ) -> Mapping[HFSubset, ScoresDict]:
307
310
  """Evaluate the model on the retrieval task.
@@ -342,7 +345,7 @@ class AbsTaskRetrieval(AbsTask):
342
345
  hf_split: str,
343
346
  hf_subset: str,
344
347
  prediction_folder: Path | None = None,
345
- num_proc: int = 1,
348
+ num_proc: int | None = None,
346
349
  **kwargs,
347
350
  ) -> ScoresDict:
348
351
  """Evaluate a model on a specific subset of the data.
@@ -473,7 +476,7 @@ class AbsTaskRetrieval(AbsTask):
473
476
  split: str,
474
477
  hf_subset: str | None = None,
475
478
  compute_overall: bool = False,
476
- num_proc: int = 1,
479
+ num_proc: int | None = None,
477
480
  ) -> RetrievalDescriptiveStatistics:
478
481
  self.convert_v1_dataset_format_to_v2(num_proc)
479
482
  if hf_subset and hf_subset in self.dataset: