mteb 2.5.3__py3-none-any.whl → 2.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. mteb/_create_dataloaders.py +10 -15
  2. mteb/_evaluators/any_sts_evaluator.py +1 -4
  3. mteb/_evaluators/evaluator.py +2 -1
  4. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
  5. mteb/_evaluators/pair_classification_evaluator.py +3 -1
  6. mteb/_evaluators/retrieval_metrics.py +17 -16
  7. mteb/_evaluators/sklearn_evaluator.py +9 -8
  8. mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
  9. mteb/_evaluators/text/summarization_evaluator.py +20 -16
  10. mteb/abstasks/_data_filter/filters.py +1 -1
  11. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  12. mteb/abstasks/_statistics_calculation.py +18 -10
  13. mteb/abstasks/_stratification.py +18 -18
  14. mteb/abstasks/abstask.py +27 -21
  15. mteb/abstasks/aggregate_task_metadata.py +1 -9
  16. mteb/abstasks/aggregated_task.py +3 -16
  17. mteb/abstasks/classification.py +10 -4
  18. mteb/abstasks/clustering.py +18 -14
  19. mteb/abstasks/clustering_legacy.py +8 -8
  20. mteb/abstasks/image/image_text_pair_classification.py +5 -3
  21. mteb/abstasks/multilabel_classification.py +20 -16
  22. mteb/abstasks/pair_classification.py +18 -9
  23. mteb/abstasks/regression.py +3 -3
  24. mteb/abstasks/retrieval.py +12 -9
  25. mteb/abstasks/sts.py +6 -3
  26. mteb/abstasks/task_metadata.py +20 -16
  27. mteb/abstasks/text/bitext_mining.py +36 -25
  28. mteb/abstasks/text/reranking.py +7 -5
  29. mteb/abstasks/text/summarization.py +8 -3
  30. mteb/abstasks/zeroshot_classification.py +5 -2
  31. mteb/benchmarks/benchmark.py +2 -2
  32. mteb/cache.py +20 -18
  33. mteb/cli/_display_tasks.py +2 -2
  34. mteb/cli/build_cli.py +5 -5
  35. mteb/cli/generate_model_card.py +6 -4
  36. mteb/deprecated_evaluator.py +56 -43
  37. mteb/evaluate.py +35 -29
  38. mteb/filter_tasks.py +25 -26
  39. mteb/get_tasks.py +25 -27
  40. mteb/languages/language_scripts.py +5 -3
  41. mteb/leaderboard/app.py +1 -1
  42. mteb/load_results.py +12 -12
  43. mteb/models/abs_encoder.py +2 -2
  44. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  45. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  46. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +2 -1
  47. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +30 -13
  48. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  49. mteb/models/get_model_meta.py +8 -1
  50. mteb/models/instruct_wrapper.py +11 -5
  51. mteb/models/model_implementations/andersborges.py +2 -2
  52. mteb/models/model_implementations/blip_models.py +8 -8
  53. mteb/models/model_implementations/bm25.py +1 -1
  54. mteb/models/model_implementations/clip_models.py +3 -3
  55. mteb/models/model_implementations/cohere_models.py +1 -1
  56. mteb/models/model_implementations/cohere_v.py +2 -2
  57. mteb/models/model_implementations/dino_models.py +23 -23
  58. mteb/models/model_implementations/emillykkejensen_models.py +3 -3
  59. mteb/models/model_implementations/jina_clip.py +1 -1
  60. mteb/models/model_implementations/jina_models.py +1 -1
  61. mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
  62. mteb/models/model_implementations/llm2clip_models.py +3 -3
  63. mteb/models/model_implementations/moco_models.py +2 -2
  64. mteb/models/model_implementations/model2vec_models.py +1 -1
  65. mteb/models/model_implementations/nomic_models.py +8 -8
  66. mteb/models/model_implementations/openclip_models.py +7 -7
  67. mteb/models/model_implementations/random_baseline.py +3 -3
  68. mteb/models/model_implementations/rasgaard_models.py +1 -1
  69. mteb/models/model_implementations/repllama_models.py +2 -2
  70. mteb/models/model_implementations/rerankers_custom.py +3 -3
  71. mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
  72. mteb/models/model_implementations/siglip_models.py +10 -10
  73. mteb/models/model_implementations/vlm2vec_models.py +1 -1
  74. mteb/models/model_implementations/voyage_v.py +4 -4
  75. mteb/models/model_meta.py +11 -12
  76. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +5 -5
  77. mteb/models/search_wrappers.py +22 -10
  78. mteb/models/sentence_transformer_wrapper.py +9 -4
  79. mteb/py.typed +0 -0
  80. mteb/results/benchmark_results.py +25 -19
  81. mteb/results/model_result.py +49 -21
  82. mteb/results/task_result.py +45 -51
  83. mteb/similarity_functions.py +11 -7
  84. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  85. mteb/tasks/classification/est/estonian_valence.py +1 -1
  86. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  87. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  88. mteb/tasks/retrieval/code/code_rag.py +12 -12
  89. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  90. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  91. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  92. mteb/tasks/retrieval/nob/norquad.py +2 -2
  93. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  94. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  95. mteb/types/_result.py +2 -1
  96. mteb/types/statistics.py +9 -3
  97. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/METADATA +1 -1
  98. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/RECORD +102 -101
  99. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/WHEEL +0 -0
  100. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/entry_points.txt +0 -0
  101. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/licenses/LICENSE +0 -0
  102. {mteb-2.5.3.dist-info → mteb-2.5.4.dist-info}/top_level.txt +0 -0
