EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
euroeval/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Utility functions to be used in other scripts."""
2
2
 
3
3
  import asyncio
4
+ import collections.abc as c
4
5
  import gc
5
6
  import importlib
6
7
  import importlib.metadata
@@ -8,34 +9,28 @@ import importlib.util
8
9
  import logging
9
10
  import os
10
11
  import random
12
+ import re
13
+ import socket
11
14
  import sys
12
15
  import typing as t
13
- import warnings
14
- from functools import cache
15
16
  from pathlib import Path
17
+ from types import ModuleType
16
18
 
17
- import litellm
19
+ import demjson3
20
+ import huggingface_hub as hf_hub
18
21
  import numpy as np
19
- import requests
20
22
  import torch
21
- from datasets.utils import disable_progress_bar
22
- from requests.exceptions import RequestException
23
- from transformers import logging as tf_logging
24
23
 
25
- from .exceptions import NaNValueInModelOutput
26
-
27
- if importlib.util.find_spec("ray") is not None:
28
- import ray
24
+ from .caching_utils import cache_arguments
25
+ from .constants import T
26
+ from .exceptions import InvalidBenchmark, InvalidModel, NaNValueInModelOutput
27
+ from .logging_utils import log, log_once
29
28
 
30
29
  if t.TYPE_CHECKING:
31
- from types import TracebackType
32
-
30
+ from .data_models import ModelIdComponents
33
31
  from .types import Predictions
34
32
 
35
33
 
36
- logger = logging.getLogger("euroeval")
37
-
38
-
39
34
  def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
40
35
  """Create cache directory for a model.
41
36
 
@@ -54,6 +49,72 @@ def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
54
49
  return str(cache_dir_path)
55
50
 
56
51
 
52
+ def resolve_model_path(download_dir: str) -> str:
53
+ """Resolve the path to the directory containing the model config files and weights.
54
+
55
+ Args:
56
+ download_dir:
57
+ The download directory
58
+
59
+ Returns:
60
+ The path to the model.
61
+
62
+ Raises:
63
+ InvalidModel:
64
+ If the model path is not valid, or if required files are missing.
65
+ """
66
+ model_path = Path(download_dir)
67
+ # Get the 'path safe' version of the model id, which is the last dir in the path
68
+ model_id_path = model_path.name
69
+ # Hf hub `cache_dir` puts the files in models--`model_id_path`/snapshots
70
+ model_path = model_path / f"models--{model_id_path}" / "snapshots"
71
+ if not model_path.exists():
72
+ raise InvalidModel(
73
+ f"Attempted to load models from the {model_path} directory, "
74
+ "but it does not exist."
75
+ )
76
+
77
+ # Get all files in the model path
78
+ found_files = [
79
+ found_file for found_file in model_path.rglob("*") if found_file.is_file()
80
+ ]
81
+ if not found_files:
82
+ raise InvalidModel(f"No model files found at {model_path}")
83
+
84
+ # Make sure that there arent multiples of the files found
85
+ if len(found_files) == len(set(found_files)):
86
+ raise InvalidModel(
87
+ f"Found multiple model config files for {model_id_path.strip('models--')}"
88
+ f"at {model_path}"
89
+ )
90
+
91
+ # Check that found_files contains at least a 'config.json'
92
+ config_file = next(
93
+ (file for file in found_files if file.name == "config.json"), None
94
+ )
95
+ if config_file is None:
96
+ raise InvalidModel(
97
+ f"Missing required file 'config.json' for {model_id_path.strip('models--')}"
98
+ f"at {model_path}"
99
+ )
100
+ model_path = config_file.parent
101
+
102
+ # As a precaution we also check that all of the files are in the same directory
103
+ # if not we create a new dir with symlinks to all of the files from all snapshots
104
+ # this is especially useful for vllm where we can only specify one folder and e.g.,
105
+ # the safetensors version of the weights was added in an unmerged PR
106
+ if not all(
107
+ [found_file.parent == found_files[0].parent for found_file in found_files]
108
+ ):
109
+ new_model_path = model_path.parent / "model_files"
110
+ new_model_path.mkdir(exist_ok=True)
111
+ for found_file in found_files:
112
+ Path(new_model_path / found_file.name).symlink_to(found_file)
113
+ model_path = new_model_path
114
+
115
+ return str(model_path)
116
+
117
+
57
118
  def clear_memory() -> None:
