EuroEval 15.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (40) hide show
  1. euroeval/__init__.py +72 -0
  2. euroeval/benchmark_config_factory.py +358 -0
  3. euroeval/benchmark_modules/__init__.py +7 -0
  4. euroeval/benchmark_modules/base.py +354 -0
  5. euroeval/benchmark_modules/fresh.py +286 -0
  6. euroeval/benchmark_modules/hf.py +1185 -0
  7. euroeval/benchmark_modules/litellm.py +905 -0
  8. euroeval/benchmark_modules/vllm.py +1171 -0
  9. euroeval/benchmarker.py +1074 -0
  10. euroeval/callbacks.py +72 -0
  11. euroeval/cli.py +281 -0
  12. euroeval/constants.py +50 -0
  13. euroeval/data_loading.py +96 -0
  14. euroeval/data_models.py +474 -0
  15. euroeval/dataset_configs.py +2001 -0
  16. euroeval/enums.py +144 -0
  17. euroeval/exceptions.py +191 -0
  18. euroeval/finetuning.py +324 -0
  19. euroeval/generation.py +296 -0
  20. euroeval/human_evaluation.py +737 -0
  21. euroeval/languages.py +200 -0
  22. euroeval/model_cache.py +253 -0
  23. euroeval/model_config.py +77 -0
  24. euroeval/model_loading.py +78 -0
  25. euroeval/scores.py +90 -0
  26. euroeval/speed_benchmark.py +124 -0
  27. euroeval/task_utils/__init__.py +1 -0
  28. euroeval/task_utils/multiple_choice_classification.py +176 -0
  29. euroeval/task_utils/question_answering.py +698 -0
  30. euroeval/task_utils/sequence_classification.py +237 -0
  31. euroeval/task_utils/text_to_text.py +150 -0
  32. euroeval/task_utils/token_classification.py +464 -0
  33. euroeval/tasks.py +202 -0
  34. euroeval/types.py +97 -0
  35. euroeval/utils.py +574 -0
  36. euroeval-15.2.0.dist-info/METADATA +234 -0
  37. euroeval-15.2.0.dist-info/RECORD +40 -0
  38. euroeval-15.2.0.dist-info/WHEEL +4 -0
  39. euroeval-15.2.0.dist-info/entry_points.txt +4 -0
  40. euroeval-15.2.0.dist-info/licenses/LICENSE +21 -0