@@ -87,7 +87,7 @@ class AbsTaskRegression(AbsTaskClassification):
87
87
  Full details of api in [`SklearnModelProtocol`][mteb._evaluators.sklearn_evaluator.SklearnModelProtocol].
88
88
  """
89
89
 
90
- evaluator: type[SklearnModelProtocol] = SklearnEvaluator
90
+ evaluator: type[SklearnEvaluator] = SklearnEvaluator
91
91
  evaluator_model: SklearnModelProtocol = LinearRegression(n_jobs=-1)
92
92
 
93
93
  train_split: str = "train"
@@ -113,7 +113,7 @@ class AbsTaskRegression(AbsTaskClassification):
113
113
  )["train"]
114
114
  return train_split_sampled, []
115
115
 
116
- def _calculate_scores(
116
+ def _calculate_scores( # type: ignore[override]
117
117
  self,
118
118
  y_test: np.ndarray | list[int],
119
119
  y_pred: np.ndarray,
@@ -183,7 +183,7 @@ class AbsTaskRegression(AbsTaskClassification):
183
183
 
184
184
  return dataset_dict
185
185
 
186
- def _calculate_descriptive_statistics_from_split(
186
+ def _calculate_descriptive_statistics_from_split( # type: ignore[override]
187
187
  self, split: str, hf_subset: str | None = None, compute_overall: bool = False
188
188
  ) -> RegressionDescriptiveStatistics:
189
189
  train_text = []
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  from collections import defaultdict
4
- from collections.abc import Callable, Sequence
4
+ from collections.abc import Callable, Mapping, Sequence
5
5
  from pathlib import Path
6
6
  from time import time
7
7
  from typing import Any, Literal
@@ -286,7 +286,7 @@ class AbsTaskRetrieval(AbsTask):
286
286
  encode_kwargs: dict[str, Any],
287
287
  prediction_folder: Path | None = None,
288
288
  **kwargs,
289
- ) -> dict[HFSubset, ScoresDict]:
289
+ ) -> Mapping[HFSubset, ScoresDict]:
290
290
  """Evaluate the model on the retrieval task.
291
291
 
292
292
  Args:
@@ -357,6 +357,8 @@ class AbsTaskRetrieval(AbsTask):
357
357
  **kwargs,
358
358
  )
359
359
 
360
+ search_model: SearchProtocol
361
+
360
362
  if isinstance(model, EncoderProtocol) and not isinstance(model, SearchProtocol):
361
363
  search_model = SearchEncoderWrapper(model)
362
364
  elif isinstance(model, CrossEncoderProtocol):
@@ -578,11 +580,12 @@ class AbsTaskRetrieval(AbsTask):
578
580
  if isinstance(data[split][subset_item], Dataset):
579
581
  sections[split] = data[split][subset_item]
580
582
  elif converter is not None:
583
+ subset_data = data[split][subset_item]
584
+ if subset_data is None:
585
+ continue
586
+
581
587
  sections[split] = Dataset.from_list(
582
- [
583
- converter(idx, item)
584
- for idx, item in data[split][subset_item].items()
585
- ]
588
+ [converter(idx, item) for idx, item in subset_data.items()]
586
589
  )
587
590
  else:
588
591
  raise ValueError(
@@ -680,7 +683,7 @@ class AbsTaskRetrieval(AbsTask):
680
683
 
681
684
  top_k_sorted = defaultdict(list)
682
685
  for query_id, values in top_ranked.items():
683
- sorted_keys = sorted(values, key=values.get, reverse=True)
686
+ sorted_keys = sorted(values, key=lambda k: values[k], reverse=True)
684
687
  top_k_sorted[query_id] = sorted_keys[: self._top_k]
685
688
 
686
689
  self.dataset[subset][split]["top_ranked"] = top_k_sorted
@@ -688,10 +691,10 @@ class AbsTaskRetrieval(AbsTask):
688
691
 
689
692
 
690
693
  def _process_relevant_docs(
691
- collection: dict[str, dict[str, float]],
694
+ collection: Mapping[str, Mapping[str, int]],
692
695
  hf_subset: str,
693
696
  split: str,
694
- ) -> dict[str, dict[str, float]]:
697
+ ) -> dict[str, dict[str, int]]:
695
698
  """Collections can contain overlapping ids in different splits. Prepend split and subset to avoid this
696
699
 
697
700
  Returns:
mteb/abstasks/sts.py CHANGED
@@ -7,7 +7,7 @@ from scipy.stats import pearsonr, spearmanr
7
7
 
8
8
  from mteb._evaluators import AnySTSEvaluator
9
9
  from mteb._evaluators.any_sts_evaluator import STSEvaluatorScores
10
- from mteb.models import EncoderProtocol
10
+ from mteb.models import EncoderProtocol, MTEBModels
11
11
  from mteb.types import PromptType
12
12
  from mteb.types.statistics import (
13
13
  ImageStatistics,
@@ -103,7 +103,7 @@ class AbsTaskSTS(AbsTask):
103
103
 
104
104
  def _evaluate_subset(
105
105
  self,
106
- model: EncoderProtocol,
106
+ model: MTEBModels,
107
107
  data_split: Dataset,
108
108
  encode_kwargs: dict[str, Any],
109
109
  hf_split: str,
@@ -111,6 +111,9 @@ class AbsTaskSTS(AbsTask):
111
111
  prediction_folder: Path | None = None,
112
112
  **kwargs: Any,
113
113
  ) -> STSMetrics:
114
+ if not isinstance(model, EncoderProtocol):
115
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
116
+
114
117
  normalized_scores = list(map(self._normalize, data_split["score"]))
115
118
  data_split = data_split.select_columns(list(self.column_names))
116
119
 
@@ -142,7 +145,7 @@ class AbsTaskSTS(AbsTask):
142
145
  ) -> STSMetrics:
143
146
  def compute_corr(x: list[float], y: list[float]) -> tuple[float, float]:
144
147
  """Return (pearson, spearman) correlations between x and y."""
145
- return pearsonr(x, y)[0], spearmanr(x, y)[0]
148
+ return float(pearsonr(x, y)[0]), float(spearmanr(x, y)[0])
146
149
 
147
150
  cosine_pearson, cosine_spearman = compute_corr(
148
151
  normalized_scores, scores["cosine_scores"]
@@ -2,9 +2,10 @@ import json
2
2
  import logging
3
3
  from collections.abc import Sequence
4
4
  from pathlib import Path
5
- from typing import Any, Literal
5
+ from typing import Any, Literal, cast
6
6
 
7
7
  from huggingface_hub import (
8
+ CardData,
8
9
  DatasetCard,
9
10
  DatasetCardData,
10
11
  constants,
@@ -150,7 +151,7 @@ _TASK_TYPE = (
150
151
  "InstructionReranking",
151
152
  ) + MIEB_TASK_TYPE
152
153
 
153
- TaskType = Literal[_TASK_TYPE]
154
+ TaskType = Literal[_TASK_TYPE] # type: ignore[valid-type]
154
155
  """The type of the task. E.g. includes "Classification", "Retrieval" and "Clustering"."""
155
156
 
156
157
 
@@ -192,8 +193,10 @@ AnnotatorType = Literal[
192
193
  """The type of the annotators. Is often important for understanding the quality of a dataset."""
193
194
 
194
195
 
195
- PromptDict = TypedDict(
196
- "PromptDict", {prompt_type.value: str for prompt_type in PromptType}, total=False
196
+ PromptDict = TypedDict( # type: ignore[misc]
197
+ "PromptDict",
198
+ {prompt_type.value: str for prompt_type in PromptType},
199
+ total=False,
197
200
  )
198
201
  """A dictionary containing the prompt used for the task.
199
202
 
@@ -365,7 +368,7 @@ class TaskMetadata(BaseModel):
365
368
  """Return a dictionary mapping huggingface subsets to languages."""
366
369
  if isinstance(self.eval_langs, dict):
367
370
  return self.eval_langs
368
- return {"default": self.eval_langs} # type: ignore
371
+ return {"default": cast(list[str], self.eval_langs)}
369
372
 
370
373
  @property
371
374
  def intext_citation(self, include_cite: bool = True) -> str:
@@ -413,7 +416,7 @@ class TaskMetadata(BaseModel):
413
416
  for subset, subset_value in stats.items():
414
417
  if subset == "hf_subset_descriptive_stats":
415
418
  continue
416
- n_samples[subset] = subset_value["num_samples"] # type: ignore
419
+ n_samples[subset] = subset_value["num_samples"]
417
420
  return n_samples
418
421
 
419
422
  @property
@@ -446,7 +449,7 @@ class TaskMetadata(BaseModel):
446
449
  Raises:
447
450
  ValueError: If the prompt type is not recognized.
448
451
  """
449
- if prompt_type is None:
452
+ if prompt_type is None or self.category is None:
450
453
  return self.modalities
451
454
  query_modalities, doc_modalities = self.category.split("2")
452
455
  category_to_modality: dict[str, Modalities] = {
@@ -466,7 +469,7 @@ class TaskMetadata(BaseModel):
466
469
 
467
470
  def _create_dataset_card_data(
468
471
  self,
469
- existing_dataset_card_data: DatasetCardData | None = None,
472
+ existing_dataset_card_data: CardData | None = None,
470
473
  ) -> tuple[DatasetCardData, dict[str, Any]]:
471
474
  """Create a DatasetCardData object from the task metadata.
472
475
 
@@ -501,12 +504,13 @@ class TaskMetadata(BaseModel):
501
504
 
502
505
  tags = ["mteb"] + self.modalities
503
506
 
504
- descriptive_stats = self.descriptive_stats
505
- if descriptive_stats is not None:
506
- for split, split_stat in descriptive_stats.items():
507
+ descriptive_stats = ""
508
+ if self.descriptive_stats is not None:
509
+ descriptive_stats_ = self.descriptive_stats
510
+ for split, split_stat in descriptive_stats_.items():
507
511
  if len(split_stat.get("hf_subset_descriptive_stats", {})) > 10:
508
512
  split_stat.pop("hf_subset_descriptive_stats", {})
509
- descriptive_stats = json.dumps(descriptive_stats, indent=4)
513
+ descriptive_stats = json.dumps(descriptive_stats_, indent=4)
510
514
 
511
515
  dataset_card_data_params = existing_dataset_card_data.to_dict()
512
516
  # override the existing values
@@ -694,11 +698,11 @@ class TaskMetadata(BaseModel):
694
698
 
695
699
  def _hf_languages(self) -> list[str]:
696
700
  languages: list[str] = []
697
- if self.is_multilingual:
698
- for val in list(self.eval_langs.values()):
701
+ if self.is_multilingual and isinstance(self.eval_langs, dict):
702
+ for val in self.eval_langs.values():
699
703
  languages.extend(val)
700
704
  else:
701
- languages = self.eval_langs
705
+ languages = cast(list[str], self.eval_langs)
702
706
  # value "python" is not valid. It must be an ISO 639-1, 639-2 or 639-3 code (two/three letters),
703
707
  # or a special value like "code", "multilingual".
704
708
  readme_langs = []
@@ -710,7 +714,7 @@ class TaskMetadata(BaseModel):
710
714
  readme_langs.append(lang_name)
711
715
  return sorted(set(readme_langs))
712
716
 
713
- def _hf_license(self) -> str:
717
+ def _hf_license(self) -> str | None:
714
718
  dataset_license = self.license
715
719
  if dataset_license:
716
720
  license_mapping = {
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from collections import defaultdict
3
3
  from pathlib import Path
4
- from typing import Any, ClassVar, TypedDict
4
+ from typing import Any, ClassVar, TypedDict, cast
5
5
 
6
6
  from datasets import Dataset, DatasetDict
7
7
  from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
@@ -78,6 +78,9 @@ class AbsTaskBitextMining(AbsTask):
78
78
  **kwargs: Any,
79
79
  ) -> dict[HFSubset, ScoresDict]:
80
80
  """Added load for "parallel" datasets"""
81
+ if not isinstance(model, EncoderProtocol):
82
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
83
+
81
84
  if not self.data_loaded:
82
85
  self.load_data()
83
86
 
@@ -87,11 +90,16 @@ class AbsTaskBitextMining(AbsTask):
87
90
  if subsets_to_run is not None:
88
91
  hf_subsets = [s for s in hf_subsets if s in subsets_to_run]
89
92
 