58
119
  """Clears the memory of unused items."""
59
120
  for gc_generation in range(3):
@@ -84,67 +145,9 @@ def enforce_reproducibility(seed: int = 4242) -> np.random.Generator:
84
145
  return rng
85
146
 
86
147
 
87
- def block_terminal_output() -> None:
88
- """Blocks libraries from writing output to the terminal.
89
-
90
- This filters warnings from some libraries, sets the logging level to ERROR for some
91
- libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
92
- disables most of the logging from the `transformers` library.
93
- """
94
- # Ignore miscellaneous warnings
95
- warnings.filterwarnings("ignore", category=UserWarning)
96
- warnings.filterwarnings("ignore", category=FutureWarning)
97
- warnings.filterwarnings(
98
- "ignore",
99
- module="torch.nn.parallel*",
100
- message="Was asked to gather along dimension 0, but all input tensors were "
101
- "scalars; will instead unsqueeze and return a vector.",
102
- )
103
- warnings.filterwarnings("ignore", module="seqeval*")
104
-
105
- # Up the logging level, to disable outputs
106
- logging.getLogger("filelock").setLevel(logging.CRITICAL)
107
- logging.getLogger("absl").setLevel(logging.CRITICAL)
108
- logging.getLogger("datasets").setLevel(logging.CRITICAL)
109
- logging.getLogger("openai").setLevel(logging.CRITICAL)
110
- logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.CRITICAL)
111
- logging.getLogger("torch.nn.parallel.distributed").setLevel(logging.CRITICAL)
112
- logging.getLogger("vllm").setLevel(logging.CRITICAL)
113
- logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
114
- logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
115
- logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
116
- logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
117
- logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
118
- logging.getLogger("httpx").setLevel(logging.CRITICAL)
119
- logging.getLogger("ray._private.worker").setLevel(logging.CRITICAL)
120
- logging.getLogger("ray._private.services").setLevel(logging.CRITICAL)
121
- logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
122
- logging.getLogger("accelerate").setLevel(logging.CRITICAL)
123
- logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
124
- logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
125
- logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
126
- logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
127
-
128
- # This suppresses vLLM logging
129
- os.environ["LOG_LEVEL"] = "CRITICAL"
130
- os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
131
-
132
- if importlib.util.find_spec("ray") is not None:
133
- ray._private.worker._worker_logs_enabled = False
134
-
135
- # Disable the tokeniser progress bars
136
- disable_progress_bar()
137
-
138
- # Disable most of the `transformers` logging
139
- tf_logging._default_log_level = logging.CRITICAL
140
- tf_logging.set_verbosity(logging.CRITICAL)
141
- logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
142
-
143
- # Disable logging from `litellm`
144
- litellm.suppress_debug_info = True
145
-
146
-
147
- def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type | None:
148
+ def get_class_by_name(
149
+ class_name: str | c.Sequence[str], module_name: str
150
+ ) -> t.Type | None:
148
151
  """Get a class by its name.
149
152
 
150
153
  Args:
@@ -173,9 +176,10 @@ def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type |
173
176
 
174
177
  if error_messages:
175
178
  errors = "\n- " + "\n- ".join(error_messages)
176
- logger.debug(
179
+ log(
177
180
  f"Could not find the class with the name(s) {', '.join(class_name)}. The "
178
- f"following error messages were raised: {errors}"
181
+ f"following error messages were raised: {errors}",
182
+ level=logging.DEBUG,
179
183
  )
180
184
 
181
185
  # If the class could not be found, return None
@@ -197,40 +201,27 @@ def get_min_cuda_compute_capability() -> float | None:
197
201
  return float(f"{major}.{minor}")
198
202
 
199
203
 
204
+ @cache_arguments(disable_condition=lambda: hasattr(sys, "_called_from_test"))
200
205
  def internet_connection_available() -> bool:
201
206
  """Checks if internet connection is available by pinging google.com.
202
207
 
203
208
  Returns:
204
209
  Whether or not internet connection is available.
205
210
  """
206
- try:
207
- requests.get("https://www.google.com")
208
- return True
209
- except RequestException:
210
- return False
211
-
212
-
213
- class HiddenPrints:
214
- """Context manager which removes all terminal output."""
211
+ internet_available: bool = False
215
212
 
216
- def __enter__(self) -> None:
217
- """Enter the context manager."""
218
- self._original_stdout = sys.stdout
219
- self._original_stderr = sys.stderr
220
- sys.stdout = open(os.devnull, "w")
221
- sys.stderr = open(os.devnull, "w")
213
+ try:
214
+ s = socket.create_connection(("1.1.1.1", 80))
215
+ s.close()
216
+ internet_available = True
217
+ except OSError:
218
+ pass
219
+ except Exception as e:
220
+ pytest_socket_errors = ["SocketConnectBlockedError", "SocketBlockedError"]
221
+ if type(e).__name__ not in pytest_socket_errors:
222
+ raise e
222
223
 
223
- def __exit__(
224
- self,
225
- exc_type: t.Type[BaseException],
226
- exc_val: BaseException,
227
- exc_tb: "TracebackType",
228
- ) -> None:
229
- """Exit the context manager."""
230
- sys.stdout.close()
231
- sys.stderr.close()
232
- sys.stdout = self._original_stdout
233
- sys.stderr = self._original_stderr
224
+ return internet_available
234
225
 
235
226
 
236
227
  def raise_if_model_output_contains_nan_values(model_output: "Predictions") -> None:
@@ -288,34 +279,6 @@ def unscramble(scrambled_text: str) -> str:
288
279
  return unscrambled
289
280
 
290
281
 
291
- @cache
292
- def log_once(message: str, level: int = logging.INFO) -> None:
293
- """Log a message once.
294
-
295
- This is ensured by caching the input/output pairs of this function, using the
296
- `functools.cache` decorator.
297
-
298
- Args:
299
- message:
300
- The message to log.
301
- level:
302
- The logging level. Defaults to logging.INFO.
303
- """
304
- match level:
305
- case logging.DEBUG:
306
- logger.debug(message)
307
- case logging.INFO:
308
- logger.info(message)
309
- case logging.WARNING:
310
- logger.warning(message)
311
- case logging.ERROR:
312
- logger.error(message)
313
- case logging.CRITICAL:
314
- logger.critical(message)
315
- case _:
316
- raise ValueError(f"Invalid logging level: {level}")
317
-
318
-
319
282
  def get_package_version(package_name: str) -> str | None:
320
283
  """Get the version of a package.
321
284
 
@@ -332,9 +295,6 @@ def get_package_version(package_name: str) -> str | None:
332
295
  return None
333
296
 
334
297
 
335
- T = t.TypeVar("T", bound=object)
336
-
337
-
338
298
  def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
339
299
  """Run a coroutine, ensuring that the event loop is always closed when we're done.
340
300
 
@@ -348,7 +308,8 @@ def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
348
308
  loop = asyncio.new_event_loop()
349
309
  try:
350
310
  asyncio.set_event_loop(loop)
351
- return loop.run_until_complete(coroutine)
311
+ response = loop.run_until_complete(coroutine)
312
+ return response
352
313
  finally:
353
314
  loop.close()
354
315
  asyncio.set_event_loop(None)
@@ -373,3 +334,237 @@ async def add_semaphore_and_catch_exception(
373
334
  return await coroutine
374
335
  except Exception as exc:
375
336
  return exc
337
+
338
+
339
+ def extract_json_dict_from_string(s: str) -> dict | None:
340
+ """Extract a JSON dictionary from a string.
341
+
342
+ Args:
343
+ s:
344
+ The string to extract the JSON dictionary from.
345
+
346
+ Returns:
347
+ The extracted JSON dictionary, or None if no JSON dictionary could be found.
348
+ """
349
+ json_regex = r"\{[^{}]*?\}"
350
+ if (json_match := re.search(pattern=json_regex, string=s, flags=re.DOTALL)) is None:
351
+ log(
352
+ "The model output does not contain any JSON dictionary, so cannot parse "
353
+ f"it. Skipping. Here is the output: {s!r}",
354
+ level=logging.DEBUG,
355
+ )
356
+ return None
357
+ json_string = json_match.group()
358
+ try:
359
+ json_output = demjson3.decode(txt=json_string)
360
+ except demjson3.JSONDecodeError:
361
+ log(
362
+ "The model output is not valid JSON, so cannot parse it. Skipping. "
363
+ f"Here is the output: {json_string!r}",
364
+ level=logging.DEBUG,
365
+ )
366
+ return None
367
+ if not isinstance(json_output, dict):
368
+ log(
369
+ "The model output is not a JSON dictionary, so cannot parse "
370
+ f"it. Skipping. Here is the output: {json_string!r}",
371
+ level=logging.DEBUG,
372
+ )
373
+ return None
374
+ elif not all(isinstance(key, str) for key in json_output.keys()):
375
+ log(
376
+ "The model output is not a JSON dictionary with string keys, "
377
+ "so cannot parse it. Skipping. Here is the output: "
378
+ f"{json_string!r}",
379
+ level=logging.DEBUG,
380
+ )
381
+ return None
382
+ return json_output
383
+
384
+
385
+ @cache_arguments()
386
+ def get_hf_token(api_key: str | None) -> str | bool:
387
+ """Get the Hugging Face token.
388
+
389
+ Args:
390
+ api_key:
391
+ The API key to use as the Hugging Face token. If None, we will try to
392
+ extract it in other ways.
393
+
394
+ Returns:
395
+ The Hugging Face token, or True if no token is set but the user is logged in, or
396
+ False if no token is set and the user is not logged in.
397
+ """
398
+ if api_key is not None:
399
+ log_once(
400
+ "Using the Hugging Face API key passed to the function.",
401
+ level=logging.DEBUG,
402
+ )
403
+ return api_key
404
+ elif (token := os.getenv("HUGGINGFACE_API_KEY")) is not None:
405
+ log_once(
406
+ "Using the Hugging Face API key from the environment variable "
407
+ "`HUGGINGFACE_API_KEY`.",
408
+ level=logging.DEBUG,
409
+ )
410
+ return token
411
+ try:
412
+ hf_hub.whoami()
413
+ log_once(
414
+ "No Hugging Face API key was set, but the user is logged in to Hugging "
415
+ "Face, so using the local token.",
416
+ level=logging.DEBUG,
417
+ )
418
+ return True
419
+ except hf_hub.errors.LocalTokenNotFoundError:
420
+ log_once(
421
+ "No Hugging Face API key was set and the user is not logged in to Hugging "
422
+ "Face, so no token will be used.",
423
+ level=logging.DEBUG,
424
+ )
425
+ return False
426
+
427
+
428
+ def extract_multiple_choice_labels(
429
+ prompt: str, candidate_labels: c.Sequence[str]
430
+ ) -> c.Sequence[str]:
431
+ """Extract multiple choice labels from a prompt.
432
+
433
+ Args:
434
+ prompt:
435
+ The prompt to extract the labels from.
436
+ candidate_labels:
437
+ The candidate labels to look for in the prompt.
438
+
439
+ Returns:
440
+ The extracted labels.
441
+ """
442
+ sample_candidate_labels: list[str] = list()
443
+ for candidate_label in candidate_labels:
444
+ candidate_label_match = re.search(
445
+ pattern=rf"\b{candidate_label}\. ", string=prompt, flags=re.IGNORECASE
446
+ )
447
+ if candidate_label_match is not None:
448
+ sample_candidate_labels.append(candidate_label)
449
+ if not sample_candidate_labels:
450
+ raise InvalidBenchmark(
451
+ "Could not extract any candidate labels from the prompt. Please ensure "
452
+ "that the candidate labels are present in the prompt, each followed by a "
453
+ "dot and a space (e.g., 'a. '). The candidate labels are: "
454
+ f"{', '.join(candidate_labels)}. Here is the prompt: {prompt!r}"
455
+ )
456
+ return sample_candidate_labels
457
+
458
+
459
+ def split_model_id(model_id: str) -> "ModelIdComponents":
460
+ """Split a model ID into its components.
461
+
462
+ Args:
463
+ model_id:
464
+ The model ID to split.
465
+
466
+ Returns:
467
+ The split model ID.
468
+
469
+ Raises:
470
+ If the model ID is not valid.
471
+ """
472
+ # Importing here to avoid circular imports
473
+ from .data_models import ModelIdComponents
474
+
475
+ # Attempt to extract the model ID, revision, and param using regex
476
+ model_id_match = re.match(pattern=r"^[^@#]+", string=model_id)
477
+ revision_match = re.search(pattern=r"@([^@#]+)", string=model_id)
478
+ param_match = re.search(pattern=r"#([^@#]+)", string=model_id)
479
+
480
+ # If we cannot extract the model ID, raise an error
481
+ if model_id_match is None:
482
+ raise InvalidModel(f"The model ID {model_id!r} is not valid.")
483
+ model_id = model_id_match.group()
484
+
485
+ # Extract the revision and param and return the result
486
+ revision = revision_match.group(1) if revision_match is not None else "main"
487
+ param = param_match.group(1) if param_match is not None else None
488
+ return ModelIdComponents(model_id=model_id, revision=revision, param=param)
489
+
490
+
491
+ def load_custom_datasets_module() -> ModuleType | None:
492
+ """Load the custom datasets module if it exists.
493
+
494
+ Raises:
495
+ RuntimeError:
496
+ If the custom datasets module cannot be loaded.
497
+ """
498
+ custom_datasets_file = Path("custom_datasets.py")
499
+ if custom_datasets_file.exists():
500
+ spec = importlib.util.spec_from_file_location(
501
+ name="custom_datasets_module", location=str(custom_datasets_file.resolve())
502
+ )
503
+ if spec is None:
504
+ log_once(
505
+ "Could not load the spec for the custom datasets file from "
506
+ f"{custom_datasets_file.resolve()}.",
507
+ level=logging.ERROR,
508
+ )
509
+ return None
510
+ module = importlib.util.module_from_spec(spec=spec)
511
+ if spec.loader is None:
512
+ log_once(
513
+ "Could not load the module for the custom datasets file from "
514
+ f"{custom_datasets_file.resolve()}.",
515
+ level=logging.ERROR,
516
+ )
517
+ return None
518
+ spec.loader.exec_module(module)
519
+ return module
520
+ return None
521
+
522
+
523
+ class flash_attention_backend:
524
+ """Context manager to temporarily set the flash attention backend.
525
+
526
+ This sets the `VLLM_ATTENTION_BACKEND` environment variable to `FLASH_ATTN`
527
+ for the duration of the context manager, and restores the previous value afterwards.
528
+ """
529
+
530
+ def __init__(self, disabled: bool = False) -> None:
531
+ """Initialise the context manager.
532
+
533
+ Args:
534
+ disabled:
535
+ If True, this context manager does nothing.
536
+ """
537
+ self.disabled = (
538
+ True if disabled else os.environ["VLLM_ATTENTION_BACKEND"] != "FLASHINFER"
539
+ )
540
+ self.previous_value: str | None = None
541
+
542
+ def __enter__(self) -> None:
543
+ """Enter the context manager."""
544
+ if self.disabled:
545
+ return
546
+ self.previous_value = os.getenv("VLLM_ATTENTION_BACKEND")
547
+ os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
548
+
549
+ def __exit__(
550
+ self,
551
+ exc_type: t.Type[BaseException] | None,
552
+ exc_value: BaseException | None,
553
+ traceback: type[BaseException] | None,
554
+ ) -> None:
555
+ """Exit the context manager.
556
+
557
+ Args:
558
+ exc_type:
559
+ The type of the exception.
560
+ exc_value:
561
+ The value of the exception.
562
+ exc_tb:
563
+ The traceback of the exception.
564
+ """
565
+ if self.disabled:
566
+ return
567
+ if self.previous_value is None:
568
+ os.environ.pop("VLLM_ATTENTION_BACKEND", None)
569
+ else:
570
+ os.environ["VLLM_ATTENTION_BACKEND"] = self.previous_value