euroeval/languages.py ADDED
@@ -0,0 +1,200 @@
1
+ """List of languages and their ISO 639-1 codes.
2
+
3
+ Taken from https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes.
4
+
5
+ Last updated 19 June 2022.
6
+ """
7
+
8
+ from .data_models import Language
9
+
10
+
11
+ def get_all_languages() -> dict[str, Language]:
12
+ """Get a list of all the languages.
13
+
14
+ Returns:
15
+ A mapping between language codes and their configurations.
16
+ """
17
+ return {cfg.code: cfg for cfg in globals().values() if isinstance(cfg, Language)}
18
+
19
+
20
+ AB = Language(code="ab", name="Abkhazian")
21
+ AA = Language(code="aa", name="Afar")
22
+ AF = Language(code="af", name="Afrikaans")
23
+ SQ = Language(code="sq", name="Albanian")
24
+ AM = Language(code="am", name="Amharic")
25
+ AR = Language(code="ar", name="Arabic")
26
+ AN = Language(code="an", name="Aragonese")
27
+ HY = Language(code="hy", name="Armenian")
28
+ AS = Language(code="as", name="Assamese")
29
+ AV = Language(code="av", name="Avaric")
30
+ AE = Language(code="ae", name="Avestan")
31
+ AY = Language(code="ay", name="Aymara")
32
+ AZ = Language(code="az", name="Azerbaijani")
33
+ BM = Language(code="bm", name="Bambara")
34
+ BA = Language(code="ba", name="Bashkir")
35
+ EU = Language(code="eu", name="Basque")
36
+ BE = Language(code="be", name="Belarusian")
37
+ BN = Language(code="bn", name="Bengali")
38
+ BI = Language(code="bi", name="Bislama")
39
+ BS = Language(code="bs", name="Bosnian")
40
+ BR = Language(code="br", name="Breton")
41
+ BG = Language(code="bg", name="Bulgarian")
42
+ MY = Language(code="my", name="Burmese")
43
+ CA = Language(code="ca", name="Catalan")
44
+ CH = Language(code="ch", name="Chamorro")
45
+ CE = Language(code="ce", name="Chechen")
46
+ NY = Language(code="ny", name="Chichewa")
47
+ ZH = Language(code="zh", name="Chinese")
48
+ CU = Language(code="cu", name="Church Slavic")
49
+ CV = Language(code="cv", name="Chuvash")
50
+ KW = Language(code="kw", name="Cornish")
51
+ CO = Language(code="co", name="Corsican")
52
+ CR = Language(code="cr", name="Cree")
53
+ HR = Language(code="hr", name="Croatian")
54
+ CS = Language(code="cs", name="Czech")
55
+ DA = Language(code="da", name="Danish")
56
+ DV = Language(code="dv", name="Divehi")
57
+ NL = Language(code="nl", name="Dutch")
58
+ DZ = Language(code="dz", name="Dzongkha")
59
+ EN = Language(code="en", name="English")
60
+ EO = Language(code="eo", name="Esperanto")
61
+ ET = Language(code="et", name="Estonian")
62
+ EE = Language(code="ee", name="Ewe")
63
+ FO = Language(code="fo", name="Faroese")
64
+ FJ = Language(code="fj", name="Fijian")
65
+ FI = Language(code="fi", name="Finnish")
66
+ FR = Language(code="fr", name="French")
67
+ FY = Language(code="fy", name="Western Frisian")
68
+ FF = Language(code="ff", name="Fulah")
69
+ GD = Language(code="gd", name="Gaelic")
70
+ GL = Language(code="gl", name="Galician")
71
+ LG = Language(code="lg", name="Ganda")
72
+ KA = Language(code="ka", name="Georgian")
73
+ DE = Language(code="de", name="German")
74
+ EL = Language(code="el", name="Greek")
75
+ KL = Language(code="kl", name="Greenlandic")
76
+ GN = Language(code="gn", name="Guarani")
77
+ GU = Language(code="gu", name="Gujarati")
78
+ HT = Language(code="ht", name="Haitian")
79
+ HA = Language(code="ha", name="Hausa")
80
+ HE = Language(code="he", name="Hebrew")
81
+ HZ = Language(code="hz", name="Herero")
82
+ HI = Language(code="hi", name="Hindi")
83
+ HO = Language(code="ho", name="Hiri Motu")
84
+ HU = Language(code="hu", name="Hungarian")
85
+ IS = Language(code="is", name="Icelandic")
86
+ IO = Language(code="io", name="Ido")
87
+ IG = Language(code="ig", name="Igbo")
88
+ ID = Language(code="id", name="Indonesian")
89
+ IA = Language(code="ia", name="Interlingua")
90
+ IE = Language(code="ie", name="Interlingue")
91
+ IU = Language(code="iu", name="Inuktitut")
92
+ IK = Language(code="ik", name="Inupiaq")
93
+ GA = Language(code="ga", name="Irish")
94
+ IT = Language(code="it", name="Italian")
95
+ JA = Language(code="ja", name="Japanese")
96
+ KN = Language(code="kn", name="Kannada")
97
+ KR = Language(code="kr", name="Kanuri")
98
+ KS = Language(code="ks", name="Kashmiri")
99
+ KK = Language(code="kk", name="Kazakh")
100
+ KM = Language(code="km", name="Central Khmer")
101
+ KI = Language(code="ki", name="Kikuyu")
102
+ RW = Language(code="rw", name="Kinyarwanda")
103
+ KY = Language(code="ky", name="Kirghiz")
104
+ KV = Language(code="kv", name="Komi")
105
+ KG = Language(code="kg", name="Kongo")
106
+ KO = Language(code="ko", name="Korean")
107
+ KJ = Language(code="kj", name="Kuanyama")
108
+ KU = Language(code="ku", name="Kurdish")
109
+ LO = Language(code="lo", name="Lao")
110
+ LA = Language(code="la", name="Latin")
111
+ LV = Language(code="lv", name="Latvian")
112
+ LI = Language(code="li", name="Limburgan")
113
+ LN = Language(code="ln", name="Lingala")
114
+ LT = Language(code="lt", name="Lithuanian")
115
+ LU = Language(code="lu", name="Luba-Katanga")
116
+ LB = Language(code="lb", name="Luxembourgish")
117
+ MK = Language(code="mk", name="Macedonian")
118
+ MG = Language(code="mg", name="Malagasy")
119
+ MS = Language(code="ms", name="Malay")
120
+ ML = Language(code="ml", name="Malayalam")
121
+ MT = Language(code="mt", name="Maltese")
122
+ GV = Language(code="gv", name="Manx")
123
+ MI = Language(code="mi", name="Maori")
124
+ MR = Language(code="mr", name="Marathi")
125
+ MH = Language(code="mh", name="Marshallese")
126
+ MN = Language(code="mn", name="Mongolian")
127
+ NA = Language(code="na", name="Nauru")
128
+ NV = Language(code="nv", name="Navajo")
129
+ ND = Language(code="nd", name="Northern Ndebele")
130
+ NR = Language(code="nr", name="South Ndebele")
131
+ NG = Language(code="ng", name="Ndonga")
132
+ NE = Language(code="ne", name="Nepali")
133
+ NO = Language(code="no", name="Norwegian")
134
+ NB = Language(code="nb", name="Norwegian Bokmål")
135
+ NN = Language(code="nn", name="Norwegian Nynorsk")
136
+ II = Language(code="ii", name="Sichuan Yi")
137
+ OC = Language(code="oc", name="Occitan")
138
+ OJ = Language(code="oj", name="Ojibwa")
139
+ OR = Language(code="or", name="Oriya")
140
+ OM = Language(code="om", name="Oromo")
141
+ OS = Language(code="os", name="Ossetian")
142
+ PI = Language(code="pi", name="Pali")
143
+ PS = Language(code="ps", name="Pashto")
144
+ FA = Language(code="fa", name="Persian")
145
+ PL = Language(code="pl", name="Polish")
146
+ PT = Language(code="pt", name="Portuguese")
147
+ PA = Language(code="pa", name="Punjabi")
148
+ QU = Language(code="qu", name="Quechua")
149
+ RO = Language(code="ro", name="Romanian")
150
+ RM = Language(code="rm", name="Romansh")
151
+ RN = Language(code="rn", name="Rundi")
152
+ RU = Language(code="ru", name="Russian")
153
+ SE = Language(code="se", name="Northern Sami")
154
+ SM = Language(code="sm", name="Samoan")
155
+ SG = Language(code="sg", name="Sango")
156
+ SA = Language(code="sa", name="Sanskrit")
157
+ SC = Language(code="sc", name="Sardinian")
158
+ SR = Language(code="sr", name="Serbian")
159
+ SN = Language(code="sn", name="Shona")
160
+ SD = Language(code="sd", name="Sindhi")
161
+ SI = Language(code="si", name="Sinhala")
162
+ SK = Language(code="sk", name="Slovak")
163
+ SL = Language(code="sl", name="Slovenian")
164
+ SO = Language(code="so", name="Somali")
165
+ ST = Language(code="st", name="Sotho")
166
+ ES = Language(code="es", name="Spanish")
167
+ SU = Language(code="su", name="Sundanese")
168
+ SW = Language(code="sw", name="Swahili")
169
+ SS = Language(code="ss", name="Swati")
170
+ SV = Language(code="sv", name="Swedish")
171
+ TL = Language(code="tl", name="Tagalog")
172
+ TY = Language(code="ty", name="Tahitian")
173
+ TG = Language(code="tg", name="Tajik")
174
+ TA = Language(code="ta", name="Tamil")
175
+ TT = Language(code="tt", name="Tatar")
176
+ TE = Language(code="te", name="Telugu")
177
+ TH = Language(code="th", name="Thai")
178
+ BO = Language(code="bo", name="Tibetan")
179
+ TI = Language(code="ti", name="Tigrinya")
180
+ TO = Language(code="to", name="Tonga")
181
+ TS = Language(code="ts", name="Tsonga")
182
+ TN = Language(code="tn", name="Tswana")
183
+ TR = Language(code="tr", name="Turkish")
184
+ TK = Language(code="tk", name="Turkmen")
185
+ TW = Language(code="tw", name="Twi")
186
+ UG = Language(code="ug", name="Uighur")
187
+ UK = Language(code="uk", name="Ukrainian")
188
+ UR = Language(code="ur", name="Urdu")
189
+ UZ = Language(code="uz", name="Uzbek")
190
+ VE = Language(code="ve", name="Venda")
191
+ VI = Language(code="vi", name="Vietnamese")
192
+ VO = Language(code="vo", name="Volapük")
193
+ WA = Language(code="wa", name="Walloon")
194
+ CY = Language(code="cy", name="Welsh")
195
+ WO = Language(code="wo", name="Wolof")
196
+ XH = Language(code="xh", name="Xhosa")
197
+ YI = Language(code="yi", name="Yiddish")
198
+ YO = Language(code="yo", name="Yoruba")
199
+ ZA = Language(code="za", name="Zhuang")
200
+ ZU = Language(code="zu", name="Zulu")
@@ -0,0 +1,253 @@
1
+ """ModelCache class for caching model outputs."""
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import sys
7
+ import typing as t
8
+ from collections import defaultdict
9
+ from dataclasses import asdict
10
+
11
+ from tqdm.auto import tqdm
12
+
13
+ from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
14
+
15
+ if t.TYPE_CHECKING:
16
+ from pathlib import Path
17
+
18
+ from datasets import Dataset
19
+
20
+
21
+ logger = logging.getLogger("euroeval")
22
+
23
+
24
+ class ModelCache:
25
+ """A cache for model outputs.
26
+
27
+ Attributes:
28
+ model_cache_dir:
29
+ The directory to store the cache in.
30
+ cache_path:
31
+ The path to the cache file.
32
+ cache:
33
+ The model output cache.
34
+ max_generated_tokens:
35
+ The maximum number of tokens to generate for each example.
36
+ """
37
+
38
+ def __init__(
39
+ self, model_cache_dir: "Path", cache_name: str, max_generated_tokens: int
40
+ ) -> None:
41
+ """Initialize the model output cache.
42
+
43
+ Args:
44
+ model_cache_dir:
45
+ The directory to store the cache in.
46
+ cache_name:
47
+ The name of the cache file.
48
+ max_generated_tokens:
49
+ The maximum number of tokens to generate for each example.
50
+ """
51
+ self.model_cache_dir = model_cache_dir
52
+ self.model_cache_dir.mkdir(parents=True, exist_ok=True)
53
+ self.cache_path = self.model_cache_dir / cache_name.replace("/", "--")
54
+ self.max_generated_tokens = max_generated_tokens
55
+
56
+ def load(self) -> None:
57
+ """Load the model output cache."""
58
+ if not self.cache_path.exists():
59
+ with self.cache_path.open("w") as f:
60
+ json.dump(dict(), f)
61
+
62
+ try:
63
+ with self.cache_path.open() as f:
64
+ json_cache = json.load(f)
65
+ except json.JSONDecodeError:
66
+ logger.warning(
67
+ f"Failed to load the cache from {self.cache_path}. The cache will be "
68
+ f"re-initialised."
69
+ )
70
+ json_cache = dict()
71
+ with self.cache_path.open("w") as f:
72
+ json.dump(dict(), f)
73
+
74
+ cache: dict[str, SingleGenerativeModelOutput] = dict()
75
+ for key in json_cache:
76
+ cache[key] = SingleGenerativeModelOutput(**json_cache[key])
77
+
78
+ self.cache = cache
79
+
80
+ def save(self) -> None:
81
+ """Save the model output cache to disk."""
82
+ dumpable_cache: dict[str, dict] = defaultdict(dict)
83
+ for key, value in self.cache.items():
84
+ dumpable_cache[key] = asdict(value)
85
+
86
+ try:
87
+ with self.cache_path.open("w") as f:
88
+ json.dump(dumpable_cache, f)
89
+ except KeyError:
90
+ logger.warning(
91
+ f"Failed to load the cache from {self.cache_path}. The cache will be "
92
+ f"re-initialised."
93
+ )
94
+ self.cache = dict()
95
+ with self.cache_path.open("w") as f:
96
+ json.dump(dict(), f)
97
+
98
+ def _hash_key(self, key: str | list[dict[str, str]]) -> str:
99
+ """Hash the key to use as an index in the cache.
100
+
101
+ Args:
102
+ key:
103
+ The key to hash.
104
+
105
+ Returns:
106
+ The hashed key.
107
+ """
108
+ return hashlib.md5(string=str(key).encode()).hexdigest()
109
+
110
+ def __getitem__(
111
+ self, key: str | list[dict[str, str]]
112
+ ) -> SingleGenerativeModelOutput:
113
+ """Get an item from the cache.
114
+
115
+ Args:
116
+ key:
117
+ The key to use to index the cache.
118
+
119
+ Returns:
120
+ The model output.
121
+ """
122
+ hashed_key = self._hash_key(key=key)
123
+ return self.cache[hashed_key]
124
+
125
+ def __setitem__(
126
+ self, key: str | list[dict[str, str]], value: SingleGenerativeModelOutput
127
+ ) -> None:
128
+ """Set an item in the cache.
129
+
130
+ Args:
131
+ key:
132
+ The key to use to index the cache.
133
+ value:
134
+ The value to set in the cache.
135
+ """
136
+ hashed_key = self._hash_key(key=key)
137
+ self.cache[hashed_key] = value
138
+
139
+ def remove(self) -> None:
140
+ """Remove the cache from memory and delete it from disk."""
141
+ self.cache_path.unlink()
142
+ del self.cache
143
+
144
+ def __contains__(self, key: str | list[dict[str, str]]) -> bool:
145
+ """Check if a key is in the cache.
146
+
147
+ Args:
148
+ key:
149
+ The key to check.
150
+
151
+ Returns:
152
+ Whether the key is in the cache.
153
+ """
154
+ hashed_key = self._hash_key(key=key)
155
+ return hashed_key in self.cache
156
+
157
+ def add_to_cache(
158
+ self, model_inputs: dict, model_output: GenerativeModelOutput
159
+ ) -> None:
160
+ """Add the model input/output to the cache.
161
+
162
+ Args:
163
+ model_inputs:
164
+ The model inputs.
165
+ model_output:
166
+ The model output.
167
+ """
168
+ input_column = "messages" if "messages" in model_inputs else "text"
169
+ model_inputs = model_inputs[input_column]
170
+
171
+ # Store the generated sequences in the cache, one by one
172
+ with tqdm(
173
+ iterable=model_inputs,
174
+ desc="Caching model outputs",
175
+ leave=False,
176
+ disable=hasattr(sys, "_called_from_test"),
177
+ ) as pbar:
178
+ for sample_idx, model_input in enumerate(pbar):
179
+ # Extract the scores from the model output, to be cached. We only store
180
+ # the indices of the top scores, to save space. Further, we only store
181
+ # the scores if the generated sequence is shorter than the maximum
182
+ # length
183
+ if model_output.scores is not None and self.max_generated_tokens < 8:
184
+ assert model_output.scores is not None
185
+ scores = model_output.scores[sample_idx]
186
+ else:
187
+ scores = None
188
+ self[model_input] = SingleGenerativeModelOutput(
189
+ sequence=model_output.sequences[sample_idx], scores=scores
190
+ )
191
+
192
+
193
+ def split_dataset_into_cached_and_non_cached(
194
+ dataset: "Dataset", cache: ModelCache
195
+ ) -> tuple["Dataset", "Dataset"]:
196
+ """Split a dataset into a cached and non-cached part.
197
+
198
+ Args:
199
+ dataset:
200
+ The dataset to split.
201
+ cache:
202
+ The model output cache.
203
+
204
+ Returns:
205
+ The cached and non-cached parts of the dataset.
206
+ """
207
+ # Get the sample indices of the non-cached examples, which are unique with respect
208
+ # to the "text" column.
209
+ input_column = "messages" if "messages" in dataset.column_names else "text"
210
+ dataset_texts = dataset[input_column]
211
+ unique_non_cached_ids = set()
212
+ unique_texts = list()
213
+ for idx, dataset_text in enumerate(dataset_texts):
214
+ if dataset_text not in cache and dataset_text not in unique_texts:
215
+ unique_non_cached_ids.add(idx)
216
+ unique_texts.append(dataset_text)
217
+
218
+ # The cached examples are the ones that are not in the non-cached examples. This
219
+ # means that if the dataset has duplicates, only a single copy of the duplicate
220
+ # will be put in the non-cached part, and the rest in the cached part.
221
+ cached_ids = set(range(len(dataset))) - unique_non_cached_ids
222
+
223
+ cached = dataset.select(cached_ids)
224
+ non_cached = dataset.select(unique_non_cached_ids)
225
+ return cached, non_cached
226
+
227
+
228
+ def load_cached_model_outputs(
229
+ cached_dataset: "Dataset", cache: ModelCache
230
+ ) -> GenerativeModelOutput:
231
+ """Load the cached model outputs.
232
+
233
+ Args:
234
+ cached_dataset:
235
+ The dataset containing the cached examples.
236
+ cache:
237
+ The model output cache.
238
+
239
+ Returns:
240
+ The model output containing the cached sequences.
241
+ """
242
+ input_column = "messages" if "messages" in cached_dataset.column_names else "text"
243
+ cached_model_outputs: list[SingleGenerativeModelOutput] = [
244
+ cache[prompt] for prompt in cached_dataset[input_column]
245
+ ]
246
+
247
+ cached_sequences = [model_output.sequence for model_output in cached_model_outputs]
248
+
249
+ if cached_model_outputs[0].scores is None:
250
+ return GenerativeModelOutput(sequences=cached_sequences)
251
+
252
+ cached_scores = [model_output.scores or [] for model_output in cached_model_outputs]
253
+ return GenerativeModelOutput(sequences=cached_sequences, scores=cached_scores)
@@ -0,0 +1,77 @@
1
+ """Functions related to getting the model configuration."""
2
+
3
+ import logging
4
+ import typing as t
5
+
6
+ from . import benchmark_modules
7
+ from .exceptions import InvalidModel, NeedsEnvironmentVariable, NeedsExtraInstalled
8
+
9
+ if t.TYPE_CHECKING:
10
+ from .data_models import BenchmarkConfig, ModelConfig
11
+
12
+
13
+ logger = logging.getLogger("euroeval")
14
+
15
+
16
+ def get_model_config(
17
+ model_id: str, benchmark_config: "BenchmarkConfig"
18
+ ) -> "ModelConfig":
19
+ """Fetches configuration for a model.
20
+
21
+ Args:
22
+ model_id:
23
+ The model ID.
24
+ benchmark_config:
25
+ The configuration of the benchmark.
26
+
27
+ Returns:
28
+ The model configuration.
29
+
30
+ Raises:
31
+ InvalidModel:
32
+ If all model setups can handle the model, but the model does not exist.
33
+ """
34
+ all_benchmark_modules = [
35
+ cls
36
+ for cls in benchmark_modules.__dict__.values()
37
+ if isinstance(cls, type)
38
+ and issubclass(cls, benchmark_modules.BenchmarkModule)
39
+ and cls is not benchmark_modules.BenchmarkModule
40
+ ]
41
+ all_benchmark_modules.sort(key=lambda cls: cls.high_priority, reverse=True)
42
+
43
+ needs_extras: list[str] = list()
44
+ needs_env_vars: list[str] = list()
45
+ for benchmark_module in all_benchmark_modules:
46
+ exists_or_err = benchmark_module.model_exists(
47
+ model_id=model_id, benchmark_config=benchmark_config
48
+ )
49
+ if isinstance(exists_or_err, NeedsExtraInstalled):
50
+ needs_extras.append(exists_or_err.extra)
51
+ elif isinstance(exists_or_err, NeedsEnvironmentVariable):
52
+ needs_env_vars.append(exists_or_err.env_var)
53
+ elif exists_or_err is True:
54
+ logger.debug(
55
+ f"The model {model_id!r} was identified by the "
56
+ f"{benchmark_module.__name__} benchmark module."
57
+ )
58
+ model_config = benchmark_module.get_model_config(
59
+ model_id=model_id, benchmark_config=benchmark_config
60
+ )
61
+ return model_config
62
+ else:
63
+ msg = f"Model {model_id} not found."
64
+ if needs_extras:
65
+ msg += (
66
+ " However, it is possible that the model exists, but a package "
67
+ "needs to be installed to check if it exists. Please try running "
68
+ f"`pip install euroeval[{','.join(needs_extras)}]` or `pip install "
69
+ "euroeval[all]`, and try again."
70
+ )
71
+ elif needs_env_vars:
72
+ msg += (
73
+ " However, it is possible that the model exists, but an environment "
74
+ "variable needs to be set to check if it exists. Please set the "
75
+ f"environment variables {','.join(needs_env_vars)} and try again."
76
+ )
77
+ raise InvalidModel(msg)
@@ -0,0 +1,78 @@
1
+ """Functions related to the loading of models."""
2
+
3
+ import typing as t
4
+
5
+ from .benchmark_modules import (
6
+ FreshEncoderModel,
7
+ HuggingFaceEncoderModel,
8
+ LiteLLMModel,
9
+ VLLMModel,
10
+ )
11
+ from .constants import GENERATIVE_DATASET_TASK_GROUPS
12
+ from .enums import InferenceBackend, ModelType
13
+ from .exceptions import InvalidBenchmark, InvalidModel
14
+
15
+ if t.TYPE_CHECKING:
16
+ from .benchmark_modules import BenchmarkModule
17
+ from .data_models import BenchmarkConfig, DatasetConfig, ModelConfig
18
+
19
+
20
+ def load_model(
21
+ model_config: "ModelConfig",
22
+ dataset_config: "DatasetConfig",
23
+ benchmark_config: "BenchmarkConfig",
24
+ ) -> "BenchmarkModule":
25
+ """Load a model.
26
+
27
+ Args:
28
+ model_config:
29
+ The model configuration.
30
+ dataset_config:
31
+ The dataset configuration.
32
+ benchmark_config:
33
+ The benchmark configuration.
34
+
35
+ Returns:
36
+ The model.
37
+ """
38
+ # The order matters; the first model type that matches will be used. For this
39
+ # reason, they have been ordered in terms of the most common model types.
40
+ model_class: t.Type[BenchmarkModule]
41
+ match (model_config.model_type, model_config.inference_backend, model_config.fresh):
42
+ case (ModelType.GENERATIVE, InferenceBackend.VLLM, False):
43
+ model_class = VLLMModel
44
+ case (ModelType.ENCODER, InferenceBackend.TRANSFORMERS, False):
45
+ model_class = HuggingFaceEncoderModel
46
+ case (ModelType.GENERATIVE, InferenceBackend.LITELLM, False):
47
+ model_class = LiteLLMModel
48
+ case (ModelType.ENCODER, InferenceBackend.TRANSFORMERS, True):
49
+ model_class = FreshEncoderModel
50
+ case (_, _, True):
51
+ raise InvalidModel(
52
+ "Cannot load a freshly initialised model with the model type "
53
+ f"{model_config.model_type!r} and inference backend "
54
+ f"{model_config.inference_backend!r}."
55
+ )
56
+ case _:
57
+ raise InvalidModel(
58
+ f"Cannot load model with model type {model_config.model_type!r} and "
59
+ f"inference backend {model_config.inference_backend!r}."
60
+ )
61
+
62
+ # Refuse to benchmark non-generative models on generative tasks
63
+ if (
64
+ dataset_config.task.task_group in GENERATIVE_DATASET_TASK_GROUPS
65
+ and not model_config.model_type == ModelType.GENERATIVE
66
+ ):
67
+ raise InvalidBenchmark(
68
+ f"Cannot benchmark non-generative model {model_config.model_id!r} on "
69
+ f"generative task {dataset_config.task.name!r}."
70
+ )
71
+
72
+ model = model_class(
73
+ model_config=model_config,
74
+ dataset_config=dataset_config,
75
+ benchmark_config=benchmark_config,
76
+ )
77
+
78
+ return model