EuroEval 15.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (40) hide show
  1. euroeval/__init__.py +72 -0
  2. euroeval/benchmark_config_factory.py +358 -0
  3. euroeval/benchmark_modules/__init__.py +7 -0
  4. euroeval/benchmark_modules/base.py +354 -0
  5. euroeval/benchmark_modules/fresh.py +286 -0
  6. euroeval/benchmark_modules/hf.py +1185 -0
  7. euroeval/benchmark_modules/litellm.py +905 -0
  8. euroeval/benchmark_modules/vllm.py +1171 -0
  9. euroeval/benchmarker.py +1074 -0
  10. euroeval/callbacks.py +72 -0
  11. euroeval/cli.py +281 -0
  12. euroeval/constants.py +50 -0
  13. euroeval/data_loading.py +96 -0
  14. euroeval/data_models.py +474 -0
  15. euroeval/dataset_configs.py +2001 -0
  16. euroeval/enums.py +144 -0
  17. euroeval/exceptions.py +191 -0
  18. euroeval/finetuning.py +324 -0
  19. euroeval/generation.py +296 -0
  20. euroeval/human_evaluation.py +737 -0
  21. euroeval/languages.py +200 -0
  22. euroeval/model_cache.py +253 -0
  23. euroeval/model_config.py +77 -0
  24. euroeval/model_loading.py +78 -0
  25. euroeval/scores.py +90 -0
  26. euroeval/speed_benchmark.py +124 -0
  27. euroeval/task_utils/__init__.py +1 -0
  28. euroeval/task_utils/multiple_choice_classification.py +176 -0
  29. euroeval/task_utils/question_answering.py +698 -0
  30. euroeval/task_utils/sequence_classification.py +237 -0
  31. euroeval/task_utils/text_to_text.py +150 -0
  32. euroeval/task_utils/token_classification.py +464 -0
  33. euroeval/tasks.py +202 -0
  34. euroeval/types.py +97 -0
  35. euroeval/utils.py +574 -0
  36. euroeval-15.2.0.dist-info/METADATA +234 -0
  37. euroeval-15.2.0.dist-info/RECORD +40 -0
  38. euroeval-15.2.0.dist-info/WHEEL +4 -0
  39. euroeval-15.2.0.dist-info/entry_points.txt +4 -0
  40. euroeval-15.2.0.dist-info/licenses/LICENSE +21 -0
