mteb 2.5.2__py3-none-any.whl → 2.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. mteb/_create_dataloaders.py +10 -15
  2. mteb/_evaluators/any_sts_evaluator.py +1 -4
  3. mteb/_evaluators/evaluator.py +2 -1
  4. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
  5. mteb/_evaluators/pair_classification_evaluator.py +3 -1
  6. mteb/_evaluators/retrieval_metrics.py +17 -16
  7. mteb/_evaluators/sklearn_evaluator.py +9 -8
  8. mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
  9. mteb/_evaluators/text/summarization_evaluator.py +20 -16
  10. mteb/abstasks/_data_filter/filters.py +1 -1
  11. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  12. mteb/abstasks/_statistics_calculation.py +18 -10
  13. mteb/abstasks/_stratification.py +18 -18
  14. mteb/abstasks/abstask.py +33 -27
  15. mteb/abstasks/aggregate_task_metadata.py +1 -9
  16. mteb/abstasks/aggregated_task.py +7 -26
  17. mteb/abstasks/classification.py +10 -4
  18. mteb/abstasks/clustering.py +18 -14
  19. mteb/abstasks/clustering_legacy.py +8 -8
  20. mteb/abstasks/image/image_text_pair_classification.py +5 -3
  21. mteb/abstasks/multilabel_classification.py +20 -16
  22. mteb/abstasks/pair_classification.py +18 -9
  23. mteb/abstasks/regression.py +3 -3
  24. mteb/abstasks/retrieval.py +12 -9
  25. mteb/abstasks/sts.py +6 -3
  26. mteb/abstasks/task_metadata.py +22 -19
  27. mteb/abstasks/text/bitext_mining.py +36 -25
  28. mteb/abstasks/text/reranking.py +7 -5
  29. mteb/abstasks/text/summarization.py +8 -3
  30. mteb/abstasks/zeroshot_classification.py +5 -2
  31. mteb/benchmarks/benchmark.py +2 -2
  32. mteb/cache.py +27 -22
  33. mteb/cli/_display_tasks.py +2 -2
  34. mteb/cli/build_cli.py +15 -10
  35. mteb/cli/generate_model_card.py +10 -7
  36. mteb/deprecated_evaluator.py +60 -46
  37. mteb/evaluate.py +39 -30
  38. mteb/filter_tasks.py +25 -26
  39. mteb/get_tasks.py +29 -30
  40. mteb/languages/language_scripts.py +5 -3
  41. mteb/leaderboard/app.py +1 -1
  42. mteb/load_results.py +12 -12
  43. mteb/models/abs_encoder.py +7 -5
  44. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  45. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  46. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
  47. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
  48. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  49. mteb/models/get_model_meta.py +8 -1
  50. mteb/models/instruct_wrapper.py +11 -5
  51. mteb/models/model_implementations/andersborges.py +2 -2
  52. mteb/models/model_implementations/blip_models.py +8 -8
  53. mteb/models/model_implementations/bm25.py +1 -1
  54. mteb/models/model_implementations/clip_models.py +3 -3
  55. mteb/models/model_implementations/cohere_models.py +1 -1
  56. mteb/models/model_implementations/cohere_v.py +2 -2
  57. mteb/models/model_implementations/dino_models.py +23 -23
  58. mteb/models/model_implementations/emillykkejensen_models.py +3 -3
  59. mteb/models/model_implementations/gme_v_models.py +4 -3
  60. mteb/models/model_implementations/jina_clip.py +1 -1
  61. mteb/models/model_implementations/jina_models.py +1 -1
  62. mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
  63. mteb/models/model_implementations/llm2clip_models.py +3 -3
  64. mteb/models/model_implementations/mcinext_models.py +4 -1
  65. mteb/models/model_implementations/moco_models.py +2 -2
  66. mteb/models/model_implementations/model2vec_models.py +1 -1
  67. mteb/models/model_implementations/nomic_models.py +8 -8
  68. mteb/models/model_implementations/openclip_models.py +7 -7
  69. mteb/models/model_implementations/random_baseline.py +3 -3
  70. mteb/models/model_implementations/rasgaard_models.py +1 -1
  71. mteb/models/model_implementations/repllama_models.py +2 -2
  72. mteb/models/model_implementations/rerankers_custom.py +3 -3
  73. mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
  74. mteb/models/model_implementations/siglip_models.py +10 -10
  75. mteb/models/model_implementations/vlm2vec_models.py +1 -1
  76. mteb/models/model_implementations/voyage_v.py +4 -4
  77. mteb/models/model_meta.py +14 -13
  78. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
  79. mteb/models/search_wrappers.py +26 -12
  80. mteb/models/sentence_transformer_wrapper.py +19 -14
  81. mteb/py.typed +0 -0
  82. mteb/results/benchmark_results.py +28 -20
  83. mteb/results/model_result.py +52 -22
  84. mteb/results/task_result.py +55 -58
  85. mteb/similarity_functions.py +11 -7
  86. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  87. mteb/tasks/classification/est/estonian_valence.py +1 -1
  88. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  89. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  90. mteb/tasks/retrieval/code/code_rag.py +12 -12
  91. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  92. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  93. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  94. mteb/tasks/retrieval/nob/norquad.py +2 -2
  95. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  96. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  97. mteb/types/_result.py +2 -1
  98. mteb/types/statistics.py +9 -3
  99. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/METADATA +1 -1
  100. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/RECORD +104 -103
  101. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/WHEEL +0 -0
  102. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/entry_points.txt +0 -0
  103. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/licenses/LICENSE +0 -0
  104. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/top_level.txt +0 -0