90
- scores = {}
93
+ encoder_model = cast(EncoderProtocol, model)
94
+
95
+ if self.dataset is None:
96
+ raise ValueError("Dataset is not loaded.")
97
+
98
+ scores: dict[str, BitextMiningMetrics] = {}
91
99
  if self.parallel_subsets:
92
- scores = self._evaluate_subset(
93
- model,
94
- self.dataset[split], # type: ignore
100
+ scores = self._evaluate_subset( # type: ignore[assignment]
101
+ encoder_model,
102
+ self.dataset[split],
95
103
  parallel=True,
96
104
  hf_split=split,
97
105
  hf_subset="parallel",
@@ -109,8 +117,8 @@ class AbsTaskBitextMining(AbsTask):
109
117
  data_split = self.dataset[split]
110
118
  else:
111
119
  data_split = self.dataset[hf_subset][split]
112
- scores[hf_subset] = self._evaluate_subset(
113
- model,
120
+ scores[hf_subset] = self._evaluate_subset( # type: ignore[assignment]
121
+ encoder_model,
114
122
  data_split,
115
123
  hf_split=split,
116
124
  hf_subset=hf_subset,
@@ -119,32 +127,32 @@ class AbsTaskBitextMining(AbsTask):
119
127
  **kwargs,
120
128
  )
121
129
 
122
- return scores
130
+ return cast(dict[HFSubset, ScoresDict], scores)
123
131
 
124
132
  def _get_pairs(self, parallel: bool) -> list[tuple[str, str]]:
125
133
  pairs = self._DEFAULT_PAIR
126
134
  if parallel:
127
- pairs = [langpair.split("-") for langpair in self.hf_subsets]
135
+ pairs = [langpair.split("-") for langpair in self.hf_subsets] # type: ignore[misc]
128
136
  return pairs
129
137
 
130
- def _evaluate_subset(
138
+ def _evaluate_subset( # type: ignore[override]
131
139
  self,
132
140
  model: EncoderProtocol,
133
141
  data_split: Dataset,
134
142
  *,
135
143
  hf_split: str,
136
144
  hf_subset: str,
137
- parallel: bool = False,
138
145
  encode_kwargs: dict[str, Any],
139
146
  prediction_folder: Path | None = None,
147
+ parallel: bool = False,
140
148
  **kwargs,
141
- ) -> ScoresDict:
149
+ ) -> BitextMiningMetrics | dict[str, BitextMiningMetrics]:
142
150
  pairs = self._get_pairs(parallel)
143
151
 
144
152
  evaluator = BitextMiningEvaluator(
145
153
  data_split,
146
154
  task_metadata=self.metadata,
147
- pair_columns=pairs, # type: ignore
155
+ pair_columns=pairs,
148
156
  hf_split=hf_split,
149
157
  hf_subset=hf_subset,
150
158
  **kwargs,
@@ -168,16 +176,16 @@ class AbsTaskBitextMining(AbsTask):
168
176
  )
169
177
 
170
178
  if parallel:
171
- metrics = {}
179
+ parallel_metrics = {}
172
180
  for keys, nearest_neighbors in neighbours.items():
173
- metrics[keys] = self._compute_metrics(nearest_neighbors, gold)
181
+ parallel_metrics[keys] = self._compute_metrics(nearest_neighbors, gold)
174
182
 
175
- for v in metrics.values():
183
+ for v in parallel_metrics.values():
176
184
  self._add_main_score(v)
177
- else:
178
- def_pair_str = "-".join(self._DEFAULT_PAIR[0])
179
- metrics = self._compute_metrics(neighbours[def_pair_str], gold)
180
- self._add_main_score(metrics)
185
+ return parallel_metrics
186
+ def_pair_str = "-".join(self._DEFAULT_PAIR[0])
187
+ metrics = self._compute_metrics(neighbours[def_pair_str], gold)
188
+ self._add_main_score(metrics)
181
189
  return metrics
182
190
 
183
191
  def _compute_metrics(
@@ -250,8 +258,11 @@ class AbsTaskBitextMining(AbsTask):
250
258
  )
251
259
 
252
260
  def _push_dataset_to_hub(self, repo_name: str) -> None:
261
+ if self.dataset is None:
262
+ raise ValueError("Dataset is not loaded.")
263
+
253
264
  if self.metadata.is_multilingual:
254
- dataset = defaultdict(dict)
265
+ dataset: dict[str, dict[str, list[str]]] = defaultdict(dict)
255
266
  for config in self.metadata.eval_langs:
256
267
  logger.info(f"Converting {config} of {self.metadata.name}")
257
268
 
@@ -266,10 +277,10 @@ class AbsTaskBitextMining(AbsTask):
266
277
  for split in self.dataset[config]:
267
278
  dataset[split][lang_1] = self.dataset[config][split][sent_1]
268
279
  dataset[split][lang_2] = self.dataset[config][split][sent_2]
269
- for split in dataset:
270
- dataset[split] = Dataset.from_dict(dataset[split])
271
- dataset = DatasetDict(dataset)
272
- dataset.push_to_hub(repo_name)
280
+ dataset_dict = DatasetDict(
281
+ {split: Dataset.from_dict(dataset[split]) for split in dataset}
282
+ )
283
+ dataset_dict.push_to_hub(repo_name)
273
284
  else:
274
285
  sentences = {}
275
286
  for split in self.dataset:
@@ -16,7 +16,7 @@ else:
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
- OLD_FORMAT_RERANKING_TASKS = []
19
+ OLD_FORMAT_RERANKING_TASKS: list[str] = []
20
20
 
21
21
 
22
22
  @deprecated(
@@ -105,7 +105,9 @@ class AbsTaskReranking(AbsTaskRetrieval):
105
105
  )
106
106
 
107
107
  given_dataset = copy(given_dataset)
108
- self.dataset = defaultdict(lambda: defaultdict(dict))
108
+ self.dataset: dict[str, dict[str, RetrievalSplitData]] = defaultdict(
109
+ lambda: defaultdict(dict) # type: ignore[arg-type]
110
+ )
109
111
 
110
112
  hf_subsets = self.hf_subsets
111
113
 
@@ -115,19 +117,19 @@ class AbsTaskReranking(AbsTaskRetrieval):
115
117
  if hf_subset in cur_dataset:
116
118
  cur_dataset = cur_dataset[hf_subset]
117
119
  elif "name" in self.metadata.dataset:
118
- cur_dataset = datasets.load_dataset(**self.metadata.dataset) # type: ignore
120
+ cur_dataset = datasets.load_dataset(**self.metadata.dataset)
119
121
  assert hf_subset == "default", (
120
122
  f"Only default subset is supported for {self.metadata.name} since `name` is given in the metadata."
121
123
  )
122
124
  else:
123
125
  cur_dataset = datasets.load_dataset(
124
126
  **self.metadata.dataset, name=hf_subset
125
- ) # type: ignore
127
+ )
126
128
 
127
129
  for split in cur_dataset:
128
130
  corpus = []
129
131
  queries = []
130
- relevant_docs = defaultdict(dict)
132
+ relevant_docs: dict[str, dict[str, int]] = defaultdict(dict)
131
133
  top_ranked = defaultdict(list)
132
134
 
133
135
  # Create an enumerated dataset to pass indices
@@ -12,7 +12,7 @@ from mteb.abstasks._statistics_calculation import (
12
12
  calculate_text_statistics,
13
13
  )
14
14
  from mteb.abstasks.abstask import AbsTask
15
- from mteb.models import EncoderProtocol
15
+ from mteb.models import EncoderProtocol, MTEBModels
16
16
  from mteb.types.statistics import (
17
17
  ScoreStatistics,
18
18
  SplitDescriptiveStatistics,
@@ -77,7 +77,7 @@ class AbsTaskSummarization(AbsTask):
77
77
 
78
78
  def _evaluate_subset(
79
79
  self,
80
- model: EncoderProtocol,
80
+ model: MTEBModels,
81
81
  data_split: Dataset,
82
82
  *,
83
83
  hf_split: str,
@@ -86,8 +86,13 @@ class AbsTaskSummarization(AbsTask):
86
86
  prediction_folder: Path | None = None,
87
87
  **kwargs,
88
88
  ) -> SummarizationMetrics:
89
+ if not isinstance(model, EncoderProtocol):
90
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
91
+
89
92
  normalized_scores = [
90
- (np.array(x) - self.min_score) / (self.max_score - self.min_score)
93
+ (
94
+ (np.array(x) - self.min_score) / (self.max_score - self.min_score)
95
+ ).tolist()
91
96
  for x in data_split[self.relevancy_column_name]
92
97
  ]
93
98
  evaluator = self.evaluator(
@@ -7,7 +7,7 @@ from datasets import Dataset
7
7
  from sklearn import metrics
8
8
 
9
9
  from mteb._evaluators import ZeroShotClassificationEvaluator
10
- from mteb.models import EncoderProtocol
10
+ from mteb.models import EncoderProtocol, MTEBModels
11
11
  from mteb.types.statistics import (
12
12
  ImageStatistics,
13
13
  LabelStatistics,
@@ -111,7 +111,7 @@ class AbsTaskZeroShotClassification(AbsTask):
111
111
 
112
112
  def _evaluate_subset(
113
113
  self,
114
- model: EncoderProtocol,
114
+ model: MTEBModels,
115
115
  data_split: Dataset,
116
116
  *,
117
117
  hf_split: str,
@@ -120,6 +120,9 @@ class AbsTaskZeroShotClassification(AbsTask):
120
120
  prediction_folder: Path | None = None,
121
121
  **kwargs,
122
122
  ) -> ZeroShotClassificationMetrics:
123
+ if not isinstance(model, EncoderProtocol):
124
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
125
+
123
126
  candidate_labels = self.get_candidate_labels()
124
127
  data_split = data_split.select_columns(
125
128
  [self.input_column_name, self.label_column_name]
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Iterable, Sequence
3
+ from collections.abc import Iterator, Sequence
4
4
  from dataclasses import dataclass, field
5
5
  from typing import TYPE_CHECKING, Literal
6
6
 
@@ -47,7 +47,7 @@ class Benchmark:
47
47
  display_name: str | None = None
48
48
  language_view: list[str] | Literal["all"] = field(default_factory=list)
49
49
 
50
- def __iter__(self) -> Iterable[AbsTask]:
50
+ def __iter__(self) -> Iterator[AbsTask]:
51
51
  return iter(self.tasks)
52
52
 
53
53
  def __len__(self) -> int:
mteb/cache.py CHANGED
@@ -5,7 +5,7 @@ import shutil
5
5
  import subprocess
6
6
  import warnings
7
7
  from collections import defaultdict
8
- from collections.abc import Sequence
8
+ from collections.abc import Iterable, Sequence
9
9
  from pathlib import Path
10
10
  from typing import cast
11
11
 
@@ -291,8 +291,8 @@ class ResultCache:
291
291
 
292
292
  def get_cache_paths(
293
293
  self,
294
- models: Sequence[str] | Sequence[ModelMeta] | None = None,
295
- tasks: Sequence[str] | Sequence[AbsTask] | None = None,
294
+ models: Sequence[str] | Iterable[ModelMeta] | None = None,
295
+ tasks: Sequence[str] | Iterable[AbsTask] | None = None,
296
296
  require_model_meta: bool = True,
297
297
  include_remote: bool = True,
298
298
  ) -> list[Path]:
@@ -425,7 +425,7 @@ class ResultCache:
425
425
  @staticmethod
426
426
  def _filter_paths_by_model_and_revision(
427
427
  paths: list[Path],
428
- models: Sequence[str] | Sequence[ModelMeta] | None = None,
428
+ models: Sequence[str] | Iterable[ModelMeta] | None = None,
429
429
  ) -> list[Path]:
430
430
  """Filter a list of paths by model name and optional revision.
431
431
 
@@ -435,8 +435,9 @@ class ResultCache:
435
435
  if not models:
436
436
  return paths
437
437
 
438
- if isinstance(models[0], ModelMeta):
439
- models = cast(list[ModelMeta], models)
438
+ first_model = next(iter(models))
439
+ if isinstance(first_model, ModelMeta):
440
+ models = cast(Iterable[ModelMeta], models)
440
441
  name_and_revision = {
441
442
  (m.model_name_as_path(), m.revision or "no_revision_available")
442
443
  for m in models
@@ -447,13 +448,14 @@ class ResultCache:
447
448
  if (p.parent.parent.name, p.parent.name) in name_and_revision
448
449
  ]
449
450
 
450
- model_names = {m.replace("/", "__").replace(" ", "_") for m in models}
451
+ str_models = cast(Sequence[str], models)
452
+ model_names = {m.replace("/", "__").replace(" ", "_") for m in str_models}
451
453
  return [p for p in paths if p.parent.parent.name in model_names]
452
454
 
453
455
  @staticmethod
454
456
  def _filter_paths_by_task(
455
457
  paths: list[Path],
456
- tasks: Sequence[str] | Sequence[AbsTask] | None = None,
458
+ tasks: Sequence[str] | Iterable[AbsTask] | None = None,
457
459
  ) -> list[Path]:
458
460
  if tasks is not None:
459
461
  task_names = set()
@@ -469,8 +471,8 @@ class ResultCache:
469
471
 
470
472
  def load_results(
471
473
  self,
472
- models: Sequence[str] | Sequence[ModelMeta] | None = None,
473
- tasks: Sequence[str] | Sequence[AbsTask] | Benchmark | str | None = None,
474
+ models: Sequence[str] | Iterable[ModelMeta] | None = None,
475
+ tasks: Sequence[str] | Iterable[AbsTask] | str | None = None,
474
476
  require_model_meta: bool = True,
475
477
  include_remote: bool = True,
476
478
  validate_and_filter: bool = False,
@@ -514,7 +516,7 @@ class ResultCache:
514
516
  )
515
517
  models_results = defaultdict(list)
516
518
 
517
- task_names = {}
519
+ task_names: dict[str, AbsTask | None] = {}
518
520
  if tasks is not None:
519
521
  for task in tasks:
520
522
  if isinstance(task, AbsTask):
@@ -532,9 +534,11 @@ class ResultCache:
532
534
  )
533
535
 
534
536
  if validate_and_filter:
535
- task = task_names[task_result.task_name]
537
+ task_instance = task_names[task_result.task_name]
536
538
  try:
537
- task_result = task_result.validate_and_filter_scores(task=task)
539
+ task_result = task_result.validate_and_filter_scores(
540
+ task=task_instance
541
+ )
538
542
  except Exception as e:
539
543
  logger.info(
540
544
  f"Validation failed for {task_result.task_name} in {model_name} {revision}: {e}"
@@ -544,7 +548,7 @@ class ResultCache:
544
548
  models_results[(model_name, revision)].append(task_result)
545
549
 
546
550
  # create BenchmarkResults object
547
- models_results = [
551
+ models_results_object = [
548
552
  ModelResult(
549
553
  model_name=model_name,
550
554
  model_revision=revision,
@@ -553,9 +557,7 @@ class ResultCache:
553
557
  for (model_name, revision), task_results in models_results.items()
554
558
  ]
555
559
 
556
- benchmark_results = BenchmarkResults(
557
- model_results=models_results,
560
+ return BenchmarkResults(
561
+ model_results=models_results_object,
558
562
  benchmark=tasks if isinstance(tasks, Benchmark) else None,
559
563
  )
560
-
561
- return benchmark_results
@@ -1,4 +1,4 @@
1
- from collections.abc import Sequence
1
+ from collections.abc import Iterable, Sequence
2
2
 
3
3
  from mteb.abstasks import AbsTask
4
4
  from mteb.benchmarks import Benchmark
@@ -31,7 +31,7 @@ def _display_benchmarks(benchmarks: Sequence[Benchmark]) -> None:
31
31
  _display_tasks(benchmark.tasks, name=name)
32
32
 
33
33
 
34
- def _display_tasks(task_list: Sequence[AbsTask], name: str | None = None) -> None:
34
+ def _display_tasks(task_list: Iterable[AbsTask], name: str | None = None) -> None:
35
35
  from rich.console import Console
36
36
 
37
37
  console = Console()
mteb/cli/build_cli.py CHANGED
@@ -8,12 +8,12 @@ import torch
8
8
  from rich.logging import RichHandler
9
9
 
10
10
  import mteb
11
+ from mteb.abstasks.abstask import AbsTask
11
12
  from mteb.cache import ResultCache
13
+ from mteb.cli._display_tasks import _display_benchmarks, _display_tasks
12
14
  from mteb.cli.generate_model_card import generate_model_card
13
15
  from mteb.evaluate import OverwriteStrategy
14
16
 
15
- from ._display_tasks import _display_benchmarks, _display_tasks
16
-
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
19
 
@@ -54,7 +54,7 @@ def run(args: argparse.Namespace) -> None:
54
54
 
55
55
  if args.benchmarks:
56
56
  benchmarks = mteb.get_benchmarks(names=args.benchmarks)
57
- tasks = [t for b in benchmarks for t in b.tasks]
57
+ tasks = tuple(t for b in benchmarks for t in b.tasks)
58
58
  else:
59
59
  tasks = mteb.get_tasks(
60
60
  categories=args.categories,
@@ -290,9 +290,9 @@ def _create_meta(args: argparse.Namespace) -> None:
290
290
  "Output path already exists, use --overwrite to overwrite."
291
291
  )
292
292
 
293
- tasks = []
293
+ tasks: list[AbsTask] = []
294
294
  if tasks_names is not None:
295
- tasks = mteb.get_tasks(tasks_names)
295
+ tasks = list(mteb.get_tasks(tasks_names))
296
296
  if benchmarks is not None:
297
297
  benchmarks = mteb.get_benchmarks(benchmarks)
298
298
  for benchmark in benchmarks: