EuroEval 16.3.0__py3-none-any.whl → 16.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (64) hide show
  1. euroeval/__init__.py +3 -2
  2. euroeval/benchmark_config_factory.py +0 -4
  3. euroeval/benchmark_modules/base.py +3 -16
  4. euroeval/benchmark_modules/fresh.py +2 -1
  5. euroeval/benchmark_modules/hf.py +99 -62
  6. euroeval/benchmark_modules/litellm.py +101 -41
  7. euroeval/benchmark_modules/vllm.py +91 -83
  8. euroeval/benchmarker.py +84 -78
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/constants.py +6 -0
  12. euroeval/data_loading.py +14 -11
  13. euroeval/data_models.py +12 -4
  14. euroeval/dataset_configs/__init__.py +2 -0
  15. euroeval/dataset_configs/czech.py +79 -0
  16. euroeval/dataset_configs/danish.py +10 -11
  17. euroeval/dataset_configs/dutch.py +0 -1
  18. euroeval/dataset_configs/english.py +0 -1
  19. euroeval/dataset_configs/estonian.py +11 -1
  20. euroeval/dataset_configs/finnish.py +0 -1
  21. euroeval/dataset_configs/french.py +0 -1
  22. euroeval/dataset_configs/german.py +0 -1
  23. euroeval/dataset_configs/italian.py +0 -1
  24. euroeval/dataset_configs/latvian.py +0 -1
  25. euroeval/dataset_configs/lithuanian.py +9 -3
  26. euroeval/dataset_configs/norwegian.py +0 -1
  27. euroeval/dataset_configs/polish.py +0 -1
  28. euroeval/dataset_configs/portuguese.py +0 -1
  29. euroeval/dataset_configs/slovak.py +60 -0
  30. euroeval/dataset_configs/spanish.py +0 -1
  31. euroeval/dataset_configs/swedish.py +10 -12
  32. euroeval/finetuning.py +21 -15
  33. euroeval/generation.py +10 -10
  34. euroeval/generation_utils.py +2 -3
  35. euroeval/logging_utils.py +250 -0
  36. euroeval/metrics/base.py +0 -3
  37. euroeval/metrics/huggingface.py +9 -5
  38. euroeval/metrics/llm_as_a_judge.py +5 -3
  39. euroeval/metrics/pipeline.py +17 -9
  40. euroeval/metrics/speed.py +0 -3
  41. euroeval/model_cache.py +11 -14
  42. euroeval/model_config.py +4 -5
  43. euroeval/model_loading.py +3 -0
  44. euroeval/prompt_templates/linguistic_acceptability.py +21 -3
  45. euroeval/prompt_templates/multiple_choice.py +25 -1
  46. euroeval/prompt_templates/named_entity_recognition.py +51 -11
  47. euroeval/prompt_templates/reading_comprehension.py +31 -3
  48. euroeval/prompt_templates/sentiment_classification.py +23 -1
  49. euroeval/prompt_templates/summarization.py +26 -6
  50. euroeval/scores.py +7 -7
  51. euroeval/speed_benchmark.py +3 -5
  52. euroeval/task_group_utils/multiple_choice_classification.py +0 -3
  53. euroeval/task_group_utils/question_answering.py +0 -3
  54. euroeval/task_group_utils/sequence_classification.py +43 -31
  55. euroeval/task_group_utils/text_to_text.py +17 -8
  56. euroeval/task_group_utils/token_classification.py +10 -9
  57. euroeval/tokenisation_utils.py +14 -12
  58. euroeval/utils.py +29 -146
  59. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/METADATA +4 -4
  60. euroeval-16.4.0.dist-info/RECORD +75 -0
  61. euroeval-16.3.0.dist-info/RECORD +0 -71
  62. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
  63. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
  64. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,250 @@
