mteb 2.5.3__py3-none-any.whl → 2.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. mteb/_create_dataloaders.py +10 -15
  2. mteb/_evaluators/any_sts_evaluator.py +1 -4
  3. mteb/_evaluators/evaluator.py +2 -1
  4. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
  5. mteb/_evaluators/pair_classification_evaluator.py +3 -1
  6. mteb/_evaluators/retrieval_metrics.py +17 -16
  7. mteb/_evaluators/sklearn_evaluator.py +9 -8
  8. mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
  9. mteb/_evaluators/text/summarization_evaluator.py +20 -16
  10. mteb/abstasks/_data_filter/filters.py +1 -1
  11. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  12. mteb/abstasks/_statistics_calculation.py +18 -10
  13. mteb/abstasks/_stratification.py +18 -18
  14. mteb/abstasks/abstask.py +27 -21
  15. mteb/abstasks/aggregate_task_metadata.py +1 -9
  16. mteb/abstasks/aggregated_task.py +3 -16
  17. mteb/abstasks/classification.py +10 -4
  18. mteb/abstasks/clustering.py +18 -14
  19. mteb/abstasks/clustering_legacy.py +8 -8
  20. mteb/abstasks/image/image_text_pair_classification.py +5 -3
  21. mteb/abstasks/multilabel_classification.py +20 -16
  22. mteb/abstasks/pair_classification.py +18 -9
  23. mteb/abstasks/regression.py +3 -3
  24. mteb/abstasks/retrieval.py +12 -9
  25. mteb/abstasks/sts.py +6 -3
  26. mteb/abstasks/task_metadata.py +20 -16
  27. mteb/abstasks/text/bitext_mining.py +36 -25
  28. mteb/abstasks/text/reranking.py +7 -5
  29. mteb/abstasks/text/summarization.py +8 -3
  30. mteb/abstasks/zeroshot_classification.py +5 -2
  31. mteb/benchmarks/benchmark.py +4 -2
  32. mteb/benchmarks/benchmarks/benchmarks.py +22 -1
  33. mteb/benchmarks/get_benchmark.py +14 -55
  34. mteb/cache.py +21 -18
  35. mteb/cli/_display_tasks.py +2 -2
  36. mteb/cli/build_cli.py +8 -8
  37. mteb/cli/generate_model_card.py +39 -20
  38. mteb/deprecated_evaluator.py +56 -43
  39. mteb/evaluate.py +35 -29
  40. mteb/filter_tasks.py +25 -26
  41. mteb/get_tasks.py +25 -27
  42. mteb/languages/language_scripts.py +5 -3
  43. mteb/leaderboard/app.py +1 -1
  44. mteb/load_results.py +12 -12
  45. mteb/models/abs_encoder.py +2 -2
  46. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  47. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  48. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +2 -1
  49. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +30 -13
  50. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  51. mteb/models/get_model_meta.py +8 -1
  52. mteb/models/instruct_wrapper.py +11 -5
  53. mteb/models/model_implementations/andersborges.py +2 -2
  54. mteb/models/model_implementations/blip_models.py +8 -8
  55. mteb/models/model_implementations/bm25.py +1 -1
  56. mteb/models/model_implementations/clip_models.py +3 -3
  57. mteb/models/model_implementations/cohere_models.py +1 -1
  58. mteb/models/model_implementations/cohere_v.py +2 -2
  59. mteb/models/model_implementations/dino_models.py +23 -23
  60. mteb/models/model_implementations/emillykkejensen_models.py +3 -3
  61. mteb/models/model_implementations/jina_clip.py +1 -1
  62. mteb/models/model_implementations/jina_models.py +1 -1
  63. mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
  64. mteb/models/model_implementations/llm2clip_models.py +3 -3
  65. mteb/models/model_implementations/moco_models.py +2 -2
  66. mteb/models/model_implementations/model2vec_models.py +1 -1
  67. mteb/models/model_implementations/nomic_models.py +8 -8
  68. mteb/models/model_implementations/openclip_models.py +7 -7
  69. mteb/models/model_implementations/random_baseline.py +3 -3
  70. mteb/models/model_implementations/rasgaard_models.py +1 -1
  71. mteb/models/model_implementations/repllama_models.py +2 -2
  72. mteb/models/model_implementations/rerankers_custom.py +3 -3
  73. mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
  74. mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +113 -146
  75. mteb/models/model_implementations/siglip_models.py +10 -10
  76. mteb/models/model_implementations/vlm2vec_models.py +1 -1
  77. mteb/models/model_implementations/voyage_v.py +4 -4
  78. mteb/models/model_meta.py +30 -14
  79. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +5 -5
  80. mteb/models/search_wrappers.py +22 -10
  81. mteb/models/sentence_transformer_wrapper.py +9 -4
  82. mteb/py.typed +0 -0
  83. mteb/results/benchmark_results.py +25 -19
  84. mteb/results/model_result.py +49 -21
  85. mteb/results/task_result.py +45 -51
  86. mteb/similarity_functions.py +11 -7
  87. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  88. mteb/tasks/classification/est/estonian_valence.py +1 -1
  89. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  90. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  91. mteb/tasks/retrieval/code/code_rag.py +12 -12
  92. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  93. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  94. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  95. mteb/tasks/retrieval/nob/norquad.py +2 -2
  96. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  97. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  98. mteb/types/_result.py +2 -1
  99. mteb/types/statistics.py +9 -3
  100. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/METADATA +1 -1
  101. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/RECORD +105 -104
  102. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/WHEEL +0 -0
  103. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/entry_points.txt +0 -0
  104. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/licenses/LICENSE +0 -0
  105. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/top_level.txt +0 -0
@@ -120,7 +120,7 @@ def openclip_loader(model_name, **kwargs):
120
120
 
121
121
 
122
122
  CLIP_ViT_L_14_DataComp_XL_s13B_b90K = ModelMeta(
123
- loader=openclip_loader, # type: ignore
123
+ loader=openclip_loader,
124
124
  name="laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K",
125
125
  model_type=["dense"],
126
126
  languages=["eng-Latn"],
@@ -146,7 +146,7 @@ CLIP_ViT_L_14_DataComp_XL_s13B_b90K = ModelMeta(
146
146
  )
147
147
 
148
148
  CLIP_ViT_B_32_DataComp_XL_s13B_b90K = ModelMeta(
149
- loader=openclip_loader, # type: ignore
149
+ loader=openclip_loader,
150
150
  name="laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K",
151
151
  model_type=["dense"],
152
152
  languages=["eng-Latn"],
@@ -172,7 +172,7 @@ CLIP_ViT_B_32_DataComp_XL_s13B_b90K = ModelMeta(
172
172
  )
173
173
 
174
174
  CLIP_ViT_B_16_DataComp_XL_s13B_b90K = ModelMeta(
175
- loader=openclip_loader, # type: ignore
175
+ loader=openclip_loader,
176
176
  name="laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K",
177
177
  model_type=["dense"],
178
178
  languages=["eng-Latn"],
@@ -198,7 +198,7 @@ CLIP_ViT_B_16_DataComp_XL_s13B_b90K = ModelMeta(
198
198
  )
199
199
 
200
200
  CLIP_ViT_bigG_14_laion2B_39B_b160k = ModelMeta(
201
- loader=openclip_loader, # type: ignore
201
+ loader=openclip_loader,
202
202
  name="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
203
203
  model_type=["dense"],
204
204
  languages=["eng-Latn"],
@@ -224,7 +224,7 @@ CLIP_ViT_bigG_14_laion2B_39B_b160k = ModelMeta(
224
224
  )
225
225
 
226
226
  CLIP_ViT_g_14_laion2B_s34B_b88K = ModelMeta(
227
- loader=openclip_loader, # type: ignore
227
+ loader=openclip_loader,
228
228
  name="laion/CLIP-ViT-g-14-laion2B-s34B-b88K",
229
229
  model_type=["dense"],
230
230
  languages=["eng-Latn"],
@@ -250,7 +250,7 @@ CLIP_ViT_g_14_laion2B_s34B_b88K = ModelMeta(
250
250
  )
251
251
 
252
252
  CLIP_ViT_H_14_laion2B_s32B_b79K = ModelMeta(
253
- loader=openclip_loader, # type: ignore
253
+ loader=openclip_loader,
254
254
  name="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
255
255
  model_type=["dense"],
256
256
  languages=["eng-Latn"],
@@ -276,7 +276,7 @@ CLIP_ViT_H_14_laion2B_s32B_b79K = ModelMeta(
276
276
  )
277
277
 
278
278
  CLIP_ViT_L_14_laion2B_s32B_b82K = ModelMeta(
279
- loader=openclip_loader, # type: ignore
279
+ loader=openclip_loader,
280
280
  name="laion/CLIP-ViT-L-14-laion2B-s32B-b82K",
281
281
  model_type=["dense"],
282
282
  languages=["eng-Latn"],
@@ -68,7 +68,7 @@ _common_mock_metadata = dict(
68
68
  license="mit",
69
69
  max_tokens=np.inf,
70
70
  reference=None,
71
- similarity_fn_name="cosine", # type: ignore
71
+ similarity_fn_name="cosine",
72
72
  framework=[],
73
73
  use_instructions=False,
74
74
  public_training_code=None, # No training code, as this is a random baseline
@@ -187,7 +187,7 @@ class RandomEncoderBaseline:
187
187
 
188
188
 
189
189
  random_encoder_baseline = ModelMeta(
190
- loader=RandomEncoderBaseline, # type: ignore
190
+ loader=RandomEncoderBaseline,
191
191
  name="baseline/random-encoder-baseline",
192
192
  model_type=["dense"],
193
193
  modalities=["text", "image"],
@@ -232,7 +232,7 @@ class RandomCrossEncoderBaseline:
232
232
 
233
233
 
234
234
  random_cross_encoder_baseline = ModelMeta(
235
- loader=RandomCrossEncoderBaseline, # type: ignore
235
+ loader=RandomCrossEncoderBaseline,
236
236
  name="baseline/random-cross-encoder-baseline",
237
237
  model_type=["cross-encoder"],
238
238
  modalities=["text", "image"],
@@ -4,7 +4,7 @@ from mteb.models.model_implementations.model2vec_models import Model2VecModel
4
4
  from mteb.models.model_meta import ModelMeta, ScoringFunction
5
5
 
6
6
  potion_base_8m = ModelMeta(
7
- loader=Model2VecModel, # type: ignore
7
+ loader=Model2VecModel,
8
8
  name="rasgaard/m2v-dfm-large",
9
9
  model_type=["dense"],
10
10
  languages=["dan-Latn"],
@@ -154,7 +154,7 @@ REPLLAMA_CITATION = """
154
154
  """
155
155
 
156
156
  repllama_llama2_original = ModelMeta(
157
- loader=RepLLaMAModel, # type: ignore
157
+ loader=RepLLaMAModel,
158
158
  loader_kwargs=dict(
159
159
  base_model_name_or_path="meta-llama/Llama-2-7b-hf",
160
160
  device_map="auto",
@@ -187,7 +187,7 @@ repllama_llama2_original = ModelMeta(
187
187
 
188
188
 
189
189
  repllama_llama2_reproduced = ModelMeta(
190
- loader=RepLLaMAModel, # type: ignore
190
+ loader=RepLLaMAModel,
191
191
  loader_kwargs=dict(
192
192
  base_model_name_or_path="meta-llama/Llama-2-7b-hf",
193
193
  device_map="auto",
@@ -214,7 +214,7 @@ class JinaReranker(RerankerWrapper):
214
214
 
215
215
 
216
216
  monobert_large = ModelMeta(
217
- loader=MonoBERTReranker, # type: ignore
217
+ loader=MonoBERTReranker,
218
218
  loader_kwargs=dict(
219
219
  fp_options="float16",
220
220
  ),
@@ -239,7 +239,7 @@ monobert_large = ModelMeta(
239
239
 
240
240
  # languages unclear: https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual/discussions/28
241
241
  jina_reranker_multilingual = ModelMeta(
242
- loader=JinaReranker, # type: ignore
242
+ loader=JinaReranker,
243
243
  loader_kwargs=dict(
244
244
  fp_options="float16",
245
245
  ),
@@ -263,7 +263,7 @@ jina_reranker_multilingual = ModelMeta(
263
263
  )
264
264
 
265
265
  bge_reranker_v2_m3 = ModelMeta(
266
- loader=BGEReranker, # type: ignore
266
+ loader=BGEReranker,
267
267
  loader_kwargs=dict(
268
268
  fp_options="float16",
269
269
  ),
@@ -343,7 +343,7 @@ monot5_small = ModelMeta(
343
343
  )
344
344
 
345
345
  monot5_base = ModelMeta(
346
- loader=MonoT5Reranker, # type: ignore
346
+ loader=MonoT5Reranker,
347
347
  loader_kwargs=dict(
348
348
  fp_options="float16",
349
349
  ),
@@ -442,7 +442,7 @@ monot5_3b = ModelMeta(
442
442
  )
443
443
 
444
444
  flant5_base = ModelMeta(
445
- loader=FLANT5Reranker, # type: ignore
445
+ loader=FLANT5Reranker,
446
446
  loader_kwargs=dict(
447
447
  fp_options="float16",
448
448
  ),
@@ -902,7 +902,7 @@ mt5_base_mmarco_v2 = ModelMeta(
902
902
  )
903
903
 
904
904
  mt5_13b_mmarco_100k = ModelMeta(
905
- loader=MonoT5Reranker, # type: ignore
905
+ loader=MonoT5Reranker,
906
906
  loader_kwargs=dict(
907
907
  fp_options="float16",
908
908
  ),
@@ -4,13 +4,15 @@ import base64
4
4
  import logging
5
5
  import os
6
6
  import time
7
- from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from functools import partial
8
9
  from io import BytesIO
9
10
  from typing import TYPE_CHECKING, Any
10
11
 
11
12
  import requests
12
13
  import torch
13
14
  from torch.utils.data import DataLoader
15
+ from tqdm import tqdm
14
16
 
15
17
  from mteb._requires_package import requires_package
16
18
  from mteb.abstasks.task_metadata import TaskMetadata
@@ -26,114 +28,6 @@ if TYPE_CHECKING:
26
28
 
27
29
  logger = logging.getLogger(__name__)
28
30
 
29
-
30
- def pil_to_base64(image, format="jpeg"):
31
- if image is None:
32
- return None
33
- buffer = BytesIO()
34
- image.save(buffer, format=format)
35
- img_bytes = buffer.getvalue()
36
- encoded_bytes = base64.b64encode(img_bytes)
37
- return encoded_bytes.decode("utf-8")
38
-
39
-
40
- def multimodal_embedding(image_base64=None, text_content=None):
41
- auth_token = os.getenv("VOLCES_AUTH_TOKEN")
42
- model_name = "doubao-embedding-vision-251215"
43
- api_url = "https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
44
-
45
- headers = {
46
- "Authorization": f"Bearer {auth_token}",
47
- "x-ark-vlm1": "true",
48
- "Content-Type": "application/json",
49
- }
50
-
51
- if image_base64 is not None and text_content is None:
52
- inputs = []
53
- for image in image_base64:
54
- image_format = "jpeg"
55
- image_data = f"data:image/{image_format};base64,{image}"
56
- inputs.append({"type": "image_url", "image_url": {"url": image_data}})
57
-
58
- payload = {"model": model_name, "input": inputs}
59
- elif image_base64 is None and text_content is not None:
60
- payload = {
61
- "model": model_name,
62
- "input": [
63
- {"type": "text", "text": text_content},
64
- ],
65
- }
66
- else:
67
- inputs = []
68
- for image in image_base64:
69
- image_format = "jpeg"
70
- image_data = f"data:image/{image_format};base64,{image}"
71
- inputs.append({"type": "image_url", "image_url": {"url": image_data}})
72
- inputs.append({"type": "text", "text": text_content})
73
- payload = {"model": model_name, "input": inputs}
74
-
75
- try:
76
- response = requests.post(url=api_url, headers=headers, json=payload, timeout=10)
77
-
78
- response.raise_for_status()
79
- return response.json()
80
-
81
- except requests.exceptions.HTTPError as http_err:
82
- logger.error(f"HTTP error ({http_err.response.status_code}): {http_err}")
83
- except requests.exceptions.JSONDecodeError:
84
- logger.error("Error:The response is not in valid JSON format")
85
- except requests.exceptions.Timeout:
86
- logger.error("Error:Request timeout")
87
- except Exception as e:
88
- logger.error(f"Unknown error: {str(e)}")
89
-
90
- return None
91
-
92
-
93
- def multi_thread_encode(sentences, batch_size=1, max_workers=8):
94
- batches = []
95
- for idx in range(0, len(sentences), batch_size):
96
- batches.append((idx // batch_size, sentences[idx : idx + batch_size]))
97
-
98
- n_batches = len(batches)
99
- results = [None] * n_batches # Pre-allocated result list
100
- all_embeddings = [] # Final ordered embeddings
101
-
102
- def _process_batch(batch_idx, batch_sentences):
103
- sentence = batch_sentences[0]
104
-
105
- retries = 5
106
- while retries > 0:
107
- try:
108
- resp = multimodal_embedding(text_content=sentence)
109
- embedding = torch.tensor(resp["data"]["embedding"])
110
- break
111
- except Exception as e:
112
- time.sleep(1)
113
- logger.warning(f"Retrying... {retries} retries left. Error: {str(e)}")
114
- retries -= 1
115
- if retries == 0:
116
- raise e
117
- return batch_idx, embedding
118
-
119
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
120
- futures = {
121
- executor.submit(_process_batch, idx, batch): idx for idx, batch in batches
122
- }
123
-
124
- for future in as_completed(futures):
125
- batch_idx, embeddings = future.result()
126
- results[batch_idx] = embeddings
127
-
128
- for batch_embeddings in results:
129
- all_embeddings.append(batch_embeddings)
130
-
131
- all_embeddings = torch.stack(all_embeddings, dim=0)
132
- all_embeddings = torch.nn.functional.normalize(all_embeddings, dim=-1)
133
-
134
- return all_embeddings.float().cpu()
135
-
136
-
137
31
  doubao_embedding_training_data = (
138
32
  {
139
33
  "PawsXPairClassification",
@@ -166,25 +60,80 @@ class Seed16EmbeddingWrapper(AbsEncoder):
166
60
  "pip install mteb[ark]",
167
61
  "tiktoken",
168
62
  )
169
- import tiktoken
170
63
 
171
64
  self._model_name = model_name
172
65
  self._max_tokens = 32768
173
66
  self._embed_dim = embed_dim
174
67
  self._available_embed_dims = [2048, 1024]
175
- self._encoding = tiktoken.get_encoding(tokenizer_name)
176
68
 
177
- def truncate_text_tokens(self, text: str) -> str:
178
- """Truncate a string to have `max_tokens` according to the given encoding.
69
+ def pil_to_base64(self, image, format="jpeg"):
70
+ if image is None:
71
+ return None
72
+ buffer = BytesIO()
73
+ image.save(buffer, format=format)
74
+ img_bytes = buffer.getvalue()
75
+ encoded_bytes = base64.b64encode(img_bytes)
76
+ return encoded_bytes.decode("utf-8")
77
+
78
+ def multimodal_embedding(self, instruction, image_base64, text_content):
79
+ auth_token = os.getenv("VOLCES_AUTH_TOKEN")
80
+ model_name = "doubao-embedding-vision-251215"
81
+ api_url = "https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
82
+
83
+ headers = {
84
+ "Authorization": f"Bearer {auth_token}",
85
+ "x-ark-vlm1": "true",
86
+ "Content-Type": "application/json",
87
+ }
179
88
 
180
- Args:
181
- text: The input string to be truncated.
89
+ if text_content is not None and len(text_content) > self._max_tokens:
90
+ text_content = text_content[: self._max_tokens]
91
+
92
+ if image_base64 is not None and text_content is None:
93
+ inputs = []
94
+ for image in image_base64:
95
+ image_format = "jpeg"
96
+ image_data = f"data:image/{image_format};base64,{image}"
97
+ inputs.append({"type": "image_url", "image_url": {"url": image_data}})
98
+
99
+ payload = {"model": model_name, "input": inputs}
100
+ elif image_base64 is None and text_content is not None:
101
+ payload = {
102
+ "model": model_name,
103
+ "instruction": instruction,
104
+ "input": [
105
+ {"type": "text", "text": text_content},
106
+ ],
107
+ }
108
+ else:
109
+ inputs = []
110
+ for image in image_base64:
111
+ image_format = "jpeg"
112
+ image_data = f"data:image/{image_format};base64,{image}"
113
+ inputs.append({"type": "image_url", "image_url": {"url": image_data}})
114
+ inputs.append({"type": "text", "text": text_content})
115
+ payload = {"model": model_name, "input": inputs}
116
+
117
+ max_retries = 3
118
+ retry_count = 0
119
+
120
+ while retry_count < max_retries:
121
+ response = requests.post(
122
+ url=api_url, headers=headers, json=payload, timeout=30
123
+ )
182
124
 
183
- Returns:
184
- The truncated string.
185
- """
186
- truncated_sentence = self._encoding.encode(text)[: self._max_tokens]
187
- return self._encoding.decode(truncated_sentence)
125
+ if response.status_code != 200:
126
+ retry_count += 1
127
+ time.sleep(3)
128
+ continue
129
+
130
+ response_json = response.json()
131
+ return response_json
132
+
133
+ raise Exception(
134
+ f"Request failed with status code {response.status_code}. "
135
+ f"Response: {response.text}"
136
+ )
188
137
 
189
138
  def get_fused_embeddings(
190
139
  self,
@@ -204,59 +153,69 @@ class Seed16EmbeddingWrapper(AbsEncoder):
204
153
  if images is not None and texts is not None:
205
154
  assert len(texts) == len(images)
206
155
  batch_len = len(texts)
207
- images_base64 = [pil_to_base64(image) for image in images]
156
+ images_base64 = [self.pil_to_base64(image) for image in images]
208
157
  elif images is None:
209
158
  batch_len = len(texts)
210
159
  images_base64 = [None for _ in range(batch_len)]
211
160
  elif texts is None:
212
161
  batch_len = len(images)
213
- images_base64 = [pil_to_base64(image) for image in images]
162
+ images_base64 = [self.pil_to_base64(image) for image in images]
214
163
  else:
215
164
  raise ValueError("images and texts cannot be None at the same time")
216
165
 
217
- outputs = []
218
- for i in range(batch_len):
166
+ def process_item(
167
+ i, prompt_type, task_name, texts, images_base64, multimodal_embedding
168
+ ):
219
169
  if (
220
170
  prompt_type == PromptType("query") or prompt_type is None
221
171
  ) and task_name in TASK_NAME_TO_INSTRUCTION:
222
172
  instruction = TASK_NAME_TO_INSTRUCTION[task_name]
223
173
  instruction = instruction.rstrip("{}").rstrip("\n")
224
- if texts[i] != "":
225
- input_text = (
226
- "Target_modality:Text.\n Instruction:"
227
- + instruction
228
- + "\n Query:{}"
229
- ).format(texts[i])
230
- else:
231
- input_text = (
232
- "Target_modality:Text.\n Instruction:"
233
- + instruction
234
- + "\n Query:"
235
- )
174
+ instruction = (
175
+ "Target_modality:Text.\n Instruction:" + instruction + "\n Query:"
176
+ )
177
+ input_text = texts[i]
236
178
  else:
237
179
  if texts[i] != "" and images_base64[i] is not None:
238
- instruction = "Instruction: Compress the the text and image into one word.\n Query: {}"
239
- input_text = instruction.format(texts[i])
180
+ instruction = "Instruction: Compress the text and image into one word.\n Query:"
181
+ input_text = texts[i]
240
182
  elif texts[i] != "":
241
183
  instruction = (
242
- "Instruction: Compress the the text into one word.\n Query: {}"
184
+ "Instruction: Compress the text into one word.\n Query:"
243
185
  )
244
- input_text = instruction.format(texts[i])
186
+ input_text = texts[i]
245
187
  elif images_base64[i] is not None:
246
188
  instruction = (
247
- "Instruction: Compress the the image into one word.\n Query:"
189
+ "Instruction: Compress the image into one word.\n Query:"
248
190
  )
249
- input_text = instruction
191
+ input_text = None
250
192
  else:
251
193
  raise ValueError("image and text are both None")
252
194
 
253
195
  resp = multimodal_embedding(
254
- image_base64=[images_base64[i]], text_content=input_text
196
+ instruction=instruction,
197
+ image_base64=images_base64[i],
198
+ text_content=input_text,
255
199
  )
256
200
  embedding = torch.tensor(resp["data"]["embedding"])
257
201
  embedding = torch.reshape(embedding, (1, -1))
202
+ return embedding
203
+
204
+ outputs = []
205
+ process_partial = partial(
206
+ process_item,
207
+ prompt_type=prompt_type,
208
+ task_name=task_name,
209
+ texts=texts,
210
+ images_base64=images_base64,
211
+ multimodal_embedding=self.multimodal_embedding,
212
+ )
213
+ with ThreadPoolExecutor(max_workers=15) as executor:
214
+ futures = [executor.submit(process_partial, i) for i in range(batch_len)]
215
+ for future in tqdm(futures, total=batch_len, desc="Encoding"):
216
+ outputs.append(future.result())
258
217
 
259
- outputs = torch.stack(outputs, dim=0)
218
+ outputs = torch.stack(outputs, dim=0).squeeze(1)
260
219
 
261
220
  if self._embed_dim is not None:
262
221
  outputs = outputs[:, : self._embed_dim]
@@ -273,13 +232,21 @@ class Seed16EmbeddingWrapper(AbsEncoder):
273
232
  prompt_type: PromptType | None = None,
274
233
  **kwargs: Any,
275
234
  ) -> Array:
276
- sentences = [text for batch in inputs for text in batch["text"]]
277
- images = [image for batch in inputs for image in batch["image"]]
235
+ if "text" in inputs.dataset.features:
236
+ sentences = [text for batch in inputs for text in batch["text"]]
237
+ else:
238
+ sentences = None
239
+
240
+ if "image" in inputs.dataset.features:
241
+ images = [image for batch in inputs for image in batch["image"]]
242
+ else:
243
+ images = None
278
244
 
279
245
  return self.get_fused_embeddings(
280
246
  texts=sentences,
281
247
  images=images,
282
248
  task_name=task_metadata.name,
249
+ prompt_type=prompt_type,
283
250
  **kwargs,
284
251
  )
285
252
 
@@ -123,7 +123,7 @@ siglip_training_datasets = set(
123
123
  )
124
124
 
125
125
  siglip_so400m_patch14_224 = ModelMeta(
126
- loader=SiglipModelWrapper, # type: ignore
126
+ loader=SiglipModelWrapper,
127
127
  name="google/siglip-so400m-patch14-224",
128
128
  model_type=["dense"],
129
129
  languages=["eng-Latn"],
@@ -147,7 +147,7 @@ siglip_so400m_patch14_224 = ModelMeta(
147
147
  )
148
148
 
149
149
  siglip_so400m_patch14_384 = ModelMeta(
150
- loader=SiglipModelWrapper, # type: ignore
150
+ loader=SiglipModelWrapper,
151
151
  name="google/siglip-so400m-patch14-384",
152
152
  model_type=["dense"],
153
153
  languages=["eng-Latn"],
@@ -171,7 +171,7 @@ siglip_so400m_patch14_384 = ModelMeta(
171
171
  )
172
172
 
173
173
  siglip_so400m_patch16_256_i18n = ModelMeta(
174
- loader=SiglipModelWrapper, # type: ignore
174
+ loader=SiglipModelWrapper,
175
175
  name="google/siglip-so400m-patch16-256-i18n",
176
176
  model_type=["dense"],
177
177
  languages=["eng-Latn"],
@@ -195,7 +195,7 @@ siglip_so400m_patch16_256_i18n = ModelMeta(
195
195
  )
196
196
 
197
197
  siglip_base_patch16_256_multilingual = ModelMeta(
198
- loader=SiglipModelWrapper, # type: ignore
198
+ loader=SiglipModelWrapper,
199
199
  name="google/siglip-base-patch16-256-multilingual",
200
200
  model_type=["dense"],
201
201
  languages=["eng-Latn"],
@@ -219,7 +219,7 @@ siglip_base_patch16_256_multilingual = ModelMeta(
219
219
  )
220
220
 
221
221
  siglip_base_patch16_256 = ModelMeta(
222
- loader=SiglipModelWrapper, # type: ignore
222
+ loader=SiglipModelWrapper,
223
223
  name="google/siglip-base-patch16-256",
224
224
  model_type=["dense"],
225
225
  languages=["eng-Latn"],
@@ -243,7 +243,7 @@ siglip_base_patch16_256 = ModelMeta(
243
243
  )
244
244
 
245
245
  siglip_base_patch16_512 = ModelMeta(
246
- loader=SiglipModelWrapper, # type: ignore
246
+ loader=SiglipModelWrapper,
247
247
  name="google/siglip-base-patch16-512",
248
248
  model_type=["dense"],
249
249
  languages=["eng-Latn"],
@@ -267,7 +267,7 @@ siglip_base_patch16_512 = ModelMeta(
267
267
  )
268
268
 
269
269
  siglip_base_patch16_384 = ModelMeta(
270
- loader=SiglipModelWrapper, # type: ignore
270
+ loader=SiglipModelWrapper,
271
271
  name="google/siglip-base-patch16-384",
272
272
  model_type=["dense"],
273
273
  languages=["eng-Latn"],
@@ -291,7 +291,7 @@ siglip_base_patch16_384 = ModelMeta(
291
291
  )
292
292
 
293
293
  siglip_base_patch16_224 = ModelMeta(
294
- loader=SiglipModelWrapper, # type: ignore
294
+ loader=SiglipModelWrapper,
295
295
  name="google/siglip-base-patch16-224",
296
296
  model_type=["dense"],
297
297
  languages=["eng-Latn"],
@@ -315,7 +315,7 @@ siglip_base_patch16_224 = ModelMeta(
315
315
  )
316
316
 
317
317
  siglip_large_patch16_256 = ModelMeta(
318
- loader=SiglipModelWrapper, # type: ignore
318
+ loader=SiglipModelWrapper,
319
319
  name="google/siglip-large-patch16-256",
320
320
  model_type=["dense"],
321
321
  languages=["eng-Latn"],
@@ -339,7 +339,7 @@ siglip_large_patch16_256 = ModelMeta(
339
339
  )
340
340
 
341
341
  siglip_large_patch16_384 = ModelMeta(
342
- loader=SiglipModelWrapper, # type: ignore
342
+ loader=SiglipModelWrapper,
343
343
  name="google/siglip-large-patch16-384",
344
344
  model_type=["dense"],
345
345
  languages=["eng-Latn"],
@@ -41,7 +41,7 @@ class VLM2VecWrapper(AbsEncoder):
41
41
  model_name,
42
42
  "pip install flash-attn --no-build-isolation",
43
43
  ):
44
- import flash_attn # noqa
44
+ pass
45
45
 
46
46
  requires_package(self, "peft", model_name, "pip install 'mteb[peft]'")
47
47
  from peft import LoraConfig, PeftModel
@@ -40,15 +40,15 @@ def _downsample_image(
40
40
  logging.info(
41
41
  f"Downsampling image from {width}x{height} to {new_width}x{new_height}"
42
42
  )
43
- return image.resize(new_size, Image.LANCZOS) # type: ignore
43
+ return image.resize(new_size, Image.LANCZOS)
44
44
  if width > height:
45
45
  if width > 10000:
46
46
  logging.error("Processing extremely wide images.")
47
- return image.resize((10000, height), Image.LANCZOS) # type: ignore
47
+ return image.resize((10000, height), Image.LANCZOS)
48
48
  else:
49
49
  if height > 10000:
50
50
  logging.error("Processing extremely high images.")
51
- return image.resize((width, 10000), Image.LANCZOS) # type: ignore
51
+ return image.resize((width, 10000), Image.LANCZOS)
52
52
  return image
53
53
 
54
54
 
@@ -202,7 +202,7 @@ def voyage_v_loader(model_name, **kwargs):
202
202
 
203
203
 
204
204
  voyage_v = ModelMeta(
205
- loader=voyage_v_loader, # type: ignore
205
+ loader=voyage_v_loader,
206
206
  name="voyageai/voyage-multimodal-3",
207
207
  model_type=["dense"],
208
208
  languages=[], # Unknown