mteb 2.5.3__py3-none-any.whl → 2.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. mteb/_create_dataloaders.py +10 -15
  2. mteb/_evaluators/any_sts_evaluator.py +1 -4
  3. mteb/_evaluators/evaluator.py +2 -1
  4. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
  5. mteb/_evaluators/pair_classification_evaluator.py +3 -1
  6. mteb/_evaluators/retrieval_metrics.py +17 -16
  7. mteb/_evaluators/sklearn_evaluator.py +9 -8
  8. mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
  9. mteb/_evaluators/text/summarization_evaluator.py +20 -16
  10. mteb/abstasks/_data_filter/filters.py +1 -1
  11. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  12. mteb/abstasks/_statistics_calculation.py +18 -10
  13. mteb/abstasks/_stratification.py +18 -18
  14. mteb/abstasks/abstask.py +27 -21
  15. mteb/abstasks/aggregate_task_metadata.py +1 -9
  16. mteb/abstasks/aggregated_task.py +3 -16
  17. mteb/abstasks/classification.py +10 -4
  18. mteb/abstasks/clustering.py +18 -14
  19. mteb/abstasks/clustering_legacy.py +8 -8
  20. mteb/abstasks/image/image_text_pair_classification.py +5 -3
  21. mteb/abstasks/multilabel_classification.py +20 -16
  22. mteb/abstasks/pair_classification.py +18 -9
  23. mteb/abstasks/regression.py +3 -3
  24. mteb/abstasks/retrieval.py +12 -9
  25. mteb/abstasks/sts.py +6 -3
  26. mteb/abstasks/task_metadata.py +20 -16
  27. mteb/abstasks/text/bitext_mining.py +36 -25
  28. mteb/abstasks/text/reranking.py +7 -5
  29. mteb/abstasks/text/summarization.py +8 -3
  30. mteb/abstasks/zeroshot_classification.py +5 -2
  31. mteb/benchmarks/benchmark.py +4 -2
  32. mteb/benchmarks/benchmarks/benchmarks.py +22 -1
  33. mteb/benchmarks/get_benchmark.py +14 -55
  34. mteb/cache.py +21 -18
  35. mteb/cli/_display_tasks.py +2 -2
  36. mteb/cli/build_cli.py +8 -8
  37. mteb/cli/generate_model_card.py +39 -20
  38. mteb/deprecated_evaluator.py +56 -43
  39. mteb/evaluate.py +35 -29
  40. mteb/filter_tasks.py +25 -26
  41. mteb/get_tasks.py +25 -27
  42. mteb/languages/language_scripts.py +5 -3
  43. mteb/leaderboard/app.py +1 -1
  44. mteb/load_results.py +12 -12
  45. mteb/models/abs_encoder.py +2 -2
  46. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  47. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  48. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +2 -1
  49. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +30 -13
  50. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  51. mteb/models/get_model_meta.py +8 -1
  52. mteb/models/instruct_wrapper.py +11 -5
  53. mteb/models/model_implementations/andersborges.py +2 -2
  54. mteb/models/model_implementations/blip_models.py +8 -8
  55. mteb/models/model_implementations/bm25.py +1 -1
  56. mteb/models/model_implementations/clip_models.py +3 -3
  57. mteb/models/model_implementations/cohere_models.py +1 -1
  58. mteb/models/model_implementations/cohere_v.py +2 -2
  59. mteb/models/model_implementations/dino_models.py +23 -23
  60. mteb/models/model_implementations/emillykkejensen_models.py +3 -3
  61. mteb/models/model_implementations/jina_clip.py +1 -1
  62. mteb/models/model_implementations/jina_models.py +1 -1
  63. mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
  64. mteb/models/model_implementations/llm2clip_models.py +3 -3
  65. mteb/models/model_implementations/moco_models.py +2 -2
  66. mteb/models/model_implementations/model2vec_models.py +1 -1
  67. mteb/models/model_implementations/nomic_models.py +8 -8
  68. mteb/models/model_implementations/openclip_models.py +7 -7
  69. mteb/models/model_implementations/random_baseline.py +3 -3
  70. mteb/models/model_implementations/rasgaard_models.py +1 -1
  71. mteb/models/model_implementations/repllama_models.py +2 -2
  72. mteb/models/model_implementations/rerankers_custom.py +3 -3
  73. mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
  74. mteb/models/model_implementations/seed_1_6_embedding_models_1215.py +113 -146
  75. mteb/models/model_implementations/siglip_models.py +10 -10
  76. mteb/models/model_implementations/vlm2vec_models.py +1 -1
  77. mteb/models/model_implementations/voyage_v.py +4 -4
  78. mteb/models/model_meta.py +30 -14
  79. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +5 -5
  80. mteb/models/search_wrappers.py +22 -10
  81. mteb/models/sentence_transformer_wrapper.py +9 -4
  82. mteb/py.typed +0 -0
  83. mteb/results/benchmark_results.py +25 -19
  84. mteb/results/model_result.py +49 -21
  85. mteb/results/task_result.py +45 -51
  86. mteb/similarity_functions.py +11 -7
  87. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  88. mteb/tasks/classification/est/estonian_valence.py +1 -1
  89. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  90. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  91. mteb/tasks/retrieval/code/code_rag.py +12 -12
  92. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  93. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  94. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  95. mteb/tasks/retrieval/nob/norquad.py +2 -2
  96. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  97. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  98. mteb/types/_result.py +2 -1
  99. mteb/types/statistics.py +9 -3
  100. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/METADATA +1 -1
  101. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/RECORD +105 -104
  102. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/WHEEL +0 -0
  103. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/entry_points.txt +0 -0
  104. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/licenses/LICENSE +0 -0
  105. {mteb-2.5.3.dist-info → mteb-2.5.5.dist-info}/top_level.txt +0 -0
