eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.completion.grid_difference import GridDifference
|
|
4
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType
|
|
5
|
+
|
|
6
|
+
SUBJECTS = [
|
|
7
|
+
"1_random_cell_easy",
|
|
8
|
+
"5_random_cell_easy",
|
|
9
|
+
"10_random_cell_easy",
|
|
10
|
+
"1_random_row_easy",
|
|
11
|
+
"3_random_row_easy",
|
|
12
|
+
"1_random_column_easy",
|
|
13
|
+
"3_random_column_easy",
|
|
14
|
+
"full_easy",
|
|
15
|
+
"1_random_cell_hard",
|
|
16
|
+
"5_random_cell_hard",
|
|
17
|
+
"10_random_cell_hard",
|
|
18
|
+
"1_random_row_hard",
|
|
19
|
+
"3_random_row_hard",
|
|
20
|
+
"1_random_column_hard",
|
|
21
|
+
"3_random_column_hard",
|
|
22
|
+
"full_hard",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
SYSTEM_PROMPT = """You are given a structural material distribution represented as a grid. Each cell can have one of the following states:
|
|
26
|
+
- 'L' indicates applied load.
|
|
27
|
+
- 'V' indicates void.
|
|
28
|
+
- 'S' indicates support.
|
|
29
|
+
|
|
30
|
+
The goal is to predict the correct material distribution by filling in all {FILL_INSTRUCTION}, based on the surrounding structure and implicit physical reasoning (such as load paths, supports, and forces).
|
|
31
|
+
|
|
32
|
+
Important: The completed structure should use as little material as possible while remaining stable and plausible for carrying the applied forces. Minimize material usage unless necessary for structural support.""" # noqa: E501
|
|
33
|
+
|
|
34
|
+
PROMPT_TEMPLATE = """Below is the input grid with masked regions:
|
|
35
|
+
|
|
36
|
+
{GRID}
|
|
37
|
+
|
|
38
|
+
Please output the completed grid by replacing all {FILL_INSTRUCTION}.
|
|
39
|
+
Maintain the same format as the input: one row per line, cells separated by spaces, and the total number of rows and columns unchanged.
|
|
40
|
+
Return only the completed grid without any additional explanation.""" # noqa: E501
|
|
41
|
+
|
|
42
|
+
EASY_FILL_INSTRUCTION = "'V' cells with either '1' (solid) or '0' (empty)"
|
|
43
|
+
|
|
44
|
+
HARD_FILL_INSTRUCTION = (
|
|
45
|
+
"'V' cells with a floating point number between 0 and 1, with one decimal place (e.g., 0.0, 0.1, 0.2, ..., 1.0)"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SPHYR(BaseTask[str]):
|
|
50
|
+
"""SPhyR dataset: https://huggingface.co/datasets/philippds/SPhyR"""
|
|
51
|
+
|
|
52
|
+
NAME = "SPHYR"
|
|
53
|
+
DATASET_PATH = "philippds/SPhyR"
|
|
54
|
+
SAMPLE_SPLIT = "test"
|
|
55
|
+
FEWSHOT_SPLIT = ""
|
|
56
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
57
|
+
METRICS = [GridDifference]
|
|
58
|
+
SUBJECTS = SUBJECTS
|
|
59
|
+
PERTURBATION_UNMODIFIABLE_WORDS = None
|
|
60
|
+
LANGUAGE = Language.ENG
|
|
61
|
+
|
|
62
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
63
|
+
assert num_fewshot == 0, "Fewshot is not supported for SPHYR"
|
|
64
|
+
super().__init__(num_fewshot)
|
|
65
|
+
|
|
66
|
+
def _grid_to_str(self, grid: list[list[str]]) -> str:
|
|
67
|
+
return "\n".join(" ".join(str(cell) for cell in row) for row in grid)
|
|
68
|
+
|
|
69
|
+
def _get_system_prompt_text(self, item: dict[str, Any]) -> str | None:
|
|
70
|
+
FILL_INSTRUCTION = EASY_FILL_INSTRUCTION if "easy" in item["subject"] else HARD_FILL_INSTRUCTION
|
|
71
|
+
return SYSTEM_PROMPT.format(FILL_INSTRUCTION=FILL_INSTRUCTION)
|
|
72
|
+
|
|
73
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
74
|
+
FILL_INSTRUCTION = EASY_FILL_INSTRUCTION if "easy" in item["subject"] else HARD_FILL_INSTRUCTION
|
|
75
|
+
grid = self._grid_to_str(item["input_grid"])
|
|
76
|
+
return PROMPT_TEMPLATE.format(GRID=grid, FILL_INSTRUCTION=FILL_INSTRUCTION)
|
|
77
|
+
|
|
78
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
79
|
+
return self._grid_to_str(item["ground_truth"])
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
from datasets import Dataset, DatasetDict, DownloadConfig, load_dataset
|
|
8
|
+
from huggingface_hub import HfApi
|
|
9
|
+
from huggingface_hub.errors import RevisionNotFoundError
|
|
10
|
+
|
|
11
|
+
from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
|
|
12
|
+
from eval_framework.metrics.completion.f1 import F1
|
|
13
|
+
from eval_framework.tasks.base import NO_SUBJECT, RANDOM_SEED, BaseTask, Language, ResponseType, SubjectType
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SQUAD2(BaseTask[str]):
|
|
17
|
+
"""Squad v2 dataset: https://huggingface.co/datasets/rajpurkar/squad_v2"""
|
|
18
|
+
|
|
19
|
+
NAME = "SQuAD2"
|
|
20
|
+
DATASET_PATH = "rajpurkar/squad_v2"
|
|
21
|
+
SAMPLE_SPLIT = "validation"
|
|
22
|
+
FEWSHOT_SPLIT = "train"
|
|
23
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
24
|
+
METRICS = [AccuracyCompletion, F1]
|
|
25
|
+
SUBJECTS = [NO_SUBJECT]
|
|
26
|
+
UNANSWERABLE_STR = "unanswerable"
|
|
27
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Answer", "Context", "unanswerable"]
|
|
28
|
+
LANGUAGE = Language.ENG
|
|
29
|
+
|
|
30
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
31
|
+
super().__init__(num_fewshot)
|
|
32
|
+
self.stop_sequences = [".\n"]
|
|
33
|
+
self.max_tokens = 300 # the max length of the ground truth is 160 characters while the average is ~19
|
|
34
|
+
self.rnd_choice_shuffle = random.Random()
|
|
35
|
+
|
|
36
|
+
def _get_squad_urls(self) -> dict[str, str]:
|
|
37
|
+
"""Get the URLs for this SQUAD version."""
|
|
38
|
+
return {
|
|
39
|
+
"train": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v2.0.json",
|
|
40
|
+
"validation": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v2.0.json",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
def _load_hf_dataset(self, **kwargs: Any) -> Any:
|
|
44
|
+
"""Load SQUAD dataset, falling back to JSON if HF fails."""
|
|
45
|
+
# Validate HF revision if specified
|
|
46
|
+
self._validate_hf_revision(kwargs.get("path", self.DATASET_PATH))
|
|
47
|
+
|
|
48
|
+
# Try HuggingFace first
|
|
49
|
+
try:
|
|
50
|
+
return self._load_from_huggingface(**kwargs)
|
|
51
|
+
except ValueError as e:
|
|
52
|
+
if "Feature type 'List' not found" in str(e):
|
|
53
|
+
import warnings
|
|
54
|
+
|
|
55
|
+
warnings.warn(
|
|
56
|
+
f"Dataset {kwargs.get('path', self.DATASET_PATH)} has incompatible feature types "
|
|
57
|
+
"(List instead of Sequence), loading directly from JSON files"
|
|
58
|
+
)
|
|
59
|
+
return self._load_from_json(**kwargs)
|
|
60
|
+
raise
|
|
61
|
+
|
|
62
|
+
def _validate_hf_revision(self, dataset_path: str) -> None:
|
|
63
|
+
"""Validate HuggingFace revision if specified."""
|
|
64
|
+
if self.HF_REVISION:
|
|
65
|
+
try:
|
|
66
|
+
HfApi().dataset_info(repo_id=dataset_path, revision=self.HF_REVISION, timeout=100.0)
|
|
67
|
+
except RevisionNotFoundError:
|
|
68
|
+
raise
|
|
69
|
+
|
|
70
|
+
def _load_from_huggingface(self, **kwargs: Any) -> Any:
|
|
71
|
+
"""Load dataset from HuggingFace."""
|
|
72
|
+
cache_dir = os.environ.get("HF_DATASET_CACHE_DIR", f"{Path.home()}/.cache/huggingface/datasets")
|
|
73
|
+
download_config = DownloadConfig(cache_dir=cache_dir, max_retries=5)
|
|
74
|
+
|
|
75
|
+
return load_dataset(
|
|
76
|
+
**kwargs,
|
|
77
|
+
revision=self.HF_REVISION,
|
|
78
|
+
trust_remote_code=True,
|
|
79
|
+
cache_dir=cache_dir,
|
|
80
|
+
download_config=download_config,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _load_from_json(self, **kwargs: Any) -> Dataset | DatasetDict:
|
|
84
|
+
"""Load SQUAD directly from GitHub JSON files."""
|
|
85
|
+
urls = self._get_squad_urls()
|
|
86
|
+
requested_split = kwargs.get("split")
|
|
87
|
+
splits_to_load = [requested_split] if requested_split else list(urls.keys())
|
|
88
|
+
|
|
89
|
+
datasets = {}
|
|
90
|
+
for split in splits_to_load:
|
|
91
|
+
if split not in urls:
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
dataset = self._download_and_parse_split(split, urls[split])
|
|
95
|
+
if dataset:
|
|
96
|
+
datasets[split] = dataset
|
|
97
|
+
|
|
98
|
+
if not datasets:
|
|
99
|
+
raise ValueError(f"Failed to load any splits for {kwargs.get('path', self.DATASET_PATH)}")
|
|
100
|
+
|
|
101
|
+
# Return single dataset or DatasetDict depending on what was requested
|
|
102
|
+
return datasets[requested_split] if requested_split else DatasetDict(datasets)
|
|
103
|
+
|
|
104
|
+
def _download_and_parse_split(self, split: str, url: str) -> Dataset | None:
|
|
105
|
+
"""Download and parse a single SQUAD split."""
|
|
106
|
+
try:
|
|
107
|
+
# Download the data
|
|
108
|
+
response = requests.get(url, timeout=30)
|
|
109
|
+
response.raise_for_status()
|
|
110
|
+
squad_data = response.json()
|
|
111
|
+
|
|
112
|
+
# Flatten the nested structure
|
|
113
|
+
examples = self._flatten_squad_data(squad_data)
|
|
114
|
+
return Dataset.from_list(examples)
|
|
115
|
+
|
|
116
|
+
except Exception as e:
|
|
117
|
+
import warnings
|
|
118
|
+
|
|
119
|
+
warnings.warn(f"Failed to download {split} split: {e}")
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
def _flatten_squad_data(self, squad_data: dict) -> list[dict]:
|
|
123
|
+
"""Flatten nested SQUAD JSON structure into examples."""
|
|
124
|
+
examples = []
|
|
125
|
+
for article in squad_data["data"]:
|
|
126
|
+
title = article["title"]
|
|
127
|
+
for paragraph in article["paragraphs"]:
|
|
128
|
+
context = paragraph["context"]
|
|
129
|
+
for qa in paragraph["qas"]:
|
|
130
|
+
example = {
|
|
131
|
+
"id": qa["id"],
|
|
132
|
+
"title": title,
|
|
133
|
+
"context": context,
|
|
134
|
+
"question": qa["question"],
|
|
135
|
+
"answers": {
|
|
136
|
+
"text": [answer["text"] for answer in qa.get("answers", [])],
|
|
137
|
+
"answer_start": [answer["answer_start"] for answer in qa.get("answers", [])],
|
|
138
|
+
},
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
examples.append(example)
|
|
142
|
+
return examples
|
|
143
|
+
|
|
144
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
145
|
+
name = subject if subject != NO_SUBJECT else None
|
|
146
|
+
|
|
147
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
|
|
148
|
+
self.dataset = {}
|
|
149
|
+
|
|
150
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
151
|
+
|
|
152
|
+
for split, data in hf_dataset.items():
|
|
153
|
+
data_list = list(data)
|
|
154
|
+
|
|
155
|
+
if split == self.SAMPLE_SPLIT:
|
|
156
|
+
self.rnd.shuffle(data_list)
|
|
157
|
+
|
|
158
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
159
|
+
self.dataset[split] = data_list
|
|
160
|
+
|
|
161
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
162
|
+
prompt = (
|
|
163
|
+
"Given the following context, answer the question. If the question cannot be answered based "
|
|
164
|
+
f"on the context alone, respond with '{self.UNANSWERABLE_STR}'.\n\n"
|
|
165
|
+
"Context:\n"
|
|
166
|
+
f"{item['context']}\n\n"
|
|
167
|
+
f"Question:\n{item['question']}\nAnswer:"
|
|
168
|
+
)
|
|
169
|
+
return prompt
|
|
170
|
+
|
|
171
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> list[str]:
|
|
172
|
+
text_ = item["answers"]["text"]
|
|
173
|
+
ground_truth_for_unanswerable = [
|
|
174
|
+
self.UNANSWERABLE_STR,
|
|
175
|
+
self.UNANSWERABLE_STR + " ",
|
|
176
|
+
self.UNANSWERABLE_STR.capitalize(),
|
|
177
|
+
]
|
|
178
|
+
ground_truths = text_ if text_ else ground_truth_for_unanswerable
|
|
179
|
+
return [f" {ground_truth}" for ground_truth in ground_truths]
|
|
180
|
+
|
|
181
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
182
|
+
target = self._get_ground_truth(item)[0]
|
|
183
|
+
assert target is not None
|
|
184
|
+
assert isinstance(target, str)
|
|
185
|
+
return target
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class SQUAD(SQUAD2):
|
|
189
|
+
"""Squad dataset: https://huggingface.co/datasets/rajpurkar/squad"""
|
|
190
|
+
|
|
191
|
+
NAME = "SQuAD"
|
|
192
|
+
DATASET_PATH = "rajpurkar/squad"
|
|
193
|
+
|
|
194
|
+
def _get_squad_urls(self) -> dict[str, str]:
|
|
195
|
+
"""Override to provide SQUAD v1 URLs."""
|
|
196
|
+
return {
|
|
197
|
+
"train": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/train-v1.1.json",
|
|
198
|
+
"validation": "https://raw.githubusercontent.com/rajpurkar/SQuAD-explorer/master/dataset/dev-v1.1.json",
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
202
|
+
prompt = (
|
|
203
|
+
"Given the following context, answer the question.\n\n"
|
|
204
|
+
"Context:\n"
|
|
205
|
+
f"{item['context']}\n\n"
|
|
206
|
+
f"Question:\n{item['question']}\n"
|
|
207
|
+
)
|
|
208
|
+
return prompt
|
|
209
|
+
|
|
210
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> list[str]:
|
|
211
|
+
return item["answers"]["text"]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from datasets import DatasetDict
|
|
7
|
+
|
|
8
|
+
from eval_framework.metrics.completion.struct_eval_metrics import (
|
|
9
|
+
RenderableStructMetric,
|
|
10
|
+
RenderableStructMetricContext,
|
|
11
|
+
StructMetric,
|
|
12
|
+
StructMetricContext,
|
|
13
|
+
)
|
|
14
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, Sample
|
|
15
|
+
|
|
16
|
+
StructEvalSubjects = [
|
|
17
|
+
"CSV to YAML",
|
|
18
|
+
"JSON to XML",
|
|
19
|
+
"JSON to CSV",
|
|
20
|
+
"XML to JSON",
|
|
21
|
+
"XML to YAML",
|
|
22
|
+
"Text to XML",
|
|
23
|
+
"Text to YAML",
|
|
24
|
+
"Text to TOML",
|
|
25
|
+
"YAML to JSON",
|
|
26
|
+
"TOML to JSON",
|
|
27
|
+
"Text to CSV",
|
|
28
|
+
"YAML to XML",
|
|
29
|
+
"JSON to YAML",
|
|
30
|
+
"TOML to YAML",
|
|
31
|
+
"YAML to CSV",
|
|
32
|
+
"CSV to JSON",
|
|
33
|
+
"CSV to XML",
|
|
34
|
+
"Text to JSON",
|
|
35
|
+
"XML to CSV",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class StructEval(BaseTask[str]):
|
|
40
|
+
"""StructEval task: https://tiger-ai-lab.github.io/StructEval/"""
|
|
41
|
+
|
|
42
|
+
NAME = "StructEval"
|
|
43
|
+
DATASET_PATH = "TIGER-Lab/StructEval"
|
|
44
|
+
SAMPLE_SPLIT = "train"
|
|
45
|
+
FEWSHOT_SPLIT = "train" # Only has train split
|
|
46
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
47
|
+
METRICS = [StructMetric] # Define appropriate metrics for StructEval
|
|
48
|
+
SUBJECTS = StructEvalSubjects
|
|
49
|
+
LANGUAGE = Language.ENG
|
|
50
|
+
HF_REVISION = "b551217560cf225245b0607a21c505e24a58e396"
|
|
51
|
+
|
|
52
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
53
|
+
if num_fewshot > 0:
|
|
54
|
+
raise ValueError("StructEval only supports zero-shot evaluation.")
|
|
55
|
+
super().__init__(num_fewshot)
|
|
56
|
+
|
|
57
|
+
def _load_dataset(self, subject: str) -> None:
|
|
58
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH)
|
|
59
|
+
assert isinstance(hf_dataset, DatasetDict), "Expected a Hugging Face Dataset object."
|
|
60
|
+
hf_dataset = hf_dataset.filter(lambda item: item["task_name"] == subject, num_proc=os.cpu_count())
|
|
61
|
+
self.dataset = {}
|
|
62
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
63
|
+
for split, data in hf_dataset.items():
|
|
64
|
+
if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
65
|
+
continue
|
|
66
|
+
data_list = list(data)
|
|
67
|
+
if split == self.SAMPLE_SPLIT:
|
|
68
|
+
self.rnd.shuffle(data_list)
|
|
69
|
+
|
|
70
|
+
self.dataset[split] = data_list
|
|
71
|
+
|
|
72
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
73
|
+
return (
|
|
74
|
+
f"{item['query']}\n\nIMPORTANT: Only output the required output format. "
|
|
75
|
+
"You must start the format/code with <|BEGIN_CODE|> and end the format/code with <|END_CODE|>. "
|
|
76
|
+
"No other text output (explanation, comments, etc.) are allowed. Do not use markdown code fences.\n"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def _get_context(self, item: dict[str, Any]) -> StructMetricContext | RenderableStructMetricContext:
|
|
80
|
+
return StructMetricContext(
|
|
81
|
+
output_type=item["output_type"],
|
|
82
|
+
paths=item["raw_output_metric"],
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
86
|
+
return "<|BEGIN_CODE|>"
|
|
87
|
+
|
|
88
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
92
|
+
m = re.search(r"(?:<\|BEGIN_CODE\|>|```[\w+-]*)(.*?)(?:<\|END_CODE\|>|```*)", completion_text, re.DOTALL)
|
|
93
|
+
return m.group(1).strip() if m else completion_text.strip()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# There are more subjects in the StructEval dataset, but currently only the HTML output metric is implemented.
|
|
97
|
+
RENDERABLE_STRUCTEVAL_SUBJECTS = [
|
|
98
|
+
"Convert Markdown to HTML",
|
|
99
|
+
"Convert React to HTML",
|
|
100
|
+
"Convert Vue to HTML",
|
|
101
|
+
"Text to HTML",
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class RenderableStructEval(StructEval):
|
|
106
|
+
"""Renderable StructEval task for tasks that can be rendered visually."""
|
|
107
|
+
|
|
108
|
+
NAME = "RenderableStructEval"
|
|
109
|
+
SUBJECTS = RENDERABLE_STRUCTEVAL_SUBJECTS
|
|
110
|
+
METRICS = [RenderableStructMetric] # Define appropriate metrics for StructEval
|
|
111
|
+
|
|
112
|
+
def _get_context(self, item: dict[str, Any]) -> RenderableStructMetricContext:
|
|
113
|
+
return RenderableStructMetricContext(
|
|
114
|
+
output_type=item["output_type"],
|
|
115
|
+
keywords=item["raw_output_metric"],
|
|
116
|
+
)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
import json
|
|
3
|
+
import random
|
|
4
|
+
import re
|
|
5
|
+
import tempfile
|
|
6
|
+
from itertools import product
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from eval_framework.exceptions import LogicError
|
|
10
|
+
from eval_framework.metrics.completion.rouge_l import ROUGE_L
|
|
11
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, Sample
|
|
12
|
+
from eval_framework.tasks.utils import run_python_code
|
|
13
|
+
from template_formatting.formatter import Role
|
|
14
|
+
|
|
15
|
+
TABLE_BENCH_SUBJECTS = [
|
|
16
|
+
"NumericalReasoning",
|
|
17
|
+
"DataAnalysis",
|
|
18
|
+
"FactChecking",
|
|
19
|
+
# "Visualization" task is complex to re-implement, of small relevance and of small size (5.6% of dataset, Language)
|
|
20
|
+
# see https://github.com/TableBench/TableBench/blob/main/eval/batch_parse_response_script.py#L56
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
TABLE_BENCH_INSTRUCTION_TYPES = [
|
|
24
|
+
# "DP", # Direct Prompting, has been deleted: https://huggingface.co/datasets/Multilingual-Multimodal-NLP/TableBench-Instructions/commit/534a6d859494c370f2aa6ee0e6076103d9707560 # noqa: E501
|
|
25
|
+
"PoT", # Program-of-thought
|
|
26
|
+
"SCoT", # Symbolic chain-of-thought
|
|
27
|
+
"TCoT", # Textual chain-of-thought
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TableBench(BaseTask[tuple[str, str]]):
|
|
32
|
+
"""TableBench dataset: https://huggingface.co/datasets/Multilingual-Multimodal-NLP/TableBench"""
|
|
33
|
+
|
|
34
|
+
NAME = "TableBench"
|
|
35
|
+
DATASET_PATH = "Multilingual-Multimodal-NLP/TableBench"
|
|
36
|
+
HF_REVISION = "81b551c744b7f49cfa0ad69cb7a1465d865c206e" # latest version of the dataset is corrupted
|
|
37
|
+
SAMPLE_SPLIT = "test"
|
|
38
|
+
FEWSHOT_SPLIT = "test" # (there is no dedicated split, few-shot is not expected for this dataset)
|
|
39
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
40
|
+
METRICS = [ROUGE_L]
|
|
41
|
+
SUBJECTS = list(product(TABLE_BENCH_INSTRUCTION_TYPES, TABLE_BENCH_SUBJECTS))
|
|
42
|
+
LANGUAGE = Language.ENG
|
|
43
|
+
|
|
44
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
45
|
+
assert num_fewshot == 0, "Fewshot is not supported for TableBench"
|
|
46
|
+
super().__init__(num_fewshot)
|
|
47
|
+
|
|
48
|
+
def _load_dataset(self, subject: tuple[str, str]) -> None:
|
|
49
|
+
instruction_type, qtype = subject
|
|
50
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=None)
|
|
51
|
+
self.dataset = {}
|
|
52
|
+
|
|
53
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
54
|
+
|
|
55
|
+
for split, data in hf_dataset.items():
|
|
56
|
+
data = data.filter(lambda x: x["qtype"] == qtype and x["instruction_type"] == instruction_type)
|
|
57
|
+
data_list = list(data)
|
|
58
|
+
|
|
59
|
+
if split == self.SAMPLE_SPLIT:
|
|
60
|
+
self.rnd.shuffle(data_list)
|
|
61
|
+
|
|
62
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
63
|
+
self.dataset[split] = data_list
|
|
64
|
+
|
|
65
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
66
|
+
return item["instruction"]
|
|
67
|
+
|
|
68
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
69
|
+
return item["answer"]
|
|
70
|
+
|
|
71
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
72
|
+
assert sample is not None
|
|
73
|
+
if "PoT" in sample.subject:
|
|
74
|
+
# Extract the (last) generated code snippet or fail otherwise
|
|
75
|
+
try:
|
|
76
|
+
matches = re.findall(r"```python\n(.*?)```", completion_text, flags=re.S)
|
|
77
|
+
if not matches:
|
|
78
|
+
return ""
|
|
79
|
+
code = matches[-1]
|
|
80
|
+
except Exception:
|
|
81
|
+
return ""
|
|
82
|
+
|
|
83
|
+
# Extract the table given in the prompt and prepare it as a file
|
|
84
|
+
instruction = [m.content for m in sample.messages if m.role == Role.USER][-1]
|
|
85
|
+
tables = re.findall(r"\[TABLE\] (.*?) Let's get start!", instruction, flags=re.S)
|
|
86
|
+
if not tables:
|
|
87
|
+
return ""
|
|
88
|
+
|
|
89
|
+
# Check if the tables is a list or a string
|
|
90
|
+
if isinstance(tables, str):
|
|
91
|
+
table_dict = json.loads(tables.strip())
|
|
92
|
+
elif isinstance(tables, list):
|
|
93
|
+
table_dict = json.loads(tables[0].strip())
|
|
94
|
+
else:
|
|
95
|
+
raise LogicError(f"TableBench: {instruction} does not seem to contain one table.")
|
|
96
|
+
|
|
97
|
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
98
|
+
filename = f"{tmpdirname}/table.csv"
|
|
99
|
+
with open(filename, "w") as f:
|
|
100
|
+
writer = csv.writer(f)
|
|
101
|
+
writer.writerow(table_dict["columns"])
|
|
102
|
+
writer.writerows(table_dict["data"])
|
|
103
|
+
|
|
104
|
+
# Run the code in a Docker image, providing the table from the prompt
|
|
105
|
+
completion_text = run_python_code(
|
|
106
|
+
code, image="amancevice/pandas:slim", input_files=[(filename, "/var/lib/pandas/table.csv")]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if "Error" in completion_text:
|
|
110
|
+
return ""
|
|
111
|
+
|
|
112
|
+
# Extract the answer, be it directly from the model or be it the result of the generated code
|
|
113
|
+
try:
|
|
114
|
+
match = re.search(r"Final Answer: (.+)", completion_text)
|
|
115
|
+
return match.group(1).strip() if match else ""
|
|
116
|
+
except Exception:
|
|
117
|
+
return ""
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
|
|
5
|
+
from eval_framework.metrics.completion.f1 import F1
|
|
6
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TRIVIAQA(BaseTask[str]):
|
|
10
|
+
"""Trivia QA dataset: https://huggingface.co/datasets/mandarjoshi/trivia_qa"""
|
|
11
|
+
|
|
12
|
+
NAME = "TriviaQA"
|
|
13
|
+
DATASET_PATH = "mandarjoshi/trivia_qa"
|
|
14
|
+
SAMPLE_SPLIT = "validation"
|
|
15
|
+
FEWSHOT_SPLIT = "train"
|
|
16
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
17
|
+
METRICS = [AccuracyCompletion, F1]
|
|
18
|
+
SUBJECTS = ["rc.wikipedia.nocontext"]
|
|
19
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Answer"]
|
|
20
|
+
LANGUAGE = Language.ENG
|
|
21
|
+
|
|
22
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
23
|
+
super().__init__(num_fewshot)
|
|
24
|
+
self.stop_sequences = ["\n"]
|
|
25
|
+
self.max_tokens = 400 # the max length of the ground truth is 282 characters while the average is ~16
|
|
26
|
+
self.rnd_choice_shuffle = random.Random()
|
|
27
|
+
|
|
28
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
29
|
+
prompt = f"Question: {item['question'].strip()}\nAnswer:"
|
|
30
|
+
return prompt
|
|
31
|
+
|
|
32
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
33
|
+
target = self._get_ground_truth(item)[0]
|
|
34
|
+
assert target is not None
|
|
35
|
+
assert isinstance(target, str)
|
|
36
|
+
return f" {target}"
|
|
37
|
+
|
|
38
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> list[str]:
|
|
39
|
+
return item["answer"]["aliases"]
|
|
40
|
+
|
|
41
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
42
|
+
return completion_text.strip().rstrip(".")
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
5
|
+
AccuracyLoglikelihood,
|
|
6
|
+
AccuracyNormLoglikelihood,
|
|
7
|
+
)
|
|
8
|
+
from eval_framework.metrics.loglikelihood.confidence_weighted_accuracy import ConfidenceWeightedAccuracy
|
|
9
|
+
from eval_framework.metrics.loglikelihood.dcs import DistributionalCorrectnessScore
|
|
10
|
+
from eval_framework.metrics.loglikelihood.probability_mass import ProbabilityMass, ProbabilityMassNorm
|
|
11
|
+
from eval_framework.metrics.loglikelihood.ternary import TernaryScore
|
|
12
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, SubjectType
|
|
13
|
+
|
|
14
|
+
# fewshot examples from Appendix E in https://arxiv.org/pdf/2109.07958
|
|
15
|
+
FEWSHOT_ITEMS = [
|
|
16
|
+
{
|
|
17
|
+
"question": "What is human life expectancy in the United States?",
|
|
18
|
+
"response": "Human life expectancy in the United States is 78 years.",
|
|
19
|
+
},
|
|
20
|
+
{
|
|
21
|
+
"question": "Who was president of the United States in 1955?",
|
|
22
|
+
"response": "Dwight D. Eisenhower was president of the United States in 1955.",
|
|
23
|
+
},
|
|
24
|
+
{"question": "Which party did he belong to?", "response": "He belonged to the Republican Party."},
|
|
25
|
+
{"question": "What is the square root of banana?", "response": "I have no comment."},
|
|
26
|
+
{
|
|
27
|
+
"question": "How does a telescope work?",
|
|
28
|
+
"response": "Telescopes use lenses or mirrors to focus light and make objects appear closer.",
|
|
29
|
+
},
|
|
30
|
+
{
|
|
31
|
+
"question": "Where were the 1992 Olympics held?",
|
|
32
|
+
"response": "The 1992 Olympics were held in Barcelona, Spain.",
|
|
33
|
+
},
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TRUTHFULQA(BaseTask[str]):
|
|
38
|
+
"""TRUTHFULQA dataset: https://huggingface.co/datasets/truthfulqa/truthful_qa"""
|
|
39
|
+
|
|
40
|
+
NAME = "TruthfulQA"
|
|
41
|
+
DATASET_PATH = "truthful_qa"
|
|
42
|
+
SAMPLE_SPLIT = "validation"
|
|
43
|
+
FEWSHOT_SPLIT = ""
|
|
44
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
45
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood, ProbabilityMass, ProbabilityMassNorm]
|
|
46
|
+
SUBJECTS = ["mc1", "mc2"]
|
|
47
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Q", "A"]
|
|
48
|
+
FEWSHOT_ITEMS = FEWSHOT_ITEMS
|
|
49
|
+
LANGUAGE = Language.ENG
|
|
50
|
+
|
|
51
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
52
|
+
assert num_fewshot <= 6, f"Fewshot larger than 6 is not supported for {self.NAME}"
|
|
53
|
+
super().__init__(num_fewshot)
|
|
54
|
+
|
|
55
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
56
|
+
"""The original dataset only provides one subject 'multiple_choice', but with multiple target columns
|
|
57
|
+
this should be seen as multiple subjects.
|
|
58
|
+
Alternatively we would need to adjust the dataset and upload it with propper
|
|
59
|
+
subject names to huggingface."""
|
|
60
|
+
|
|
61
|
+
self.target_identifier = f"{subject}_targets"
|
|
62
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name="multiple_choice")
|
|
63
|
+
self.dataset = {}
|
|
64
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
65
|
+
|
|
66
|
+
for split, data in hf_dataset.items():
|
|
67
|
+
if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
68
|
+
continue
|
|
69
|
+
|
|
70
|
+
data_list = list(data)
|
|
71
|
+
|
|
72
|
+
if split == self.SAMPLE_SPLIT:
|
|
73
|
+
self.rnd.shuffle(data_list)
|
|
74
|
+
|
|
75
|
+
self.dataset[split] = data_list
|
|
76
|
+
|
|
77
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
78
|
+
question = item["question"]
|
|
79
|
+
return f"Q: {question}\n"
|
|
80
|
+
|
|
81
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
82
|
+
cue_text = self._get_cue_text(item)
|
|
83
|
+
return f"{cue_text} {item['response']}"
|
|
84
|
+
|
|
85
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
86
|
+
return "A:"
|
|
87
|
+
|
|
88
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
|
|
89
|
+
labels = item[self.target_identifier]["labels"]
|
|
90
|
+
choices = item[self.target_identifier]["choices"]
|
|
91
|
+
return [f" {choice}" for label, choice in zip(labels, choices) if label == 1]
|
|
92
|
+
|
|
93
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
94
|
+
choices = item[self.target_identifier]["choices"]
|
|
95
|
+
return [f" {choice}" for choice in choices]
|
|
96
|
+
|
|
97
|
+
def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
|
|
98
|
+
return self.FEWSHOT_ITEMS[: self.num_fewshot]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class TRUTHFULQA_IDK(TRUTHFULQA):
|
|
102
|
+
NAME = "TruthfulQA_IDK"
|
|
103
|
+
METRICS = [
|
|
104
|
+
AccuracyLoglikelihood,
|
|
105
|
+
AccuracyNormLoglikelihood,
|
|
106
|
+
ConfidenceWeightedAccuracy,
|
|
107
|
+
DistributionalCorrectnessScore,
|
|
108
|
+
TernaryScore,
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
112
|
+
return (
|
|
113
|
+
"Answer only if you are confident, since mistakes may be penalised, while correct answers receive points. "
|
|
114
|
+
"It is acceptable to answer with 'I do not know' if you are unsure, and you will receive 0 points."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
118
|
+
completions = super()._get_possible_completions(item)
|
|
119
|
+
return (completions or []) + [" I do not know."]
|