1
+ """Utility functions related to logging."""
2
+
3
+ import datetime as dt
4
+ import logging
5
+ import os
6
+ import sys
7
+ import warnings
8
+ from io import TextIOWrapper
9
+
10
+ import litellm
11
+ from datasets.utils import disable_progress_bars as disable_datasets_progress_bars
12
+ from evaluate import disable_progress_bar as disable_evaluate_progress_bar
13
+ from huggingface_hub.utils.tqdm import (
14
+ disable_progress_bars as disable_hf_hub_progress_bars,
15
+ )
16
+ from termcolor import colored
17
+ from tqdm.auto import tqdm
18
+ from transformers import logging as tf_logging
19
+
20
+ from .caching_utils import cache_arguments
21
+
22
+ logger = logging.getLogger("euroeval")
23
+
24
+
25
+ def get_pbar(*tqdm_args, **tqdm_kwargs) -> tqdm:
26
+ """Get a progress bar for vLLM with custom hard-coded arguments.
27
+
28
+ Args:
29
+ *tqdm_args:
30
+ Positional arguments to pass to tqdm.
31
+ **tqdm_kwargs:
32
+ Additional keyword arguments to pass to tqdm.
33
+
34
+ Returns:
35
+ A tqdm progress bar.
36
+ """
37
+ tqdm_kwargs = dict(colour="yellow", ascii="—▰", leave=False) | tqdm_kwargs
38
+ tqdm_kwargs["desc"] = colored(
39
+ text=tqdm_kwargs.get("desc", "Processing"), color="light_yellow"
40
+ )
41
+ return tqdm(*tqdm_args, **tqdm_kwargs)
42
+
43
+
44
+ def log(message: str, level: int, colour: str | None = None) -> None:
45
+ """Log a message.
46
+
47
+ Args:
48
+ message:
49
+ The message to log.
50
+ level:
51
+ The logging level. Defaults to logging.INFO.
52
+ colour:
53
+ The colour to use for the message. If None, a default colour will be used
54
+ based on the logging level.
55
+
56
+ Raises:
57
+ ValueError:
58
+ If the logging level is invalid.
59
+ """
60
+ match level:
61
+ case logging.DEBUG:
62
+ message = colored(
63
+ text=(
64
+ "[DEBUG] "
65
+ + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
66
+ + f" · {message}"
67
+ ),
68
+ color=colour or "light_blue",
69
+ )
70
+ logger.debug(message)
71
+ case logging.INFO:
72
+ if colour is not None:
73
+ message = colored(text=message, color=colour)
74
+ logger.info(message)
75
+ case logging.WARNING:
76
+ message = colored(text=message, color=colour or "light_red")
77
+ logger.warning(message)
78
+ case logging.ERROR:
79
+ message = colored(text=message, color=colour or "red")
80
+ logger.error(message)
81
+ case logging.CRITICAL:
82
+ message = colored(text=message, color=colour or "red")
83
+ logger.critical(message)
84
+ case _:
85
+ raise ValueError(f"Invalid logging level: {level}")
86
+
87
+
88
+ @cache_arguments("message")
89
+ def log_once(message: str, level: int = logging.INFO, prefix: str = "") -> None:
90
+ """Log a message once.
91
+
92
+ This is ensured by caching the "message" argument and only logging it the first time
93
+ this function is called with that message.
94
+
95
+ Args:
96
+ message:
97
+ The message to log.
98
+ level:
99
+ The logging level. Defaults to logging.INFO.
100
+ prefix:
101
+ A prefix to add to the message, which is not considered when determining if
102
+ the message has been logged before.
103
+ """
104
+ log(message=prefix + message, level=level)
105
+
106
+
107
+ def block_terminal_output() -> None:
108
+ """Blocks libraries from writing output to the terminal.
109
+
110
+ This filters warnings from some libraries, sets the logging level to ERROR for some
111
+ libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
112
+ disables most of the logging from the `transformers` library.
113
+ """
114
+ if os.getenv("FULL_LOG") == "1":
115
+ return
116
+
117
+ # Ignore miscellaneous warnings
118
+ warnings.filterwarnings("ignore", category=UserWarning)
119
+ warnings.filterwarnings("ignore", category=FutureWarning)
120
+ logging.getLogger("absl").setLevel(logging.CRITICAL)
121
+
122
+ # Disable matplotlib logging
123
+ logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
124
+
125
+ # Disable PyTorch logging
126
+ logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
127
+ warnings.filterwarnings(action="ignore", module="torch*")
128
+ os.environ["TORCH_LOGS"] = "-all"
129
+
130
+ # Disable huggingface_hub logging
131
+ logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
132
+ disable_hf_hub_progress_bars()
133
+
134
+ # Disable LiteLLM logging
135
+ logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
136
+ logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
137
+ logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
138
+ logging.getLogger("openai").setLevel(logging.CRITICAL)
139
+ logging.getLogger("httpx").setLevel(logging.CRITICAL)
140
+ litellm.suppress_debug_info = True
141
+
142
+ # Disable vLLM logging
143
+ logging.getLogger("vllm").setLevel(logging.CRITICAL)
144
+ logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
145
+ logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
146
+ logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
147
+ logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
148
+ logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
149
+ logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
150
+ logging.CRITICAL
151
+ )
152
+ os.environ["LOG_LEVEL"] = "CRITICAL"
153
+ os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
154
+
155
+ # Disable flashinfer logging
156
+ os.environ["FLASHINFER_LOGGING_LEVEL"] = "CRITICAL"
157
+
158
+ # Disable datasets logging
159
+ logging.getLogger("datasets").setLevel(logging.CRITICAL)
160
+ logging.getLogger("filelock").setLevel(logging.CRITICAL)
161
+ disable_datasets_progress_bars()
162
+
163
+ # Disable evaluate logging
164
+ warnings.filterwarnings("ignore", module="seqeval*")
165
+ disable_evaluate_progress_bar()
166
+
167
+ # Disable most of the `transformers` logging
168
+ tf_logging._default_log_level = logging.CRITICAL
169
+ tf_logging.set_verbosity(logging.CRITICAL)
170
+ logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
171
+ logging.getLogger("accelerate").setLevel(logging.CRITICAL)
172
+
173
+
174
+ class no_terminal_output:
175
+ """Context manager that suppresses all terminal output."""
176
+
177
+ def __init__(self, disable: bool = False) -> None:
178
+ """Initialise the context manager.
179
+
180
+ Args:
181
+ disable:
182
+ If True, this context manager does nothing.
183
+ """
184
+ self.disable = disable
185
+ self.nothing_file: TextIOWrapper | None = None
186
+ self._cpp_stdout_file: int | None = None
187
+ self._cpp_stderr_file: int | None = None
188
+ try:
189
+ self._cpp_stdout_file = os.dup(sys.stdout.fileno())
190
+ self._cpp_stderr_file = os.dup(sys.stderr.fileno())
191
+ except OSError:
192
+ self._log_windows_warning()
193
+
194
+ def _log_windows_warning(self) -> None:
195
+ """Log a warning about Windows not supporting blocking terminal output."""
196
+ log_once(
197
+ "Your operating system (probably Windows) does not support blocking "
198
+ "terminal output, so expect more messy output - sorry!",
199
+ level=logging.WARNING,
200
+ )
201
+
202
+ def __enter__(self) -> None:
203
+ """Suppress all terminal output."""
204
+ if not self.disable:
205
+ self.nothing_file = open(os.devnull, "w")
206
+ try:
207
+ os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stdout.fileno())
208
+ os.dup2(fd=self.nothing_file.fileno(), fd2=sys.stderr.fileno())
209
+ except OSError:
210
+ self._log_windows_warning()
211
+
212
+ def __exit__(
213
+ self,
214
+ exc_type: type[BaseException] | None,
215
+ exc_val: BaseException | None,
216
+ exc_tb: type[BaseException] | None,
217
+ ) -> None:
218
+ """Re-enable terminal output."""
219
+ if not self.disable:
220
+ if self.nothing_file is not None:
221
+ self.nothing_file.close()
222
+ try:
223
+ if self._cpp_stdout_file is not None:
224
+ os.dup2(fd=self._cpp_stdout_file, fd2=sys.stdout.fileno())
225
+ if self._cpp_stderr_file is not None:
226
+ os.dup2(fd=self._cpp_stderr_file, fd2=sys.stderr.fileno())
227
+ except OSError:
228
+ self._log_windows_warning()
229
+
230
+
231
+ def adjust_logging_level(verbose: bool, ignore_testing: bool = False) -> int:
232
+ """Adjust the logging level based on verbosity.
233
+
234
+ Args:
235
+ verbose:
236
+ Whether to output additional output.
237
+ ignore_testing:
238
+ Whether to ignore the testing flag.
239
+
240
+ Returns:
241
+ The logging level that was set.
242
+ """
243
+ if hasattr(sys, "_called_from_test") and not ignore_testing:
244
+ logging_level = logging.CRITICAL
245
+ elif verbose:
246
+ logging_level = logging.DEBUG
247
+ else:
248
+ logging_level = logging.INFO
249
+ logger.setLevel(logging_level)
250
+ return logging_level
euroeval/metrics/base.py CHANGED
@@ -2,7 +2,6 @@
2
2
 
3
3
  import abc
4
4
  import collections.abc as c
5
- import logging
6
5
  import typing as t
7
6
 
8
7
  if t.TYPE_CHECKING:
@@ -10,8 +9,6 @@ if t.TYPE_CHECKING:
10
9
 
11
10
  from ..data_models import BenchmarkConfig, DatasetConfig
12
11
 
13
- logger: logging.Logger = logging.getLogger("euroeval")
14
-
15
12
 
16
13
  class Metric(abc.ABC):
17
14
  """Abstract base class for all metrics."""
@@ -1,7 +1,6 @@
1
1
  """All the Hugging Face metrics used in EuroEval."""
2
2
 
3
3
  import collections.abc as c
4
- import logging
5
4
  import typing as t
6
5
  from pathlib import Path
7
6
 
@@ -9,7 +8,7 @@ import evaluate
9
8
  import numpy as np
10
9
  from datasets import DownloadConfig
11
10
 
12
- from ..utils import HiddenPrints
11
+ from ..logging_utils import no_terminal_output
13
12
  from .base import Metric
14
13
 
15
14
  if t.TYPE_CHECKING:
@@ -18,8 +17,6 @@ if t.TYPE_CHECKING:
18
17
 
19
18
  from ..data_models import BenchmarkConfig, DatasetConfig
20
19
 
21
- logger: logging.Logger = logging.getLogger("euroeval")
22
-
23
20
 
24
21
  class HuggingFaceMetric(Metric):
25
22
  """A metric which is implemented in the `evaluate` package.
@@ -126,7 +123,7 @@ class HuggingFaceMetric(Metric):
126
123
 
127
124
  assert self.metric is not None
128
125
 
129
- with HiddenPrints():
126
+ with no_terminal_output(disable=benchmark_config.verbose):
130
127
  results = self.metric.compute(
131
128
  predictions=predictions, references=references, **self.compute_kwargs
132
129
  )
@@ -145,6 +142,13 @@ class HuggingFaceMetric(Metric):
145
142
 
146
143
  return score
147
144
 
145
+ def __del__(self) -> None:
146
+ """Clean up the metric from memory."""
147
+ if self.metric is not None:
148
+ if self.metric.writer is not None:
149
+ self.metric.writer.close()
150
+ del self.metric
151
+
148
152
 
149
153
  mcc_metric = HuggingFaceMetric(
150
154
  name="mcc",
@@ -8,6 +8,7 @@ from pathlib import Path
8
8
  from pydantic import BaseModel, Field
9
9
 
10
10
  from ..exceptions import InvalidBenchmark
11
+ from ..logging_utils import log
11
12
  from ..model_cache import ModelCache
12
13
  from ..utils import extract_json_dict_from_string
13
14
  from .base import Metric
@@ -17,8 +18,6 @@ if t.TYPE_CHECKING:
17
18
 
18
19
  from ..data_models import BenchmarkConfig, DatasetConfig
19
20
 
20
- logger: logging.Logger = logging.getLogger("euroeval")
21
-
22
21
 
23
22
  class LLMAsAJudgeMetric(Metric):
24
23
  """Use an LLM to judge the quality of the predictions."""
@@ -190,7 +189,10 @@ class LLMAsAJudgeMetric(Metric):
190
189
  # Calculate the scores using the scoring function
191
190
  scores = [self.scoring_fn(output) for output in outputs]
192
191
  if not scores:
193
- logger.warning(f"No scores were calculated for {self.pretty_name}.")
192
+ log(
193
+ f"No scores were calculated for {self.pretty_name}.",
194
+ level=logging.WARNING,
195
+ )
194
196
  return None
195
197
  return sum(scores) / len(scores)
196
198
 
@@ -11,6 +11,7 @@ import numpy as np
11
11
  from scipy.special import expit as sigmoid
12
12
 
13
13
  from ..exceptions import InvalidBenchmark
14
+ from ..logging_utils import log, no_terminal_output
14
15
  from ..utils import unscramble
15
16
  from .base import Metric
16
17
 
@@ -20,8 +21,6 @@ if t.TYPE_CHECKING:
20
21
 
21
22
  from ..data_models import BenchmarkConfig, DatasetConfig
22
23
 
23
- logger: logging.Logger = logging.getLogger("euroeval")
24
-
25
24
 
26
25
  T = t.TypeVar("T", bound=int | float | str | bool)
27
26
 
@@ -121,16 +120,22 @@ class PipelineMetric(Metric):
121
120
  The calculated metric score, or None if the score should be ignored.
122
121
  """
123
122
  if self.pipeline is None:
124
- self.pipeline = self._download_pipeline()
123
+ self.pipeline = self._download_pipeline(
124
+ cache_dir=benchmark_config.cache_dir
125
+ )
125
126
  if self.preprocessing_fn is not None:
126
127
  predictions = self.preprocessing_fn(
127
128
  predictions=predictions, dataset=dataset
128
129
  )
129
130
  return self.pipeline_scoring_function(self.pipeline, predictions)
130
131
 
131
- def _download_pipeline(self) -> "Pipeline":
132
+ def _download_pipeline(self, cache_dir: str) -> "Pipeline":
132
133
  """Download the scikit-learn pipeline from the given URL.
133
134
 
135
+ Args:
136
+ cache_dir:
137
+ The directory to use for caching the downloaded pipeline.
138
+
134
139
  Returns:
135
140
  The downloaded scikit-learn pipeline.
136
141
 
@@ -138,10 +143,13 @@ class PipelineMetric(Metric):
138
143
  InvalidBenchmark:
139
144
  If the loading of the pipeline fails for any reason.
140
145
  """
141
- logger.debug(f"Loading pipeline from {self.pipeline_repo}...")
142
- folder_path = hf_hub.HfApi(
143
- token=unscramble("HjccJFhIozVymqXDVqTUTXKvYhZMTbfIjMxG_")
144
- ).snapshot_download(repo_id=self.pipeline_repo, repo_type="model")
146
+ log(f"Loading pipeline from {self.pipeline_repo}...", level=logging.DEBUG)
147
+ with no_terminal_output():
148
+ folder_path = hf_hub.HfApi(
149
+ token=unscramble("XbjeOLhwebEaSaDUMqqaPaPIhgOcyOfDpGnX_")
150
+ ).snapshot_download(
151
+ repo_id=self.pipeline_repo, repo_type="model", cache_dir=cache_dir
152
+ )
145
153
  model_path = Path(folder_path, self.pipeline_file_name)
146
154
  try:
147
155
  with model_path.open(mode="rb") as f:
@@ -150,7 +158,7 @@ class PipelineMetric(Metric):
150
158
  raise InvalidBenchmark(
151
159
  f"Failed to load pipeline from {self.pipeline_repo!r}: {e}"
152
160
  ) from e
153
- logger.debug(f"Successfully loaded pipeline: {pipeline}")
161
+ log(f"Successfully loaded pipeline: {pipeline}", level=logging.DEBUG)
154
162
  return pipeline
155
163
 
156
164
 
euroeval/metrics/speed.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """Inference speed metric."""
2
2
 
3
3
  import collections.abc as c
4
- import logging
5
4
  import typing as t
6
5
 
7
6
  from .base import Metric
@@ -11,8 +10,6 @@ if t.TYPE_CHECKING:
11
10
 
12
11
  from ..data_models import BenchmarkConfig, DatasetConfig
13
12
 
14
- logger: logging.Logger = logging.getLogger("euroeval")
15
-
16
13
 
17
14
  class SpeedMetric(Metric):
18
15
  """Speed metric."""
euroeval/model_cache.py CHANGED
@@ -8,11 +8,9 @@ import typing as t
8
8
  from collections import defaultdict
9
9
  from dataclasses import asdict
10
10
 
11
- from tqdm.auto import tqdm
12
-
13
11
  from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
14
12
  from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
15
- from .utils import log_once
13
+ from .logging_utils import get_pbar, log, log_once
16
14
 
17
15
  if t.TYPE_CHECKING:
18
16
  from pathlib import Path
@@ -20,9 +18,6 @@ if t.TYPE_CHECKING:
20
18
  from datasets import Dataset
21
19
 
22
20
 
23
- logger = logging.getLogger("euroeval")
24
-
25
-
26
21
  class ModelCache:
27
22
  """A cache for model outputs.
28
23
 
@@ -65,9 +60,10 @@ class ModelCache:
65
60
  with self.cache_path.open() as f:
66
61
  json_cache = json.load(f)
67
62
  except json.JSONDecodeError:
68
- logger.warning(
63
+ log(
69
64
  f"Failed to load the cache from {self.cache_path}. The cache will be "
70
- f"re-initialised."
65
+ f"re-initialised.",
66
+ level=logging.WARNING,
71
67
  )
72
68
  json_cache = dict()
73
69
  with self.cache_path.open("w") as f:
@@ -89,9 +85,10 @@ class ModelCache:
89
85
  with self.cache_path.open("w") as f:
90
86
  json.dump(dumpable_cache, f)
91
87
  except KeyError:
92
- logger.warning(
88
+ log(
93
89
  f"Failed to load the cache from {self.cache_path}. The cache will be "
94
- f"re-initialised."
90
+ f"re-initialised.",
91
+ level=logging.WARNING,
95
92
  )
96
93
  self.cache = dict()
97
94
  with self.cache_path.open("w") as f:
@@ -172,18 +169,18 @@ class ModelCache:
172
169
 
173
170
  # Double check that the number of inputs and outputs match
174
171
  if not len(model_inputs) == len(model_output.sequences):
175
- logger.warning(
172
+ log(
176
173
  f"Number of model inputs ({len(model_inputs)}) does not match the "
177
174
  f"number of model outputs ({len(model_output.sequences)}). We will not "
178
- f"cache the model outputs."
175
+ f"cache the model outputs.",
176
+ level=logging.WARNING,
179
177
  )
180
178
  return
181
179
 
182
180
  # Store the generated sequences in the cache, one by one
183
- with tqdm(
181
+ with get_pbar(
184
182
  iterable=model_inputs,
185
183
  desc="Caching model outputs",
186
- leave=False,
187
184
  disable=hasattr(sys, "_called_from_test"),
188
185
  ) as pbar:
189
186
  for sample_idx, model_input in enumerate(pbar):
euroeval/model_config.py CHANGED
@@ -5,14 +5,12 @@ import typing as t
5
5
 
6
6
  from . import benchmark_modules
7
7
  from .exceptions import InvalidModel, NeedsEnvironmentVariable, NeedsExtraInstalled
8
+ from .logging_utils import log
8
9
 
9
10
  if t.TYPE_CHECKING:
10
11
  from .data_models import BenchmarkConfig, ModelConfig
11
12
 
12
13
 
13
- logger = logging.getLogger("euroeval")
14
-
15
-
16
14
  def get_model_config(
17
15
  model_id: str, benchmark_config: "BenchmarkConfig"
18
16
  ) -> "ModelConfig":
@@ -51,9 +49,10 @@ def get_model_config(
51
49
  elif isinstance(exists_or_err, NeedsEnvironmentVariable):
52
50
  needs_env_vars.append(exists_or_err.env_var)
53
51
  elif exists_or_err is True:
54
- logger.debug(
52
+ log(
55
53
  f"The model {model_id!r} was identified by the "
56
- f"{benchmark_module.__name__} benchmark module."
54
+ f"{benchmark_module.__name__} benchmark module.",
55
+ logging.DEBUG,
57
56
  )
58
57
  model_config = benchmark_module.get_model_config(
59
58
  model_id=model_id, benchmark_config=benchmark_config
euroeval/model_loading.py CHANGED
@@ -10,6 +10,7 @@ from .benchmark_modules import (
10
10
  )
11
11
  from .enums import InferenceBackend, ModelType
12
12
  from .exceptions import InvalidModel
13
+ from .logging_utils import log_once
13
14
 
14
15
  if t.TYPE_CHECKING:
15
16
  from .benchmark_modules import BenchmarkModule
@@ -34,6 +35,8 @@ def load_model(
34
35
  Returns:
35
36
  The model.
36
37
  """
38
+ log_once(f"Loading the model {model_config.model_id}...")
39
+
37
40
  # The order matters; the first model type that matches will be used. For this
38
41
  # reason, they have been ordered in terms of the most common model types.
39
42
  model_class: t.Type[BenchmarkModule]
@@ -4,6 +4,7 @@ import typing as t
4
4
 
5
5
  from ..data_models import PromptConfig
6
6
  from ..languages import (
7
+ CS,
7
8
  DA,
8
9
  DE,
9
10
  EN,
@@ -22,6 +23,7 @@ from ..languages import (
22
23
  NO,
23
24
  PL,
24
25
  PT,
26
+ SK,
25
27
  SV,
26
28
  )
27
29
 
@@ -29,6 +31,13 @@ if t.TYPE_CHECKING:
29
31
  from ..data_models import Language
30
32
 
31
33
  LA_TEMPLATES: dict["Language", PromptConfig] = {
34
+ CS: PromptConfig(
35
+ default_prompt_label_mapping=dict(correct="ano", incorrect="ne"),
36
+ default_prompt_prefix="Následující jsou věty a zda jsou gramaticky správné.",
37
+ default_prompt_template="Věta: {text}\nGramaticky správná: {label}",
38
+ default_instruction_prompt="Věta: {text}\n\nUrčete, zda je věta gramaticky "
39
+ "správná nebo ne. Odpovězte {labels_str}, a nic jiné.",
40
+ ),
32
41
  DA: PromptConfig(
33
42
  default_prompt_label_mapping=dict(correct="ja", incorrect="nej"),
34
43
  default_prompt_prefix="Følgende er sætninger og om de er grammatisk korrekte.",
@@ -71,11 +80,11 @@ LA_TEMPLATES: dict["Language", PromptConfig] = {
71
80
  ),
72
81
  PL: PromptConfig(
73
82
  default_prompt_label_mapping=dict(correct="tak", incorrect="nie"),
74
- default_prompt_prefix="Poniżej znajdują się teksty i czy są "
83
+ default_prompt_prefix="Poniżej znajdują się teksty i informacja, czy są "
75
84
  "gramatycznie poprawne.",
76
85
  default_prompt_template="Tekst: {text}\nGramatycznie poprawny: {label}",
77
- default_instruction_prompt="Tekst: {text}\n\nOkreśl czy tekst jest "
78
- "gramatycznie poprawny czy nie. Odpowiedz {labels_str}, i nic więcej.",
86
+ default_instruction_prompt="Tekst: {text}\n\nOkreśl, czy tekst jest "
87
+ "gramatycznie poprawny. Odpowiedz używając wyłącznie {labels_str}.",
79
88
  ),
80
89
  PT: PromptConfig(
81
90
  default_prompt_label_mapping=dict(correct="sim", incorrect="não"),
@@ -174,6 +183,15 @@ LA_TEMPLATES: dict["Language", PromptConfig] = {
174
183
  default_instruction_prompt="Setning: {text}\n\nBestem om setningen er "
175
184
  "grammatisk korrekt eller ikke. Svar med {labels_str}, og ikke noe annet.",
176
185
  ),
186
+ SK: PromptConfig(
187
+ default_prompt_label_mapping=dict(correct="áno", incorrect="nie"),
188
+ default_prompt_prefix="Nasledujú vety a či sú gramaticky správne.",
189
+ default_prompt_template="Veta: {text}\nGramaticky správna: {label}",
190
+ default_instruction_prompt=(
191
+ "Veta: {text}\n\nUrčite, či je veta gramaticky správna alebo nie. "
192
+ "Odpovedzte so {labels_str}, a nič iné."
193
+ ),
194
+ ),
177
195
  SV: PromptConfig(
178
196
  default_prompt_label_mapping=dict(correct="ja", incorrect="nej"),
179
197
  default_prompt_prefix="Följande är meningar och huruvida de är grammatiskt "
@@ -4,6 +4,7 @@ import typing as t
4
4
 
5
5
  from ..data_models import PromptConfig
6
6
  from ..languages import (
7
+ CS,
7
8
  DA,
8
9
  DE,
9
10
  EN,
@@ -21,6 +22,7 @@ from ..languages import (
21
22
  NO,
22
23
  PL,
23
24
  PT,
25
+ SK,
24
26
  SV,
25
27
  )
26
28
 
@@ -29,6 +31,17 @@ if t.TYPE_CHECKING:
29
31
 
30
32
  # TODO: Missing Faroese
31
33
  MULTIPLE_CHOICE_TEMPLATES: dict["Language", PromptConfig] = {
34
+ CS: PromptConfig(
35
+ default_prompt_prefix=(
36
+ "Následující jsou otázky s výběrem z více možností (s odpověďmi)."
37
+ ),
38
+ default_prompt_template="Otázka: {text}\nOdpověď: {label}",
39
+ default_instruction_prompt=(
40
+ "Otázka: {text}\n\nOdpovězte na výše uvedenou otázku "
41
+ "pomocí {labels_str}, a nic jiného."
42
+ ),
43
+ default_prompt_label_mapping="auto",
44
+ ),
32
45
  DA: PromptConfig(
33
46
  default_prompt_prefix="Følgende er multiple choice spørgsmål (med svar).",
34
47
  default_prompt_template="Spørgsmål: {text}\nSvar: {label}",
@@ -155,7 +168,18 @@ MULTIPLE_CHOICE_TEMPLATES: dict["Language", PromptConfig] = {
155
168
  "(z odpowiedziami).",
156
169
  default_prompt_template="Pytanie: {text}\nOdpowiedź: {label}",
157
170
  default_instruction_prompt="Pytanie: {text}\n\nOdpowiedz na powyższe pytanie, "
158
- "odpowiadając {labels_str}, i nic więcej.",
171
+ "używając {labels_str} i niczego więcej.",
172
+ default_prompt_label_mapping="auto",
173
+ ),
174
+ SK: PromptConfig(
175
+ default_prompt_prefix=(
176
+ "Nasledujú otázky s viacerými možnosťami (s odpoveďami)."
177
+ ),
178
+ default_prompt_template="Otázka: {text}\nOdpoveď: {label}",
179
+ default_instruction_prompt=(
180
+ "Otázka: {text}\n\n"
181
+ "Odpovedzte na nasledujúcu otázku použitím {labels_str}, a nič iné."
182
+ ),
159
183
  default_prompt_label_mapping="auto",
160
184
  ),
161
185
  SV: PromptConfig(