EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. euroeval/__init__.py +32 -14
  2. euroeval/benchmark_config_factory.py +92 -180
  3. euroeval/benchmark_modules/base.py +49 -39
  4. euroeval/benchmark_modules/fresh.py +35 -21
  5. euroeval/benchmark_modules/hf.py +280 -244
  6. euroeval/benchmark_modules/litellm.py +752 -312
  7. euroeval/benchmark_modules/vllm.py +570 -268
  8. euroeval/benchmarker.py +651 -528
  9. euroeval/caching_utils.py +79 -0
  10. euroeval/callbacks.py +5 -7
  11. euroeval/cli.py +49 -38
  12. euroeval/constants.py +44 -25
  13. euroeval/data_loading.py +111 -55
  14. euroeval/data_models.py +490 -323
  15. euroeval/dataset_configs/__init__.py +26 -4
  16. euroeval/dataset_configs/bosnian.py +39 -0
  17. euroeval/dataset_configs/bulgarian.py +56 -0
  18. euroeval/dataset_configs/croatian.py +56 -0
  19. euroeval/dataset_configs/czech.py +75 -0
  20. euroeval/dataset_configs/danish.py +78 -50
  21. euroeval/dataset_configs/dutch.py +74 -44
  22. euroeval/dataset_configs/english.py +71 -36
  23. euroeval/dataset_configs/estonian.py +111 -0
  24. euroeval/dataset_configs/faroese.py +25 -18
  25. euroeval/dataset_configs/finnish.py +63 -26
  26. euroeval/dataset_configs/french.py +65 -32
  27. euroeval/dataset_configs/german.py +77 -36
  28. euroeval/dataset_configs/greek.py +64 -0
  29. euroeval/dataset_configs/icelandic.py +68 -57
  30. euroeval/dataset_configs/italian.py +68 -36
  31. euroeval/dataset_configs/latvian.py +87 -0
  32. euroeval/dataset_configs/lithuanian.py +64 -0
  33. euroeval/dataset_configs/norwegian.py +98 -72
  34. euroeval/dataset_configs/polish.py +96 -0
  35. euroeval/dataset_configs/portuguese.py +63 -40
  36. euroeval/dataset_configs/serbian.py +64 -0
  37. euroeval/dataset_configs/slovak.py +55 -0
  38. euroeval/dataset_configs/slovene.py +56 -0
  39. euroeval/dataset_configs/spanish.py +68 -34
  40. euroeval/dataset_configs/swedish.py +82 -41
  41. euroeval/dataset_configs/ukrainian.py +64 -0
  42. euroeval/enums.py +12 -6
  43. euroeval/exceptions.py +21 -1
  44. euroeval/finetuning.py +34 -26
  45. euroeval/generation.py +76 -41
  46. euroeval/generation_utils.py +169 -34
  47. euroeval/languages.py +1020 -188
  48. euroeval/logging_utils.py +268 -0
  49. euroeval/metrics/__init__.py +6 -0
  50. euroeval/metrics/base.py +85 -0
  51. euroeval/metrics/huggingface.py +216 -0
  52. euroeval/metrics/llm_as_a_judge.py +260 -0
  53. euroeval/metrics/pipeline.py +289 -0
  54. euroeval/metrics/speed.py +48 -0
  55. euroeval/model_cache.py +40 -21
  56. euroeval/model_config.py +4 -5
  57. euroeval/model_loading.py +3 -0
  58. euroeval/prompt_templates/__init__.py +2 -0
  59. euroeval/prompt_templates/classification.py +206 -0
  60. euroeval/prompt_templates/linguistic_acceptability.py +157 -22
  61. euroeval/prompt_templates/multiple_choice.py +159 -17
  62. euroeval/prompt_templates/named_entity_recognition.py +318 -21
  63. euroeval/prompt_templates/reading_comprehension.py +207 -16
  64. euroeval/prompt_templates/sentiment_classification.py +205 -22
  65. euroeval/prompt_templates/summarization.py +122 -22
  66. euroeval/prompt_templates/token_classification.py +279 -0
  67. euroeval/scores.py +20 -9
  68. euroeval/speed_benchmark.py +11 -12
  69. euroeval/task_group_utils/multiple_choice_classification.py +21 -12
  70. euroeval/task_group_utils/question_answering.py +101 -73
  71. euroeval/task_group_utils/sequence_classification.py +144 -61
  72. euroeval/task_group_utils/text_to_text.py +33 -12
  73. euroeval/task_group_utils/token_classification.py +86 -89
  74. euroeval/tasks.py +75 -16
  75. euroeval/tokenisation_utils.py +603 -0
  76. euroeval/types.py +17 -11
  77. euroeval/utils.py +332 -137
  78. euroeval-16.7.1.dist-info/METADATA +623 -0
  79. euroeval-16.7.1.dist-info/RECORD +84 -0
  80. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
  81. euroeval/human_evaluation.py +0 -737
  82. euroeval/metrics.py +0 -452
  83. euroeval/tokenization_utils.py +0 -498
  84. euroeval-15.12.0.dist-info/METADATA +0 -285
  85. euroeval-15.12.0.dist-info/RECORD +0 -63
  86. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
  87. {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,32 +1,35 @@
1
1
  """Utility functions related to the token-classification task group."""
2
2
 
3
+ import collections.abc as c
3
4
  import logging
4
- import re
5
5
  import typing as t
6
6
  from copy import deepcopy
7
7
 
8
- import demjson3
9
8
  import numpy as np
10
9
 
11
10
  from ..exceptions import InvalidBenchmark
12
- from ..utils import raise_if_model_output_contains_nan_values
11
+ from ..logging_utils import log
12
+ from ..utils import (
13
+ extract_json_dict_from_string,
14
+ raise_if_model_output_contains_nan_values,
15
+ )
13
16
 
14
17
  if t.TYPE_CHECKING:
18
+ from datasets.arrow_dataset import Dataset
15
19
  from transformers.tokenization_utils import PreTrainedTokenizer
16
20
  from transformers.tokenization_utils_base import BatchEncoding
17
21
  from transformers.trainer_utils import EvalPrediction
18
22
 
19
- from ..data_models import DatasetConfig, GenerativeModelOutput
23
+ from ..data_models import BenchmarkConfig, DatasetConfig, GenerativeModelOutput
20
24
  from ..types import Labels, Predictions
21
25
 
22
26
 
23
- logger = logging.getLogger("euroeval")
24
-
25
-
26
27
  def compute_metrics(
27
28
  model_outputs_and_labels: "tuple[Predictions, Labels] | EvalPrediction",
28
29
  has_misc_tags: bool,
29
30
  dataset_config: "DatasetConfig",
31
+ benchmark_config: "BenchmarkConfig",
32
+ dataset: "Dataset",
30
33
  ) -> dict[str, float]:
31
34
  """Compute the metrics needed for evaluation.
32
35
 
@@ -38,6 +41,11 @@ def compute_metrics(
38
41
  Whether the dataset has MISC tags.
39
42
  dataset_config:
40
43
  The configuration of the dataset.
44
+ benchmark_config:
45
+ The configuration of the benchmark.
46
+ dataset:
47
+ The dataset used for evaluation. This is only used in case any additional
48
+ metadata is used to compute the metrics.
41
49
 
42
50
  Returns:
43
51
  A dictionary with the names of the metrics as keys and the metric values as
@@ -52,7 +60,9 @@ def compute_metrics(
52
60
 
53
61
  predictions: list[list[str]]
54
62
  if not isinstance(model_outputs[0][0], str):
55
- raw_predictions: list[list[int]] = np.argmax(model_outputs, axis=-1).tolist()
63
+ raw_predictions: c.Sequence[c.Sequence[int]] = np.argmax(
64
+ model_outputs, axis=-1
65
+ ).tolist()
56
66
 
57
67
  # Remove ignored index (special tokens)
58
68
  predictions = [
@@ -136,7 +146,13 @@ def compute_metrics(
136
146
  for metric in dataset_config.task.metrics
137
147
  if metric.name == "micro_f1"
138
148
  )
139
- micro_f1_score = metric(predictions=predictions, references=list(labels))
149
+ micro_f1_score = metric(
150
+ predictions=predictions,
151
+ references=list(labels),
152
+ dataset=dataset,
153
+ dataset_config=dataset_config,
154
+ benchmark_config=benchmark_config,
155
+ )
140
156
 
141
157
  # Compute the metrics without MISC tags
142
158
  # We manually set the F1 metric to be 100% if both the labels and the models
@@ -158,7 +174,11 @@ def compute_metrics(
158
174
  if metric.name == "micro_f1_no_misc"
159
175
  )
160
176
  micro_f1_no_misc_score = metric(
161
- predictions=predictions_no_misc, references=labels_no_misc
177
+ predictions=predictions_no_misc,
178
+ references=labels_no_misc,
179
+ dataset=dataset,
180
+ dataset_config=dataset_config,
181
+ benchmark_config=benchmark_config,
162
182
  )
163
183
 
164
184
  # Raise error if the metrics are invalid
@@ -172,7 +192,7 @@ def extract_labels_from_generation(
172
192
  input_batch: dict[str, list],
173
193
  model_output: "GenerativeModelOutput",
174
194
  dataset_config: "DatasetConfig",
175
- ) -> list[t.Any]:
195
+ ) -> c.Sequence[t.Any]:
176
196
  """Extract the predicted labels from the generated output.
177
197
 
178
198
  Args:
@@ -187,55 +207,31 @@ def extract_labels_from_generation(
187
207
  Returns:
188
208
  The predicted labels.
189
209
  """
190
- raw_predictions = model_output.sequences
191
-
192
- # Attempt to extract the JSON dictionary from the predictions
193
- json_regex = r"\{[^{}]+?\}"
194
- json_matches = [
195
- re.search(pattern=json_regex, string=raw_prediction, flags=re.DOTALL)
196
- or raw_prediction
197
- for raw_prediction in raw_predictions
198
- ]
199
- raw_predictions = [
200
- json_match.group() if isinstance(json_match, re.Match) else json_match
201
- for json_match in json_matches
202
- ]
203
-
204
210
  tokens = input_batch["tokens"]
205
211
  predicted_labels: list[list[str]] = [["o"] * len(token_ids) for token_ids in tokens]
206
- for idx, raw_prediction in enumerate(raw_predictions):
207
- try:
208
- json_output = demjson3.decode(txt=raw_prediction)
209
- if not isinstance(json_output, dict):
210
- logger.debug(
211
- "The model output is not a JSON dictionary, so cannot parse "
212
- f"it. Skipping. Here is the output: {raw_prediction}"
213
- )
214
- continue
215
- elif not all(isinstance(key, str) for key in json_output.keys()):
216
- logger.debug(
217
- "The model output is not a JSON dictionary with string keys, "
218
- "so cannot parse it. Skipping. Here is the output: "
219
- f"{raw_prediction}"
220
- )
221
- continue
222
- elif not all(isinstance(value, list) for value in json_output.values()):
223
- logger.debug(
224
- "The model output is not a JSON dictionary with list values, "
225
- "so cannot parse it. Skipping. Here is the output: "
226
- f"{raw_prediction}"
227
- )
228
- continue
229
- prediction_dict: dict[str, list[str]] = json_output
230
- except demjson3.JSONDecodeError:
231
- logger.debug(
232
- "The model output is not valid JSON, so cannot parse it. Skipping. "
233
- f"Here is the output: {raw_prediction!r}"
234
- )
212
+ for idx, raw_prediction in enumerate(model_output.sequences):
213
+ prediction_dict = extract_json_dict_from_string(s=raw_prediction)
214
+ if prediction_dict is None:
235
215
  continue
236
216
 
237
217
  prompt_label_mapping = dataset_config.prompt_label_mapping
238
218
  for prompt_tag_name, named_entities in prediction_dict.items():
219
+ if not isinstance(named_entities, list):
220
+ log(
221
+ "The model produced an invalid format for the named entities. "
222
+ f"Expected a list but got {type(named_entities)}. Skipping.",
223
+ level=logging.DEBUG,
224
+ )
225
+ continue
226
+ try:
227
+ named_entities = [str(ne) for ne in named_entities]
228
+ except Exception:
229
+ log(
230
+ "The model produced an invalid format for the named entities. "
231
+ f"Expected a list of strings but got {named_entities}. Skipping.",
232
+ level=logging.DEBUG,
233
+ )
234
+ continue
239
235
  try:
240
236
  tag_name = [
241
237
  tag[2:]
@@ -243,9 +239,10 @@ def extract_labels_from_generation(
243
239
  if prompt_tag == prompt_tag_name
244
240
  ][0]
245
241
  except IndexError:
246
- logger.debug(
242
+ log(
247
243
  "The model produced an invalid prompt tag name, "
248
- f"{prompt_tag_name}. Skipping."
244
+ f"{prompt_tag_name}. Skipping.",
245
+ level=logging.DEBUG,
249
246
  )
250
247
  continue
251
248
 
@@ -265,49 +262,49 @@ def extract_labels_from_generation(
265
262
 
266
263
 
267
264
  def tokenize_and_align_labels(
268
- examples: dict, tokenizer: "PreTrainedTokenizer", label2id: dict[str, int]
265
+ examples: dict, tokeniser: "PreTrainedTokenizer", label2id: dict[str, int]
269
266
  ) -> "BatchEncoding":
270
267
  """Tokenise all texts and align the labels with them.
271
268
 
272
269
  Args:
273
270
  examples:
274
271
  The examples to be tokenised.
275
- tokenizer:
276
- A pretrained tokenizer.
272
+ tokeniser:
273
+ A pretrained tokeniser.
277
274
  label2id:
278
275
  A dictionary that converts NER tags to IDs.
279
276
 
280
277
  Returns:
281
278
  A dictionary containing the tokenized data as well as labels.
282
279
  """
283
- # Tokenize the texts. We use the `is_split_into_words` argument here because
280
+ # Tokenise the texts. We use the `is_split_into_words` argument here because
284
281
  # the texts in our dataset are lists of words (with a label for each word)
285
- tokenized_inputs = tokenizer(
282
+ tokenized_inputs = tokeniser(
286
283
  examples["tokens"], is_split_into_words=True, truncation=True, padding=True
287
284
  )
288
285
 
289
286
  # Extract a mapping between all the tokens and their corresponding word. If the
290
- # tokenizer is of a "fast" variant then this can be accessed through the
287
+ # tokeniser is of a "fast" variant then this can be accessed through the
291
288
  # `word_ids` method. Otherwise, we have to extract it manually.
292
289
  all_labels: list[list[int]] = list()
293
- labels: list[str]
294
- word_ids: list[int | None]
290
+ labels: c.Sequence[str]
291
+ word_ids: c.Sequence[int | None]
295
292
  for i, labels in enumerate(examples["labels"]):
296
- # Try to get the word IDs from the tokenizer
293
+ # Try to get the word IDs from the tokeniser
297
294
  try:
298
295
  word_ids = tokenized_inputs.word_ids(batch_index=i)
299
296
 
300
- # If the tokenizer is not of a "fast" variant, we have to extract the word
297
+ # If the tokeniser is not of a "fast" variant, we have to extract the word
301
298
  # IDs manually
302
299
  except ValueError:
303
300
  # Get the list of words in the document
304
- words: list[str] = examples["tokens"][i]
301
+ words: c.Sequence[str] = examples["tokens"][i]
305
302
 
306
303
  # Get the list of token IDs in the document
307
- tok_ids: list[int] = tokenized_inputs.input_ids[i]
304
+ tok_ids: c.Sequence[int] = tokenized_inputs.input_ids[i]
308
305
 
309
306
  # Decode the token IDs
310
- tokens = tokenizer.convert_ids_to_tokens(tok_ids)
307
+ tokens = tokeniser.convert_ids_to_tokens(tok_ids)
311
308
  assert isinstance(tokens, list)
312
309
 
313
310
  # Remove prefixes from the tokens
@@ -319,14 +316,14 @@ def tokenize_and_align_labels(
319
316
  tokens[tok_idx] = tok[len(prefix) :]
320
317
 
321
318
  # Replace UNK tokens with the correct word
322
- tokens = handle_unk_tokens(tokenizer=tokenizer, tokens=tokens, words=words)
319
+ tokens = handle_unk_tokens(tokeniser=tokeniser, tokens=tokens, words=words)
323
320
 
324
- # Get list of special tokens. Some tokenizers do not record these
321
+ # Get list of special tokens. Some tokenisers do not record these
325
322
  # properly, which is why we convert the values to their indices and
326
323
  # then back to strings
327
324
  sp_toks = [
328
- tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(sp_tok))
329
- for sp_tok in tokenizer.special_tokens_map.values()
325
+ tokeniser.convert_ids_to_tokens(tokeniser.convert_tokens_to_ids(sp_tok))
326
+ for sp_tok in tokeniser.special_tokens_map.values()
330
327
  ]
331
328
 
332
329
  # Replace special tokens with `None`
@@ -350,7 +347,7 @@ def tokenize_and_align_labels(
350
347
  if len(word_idxs) != len(token_idxs):
351
348
  raise InvalidBenchmark(
352
349
  "The tokens could not be aligned with the words during manual "
353
- "word-token alignment. It seems that the tokenizer is neither "
350
+ "word-token alignment. It seems that the tokeniser is neither "
354
351
  "of the fast variant nor of a SentencePiece/WordPiece variant."
355
352
  )
356
353
 
@@ -380,9 +377,9 @@ def tokenize_and_align_labels(
380
377
  label = labels[word_id]
381
378
  try:
382
379
  label_id = label2id[label.lower()]
383
- except KeyError:
380
+ except KeyError as e:
384
381
  msg = f"The label {label} was not found in the model's config."
385
- raise InvalidBenchmark(msg)
382
+ raise InvalidBenchmark(msg) from e
386
383
  label_ids.append(label_id)
387
384
 
388
385
  # For the other tokens in a word, we set the label to -100
@@ -397,13 +394,13 @@ def tokenize_and_align_labels(
397
394
 
398
395
 
399
396
  def handle_unk_tokens(
400
- tokenizer: "PreTrainedTokenizer", tokens: list[str], words: list[str]
401
- ) -> list[str]:
397
+ tokeniser: "PreTrainedTokenizer", tokens: list[str], words: c.Sequence[str]
398
+ ) -> c.Sequence[str]:
402
399
  """Replace unknown tokens in the tokens with the corresponding word.
403
400
 
404
401
  Args:
405
- tokenizer:
406
- The tokenizer used to tokenize the words.
402
+ tokeniser:
403
+ The tokeniser used to tokenise the words.
407
404
  tokens:
408
405
  The list of tokens.
409
406
  words:
@@ -413,15 +410,15 @@ def handle_unk_tokens(
413
410
  The list of tokens with unknown tokens replaced by the corresponding word.
414
411
  """
415
412
  # Locate the token indices of the unknown tokens
416
- token_unk_idxs = [i for i, tok in enumerate(tokens) if tok == tokenizer.unk_token]
413
+ token_unk_idxs = [i for i, tok in enumerate(tokens) if tok == tokeniser.unk_token]
417
414
 
418
415
  # Locate the word indices of the words which contain an unknown token
419
416
  word_unk_idxs = [
420
417
  i
421
418
  for i, word in enumerate(words)
422
- if tokenizer.unk_token
423
- in tokenizer.convert_ids_to_tokens(
424
- tokenizer.encode(word, add_special_tokens=False)
419
+ if tokeniser.unk_token
420
+ in tokeniser.convert_ids_to_tokens(
421
+ tokeniser.encode(word, add_special_tokens=False)
425
422
  )
426
423
  ]
427
424
 
@@ -430,9 +427,9 @@ def handle_unk_tokens(
430
427
  # Fetch the word
431
428
  word = words[word_idx]
432
429
 
433
- # Tokenize the word, which is now a list containing at least one UNK token
434
- tokens_with_unk = tokenizer.convert_ids_to_tokens(
435
- tokenizer.encode(word, add_special_tokens=False)
430
+ # Tokenise the word, which is now a list containing at least one UNK token
431
+ tokens_with_unk = tokeniser.convert_ids_to_tokens(
432
+ tokeniser.encode(word, add_special_tokens=False)
436
433
  )
437
434
 
438
435
  # Iterate over the tokens in the word
@@ -441,10 +438,10 @@ def handle_unk_tokens(
441
438
  # of the content of this token from the word. The result of the `word`
442
439
  # variable will be the content of the UNK token.
443
440
  # NOTE: This is a bit hacky and not bulletproof. For instance, if the
444
- # word is "1925-1950" and the tokenizer splits it into ["[UNK]", "-",
441
+ # word is "1925-1950" and the tokeniser splits it into ["[UNK]", "-",
445
442
  # "19", "50"], then the result will be 2519 instead of 1925. This
446
443
  # happens almost never, however, so we can live with it.
447
- if possible_unk_token != tokenizer.unk_token:
444
+ if possible_unk_token != tokeniser.unk_token:
448
445
  word = word.replace(possible_unk_token, "", 1)
449
446
 
450
447
  # Replace the token with the word
euroeval/tasks.py CHANGED
@@ -1,35 +1,29 @@
1
1
  """All benchmarks tasks used in EuroEval."""
2
2
 
3
3
  from . import metrics as m
4
+ from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
4
5
  from .data_models import Task
5
- from .enums import TaskGroup
6
+ from .enums import GenerativeType, ModelType, TaskGroup
6
7
  from .prompt_templates import (
8
+ CLASSIFICATION_TEMPLATES,
7
9
  LA_TEMPLATES,
8
10
  MULTIPLE_CHOICE_TEMPLATES,
9
11
  NER_TEMPLATES,
10
12
  RC_TEMPLATES,
11
13
  SENT_TEMPLATES,
12
14
  SUMM_TEMPLATES,
15
+ TOKEN_CLASSIFICATION_TEMPLATES,
13
16
  )
14
17
 
15
-
16
- def get_all_tasks() -> dict[str, Task]:
17
- """Get a list of all the dataset tasks.
18
-
19
- Returns:
20
- A mapping between names of dataset tasks and their configurations.
21
- """
22
- return {cfg.name: cfg for cfg in globals().values() if isinstance(cfg, Task)}
23
-
24
-
25
18
  LA = Task(
26
19
  name="linguistic-acceptability",
27
20
  task_group=TaskGroup.SEQUENCE_CLASSIFICATION,
28
21
  template_dict=LA_TEMPLATES,
29
22
  metrics=[m.mcc_metric, m.macro_f1_metric],
30
23
  default_num_few_shot_examples=12,
31
- default_max_generated_tokens=5,
24
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
32
25
  default_labels=["correct", "incorrect"],
26
+ uses_logprobs=True,
33
27
  )
34
28
 
35
29
 
@@ -51,6 +45,7 @@ NER = Task(
51
45
  "b-misc",
52
46
  "i-misc",
53
47
  ],
48
+ uses_structured_output=True,
54
49
  )
55
50
 
56
51
 
@@ -71,8 +66,9 @@ SENT = Task(
71
66
  template_dict=SENT_TEMPLATES,
72
67
  metrics=[m.mcc_metric, m.macro_f1_metric],
73
68
  default_num_few_shot_examples=12,
74
- default_max_generated_tokens=5,
69
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
75
70
  default_labels=["positive", "neutral", "negative"],
71
+ uses_logprobs=True,
76
72
  )
77
73
 
78
74
 
@@ -84,6 +80,7 @@ SUMM = Task(
84
80
  default_num_few_shot_examples=1,
85
81
  default_max_generated_tokens=256,
86
82
  default_labels=[],
83
+ default_allowed_model_types=[ModelType.GENERATIVE],
87
84
  )
88
85
 
89
86
 
@@ -93,8 +90,10 @@ KNOW = Task(
93
90
  template_dict=MULTIPLE_CHOICE_TEMPLATES,
94
91
  metrics=[m.mcc_metric, m.accuracy_metric],
95
92
  default_num_few_shot_examples=5,
96
- default_max_generated_tokens=5,
93
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
97
94
  default_labels=["a", "b", "c", "d"],
95
+ default_allowed_model_types=[ModelType.GENERATIVE],
96
+ uses_logprobs=True,
98
97
  )
99
98
 
100
99
 
@@ -104,8 +103,10 @@ MCRC = Task(
104
103
  template_dict=MULTIPLE_CHOICE_TEMPLATES,
105
104
  metrics=[m.mcc_metric, m.accuracy_metric],
106
105
  default_num_few_shot_examples=5,
107
- default_max_generated_tokens=5,
106
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
108
107
  default_labels=["a", "b", "c", "d"],
108
+ default_allowed_model_types=[ModelType.GENERATIVE],
109
+ uses_logprobs=True,
109
110
  )
110
111
 
111
112
 
@@ -115,8 +116,29 @@ COMMON_SENSE = Task(
115
116
  template_dict=MULTIPLE_CHOICE_TEMPLATES,
116
117
  metrics=[m.mcc_metric, m.accuracy_metric],
117
118
  default_num_few_shot_examples=5,
118
- default_max_generated_tokens=5,
119
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
119
120
  default_labels=["a", "b", "c", "d"],
121
+ default_allowed_model_types=[ModelType.GENERATIVE],
122
+ uses_logprobs=True,
123
+ )
124
+
125
+
126
+ EUROPEAN_VALUES = Task(
127
+ name="european-values",
128
+ task_group=TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION,
129
+ template_dict=MULTIPLE_CHOICE_TEMPLATES,
130
+ metrics=[m.european_values_metric],
131
+ default_num_few_shot_examples=0,
132
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
133
+ default_labels=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"],
134
+ default_allowed_model_types=[ModelType.GENERATIVE],
135
+ default_allowed_generative_types=[
136
+ GenerativeType.INSTRUCTION_TUNED,
137
+ GenerativeType.REASONING,
138
+ ],
139
+ requires_zero_shot=True,
140
+ uses_logprobs=True,
141
+ default_allow_invalid_model_outputs=False,
120
142
  )
121
143
 
122
144
 
@@ -129,3 +151,40 @@ SPEED = Task(
129
151
  default_max_generated_tokens=5,
130
152
  default_labels=[],
131
153
  )
154
+
155
+
156
+ # Used for custom datasets
157
+
158
+ TEXT_CLASSIFICATION = Task(
159
+ name="classification",
160
+ task_group=TaskGroup.SEQUENCE_CLASSIFICATION,
161
+ template_dict=CLASSIFICATION_TEMPLATES,
162
+ metrics=[m.mcc_metric, m.macro_f1_metric],
163
+ default_num_few_shot_examples=12,
164
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
165
+ default_labels=None,
166
+ uses_logprobs=True,
167
+ )
168
+
169
+ TOKEN_CLASSIFICATION = Task(
170
+ name="token-classification",
171
+ task_group=TaskGroup.TOKEN_CLASSIFICATION,
172
+ template_dict=TOKEN_CLASSIFICATION_TEMPLATES,
173
+ metrics=[m.micro_f1_metric],
174
+ default_num_few_shot_examples=8,
175
+ default_max_generated_tokens=128,
176
+ default_labels=None,
177
+ uses_structured_output=True,
178
+ )
179
+
180
+ MULTIPLE_CHOICE = Task(
181
+ name="multiple-choice",
182
+ task_group=TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION,
183
+ template_dict=MULTIPLE_CHOICE_TEMPLATES,
184
+ metrics=[m.mcc_metric, m.accuracy_metric],
185
+ default_num_few_shot_examples=5,
186
+ default_max_generated_tokens=NUM_GENERATION_TOKENS_FOR_CLASSIFICATION,
187
+ default_labels=None,
188
+ default_allowed_model_types=[ModelType.GENERATIVE],
189
+ uses_logprobs=True,
190
+ )