EuroEval 16.3.0__py3-none-any.whl → 16.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of EuroEval might be problematic. Click here for more details.
- euroeval/__init__.py +9 -2
- euroeval/benchmark_config_factory.py +51 -50
- euroeval/benchmark_modules/base.py +9 -21
- euroeval/benchmark_modules/fresh.py +2 -1
- euroeval/benchmark_modules/hf.py +101 -71
- euroeval/benchmark_modules/litellm.py +115 -53
- euroeval/benchmark_modules/vllm.py +107 -92
- euroeval/benchmarker.py +144 -121
- euroeval/caching_utils.py +79 -0
- euroeval/callbacks.py +5 -7
- euroeval/cli.py +86 -8
- euroeval/constants.py +9 -0
- euroeval/data_loading.py +80 -29
- euroeval/data_models.py +338 -330
- euroeval/dataset_configs/__init__.py +12 -3
- euroeval/dataset_configs/bulgarian.py +56 -0
- euroeval/dataset_configs/czech.py +75 -0
- euroeval/dataset_configs/danish.py +55 -93
- euroeval/dataset_configs/dutch.py +48 -87
- euroeval/dataset_configs/english.py +45 -77
- euroeval/dataset_configs/estonian.py +42 -34
- euroeval/dataset_configs/faroese.py +19 -60
- euroeval/dataset_configs/finnish.py +36 -69
- euroeval/dataset_configs/french.py +39 -75
- euroeval/dataset_configs/german.py +45 -82
- euroeval/dataset_configs/greek.py +64 -0
- euroeval/dataset_configs/icelandic.py +54 -91
- euroeval/dataset_configs/italian.py +42 -79
- euroeval/dataset_configs/latvian.py +28 -35
- euroeval/dataset_configs/lithuanian.py +28 -26
- euroeval/dataset_configs/norwegian.py +72 -115
- euroeval/dataset_configs/polish.py +33 -61
- euroeval/dataset_configs/portuguese.py +33 -66
- euroeval/dataset_configs/serbian.py +64 -0
- euroeval/dataset_configs/slovak.py +55 -0
- euroeval/dataset_configs/spanish.py +42 -77
- euroeval/dataset_configs/swedish.py +52 -90
- euroeval/dataset_configs/ukrainian.py +64 -0
- euroeval/exceptions.py +1 -1
- euroeval/finetuning.py +24 -17
- euroeval/generation.py +15 -14
- euroeval/generation_utils.py +8 -8
- euroeval/languages.py +395 -323
- euroeval/logging_utils.py +250 -0
- euroeval/metrics/base.py +0 -3
- euroeval/metrics/huggingface.py +21 -6
- euroeval/metrics/llm_as_a_judge.py +6 -4
- euroeval/metrics/pipeline.py +17 -9
- euroeval/metrics/speed.py +0 -3
- euroeval/model_cache.py +17 -19
- euroeval/model_config.py +4 -5
- euroeval/model_loading.py +3 -0
- euroeval/prompt_templates/__init__.py +2 -0
- euroeval/prompt_templates/classification.py +206 -0
- euroeval/prompt_templates/linguistic_acceptability.py +99 -42
- euroeval/prompt_templates/multiple_choice.py +102 -38
- euroeval/prompt_templates/named_entity_recognition.py +172 -51
- euroeval/prompt_templates/reading_comprehension.py +119 -42
- euroeval/prompt_templates/sentiment_classification.py +110 -40
- euroeval/prompt_templates/summarization.py +85 -40
- euroeval/prompt_templates/token_classification.py +279 -0
- euroeval/scores.py +11 -10
- euroeval/speed_benchmark.py +5 -6
- euroeval/task_group_utils/multiple_choice_classification.py +2 -4
- euroeval/task_group_utils/question_answering.py +24 -16
- euroeval/task_group_utils/sequence_classification.py +48 -35
- euroeval/task_group_utils/text_to_text.py +19 -9
- euroeval/task_group_utils/token_classification.py +21 -17
- euroeval/tasks.py +44 -1
- euroeval/tokenisation_utils.py +33 -22
- euroeval/types.py +10 -9
- euroeval/utils.py +35 -149
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/METADATA +196 -39
- euroeval-16.5.0.dist-info/RECORD +81 -0
- euroeval-16.3.0.dist-info/RECORD +0 -71
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/WHEEL +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/entry_points.txt +0 -0
- {euroeval-16.3.0.dist-info → euroeval-16.5.0.dist-info}/licenses/LICENSE +0 -0
euroeval/tokenisation_utils.py
CHANGED
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
"""Utility functions related to tokenisation."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import logging
|
|
4
5
|
import re
|
|
5
6
|
import typing as t
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
|
-
from transformers import MistralCommonTokenizer
|
|
9
|
+
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
|
9
10
|
|
|
10
11
|
from .enums import GenerativeType
|
|
11
12
|
from .exceptions import InvalidModel
|
|
12
|
-
from .
|
|
13
|
+
from .logging_utils import log, log_once
|
|
13
14
|
|
|
14
15
|
if t.TYPE_CHECKING:
|
|
15
16
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
@@ -18,9 +19,6 @@ if t.TYPE_CHECKING:
|
|
|
18
19
|
from .data_models import DatasetConfig, ModelConfig
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
logger = logging.getLogger("euroeval")
|
|
22
|
-
|
|
23
|
-
|
|
24
22
|
def get_special_token_metadata(tokeniser: "PreTrainedTokenizerBase") -> dict:
|
|
25
23
|
"""Get the special token metadata for a tokeniser.
|
|
26
24
|
|
|
@@ -74,7 +72,7 @@ def get_special_token_metadata(tokeniser: "PreTrainedTokenizerBase") -> dict:
|
|
|
74
72
|
|
|
75
73
|
|
|
76
74
|
def should_prompts_be_stripped(
|
|
77
|
-
labels_to_be_generated:
|
|
75
|
+
labels_to_be_generated: c.Sequence[str], tokeniser: "PreTrainedTokenizer"
|
|
78
76
|
) -> bool:
|
|
79
77
|
"""Determine if we should strip the prompts for few-shot evaluation.
|
|
80
78
|
|
|
@@ -113,7 +111,7 @@ def should_prompts_be_stripped(
|
|
|
113
111
|
|
|
114
112
|
|
|
115
113
|
def should_prefix_space_be_added_to_labels(
|
|
116
|
-
labels_to_be_generated:
|
|
114
|
+
labels_to_be_generated: c.Sequence[str], tokeniser: "PreTrainedTokenizer"
|
|
117
115
|
) -> bool:
|
|
118
116
|
"""Determine if we should add a prefix space to the labels.
|
|
119
117
|
|
|
@@ -182,7 +180,7 @@ def get_bos_token(
|
|
|
182
180
|
"The model does not have a beginning-of-sequence token. Please ensure that "
|
|
183
181
|
"this has been set in the tokeniser's configuration. Using no BOS token."
|
|
184
182
|
" This may lead to unexpected behavior in the model.",
|
|
185
|
-
level=logging.
|
|
183
|
+
level=logging.WARNING,
|
|
186
184
|
)
|
|
187
185
|
return None, None
|
|
188
186
|
|
|
@@ -223,14 +221,14 @@ def get_eos_token(
|
|
|
223
221
|
"The model does not have an end-of-sequence token. Please ensure that this "
|
|
224
222
|
"has been set in the tokeniser's configuration. Using no EOS token. This "
|
|
225
223
|
"may lead to unexpected behavior in the model.",
|
|
226
|
-
level=logging.
|
|
224
|
+
level=logging.WARNING,
|
|
227
225
|
)
|
|
228
226
|
return None, None
|
|
229
227
|
|
|
230
228
|
log_once(
|
|
231
229
|
f"End-of-sequence token was not set, but detected it as {eos_token!r} with "
|
|
232
230
|
f"ID {eos_token_id}.",
|
|
233
|
-
level=logging.
|
|
231
|
+
level=logging.WARNING,
|
|
234
232
|
)
|
|
235
233
|
return eos_token, eos_token_id
|
|
236
234
|
|
|
@@ -306,7 +304,7 @@ def get_pad_token(
|
|
|
306
304
|
"Could not identify a padding token for the model. Please ensure that "
|
|
307
305
|
"this has been set in the tokeniser's configuration. Using no padding "
|
|
308
306
|
"token. This may lead to unexpected behavior in the model.",
|
|
309
|
-
level=logging.
|
|
307
|
+
level=logging.WARNING,
|
|
310
308
|
)
|
|
311
309
|
return None, None
|
|
312
310
|
|
|
@@ -320,7 +318,7 @@ def get_pad_token(
|
|
|
320
318
|
|
|
321
319
|
def get_end_of_chat_token_ids(
|
|
322
320
|
tokeniser: "PreTrainedTokenizer", generative_type: GenerativeType | None
|
|
323
|
-
) ->
|
|
321
|
+
) -> c.Sequence[int] | None:
|
|
324
322
|
"""Get the end token ID for chat models.
|
|
325
323
|
|
|
326
324
|
This is only relevant for tokenisers with a chat template.
|
|
@@ -358,12 +356,16 @@ def get_end_of_chat_token_ids(
|
|
|
358
356
|
x_token_index = idx
|
|
359
357
|
break
|
|
360
358
|
else:
|
|
361
|
-
|
|
359
|
+
log(
|
|
360
|
+
"Could not locate the end-of-chat token for the model.", level=logging.DEBUG
|
|
361
|
+
)
|
|
362
362
|
return None
|
|
363
363
|
|
|
364
364
|
end_of_chat_tokens = token_ids[x_token_index + 1 :]
|
|
365
365
|
if len(end_of_chat_tokens) == 0:
|
|
366
|
-
|
|
366
|
+
log(
|
|
367
|
+
"Could not locate the end-of-chat token for the model.", level=logging.DEBUG
|
|
368
|
+
)
|
|
367
369
|
return None
|
|
368
370
|
|
|
369
371
|
log_once(
|
|
@@ -432,13 +434,19 @@ def get_first_label_token_mapping(
|
|
|
432
434
|
|
|
433
435
|
# Tokenise some text containing each label, which we will use to extract the
|
|
434
436
|
# first token of each label
|
|
435
|
-
all_tokens:
|
|
437
|
+
all_tokens: c.Sequence[c.Sequence[str]]
|
|
436
438
|
if not has_chat_template(tokeniser=tokeniser):
|
|
437
439
|
add_prefix_space = should_prefix_space_be_added_to_labels(
|
|
438
440
|
labels_to_be_generated=local_labels, tokeniser=tokeniser
|
|
439
441
|
)
|
|
440
442
|
all_tokens = [
|
|
441
|
-
|
|
443
|
+
[
|
|
444
|
+
tokeniser.decode(token_id)
|
|
445
|
+
for token_id in tokeniser.encode(
|
|
446
|
+
text=f" {label}" if add_prefix_space else label,
|
|
447
|
+
add_special_tokens=False,
|
|
448
|
+
)
|
|
449
|
+
]
|
|
442
450
|
for label in local_labels
|
|
443
451
|
]
|
|
444
452
|
else:
|
|
@@ -465,7 +473,7 @@ def get_first_label_token_mapping(
|
|
|
465
473
|
all_tokens = [
|
|
466
474
|
[
|
|
467
475
|
re.sub(
|
|
468
|
-
pattern=r"^[^a-zæøåüöä0-9]+|[^a-zæøåüöä0-9]+$",
|
|
476
|
+
pattern=r"^[^a-zæøåüöä0-9 ]+|[^a-zæøåüöä0-9 ]+$",
|
|
469
477
|
repl="",
|
|
470
478
|
string=token.lower(),
|
|
471
479
|
)
|
|
@@ -477,11 +485,13 @@ def get_first_label_token_mapping(
|
|
|
477
485
|
# Extract the first token of each label
|
|
478
486
|
first_tokens: list[str] = list()
|
|
479
487
|
for token_list, label in zip(all_tokens, local_labels):
|
|
480
|
-
matching_tokens = [
|
|
488
|
+
matching_tokens = [
|
|
489
|
+
tok for tok in token_list if tok and label.startswith(tok.strip())
|
|
490
|
+
]
|
|
481
491
|
if not matching_tokens:
|
|
482
492
|
if log_metadata:
|
|
483
493
|
log_once(
|
|
484
|
-
f"No matching token found in token_list for label
|
|
494
|
+
f"No matching token found in token_list for label {label!r}, so "
|
|
485
495
|
"we will not use logprobs with the model.",
|
|
486
496
|
level=logging.DEBUG,
|
|
487
497
|
)
|
|
@@ -506,7 +516,8 @@ def get_first_label_token_mapping(
|
|
|
506
516
|
log_once(
|
|
507
517
|
"We will not use logprobs with the model since the first tokens of the "
|
|
508
518
|
"labels are not distinct. The first tokens for the labels "
|
|
509
|
-
f"{local_labels} are {first_tokens}"
|
|
519
|
+
f"{local_labels} are {first_tokens}",
|
|
520
|
+
level=logging.DEBUG,
|
|
510
521
|
)
|
|
511
522
|
return False
|
|
512
523
|
|
|
@@ -547,12 +558,12 @@ def has_chat_template(tokeniser: "PreTrainedTokenizer") -> bool:
|
|
|
547
558
|
|
|
548
559
|
|
|
549
560
|
def apply_chat_template(
|
|
550
|
-
conversation:
|
|
561
|
+
conversation: c.Sequence[dict[str, str]],
|
|
551
562
|
tokeniser: "PreTrainedTokenizer",
|
|
552
563
|
tokenise: bool,
|
|
553
564
|
add_generation_prompt: bool,
|
|
554
565
|
**extra_kwargs,
|
|
555
|
-
) -> str |
|
|
566
|
+
) -> str | c.Sequence[int]:
|
|
556
567
|
"""Apply the chat template to a prompt.
|
|
557
568
|
|
|
558
569
|
Args:
|
euroeval/types.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Types used throughout the project."""
|
|
2
2
|
|
|
3
|
+
import collections.abc as c
|
|
3
4
|
import typing as t
|
|
4
5
|
|
|
5
6
|
from transformers.trainer_utils import EvalPrediction
|
|
@@ -10,9 +11,9 @@ if t.TYPE_CHECKING:
|
|
|
10
11
|
|
|
11
12
|
from .data_models import BenchmarkConfig, GenerativeModelOutput
|
|
12
13
|
|
|
13
|
-
ScoreDict: t.TypeAlias = dict[str, dict[str, float] |
|
|
14
|
-
Predictions: t.TypeAlias = "NDArray |
|
|
15
|
-
Labels: t.TypeAlias = "NDArray |
|
|
14
|
+
ScoreDict: t.TypeAlias = dict[str, dict[str, float] | c.Sequence[dict[str, float]]]
|
|
15
|
+
Predictions: t.TypeAlias = "NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]"
|
|
16
|
+
Labels: t.TypeAlias = "NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]"
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class ComputeMetricsFunction(t.Protocol):
|
|
@@ -22,8 +23,8 @@ class ComputeMetricsFunction(t.Protocol):
|
|
|
22
23
|
self,
|
|
23
24
|
model_outputs_and_labels: EvalPrediction
|
|
24
25
|
| tuple[
|
|
25
|
-
"NDArray |
|
|
26
|
-
"NDArray |
|
|
26
|
+
"NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]",
|
|
27
|
+
"NDArray | c.Sequence[str] | c.Sequence[c.Sequence[str]]",
|
|
27
28
|
],
|
|
28
29
|
dataset: "Dataset",
|
|
29
30
|
benchmark_config: "BenchmarkConfig",
|
|
@@ -48,7 +49,7 @@ class ExtractLabelsFunction(t.Protocol):
|
|
|
48
49
|
|
|
49
50
|
def __call__(
|
|
50
51
|
self, input_batch: dict[str, list], model_output: "GenerativeModelOutput"
|
|
51
|
-
) ->
|
|
52
|
+
) -> c.Sequence[str]:
|
|
52
53
|
"""Extract the labels from the generated output.
|
|
53
54
|
|
|
54
55
|
Args:
|
|
@@ -63,7 +64,7 @@ class ExtractLabelsFunction(t.Protocol):
|
|
|
63
64
|
...
|
|
64
65
|
|
|
65
66
|
|
|
66
|
-
def is_list_of_int(x: object) -> t.TypeGuard[
|
|
67
|
+
def is_list_of_int(x: object) -> t.TypeGuard[c.Sequence[int]]:
|
|
67
68
|
"""Check if an object is a list of integers.
|
|
68
69
|
|
|
69
70
|
Args:
|
|
@@ -76,7 +77,7 @@ def is_list_of_int(x: object) -> t.TypeGuard[list[int]]:
|
|
|
76
77
|
return isinstance(x, list) and all(isinstance(i, int) for i in x)
|
|
77
78
|
|
|
78
79
|
|
|
79
|
-
def is_list_of_list_of_int(x: object) -> t.TypeGuard[
|
|
80
|
+
def is_list_of_list_of_int(x: object) -> t.TypeGuard[c.Sequence[c.Sequence[int]]]:
|
|
80
81
|
"""Check if an object is a list of list of integers.
|
|
81
82
|
|
|
82
83
|
Args:
|
|
@@ -93,7 +94,7 @@ def is_list_of_list_of_int(x: object) -> t.TypeGuard[list[list[int]]]:
|
|
|
93
94
|
)
|
|
94
95
|
|
|
95
96
|
|
|
96
|
-
def is_list_of_str(x: object) -> t.TypeGuard[
|
|
97
|
+
def is_list_of_str(x: object) -> t.TypeGuard[c.Sequence[str]]:
|
|
97
98
|
"""Check if an object is a list of integers.
|
|
98
99
|
|
|
99
100
|
Args:
|
euroeval/utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Utility functions to be used in other scripts."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import collections.abc as c
|
|
4
5
|
import gc
|
|
5
6
|
import importlib
|
|
6
7
|
import importlib.metadata
|
|
@@ -11,30 +12,23 @@ import re
|
|
|
11
12
|
import socket
|
|
12
13
|
import sys
|
|
13
14
|
import typing as t
|
|
14
|
-
import warnings
|
|
15
|
-
from functools import cache
|
|
16
15
|
from pathlib import Path
|
|
17
16
|
|
|
18
17
|
import demjson3
|
|
19
18
|
import huggingface_hub as hf_hub
|
|
20
|
-
import litellm
|
|
21
19
|
import numpy as np
|
|
22
20
|
import torch
|
|
23
|
-
from datasets.utils import disable_progress_bar
|
|
24
|
-
from transformers import logging as tf_logging
|
|
25
21
|
|
|
22
|
+
from .caching_utils import cache_arguments
|
|
23
|
+
from .constants import T
|
|
26
24
|
from .exceptions import InvalidBenchmark, InvalidModel, NaNValueInModelOutput
|
|
25
|
+
from .logging_utils import log, log_once
|
|
27
26
|
|
|
28
27
|
if t.TYPE_CHECKING:
|
|
29
|
-
from types import TracebackType
|
|
30
|
-
|
|
31
28
|
from .data_models import ModelIdComponents
|
|
32
29
|
from .types import Predictions
|
|
33
30
|
|
|
34
31
|
|
|
35
|
-
logger = logging.getLogger("euroeval")
|
|
36
|
-
|
|
37
|
-
|
|
38
32
|
def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
|
|
39
33
|
"""Create cache directory for a model.
|
|
40
34
|
|
|
@@ -149,69 +143,9 @@ def enforce_reproducibility(seed: int = 4242) -> np.random.Generator:
|
|
|
149
143
|
return rng
|
|
150
144
|
|
|
151
145
|
|
|
152
|
-
def
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
This filters warnings from some libraries, sets the logging level to ERROR for some
|
|
156
|
-
libraries, disabled tokeniser progress bars when using Hugging Face tokenisers, and
|
|
157
|
-
disables most of the logging from the `transformers` library.
|
|
158
|
-
"""
|
|
159
|
-
if os.getenv("FULL_LOG") == "1":
|
|
160
|
-
return
|
|
161
|
-
|
|
162
|
-
# Ignore miscellaneous warnings
|
|
163
|
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
164
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
165
|
-
logging.getLogger("absl").setLevel(logging.CRITICAL)
|
|
166
|
-
|
|
167
|
-
# Disable matplotlib logging
|
|
168
|
-
logging.getLogger("matplotlib.font_manager").setLevel(logging.CRITICAL)
|
|
169
|
-
|
|
170
|
-
# Disable PyTorch logging
|
|
171
|
-
logging.getLogger("torch.utils.cpp_extension").setLevel(logging.CRITICAL)
|
|
172
|
-
warnings.filterwarnings(action="ignore", module="torch*")
|
|
173
|
-
os.environ["TORCH_LOGS"] = "-all"
|
|
174
|
-
|
|
175
|
-
# Disable huggingface_hub logging
|
|
176
|
-
logging.getLogger("huggingface_hub").setLevel(logging.CRITICAL)
|
|
177
|
-
|
|
178
|
-
# Disable LiteLLM logging
|
|
179
|
-
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
|
|
180
|
-
logging.getLogger("LiteLLM Router").setLevel(logging.CRITICAL)
|
|
181
|
-
logging.getLogger("LiteLLM Proxy").setLevel(logging.CRITICAL)
|
|
182
|
-
logging.getLogger("openai").setLevel(logging.CRITICAL)
|
|
183
|
-
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
|
184
|
-
litellm.suppress_debug_info = True
|
|
185
|
-
|
|
186
|
-
# Disable vLLM logging
|
|
187
|
-
logging.getLogger("vllm").setLevel(logging.CRITICAL)
|
|
188
|
-
logging.getLogger("vllm.engine.llm_engine").setLevel(logging.CRITICAL)
|
|
189
|
-
logging.getLogger("vllm.transformers_utils.tokenizer").setLevel(logging.CRITICAL)
|
|
190
|
-
logging.getLogger("vllm.core.scheduler").setLevel(logging.CRITICAL)
|
|
191
|
-
logging.getLogger("vllm.model_executor.weight_utils").setLevel(logging.CRITICAL)
|
|
192
|
-
logging.getLogger("vllm.platforms").setLevel(logging.CRITICAL)
|
|
193
|
-
logging.getLogger("mistral_common.tokens.tokenizers.tekken").setLevel(
|
|
194
|
-
logging.CRITICAL
|
|
195
|
-
)
|
|
196
|
-
os.environ["LOG_LEVEL"] = "CRITICAL"
|
|
197
|
-
os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
|
|
198
|
-
|
|
199
|
-
# Disable datasets logging
|
|
200
|
-
logging.getLogger("datasets").setLevel(logging.CRITICAL)
|
|
201
|
-
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
|
202
|
-
disable_progress_bar()
|
|
203
|
-
|
|
204
|
-
# Disable evaluate logging
|
|
205
|
-
warnings.filterwarnings("ignore", module="seqeval*")
|
|
206
|
-
|
|
207
|
-
# Disable most of the `transformers` logging
|
|
208
|
-
tf_logging._default_log_level = logging.CRITICAL
|
|
209
|
-
tf_logging.set_verbosity(logging.CRITICAL)
|
|
210
|
-
logging.getLogger("transformers.trainer").setLevel(logging.CRITICAL)
|
|
211
|
-
logging.getLogger("accelerate").setLevel(logging.CRITICAL)
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type | None:
|
|
146
|
+
def get_class_by_name(
|
|
147
|
+
class_name: str | c.Sequence[str], module_name: str
|
|
148
|
+
) -> t.Type | None:
|
|
215
149
|
"""Get a class by its name.
|
|
216
150
|
|
|
217
151
|
Args:
|
|
@@ -240,9 +174,10 @@ def get_class_by_name(class_name: str | list[str], module_name: str) -> t.Type |
|
|
|
240
174
|
|
|
241
175
|
if error_messages:
|
|
242
176
|
errors = "\n- " + "\n- ".join(error_messages)
|
|
243
|
-
|
|
177
|
+
log(
|
|
244
178
|
f"Could not find the class with the name(s) {', '.join(class_name)}. The "
|
|
245
|
-
f"following error messages were raised: {errors}"
|
|
179
|
+
f"following error messages were raised: {errors}",
|
|
180
|
+
level=logging.DEBUG,
|
|
246
181
|
)
|
|
247
182
|
|
|
248
183
|
# If the class could not be found, return None
|
|
@@ -264,49 +199,27 @@ def get_min_cuda_compute_capability() -> float | None:
|
|
|
264
199
|
return float(f"{major}.{minor}")
|
|
265
200
|
|
|
266
201
|
|
|
267
|
-
@
|
|
202
|
+
@cache_arguments(disable_condition=lambda: hasattr(sys, "_called_from_test"))
|
|
268
203
|
def internet_connection_available() -> bool:
|
|
269
204
|
"""Checks if internet connection is available by pinging google.com.
|
|
270
205
|
|
|
271
206
|
Returns:
|
|
272
207
|
Whether or not internet connection is available.
|
|
273
208
|
"""
|
|
209
|
+
internet_available: bool = False
|
|
210
|
+
|
|
274
211
|
try:
|
|
275
212
|
s = socket.create_connection(("1.1.1.1", 80))
|
|
276
213
|
s.close()
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
# import these here as they're developer dependencies, we check the exception name
|
|
281
|
-
# instead. If the exception is not related to socket connections, we reraise it.
|
|
214
|
+
internet_available = True
|
|
215
|
+
except OSError:
|
|
216
|
+
pass
|
|
282
217
|
except Exception as e:
|
|
283
218
|
pytest_socket_errors = ["SocketConnectBlockedError", "SocketBlockedError"]
|
|
284
|
-
if type(e).__name__ in pytest_socket_errors
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
class HiddenPrints:
|
|
290
|
-
"""Context manager which removes all terminal output."""
|
|
291
|
-
|
|
292
|
-
def __enter__(self) -> None:
|
|
293
|
-
"""Enter the context manager."""
|
|
294
|
-
self._original_stdout = sys.stdout
|
|
295
|
-
self._original_stderr = sys.stderr
|
|
296
|
-
sys.stdout = open(os.devnull, "w")
|
|
297
|
-
sys.stderr = open(os.devnull, "w")
|
|
298
|
-
|
|
299
|
-
def __exit__(
|
|
300
|
-
self,
|
|
301
|
-
exc_type: t.Type[BaseException],
|
|
302
|
-
exc_val: BaseException,
|
|
303
|
-
exc_tb: "TracebackType",
|
|
304
|
-
) -> None:
|
|
305
|
-
"""Exit the context manager."""
|
|
306
|
-
sys.stdout.close()
|
|
307
|
-
sys.stderr.close()
|
|
308
|
-
sys.stdout = self._original_stdout
|
|
309
|
-
sys.stderr = self._original_stderr
|
|
219
|
+
if type(e).__name__ not in pytest_socket_errors:
|
|
220
|
+
raise e
|
|
221
|
+
|
|
222
|
+
return internet_available
|
|
310
223
|
|
|
311
224
|
|
|
312
225
|
def raise_if_model_output_contains_nan_values(model_output: "Predictions") -> None:
|
|
@@ -364,34 +277,6 @@ def unscramble(scrambled_text: str) -> str:
|
|
|
364
277
|
return unscrambled
|
|
365
278
|
|
|
366
279
|
|
|
367
|
-
@cache
|
|
368
|
-
def log_once(message: str, level: int = logging.INFO) -> None:
|
|
369
|
-
"""Log a message once.
|
|
370
|
-
|
|
371
|
-
This is ensured by caching the input/output pairs of this function, using the
|
|
372
|
-
`functools.cache` decorator.
|
|
373
|
-
|
|
374
|
-
Args:
|
|
375
|
-
message:
|
|
376
|
-
The message to log.
|
|
377
|
-
level:
|
|
378
|
-
The logging level. Defaults to logging.INFO.
|
|
379
|
-
"""
|
|
380
|
-
match level:
|
|
381
|
-
case logging.DEBUG:
|
|
382
|
-
logger.debug(message)
|
|
383
|
-
case logging.INFO:
|
|
384
|
-
logger.info(message)
|
|
385
|
-
case logging.WARNING:
|
|
386
|
-
logger.warning(message)
|
|
387
|
-
case logging.ERROR:
|
|
388
|
-
logger.error(message)
|
|
389
|
-
case logging.CRITICAL:
|
|
390
|
-
logger.critical(message)
|
|
391
|
-
case _:
|
|
392
|
-
raise ValueError(f"Invalid logging level: {level}")
|
|
393
|
-
|
|
394
|
-
|
|
395
280
|
def get_package_version(package_name: str) -> str | None:
|
|
396
281
|
"""Get the version of a package.
|
|
397
282
|
|
|
@@ -408,9 +293,6 @@ def get_package_version(package_name: str) -> str | None:
|
|
|
408
293
|
return None
|
|
409
294
|
|
|
410
295
|
|
|
411
|
-
T = t.TypeVar("T", bound=object)
|
|
412
|
-
|
|
413
|
-
|
|
414
296
|
def safe_run(coroutine: t.Coroutine[t.Any, t.Any, T]) -> T:
|
|
415
297
|
"""Run a coroutine, ensuring that the event loop is always closed when we're done.
|
|
416
298
|
|
|
@@ -464,37 +346,41 @@ def extract_json_dict_from_string(s: str) -> dict | None:
|
|
|
464
346
|
"""
|
|
465
347
|
json_regex = r"\{[^{}]*?\}"
|
|
466
348
|
if (json_match := re.search(pattern=json_regex, string=s, flags=re.DOTALL)) is None:
|
|
467
|
-
|
|
349
|
+
log(
|
|
468
350
|
"The model output does not contain any JSON dictionary, so cannot parse "
|
|
469
|
-
f"it. Skipping. Here is the output: {s!r}"
|
|
351
|
+
f"it. Skipping. Here is the output: {s!r}",
|
|
352
|
+
level=logging.DEBUG,
|
|
470
353
|
)
|
|
471
354
|
return None
|
|
472
355
|
json_string = json_match.group()
|
|
473
356
|
try:
|
|
474
357
|
json_output = demjson3.decode(txt=json_string)
|
|
475
358
|
except demjson3.JSONDecodeError:
|
|
476
|
-
|
|
359
|
+
log(
|
|
477
360
|
"The model output is not valid JSON, so cannot parse it. Skipping. "
|
|
478
|
-
f"Here is the output: {json_string!r}"
|
|
361
|
+
f"Here is the output: {json_string!r}",
|
|
362
|
+
level=logging.DEBUG,
|
|
479
363
|
)
|
|
480
364
|
return None
|
|
481
365
|
if not isinstance(json_output, dict):
|
|
482
|
-
|
|
366
|
+
log(
|
|
483
367
|
"The model output is not a JSON dictionary, so cannot parse "
|
|
484
|
-
f"it. Skipping. Here is the output: {json_string!r}"
|
|
368
|
+
f"it. Skipping. Here is the output: {json_string!r}",
|
|
369
|
+
level=logging.DEBUG,
|
|
485
370
|
)
|
|
486
371
|
return None
|
|
487
372
|
elif not all(isinstance(key, str) for key in json_output.keys()):
|
|
488
|
-
|
|
373
|
+
log(
|
|
489
374
|
"The model output is not a JSON dictionary with string keys, "
|
|
490
375
|
"so cannot parse it. Skipping. Here is the output: "
|
|
491
|
-
f"{json_string!r}"
|
|
376
|
+
f"{json_string!r}",
|
|
377
|
+
level=logging.DEBUG,
|
|
492
378
|
)
|
|
493
379
|
return None
|
|
494
380
|
return json_output
|
|
495
381
|
|
|
496
382
|
|
|
497
|
-
@
|
|
383
|
+
@cache_arguments()
|
|
498
384
|
def get_hf_token(api_key: str | None) -> str | bool:
|
|
499
385
|
"""Get the Hugging Face token.
|
|
500
386
|
|
|
@@ -538,8 +424,8 @@ def get_hf_token(api_key: str | None) -> str | bool:
|
|
|
538
424
|
|
|
539
425
|
|
|
540
426
|
def extract_multiple_choice_labels(
|
|
541
|
-
prompt: str, candidate_labels:
|
|
542
|
-
) ->
|
|
427
|
+
prompt: str, candidate_labels: c.Sequence[str]
|
|
428
|
+
) -> c.Sequence[str]:
|
|
543
429
|
"""Extract multiple choice labels from a prompt.
|
|
544
430
|
|
|
545
431
|
Args:
|