EuroEval 16.3.0__py3-none-any.whl → 16.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (64) hide show
  1. euroeval/__init__.py +3 -2
  2. euroeval/benchmark_config_factory.py +0 -4
  3. euroeval/benchmark_modules/base.py +3 -16
  4. euroeval/benchmark_modules/fresh.py +2 -1
  5. euroeval/benchmark_modules/hf.py +99 -62
  6. euroeval/benchmark_modules/litellm.py +101 -41
  7. euroeval/benchmark_modules/vllm.py +91 -83
  8. euroeval/benchmarker.py +84 -78
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/constants.py +6 -0
  12. euroeval/data_loading.py +14 -11
  13. euroeval/data_models.py +12 -4
  14. euroeval/dataset_configs/__init__.py +2 -0
  15. euroeval/dataset_configs/czech.py +79 -0
  16. euroeval/dataset_configs/danish.py +10 -11
  17. euroeval/dataset_configs/dutch.py +0 -1
  18. euroeval/dataset_configs/english.py +0 -1
  19. euroeval/dataset_configs/estonian.py +11 -1
  20. euroeval/dataset_configs/finnish.py +0 -1
  21. euroeval/dataset_configs/french.py +0 -1
  22. euroeval/dataset_configs/german.py +0 -1
  23. euroeval/dataset_configs/italian.py +0 -1
  24. euroeval/dataset_configs/latvian.py +0 -1
  25. euroeval/dataset_configs/lithuanian.py +9 -3
  26. euroeval/dataset_configs/norwegian.py +0 -1
  27. euroeval/dataset_configs/polish.py +0 -1
  28. euroeval/dataset_configs/portuguese.py +0 -1
  29. euroeval/dataset_configs/slovak.py +60 -0
  30. euroeval/dataset_configs/spanish.py +0 -1
  31. euroeval/dataset_configs/swedish.py +10 -12
  32. euroeval/finetuning.py +21 -15
  33. euroeval/generation.py +10 -10
  34. euroeval/generation_utils.py +2 -3
  35. euroeval/logging_utils.py +250 -0
  36. euroeval/metrics/base.py +0 -3
  37. euroeval/metrics/huggingface.py +9 -5
  38. euroeval/metrics/llm_as_a_judge.py +5 -3
  39. euroeval/metrics/pipeline.py +17 -9
  40. euroeval/metrics/speed.py +0 -3
  41. euroeval/model_cache.py +11 -14
  42. euroeval/model_config.py +4 -5
  43. euroeval/model_loading.py +3 -0
  44. euroeval/prompt_templates/linguistic_acceptability.py +21 -3
  45. euroeval/prompt_templates/multiple_choice.py +25 -1
  46. euroeval/prompt_templates/named_entity_recognition.py +51 -11
  47. euroeval/prompt_templates/reading_comprehension.py +31 -3
  48. euroeval/prompt_templates/sentiment_classification.py +23 -1
  49. euroeval/prompt_templates/summarization.py +26 -6
  50. euroeval/scores.py +7 -7
  51. euroeval/speed_benchmark.py +3 -5
  52. euroeval/task_group_utils/multiple_choice_classification.py +0 -3
  53. euroeval/task_group_utils/question_answering.py +0 -3
  54. euroeval/task_group_utils/sequence_classification.py +43 -31
  55. euroeval/task_group_utils/text_to_text.py +17 -8
  56. euroeval/task_group_utils/token_classification.py +10 -9
  57. euroeval/tokenisation_utils.py +14 -12
  58. euroeval/utils.py +29 -146
  59. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/METADATA +4 -4
  60. euroeval-16.4.0.dist-info/RECORD +75 -0
  61. euroeval-16.3.0.dist-info/RECORD +0 -71
  62. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
  63. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
  64. {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
euroeval/benchmarker.py CHANGED
@@ -1,10 +1,11 @@
1
1
  """Class that benchmarks language models."""
2
2
 
3
3
  import contextlib
4
+ import datetime as dt
4
5
  import json
5
6
  import logging
7
+ import os
6
8
  import re
7
- import sys
8
9
  import typing as t
9
10
  from pathlib import Path
10
11
  from shutil import rmtree
@@ -12,7 +13,6 @@ from time import sleep
12
13
 
13
14
  from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER
14
15
  from torch.distributed import destroy_process_group
15
- from tqdm.auto import tqdm
16
16
 
17
17
  from .benchmark_config_factory import build_benchmark_config
18
18
  from .constants import GENERATIVE_PIPELINE_TAGS
@@ -23,6 +23,7 @@ from .enums import Device, GenerativeType, ModelType
23
23
  from .exceptions import HuggingFaceHubDown, InvalidBenchmark, InvalidModel
24
24
  from .finetuning import finetune
25
25
  from .generation import generate
26
+ from .logging_utils import adjust_logging_level, get_pbar, log, log_once
26
27
  from .model_config import get_model_config
27
28
  from .model_loading import load_model
28
29
  from .scores import log_scores
@@ -32,7 +33,6 @@ from .utils import (
32
33
  enforce_reproducibility,
33
34
  get_package_version,
34
35
  internet_connection_available,
35
- log_once,
36
36
  split_model_id,
37
37
  )
38
38
 
@@ -41,9 +41,6 @@ if t.TYPE_CHECKING:
41
41
  from .data_models import BenchmarkConfig, DatasetConfig, ModelConfig
42
42
 
43
43
 
44
- logger = logging.getLogger("euroeval")
45
-
46
-
47
44
  class Benchmarker:
48
45
  """Benchmarking all the language models.
49
46
 
@@ -200,6 +197,10 @@ class Benchmarker:
200
197
  "Try installing it with `pip install hf_transfer`."
201
198
  )
202
199
 
200
+ # If FULL_LOG has been set, then force verbose mode
201
+ if os.getenv("FULL_LOG", "0") == "1":
202
+ verbose = True
203
+
203
204
  self.benchmark_config_default_params = BenchmarkConfigParams(
204
205
  task=task,
205
206
  dataset=dataset,
@@ -301,7 +302,6 @@ class Benchmarker:
301
302
  )
302
303
  del dataset
303
304
 
304
- log_once(f"Loading model {model_config.model_id}", level=logging.INFO)
305
305
  model = load_model(
306
306
  model_config=model_config,
307
307
  dataset_config=dataset_config,
@@ -611,7 +611,7 @@ class Benchmarker:
611
611
 
612
612
  # Get all the model configs
613
613
  model_configs: list[ModelConfig] = list()
614
- for model_id in tqdm(
614
+ for model_id in get_pbar(
615
615
  iterable=model_ids,
616
616
  desc="Fetching model configurations",
617
617
  disable=not benchmark_config.verbose or not benchmark_config.progress_bar,
@@ -622,7 +622,7 @@ class Benchmarker:
622
622
  )
623
623
  model_configs.append(model_config)
624
624
  except InvalidModel as e:
625
- logger.info(e.message)
625
+ log(e.message, level=logging.ERROR)
626
626
 
627
627
  # Create a dictionary that takes each model config to the dataset configs that
628
628
  # we need to benchmark the model on. Here we remove the datasets that the model
@@ -651,21 +651,22 @@ class Benchmarker:
651
651
  for dataset_configs in model_config_to_dataset_configs.values()
652
652
  )
653
653
  if total_benchmarks == 0:
654
- logger.info(
654
+ log(
655
655
  "No benchmarks to run, as all the selected models have already been "
656
- "benchmarked on all the selected datasets."
656
+ "benchmarked on all the selected datasets.",
657
+ level=logging.INFO,
657
658
  )
658
659
  return list()
659
660
 
660
- logger.info(f"Initiated evaluation of {total_benchmarks:,} benchmarks.")
661
-
662
661
  num_finished_benchmarks = 0
663
662
  current_benchmark_results: list[BenchmarkResult] = list()
663
+ benchmark_params_to_revert: dict[str, t.Any] = dict()
664
664
  for model_config in model_configs:
665
665
  if not model_config_to_dataset_configs[model_config]:
666
- logger.debug(
666
+ log(
667
667
  f"Skipping model {model_config.model_id!r} because it has "
668
- "already been benchmarked on all valid datasets."
668
+ "already been benchmarked on all valid datasets.",
669
+ level=logging.DEBUG,
669
670
  )
670
671
  continue
671
672
 
@@ -691,7 +692,6 @@ class Benchmarker:
691
692
  )
692
693
 
693
694
  loaded_model: BenchmarkModule | None = None
694
- benchmark_params_to_revert: dict[str, t.Any] = dict()
695
695
  for dataset_config in model_config_to_dataset_configs[model_config]:
696
696
  # Revert any changes to the benchmark configuration made for the
697
697
  # previous dataset
@@ -704,18 +704,20 @@ class Benchmarker:
704
704
  "val" not in dataset_config.splits
705
705
  and not benchmark_config.evaluate_test_split
706
706
  ):
707
- logger.debug(
707
+ log(
708
708
  "The dataset does not have a validation split, so even though "
709
709
  "you requested evaluating the validation split (the default), "
710
- "we will evaluate on the test split."
710
+ "we will evaluate on the test split.",
711
+ level=logging.DEBUG,
711
712
  )
712
713
  benchmark_params_to_revert["evaluate_test_split"] = False
713
714
  benchmark_config.evaluate_test_split = True
714
715
  if dataset_config.task.requires_zero_shot and benchmark_config.few_shot:
715
- logger.debug(
716
+ log(
716
717
  "The task requires zero-shot evaluation, so even though you "
717
718
  "requested few-shot evaluation (the default), we will evaluate "
718
- "zero-shot."
719
+ "zero-shot.",
720
+ level=logging.DEBUG,
719
721
  )
720
722
  benchmark_params_to_revert["few_shot"] = True
721
723
  benchmark_config.few_shot = False
@@ -723,13 +725,7 @@ class Benchmarker:
723
725
  # We do not re-initialise generative models as their architecture is not
724
726
  # customised to specific datasets
725
727
  if model_config.model_type == ModelType.GENERATIVE:
726
- initial_logging(
727
- model_config=model_config,
728
- dataset_config=dataset_config,
729
- benchmark_config=benchmark_config,
730
- )
731
728
  if loaded_model is None:
732
- logger.info("Loading model...")
733
729
  try:
734
730
  loaded_model = load_model(
735
731
  model_config=model_config,
@@ -739,7 +735,7 @@ class Benchmarker:
739
735
  except InvalidModel as e:
740
736
  if benchmark_config.raise_errors:
741
737
  raise e
742
- logger.info(e.message)
738
+ log(e.message, level=logging.ERROR)
743
739
 
744
740
  # Add the remaining number of benchmarks for the model to
745
741
  # our benchmark counter, since we're skipping the rest of
@@ -759,12 +755,13 @@ class Benchmarker:
759
755
  loaded_model.generative_type
760
756
  not in dataset_config.allowed_generative_types
761
757
  ):
762
- logger.debug(
758
+ log(
763
759
  f"Skipping the benchmark of model "
764
760
  f"{model_config.model_id!r}on dataset "
765
761
  f"{dataset_config.name!r} because the model has generative "
766
762
  f"type {loaded_model.generative_type} and the dataset "
767
- f"only allows {dataset_config.allowed_generative_types}."
763
+ f"only allows {dataset_config.allowed_generative_types}.",
764
+ level=logging.DEBUG,
768
765
  )
769
766
  num_finished_benchmarks += 1
770
767
  continue
@@ -775,6 +772,8 @@ class Benchmarker:
775
772
  model_config=model_config,
776
773
  dataset_config=dataset_config,
777
774
  benchmark_config=benchmark_config,
775
+ num_finished_benchmarks=num_finished_benchmarks,
776
+ num_total_benchmarks=total_benchmarks,
778
777
  )
779
778
 
780
779
  if (
@@ -784,12 +783,12 @@ class Benchmarker:
784
783
  raise benchmark_output_or_err
785
784
 
786
785
  elif isinstance(benchmark_output_or_err, InvalidBenchmark):
787
- logger.info(benchmark_output_or_err.message)
786
+ log(benchmark_output_or_err.message, level=logging.WARNING)
788
787
  num_finished_benchmarks += 1
789
788
  continue
790
789
 
791
790
  elif isinstance(benchmark_output_or_err, InvalidModel):
792
- logger.info(benchmark_output_or_err.message)
791
+ log(benchmark_output_or_err.message, level=logging.WARNING)
793
792
 
794
793
  # Add the remaining number of benchmarks for the model to our
795
794
  # benchmark counter, since we're skipping the rest of them
@@ -805,15 +804,13 @@ class Benchmarker:
805
804
  record.append_to_results(results_path=self.results_path)
806
805
 
807
806
  num_finished_benchmarks += 1
808
- logger.info(
809
- f"Finished {num_finished_benchmarks} out of "
810
- f"{total_benchmarks} benchmarks."
811
- )
812
807
 
813
808
  del loaded_model
814
809
  if benchmark_config.clear_model_cache:
815
810
  clear_model_cache_fn(cache_dir=benchmark_config.cache_dir)
816
811
 
812
+ log(f"Completed {num_finished_benchmarks:,} benchmarks.\n", level=logging.INFO)
813
+
817
814
  # This avoids the following warning at the end of the benchmarking:
818
815
  # Warning: WARNING: process group has NOT been destroyed before we destruct
819
816
  # ProcessGroupNCCL. On normal program exit, the application should call
@@ -857,6 +854,8 @@ class Benchmarker:
857
854
  model_config: "ModelConfig",
858
855
  dataset_config: "DatasetConfig",
859
856
  benchmark_config: "BenchmarkConfig",
857
+ num_finished_benchmarks: int,
858
+ num_total_benchmarks: int,
860
859
  ) -> BenchmarkResult | InvalidBenchmark | InvalidModel:
861
860
  """Benchmark a single model on a single dataset.
862
861
 
@@ -869,25 +868,29 @@ class Benchmarker:
869
868
  The configuration of the dataset we are evaluating on.
870
869
  benchmark_config:
871
870
  The general benchmark configuration.
871
+ num_finished_benchmarks:
872
+ The number of benchmarks that have already been completed.
873
+ num_total_benchmarks:
874
+ The total number of benchmarks to be completed.
872
875
 
873
876
  Returns:
874
877
  The benchmark result, or an error if the benchmark was unsuccessful.
875
- """
876
- if model is None:
877
- initial_logging(
878
- model_config=model_config,
879
- dataset_config=dataset_config,
880
- benchmark_config=benchmark_config,
881
- )
882
878
 
883
- while True:
879
+ Raises:
880
+ RuntimeError:
881
+ If the MPS fallback is not enabled when required.
882
+ InvalidBenchmark:
883
+ If the benchmark was unsuccessful.
884
+ InvalidModel:
885
+ If the model is invalid.
886
+ """
887
+ for _ in range(num_attempts := 5):
884
888
  try:
885
889
  # Set random seeds to enforce reproducibility of the randomly
886
890
  # initialised weights
887
891
  rng = enforce_reproducibility()
888
892
 
889
893
  if model is None or model_config.model_type != ModelType.GENERATIVE:
890
- logger.info("Loading model...")
891
894
  model = load_model(
892
895
  model_config=model_config,
893
896
  dataset_config=dataset_config,
@@ -895,6 +898,14 @@ class Benchmarker:
895
898
  )
896
899
  assert model is not None
897
900
 
901
+ initial_logging(
902
+ model_config=model_config,
903
+ dataset_config=dataset_config,
904
+ benchmark_config=benchmark_config,
905
+ num_finished_benchmarks=num_finished_benchmarks,
906
+ num_total_benchmarks=num_total_benchmarks,
907
+ )
908
+
898
909
  if dataset_config.task == SPEED:
899
910
  scores = benchmark_speed(
900
911
  model=model, benchmark_config=benchmark_config
@@ -962,14 +973,15 @@ class Benchmarker:
962
973
  few_shot=benchmark_config.few_shot,
963
974
  validation_split=not benchmark_config.evaluate_test_split,
964
975
  )
965
- logger.debug(f"Results:\n{results}")
976
+ log(f"Results:\n{results}", level=logging.DEBUG)
966
977
  return record
967
978
 
968
979
  except HuggingFaceHubDown:
969
980
  wait_time = 30
970
- logger.debug(
981
+ log(
971
982
  f"The Hugging Face Hub seems to be down. Retrying in {wait_time} "
972
- "seconds."
983
+ "seconds.",
984
+ level=logging.DEBUG,
973
985
  )
974
986
  sleep(wait_time)
975
987
  continue
@@ -992,12 +1004,18 @@ class Benchmarker:
992
1004
  elif benchmark_config.raise_errors:
993
1005
  raise e
994
1006
  return e
1007
+ else:
1008
+ return InvalidBenchmark(
1009
+ f"Failed to benchmark model {model_config.model_id!r} on dataset "
1010
+ f"{dataset_config.name!r} after {num_attempts} attempts."
1011
+ )
995
1012
 
996
1013
  def __call__(self, *args: t.Any, **kwds: t.Any) -> t.Any: # noqa: ANN401
997
1014
  """Alias for `self.benchmark()`."""
998
- logger.warning(
1015
+ log(
999
1016
  "Calling the `Benchmarker` class directly is deprecated. Please use the "
1000
- "`benchmark` function instead. This will be removed in a future version."
1017
+ "`benchmark` function instead. This will be removed in a future version.",
1018
+ level=logging.WARNING,
1001
1019
  )
1002
1020
  return self.benchmark(*args, **kwds)
1003
1021
 
@@ -1050,28 +1068,6 @@ def model_has_been_benchmarked(
1050
1068
  return False
1051
1069
 
1052
1070
 
1053
- def adjust_logging_level(verbose: bool, ignore_testing: bool = False) -> int:
1054
- """Adjust the logging level based on verbosity.
1055
-
1056
- Args:
1057
- verbose:
1058
- Whether to output additional output.
1059
- ignore_testing:
1060
- Whether to ignore the testing flag.
1061
-
1062
- Returns:
1063
- The logging level that was set.
1064
- """
1065
- if hasattr(sys, "_called_from_test") and not ignore_testing:
1066
- logging_level = logging.CRITICAL
1067
- elif verbose:
1068
- logging_level = logging.DEBUG
1069
- else:
1070
- logging_level = logging.INFO
1071
- logger.setLevel(logging_level)
1072
- return logging_level
1073
-
1074
-
1075
1071
  def clear_model_cache_fn(cache_dir: str) -> None:
1076
1072
  """Clear the model cache.
1077
1073
 
@@ -1109,6 +1105,8 @@ def initial_logging(
1109
1105
  model_config: "ModelConfig",
1110
1106
  dataset_config: "DatasetConfig",
1111
1107
  benchmark_config: "BenchmarkConfig",
1108
+ num_finished_benchmarks: int,
1109
+ num_total_benchmarks: int,
1112
1110
  ) -> None:
1113
1111
  """Initial logging at the start of the benchmarking process.
1114
1112
 
@@ -1119,6 +1117,10 @@ def initial_logging(
1119
1117
  The configuration of the dataset we are evaluating on.
1120
1118
  benchmark_config:
1121
1119
  The general benchmark configuration.
1120
+ num_finished_benchmarks:
1121
+ The number of benchmarks that have already been finished.
1122
+ num_total_benchmarks:
1123
+ The total number of benchmarks to be run.
1122
1124
  """
1123
1125
  model_id = model_config.model_id
1124
1126
  if model_config.revision and model_config.revision != "main":
@@ -1135,21 +1137,25 @@ def initial_logging(
1135
1137
  else:
1136
1138
  eval_type = "Benchmarking"
1137
1139
 
1138
- logger.info(
1139
- f"{eval_type} {model_id} on the {split_type} split of "
1140
- f"{dataset_config.pretty_name}"
1140
+ log_once(
1141
+ f"\n{eval_type} {model_id} on the {split_type} split of "
1142
+ f"{dataset_config.pretty_name} ({num_finished_benchmarks + 1}/"
1143
+ f"{num_total_benchmarks} benchmarks)...",
1144
+ prefix=f"\n[{dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]",
1141
1145
  )
1142
1146
 
1143
1147
  if dataset_config.unofficial:
1144
- logger.info(
1148
+ log_once(
1145
1149
  f"Note that the {dataset_config.name!r} dataset is unofficial, "
1146
1150
  "meaning that the resulting evaluation will not be included in the "
1147
- "official leaderboard."
1151
+ "official leaderboard.",
1152
+ level=logging.WARNING,
1148
1153
  )
1149
1154
 
1150
1155
  if benchmark_config.debug:
1151
- logger.info(
1156
+ log_once(
1152
1157
  "Running in debug mode. This will output additional information, as "
1153
1158
  "well as store the model outputs in the current directory after each "
1154
- "batch. For this reason, evaluation will be slower."
1159
+ "batch. For this reason, evaluation will be slower.",
1160
+ level=logging.WARNING,
1155
1161
  )
@@ -0,0 +1,79 @@
1
+ """Caching utility functions."""
2
+
3
+ import typing as t
4
+ from functools import wraps
5
+
6
+ from .constants import T
7
+
8
+
9
+ def cache_arguments(
10
+ *arguments: str, disable_condition: t.Callable[[], bool] = lambda: False
11
+ ) -> t.Callable[[t.Callable[..., T]], t.Callable[..., T]]:
12
+ """Cache specified arguments of a function.
13
+
14
+ Args:
15
+ arguments:
16
+ The list of argument names to cache. If empty, all arguments are cached.
17
+ disable_condition:
18
+ A function that checks if cache should be disabled.
19
+
20
+ Returns:
21
+ A decorator that caches the specified arguments of a function.
22
+ """
23
+
24
+ def caching_decorator(func: t.Callable[..., T]) -> t.Callable[..., T]:
25
+ """Decorator that caches the specified arguments of a function.
26
+
27
+ Args:
28
+ func:
29
+ The function to decorate.
30
+
31
+ Returns:
32
+ The decorated function.
33
+ """
34
+ cache: dict[tuple, T] = dict()
35
+
36
+ @wraps(func)
37
+ def wrapper(*args, **kwargs) -> T:
38
+ """Wrapper function that caches the specified arguments.
39
+
40
+ Args:
41
+ *args:
42
+ The positional arguments to the function.
43
+ **kwargs:
44
+ The keyword arguments to the function.
45
+
46
+ Returns:
47
+ The result of the function.
48
+
49
+ Raises:
50
+ ValueError:
51
+ If an argument name is not found in the function parameters.
52
+ """
53
+ if not arguments:
54
+ key = args + tuple(kwargs[k] for k in sorted(kwargs.keys()))
55
+ else:
56
+ func_params = func.__code__.co_varnames
57
+ key_items: list[t.Any] = []
58
+ for arg_name in arguments:
59
+ if arg_name in kwargs:
60
+ key_items.append(kwargs[arg_name])
61
+ else:
62
+ try:
63
+ arg_index = func_params.index(arg_name)
64
+ key_items.append(args[arg_index])
65
+ except (ValueError, IndexError):
66
+ raise ValueError(
67
+ f"Argument {arg_name} not found in function "
68
+ f"{func.__name__} parameters."
69
+ )
70
+ key = tuple(key_items)
71
+
72
+ # Do not cache if the condition is met
73
+ if key not in cache or disable_condition():
74
+ cache[key] = func(*args, **kwargs)
75
+ return cache[key]
76
+
77
+ return wrapper
78
+
79
+ return caching_decorator
euroeval/callbacks.py CHANGED
@@ -7,6 +7,8 @@ from collections.abc import Sized
7
7
  from tqdm.auto import tqdm
8
8
  from transformers.trainer_callback import ProgressCallback
9
9
 
10
+ from .logging_utils import get_pbar
11
+
10
12
  if t.TYPE_CHECKING:
11
13
  from torch.utils.data import DataLoader
12
14
  from transformers.trainer_callback import TrainerControl, TrainerState
@@ -32,11 +34,8 @@ class NeverLeaveProgressCallback(ProgressCallback):
32
34
  """Callback actions when training begins."""
33
35
  if state.is_local_process_zero:
34
36
  desc = "Finetuning model"
35
- self.training_bar = tqdm(
36
- total=None,
37
- leave=False,
38
- desc=desc,
39
- disable=hasattr(sys, "_called_from_test"),
37
+ self.training_bar = get_pbar(
38
+ total=None, desc=desc, disable=hasattr(sys, "_called_from_test")
40
39
  )
41
40
  self.current_step = 0
42
41
 
@@ -67,9 +66,8 @@ class NeverLeaveProgressCallback(ProgressCallback):
67
66
  if state.is_local_process_zero and correct_dtype:
68
67
  if self.prediction_bar is None:
69
68
  desc = "Evaluating model"
70
- self.prediction_bar = tqdm(
69
+ self.prediction_bar = get_pbar(
71
70
  total=len(eval_dataloader),
72
- leave=False,
73
71
  desc=desc,
74
72
  disable=hasattr(sys, "_called_from_test"),
75
73
  )
euroeval/constants.py CHANGED
@@ -1,7 +1,13 @@
1
1
  """Constants used throughout the project."""
2
2
 
3
+ from typing import TypeVar
4
+
3
5
  from .enums import TaskGroup
4
6
 
7
+ # Type variable used for generic typing
8
+ T = TypeVar("T", bound=object)
9
+
10
+
5
11
  # This is used as input to generative models; it cannot be a special token
6
12
  DUMMY_FILL_VALUE = 100
7
13
 
euroeval/data_loading.py CHANGED
@@ -12,6 +12,7 @@ from huggingface_hub.errors import HfHubHTTPError
12
12
  from numpy.random import Generator
13
13
 
14
14
  from .exceptions import HuggingFaceHubDown, InvalidBenchmark
15
+ from .logging_utils import log, no_terminal_output
15
16
  from .tasks import EUROPEAN_VALUES
16
17
  from .utils import unscramble
17
18
 
@@ -20,8 +21,6 @@ if t.TYPE_CHECKING:
20
21
 
21
22
  from .data_models import BenchmarkConfig, DatasetConfig
22
23
 
23
- logger = logging.getLogger("euroeval")
24
-
25
24
 
26
25
  def load_data(
27
26
  rng: Generator, dataset_config: "DatasetConfig", benchmark_config: "BenchmarkConfig"
@@ -106,11 +105,12 @@ def load_raw_data(dataset_config: "DatasetConfig", cache_dir: str) -> "DatasetDi
106
105
  num_attempts = 5
107
106
  for _ in range(num_attempts):
108
107
  try:
109
- dataset = load_dataset(
110
- path=dataset_config.huggingface_id,
111
- cache_dir=cache_dir,
112
- token=unscramble("HjccJFhIozVymqXDVqTUTXKvYhZMTbfIjMxG_"),
113
- )
108
+ with no_terminal_output():
109
+ dataset = load_dataset(
110
+ path=dataset_config.huggingface_id,
111
+ cache_dir=cache_dir,
112
+ token=unscramble("XbjeOLhwebEaSaDUMqqaPaPIhgOcyOfDpGnX_"),
113
+ )
114
114
  break
115
115
  except (
116
116
  FileNotFoundError,
@@ -118,9 +118,11 @@ def load_raw_data(dataset_config: "DatasetConfig", cache_dir: str) -> "DatasetDi
118
118
  DatasetsError,
119
119
  requests.ConnectionError,
120
120
  requests.ReadTimeout,
121
- ):
122
- logger.debug(
123
- f"Failed to load dataset {dataset_config.huggingface_id!r}. Retrying..."
121
+ ) as e:
122
+ log(
123
+ f"Failed to load dataset {dataset_config.huggingface_id!r}, due to "
124
+ f"the following error: {e}. Retrying...",
125
+ level=logging.DEBUG,
124
126
  )
125
127
  time.sleep(1)
126
128
  continue
@@ -129,7 +131,8 @@ def load_raw_data(dataset_config: "DatasetConfig", cache_dir: str) -> "DatasetDi
129
131
  else:
130
132
  raise InvalidBenchmark(
131
133
  f"Failed to load dataset {dataset_config.huggingface_id!r} after "
132
- f"{num_attempts} attempts."
134
+ f"{num_attempts} attempts. Run with verbose mode to see the individual "
135
+ "errors."
133
136
  )
134
137
  assert isinstance(dataset, DatasetDict) # type: ignore[used-before-def]
135
138
  missing_keys = [key for key in dataset_config.splits if key not in dataset]
euroeval/data_models.py CHANGED
@@ -558,14 +558,14 @@ class DatasetConfig:
558
558
  )
559
559
 
560
560
  @property
561
- def id2label(self) -> dict[int, str]:
561
+ def id2label(self) -> "HashableDict":
562
562
  """The mapping from ID to label."""
563
- return {idx: label for idx, label in enumerate(self.labels)}
563
+ return HashableDict({idx: label for idx, label in enumerate(self.labels)})
564
564
 
565
565
  @property
566
- def label2id(self) -> dict[str, int]:
566
+ def label2id(self) -> "HashableDict":
567
567
  """The mapping from label to ID."""
568
- return {label: i for i, label in enumerate(self.labels)}
568
+ return HashableDict({label: i for i, label in enumerate(self.labels)})
569
569
 
570
570
  @property
571
571
  def num_labels(self) -> int:
@@ -783,3 +783,11 @@ class ModelIdComponents:
783
783
  model_id: str
784
784
  revision: str
785
785
  param: str | None
786
+
787
+
788
+ class HashableDict(dict):
789
+ """A hashable dictionary."""
790
+
791
+ def __hash__(self) -> int: # type: ignore[override]
792
+ """Return the hash of the dictionary."""
793
+ return hash(frozenset(self.items()))
@@ -3,6 +3,7 @@
3
3
  from ..data_models import DatasetConfig
4
4
  from ..languages import get_all_languages
5
5
  from ..tasks import SPEED
6
+ from .czech import * # noqa: F403
6
7
  from .danish import * # noqa: F403
7
8
  from .dutch import * # noqa: F403
8
9
  from .english import * # noqa: F403
@@ -18,6 +19,7 @@ from .lithuanian import * # noqa: F403
18
19
  from .norwegian import * # noqa: F403
19
20
  from .polish import * # noqa: F403
20
21
  from .portuguese import * # noqa: F403
22
+ from .slovak import * # noqa: F403
21
23
  from .spanish import * # noqa: F403
22
24
  from .swedish import * # noqa: F403
23
25