euroeval/types.py ADDED
@@ -0,0 +1,97 @@
1
+ """Types used throughout the project."""
2
+
3
+ import typing as t
4
+
5
+ from numpy.typing import NDArray
6
+
7
+ if t.TYPE_CHECKING:
8
+ from .data_models import GenerativeModelOutput
9
+
10
+
11
+ ScoreDict = dict[str, dict[str, float] | list[dict[str, float]]]
12
+ Predictions = NDArray | list[str] | list[list[str]]
13
+ Labels = NDArray | list[str] | list[list[str]]
14
+
15
+
16
+ class ComputeMetricsFunction(t.Protocol):
17
+ """A function used to compute the metrics."""
18
+
19
+ def __call__(
20
+ self,
21
+ model_outputs_and_labels: tuple[
22
+ NDArray | list[str] | list[list[str]], NDArray | list[str] | list[list[str]]
23
+ ],
24
+ ) -> dict[str, float]:
25
+ """Compute the metrics.
26
+
27
+ Args:
28
+ model_outputs_and_labels:
29
+ The model outputs and labels.
30
+
31
+ Returns:
32
+ The computed metrics.
33
+ """
34
+ ...
35
+
36
+
37
+ class ExtractLabelsFunction(t.Protocol):
38
+ """A function used to extract the labels from the generated output."""
39
+
40
+ def __call__(
41
+ self, input_batch: dict[str, list], model_output: "GenerativeModelOutput"
42
+ ) -> list[str]:
43
+ """Extract the labels from the generated output.
44
+
45
+ Args:
46
+ input_batch:
47
+ The input batch.
48
+ model_output:
49
+ The model output.
50
+
51
+ Returns:
52
+ The extracted labels.
53
+ """
54
+ ...
55
+
56
+
57
+ def is_list_of_int(x: object) -> t.TypeGuard[list[int]]:
58
+ """Check if an object is a list of integers.
59
+
60
+ Args:
61
+ x:
62
+ The object to check.
63
+
64
+ Returns:
65
+ Whether the object is a list of integers.
66
+ """
67
+ return isinstance(x, list) and all(isinstance(i, int) for i in x)
68
+
69
+
70
+ def is_list_of_list_of_int(x: object) -> t.TypeGuard[list[list[int]]]:
71
+ """Check if an object is a list of list of integers.
72
+
73
+ Args:
74
+ x:
75
+ The object to check.
76
+
77
+ Returns:
78
+ Whether the object is a list of list of integers.
79
+ """
80
+ return (
81
+ isinstance(x, list)
82
+ and all(isinstance(i, list) for i in x)
83
+ and all(isinstance(j, int) for i in x for j in i)
84
+ )
85
+
86
+
87
+ def is_list_of_str(x: object) -> t.TypeGuard[list[str]]:
88
+ """Check if an object is a list of integers.
89
+
90
+ Args:
91
+ x:
92
+ The object to check.
93
+
94
+ Returns:
95
+ Whether the object is a list of strings.
96
+ """
97
+ return isinstance(x, list) and all(isinstance(i, str) for i in x)
euroeval/utils.py ADDED
@@ -0,0 +1,574 @@
1
+ """Utility functions to be used in other scripts."""
2
+
3
+ import gc
4
+ import importlib
5
+ import importlib.util
6
+ import logging
7
+ import os
8
+ import random
9
+ import re
10
+ import sys
11
+ import typing as t
12
+ import warnings
13
+ from functools import cache
14
+ from pathlib import Path
15
+ from types import TracebackType
16
+
17
+ import litellm
18
+ import numpy as np
19
+ import pkg_resources
20
+ import requests
21
+ import torch
22
+ from datasets.utils import disable_progress_bar
23
+ from requests.exceptions import RequestException
24
+ from transformers import PreTrainedTokenizer
25
+ from transformers import logging as tf_logging
26
+
27
+ from .exceptions import InvalidModel, NaNValueInModelOutput
28
+
29
+ if importlib.util.find_spec("ray") is not None:
30
+ import ray
31
+
32
+ if t.TYPE_CHECKING:
33
+ from .types import Predictions
34
+
35
+
36
+ logger = logging.getLogger("euroeval")
37
+
38
+
39
+ def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
40
+ """Create cache directory for a model.
41
+
42
+ Args:
43
+ cache_dir:
44
+ The cache directory.
45
+ model_id:
46
+ The model ID.
47
+
48
+ Returns:
49
+ The path to the cache directory.
50
+ """
51
+ # to avoid nesting due to models name containing '/'
52
+ _model_id = model_id.replace("/", "--")
53
+ cache_dir_path = Path(cache_dir) / "model_cache" / _model_id
54
+ return str(cache_dir_path)
55
+
56
+
57
+ def clear_memory() -> None:
58
+ """Clears the memory of unused items."""
59
+ for gc_generation in range(3):
60
+ gc.collect(generation=gc_generation)
61
+ if torch.cuda.is_available():
62
+ torch.cuda.empty_cache()
63
+ if torch.backends.mps.is_available():
64
+ torch.mps.empty_cache()
65
+
66
+
67
+ def enforce_reproducibility(seed: int = 4242) -> np.random.Generator:
68
+ """Ensures reproducibility of experiments.
69
+
70
+ Args:
71
+ seed:
72
+ Seed for the random number generator.
73
+ """
74
+ random.seed(seed)
75
+ np.random.seed(seed)
76
+ rng = np.random.default_rng(seed)
77
+ torch.manual_seed(seed)
78
+ torch.cuda.manual_seed_all(seed)
79
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
80
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
81
+ torch.backends.cudnn.benchmark = False
82
+ torch.backends.cudnn.deterministic = True
83
+ torch.use_deterministic_algorithms(True, warn_only=True)
84
+ return rng
85
+
86
+
87
+ def is_module_installed(module: str) -> bool:
88
+ """Check if a module is installed.
89
+
90
+ This is used when dealing with spaCy models, as these are installed as separate
91
+ Python packages.
92
+
93
+ Args:
94
+ module:
95
+ The name of the module.
96
+
97
+ Returns:
98
+ Whether the module is installed or not.
99
+ """
100
+ # Get list of all modules, including their versions
101
+ installed_modules_with_versions = list(pkg_resources.working_set)
102
+
103
+ # Strip the module versions from the list of modules. Also make the modules lower
104
+ # case and replace dashes with underscores
105
+ installed_modules = [
106
+ re.sub("[0-9. ]", "", str(module)).lower().replace("-", "_")
107
+ for module in installed_modules_with_versions
108
+ ]
109
+
110
+ # Check if the module is installed by checking if the module name is in the list
111
+ return module.lower() in installed_modules
112
+
113
+
114
+ def block_terminal_output() -> None:
115
+ """Blocks libraries from writing output to the terminal.
116
+
117
+ This filters warnings from some libraries, sets the logging level to ERROR for some
118
+ libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
119
+ disables most of the logging from the `transformers` library.
120
+ """
121
+ # Ignore miscellaneous warnings
122
+ warnings.filterwarnings("ignore", category=UserWarning)
123
+ warnings.filterwarnings("ignore", category=FutureWarning)
124
+ warnings.filterwarnings(
125
+ "ignore",
126
+ module="torch.nn.parallel*",
127
+ message="Was asked to gather along dimension 0, but all input tensors were "
128
+ "scalars; will instead unsqueeze and return a vector.",
129
+ )
130
+ warnings.filterwarnings("ignore", module="seqeval*")
131
+
132
+ # Up the logging level, to disable outputs
133
+ logging.getLogger("filelock").setLevel(logging.CRITICAL)
134
+ logging.getLogger("absl").setLevel(logging.CRITICAL)
135
+ logging.getLogger("datasets").setLevel(logging.CRITICAL)
136
+ logging.getLogger("openai").setLevel(logging.CRITICAL)
137
+ logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.CRITICAL)
138
+ logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.CRITICAL)
139
+ logging.getLogger("vllm").setLevel(logging.CRITICAL)
140
+ logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
141
+ logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
142
+ logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
143
+ logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
144
+ logging.getLogger("httpx").setLevel(logging.CRITICAL)
145
+ logging.getLogger("ray._private.worker").setLevel(logging.CRITICAL)
146
+ logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
147
+ logging.getLogger("accelerate").setLevel(logging.CRITICAL)
148
+ logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
149
+ logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
150
+
151
+ # This suppresses vLLM logging
152
+ os.environ["LOG_LEVEL"] = "CRITICAL"
153
+ os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
154
+
155
+ if importlib.util.find_spec("ray") is not None:
156
+ ray._private.worker._worker_logs_enabled = False
157
+
158
+ # Disable the tokeniser progress bars
159
+ disable_progress_bar()
160
+
161
+ # Disable most of the `transformers` logging
162
+ tf_logging._default_log_level = logging.CRITICAL
163
+ tf_logging.set_verbosity(logging.CRITICAL)
164
+ logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
165
+
166
+ # Disable logging from `litellm`
167
+ litellm.suppress_debug_info = True
168
+
169
+
170
+ def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type | None:
171
+ """Get a class by its name.
172
+
173
+ Args:
174
+ class_name:
175
+ The name of the class, written in kebab-case. The corresponding class name
176
+ must be the same, but written in PascalCase, and lying in a module with the
177
+ same name, but written in snake_case. If a list of strings is passed, the
178
+ first class that is found is returned.
179
+ module_name:
180
+ The name of the module where the class is located.
181
+
182
+ Returns:
183
+ The class. If the class is not found, None is returned.
184
+ """
185
+ if isinstance(class_name, str):
186
+ class_name = [class_name]
187
+
188
+ error_messages = list()
189
+ for name in class_name:
190
+ try:
191
+ module = importlib.import_module(name=module_name)
192
+ class_: t.Type = getattr(module, name)
193
+ return class_
194
+ except (ModuleNotFoundError, AttributeError) as e:
195
+ error_messages.append(str(e))
196
+
197
+ if error_messages:
198
+ errors = "\n- " + "\n- ".join(error_messages)
199
+ logger.debug(
200
+ f"Could not find the class with the name(s) {', '.join(class_name)}. The "
201
+ f"following error messages were raised: {errors}"
202
+ )
203
+
204
+ # If the class could not be found, return None
205
+ return None
206
+
207
+
208
+ def kebab_to_pascal(kebab_string: str) -> str:
209
+ """Converts a kebab-case string to PascalCase.
210
+
211
+ Args:
212
+ kebab_string:
213
+ The kebab-case string.
214
+
215
+ Returns:
216
+ The PascalCase string.
217
+ """
218
+ return "".join(word.title() for word in kebab_string.split("-"))
219
+
220
+
221
+ def internet_connection_available() -> bool:
222
+ """Checks if internet connection is available by pinging google.com.
223
+
224
+ Returns:
225
+ Whether or not internet connection is available.
226
+ """
227
+ try:
228
+ requests.get("https://www.google.com")
229
+ return True
230
+ except RequestException:
231
+ return False
232
+
233
+
234
+ def get_special_token_metadata(tokenizer: "PreTrainedTokenizer") -> dict:
235
+ """Get the special token metadata for a tokenizer.
236
+
237
+ Args:
238
+ tokenizer:
239
+ The tokenizer.
240
+
241
+ Returns:
242
+ The special token metadata.
243
+ """
244
+ # Create some test input IDs, to check if the tokenizer is adding special tokens
245
+ test_input_ids = tokenizer("Test").input_ids
246
+
247
+ # Extract the CLS token IDs from the tokenizer, if it's using them
248
+ has_cls_token = True
249
+ if tokenizer.cls_token_id in test_input_ids:
250
+ cls_token_id = tokenizer.cls_token_id
251
+ cls_token = tokenizer.cls_token
252
+ elif tokenizer.bos_token_id in test_input_ids:
253
+ cls_token_id = tokenizer.bos_token_id
254
+ cls_token = tokenizer.bos_token
255
+ elif tokenizer.cls_token is not None:
256
+ cls_token_id = tokenizer.cls_token_id
257
+ cls_token = tokenizer.cls_token
258
+ has_cls_token = False
259
+ else:
260
+ cls_token_id = tokenizer.bos_token_id
261
+ cls_token = tokenizer.bos_token
262
+ has_cls_token = False
263
+
264
+ # Extract the SEP token IDs from the tokenizer, if it's using them
265
+ has_sep_token = True
266
+ if tokenizer.sep_token_id in test_input_ids:
267
+ sep_token = tokenizer.sep_token
268
+ elif tokenizer.eos_token_id in test_input_ids:
269
+ sep_token = tokenizer.eos_token
270
+ elif tokenizer.sep_token is not None:
271
+ sep_token = tokenizer.sep_token
272
+ has_sep_token = False
273
+ else:
274
+ sep_token = tokenizer.eos_token
275
+ has_sep_token = False
276
+
277
+ return dict(
278
+ cls_token_id=cls_token_id,
279
+ cls_token=cls_token,
280
+ sep_token=sep_token,
281
+ has_cls_token=has_cls_token,
282
+ has_sep_token=has_sep_token,
283
+ )
284
+
285
+
286
+ class HiddenPrints:
287
+ """Context manager which removes all terminal output."""
288
+
289
+ def __enter__(self) -> None:
290
+ """Enter the context manager."""
291
+ self._original_stdout = sys.stdout
292
+ self._original_stderr = sys.stderr
293
+ sys.stdout = open(os.devnull, "w")
294
+ sys.stderr = open(os.devnull, "w")
295
+
296
+ def __exit__(
297
+ self,
298
+ exc_type: t.Type[BaseException],
299
+ exc_val: BaseException,
300
+ exc_tb: TracebackType,
301
+ ) -> None:
302
+ """Exit the context manager."""
303
+ sys.stdout.close()
304
+ sys.stderr.close()
305
+ sys.stdout = self._original_stdout
306
+ sys.stderr = self._original_stderr
307
+
308
+
309
+ def raise_if_model_output_contains_nan_values(model_output: "Predictions") -> None:
310
+ """Raise an exception if the model output contains NaN values.
311
+
312
+ Args:
313
+ model_output:
314
+ The model output to check.
315
+
316
+ Raises:
317
+ If the model output contains NaN values.
318
+ """
319
+ if isinstance(model_output, np.ndarray):
320
+ if model_output.dtype == np.float32 and np.isnan(model_output).any():
321
+ raise NaNValueInModelOutput()
322
+ elif len(model_output) > 0:
323
+ if isinstance(model_output[0], str):
324
+ if any(x != x for x in model_output):
325
+ raise NaNValueInModelOutput()
326
+ elif len(model_output[0]) > 0:
327
+ if any(x != x for sublist in model_output for x in sublist):
328
+ raise NaNValueInModelOutput()
329
+
330
+
331
+ def should_prompts_be_stripped(
332
+ labels_to_be_generated: list[str], tokenizer: "PreTrainedTokenizer"
333
+ ) -> bool:
334
+ """Determine if we should strip the prompts for few-shot evaluation.
335
+
336
+ This is the case if the tokenizer needs to include the space as part of the label
337
+ token. The strategy is thus to tokenize a label with a preceeding colon (as in the
338
+ prompts), i.e., ": positive", and check if the tokenization starts with the tokens
339
+ of ": ". If this is the case, then we should not strip the prompts, since the
340
+ tokenizer produces the whitespace token separately.
341
+
342
+ Args:
343
+ labels_to_be_generated:
344
+ The labels that are to be generated.
345
+ tokenizer:
346
+ The tokenizer used to tokenize the labels.
347
+
348
+ Returns:
349
+ Whether we should strip the prompts.
350
+ """
351
+ strip_prompts = True
352
+ for label in labels_to_be_generated:
353
+ colon_tokens = tokenizer(": ", add_special_tokens=False).input_ids
354
+ label_tokens = tokenizer(": " + label, add_special_tokens=False).input_ids
355
+
356
+ if isinstance(colon_tokens, torch.Tensor):
357
+ colon_tokens = list(colon_tokens.squeeze(0))
358
+ if isinstance(label_tokens, torch.Tensor):
359
+ label_tokens = list(label_tokens.squeeze(0))
360
+
361
+ label_tokens_start_with_colon_tokens = (
362
+ label_tokens[: len(colon_tokens)] == colon_tokens
363
+ )
364
+ if label_tokens_start_with_colon_tokens:
365
+ strip_prompts = False
366
+
367
+ return strip_prompts
368
+
369
+
370
+ # TODO: This is currently not used - maybe remove.
371
+ def should_prefix_space_be_added_to_labels(
372
+ labels_to_be_generated: list[str], tokenizer: "PreTrainedTokenizer"
373
+ ) -> bool:
374
+ """Determine if we should add a prefix space to the labels.
375
+
376
+ This is the case if the prompts are stripped and the tokenizer doesn't
377
+ automatically add prefix whitespaces to the labels.
378
+
379
+ Args:
380
+ labels_to_be_generated:
381
+ The labels that are to be generated.
382
+ tokenizer:
383
+ The tokenizer used to tokenize the labels.
384
+
385
+ Returns:
386
+ Whether we should add a prefix space to the labels.
387
+ """
388
+ if not should_prompts_be_stripped(
389
+ labels_to_be_generated=labels_to_be_generated, tokenizer=tokenizer
390
+ ):
391
+ return False
392
+
393
+ whitespace_token = tokenizer.convert_ids_to_tokens(
394
+ ids=tokenizer(" ", add_special_tokens=False).input_ids[0]
395
+ )[0]
396
+
397
+ add_prefix_space = True
398
+ for label in labels_to_be_generated:
399
+ label_tokens = tokenizer(label, add_special_tokens=False).input_ids
400
+ if isinstance(label_tokens, torch.Tensor):
401
+ label_tokens = list(label_tokens.squeeze(0))
402
+ first_label_token: int = int(label_tokens[0])
403
+ first_character_of_label = tokenizer.convert_ids_to_tokens(first_label_token)[0]
404
+ has_prefix_space = first_character_of_label == whitespace_token
405
+ if has_prefix_space:
406
+ add_prefix_space = False
407
+ break
408
+
409
+ return add_prefix_space
410
+
411
+
412
+ def get_bos_token(tokenizer: "PreTrainedTokenizer") -> tuple[str, int]:
413
+ """Get the beginning-of-sequence token from a tokenizer.
414
+
415
+ Args:
416
+ tokenizer:
417
+ The tokenizer.
418
+
419
+ Returns:
420
+ A pair (token, token_id) representing the beginning-of-sequence token and its
421
+ token ID.
422
+ """
423
+ if isinstance(tokenizer.bos_token, str) and isinstance(tokenizer.bos_token_id, int):
424
+ return tokenizer.bos_token, tokenizer.bos_token_id
425
+
426
+ vocab: dict[str, int] = tokenizer.get_vocab()
427
+
428
+ candidate_bos_tokens = ["<s>", "<|begin_of_text|>", "[CLS]"]
429
+ for candidate_bos_token in candidate_bos_tokens:
430
+ if candidate_bos_token in vocab:
431
+ bos_token = candidate_bos_token
432
+ bos_token_id = vocab[bos_token]
433
+ break
434
+ else:
435
+ raise InvalidModel(
436
+ "The model does not have a beginning-of-sequence token. Please ensure that "
437
+ "this has been set in the tokenizer's configuration."
438
+ )
439
+
440
+ return bos_token, bos_token_id
441
+
442
+
443
+ def get_eos_token(tokenizer: "PreTrainedTokenizer") -> tuple[str, int]:
444
+ """Get the end-of-sequence token from a tokenizer.
445
+
446
+ Args:
447
+ tokenizer:
448
+ The tokenizer.
449
+
450
+ Returns:
451
+ A pair (token, token_id) representing the end-of-sequence token and its token
452
+ ID.
453
+ """
454
+ if isinstance(tokenizer.eos_token, str) and isinstance(tokenizer.eos_token_id, int):
455
+ return tokenizer.eos_token, tokenizer.eos_token_id
456
+
457
+ vocab: dict[str, int] = tokenizer.get_vocab()
458
+
459
+ candidate_eos_tokens = ["</s>", "<|end_of_text|>", "[SEP]"]
460
+ for candidate_eos_token in candidate_eos_tokens:
461
+ if candidate_eos_token in vocab:
462
+ eos_token = candidate_eos_token
463
+ eos_token_id = vocab[eos_token]
464
+ break
465
+ else:
466
+ raise InvalidModel(
467
+ "The model does not have an end-of-sequence token. Please ensure that this "
468
+ "has been set in the tokenizer's configuration."
469
+ )
470
+
471
+ return eos_token, eos_token_id
472
+
473
+
474
+ def get_end_of_chat_token_ids(tokenizer: "PreTrainedTokenizer") -> list[int] | None:
475
+ """Get the end token ID for chat models.
476
+
477
+ This is only relevant for tokenizers with a chat template.
478
+
479
+ Args:
480
+ tokenizer:
481
+ The tokenizer.
482
+
483
+ Returns:
484
+ The token IDs used to end chats, or None if the tokenizer does not have a chat
485
+ template.
486
+
487
+ Raises:
488
+ ValueError:
489
+ If the end-of-chat token could not be located.
490
+ """
491
+ if tokenizer.chat_template is None:
492
+ return None
493
+
494
+ user_message: dict[t.Literal["role", "content"], str] = dict()
495
+ user_message["role"] = "user"
496
+ user_message["content"] = "X"
497
+ token_ids = tokenizer.apply_chat_template(conversation=[user_message])
498
+ assert isinstance(token_ids, list)
499
+
500
+ for idx, token in enumerate(tokenizer.convert_ids_to_tokens(token_ids)):
501
+ token_id = tokenizer.convert_tokens_to_ids(token)
502
+ assert isinstance(token_id, int)
503
+ token = tokenizer.decode([token_id])
504
+ if "X" in token:
505
+ x_token_index = idx
506
+ break
507
+ else:
508
+ raise ValueError("Could not locate the end-of-chat token for the model.")
509
+
510
+ end_of_chat_tokens = token_ids[x_token_index + 1 :]
511
+ if len(end_of_chat_tokens) == 0:
512
+ return None
513
+ return end_of_chat_tokens
514
+
515
+
516
+ def scramble(text: str) -> str:
517
+ """Scramble a string in a bijective manner.
518
+
519
+ Args:
520
+ text:
521
+ The string to scramble.
522
+
523
+ Returns:
524
+ The scrambled string.
525
+ """
526
+ rng = np.random.default_rng(seed=4242)
527
+ permutation = rng.permutation(x=len(text))
528
+ scrambled = "".join(text[i] for i in permutation)
529
+ return scrambled
530
+
531
+
532
+ def unscramble(scrambled_text: str) -> str:
533
+ """Unscramble a string in a bijective manner.
534
+
535
+ Args:
536
+ scrambled_text:
537
+ The scrambled string to unscramble.
538
+
539
+ Returns:
540
+ The unscrambled string.
541
+ """
542
+ rng = np.random.default_rng(seed=4242)
543
+ permutation = rng.permutation(x=len(scrambled_text))
544
+ inverse_permutation = np.argsort(permutation)
545
+ unscrambled = "".join(scrambled_text[i] for i in inverse_permutation)
546
+ return unscrambled
547
+
548
+
549
+ @cache
550
+ def log_once(message: str, level: int = logging.INFO) -> None:
551
+ """Log a message once.
552
+
553
+ This is ensured by caching the input/output pairs of this function, using the
554
+ `functools.cache` decorator.
555
+
556
+ Args:
557
+ message:
558
+ The message to log.
559
+ level:
560
+ The logging level. Defaults to logging.INFO.
561
+ """
562
+ match level:
563
+ case logging.DEBUG:
564
+ logger.debug(message)
565
+ case logging.INFO:
566
+ logger.info(message)
567
+ case logging.WARNING:
568
+ logger.warning(message)
569
+ case logging.ERROR:
570
+ logger.error(message)
571
+ case logging.CRITICAL:
572
+ logger.critical(message)
573
+ case _:
574
+ raise ValueError(f"Invalid logging level: {level}")