EuroEval 16.3.0__py3-none-any.whl → 16.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +3 -2
- euroeval/benchmark_config_factory.py +0 -4
- euroeval/benchmark_modules/base.py +3 -16
- euroeval/benchmark_modules/fresh.py +2 -1
- euroeval/benchmark_modules/hf.py +99 -62
- euroeval/benchmark_modules/litellm.py +101 -41
- euroeval/benchmark_modules/vllm.py +91 -83
- euroeval/benchmarker.py +84 -78
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/constants.py +6 -0
- euroeval/data_loading.py +14 -11
- euroeval/data_models.py +12 -4
- euroeval/dataset_configs/__init__.py +2 -0
- euroeval/dataset_configs/czech.py +79 -0
- euroeval/dataset_configs/danish.py +10 -11
- euroeval/dataset_configs/dutch.py +0 -1
- euroeval/dataset_configs/english.py +0 -1
- euroeval/dataset_configs/estonian.py +11 -1
- euroeval/dataset_configs/finnish.py +0 -1
- euroeval/dataset_configs/french.py +0 -1
- euroeval/dataset_configs/german.py +0 -1
- euroeval/dataset_configs/italian.py +0 -1
- euroeval/dataset_configs/latvian.py +0 -1
- euroeval/dataset_configs/lithuanian.py +9 -3
- euroeval/dataset_configs/norwegian.py +0 -1
- euroeval/dataset_configs/polish.py +0 -1
- euroeval/dataset_configs/portuguese.py +0 -1
- euroeval/dataset_configs/slovak.py +60 -0
- euroeval/dataset_configs/spanish.py +0 -1
- euroeval/dataset_configs/swedish.py +10 -12
- euroeval/finetuning.py +21 -15
- euroeval/generation.py +10 -10
- euroeval/generation_utils.py +2 -3
- euroeval/logging_utils.py +250 -0
- euroeval/metrics/base.py +0 -3
- euroeval/metrics/huggingface.py +9 -5
- euroeval/metrics/llm_as_a_judge.py +5 -3
- euroeval/metrics/pipeline.py +17 -9
- euroeval/metrics/speed.py +0 -3
- euroeval/model_cache.py +11 -14
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/linguistic_acceptability.py +21 -3
- euroeval/prompt_templates/multiple_choice.py +25 -1
- euroeval/prompt_templates/named_entity_recognition.py +51 -11
- euroeval/prompt_templates/reading_comprehension.py +31 -3
- euroeval/prompt_templates/sentiment_classification.py +23 -1
- euroeval/prompt_templates/summarization.py +26 -6
- euroeval/scores.py +7 -7
- euroeval/speed_benchmark.py +3 -5
- euroeval/task_group_utils/multiple_choice_classification.py +0 -3
- euroeval/task_group_utils/question_answering.py +0 -3
- euroeval/task_group_utils/sequence_classification.py +43 -31
- euroeval/task_group_utils/text_to_text.py +17 -8
- euroeval/task_group_utils/token_classification.py +10 -9
- euroeval/tokenisation_utils.py +14 -12
- euroeval/utils.py +29 -146
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/METADATA +4 -4
- euroeval-16.4.0.dist-info/RECORD +75 -0
- euroeval-16.3.0.dist-info/RECORD +0 -71
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/WHEEL +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""All Czech dataset configurations used in EuroEval."""
|
|
2
|
+
|
|
3
|
+
from ..data_models import DatasetConfig
|
|
4
|
+
from ..languages import CS
|
|
5
|
+
from ..tasks import COMMON_SENSE, KNOW, LA, NER, RC, SENT, SUMM
|
|
6
|
+
|
|
7
|
+
### Official datasets ###
|
|
8
|
+
|
|
9
|
+
CSFD_SENTIMENT_CONFIG = DatasetConfig(
|
|
10
|
+
name="csfd-sentiment",
|
|
11
|
+
pretty_name="the truncated version of the Czech sentiment classification dataset "
|
|
12
|
+
"CSFD Sentiment",
|
|
13
|
+
huggingface_id="EuroEval/csfd-sentiment-mini",
|
|
14
|
+
task=SENT,
|
|
15
|
+
languages=[CS],
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
CS_GEC_CONFIG = DatasetConfig(
|
|
19
|
+
name="cs-gec",
|
|
20
|
+
pretty_name="the truncated version of the Czech linguistic acceptability dataset "
|
|
21
|
+
"CS-GEC",
|
|
22
|
+
huggingface_id="EuroEval/cs-gec-mini",
|
|
23
|
+
task=LA,
|
|
24
|
+
languages=[CS],
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
PONER_CONFIG = DatasetConfig(
|
|
28
|
+
name="poner",
|
|
29
|
+
pretty_name="the truncated version of the Czech named entity recognition dataset "
|
|
30
|
+
"PONER",
|
|
31
|
+
huggingface_id="EuroEval/poner-mini",
|
|
32
|
+
task=NER,
|
|
33
|
+
languages=[CS],
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
SQAD_CONFIG = DatasetConfig(
|
|
37
|
+
name="sqad",
|
|
38
|
+
pretty_name="the truncated version of the Czech reading comprehension dataset SQAD",
|
|
39
|
+
huggingface_id="EuroEval/sqad-mini",
|
|
40
|
+
task=RC,
|
|
41
|
+
languages=[CS],
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
CZECH_NEWS_CONFIG = DatasetConfig(
|
|
45
|
+
name="czech-news",
|
|
46
|
+
pretty_name="the truncated version of the Czech summarisation dataset",
|
|
47
|
+
huggingface_id="EuroEval/czech-news-mini",
|
|
48
|
+
task=SUMM,
|
|
49
|
+
languages=[CS],
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
UMIMETO_QA_CONFIG = DatasetConfig(
|
|
53
|
+
name="umimeto-qa",
|
|
54
|
+
pretty_name="the Czech knowledge dataset UmimetoQA",
|
|
55
|
+
huggingface_id="EuroEval/umimeto-qa",
|
|
56
|
+
task=KNOW,
|
|
57
|
+
languages=[CS],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
HELLASWAG_CS_CONFIG = DatasetConfig(
|
|
61
|
+
name="hellaswag-cs",
|
|
62
|
+
pretty_name="the truncated version of the Czech common-sense reasoning dataset "
|
|
63
|
+
"HellaSwag-cs, translated from the English HellaSwag dataset",
|
|
64
|
+
huggingface_id="EuroEval/hellaswag-cs-mini",
|
|
65
|
+
task=COMMON_SENSE,
|
|
66
|
+
languages=[CS],
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
### Unofficial datasets ###
|
|
71
|
+
|
|
72
|
+
SCALA_CS_CONFIG = DatasetConfig(
|
|
73
|
+
name="scala-cs",
|
|
74
|
+
pretty_name="the Czech part of the linguistic acceptability dataset ScaLA",
|
|
75
|
+
huggingface_id="EuroEval/scala-cs",
|
|
76
|
+
task=LA,
|
|
77
|
+
languages=[CS],
|
|
78
|
+
unofficial=True,
|
|
79
|
+
)
|
|
@@ -32,11 +32,11 @@ DANSK_CONFIG = DatasetConfig(
|
|
|
32
32
|
languages=[DA],
|
|
33
33
|
)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
name="
|
|
37
|
-
pretty_name="the
|
|
38
|
-
"dataset
|
|
39
|
-
huggingface_id="EuroEval/
|
|
35
|
+
MULTI_WIKI_QA_DA_CONFIG = DatasetConfig(
|
|
36
|
+
name="multi-wiki-qa-da",
|
|
37
|
+
pretty_name="the truncated version of the Danish part of the reading "
|
|
38
|
+
"comprehension dataset MultiWikiQA",
|
|
39
|
+
huggingface_id="EuroEval/multi-wiki-qa-da-mini",
|
|
40
40
|
task=RC,
|
|
41
41
|
languages=[DA],
|
|
42
42
|
)
|
|
@@ -129,11 +129,11 @@ BELEBELE_DA_CONFIG = DatasetConfig(
|
|
|
129
129
|
unofficial=True,
|
|
130
130
|
)
|
|
131
131
|
|
|
132
|
-
|
|
133
|
-
name="
|
|
134
|
-
pretty_name="the
|
|
135
|
-
"
|
|
136
|
-
huggingface_id="EuroEval/
|
|
132
|
+
SCANDIQA_DA_CONFIG = DatasetConfig(
|
|
133
|
+
name="scandiqa-da",
|
|
134
|
+
pretty_name="the Danish part of the truncated version of the question answering "
|
|
135
|
+
"dataset ScandiQA",
|
|
136
|
+
huggingface_id="EuroEval/scandiqa-da-mini",
|
|
137
137
|
task=RC,
|
|
138
138
|
languages=[DA],
|
|
139
139
|
unofficial=True,
|
|
@@ -156,7 +156,6 @@ WINOGRANDE_DA_CONFIG = DatasetConfig(
|
|
|
156
156
|
huggingface_id="EuroEval/winogrande-da",
|
|
157
157
|
task=COMMON_SENSE,
|
|
158
158
|
languages=[DA],
|
|
159
|
-
splits=["train", "test"],
|
|
160
159
|
_labels=["a", "b"],
|
|
161
160
|
unofficial=True,
|
|
162
161
|
)
|
|
@@ -94,10 +94,20 @@ SCALA_ET_CONFIG = DatasetConfig(
|
|
|
94
94
|
|
|
95
95
|
EXAM_ET_CONFIG = DatasetConfig(
|
|
96
96
|
name="exam-et",
|
|
97
|
-
pretty_name="the Estonian knowledge
|
|
97
|
+
pretty_name="the Estonian knowledge dataset Exam-et",
|
|
98
98
|
huggingface_id="EuroEval/exam-et",
|
|
99
99
|
task=KNOW,
|
|
100
100
|
languages=[ET],
|
|
101
101
|
_labels=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o"],
|
|
102
102
|
unofficial=True,
|
|
103
103
|
)
|
|
104
|
+
|
|
105
|
+
MMLU_ET_CONFIG = DatasetConfig(
|
|
106
|
+
name="mmlu-et",
|
|
107
|
+
pretty_name="the truncated version of the Estonian knowledge dataset MMLU-et, "
|
|
108
|
+
"translated from the English MMLU dataset",
|
|
109
|
+
huggingface_id="EuroEval/mmlu-et-mini",
|
|
110
|
+
task=KNOW,
|
|
111
|
+
languages=[ET],
|
|
112
|
+
unofficial=True,
|
|
113
|
+
)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from ..data_models import DatasetConfig
|
|
4
4
|
from ..languages import LT
|
|
5
|
-
from ..tasks import COMMON_SENSE, KNOW, LA, NER, RC, SENT
|
|
5
|
+
from ..tasks import COMMON_SENSE, KNOW, LA, NER, RC, SENT, SUMM
|
|
6
6
|
|
|
7
7
|
### Official datasets ###
|
|
8
8
|
|
|
@@ -41,13 +41,20 @@ MULTI_WIKI_QA_LT_CONFIG = DatasetConfig(
|
|
|
41
41
|
languages=[LT],
|
|
42
42
|
)
|
|
43
43
|
|
|
44
|
+
LRYTAS_CONFIG = DatasetConfig(
|
|
45
|
+
name="lrytas",
|
|
46
|
+
pretty_name="the truncated version of the Lithuanian summarisation dataset Lrytas",
|
|
47
|
+
huggingface_id="EuroEval/lrytas-mini",
|
|
48
|
+
task=SUMM,
|
|
49
|
+
languages=[LT],
|
|
50
|
+
)
|
|
51
|
+
|
|
44
52
|
LT_HISTORY_CONFIG = DatasetConfig(
|
|
45
53
|
name="lt-history",
|
|
46
54
|
pretty_name="the Lithuanian knowledge dataset LT-History",
|
|
47
55
|
huggingface_id="EuroEval/lt-history",
|
|
48
56
|
task=KNOW,
|
|
49
57
|
languages=[LT],
|
|
50
|
-
splits=["train", "test"],
|
|
51
58
|
)
|
|
52
59
|
|
|
53
60
|
WINOGRANDE_LT_CONFIG = DatasetConfig(
|
|
@@ -57,6 +64,5 @@ WINOGRANDE_LT_CONFIG = DatasetConfig(
|
|
|
57
64
|
huggingface_id="EuroEval/winogrande-lt",
|
|
58
65
|
task=COMMON_SENSE,
|
|
59
66
|
languages=[LT],
|
|
60
|
-
splits=["train", "test"],
|
|
61
67
|
_labels=["a", "b"],
|
|
62
68
|
)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""All Slovak dataset configurations used in EuroEval."""
|
|
2
|
+
|
|
3
|
+
from ..data_models import DatasetConfig
|
|
4
|
+
from ..languages import SK
|
|
5
|
+
from ..tasks import COMMON_SENSE, KNOW, LA, NER, RC, SENT
|
|
6
|
+
|
|
7
|
+
### Official datasets ###
|
|
8
|
+
|
|
9
|
+
CSFD_SENTIMENT_SK_CONFIG = DatasetConfig(
|
|
10
|
+
name="csfd-sentiment-sk",
|
|
11
|
+
pretty_name="the truncated version of the Slovak sentiment classification dataset "
|
|
12
|
+
"CSFD-sentiment-sk",
|
|
13
|
+
huggingface_id="EuroEval/csfd-sentiment-sk-mini",
|
|
14
|
+
task=SENT,
|
|
15
|
+
languages=[SK],
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
SCALA_SK_CONFIG = DatasetConfig(
|
|
19
|
+
name="scala-sk",
|
|
20
|
+
pretty_name="the Slovak part of the linguistic acceptability dataset ScaLA",
|
|
21
|
+
huggingface_id="EuroEval/scala-sk",
|
|
22
|
+
task=LA,
|
|
23
|
+
languages=[SK],
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
UNER_SK_CONFIG = DatasetConfig(
|
|
27
|
+
name="uner-sk",
|
|
28
|
+
pretty_name="the truncated version of the Slovak named entity recognition dataset "
|
|
29
|
+
"UNER-sk",
|
|
30
|
+
huggingface_id="EuroEval/uner-sk-mini",
|
|
31
|
+
task=NER,
|
|
32
|
+
languages=[SK],
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
MULTI_WIKI_QA_SK_CONFIG = DatasetConfig(
|
|
36
|
+
name="multi-wiki-qa-sk",
|
|
37
|
+
pretty_name="the truncated version of the Slovak part of the reading comprehension "
|
|
38
|
+
"dataset MultiWikiQA",
|
|
39
|
+
huggingface_id="EuroEval/multi-wiki-qa-sk-mini",
|
|
40
|
+
task=RC,
|
|
41
|
+
languages=[SK],
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
MMLU_SK_CONFIG = DatasetConfig(
|
|
45
|
+
name="mmlu-sk",
|
|
46
|
+
pretty_name="the truncated version of the Slovak knowledge dataset MMLU-sk, "
|
|
47
|
+
"translated from the English MMLU dataset",
|
|
48
|
+
huggingface_id="EuroEval/mmlu-sk-mini",
|
|
49
|
+
task=KNOW,
|
|
50
|
+
languages=[SK],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
WINOGRANDE_SK_CONFIG = DatasetConfig(
|
|
54
|
+
name="winogrande-sk",
|
|
55
|
+
pretty_name="the Slovak common-sense reasoning dataset Winogrande-sk, translated "
|
|
56
|
+
"from the English Winogrande dataset",
|
|
57
|
+
huggingface_id="EuroEval/winogrande-sk",
|
|
58
|
+
task=COMMON_SENSE,
|
|
59
|
+
languages=[SK],
|
|
60
|
+
)
|
|
@@ -32,11 +32,11 @@ SUC3_CONFIG = DatasetConfig(
|
|
|
32
32
|
languages=[SV],
|
|
33
33
|
)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
name="
|
|
37
|
-
pretty_name="the
|
|
38
|
-
"dataset
|
|
39
|
-
huggingface_id="EuroEval/
|
|
35
|
+
MULTI_WIKI_QA_SV_CONFIG = DatasetConfig(
|
|
36
|
+
name="multi-wiki-qa-sv",
|
|
37
|
+
pretty_name="the truncated version of the Swedish part of the reading "
|
|
38
|
+
"comprehension dataset MultiWikiQA",
|
|
39
|
+
huggingface_id="EuroEval/multi-wiki-qa-sv-mini",
|
|
40
40
|
task=RC,
|
|
41
41
|
languages=[SV],
|
|
42
42
|
)
|
|
@@ -110,11 +110,11 @@ BELEBELE_SV_CONFIG = DatasetConfig(
|
|
|
110
110
|
unofficial=True,
|
|
111
111
|
)
|
|
112
112
|
|
|
113
|
-
|
|
114
|
-
name="
|
|
115
|
-
pretty_name="the
|
|
116
|
-
"
|
|
117
|
-
huggingface_id="EuroEval/
|
|
113
|
+
SCANDIQA_SV_CONFIG = DatasetConfig(
|
|
114
|
+
name="scandiqa-sv",
|
|
115
|
+
pretty_name="the Swedish part of the truncated version of the question answering "
|
|
116
|
+
"dataset ScandiQA",
|
|
117
|
+
huggingface_id="EuroEval/scandiqa-sv-mini",
|
|
118
118
|
task=RC,
|
|
119
119
|
languages=[SV],
|
|
120
120
|
unofficial=True,
|
|
@@ -137,7 +137,6 @@ WINOGRANDE_SV_CONFIG = DatasetConfig(
|
|
|
137
137
|
huggingface_id="EuroEval/winogrande-sv",
|
|
138
138
|
task=COMMON_SENSE,
|
|
139
139
|
languages=[SV],
|
|
140
|
-
splits=["train", "test"],
|
|
141
140
|
_labels=["a", "b"],
|
|
142
141
|
unofficial=True,
|
|
143
142
|
)
|
|
@@ -174,6 +173,5 @@ SKOLPROV_CONFIG = DatasetConfig(
|
|
|
174
173
|
huggingface_id="EuroEval/skolprov",
|
|
175
174
|
task=KNOW,
|
|
176
175
|
languages=[SV],
|
|
177
|
-
splits=["train", "test"],
|
|
178
176
|
unofficial=True,
|
|
179
177
|
)
|
euroeval/finetuning.py
CHANGED
|
@@ -6,7 +6,6 @@ import typing as t
|
|
|
6
6
|
from functools import partial
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
|
-
from tqdm.auto import tqdm
|
|
10
9
|
from transformers.trainer_callback import (
|
|
11
10
|
EarlyStoppingCallback,
|
|
12
11
|
PrinterCallback,
|
|
@@ -18,13 +17,9 @@ from transformers.training_args import OptimizerNames, TrainingArguments
|
|
|
18
17
|
from .callbacks import NeverLeaveProgressCallback
|
|
19
18
|
from .enums import DataType
|
|
20
19
|
from .exceptions import InvalidBenchmark, NaNValueInModelOutput
|
|
20
|
+
from .logging_utils import block_terminal_output, get_pbar, log, log_once
|
|
21
21
|
from .model_loading import load_model
|
|
22
|
-
from .utils import
|
|
23
|
-
block_terminal_output,
|
|
24
|
-
clear_memory,
|
|
25
|
-
enforce_reproducibility,
|
|
26
|
-
log_once,
|
|
27
|
-
)
|
|
22
|
+
from .utils import clear_memory, enforce_reproducibility
|
|
28
23
|
|
|
29
24
|
if t.TYPE_CHECKING:
|
|
30
25
|
from datasets import DatasetDict
|
|
@@ -32,8 +27,6 @@ if t.TYPE_CHECKING:
|
|
|
32
27
|
from .benchmark_modules import BenchmarkModule
|
|
33
28
|
from .data_models import BenchmarkConfig, DatasetConfig, ModelConfig
|
|
34
29
|
|
|
35
|
-
logger = logging.getLogger("euroeval")
|
|
36
|
-
|
|
37
30
|
|
|
38
31
|
def finetune(
|
|
39
32
|
model: "BenchmarkModule",
|
|
@@ -58,6 +51,10 @@ def finetune(
|
|
|
58
51
|
|
|
59
52
|
Returns:
|
|
60
53
|
A list of dicts containing the scores for each metric for each iteration.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
InvalidBenchmark:
|
|
57
|
+
If the benchmark could not be completed.
|
|
61
58
|
"""
|
|
62
59
|
# Set the data type to use for the model weights
|
|
63
60
|
using_cuda = benchmark_config.device == torch.device("cuda")
|
|
@@ -70,7 +67,7 @@ def finetune(
|
|
|
70
67
|
|
|
71
68
|
bs: int = benchmark_config.batch_size
|
|
72
69
|
scores: list[dict[str, float]] = list()
|
|
73
|
-
for idx in
|
|
70
|
+
for idx in get_pbar(
|
|
74
71
|
iterable=range(benchmark_config.num_iterations),
|
|
75
72
|
desc="Benchmarking",
|
|
76
73
|
disable=not benchmark_config.progress_bar,
|
|
@@ -80,7 +77,7 @@ def finetune(
|
|
|
80
77
|
model_already_initialized = idx == 0
|
|
81
78
|
|
|
82
79
|
# Run a loop here to deal with automatic reduction of batch size
|
|
83
|
-
|
|
80
|
+
for _ in range(num_attempts := 10):
|
|
84
81
|
# Clear GPU memory
|
|
85
82
|
if not model_already_initialized:
|
|
86
83
|
try:
|
|
@@ -112,7 +109,10 @@ def finetune(
|
|
|
112
109
|
)
|
|
113
110
|
|
|
114
111
|
scores.append(itr_scores)
|
|
115
|
-
|
|
112
|
+
log(
|
|
113
|
+
f"Test scores for iteration {idx}: {itr_scores}",
|
|
114
|
+
level=logging.DEBUG,
|
|
115
|
+
)
|
|
116
116
|
|
|
117
117
|
break
|
|
118
118
|
|
|
@@ -123,9 +123,10 @@ def finetune(
|
|
|
123
123
|
if dtype != DataType.FP32:
|
|
124
124
|
dtype = DataType.FP32
|
|
125
125
|
model_already_initialized = False
|
|
126
|
-
|
|
126
|
+
log(
|
|
127
127
|
"NaN value detected in model outputs while using mixed "
|
|
128
|
-
"precision. Retrying with full fp32 precision."
|
|
128
|
+
"precision. Retrying with full fp32 precision.",
|
|
129
|
+
level=logging.DEBUG,
|
|
129
130
|
)
|
|
130
131
|
else:
|
|
131
132
|
raise InvalidBenchmark(
|
|
@@ -151,7 +152,12 @@ def finetune(
|
|
|
151
152
|
model_already_initialized = False
|
|
152
153
|
|
|
153
154
|
bs //= 2
|
|
154
|
-
|
|
155
|
+
log(f"Reduced batch size to {bs}", level=logging.DEBUG)
|
|
156
|
+
|
|
157
|
+
else:
|
|
158
|
+
raise InvalidBenchmark(
|
|
159
|
+
f"Could not benchmark the model after {num_attempts} attempts!"
|
|
160
|
+
)
|
|
155
161
|
|
|
156
162
|
return scores
|
|
157
163
|
|
euroeval/generation.py
CHANGED
|
@@ -11,12 +11,13 @@ from tqdm.auto import tqdm
|
|
|
11
11
|
|
|
12
12
|
from .enums import BatchingPreference, TaskGroup
|
|
13
13
|
from .exceptions import InvalidBenchmark
|
|
14
|
+
from .logging_utils import get_pbar, log, log_once
|
|
14
15
|
from .model_cache import (
|
|
15
16
|
ModelCache,
|
|
16
17
|
load_cached_model_outputs,
|
|
17
18
|
split_dataset_into_cached_and_non_cached,
|
|
18
19
|
)
|
|
19
|
-
from .utils import clear_memory
|
|
20
|
+
from .utils import clear_memory
|
|
20
21
|
|
|
21
22
|
if t.TYPE_CHECKING:
|
|
22
23
|
from datasets import DatasetDict
|
|
@@ -29,8 +30,6 @@ if t.TYPE_CHECKING:
|
|
|
29
30
|
ModelConfig,
|
|
30
31
|
)
|
|
31
32
|
|
|
32
|
-
logger = logging.getLogger("euroeval")
|
|
33
|
-
|
|
34
33
|
|
|
35
34
|
def generate(
|
|
36
35
|
model: "BenchmarkModule",
|
|
@@ -78,7 +77,7 @@ def generate(
|
|
|
78
77
|
)
|
|
79
78
|
|
|
80
79
|
scores: list[dict[str, float]] = list()
|
|
81
|
-
for idx in
|
|
80
|
+
for idx in get_pbar(
|
|
82
81
|
iterable=range(len(datasets)),
|
|
83
82
|
desc="Benchmarking",
|
|
84
83
|
disable=not benchmark_config.progress_bar,
|
|
@@ -90,7 +89,7 @@ def generate(
|
|
|
90
89
|
dataset_config=dataset_config,
|
|
91
90
|
benchmark_config=benchmark_config,
|
|
92
91
|
)
|
|
93
|
-
|
|
92
|
+
log(f"Test scores for iteration {idx}: {test_scores}", level=logging.DEBUG)
|
|
94
93
|
scores.append(test_scores)
|
|
95
94
|
clear_memory()
|
|
96
95
|
|
|
@@ -142,14 +141,14 @@ def generate_single_iteration(
|
|
|
142
141
|
itr: t.Iterable
|
|
143
142
|
match model.batching_preference:
|
|
144
143
|
case BatchingPreference.SINGLE_SAMPLE:
|
|
145
|
-
itr =
|
|
144
|
+
itr = get_pbar(iterable=non_cached_dataset)
|
|
146
145
|
case BatchingPreference.ALL_AT_ONCE:
|
|
147
146
|
itr = [non_cached_dataset[:]]
|
|
148
147
|
case _:
|
|
149
148
|
num_batches = len(non_cached_dataset) // benchmark_config.batch_size
|
|
150
149
|
if len(non_cached_dataset) % benchmark_config.batch_size != 0:
|
|
151
150
|
num_batches += 1
|
|
152
|
-
itr =
|
|
151
|
+
itr = get_pbar(
|
|
153
152
|
iterable=mit.batched(
|
|
154
153
|
iterable=non_cached_dataset, n=benchmark_config.batch_size
|
|
155
154
|
),
|
|
@@ -297,7 +296,7 @@ def debug_log(
|
|
|
297
296
|
+ "\n"
|
|
298
297
|
+ "\t".join(labels)
|
|
299
298
|
)
|
|
300
|
-
|
|
299
|
+
log("\n\n".join(log_msgs), level=logging.DEBUG)
|
|
301
300
|
return
|
|
302
301
|
|
|
303
302
|
case (
|
|
@@ -347,6 +346,7 @@ def debug_log(
|
|
|
347
346
|
if labels[idx]:
|
|
348
347
|
data_to_log["Label"] = labels[idx]
|
|
349
348
|
data_to_log |= {key.capitalize(): batch[key][idx] for key in metadata_keys}
|
|
350
|
-
|
|
351
|
-
"\n".join(f"{key}: {value!r}" for key, value in data_to_log.items())
|
|
349
|
+
log(
|
|
350
|
+
"\n".join(f"{key}: {value!r}" for key, value in data_to_log.items()),
|
|
351
|
+
level=logging.DEBUG,
|
|
352
352
|
)
|
euroeval/generation_utils.py
CHANGED
|
@@ -9,8 +9,9 @@ import typing as t
|
|
|
9
9
|
|
|
10
10
|
from .enums import GenerativeType, TaskGroup
|
|
11
11
|
from .exceptions import InvalidBenchmark, InvalidModel
|
|
12
|
+
from .logging_utils import log_once
|
|
12
13
|
from .tokenisation_utils import apply_chat_template
|
|
13
|
-
from .utils import extract_multiple_choice_labels
|
|
14
|
+
from .utils import extract_multiple_choice_labels
|
|
14
15
|
|
|
15
16
|
if t.TYPE_CHECKING:
|
|
16
17
|
from datasets import DatasetDict
|
|
@@ -18,8 +19,6 @@ if t.TYPE_CHECKING:
|
|
|
18
19
|
|
|
19
20
|
from .data_models import BenchmarkConfig, DatasetConfig, ModelConfig
|
|
20
21
|
|
|
21
|
-
logger = logging.getLogger("euroeval")
|
|
22
|
-
|
|
23
22
|
|
|
24
23
|
def extract_few_shot_examples(
|
|
25
24
|
dataset: "DatasetDict",
|