EuroEval 15.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +72 -0
- euroeval/benchmark_config_factory.py +358 -0
- euroeval/benchmark_modules/__init__.py +7 -0
- euroeval/benchmark_modules/base.py +354 -0
- euroeval/benchmark_modules/fresh.py +286 -0
- euroeval/benchmark_modules/hf.py +1185 -0
- euroeval/benchmark_modules/litellm.py +905 -0
- euroeval/benchmark_modules/vllm.py +1171 -0
- euroeval/benchmarker.py +1074 -0
- euroeval/callbacks.py +72 -0
- euroeval/cli.py +281 -0
- euroeval/constants.py +50 -0
- euroeval/data_loading.py +96 -0
- euroeval/data_models.py +474 -0
- euroeval/dataset_configs.py +2001 -0
- euroeval/enums.py +144 -0
- euroeval/exceptions.py +191 -0
- euroeval/finetuning.py +324 -0
- euroeval/generation.py +296 -0
- euroeval/human_evaluation.py +737 -0
- euroeval/languages.py +200 -0
- euroeval/model_cache.py +253 -0
- euroeval/model_config.py +77 -0
- euroeval/model_loading.py +78 -0
- euroeval/scores.py +90 -0
- euroeval/speed_benchmark.py +124 -0
- euroeval/task_utils/__init__.py +1 -0
- euroeval/task_utils/multiple_choice_classification.py +176 -0
- euroeval/task_utils/question_answering.py +698 -0
- euroeval/task_utils/sequence_classification.py +237 -0
- euroeval/task_utils/text_to_text.py +150 -0
- euroeval/task_utils/token_classification.py +464 -0
- euroeval/tasks.py +202 -0
- euroeval/types.py +97 -0
- euroeval/utils.py +574 -0
- euroeval-15.2.0.dist-info/METADATA +234 -0
- euroeval-15.2.0.dist-info/RECORD +40 -0
- euroeval-15.2.0.dist-info/WHEEL +4 -0
- euroeval-15.2.0.dist-info/entry_points.txt +4 -0
- euroeval-15.2.0.dist-info/licenses/LICENSE +21 -0
euroeval/languages.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""List of languages and their ISO 639-1 codes.
|
|
2
|
+
|
|
3
|
+
Taken from https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes.
|
|
4
|
+
|
|
5
|
+
Last updated 19 June 2022.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .data_models import Language
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_all_languages() -> dict[str, Language]:
|
|
12
|
+
"""Get a list of all the languages.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
A mapping between language codes and their configurations.
|
|
16
|
+
"""
|
|
17
|
+
return {cfg.code: cfg for cfg in globals().values() if isinstance(cfg, Language)}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
AB = Language(code="ab", name="Abkhazian")
|
|
21
|
+
AA = Language(code="aa", name="Afar")
|
|
22
|
+
AF = Language(code="af", name="Afrikaans")
|
|
23
|
+
SQ = Language(code="sq", name="Albanian")
|
|
24
|
+
AM = Language(code="am", name="Amharic")
|
|
25
|
+
AR = Language(code="ar", name="Arabic")
|
|
26
|
+
AN = Language(code="an", name="Aragonese")
|
|
27
|
+
HY = Language(code="hy", name="Armenian")
|
|
28
|
+
AS = Language(code="as", name="Assamese")
|
|
29
|
+
AV = Language(code="av", name="Avaric")
|
|
30
|
+
AE = Language(code="ae", name="Avestan")
|
|
31
|
+
AY = Language(code="ay", name="Aymara")
|
|
32
|
+
AZ = Language(code="az", name="Azerbaijani")
|
|
33
|
+
BM = Language(code="bm", name="Bambara")
|
|
34
|
+
BA = Language(code="ba", name="Bashkir")
|
|
35
|
+
EU = Language(code="eu", name="Basque")
|
|
36
|
+
BE = Language(code="be", name="Belarusian")
|
|
37
|
+
BN = Language(code="bn", name="Bengali")
|
|
38
|
+
BI = Language(code="bi", name="Bislama")
|
|
39
|
+
BS = Language(code="bs", name="Bosnian")
|
|
40
|
+
BR = Language(code="br", name="Breton")
|
|
41
|
+
BG = Language(code="bg", name="Bulgarian")
|
|
42
|
+
MY = Language(code="my", name="Burmese")
|
|
43
|
+
CA = Language(code="ca", name="Catalan")
|
|
44
|
+
CH = Language(code="ch", name="Chamorro")
|
|
45
|
+
CE = Language(code="ce", name="Chechen")
|
|
46
|
+
NY = Language(code="ny", name="Chichewa")
|
|
47
|
+
ZH = Language(code="zh", name="Chinese")
|
|
48
|
+
CU = Language(code="cu", name="Church Slavic")
|
|
49
|
+
CV = Language(code="cv", name="Chuvash")
|
|
50
|
+
KW = Language(code="kw", name="Cornish")
|
|
51
|
+
CO = Language(code="co", name="Corsican")
|
|
52
|
+
CR = Language(code="cr", name="Cree")
|
|
53
|
+
HR = Language(code="hr", name="Croatian")
|
|
54
|
+
CS = Language(code="cs", name="Czech")
|
|
55
|
+
DA = Language(code="da", name="Danish")
|
|
56
|
+
DV = Language(code="dv", name="Divehi")
|
|
57
|
+
NL = Language(code="nl", name="Dutch")
|
|
58
|
+
DZ = Language(code="dz", name="Dzongkha")
|
|
59
|
+
EN = Language(code="en", name="English")
|
|
60
|
+
EO = Language(code="eo", name="Esperanto")
|
|
61
|
+
ET = Language(code="et", name="Estonian")
|
|
62
|
+
EE = Language(code="ee", name="Ewe")
|
|
63
|
+
FO = Language(code="fo", name="Faroese")
|
|
64
|
+
FJ = Language(code="fj", name="Fijian")
|
|
65
|
+
FI = Language(code="fi", name="Finnish")
|
|
66
|
+
FR = Language(code="fr", name="French")
|
|
67
|
+
FY = Language(code="fy", name="Western Frisian")
|
|
68
|
+
FF = Language(code="ff", name="Fulah")
|
|
69
|
+
GD = Language(code="gd", name="Gaelic")
|
|
70
|
+
GL = Language(code="gl", name="Galician")
|
|
71
|
+
LG = Language(code="lg", name="Ganda")
|
|
72
|
+
KA = Language(code="ka", name="Georgian")
|
|
73
|
+
DE = Language(code="de", name="German")
|
|
74
|
+
EL = Language(code="el", name="Greek")
|
|
75
|
+
KL = Language(code="kl", name="Greenlandic")
|
|
76
|
+
GN = Language(code="gn", name="Guarani")
|
|
77
|
+
GU = Language(code="gu", name="Gujarati")
|
|
78
|
+
HT = Language(code="ht", name="Haitian")
|
|
79
|
+
HA = Language(code="ha", name="Hausa")
|
|
80
|
+
HE = Language(code="he", name="Hebrew")
|
|
81
|
+
HZ = Language(code="hz", name="Herero")
|
|
82
|
+
HI = Language(code="hi", name="Hindi")
|
|
83
|
+
HO = Language(code="ho", name="Hiri Motu")
|
|
84
|
+
HU = Language(code="hu", name="Hungarian")
|
|
85
|
+
IS = Language(code="is", name="Icelandic")
|
|
86
|
+
IO = Language(code="io", name="Ido")
|
|
87
|
+
IG = Language(code="ig", name="Igbo")
|
|
88
|
+
ID = Language(code="id", name="Indonesian")
|
|
89
|
+
IA = Language(code="ia", name="Interlingua")
|
|
90
|
+
IE = Language(code="ie", name="Interlingue")
|
|
91
|
+
IU = Language(code="iu", name="Inuktitut")
|
|
92
|
+
IK = Language(code="ik", name="Inupiaq")
|
|
93
|
+
GA = Language(code="ga", name="Irish")
|
|
94
|
+
IT = Language(code="it", name="Italian")
|
|
95
|
+
JA = Language(code="ja", name="Japanese")
|
|
96
|
+
KN = Language(code="kn", name="Kannada")
|
|
97
|
+
KR = Language(code="kr", name="Kanuri")
|
|
98
|
+
KS = Language(code="ks", name="Kashmiri")
|
|
99
|
+
KK = Language(code="kk", name="Kazakh")
|
|
100
|
+
KM = Language(code="km", name="Central Khmer")
|
|
101
|
+
KI = Language(code="ki", name="Kikuyu")
|
|
102
|
+
RW = Language(code="rw", name="Kinyarwanda")
|
|
103
|
+
KY = Language(code="ky", name="Kirghiz")
|
|
104
|
+
KV = Language(code="kv", name="Komi")
|
|
105
|
+
KG = Language(code="kg", name="Kongo")
|
|
106
|
+
KO = Language(code="ko", name="Korean")
|
|
107
|
+
KJ = Language(code="kj", name="Kuanyama")
|
|
108
|
+
KU = Language(code="ku", name="Kurdish")
|
|
109
|
+
LO = Language(code="lo", name="Lao")
|
|
110
|
+
LA = Language(code="la", name="Latin")
|
|
111
|
+
LV = Language(code="lv", name="Latvian")
|
|
112
|
+
LI = Language(code="li", name="Limburgan")
|
|
113
|
+
LN = Language(code="ln", name="Lingala")
|
|
114
|
+
LT = Language(code="lt", name="Lithuanian")
|
|
115
|
+
LU = Language(code="lu", name="Luba-Katanga")
|
|
116
|
+
LB = Language(code="lb", name="Luxembourgish")
|
|
117
|
+
MK = Language(code="mk", name="Macedonian")
|
|
118
|
+
MG = Language(code="mg", name="Malagasy")
|
|
119
|
+
MS = Language(code="ms", name="Malay")
|
|
120
|
+
ML = Language(code="ml", name="Malayalam")
|
|
121
|
+
MT = Language(code="mt", name="Maltese")
|
|
122
|
+
GV = Language(code="gv", name="Manx")
|
|
123
|
+
MI = Language(code="mi", name="Maori")
|
|
124
|
+
MR = Language(code="mr", name="Marathi")
|
|
125
|
+
MH = Language(code="mh", name="Marshallese")
|
|
126
|
+
MN = Language(code="mn", name="Mongolian")
|
|
127
|
+
NA = Language(code="na", name="Nauru")
|
|
128
|
+
NV = Language(code="nv", name="Navajo")
|
|
129
|
+
ND = Language(code="nd", name="Northern Ndebele")
|
|
130
|
+
NR = Language(code="nr", name="South Ndebele")
|
|
131
|
+
NG = Language(code="ng", name="Ndonga")
|
|
132
|
+
NE = Language(code="ne", name="Nepali")
|
|
133
|
+
NO = Language(code="no", name="Norwegian")
|
|
134
|
+
NB = Language(code="nb", name="Norwegian Bokmål")
|
|
135
|
+
NN = Language(code="nn", name="Norwegian Nynorsk")
|
|
136
|
+
II = Language(code="ii", name="Sichuan Yi")
|
|
137
|
+
OC = Language(code="oc", name="Occitan")
|
|
138
|
+
OJ = Language(code="oj", name="Ojibwa")
|
|
139
|
+
OR = Language(code="or", name="Oriya")
|
|
140
|
+
OM = Language(code="om", name="Oromo")
|
|
141
|
+
OS = Language(code="os", name="Ossetian")
|
|
142
|
+
PI = Language(code="pi", name="Pali")
|
|
143
|
+
PS = Language(code="ps", name="Pashto")
|
|
144
|
+
FA = Language(code="fa", name="Persian")
|
|
145
|
+
PL = Language(code="pl", name="Polish")
|
|
146
|
+
PT = Language(code="pt", name="Portuguese")
|
|
147
|
+
PA = Language(code="pa", name="Punjabi")
|
|
148
|
+
QU = Language(code="qu", name="Quechua")
|
|
149
|
+
RO = Language(code="ro", name="Romanian")
|
|
150
|
+
RM = Language(code="rm", name="Romansh")
|
|
151
|
+
RN = Language(code="rn", name="Rundi")
|
|
152
|
+
RU = Language(code="ru", name="Russian")
|
|
153
|
+
SE = Language(code="se", name="Northern Sami")
|
|
154
|
+
SM = Language(code="sm", name="Samoan")
|
|
155
|
+
SG = Language(code="sg", name="Sango")
|
|
156
|
+
SA = Language(code="sa", name="Sanskrit")
|
|
157
|
+
SC = Language(code="sc", name="Sardinian")
|
|
158
|
+
SR = Language(code="sr", name="Serbian")
|
|
159
|
+
SN = Language(code="sn", name="Shona")
|
|
160
|
+
SD = Language(code="sd", name="Sindhi")
|
|
161
|
+
SI = Language(code="si", name="Sinhala")
|
|
162
|
+
SK = Language(code="sk", name="Slovak")
|
|
163
|
+
SL = Language(code="sl", name="Slovenian")
|
|
164
|
+
SO = Language(code="so", name="Somali")
|
|
165
|
+
ST = Language(code="st", name="Sotho")
|
|
166
|
+
ES = Language(code="es", name="Spanish")
|
|
167
|
+
SU = Language(code="su", name="Sundanese")
|
|
168
|
+
SW = Language(code="sw", name="Swahili")
|
|
169
|
+
SS = Language(code="ss", name="Swati")
|
|
170
|
+
SV = Language(code="sv", name="Swedish")
|
|
171
|
+
TL = Language(code="tl", name="Tagalog")
|
|
172
|
+
TY = Language(code="ty", name="Tahitian")
|
|
173
|
+
TG = Language(code="tg", name="Tajik")
|
|
174
|
+
TA = Language(code="ta", name="Tamil")
|
|
175
|
+
TT = Language(code="tt", name="Tatar")
|
|
176
|
+
TE = Language(code="te", name="Telugu")
|
|
177
|
+
TH = Language(code="th", name="Thai")
|
|
178
|
+
BO = Language(code="bo", name="Tibetan")
|
|
179
|
+
TI = Language(code="ti", name="Tigrinya")
|
|
180
|
+
TO = Language(code="to", name="Tonga")
|
|
181
|
+
TS = Language(code="ts", name="Tsonga")
|
|
182
|
+
TN = Language(code="tn", name="Tswana")
|
|
183
|
+
TR = Language(code="tr", name="Turkish")
|
|
184
|
+
TK = Language(code="tk", name="Turkmen")
|
|
185
|
+
TW = Language(code="tw", name="Twi")
|
|
186
|
+
UG = Language(code="ug", name="Uighur")
|
|
187
|
+
UK = Language(code="uk", name="Ukrainian")
|
|
188
|
+
UR = Language(code="ur", name="Urdu")
|
|
189
|
+
UZ = Language(code="uz", name="Uzbek")
|
|
190
|
+
VE = Language(code="ve", name="Venda")
|
|
191
|
+
VI = Language(code="vi", name="Vietnamese")
|
|
192
|
+
VO = Language(code="vo", name="Volapük")
|
|
193
|
+
WA = Language(code="wa", name="Walloon")
|
|
194
|
+
CY = Language(code="cy", name="Welsh")
|
|
195
|
+
WO = Language(code="wo", name="Wolof")
|
|
196
|
+
XH = Language(code="xh", name="Xhosa")
|
|
197
|
+
YI = Language(code="yi", name="Yiddish")
|
|
198
|
+
YO = Language(code="yo", name="Yoruba")
|
|
199
|
+
ZA = Language(code="za", name="Zhuang")
|
|
200
|
+
ZU = Language(code="zu", name="Zulu")
|
euroeval/model_cache.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""ModelCache class for caching model outputs."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import sys
|
|
7
|
+
import typing as t
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from dataclasses import asdict
|
|
10
|
+
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
|
|
13
|
+
from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
|
|
14
|
+
|
|
15
|
+
if t.TYPE_CHECKING:
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
from datasets import Dataset
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger("euroeval")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ModelCache:
|
|
25
|
+
"""A cache for model outputs.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
model_cache_dir:
|
|
29
|
+
The directory to store the cache in.
|
|
30
|
+
cache_path:
|
|
31
|
+
The path to the cache file.
|
|
32
|
+
cache:
|
|
33
|
+
The model output cache.
|
|
34
|
+
max_generated_tokens:
|
|
35
|
+
The maximum number of tokens to generate for each example.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self, model_cache_dir: "Path", cache_name: str, max_generated_tokens: int
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Initialize the model output cache.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
model_cache_dir:
|
|
45
|
+
The directory to store the cache in.
|
|
46
|
+
cache_name:
|
|
47
|
+
The name of the cache file.
|
|
48
|
+
max_generated_tokens:
|
|
49
|
+
The maximum number of tokens to generate for each example.
|
|
50
|
+
"""
|
|
51
|
+
self.model_cache_dir = model_cache_dir
|
|
52
|
+
self.model_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
self.cache_path = self.model_cache_dir / cache_name.replace("/", "--")
|
|
54
|
+
self.max_generated_tokens = max_generated_tokens
|
|
55
|
+
|
|
56
|
+
def load(self) -> None:
|
|
57
|
+
"""Load the model output cache."""
|
|
58
|
+
if not self.cache_path.exists():
|
|
59
|
+
with self.cache_path.open("w") as f:
|
|
60
|
+
json.dump(dict(), f)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
with self.cache_path.open() as f:
|
|
64
|
+
json_cache = json.load(f)
|
|
65
|
+
except json.JSONDecodeError:
|
|
66
|
+
logger.warning(
|
|
67
|
+
f"Failed to load the cache from {self.cache_path}. The cache will be "
|
|
68
|
+
f"re-initialised."
|
|
69
|
+
)
|
|
70
|
+
json_cache = dict()
|
|
71
|
+
with self.cache_path.open("w") as f:
|
|
72
|
+
json.dump(dict(), f)
|
|
73
|
+
|
|
74
|
+
cache: dict[str, SingleGenerativeModelOutput] = dict()
|
|
75
|
+
for key in json_cache:
|
|
76
|
+
cache[key] = SingleGenerativeModelOutput(**json_cache[key])
|
|
77
|
+
|
|
78
|
+
self.cache = cache
|
|
79
|
+
|
|
80
|
+
def save(self) -> None:
|
|
81
|
+
"""Save the model output cache to disk."""
|
|
82
|
+
dumpable_cache: dict[str, dict] = defaultdict(dict)
|
|
83
|
+
for key, value in self.cache.items():
|
|
84
|
+
dumpable_cache[key] = asdict(value)
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
with self.cache_path.open("w") as f:
|
|
88
|
+
json.dump(dumpable_cache, f)
|
|
89
|
+
except KeyError:
|
|
90
|
+
logger.warning(
|
|
91
|
+
f"Failed to load the cache from {self.cache_path}. The cache will be "
|
|
92
|
+
f"re-initialised."
|
|
93
|
+
)
|
|
94
|
+
self.cache = dict()
|
|
95
|
+
with self.cache_path.open("w") as f:
|
|
96
|
+
json.dump(dict(), f)
|
|
97
|
+
|
|
98
|
+
def _hash_key(self, key: str | list[dict[str, str]]) -> str:
|
|
99
|
+
"""Hash the key to use as an index in the cache.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
key:
|
|
103
|
+
The key to hash.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
The hashed key.
|
|
107
|
+
"""
|
|
108
|
+
return hashlib.md5(string=str(key).encode()).hexdigest()
|
|
109
|
+
|
|
110
|
+
def __getitem__(
|
|
111
|
+
self, key: str | list[dict[str, str]]
|
|
112
|
+
) -> SingleGenerativeModelOutput:
|
|
113
|
+
"""Get an item from the cache.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
key:
|
|
117
|
+
The key to use to index the cache.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
The model output.
|
|
121
|
+
"""
|
|
122
|
+
hashed_key = self._hash_key(key=key)
|
|
123
|
+
return self.cache[hashed_key]
|
|
124
|
+
|
|
125
|
+
def __setitem__(
|
|
126
|
+
self, key: str | list[dict[str, str]], value: SingleGenerativeModelOutput
|
|
127
|
+
) -> None:
|
|
128
|
+
"""Set an item in the cache.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
key:
|
|
132
|
+
The key to use to index the cache.
|
|
133
|
+
value:
|
|
134
|
+
The value to set in the cache.
|
|
135
|
+
"""
|
|
136
|
+
hashed_key = self._hash_key(key=key)
|
|
137
|
+
self.cache[hashed_key] = value
|
|
138
|
+
|
|
139
|
+
def remove(self) -> None:
|
|
140
|
+
"""Remove the cache from memory and delete it from disk."""
|
|
141
|
+
self.cache_path.unlink()
|
|
142
|
+
del self.cache
|
|
143
|
+
|
|
144
|
+
def __contains__(self, key: str | list[dict[str, str]]) -> bool:
|
|
145
|
+
"""Check if a key is in the cache.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
key:
|
|
149
|
+
The key to check.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Whether the key is in the cache.
|
|
153
|
+
"""
|
|
154
|
+
hashed_key = self._hash_key(key=key)
|
|
155
|
+
return hashed_key in self.cache
|
|
156
|
+
|
|
157
|
+
def add_to_cache(
|
|
158
|
+
self, model_inputs: dict, model_output: GenerativeModelOutput
|
|
159
|
+
) -> None:
|
|
160
|
+
"""Add the model input/output to the cache.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
model_inputs:
|
|
164
|
+
The model inputs.
|
|
165
|
+
model_output:
|
|
166
|
+
The model output.
|
|
167
|
+
"""
|
|
168
|
+
input_column = "messages" if "messages" in model_inputs else "text"
|
|
169
|
+
model_inputs = model_inputs[input_column]
|
|
170
|
+
|
|
171
|
+
# Store the generated sequences in the cache, one by one
|
|
172
|
+
with tqdm(
|
|
173
|
+
iterable=model_inputs,
|
|
174
|
+
desc="Caching model outputs",
|
|
175
|
+
leave=False,
|
|
176
|
+
disable=hasattr(sys, "_called_from_test"),
|
|
177
|
+
) as pbar:
|
|
178
|
+
for sample_idx, model_input in enumerate(pbar):
|
|
179
|
+
# Extract the scores from the model output, to be cached. We only store
|
|
180
|
+
# the indices of the top scores, to save space. Further, we only store
|
|
181
|
+
# the scores if the generated sequence is shorter than the maximum
|
|
182
|
+
# length
|
|
183
|
+
if model_output.scores is not None and self.max_generated_tokens < 8:
|
|
184
|
+
assert model_output.scores is not None
|
|
185
|
+
scores = model_output.scores[sample_idx]
|
|
186
|
+
else:
|
|
187
|
+
scores = None
|
|
188
|
+
self[model_input] = SingleGenerativeModelOutput(
|
|
189
|
+
sequence=model_output.sequences[sample_idx], scores=scores
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def split_dataset_into_cached_and_non_cached(
|
|
194
|
+
dataset: "Dataset", cache: ModelCache
|
|
195
|
+
) -> tuple["Dataset", "Dataset"]:
|
|
196
|
+
"""Split a dataset into a cached and non-cached part.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
dataset:
|
|
200
|
+
The dataset to split.
|
|
201
|
+
cache:
|
|
202
|
+
The model output cache.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
The cached and non-cached parts of the dataset.
|
|
206
|
+
"""
|
|
207
|
+
# Get the sample indices of the non-cached examples, which are unique with respect
|
|
208
|
+
# to the "text" column.
|
|
209
|
+
input_column = "messages" if "messages" in dataset.column_names else "text"
|
|
210
|
+
dataset_texts = dataset[input_column]
|
|
211
|
+
unique_non_cached_ids = set()
|
|
212
|
+
unique_texts = list()
|
|
213
|
+
for idx, dataset_text in enumerate(dataset_texts):
|
|
214
|
+
if dataset_text not in cache and dataset_text not in unique_texts:
|
|
215
|
+
unique_non_cached_ids.add(idx)
|
|
216
|
+
unique_texts.append(dataset_text)
|
|
217
|
+
|
|
218
|
+
# The cached examples are the ones that are not in the non-cached examples. This
|
|
219
|
+
# means that if the dataset has duplicates, only a single copy of the duplicate
|
|
220
|
+
# will be put in the non-cached part, and the rest in the cached part.
|
|
221
|
+
cached_ids = set(range(len(dataset))) - unique_non_cached_ids
|
|
222
|
+
|
|
223
|
+
cached = dataset.select(cached_ids)
|
|
224
|
+
non_cached = dataset.select(unique_non_cached_ids)
|
|
225
|
+
return cached, non_cached
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def load_cached_model_outputs(
|
|
229
|
+
cached_dataset: "Dataset", cache: ModelCache
|
|
230
|
+
) -> GenerativeModelOutput:
|
|
231
|
+
"""Load the cached model outputs.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
cached_dataset:
|
|
235
|
+
The dataset containing the cached examples.
|
|
236
|
+
cache:
|
|
237
|
+
The model output cache.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
The model output containing the cached sequences.
|
|
241
|
+
"""
|
|
242
|
+
input_column = "messages" if "messages" in cached_dataset.column_names else "text"
|
|
243
|
+
cached_model_outputs: list[SingleGenerativeModelOutput] = [
|
|
244
|
+
cache[prompt] for prompt in cached_dataset[input_column]
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
cached_sequences = [model_output.sequence for model_output in cached_model_outputs]
|
|
248
|
+
|
|
249
|
+
if cached_model_outputs[0].scores is None:
|
|
250
|
+
return GenerativeModelOutput(sequences=cached_sequences)
|
|
251
|
+
|
|
252
|
+
cached_scores = [model_output.scores or [] for model_output in cached_model_outputs]
|
|
253
|
+
return GenerativeModelOutput(sequences=cached_sequences, scores=cached_scores)
|
euroeval/model_config.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Functions related to getting the model configuration."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
from . import benchmark_modules
|
|
7
|
+
from .exceptions import InvalidModel, NeedsEnvironmentVariable, NeedsExtraInstalled
|
|
8
|
+
|
|
9
|
+
if t.TYPE_CHECKING:
|
|
10
|
+
from .data_models import BenchmarkConfig, ModelConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("euroeval")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_model_config(
|
|
17
|
+
model_id: str, benchmark_config: "BenchmarkConfig"
|
|
18
|
+
) -> "ModelConfig":
|
|
19
|
+
"""Fetches configuration for a model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model_id:
|
|
23
|
+
The model ID.
|
|
24
|
+
benchmark_config:
|
|
25
|
+
The configuration of the benchmark.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The model configuration.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
InvalidModel:
|
|
32
|
+
If all model setups can handle the model, but the model does not exist.
|
|
33
|
+
"""
|
|
34
|
+
all_benchmark_modules = [
|
|
35
|
+
cls
|
|
36
|
+
for cls in benchmark_modules.__dict__.values()
|
|
37
|
+
if isinstance(cls, type)
|
|
38
|
+
and issubclass(cls, benchmark_modules.BenchmarkModule)
|
|
39
|
+
and cls is not benchmark_modules.BenchmarkModule
|
|
40
|
+
]
|
|
41
|
+
all_benchmark_modules.sort(key=lambda cls: cls.high_priority, reverse=True)
|
|
42
|
+
|
|
43
|
+
needs_extras: list[str] = list()
|
|
44
|
+
needs_env_vars: list[str] = list()
|
|
45
|
+
for benchmark_module in all_benchmark_modules:
|
|
46
|
+
exists_or_err = benchmark_module.model_exists(
|
|
47
|
+
model_id=model_id, benchmark_config=benchmark_config
|
|
48
|
+
)
|
|
49
|
+
if isinstance(exists_or_err, NeedsExtraInstalled):
|
|
50
|
+
needs_extras.append(exists_or_err.extra)
|
|
51
|
+
elif isinstance(exists_or_err, NeedsEnvironmentVariable):
|
|
52
|
+
needs_env_vars.append(exists_or_err.env_var)
|
|
53
|
+
elif exists_or_err is True:
|
|
54
|
+
logger.debug(
|
|
55
|
+
f"The model {model_id!r} was identified by the "
|
|
56
|
+
f"{benchmark_module.__name__} benchmark module."
|
|
57
|
+
)
|
|
58
|
+
model_config = benchmark_module.get_model_config(
|
|
59
|
+
model_id=model_id, benchmark_config=benchmark_config
|
|
60
|
+
)
|
|
61
|
+
return model_config
|
|
62
|
+
else:
|
|
63
|
+
msg = f"Model {model_id} not found."
|
|
64
|
+
if needs_extras:
|
|
65
|
+
msg += (
|
|
66
|
+
" However, it is possible that the model exists, but a package "
|
|
67
|
+
"needs to be installed to check if it exists. Please try running "
|
|
68
|
+
f"`pip install euroeval[{','.join(needs_extras)}]` or `pip install "
|
|
69
|
+
"euroeval[all]`, and try again."
|
|
70
|
+
)
|
|
71
|
+
elif needs_env_vars:
|
|
72
|
+
msg += (
|
|
73
|
+
" However, it is possible that the model exists, but an environment "
|
|
74
|
+
"variable needs to be set to check if it exists. Please set the "
|
|
75
|
+
f"environment variables {','.join(needs_env_vars)} and try again."
|
|
76
|
+
)
|
|
77
|
+
raise InvalidModel(msg)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Functions related to the loading of models."""
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from .benchmark_modules import (
|
|
6
|
+
FreshEncoderModel,
|
|
7
|
+
HuggingFaceEncoderModel,
|
|
8
|
+
LiteLLMModel,
|
|
9
|
+
VLLMModel,
|
|
10
|
+
)
|
|
11
|
+
from .constants import GENERATIVE_DATASET_TASK_GROUPS
|
|
12
|
+
from .enums import InferenceBackend, ModelType
|
|
13
|
+
from .exceptions import InvalidBenchmark, InvalidModel
|
|
14
|
+
|
|
15
|
+
if t.TYPE_CHECKING:
|
|
16
|
+
from .benchmark_modules import BenchmarkModule
|
|
17
|
+
from .data_models import BenchmarkConfig, DatasetConfig, ModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_model(
|
|
21
|
+
model_config: "ModelConfig",
|
|
22
|
+
dataset_config: "DatasetConfig",
|
|
23
|
+
benchmark_config: "BenchmarkConfig",
|
|
24
|
+
) -> "BenchmarkModule":
|
|
25
|
+
"""Load a model.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model_config:
|
|
29
|
+
The model configuration.
|
|
30
|
+
dataset_config:
|
|
31
|
+
The dataset configuration.
|
|
32
|
+
benchmark_config:
|
|
33
|
+
The benchmark configuration.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
The model.
|
|
37
|
+
"""
|
|
38
|
+
# The order matters; the first model type that matches will be used. For this
|
|
39
|
+
# reason, they have been ordered in terms of the most common model types.
|
|
40
|
+
model_class: t.Type[BenchmarkModule]
|
|
41
|
+
match (model_config.model_type, model_config.inference_backend, model_config.fresh):
|
|
42
|
+
case (ModelType.GENERATIVE, InferenceBackend.VLLM, False):
|
|
43
|
+
model_class = VLLMModel
|
|
44
|
+
case (ModelType.ENCODER, InferenceBackend.TRANSFORMERS, False):
|
|
45
|
+
model_class = HuggingFaceEncoderModel
|
|
46
|
+
case (ModelType.GENERATIVE, InferenceBackend.LITELLM, False):
|
|
47
|
+
model_class = LiteLLMModel
|
|
48
|
+
case (ModelType.ENCODER, InferenceBackend.TRANSFORMERS, True):
|
|
49
|
+
model_class = FreshEncoderModel
|
|
50
|
+
case (_, _, True):
|
|
51
|
+
raise InvalidModel(
|
|
52
|
+
"Cannot load a freshly initialised model with the model type "
|
|
53
|
+
f"{model_config.model_type!r} and inference backend "
|
|
54
|
+
f"{model_config.inference_backend!r}."
|
|
55
|
+
)
|
|
56
|
+
case _:
|
|
57
|
+
raise InvalidModel(
|
|
58
|
+
f"Cannot load model with model type {model_config.model_type!r} and "
|
|
59
|
+
f"inference backend {model_config.inference_backend!r}."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Refuse to benchmark non-generative models on generative tasks
|
|
63
|
+
if (
|
|
64
|
+
dataset_config.task.task_group in GENERATIVE_DATASET_TASK_GROUPS
|
|
65
|
+
and not model_config.model_type == ModelType.GENERATIVE
|
|
66
|
+
):
|
|
67
|
+
raise InvalidBenchmark(
|
|
68
|
+
f"Cannot benchmark non-generative model {model_config.model_id!r} on "
|
|
69
|
+
f"generative task {dataset_config.task.name!r}."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
model = model_class(
|
|
73
|
+
model_config=model_config,
|
|
74
|
+
dataset_config=dataset_config,
|
|
75
|
+
benchmark_config=benchmark_config,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return model
|