mteb 2.5.2__py3-none-any.whl → 2.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (241) hide show
  1. mteb/__init__.py +2 -0
  2. mteb/_create_dataloaders.py +17 -18
  3. mteb/_evaluators/any_sts_evaluator.py +3 -3
  4. mteb/_evaluators/clustering_evaluator.py +2 -2
  5. mteb/_evaluators/evaluator.py +4 -2
  6. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +10 -8
  7. mteb/_evaluators/pair_classification_evaluator.py +5 -3
  8. mteb/_evaluators/retrieval_evaluator.py +2 -2
  9. mteb/_evaluators/retrieval_metrics.py +18 -17
  10. mteb/_evaluators/sklearn_evaluator.py +11 -10
  11. mteb/_evaluators/text/bitext_mining_evaluator.py +27 -18
  12. mteb/_evaluators/text/summarization_evaluator.py +23 -18
  13. mteb/_evaluators/zeroshot_classification_evaluator.py +5 -3
  14. mteb/abstasks/_data_filter/filters.py +1 -1
  15. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  16. mteb/abstasks/_statistics_calculation.py +18 -10
  17. mteb/abstasks/_stratification.py +18 -18
  18. mteb/abstasks/abstask.py +35 -28
  19. mteb/abstasks/aggregate_task_metadata.py +1 -9
  20. mteb/abstasks/aggregated_task.py +10 -29
  21. mteb/abstasks/classification.py +15 -10
  22. mteb/abstasks/clustering.py +19 -15
  23. mteb/abstasks/clustering_legacy.py +10 -10
  24. mteb/abstasks/image/image_text_pair_classification.py +7 -4
  25. mteb/abstasks/multilabel_classification.py +23 -19
  26. mteb/abstasks/pair_classification.py +20 -11
  27. mteb/abstasks/regression.py +4 -4
  28. mteb/abstasks/retrieval.py +28 -24
  29. mteb/abstasks/retrieval_dataset_loaders.py +2 -2
  30. mteb/abstasks/sts.py +8 -5
  31. mteb/abstasks/task_metadata.py +31 -33
  32. mteb/abstasks/text/bitext_mining.py +39 -28
  33. mteb/abstasks/text/reranking.py +8 -6
  34. mteb/abstasks/text/summarization.py +10 -5
  35. mteb/abstasks/zeroshot_classification.py +8 -4
  36. mteb/benchmarks/benchmark.py +4 -2
  37. mteb/benchmarks/benchmarks/__init__.py +4 -0
  38. mteb/benchmarks/benchmarks/benchmarks.py +112 -11
  39. mteb/benchmarks/get_benchmark.py +14 -55
  40. mteb/cache.py +182 -29
  41. mteb/cli/_display_tasks.py +2 -2
  42. mteb/cli/build_cli.py +110 -14
  43. mteb/cli/generate_model_card.py +43 -23
  44. mteb/deprecated_evaluator.py +63 -49
  45. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2CybersecurityRetrieval.json +32 -0
  46. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EconomicRetrieval.json +32 -0
  47. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2EnergyRetrieval.json +32 -0
  48. mteb/descriptive_stats/Image/DocumentUnderstanding/KoVidore2HrRetrieval.json +32 -0
  49. mteb/descriptive_stats/Retrieval/ChemRxivRetrieval.json +30 -0
  50. mteb/descriptive_stats/Retrieval/EuroPIRQRetrieval.json +116 -0
  51. mteb/descriptive_stats/Retrieval/NanoClimateFEVER-VN.json +30 -0
  52. mteb/descriptive_stats/Retrieval/NanoDBPedia-VN.json +30 -0
  53. mteb/descriptive_stats/Retrieval/NanoFEVER-VN.json +30 -0
  54. mteb/descriptive_stats/Retrieval/NanoHotpotQA-VN.json +30 -0
  55. mteb/descriptive_stats/Retrieval/NanoMSMARCO-VN.json +30 -0
  56. mteb/descriptive_stats/Retrieval/NanoNQ-VN.json +30 -0
  57. mteb/descriptive_stats/Retrieval/TVPLRetrieval.json +30 -0
  58. mteb/evaluate.py +44 -33
  59. mteb/filter_tasks.py +25 -26
  60. mteb/get_tasks.py +29 -30
  61. mteb/languages/language_scripts.py +5 -3
  62. mteb/leaderboard/app.py +162 -34
  63. mteb/load_results.py +12 -12
  64. mteb/models/abs_encoder.py +10 -6
  65. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  66. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  67. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
  68. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
  69. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  70. mteb/models/get_model_meta.py +21 -3
  71. mteb/models/instruct_wrapper.py +28 -8
  72. mteb/models/model_implementations/align_models.py +1 -1
  73. mteb/models/model_implementations/andersborges.py +4 -4
  74. mteb/models/model_implementations/ara_models.py +1 -1
  75. mteb/models/model_implementations/arctic_models.py +8 -8
  76. mteb/models/model_implementations/b1ade_models.py +1 -1
  77. mteb/models/model_implementations/bge_models.py +45 -21
  78. mteb/models/model_implementations/bica_model.py +3 -3
  79. mteb/models/model_implementations/blip2_models.py +2 -2
  80. mteb/models/model_implementations/blip_models.py +16 -16
  81. mteb/models/model_implementations/bm25.py +4 -4
  82. mteb/models/model_implementations/bmretriever_models.py +6 -4
  83. mteb/models/model_implementations/cadet_models.py +1 -1
  84. mteb/models/model_implementations/cde_models.py +11 -4
  85. mteb/models/model_implementations/clip_models.py +6 -6
  86. mteb/models/model_implementations/clips_models.py +3 -3
  87. mteb/models/model_implementations/codefuse_models.py +5 -5
  88. mteb/models/model_implementations/codesage_models.py +3 -3
  89. mteb/models/model_implementations/cohere_models.py +5 -5
  90. mteb/models/model_implementations/cohere_v.py +2 -2
  91. mteb/models/model_implementations/colpali_models.py +3 -3
  92. mteb/models/model_implementations/colqwen_models.py +8 -8
  93. mteb/models/model_implementations/colsmol_models.py +2 -2
  94. mteb/models/model_implementations/conan_models.py +1 -1
  95. mteb/models/model_implementations/dino_models.py +42 -42
  96. mteb/models/model_implementations/e5_instruct.py +23 -4
  97. mteb/models/model_implementations/e5_models.py +9 -9
  98. mteb/models/model_implementations/e5_v.py +6 -6
  99. mteb/models/model_implementations/eagerworks_models.py +1 -1
  100. mteb/models/model_implementations/emillykkejensen_models.py +6 -6
  101. mteb/models/model_implementations/en_code_retriever.py +1 -1
  102. mteb/models/model_implementations/euler_models.py +2 -2
  103. mteb/models/model_implementations/fa_models.py +9 -9
  104. mteb/models/model_implementations/facebookai.py +14 -2
  105. mteb/models/model_implementations/geogpt_models.py +1 -1
  106. mteb/models/model_implementations/gme_v_models.py +6 -5
  107. mteb/models/model_implementations/google_models.py +1 -1
  108. mteb/models/model_implementations/granite_vision_embedding_models.py +1 -1
  109. mteb/models/model_implementations/gritlm_models.py +2 -2
  110. mteb/models/model_implementations/gte_models.py +25 -13
  111. mteb/models/model_implementations/hinvec_models.py +1 -1
  112. mteb/models/model_implementations/ibm_granite_models.py +30 -6
  113. mteb/models/model_implementations/inf_models.py +2 -2
  114. mteb/models/model_implementations/jasper_models.py +2 -2
  115. mteb/models/model_implementations/jina_clip.py +48 -10
  116. mteb/models/model_implementations/jina_models.py +18 -11
  117. mteb/models/model_implementations/kblab.py +12 -6
  118. mteb/models/model_implementations/kennethenevoldsen_models.py +4 -4
  119. mteb/models/model_implementations/kfst.py +1 -1
  120. mteb/models/model_implementations/kowshik24_models.py +1 -1
  121. mteb/models/model_implementations/lgai_embedding_models.py +1 -1
  122. mteb/models/model_implementations/linq_models.py +1 -1
  123. mteb/models/model_implementations/listconranker.py +1 -1
  124. mteb/models/model_implementations/llm2clip_models.py +6 -6
  125. mteb/models/model_implementations/llm2vec_models.py +8 -8
  126. mteb/models/model_implementations/mcinext_models.py +4 -1
  127. mteb/models/model_implementations/mdbr_models.py +17 -3
  128. mteb/models/model_implementations/misc_models.py +68 -68
  129. mteb/models/model_implementations/mixedbread_ai_models.py +332 -0
  130. mteb/models/model_implementations/mme5_models.py +1 -1
  131. mteb/models/model_implementations/moco_models.py +4 -4
  132. mteb/models/model_implementations/mod_models.py +1 -1
  133. mteb/models/model_implementations/model2vec_models.py +14 -14
  134. mteb/models/model_implementations/moka_models.py +1 -1
  135. mteb/models/model_implementations/nbailab.py +3 -3
  136. mteb/models/model_implementations/no_instruct_sentence_models.py +2 -2
  137. mteb/models/model_implementations/nomic_models.py +30 -15
  138. mteb/models/model_implementations/nomic_models_vision.py +1 -1
  139. mteb/models/model_implementations/nvidia_llama_nemoretriever_colemb.py +15 -9
  140. mteb/models/model_implementations/nvidia_models.py +151 -19
  141. mteb/models/model_implementations/octen_models.py +61 -2
  142. mteb/models/model_implementations/openclip_models.py +13 -13
  143. mteb/models/model_implementations/opensearch_neural_sparse_models.py +5 -5
  144. mteb/models/model_implementations/ops_moa_models.py +1 -1
  145. mteb/models/model_implementations/ordalietech_solon_embeddings_mini_beta_1_1.py +1 -1
  146. mteb/models/model_implementations/pawan_models.py +1 -1
  147. mteb/models/model_implementations/piccolo_models.py +1 -1
  148. mteb/models/model_implementations/pixie_models.py +56 -0
  149. mteb/models/model_implementations/promptriever_models.py +4 -4
  150. mteb/models/model_implementations/pylate_models.py +10 -9
  151. mteb/models/model_implementations/qodo_models.py +2 -2
  152. mteb/models/model_implementations/qtack_models.py +1 -1
  153. mteb/models/model_implementations/qwen3_models.py +3 -3
  154. mteb/models/model_implementations/qzhou_models.py +2 -2
  155. mteb/models/model_implementations/random_baseline.py +3 -3
  156. mteb/models/model_implementations/rasgaard_models.py +2 -2
  157. mteb/models/model_implementations/reasonir_model.py +1 -1
  158. mteb/models/model_implementations/repllama_models.py +3 -3
  159. mteb/models/model_implementations/rerankers_custom.py +12 -6
  160. mteb/models/model_implementations/rerankers_monot5_based.py +17 -17
  161. mteb/models/model_implementations/richinfoai_models.py +1 -1
  162. mteb/models/model_implementations/ru_sentence_models.py +20 -20
  163. mteb/models/model_implementations/ruri_models.py +10 -10
  164. mteb/models/model_implementations/salesforce_models.py +3 -3
  165. mteb/models/model_implementations/samilpwc_models.py +1 -1
  166. mteb/models/model_implementations/sarashina_embedding_models.py +2 -2
  167. mteb/models/model_implementations/searchmap_models.py +1 -1
  168. mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +113 -146
  169. mteb/models/model_implementations/sentence_transformers_models.py +124 -22
  170. mteb/models/model_implementations/shuu_model.py +1 -1
  171. mteb/models/model_implementations/siglip_models.py +20 -20
  172. mteb/models/model_implementations/slm_models.py +416 -0
  173. mteb/models/model_implementations/spartan8806_atles_champion.py +1 -1
  174. mteb/models/model_implementations/stella_models.py +17 -4
  175. mteb/models/model_implementations/tarka_models.py +2 -2
  176. mteb/models/model_implementations/text2vec_models.py +9 -3
  177. mteb/models/model_implementations/ua_sentence_models.py +1 -1
  178. mteb/models/model_implementations/uae_models.py +7 -1
  179. mteb/models/model_implementations/vdr_models.py +1 -1
  180. mteb/models/model_implementations/vi_vn_models.py +6 -6
  181. mteb/models/model_implementations/vlm2vec_models.py +3 -3
  182. mteb/models/model_implementations/voyage_models.py +84 -0
  183. mteb/models/model_implementations/voyage_v.py +9 -7
  184. mteb/models/model_implementations/youtu_models.py +1 -1
  185. mteb/models/model_implementations/yuan_models.py +1 -1
  186. mteb/models/model_implementations/yuan_models_en.py +1 -1
  187. mteb/models/model_meta.py +80 -31
  188. mteb/models/models_protocols.py +22 -6
  189. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
  190. mteb/models/search_wrappers.py +33 -18
  191. mteb/models/sentence_transformer_wrapper.py +50 -25
  192. mteb/models/vllm_wrapper.py +327 -0
  193. mteb/py.typed +0 -0
  194. mteb/results/benchmark_results.py +29 -21
  195. mteb/results/model_result.py +52 -22
  196. mteb/results/task_result.py +80 -58
  197. mteb/similarity_functions.py +11 -7
  198. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  199. mteb/tasks/classification/est/estonian_valence.py +1 -1
  200. mteb/tasks/classification/kur/kurdish_sentiment_classification.py +2 -2
  201. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  202. mteb/tasks/clustering/eng/hume_wiki_cities_clustering.py +1 -1
  203. mteb/tasks/clustering/eng/wiki_cities_clustering.py +1 -1
  204. mteb/tasks/clustering/zho/cmteb_clustering.py +2 -2
  205. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  206. mteb/tasks/reranking/multilingual/wikipedia_reranking_multilingual.py +1 -1
  207. mteb/tasks/retrieval/code/code_rag.py +12 -12
  208. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  209. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  210. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  211. mteb/tasks/retrieval/eng/__init__.py +2 -0
  212. mteb/tasks/retrieval/eng/chemrxiv.py +33 -0
  213. mteb/tasks/retrieval/eng/cub200_i2i_retrieval.py +1 -1
  214. mteb/tasks/retrieval/kor/__init__.py +15 -1
  215. mteb/tasks/retrieval/kor/kovidore2_bench_retrieval.py +142 -0
  216. mteb/tasks/retrieval/multilingual/__init__.py +2 -0
  217. mteb/tasks/retrieval/multilingual/euro_pirq_retrieval.py +43 -0
  218. mteb/tasks/retrieval/multilingual/vidore3_bench_retrieval.py +90 -100
  219. mteb/tasks/retrieval/nob/norquad.py +2 -2
  220. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  221. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  222. mteb/tasks/retrieval/vie/__init__.py +14 -6
  223. mteb/tasks/retrieval/vie/climate_fevervn_retrieval.py +39 -0
  224. mteb/tasks/retrieval/vie/db_pedia_vn_retrieval.py +39 -0
  225. mteb/tasks/retrieval/vie/fevervn_retrieval.py +39 -0
  226. mteb/tasks/retrieval/vie/hotpot_qavn_retrieval.py +39 -0
  227. mteb/tasks/retrieval/vie/msmarcovn_retrieval.py +48 -0
  228. mteb/tasks/retrieval/vie/nqvn_retrieval.py +39 -0
  229. mteb/tasks/retrieval/vie/tvpl_retrieval.py +42 -0
  230. mteb/tasks/retrieval/vie/zac_legal_text_retrieval.py +15 -1
  231. mteb/types/__init__.py +2 -0
  232. mteb/types/_encoder_io.py +12 -0
  233. mteb/types/_result.py +2 -1
  234. mteb/types/statistics.py +9 -3
  235. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/METADATA +15 -4
  236. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/RECORD +240 -219
  237. mteb/models/model_implementations/mxbai_models.py +0 -111
  238. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/WHEEL +0 -0
  239. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/entry_points.txt +0 -0
  240. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/licenses/LICENSE +0 -0
  241. {mteb-2.5.2.dist-info → mteb-2.7.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
1
+ from __future__ import annotations
2
+
3
+ import atexit
4
+ import gc
5
+ import logging
6
+ import os
7
+ from collections.abc import Callable
8
+ from typing import TYPE_CHECKING, Any, Literal
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data import DataLoader
13
+
14
+ from mteb._requires_package import requires_package
15
+ from mteb.abstasks.task_metadata import TaskMetadata
16
+ from mteb.models import ModelMeta
17
+ from mteb.models.abs_encoder import AbsEncoder
18
+ from mteb.types import Array, BatchedInput, PromptType
19
+
20
+ if TYPE_CHECKING:
21
+ from vllm.config import PoolerConfig # type: ignore[import-not-found]
22
+ else:
23
+ PoolerConfig = dict[str, Any]
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ Dtype = Literal["half", "float16", "float", "float32", "bfloat16", "auto"]
28
+
29
+
30
+ class VllmWrapperBase:
31
+ """Wrapper for vllm serving engine."""
32
+
33
+ convert = "auto"
34
+ mteb_model_meta: ModelMeta | None = None
35
+
36
+ def __init__(
37
+ self,
38
+ model: str | ModelMeta,
39
+ revision: str | None = None,
40
+ *,
41
+ trust_remote_code: bool = True,
42
+ dtype: Dtype = "auto",
43
+ head_dtype: Literal["model"] | Dtype | None = None,
44
+ max_model_len: int | None = None,
45
+ max_num_batched_tokens: int | None = None,
46
+ max_num_seqs: int = 128,
47
+ tensor_parallel_size: int = 1,
48
+ enable_prefix_caching: bool | None = None,
49
+ gpu_memory_utilization: float = 0.9,
50
+ hf_overrides: dict[str, Any] | None = None,
51
+ pooler_config: PoolerConfig | None = None,
52
+ enforce_eager: bool = False,
53
+ **kwargs: Any,
54
+ ):
55
+ """Wrapper for vllm serving engine.
56
+
57
+ Args:
58
+ model: model name string.
59
+ revision: The revision of the model to use.
60
+ trust_remote_code: Whether to trust remote code execution when loading the model.
61
+ Should be True for models with custom code.
62
+ dtype: Data type for model weights. "auto" will automatically select appropriate
63
+ dtype based on hardware and model capabilities. vllm uses flash attention by
64
+ default, which does not support fp32. Therefore, it defaults to using fp16 for
65
+ inference on fp32 models. Testing has shown a relatively small drop in accuracy.
66
+ You can manually opt for fp32, but inference speed will be very slow.
67
+ head_dtype: "head" refers to the last Linear layer(s) of an LLMs, such as the score
68
+ or classifier in a classification model. Uses fp32 for the head by default to
69
+ gain extra precision.
70
+ max_model_len: Maximum sequence length (context window) supported by the model.
71
+ If None, uses the model's default maximum length.
72
+ max_num_batched_tokens: Maximum number of tokens to process in a single batch.
73
+ If None, automatically determined.
74
+ max_num_seqs: Maximum number of sequences to process concurrently.
75
+ tensor_parallel_size: Number of GPUs for tensor parallelism.
76
+ enable_prefix_caching: Whether to enable KV cache sharing for common prompt prefixes.
77
+ If None, uses the model's default setting.
78
+ gpu_memory_utilization: Target GPU memory utilization ratio (0.0 to 1.0).
79
+ hf_overrides: Dictionary mapping Hugging Face configuration keys to override values.
80
+ pooler_config: Controls the behavior of output pooling in pooling models.
81
+ enforce_eager: Whether to disable CUDA graph optimization and use eager execution.
82
+ **kwargs: Additional arguments to pass to the vllm serving engine model.
83
+ """
84
+ requires_package(
85
+ self,
86
+ "vllm",
87
+ "Wrapper for vllm serving engine",
88
+ install_instruction="pip install mteb[vllm]",
89
+ )
90
+
91
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
92
+
93
+ from vllm import LLM, EngineArgs
94
+
95
+ hf_overrides = {} if hf_overrides is None else hf_overrides
96
+
97
+ if head_dtype is not None:
98
+ hf_overrides["head_dtype"] = head_dtype
99
+
100
+ model_name = model if isinstance(model, str) else model.name
101
+
102
+ if isinstance(model, ModelMeta):
103
+ logger.info(
104
+ "Using revision from model meta. Passed revision will be ignored"
105
+ )
106
+ revision = model.revision
107
+
108
+ args = EngineArgs(
109
+ model=model_name,
110
+ revision=revision,
111
+ runner="pooling",
112
+ convert=self.convert, # type: ignore[arg-type]
113
+ max_model_len=max_model_len,
114
+ max_num_batched_tokens=max_num_batched_tokens,
115
+ max_num_seqs=max_num_seqs,
116
+ tensor_parallel_size=tensor_parallel_size,
117
+ enable_prefix_caching=enable_prefix_caching,
118
+ gpu_memory_utilization=gpu_memory_utilization,
119
+ hf_overrides=hf_overrides,
120
+ pooler_config=pooler_config,
121
+ enforce_eager=enforce_eager,
122
+ trust_remote_code=trust_remote_code,
123
+ dtype=dtype,
124
+ **kwargs,
125
+ )
126
+ self.llm = LLM(**vars(args))
127
+
128
+ if isinstance(model, str):
129
+ self.mteb_model_meta = ModelMeta.from_hub(model=model, revision=revision)
130
+ else:
131
+ self.mteb_model_meta = model
132
+
133
+ atexit.register(self.cleanup)
134
+
135
+ def cleanup(self):
136
+ """Clean up the VLLM distributed runtime environment and release GPU resources."""
137
+ if self.llm is None:
138
+ return
139
+
140
+ from vllm.distributed import ( # type: ignore[import-not-found]
141
+ cleanup_dist_env_and_memory,
142
+ )
143
+
144
+ self.llm = None
145
+ gc.collect()
146
+ cleanup_dist_env_and_memory()
147
+
148
+ def __del__(self):
149
+ try:
150
+ self.cleanup()
151
+ except Exception:
152
+ pass
153
+
154
+
155
+ class VllmEncoderWrapper(AbsEncoder, VllmWrapperBase):
156
+ """vLLM wrapper for Encoder models.
157
+
158
+ Args:
159
+ model: model name string or ModelMeta.
160
+ revision: The revision of the model to use.
161
+ prompt_dict: A dictionary mapping task names to prompt strings.
162
+ use_instructions: Whether to use instructions from the prompt_dict.
163
+ When False, values from prompt_dict are used as static prompts (prefixes).
164
+ When True, values from prompt_dict are used as instructions to be formatted
165
+ using the instruction_template.
166
+ instruction_template: A template or callable to format instructions.
167
+ Can be a string with '{instruction}' placeholder or a callable that takes
168
+ the instruction and prompt type and returns a formatted string.
169
+ apply_instruction_to_documents: Whether to apply instructions to documents prompts.
170
+ **kwargs: Additional arguments to pass to the vllm serving engine model.
171
+ """
172
+
173
+ convert = "embed"
174
+
175
+ def __init__(
176
+ self,
177
+ model: str | ModelMeta,
178
+ revision: str | None = None,
179
+ prompt_dict: dict[str, str] | None = None,
180
+ use_instructions: bool = False,
181
+ instruction_template: (
182
+ str | Callable[[str, PromptType | None], str] | None
183
+ ) = None,
184
+ apply_instruction_to_documents: bool = True,
185
+ **kwargs: Any,
186
+ ):
187
+ if use_instructions and instruction_template is None:
188
+ raise ValueError(
189
+ "To use instructions, an instruction_template must be provided. "
190
+ "For example, `Instruction: {instruction}`"
191
+ )
192
+
193
+ if (
194
+ isinstance(instruction_template, str)
195
+ and "{instruction}" not in instruction_template
196
+ ):
197
+ raise ValueError(
198
+ "Instruction template must contain the string '{instruction}'."
199
+ )
200
+
201
+ self.prompts_dict = prompt_dict
202
+ self.use_instructions = use_instructions
203
+ self.instruction_template = instruction_template
204
+ self.apply_instruction_to_passages = apply_instruction_to_documents
205
+ super().__init__(
206
+ model,
207
+ revision,
208
+ **kwargs,
209
+ )
210
+
211
+ def encode(
212
+ self,
213
+ inputs: DataLoader[BatchedInput],
214
+ *,
215
+ task_metadata: TaskMetadata,
216
+ hf_split: str,
217
+ hf_subset: str,
218
+ prompt_type: PromptType | None = None,
219
+ **kwargs: Any,
220
+ ) -> Array:
221
+ """Encodes the given sentences using the encoder.
222
+
223
+ Args:
224
+ inputs: The sentences to encode.
225
+ task_metadata: The metadata of the task. Sentence-transformers uses this to
226
+ determine which prompt to use from a specified dictionary.
227
+ prompt_type: The name type of prompt. (query or passage)
228
+ hf_split: Split of current task
229
+ hf_subset: Subset of current task
230
+ **kwargs: Additional arguments to pass to the encoder.
231
+
232
+ Returns:
233
+ The encoded sentences.
234
+ """
235
+ prompt = ""
236
+ if self.use_instructions and self.prompts_dict is not None:
237
+ prompt = self.get_task_instruction(task_metadata, prompt_type)
238
+ elif self.prompts_dict is not None:
239
+ prompt_name = self.get_prompt_name(task_metadata, prompt_type)
240
+ if prompt_name is not None:
241
+ prompt = self.prompts_dict.get(prompt_name, "")
242
+
243
+ if (
244
+ self.use_instructions
245
+ and self.apply_instruction_to_passages is False
246
+ and prompt_type == PromptType.document
247
+ ):
248
+ logger.info(
249
+ f"No instruction used, because prompt type = {prompt_type.document}"
250
+ )
251
+ prompt = ""
252
+ else:
253
+ logger.info(
254
+ f"Using instruction: '{prompt}' for task: '{task_metadata.name}' prompt type: '{prompt_type}'"
255
+ )
256
+
257
+ prompts = [prompt + text for batch in inputs for text in batch["text"]]
258
+ outputs = self.llm.encode(
259
+ prompts, pooling_task="embed", truncate_prompt_tokens=-1
260
+ )
261
+ embeddings = torch.stack([output.outputs.data for output in outputs])
262
+ return embeddings
263
+
264
+
265
+ class VllmCrossEncoderWrapper(VllmWrapperBase):
266
+ """vLLM wrapper for CrossEncoder models."""
267
+
268
+ convert = "classify"
269
+
270
+ def __init__(
271
+ self,
272
+ model: str | ModelMeta,
273
+ revision: str | None = None,
274
+ query_prefix: str = "",
275
+ document_prefix: str = "",
276
+ **kwargs: Any,
277
+ ):
278
+ super().__init__(
279
+ model,
280
+ revision,
281
+ **kwargs,
282
+ )
283
+ self.query_prefix = query_prefix
284
+ self.document_prefix = document_prefix
285
+
286
+ def predict(
287
+ self,
288
+ inputs1: DataLoader[BatchedInput],
289
+ inputs2: DataLoader[BatchedInput],
290
+ *,
291
+ task_metadata: TaskMetadata,
292
+ hf_split: str,
293
+ hf_subset: str,
294
+ prompt_type: PromptType | None = None,
295
+ **kwargs: Any,
296
+ ) -> Array:
297
+ """Predicts relevance scores for pairs of inputs. Note that, unlike the encoder, the cross-encoder can compare across inputs.
298
+
299
+ Args:
300
+ inputs1: First Dataloader of inputs to encode. For reranking tasks, these are queries (for text only tasks `QueryDatasetType`).
301
+ inputs2: Second Dataloader of inputs to encode. For reranking, these are documents (for text only tasks `RetrievalOutputType`).
302
+ task_metadata: Metadata of the current task.
303
+ hf_split: Split of current task, allows to know some additional information about current split.
304
+ E.g. Current language
305
+ hf_subset: Subset of current task. Similar to `hf_split` to get more information
306
+ prompt_type: The name type of prompt. (query or passage)
307
+ **kwargs: Additional arguments to pass to the cross-encoder.
308
+
309
+ Returns:
310
+ The predicted relevance scores for each inputs pair.
311
+ """
312
+ queries = [
313
+ self.query_prefix + text for batch in inputs1 for text in batch["text"]
314
+ ]
315
+ corpus = [
316
+ self.document_prefix + text for batch in inputs2 for text in batch["text"]
317
+ ]
318
+ # TODO: support score prompt
319
+
320
+ outputs = self.llm.score(
321
+ queries,
322
+ corpus,
323
+ truncate_prompt_tokens=-1,
324
+ use_tqdm=False,
325
+ )
326
+ scores = np.array([output.outputs.score for output in outputs])
327
+ return scores
mteb/py.typed ADDED
File without changes
@@ -1,10 +1,12 @@
1
+ from __future__ import annotations
2
+
1
3
  import functools
2
4
  import json
3
5
  import logging
4
6
  import warnings
5
- from collections.abc import Callable, Iterable, Iterator, Sequence
7
+ from collections.abc import Callable, Iterable, Iterator
6
8
  from pathlib import Path
7
- from typing import Any, Literal
9
+ from typing import Any, Literal, cast
8
10
 
9
11
  import pandas as pd
10
12
  from packaging.version import InvalidVersion, Version
@@ -33,11 +35,12 @@ from .model_result import ModelResult, _aggregate_and_pivot
33
35
  logger = logging.getLogger(__name__)
34
36
 
35
37
 
36
- # Global cache for model metas and version parsing
37
38
  @functools.lru_cache
38
39
  def _get_cached_model_metas() -> dict[str, str | None]:
39
40
  """Cache model metas to avoid repeated calls."""
40
- return {meta.name: meta.revision for meta in get_model_metas()}
41
+ return {
42
+ meta.name: meta.revision for meta in get_model_metas() if meta.name is not None
43
+ }
41
44
 
42
45
 
43
46
  @functools.lru_cache(maxsize=10000)
@@ -77,10 +80,10 @@ class BenchmarkResults(BaseModel):
77
80
  task_names: list[str] | None = None,
78
81
  languages: list[str] | None = None,
79
82
  domains: list[TaskDomain] | None = None,
80
- task_types: list[TaskType] | None = None, # type: ignore
83
+ task_types: list[TaskType] | None = None,
81
84
  modalities: list[Modalities] | None = None,
82
85
  is_public: bool | None = None,
83
- ) -> Self:
86
+ ) -> BenchmarkResults:
84
87
  # TODO: Same as filter_models
85
88
  model_results = [
86
89
  res._filter_tasks(
@@ -97,7 +100,7 @@ class BenchmarkResults(BaseModel):
97
100
  model_results=[res for res in model_results if res.task_results]
98
101
  )
99
102
 
100
- def select_tasks(self, tasks: Sequence[AbsTask]) -> Self:
103
+ def select_tasks(self, tasks: Iterable[AbsTask]) -> BenchmarkResults:
101
104
  """Select tasks from the benchmark results.
102
105
 
103
106
  Args:
@@ -115,7 +118,7 @@ class BenchmarkResults(BaseModel):
115
118
  self,
116
119
  names: list[str] | list[ModelMeta],
117
120
  revisions: list[str | None] | None = None,
118
- ) -> Self:
121
+ ) -> BenchmarkResults:
119
122
  """Get models by name and revision.
120
123
 
121
124
  Args:
@@ -128,7 +131,7 @@ class BenchmarkResults(BaseModel):
128
131
  models_res = []
129
132
  _revisions = revisions if revisions is not None else [None] * len(names)
130
133
 
131
- name_rev = {}
134
+ name_rev: dict[str, str | None] = {}
132
135
 
133
136
  if len(names) != len(_revisions):
134
137
  raise ValueError(
@@ -137,9 +140,12 @@ class BenchmarkResults(BaseModel):
137
140
 
138
141
  for name, revision in zip(names, _revisions):
139
142
  if isinstance(name, ModelMeta):
143
+ if name.name is None:
144
+ raise ValueError("name in ModelMeta is None. It must be a string.")
140
145
  name_rev[name.name] = name.revision
141
146
  else:
142
- name_rev[name] = revision
147
+ name_ = cast(str, name)
148
+ name_rev[name_] = revision
143
149
 
144
150
  for model_res in self.model_results:
145
151
  model_name = model_res.model_name
@@ -159,7 +165,7 @@ class BenchmarkResults(BaseModel):
159
165
  n_parameters_range: tuple[int | None, int | None] = (None, None),
160
166
  use_instructions: bool | None = None,
161
167
  zero_shot_on: list[AbsTask] | None = None,
162
- ) -> Self:
168
+ ) -> BenchmarkResults:
163
169
  # mostly a utility function for the leaderboard app.
164
170
  # I would probably move the filtering of the models outside of this call. No need to call get_model_metas inside the filter.
165
171
  # interface would then be the same as the get_models function
@@ -182,7 +188,7 @@ class BenchmarkResults(BaseModel):
182
188
 
183
189
  return type(self).model_construct(model_results=new_model_results)
184
190
 
185
- def join_revisions(self) -> Self:
191
+ def join_revisions(self) -> BenchmarkResults:
186
192
  """Join revisions of the same model.
187
193
 
188
194
  In case of conflicts, the following rules are applied:
@@ -212,10 +218,10 @@ class BenchmarkResults(BaseModel):
212
218
 
213
219
  # Use cached model metas
214
220
  model_to_main_revision = _get_cached_model_metas()
215
- task_df["main_revision"] = task_df["model"].map(model_to_main_revision) # type: ignore
221
+ task_df["main_revision"] = task_df["model"].map(model_to_main_revision)
216
222
 
217
223
  # Use cached version parsing
218
- task_df["mteb_version"] = task_df["mteb_version"].map(_parse_version_cached) # type: ignore
224
+ task_df["mteb_version"] = task_df["mteb_version"].map(_parse_version_cached)
219
225
 
220
226
  # Filter out rows without scores first
221
227
  task_df = task_df[task_df["has_scores"]]
@@ -259,8 +265,8 @@ class BenchmarkResults(BaseModel):
259
265
  # so grouping by original revision ensures consistent ModelResult creation
260
266
  for (model, model_revision), group in task_df.groupby(["model", "revision"]):
261
267
  model_result = ModelResult.model_construct(
262
- model_name=model,
263
- model_revision=model_revision,
268
+ model_name=model, # type: ignore[arg-type]
269
+ model_revision=model_revision, # type: ignore[arg-type]
264
270
  task_results=list(group["task_result"]),
265
271
  )
266
272
  model_results.append(model_result)
@@ -291,7 +297,7 @@ class BenchmarkResults(BaseModel):
291
297
  {
292
298
  "model": model_res.model_name,
293
299
  "revision": model_res.model_revision,
294
- **model_scores, # type: ignore
300
+ **model_scores,
295
301
  }
296
302
  )
297
303
  except Exception as e:
@@ -364,7 +370,9 @@ class BenchmarkResults(BaseModel):
364
370
  scores_data.extend(model_result._get_score_for_table())
365
371
 
366
372
  if not scores_data:
367
- logger.warning("No scores data available. Returning empty DataFrame.")
373
+ msg = "No scores data available. Returning empty DataFrame."
374
+ logger.warning(msg)
375
+ warnings.warn(msg)
368
376
  return pd.DataFrame()
369
377
 
370
378
  # Create DataFrame
@@ -402,7 +410,7 @@ class BenchmarkResults(BaseModel):
402
410
 
403
411
  return self.benchmark._create_summary_table(self)
404
412
 
405
- def __iter__(self) -> Iterator[ModelResult]:
413
+ def __iter__(self) -> Iterator[ModelResult]: # type: ignore[override]
406
414
  return iter(self.model_results)
407
415
 
408
416
  def __getitem__(self, index: int) -> ModelResult:
@@ -424,11 +432,11 @@ class BenchmarkResults(BaseModel):
424
432
  out_file.write(self.model_dump_json(indent=2))
425
433
 
426
434
  @classmethod
427
- def from_validated(cls, **data) -> Self:
435
+ def from_validated(cls, **data: Any) -> BenchmarkResults:
428
436
  """Create BenchmarkResults from validated data.
