EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,32 +1,35 @@
|
|
|
1
1
|
"""Utility functions related to the token-classification task group."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import logging
|
|
4
|
-
import re
|
|
5
5
|
import typing as t
|
|
6
6
|
from copy import deepcopy
|
|
7
7
|
|
|
8
|
-
import demjson3
|
|
9
8
|
import numpy as np
|
|
10
9
|
|
|
11
10
|
from ..exceptions import InvalidBenchmark
|
|
12
|
-
from ..
|
|
11
|
+
from ..logging_utils import log
|
|
12
|
+
from ..utils import (
|
|
13
|
+
extract_json_dict_from_string,
|
|
14
|
+
raise_if_model_output_contains_nan_values,
|
|
15
|
+
)
|
|
13
16
|
|
|
14
17
|
if t.TYPE_CHECKING:
|
|
18
|
+
from datasets.arrow_dataset import Dataset
|
|
15
19
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
16
20
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
17
21
|
from transformers.trainer_utils import EvalPrediction
|
|
18
22
|
|
|
19
|
-
from ..data_models import DatasetConfig, GenerativeModelOutput
|
|
23
|
+
from ..data_models import BenchmarkConfig, DatasetConfig, GenerativeModelOutput
|
|
20
24
|
from ..types import Labels, Predictions
|
|
21
25
|
|
|
22
26
|
|
|
23
|
-
logger = logging.getLogger("euroeval")
|
|
24
|
-
|
|
25
|
-
|
|
26
27
|
def compute_metrics(
|
|
27
28
|
model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
|
|
28
29
|
has_misc_tags: bool,
|
|
29
30
|
dataset_config: "DatasetConfig",
|
|
31
|
+
benchmark_config: "BenchmarkConfig",
|
|
32
|
+
dataset: "Dataset",
|
|
30
33
|
) -> dict[str, float]:
|
|
31
34
|
"""Compute the metrics needed for evaluation.
|
|
32
35
|
|
|
@@ -38,6 +41,11 @@ def compute_metrics(
|
|
|
38
41
|
Whether the dataset has MISC tags.
|
|
39
42
|
dataset_config:
|
|
40
43
|
The configuration of the dataset.
|
|
44
|
+
benchmark_config:
|
|
45
|
+
The configuration of the benchmark.
|
|
46
|
+
dataset:
|
|
47
|
+
The dataset used for evaluation. This is only used in case any additional
|
|
48
|
+
metadata is used to compute the metrics.
|
|
41
49
|
|
|
42
50
|
Returns:
|
|
43
51
|
A dictionary with the names of the metrics as keys and the metric values as
|
|
@@ -52,7 +60,9 @@ def compute_metrics(
|
|
|
52
60
|
|
|
53
61
|
predictions: list[list[str]]
|
|
54
62
|
if not isinstance(model_outputs[0][0], str):
|
|
55
|
-
raw_predictions:
|
|
63
|
+
raw_predictions: c.Sequence[c.Sequence[int]] = np.argmax(
|
|
64
|
+
model_outputs, axis=-1
|
|
65
|
+
).tolist()
|
|
56
66
|
|
|
57
67
|
# Remove ignored index (special tokens)
|
|
58
68
|
predictions = [
|
|
@@ -136,7 +146,13 @@ def compute_metrics(
|
|
|
136
146
|
for metric in dataset_config.task.metrics
|
|
137
147
|
if metric.name == "micro_f1"
|
|
138
148
|
)
|
|
139
|
-
micro_f1_score = metric(
|
|
149
|
+
micro_f1_score = metric(
|
|
150
|
+
predictions=predictions,
|
|
151
|
+
references=list(labels),
|
|
152
|
+
dataset=dataset,
|
|
153
|
+
dataset_config=dataset_config,
|
|
154
|
+
benchmark_config=benchmark_config,
|
|
155
|
+
)
|
|
140
156
|
|
|
141
157
|
# Compute the metrics without MISC tags
|
|
142
158
|
# We manually set the F1 metric to be 100% if both the labels and the models
|
|
@@ -158,7 +174,11 @@ def compute_metrics(
|
|
|
158
174
|
if metric.name == "micro_f1_no_misc"
|
|
159
175
|
)
|
|
160
176
|
micro_f1_no_misc_score = metric(
|
|
161
|
-
predictions=predictions_no_misc,
|
|
177
|
+
predictions=predictions_no_misc,
|
|
178
|
+
references=labels_no_misc,
|
|
179
|
+
dataset=dataset,
|
|
180
|
+
dataset_config=dataset_config,
|
|
181
|
+
benchmark_config=benchmark_config,
|
|
162
182
|
)
|
|
163
183
|
|
|
164
184
|
# Raise error if the metrics are invalid
|
|
@@ -172,7 +192,7 @@ def extract_labels_from_generation(
|
|
|
172
192
|
input_batch: dict[str, list],
|
|
173
193
|
model_output: "GenerativeModelOutput",
|
|
174
194
|
dataset_config: "DatasetConfig",
|
|
175
|
-
) ->
|
|
195
|
+
) -> c.Sequence[t.Any]:
|
|
176
196
|
"""Extract the predicted labels from the generated output.
|
|
177
197
|
|
|
178
198
|
Args:
|
|
@@ -187,55 +207,31 @@ def extract_labels_from_generation(
|
|
|
187
207
|
Returns:
|
|
188
208
|
The predicted labels.
|
|
189
209
|
"""
|
|
190
|
-
raw_predictions = model_output.sequences
|
|
191
|
-
|
|
192
|
-
# Attempt to extract the JSON dictionary from the predictions
|
|
193
|
-
json_regex = r"\{[^{}]+?\}"
|
|
194
|
-
json_matches = [
|
|
195
|
-
re.search(pattern=json_regex, string=raw_prediction, flags=re.DOTALL)
|
|
196
|
-
or raw_prediction
|
|
197
|
-
for raw_prediction in raw_predictions
|
|
198
|
-
]
|
|
199
|
-
raw_predictions = [
|
|
200
|
-
json_match.group() if isinstance(json_match, re.Match) else json_match
|
|
201
|
-
for json_match in json_matches
|
|
202
|
-
]
|
|
203
|
-
|
|
204
210
|
tokens = input_batch["tokens"]
|
|
205
211
|
predicted_labels: list[list[str]] = [["o"] * len(token_ids) for token_ids in tokens]
|
|
206
|
-
for idx, raw_prediction in enumerate(
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
if not isinstance(json_output, dict):
|
|
210
|
-
logger.debug(
|
|
211
|
-
"The model output is not a JSON dictionary, so cannot parse "
|
|
212
|
-
f"it. Skipping. Here is the output: {raw_prediction}"
|
|
213
|
-
)
|
|
214
|
-
continue
|
|
215
|
-
elif not all(isinstance(key, str) for key in json_output.keys()):
|
|
216
|
-
logger.debug(
|
|
217
|
-
"The model output is not a JSON dictionary with string keys, "
|
|
218
|
-
"so cannot parse it. Skipping. Here is the output: "
|
|
219
|
-
f"{raw_prediction}"
|
|
220
|
-
)
|
|
221
|
-
continue
|
|
222
|
-
elif not all(isinstance(value, list) for value in json_output.values()):
|
|
223
|
-
logger.debug(
|
|
224
|
-
"The model output is not a JSON dictionary with list values, "
|
|
225
|
-
"so cannot parse it. Skipping. Here is the output: "
|
|
226
|
-
f"{raw_prediction}"
|
|
227
|
-
)
|
|
228
|
-
continue
|
|
229
|
-
prediction_dict: dict[str, list[str]] = json_output
|
|
230
|
-
except demjson3.JSONDecodeError:
|
|
231
|
-
logger.debug(
|
|
232
|
-
"The model output is not valid JSON, so cannot parse it. Skipping. "
|
|
233
|
-
f"Here is the output: {raw_prediction!r}"
|
|
234
|
-
)
|
|
212
|
+
for idx, raw_prediction in enumerate(model_output.sequences):
|
|
213
|
+
prediction_dict = extract_json_dict_from_string(s=raw_prediction)
|
|
214
|
+
if prediction_dict is None:
|
|
235
215
|
continue
|
|
236
216
|
|
|
237
217
|
prompt_label_mapping = dataset_config.prompt_label_mapping
|
|
238
218
|
for prompt_tag_name, named_entities in prediction_dict.items():
|
|
219
|
+
if not isinstance(named_entities, list):
|
|
220
|
+
log(
|
|
221
|
+
"The model produced an invalid format for the named entities. "
|
|
222
|
+
f"Expected a list but got {type(named_entities)}. Skipping.",
|
|
223
|
+
level=logging.DEBUG,
|
|
224
|
+
)
|
|
225
|
+
continue
|
|
226
|
+
try:
|
|
227
|
+
named_entities = [str(ne) for ne in named_entities]
|
|
228
|
+
except Exception:
|
|
229
|
+
log(
|
|
230
|
+
"The model produced an invalid format for the named entities. "
|
|
231
|
+
f"Expected a list of strings but got {named_entities}. Skipping.",
|
|
232
|
+
level=logging.DEBUG,
|
|
233
|
+
)
|
|
234
|
+
continue
|
|
239
235
|
try:
|
|
240
236
|
tag_name = [
|
|
241
237
|
tag[2:]
|
|
@@ -243,9 +239,10 @@ def extract_labels_from_generation(
|
|
|
243
239
|
if prompt_tag == prompt_tag_name
|
|
244
240
|
][0]
|
|
245
241
|
except IndexError:
|
|
246
|
-
|
|
242
|
+
log(
|
|
247
243
|
"The model produced an invalid prompt tag name, "
|
|
248
|
-
f"{prompt_tag_name}. Skipping."
|
|
244
|
+
f"{prompt_tag_name}. Skipping.",
|
|
245
|
+
level=logging.DEBUG,
|
|
249
246
|
)
|
|
250
247
|
continue
|
|
251
248
|
|
|
@@ -265,49 +262,49 @@ def extract_labels_from_generation(
|
|
|
265
262
|
|
|
266
263
|
|
|
267
264
|
def tokenize_and_align_labels(
|
|
268
|
-
examples: dict,
|
|
265
|
+
examples: dict, tokeniser: "PreTrainedTokenizer", label2id: dict[str, int]
|
|
269
266
|
) -> "BatchEncoding":
|
|
270
267
|
"""Tokenise all texts and align the labels with them.
|
|
271
268
|
|
|
272
269
|
Args:
|
|
273
270
|
examples:
|
|
274
271
|
The examples to be tokenised.
|
|
275
|
-
|
|
276
|
-
A pretrained
|
|
272
|
+
tokeniser:
|
|
273
|
+
A pretrained tokeniser.
|
|
277
274
|
label2id:
|
|
278
275
|
A dictionary that converts NER tags to IDs.
|
|
279
276
|
|
|
280
277
|
Returns:
|
|
281
278
|
A dictionary containing the tokenized data as well as labels.
|
|
282
279
|
"""
|
|
283
|
-
#
|
|
280
|
+
# Tokenise the texts. We use the `is_split_into_words` argument here because
|
|
284
281
|
# the texts in our dataset are lists of words (with a label for each word)
|
|
285
|
-
tokenized_inputs =
|
|
282
|
+
tokenized_inputs = tokeniser(
|
|
286
283
|
examples["tokens"], is_split_into_words=True, truncation=True, padding=True
|
|
287
284
|
)
|
|
288
285
|
|
|
289
286
|
# Extract a mapping between all the tokens and their corresponding word. If the
|
|
290
|
-
#
|
|
287
|
+
# tokeniser is of a "fast" variant then this can be accessed through the
|
|
291
288
|
# `word_ids` method. Otherwise, we have to extract it manually.
|
|
292
289
|
all_labels: list[list[int]] = list()
|
|
293
|
-
labels:
|
|
294
|
-
word_ids:
|
|
290
|
+
labels: c.Sequence[str]
|
|
291
|
+
word_ids: c.Sequence[int | None]
|
|
295
292
|
for i, labels in enumerate(examples["labels"]):
|
|
296
|
-
# Try to get the word IDs from the
|
|
293
|
+
# Try to get the word IDs from the tokeniser
|
|
297
294
|
try:
|
|
298
295
|
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
|
299
296
|
|
|
300
|
-
# If the
|
|
297
|
+
# If the tokeniser is not of a "fast" variant, we have to extract the word
|
|
301
298
|
# IDs manually
|
|
302
299
|
except ValueError:
|
|
303
300
|
# Get the list of words in the document
|
|
304
|
-
words:
|
|
301
|
+
words: c.Sequence[str] = examples["tokens"][i]
|
|
305
302
|
|
|
306
303
|
# Get the list of token IDs in the document
|
|
307
|
-
tok_ids:
|
|
304
|
+
tok_ids: c.Sequence[int] = tokenized_inputs.input_ids[i]
|
|
308
305
|
|
|
309
306
|
# Decode the token IDs
|
|
310
|
-
tokens =
|
|
307
|
+
tokens = tokeniser.convert_ids_to_tokens(tok_ids)
|
|
311
308
|
assert isinstance(tokens, list)
|
|
312
309
|
|
|
313
310
|
# Remove prefixes from the tokens
|
|
@@ -319,14 +316,14 @@ def tokenize_and_align_labels(
|
|
|
319
316
|
tokens[tok_idx] = tok[len(prefix) :]
|
|
320
317
|
|
|
321
318
|
# Replace UNK tokens with the correct word
|
|
322
|
-
tokens = handle_unk_tokens(
|
|
319
|
+
tokens = handle_unk_tokens(tokeniser=tokeniser, tokens=tokens, words=words)
|
|
323
320
|
|
|
324
|
-
# Get list of special tokens. Some
|
|
321
|
+
# Get list of special tokens. Some tokenisers do not record these
|
|
325
322
|
# properly, which is why we convert the values to their indices and
|
|
326
323
|
# then back to strings
|
|
327
324
|
sp_toks = [
|
|
328
|
-
|
|
329
|
-
for sp_tok in
|
|
325
|
+
tokeniser.convert_ids_to_tokens(tokeniser.convert_tokens_to_ids(sp_tok))
|
|
326
|
+
for sp_tok in tokeniser.special_tokens_map.values()
|
|
330
327
|
]
|
|
331
328
|
|
|
332
329
|
# Replace special tokens with `None`
|
|
@@ -350,7 +347,7 @@ def tokenize_and_align_labels(
|
|
|
350
347
|
if len(word_idxs) != len(token_idxs):
|
|
351
348
|
raise InvalidBenchmark(
|
|
352
349
|
"The tokens could not be aligned with the words during manual "
|
|
353
|
-
"word-token alignment. It seems that the
|
|
350
|
+
"word-token alignment. It seems that the tokeniser is neither "
|
|
354
351
|
"of the fast variant nor of a SentencePiece/WordPiece variant."
|
|
355
352
|
)
|
|
356
353
|
|
|
@@ -380,9 +377,9 @@ def tokenize_and_align_labels(
|
|
|
380
377
|
label = labels[word_id]
|
|
381
378
|
try:
|
|
382
379
|
label_id = label2id[label.lower()]
|
|
383
|
-
except KeyError:
|
|
380
|
+
except KeyError as e:
|
|
384
381
|
msg = f"The label {label} was not found in the model's config."
|
|
385
|
-
raise InvalidBenchmark(msg)
|
|
382
|
+
raise InvalidBenchmark(msg) from e
|
|
386
383
|
label_ids.append(label_id)
|
|
387
384
|
|
|
388
385
|
# For the other tokens in a word, we set the label to -100
|
|
@@ -397,13 +394,13 @@ def tokenize_and_align_labels(
|
|
|
397
394
|
|
|
398
395
|
|
|
399
396
|
def handle_unk_tokens(
|
|
400
|
-
|
|
401
|
-
) ->
|
|
397
|
+
tokeniser: "PreTrainedTokenizer", tokens: list[str], words: c.Sequence[str]
|
|
398
|
+
) -> c.Sequence[str]:
|
|
402
399
|
"""Replace unknown tokens in the tokens with the corresponding word.
|
|
403
400
|
|
|
404
401
|
Args:
|
|
405
|
-
|
|
406
|
-
The
|
|
402
|
+
tokeniser:
|
|
403
|
+
The tokeniser used to tokenise the words.
|
|
407
404
|
tokens:
|
|
408
405
|
The list of tokens.
|
|
409
406
|
words:
|
|
@@ -413,15 +410,15 @@ def handle_unk_tokens(
|
|
|
413
410
|
The list of tokens with unknown tokens replaced by the corresponding word.
|
|
414
411
|
"""
|
|
415
412
|
# Locate the token indices of the unknown tokens
|
|
416
|
-
token_unk_idxs = [i for i, tok in enumerate(tokens) if tok ==
|
|
413
|
+
token_unk_idxs = [i for i, tok in enumerate(tokens) if tok == tokeniser.unk_token]
|
|
417
414
|
|
|
418
415
|
# Locate the word indices of the words which contain an unknown token
|
|
419
416
|
word_unk_idxs = [
|
|
420
417
|
i
|
|
421
418
|
for i, word in enumerate(words)
|
|
422
|
-
if
|
|
423
|
-
in
|
|
424
|
-
|
|
419
|
+
if tokeniser.unk_token
|
|
420
|
+
in tokeniser.convert_ids_to_tokens(
|
|
421
|
+
tokeniser.encode(word, add_special_tokens=False)
|
|
425
422
|
)
|
|
426
423
|
]
|
|
427
424
|
|
|
@@ -430,9 +427,9 @@ def handle_unk_tokens(
|
|
|
430
427
|
# Fetch the word
|
|
431
428
|
word = words[word_idx]
|
|
432
429
|
|
|
433
|
-
#
|
|
434
|
-
tokens_with_unk =
|
|
435
|
-
|
|
430
|
+
# Tokenise the word, which is now a list containing at least one UNK token
|
|
431
|
+
tokens_with_unk = tokeniser.convert_ids_to_tokens(
|
|
432
|
+
tokeniser.encode(word, add_special_tokens=False)
|
|
436
433
|
)
|
|
437
434
|
|
|
438
435
|
# Iterate over the tokens in the word
|
|
@@ -441,10 +438,10 @@ def handle_unk_tokens(
|
|
|
441
438
|
# of the content of this token from the word. The result of the `word`
|
|
442
439
|
# variable will be the content of the UNK token.
|
|
443
440
|
# NOTE: This is a bit hacky and not bulletproof. For instance, if the
|
|
444
|
-
# word is "1925-1950" and the
|
|
441
|
+
# word is "1925-1950" and the tokeniser splits it into ["[UNK]", "-",
|
|
445
442
|
# "19", "50"], then the result will be 2519 instead of 1925. This
|
|
446
443
|
# happens almost never, however, so we can live with it.
|
|
447
|
-
if possible_unk_token !=
|
|
444
|
+
if possible_unk_token != tokeniser.unk_token:
|
|
448
445
|
word = word.replace(possible_unk_token, "", 1)
|
|
449
446
|
|
|
450
447
|
# Replace the token with the word
|
euroeval/tasks.py
CHANGED
|
@@ -1,35 +1,29 @@
|
|
|
1
1
|
"""All benchmarks tasks used in EuroEval."""
|
|
2
2
|
|
|
3
3
|
from . import metrics as m
|
|
4
|
+
from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
|
|
4
5
|
from .data_models import Task
|
|
5
|
-
from .enums import TaskGroup
|
|
6
|
+
from .enums import GenerativeType, ModelType, TaskGroup
|
|
6
7
|
from .prompt_templates import (
|
|
8
|
+
CLASSIFICATION_TEMPLATES,
|
|
7
9
|
LA_TEMPLATES,
|
|
8
10
|
MULTIPLE_CHOICE_TEMPLATES,
|
|
9
11
|
NER_TEMPLATES,
|
|
10
12
|
RC_TEMPLATES,
|
|
11
13
|
SENT_TEMPLATES,
|
|
12
14
|
SUMM_TEMPLATES,
|
|
15
|
+
TOKEN_CLASSIFICATION_TEMPLATES,
|
|
13
16
|
)
|
|
14
17
|
|
|
15
|
-
|
|
16
|
-
def get_all_tasks() -> dict[str, Task]:
|
|
17
|
-
"""Get a list of all the dataset tasks.
|
|
18
|
-
|
|
19
|
-
Returns:
|
|
20
|
-
A mapping between names of dataset tasks and their configurations.
|
|
21
|
-
"""
|
|
22
|
-
return {cfg.name: cfg for cfg in globals().values() if isinstance(cfg, Task)}
|
|
23
|
-
|
|
24
|
-
|
|
25
18
|
LA = Task(
|
|
26
19
|
name="linguistic-acceptability",
|
|
27
20
|
task_group=TaskGroup.SEQUENCE_CLASSIFICATION,
|
|
28
21
|
template_dict=LA_TEMPLATES,
|
|
29
22
|
metrics=[m.mcc_metric, m.macro_f1_metric],
|
|
30
23
|
default_num_few_shot_examples=12,
|
|
31
|
-
default_max_generated_tokens=
|
|
24
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
32
25
|
default_labels=["correct", "incorrect"],
|
|
26
|
+
uses_logprobs=True,
|
|
33
27
|
)
|
|
34
28
|
|
|
35
29
|
|
|
@@ -51,6 +45,7 @@ NER = Task(
|
|
|
51
45
|
"b-misc",
|
|
52
46
|
"i-misc",
|
|
53
47
|
],
|
|
48
|
+
uses_structured_output=True,
|
|
54
49
|
)
|
|
55
50
|
|
|
56
51
|
|
|
@@ -71,8 +66,9 @@ SENT = Task(
|
|
|
71
66
|
template_dict=SENT_TEMPLATES,
|
|
72
67
|
metrics=[m.mcc_metric, m.macro_f1_metric],
|
|
73
68
|
default_num_few_shot_examples=12,
|
|
74
|
-
default_max_generated_tokens=
|
|
69
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
75
70
|
default_labels=["positive", "neutral", "negative"],
|
|
71
|
+
uses_logprobs=True,
|
|
76
72
|
)
|
|
77
73
|
|
|
78
74
|
|
|
@@ -84,6 +80,7 @@ SUMM = Task(
|
|
|
84
80
|
default_num_few_shot_examples=1,
|
|
85
81
|
default_max_generated_tokens=256,
|
|
86
82
|
default_labels=[],
|
|
83
|
+
default_allowed_model_types=[ModelType.GENERATIVE],
|
|
87
84
|
)
|
|
88
85
|
|
|
89
86
|
|
|
@@ -93,8 +90,10 @@ KNOW = Task(
|
|
|
93
90
|
template_dict=MULTIPLE_CHOICE_TEMPLATES,
|
|
94
91
|
metrics=[m.mcc_metric, m.accuracy_metric],
|
|
95
92
|
default_num_few_shot_examples=5,
|
|
96
|
-
default_max_generated_tokens=
|
|
93
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
97
94
|
default_labels=["a", "b", "c", "d"],
|
|
95
|
+
default_allowed_model_types=[ModelType.GENERATIVE],
|
|
96
|
+
uses_logprobs=True,
|
|
98
97
|
)
|
|
99
98
|
|
|
100
99
|
|
|
@@ -104,8 +103,10 @@ MCRC = Task(
|
|
|
104
103
|
template_dict=MULTIPLE_CHOICE_TEMPLATES,
|
|
105
104
|
metrics=[m.mcc_metric, m.accuracy_metric],
|
|
106
105
|
default_num_few_shot_examples=5,
|
|
107
|
-
default_max_generated_tokens=
|
|
106
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
108
107
|
default_labels=["a", "b", "c", "d"],
|
|
108
|
+
default_allowed_model_types=[ModelType.GENERATIVE],
|
|
109
|
+
uses_logprobs=True,
|
|
109
110
|
)
|
|
110
111
|
|
|
111
112
|
|
|
@@ -115,8 +116,29 @@ COMMON_SENSE = Task(
|
|
|
115
116
|
template_dict=MULTIPLE_CHOICE_TEMPLATES,
|
|
116
117
|
metrics=[m.mcc_metric, m.accuracy_metric],
|
|
117
118
|
default_num_few_shot_examples=5,
|
|
118
|
-
default_max_generated_tokens=
|
|
119
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
119
120
|
default_labels=["a", "b", "c", "d"],
|
|
121
|
+
default_allowed_model_types=[ModelType.GENERATIVE],
|
|
122
|
+
uses_logprobs=True,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
EUROPEAN_VALUES = Task(
|
|
127
|
+
name="european-values",
|
|
128
|
+
task_group=TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION,
|
|
129
|
+
template_dict=MULTIPLE_CHOICE_TEMPLATES,
|
|
130
|
+
metrics=[m.european_values_metric],
|
|
131
|
+
default_num_few_shot_examples=0,
|
|
132
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
133
|
+
default_labels=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"],
|
|
134
|
+
default_allowed_model_types=[ModelType.GENERATIVE],
|
|
135
|
+
default_allowed_generative_types=[
|
|
136
|
+
GenerativeType.INSTRUCTION_TUNED,
|
|
137
|
+
GenerativeType.REASONING,
|
|
138
|
+
],
|
|
139
|
+
requires_zero_shot=True,
|
|
140
|
+
uses_logprobs=True,
|
|
141
|
+
default_allow_invalid_model_outputs=False,
|
|
120
142
|
)
|
|
121
143
|
|
|
122
144
|
|
|
@@ -129,3 +151,40 @@ SPEED = Task(
|
|
|
129
151
|
default_max_generated_tokens=5,
|
|
130
152
|
default_labels=[],
|
|
131
153
|
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# Used for custom datasets
|
|
157
|
+
|
|
158
|
+
TEXT_CLASSIFICATION = Task(
|
|
159
|
+
name="classification",
|
|
160
|
+
task_group=TaskGroup.SEQUENCE_CLASSIFICATION,
|
|
161
|
+
template_dict=CLASSIFICATION_TEMPLATES,
|
|
162
|
+
metrics=[m.mcc_metric, m.macro_f1_metric],
|
|
163
|
+
default_num_few_shot_examples=12,
|
|
164
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
165
|
+
default_labels=None,
|
|
166
|
+
uses_logprobs=True,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
TOKEN_CLASSIFICATION = Task(
|
|
170
|
+
name="token-classification",
|
|
171
|
+
task_group=TaskGroup.TOKEN_CLASSIFICATION,
|
|
172
|
+
template_dict=TOKEN_CLASSIFICATION_TEMPLATES,
|
|
173
|
+
metrics=[m.micro_f1_metric],
|
|
174
|
+
default_num_few_shot_examples=8,
|
|
175
|
+
default_max_generated_tokens=128,
|
|
176
|
+
default_labels=None,
|
|
177
|
+
uses_structured_output=True,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
MULTIPLE_CHOICE = Task(
|
|
181
|
+
name="multiple-choice",
|
|
182
|
+
task_group=TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION,
|
|
183
|
+
template_dict=MULTIPLE_CHOICE_TEMPLATES,
|
|
184
|
+
metrics=[m.mcc_metric, m.accuracy_metric],
|
|
185
|
+
default_num_few_shot_examples=5,
|
|
186
|
+
default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
|
|
187
|
+
default_labels=None,
|
|
188
|
+
default_allowed_model_types=[ModelType.GENERATIVE],
|
|
189
|
+
uses_logprobs=True,
|
|
190
|
+
)
|