mteb/evaluate.py CHANGED
@@ -14,11 +14,10 @@ from mteb._helpful_enum import HelpfulStrEnum
14
14
  from mteb.abstasks import AbsTaskRetrieval
15
15
  from mteb.abstasks.abstask import AbsTask
16
16
  from mteb.abstasks.aggregated_task import AbsTaskAggregate
17
+ from mteb.benchmarks.benchmark import Benchmark
17
18
  from mteb.cache import ResultCache
18
19
  from mteb.models.model_meta import ModelMeta
19
20
  from mteb.models.models_protocols import (
20
- CrossEncoderProtocol,
21
- EncoderProtocol,
22
21
  MTEBModels,
23
22
  )
24
23
  from mteb.models.sentence_transformer_wrapper import (
@@ -58,27 +57,26 @@ def _sanitize_model(
58
57
  ) -> tuple[MTEBModels | ModelMeta, ModelMeta, ModelName, Revision]:
59
58
  from sentence_transformers import CrossEncoder, SentenceTransformer
60
59
 
60
+ wrapped_model: MTEBModels | ModelMeta
61
61
  if isinstance(model, SentenceTransformer):
62
- _mdl = SentenceTransformerEncoderWrapper(model)
63
- meta = _mdl.mteb_model_meta
64
- _mdl = cast(EncoderProtocol, _mdl)
65
- model = _mdl
62
+ wrapped_model = SentenceTransformerEncoderWrapper(model)
63
+ meta = wrapped_model.mteb_model_meta
66
64
  elif isinstance(model, CrossEncoder):
67
- _mdl = CrossEncoderWrapper(model)
68
- _mdl = cast(CrossEncoderProtocol, _mdl)
69
- meta = _mdl.mteb_model_meta
70
- model = _mdl
65
+ wrapped_model = CrossEncoderWrapper(model)
66
+ meta = wrapped_model.mteb_model_meta
71
67
  elif hasattr(model, "mteb_model_meta"):
72
- meta = model.mteb_model_meta # type: ignore[attr-defined]
68
+ meta = getattr(model, "mteb_model_meta")
73
69
  if not isinstance(meta, ModelMeta):
74
- meta = ModelMeta.from_hub(None)
70
+ meta = ModelMeta._from_hub(None)
71
+ wrapped_model = cast(MTEBModels | ModelMeta, model)
75
72
  else:
76
- meta = ModelMeta.from_hub(None) if not isinstance(model, ModelMeta) else model
73
+ meta = ModelMeta._from_hub(None) if not isinstance(model, ModelMeta) else model
74
+ wrapped_model = meta
77
75
 
78
76
  model_name = cast(str, meta.name)
79
77
  model_revision = cast(str, meta.revision)
80
78
 
81
- return model, meta, model_name, model_revision
79
+ return wrapped_model, meta, model_name, model_revision
82
80
 
83
81
 
84
82
  def _evaluate_task(
@@ -124,7 +122,8 @@ def _evaluate_task(
124
122
  prediction_folder=prediction_folder,
125
123
  public_only=public_only,
126
124
  )
127
- result.kg_co2_emissions = tracker.final_emissions
125
+ if isinstance(result, TaskResult):
126
+ result.kg_co2_emissions = tracker.final_emissions
128
127
  return result
129
128
 
130
129
  task_results = {}
@@ -150,7 +149,7 @@ def _evaluate_task(
150
149
  if public_only is False:
151
150
  raise e
152
151
 
153
- evaluation_time = 0
152
+ evaluation_time = 0.0
154
153
 
155
154
  for split, hf_subsets in splits.items():
156
155
  tick = time()
@@ -197,12 +196,18 @@ def _check_model_modalities(
197
196
  return
198
197
 
199
198
  model_modalities = set(model.modalities)
199
+ check_tasks: Iterable[AbsTask] = []
200
200
  if isinstance(tasks, AbsTask):
201
- tasks = [tasks]
201
+ check_tasks = [tasks]
202
+ elif isinstance(tasks, Benchmark):
203
+ benchmark = cast(Benchmark, tasks)
204
+ check_tasks = benchmark.tasks
205
+ else:
206
+ check_tasks = cast(Iterable[AbsTask], tasks)
202
207
 
203
208
  warnings, errors = [], []
204
209
 
205
- for task in tasks:
210
+ for task in check_tasks:
206
211
  # only retrieval tasks have different modalities for query and document and can be run with partial overlaps
207
212
  if isinstance(task, AbsTaskRetrieval):
208
213
  query_mods = set(task.metadata.get_modalities(PromptType.query))
@@ -335,10 +340,10 @@ def evaluate(
335
340
 
336
341
  # AbsTaskAggregate is a special case where we have to run multiple tasks and combine the results
337
342
  if isinstance(tasks, AbsTaskAggregate):
338
- task = cast(AbsTaskAggregate, tasks)
343
+ aggregated_task = cast(AbsTaskAggregate, tasks)
339
344
  results = evaluate(
340
345
  model,
341
- task.metadata.tasks,
346
+ aggregated_task.metadata.tasks,
342
347
  co2_tracker=co2_tracker,
343
348
  raise_error=raise_error,
344
349
  encode_kwargs=encode_kwargs,
@@ -348,17 +353,18 @@ def evaluate(
348
353
  show_progress_bar=show_progress_bar,
349
354
  public_only=public_only,
350
355
  )
351
- result = task.combine_task_results(results.task_results)
356
+ combined_results = aggregated_task.combine_task_results(results.task_results)
352
357
  return ModelResult(
353
358
  model_name=results.model_name,
354
359
  model_revision=results.model_revision,
355
- task_results=[result],
360
+ task_results=[combined_results],
356
361
  )
357
362
 
358
363
  if isinstance(tasks, AbsTask):
359
364
  task = tasks
360
365
  else:
361
- results = []
366
+ tasks = cast(Iterable[AbsTask], tasks)
367
+ evaluate_results = []
362
368
  exceptions = []
363
369
  tasks_tqdm = tqdm(
364
370
  tasks,
@@ -379,23 +385,23 @@ def evaluate(
379
385
  show_progress_bar=False,
380
386
  public_only=public_only,
381
387
  )
382
- results.extend(_res.task_results)
388
+ evaluate_results.extend(_res.task_results)
383
389
  if _res.exceptions:
384
390
  exceptions.extend(_res.exceptions)
385
391
  return ModelResult(
386
392
  model_name=_res.model_name,
387
393
  model_revision=_res.model_revision,
388
- task_results=results,
394
+ task_results=evaluate_results,
389
395
  exceptions=exceptions,
390
396
  )
391
397
 
392
398
  overwrite_strategy = OverwriteStrategy.from_str(overwrite_strategy)
393
399
 
394
- existing_results = None
400
+ existing_results: TaskResult | None = None
395
401
  if cache and overwrite_strategy != OverwriteStrategy.ALWAYS:
396
- results = cache.load_task_result(task.metadata.name, meta)
397
- if results:
398
- existing_results = results
402
+ cache_results = cache.load_task_result(task.metadata.name, meta)
403
+ if cache_results:
404
+ existing_results = cache_results
399
405
 
400
406
  if (
401
407
  existing_results
mteb/filter_tasks.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """This script contains functions that are used to get an overview of the MTEB benchmark."""
2
2
 
3
3
  import logging
4
- from collections.abc import Sequence
4
+ from collections.abc import Iterable, Sequence
5
5
  from typing import overload
6
6
 
7
7
  from mteb.abstasks import (
@@ -34,14 +34,14 @@ def _check_is_valid_language(lang: str) -> None:
34
34
 
35
35
  @overload
36
36
  def filter_tasks(
37
- tasks: Sequence[AbsTask],
37
+ tasks: Iterable[AbsTask],
38
38
  *,
39
- languages: list[str] | None = None,
40
- script: list[str] | None = None,
41
- domains: list[TaskDomain] | None = None,
42
- task_types: list[TaskType] | None = None, # type: ignore
43
- categories: list[TaskCategory] | None = None,
44
- modalities: list[Modalities] | None = None,
39
+ languages: Sequence[str] | None = None,
40
+ script: Sequence[str] | None = None,
41
+ domains: Iterable[TaskDomain] | None = None,
42
+ task_types: Iterable[TaskType] | None = None,
43
+ categories: Iterable[TaskCategory] | None = None,
44
+ modalities: Iterable[Modalities] | None = None,
45
45
  exclusive_modality_filter: bool = False,
46
46
  exclude_superseded: bool = False,
47
47
  exclude_aggregate: bool = False,
@@ -51,14 +51,14 @@ def filter_tasks(
51
51
 
52
52
  @overload
53
53
  def filter_tasks(
54
- tasks: Sequence[type[AbsTask]],
54
+ tasks: Iterable[type[AbsTask]],
55
55
  *,
56
- languages: list[str] | None = None,
57
- script: list[str] | None = None,
58
- domains: list[TaskDomain] | None = None,
59
- task_types: list[TaskType] | None = None, # type: ignore
60
- categories: list[TaskCategory] | None = None,
61
- modalities: list[Modalities] | None = None,
56
+ languages: Sequence[str] | None = None,
57
+ script: Sequence[str] | None = None,
58
+ domains: Iterable[TaskDomain] | None = None,
59
+ task_types: Iterable[TaskType] | None = None,
60
+ categories: Iterable[TaskCategory] | None = None,
61
+ modalities: Iterable[Modalities] | None = None,
62
62
  exclusive_modality_filter: bool = False,
63
63
  exclude_superseded: bool = False,
64
64
  exclude_aggregate: bool = False,
@@ -67,14 +67,14 @@ def filter_tasks(
67
67
 
68
68
 
69
69
  def filter_tasks(
70
- tasks: Sequence[AbsTask] | Sequence[type[AbsTask]],
70
+ tasks: Iterable[AbsTask] | Iterable[type[AbsTask]],
71
71
  *,
72
- languages: list[str] | None = None,
73
- script: list[str] | None = None,
74
- domains: list[TaskDomain] | None = None,
75
- task_types: list[TaskType] | None = None, # type: ignore
76
- categories: list[TaskCategory] | None = None,
77
- modalities: list[Modalities] | None = None,
72
+ languages: Sequence[str] | None = None,
73
+ script: Sequence[str] | None = None,
74
+ domains: Iterable[TaskDomain] | None = None,
75
+ task_types: Iterable[TaskType] | None = None,
76
+ categories: Iterable[TaskCategory] | None = None,
77
+ modalities: Iterable[Modalities] | None = None,
78
78
  exclusive_modality_filter: bool = False,
79
79
  exclude_superseded: bool = False,
80
80
  exclude_aggregate: bool = False,
@@ -92,7 +92,6 @@ def filter_tasks(
92
92
  task_types: A string specifying the type of task e.g. "Classification" or "Retrieval". If None, all tasks are included.
93
93
  categories: A list of task categories these include "t2t" (text to text), "t2i" (text to image). See TaskMetadata for the full list.
94
94
  exclude_superseded: A boolean flag to exclude datasets which are superseded by another.
95
- eval_splits: A list of evaluation splits to include. If None, all splits are included.
96
95
  modalities: A list of modalities to include. If None, all modalities are included.
97
96
  exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
98
97
  task's modalities and ALL task modalities are in filter modalities (exact match).
@@ -113,12 +112,12 @@ def filter_tasks(
113
112
  """
114
113
  langs_to_keep = None
115
114
  if languages:
116
- [_check_is_valid_language(lang) for lang in languages]
115
+ [_check_is_valid_language(lang) for lang in languages] # type: ignore[func-returns-value]
117
116
  langs_to_keep = set(languages)
118
117
 
119
118
  script_to_keep = None
120
119
  if script:
121
- [_check_is_valid_script(s) for s in script]
120
+ [_check_is_valid_script(s) for s in script] # type: ignore[func-returns-value]
122
121
  script_to_keep = set(script)
123
122
 
124
123
  domains_to_keep = None
@@ -178,4 +177,4 @@ def filter_tasks(
178
177
 
179
178
  _tasks.append(t)
180
179
 
181
- return _tasks
180
+ return _tasks # type: ignore[return-value] # type checker cannot infer the overload return type
mteb/get_tasks.py CHANGED
@@ -4,7 +4,7 @@ import difflib
4
4
  import logging
5
5
  import warnings
6
6
  from collections import Counter, defaultdict
7
- from collections.abc import Sequence
7
+ from collections.abc import Iterable, Sequence
8
8
  from typing import Any
9
9
 
10
10
  import pandas as pd
@@ -23,12 +23,11 @@ logger = logging.getLogger(__name__)
23
23
  def _gather_tasks() -> tuple[type[AbsTask], ...]:
24
24
  import mteb.tasks as tasks
25
25
 
26
- tasks = [
26
+ return tuple(
27
27
  t
28
28
  for t in tasks.__dict__.values()
29
29
  if isinstance(t, type) and issubclass(t, AbsTask)
30
- ]
31
- return tuple(tasks)
30
+ )
32
31
 
33
32
 
34
33
  def _create_name_to_task_mapping(
@@ -44,7 +43,7 @@ def _create_name_to_task_mapping(
44
43
  return metadata_names
45
44
 
46
45
 
47
- def _create_similar_tasks(tasks: Sequence[type[AbsTask]]) -> dict[str, list[str]]:
46
+ def _create_similar_tasks(tasks: Iterable[type[AbsTask]]) -> dict[str, list[str]]:
48
47
  """Create a dictionary of similar tasks.
49
48
 
50
49
  Returns:
@@ -195,9 +194,8 @@ class MTEBTasks(tuple[AbsTask]):
195
194
  string with a LaTeX table.
196
195
  """
197
196
  if include_citation_in_name and "name" in properties:
198
- properties += ["intext_citation"]
199
- df = self.to_dataframe(properties)
200
- df["name"] = df["name"] + " " + df["intext_citation"]
197
+ df = self.to_dataframe(tuple(properties) + ("intext_citation",))
198
+ df["name"] = df["name"] + " " + df["intext_citation"] # type: ignore[operator]
201
199
  df = df.drop(columns=["intext_citation"])
202
200
  else:
203
201
  df = self.to_dataframe(properties)
@@ -222,17 +220,17 @@ class MTEBTasks(tuple[AbsTask]):
222
220
 
223
221
 
224
222
  def get_tasks(
225
- tasks: list[str] | None = None,
223
+ tasks: Sequence[str] | None = None,
226
224
  *,
227
- languages: list[str] | None = None,
228
- script: list[str] | None = None,
229
- domains: list[TaskDomain] | None = None,
230
- task_types: list[TaskType] | None = None, # type: ignore
231
- categories: list[TaskCategory] | None = None,
225
+ languages: Sequence[str] | None = None,
226
+ script: Sequence[str] | None = None,
227
+ domains: Sequence[TaskDomain] | None = None,
228
+ task_types: Sequence[TaskType] | None = None,
229
+ categories: Sequence[TaskCategory] | None = None,
232
230
  exclude_superseded: bool = True,
233
- eval_splits: list[str] | None = None,
231
+ eval_splits: Sequence[str] | None = None,
234
232
  exclusive_language_filter: bool = False,
235
- modalities: list[Modalities] | None = None,
233
+ modalities: Sequence[Modalities] | None = None,
236
234
  exclusive_modality_filter: bool = False,
237
235
  exclude_aggregate: bool = False,
238
236
  exclude_private: bool = True,
@@ -288,7 +286,7 @@ def get_tasks(
288
286
  ]
289
287
  return MTEBTasks(_tasks)
290
288
 
291
- _tasks = filter_tasks(
289
+ tasks_: Sequence[type[AbsTask]] = filter_tasks(
292
290
  TASK_LIST,
293
291
  languages=languages,
294
292
  script=script,
@@ -301,12 +299,12 @@ def get_tasks(
301
299
  exclude_aggregate=exclude_aggregate,
302
300
  exclude_private=exclude_private,
303
301
  )
304
- _tasks = [
305
- cls().filter_languages(languages, script).filter_eval_splits(eval_splits)
306
- for cls in _tasks
307
- ]
308
-
309
- return MTEBTasks(_tasks)
302
+ return MTEBTasks(
303
+ [
304
+ cls().filter_languages(languages, script).filter_eval_splits(eval_splits)
305
+ for cls in tasks_
306
+ ]
307
+ )
310
308
 
311
309
 
312
310
  _TASK_RENAMES = {"PersianTextTone": "SynPerTextToneClassification"}
@@ -314,10 +312,10 @@ _TASK_RENAMES = {"PersianTextTone": "SynPerTextToneClassification"}
314
312
 
315
313
  def get_task(
316
314
  task_name: str,
317
- languages: list[str] | None = None,
318
- script: list[str] | None = None,
319
- eval_splits: list[str] | None = None,
320
- hf_subsets: list[str] | None = None,
315
+ languages: Sequence[str] | None = None,
316
+ script: Sequence[str] | None = None,
317
+ eval_splits: Sequence[str] | None = None,
318
+ hf_subsets: Sequence[str] | None = None,
321
319
  exclusive_language_filter: bool = False,
322
320
  ) -> AbsTask:
323
321
  """Get a task by name.
@@ -1,9 +1,9 @@
1
- from collections.abc import Iterable
1
+ from collections.abc import Iterable, Sequence
2
2
  from dataclasses import dataclass
3
3
 
4
4
  from typing_extensions import Self
5
5
 
6
- from mteb.languages import check_language_code
6
+ from mteb.languages.check_language_code import check_language_code
7
7
 
8
8
 
9
9
  @dataclass
@@ -25,7 +25,9 @@ class LanguageScripts:
25
25
 
26
26
  @classmethod
27
27
  def from_languages_and_scripts(
28
- cls, languages: list[str] | None = None, scripts: list[str] | None = None
28
+ cls,
29
+ languages: Sequence[str] | None = None,
30
+ scripts: Sequence[str] | None = None,
29
31
  ) -> Self:
30
32
  """Create a LanguageScripts object from lists of languages and scripts.
31
33
 
mteb/leaderboard/app.py CHANGED
@@ -169,7 +169,7 @@ def _update_task_info(task_names: str) -> gr.DataFrame:
169
169
  df = df.drop(columns="reference")
170
170
  return gr.DataFrame(
171
171
  df,
172
- datatype=["markdown"] + ["str"] * (len(df.columns) - 1), # type: ignore
172
+ datatype=["markdown"] + ["str"] * (len(df.columns) - 1),
173
173
  buttons=["copy", "fullscreen"],
174
174
  show_search="filter",
175
175
  )
mteb/load_results.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import sys
4
- from collections.abc import Sequence
4
+ from collections.abc import Iterable, Sequence
5
5
  from pathlib import Path
6
6
 
7
7
  from mteb.abstasks.abstask import AbsTask
@@ -45,8 +45,8 @@ def _model_name_and_revision(
45
45
  def load_results(
46
46
  results_repo: str = "https://github.com/embeddings-benchmark/results",
47
47
  download_latest: bool = True,
48
- models: Sequence[ModelMeta] | Sequence[str] | None = None,
49
- tasks: Sequence[AbsTask] | Sequence[str] | None = None,
48
+ models: Iterable[ModelMeta] | Sequence[str] | None = None,
49
+ tasks: Iterable[AbsTask] | Sequence[str] | None = None,
50
50
  validate_and_filter: bool = True,
51
51
  require_model_meta: bool = True,
52
52
  only_main_score: bool = False,
@@ -83,21 +83,21 @@ def load_results(
83
83
 
84
84
  if models is not None:
85
85
  models_to_keep = {}
86
- for model_path in models:
87
- if isinstance(model_path, ModelMeta):
88
- models_to_keep[model_path.name] = model_path.revision
86
+ for model in models:
87
+ if isinstance(model, ModelMeta):
88
+ models_to_keep[model.name] = model.revision
89
89
  else:
90
- models_to_keep[model_path] = None
90
+ models_to_keep[model] = None
91
91
  else:
92
92
  models_to_keep = None
93
93
 
94
- task_names = {}
94
+ task_names: dict[str, AbsTask | None] = {}
95
95
  if tasks is not None:
96
- for task in tasks:
97
- if isinstance(task, AbsTask):
98
- task_names[task.metadata.name] = task
96
+ for task_ in tasks:
97
+ if isinstance(task_, AbsTask):
98
+ task_names[task_.metadata.name] = task_
99
99
  else:
100
- task_names[task] = None
100
+ task_names[task_] = None
101
101
 
102
102
  model_results = []
103
103
  for model_path in model_paths:
@@ -44,7 +44,7 @@ class AbsEncoder(ABC):
44
44
  model: Any
45
45
  mteb_model_meta: ModelMeta | None = None
46
46
  model_prompts: dict[str, str] | None = None
47
- instruction_template: str | Callable[[str, PromptType], str] | None = None
47
+ instruction_template: str | Callable[[str, PromptType | None], str] | None = None
48
48
  prompts_dict: dict[str, str] | None = None
49
49
 
50
50
  def get_prompt_name(
@@ -111,7 +111,7 @@ class AbsEncoder(ABC):
111
111
  if not self.model_prompts:
112
112
  return None
113
113
  prompt_name = self.get_prompt_name(task_metadata, prompt_type)
114
- return self.model_prompts.get(prompt_name)
114
+ return self.model_prompts.get(prompt_name) if prompt_name else None
115
115
 
116
116
  @staticmethod
117
117
  @overload
@@ -5,8 +5,6 @@ from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  import numpy as np
7
7
 
8
- from mteb.types import BatchedInput
9
-
10
8
 
11
9
  @runtime_checkable
12
10
  class CacheBackendProtocol(Protocol):
@@ -26,7 +24,7 @@ class CacheBackendProtocol(Protocol):
26
24
  **kwargs: Additional backend-specific arguments.
27
25
  """
28
26
 
29
- def add(self, item: list[BatchedInput], vectors: np.ndarray) -> None:
27
+ def add(self, item: list[dict[str, Any]], vectors: np.ndarray) -> None:
30
28
  """Add a vector to the cache.
31
29
 
32
30
  Args:
@@ -34,7 +32,7 @@ class CacheBackendProtocol(Protocol):
34
32
  vectors: Embedding vector of shape (dim,) or (1, dim).
35
33
  """
36
34
 
37
- def get_vector(self, item: BatchedInput) -> np.ndarray | None:
35
+ def get_vector(self, item: dict[str, Any]) -> np.ndarray | None:
38
36
  """Retrieve the cached vector for the given item.
39
37
 
40
38
  Args:
@@ -53,5 +51,5 @@ class CacheBackendProtocol(Protocol):
53
51
  def close(self) -> None:
54
52
  """Release resources or flush data."""
55
53
 
56
- def __contains__(self, item: BatchedInput) -> bool:
54
+ def __contains__(self, item: dict[str, Any]) -> bool:
57
55
  """Check whether the cache contains an item."""
@@ -1,12 +1,13 @@
1
1
  import hashlib
2
+ from collections.abc import Mapping
3
+ from typing import Any
2
4
 
3
- from mteb.types import BatchedInput
4
5
 
5
-
6
- def _hash_item(item: BatchedInput) -> str:
6
+ def _hash_item(item: Mapping[str, Any]) -> str:
7
7
  item_hash = ""
8
8
  if "text" in item:
9
- item_hash = hashlib.sha256(item["text"].encode()).hexdigest()
9
+ item_text: str = item["text"]
10
+ item_hash = hashlib.sha256(item_text.encode()).hexdigest()
10
11
 
11
12
  if "image" in item:
12
13
  from PIL import Image
@@ -2,6 +2,7 @@ import json
2
2
  import logging
3
3
  import warnings
4
4
  from pathlib import Path
5
+ from typing import Any
5
6
 
6
7
  import numpy as np
7
8
 
@@ -37,7 +38,7 @@ class FaissCache:
37
38
  logger.info(f"Initialized FAISS VectorCacheMap in {self.directory}")
38
39
  self.load()
39
40
 
40
- def add(self, items: list[BatchedInput], vectors: np.ndarray) -> None:
41
+ def add(self, items: list[dict[str, Any]], vectors: np.ndarray) -> None:
41
42
  """Add vector to FAISS index."""
42
43
  import faiss
43
44
 
@@ -2,11 +2,10 @@ import json
2
2
  import logging
3
3
  import warnings
4
4
  from pathlib import Path
5
+ from typing import Any
5
6
 
6
7
  import numpy as np
7
8
 
8
- from mteb.types import BatchedInput
9
-
10
9
  from ._hash_utils import _hash_item
11
10
 
12
11
  logger = logging.getLogger(__name__)
@@ -15,7 +14,7 @@ logger = logging.getLogger(__name__)
15
14
  class NumpyCache:
16
15
  """Generic vector cache for both text and images."""
17
16
 
18
- def __init__(self, directory: str | Path, initial_vectors: int = 100000):
17
+ def __init__(self, directory: str | Path, initial_vectors: int = 100_000):
19
18
  self.directory = Path(directory)
20
19
  self.directory.mkdir(parents=True, exist_ok=True)
21
20
  self.vectors_file = self.directory / "vectors.npy"
@@ -28,7 +27,7 @@ class NumpyCache:
28
27
  logger.info(f"Initialized VectorCacheMap in directory: {self.directory}")
29
28
  self._initialize_vectors_file()
30
29
 
31
- def add(self, item: list[BatchedInput], vectors: np.ndarray) -> None:
30
+ def add(self, items: list[dict[str, Any]], vectors: np.ndarray) -> None:
32
31
  """Add a vector to the cache."""
33
32
  try:
34
33
  if self.vector_dim is None:
@@ -39,7 +38,12 @@ class NumpyCache:
39
38
  self._save_dimension()
40
39
  logger.info(f"Initialized vector dimension to {self.vector_dim}")
41
40
 
42
- for item, vec in zip(item, vectors):
41
+ if self.vectors is None:
42
+ raise RuntimeError(
43
+ "Vectors file not initialized. Call _initialize_vectors_file() first."
44
+ )
45
+
46
+ for item, vec in zip(items, vectors):
43
47
  item_hash = _hash_item(item)
44
48
  if item_hash in self.hash_to_index:
45
49
  msg = f"Hash collision or duplicate item for hash {item_hash}. Overwriting existing vector."
@@ -75,18 +79,26 @@ class NumpyCache:
75
79
  shape=(self.initial_vectors, self.vector_dim),
76
80
  )
77
81
  else:
78
- self.vectors = np.memmap(self.vectors_file, dtype="float32", mode="r+")
79
- self.vectors = self.vectors.reshape(-1, self.vector_dim)
82
+ self.vectors = np.memmap(
83
+ self.vectors_file,
84
+ dtype="float32",
85
+ mode="r+",
86
+ shape=(-1, self.vector_dim),
87
+ )
80
88
  logger.info(f"Vectors file initialized with shape: {self.vectors.shape}")
81
89
 
82
90
  def _double_vectors_file(self) -> None:
91
+ if self.vectors is None or self.vector_dim is None:
92
+ raise RuntimeError(
93
+ "Vectors file not initialized. Call _initialize_vectors_file() first."
94
+ )
83
95
  current_size = len(self.vectors)
84
96
  new_size = current_size * 2
85
97
  logger.info(f"Doubling vectors file from {current_size} to {new_size} vectors")
86
98
  self.vectors.flush()
87
99
  new_vectors = np.memmap(
88
- self.vectors_file,
89
- dtype="float32",
100
+ str(self.vectors_file),
101
+ dtype=np.float32,
90
102
  mode="r+",
91
103
  shape=(new_size, self.vector_dim),
92
104
  )
@@ -147,9 +159,11 @@ class NumpyCache:
147
159
 
148
160
  if self.vector_dim is not None:
149
161
  self.vectors = np.memmap(
150
- self.vectors_file, dtype="float32", mode="r+"
162
+ self.vectors_file,
163
+ dtype="float32",
164
+ mode="r+",
165
+ shape=(-1, self.vector_dim),
151
166
  )
152
- self.vectors = self.vectors.reshape(-1, self.vector_dim)
153
167
  logger.info(f"Loaded vectors file with shape: {self.vectors.shape}")
154
168
  else:
155
169
  msg = "Vector dimension not set. Unable to load vectors file."
@@ -164,8 +178,11 @@ class NumpyCache:
164
178
  logger.error(f"Error loading VectorCacheMap: {str(e)}")
165
179
  raise
166
180
 
167
- def get_vector(self, item: BatchedInput) -> np.ndarray | None:
181
+ def get_vector(self, item: dict[str, Any]) -> np.ndarray | None:
168
182
  """Retrieve vector from index by hash."""
183
+ if self.vectors is None:
184
+ return None
185
+
169
186
  try:
170
187
  item_hash = _hash_item(item)
171
188
  if item_hash not in self.hash_to_index:
@@ -177,7 +194,7 @@ class NumpyCache:
177
194
  logger.error(f"Error retrieving vector for item: {str(e)}")
178
195
  raise
179
196
 
180
- def __contains__(self, item: BatchedInput) -> bool:
197
+ def __contains__(self, item: dict[str, Any]) -> bool:
181
198
  return _hash_item(item) in self.hash_to_index
182
199
 
183
200
  def __del__(self):
@@ -90,9 +90,9 @@ class CachedEmbeddingWrapper:
90
90
  try:
91
91
  cache = self._get_or_create_cache(task_name)
92
92
 
93
- uncached_items: list[BatchedInput] = []
93
+ uncached_items: list[dict[str, Any]] = []
94
94
  uncached_indices: list[int] = []
95
- all_items = inputs.dataset
95
+ all_items: Dataset = inputs.dataset
96
96
  cached_vectors: dict[int, np.ndarray] = {}
97
97
 
98
98
  for i, item in enumerate(all_items):