EuroEval 15.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +72 -0
- euroeval/benchmark_config_factory.py +358 -0
- euroeval/benchmark_modules/__init__.py +7 -0
- euroeval/benchmark_modules/base.py +354 -0
- euroeval/benchmark_modules/fresh.py +286 -0
- euroeval/benchmark_modules/hf.py +1185 -0
- euroeval/benchmark_modules/litellm.py +905 -0
- euroeval/benchmark_modules/vllm.py +1171 -0
- euroeval/benchmarker.py +1074 -0
- euroeval/callbacks.py +72 -0
- euroeval/cli.py +281 -0
- euroeval/constants.py +50 -0
- euroeval/data_loading.py +96 -0
- euroeval/data_models.py +474 -0
- euroeval/dataset_configs.py +2001 -0
- euroeval/enums.py +144 -0
- euroeval/exceptions.py +191 -0
- euroeval/finetuning.py +324 -0
- euroeval/generation.py +296 -0
- euroeval/human_evaluation.py +737 -0
- euroeval/languages.py +200 -0
- euroeval/model_cache.py +253 -0
- euroeval/model_config.py +77 -0
- euroeval/model_loading.py +78 -0
- euroeval/scores.py +90 -0
- euroeval/speed_benchmark.py +124 -0
- euroeval/task_utils/__init__.py +1 -0
- euroeval/task_utils/multiple_choice_classification.py +176 -0
- euroeval/task_utils/question_answering.py +698 -0
- euroeval/task_utils/sequence_classification.py +237 -0
- euroeval/task_utils/text_to_text.py +150 -0
- euroeval/task_utils/token_classification.py +464 -0
- euroeval/tasks.py +202 -0
- euroeval/types.py +97 -0
- euroeval/utils.py +574 -0
- euroeval-15.2.0.dist-info/METADATA +234 -0
- euroeval-15.2.0.dist-info/RECORD +40 -0
- euroeval-15.2.0.dist-info/WHEEL +4 -0
- euroeval-15.2.0.dist-info/entry_points.txt +4 -0
- euroeval-15.2.0.dist-info/licenses/LICENSE +21 -0
euroeval/types.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Types used throughout the project."""
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
if t.TYPE_CHECKING:
|
|
8
|
+
from .data_models import GenerativeModelOutput
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
ScoreDict = dict[str, dict[str, float] | list[dict[str, float]]]
|
|
12
|
+
Predictions = NDArray | list[str] | list[list[str]]
|
|
13
|
+
Labels = NDArray | list[str] | list[list[str]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ComputeMetricsFunction(t.Protocol):
|
|
17
|
+
"""A function used to compute the metrics."""
|
|
18
|
+
|
|
19
|
+
def __call__(
|
|
20
|
+
self,
|
|
21
|
+
model_outputs_and_labels: tuple[
|
|
22
|
+
NDArray | list[str] | list[list[str]], NDArray | list[str] | list[list[str]]
|
|
23
|
+
],
|
|
24
|
+
) -> dict[str, float]:
|
|
25
|
+
"""Compute the metrics.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model_outputs_and_labels:
|
|
29
|
+
The model outputs and labels.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The computed metrics.
|
|
33
|
+
"""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ExtractLabelsFunction(t.Protocol):
|
|
38
|
+
"""A function used to extract the labels from the generated output."""
|
|
39
|
+
|
|
40
|
+
def __call__(
|
|
41
|
+
self, input_batch: dict[str, list], model_output: "GenerativeModelOutput"
|
|
42
|
+
) -> list[str]:
|
|
43
|
+
"""Extract the labels from the generated output.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
input_batch:
|
|
47
|
+
The input batch.
|
|
48
|
+
model_output:
|
|
49
|
+
The model output.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The extracted labels.
|
|
53
|
+
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def is_list_of_int(x: object) -> t.TypeGuard[list[int]]:
|
|
58
|
+
"""Check if an object is a list of integers.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
x:
|
|
62
|
+
The object to check.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Whether the object is a list of integers.
|
|
66
|
+
"""
|
|
67
|
+
return isinstance(x, list) and all(isinstance(i, int) for i in x)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def is_list_of_list_of_int(x: object) -> t.TypeGuard[list[list[int]]]:
|
|
71
|
+
"""Check if an object is a list of list of integers.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
x:
|
|
75
|
+
The object to check.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Whether the object is a list of list of integers.
|
|
79
|
+
"""
|
|
80
|
+
return (
|
|
81
|
+
isinstance(x, list)
|
|
82
|
+
and all(isinstance(i, list) for i in x)
|
|
83
|
+
and all(isinstance(j, int) for i in x for j in i)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def is_list_of_str(x: object) -> t.TypeGuard[list[str]]:
|
|
88
|
+
"""Check if an object is a list of integers.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
x:
|
|
92
|
+
The object to check.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Whether the object is a list of strings.
|
|
96
|
+
"""
|
|
97
|
+
return isinstance(x, list) and all(isinstance(i, str) for i in x)
|
euroeval/utils.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
1
|
+
"""Utility functions to be used in other scripts."""
|
|
2
|
+
|
|
3
|
+
import gc
|
|
4
|
+
import importlib
|
|
5
|
+
import importlib.util
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import random
|
|
9
|
+
import re
|
|
10
|
+
import sys
|
|
11
|
+
import typing as t
|
|
12
|
+
import warnings
|
|
13
|
+
from functools import cache
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from types import TracebackType
|
|
16
|
+
|
|
17
|
+
import litellm
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pkg_resources
|
|
20
|
+
import requests
|
|
21
|
+
import torch
|
|
22
|
+
from datasets.utils import disable_progress_bar
|
|
23
|
+
from requests.exceptions import RequestException
|
|
24
|
+
from transformers import PreTrainedTokenizer
|
|
25
|
+
from transformers import logging as tf_logging
|
|
26
|
+
|
|
27
|
+
from .exceptions import InvalidModel, NaNValueInModelOutput
|
|
28
|
+
|
|
29
|
+
if importlib.util.find_spec("ray") is not None:
|
|
30
|
+
import ray
|
|
31
|
+
|
|
32
|
+
if t.TYPE_CHECKING:
|
|
33
|
+
from .types import Predictions
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger("euroeval")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
|
|
40
|
+
"""Create cache directory for a model.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
cache_dir:
|
|
44
|
+
The cache directory.
|
|
45
|
+
model_id:
|
|
46
|
+
The model ID.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The path to the cache directory.
|
|
50
|
+
"""
|
|
51
|
+
# to avoid nesting due to models name containing '/'
|
|
52
|
+
_model_id = model_id.replace("/", "--")
|
|
53
|
+
cache_dir_path = Path(cache_dir) / "model_cache" / _model_id
|
|
54
|
+
return str(cache_dir_path)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def clear_memory() -> None:
|
|
58
|
+
"""Clears the memory of unused items."""
|
|
59
|
+
for gc_generation in range(3):
|
|
60
|
+
gc.collect(generation=gc_generation)
|
|
61
|
+
if torch.cuda.is_available():
|
|
62
|
+
torch.cuda.empty_cache()
|
|
63
|
+
if torch.backends.mps.is_available():
|
|
64
|
+
torch.mps.empty_cache()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def enforce_reproducibility(seed: int = 4242) -> np.random.Generator:
|
|
68
|
+
"""Ensures reproducibility of experiments.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
seed:
|
|
72
|
+
Seed for the random number generator.
|
|
73
|
+
"""
|
|
74
|
+
random.seed(seed)
|
|
75
|
+
np.random.seed(seed)
|
|
76
|
+
rng = np.random.default_rng(seed)
|
|
77
|
+
torch.manual_seed(seed)
|
|
78
|
+
torch.cuda.manual_seed_all(seed)
|
|
79
|
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
80
|
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
|
81
|
+
torch.backends.cudnn.benchmark = False
|
|
82
|
+
torch.backends.cudnn.deterministic = True
|
|
83
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
84
|
+
return rng
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def is_module_installed(module: str) -> bool:
|
|
88
|
+
"""Check if a module is installed.
|
|
89
|
+
|
|
90
|
+
This is used when dealing with spaCy models, as these are installed as separate
|
|
91
|
+
Python packages.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
module:
|
|
95
|
+
The name of the module.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Whether the module is installed or not.
|
|
99
|
+
"""
|
|
100
|
+
# Get list of all modules, including their versions
|
|
101
|
+
installed_modules_with_versions = list(pkg_resources.working_set)
|
|
102
|
+
|
|
103
|
+
# Strip the module versions from the list of modules. Also make the modules lower
|
|
104
|
+
# case and replace dashes with underscores
|
|
105
|
+
installed_modules = [
|
|
106
|
+
re.sub("[0-9. ]", "", str(module)).lower().replace("-", "_")
|
|
107
|
+
for module in installed_modules_with_versions
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# Check if the module is installed by checking if the module name is in the list
|
|
111
|
+
return module.lower() in installed_modules
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def block_terminal_output() -> None:
|
|
115
|
+
"""Blocks libraries from writing output to the terminal.
|
|
116
|
+
|
|
117
|
+
This filters warnings from some libraries, sets the logging level to ERROR for some
|
|
118
|
+
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
|
|
119
|
+
disables most of the logging from the `transformers` library.
|
|
120
|
+
"""
|
|
121
|
+
# Ignore miscellaneous warnings
|
|
122
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
123
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
124
|
+
warnings.filterwarnings(
|
|
125
|
+
"ignore",
|
|
126
|
+
module="torch.nn.parallel*",
|
|
127
|
+
message="Was asked to gather along dimension 0, but all input tensors were "
|
|
128
|
+
"scalars; will instead unsqueeze and return a vector.",
|
|
129
|
+
)
|
|
130
|
+
warnings.filterwarnings("ignore", module="seqeval*")
|
|
131
|
+
|
|
132
|
+
# Up the logging level, to disable outputs
|
|
133
|
+
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
|
134
|
+
logging.getLogger("absl").setLevel(logging.CRITICAL)
|
|
135
|
+
logging.getLogger("datasets").setLevel(logging.CRITICAL)
|
|
136
|
+
logging.getLogger("openai").setLevel(logging.CRITICAL)
|
|
137
|
+
logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.CRITICAL)
|
|
138
|
+
logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.CRITICAL)
|
|
139
|
+
logging.getLogger("vllm").setLevel(logging.CRITICAL)
|
|
140
|
+
logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
|
|
141
|
+
logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
|
|
142
|
+
logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
|
|
143
|
+
logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
|
|
144
|
+
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
145
|
+
logging.getLogger("ray._private.worker").setLevel(logging.CRITICAL)
|
|
146
|
+
logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
|
|
147
|
+
logging.getLogger("accelerate").setLevel(logging.CRITICAL)
|
|
148
|
+
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
|
149
|
+
logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
|
|
150
|
+
|
|
151
|
+
# This suppresses vLLM logging
|
|
152
|
+
os.environ["LOG_LEVEL"] = "CRITICAL"
|
|
153
|
+
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
154
|
+
|
|
155
|
+
if importlib.util.find_spec("ray") is not None:
|
|
156
|
+
ray._private.worker._worker_logs_enabled = False
|
|
157
|
+
|
|
158
|
+
# Disable the tokeniser progress bars
|
|
159
|
+
disable_progress_bar()
|
|
160
|
+
|
|
161
|
+
# Disable most of the `transformers` logging
|
|
162
|
+
tf_logging._default_log_level = logging.CRITICAL
|
|
163
|
+
tf_logging.set_verbosity(logging.CRITICAL)
|
|
164
|
+
logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
|
|
165
|
+
|
|
166
|
+
# Disable logging from `litellm`
|
|
167
|
+
litellm.suppress_debug_info = True
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type | None:
|
|
171
|
+
"""Get a class by its name.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
class_name:
|
|
175
|
+
The name of the class, written in kebab-case. The corresponding class name
|
|
176
|
+
must be the same, but written in PascalCase, and lying in a module with the
|
|
177
|
+
same name, but written in snake_case. If a list of strings is passed, the
|
|
178
|
+
first class that is found is returned.
|
|
179
|
+
module_name:
|
|
180
|
+
The name of the module where the class is located.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
The class. If the class is not found, None is returned.
|
|
184
|
+
"""
|
|
185
|
+
if isinstance(class_name, str):
|
|
186
|
+
class_name = [class_name]
|
|
187
|
+
|
|
188
|
+
error_messages = list()
|
|
189
|
+
for name in class_name:
|
|
190
|
+
try:
|
|
191
|
+
module = importlib.import_module(name=module_name)
|
|
192
|
+
class_: t.Type = getattr(module, name)
|
|
193
|
+
return class_
|
|
194
|
+
except (ModuleNotFoundError, AttributeError) as e:
|
|
195
|
+
error_messages.append(str(e))
|
|
196
|
+
|
|
197
|
+
if error_messages:
|
|
198
|
+
errors = "\n- " + "\n- ".join(error_messages)
|
|
199
|
+
logger.debug(
|
|
200
|
+
f"Could not find the class with the name(s) {', '.join(class_name)}. The "
|
|
201
|
+
f"following error messages were raised: {errors}"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# If the class could not be found, return None
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def kebab_to_pascal(kebab_string: str) -> str:
|
|
209
|
+
"""Converts a kebab-case string to PascalCase.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
kebab_string:
|
|
213
|
+
The kebab-case string.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
The PascalCase string.
|
|
217
|
+
"""
|
|
218
|
+
return "".join(word.title() for word in kebab_string.split("-"))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def internet_connection_available() -> bool:
|
|
222
|
+
"""Checks if internet connection is available by pinging google.com.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Whether or not internet connection is available.
|
|
226
|
+
"""
|
|
227
|
+
try:
|
|
228
|
+
requests.get("https://www.google.com")
|
|
229
|
+
return True
|
|
230
|
+
except RequestException:
|
|
231
|
+
return False
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def get_special_token_metadata(tokenizer: "PreTrainedTokenizer") -> dict:
|
|
235
|
+
"""Get the special token metadata for a tokenizer.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
tokenizer:
|
|
239
|
+
The tokenizer.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
The special token metadata.
|
|
243
|
+
"""
|
|
244
|
+
# Create some test input IDs, to check if the tokenizer is adding special tokens
|
|
245
|
+
test_input_ids = tokenizer("Test").input_ids
|
|
246
|
+
|
|
247
|
+
# Extract the CLS token IDs from the tokenizer, if it's using them
|
|
248
|
+
has_cls_token = True
|
|
249
|
+
if tokenizer.cls_token_id in test_input_ids:
|
|
250
|
+
cls_token_id = tokenizer.cls_token_id
|
|
251
|
+
cls_token = tokenizer.cls_token
|
|
252
|
+
elif tokenizer.bos_token_id in test_input_ids:
|
|
253
|
+
cls_token_id = tokenizer.bos_token_id
|
|
254
|
+
cls_token = tokenizer.bos_token
|
|
255
|
+
elif tokenizer.cls_token is not None:
|
|
256
|
+
cls_token_id = tokenizer.cls_token_id
|
|
257
|
+
cls_token = tokenizer.cls_token
|
|
258
|
+
has_cls_token = False
|
|
259
|
+
else:
|
|
260
|
+
cls_token_id = tokenizer.bos_token_id
|
|
261
|
+
cls_token = tokenizer.bos_token
|
|
262
|
+
has_cls_token = False
|
|
263
|
+
|
|
264
|
+
# Extract the SEP token IDs from the tokenizer, if it's using them
|
|
265
|
+
has_sep_token = True
|
|
266
|
+
if tokenizer.sep_token_id in test_input_ids:
|
|
267
|
+
sep_token = tokenizer.sep_token
|
|
268
|
+
elif tokenizer.eos_token_id in test_input_ids:
|
|
269
|
+
sep_token = tokenizer.eos_token
|
|
270
|
+
elif tokenizer.sep_token is not None:
|
|
271
|
+
sep_token = tokenizer.sep_token
|
|
272
|
+
has_sep_token = False
|
|
273
|
+
else:
|
|
274
|
+
sep_token = tokenizer.eos_token
|
|
275
|
+
has_sep_token = False
|
|
276
|
+
|
|
277
|
+
return dict(
|
|
278
|
+
cls_token_id=cls_token_id,
|
|
279
|
+
cls_token=cls_token,
|
|
280
|
+
sep_token=sep_token,
|
|
281
|
+
has_cls_token=has_cls_token,
|
|
282
|
+
has_sep_token=has_sep_token,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class HiddenPrints:
|
|
287
|
+
"""Context manager which removes all terminal output."""
|
|
288
|
+
|
|
289
|
+
def __enter__(self) -> None:
|
|
290
|
+
"""Enter the context manager."""
|
|
291
|
+
self._original_stdout = sys.stdout
|
|
292
|
+
self._original_stderr = sys.stderr
|
|
293
|
+
sys.stdout = open(os.devnull, "w")
|
|
294
|
+
sys.stderr = open(os.devnull, "w")
|
|
295
|
+
|
|
296
|
+
def __exit__(
|
|
297
|
+
self,
|
|
298
|
+
exc_type: t.Type[BaseException],
|
|
299
|
+
exc_val: BaseException,
|
|
300
|
+
exc_tb: TracebackType,
|
|
301
|
+
) -> None:
|
|
302
|
+
"""Exit the context manager."""
|
|
303
|
+
sys.stdout.close()
|
|
304
|
+
sys.stderr.close()
|
|
305
|
+
sys.stdout = self._original_stdout
|
|
306
|
+
sys.stderr = self._original_stderr
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def raise_if_model_output_contains_nan_values(model_output: "Predictions") -> None:
|
|
310
|
+
"""Raise an exception if the model output contains NaN values.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
model_output:
|
|
314
|
+
The model output to check.
|
|
315
|
+
|
|
316
|
+
Raises:
|
|
317
|
+
If the model output contains NaN values.
|
|
318
|
+
"""
|
|
319
|
+
if isinstance(model_output, np.ndarray):
|
|
320
|
+
if model_output.dtype == np.float32 and np.isnan(model_output).any():
|
|
321
|
+
raise NaNValueInModelOutput()
|
|
322
|
+
elif len(model_output) > 0:
|
|
323
|
+
if isinstance(model_output[0], str):
|
|
324
|
+
if any(x != x for x in model_output):
|
|
325
|
+
raise NaNValueInModelOutput()
|
|
326
|
+
elif len(model_output[0]) > 0:
|
|
327
|
+
if any(x != x for sublist in model_output for x in sublist):
|
|
328
|
+
raise NaNValueInModelOutput()
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def should_prompts_be_stripped(
|
|
332
|
+
labels_to_be_generated: list[str], tokenizer: "PreTrainedTokenizer"
|
|
333
|
+
) -> bool:
|
|
334
|
+
"""Determine if we should strip the prompts for few-shot evaluation.
|
|
335
|
+
|
|
336
|
+
This is the case if the tokenizer needs to include the space as part of the label
|
|
337
|
+
token. The strategy is thus to tokenize a label with a preceeding colon (as in the
|
|
338
|
+
prompts), i.e., ": positive", and check if the tokenization starts with the tokens
|
|
339
|
+
of ": ". If this is the case, then we should not strip the prompts, since the
|
|
340
|
+
tokenizer produces the whitespace token separately.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
labels_to_be_generated:
|
|
344
|
+
The labels that are to be generated.
|
|
345
|
+
tokenizer:
|
|
346
|
+
The tokenizer used to tokenize the labels.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
Whether we should strip the prompts.
|
|
350
|
+
"""
|
|
351
|
+
strip_prompts = True
|
|
352
|
+
for label in labels_to_be_generated:
|
|
353
|
+
colon_tokens = tokenizer(": ", add_special_tokens=False).input_ids
|
|
354
|
+
label_tokens = tokenizer(": " + label, add_special_tokens=False).input_ids
|
|
355
|
+
|
|
356
|
+
if isinstance(colon_tokens, torch.Tensor):
|
|
357
|
+
colon_tokens = list(colon_tokens.squeeze(0))
|
|
358
|
+
if isinstance(label_tokens, torch.Tensor):
|
|
359
|
+
label_tokens = list(label_tokens.squeeze(0))
|
|
360
|
+
|
|
361
|
+
label_tokens_start_with_colon_tokens = (
|
|
362
|
+
label_tokens[: len(colon_tokens)] == colon_tokens
|
|
363
|
+
)
|
|
364
|
+
if label_tokens_start_with_colon_tokens:
|
|
365
|
+
strip_prompts = False
|
|
366
|
+
|
|
367
|
+
return strip_prompts
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
# TODO: This is currently not used - maybe remove.
|
|
371
|
+
def should_prefix_space_be_added_to_labels(
|
|
372
|
+
labels_to_be_generated: list[str], tokenizer: "PreTrainedTokenizer"
|
|
373
|
+
) -> bool:
|
|
374
|
+
"""Determine if we should add a prefix space to the labels.
|
|
375
|
+
|
|
376
|
+
This is the case if the prompts are stripped and the tokenizer doesn't
|
|
377
|
+
automatically add prefix whitespaces to the labels.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
labels_to_be_generated:
|
|
381
|
+
The labels that are to be generated.
|
|
382
|
+
tokenizer:
|
|
383
|
+
The tokenizer used to tokenize the labels.
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
Whether we should add a prefix space to the labels.
|
|
387
|
+
"""
|
|
388
|
+
if not should_prompts_be_stripped(
|
|
389
|
+
labels_to_be_generated=labels_to_be_generated, tokenizer=tokenizer
|
|
390
|
+
):
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
whitespace_token = tokenizer.convert_ids_to_tokens(
|
|
394
|
+
ids=tokenizer(" ", add_special_tokens=False).input_ids[0]
|
|
395
|
+
)[0]
|
|
396
|
+
|
|
397
|
+
add_prefix_space = True
|
|
398
|
+
for label in labels_to_be_generated:
|
|
399
|
+
label_tokens = tokenizer(label, add_special_tokens=False).input_ids
|
|
400
|
+
if isinstance(label_tokens, torch.Tensor):
|
|
401
|
+
label_tokens = list(label_tokens.squeeze(0))
|
|
402
|
+
first_label_token: int = int(label_tokens[0])
|
|
403
|
+
first_character_of_label = tokenizer.convert_ids_to_tokens(first_label_token)[0]
|
|
404
|
+
has_prefix_space = first_character_of_label == whitespace_token
|
|
405
|
+
if has_prefix_space:
|
|
406
|
+
add_prefix_space = False
|
|
407
|
+
break
|
|
408
|
+
|
|
409
|
+
return add_prefix_space
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def get_bos_token(tokenizer: "PreTrainedTokenizer") -> tuple[str, int]:
|
|
413
|
+
"""Get the beginning-of-sequence token from a tokenizer.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
tokenizer:
|
|
417
|
+
The tokenizer.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
A pair (token, token_id) representing the beginning-of-sequence token and its
|
|
421
|
+
token ID.
|
|
422
|
+
"""
|
|
423
|
+
if isinstance(tokenizer.bos_token, str) and isinstance(tokenizer.bos_token_id, int):
|
|
424
|
+
return tokenizer.bos_token, tokenizer.bos_token_id
|
|
425
|
+
|
|
426
|
+
vocab: dict[str, int] = tokenizer.get_vocab()
|
|
427
|
+
|
|
428
|
+
candidate_bos_tokens = ["<s>", "<|begin_of_text|>", "[CLS]"]
|
|
429
|
+
for candidate_bos_token in candidate_bos_tokens:
|
|
430
|
+
if candidate_bos_token in vocab:
|
|
431
|
+
bos_token = candidate_bos_token
|
|
432
|
+
bos_token_id = vocab[bos_token]
|
|
433
|
+
break
|
|
434
|
+
else:
|
|
435
|
+
raise InvalidModel(
|
|
436
|
+
"The model does not have a beginning-of-sequence token. Please ensure that "
|
|
437
|
+
"this has been set in the tokenizer's configuration."
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
return bos_token, bos_token_id
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def get_eos_token(tokenizer: "PreTrainedTokenizer") -> tuple[str, int]:
|
|
444
|
+
"""Get the end-of-sequence token from a tokenizer.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
tokenizer:
|
|
448
|
+
The tokenizer.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
A pair (token, token_id) representing the end-of-sequence token and its token
|
|
452
|
+
ID.
|
|
453
|
+
"""
|
|
454
|
+
if isinstance(tokenizer.eos_token, str) and isinstance(tokenizer.eos_token_id, int):
|
|
455
|
+
return tokenizer.eos_token, tokenizer.eos_token_id
|
|
456
|
+
|
|
457
|
+
vocab: dict[str, int] = tokenizer.get_vocab()
|
|
458
|
+
|
|
459
|
+
candidate_eos_tokens = ["</s>", "<|end_of_text|>", "[SEP]"]
|
|
460
|
+
for candidate_eos_token in candidate_eos_tokens:
|
|
461
|
+
if candidate_eos_token in vocab:
|
|
462
|
+
eos_token = candidate_eos_token
|
|
463
|
+
eos_token_id = vocab[eos_token]
|
|
464
|
+
break
|
|
465
|
+
else:
|
|
466
|
+
raise InvalidModel(
|
|
467
|
+
"The model does not have an end-of-sequence token. Please ensure that this "
|
|
468
|
+
"has been set in the tokenizer's configuration."
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
return eos_token, eos_token_id
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def get_end_of_chat_token_ids(tokenizer: "PreTrainedTokenizer") -> list[int] | None:
|
|
475
|
+
"""Get the end token ID for chat models.
|
|
476
|
+
|
|
477
|
+
This is only relevant for tokenizers with a chat template.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
tokenizer:
|
|
481
|
+
The tokenizer.
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
The token IDs used to end chats, or None if the tokenizer does not have a chat
|
|
485
|
+
template.
|
|
486
|
+
|
|
487
|
+
Raises:
|
|
488
|
+
ValueError:
|
|
489
|
+
If the end-of-chat token could not be located.
|
|
490
|
+
"""
|
|
491
|
+
if tokenizer.chat_template is None:
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
user_message: dict[t.Literal["role", "content"], str] = dict()
|
|
495
|
+
user_message["role"] = "user"
|
|
496
|
+
user_message["content"] = "X"
|
|
497
|
+
token_ids = tokenizer.apply_chat_template(conversation=[user_message])
|
|
498
|
+
assert isinstance(token_ids, list)
|
|
499
|
+
|
|
500
|
+
for idx, token in enumerate(tokenizer.convert_ids_to_tokens(token_ids)):
|
|
501
|
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
|
502
|
+
assert isinstance(token_id, int)
|
|
503
|
+
token = tokenizer.decode([token_id])
|
|
504
|
+
if "X" in token:
|
|
505
|
+
x_token_index = idx
|
|
506
|
+
break
|
|
507
|
+
else:
|
|
508
|
+
raise ValueError("Could not locate the end-of-chat token for the model.")
|
|
509
|
+
|
|
510
|
+
end_of_chat_tokens = token_ids[x_token_index + 1 :]
|
|
511
|
+
if len(end_of_chat_tokens) == 0:
|
|
512
|
+
return None
|
|
513
|
+
return end_of_chat_tokens
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def scramble(text: str) -> str:
|
|
517
|
+
"""Scramble a string in a bijective manner.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
text:
|
|
521
|
+
The string to scramble.
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
The scrambled string.
|
|
525
|
+
"""
|
|
526
|
+
rng = np.random.default_rng(seed=4242)
|
|
527
|
+
permutation = rng.permutation(x=len(text))
|
|
528
|
+
scrambled = "".join(text[i] for i in permutation)
|
|
529
|
+
return scrambled
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def unscramble(scrambled_text: str) -> str:
|
|
533
|
+
"""Unscramble a string in a bijective manner.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
scrambled_text:
|
|
537
|
+
The scrambled string to unscramble.
|
|
538
|
+
|
|
539
|
+
Returns:
|
|
540
|
+
The unscrambled string.
|
|
541
|
+
"""
|
|
542
|
+
rng = np.random.default_rng(seed=4242)
|
|
543
|
+
permutation = rng.permutation(x=len(scrambled_text))
|
|
544
|
+
inverse_permutation = np.argsort(permutation)
|
|
545
|
+
unscrambled = "".join(scrambled_text[i] for i in inverse_permutation)
|
|
546
|
+
return unscrambled
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@cache
|
|
550
|
+
def log_once(message: str, level: int = logging.INFO) -> None:
|
|
551
|
+
"""Log a message once.
|
|
552
|
+
|
|
553
|
+
This is ensured by caching the input/output pairs of this function, using the
|
|
554
|
+
`functools.cache` decorator.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
message:
|
|
558
|
+
The message to log.
|
|
559
|
+
level:
|
|
560
|
+
The logging level. Defaults to logging.INFO.
|
|
561
|
+
"""
|
|
562
|
+
match level:
|
|
563
|
+
case logging.DEBUG:
|
|
564
|
+
logger.debug(message)
|
|
565
|
+
case logging.INFO:
|
|
566
|
+
logger.info(message)
|
|
567
|
+
case logging.WARNING:
|
|
568
|
+
logger.warning(message)
|
|
569
|
+
case logging.ERROR:
|
|
570
|
+
logger.error(message)
|
|
571
|
+
case logging.CRITICAL:
|
|
572
|
+
logger.critical(message)
|
|
573
|
+
case _:
|
|
574
|
+
raise ValueError(f"Invalid logging level: {level}")
|