EuroEval 16.3.0__py3-none-any.whl → 16.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (78) hide show
  1. euroeval/__init__.py +9 -2
  2. euroeval/benchmark_config_factory.py +51 -50
  3. euroeval/benchmark_modules/base.py +9 -21
  4. euroeval/benchmark_modules/fresh.py +2 -1
  5. euroeval/benchmark_modules/hf.py +101 -71
  6. euroeval/benchmark_modules/litellm.py +115 -53
  7. euroeval/benchmark_modules/vllm.py +107 -92
  8. euroeval/benchmarker.py +144 -121
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +86 -8
  12. euroeval/constants.py +9 -0
  13. euroeval/data_loading.py +80 -29
  14. euroeval/data_models.py +338 -330
  15. euroeval/dataset_configs/__init__.py +12 -3
  16. euroeval/dataset_configs/bulgarian.py +56 -0
  17. euroeval/dataset_configs/czech.py +75 -0
  18. euroeval/dataset_configs/danish.py +55 -93
  19. euroeval/dataset_configs/dutch.py +48 -87
  20. euroeval/dataset_configs/english.py +45 -77
  21. euroeval/dataset_configs/estonian.py +42 -34
  22. euroeval/dataset_configs/faroese.py +19 -60
  23. euroeval/dataset_configs/finnish.py +36 -69
  24. euroeval/dataset_configs/french.py +39 -75
  25. euroeval/dataset_configs/german.py +45 -82
  26. euroeval/dataset_configs/greek.py +64 -0
  27. euroeval/dataset_configs/icelandic.py +54 -91
  28. euroeval/dataset_configs/italian.py +42 -79
  29. euroeval/dataset_configs/latvian.py +28 -35
  30. euroeval/dataset_configs/lithuanian.py +28 -26
  31. euroeval/dataset_configs/norwegian.py +72 -115
  32. euroeval/dataset_configs/polish.py +33 -61
  33. euroeval/dataset_configs/portuguese.py +33 -66
  34. euroeval/dataset_configs/serbian.py +64 -0
  35. euroeval/dataset_configs/slovak.py +55 -0
  36. euroeval/dataset_configs/spanish.py +42 -77
  37. euroeval/dataset_configs/swedish.py +52 -90
  38. euroeval/dataset_configs/ukrainian.py +64 -0
  39. euroeval/exceptions.py +1 -1
  40. euroeval/finetuning.py +24 -17
  41. euroeval/generation.py +15 -14
  42. euroeval/generation_utils.py +8 -8
  43. euroeval/languages.py +395 -323
  44. euroeval/logging_utils.py +250 -0
  45. euroeval/metrics/base.py +0 -3
  46. euroeval/metrics/huggingface.py +21 -6
  47. euroeval/metrics/llm_as_a_judge.py +6 -4
  48. euroeval/metrics/pipeline.py +17 -9
  49. euroeval/metrics/speed.py +0 -3
  50. euroeval/model_cache.py +17 -19
  51. euroeval/model_config.py +4 -5
  52. euroeval/model_loading.py +3 -0
  53. euroeval/prompt_templates/__init__.py +2 -0
  54. euroeval/prompt_templates/classification.py +206 -0
  55. euroeval/prompt_templates/linguistic_acceptability.py +99 -42
  56. euroeval/prompt_templates/multiple_choice.py +102 -38
  57. euroeval/prompt_templates/named_entity_recognition.py +172 -51
  58. euroeval/prompt_templates/reading_comprehension.py +119 -42
  59. euroeval/prompt_templates/sentiment_classification.py +110 -40
  60. euroeval/prompt_templates/summarization.py +85 -40
  61. euroeval/prompt_templates/token_classification.py +279 -0
  62. euroeval/scores.py +11 -10
  63. euroeval/speed_benchmark.py +5 -6
  64. euroeval/task_group_utils/multiple_choice_classification.py +2 -4
  65. euroeval/task_group_utils/question_answering.py +24 -16
  66. euroeval/task_group_utils/sequence_classification.py +48 -35
  67. euroeval/task_group_utils/text_to_text.py +19 -9
  68. euroeval/task_group_utils/token_classification.py +21 -17
  69. euroeval/tasks.py +44 -1
  70. euroeval/tokenisation_utils.py +33 -22
  71. euroeval/types.py +10 -9
  72. euroeval/utils.py +35 -149
  73. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/METADATA +196 -39
  74. euroeval-16.5.0.dist-info/RECORD +81 -0
  75. euroeval-16.3.0.dist-info/RECORD +0 -71
  76. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/WHEEL +0 -0
  77. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/entry_points.txt +0 -0
  78. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,16 @@
1
1
  """Utility functions related to tokenisation."""
2
2
 
3
+ import collections.abc as c
3
4
  import logging
4
5
  import re
5
6
  import typing as t
6
7
 
7
8
  import torch
8
- from transformers import MistralCommonTokenizer
9
+ from transformers.tokenization_mistral_common import MistralCommonTokenizer
9
10
 
10
11
  from .enums import GenerativeType
11
12
  from .exceptions import InvalidModel
12
- from .utils import log_once
13
+ from .logging_utils import log, log_once
13
14
 
14
15
  if t.TYPE_CHECKING:
15
16
  from transformers.tokenization_utils import PreTrainedTokenizer
@@ -18,9 +19,6 @@ if t.TYPE_CHECKING:
18
19
  from .data_models import DatasetConfig, ModelConfig
19
20
 
20
21
 
21
- logger = logging.getLogger("euroeval")
22
-
23
-
24
22
  def get_special_token_metadata(tokeniser: "PreTrainedTokenizerBase") -> dict:
25
23
  """Get the special token metadata for a tokeniser.
26
24
 
@@ -74,7 +72,7 @@ def get_special_token_metadata(tokeniser: "PreTrainedTokenizerBase") -> dict:
74
72
 
75
73
 
76
74
  def should_prompts_be_stripped(
77
- labels_to_be_generated: list[str], tokeniser: "PreTrainedTokenizer"
75
+ labels_to_be_generated: c.Sequence[str], tokeniser: "PreTrainedTokenizer"
78
76
  ) -> bool:
79
77
  """Determine if we should strip the prompts for few-shot evaluation.
80
78
 
@@ -113,7 +111,7 @@ def should_prompts_be_stripped(
113
111
 
114
112
 
115
113
  def should_prefix_space_be_added_to_labels(
116
- labels_to_be_generated: list[str], tokeniser: "PreTrainedTokenizer"
114
+ labels_to_be_generated: c.Sequence[str], tokeniser: "PreTrainedTokenizer"
117
115
  ) -> bool:
118
116
  """Determine if we should add a prefix space to the labels.
119
117
 
@@ -182,7 +180,7 @@ def get_bos_token(
182
180
  "The model does not have a beginning-of-sequence token. Please ensure that "
183
181
  "this has been set in the tokeniser's configuration. Using no BOS token."
184
182
  " This may lead to unexpected behavior in the model.",
185
- level=logging.INFO,
183
+ level=logging.WARNING,
186
184
  )
187
185
  return None, None
188
186
 
@@ -223,14 +221,14 @@ def get_eos_token(
223
221
  "The model does not have an end-of-sequence token. Please ensure that this "
224
222
  "has been set in the tokeniser's configuration. Using no EOS token. This "
225
223
  "may lead to unexpected behavior in the model.",
226
- level=logging.INFO,
224
+ level=logging.WARNING,
227
225
  )
228
226
  return None, None
229
227
 
230
228
  log_once(
231
229
  f"End-of-sequence token was not set, but detected it as {eos_token!r} with "
232
230
  f"ID {eos_token_id}.",
233
- level=logging.DEBUG,
231
+ level=logging.WARNING,
234
232
  )
235
233
  return eos_token, eos_token_id
236
234
 
@@ -306,7 +304,7 @@ def get_pad_token(
306
304
  "Could not identify a padding token for the model. Please ensure that "
307
305
  "this has been set in the tokeniser's configuration. Using no padding "
308
306
  "token. This may lead to unexpected behavior in the model.",
309
- level=logging.INFO,
307
+ level=logging.WARNING,
310
308
  )
311
309
  return None, None
312
310
 
@@ -320,7 +318,7 @@ def get_pad_token(
320
318
 
321
319
  def get_end_of_chat_token_ids(
322
320
  tokeniser: "PreTrainedTokenizer", generative_type: GenerativeType | None
323
- ) -> list[int] | None:
321
+ ) -> c.Sequence[int] | None:
324
322
  """Get the end token ID for chat models.
325
323
 
326
324
  This is only relevant for tokenisers with a chat template.
@@ -358,12 +356,16 @@ def get_end_of_chat_token_ids(
358
356
  x_token_index = idx
359
357
  break
360
358
  else:
361
- logger.debug("Could not locate the end-of-chat token for the model.")
359
+ log(
360
+ "Could not locate the end-of-chat token for the model.", level=logging.DEBUG
361
+ )
362
362
  return None
363
363
 
364
364
  end_of_chat_tokens = token_ids[x_token_index + 1 :]
365
365
  if len(end_of_chat_tokens) == 0:
366
- logger.debug("Could not locate the end-of-chat token for the model.")
366
+ log(
367
+ "Could not locate the end-of-chat token for the model.", level=logging.DEBUG
368
+ )
367
369
  return None
368
370
 
369
371
  log_once(
@@ -432,13 +434,19 @@ def get_first_label_token_mapping(
432
434
 
433
435
  # Tokenise some text containing each label, which we will use to extract the
434
436
  # first token of each label
435
- all_tokens: list[list[str]]
437
+ all_tokens: c.Sequence[c.Sequence[str]]
436
438
  if not has_chat_template(tokeniser=tokeniser):
437
439
  add_prefix_space = should_prefix_space_be_added_to_labels(
438
440
  labels_to_be_generated=local_labels, tokeniser=tokeniser
439
441
  )
440
442
  all_tokens = [
441
- tokeniser.tokenize(text=f" {label}" if add_prefix_space else label)
443
+ [
444
+ tokeniser.decode(token_id)
445
+ for token_id in tokeniser.encode(
446
+ text=f" {label}" if add_prefix_space else label,
447
+ add_special_tokens=False,
448
+ )
449
+ ]
442
450
  for label in local_labels
443
451
  ]
444
452
  else:
@@ -465,7 +473,7 @@ def get_first_label_token_mapping(
465
473
  all_tokens = [
466
474
  [
467
475
  re.sub(
468
- pattern=r"^[^a-zæøåüöä0-9]+|[^a-zæøåüöä0-9]+$",
476
+ pattern=r"^[^a-zæøåüöä0-9 ]+|[^a-zæøåüöä0-9 ]+$",
469
477
  repl="",
470
478
  string=token.lower(),
471
479
  )
@@ -477,11 +485,13 @@ def get_first_label_token_mapping(
477
485
  # Extract the first token of each label
478
486
  first_tokens: list[str] = list()
479
487
  for token_list, label in zip(all_tokens, local_labels):
480
- matching_tokens = [tok for tok in token_list if tok and label.startswith(tok)]
488
+ matching_tokens = [
489
+ tok for tok in token_list if tok and label.startswith(tok.strip())
490
+ ]
481
491
  if not matching_tokens:
482
492
  if log_metadata:
483
493
  log_once(
484
- f"No matching token found in token_list for label '{label}', so "
494
+ f"No matching token found in token_list for label {label!r}, so "
485
495
  "we will not use logprobs with the model.",
486
496
  level=logging.DEBUG,
487
497
  )
@@ -506,7 +516,8 @@ def get_first_label_token_mapping(
506
516
  log_once(
507
517
  "We will not use logprobs with the model since the first tokens of the "
508
518
  "labels are not distinct. The first tokens for the labels "
509
- f"{local_labels} are {first_tokens}"
519
+ f"{local_labels} are {first_tokens}",
520
+ level=logging.DEBUG,
510
521
  )
511
522
  return False
512
523
 
@@ -547,12 +558,12 @@ def has_chat_template(tokeniser: "PreTrainedTokenizer") -> bool:
547
558
 
548
559
 
549
560
  def apply_chat_template(
550
- conversation: list[dict[str, str]],
561
+ conversation: c.Sequence[dict[str, str]],
551
562
  tokeniser: "PreTrainedTokenizer",
552
563
  tokenise: bool,
553
564
  add_generation_prompt: bool,
554
565
  **extra_kwargs,
555
- ) -> str | list[int]:
566
+ ) -> str | c.Sequence[int]:
556
567
  """Apply the chat template to a prompt.
557
568
 
558
569
  Args:
euroeval/types.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Types used throughout the project."""
2
2
 
3
+ import collections.abc as c
3
4
  import typing as t
4
5
 
5
6
  from transformers.trainer_utils import EvalPrediction
@@ -10,9 +11,9 @@ if t.TYPE_CHECKING:
10
11
 
11
12
  from .data_models import BenchmarkConfig, GenerativeModelOutput
12
13
 
13
- ScoreDict: t.TypeAlias = dict[str, dict[str, float] | list[dict[str, float]]]
14
- Predictions: t.TypeAlias = "NDArray | list[str] | list[list[str]]"
15
- Labels: t.TypeAlias = "NDArray | list[str] | list[list[str]]"
14
+ ScoreDict: t.TypeAlias = dict[str, dict[str, float] | c.Sequence[dict[str, float]]]
15
+ Predictions: t.TypeAlias = "NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]"
16
+ Labels: t.TypeAlias = "NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]"
16
17
 
17
18
 
18
19
  class ComputeMetricsFunction(t.Protocol):
@@ -22,8 +23,8 @@ class ComputeMetricsFunction(t.Protocol):
22
23
  self,
23
24
  model_outputs_and_labels: EvalPrediction
24
25
  | tuple[
25
- "NDArray | list[str] | list[list[str]]",
26
- "NDArray | list[str] | list[list[str]]",
26
+ "NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]",
27
+ "NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]",
27
28
  ],
28
29
  dataset: "Dataset",
29
30
  benchmark_config: "BenchmarkConfig",
@@ -48,7 +49,7 @@ class ExtractLabelsFunction(t.Protocol):
48
49
 
49
50
  def __call__(
50
51
  self, input_batch: dict[str, list], model_output: "GenerativeModelOutput"
51
- ) -> list[str]:
52
+ ) -> c.Sequence[str]:
52
53
  """Extract the labels from the generated output.
53
54
 
54
55
  Args:
@@ -63,7 +64,7 @@ class ExtractLabelsFunction(t.Protocol):
63
64
  ...
64
65
 
65
66
 
66
- def is_list_of_int(x: object) -> t.TypeGuard[list[int]]:
67
+ def is_list_of_int(x: object) -> t.TypeGuard[c.Sequence[int]]:
67
68
  """Check if an object is a list of integers.
68
69
 
69
70
  Args:
@@ -76,7 +77,7 @@ def is_list_of_int(x: object) -> t.TypeGuard[list[int]]:
76
77
  return isinstance(x, list) and all(isinstance(i, int) for i in x)
77
78
 
78
79
 
79
- def is_list_of_list_of_int(x: object) -> t.TypeGuard[list[list[int]]]:
80
+ def is_list_of_list_of_int(x: object) -> t.TypeGuard[c.Sequence[c.Sequence[int]]]:
80
81
  """Check if an object is a list of list of integers.
81
82
 
82
83
  Args:
@@ -93,7 +94,7 @@ def is_list_of_list_of_int(x: object) -> t.TypeGuard[list[list[int]]]:
93
94
  )
94
95
 
95
96
 
96
- def is_list_of_str(x: object) -> t.TypeGuard[list[str]]:
97
+ def is_list_of_str(x: object) -> t.TypeGuard[c.Sequence[str]]:
97
98
  """Check if an object is a list of integers.
98
99
 
99
100
  Args:
euroeval/utils.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Utility functions to be used in other scripts."""
2
2
 
3
3
  import asyncio
4
+ import collections.abc as c
4
5
  import gc
5
6
  import importlib
6
7
  import importlib.metadata
@@ -11,30 +12,23 @@ import re
11
12
  import socket
12
13
  import sys
13
14
  import typing as t
14
- import warnings
15
- from functools import cache
16
15
  from pathlib import Path
17
16
 
18
17
  import demjson3
19
18
  import huggingface_hub as hf_hub
20
- import litellm
21
19
  import numpy as np
22
20
  import torch
23
- from datasets.utils import disable_progress_bar
24
- from transformers import logging as tf_logging
25
21
 
22
+ from .caching_utils import cache_arguments
23
+ from .constants import T
26
24
  from .exceptions import InvalidBenchmark, InvalidModel, NaNValueInModelOutput
25
+ from .logging_utils import log, log_once
27
26
 
28
27
  if t.TYPE_CHECKING:
29
- from types import TracebackType
30
-
31
28
  from .data_models import ModelIdComponents
32
29
  from .types import Predictions
33
30
 
34
31
 
35
- logger = logging.getLogger("euroeval")
36
-
37
-
38
32
  def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
39
33
  """Create cache directory for a model.
40
34
 
@@ -149,69 +143,9 @@ def enforce_reproducibility(seed: int = 4242) -> np.random.Generator:
149
143
  return rng
150
144
 
151
145
 
152
- def block_terminal_output() -> None:
153
- """Blocks libraries from writing output to the terminal.
154
-
155
- This filters warnings from some libraries, sets the logging level to ERROR for some
156
- libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
157
- disables most of the logging from the `transformers` library.
158
- """
159
- if os.getenv("FULL_LOG") == "1":
160
- return
161
-
162
- # Ignore miscellaneous warnings
163
- warnings.filterwarnings("ignore", category=UserWarning)
164
- warnings.filterwarnings("ignore", category=FutureWarning)
165
- logging.getLogger("absl").setLevel(logging.CRITICAL)
166
-
167
- # Disable matplotlib logging
168
- logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
169
-
170
- # Disable PyTorch logging
171
- logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
172
- warnings.filterwarnings(action="ignore", module="torch*")
173
- os.environ["TORCH_LOGS"] = "-all"
174
-
175
- # Disable huggingface_hub logging
176
- logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
177
-
178
- # Disable LiteLLM logging
179
- logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
180
- logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
181
- logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
182
- logging.getLogger("openai").setLevel(logging.CRITICAL)
183
- logging.getLogger("httpx").setLevel(logging.CRITICAL)
184
- litellm.suppress_debug_info = True
185
-
186
- # Disable vLLM logging
187
- logging.getLogger("vllm").setLevel(logging.CRITICAL)
188
- logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
189
- logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
190
- logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
191
- logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
192
- logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
193
- logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
194
- logging.CRITICAL
195
- )
196
- os.environ["LOG_LEVEL"] = "CRITICAL"
197
- os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
198
-
199
- # Disable datasets logging
200
- logging.getLogger("datasets").setLevel(logging.CRITICAL)
201
- logging.getLogger("filelock").setLevel(logging.CRITICAL)
202
- disable_progress_bar()
203
-
204
- # Disable evaluate logging
205
- warnings.filterwarnings("ignore", module="seqeval*")
206
-
207
- # Disable most of the `transformers` logging
208
- tf_logging._default_log_level = logging.CRITICAL
209
- tf_logging.set_verbosity(logging.CRITICAL)
210
- logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
211
- logging.getLogger("accelerate").setLevel(logging.CRITICAL)
212
-
213
-
214
- def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type | None:
146
+ def get_class_by_name(
147
+ class_name: str | c.Sequence[str], module_name: str
148
+ ) -> t.Type | None:
215
149
  """Get a class by its name.
216
150
 
217
151
  Args:
@@ -240,9 +174,10 @@ def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type |
240
174
 
241
175
  if error_messages:
242
176
  errors = "\n- " + "\n- ".join(error_messages)
243
- logger.debug(
177
+ log(
244
178
  f"Could not find the class with the name(s) {', '.join(class_name)}. The "
245
- f"following error messages were raised: {errors}"
179
+ f"following error messages were raised: {errors}",
180
+ level=logging.DEBUG,
246
181
  )
247
182
 
248
183
  # If the class could not be found, return None
@@ -264,49 +199,27 @@ def get_min_cuda_compute_capability() -> float | None:
264
199
  return float(f"{major}.{minor}")
265
200
 
266
201
 
267
- @cache
202
+ @cache_arguments(disable_condition=lambda: hasattr(sys, "_called_from_test"))
268
203
  def internet_connection_available() -> bool:
269
204
  """Checks if internet connection is available by pinging google.com.
270
205
 
271
206
  Returns:
272
207
  Whether or not internet connection is available.
273
208
  """
209
+ internet_available: bool = False
210
+
274
211
  try:
275
212
  s = socket.create_connection(("1.1.1.1", 80))
276
213
  s.close()
277
- return True
278
-
279
- # We want to only catch exceptions related to socket connections, but as we cannot
280
- # import these here as they're developer dependencies, we check the exception name
281
- # instead. If the exception is not related to socket connections, we reraise it.
214
+ internet_available = True
215
+ except OSError:
216
+ pass
282
217
  except Exception as e:
283
218
  pytest_socket_errors = ["SocketConnectBlockedError", "SocketBlockedError"]
284
- if type(e).__name__ in pytest_socket_errors or isinstance(e, OSError):
285
- return False
286
- raise e
287
-
288
-
289
- class HiddenPrints:
290
- """Context manager which removes all terminal output."""
291
-
292
- def __enter__(self) -> None:
293
- """Enter the context manager."""
294
- self._original_stdout = sys.stdout
295
- self._original_stderr = sys.stderr
296
- sys.stdout = open(os.devnull, "w")
297
- sys.stderr = open(os.devnull, "w")
298
-
299
- def __exit__(
300
- self,
301
- exc_type: t.Type[BaseException],
302
- exc_val: BaseException,
303
- exc_tb: "TracebackType",
304
- ) -> None:
305
- """Exit the context manager."""
306
- sys.stdout.close()
307
- sys.stderr.close()
308
- sys.stdout = self._original_stdout
309
- sys.stderr = self._original_stderr
219
+ if type(e).__name__ not in pytest_socket_errors:
220
+ raise e
221
+
222
+ return internet_available
310
223
 
311
224
 
312
225
  def raise_if_model_output_contains_nan_values(model_output: "Predictions") -> None:
@@ -364,34 +277,6 @@ def unscramble(scrambled_text: str) -> str:
364
277
  return unscrambled
365
278
 
366
279
 
367
- @cache
368
- def log_once(message: str, level: int = logging.INFO) -> None:
369
- """Log a message once.
370
-
371
- This is ensured by caching the input/output pairs of this function, using the
372
- `functools.cache` decorator.
373
-
374
- Args:
375
- message:
376
- The message to log.
377
- level:
378
- The logging level. Defaults to logging.INFO.
379
- """
380
- match level:
381
- case logging.DEBUG:
382
- logger.debug(message)
383
- case logging.INFO:
384
- logger.info(message)
385
- case logging.WARNING:
386
- logger.warning(message)
387
- case logging.ERROR:
388
- logger.error(message)
389
- case logging.CRITICAL:
390
- logger.critical(message)
391
- case _:
392
- raise ValueError(f"Invalid logging level: {level}")
393
-
394
-
395
280
  def get_package_version(package_name: str) -> str | None:
396
281
  """Get the version of a package.
397
282
 
@@ -408,9 +293,6 @@ def get_package_version(package_name: str) -> str | None:
408
293
  return None
409
294
 
410
295
 
411
- T = t.TypeVar("T", bound=object)
412
-
413
-
414
296
  def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
415
297
  """Run a coroutine, ensuring that the event loop is always closed when we're done.
416
298
 
@@ -464,37 +346,41 @@ def extract_json_dict_from_string(s: str) -> dict | None:
464
346
  """
465
347
  json_regex = r"\{[^{}]*?\}"
466
348
  if (json_match := re.search(pattern=json_regex, string=s, flags=re.DOTALL)) is None:
467
- logger.debug(
349
+ log(
468
350
  "The model output does not contain any JSON dictionary, so cannot parse "
469
- f"it. Skipping. Here is the output: {s!r}"
351
+ f"it. Skipping. Here is the output: {s!r}",
352
+ level=logging.DEBUG,
470
353
  )
471
354
  return None
472
355
  json_string = json_match.group()
473
356
  try:
474
357
  json_output = demjson3.decode(txt=json_string)
475
358
  except demjson3.JSONDecodeError:
476
- logger.debug(
359
+ log(
477
360
  "The model output is not valid JSON, so cannot parse it. Skipping. "
478
- f"Here is the output: {json_string!r}"
361
+ f"Here is the output: {json_string!r}",
362
+ level=logging.DEBUG,
479
363
  )
480
364
  return None
481
365
  if not isinstance(json_output, dict):
482
- logger.debug(
366
+ log(
483
367
  "The model output is not a JSON dictionary, so cannot parse "
484
- f"it. Skipping. Here is the output: {json_string!r}"
368
+ f"it. Skipping. Here is the output: {json_string!r}",
369
+ level=logging.DEBUG,
485
370
  )
486
371
  return None
487
372
  elif not all(isinstance(key, str) for key in json_output.keys()):
488
- logger.debug(
373
+ log(
489
374
  "The model output is not a JSON dictionary with string keys, "
490
375
  "so cannot parse it. Skipping. Here is the output: "
491
- f"{json_string!r}"
376
+ f"{json_string!r}",
377
+ level=logging.DEBUG,
492
378
  )
493
379
  return None
494
380
  return json_output
495
381
 
496
382
 
497
- @cache
383
+ @cache_arguments()
498
384
  def get_hf_token(api_key: str | None) -> str | bool:
499
385
  """Get the Hugging Face token.
500
386
 
@@ -538,8 +424,8 @@ def get_hf_token(api_key: str | None) -> str | bool:
538
424
 
539
425
 
540
426
  def extract_multiple_choice_labels(
541
- prompt: str, candidate_labels: list[str]
542
- ) -> list[str]:
427
+ prompt: str, candidate_labels: c.Sequence[str]
428
+ ) -> c.Sequence[str]:
543
429
  """Extract multiple choice labels from a prompt.
544
430
 
545
431
  Args: