EuroEval 15.12.0__py3-none-any.whl → 16.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- euroeval/__init__.py +32 -14
- euroeval/benchmark_config_factory.py +92 -180
- euroeval/benchmark_modules/base.py +49 -39
- euroeval/benchmark_modules/fresh.py +35 -21
- euroeval/benchmark_modules/hf.py +280 -244
- euroeval/benchmark_modules/litellm.py +752 -312
- euroeval/benchmark_modules/vllm.py +570 -268
- euroeval/benchmarker.py +651 -528
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +49 -38
- euroeval/constants.py +44 -25
- euroeval/data_loading.py +111 -55
- euroeval/data_models.py +490 -323
- euroeval/dataset_configs/__init__.py +26 -4
- euroeval/dataset_configs/bosnian.py +39 -0
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/croatian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +78 -50
- euroeval/dataset_configs/dutch.py +74 -44
- euroeval/dataset_configs/english.py +71 -36
- euroeval/dataset_configs/estonian.py +111 -0
- euroeval/dataset_configs/faroese.py +25 -18
- euroeval/dataset_configs/finnish.py +63 -26
- euroeval/dataset_configs/french.py +65 -32
- euroeval/dataset_configs/german.py +77 -36
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +68 -57
- euroeval/dataset_configs/italian.py +68 -36
- euroeval/dataset_configs/latvian.py +87 -0
- euroeval/dataset_configs/lithuanian.py +64 -0
- euroeval/dataset_configs/norwegian.py +98 -72
- euroeval/dataset_configs/polish.py +96 -0
- euroeval/dataset_configs/portuguese.py +63 -40
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/slovene.py +56 -0
- euroeval/dataset_configs/spanish.py +68 -34
- euroeval/dataset_configs/swedish.py +82 -41
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/enums.py +12 -6
- euroeval/exceptions.py +21 -1
- euroeval/finetuning.py +34 -26
- euroeval/generation.py +76 -41
- euroeval/generation_utils.py +169 -34
- euroeval/languages.py +1020 -188
- euroeval/logging_utils.py +268 -0
- euroeval/metrics/__init__.py +6 -0
- euroeval/metrics/base.py +85 -0
- euroeval/metrics/huggingface.py +216 -0
- euroeval/metrics/llm_as_a_judge.py +260 -0
- euroeval/metrics/pipeline.py +289 -0
- euroeval/metrics/speed.py +48 -0
- euroeval/model_cache.py +40 -21
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +157 -22
- euroeval/prompt_templates/multiple_choice.py +159 -17
- euroeval/prompt_templates/named_entity_recognition.py +318 -21
- euroeval/prompt_templates/reading_comprehension.py +207 -16
- euroeval/prompt_templates/sentiment_classification.py +205 -22
- euroeval/prompt_templates/summarization.py +122 -22
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +20 -9
- euroeval/speed_benchmark.py +11 -12
- euroeval/task_group_utils/multiple_choice_classification.py +21 -12
- euroeval/task_group_utils/question_answering.py +101 -73
- euroeval/task_group_utils/sequence_classification.py +144 -61
- euroeval/task_group_utils/text_to_text.py +33 -12
- euroeval/task_group_utils/token_classification.py +86 -89
- euroeval/tasks.py +75 -16
- euroeval/tokenisation_utils.py +603 -0
- euroeval/types.py +17 -11
- euroeval/utils.py +332 -137
- euroeval-16.7.1.dist-info/METADATA +623 -0
- euroeval-16.7.1.dist-info/RECORD +84 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/entry_points.txt +0 -1
- euroeval/human_evaluation.py +0 -737
- euroeval/metrics.py +0 -452
- euroeval/tokenization_utils.py +0 -498
- euroeval-15.12.0.dist-info/METADATA +0 -285
- euroeval-15.12.0.dist-info/RECORD +0 -63
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/WHEEL +0 -0
- {euroeval-15.12.0.dist-info → euroeval-16.7.1.dist-info}/licenses/LICENSE +0 -0
euroeval/finetuning.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""Functions related to the finetuning of models."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import logging
|
|
4
5
|
import sys
|
|
5
6
|
import typing as t
|
|
7
|
+
from functools import partial
|
|
6
8
|
|
|
7
9
|
import torch
|
|
8
|
-
from tqdm.auto import tqdm
|
|
9
10
|
from transformers.trainer_callback import (
|
|
10
11
|
EarlyStoppingCallback,
|
|
11
12
|
PrinterCallback,
|
|
@@ -17,13 +18,9 @@ from transformers.training_args import OptimizerNames, TrainingArguments
|
|
|
17
18
|
from .callbacks import NeverLeaveProgressCallback
|
|
18
19
|
from .enums import DataType
|
|
19
20
|
from .exceptions import InvalidBenchmark, NaNValueInModelOutput
|
|
21
|
+
from .logging_utils import block_terminal_output, get_pbar, log, log_once
|
|
20
22
|
from .model_loading import load_model
|
|
21
|
-
from .utils import
|
|
22
|
-
block_terminal_output,
|
|
23
|
-
clear_memory,
|
|
24
|
-
enforce_reproducibility,
|
|
25
|
-
log_once,
|
|
26
|
-
)
|
|
23
|
+
from .utils import clear_memory, enforce_reproducibility
|
|
27
24
|
|
|
28
25
|
if t.TYPE_CHECKING:
|
|
29
26
|
from datasets import DatasetDict
|
|
@@ -31,16 +28,14 @@ if t.TYPE_CHECKING:
|
|
|
31
28
|
from .benchmark_modules import BenchmarkModule
|
|
32
29
|
from .data_models import BenchmarkConfig, DatasetConfig, ModelConfig
|
|
33
30
|
|
|
34
|
-
logger = logging.getLogger("euroeval")
|
|
35
|
-
|
|
36
31
|
|
|
37
32
|
def finetune(
|
|
38
33
|
model: "BenchmarkModule",
|
|
39
|
-
datasets:
|
|
34
|
+
datasets: c.Sequence["DatasetDict"],
|
|
40
35
|
model_config: "ModelConfig",
|
|
41
36
|
dataset_config: "DatasetConfig",
|
|
42
37
|
benchmark_config: "BenchmarkConfig",
|
|
43
|
-
) ->
|
|
38
|
+
) -> c.Sequence[dict[str, float]]:
|
|
44
39
|
"""Evaluate a model on a dataset through finetuning.
|
|
45
40
|
|
|
46
41
|
Args:
|
|
@@ -57,6 +52,10 @@ def finetune(
|
|
|
57
52
|
|
|
58
53
|
Returns:
|
|
59
54
|
A list of dicts containing the scores for each metric for each iteration.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
InvalidBenchmark:
|
|
58
|
+
If the benchmark could not be completed.
|
|
60
59
|
"""
|
|
61
60
|
# Set the data type to use for the model weights
|
|
62
61
|
using_cuda = benchmark_config.device == torch.device("cuda")
|
|
@@ -67,9 +66,9 @@ def finetune(
|
|
|
67
66
|
else:
|
|
68
67
|
dtype = DataType.FP32
|
|
69
68
|
|
|
70
|
-
bs: int = benchmark_config.
|
|
69
|
+
bs: int = benchmark_config.finetuning_batch_size
|
|
71
70
|
scores: list[dict[str, float]] = list()
|
|
72
|
-
for idx in
|
|
71
|
+
for idx in get_pbar(
|
|
73
72
|
iterable=range(benchmark_config.num_iterations),
|
|
74
73
|
desc="Benchmarking",
|
|
75
74
|
disable=not benchmark_config.progress_bar,
|
|
@@ -79,7 +78,7 @@ def finetune(
|
|
|
79
78
|
model_already_initialized = idx == 0
|
|
80
79
|
|
|
81
80
|
# Run a loop here to deal with automatic reduction of batch size
|
|
82
|
-
|
|
81
|
+
for _ in range(num_attempts := 10):
|
|
83
82
|
# Clear GPU memory
|
|
84
83
|
if not model_already_initialized:
|
|
85
84
|
try:
|
|
@@ -111,30 +110,34 @@ def finetune(
|
|
|
111
110
|
)
|
|
112
111
|
|
|
113
112
|
scores.append(itr_scores)
|
|
114
|
-
|
|
113
|
+
log(
|
|
114
|
+
f"Test scores for iteration {idx}: {itr_scores}",
|
|
115
|
+
level=logging.DEBUG,
|
|
116
|
+
)
|
|
115
117
|
|
|
116
118
|
break
|
|
117
119
|
|
|
118
120
|
# NaN values can appear in the model output when using mixed precision, as
|
|
119
121
|
# the hidden states get overflowed. In this case we try to disable mixed
|
|
120
122
|
# precision and try again.
|
|
121
|
-
except NaNValueInModelOutput:
|
|
123
|
+
except NaNValueInModelOutput as e:
|
|
122
124
|
if dtype != DataType.FP32:
|
|
123
125
|
dtype = DataType.FP32
|
|
124
126
|
model_already_initialized = False
|
|
125
|
-
|
|
127
|
+
log(
|
|
126
128
|
"NaN value detected in model outputs while using mixed "
|
|
127
|
-
"precision. Retrying with full fp32 precision."
|
|
129
|
+
"precision. Retrying with full fp32 precision.",
|
|
130
|
+
level=logging.DEBUG,
|
|
128
131
|
)
|
|
129
132
|
else:
|
|
130
133
|
raise InvalidBenchmark(
|
|
131
134
|
"NaN value detected in model outputs, even with mixed "
|
|
132
135
|
"precision disabled."
|
|
133
|
-
)
|
|
136
|
+
) from e
|
|
134
137
|
|
|
135
138
|
except Exception as e:
|
|
136
139
|
if "CUDA" not in str(e) and "out of memory" not in str(e):
|
|
137
|
-
raise InvalidBenchmark(str(e))
|
|
140
|
+
raise InvalidBenchmark(str(e)) from e
|
|
138
141
|
|
|
139
142
|
if bs <= 1:
|
|
140
143
|
msg = "Could not benchmark the model, even with a batch size of 1!"
|
|
@@ -145,12 +148,17 @@ def finetune(
|
|
|
145
148
|
"environment variable set, as this removes the upper bound "
|
|
146
149
|
"on the memory usage."
|
|
147
150
|
)
|
|
148
|
-
raise InvalidBenchmark(msg)
|
|
151
|
+
raise InvalidBenchmark(msg) from e
|
|
149
152
|
|
|
150
153
|
model_already_initialized = False
|
|
151
154
|
|
|
152
155
|
bs //= 2
|
|
153
|
-
|
|
156
|
+
log(f"Reduced batch size to {bs}", level=logging.DEBUG)
|
|
157
|
+
|
|
158
|
+
else:
|
|
159
|
+
raise InvalidBenchmark(
|
|
160
|
+
f"Could not benchmark the model after {num_attempts} attempts!"
|
|
161
|
+
)
|
|
154
162
|
|
|
155
163
|
return scores
|
|
156
164
|
|
|
@@ -194,11 +202,11 @@ def finetune_single_iteration(
|
|
|
194
202
|
|
|
195
203
|
trainer = model.trainer_class(
|
|
196
204
|
model=model.get_pytorch_module(),
|
|
197
|
-
processing_class=model.
|
|
205
|
+
processing_class=model.get_tokeniser(),
|
|
198
206
|
args=training_args,
|
|
199
207
|
train_dataset=dataset["train"],
|
|
200
208
|
eval_dataset=dataset["val"],
|
|
201
|
-
compute_metrics=model.compute_metrics,
|
|
209
|
+
compute_metrics=partial(model.compute_metrics, dataset=None),
|
|
202
210
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
|
|
203
211
|
data_collator=model.data_collator,
|
|
204
212
|
preprocess_logits_for_metrics=remove_extra_tensors_from_logits,
|
|
@@ -244,7 +252,7 @@ def finetune_single_iteration(
|
|
|
244
252
|
clear_memory()
|
|
245
253
|
raise e
|
|
246
254
|
except (RuntimeError, ValueError, IndexError) as e:
|
|
247
|
-
raise InvalidBenchmark(str(e))
|
|
255
|
+
raise InvalidBenchmark(str(e)) from e
|
|
248
256
|
|
|
249
257
|
return test_scores
|
|
250
258
|
|
|
@@ -283,7 +291,7 @@ def get_training_args(
|
|
|
283
291
|
logging_strategy = IntervalStrategy.NO
|
|
284
292
|
|
|
285
293
|
if batch_size is None:
|
|
286
|
-
batch_size = benchmark_config.
|
|
294
|
+
batch_size = benchmark_config.finetuning_batch_size
|
|
287
295
|
|
|
288
296
|
training_args = TrainingArguments(
|
|
289
297
|
output_dir=model_config.model_cache_dir,
|
euroeval/generation.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
|
1
1
|
"""Functions related to text generation of models."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import logging
|
|
4
5
|
import sys
|
|
5
6
|
import typing as t
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from datasets import Dataset
|
|
9
10
|
from tqdm.auto import tqdm
|
|
10
11
|
|
|
11
12
|
from .enums import BatchingPreference, TaskGroup
|
|
12
|
-
from .exceptions import InvalidBenchmark
|
|
13
|
+
from .exceptions import InvalidBenchmark, InvalidModel
|
|
14
|
+
from .logging_utils import get_pbar, log, log_once
|
|
13
15
|
from .model_cache import (
|
|
14
16
|
ModelCache,
|
|
15
17
|
load_cached_model_outputs,
|
|
@@ -18,7 +20,7 @@ from .model_cache import (
|
|
|
18
20
|
from .utils import clear_memory
|
|
19
21
|
|
|
20
22
|
if t.TYPE_CHECKING:
|
|
21
|
-
from datasets import
|
|
23
|
+
from datasets import DatasetDict
|
|
22
24
|
|
|
23
25
|
from .benchmark_modules import BenchmarkModule
|
|
24
26
|
from .data_models import (
|
|
@@ -28,16 +30,14 @@ if t.TYPE_CHECKING:
|
|
|
28
30
|
ModelConfig,
|
|
29
31
|
)
|
|
30
32
|
|
|
31
|
-
logger = logging.getLogger("euroeval")
|
|
32
|
-
|
|
33
33
|
|
|
34
34
|
def generate(
|
|
35
35
|
model: "BenchmarkModule",
|
|
36
|
-
datasets:
|
|
36
|
+
datasets: c.Sequence["DatasetDict"],
|
|
37
37
|
model_config: "ModelConfig",
|
|
38
38
|
dataset_config: "DatasetConfig",
|
|
39
39
|
benchmark_config: "BenchmarkConfig",
|
|
40
|
-
) ->
|
|
40
|
+
) -> c.Sequence[dict[str, float]]:
|
|
41
41
|
"""Evaluate a model on a dataset through generation.
|
|
42
42
|
|
|
43
43
|
Args:
|
|
@@ -74,11 +74,12 @@ def generate(
|
|
|
74
74
|
model_cache_dir=model_cache_dir,
|
|
75
75
|
cache_name=cache_name,
|
|
76
76
|
max_generated_tokens=dataset_config.max_generated_tokens,
|
|
77
|
+
progress_bar=benchmark_config.progress_bar,
|
|
77
78
|
)
|
|
78
79
|
|
|
79
80
|
scores: list[dict[str, float]] = list()
|
|
80
|
-
for idx in
|
|
81
|
-
iterable=range(
|
|
81
|
+
for idx in get_pbar(
|
|
82
|
+
iterable=range(len(datasets)),
|
|
82
83
|
desc="Benchmarking",
|
|
83
84
|
disable=not benchmark_config.progress_bar,
|
|
84
85
|
):
|
|
@@ -89,8 +90,7 @@ def generate(
|
|
|
89
90
|
dataset_config=dataset_config,
|
|
90
91
|
benchmark_config=benchmark_config,
|
|
91
92
|
)
|
|
92
|
-
|
|
93
|
-
logger.debug(f"Test scores for iteration {idx}: {test_scores}")
|
|
93
|
+
log(f"Test scores for iteration {idx}: {test_scores}", level=logging.DEBUG)
|
|
94
94
|
scores.append(test_scores)
|
|
95
95
|
clear_memory()
|
|
96
96
|
|
|
@@ -126,10 +126,15 @@ def generate_single_iteration(
|
|
|
126
126
|
"""
|
|
127
127
|
cache.load()
|
|
128
128
|
|
|
129
|
-
# Split up the dataset into a cached and non-cached part
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
# Split up the dataset into a cached and non-cached part, unless we are not
|
|
130
|
+
# bootstrapping the samples. In that case, we just use the dataset as is.
|
|
131
|
+
if dataset_config.bootstrap_samples:
|
|
132
|
+
cached_dataset, non_cached_dataset = split_dataset_into_cached_and_non_cached(
|
|
133
|
+
dataset=dataset, cache=cache
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
cached_dataset = Dataset.from_dict({})
|
|
137
|
+
non_cached_dataset = dataset
|
|
133
138
|
|
|
134
139
|
all_preds: list[str] = list()
|
|
135
140
|
|
|
@@ -137,19 +142,31 @@ def generate_single_iteration(
|
|
|
137
142
|
itr: t.Iterable
|
|
138
143
|
match model.batching_preference:
|
|
139
144
|
case BatchingPreference.SINGLE_SAMPLE:
|
|
140
|
-
itr =
|
|
145
|
+
itr = get_pbar(
|
|
146
|
+
iterable=non_cached_dataset,
|
|
147
|
+
disable=not benchmark_config.progress_bar,
|
|
148
|
+
)
|
|
141
149
|
case BatchingPreference.ALL_AT_ONCE:
|
|
142
150
|
itr = [non_cached_dataset[:]]
|
|
143
151
|
case _:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
itr = tqdm(
|
|
148
|
-
iterable=mit.batched(
|
|
149
|
-
iterable=non_cached_dataset, n=benchmark_config.batch_size
|
|
150
|
-
),
|
|
151
|
-
total=len(non_cached_dataset) // benchmark_config.batch_size,
|
|
152
|
+
raise InvalidModel(
|
|
153
|
+
f"The batching preference {model.batching_preference!r} is "
|
|
154
|
+
"currently not supported."
|
|
152
155
|
)
|
|
156
|
+
# NOTE: The code below can be used if we want to support batching for
|
|
157
|
+
# generative models. But in that case, we have to deal with the naming
|
|
158
|
+
# of the batch size variable, since it is currently
|
|
159
|
+
# `finetuning_batch_size`, as it is only used during finetuning of
|
|
160
|
+
# encoder models.
|
|
161
|
+
# num_batches = len(non_cached_dataset) // benchmark_config.batch_size
|
|
162
|
+
# if len(non_cached_dataset) % benchmark_config.batch_size != 0:
|
|
163
|
+
# num_batches += 1
|
|
164
|
+
# itr = get_pbar(
|
|
165
|
+
# iterable=mit.batched(
|
|
166
|
+
# iterable=non_cached_dataset, n=benchmark_config.batch_size
|
|
167
|
+
# ),
|
|
168
|
+
# total=len(non_cached_dataset) // benchmark_config.batch_size,
|
|
169
|
+
# )
|
|
153
170
|
|
|
154
171
|
# Generate the completions for the non-cached examples
|
|
155
172
|
for batch in itr:
|
|
@@ -230,12 +247,17 @@ def generate_single_iteration(
|
|
|
230
247
|
cached_labels = list(cached_labels)
|
|
231
248
|
ground_truth = non_cached_labels + cached_labels
|
|
232
249
|
else:
|
|
233
|
-
|
|
234
|
-
"
|
|
250
|
+
log_once(
|
|
251
|
+
"No labels found in the dataset. We assume that this is intentional, and "
|
|
252
|
+
"will not supply any ground truth labels for evaluation.",
|
|
253
|
+
level=logging.DEBUG,
|
|
235
254
|
)
|
|
255
|
+
ground_truth = []
|
|
236
256
|
|
|
237
257
|
itr_scores: dict[str, float] = model.compute_metrics(
|
|
238
|
-
model_outputs_and_labels=(all_preds, ground_truth)
|
|
258
|
+
model_outputs_and_labels=(all_preds, ground_truth),
|
|
259
|
+
dataset=dataset,
|
|
260
|
+
benchmark_config=benchmark_config,
|
|
239
261
|
)
|
|
240
262
|
|
|
241
263
|
return itr_scores
|
|
@@ -244,7 +266,7 @@ def generate_single_iteration(
|
|
|
244
266
|
def debug_log(
|
|
245
267
|
batch: dict[str, t.Any],
|
|
246
268
|
model_output: "GenerativeModelOutput",
|
|
247
|
-
extracted_labels:
|
|
269
|
+
extracted_labels: c.Sequence[dict | str | c.Sequence[str]],
|
|
248
270
|
dataset_config: "DatasetConfig",
|
|
249
271
|
) -> None:
|
|
250
272
|
"""Log inputs and outputs for debugging purposes.
|
|
@@ -287,16 +309,19 @@ def debug_log(
|
|
|
287
309
|
+ "\n"
|
|
288
310
|
+ "\t".join(labels)
|
|
289
311
|
)
|
|
290
|
-
|
|
312
|
+
log("\n\n".join(log_msgs), level=logging.DEBUG)
|
|
291
313
|
return
|
|
292
314
|
|
|
293
315
|
case (
|
|
294
316
|
TaskGroup.SEQUENCE_CLASSIFICATION | TaskGroup.MULTIPLE_CHOICE_CLASSIFICATION
|
|
295
317
|
):
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
318
|
+
if "label" in batch:
|
|
319
|
+
labels = [
|
|
320
|
+
dataset_config.prompt_label_mapping.get(label, label).lower()
|
|
321
|
+
for label in batch["label"]
|
|
322
|
+
]
|
|
323
|
+
else:
|
|
324
|
+
labels = [None] * len(extracted_labels)
|
|
300
325
|
|
|
301
326
|
case TaskGroup.QUESTION_ANSWERING:
|
|
302
327
|
extracted_labels = [
|
|
@@ -319,12 +344,22 @@ def debug_log(
|
|
|
319
344
|
else:
|
|
320
345
|
input_texts = batch["text"]
|
|
321
346
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
347
|
+
metadata_keys: c.Sequence[str] = [
|
|
348
|
+
key
|
|
349
|
+
for key in batch.keys()
|
|
350
|
+
if key not in ["text", "messages", "label", "labels", "target_text"]
|
|
351
|
+
]
|
|
352
|
+
|
|
353
|
+
for idx in range(len(input_texts)):
|
|
354
|
+
data_to_log: dict[str, t.Any] = {
|
|
355
|
+
"Input": input_texts[idx],
|
|
356
|
+
"Raw output": model_output.sequences[idx],
|
|
357
|
+
"Prediction": extracted_labels[idx],
|
|
358
|
+
}
|
|
359
|
+
if labels[idx]:
|
|
360
|
+
data_to_log["Label"] = labels[idx]
|
|
361
|
+
data_to_log |= {key.capitalize(): batch[key][idx] for key in metadata_keys}
|
|
362
|
+
log(
|
|
363
|
+
"\n".join(f"{key}: {value!r}" for key, value in data_to_log.items()),
|
|
364
|
+
level=logging.DEBUG,
|
|
330
365
|
)
|