mteb 2.5.2__py3-none-any.whl → 2.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. mteb/_create_dataloaders.py +10 -15
  2. mteb/_evaluators/any_sts_evaluator.py +1 -4
  3. mteb/_evaluators/evaluator.py +2 -1
  4. mteb/_evaluators/image/imagetext_pairclassification_evaluator.py +5 -6
  5. mteb/_evaluators/pair_classification_evaluator.py +3 -1
  6. mteb/_evaluators/retrieval_metrics.py +17 -16
  7. mteb/_evaluators/sklearn_evaluator.py +9 -8
  8. mteb/_evaluators/text/bitext_mining_evaluator.py +23 -16
  9. mteb/_evaluators/text/summarization_evaluator.py +20 -16
  10. mteb/abstasks/_data_filter/filters.py +1 -1
  11. mteb/abstasks/_data_filter/task_pipelines.py +3 -0
  12. mteb/abstasks/_statistics_calculation.py +18 -10
  13. mteb/abstasks/_stratification.py +18 -18
  14. mteb/abstasks/abstask.py +33 -27
  15. mteb/abstasks/aggregate_task_metadata.py +1 -9
  16. mteb/abstasks/aggregated_task.py +7 -26
  17. mteb/abstasks/classification.py +10 -4
  18. mteb/abstasks/clustering.py +18 -14
  19. mteb/abstasks/clustering_legacy.py +8 -8
  20. mteb/abstasks/image/image_text_pair_classification.py +5 -3
  21. mteb/abstasks/multilabel_classification.py +20 -16
  22. mteb/abstasks/pair_classification.py +18 -9
  23. mteb/abstasks/regression.py +3 -3
  24. mteb/abstasks/retrieval.py +12 -9
  25. mteb/abstasks/sts.py +6 -3
  26. mteb/abstasks/task_metadata.py +22 -19
  27. mteb/abstasks/text/bitext_mining.py +36 -25
  28. mteb/abstasks/text/reranking.py +7 -5
  29. mteb/abstasks/text/summarization.py +8 -3
  30. mteb/abstasks/zeroshot_classification.py +5 -2
  31. mteb/benchmarks/benchmark.py +2 -2
  32. mteb/cache.py +27 -22
  33. mteb/cli/_display_tasks.py +2 -2
  34. mteb/cli/build_cli.py +15 -10
  35. mteb/cli/generate_model_card.py +10 -7
  36. mteb/deprecated_evaluator.py +60 -46
  37. mteb/evaluate.py +39 -30
  38. mteb/filter_tasks.py +25 -26
  39. mteb/get_tasks.py +29 -30
  40. mteb/languages/language_scripts.py +5 -3
  41. mteb/leaderboard/app.py +1 -1
  42. mteb/load_results.py +12 -12
  43. mteb/models/abs_encoder.py +7 -5
  44. mteb/models/cache_wrappers/cache_backend_protocol.py +3 -5
  45. mteb/models/cache_wrappers/cache_backends/_hash_utils.py +5 -4
  46. mteb/models/cache_wrappers/cache_backends/faiss_cache.py +6 -2
  47. mteb/models/cache_wrappers/cache_backends/numpy_cache.py +43 -25
  48. mteb/models/cache_wrappers/cache_wrapper.py +2 -2
  49. mteb/models/get_model_meta.py +8 -1
  50. mteb/models/instruct_wrapper.py +11 -5
  51. mteb/models/model_implementations/andersborges.py +2 -2
  52. mteb/models/model_implementations/blip_models.py +8 -8
  53. mteb/models/model_implementations/bm25.py +1 -1
  54. mteb/models/model_implementations/clip_models.py +3 -3
  55. mteb/models/model_implementations/cohere_models.py +1 -1
  56. mteb/models/model_implementations/cohere_v.py +2 -2
  57. mteb/models/model_implementations/dino_models.py +23 -23
  58. mteb/models/model_implementations/emillykkejensen_models.py +3 -3
  59. mteb/models/model_implementations/gme_v_models.py +4 -3
  60. mteb/models/model_implementations/jina_clip.py +1 -1
  61. mteb/models/model_implementations/jina_models.py +1 -1
  62. mteb/models/model_implementations/kennethenevoldsen_models.py +2 -2
  63. mteb/models/model_implementations/llm2clip_models.py +3 -3
  64. mteb/models/model_implementations/mcinext_models.py +4 -1
  65. mteb/models/model_implementations/moco_models.py +2 -2
  66. mteb/models/model_implementations/model2vec_models.py +1 -1
  67. mteb/models/model_implementations/nomic_models.py +8 -8
  68. mteb/models/model_implementations/openclip_models.py +7 -7
  69. mteb/models/model_implementations/random_baseline.py +3 -3
  70. mteb/models/model_implementations/rasgaard_models.py +1 -1
  71. mteb/models/model_implementations/repllama_models.py +2 -2
  72. mteb/models/model_implementations/rerankers_custom.py +3 -3
  73. mteb/models/model_implementations/rerankers_monot5_based.py +3 -3
  74. mteb/models/model_implementations/siglip_models.py +10 -10
  75. mteb/models/model_implementations/vlm2vec_models.py +1 -1
  76. mteb/models/model_implementations/voyage_v.py +4 -4
  77. mteb/models/model_meta.py +14 -13
  78. mteb/models/search_encoder_index/search_indexes/faiss_search_index.py +9 -6
  79. mteb/models/search_wrappers.py +26 -12
  80. mteb/models/sentence_transformer_wrapper.py +19 -14
  81. mteb/py.typed +0 -0
  82. mteb/results/benchmark_results.py +28 -20
  83. mteb/results/model_result.py +52 -22
  84. mteb/results/task_result.py +55 -58
  85. mteb/similarity_functions.py +11 -7
  86. mteb/tasks/classification/dan/dk_hate_classification.py +1 -1
  87. mteb/tasks/classification/est/estonian_valence.py +1 -1
  88. mteb/tasks/classification/multilingual/scala_classification.py +1 -1
  89. mteb/tasks/image_text_pair_classification/eng/sugar_crepe.py +1 -1
  90. mteb/tasks/retrieval/code/code_rag.py +12 -12
  91. mteb/tasks/retrieval/dan/dan_fever_retrieval.py +1 -1
  92. mteb/tasks/retrieval/dan/tv2_nordretrieval.py +2 -2
  93. mteb/tasks/retrieval/dan/twitter_hjerne_retrieval.py +2 -2
  94. mteb/tasks/retrieval/nob/norquad.py +2 -2
  95. mteb/tasks/retrieval/nob/snl_retrieval.py +2 -2
  96. mteb/tasks/retrieval/tur/tur_hist_quad.py +1 -1
  97. mteb/types/_result.py +2 -1
  98. mteb/types/statistics.py +9 -3
  99. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/METADATA +1 -1
  100. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/RECORD +104 -103
  101. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/WHEEL +0 -0
  102. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/entry_points.txt +0 -0
  103. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/licenses/LICENSE +0 -0
  104. {mteb-2.5.2.dist-info → mteb-2.5.4.dist-info}/top_level.txt +0 -0
mteb/abstasks/abstask.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import json
2
2
  import logging
3
+ import warnings
3
4
  from abc import ABC, abstractmethod
4
- from collections.abc import Sequence
5
+ from collections.abc import Mapping, Sequence
5
6
  from copy import copy
6
7
  from pathlib import Path
7
- from typing import Any, cast
8
+ from typing import Any, Literal, cast
8
9
 
9
10
  import numpy as np
10
11
  from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
@@ -78,8 +79,8 @@ class AbsTask(ABC):
78
79
  """
79
80
 
80
81
  metadata: TaskMetadata
81
- abstask_prompt: str | None = None
82
- _eval_splits: list[str] | None = None
82
+ abstask_prompt: str
83
+ _eval_splits: Sequence[str] | None = None
83
84
  dataset: dict[HFSubset, DatasetDict] | None = None
84
85
  data_loaded: bool = False
85
86
  hf_subsets: list[HFSubset]
@@ -102,9 +103,9 @@ class AbsTask(ABC):
102
103
  def check_if_dataset_is_superseded(self) -> None:
103
104
  """Check if the dataset is superseded by a newer version."""
104
105
  if self.superseded_by:
105
- logger.warning(
106
- f"Dataset '{self.metadata.name}' is superseded by '{self.superseded_by}', you might consider using the newer version of the dataset."
107
- )
106
+ msg = f"Dataset '{self.metadata.name}' is superseded by '{self.superseded_by}'. We recommend using the newer version of the dataset unless you are running a specific benchmark. See `get_task('{self.superseded_by}').metadata.description` to get a description of the task and changes."
107
+ logger.warning(msg)
108
+ warnings.warn(msg)
108
109
 
109
110
  def dataset_transform(self):
110
111
  """A transform operations applied to the dataset after loading.
@@ -123,7 +124,7 @@ class AbsTask(ABC):
123
124
  encode_kwargs: dict[str, Any],
124
125
  prediction_folder: Path | None = None,
125
126
  **kwargs: Any,
126
- ) -> dict[HFSubset, ScoresDict]:
127
+ ) -> Mapping[HFSubset, ScoresDict]:
127
128
  """Evaluates an MTEB compatible model on the task.
128
129
 
129
130
  Args:
@@ -195,12 +196,12 @@ class AbsTask(ABC):
195
196
  @abstractmethod
196
197
  def _evaluate_subset(
197
198
  self,
198
- model: EncoderProtocol,
199
+ model: MTEBModels,
199
200
  data_split: Dataset,
200
201
  *,
201
- encode_kwargs: dict[str, Any],
202
202
  hf_split: str,
203
203
  hf_subset: str,
204
+ encode_kwargs: dict[str, Any],
204
205
  prediction_folder: Path | None = None,
205
206
  **kwargs: Any,
206
207
  ) -> ScoresDict:
@@ -210,7 +211,7 @@ class AbsTask(ABC):
210
211
 
211
212
  def _save_task_predictions(
212
213
  self,
213
- predictions: dict[str, Any] | list[Any],
214
+ predictions: Mapping[str, Any] | list[Any],
214
215
  model: MTEBModels,
215
216
  prediction_folder: Path,
216
217
  hf_split: str,
@@ -226,7 +227,7 @@ class AbsTask(ABC):
226
227
  hf_subset: The subset of the dataset (e.g. "en").
227
228
  """
228
229
  predictions_path = self._predictions_path(prediction_folder)
229
- existing_results = {
230
+ existing_results: dict[str, Any] = {
230
231
  "mteb_model_meta": {
231
232
  "model_name": model.mteb_model_meta.name,
232
233
  "revision": model.mteb_model_meta.revision,
@@ -326,7 +327,7 @@ class AbsTask(ABC):
326
327
  )
327
328
  else:
328
329
  # some of monolingual datasets explicitly adding the split name to the dataset name
329
- self.dataset = load_dataset(**self.metadata.dataset) # type: ignore
330
+ self.dataset = load_dataset(**self.metadata.dataset)
330
331
  self.dataset_transform()
331
332
  self.data_loaded = True
332
333
 
@@ -362,15 +363,19 @@ class AbsTask(ABC):
362
363
  """
363
364
  from mteb.abstasks import AbsTaskClassification
364
365
 
365
- if self.metadata.descriptive_stat_path.exists() and not overwrite_results:
366
+ existing_stats = self.metadata.descriptive_stats
367
+
368
+ if existing_stats is not None and not overwrite_results:
366
369
  logger.info("Loading metadata descriptive statistics from cache.")
367
- return self.metadata.descriptive_stats
370
+ return existing_stats
368
371
 
369
372
  if not self.data_loaded:
370
373
  self.load_data()
371
374
 
372
375
  descriptive_stats: dict[str, DescriptiveStatistics] = {}
373
- hf_subset_stat = "hf_subset_descriptive_stats"
376
+ hf_subset_stat: Literal["hf_subset_descriptive_stats"] = (
377
+ "hf_subset_descriptive_stats"
378
+ )
374
379
  eval_splits = self.metadata.eval_splits
375
380
  if isinstance(self, AbsTaskClassification):
376
381
  eval_splits.append(self.train_split)
@@ -381,7 +386,7 @@ class AbsTask(ABC):
381
386
  logger.info(f"Processing metadata for split {split}")
382
387
  if self.metadata.is_multilingual:
383
388
  descriptive_stats[split] = (
384
- self._calculate_descriptive_statistics_from_split(
389
+ self._calculate_descriptive_statistics_from_split( # type: ignore[assignment]
385
390
  split, compute_overall=True
386
391
  )
387
392
  )
@@ -400,7 +405,7 @@ class AbsTask(ABC):
400
405
  descriptive_stats[split][hf_subset_stat][hf_subset] = split_details
401
406
  else:
402
407
  split_details = self._calculate_descriptive_statistics_from_split(split)
403
- descriptive_stats[split] = split_details
408
+ descriptive_stats[split] = split_details # type: ignore[assignment]
404
409
 
405
410
  with self.metadata.descriptive_stat_path.open("w") as f:
406
411
  json.dump(descriptive_stats, f, indent=4)
@@ -437,7 +442,7 @@ class AbsTask(ABC):
437
442
 
438
443
  return self.metadata.languages
439
444
 
440
- def filter_eval_splits(self, eval_splits: list[str] | None) -> Self:
445
+ def filter_eval_splits(self, eval_splits: Sequence[str] | None) -> Self:
441
446
  """Filter the evaluation splits of the task.
442
447
 
443
448
  Args:
@@ -451,9 +456,9 @@ class AbsTask(ABC):
451
456
 
452
457
  def filter_languages(
453
458
  self,
454
- languages: list[str] | None,
455
- script: list[str] | None = None,
456
- hf_subsets: list[HFSubset] | None = None,
459
+ languages: Sequence[str] | None,
460
+ script: Sequence[str] | None = None,
461
+ hf_subsets: Sequence[HFSubset] | None = None,
457
462
  exclusive_language_filter: bool = False,
458
463
  ) -> Self:
459
464
  """Filter the languages of the task.
@@ -499,12 +504,14 @@ class AbsTask(ABC):
499
504
  self.hf_subsets = subsets_to_keep
500
505
  return self
501
506
 
502
- def _add_main_score(self, scores: dict[HFSubset, ScoresDict]) -> None:
507
+ def _add_main_score(self, scores: ScoresDict) -> None:
503
508
  scores["main_score"] = scores[self.metadata.main_score]
504
509
 
505
510
  def _upload_dataset_to_hub(
506
511
  self, repo_name: str, fields: list[str] | dict[str, str]
507
512
  ) -> None:
513
+ if self.dataset is None:
514
+ raise ValueError("Dataset not loaded")
508
515
  if self.metadata.is_multilingual:
509
516
  for config in self.metadata.eval_langs:
510
517
  logger.info(f"Converting {config} of {self.metadata.name}")
@@ -574,7 +581,7 @@ class AbsTask(ABC):
574
581
  return False
575
582
 
576
583
  @property
577
- def eval_splits(self) -> list[str]:
584
+ def eval_splits(self) -> Sequence[str]:
578
585
  """Returns the evaluation splits of the task."""
579
586
  if self._eval_splits:
580
587
  return self._eval_splits
@@ -607,9 +614,8 @@ class AbsTask(ABC):
607
614
  self.data_loaded = False
608
615
  logger.info(f"Unloaded dataset {self.metadata.name} from memory.")
609
616
  else:
610
- logger.warning(
611
- f"Dataset {self.metadata.name} is not loaded, cannot unload it."
612
- )
617
+ msg = f"Dataset `{self.metadata.name}` is not loaded, cannot unload it."
618
+ logger.warning(msg)
613
619
 
614
620
  @property
615
621
  def superseded_by(self) -> str | None:
@@ -5,7 +5,6 @@ from pydantic import ConfigDict, Field, model_validator
5
5
  from typing_extensions import Self
6
6
 
7
7
  from mteb.types import (
8
- HFSubset,
9
8
  ISOLanguageScript,
10
9
  Languages,
11
10
  Licenses,
@@ -60,14 +59,7 @@ class AggregateTaskMetadata(TaskMetadata):
60
59
  reference: str | None = None
61
60
  bibtex_citation: str | None = None
62
61
 
63
- @property
64
- def hf_subsets_to_langscripts(self) -> dict[HFSubset, list[ISOLanguageScript]]:
65
- """Return a dictionary mapping huggingface subsets to languages."""
66
- if isinstance(self.eval_langs, dict):
67
- return self.eval_langs
68
- return {"default": self.eval_langs} # type: ignore
69
-
70
- @model_validator(mode="after") # type: ignore
62
+ @model_validator(mode="after")
71
63
  def _compute_unfilled_cases(self) -> Self:
72
64
  if not self.eval_langs:
73
65
  self.eval_langs = self._compute_eval_langs()
@@ -1,10 +1,11 @@
1
1
  import logging
2
+ import warnings
3
+ from collections.abc import Mapping
2
4
  from pathlib import Path
3
5
  from typing import Any
4
6
 
5
7
  import numpy as np
6
8
  from datasets import Dataset, DatasetDict
7
- from typing_extensions import Self
8
9
 
9
10
  from mteb.models.models_protocols import MTEBModels
10
11
  from mteb.results.task_result import TaskResult
@@ -32,7 +33,7 @@ class AbsTaskAggregate(AbsTask):
32
33
 
33
34
  def task_results_to_scores(
34
35
  self, task_results: list[TaskResult]
35
- ) -> dict[str, dict[HFSubset, ScoresDict]]:
36
+ ) -> dict[str, Mapping[HFSubset, ScoresDict]]:
36
37
  """The function that aggregated scores. Can be redefined to allow for custom aggregations.
37
38
 
38
39
  Args:
@@ -41,7 +42,7 @@ class AbsTaskAggregate(AbsTask):
41
42
  Returns:
42
43
  A dictionary with the aggregated scores.
43
44
  """
44
- scores = {}
45
+ scores: dict[str, Mapping[HFSubset, ScoresDict]] = {}
45
46
  subsets = (
46
47
  self.metadata.eval_langs.keys()
47
48
  if isinstance(self.metadata.eval_langs, dict)
@@ -113,33 +114,13 @@ class AbsTaskAggregate(AbsTask):
113
114
  )
114
115
  mteb_versions = {tr.mteb_version for tr in task_results}
115
116
  if len(mteb_versions) != 1:
116
- logger.warning(
117
- f"All tasks of {self.metadata.name} is not run using the same version."
118
- )
117
+ msg = f"All tasks of {self.metadata.name} is not run using the same version. different versions found are: {mteb_versions}"
118
+ logger.warning(msg)
119
+ warnings.warn(msg)
119
120
  task_res.mteb_version = None
120
121
  task_res.mteb_version = task_results[0].mteb_version
121
122
  return task_res
122
123
 
123
- def check_if_dataset_is_superseded(self) -> None:
124
- """Check if the dataset is superseded by a newer version"""
125
- if self.superseded_by:
126
- logger.warning(
127
- f"Dataset '{self.metadata.name}' is superseded by '{self.superseded_by}', you might consider using the newer version of the dataset."
128
- )
129
-
130
- def filter_eval_splits(self, eval_splits: list[str] | None) -> Self:
131
- """Filter the evaluation splits of the task.
132
-
133
- Args:
134
- eval_splits: List of splits to evaluate on. If None, all splits in metadata
135
- are used.
136
-
137
- Returns:
138
- The task with filtered evaluation splits.
139
- """
140
- self._eval_splits = eval_splits
141
- return self
142
-
143
124
  def evaluate(
144
125
  self,
145
126
  model: MTEBModels,
@@ -143,6 +143,9 @@ class AbsTaskClassification(AbsTask):
143
143
  if not self.data_loaded:
144
144
  self.load_data()
145
145
 
146
+ if self.dataset is None:
147
+ raise RuntimeError("Dataset not loaded.")
148
+
146
149
  if "random_state" in self.evaluator_model.get_params():
147
150
  self.evaluator_model = self.evaluator_model.set_params(
148
151
  random_state=self.seed
@@ -175,11 +178,11 @@ class AbsTaskClassification(AbsTask):
175
178
  )
176
179
  self._add_main_score(scores[hf_subset])
177
180
 
178
- return scores
181
+ return scores # type: ignore[return-value]
179
182
 
180
183
  def _evaluate_subset(
181
184
  self,
182
- model: EncoderProtocol,
185
+ model: MTEBModels,
183
186
  data_split: DatasetDict,
184
187
  *,
185
188
  encode_kwargs: dict[str, Any],
@@ -188,6 +191,9 @@ class AbsTaskClassification(AbsTask):
188
191
  prediction_folder: Path | None = None,
189
192
  **kwargs: Any,
190
193
  ) -> FullClassificationMetrics:
194
+ if not isinstance(model, EncoderProtocol):
195
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
196
+
191
197
  train_split = data_split[self.train_split]
192
198
  eval_split = data_split[hf_split]
193
199
 
@@ -237,7 +243,7 @@ class AbsTaskClassification(AbsTask):
237
243
  # ap will be none for non binary classification tasks
238
244
  k: (
239
245
  float(np.mean(values))
240
- if (values := [s[k] for s in scores if s[k] is not None])
246
+ if (values := [s[k] for s in scores if s[k] is not None]) # type: ignore[literal-required]
241
247
  else np.nan
242
248
  )
243
249
  for k in scores[0].keys()
@@ -245,7 +251,7 @@ class AbsTaskClassification(AbsTask):
245
251
  logger.info(f"Running {self.metadata.name} - Finished.")
246
252
  return FullClassificationMetrics(
247
253
  scores_per_experiment=scores,
248
- **avg_scores,
254
+ **avg_scores, # type: ignore[typeddict-item]
249
255
  )
250
256
 
251
257
  def _calculate_scores(
@@ -3,7 +3,7 @@ import logging
3
3
  import random
4
4
  from collections import defaultdict
5
5
  from pathlib import Path
6
- from typing import Any
6
+ from typing import Any, cast
7
7
 
8
8
  import numpy as np
9
9
  from datasets import Dataset, DatasetDict
@@ -11,8 +11,8 @@ from sklearn.cluster import MiniBatchKMeans
11
11
  from sklearn.metrics.cluster import v_measure_score
12
12
 
13
13
  from mteb._create_dataloaders import create_dataloader
14
- from mteb.models import EncoderProtocol
15
- from mteb.types import HFSubset, ScoresDict
14
+ from mteb.models import EncoderProtocol, MTEBModels
15
+ from mteb.types import Array, HFSubset, ScoresDict
16
16
  from mteb.types.statistics import (
17
17
  ImageStatistics,
18
18
  LabelStatistics,
@@ -34,7 +34,7 @@ MultilingualDataset = dict[HFSubset, DatasetDict]
34
34
 
35
35
 
36
36
  def _evaluate_clustering_bootstrapped(
37
- embeddings: np.ndarray,
37
+ embeddings: Array,
38
38
  labels: list[list[str]],
39
39
  n_clusters: int,
40
40
  cluster_size: int,
@@ -61,21 +61,21 @@ def _evaluate_clustering_bootstrapped(
61
61
  max_depth = max(map(len, labels))
62
62
  # Evaluate on each level til max depth
63
63
  for i_level in range(max_depth):
64
- level_labels = []
64
+ level_labels: list[str | int] = []
65
65
  # Assign -1 to gold label if the level is not there
66
66
  for label in labels:
67
67
  if len(label) > i_level:
68
68
  level_labels.append(label[i_level])
69
69
  else:
70
70
  level_labels.append(-1)
71
- level_labels = np.array(level_labels)
71
+ np_level_labels = np.array(level_labels)
72
72
  valid_idx = np.array(
73
- [level_label != -1 for level_label in level_labels]
73
+ [level_label != -1 for level_label in np_level_labels]
74
74
  ) # Could be level_labels != -1 but fails with FutureWarning: elementwise comparison failed
75
- level_labels = level_labels[valid_idx]
75
+ np_level_labels = np_level_labels[valid_idx]
76
76
  level_embeddings = embeddings[valid_idx]
77
77
  clustering_model = MiniBatchKMeans(
78
- n_clusters=np.unique(level_labels).size,
78
+ n_clusters=np.unique(np_level_labels).size,
79
79
  batch_size=kmean_batch_size,
80
80
  init="k-means++",
81
81
  n_init=1, # default when kmeans++ is used
@@ -87,7 +87,7 @@ def _evaluate_clustering_bootstrapped(
87
87
  cluster_indices = rng_state.choices(range(n_embeddings), k=cluster_size)
88
88
 
89
89
  _embeddings = level_embeddings[cluster_indices]
90
- _labels = level_labels[cluster_indices]
90
+ _labels = np_level_labels[cluster_indices]
91
91
  cluster_assignment = clustering_model.fit_predict(_embeddings)
92
92
  v_measure = v_measure_score(_labels, cluster_assignment)
93
93
  v_measures[f"Level {i_level}"].append(v_measure)
@@ -153,7 +153,7 @@ class AbsTaskClustering(AbsTask):
153
153
 
154
154
  def _evaluate_subset(
155
155
  self,
156
- model: EncoderProtocol,
156
+ model: MTEBModels,
157
157
  data_split: Dataset,
158
158
  *,
159
159
  encode_kwargs: dict[str, Any],
@@ -162,6 +162,10 @@ class AbsTaskClustering(AbsTask):
162
162
  prediction_folder: Path | None = None,
163
163
  **kwargs: Any,
164
164
  ) -> ScoresDict:
165
+ if not isinstance(model, EncoderProtocol):
166
+ raise TypeError(
167
+ "Expected encoder model to be an instance of EncoderProtocol."
168
+ )
165
169
  if (
166
170
  self.max_document_to_embed is not None
167
171
  and self.max_fraction_of_documents_to_embed is not None
@@ -182,13 +186,13 @@ class AbsTaskClustering(AbsTask):
182
186
  self.max_fraction_of_documents_to_embed * len(data_split)
183
187
  )
184
188
  else:
185
- max_documents_to_embed = self.max_document_to_embed
189
+ max_documents_to_embed = cast(int, self.max_document_to_embed)
186
190
 
187
- max_documents_to_embed = min(len(data_split), max_documents_to_embed) # type: ignore
191
+ max_documents_to_embed = min(len(data_split), max_documents_to_embed)
188
192
  example_indices = self.rng_state.sample(
189
193
  range(len(data_split)), k=max_documents_to_embed
190
194
  )
191
- downsampled_dataset = data_split.select(example_indices) # type: ignore
195
+ downsampled_dataset = data_split.select(example_indices)
192
196
 
193
197
  downsampled_dataset = downsampled_dataset.select_columns(
194
198
  [self.input_column_name, self.label_column_name]
@@ -8,7 +8,7 @@ from scipy.optimize import linear_sum_assignment
8
8
  from sklearn import metrics
9
9
 
10
10
  from mteb._evaluators import ClusteringEvaluator
11
- from mteb.models import EncoderProtocol
11
+ from mteb.models import EncoderProtocol, MTEBModels
12
12
  from mteb.types import ScoresDict
13
13
  from mteb.types.statistics import (
14
14
  ImageStatistics,
@@ -80,7 +80,7 @@ class AbsTaskClusteringLegacy(AbsTask):
80
80
 
81
81
  def _evaluate_subset(
82
82
  self,
83
- model: EncoderProtocol,
83
+ model: MTEBModels,
84
84
  data_split: Dataset,
85
85
  *,
86
86
  encode_kwargs: dict[str, Any],
@@ -89,6 +89,9 @@ class AbsTaskClusteringLegacy(AbsTask):
89
89
  prediction_folder: Path | None = None,
90
90
  **kwargs: Any,
91
91
  ) -> ScoresDict:
92
+ if not isinstance(model, EncoderProtocol):
93
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
94
+
92
95
  data_split = data_split.select_columns(
93
96
  [self.input_column_name, self.label_column_name]
94
97
  )
@@ -139,9 +142,6 @@ class AbsTaskClusteringLegacy(AbsTask):
139
142
  }
140
143
  return scores
141
144
 
142
- data_split = data_split.select_columns(
143
- [self.input_column_name, self.label_column_name]
144
- )
145
145
  evaluator = self.evaluator(
146
146
  data_split,
147
147
  input_column_name=self.input_column_name,
@@ -151,10 +151,10 @@ class AbsTaskClusteringLegacy(AbsTask):
151
151
  hf_subset=hf_subset,
152
152
  **kwargs,
153
153
  )
154
- clusters = evaluator(model, encode_kwargs=encode_kwargs)
154
+ evaluate_clusters = evaluator(model, encode_kwargs=encode_kwargs)
155
155
  if prediction_folder:
156
156
  self._save_task_predictions(
157
- clusters,
157
+ evaluate_clusters,
158
158
  model,
159
159
  prediction_folder,
160
160
  hf_subset=hf_subset,
@@ -163,7 +163,7 @@ class AbsTaskClusteringLegacy(AbsTask):
163
163
 
164
164
  return self._compute_metrics(
165
165
  data_split[self.label_column_name],
166
- clusters,
166
+ evaluate_clusters,
167
167
  )
168
168
 
169
169
  def _compute_metrics(
@@ -12,7 +12,7 @@ from mteb.abstasks._statistics_calculation import (
12
12
  calculate_text_statistics,
13
13
  )
14
14
  from mteb.abstasks.abstask import AbsTask
15
- from mteb.models.models_protocols import EncoderProtocol
15
+ from mteb.models.models_protocols import EncoderProtocol, MTEBModels
16
16
  from mteb.types.statistics import (
17
17
  ImageStatistics,
18
18
  SplitDescriptiveStatistics,
@@ -116,7 +116,7 @@ class AbsTaskImageTextPairClassification(AbsTask):
116
116
 
117
117
  def _evaluate_subset(
118
118
  self,
119
- model: EncoderProtocol,
119
+ model: MTEBModels,
120
120
  data_split: Dataset,
121
121
  *,
122
122
  encode_kwargs: dict[str, Any],
@@ -125,6 +125,8 @@ class AbsTaskImageTextPairClassification(AbsTask):
125
125
  prediction_folder: Path | None = None,
126
126
  **kwargs: Any,
127
127
  ) -> ImageTextPairClassificationMetrics:
128
+ if not isinstance(model, EncoderProtocol):
129
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
128
130
  select_columns = []
129
131
  for columns in (self.images_column_names, self.texts_column_names):
130
132
  if isinstance(columns, str):
@@ -154,7 +156,7 @@ class AbsTaskImageTextPairClassification(AbsTask):
154
156
  hf_subset=hf_subset,
155
157
  **kwargs,
156
158
  )
157
- scores = evaluator(model, encode_kwargs=encode_kwargs)
159
+ scores: list[torch.Tensor] = evaluator(model, encode_kwargs=encode_kwargs) # type: ignore[assignment]
158
160
  if prediction_folder:
159
161
  self._save_task_predictions(
160
162
  [score.tolist() for score in scores],
@@ -16,7 +16,8 @@ from typing_extensions import override
16
16
  from mteb._create_dataloaders import create_dataloader
17
17
  from mteb._evaluators.classification_metrics import hamming_score
18
18
  from mteb._evaluators.sklearn_evaluator import SklearnModelProtocol
19
- from mteb.models import EncoderProtocol
19
+ from mteb.models import EncoderProtocol, MTEBModels
20
+ from mteb.types import Array
20
21
 
21
22
  from .classification import AbsTaskClassification
22
23
 
@@ -24,14 +25,14 @@ logger = logging.getLogger(__name__)
24
25
 
25
26
 
26
27
  def _evaluate_classifier(
27
- embeddings_train: np.ndarray,
28
+ embeddings_train: Array,
28
29
  y_train: np.ndarray,
29
- embeddings_test: np.ndarray,
30
+ embeddings_test: Array,
30
31
  classifier: SklearnModelProtocol,
31
32
  ) -> tuple[np.ndarray, SklearnModelProtocol]:
32
- classifier: SklearnModelProtocol = clone(classifier)
33
- classifier.fit(embeddings_train, y_train)
34
- return classifier.predict(embeddings_test), classifier
33
+ classifier_copy: SklearnModelProtocol = clone(classifier)
34
+ classifier_copy.fit(embeddings_train, y_train)
35
+ return classifier_copy.predict(embeddings_test), classifier_copy
35
36
 
36
37
 
37
38
  class MultilabelClassificationMetrics(TypedDict):
@@ -72,14 +73,14 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
72
73
  evaluator: Classifier to use for evaluation. Must implement the SklearnModelProtocol.
73
74
  """
74
75
 
75
- evaluator: SklearnModelProtocol = KNeighborsClassifier(n_neighbors=5)
76
+ evaluator: SklearnModelProtocol = KNeighborsClassifier(n_neighbors=5) # type: ignore[assignment]
76
77
  input_column_name: str = "text"
77
78
  label_column_name: str = "label"
78
79
 
79
80
  @override
80
- def _evaluate_subset(
81
+ def _evaluate_subset( # type: ignore[override]
81
82
  self,
82
- model: EncoderProtocol,
83
+ model: MTEBModels,
83
84
  data_split: DatasetDict,
84
85
  *,
85
86
  encode_kwargs: dict[str, Any],
@@ -88,6 +89,9 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
88
89
  prediction_folder: Path | None = None,
89
90
  **kwargs: Any,
90
91
  ) -> FullMultilabelClassificationMetrics:
92
+ if not isinstance(model, EncoderProtocol):
93
+ raise TypeError("Expected model to be an instance of EncoderProtocol")
94
+
91
95
  if isinstance(data_split, DatasetDict):
92
96
  data_split = data_split.select_columns(
93
97
  [self.input_column_name, self.label_column_name]
@@ -185,19 +189,20 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
185
189
  )
186
190
 
187
191
  avg_scores: dict[str, Any] = {
188
- k: np.mean([s[k] for s in scores]) for k in scores[0].keys()
192
+ k: np.mean([s[k] for s in scores]) # type: ignore[literal-required]
193
+ for k in scores[0].keys()
189
194
  }
190
195
  logger.info("Running multilabel classification - Finished.")
191
196
  return FullMultilabelClassificationMetrics(
192
197
  scores_per_experiment=scores,
193
- **avg_scores,
198
+ **avg_scores, # type: ignore[typeddict-item]
194
199
  )
195
200
 
196
- def _calculate_scores(
201
+ def _calculate_scores( # type: ignore[override]
197
202
  self,
198
203
  y_test: np.ndarray,
199
204
  y_pred: np.ndarray,
200
- x_test_embedding: np.ndarray,
205
+ x_test_embedding: Array,
201
206
  current_classifier: SklearnModelProtocol,
202
207
  ) -> MultilabelClassificationMetrics:
203
208
  accuracy = current_classifier.score(x_test_embedding, y_test)
@@ -232,10 +237,9 @@ class AbsTaskMultilabelClassification(AbsTaskClassification):
232
237
  """
233
238
  sample_indices = []
234
239
  if idxs is None:
235
- idxs = np.arange(len(y))
240
+ idxs = list(np.arange(len(y)))
236
241
  self.np_rng.shuffle(idxs)
237
- idxs = idxs.tolist()
238
- label_counter = defaultdict(int)
242
+ label_counter: dict[int, int] = defaultdict(int)
239
243
  for i in idxs:
240
244
  if any((label_counter[label] < samples_per_label) for label in y[i]):
241
245
  sample_indices.append(i)