ScandEval 16.12.0__py3-none-any.whl → 16.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scandeval/async_utils.py +46 -0
- scandeval/benchmark_config_factory.py +26 -2
- scandeval/benchmark_modules/fresh.py +2 -1
- scandeval/benchmark_modules/hf.py +50 -12
- scandeval/benchmark_modules/litellm.py +25 -15
- scandeval/benchmark_modules/vllm.py +3 -3
- scandeval/benchmarker.py +15 -33
- scandeval/cli.py +2 -4
- scandeval/constants.py +5 -0
- scandeval/custom_dataset_configs.py +152 -0
- scandeval/data_loading.py +87 -31
- scandeval/data_models.py +396 -225
- scandeval/dataset_configs/__init__.py +51 -25
- scandeval/dataset_configs/albanian.py +1 -1
- scandeval/dataset_configs/belarusian.py +47 -0
- scandeval/dataset_configs/bulgarian.py +1 -1
- scandeval/dataset_configs/catalan.py +1 -1
- scandeval/dataset_configs/croatian.py +1 -1
- scandeval/dataset_configs/danish.py +3 -2
- scandeval/dataset_configs/dutch.py +7 -6
- scandeval/dataset_configs/english.py +4 -3
- scandeval/dataset_configs/estonian.py +8 -7
- scandeval/dataset_configs/faroese.py +1 -1
- scandeval/dataset_configs/finnish.py +5 -4
- scandeval/dataset_configs/french.py +6 -5
- scandeval/dataset_configs/german.py +4 -3
- scandeval/dataset_configs/greek.py +1 -1
- scandeval/dataset_configs/hungarian.py +1 -1
- scandeval/dataset_configs/icelandic.py +4 -3
- scandeval/dataset_configs/italian.py +4 -3
- scandeval/dataset_configs/latvian.py +2 -2
- scandeval/dataset_configs/lithuanian.py +1 -1
- scandeval/dataset_configs/norwegian.py +6 -5
- scandeval/dataset_configs/polish.py +4 -3
- scandeval/dataset_configs/portuguese.py +5 -4
- scandeval/dataset_configs/romanian.py +2 -2
- scandeval/dataset_configs/serbian.py +1 -1
- scandeval/dataset_configs/slovene.py +1 -1
- scandeval/dataset_configs/spanish.py +4 -3
- scandeval/dataset_configs/swedish.py +4 -3
- scandeval/dataset_configs/ukrainian.py +1 -1
- scandeval/generation_utils.py +6 -6
- scandeval/metrics/llm_as_a_judge.py +1 -1
- scandeval/metrics/pipeline.py +1 -1
- scandeval/model_cache.py +34 -4
- scandeval/prompt_templates/linguistic_acceptability.py +9 -0
- scandeval/prompt_templates/multiple_choice.py +9 -0
- scandeval/prompt_templates/named_entity_recognition.py +21 -0
- scandeval/prompt_templates/reading_comprehension.py +10 -0
- scandeval/prompt_templates/sentiment_classification.py +11 -0
- scandeval/string_utils.py +157 -0
- scandeval/task_group_utils/sequence_classification.py +2 -5
- scandeval/task_group_utils/token_classification.py +2 -4
- scandeval/utils.py +6 -323
- scandeval-16.13.0.dist-info/METADATA +334 -0
- scandeval-16.13.0.dist-info/RECORD +94 -0
- scandeval-16.12.0.dist-info/METADATA +0 -667
- scandeval-16.12.0.dist-info/RECORD +0 -90
- {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/WHEEL +0 -0
- {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/entry_points.txt +0 -0
- {scandeval-16.12.0.dist-info → scandeval-16.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -68,9 +68,10 @@ VALEU_SV_CONFIG = DatasetConfig(
|
|
|
68
68
|
source="EuroEval/european-values-sv",
|
|
69
69
|
task=EUROPEAN_VALUES,
|
|
70
70
|
languages=[SWEDISH],
|
|
71
|
-
|
|
71
|
+
train_split=None,
|
|
72
|
+
val_split=None,
|
|
72
73
|
bootstrap_samples=False,
|
|
73
|
-
|
|
74
|
+
instruction_prompt="{text}",
|
|
74
75
|
)
|
|
75
76
|
|
|
76
77
|
|
|
@@ -127,7 +128,7 @@ WINOGRANDE_SV_CONFIG = DatasetConfig(
|
|
|
127
128
|
source="EuroEval/winogrande-sv",
|
|
128
129
|
task=COMMON_SENSE,
|
|
129
130
|
languages=[SWEDISH],
|
|
130
|
-
|
|
131
|
+
labels=["a", "b"],
|
|
131
132
|
unofficial=True,
|
|
132
133
|
)
|
|
133
134
|
|
scandeval/generation_utils.py
CHANGED
|
@@ -13,8 +13,8 @@ from datasets import Dataset
|
|
|
13
13
|
from .enums import GenerativeType, TaskGroup
|
|
14
14
|
from .exceptions import InvalidBenchmark, InvalidModel
|
|
15
15
|
from .logging_utils import log_once
|
|
16
|
+
from .string_utils import extract_multiple_choice_labels
|
|
16
17
|
from .tokenisation_utils import apply_chat_template
|
|
17
|
-
from .utils import extract_multiple_choice_labels
|
|
18
18
|
|
|
19
19
|
if t.TYPE_CHECKING:
|
|
20
20
|
from datasets import DatasetDict
|
|
@@ -102,7 +102,7 @@ def extract_few_shot_examples(
|
|
|
102
102
|
)
|
|
103
103
|
label = next(labels)
|
|
104
104
|
possible_examples = shuffled_train.filter(
|
|
105
|
-
lambda x: x["label"].lower() == label.lower()
|
|
105
|
+
lambda x: str(x["label"]).lower() == label.lower()
|
|
106
106
|
)
|
|
107
107
|
assert isinstance(possible_examples, Dataset), (
|
|
108
108
|
f"Expected `possible_examples` to be a Dataset, but got "
|
|
@@ -142,7 +142,7 @@ def extract_few_shot_examples(
|
|
|
142
142
|
while len(few_shot_examples) < num_few_shots and len(shuffled_train) > 0:
|
|
143
143
|
label = next(labels)
|
|
144
144
|
possible_examples = shuffled_train.filter(
|
|
145
|
-
lambda x: label in [tag.lower() for tag in x["labels"]]
|
|
145
|
+
lambda x: label in [str(tag).lower() for tag in x["labels"]]
|
|
146
146
|
)
|
|
147
147
|
assert isinstance(possible_examples, Dataset), (
|
|
148
148
|
f"Expected `possible_examples` to be a Dataset, but got "
|
|
@@ -274,7 +274,7 @@ def apply_prompt(
|
|
|
274
274
|
few_shot_sections = [
|
|
275
275
|
create_prompt(
|
|
276
276
|
text=example["text"].replace("\n", " ").strip(),
|
|
277
|
-
label=example["label"].replace("\n", " ").strip(),
|
|
277
|
+
label=str(example["label"]).replace("\n", " ").strip(),
|
|
278
278
|
labels_str=labels_str,
|
|
279
279
|
)
|
|
280
280
|
for example in few_shot_examples
|
|
@@ -292,7 +292,7 @@ def apply_prompt(
|
|
|
292
292
|
few_shot_sections = [
|
|
293
293
|
create_prompt(
|
|
294
294
|
text=example["text"].replace("\n", " ").strip(),
|
|
295
|
-
label=example["label"].replace("\n", " ").strip(),
|
|
295
|
+
label=str(example["label"]).replace("\n", " ").strip(),
|
|
296
296
|
labels_str=dataset_config.get_labels_str(
|
|
297
297
|
labels=extract_multiple_choice_labels(
|
|
298
298
|
prompt=example["text"],
|
|
@@ -337,7 +337,7 @@ def apply_prompt(
|
|
|
337
337
|
prompt_label: list() for prompt_label in prompt_labels
|
|
338
338
|
}
|
|
339
339
|
for token, label in zip(example["tokens"], example["labels"]):
|
|
340
|
-
label = label.lower()
|
|
340
|
+
label = str(label).lower()
|
|
341
341
|
if label == "o":
|
|
342
342
|
continue
|
|
343
343
|
prompt_label = dataset_config.prompt_label_mapping[label]
|
|
@@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, ValidationError
|
|
|
9
9
|
|
|
10
10
|
from ..exceptions import InvalidBenchmark
|
|
11
11
|
from ..logging_utils import log
|
|
12
|
-
from ..
|
|
12
|
+
from ..string_utils import extract_json_dict_from_string
|
|
13
13
|
from .base import Metric
|
|
14
14
|
|
|
15
15
|
if t.TYPE_CHECKING:
|
scandeval/metrics/pipeline.py
CHANGED
|
@@ -12,7 +12,7 @@ from scipy.special import expit as sigmoid
|
|
|
12
12
|
|
|
13
13
|
from ..exceptions import InvalidBenchmark
|
|
14
14
|
from ..logging_utils import log, no_terminal_output
|
|
15
|
-
from ..
|
|
15
|
+
from ..string_utils import unscramble
|
|
16
16
|
from .base import Metric
|
|
17
17
|
|
|
18
18
|
if t.TYPE_CHECKING:
|
scandeval/model_cache.py
CHANGED
|
@@ -5,9 +5,9 @@ import hashlib
|
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
7
|
import sys
|
|
8
|
-
import typing as t
|
|
9
8
|
from collections import defaultdict
|
|
10
9
|
from dataclasses import asdict
|
|
10
|
+
from pathlib import Path
|
|
11
11
|
|
|
12
12
|
from datasets import Dataset
|
|
13
13
|
|
|
@@ -15,9 +15,6 @@ from .constants import NUM_GENERATION_TOKENS_FOR_CLASSIFICATION
|
|
|
15
15
|
from .data_models import GenerativeModelOutput, SingleGenerativeModelOutput
|
|
16
16
|
from .logging_utils import get_pbar, log, log_once
|
|
17
17
|
|
|
18
|
-
if t.TYPE_CHECKING:
|
|
19
|
-
from pathlib import Path
|
|
20
|
-
|
|
21
18
|
|
|
22
19
|
class ModelCache:
|
|
23
20
|
"""A cache for model outputs.
|
|
@@ -295,3 +292,36 @@ def load_cached_model_outputs(
|
|
|
295
292
|
|
|
296
293
|
cached_scores = [model_output.scores or [] for model_output in cached_model_outputs]
|
|
297
294
|
return GenerativeModelOutput(sequences=cached_sequences, scores=cached_scores)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def create_model_cache_dir(cache_dir: str, model_id: str) -> str:
|
|
298
|
+
"""Create cache directory for a model.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
cache_dir:
|
|
302
|
+
The cache directory.
|
|
303
|
+
model_id:
|
|
304
|
+
The model ID.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
The path to the cache directory.
|
|
308
|
+
"""
|
|
309
|
+
# If the model ID is a path, we just use that as the cache dir
|
|
310
|
+
if Path(model_id).is_dir():
|
|
311
|
+
log_once(
|
|
312
|
+
f"Since the model {model_id!r} is a local model, we will use the model "
|
|
313
|
+
"directory directly as the model cache directory.",
|
|
314
|
+
level=logging.DEBUG,
|
|
315
|
+
)
|
|
316
|
+
return model_id
|
|
317
|
+
|
|
318
|
+
# Otherwise, we create a cache dir based on the model ID
|
|
319
|
+
model_cache_dir = Path(
|
|
320
|
+
cache_dir, "model_cache", model_id.replace("/", "--")
|
|
321
|
+
).as_posix()
|
|
322
|
+
log_once(
|
|
323
|
+
f"Using the model cache directory {model_cache_dir!r} for the model "
|
|
324
|
+
f"{model_id!r}.",
|
|
325
|
+
level=logging.DEBUG,
|
|
326
|
+
)
|
|
327
|
+
return model_cache_dir
|
|
@@ -5,6 +5,7 @@ import typing as t
|
|
|
5
5
|
from ..data_models import PromptConfig
|
|
6
6
|
from ..languages import (
|
|
7
7
|
ALBANIAN,
|
|
8
|
+
BELARUSIAN,
|
|
8
9
|
BULGARIAN,
|
|
9
10
|
CATALAN,
|
|
10
11
|
CROATIAN,
|
|
@@ -49,6 +50,14 @@ LA_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
49
50
|
default_instruction_prompt="Fjali: {text}\n\nPërcaktoni nëse fjalia është "
|
|
50
51
|
"gramatikisht e saktë apo jo. Përgjigjuni me {labels_str}, dhe asgjë tjetër.",
|
|
51
52
|
),
|
|
53
|
+
BELARUSIAN: PromptConfig(
|
|
54
|
+
default_prompt_label_mapping=dict(correct="так", incorrect="не"),
|
|
55
|
+
default_prompt_prefix="Ніжэй прыведзены сказы і ці з'яўляюцца яны "
|
|
56
|
+
"граматычна правільнымі.",
|
|
57
|
+
default_prompt_template="Сказ: {text}\nГраматычна правільны: {label}",
|
|
58
|
+
default_instruction_prompt="Сказ: {text}\n\nВызначце, ці сказ граматычна "
|
|
59
|
+
"правільны ці не. Адкажыце толькі {labels_str}, і нічога іншага.",
|
|
60
|
+
),
|
|
52
61
|
BULGARIAN: PromptConfig(
|
|
53
62
|
default_prompt_label_mapping=dict(correct="да", incorrect="не"),
|
|
54
63
|
default_prompt_prefix="Следват изречения и дали са граматически правилни.",
|
|
@@ -5,6 +5,7 @@ import typing as t
|
|
|
5
5
|
from ..data_models import PromptConfig
|
|
6
6
|
from ..languages import (
|
|
7
7
|
ALBANIAN,
|
|
8
|
+
BELARUSIAN,
|
|
8
9
|
BULGARIAN,
|
|
9
10
|
CATALAN,
|
|
10
11
|
CROATIAN,
|
|
@@ -49,6 +50,14 @@ MULTIPLE_CHOICE_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
49
50
|
"mësipërme duke u përgjigjur me {labels_str}, dhe asgjë tjetër.",
|
|
50
51
|
default_prompt_label_mapping="auto",
|
|
51
52
|
),
|
|
53
|
+
BELARUSIAN: PromptConfig(
|
|
54
|
+
default_prompt_prefix="Ніжэй прыведзены пытанні з некалькімі варыянтамі "
|
|
55
|
+
"адказу (з адказамі).",
|
|
56
|
+
default_prompt_template="Пытанне: {text}\nАдказ: {label}",
|
|
57
|
+
default_instruction_prompt="Пытанне: {text}\n\nАдкажыце на пытанне вышэй, "
|
|
58
|
+
"адказаўшы {labels_str}, і нічога іншага.",
|
|
59
|
+
default_prompt_label_mapping="auto",
|
|
60
|
+
),
|
|
52
61
|
BULGARIAN: PromptConfig(
|
|
53
62
|
default_prompt_prefix="Следват въпроси с множествен избор (с отговори).",
|
|
54
63
|
default_prompt_template="Въпрос: {text}\nОтговор: {label}",
|
|
@@ -5,6 +5,7 @@ import typing as t
|
|
|
5
5
|
from ..data_models import PromptConfig
|
|
6
6
|
from ..languages import (
|
|
7
7
|
ALBANIAN,
|
|
8
|
+
BELARUSIAN,
|
|
8
9
|
BOSNIAN,
|
|
9
10
|
BULGARIAN,
|
|
10
11
|
CATALAN,
|
|
@@ -62,6 +63,26 @@ NER_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
62
63
|
"{labels_str}. Vlerat duhet të jenë lista të entiteteve të emërtuara të atij "
|
|
63
64
|
"lloji, saktësisht ashtu siç shfaqen në fjali.",
|
|
64
65
|
),
|
|
66
|
+
BELARUSIAN: PromptConfig(
|
|
67
|
+
default_prompt_label_mapping={
|
|
68
|
+
"b-per": "асоба",
|
|
69
|
+
"i-per": "асоба",
|
|
70
|
+
"b-loc": "месца",
|
|
71
|
+
"i-loc": "месца",
|
|
72
|
+
"b-org": "арганізацыя",
|
|
73
|
+
"i-org": "арганізацыя",
|
|
74
|
+
"b-misc": "рознае",
|
|
75
|
+
"i-misc": "рознае",
|
|
76
|
+
},
|
|
77
|
+
default_prompt_prefix="Ніжэй прыведзены сказы і JSON-слоўнікі з іменаванымі "
|
|
78
|
+
"сутнасцямі, якія прысутнічаюць у дадзеным сказе.",
|
|
79
|
+
default_prompt_template="Сказ: {text}\nІменаваныя сутнасці: {label}",
|
|
80
|
+
default_instruction_prompt="Сказ: {text}\n\n"
|
|
81
|
+
"Ідэнтыфікуйце іменаваныя сутнасці ў сказе. Вы павінны вывесці гэта як "
|
|
82
|
+
"JSON-слоўнік з ключамі {labels_str}. Значэнні павінны быць спісамі "
|
|
83
|
+
"іменаваных сутнасцей гэтага тыпу, дакладна такімі, як яны з'яўляюцца ў "
|
|
84
|
+
"сказе.",
|
|
85
|
+
),
|
|
65
86
|
BOSNIAN: PromptConfig(
|
|
66
87
|
default_prompt_label_mapping={
|
|
67
88
|
"b-per": "osoba",
|
|
@@ -5,6 +5,7 @@ import typing as t
|
|
|
5
5
|
from ..data_models import PromptConfig
|
|
6
6
|
from ..languages import (
|
|
7
7
|
ALBANIAN,
|
|
8
|
+
BELARUSIAN,
|
|
8
9
|
BOSNIAN,
|
|
9
10
|
BULGARIAN,
|
|
10
11
|
CATALAN,
|
|
@@ -50,6 +51,15 @@ RC_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
50
51
|
"rreth tekstit të mësipërm me maksimum 3 fjalë.\n\nPyetje: {question}",
|
|
51
52
|
default_prompt_label_mapping=dict(),
|
|
52
53
|
),
|
|
54
|
+
BELARUSIAN: PromptConfig(
|
|
55
|
+
default_prompt_prefix="Ніжэй прыведзены тэксты з адпаведнымі пытаннямі і "
|
|
56
|
+
"адказамі.",
|
|
57
|
+
default_prompt_template="Тэкст: {text}\nПытанне: {question}\nАдказ "
|
|
58
|
+
"максімум 3 словамі: {label}",
|
|
59
|
+
default_instruction_prompt="Тэкст: {text}\n\nАдкажыце на наступнае пытанне "
|
|
60
|
+
"пра тэкст вышэй максімум 3 словамі.\n\nПытанне: {question}",
|
|
61
|
+
default_prompt_label_mapping=dict(),
|
|
62
|
+
),
|
|
53
63
|
BOSNIAN: PromptConfig(
|
|
54
64
|
default_prompt_prefix="Slijede tekstovi s pitanjima i odgovorima.",
|
|
55
65
|
default_prompt_template="Tekst: {text}\nPitanje: {question}\nOdgovor s "
|
|
@@ -5,6 +5,7 @@ import typing as t
|
|
|
5
5
|
from ..data_models import PromptConfig
|
|
6
6
|
from ..languages import (
|
|
7
7
|
ALBANIAN,
|
|
8
|
+
BELARUSIAN,
|
|
8
9
|
BOSNIAN,
|
|
9
10
|
BULGARIAN,
|
|
10
11
|
CATALAN,
|
|
@@ -52,6 +53,16 @@ SENT_TEMPLATES: dict["Language", PromptConfig] = {
|
|
|
52
53
|
default_instruction_prompt="Dokument: {text}\n\nKlasifikoni ndjenjën në "
|
|
53
54
|
"dokument. Përgjigjuni vetëm me {labels_str}, dhe asgjë tjetër.",
|
|
54
55
|
),
|
|
56
|
+
BELARUSIAN: PromptConfig(
|
|
57
|
+
default_prompt_label_mapping=dict(
|
|
58
|
+
positive="станоўчы", neutral="нейтральны", negative="адмоўны"
|
|
59
|
+
),
|
|
60
|
+
default_prompt_prefix="Ніжэй прыведзены дакументы і іх сентымент, які можа "
|
|
61
|
+
"быць {labels_str}.",
|
|
62
|
+
default_prompt_template="Дакумент: {text}\nСентымент: {label}",
|
|
63
|
+
default_instruction_prompt="Дакумент: {text}\n\nКласіфікуйце сентымент у "
|
|
64
|
+
"дакуменце. Адкажыце толькі {labels_str}, і нічога іншага.",
|
|
65
|
+
),
|
|
55
66
|
BOSNIAN: PromptConfig(
|
|
56
67
|
default_prompt_label_mapping=dict(
|
|
57
68
|
positive="pozitivno", neutral="neutralno", negative="negativno"
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Utility functions related to string manipulation or structuring."""
|
|
2
|
+
|
|
3
|
+
import collections.abc as c
|
|
4
|
+
import logging
|
|
5
|
+
import re
|
|
6
|
+
import typing as t
|
|
7
|
+
|
|
8
|
+
import demjson3
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from .exceptions import InvalidBenchmark, InvalidModel
|
|
12
|
+
from .logging_utils import log
|
|
13
|
+
|
|
14
|
+
if t.TYPE_CHECKING:
|
|
15
|
+
from .data_models import ModelIdComponents
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def scramble(text: str) -> str:
|
|
19
|
+
"""Scramble a string in a bijective manner.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
text:
|
|
23
|
+
The string to scramble.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The scrambled string.
|
|
27
|
+
"""
|
|
28
|
+
rng = np.random.default_rng(seed=4242)
|
|
29
|
+
permutation = rng.permutation(x=len(text))
|
|
30
|
+
scrambled = "".join(text[i] for i in permutation)
|
|
31
|
+
return scrambled
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def unscramble(scrambled_text: str) -> str:
|
|
35
|
+
"""Unscramble a string in a bijective manner.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
scrambled_text:
|
|
39
|
+
The scrambled string to unscramble.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
The unscrambled string.
|
|
43
|
+
"""
|
|
44
|
+
rng = np.random.default_rng(seed=4242)
|
|
45
|
+
permutation = rng.permutation(x=len(scrambled_text))
|
|
46
|
+
inverse_permutation = np.argsort(permutation)
|
|
47
|
+
unscrambled = "".join(scrambled_text[i] for i in inverse_permutation)
|
|
48
|
+
return unscrambled
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def extract_json_dict_from_string(s: str) -> dict | None:
|
|
52
|
+
"""Extract a JSON dictionary from a string.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
s:
|
|
56
|
+
The string to extract the JSON dictionary from.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The extracted JSON dictionary, or None if no JSON dictionary could be found.
|
|
60
|
+
"""
|
|
61
|
+
json_regex = r"\{[^{}]*?\}"
|
|
62
|
+
if (json_match := re.search(pattern=json_regex, string=s, flags=re.DOTALL)) is None:
|
|
63
|
+
log(
|
|
64
|
+
"The model output does not contain any JSON dictionary, so cannot parse "
|
|
65
|
+
f"it. Skipping. Here is the output: {s!r}",
|
|
66
|
+
level=logging.DEBUG,
|
|
67
|
+
)
|
|
68
|
+
return None
|
|
69
|
+
json_string = json_match.group()
|
|
70
|
+
try:
|
|
71
|
+
json_output = demjson3.decode(txt=json_string)
|
|
72
|
+
except demjson3.JSONDecodeError:
|
|
73
|
+
log(
|
|
74
|
+
"The model output is not valid JSON, so cannot parse it. Skipping. "
|
|
75
|
+
f"Here is the output: {json_string!r}",
|
|
76
|
+
level=logging.DEBUG,
|
|
77
|
+
)
|
|
78
|
+
return None
|
|
79
|
+
if not isinstance(json_output, dict):
|
|
80
|
+
log(
|
|
81
|
+
"The model output is not a JSON dictionary, so cannot parse "
|
|
82
|
+
f"it. Skipping. Here is the output: {json_string!r}",
|
|
83
|
+
level=logging.DEBUG,
|
|
84
|
+
)
|
|
85
|
+
return None
|
|
86
|
+
elif not all(isinstance(key, str) for key in json_output.keys()):
|
|
87
|
+
log(
|
|
88
|
+
"The model output is not a JSON dictionary with string keys, "
|
|
89
|
+
"so cannot parse it. Skipping. Here is the output: "
|
|
90
|
+
f"{json_string!r}",
|
|
91
|
+
level=logging.DEBUG,
|
|
92
|
+
)
|
|
93
|
+
return None
|
|
94
|
+
return json_output
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def extract_multiple_choice_labels(
|
|
98
|
+
prompt: str, candidate_labels: c.Sequence[str]
|
|
99
|
+
) -> c.Sequence[str]:
|
|
100
|
+
"""Extract multiple choice labels from a prompt.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
prompt:
|
|
104
|
+
The prompt to extract the labels from.
|
|
105
|
+
candidate_labels:
|
|
106
|
+
The candidate labels to look for in the prompt.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
The extracted labels.
|
|
110
|
+
"""
|
|
111
|
+
sample_candidate_labels: list[str] = list()
|
|
112
|
+
for candidate_label in candidate_labels:
|
|
113
|
+
candidate_label_match = re.search(
|
|
114
|
+
pattern=rf"\b{candidate_label}\. ", string=prompt, flags=re.IGNORECASE
|
|
115
|
+
)
|
|
116
|
+
if candidate_label_match is not None:
|
|
117
|
+
sample_candidate_labels.append(candidate_label)
|
|
118
|
+
if not sample_candidate_labels:
|
|
119
|
+
raise InvalidBenchmark(
|
|
120
|
+
"Could not extract any candidate labels from the prompt. Please ensure "
|
|
121
|
+
"that the candidate labels are present in the prompt, each followed by a "
|
|
122
|
+
"dot and a space (e.g., 'a. '). The candidate labels are: "
|
|
123
|
+
f"{', '.join(candidate_labels)}. Here is the prompt: {prompt!r}"
|
|
124
|
+
)
|
|
125
|
+
return sample_candidate_labels
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def split_model_id(model_id: str) -> "ModelIdComponents":
|
|
129
|
+
"""Split a model ID into its components.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model_id:
|
|
133
|
+
The model ID to split.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
The split model ID.
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
If the model ID is not valid.
|
|
140
|
+
"""
|
|
141
|
+
# Importing here to avoid circular imports
|
|
142
|
+
from .data_models import ModelIdComponents
|
|
143
|
+
|
|
144
|
+
# Attempt to extract the model ID, revision, and param using regex
|
|
145
|
+
model_id_match = re.match(pattern=r"^[^@#]+", string=model_id)
|
|
146
|
+
revision_match = re.search(pattern=r"@([^@#]+)", string=model_id)
|
|
147
|
+
param_match = re.search(pattern=r"#([^@#]+)", string=model_id)
|
|
148
|
+
|
|
149
|
+
# If we cannot extract the model ID, raise an error
|
|
150
|
+
if model_id_match is None:
|
|
151
|
+
raise InvalidModel(f"The model ID {model_id!r} is not valid.")
|
|
152
|
+
model_id = model_id_match.group()
|
|
153
|
+
|
|
154
|
+
# Extract the revision and param and return the result
|
|
155
|
+
revision = revision_match.group(1) if revision_match is not None else "main"
|
|
156
|
+
param = param_match.group(1) if param_match is not None else None
|
|
157
|
+
return ModelIdComponents(model_id=model_id, revision=revision, param=param)
|
|
@@ -10,12 +10,9 @@ import numpy as np
|
|
|
10
10
|
|
|
11
11
|
from ..enums import TaskGroup
|
|
12
12
|
from ..exceptions import InvalidBenchmark
|
|
13
|
+
from ..string_utils import extract_multiple_choice_labels
|
|
13
14
|
from ..types import Predictions
|
|
14
|
-
from ..utils import
|
|
15
|
-
extract_multiple_choice_labels,
|
|
16
|
-
log_once,
|
|
17
|
-
raise_if_model_output_contains_nan_values,
|
|
18
|
-
)
|
|
15
|
+
from ..utils import log_once, raise_if_model_output_contains_nan_values
|
|
19
16
|
|
|
20
17
|
if t.TYPE_CHECKING:
|
|
21
18
|
from datasets.arrow_dataset import Dataset
|
|
@@ -9,10 +9,8 @@ import numpy as np
|
|
|
9
9
|
|
|
10
10
|
from ..exceptions import InvalidBenchmark
|
|
11
11
|
from ..logging_utils import log
|
|
12
|
-
from ..
|
|
13
|
-
|
|
14
|
-
raise_if_model_output_contains_nan_values,
|
|
15
|
-
)
|
|
12
|
+
from ..string_utils import extract_json_dict_from_string
|
|
13
|
+
from ..utils import raise_if_model_output_contains_nan_values
|
|
16
14
|
|
|
17
15
|
if t.TYPE_CHECKING:
|
|
18
16
|
from datasets.arrow_dataset import Dataset
|