EuroEval 15.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of EuroEval might be problematic. Click here for more details.

Files changed (40) hide show
  1. euroeval/__init__.py +72 -0
  2. euroeval/benchmark_config_factory.py +358 -0
  3. euroeval/benchmark_modules/__init__.py +7 -0
  4. euroeval/benchmark_modules/base.py +354 -0
  5. euroeval/benchmark_modules/fresh.py +286 -0
  6. euroeval/benchmark_modules/hf.py +1185 -0
  7. euroeval/benchmark_modules/litellm.py +905 -0
  8. euroeval/benchmark_modules/vllm.py +1171 -0
  9. euroeval/benchmarker.py +1074 -0
  10. euroeval/callbacks.py +72 -0
  11. euroeval/cli.py +281 -0
  12. euroeval/constants.py +50 -0
  13. euroeval/data_loading.py +96 -0
  14. euroeval/data_models.py +474 -0
  15. euroeval/dataset_configs.py +2001 -0
  16. euroeval/enums.py +144 -0
  17. euroeval/exceptions.py +191 -0
  18. euroeval/finetuning.py +324 -0
  19. euroeval/generation.py +296 -0
  20. euroeval/human_evaluation.py +737 -0
  21. euroeval/languages.py +200 -0
  22. euroeval/model_cache.py +253 -0
  23. euroeval/model_config.py +77 -0
  24. euroeval/model_loading.py +78 -0
  25. euroeval/scores.py +90 -0
  26. euroeval/speed_benchmark.py +124 -0
  27. euroeval/task_utils/__init__.py +1 -0
  28. euroeval/task_utils/multiple_choice_classification.py +176 -0
  29. euroeval/task_utils/question_answering.py +698 -0
  30. euroeval/task_utils/sequence_classification.py +237 -0
  31. euroeval/task_utils/text_to_text.py +150 -0
  32. euroeval/task_utils/token_classification.py +464 -0
  33. euroeval/tasks.py +202 -0
  34. euroeval/types.py +97 -0
  35. euroeval/utils.py +574 -0
  36. euroeval-15.2.0.dist-info/METADATA +234 -0
  37. euroeval-15.2.0.dist-info/RECORD +40 -0
  38. euroeval-15.2.0.dist-info/WHEEL +4 -0
  39. euroeval-15.2.0.dist-info/entry_points.txt +4 -0
  40. euroeval-15.2.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,698 @@
1
+ """Utility functions related to the question-answering task group."""
2
+
3
+ import collections.abc as c
4
+ import logging
5
+ import typing as t
6
+ from collections import defaultdict
7
+
8
+ import evaluate
9
+ import numpy as np
10
+ from evaluate import EvaluationModule
11
+ from transformers import PreTrainedTokenizer
12
+ from transformers.trainer import Trainer
13
+
14
+ from ..data_models import BenchmarkConfig, DatasetConfig, GenerativeModelOutput
15
+ from ..utils import (
16
+ get_special_token_metadata,
17
+ raise_if_model_output_contains_nan_values,
18
+ )
19
+
20
+ if t.TYPE_CHECKING:
21
+ import torch.nn as nn
22
+ from datasets.arrow_dataset import Dataset
23
+ from transformers import (
24
+ BaseImageProcessor,
25
+ EvalPrediction,
26
+ FeatureExtractionMixin,
27
+ PreTrainedModel,
28
+ PreTrainedTokenizerBase,
29
+ ProcessorMixin,
30
+ TrainerCallback,
31
+ TrainingArguments,
32
+ )
33
+ from transformers.tokenization_utils_base import BatchEncoding
34
+
35
+ from ..types import Labels, Predictions
36
+
37
+ logger = logging.getLogger("euroeval")
38
+
39
+
40
+ class QuestionAnsweringTrainer(Trainer):
41
+ """Trainer subclass for question answering tasks."""
42
+
43
+ def __init__(
44
+ self,
45
+ model: "PreTrainedModel | nn.Module",
46
+ processing_class: "PreTrainedTokenizerBase",
47
+ args: "TrainingArguments",
48
+ train_dataset: "Dataset",
49
+ eval_dataset: "Dataset",
50
+ compute_metrics: "c.Callable[[EvalPrediction], dict[str, float]]",
51
+ callbacks: "list[TrainerCallback]",
52
+ data_collator: "c.Callable",
53
+ ) -> None:
54
+ """Initialize the trainer."""
55
+ super().__init__(
56
+ model=model,
57
+ processing_class=processing_class,
58
+ args=args,
59
+ train_dataset=train_dataset,
60
+ eval_dataset=eval_dataset,
61
+ compute_metrics=compute_metrics,
62
+ callbacks=callbacks,
63
+ data_collator=data_collator,
64
+ )
65
+
66
+ # Get the CLS token id for the tokenizer
67
+ if self.tokenizer is not None:
68
+ assert isinstance(self.tokenizer, PreTrainedTokenizer)
69
+ special_token_metadata = get_special_token_metadata(self.tokenizer)
70
+ self.cls_token_id = special_token_metadata["cls_token_id"]
71
+
72
+ # Set the label names
73
+ self.label_names = ["start_positions", "end_positions"]
74
+
75
+ def evaluate(
76
+ self,
77
+ eval_dataset: "Dataset | None" = None,
78
+ orig_eval_dataset: "Dataset | None" = None,
79
+ ignore_keys: list[str] | None = None,
80
+ metric_key_prefix: str = "eval",
81
+ ) -> dict[str, float] | None:
82
+ """Evaluate the model on the given dataset.
83
+
84
+ Args:
85
+ eval_dataset:
86
+ The dataset to evaluate on. If None, then use the stored evaluation
87
+ dataset.
88
+ orig_eval_dataset:
89
+ The original evaluation dataset, before any postprocessing. If None,
90
+ then use the stored original evaluation dataset.
91
+ ignore_keys:
92
+ The keys to ignore when computing the metrics.
93
+ metric_key_prefix:
94
+ The prefix to use for the metric keys.
95
+
96
+ Returns:
97
+ The metrics computed on the evaluation dataset.
98
+ """
99
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
100
+
101
+ # Temporarily disable metric computation, we will do it in the loop here.
102
+ compute_metrics = self.compute_metrics # type: ignore[has-type]
103
+ self.compute_metrics = None
104
+ eval_loop = (
105
+ self.prediction_loop
106
+ if self.args.use_legacy_prediction_loop
107
+ else self.evaluation_loop
108
+ )
109
+ try:
110
+ output = eval_loop(
111
+ eval_dataloader,
112
+ description="Evaluation",
113
+ prediction_loss_only=True if compute_metrics is None else None,
114
+ ignore_keys=ignore_keys,
115
+ metric_key_prefix=metric_key_prefix,
116
+ )
117
+ finally:
118
+ self.compute_metrics = compute_metrics
119
+
120
+ if orig_eval_dataset is not None:
121
+ preds_and_labels = postprocess_predictions_and_labels(
122
+ predictions=output.predictions,
123
+ dataset=orig_eval_dataset,
124
+ prepared_dataset=eval_dataset,
125
+ cls_token_index=self.cls_token_id,
126
+ )
127
+ output.metrics.update(self.compute_metrics(preds_and_labels))
128
+
129
+ # Prefix all keys with metric_key_prefix + '_'
130
+ for key in list(output.metrics.keys()):
131
+ if not key.startswith(f"{metric_key_prefix}_"):
132
+ output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(
133
+ key
134
+ )
135
+
136
+ # Only the main node log the results by default
137
+ if self.args.should_log:
138
+ self.log(output.metrics)
139
+
140
+ self.control = self.callback_handler.on_evaluate(
141
+ self.args,
142
+ self.state,
143
+ self.control, # type: ignore[has-type]
144
+ output.metrics,
145
+ )
146
+ return output.metrics
147
+
148
+
149
+ def compute_metrics(
150
+ model_outputs_and_labels: tuple["Predictions", "Labels"],
151
+ dataset_config: "DatasetConfig",
152
+ benchmark_config: "BenchmarkConfig",
153
+ ) -> dict[str, float]:
154
+ """Compute the metrics needed for evaluation.
155
+
156
+ Args:
157
+ model_outputs_and_labels:
158
+ The first sequence contains the model outputs and the second sequence
159
+ contains the true labels.
160
+ dataset_config:
161
+ The configuration of the dataset.
162
+ benchmark_config:
163
+ The configuration of the benchmark.
164
+
165
+ Returns:
166
+ A dictionary with the names of the metrics as keys and the metric values as
167
+ values.
168
+ """
169
+ model_outputs, labels = model_outputs_and_labels
170
+ raise_if_model_output_contains_nan_values(model_output=model_outputs)
171
+
172
+ metrics = {
173
+ metric_cfg.name: (
174
+ evaluate.load(
175
+ path=metric_cfg.huggingface_id, cache_dir=benchmark_config.cache_dir
176
+ )
177
+ if metric_cfg.huggingface_id != ""
178
+ else None
179
+ )
180
+ for metric_cfg in dataset_config.task.metrics
181
+ }
182
+
183
+ model_output_dtype = np.asarray(model_outputs).dtype
184
+ if model_output_dtype in [np.float16, np.float32, np.float64]:
185
+ predictions = np.asarray(model_outputs).argmax(axis=-1)
186
+ else:
187
+ predictions = model_outputs
188
+
189
+ results: dict[str, float] = dict()
190
+ for cfg in dataset_config.task.metrics:
191
+ metric = metrics[cfg.name]
192
+ assert isinstance(metric, EvaluationModule)
193
+ score_dict: dict[str, float] | None = metric.compute(
194
+ predictions=predictions, references=labels, **cfg.compute_kwargs
195
+ )
196
+
197
+ # The metric returns None if we are running on multi-GPU and the current
198
+ # process is not the main process
199
+ if score_dict is not None:
200
+ scores = score_dict[cfg.results_key]
201
+ if isinstance(scores, list):
202
+ scores = sum(scores) / len(scores)
203
+ results[cfg.name] = scores
204
+
205
+ return results
206
+
207
+
208
+ def extract_labels_from_generation(
209
+ input_batch: dict[str, list], model_output: "GenerativeModelOutput"
210
+ ) -> list[t.Any]:
211
+ """Extract the predicted labels from the generated output.
212
+
213
+ Args:
214
+ input_batch:
215
+ The input batch, where the keys are the feature names and the values
216
+ are lists with the feature values.
217
+ model_output:
218
+ The raw generated output of the model.
219
+
220
+ Returns:
221
+ The predicted labels.
222
+ """
223
+ raw_predictions = model_output.sequences
224
+ predictions = [
225
+ dict(id=id, prediction_text=predicted_answer.lower(), no_answer_probability=0.0)
226
+ for id, predicted_answer in zip(input_batch["id"], raw_predictions)
227
+ ]
228
+ return predictions
229
+
230
+
231
+ def prepare_train_examples(
232
+ examples: "BatchEncoding", tokenizer: "PreTrainedTokenizer"
233
+ ) -> "BatchEncoding":
234
+ """Prepare the features for training.
235
+
236
+ Args:
237
+ examples:
238
+ The examples to prepare.
239
+ tokenizer:
240
+ The tokenizer to use to prepare the examples.
241
+
242
+ Returns:
243
+ The prepared examples.
244
+ """
245
+ # Some of the questions have lots of whitespace on the left, which is not useful
246
+ # and will make the truncation of the context fail (the tokenized question will
247
+ # take a lots of space). So we remove that left whitespace
248
+ examples["question"] = [q.lstrip() for q in examples["question"]]
249
+
250
+ # Extract special token metadata from the tokenizer
251
+ special_token_metadata = get_special_token_metadata(tokenizer=tokenizer)
252
+ has_cls_token = special_token_metadata["has_cls_token"]
253
+ has_sep_token = special_token_metadata["has_sep_token"]
254
+ cls_token_id = special_token_metadata["cls_token_id"]
255
+ cls_token = special_token_metadata["cls_token"]
256
+ sep_token = special_token_metadata["sep_token"]
257
+
258
+ # If the tokenizer is not adding special tokens, then we add them manually
259
+ if not has_cls_token and not has_sep_token:
260
+ examples["question"] = [
261
+ f"{cls_token}{q}{sep_token}" for q in examples["question"]
262
+ ]
263
+ examples["context"] = [f"{c}{sep_token}" for c in examples["context"]]
264
+
265
+ # Set the stride used during tokenization, when the context is long enough to be
266
+ # split into several features. Since we are always keeping the question tokens, we
267
+ # need to make sure that the stride does not exceed the resulting maximum context
268
+ # length.
269
+ max_question_tokens = max(len(tokenizer(q).input_ids) for q in examples["question"])
270
+ num_special_tokens = int(has_cls_token) + int(has_sep_token)
271
+ stride = tokenizer.model_max_length // 4
272
+ max_length = tokenizer.model_max_length - stride
273
+ stride = min(stride, max_length - max_question_tokens - num_special_tokens)
274
+ max_length = tokenizer.model_max_length - stride
275
+
276
+ # Tokenize our examples with truncation and padding, but keep the overflows using a
277
+ # stride. This results in one example possible giving several features when a
278
+ # context is long, each of those features having a context that overlaps a bit the
279
+ # context of the previous feature.
280
+ tokenized_examples = tokenizer(
281
+ text=examples["question"],
282
+ text_pair=examples["context"],
283
+ truncation="only_second",
284
+ max_length=max_length,
285
+ stride=stride,
286
+ return_overflowing_tokens=True,
287
+ return_offsets_mapping=True,
288
+ padding="max_length",
289
+ )
290
+
291
+ # Since one example might give us several features if it has a long context, we
292
+ # need a map from a feature to its corresponding example. This key gives us just
293
+ # that
294
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
295
+
296
+ # The offset mappings will give us a map from token to character position in the
297
+ # original context. This will help us compute the start_positions and
298
+ # end_positions.
299
+ offset_mapping = tokenized_examples.pop("offset_mapping")
300
+
301
+ # Initialise the start- and end positions of the answers
302
+ tokenized_examples["start_positions"] = list()
303
+ tokenized_examples["end_positions"] = list()
304
+
305
+ for i, offsets in enumerate(offset_mapping):
306
+ # Get the input IDs for the current example
307
+ input_ids = tokenized_examples.input_ids[i]
308
+
309
+ # We will label impossible answers with the index of the CLS token
310
+ cls_index = input_ids.index(cls_token_id)
311
+
312
+ # Grab the sequence corresponding to that example (to know what is the context
313
+ # and what is the question).
314
+ sequence_ids = tokenized_examples.sequence_ids(i)
315
+
316
+ # Manually ensure that the special tokens are set to None in `sequence_ids`
317
+ for special_token in tokenizer.special_tokens_map.keys():
318
+ if hasattr(tokenizer, f"{special_token}_id"):
319
+ special_token_id = getattr(tokenizer, f"{special_token}_id")
320
+ if special_token_id is not None:
321
+ sequence_ids = [
322
+ None if token_id == special_token_id else seq_id
323
+ for token_id, seq_id in zip(input_ids, sequence_ids)
324
+ ]
325
+
326
+ # One example can give several spans, this is the index of the example
327
+ # containing this span of text.
328
+ sample_index = sample_mapping[i]
329
+ answers = examples["answers"][sample_index]
330
+
331
+ # If no answers are given, set the cls_index as answer.
332
+ if len(answers["answer_start"]) == 0:
333
+ tokenized_examples.start_positions.append(cls_index)
334
+ tokenized_examples.end_positions.append(cls_index)
335
+
336
+ else:
337
+ # Start/end character index of the answer in the text.
338
+ start_char = answers["answer_start"][0]
339
+ end_char = start_char + len(answers["text"][0])
340
+
341
+ # Start token index of the current span in the text.
342
+ token_start_index = 0
343
+ while sequence_ids[token_start_index] != 1:
344
+ token_start_index += 1
345
+
346
+ # End token index of the current span in the text.
347
+ token_end_index = len(input_ids) - 1
348
+ while sequence_ids[token_end_index] != 1:
349
+ token_end_index -= 1
350
+
351
+ # Detect if the answer is out of the span (in which case this feature is
352
+ # labeled with the CLS index).
353
+ if not (
354
+ offsets[token_start_index][0] <= start_char
355
+ and offsets[token_end_index][1] >= end_char
356
+ ):
357
+ tokenized_examples.start_positions.append(cls_index)
358
+ tokenized_examples.end_positions.append(cls_index)
359
+
360
+ # Otherwise move the token_start_index and token_end_index to the two ends
361
+ # of the answer. Note: we could go after the last offset if the answer is
362
+ # the last word (edge case).
363
+ else:
364
+ while (
365
+ token_start_index <= token_end_index
366
+ and offsets[token_start_index][0] <= start_char
367
+ ):
368
+ token_start_index += 1
369
+ token_start_index -= 1
370
+ tokenized_examples.start_positions.append(token_start_index)
371
+ while (
372
+ token_start_index <= token_end_index
373
+ and offsets[token_end_index][1] >= end_char
374
+ ):
375
+ token_end_index -= 1
376
+ token_end_index += 1
377
+ tokenized_examples.end_positions.append(token_end_index)
378
+ assert token_end_index >= token_start_index
379
+
380
+ return tokenized_examples
381
+
382
+
383
+ def prepare_test_examples(
384
+ examples: "BatchEncoding", tokenizer: "PreTrainedTokenizer"
385
+ ) -> "BatchEncoding":
386
+ """Prepare test examples.
387
+
388
+ Args:
389
+ examples:
390
+ Dictionary of test examples.
391
+ tokenizer:
392
+ The tokenizer used to preprocess the examples.
393
+
394
+ Returns:
395
+ The prepared test examples.
396
+ """
397
+ # Some of the questions have lots of whitespace on the left, which is not useful
398
+ # and will make the truncation of the context fail (the tokenized question will
399
+ # take a lots of space). So we remove that left whitespace
400
+ examples["question"] = [q.lstrip() for q in examples["question"]]
401
+
402
+ # Extract special token metadata from the tokenizer
403
+ special_token_metadata = get_special_token_metadata(tokenizer=tokenizer)
404
+ has_cls_token = special_token_metadata["has_cls_token"]
405
+ has_sep_token = special_token_metadata["has_sep_token"]
406
+ cls_token = special_token_metadata["cls_token"]
407
+ sep_token = special_token_metadata["sep_token"]
408
+
409
+ # If the tokenizer is not adding special tokens, then we add them manually
410
+ if not has_cls_token and not has_sep_token:
411
+ examples["question"] = [
412
+ f"{cls_token}{q}{sep_token}" for q in examples["question"]
413
+ ]
414
+ examples["context"] = [f"{c}{sep_token}" for c in examples["context"]]
415
+
416
+ # Set the stride used during tokenization, when the context is long enough to be
417
+ # split into several features. Since we are always keeping the question tokens, we
418
+ # need to make sure that the stride does not exceed the resulting maximum context
419
+ # length.
420
+ max_question_tokens = max(len(tokenizer(q).input_ids) for q in examples["question"])
421
+ num_special_tokens = int(has_cls_token) + int(has_sep_token)
422
+ stride = tokenizer.model_max_length // 4
423
+ max_length = tokenizer.model_max_length - stride
424
+ stride = min(stride, max_length - max_question_tokens - num_special_tokens)
425
+ max_length = tokenizer.model_max_length - stride
426
+
427
+ # Tokenize our examples with truncation and maybe padding, but keep the overflows
428
+ # using a stride. This results in one example possible giving several features when
429
+ # a context is long, each of those features having a context that overlaps a bit
430
+ # the context of the previous feature.
431
+ tokenized_examples = tokenizer(
432
+ text=examples["question"],
433
+ text_pair=examples["context"],
434
+ truncation="only_second",
435
+ max_length=max_length,
436
+ stride=stride,
437
+ return_overflowing_tokens=True,
438
+ return_offsets_mapping=True,
439
+ padding="max_length",
440
+ )
441
+
442
+ # Since one example might give us several features if it has a long context, we
443
+ # need a map from a feature to its corresponding example. This key gives us just
444
+ # that.
445
+ sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
446
+
447
+ # We keep the id that gave us this feature and we will store the offset mappings.
448
+ tokenized_examples["id"] = list()
449
+
450
+ for i in range(len(tokenized_examples.input_ids)):
451
+ # Grab the sequence corresponding to that example (to know what is the context
452
+ # and what is the question).
453
+ sequence_ids = tokenized_examples.sequence_ids(i)
454
+ context_index = 1
455
+
456
+ # One example can give several spans, this is the index of the example
457
+ # containing this span of text.
458
+ sample_index = sample_mapping[i]
459
+ tokenized_examples.id.append(examples["id"][sample_index])
460
+
461
+ # Set to (-1, -1) the offset_mapping that are not part of the context so it's
462
+ # easy to determine if a token position is part of the context or not.
463
+ tokenized_examples.offset_mapping[i] = [
464
+ (o if sequence_ids[k] == context_index else (-1, -1))
465
+ for k, o in enumerate(tokenized_examples.offset_mapping[i])
466
+ ]
467
+
468
+ return tokenized_examples
469
+
470
+
471
+ def postprocess_predictions_and_labels(
472
+ predictions: list,
473
+ dataset: "Dataset",
474
+ prepared_dataset: "Dataset",
475
+ cls_token_index: int,
476
+ ) -> tuple[list[dict], list[dict]]:
477
+ """Postprocess the predictions and labels, to allow easier metric computation.
478
+
479
+ Args:
480
+ predictions:
481
+ A pair of (start_logits, end_logits) predictions.
482
+ dataset:
483
+ The dataset containing the examples.
484
+ prepared_dataset:
485
+ The dataset containing the prepared examples.
486
+ cls_token_index:
487
+ The index of the CLS token.
488
+
489
+ Returns:
490
+ The postprocessed predictions and labels.
491
+ """
492
+ # Extract the logits from the predictions
493
+ all_start_logits = predictions[0]
494
+ all_end_logits = predictions[1]
495
+
496
+ # Build a map from an example to its corresponding features, being the blocks of
497
+ # text from the context that we're feeding into the model. An example can have
498
+ # multiple features/blocks if it has a long context.
499
+ id_to_index = {k: i for i, k in enumerate(dataset["id"])}
500
+ features_per_example = defaultdict(list)
501
+ for i, feature in enumerate(prepared_dataset):
502
+ id = feature["id"]
503
+ example_index = id_to_index[id]
504
+ features_per_example[example_index].append(i)
505
+
506
+ # Loop over all the examples
507
+ predictions = list()
508
+ labels = list()
509
+ for example_index, example in enumerate(dataset):
510
+ # Extract the best valid answer associated with the current example
511
+ best_answer = find_best_answer(
512
+ all_start_logits=all_start_logits,
513
+ all_end_logits=all_end_logits,
514
+ prepared_dataset=prepared_dataset,
515
+ feature_indices=features_per_example[example_index],
516
+ context=example["context"],
517
+ max_answer_length=30,
518
+ num_best_logits=20,
519
+ min_null_score=0.0,
520
+ cls_token_index=cls_token_index,
521
+ )
522
+
523
+ # Create the final prediction dictionary, to be added to the list of
524
+ # predictions
525
+ prediction = dict(
526
+ id=example["id"], prediction_text=best_answer, no_answer_probability=0.0
527
+ )
528
+
529
+ # Add the answer to the list of predictions
530
+ predictions.append(prediction)
531
+
532
+ # Create the associated reference dictionary, to be added to the list of
533
+ # references
534
+ label = dict(
535
+ id=example["id"],
536
+ answers=dict(
537
+ text=example["answers"]["text"],
538
+ answer_start=example["answers"]["answer_start"],
539
+ ),
540
+ )
541
+
542
+ # Add the answer and label to the list of predictions and labels, respectively
543
+ labels.append(label)
544
+
545
+ return predictions, labels
546
+
547
+
548
+ def find_best_answer(
549
+ all_start_logits: np.ndarray,
550
+ all_end_logits: np.ndarray,
551
+ prepared_dataset: "Dataset",
552
+ feature_indices: list[int],
553
+ context: str,
554
+ max_answer_length: int,
555
+ num_best_logits: int,
556
+ min_null_score: float,
557
+ cls_token_index: int,
558
+ ) -> str:
559
+ """Find the best answer for a given example.
560
+
561
+ Args:
562
+ all_start_logits:
563
+ The start logits for all the features.
564
+ all_end_logits:
565
+ The end logits for all the features.
566
+ prepared_dataset:
567
+ The dataset containing the prepared examples.
568
+ feature_indices:
569
+ The indices of the features associated with the current example.
570
+ context:
571
+ The context of the example.
572
+ max_answer_length:
573
+ The maximum length of the answer.
574
+ num_best_logits:
575
+ The number of best logits to consider.
576
+ min_null_score:
577
+ The minimum score an answer can have.
578
+ cls_token_index:
579
+ The index of the CLS token.
580
+
581
+ Returns:
582
+ The best answer for the example.
583
+ """
584
+ # Loop through all the features associated to the current example
585
+ valid_answers = list()
586
+ for feature_index in feature_indices:
587
+ # Get the features associated with the current example
588
+ features = prepared_dataset[feature_index]
589
+
590
+ # Get the predictions of the model for this feature
591
+ start_logits = all_start_logits[feature_index]
592
+ end_logits = all_end_logits[feature_index]
593
+
594
+ # Update minimum null prediction
595
+ cls_index = features["input_ids"].index(cls_token_index)
596
+ feature_null_score = (start_logits[cls_index] + end_logits[cls_index]).item()
597
+ if min_null_score < feature_null_score:
598
+ min_null_score = feature_null_score
599
+
600
+ # Find the valid answers for the feature
601
+ valid_answers_for_feature = find_valid_answers(
602
+ start_logits=start_logits,
603
+ end_logits=end_logits,
604
+ offset_mapping=features["offset_mapping"],
605
+ context=context,
606
+ max_answer_length=max_answer_length,
607
+ num_best_logits=num_best_logits,
608
+ min_null_score=min_null_score,
609
+ )
610
+ valid_answers.extend(valid_answers_for_feature)
611
+
612
+ # In the very rare edge case we have not a single non-null prediction, we create a
613
+ # fake prediction to avoid failure
614
+ if not valid_answers:
615
+ return ""
616
+
617
+ # Otherwise, we select the answer with the largest score as the best answer, and
618
+ # return it
619
+ best_answer_dict = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
620
+ return best_answer_dict["text"]
621
+
622
+
623
+ def find_valid_answers(
624
+ start_logits: np.ndarray,
625
+ end_logits: np.ndarray,
626
+ offset_mapping: list[tuple[int, int]],
627
+ context: str,
628
+ max_answer_length: int,
629
+ num_best_logits: int,
630
+ min_null_score: float,
631
+ ) -> list[dict]:
632
+ """Find the valid answers from the start and end indexes.
633
+
634
+ Args:
635
+ start_logits:
636
+ The logits for the start of the answer.
637
+ end_logits:
638
+ The logits for the end of the answer.
639
+ offset_mapping:
640
+ The offset mapping, being a list of pairs of integers for each token index,
641
+ containing the start and end character index in the original context.
642
+ context:
643
+ The context of the example.
644
+ max_answer_length:
645
+ The maximum length of the answer.
646
+ num_best_logits:
647
+ The number of best logits to consider. Note that this function will run in
648
+ O(`num_best_logits` ^ 2) time.
649
+ min_null_score:
650
+ The minimum score an answer can have.
651
+
652
+ Returns:
653
+ A list of the valid answers, each being a dictionary with keys "text" and
654
+ "score", the score being the sum of the start and end logits.
655
+ """
656
+ # Fetch the top-k predictions for the start- and end token indices
657
+ start_indexes = np.argsort(start_logits)[-1 : -num_best_logits - 1 : -1].tolist()
658
+ end_indexes = np.argsort(end_logits)[-1 : -num_best_logits - 1 : -1].tolist()
659
+
660
+ # We loop over all combinations of starting and ending indexes for valid answers
661
+ valid_answers = list()
662
+ for start_index in start_indexes:
663
+ for end_index in end_indexes:
664
+ # If the starting or ending index is out-of-scope, meaning that they are
665
+ # either out of bounds or correspond to part of the input_ids that are not
666
+ # in the context, then we skip this index
667
+ if (
668
+ start_index >= len(offset_mapping)
669
+ or end_index >= len(offset_mapping)
670
+ or tuple(offset_mapping[start_index]) == (-1, -1)
671
+ or tuple(offset_mapping[end_index]) == (-1, -1)
672
+ ):
673
+ continue
674
+
675
+ # Do not consider answers with a length that is either negative or greater
676
+ # than the context length
677
+ max_val = max_answer_length + start_index - 1
678
+ if end_index < start_index or end_index > max_val:
679
+ continue
680
+
681
+ # If we got to this point then the answer is valid, so we store the
682
+ # corresponding start- and end character indices in the original context,
683
+ # and from these extract the answer
684
+ start_char = offset_mapping[start_index][0]
685
+ end_char = offset_mapping[end_index][1]
686
+ text = context[start_char:end_char]
687
+
688
+ # Compute the score of the answer, being the sum of the start and end
689
+ # logits. Intuitively, this indicates how likely the answer is to be
690
+ # correct, and allows us to pick the best valid answer.
691
+ score = start_logits[start_index] + end_logits[end_index]
692
+
693
+ # Add the answer to the list of valid answers, if the score is greater
694
+ # than the minimum null score
695
+ if score > min_null_score:
696
+ valid_answers.append(dict(score=score, text=text))
697
+
698
+ return valid_answers