429
437
 
430
438
  Args:
431
- data: Dictionary containing the data.
439
+ **data: Arbitrary keyword arguments containing the data.
432
440
 
433
441
  Returns:
434
442
  An instance of BenchmarkResults.
@@ -1,12 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import warnings
3
- from collections.abc import Callable, Iterable, Sequence
4
- from typing import Any, Literal
5
+ from collections.abc import Callable, Iterable
6
+ from typing import Any, Literal, cast
5
7
 
6
8
  import numpy as np
7
9
  import pandas as pd
8
10
  from pydantic import BaseModel, ConfigDict, Field
9
- from typing_extensions import Self
11
+ from typing_extensions import overload
10
12
 
11
13
  from mteb.abstasks.abstask import AbsTask
12
14
  from mteb.abstasks.task_metadata import (
@@ -58,7 +60,7 @@ def _aggregate_and_pivot(
58
60
  index=index_columns,
59
61
  columns=columns,
60
62
  values="score",
61
- aggfunc=aggregation_fn,
63
+ aggfunc=aggregation_fn, # type: ignore[arg-type]
62
64
  ).reset_index()
63
65
  elif format == "long":
64
66
  return (
@@ -81,7 +83,7 @@ class ModelResult(BaseModel):
81
83
  model_revision: str | None
82
84
  task_results: list[TaskResult]
83
85
  default_modalities: list[Modalities] = Field(
84
- default_factory=lambda: ["text"], alias="modalities"
86
+ default_factory=lambda: [cast(Modalities, "text")], alias="modalities"
85
87
  )
86
88
  model_config = (
87
89
  ConfigDict( # to free up the name model_* which is otherwise protected
@@ -95,16 +97,17 @@ class ModelResult(BaseModel):
95
97
  return f"ModelResult(model_name={self.model_name}, model_revision={self.model_revision}, task_results=[...](#{n_entries}))"
96
98
 
97
99
  @classmethod
98
- def from_validated(cls, **data: dict[str, Any]) -> Self:
100
+ def from_validated(cls, **data: dict[str, Any]) -> ModelResult:
99
101
  """Create a ModelResult from validated data.
100
102
 
101
103
  Args:
102
104
  data: The validated data.
103
105
  """
104
- data["task_results"] = [
105
- TaskResult.from_validated(**res) for res in data["task_results"]
106
+ data["task_results"] = [ # type: ignore[assignment]
107
+ TaskResult.from_validated(**res) # type: ignore[arg-type]
108
+ for res in data["task_results"]
106
109
  ]
107
- return cls.model_construct(**data)
110
+ return cls.model_construct(**data) # type: ignore[arg-type]
108
111
 
109
112
  def _filter_tasks(
110
113
  self,
@@ -114,7 +117,7 @@ class ModelResult(BaseModel):
114
117
  task_types: list[TaskType] | None = None,
115
118
  modalities: list[Modalities] | None = None,
116
119
  is_public: bool | None = None,
117
- ) -> Self:
120
+ ) -> ModelResult:
118
121
  new_task_results = []
119
122
  for task_result in self.task_results:
120
123
  if (task_names is not None) and (task_result.task_name not in task_names):
@@ -142,7 +145,7 @@ class ModelResult(BaseModel):
142
145
  task_results=new_task_results,
143
146
  )
144
147
 
145
- def select_tasks(self, tasks: Sequence[AbsTask]) -> Self:
148
+ def select_tasks(self, tasks: Iterable[AbsTask]) -> ModelResult:
146
149
  """Select tasks from the ModelResult based on a list of AbsTask objects.
147
150
 
148
151
  Args:
@@ -160,6 +163,28 @@ class ModelResult(BaseModel):
160
163
  task_results=new_task_results,
161
164
  )
162
165
 
166
+ @overload
167
+ def _get_scores(
168
+ self,
169
+ splits: list[SplitName] | None = None,
170
+ languages: list[ISOLanguage | ISOLanguageScript] | None = None,
171
+ scripts: list[ISOLanguageScript] | None = None,
172
+ getter: Callable[[ScoresDict], Score] | None = None,
173
+ aggregation: Callable[[list[Score]], Any] | None = None,
174
+ format: Literal["wide"] = "wide",
175
+ ) -> dict: ...
176
+
177
+ @overload
178
+ def _get_scores(
179
+ self,
180
+ splits: list[SplitName] | None = None,
181
+ languages: list[ISOLanguage | ISOLanguageScript] | None = None,
182
+ scripts: list[ISOLanguageScript] | None = None,
183
+ getter: Callable[[ScoresDict], Score] | None = None,
184
+ aggregation: Callable[[list[Score]], Any] | None = None,
185
+ format: Literal["long"] = "long",
186
+ ) -> list: ...
187
+
163
188
  def _get_scores(
164
189
  self,
165
190
  splits: list[SplitName] | None = None,
@@ -177,21 +202,24 @@ class ModelResult(BaseModel):
177
202
  aggregation = aggregation if aggregation is not None else np.mean
178
203
  else:
179
204
  use_fast = True
205
+ aggregation = cast(Callable[[list[Score]], Any], aggregation)
206
+ getter = cast(Callable[[ScoresDict], Score], getter)
207
+
180
208
  if format == "wide":
181
209
  scores = {}
182
210
  for res in self.task_results:
183
211
  try:
184
212
  if use_fast:
185
213
  scores[res.task_name] = res._get_score_fast(
186
- splits=splits, # type: ignore
187
- languages=languages, # type: ignore
214
+ splits=splits,
215
+ languages=languages,
188
216
  )
189
217
  else:
190
218
  scores[res.task_name] = res.get_score(
191
219
  splits=splits,
192
220
  languages=languages,
193
- aggregation=aggregation, # type: ignore
194
- getter=getter, # type: ignore
221
+ aggregation=aggregation,
222
+ getter=getter,
195
223
  scripts=scripts,
196
224
  )
197
225
  except Exception as e:
@@ -206,14 +234,14 @@ class ModelResult(BaseModel):
206
234
  if use_fast:
207
235
  score = task_res._get_score_fast(
208
236
  splits=splits,
209
- languages=languages, # type: ignore
237
+ languages=languages,
210
238
  )
211
239
  else:
212
240
  score = task_res.get_score(
213
241
  splits=splits,
214
242
  languages=languages,
215
- aggregation=aggregation, # type: ignore
216
- getter=getter, # type: ignore
243
+ aggregation=aggregation,
244
+ getter=getter,
217
245
  scripts=scripts,
218
246
  )
219
247
  entry = dict(
@@ -292,7 +320,9 @@ class ModelResult(BaseModel):
292
320
  scores_data = self._get_score_for_table()
293
321
 
294
322
  if not scores_data:
295
- logger.warning("No scores data available. Returning empty DataFrame.")
323
+ msg = "No scores data available. Returning empty DataFrame."
324
+ logger.warning(msg)
325
+ warnings.warn(msg)
296
326
  return pd.DataFrame()
297
327
 
298
328
  # Create DataFrame
@@ -315,7 +345,7 @@ class ModelResult(BaseModel):
315
345
  def __hash__(self) -> int:
316
346
  return id(self)
317
347
 
318
- def __iter__(self) -> Iterable[TaskResult]:
348
+ def __iter__(self) -> Iterable[TaskResult]: # type: ignore[override]
319
349
  return iter(self.task_results)
320
350
 
321
351
  def __getitem__(self, index) -> TaskResult:
@@ -368,13 +398,13 @@ class ModelResult(BaseModel):
368
398
  return [task_res.task_name for task_res in self.task_results]
369
399
 
370
400
  @property
371
- def modalities(self) -> list[str]:
401
+ def modalities(self) -> list[Modalities]:
372
402
  """Get all modalities in the task results.
373
403
 
374
404
  Returns:
375
405
  A list of modalities in the task results.
376
406
  """
377
- mods = []
407
+ mods: list[Modalities] = []
378
408
  for task_res in self.task_results:
379
409
  task_modalities = getattr(task_res, "modalities", [])
380
410
  mods.extend(task_modalities)