mteb 2.5.2__py3-none-any.whl → 2.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. mteb/__init__.py +2 -0
  2. mteb/_create_dataloaders.py +17 -18
  3. mteb/_evaluators/any_sts_evaluator.py +3 -3
  4. mteb/_evaluators/clustering_evaluator.py +2 -2
  5. mteb/_evaluators/evaluator.py +4 -2
  6. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +10 -8
  7. mteb/_evaluators/pair_classification_evaluator.py +5 -3
  8. mteb/_evaluators/retrieval_evaluator.py +2 -2
  9. mteb/_evaluators/retrieval_metrics.py +18 -17
  10. mteb/_evaluators/sklearn_evaluator.py +11 -10
  11. mteb/_evaluators/text/bitext_mining_evaluator.py +27 -18
  12. mteb/_evaluators/text/summarization_evaluator.py +23 -18
  13. mteb/_evaluators/zeroshot_classification_evaluator.py +5 -3
  14. mteb/abstasks/_data_filter/filters.py +1 -1
  15. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  16. mteb/abstasks/_statistics_calculation.py +18 -10
  17. mteb/abstasks/_stratification.py +18 -18
  18. mteb/abstasks/abstask.py +35 -28
  19. mteb/abstasks/aggregate_task_metadata.py +1 -9
  20. mteb/abstasks/aggregated_task.py +10 -29
  21. mteb/abstasks/classification.py +15 -10
  22. mteb/abstasks/clustering.py +19 -15
  23. mteb/abstasks/clustering_legacy.py +10 -10
  24. mteb/abstasks/image/image_text_pair_classification.py +7 -4
  25. mteb/abstasks/multilabel_classification.py +23 -19
  26. mteb/abstasks/pair_classification.py +20 -11
  27. mteb/abstasks/regression.py +4 -4
  28. mteb/abstasks/retrieval.py +28 -24
  29. mteb/abstasks/retrieval_dataset_loaders.py +2 -2
  30. mteb/abstasks/sts.py +8 -5
  31. mteb/abstasks/task_metadata.py +31 -33
  32. mteb/abstasks/text/bitext_mining.py +39 -28
  33. mteb/abstasks/text/reranking.py +8 -6
  34. mteb/abstasks/text/summarization.py +10 -5
  35. mteb/abstasks/zeroshot_classification.py +8 -4
  36. mteb/benchmarks/benchmark.py +4 -2
  37. mteb/benchmarks/benchmarks/__init__.py +4 -0
  38. mteb/benchmarks/benchmarks/benchmarks.py +112 -11
  39. mteb/benchmarks/get_benchmark.py +14 -55
  40. mteb/cache.py +182 -29
  41. mteb/cli/_display_tasks.py +2 -2
  42. mteb/cli/build_cli.py +110 -14
  43. mteb/cli/generate_model_card.py +43 -23
  44. mteb/deprecated_evaluator.py +63 -49
  45. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
  46. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
  47. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
  48. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
  49. mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
  50. mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
  51. mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
  52. mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
  53. mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
  54. mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
  55. mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
  56. mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
  57. mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
  58. mteb/evaluate.py +44 -33
  59. mteb/filter_tasks.py +25 -26
  60. mteb/get_tasks.py +29 -30
  61. mteb/languages/language_scripts.py +5 -3
  62. mteb/leaderboard/app.py +162 -34
  63. mteb/load_results.py +12 -12
  64. mteb/models/abs_encoder.py +10 -6
  65. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  66. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  67. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
  68. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
  69. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  70. mteb/models/get_model_meta.py +21 -3
  71. mteb/models/instruct_wrapper.py +28 -8
  72. mteb/models/model_implementations/align_models.py +1 -1
  73. mteb/models/model_implementations/andersborges.py +4 -4
  74. mteb/models/model_implementations/ara_models.py +1 -1
  75. mteb/models/model_implementations/arctic_models.py +8 -8
  76. mteb/models/model_implementations/b1ade_models.py +1 -1
  77. mteb/models/model_implementations/bge_models.py +45 -21
  78. mteb/models/model_implementations/bica_model.py +3 -3
  79. mteb/models/model_implementations/blip2_models.py +2 -2
  80. mteb/models/model_implementations/blip_models.py +16 -16
  81. mteb/models/model_implementations/bm25.py +4 -4
  82. mteb/models/model_implementations/bmretriever_models.py +6 -4
  83. mteb/models/model_implementations/cadet_models.py +1 -1
  84. mteb/models/model_implementations/cde_models.py +11 -4
  85. mteb/models/model_implementations/clip_models.py +6 -6
  86. mteb/models/model_implementations/clips_models.py +3 -3
  87. mteb/models/model_implementations/codefuse_models.py +5 -5
  88. mteb/models/model_implementations/codesage_models.py +3 -3
  89. mteb/models/model_implementations/cohere_models.py +5 -5
  90. mteb/models/model_implementations/cohere_v.py +2 -2
  91. mteb/models/model_implementations/colpali_models.py +3 -3
  92. mteb/models/model_implementations/colqwen_models.py +8 -8
  93. mteb/models/model_implementations/colsmol_models.py +2 -2
  94. mteb/models/model_implementations/conan_models.py +1 -1
  95. mteb/models/model_implementations/dino_models.py +42 -42
  96. mteb/models/model_implementations/e5_instruct.py +23 -4
  97. mteb/models/model_implementations/e5_models.py +9 -9
  98. mteb/models/model_implementations/e5_v.py +6 -6
  99. mteb/models/model_implementations/eagerworks_models.py +1 -1
  100. mteb/models/model_implementations/emillykkejensen_models.py +6 -6
  101. mteb/models/model_implementations/en_code_retriever.py +1 -1
  102. mteb/models/model_implementations/euler_models.py +2 -2
  103. mteb/models/model_implementations/fa_models.py +9 -9
  104. mteb/models/model_implementations/facebookai.py +14 -2
  105. mteb/models/model_implementations/geogpt_models.py +1 -1
  106. mteb/models/model_implementations/gme_v_models.py +6 -5
  107. mteb/models/model_implementations/google_models.py +1 -1
  108. mteb/models/model_implementations/granite_vision_embedding_models.py +1 -1
  109. mteb/models/model_implementations/gritlm_models.py +2 -2
  110. mteb/models/model_implementations/gte_models.py +25 -13
  111. mteb/models/model_implementations/hinvec_models.py +1 -1
  112. mteb/models/model_implementations/ibm_granite_models.py +30 -6
  113. mteb/models/model_implementations/inf_models.py +2 -2
  114. mteb/models/model_implementations/jasper_models.py +2 -2
  115. mteb/models/model_implementations/jina_clip.py +48 -10
  116. mteb/models/model_implementations/jina_models.py +18 -11
  117. mteb/models/model_implementations/kblab.py +12 -6
  118. mteb/models/model_implementations/kennethenevoldsen_models.py +4 -4
  119. mteb/models/model_implementations/kfst.py +1 -1
  120. mteb/models/model_implementations/kowshik24_models.py +1 -1
  121. mteb/models/model_implementations/lgai_embedding_models.py +1 -1
  122. mteb/models/model_implementations/linq_models.py +1 -1
  123. mteb/models/model_implementations/listconranker.py +1 -1
  124. mteb/models/model_implementations/llm2clip_models.py +6 -6
  125. mteb/models/model_implementations/llm2vec_models.py +8 -8
  126. mteb/models/model_implementations/mcinext_models.py +4 -1
  127. mteb/models/model_implementations/mdbr_models.py +17 -3
  128. mteb/models/model_implementations/misc_models.py +68 -68
  129. mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
  130. mteb/models/model_implementations/mme5_models.py +1 -1
  131. mteb/models/model_implementations/moco_models.py +4 -4
  132. mteb/models/model_implementations/mod_models.py +1 -1
  133. mteb/models/model_implementations/model2vec_models.py +14 -14
  134. mteb/models/model_implementations/moka_models.py +1 -1
  135. mteb/models/model_implementations/nbailab.py +3 -3
  136. mteb/models/model_implementations/no_instruct_sentence_models.py +2 -2
  137. mteb/models/model_implementations/nomic_models.py +30 -15
  138. mteb/models/model_implementations/nomic_models_vision.py +1 -1
  139. mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +15 -9
  140. mteb/models/model_implementations/nvidia_models.py +151 -19
  141. mteb/models/model_implementations/octen_models.py +61 -2
  142. mteb/models/model_implementations/openclip_models.py +13 -13
  143. mteb/models/model_implementations/opensearch_neural_sparse_models.py +5 -5
  144. mteb/models/model_implementations/ops_moa_models.py +1 -1
  145. mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
  146. mteb/models/model_implementations/pawan_models.py +1 -1
  147. mteb/models/model_implementations/piccolo_models.py +1 -1
  148. mteb/models/model_implementations/pixie_models.py +56 -0
  149. mteb/models/model_implementations/promptriever_models.py +4 -4
  150. mteb/models/model_implementations/pylate_models.py +10 -9
  151. mteb/models/model_implementations/qodo_models.py +2 -2
  152. mteb/models/model_implementations/qtack_models.py +1 -1
  153. mteb/models/model_implementations/qwen3_models.py +3 -3
  154. mteb/models/model_implementations/qzhou_models.py +2 -2
  155. mteb/models/model_implementations/random_baseline.py +3 -3
  156. mteb/models/model_implementations/rasgaard_models.py +2 -2
  157. mteb/models/model_implementations/reasonir_model.py +1 -1
  158. mteb/models/model_implementations/repllama_models.py +3 -3
  159. mteb/models/model_implementations/rerankers_custom.py +12 -6
  160. mteb/models/model_implementations/rerankers_monot5_based.py +17 -17
  161. mteb/models/model_implementations/richinfoai_models.py +1 -1
  162. mteb/models/model_implementations/ru_sentence_models.py +20 -20
  163. mteb/models/model_implementations/ruri_models.py +10 -10
  164. mteb/models/model_implementations/salesforce_models.py +3 -3
  165. mteb/models/model_implementations/samilpwc_models.py +1 -1
  166. mteb/models/model_implementations/sarashina_embedding_models.py +2 -2
  167. mteb/models/model_implementations/searchmap_models.py +1 -1
  168. mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +113 -146
  169. mteb/models/model_implementations/sentence_transformers_models.py +124 -22
  170. mteb/models/model_implementations/shuu_model.py +1 -1
  171. mteb/models/model_implementations/siglip_models.py +20 -20
  172. mteb/models/model_implementations/slm_models.py +416 -0
  173. mteb/models/model_implementations/spartan8806_atles_champion.py +1 -1
  174. mteb/models/model_implementations/stella_models.py +17 -4
  175. mteb/models/model_implementations/tarka_models.py +2 -2
  176. mteb/models/model_implementations/text2vec_models.py +9 -3
  177. mteb/models/model_implementations/ua_sentence_models.py +1 -1
  178. mteb/models/model_implementations/uae_models.py +7 -1
  179. mteb/models/model_implementations/vdr_models.py +1 -1
  180. mteb/models/model_implementations/vi_vn_models.py +6 -6
  181. mteb/models/model_implementations/vlm2vec_models.py +3 -3
  182. mteb/models/model_implementations/voyage_models.py +84 -0
  183. mteb/models/model_implementations/voyage_v.py +9 -7
  184. mteb/models/model_implementations/youtu_models.py +1 -1
  185. mteb/models/model_implementations/yuan_models.py +1 -1
  186. mteb/models/model_implementations/yuan_models_en.py +1 -1
  187. mteb/models/model_meta.py +80 -31
  188. mteb/models/models_protocols.py +22 -6
  189. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
  190. mteb/models/search_wrappers.py +33 -18
  191. mteb/models/sentence_transformer_wrapper.py +50 -25
  192. mteb/models/vllm_wrapper.py +327 -0
  193. mteb/py.typed +0 -0
  194. mteb/results/benchmark_results.py +29 -21
  195. mteb/results/model_result.py +52 -22
  196. mteb/results/task_result.py +80 -58
  197. mteb/similarity_functions.py +11 -7
  198. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  199. mteb/tasks/classification/est/estonian_valence.py +1 -1
  200. mteb/tasks/classification/kur/kurdish_sentiment_classification.py +2 -2
  201. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  202. mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
  203. mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
  204. mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
  205. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  206. mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
  207. mteb/tasks/retrieval/code/code_rag.py +12 -12
  208. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  209. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  210. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  211. mteb/tasks/retrieval/eng/__init__.py +2 -0
  212. mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
  213. mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
  214. mteb/tasks/retrieval/kor/__init__.py +15 -1
  215. mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
  216. mteb/tasks/retrieval/multilingual/__init__.py +2 -0
  217. mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
  218. mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +90 -100
  219. mteb/tasks/retrieval/nob/norquad.py +2 -2
  220. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  221. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  222. mteb/tasks/retrieval/vie/__init__.py +14 -6
  223. mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +39 -0
  224. mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +39 -0
  225. mteb/tasks/retrieval/vie/fevervn_retrieval.py +39 -0
  226. mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +39 -0
  227. mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +48 -0
  228. mteb/tasks/retrieval/vie/nqvn_retrieval.py +39 -0
  229. mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
  230. mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
  231. mteb/types/__init__.py +2 -0
  232. mteb/types/_encoder_io.py +12 -0
  233. mteb/types/_result.py +2 -1
  234. mteb/types/statistics.py +9 -3
  235. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/METADATA +15 -4
  236. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/RECORD +240 -219
  237. mteb/models/model_implementations/mxbai_models.py +0 -111
  238. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
  239. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
  240. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
  241. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
mteb/__init__.py CHANGED
@@ -3,6 +3,7 @@ from importlib.metadata import version
3
3
  from mteb import types
4
4
  from mteb.abstasks import AbsTask
5
5
  from mteb.abstasks.task_metadata import TaskMetadata
6
+ from mteb.cache import ResultCache
6
7
  from mteb.deprecated_evaluator import MTEB
7
8
  from mteb.evaluate import evaluate
8
9
  from mteb.filter_tasks import filter_tasks
@@ -33,6 +34,7 @@ __all__ = [
33
34
  "CrossEncoderProtocol",
34
35
  "EncoderProtocol",
35
36
  "IndexEncoderSearchProtocol",
37
+ "ResultCache",
36
38
  "SearchProtocol",
37
39
  "SentenceTransformerEncoderWrapper",
38
40
  "TaskMetadata",
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import warnings
2
3
  from collections.abc import Callable
3
4
  from typing import Any, cast
4
5
 
@@ -22,7 +23,7 @@ logger = logging.getLogger(__name__)
22
23
  def _create_dataloader_from_texts(
23
24
  text: list[str],
24
25
  batch_size: int = 32,
25
- **kwargs: dict[str, Any],
26
+ **kwargs: Any,
26
27
  ) -> DataLoader[TextInput]:
27
28
  """Create a dataloader from a list of text.
28
29
 
@@ -113,11 +114,8 @@ def _create_text_dataloader_for_queries(
113
114
  )
114
115
 
115
116
 
116
- _warned_about_user_role = False
117
-
118
-
119
117
  def _convert_conv_history_to_query(
120
- row: dict[str, list[str] | Conversation],
118
+ row: dict[str, str | list[str] | Conversation],
121
119
  ) -> dict[str, str | Conversation]:
122
120
  """Convert a conversation history to a single query string.
123
121
 
@@ -127,21 +125,18 @@ def _convert_conv_history_to_query(
127
125
  Returns:
128
126
  The updated row with the "query" and "text" fields set to the conversation string, and the "conversation" field set to the list of ConversationTurn.
129
127
  """
130
- global _warned_about_user_role
131
-
132
128
  conversation = row["text"]
133
129
  # if it's a list of strings, just join them
134
130
  if isinstance(conversation, list) and isinstance(conversation[0], str):
135
- conversation = cast(list[str], conversation)
136
- conv_str = "; ".join(conversation)
131
+ conversation_ = cast(list[str], conversation)
132
+ conv_str = "; ".join(conversation_)
137
133
  current_conversation = [
138
- ConversationTurn(role="user", content=message) for message in conversation
134
+ ConversationTurn(role="user", content=message) for message in conversation_
139
135
  ]
140
- if not _warned_about_user_role:
141
- logger.warning(
142
- "Conversations are a list of strings. Used 'user' role for all turns."
143
- )
144
- _warned_about_user_role = True
136
+ warnings.warn(
137
+ "Conversations are a list of strings. Used 'user' role for all turns.",
138
+ category=UserWarning,
139
+ )
145
140
  # otherwise, it's a list of dictionaries, which we need to convert to strings
146
141
  elif isinstance(conversation, list) and isinstance(conversation[0], dict):
147
142
  conv = []
@@ -178,7 +173,7 @@ def _convert_conv_history_to_query(
178
173
 
179
174
  row["text"] = conv_str
180
175
  row["conversation"] = current_conversation
181
- return row
176
+ return cast(dict[str, str | list[ConversationTurn]], row)
182
177
 
183
178
 
184
179
  def _create_dataloader_for_queries_conversation(
@@ -196,7 +191,8 @@ def _create_dataloader_for_queries_conversation(
196
191
  """
197
192
  return DataLoader(
198
193
  queries.map(
199
- _convert_conv_history_to_query, desc="Converting conversations to queries"
194
+ _convert_conv_history_to_query,
195
+ desc="Converting conversations to queries",
200
196
  ),
201
197
  collate_fn=_custom_collate_fn,
202
198
  batch_size=batch_size,
@@ -366,6 +362,9 @@ def _create_document_dataloader(
366
362
  task_metadata: Metadata of the task to determine the document type.
367
363
  input_column: The column to use as input. If None, it will use the first column that matches the modality.
368
364
  batch_size: Batch size for the dataloader.
365
+
366
+ Returns:
367
+ A dataloader for the documents.
369
368
  """
370
369
  document_type = task_metadata.get_modalities(PromptType.document)
371
370
  if document_type == ["text"]: # text only
@@ -388,7 +387,7 @@ def create_dataloader(
388
387
  prompt_type: PromptType | None = None,
389
388
  input_column: str | None = None,
390
389
  batch_size: int = 32,
391
- **kwargs: dict[str, Any],
390
+ **kwargs: Any,
392
391
  ) -> DataLoader[BatchedInput]:
393
392
  """Create a dataloader from a dataset.
394
393
 
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Any, TypedDict
2
+ from typing import TypedDict
3
3
 
4
4
  from datasets import Dataset
5
5
  from sklearn.metrics.pairwise import (
@@ -12,7 +12,7 @@ from mteb._create_dataloaders import create_dataloader
12
12
  from mteb.abstasks.task_metadata import TaskMetadata
13
13
  from mteb.models import EncoderProtocol
14
14
  from mteb.similarity_functions import compute_pairwise_similarity
15
- from mteb.types import PromptType
15
+ from mteb.types import EncodeKwargs, PromptType
16
16
 
17
17
  from .evaluator import Evaluator
18
18
 
@@ -60,7 +60,7 @@ class AnySTSEvaluator(Evaluator):
60
60
  self,
61
61
  model: EncoderProtocol,
62
62
  *,
63
- encode_kwargs: dict[str, Any],
63
+ encode_kwargs: EncodeKwargs,
64
64
  ) -> STSEvaluatorScores:
65
65
  logger.info("Running semantic similarity - Encoding samples (1/2)")
66
66
  embeddings1 = model.encode(
@@ -1,5 +1,4 @@
1
1
  import logging
2
- from typing import Any
3
2
 
4
3
  from datasets import Dataset
5
4
  from sklearn import cluster
@@ -7,6 +6,7 @@ from sklearn import cluster
7
6
  from mteb._create_dataloaders import create_dataloader
8
7
  from mteb.abstasks.task_metadata import TaskMetadata
9
8
  from mteb.models import EncoderProtocol
9
+ from mteb.types import EncodeKwargs
10
10
 
11
11
  from .evaluator import Evaluator
12
12
 
@@ -38,7 +38,7 @@ class ClusteringEvaluator(Evaluator):
38
38
  self,
39
39
  model: EncoderProtocol,
40
40
  *,
41
- encode_kwargs: dict[str, Any],
41
+ encode_kwargs: EncodeKwargs,
42
42
  ) -> list[int]:
43
43
  data_loader = create_dataloader(
44
44
  self.dataset,
@@ -1,8 +1,10 @@
1
1
  from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable, Mapping
2
3
  from typing import Any
3
4
 
4
5
  from mteb.abstasks.abstask import _set_seed
5
6
  from mteb.models import EncoderProtocol
7
+ from mteb.types import EncodeKwargs
6
8
 
7
9
 
8
10
  class Evaluator(ABC):
@@ -17,8 +19,8 @@ class Evaluator(ABC):
17
19
 
18
20
  @abstractmethod
19
21
  def __call__(
20
- self, model: EncoderProtocol, *, encode_kwargs: dict[str, Any]
21
- ) -> dict[str, float]:
22
+ self, model: EncoderProtocol, *, encode_kwargs: EncodeKwargs
23
+ ) -> Mapping[str, float] | Iterable[Any]:
22
24
  """This is called during training to evaluate the model.
23
25
 
24
26
  It returns scores.
@@ -1,20 +1,22 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ from collections.abc import Sequence
4
5
  from typing import TYPE_CHECKING, Any
5
6
 
6
7
  import torch
7
8
  import torch.nn.functional as F
8
- from datasets import Dataset
9
9
  from torch.utils.data import DataLoader
10
10
 
11
11
  from mteb._create_dataloaders import (
12
+ _create_dataloader_from_texts,
12
13
  _transform_image_to_rgb,
13
14
  )
14
15
  from mteb._evaluators.evaluator import Evaluator
15
16
  from mteb._requires_package import requires_image_dependencies
16
17
  from mteb.abstasks.task_metadata import TaskMetadata
17
18
  from mteb.models.models_protocols import EncoderProtocol
19
+ from mteb.types import EncodeKwargs
18
20
 
19
21
  if TYPE_CHECKING:
20
22
  from PIL.Image import Image
@@ -61,8 +63,8 @@ class ImageTextPairClassificationEvaluator(Evaluator):
61
63
  def __init__(
62
64
  self,
63
65
  dataset,
64
- images_column_names: str | list[str],
65
- texts_column_names: str | list[str],
66
+ images_column_names: str | Sequence[str],
67
+ texts_column_names: str | Sequence[str],
66
68
  num_images_per_sample: int,
67
69
  num_texts_per_sample: int,
68
70
  task_metadata: TaskMetadata,
@@ -82,10 +84,11 @@ class ImageTextPairClassificationEvaluator(Evaluator):
82
84
  self.hf_split = hf_split
83
85
  self.hf_subset = hf_subset
84
86
 
85
- def __call__(
87
+ def __call__( # type: ignore[override]
86
88
  self,
87
89
  model: EncoderProtocol,
88
- encode_kwargs: dict[str, Any],
90
+ *,
91
+ encode_kwargs: EncodeKwargs,
89
92
  ) -> list[torch.Tensor]:
90
93
  images = []
91
94
  if isinstance(self.images_column_names, str):
@@ -106,8 +109,8 @@ class ImageTextPairClassificationEvaluator(Evaluator):
106
109
  texts.append(row[col])
107
110
 
108
111
  text_embeddings = model.encode(
109
- DataLoader(
110
- Dataset.from_dict({"text": texts}),
112
+ _create_dataloader_from_texts(
113
+ texts,
111
114
  **encode_kwargs,
112
115
  ),
113
116
  task_metadata=self.task_metadata,
@@ -128,7 +131,6 @@ class ImageTextPairClassificationEvaluator(Evaluator):
128
131
  DataLoader(
129
132
  CustomImageDataset(images),
130
133
  collate_fn=lambda x: {"image": [item["image"] for item in x]},
131
- **encode_kwargs,
132
134
  ),
133
135
  task_metadata=self.task_metadata,
134
136
  hf_subset=self.hf_subset,
@@ -14,7 +14,7 @@ from mteb._evaluators.evaluator import Evaluator
14
14
  from mteb.abstasks.task_metadata import TaskMetadata
15
15
  from mteb.models import EncoderProtocol
16
16
  from mteb.similarity_functions import compute_pairwise_similarity
17
- from mteb.types import PromptType
17
+ from mteb.types import EncodeKwargs, PromptType
18
18
 
19
19
  logger = logging.getLogger(__name__)
20
20
 
@@ -85,7 +85,7 @@ class PairClassificationEvaluator(Evaluator):
85
85
  def __call__(
86
86
  self,
87
87
  model: EncoderProtocol,
88
- encode_kwargs: dict[str, Any],
88
+ encode_kwargs: EncodeKwargs,
89
89
  ) -> PairClassificationDistances:
90
90
  logger.info("Running pair classification - Encoding samples (1/2)")
91
91
  embeddings1 = model.encode(
@@ -148,7 +148,9 @@ class PairClassificationEvaluator(Evaluator):
148
148
  hf_subset: str,
149
149
  **encode_kwargs: Any,
150
150
  ) -> np.ndarray:
151
- index_map, all_unique_texts, all_texts_indexes = {}, [], []
151
+ index_map = {}
152
+ all_unique_texts: list[str] = []
153
+ all_texts_indexes = []
152
154
  for text in all_texts:
153
155
  text_hash = hash(text)
154
156
  if text_hash not in index_map:
@@ -1,11 +1,11 @@
1
1
  import logging
2
2
  from collections.abc import Sequence
3
- from typing import Any
4
3
 
5
4
  from mteb.abstasks.task_metadata import TaskMetadata
6
5
  from mteb.models import SearchProtocol
7
6
  from mteb.types import (
8
7
  CorpusDatasetType,
8
+ EncodeKwargs,
9
9
  QueryDatasetType,
10
10
  RelevantDocumentsType,
11
11
  RetrievalEvaluationResult,
@@ -48,7 +48,7 @@ class RetrievalEvaluator(Evaluator):
48
48
  def __call__( # type: ignore[override]
49
49
  self,
50
50
  search_model: SearchProtocol,
51
- encode_kwargs: dict[str, Any],
51
+ encode_kwargs: EncodeKwargs,
52
52
  ) -> RetrievalOutputType:
53
53
  logger.info("Running retrieval task - Indexing corpus...")
54
54
  search_model.index(
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  from collections import defaultdict
3
+ from collections.abc import Mapping
3
4
  from typing import Any
4
5
 
5
6
  import numpy as np
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
15
16
 
16
17
  def mrr(
17
18
  qrels: RelevantDocumentsType,
18
- results: dict[str, dict[str, float]],
19
+ results: Mapping[str, Mapping[str, float]],
19
20
  k_values: list[int],
20
21
  ) -> dict[str, list[float]]:
21
22
  mrr_metrics = defaultdict(list)
@@ -32,7 +33,7 @@ def mrr(
32
33
  doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
33
34
  }
34
35
  for k in k_values:
35
- rr = 0
36
+ rr = 0.0
36
37
  for rank, hit in enumerate(top_hits[query_id][0:k]):
37
38
  if hit[0] in query_relevant_docs:
38
39
  rr = 1.0 / (rank + 1)
@@ -45,8 +46,8 @@ def recall_cap(
45
46
  qrels: RelevantDocumentsType,
46
47
  results: dict[str, dict[str, float]],
47
48
  k_values: list[int],
48
- ) -> dict[str, list[float]]:
49
- capped_recall = defaultdict(list)
49
+ ) -> dict[str, list[float | None]]:
50
+ capped_recall: dict[str, list[float | None]] = defaultdict(list)
50
51
 
51
52
  k_max = max(k_values)
52
53
 
@@ -139,7 +140,7 @@ def calculate_pmrr(original_run, new_run, changed_qrels):
139
140
  changes = []
140
141
  for qid in changed_qrels.keys():
141
142
  if qid + "-og" not in original_run or qid + "-changed" not in new_run:
142
- logging.warning(f"Query {qid} not found in the runs for calculating p-MRR")
143
+ logger.warning(f"Query {qid} not found in the runs for calculating p-MRR")
143
144
  continue
144
145
  original_qid_run = original_run[qid + "-og"]
145
146
  new_qid_run = new_run[qid + "-changed"]
@@ -188,7 +189,7 @@ def evaluate_p_mrr_change(
188
189
  Returns:
189
190
  A dictionary with the scores, including "p-MRR", "og" and "changed" keys.
190
191
  """
191
- followir_scores = defaultdict(dict)
192
+ followir_scores: dict[str, float | dict[str, float]] = defaultdict(dict)
192
193
 
193
194
  qrels_sep = {
194
195
  "og": {k: v for k, v in qrels.items() if k.endswith("-og")},
@@ -227,7 +228,7 @@ def evaluate_p_mrr_change(
227
228
  ndcg, _map, recall, precision, naucs, avg_mrr, naucs_mrr, cv_recall, {}
228
229
  )
229
230
  for key, value in scores_dict.items():
230
- followir_scores[name][key] = value
231
+ followir_scores[name][key] = value # type: ignore[index]
231
232
 
232
233
  return followir_scores
233
234
 
@@ -254,8 +255,8 @@ def confidence_scores(sim_scores: list[float]) -> dict[str, float]:
254
255
  sim_scores_sorted = sorted(sim_scores)[::-1]
255
256
 
256
257
  cs_max = sim_scores_sorted[0]
257
- cs_std = np.std(sim_scores)
258
- cs_diff1 = None
258
+ cs_std = float(np.std(sim_scores))
259
+ cs_diff1 = 0.0
259
260
  if len(sim_scores) > 1:
260
261
  cs_diff1 = sim_scores_sorted[0] - sim_scores_sorted[1]
261
262
  elif len(sim_scores) == 1:
@@ -410,7 +411,7 @@ def make_score_dict(
410
411
  cv_recall: dict[str, float],
411
412
  task_scores: dict[str, float],
412
413
  previous_results_model_meta: dict[str, Any] | None = None,
413
- ) -> dict[str, float]:
414
+ ) -> dict[str, Any]:
414
415
  return {
415
416
  **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
416
417
  **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
@@ -528,7 +529,7 @@ def max_over_subqueries(
528
529
 
529
530
 
530
531
  def calculate_retrieval_scores(
531
- results: dict[str, dict[str, float]],
532
+ results: Mapping[str, Mapping[str, float]],
532
533
  qrels: RelevantDocumentsType,
533
534
  k_values: list[int],
534
535
  skip_first_result: bool = False,
@@ -576,7 +577,7 @@ def calculate_retrieval_scores(
576
577
 
577
578
 
578
579
  def evaluate_abstention(
579
- results: dict[str, dict[str, float]],
580
+ results: Mapping[str, Mapping[str, float]],
580
581
  metric_scores: dict[str, list[float]],
581
582
  ) -> dict[str, float]:
582
583
  """Computes normalized Area Under the Curve on a set of evaluated instances as presented in the paper https://arxiv.org/abs/2402.12997
@@ -591,21 +592,21 @@ def evaluate_abstention(
591
592
  all_sim_scores = [list(results[qid].values()) for qid in list(results.keys())]
592
593
  all_conf_scores = [confidence_scores(sim_scores) for sim_scores in all_sim_scores]
593
594
  conf_fcts = list(all_conf_scores[0].keys())
594
- all_conf_scores = {
595
+ all_conf_scores_ = {
595
596
  fct: np.array([x[fct] for x in all_conf_scores]) for fct in conf_fcts
596
597
  }
597
- metric_scores = {k: np.array(v) for k, v in metric_scores.items()}
598
+ metric_scores_ = {k: np.array(v) for k, v in metric_scores.items()}
598
599
  naucs = {}
599
600
 
600
- for metric_name, scores in metric_scores.items():
601
- for fct, conf_scores in all_conf_scores.items():
601
+ for metric_name, scores in metric_scores_.items():
602
+ for fct, conf_scores in all_conf_scores_.items():
602
603
  naucs[f"nAUC_{metric_name}_{fct}"] = nauc(conf_scores, scores)
603
604
 
604
605
  return naucs
605
606
 
606
607
 
607
608
  def calculate_cv_recall(
608
- results: dict[str, dict[str, float]],
609
+ results: Mapping[str, Mapping[str, float]],
609
610
  qrels: RelevantDocumentsType,
610
611
  k_values: list[int],
611
612
  skip_first_result: bool = False,
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Any, Protocol
2
+ from typing import Any, Protocol, cast
3
3
 
4
4
  import numpy as np
5
5
  from datasets import Dataset
@@ -9,7 +9,7 @@ from typing_extensions import Self
9
9
  from mteb._create_dataloaders import create_dataloader
10
10
  from mteb.abstasks.task_metadata import TaskMetadata
11
11
  from mteb.models import EncoderProtocol
12
- from mteb.types import BatchedInput
12
+ from mteb.types import Array, BatchedInput, EncodeKwargs
13
13
 
14
14
  from .evaluator import Evaluator
15
15
 
@@ -17,11 +17,11 @@ logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
19
  class SklearnModelProtocol(Protocol):
20
- def fit(self, X: np.ndarray, y: np.ndarray | list[int]) -> None: ... # noqa: N803
21
- def predict(self, X: np.ndarray) -> np.ndarray: ... # noqa: N803
20
+ def fit(self, X: Array, y: np.ndarray | list[int]) -> None: ... # noqa: N803
21
+ def predict(self, X: Array) -> np.ndarray: ... # noqa: N803
22
22
  def get_params(self) -> dict[str, Any]: ...
23
- def set_params(self, **kwargs: dict[str, Any]) -> Self: ...
24
- def score(self, X: np.ndarray, y: np.ndarray | list[int]) -> float: ... # noqa: N803
23
+ def set_params(self, random_state: int, **kwargs: dict[str, Any]) -> Self: ...
24
+ def score(self, X: Array, y: np.ndarray | list[int]) -> float: ... # noqa: N803
25
25
 
26
26
 
27
27
  class SklearnEvaluator(Evaluator):
@@ -50,7 +50,7 @@ class SklearnEvaluator(Evaluator):
50
50
  self.evaluator_model = evaluator_model
51
51
 
52
52
  def create_dataloaders(
53
- self, encode_kwargs: dict[str, Any]
53
+ self, encode_kwargs: EncodeKwargs
54
54
  ) -> tuple[DataLoader[BatchedInput], DataLoader[BatchedInput]]:
55
55
  dataloader_train = create_dataloader(
56
56
  self.train_dataset,
@@ -70,9 +70,9 @@ class SklearnEvaluator(Evaluator):
70
70
  self,
71
71
  model: EncoderProtocol,
72
72
  *,
73
- encode_kwargs: dict[str, Any],
74
- test_cache: np.ndarray | None = None,
75
- ) -> tuple[np.ndarray, np.ndarray]:
73
+ encode_kwargs: EncodeKwargs,
74
+ test_cache: Array | None = None,
75
+ ) -> tuple[np.ndarray, Array]:
76
76
  """Classification evaluation by training a sklearn classifier on the embeddings of the training set and evaluating on the embeddings of the test set.
77
77
 
78
78
  Args:
@@ -104,6 +104,7 @@ class SklearnEvaluator(Evaluator):
104
104
  hf_subset=self.hf_subset,
105
105
  **encode_kwargs,
106
106
  )
107
+ test_cache = cast(Array, test_cache)
107
108
 
108
109
  logger.info("Running - Fitting classifier...")
109
110
  y_train = self.train_dataset[self.label_column_name]
@@ -1,7 +1,5 @@
1
1
  import logging
2
- from typing import Any
3
2
 
4
- import numpy as np
5
3
  import torch
6
4
  from datasets import Dataset
7
5
  from tqdm.auto import tqdm
@@ -10,6 +8,7 @@ from mteb._create_dataloaders import _create_dataloader_from_texts
10
8
  from mteb._evaluators.evaluator import Evaluator
11
9
  from mteb.abstasks.task_metadata import TaskMetadata
12
10
  from mteb.models import EncoderProtocol
11
+ from mteb.types import Array, EncodeKwargs
13
12
 
14
13
  logger = logging.getLogger(__name__)
15
14
 
@@ -33,7 +32,10 @@ class BitextMiningEvaluator(Evaluator):
33
32
  self.task_metadata = task_metadata
34
33
 
35
34
  def __call__(
36
- self, model: EncoderProtocol, *, encode_kwargs: dict[str, Any]
35
+ self,
36
+ model: EncoderProtocol,
37
+ *,
38
+ encode_kwargs: EncodeKwargs,
37
39
  ) -> dict[str, list[dict[str, float]]]:
38
40
  pair_elements = {p for pair in self.pairs for p in pair}
39
41
  if isinstance(self.sentences, Dataset):
@@ -69,11 +71,11 @@ class BitextMiningEvaluator(Evaluator):
69
71
 
70
72
  def _similarity_search(
71
73
  self,
72
- query_embeddings: np.ndarray,
73
- corpus_embeddings: np.ndarray,
74
+ query_embeddings: Array,
75
+ corpus_embeddings: Array,
74
76
  model: EncoderProtocol,
75
77
  query_chunk_size: int = 100,
76
- corpus_chunk_size: int = 500000,
78
+ corpus_chunk_size: int = 500_000,
77
79
  ) -> list[dict[str, float]]:
78
80
  """This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings.
79
81
 
@@ -104,13 +106,15 @@ class BitextMiningEvaluator(Evaluator):
104
106
  ):
105
107
  query_embeddings = query_embeddings.to(corpus_embeddings.device)
106
108
 
107
- queries_result_list = [[] for _ in range(len(query_embeddings))]
109
+ queries_result_list: list[list[dict[str, float]]] = [
110
+ [] for _ in range(len(query_embeddings))
111
+ ]
108
112
 
109
113
  for query_start_idx in range(0, len(query_embeddings), query_chunk_size):
110
114
  # Iterate over chunks of the corpus
111
115
  for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size):
112
116
  # Compute cosine similarities
113
- similarity_scores = model.similarity( # type: ignore
117
+ similarity_scores = model.similarity(
114
118
  query_embeddings[
115
119
  query_start_idx : query_start_idx + query_chunk_size
116
120
  ],
@@ -120,15 +124,17 @@ class BitextMiningEvaluator(Evaluator):
120
124
  )
121
125
 
122
126
  # Get top-k scores
123
- cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
124
- torch.tensor(similarity_scores),
125
- 1,
126
- dim=1,
127
- largest=True,
128
- sorted=False,
127
+ cos_scores_top_k_values_tensor, cos_scores_top_k_idx_tensor = (
128
+ torch.topk(
129
+ torch.tensor(similarity_scores),
130
+ 1,
131
+ dim=1,
132
+ largest=True,
133
+ sorted=False,
134
+ )
129
135
  )
130
- cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
131
- cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
136
+ cos_scores_top_k_values = cos_scores_top_k_values_tensor.cpu().tolist()
137
+ cos_scores_top_k_idx = cos_scores_top_k_idx_tensor.cpu().tolist()
132
138
 
133
139
  for query_itr in range(len(similarity_scores)):
134
140
  for sub_corpus_id, score in zip(
@@ -141,11 +147,14 @@ class BitextMiningEvaluator(Evaluator):
141
147
  {"corpus_id": corpus_id, "score": score}
142
148
  )
143
149
 
150
+ result_queries_list: list[dict[str, float]] = [
151
+ {} for _ in range(len(query_embeddings))
152
+ ]
144
153
  # Sort and strip to top_k results
145
154
  for idx in range(len(queries_result_list)):
146
155
  queries_result_list[idx] = sorted(
147
156
  queries_result_list[idx], key=lambda x: x["score"], reverse=True
148
157
  )
149
- queries_result_list[idx] = queries_result_list[idx][0]
158
+ result_queries_list[idx] = queries_result_list[idx][0]
150
159
 
151
- return queries_result_list
160
+ return result_queries_list