mteb/evaluate.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
+ import warnings
4
5
  from collections.abc import Iterable
5
6
  from pathlib import Path
6
7
  from time import time
@@ -13,11 +14,10 @@ from mteb._helpful_enum import HelpfulStrEnum
13
14
  from mteb.abstasks import AbsTaskRetrieval
14
15
  from mteb.abstasks.abstask import AbsTask
15
16
  from mteb.abstasks.aggregated_task import AbsTaskAggregate
17
+ from mteb.benchmarks.benchmark import Benchmark
16
18
  from mteb.cache import ResultCache
17
19
  from mteb.models.model_meta import ModelMeta
18
20
  from mteb.models.models_protocols import (
19
- CrossEncoderProtocol,
20
- EncoderProtocol,
21
21
  MTEBModels,
22
22
  )
23
23
  from mteb.models.sentence_transformer_wrapper import (
@@ -57,27 +57,26 @@ def _sanitize_model(
57
57
  ) -> tuple[MTEBModels | ModelMeta, ModelMeta, ModelName, Revision]:
58
58
  from sentence_transformers import CrossEncoder, SentenceTransformer
59
59
 
60
+ wrapped_model: MTEBModels | ModelMeta
60
61
  if isinstance(model, SentenceTransformer):
61
- _mdl = SentenceTransformerEncoderWrapper(model)
62
- meta = _mdl.mteb_model_meta
63
- _mdl = cast(EncoderProtocol, _mdl)
64
- model = _mdl
62
+ wrapped_model = SentenceTransformerEncoderWrapper(model)
63
+ meta = wrapped_model.mteb_model_meta
65
64
  elif isinstance(model, CrossEncoder):
66
- _mdl = CrossEncoderWrapper(model)
67
- _mdl = cast(CrossEncoderProtocol, _mdl)
68
- meta = _mdl.mteb_model_meta
69
- model = _mdl
65
+ wrapped_model = CrossEncoderWrapper(model)
66
+ meta = wrapped_model.mteb_model_meta
70
67
  elif hasattr(model, "mteb_model_meta"):
71
- meta = model.mteb_model_meta # type: ignore[attr-defined]
68
+ meta = getattr(model, "mteb_model_meta")
72
69
  if not isinstance(meta, ModelMeta):
73
- meta = ModelMeta.from_hub(None)
70
+ meta = ModelMeta._from_hub(None)
71
+ wrapped_model = cast(MTEBModels | ModelMeta, model)
74
72
  else:
75
- meta = ModelMeta.from_hub(None) if not isinstance(model, ModelMeta) else model
73
+ meta = ModelMeta._from_hub(None) if not isinstance(model, ModelMeta) else model
74
+ wrapped_model = meta
76
75
 
77
76
  model_name = cast(str, meta.name)
78
77
  model_revision = cast(str, meta.revision)
79
78
 
80
- return model, meta, model_name, model_revision
79
+ return wrapped_model, meta, model_name, model_revision
81
80
 
82
81
 
83
82
  def _evaluate_task(
@@ -123,7 +122,8 @@ def _evaluate_task(
123
122
  prediction_folder=prediction_folder,
124
123
  public_only=public_only,
125
124
  )
126
- result.kg_co2_emissions = tracker.final_emissions
125
+ if isinstance(result, TaskResult):
126
+ result.kg_co2_emissions = tracker.final_emissions
127
127
  return result
128
128
 
129
129
  task_results = {}
@@ -136,10 +136,12 @@ def _evaluate_task(
136
136
  task.load_data()
137
137
  except DatasetNotFoundError as e:
138
138
  if not task.metadata.is_public and public_only is None:
139
- logger.warning(
139
+ msg = (
140
140
  f"Dataset for private task '{task.metadata.name}' not found. "
141
141
  "Make sure you have access to the dataset and that you have set up the authentication correctly. To disable this warning set `public_only=False`"
142
142
  )
143
+ logger.warning(msg)
144
+ warnings.warn(msg)
143
145
  return TaskError(
144
146
  task_name=task.metadata.name,
145
147
  exception=str(e),
@@ -147,7 +149,7 @@ def _evaluate_task(
147
149
  if public_only is False:
148
150
  raise e
149
151
 
150
- evaluation_time = 0
152
+ evaluation_time = 0.0
151
153
 
152
154
  for split, hf_subsets in splits.items():
153
155
  tick = time()
@@ -194,12 +196,18 @@ def _check_model_modalities(
194
196
  return
195
197
 
196
198
  model_modalities = set(model.modalities)
199
+ check_tasks: Iterable[AbsTask] = []
197
200
  if isinstance(tasks, AbsTask):
198
- tasks = [tasks]
201
+ check_tasks = [tasks]
202
+ elif isinstance(tasks, Benchmark):
203
+ benchmark = cast(Benchmark, tasks)
204
+ check_tasks = benchmark.tasks
205
+ else:
206
+ check_tasks = cast(Iterable[AbsTask], tasks)
199
207
 
200
208
  warnings, errors = [], []
201
209
 
202
- for task in tasks:
210
+ for task in check_tasks:
203
211
  # only retrieval tasks have different modalities for query and document and can be run with partial overlaps
204
212
  if isinstance(task, AbsTaskRetrieval):
205
213
  query_mods = set(task.metadata.get_modalities(PromptType.query))
@@ -332,10 +340,10 @@ def evaluate(
332
340
 
333
341
  # AbsTaskAggregate is a special case where we have to run multiple tasks and combine the results
334
342
  if isinstance(tasks, AbsTaskAggregate):
335
- task = cast(AbsTaskAggregate, tasks)
343
+ aggregated_task = cast(AbsTaskAggregate, tasks)
336
344
  results = evaluate(
337
345
  model,
338
- task.metadata.tasks,
346
+ aggregated_task.metadata.tasks,
339
347
  co2_tracker=co2_tracker,
340
348
  raise_error=raise_error,
341
349
  encode_kwargs=encode_kwargs,
@@ -345,17 +353,18 @@ def evaluate(
345
353
  show_progress_bar=show_progress_bar,
346
354
  public_only=public_only,
347
355
  )
348
- result = task.combine_task_results(results.task_results)
356
+ combined_results = aggregated_task.combine_task_results(results.task_results)
349
357
  return ModelResult(
350
358
  model_name=results.model_name,
351
359
  model_revision=results.model_revision,
352
- task_results=[result],
360
+ task_results=[combined_results],
353
361
  )
354
362
 
355
363
  if isinstance(tasks, AbsTask):
356
364
  task = tasks
357
365
  else:
358
- results = []
366
+ tasks = cast(Iterable[AbsTask], tasks)
367
+ evaluate_results = []
359
368
  exceptions = []
360
369
  tasks_tqdm = tqdm(
361
370
  tasks,
@@ -376,23 +385,23 @@ def evaluate(
376
385
  show_progress_bar=False,
377
386
  public_only=public_only,
378
387
  )
379
- results.extend(_res.task_results)
388
+ evaluate_results.extend(_res.task_results)
380
389
  if _res.exceptions:
381
390
  exceptions.extend(_res.exceptions)
382
391
  return ModelResult(
383
392
  model_name=_res.model_name,
384
393
  model_revision=_res.model_revision,
385
- task_results=results,
394
+ task_results=evaluate_results,
386
395
  exceptions=exceptions,
387
396
  )
388
397
 
389
398
  overwrite_strategy = OverwriteStrategy.from_str(overwrite_strategy)
390
399
 
391
- existing_results = None
400
+ existing_results: TaskResult | None = None
392
401
  if cache and overwrite_strategy != OverwriteStrategy.ALWAYS:
393
- results = cache.load_task_result(task.metadata.name, meta)
394
- if results:
395
- existing_results = results
402
+ cache_results = cache.load_task_result(task.metadata.name, meta)
403
+ if cache_results:
404
+ existing_results = cache_results
396
405
 
397
406
  if (
398
407
  existing_results
mteb/filter_tasks.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """This script contains functions that are used to get an overview of the MTEB benchmark."""
2
2
 
3
3
  import logging
4
- from collections.abc import Sequence
4
+ from collections.abc import Iterable, Sequence
5
5
  from typing import overload
6
6
 
7
7
  from mteb.abstasks import (
@@ -34,14 +34,14 @@ def _check_is_valid_language(lang: str) -> None:
34
34
 
35
35
  @overload
36
36
  def filter_tasks(
37
- tasks: Sequence[AbsTask],
37
+ tasks: Iterable[AbsTask],
38
38
  *,
39
- languages: list[str] | None = None,
40
- script: list[str] | None = None,
41
- domains: list[TaskDomain] | None = None,
42
- task_types: list[TaskType] | None = None, # type: ignore
43
- categories: list[TaskCategory] | None = None,
44
- modalities: list[Modalities] | None = None,
39
+ languages: Sequence[str] | None = None,
40
+ script: Sequence[str] | None = None,
41
+ domains: Iterable[TaskDomain] | None = None,
42
+ task_types: Iterable[TaskType] | None = None,
43
+ categories: Iterable[TaskCategory] | None = None,
44
+ modalities: Iterable[Modalities] | None = None,
45
45
  exclusive_modality_filter: bool = False,
46
46
  exclude_superseded: bool = False,
47
47
  exclude_aggregate: bool = False,
@@ -51,14 +51,14 @@ def filter_tasks(
51
51
 
52
52
  @overload
53
53
  def filter_tasks(
54
- tasks: Sequence[type[AbsTask]],
54
+ tasks: Iterable[type[AbsTask]],
55
55
  *,
56
- languages: list[str] | None = None,
57
- script: list[str] | None = None,
58
- domains: list[TaskDomain] | None = None,
59
- task_types: list[TaskType] | None = None, # type: ignore
60
- categories: list[TaskCategory] | None = None,
61
- modalities: list[Modalities] | None = None,
56
+ languages: Sequence[str] | None = None,
57
+ script: Sequence[str] | None = None,
58
+ domains: Iterable[TaskDomain] | None = None,
59
+ task_types: Iterable[TaskType] | None = None,
60
+ categories: Iterable[TaskCategory] | None = None,
61
+ modalities: Iterable[Modalities] | None = None,
62
62
  exclusive_modality_filter: bool = False,
63
63
  exclude_superseded: bool = False,
64
64
  exclude_aggregate: bool = False,
@@ -67,14 +67,14 @@ def filter_tasks(
67
67
 
68
68
 
69
69
  def filter_tasks(
70
- tasks: Sequence[AbsTask] | Sequence[type[AbsTask]],
70
+ tasks: Iterable[AbsTask] | Iterable[type[AbsTask]],
71
71
  *,
72
- languages: list[str] | None = None,
73
- script: list[str] | None = None,
74
- domains: list[TaskDomain] | None = None,
75
- task_types: list[TaskType] | None = None, # type: ignore
76
- categories: list[TaskCategory] | None = None,
77
- modalities: list[Modalities] | None = None,
72
+ languages: Sequence[str] | None = None,
73
+ script: Sequence[str] | None = None,
74
+ domains: Iterable[TaskDomain] | None = None,
75
+ task_types: Iterable[TaskType] | None = None,
76
+ categories: Iterable[TaskCategory] | None = None,
77
+ modalities: Iterable[Modalities] | None = None,
78
78
  exclusive_modality_filter: bool = False,
79
79
  exclude_superseded: bool = False,
80
80
  exclude_aggregate: bool = False,
@@ -92,7 +92,6 @@ def filter_tasks(
92
92
  task_types: A string specifying the type of task e.g. "Classification" or "Retrieval". If None, all tasks are included.
93
93
  categories: A list of task categories these include "t2t" (text to text), "t2i" (text to image). See TaskMetadata for the full list.
94
94
  exclude_superseded: A boolean flag to exclude datasets which are superseded by another.
95
- eval_splits: A list of evaluation splits to include. If None, all splits are included.
96
95
  modalities: A list of modalities to include. If None, all modalities are included.
97
96
  exclusive_modality_filter: If True, only keep tasks where _all_ filter modalities are included in the
98
97
  task's modalities and ALL task modalities are in filter modalities (exact match).
@@ -113,12 +112,12 @@ def filter_tasks(
113
112
  """
114
113
  langs_to_keep = None
115
114
  if languages:
116
- [_check_is_valid_language(lang) for lang in languages]
115
+ [_check_is_valid_language(lang) for lang in languages] # type: ignore[func-returns-value]
117
116
  langs_to_keep = set(languages)
118
117
 
119
118
  script_to_keep = None
120
119
  if script:
121
- [_check_is_valid_script(s) for s in script]
120
+ [_check_is_valid_script(s) for s in script] # type: ignore[func-returns-value]
122
121
  script_to_keep = set(script)
123
122
 
124
123
  domains_to_keep = None
@@ -178,4 +177,4 @@ def filter_tasks(
178
177
 
179
178
  _tasks.append(t)
180
179
 
181
- return _tasks
180
+ return _tasks # type: ignore[return-value] # type checker cannot infer the overload return type
mteb/get_tasks.py CHANGED
@@ -2,8 +2,9 @@
2
2
 
3
3
  import difflib
4
4
  import logging
5
+ import warnings
5
6
  from collections import Counter, defaultdict
6
- from collections.abc import Sequence
7
+ from collections.abc import Iterable, Sequence
7
8
  from typing import Any
8
9
 
9
10
  import pandas as pd
@@ -22,12 +23,11 @@ logger = logging.getLogger(__name__)
22
23
  def _gather_tasks() -> tuple[type[AbsTask], ...]:
23
24
  import mteb.tasks as tasks
24
25
 
25
- tasks = [
26
+ return tuple(
26
27
  t
27
28
  for t in tasks.__dict__.values()
28
29
  if isinstance(t, type) and issubclass(t, AbsTask)
29
- ]
30
- return tuple(tasks)
30
+ )
31
31
 
32
32
 
33
33
  def _create_name_to_task_mapping(
@@ -43,7 +43,7 @@ def _create_name_to_task_mapping(
43
43
  return metadata_names
44
44
 
45
45
 
46
- def _create_similar_tasks(tasks: Sequence[type[AbsTask]]) -> dict[str, list[str]]:
46
+ def _create_similar_tasks(tasks: Iterable[type[AbsTask]]) -> dict[str, list[str]]:
47
47
  """Create a dictionary of similar tasks.
48
48
 
49
49
  Returns:
@@ -194,9 +194,8 @@ class MTEBTasks(tuple[AbsTask]):
194
194
  string with a LaTeX table.
195
195
  """
196
196
  if include_citation_in_name and "name" in properties:
197
- properties += ["intext_citation"]
198
- df = self.to_dataframe(properties)
199
- df["name"] = df["name"] + " " + df["intext_citation"]
197
+ df = self.to_dataframe(tuple(properties) + ("intext_citation",))
198
+ df["name"] = df["name"] + " " + df["intext_citation"] # type: ignore[operator]
200
199
  df = df.drop(columns=["intext_citation"])
201
200
  else:
202
201
  df = self.to_dataframe(properties)
@@ -221,17 +220,17 @@ class MTEBTasks(tuple[AbsTask]):
221
220
 
222
221
 
223
222
  def get_tasks(
224
- tasks: list[str] | None = None,
223
+ tasks: Sequence[str] | None = None,
225
224
  *,
226
- languages: list[str] | None = None,
227
- script: list[str] | None = None,
228
- domains: list[TaskDomain] | None = None,
229
- task_types: list[TaskType] | None = None, # type: ignore
230
- categories: list[TaskCategory] | None = None,
225
+ languages: Sequence[str] | None = None,
226
+ script: Sequence[str] | None = None,
227
+ domains: Sequence[TaskDomain] | None = None,
228
+ task_types: Sequence[TaskType] | None = None,
229
+ categories: Sequence[TaskCategory] | None = None,
231
230
  exclude_superseded: bool = True,
232
- eval_splits: list[str] | None = None,
231
+ eval_splits: Sequence[str] | None = None,
233
232
  exclusive_language_filter: bool = False,
234
- modalities: list[Modalities] | None = None,
233
+ modalities: Sequence[Modalities] | None = None,
235
234
  exclusive_modality_filter: bool = False,
236
235
  exclude_aggregate: bool = False,
237
236
  exclude_private: bool = True,
@@ -287,7 +286,7 @@ def get_tasks(
287
286
  ]
288
287
  return MTEBTasks(_tasks)
289
288
 
290
- _tasks = filter_tasks(
289
+ tasks_: Sequence[type[AbsTask]] = filter_tasks(
291
290
  TASK_LIST,
292
291
  languages=languages,
293
292
  script=script,
@@ -300,12 +299,12 @@ def get_tasks(
300
299
  exclude_aggregate=exclude_aggregate,
301
300
  exclude_private=exclude_private,
302
301
  )
303
- _tasks = [
304
- cls().filter_languages(languages, script).filter_eval_splits(eval_splits)
305
- for cls in _tasks
306
- ]
307
-
308
- return MTEBTasks(_tasks)
302
+ return MTEBTasks(
303
+ [
304
+ cls().filter_languages(languages, script).filter_eval_splits(eval_splits)
305
+ for cls in tasks_
306
+ ]
307
+ )
309
308
 
310
309
 
311
310
  _TASK_RENAMES = {"PersianTextTone": "SynPerTextToneClassification"}
@@ -313,10 +312,10 @@ _TASK_RENAMES = {"PersianTextTone": "SynPerTextToneClassification"}
313
312
 
314
313
  def get_task(
315
314
  task_name: str,
316
- languages: list[str] | None = None,
317
- script: list[str] | None = None,
318
- eval_splits: list[str] | None = None,
319
- hf_subsets: list[str] | None = None,
315
+ languages: Sequence[str] | None = None,
316
+ script: Sequence[str] | None = None,
317
+ eval_splits: Sequence[str] | None = None,
318
+ hf_subsets: Sequence[str] | None = None,
320
319
  exclusive_language_filter: bool = False,
321
320
  ) -> AbsTask:
322
321
  """Get a task by name.
@@ -340,9 +339,9 @@ def get_task(
340
339
  """
341
340
  if task_name in _TASK_RENAMES:
342
341
  _task_name = _TASK_RENAMES[task_name]
343
- logger.warning(
344
- f"The task with the given name '{task_name}' has been renamed to '{_task_name}'. To prevent this warning use the new name."
345
- )
342
+ msg = f"The task with the given name '{task_name}' has been renamed to '{_task_name}'. To prevent this warning use the new name."
343
+ logger.warning(msg)
344
+ warnings.warn(msg)
346
345
 
347
346
  if task_name not in _TASKS_REGISTRY:
348
347
  close_matches = difflib.get_close_matches(task_name, _TASKS_REGISTRY.keys())
@@ -1,9 +1,9 @@
1
- from collections.abc import Iterable
1
+ from collections.abc import Iterable, Sequence
2
2
  from dataclasses import dataclass
3
3
 
4
4
  from typing_extensions import Self
5
5
 
6
- from mteb.languages import check_language_code
6
+ from mteb.languages.check_language_code import check_language_code
7
7
 
8
8
 
9
9
  @dataclass
@@ -25,7 +25,9 @@ class LanguageScripts:
25
25
 
26
26
  @classmethod
27
27
  def from_languages_and_scripts(
28
- cls, languages: list[str] | None = None, scripts: list[str] | None = None
28
+ cls,
29
+ languages: Sequence[str] | None = None,
30
+ scripts: Sequence[str] | None = None,
29
31
  ) -> Self:
30
32
  """Create a LanguageScripts object from lists of languages and scripts.
31
33
 
mteb/leaderboard/app.py CHANGED
@@ -169,7 +169,7 @@ def _update_task_info(task_names: str) -> gr.DataFrame:
169
169
  df = df.drop(columns="reference")
170
170
  return gr.DataFrame(
171
171
  df,
172
- datatype=["markdown"] + ["str"] * (len(df.columns) - 1), # type: ignore
172
+ datatype=["markdown"] + ["str"] * (len(df.columns) - 1),
173
173
  buttons=["copy", "fullscreen"],
174
174
  show_search="filter",
175
175
  )
mteb/load_results.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import sys
4
- from collections.abc import Sequence
4
+ from collections.abc import Iterable, Sequence
5
5
  from pathlib import Path
6
6
 
7
7
  from mteb.abstasks.abstask import AbsTask
@@ -45,8 +45,8 @@ def _model_name_and_revision(
45
45
  def load_results(
46
46
  results_repo: str = "https://github.com/embeddings-benchmark/results",
47
47
  download_latest: bool = True,
48
- models: Sequence[ModelMeta] | Sequence[str] | None = None,
49
- tasks: Sequence[AbsTask] | Sequence[str] | None = None,
48
+ models: Iterable[ModelMeta] | Sequence[str] | None = None,
49
+ tasks: Iterable[AbsTask] | Sequence[str] | None = None,
50
50
  validate_and_filter: bool = True,
51
51
  require_model_meta: bool = True,
52
52
  only_main_score: bool = False,
@@ -83,21 +83,21 @@ def load_results(
83
83
 
84
84
  if models is not None:
85
85
  models_to_keep = {}
86
- for model_path in models:
87
- if isinstance(model_path, ModelMeta):
88
- models_to_keep[model_path.name] = model_path.revision
86
+ for model in models:
87
+ if isinstance(model, ModelMeta):
88
+ models_to_keep[model.name] = model.revision
89
89
  else:
90
- models_to_keep[model_path] = None
90
+ models_to_keep[model] = None
91
91
  else:
92
92
  models_to_keep = None
93
93
 
94
- task_names = {}
94
+ task_names: dict[str, AbsTask | None] = {}
95
95
  if tasks is not None:
96
- for task in tasks:
97
- if isinstance(task, AbsTask):
98
- task_names[task.metadata.name] = task
96
+ for task_ in tasks:
97
+ if isinstance(task_, AbsTask):
98
+ task_names[task_.metadata.name] = task_
99
99
  else:
100
- task_names[task] = None
100
+ task_names[task_] = None
101
101
 
102
102
  model_results = []
103
103
  for model_path in model_paths:
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import warnings
2
3
  from abc import ABC, abstractmethod
3
4
  from collections.abc import Callable, Sequence
4
5
  from typing import Any, Literal, cast, get_args, overload
@@ -43,7 +44,7 @@ class AbsEncoder(ABC):
43
44
  model: Any
44
45
  mteb_model_meta: ModelMeta | None = None
45
46
  model_prompts: dict[str, str] | None = None
46
- instruction_template: str | Callable[[str, PromptType], str] | None = None
47
+ instruction_template: str | Callable[[str, PromptType | None], str] | None = None
47
48
  prompts_dict: dict[str, str] | None = None
48
49
 
49
50
  def get_prompt_name(
@@ -110,7 +111,7 @@ class AbsEncoder(ABC):
110
111
  if not self.model_prompts:
111
112
  return None
112
113
  prompt_name = self.get_prompt_name(task_metadata, prompt_type)
113
- return self.model_prompts.get(prompt_name)
114
+ return self.model_prompts.get(prompt_name) if prompt_name else None
114
115
 
115
116
  @staticmethod
116
117
  @overload
@@ -187,6 +188,7 @@ class AbsEncoder(ABC):
187
188
  except KeyError:
188
189
  msg = f"Task name {task_name} is not valid. {valid_keys_msg}"
189
190
  logger.warning(msg)
191
+ warnings.warn(msg)
190
192
  invalid_task_messages.add(msg)
191
193
  invalid_keys.add(task_key)
192
194
 
@@ -232,9 +234,9 @@ class AbsEncoder(ABC):
232
234
  if isinstance(prompt, dict) and prompt_type:
233
235
  if prompt.get(prompt_type.value):
234
236
  return prompt[prompt_type.value]
235
- logger.warning(
236
- f"Prompt type '{prompt_type}' not found in task metadata for task '{task_metadata.name}'."
237
- )
237
+ msg = f"Prompt type '{prompt_type}' not found in task metadata for task '{task_metadata.name}'."
238
+ logger.warning(msg)
239
+ warnings.warn(msg)
238
240
  return ""
239
241
 
240
242
  if prompt:
@@ -5,8 +5,6 @@ from typing import Any, Protocol, runtime_checkable
5
5
 
6
6
  import numpy as np
7
7
 
8
- from mteb.types import BatchedInput
9
-
10
8
 
11
9
  @runtime_checkable
12
10
  class CacheBackendProtocol(Protocol):
@@ -26,7 +24,7 @@ class CacheBackendProtocol(Protocol):
26
24
  **kwargs: Additional backend-specific arguments.
27
25
  """
28
26
 
29
- def add(self, item: list[BatchedInput], vectors: np.ndarray) -> None:
27
+ def add(self, item: list[dict[str, Any]], vectors: np.ndarray) -> None:
30
28
  """Add a vector to the cache.
31
29
 
32
30
  Args:
@@ -34,7 +32,7 @@ class CacheBackendProtocol(Protocol):
34
32
  vectors: Embedding vector of shape (dim,) or (1, dim).
35
33
  """
36
34
 
37
- def get_vector(self, item: BatchedInput) -> np.ndarray | None:
35
+ def get_vector(self, item: dict[str, Any]) -> np.ndarray | None:
38
36
  """Retrieve the cached vector for the given item.
39
37
 
40
38
  Args:
@@ -53,5 +51,5 @@ class CacheBackendProtocol(Protocol):
53
51
  def close(self) -> None:
54
52
  """Release resources or flush data."""
55
53
 
56
- def __contains__(self, item: BatchedInput) -> bool:
54
+ def __contains__(self, item: dict[str, Any]) -> bool:
57
55
  """Check whether the cache contains an item."""
@@ -1,12 +1,13 @@
1
1
  import hashlib
2
+ from collections.abc import Mapping
3
+ from typing import Any
2
4
 
3
- from mteb.types import BatchedInput
4
5
 
5
-
6
- def _hash_item(item: BatchedInput) -> str:
6
+ def _hash_item(item: Mapping[str, Any]) -> str:
7
7
  item_hash = ""
8
8
  if "text" in item:
9
- item_hash = hashlib.sha256(item["text"].encode()).hexdigest()
9
+ item_text: str = item["text"]
10
+ item_hash = hashlib.sha256(item_text.encode()).hexdigest()
10
11
 
11
12
  if "image" in item:
12
13
  from PIL import Image
@@ -1,6 +1,8 @@
1
1
  import json
2
2
  import logging
3
+ import warnings
3
4
  from pathlib import Path
5
+ from typing import Any
4
6
 
5
7
  import numpy as np
6
8
 
@@ -36,7 +38,7 @@ class FaissCache:
36
38
  logger.info(f"Initialized FAISS VectorCacheMap in {self.directory}")
37
39
  self.load()
38
40
 
39
- def add(self, items: list[BatchedInput], vectors: np.ndarray) -> None:
41
+ def add(self, items: list[dict[str, Any]], vectors: np.ndarray) -> None:
40
42
  """Add vector to FAISS index."""
41
43
  import faiss
42
44
 
@@ -71,7 +73,9 @@ class FaissCache:
71
73
  try:
72
74
  return self.index.reconstruct(idx)
73
75
  except Exception:
74
- logger.warning(f"Vector id {idx} missing for hash {item_hash}")
76
+ msg = f"Vector id {idx} missing for hash {item_hash}"
77
+ logger.warning(msg)
78
+ warnings.warn(msg)
75
79
  return None
76
80
 
77
81
  def save(self) -> None: