EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,268 @@
1
+ """Utility functions related to logging."""
2
+
3
+ import datetime as dt
4
+ import logging
5
+ import os
6
+ import sys
7
+ import warnings
8
+ from io import TextIOWrapper
9
+
10
+ import litellm
11
+ from datasets.utils import disable_progress_bars as disable_datasets_progress_bars
12
+ from evaluate import disable_progress_bar as disable_evaluate_progress_bar
13
+ from huggingface_hub.utils.tqdm import (
14
+ disable_progress_bars as disable_hf_hub_progress_bars,
15
+ )
16
+ from termcolor import colored
17
+ from tqdm.auto import tqdm
18
+ from transformers import logging as tf_logging
19
+
20
+ from .caching_utils import cache_arguments
21
+
22
+ logger = logging.getLogger("euroeval")
23
+
24
+
25
+ def get_pbar(*tqdm_args, **tqdm_kwargs) -> tqdm:
26
+ """Get a progress bar for vLLM with custom hard-coded arguments.
27
+
28
+ Args:
29
+ *tqdm_args:
30
+ Positional arguments to pass to tqdm.
31
+ **tqdm_kwargs:
32
+ Additional keyword arguments to pass to tqdm.
33
+
34
+ Returns:
35
+ A tqdm progress bar.
36
+ """
37
+ tqdm_kwargs = dict(colour="yellow", ascii="—▰", leave=False) | tqdm_kwargs
38
+ tqdm_kwargs["desc"] = colored(
39
+ text=tqdm_kwargs.get("desc", "Processing"), color="light_yellow"
40
+ )
41
+ return tqdm(*tqdm_args, **tqdm_kwargs)
42
+
43
+
44
+ def log(message: str, level: int, colour: str | None = None) -> None:
45
+ """Log a message.
46
+
47
+ Args:
48
+ message:
49
+ The message to log.
50
+ level:
51
+ The logging level. Defaults to logging.INFO.
52
+ colour:
53
+ The colour to use for the message. If None, a default colour will be used
54
+ based on the logging level.
55
+
56
+ Raises:
57
+ ValueError:
58
+ If the logging level is invalid.
59
+ """
60
+ match level:
61
+ case logging.DEBUG:
62
+ message = colored(
63
+ text=(
64
+ "[DEBUG] "
65
+ + dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
66
+ + f" · {message}"
67
+ ),
68
+ color=colour or "light_blue",
69
+ )
70
+ logger.debug(message)
71
+ case logging.INFO:
72
+ if colour is not None:
73
+ message = colored(text=message, color=colour)
74
+ logger.info(message)
75
+ case logging.WARNING:
76
+ message = colored(text=message, color=colour or "light_red")
77
+ logger.warning(message)
78
+ case logging.ERROR:
79
+ message = colored(text=message, color=colour or "red")
80
+ logger.error(message)
81
+ case logging.CRITICAL:
82
+ message = colored(text=message, color=colour or "red")
83
+ logger.critical(message)
84
+ case _:
85
+ raise ValueError(f"Invalid logging level: {level}")
86
+
87
+
88
+ @cache_arguments("message")
89
+ def log_once(message: str, level: int = logging.INFO, prefix: str = "") -> None:
90
+ """Log a message once.
91
+
92
+ This is ensured by caching the "message" argument and only logging it the first time
93
+ this function is called with that message.
94
+
95
+ Args:
96
+ message:
97
+ The message to log.
98
+ level:
99
+ The logging level. Defaults to logging.INFO.
100
+ prefix:
101
+ A prefix to add to the message, which is not considered when determining if
102
+ the message has been logged before.
103
+ """
104
+ log(message=prefix + message, level=level)
105
+
106
+
107
+ def block_terminal_output() -> None:
108
+ """Blocks libraries from writing output to the terminal.
109
+
110
+ This filters warnings from some libraries, sets the logging level to ERROR for some
111
+ libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
112
+ disables most of the logging from the `transformers` library.
113
+ """
114
+ if os.getenv("FULL_LOG") == "1":
115
+ return
116
+
117
+ # Ignore miscellaneous warnings
118
+ warnings.filterwarnings("ignore", category=UserWarning)
119
+ warnings.filterwarnings("ignore", category=FutureWarning)
120
+ logging.getLogger("absl").setLevel(logging.CRITICAL)
121
+
122
+ # Disable matplotlib logging
123
+ logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
124
+
125
+ # Disable PyTorch logging
126
+ logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
127
+ warnings.filterwarnings(action="ignore", module="torch*")
128
+ os.environ["TORCH_LOGS"] = "-all"
129
+
130
+ # Disable huggingface_hub logging
131
+ logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
132
+ disable_hf_hub_progress_bars()
133
+
134
+ # Disable LiteLLM logging
135
+ logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
136
+ logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
137
+ logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
138
+ logging.getLogger("openai").setLevel(logging.CRITICAL)
139
+ logging.getLogger("httpx").setLevel(logging.CRITICAL)
140
+ litellm.suppress_debug_info = True
141
+
142
+ # Disable vLLM logging
143
+ logging.getLogger("vllm").setLevel(logging.CRITICAL)
144
+ logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
145
+ logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
146
+ logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
147
+ logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
148
+ logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
149
+ logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
150
+ logging.CRITICAL
151
+ )
152
+ os.environ["LOG_LEVEL"] = "CRITICAL"
153
+ os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
154
+
155
+ # Disable flashinfer logging
156
+ os.environ["FLASHINFER_LOGGING_LEVEL"] = "CRITICAL"
157
+
158
+ # Disable datasets logging
159
+ logging.getLogger("datasets").setLevel(logging.CRITICAL)
160
+ logging.getLogger("filelock").setLevel(logging.CRITICAL)
161
+ disable_datasets_progress_bars()
162
+
163
+ # Disable evaluate logging
164
+ warnings.filterwarnings("ignore", module="seqeval*")
165
+ disable_evaluate_progress_bar()
166
+
167
+ # Disable most of the `transformers` logging
168
+ tf_logging._default_log_level = logging.CRITICAL
169
+ tf_logging.set_verbosity(logging.CRITICAL)
170
+ logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
171
+ logging.getLogger("accelerate").setLevel(logging.CRITICAL)
172
+
173
+
174
+ class no_terminal_output:
175
+ """Context manager that suppresses all terminal output."""
176
+
177
+ def __init__(self, disable: bool = False) -> None:
178
+ """Initialise the context manager.
179
+
180
+ Args:
181
+ disable:
182
+ If True, this context manager does nothing.
183
+ """
184
+ self.disable = disable
185
+ self.devnull_file: TextIOWrapper | None = None
186
+ self._original_stdout_fd: int | None = None
187
+ self._original_stderr_fd: int | None = None
188
+
189
+ def _log_windows_warning(self) -> None:
190
+ """Log a warning about Windows not supporting blocking terminal output."""
191
+ log_once(
192
+ "Your operating system (probably Windows) does not support blocking "
193
+ "terminal output, so expect more messy output - sorry!",
194
+ level=logging.WARNING,
195
+ )
196
+
197
+ def __enter__(self) -> None:
198
+ """Suppress all terminal output."""
199
+ if self.disable:
200
+ return
201
+
202
+ try:
203
+ # Save original FDs by duplicating them
204
+ self._original_stdout_fd = os.dup(sys.stdout.fileno())
205
+ self._original_stderr_fd = os.dup(sys.stderr.fileno())
206
+
207
+ # Open /dev/null
208
+ self.devnull_file = open(os.devnull, "w")
209
+
210
+ # Redirect stdout/stderr to /dev/null
211
+ os.dup2(self.devnull_file.fileno(), sys.stdout.fileno())
212
+ os.dup2(self.devnull_file.fileno(), sys.stderr.fileno())
213
+
214
+ except OSError:
215
+ self._log_windows_warning()
216
+ # If setup fails, clean up any resources we might have acquired
217
+ self.__exit__(None, None, None)
218
+
219
+ def __exit__(
220
+ self,
221
+ exc_type: type[BaseException] | None,
222
+ exc_val: BaseException | None,
223
+ exc_tb: type[BaseException] | None,
224
+ ) -> None:
225
+ """Re-enable terminal output."""
226
+ if self.disable:
227
+ return
228
+
229
+ # Restore stdout/stderr from our saved FDs
230
+ try:
231
+ if self._original_stdout_fd is not None:
232
+ os.dup2(self._original_stdout_fd, sys.stdout.fileno())
233
+ if self._original_stderr_fd is not None:
234
+ os.dup2(self._original_stderr_fd, sys.stderr.fileno())
235
+ except OSError:
236
+ self._log_windows_warning()
237
+ finally:
238
+ # Close the duplicated FDs we created
239
+ if self._original_stdout_fd is not None:
240
+ os.close(self._original_stdout_fd)
241
+ if self._original_stderr_fd is not None:
242
+ os.close(self._original_stderr_fd)
243
+
244
+ # Close the /dev/null file
245
+ if self.devnull_file is not None:
246
+ self.devnull_file.close()
247
+
248
+
249
+ def adjust_logging_level(verbose: bool, ignore_testing: bool = False) -> int:
250
+ """Adjust the logging level based on verbosity.
251
+
252
+ Args:
253
+ verbose:
254
+ Whether to output additional output.
255
+ ignore_testing:
256
+ Whether to ignore the testing flag.
257
+
258
+ Returns:
259
+ The logging level that was set.
260
+ """
261
+ if hasattr(sys, "_called_from_test") and not ignore_testing:
262
+ logging_level = logging.CRITICAL
263
+ elif verbose:
264
+ logging_level = logging.DEBUG
265
+ else:
266
+ logging_level = logging.INFO
267
+ logger.setLevel(logging_level)
268
+ return logging_level
@@ -0,0 +1,6 @@
1
+ """All the metrics used in EuroEval."""
2
+
3
+ from .huggingface import * # noqa: F403
4
+ from .llm_as_a_judge import * # noqa: F403
5
+ from .pipeline import * # noqa: F403
6
+ from .speed import * # noqa: F403
@@ -0,0 +1,85 @@
1
+ """The abstract base class for all metrics."""
2
+
3
+ import abc
4
+ import collections.abc as c
5
+ import typing as t
6
+
7
+ if t.TYPE_CHECKING:
8
+ from datasets.arrow_dataset import Dataset
9
+
10
+ from ..data_models import BenchmarkConfig, DatasetConfig
11
+
12
+
13
+ class Metric(abc.ABC):
14
+ """Abstract base class for all metrics."""
15
+
16
+ def __init__(
17
+ self,
18
+ name: str,
19
+ pretty_name: str,
20
+ postprocessing_fn: t.Callable[[float], tuple[float, str]] | None = None,
21
+ ) -> None:
22
+ """Initialise the metric.
23
+
24
+ Args:
25
+ name:
26
+ The name of the metric in snake_case.
27
+ pretty_name:
28
+ The pretty name of the metric, used for display purposes.
29
+ postprocessing_fn:
30
+ A function to apply to the metric scores after they are computed,
31
+ taking the score to the postprocessed score along with its string
32
+ representation. Defaults to x -> (100 * x, f"{x:.2%}").
33
+ """
34
+ self.name = name
35
+ self.pretty_name = pretty_name
36
+ self.postprocessing_fn = (
37
+ postprocessing_fn
38
+ if postprocessing_fn is not None
39
+ else lambda x: (100 * x, f"{x:.2%}")
40
+ )
41
+
42
+ def download(self, cache_dir: str) -> "Metric":
43
+ """Initiates the download of the metric if needed.
44
+
45
+ Args:
46
+ cache_dir:
47
+ The directory where the metric will be downloaded to.
48
+
49
+ Returns:
50
+ The metric object itself.
51
+ """
52
+ return self
53
+
54
+ @abc.abstractmethod
55
+ def __call__(
56
+ self,
57
+ predictions: c.Sequence,
58
+ references: c.Sequence,
59
+ dataset: "Dataset",
60
+ dataset_config: "DatasetConfig",
61
+ benchmark_config: "BenchmarkConfig",
62
+ ) -> float | None:
63
+ """Calculate the metric score.
64
+
65
+ Args:
66
+ predictions:
67
+ The model predictions.
68
+ references:
69
+ The ground truth references.
70
+ dataset:
71
+ The dataset used for evaluation. This is only used in case any
72
+ additional metadata is used to compute the metrics.
73
+ dataset_config:
74
+ The dataset configuration.
75
+ benchmark_config:
76
+ The benchmark configuration.
77
+
78
+ Returns:
79
+ The calculated metric score, or None if the score should be ignored.
80
+ """
81
+ ...
82
+
83
+ def __hash__(self) -> int:
84
+ """Return a hash of the metric configuration."""
85
+ return hash(self.name)
@@ -0,0 +1,216 @@
1
+ """All the Hugging Face metrics used in EuroEval."""
2
+
3
+ import collections.abc as c
4
+ import typing as t
5
+ from pathlib import Path
6
+
7
+ import evaluate
8
+ import numpy as np
9
+ from datasets import DownloadConfig, DownloadMode
10
+
11
+ from ..logging_utils import no_terminal_output
12
+ from .base import Metric
13
+
14
+ if t.TYPE_CHECKING:
15
+ from datasets.arrow_dataset import Dataset
16
+ from evaluate import EvaluationModule
17
+
18
+ from ..data_models import BenchmarkConfig, DatasetConfig
19
+
20
+
21
+ class HuggingFaceMetric(Metric):
22
+ """A metric which is implemented in the `evaluate` package.
23
+
24
+ Attributes:
25
+ name:
26
+ The name of the metric in snake_case.
27
+ pretty_name:
28
+ The pretty name of the metric, used for display purposes.
29
+ huggingface_id:
30
+ The Hugging Face ID of the metric.
31
+ results_key:
32
+ The name of the key used to extract the metric scores from the results
33
+ dictionary.
34
+ compute_kwargs:
35
+ Keyword arguments to pass to the metric's compute function. Defaults to
36
+ an empty dictionary.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ name: str,
42
+ pretty_name: str,
43
+ huggingface_id: str,
44
+ results_key: str,
45
+ compute_kwargs: dict[str, t.Any] | None = None,
46
+ postprocessing_fn: t.Callable[[float], tuple[float, str]] | None = None,
47
+ ) -> None:
48
+ """Initialise the Hugging Face metric.
49
+
50
+ Args:
51
+ name:
52
+ The name of the metric in snake_case.
53
+ pretty_name:
54
+ The pretty name of the metric, used for display purposes.
55
+ huggingface_id:
56
+ The Hugging Face ID of the metric.
57
+ results_key:
58
+ The name of the key used to extract the metric scores from the results
59
+ dictionary.
60
+ compute_kwargs:
61
+ Keyword arguments to pass to the metric's compute function. Defaults to
62
+ an empty dictionary.
63
+ postprocessing_fn:
64
+ A function to apply to the metric scores after they are computed, taking
65
+ the score to the postprocessed score along with its string
66
+ representation. Defaults to x -> (100 * x, f"{x:.2%}").
67
+ """
68
+ super().__init__(
69
+ name=name, pretty_name=pretty_name, postprocessing_fn=postprocessing_fn
70
+ )
71
+ self.huggingface_id = huggingface_id
72
+ self.results_key = results_key
73
+ self.compute_kwargs: dict[str, t.Any] = (
74
+ dict() if compute_kwargs is None else compute_kwargs
75
+ )
76
+ self.metric: "EvaluationModule | None" = None
77
+
78
+ def download(self, cache_dir: str) -> "HuggingFaceMetric":
79
+ """Initiates the download of the metric if needed.
80
+
81
+ Args:
82
+ cache_dir:
83
+ The directory where the metric will be downloaded to.
84
+
85
+ Returns:
86
+ The metric object itself.
87
+ """
88
+ metric_cache_dir = Path(cache_dir) / "metrics"
89
+ download_config = DownloadConfig(cache_dir=metric_cache_dir)
90
+ self.metric = evaluate.load(
91
+ path=self.huggingface_id,
92
+ download_config=download_config,
93
+ download_mode=DownloadMode.REUSE_CACHE_IF_EXISTS,
94
+ cache_dir=metric_cache_dir.as_posix(),
95
+ )
96
+ return self
97
+
98
+ def __call__(
99
+ self,
100
+ predictions: c.Sequence,
101
+ references: c.Sequence,
102
+ dataset: "Dataset",
103
+ dataset_config: "DatasetConfig",
104
+ benchmark_config: "BenchmarkConfig",
105
+ ) -> float | None:
106
+ """Calculate the metric score.
107
+
108
+ Args:
109
+ predictions:
110
+ The model predictions.
111
+ references:
112
+ The ground truth references.
113
+ dataset:
114
+ The dataset used for evaluation. This is only used in case any
115
+ additional metadata is used to compute the metrics.
116
+ dataset_config:
117
+ The dataset configuration.
118
+ benchmark_config:
119
+ The benchmark configuration.
120
+
121
+ Returns:
122
+ The calculated metric score, or None if the score should be ignored.
123
+ """
124
+ if self.metric is None:
125
+ self.download(cache_dir=benchmark_config.cache_dir)
126
+
127
+ assert self.metric is not None, (
128
+ "Metric has not been downloaded. Please call download() before using the "
129
+ "__call__ method."
130
+ )
131
+
132
+ with no_terminal_output(disable=benchmark_config.verbose):
133
+ results = self.metric.compute(
134
+ predictions=predictions, references=references, **self.compute_kwargs
135
+ )
136
+
137
+ # The metric returns None if we are running on multi-GPU and the current
138
+ # process is not the main process
139
+ if results is None:
140
+ return None
141
+
142
+ # Convert the results to a float score
143
+ score = results[self.results_key]
144
+ if isinstance(score, list):
145
+ score = sum(score) / len(score)
146
+ if isinstance(score, np.floating):
147
+ score = float(score)
148
+
149
+ return score
150
+
151
+
152
+ mcc_metric = HuggingFaceMetric(
153
+ name="mcc",
154
+ pretty_name="Matthew's Correlation Coefficient",
155
+ huggingface_id="matthews_correlation",
156
+ results_key="matthews_correlation",
157
+ )
158
+
159
+ macro_f1_metric = HuggingFaceMetric(
160
+ name="macro_f1",
161
+ pretty_name="Macro-average F1-score",
162
+ huggingface_id="f1",
163
+ results_key="f1",
164
+ compute_kwargs=dict(average="macro"),
165
+ )
166
+
167
+ micro_f1_metric = HuggingFaceMetric(
168
+ name="micro_f1",
169
+ pretty_name="Micro-average F1-score with MISC tags",
170
+ huggingface_id="seqeval",
171
+ results_key="overall_f1",
172
+ )
173
+
174
+ micro_f1_no_misc_metric = HuggingFaceMetric(
175
+ name="micro_f1_no_misc",
176
+ pretty_name="Micro-average F1-score without MISC tags",
177
+ huggingface_id="seqeval",
178
+ results_key="overall_f1",
179
+ )
180
+
181
+ f1_metric = HuggingFaceMetric(
182
+ name="f1",
183
+ pretty_name="F1-score",
184
+ huggingface_id="squad_v2",
185
+ results_key="f1",
186
+ postprocessing_fn=lambda x: (x, f"{x:.2f}%"),
187
+ )
188
+
189
+ em_metric = HuggingFaceMetric(
190
+ name="em",
191
+ pretty_name="Exact Match",
192
+ huggingface_id="squad_v2",
193
+ results_key="exact",
194
+ postprocessing_fn=lambda x: (x, f"{x:.2f}%"),
195
+ )
196
+
197
+ bert_score_metric = HuggingFaceMetric(
198
+ name="bertscore",
199
+ pretty_name="BERTScore",
200
+ huggingface_id="bertscore",
201
+ results_key="f1",
202
+ compute_kwargs=dict(
203
+ model_type="microsoft/mdeberta-v3-base", device="auto", batch_size=1
204
+ ),
205
+ )
206
+
207
+ rouge_l_metric = HuggingFaceMetric(
208
+ name="rouge_l", pretty_name="ROUGE-L", huggingface_id="rouge", results_key="rougeL"
209
+ )
210
+
211
+ accuracy_metric = HuggingFaceMetric(
212
+ name="accuracy",
213
+ pretty_name="Accuracy",
214
+ huggingface_id="accuracy",
215
+ results_key="accuracy",
216
+ )