EuroEval 16.3.0__py3-none-any.whl → 16.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (78) hide show
  1. euroeval/__init__.py +9 -2
  2. euroeval/benchmark_config_factory.py +51 -50
  3. euroeval/benchmark_modules/base.py +9 -21
  4. euroeval/benchmark_modules/fresh.py +2 -1
  5. euroeval/benchmark_modules/hf.py +101 -71
  6. euroeval/benchmark_modules/litellm.py +115 -53
  7. euroeval/benchmark_modules/vllm.py +107 -92
  8. euroeval/benchmarker.py +144 -121
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +86 -8
  12. euroeval/constants.py +9 -0
  13. euroeval/data_loading.py +80 -29
  14. euroeval/data_models.py +338 -330
  15. euroeval/dataset_configs/__init__.py +12 -3
  16. euroeval/dataset_configs/bulgarian.py +56 -0
  17. euroeval/dataset_configs/czech.py +75 -0
  18. euroeval/dataset_configs/danish.py +55 -93
  19. euroeval/dataset_configs/dutch.py +48 -87
  20. euroeval/dataset_configs/english.py +45 -77
  21. euroeval/dataset_configs/estonian.py +42 -34
  22. euroeval/dataset_configs/faroese.py +19 -60
  23. euroeval/dataset_configs/finnish.py +36 -69
  24. euroeval/dataset_configs/french.py +39 -75
  25. euroeval/dataset_configs/german.py +45 -82
  26. euroeval/dataset_configs/greek.py +64 -0
  27. euroeval/dataset_configs/icelandic.py +54 -91
  28. euroeval/dataset_configs/italian.py +42 -79
  29. euroeval/dataset_configs/latvian.py +28 -35
  30. euroeval/dataset_configs/lithuanian.py +28 -26
  31. euroeval/dataset_configs/norwegian.py +72 -115
  32. euroeval/dataset_configs/polish.py +33 -61
  33. euroeval/dataset_configs/portuguese.py +33 -66
  34. euroeval/dataset_configs/serbian.py +64 -0
  35. euroeval/dataset_configs/slovak.py +55 -0
  36. euroeval/dataset_configs/spanish.py +42 -77
  37. euroeval/dataset_configs/swedish.py +52 -90
  38. euroeval/dataset_configs/ukrainian.py +64 -0
  39. euroeval/exceptions.py +1 -1
  40. euroeval/finetuning.py +24 -17
  41. euroeval/generation.py +15 -14
  42. euroeval/generation_utils.py +8 -8
  43. euroeval/languages.py +395 -323
  44. euroeval/logging_utils.py +250 -0
  45. euroeval/metrics/base.py +0 -3
  46. euroeval/metrics/huggingface.py +21 -6
  47. euroeval/metrics/llm_as_a_judge.py +6 -4
  48. euroeval/metrics/pipeline.py +17 -9
  49. euroeval/metrics/speed.py +0 -3
  50. euroeval/model_cache.py +17 -19
  51. euroeval/model_config.py +4 -5
  52. euroeval/model_loading.py +3 -0
  53. euroeval/prompt_templates/__init__.py +2 -0
  54. euroeval/prompt_templates/classification.py +206 -0
  55. euroeval/prompt_templates/linguistic_acceptability.py +99 -42
  56. euroeval/prompt_templates/multiple_choice.py +102 -38
  57. euroeval/prompt_templates/named_entity_recognition.py +172 -51
  58. euroeval/prompt_templates/reading_comprehension.py +119 -42
  59. euroeval/prompt_templates/sentiment_classification.py +110 -40
  60. euroeval/prompt_templates/summarization.py +85 -40
  61. euroeval/prompt_templates/token_classification.py +279 -0
  62. euroeval/scores.py +11 -10
  63. euroeval/speed_benchmark.py +5 -6
  64. euroeval/task_group_utils/multiple_choice_classification.py +2 -4
  65. euroeval/task_group_utils/question_answering.py +24 -16
  66. euroeval/task_group_utils/sequence_classification.py +48 -35
  67. euroeval/task_group_utils/text_to_text.py +19 -9
  68. euroeval/task_group_utils/token_classification.py +21 -17
  69. euroeval/tasks.py +44 -1
  70. euroeval/tokenisation_utils.py +33 -22
  71. euroeval/types.py +10 -9
  72. euroeval/utils.py +35 -149
  73. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/METADATA +196 -39
  74. euroeval-16.5.0.dist-info/RECORD +81 -0
  75. euroeval-16.3.0.dist-info/RECORD +0 -71
  76. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/WHEEL +0 -0
  77. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/entry_points.txt +0 -0
  78. {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,250 @@
1
+ """Utility functions related to logging."""
2
+
3
+ import datetime as dt
4
+ import logging
5
+ import os
6
+ import sys
7
+ import warnings
8
+ from io import TextIOWrapper
9
+
10
+ import litellm
11
+ from datasets.utils import disable_progress_bars as disable_datasets_progress_bars
12
+ from evaluate import disable_progress_bar as disable_evaluate_progress_bar
13
+ from huggingface_hub.utils.tqdm import (
14
+ disable_progress_bars as disable_hf_hub_progress_bars,
15
+ )
16
+ from termcolor import colored
17
+ from tqdm.auto import tqdm
18
+ from transformers import logging as tf_logging
19
+
20
+ from .caching_utils import cache_arguments
21
+
22
+ logger = logging.getLogger("euroeval")
23
+
24
+
25
+ def get_pbar(*tqdm_args, **tqdm_kwargs) -> tqdm:
26
+ """Get a progress bar for vLLM with custom hard-coded arguments.
27
+
28
+ Args:
29
+ *tqdm_args:
30
+ Positional arguments to pass to tqdm.
31
+ **tqdm_kwargs:
32
+ Additional keyword arguments to pass to tqdm.
33
+
34
+ Returns:
35
+ A tqdm progress bar.
36
+ """
37
+ tqdm_kwargs = dict(colour="yellow", ascii="—▰", leave=False) | tqdm_kwargs
38
+ tqdm_kwargs["desc"] = colored(
39
+ text=tqdm_kwargs.get("desc", "Processing"), color="light_yellow"
40
+ )
41
+ return tqdm(*tqdm_args, **tqdm_kwargs)
42
+
43
+
44
+ def log(message: str, level: int, colour: str | None = None) -> None:
45
+ """Log a message.
46
+
47
+ Args:
48
+ message:
49
+ The message to log.
50
+ level:
51
+ The logging level. Defaults to logging.INFO.
52
+ colour:
53
+ The colour to use for the message. If None, a default colour will be used
54
+ based on the logging level.
55
+
56
+ Raises:
57
+ ValueError:
58
+ If the logging level is invalid.
59
+ """
60
+ match level:
61
+ case logging.DEBUG:
62
+ message = colored(
63
+ text=(
64
+ "[DEBUG] "
65
+ + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
66
+ + f" · {message}"
67
+ ),
68
+ color=colour or "light_blue",
69
+ )
70
+ logger.debug(message)
71
+ case logging.INFO:
72
+ if colour is not None:
73
+ message = colored(text=message, color=colour)
74
+ logger.info(message)
75
+ case logging.WARNING:
76
+ message = colored(text=message, color=colour or "light_red")
77
+ logger.warning(message)
78
+ case logging.ERROR:
79
+ message = colored(text=message, color=colour or "red")
80
+ logger.error(message)
81
+ case logging.CRITICAL:
82
+ message = colored(text=message, color=colour or "red")
83
+ logger.critical(message)
84
+ case _:
85
+ raise ValueError(f"Invalid logging level: {level}")
86
+
87
+
88
+ @cache_arguments("message")
89
+ def log_once(message: str, level: int = logging.INFO, prefix: str = "") -> None:
90
+ """Log a message once.
91
+
92
+ This is ensured by caching the "message" argument and only logging it the first time
93
+ this function is called with that message.
94
+
95
+ Args:
96
+ message:
97
+ The message to log.
98
+ level:
99
+ The logging level. Defaults to logging.INFO.
100
+ prefix:
101
+ A prefix to add to the message, which is not considered when determining if
102
+ the message has been logged before.
103
+ """
104
+ log(message=prefix + message, level=level)
105
+
106
+
107
+ def block_terminal_output() -> None:
108
+ """Blocks libraries from writing output to the terminal.
109
+
110
+ This filters warnings from some libraries, sets the logging level to ERROR for some
111
+ libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
112
+ disables most of the logging from the `transformers` library.
113
+ """
114
+ if os.getenv("FULL_LOG") == "1":
115
+ return
116
+
117
+ # Ignore miscellaneous warnings
118
+ warnings.filterwarnings("ignore", category=UserWarning)
119
+ warnings.filterwarnings("ignore", category=FutureWarning)
120
+ logging.getLogger("absl").setLevel(logging.CRITICAL)
121
+
122
+ # Disable matplotlib logging
123
+ logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
124
+
125
+ # Disable PyTorch logging
126
+ logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
127
+ warnings.filterwarnings(action="ignore", module="torch*")
128
+ os.environ["TORCH_LOGS"] = "-all"
129
+
130
+ # Disable huggingface_hub logging
131
+ logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
132
+ disable_hf_hub_progress_bars()
133
+
134
+ # Disable LiteLLM logging
135
+ logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
136
+ logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
137
+ logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
138
+ logging.getLogger("openai").setLevel(logging.CRITICAL)
139
+ logging.getLogger("httpx").setLevel(logging.CRITICAL)
140
+ litellm.suppress_debug_info = True
141
+
142
+ # Disable vLLM logging
143
+ logging.getLogger("vllm").setLevel(logging.CRITICAL)
144
+ logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
145
+ logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
146
+ logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
147
+ logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
148
+ logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
149
+ logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
150
+ logging.CRITICAL
151
+ )
152
+ os.environ["LOG_LEVEL"] = "CRITICAL"
153
+ os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
154
+
155
+ # Disable flashinfer logging
156
+ os.environ["FLASHINFER_LOGGING_LEVEL"] = "CRITICAL"
157
+
158
+ # Disable datasets logging
159
+ logging.getLogger("datasets").setLevel(logging.CRITICAL)
160
+ logging.getLogger("filelock").setLevel(logging.CRITICAL)
161
+ disable_datasets_progress_bars()
162
+
163
+ # Disable evaluate logging
164
+ warnings.filterwarnings("ignore", module="seqeval*")
165
+ disable_evaluate_progress_bar()
166
+
167
+ # Disable most of the `transformers` logging
168
+ tf_logging._default_log_level = logging.CRITICAL
169
+ tf_logging.set_verbosity(logging.CRITICAL)
170
+ logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
171
+ logging.getLogger("accelerate").setLevel(logging.CRITICAL)
172
+
173
+
174
+ class no_terminal_output:
175
+ """Context manager that suppresses all terminal output."""
176
+
177
+ def __init__(self, disable: bool = False) -> None:
178
+ """Initialise the context manager.
179
+
180
+ Args:
181
+ disable:
182
+ If True, this context manager does nothing.
183
+ """
184
+ self.disable = disable
185
+ self.nothing_file: TextIOWrapper | None = None
186
+ self._cpp_stdout_file: int | None = None
187
+ self._cpp_stderr_file: int | None = None
188
+ try:
189
+ self._cpp_stdout_file = os.dup(sys.stdout.fileno())
190
+ self._cpp_stderr_file = os.dup(sys.stderr.fileno())
191
+ except OSError:
192
+ self._log_windows_warning()
193
+
194
+ def _log_windows_warning(self) -> None:
195
+ """Log a warning about Windows not supporting blocking terminal output."""
196
+ log_once(
197
+ "Your operating system (probably Windows) does not support blocking "
198
+ "terminal output, so expect more messy output - sorry!",
199
+ level=logging.WARNING,
200
+ )
201
+
202
+ def __enter__(self) -> None:
203
+ """Suppress all terminal output."""
204
+ if not self.disable:
205
+ self.nothing_file = open(os.devnull, "w")
206
+ try:
207
+ os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stdout.fileno())
208
+ os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stderr.fileno())
209
+ except OSError:
210
+ self._log_windows_warning()
211
+
212
+ def __exit__(
213
+ self,
214
+ exc_type: type[BaseException] | None,
215
+ exc_val: BaseException | None,
216
+ exc_tb: type[BaseException] | None,
217
+ ) -> None:
218
+ """Re-enable terminal output."""
219
+ if not self.disable:
220
+ if self.nothing_file is not None:
221
+ self.nothing_file.close()
222
+ try:
223
+ if self._cpp_stdout_file is not None:
224
+ os.dup2(fd=self._cpp_stdout_file, fd2=sys.stdout.fileno())
225
+ if self._cpp_stderr_file is not None:
226
+ os.dup2(fd=self._cpp_stderr_file, fd2=sys.stderr.fileno())
227
+ except OSError:
228
+ self._log_windows_warning()
229
+
230
+
231
+ def adjust_logging_level(verbose: bool, ignore_testing: bool = False) -> int:
232
+ """Adjust the logging level based on verbosity.
233
+
234
+ Args:
235
+ verbose:
236
+ Whether to output additional output.
237
+ ignore_testing:
238
+ Whether to ignore the testing flag.
239
+
240
+ Returns:
241
+ The logging level that was set.
242
+ """
243
+ if hasattr(sys, "_called_from_test") and not ignore_testing:
244
+ logging_level = logging.CRITICAL
245
+ elif verbose:
246
+ logging_level = logging.DEBUG
247
+ else:
248
+ logging_level = logging.INFO
249
+ logger.setLevel(logging_level)
250
+ return logging_level
euroeval/metrics/base.py CHANGED
@@ -2,7 +2,6 @@
2
2
 
3
3
  import abc
4
4
  import collections.abc as c
5
- import logging
6
5
  import typing as t
7
6
 
8
7
  if t.TYPE_CHECKING:
@@ -10,8 +9,6 @@ if t.TYPE_CHECKING:
10
9
 
11
10
  from ..data_models import BenchmarkConfig, DatasetConfig
12
11
 
13
- logger: logging.Logger = logging.getLogger("euroeval")
14
-
15
12
 
16
13
  class Metric(abc.ABC):
17
14
  """Abstract base class for all metrics."""
@@ -1,7 +1,6 @@
1
1
  """All the Hugging Face metrics used in EuroEval."""
2
2
 
3
3
  import collections.abc as c
4
- import logging
5
4
  import typing as t
6
5
  from pathlib import Path
7
6
 
@@ -9,7 +8,7 @@ import evaluate
9
8
  import numpy as np
10
9
  from datasets import DownloadConfig
11
10
 
12
- from ..utils import HiddenPrints
11
+ from ..logging_utils import no_terminal_output
13
12
  from .base import Metric
14
13
 
15
14
  if t.TYPE_CHECKING:
@@ -18,8 +17,6 @@ if t.TYPE_CHECKING:
18
17
 
19
18
  from ..data_models import BenchmarkConfig, DatasetConfig
20
19
 
21
- logger: logging.Logger = logging.getLogger("euroeval")
22
-
23
20
 
24
21
  class HuggingFaceMetric(Metric):
25
22
  """A metric which is implemented in the `evaluate` package.
@@ -124,9 +121,12 @@ class HuggingFaceMetric(Metric):
124
121
  if self.metric is None:
125
122
  self.download(cache_dir=benchmark_config.cache_dir)
126
123
 
127
- assert self.metric is not None
124
+ assert self.metric is not None, (
125
+ "Metric has not been downloaded. Please call download() before using the "
126
+ "__call__ method."
127
+ )
128
128
 
129
- with HiddenPrints():
129
+ with no_terminal_output(disable=benchmark_config.verbose):
130
130
  results = self.metric.compute(
131
131
  predictions=predictions, references=references, **self.compute_kwargs
132
132
  )
@@ -143,8 +143,23 @@ class HuggingFaceMetric(Metric):
143
143
  if isinstance(score, np.floating):
144
144
  score = float(score)
145
145
 
146
+ self.close()
146
147
  return score
147
148
 
149
+ def close(self) -> None:
150
+ """Close any resources held by the metric."""
151
+ if self.metric is not None:
152
+ if self.metric.filelock is not None:
153
+ self.metric.filelock.release(force=True)
154
+ if self.metric.writer is not None:
155
+ self.metric.writer.finalize(close_stream=True)
156
+
157
+ def __del__(self) -> None:
158
+ """Clean up the metric from memory."""
159
+ if self.metric is not None:
160
+ self.close()
161
+ del self.metric
162
+
148
163
 
149
164
  mcc_metric = HuggingFaceMetric(
150
165
  name="mcc",
@@ -8,7 +8,7 @@ from pathlib import Path
8
8
  from pydantic import BaseModel, Field
9
9
 
10
10
  from ..exceptions import InvalidBenchmark
11
- from ..model_cache import ModelCache
11
+ from ..logging_utils import log
12
12
  from ..utils import extract_json_dict_from_string
13
13
  from .base import Metric
14
14
 
@@ -17,8 +17,6 @@ if t.TYPE_CHECKING:
17
17
 
18
18
  from ..data_models import BenchmarkConfig, DatasetConfig
19
19
 
20
- logger: logging.Logger = logging.getLogger("euroeval")
21
-
22
20
 
23
21
  class LLMAsAJudgeMetric(Metric):
24
22
  """Use an LLM to judge the quality of the predictions."""
@@ -112,6 +110,7 @@ class LLMAsAJudgeMetric(Metric):
112
110
  """
113
111
  # Importing here to avoid circular imports
114
112
  from ..benchmark_modules import LiteLLMModel
113
+ from ..model_cache import ModelCache
115
114
 
116
115
  if not predictions or not references:
117
116
  return None
@@ -190,7 +189,10 @@ class LLMAsAJudgeMetric(Metric):
190
189
  # Calculate the scores using the scoring function
191
190
  scores = [self.scoring_fn(output) for output in outputs]
192
191
  if not scores:
193
- logger.warning(f"No scores were calculated for {self.pretty_name}.")
192
+ log(
193
+ f"No scores were calculated for {self.pretty_name}.",
194
+ level=logging.WARNING,
195
+ )
194
196
  return None
195
197
  return sum(scores) / len(scores)
196
198
 
@@ -11,6 +11,7 @@ import numpy as np
11
11
  from scipy.special import expit as sigmoid
12
12
 
13
13
  from ..exceptions import InvalidBenchmark
14
+ from ..logging_utils import log, no_terminal_output
14
15
  from ..utils import unscramble
15
16
  from .base import Metric
16
17
 
@@ -20,8 +21,6 @@ if t.TYPE_CHECKING:
20
21
 
21
22
  from ..data_models import BenchmarkConfig, DatasetConfig
22
23
 
23
- logger: logging.Logger = logging.getLogger("euroeval")
24
-
25
24
 
26
25
  T = t.TypeVar("T", bound=int | float | str | bool)
27
26
 
@@ -121,16 +120,22 @@ class PipelineMetric(Metric):
121
120
  The calculated metric score, or None if the score should be ignored.
122
121
  """
123
122
  if self.pipeline is None:
124
- self.pipeline = self._download_pipeline()
123
+ self.pipeline = self._download_pipeline(
124
+ cache_dir=benchmark_config.cache_dir
125
+ )
125
126
  if self.preprocessing_fn is not None:
126
127
  predictions = self.preprocessing_fn(
127
128
  predictions=predictions, dataset=dataset
128
129
  )
129
130
  return self.pipeline_scoring_function(self.pipeline, predictions)
130
131
 
131
- def _download_pipeline(self) -> "Pipeline":
132
+ def _download_pipeline(self, cache_dir: str) -> "Pipeline":
132
133
  """Download the scikit-learn pipeline from the given URL.
133
134
 
135
+ Args:
136
+ cache_dir:
137
+ The directory to use for caching the downloaded pipeline.
138
+
134
139
  Returns:
135
140
  The downloaded scikit-learn pipeline.
136
141
 
@@ -138,10 +143,13 @@ class PipelineMetric(Metric):
138
143
  InvalidBenchmark:
139
144
  If the loading of the pipeline fails for any reason.
140
145
  """
141
- logger.debug(f"Loading pipeline from {self.pipeline_repo}...")
142
- folder_path = hf_hub.HfApi(
143
- token=unscramble("HjccJFhIozVymqXDVqTUTXKvYhZMTbfIjMxG_")
144
- ).snapshot_download(repo_id=self.pipeline_repo, repo_type="model")
146
+ log(f"Loading pipeline from {self.pipeline_repo}...", level=logging.DEBUG)
147
+ with no_terminal_output():
148
+ folder_path = hf_hub.HfApi(
149
+ token=unscramble("XbjeOLhwebEaSaDUMqqaPaPIhgOcyOfDpGnX_")
150
+ ).snapshot_download(
151
+ repo_id=self.pipeline_repo, repo_type="model", cache_dir=cache_dir
152
+ )
145
153
  model_path = Path(folder_path, self.pipeline_file_name)
146
154
  try:
147
155
  with model_path.open(mode="rb") as f:
@@ -150,7 +158,7 @@ class PipelineMetric(Metric):
150
158
  raise InvalidBenchmark(
151
159
  f"Failed to load pipeline from {self.pipeline_repo!r}: {e}"
152
160
  ) from e
153
- logger.debug(f"Successfully loaded pipeline: {pipeline}")
161
+ log(f"Successfully loaded pipeline: {pipeline}", level=logging.DEBUG)
154
162
  return pipeline
155
163
 
156
164
 
euroeval/metrics/speed.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """Inference speed metric."""
2
2
 
3
3
  import collections.abc as c
4
- import logging
5
4
  import typing as t
6
5
 
7
6
  from .base import Metric
@@ -11,8 +10,6 @@ if t.TYPE_CHECKING:
11
10
 
12
11
  from ..data_models import BenchmarkConfig, DatasetConfig
13
12
 
14
- logger: logging.Logger = logging.getLogger("euroeval")
15
-
16
13
 
17
14
  class SpeedMetric(Metric):
18
15
  """Speed metric."""
euroeval/model_cache.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """ModelCache class for caching model outputs."""
2
2
 
3
+ import collections.abc as c
3
4
  import hashlib
4
5
  import json
5
6
  import logging
@@ -8,11 +9,9 @@ import typing as t
8
9
  from collections import defaultdict
9
10
  from dataclasses import asdict
10
11
 
11
- from tqdm.auto import tqdm
12
-
13
12
  from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
14
13
  from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
15
- from .utils import log_once
14
+ from .logging_utils import get_pbar, log, log_once
16
15
 
17
16
  if t.TYPE_CHECKING:
18
17
  from pathlib import Path
@@ -20,9 +19,6 @@ if t.TYPE_CHECKING:
20
19
  from datasets import Dataset
21
20
 
22
21
 
23
- logger = logging.getLogger("euroeval")
24
-
25
-
26
22
  class ModelCache:
27
23
  """A cache for model outputs.
28
24
 
@@ -65,9 +61,10 @@ class ModelCache:
65
61
  with self.cache_path.open() as f:
66
62
  json_cache = json.load(f)
67
63
  except json.JSONDecodeError:
68
- logger.warning(
64
+ log(
69
65
  f"Failed to load the cache from {self.cache_path}. The cache will be "
70
- f"re-initialised."
66
+ f"re-initialised.",
67
+ level=logging.WARNING,
71
68
  )
72
69
  json_cache = dict()
73
70
  with self.cache_path.open("w") as f:
@@ -89,15 +86,16 @@ class ModelCache:
89
86
  with self.cache_path.open("w") as f:
90
87
  json.dump(dumpable_cache, f)
91
88
  except KeyError:
92
- logger.warning(
89
+ log(
93
90
  f"Failed to load the cache from {self.cache_path}. The cache will be "
94
- f"re-initialised."
91
+ f"re-initialised.",
92
+ level=logging.WARNING,
95
93
  )
96
94
  self.cache = dict()
97
95
  with self.cache_path.open("w") as f:
98
96
  json.dump(dict(), f)
99
97
 
100
- def _hash_key(self, key: str | list[dict[str, str]]) -> str:
98
+ def _hash_key(self, key: str | c.Sequence[dict[str, str]]) -> str:
101
99
  """Hash the key to use as an index in the cache.
102
100
 
103
101
  Args:
@@ -110,7 +108,7 @@ class ModelCache:
110
108
  return hashlib.md5(string=str(key).encode()).hexdigest()
111
109
 
112
110
  def __getitem__(
113
- self, key: str | list[dict[str, str]]
111
+ self, key: str | c.Sequence[dict[str, str]]
114
112
  ) -> SingleGenerativeModelOutput:
115
113
  """Get an item from the cache.
116
114
 
@@ -125,7 +123,7 @@ class ModelCache:
125
123
  return self.cache[hashed_key]
126
124
 
127
125
  def __setitem__(
128
- self, key: str | list[dict[str, str]], value: SingleGenerativeModelOutput
126
+ self, key: str | c.Sequence[dict[str, str]], value: SingleGenerativeModelOutput
129
127
  ) -> None:
130
128
  """Set an item in the cache.
131
129
 
@@ -143,7 +141,7 @@ class ModelCache:
143
141
  self.cache_path.unlink()
144
142
  del self.cache
145
143
 
146
- def __contains__(self, key: str | list[dict[str, str]]) -> bool:
144
+ def __contains__(self, key: str | c.Sequence[dict[str, str]]) -> bool:
147
145
  """Check if a key is in the cache.
148
146
 
149
147
  Args:
@@ -172,18 +170,18 @@ class ModelCache:
172
170
 
173
171
  # Double check that the number of inputs and outputs match
174
172
  if not len(model_inputs) == len(model_output.sequences):
175
- logger.warning(
173
+ log(
176
174
  f"Number of model inputs ({len(model_inputs)}) does not match the "
177
175
  f"number of model outputs ({len(model_output.sequences)}). We will not "
178
- f"cache the model outputs."
176
+ f"cache the model outputs.",
177
+ level=logging.WARNING,
179
178
  )
180
179
  return
181
180
 
182
181
  # Store the generated sequences in the cache, one by one
183
- with tqdm(
182
+ with get_pbar(
184
183
  iterable=model_inputs,
185
184
  desc="Caching model outputs",
186
- leave=False,
187
185
  disable=hasattr(sys, "_called_from_test"),
188
186
  ) as pbar:
189
187
  for sample_idx, model_input in enumerate(pbar):
@@ -261,7 +259,7 @@ def load_cached_model_outputs(
261
259
  The model output containing the cached sequences.
262
260
  """
263
261
  input_column = "messages" if "messages" in cached_dataset.column_names else "text"
264
- cached_model_outputs: list[SingleGenerativeModelOutput] = [
262
+ cached_model_outputs: c.Sequence[SingleGenerativeModelOutput] = [
265
263
  cache[prompt] for prompt in cached_dataset[input_column]
266
264
  ]
267
265
 
euroeval/model_config.py CHANGED
@@ -5,14 +5,12 @@ import typing as t
5
5
 
6
6
  from . import benchmark_modules
7
7
  from .exceptions import InvalidModel, NeedsEnvironmentVariable, NeedsExtraInstalled
8
+ from .logging_utils import log
8
9
 
9
10
  if t.TYPE_CHECKING:
10
11
  from .data_models import BenchmarkConfig, ModelConfig
11
12
 
12
13
 
13
- logger = logging.getLogger("euroeval")
14
-
15
-
16
14
  def get_model_config(
17
15
  model_id: str, benchmark_config: "BenchmarkConfig"
18
16
  ) -> "ModelConfig":
@@ -51,9 +49,10 @@ def get_model_config(
51
49
  elif isinstance(exists_or_err, NeedsEnvironmentVariable):
52
50
  needs_env_vars.append(exists_or_err.env_var)
53
51
  elif exists_or_err is True:
54
- logger.debug(
52
+ log(
55
53
  f"The model {model_id!r} was identified by the "
56
- f"{benchmark_module.__name__} benchmark module."
54
+ f"{benchmark_module.__name__} benchmark module.",
55
+ logging.DEBUG,
57
56
  )
58
57
  model_config = benchmark_module.get_model_config(
59
58
  model_id=model_id, benchmark_config=benchmark_config
euroeval/model_loading.py CHANGED
@@ -10,6 +10,7 @@ from .benchmark_modules import (
10
10
  )
11
11
  from .enums import InferenceBackend, ModelType
12
12
  from .exceptions import InvalidModel
13
+ from .logging_utils import log_once
13
14
 
14
15
  if t.TYPE_CHECKING:
15
16
  from .benchmark_modules import BenchmarkModule
@@ -34,6 +35,8 @@ def load_model(
34
35
  Returns:
35
36
  The model.
36
37
  """
38
+ log_once(f"\nLoading the model {model_config.model_id}...")
39
+
37
40
  # The order matters; the first model type that matches will be used. For this
38
41
  # reason, they have been ordered in terms of the most common model types.
39
42
  model_class: t.Type[BenchmarkModule]
@@ -1,8 +1,10 @@
1
1
  """The different prompt templates used in EuroEval."""
2
2
 
3
+ from .classification import CLASSIFICATION_TEMPLATES
3
4
  from .linguistic_acceptability import LA_TEMPLATES
4
5
  from .multiple_choice import MULTIPLE_CHOICE_TEMPLATES
5
6
  from .named_entity_recognition import NER_TEMPLATES
6
7
  from .reading_comprehension import RC_TEMPLATES
7
8
  from .sentiment_classification import SENT_TEMPLATES
8
9
  from .summarization import SUMM_TEMPLATES
10
+ from .token_classification import TOKEN_CLASSIFICATION_TEMPLATES