eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
5
|
+
AccuracyLoglikelihood,
|
|
6
|
+
AccuracyNormLoglikelihood,
|
|
7
|
+
)
|
|
8
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType
|
|
9
|
+
from eval_framework.tasks.utils import get_n_letters
|
|
10
|
+
|
|
11
|
+
CHEMBENCH_SUBJECTS = [
|
|
12
|
+
"analytical_chemistry",
|
|
13
|
+
"chemical_preference",
|
|
14
|
+
"general_chemistry",
|
|
15
|
+
"inorganic_chemistry",
|
|
16
|
+
"materials_science",
|
|
17
|
+
"organic_chemistry",
|
|
18
|
+
"physical_chemistry",
|
|
19
|
+
"technical_chemistry",
|
|
20
|
+
"toxicity_and_safety",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ChemBench(BaseTask[str]):
|
|
25
|
+
"""ChemBench dataset: https://huggingface.co/datasets/jablonkagroup/ChemBench"""
|
|
26
|
+
|
|
27
|
+
NAME = "ChemBench"
|
|
28
|
+
DATASET_PATH = "jablonkagroup/ChemBench"
|
|
29
|
+
SAMPLE_SPLIT = "train" # Only has train split
|
|
30
|
+
FEWSHOT_SPLIT = "train" # Only has train split
|
|
31
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
32
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
33
|
+
SUBJECTS = CHEMBENCH_SUBJECTS
|
|
34
|
+
LANGUAGE = Language.ENG
|
|
35
|
+
|
|
36
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
37
|
+
assert num_fewshot == 0, "Fewshot is not supported for ChemBench"
|
|
38
|
+
super().__init__(num_fewshot)
|
|
39
|
+
|
|
40
|
+
self.keys = get_n_letters(16)
|
|
41
|
+
|
|
42
|
+
def _load_dataset(self, subject: str) -> None:
|
|
43
|
+
super()._load_dataset(subject)
|
|
44
|
+
# Keep only the multiple-choice options with 1 correct answer
|
|
45
|
+
for split in self.dataset.keys():
|
|
46
|
+
filtered_items = []
|
|
47
|
+
for item in self.dataset[split]:
|
|
48
|
+
if item.get("metrics") == ["multiple_choice_grade"]:
|
|
49
|
+
target_scores = json.loads(item["examples"][0]["target_scores"])
|
|
50
|
+
correct_answers = [i for i, score in enumerate(target_scores.values()) if score == 1.0]
|
|
51
|
+
if len(correct_answers) == 1:
|
|
52
|
+
filtered_items.append(item)
|
|
53
|
+
self.dataset[split] = filtered_items
|
|
54
|
+
|
|
55
|
+
def _get_subject_name(self, item: dict[str, Any]) -> str:
|
|
56
|
+
return " ".join(item["subject"].split("_"))
|
|
57
|
+
|
|
58
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
59
|
+
return (
|
|
60
|
+
"The following is a question about chemistry. Please answer by responding with the letter of the correct "
|
|
61
|
+
"answer."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
65
|
+
question = item["examples"][0]["input"].strip()
|
|
66
|
+
target_scores = json.loads(item["examples"][0]["target_scores"])
|
|
67
|
+
choices = "".join([f"{key}. {choice}\n" for key, choice in zip(self.keys, target_scores.keys())])
|
|
68
|
+
return f"Question: {question}\n{choices}"
|
|
69
|
+
|
|
70
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
71
|
+
ground_truth = self._get_ground_truth(item)
|
|
72
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
73
|
+
|
|
74
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
75
|
+
return "Answer:"
|
|
76
|
+
|
|
77
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
78
|
+
target_scores = json.loads(item["examples"][0]["target_scores"])
|
|
79
|
+
correct_answers = [i for i, score in enumerate(target_scores.values()) if score == 1.0]
|
|
80
|
+
assert len(correct_answers) == 1, f"Expected exactly one correct answer, but got {len(correct_answers)}"
|
|
81
|
+
return f" {self.keys[correct_answers[0]]}"
|
|
82
|
+
|
|
83
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
84
|
+
target_scores = json.loads(item["examples"][0]["target_scores"])
|
|
85
|
+
return [f" {key}" for key in self.keys[: len(target_scores)]]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
4
|
+
AccuracyLoglikelihood,
|
|
5
|
+
AccuracyNormLoglikelihood,
|
|
6
|
+
)
|
|
7
|
+
from eval_framework.metrics.loglikelihood.confidence_weighted_accuracy import ConfidenceWeightedAccuracy
|
|
8
|
+
from eval_framework.metrics.loglikelihood.dcs import DistributionalCorrectnessScore
|
|
9
|
+
from eval_framework.metrics.loglikelihood.ternary import TernaryScore
|
|
10
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class COPA(BaseTask[str]):
|
|
14
|
+
"""COPA dataset: https://huggingface.co/datasets/aps/super_glue"""
|
|
15
|
+
|
|
16
|
+
NAME = "COPA"
|
|
17
|
+
DATASET_PATH = "aps/super_glue"
|
|
18
|
+
SAMPLE_SPLIT = "validation" # 100 examples (same split as lm-eval)
|
|
19
|
+
FEWSHOT_SPLIT = "test" # 500 examples
|
|
20
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
21
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
22
|
+
SUBJECTS = ["copa"]
|
|
23
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["because", "therefore"]
|
|
24
|
+
LANGUAGE = Language.ENG
|
|
25
|
+
|
|
26
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
27
|
+
connector = {
|
|
28
|
+
"cause": "because",
|
|
29
|
+
"effect": "therefore",
|
|
30
|
+
}[item["question"]]
|
|
31
|
+
return item["premise"].strip()[:-1] + f" {connector} "
|
|
32
|
+
|
|
33
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
34
|
+
correct_choice = item["choice1"] if item["label"] == 0 else item["choice2"]
|
|
35
|
+
return f"{self.convert_choice(correct_choice)}"
|
|
36
|
+
|
|
37
|
+
def convert_choice(self, choice: str) -> str:
|
|
38
|
+
return choice[0].lower() + choice[1:]
|
|
39
|
+
|
|
40
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
41
|
+
choices = [self.convert_choice(item["choice1"]), self.convert_choice(item["choice2"])]
|
|
42
|
+
return choices
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class COPA_IDK(COPA):
|
|
46
|
+
NAME = "COPA_IDK"
|
|
47
|
+
METRICS = [
|
|
48
|
+
AccuracyLoglikelihood,
|
|
49
|
+
AccuracyNormLoglikelihood,
|
|
50
|
+
ConfidenceWeightedAccuracy,
|
|
51
|
+
DistributionalCorrectnessScore,
|
|
52
|
+
TernaryScore,
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
56
|
+
return (
|
|
57
|
+
"Complete the sentence only if you are confident, since mistakes may be penalised, while correct "
|
|
58
|
+
"answers receive points. It is acceptable to answer with 'I do not know' if you are unsure, and "
|
|
59
|
+
"you will receive 0 points."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
63
|
+
completions = super()._get_possible_completions(item)
|
|
64
|
+
return (completions or []) + ["I do not know."]
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import re
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from eval_framework.metrics.base import BaseMetric
|
|
7
|
+
from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
|
|
8
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, Sample
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DUC(BaseTask[str], ABC):
|
|
12
|
+
"""https://huggingface.co/datasets/midas/duc2001"""
|
|
13
|
+
|
|
14
|
+
DATASET_PATH: str = "midas/duc2001"
|
|
15
|
+
SAMPLE_SPLIT: str = "test"
|
|
16
|
+
FEWSHOT_SPLIT: str = "test"
|
|
17
|
+
RESPONSE_TYPE: ResponseType = ResponseType.COMPLETION
|
|
18
|
+
METRICS: list[type[BaseMetric]] = [AccuracyCompletion]
|
|
19
|
+
SUBJECTS: list[str] = ["raw"]
|
|
20
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Text", "Keyphrase"]
|
|
21
|
+
LANGUAGE = Language.ENG
|
|
22
|
+
|
|
23
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
24
|
+
super().__init__(num_fewshot)
|
|
25
|
+
|
|
26
|
+
self.stop_sequences: list[str] = ["Text:"]
|
|
27
|
+
self.max_tokens = 50 # longest keyphrase is less than 50 characters long
|
|
28
|
+
|
|
29
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
30
|
+
for stop_sequence in self.stop_sequences:
|
|
31
|
+
if stop_sequence in completion_text:
|
|
32
|
+
completion_text = completion_text.split(stop_sequence)[0]
|
|
33
|
+
completion_text = completion_text.strip()
|
|
34
|
+
return completion_text
|
|
35
|
+
|
|
36
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
37
|
+
instruction_text = " ".join(item["document"])
|
|
38
|
+
instruction_text = re.sub(r"\s+([.,!?;:])", r"\1", instruction_text)
|
|
39
|
+
return f"Text: {instruction_text}\nKeyphrase:"
|
|
40
|
+
|
|
41
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
42
|
+
target = self._get_ground_truth(item)
|
|
43
|
+
assert target is not None
|
|
44
|
+
assert isinstance(target, list)
|
|
45
|
+
return f" {target[0]}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DUC_EXTRACTIVE(DUC):
|
|
49
|
+
NAME = "DUC Extractive"
|
|
50
|
+
SUBJECTS: list[str] = ["raw"]
|
|
51
|
+
|
|
52
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> list[str]:
|
|
53
|
+
return item["extractive_keyphrases"]
|
|
54
|
+
|
|
55
|
+
def _get_system_prompt_text(self, item: dict[str, Any]) -> str:
|
|
56
|
+
return (
|
|
57
|
+
"You are an AI model tasked with extracting keyphrases from a text document. "
|
|
58
|
+
"Keyphrases should capture main ideas or significant topics exactly as worded in the text."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DUC_ABSTRACTIVE(DUC):
|
|
63
|
+
NAME = "DUC Abstractive"
|
|
64
|
+
SUBJECTS: list[str] = ["raw"]
|
|
65
|
+
|
|
66
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> list[str]:
|
|
67
|
+
return item["abstractive_keyphrases"]
|
|
68
|
+
|
|
69
|
+
def _load_dataset(self, subject: str) -> None:
|
|
70
|
+
# not all samples have abstractive keyphrases
|
|
71
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=subject)
|
|
72
|
+
self.dataset = {}
|
|
73
|
+
|
|
74
|
+
for split, data in hf_dataset.items():
|
|
75
|
+
data_list = list(filter(lambda x: len(x["abstractive_keyphrases"]) > 0, data))
|
|
76
|
+
|
|
77
|
+
if split == self.SAMPLE_SPLIT:
|
|
78
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
79
|
+
self.rnd.shuffle(data_list)
|
|
80
|
+
|
|
81
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
82
|
+
self.dataset[split] = data_list
|
|
83
|
+
|
|
84
|
+
def _get_system_prompt_text(self, item: dict[str, Any]) -> str:
|
|
85
|
+
return (
|
|
86
|
+
"You are an AI model tasked with generating abstractive keyphrases "
|
|
87
|
+
"that capture the main ideas of the text without using exact wording."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
91
|
+
return "Paraphrase the following texts to improve clarity and relevance."
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import pycountry
|
|
7
|
+
from datasets import DownloadConfig, load_dataset
|
|
8
|
+
from huggingface_hub import HfApi
|
|
9
|
+
from huggingface_hub.errors import RevisionNotFoundError
|
|
10
|
+
|
|
11
|
+
from eval_framework.metrics.completion.bleu import BLEU
|
|
12
|
+
from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, Sample, SubjectType
|
|
13
|
+
|
|
14
|
+
FLORES_LANGUAGES = [
|
|
15
|
+
"deu_Latn",
|
|
16
|
+
"eng_Latn",
|
|
17
|
+
"fin_Latn",
|
|
18
|
+
"fra_Latn",
|
|
19
|
+
"nld_Latn",
|
|
20
|
+
] # Note: there are many more languages in the dataset, but we only consider these for now
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Flores200(BaseTask[str]):
|
|
24
|
+
"""FLORES-200 dataset: https://huggingface.co/datasets/facebook/flores"""
|
|
25
|
+
|
|
26
|
+
NAME = "FLoRes-200"
|
|
27
|
+
DATASET_PATH = "facebook/flores"
|
|
28
|
+
SAMPLE_SPLIT = "devtest"
|
|
29
|
+
FEWSHOT_SPLIT = "dev"
|
|
30
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
31
|
+
METRICS = [BLEU]
|
|
32
|
+
SUBJECTS = [f"{s}-{t}" for s in FLORES_LANGUAGES for t in FLORES_LANGUAGES if s != t]
|
|
33
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["sentence"]
|
|
34
|
+
LANGUAGE = {
|
|
35
|
+
"deu_Latn": Language.DEU,
|
|
36
|
+
"eng_Latn": Language.ENG,
|
|
37
|
+
"fin_Latn": Language.FIN,
|
|
38
|
+
"fra_Latn": Language.FRA,
|
|
39
|
+
"nld_Latn": Language.NLD,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
43
|
+
super().__init__(num_fewshot)
|
|
44
|
+
self.stop_sequences = ["\n"]
|
|
45
|
+
|
|
46
|
+
def _load_hf_dataset(self, **kwargs: Any) -> Any:
|
|
47
|
+
"""Override to handle FLORES-200 encoding issues by using parquet files."""
|
|
48
|
+
# Check if the HF_REVISION is valid before loading the dataset
|
|
49
|
+
if self.HF_REVISION:
|
|
50
|
+
try:
|
|
51
|
+
_ = HfApi().dataset_info(repo_id=kwargs["path"], revision=self.HF_REVISION, timeout=100.0)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
if isinstance(e, RevisionNotFoundError):
|
|
54
|
+
raise e
|
|
55
|
+
|
|
56
|
+
cache_dir: str = os.environ.get("HF_DATASET_CACHE_DIR", f"{Path.home()}/.cache/huggingface/datasets")
|
|
57
|
+
download_config = DownloadConfig(cache_dir=cache_dir, max_retries=5)
|
|
58
|
+
|
|
59
|
+
# First, try to load using parquet files to bypass the problematic loading script
|
|
60
|
+
try:
|
|
61
|
+
# Try loading without the loading script by using data_files
|
|
62
|
+
# This forces the dataset library to use the parquet files directly
|
|
63
|
+
dataset = load_dataset(
|
|
64
|
+
kwargs.get("path", self.DATASET_PATH),
|
|
65
|
+
name=kwargs.get("name"),
|
|
66
|
+
split=kwargs.get("split"),
|
|
67
|
+
data_files=None, # Let it auto-discover parquet files
|
|
68
|
+
revision=self.HF_REVISION,
|
|
69
|
+
trust_remote_code=False, # Disable the loading script!
|
|
70
|
+
cache_dir=cache_dir,
|
|
71
|
+
download_config=download_config,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return dataset
|
|
75
|
+
|
|
76
|
+
except Exception:
|
|
77
|
+
# If parquet loading fails, try the original method
|
|
78
|
+
# Try the original loading with the loading script
|
|
79
|
+
dataset = load_dataset(
|
|
80
|
+
**kwargs,
|
|
81
|
+
revision=self.HF_REVISION,
|
|
82
|
+
trust_remote_code=True,
|
|
83
|
+
cache_dir=cache_dir,
|
|
84
|
+
download_config=download_config,
|
|
85
|
+
)
|
|
86
|
+
return dataset
|
|
87
|
+
|
|
88
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
89
|
+
# Store the subject (language pair) for use in other methods
|
|
90
|
+
self.subject = subject
|
|
91
|
+
|
|
92
|
+
# For FLORES, we need to load the dataset once with all languages
|
|
93
|
+
# The subject (e.g., "eng_Latn-deu_Latn") determines which fields we use
|
|
94
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name="all")
|
|
95
|
+
self.dataset = {}
|
|
96
|
+
|
|
97
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
98
|
+
|
|
99
|
+
for split, data in hf_dataset.items():
|
|
100
|
+
data_list = list(data)
|
|
101
|
+
|
|
102
|
+
# Add the subject to each item so _get_instruction_text can use it
|
|
103
|
+
for item in data_list:
|
|
104
|
+
item["subject"] = subject
|
|
105
|
+
|
|
106
|
+
if split == self.SAMPLE_SPLIT:
|
|
107
|
+
self.rnd.shuffle(data_list)
|
|
108
|
+
self.dataset[split] = data_list
|
|
109
|
+
elif split == self.FEWSHOT_SPLIT:
|
|
110
|
+
self.dataset[split] = data_list
|
|
111
|
+
|
|
112
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
113
|
+
source_key = item["subject"].split("-")[0]
|
|
114
|
+
source_language = pycountry.languages.get(alpha_3=source_key.split("_")[0]).name
|
|
115
|
+
source = item[f"sentence_{source_key}"]
|
|
116
|
+
instruction = f"{source_language} sentence: {source}\n"
|
|
117
|
+
target_key = item["subject"].split("-")[1]
|
|
118
|
+
target_language = pycountry.languages.get(alpha_3=target_key.split("_")[0]).name
|
|
119
|
+
|
|
120
|
+
return f"{instruction}{target_language} sentence:"
|
|
121
|
+
|
|
122
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
123
|
+
target_key = item["subject"].split("-")[1]
|
|
124
|
+
return item[f"sentence_{target_key}"]
|
|
125
|
+
|
|
126
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
127
|
+
target = f" {self._get_ground_truth(item)}"
|
|
128
|
+
assert target is not None
|
|
129
|
+
assert isinstance(target, str)
|
|
130
|
+
return target
|
|
131
|
+
|
|
132
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
133
|
+
return completion_text.strip()
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from itertools import product
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from eval_framework.metrics.completion.bleu import BLEU
|
|
6
|
+
from eval_framework.metrics.completion.chrf import CHRF
|
|
7
|
+
from eval_framework.metrics.completion.comet import COMET
|
|
8
|
+
from eval_framework.shared.types import BaseMetricContext, UntemplatedPrompt
|
|
9
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
|
|
10
|
+
|
|
11
|
+
LANG_MAP = {
|
|
12
|
+
"deu_Latn": "German",
|
|
13
|
+
"eng_Latn": "English",
|
|
14
|
+
"fra_Latn": "French",
|
|
15
|
+
"ita_Latn": "Italian",
|
|
16
|
+
"nld_Latn": "Dutch",
|
|
17
|
+
"pol_Latn": "Polish",
|
|
18
|
+
"rus_Cyrl": "Russian",
|
|
19
|
+
"spa_Latn": "Spanish",
|
|
20
|
+
"ukr_Cyrl": "Ukrainian",
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FloresPlus(BaseTask[str]):
|
|
25
|
+
"""Flores-Plus dataset: https://huggingface.co/datasets/openlanguagedata/flores_plus"""
|
|
26
|
+
|
|
27
|
+
NAME = "Flores-Plus"
|
|
28
|
+
DATASET_PATH = "openlanguagedata/flores_plus"
|
|
29
|
+
SAMPLE_SPLIT = "dev"
|
|
30
|
+
FEWSHOT_SPLIT = "devtest"
|
|
31
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
32
|
+
METRICS = [BLEU, CHRF, COMET]
|
|
33
|
+
SUBJECTS = [f"{s}-{t}" for s, t in product(LANG_MAP, LANG_MAP) if s != t]
|
|
34
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["sentence"]
|
|
35
|
+
LANGUAGE = {
|
|
36
|
+
"deu_Latn": Language.DEU,
|
|
37
|
+
"eng_Latn": Language.ENG,
|
|
38
|
+
"fra_Latn": Language.FRA,
|
|
39
|
+
"ita_Latn": Language.ITA,
|
|
40
|
+
"nld_Latn": Language.NLD,
|
|
41
|
+
"pol_Latn": Language.POL,
|
|
42
|
+
"rus_Cyrl": Language.RUS,
|
|
43
|
+
"spa_Latn": Language.SPA,
|
|
44
|
+
"ukr_Cyrl": Language.UKR,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
48
|
+
super().__init__(num_fewshot)
|
|
49
|
+
self.stop_sequences = ["\n"]
|
|
50
|
+
|
|
51
|
+
def _load_dataset(self, subject: str) -> None:
|
|
52
|
+
hf_dataset_src = self._load_hf_dataset(path=self.DATASET_PATH, name=subject.split("-")[0])
|
|
53
|
+
hf_dataset_tgt = self._load_hf_dataset(path=self.DATASET_PATH, name=subject.split("-")[1])
|
|
54
|
+
self.dataset = {}
|
|
55
|
+
self.rnd = random.Random(42)
|
|
56
|
+
|
|
57
|
+
for split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
58
|
+
data_src = hf_dataset_src[split]
|
|
59
|
+
data_tgt = hf_dataset_tgt[split]
|
|
60
|
+
data_list = []
|
|
61
|
+
for item_src, item_tgt in zip(data_src, data_tgt):
|
|
62
|
+
assert item_src["id"] == item_tgt["id"]
|
|
63
|
+
iso_src = f"{item_src['iso_639_3']}_{item_src['iso_15924']}"
|
|
64
|
+
iso_tgt = f"{item_tgt['iso_639_3']}_{item_tgt['iso_15924']}"
|
|
65
|
+
text_src = item_src["text"]
|
|
66
|
+
text_tgt = item_tgt["text"]
|
|
67
|
+
data_list.append({"iso_source": iso_src, "iso_target": iso_tgt, "source": text_src, "target": text_tgt})
|
|
68
|
+
if split == self.SAMPLE_SPLIT:
|
|
69
|
+
self.rnd.shuffle(data_list)
|
|
70
|
+
self.dataset[split] = data_list
|
|
71
|
+
|
|
72
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
73
|
+
target_language = LANG_MAP[item["iso_target"]]
|
|
74
|
+
instruction = f"Translate the following text into {target_language}:\n{item['source']}"
|
|
75
|
+
return instruction
|
|
76
|
+
|
|
77
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
78
|
+
return item["target"]
|
|
79
|
+
|
|
80
|
+
def _get_context(self, item: dict[str, Any]) -> BaseMetricContext | list[BaseMetricContext] | None:
|
|
81
|
+
return UntemplatedPrompt(untemplated_prompt=item["source"])
|
|
82
|
+
|
|
83
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
84
|
+
return completion_text.strip()
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
import random
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
|
|
8
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
9
|
+
AccuracyLoglikelihood,
|
|
10
|
+
AccuracyNormLoglikelihood,
|
|
11
|
+
)
|
|
12
|
+
from eval_framework.metrics.loglikelihood.confidence_weighted_accuracy import ConfidenceWeightedAccuracy
|
|
13
|
+
from eval_framework.metrics.loglikelihood.dcs import DistributionalCorrectnessScore
|
|
14
|
+
from eval_framework.metrics.loglikelihood.ternary import TernaryScore
|
|
15
|
+
from eval_framework.tasks.base import NO_SUBJECT, RANDOM_SEED, BaseTask, Language, ResponseType, Sample, SubjectType
|
|
16
|
+
from eval_framework.tasks.utils import get_n_letters
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GPQA(BaseTask[str]):
|
|
22
|
+
"""GPQA dataset: https://huggingface.co/datasets/Idavidrein/gpqa"""
|
|
23
|
+
|
|
24
|
+
NAME = "GPQA"
|
|
25
|
+
DATASET_PATH = "Idavidrein/gpqa"
|
|
26
|
+
SAMPLE_SPLIT = "train"
|
|
27
|
+
FEWSHOT_SPLIT = "train"
|
|
28
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
29
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
30
|
+
SUBJECTS = ["gpqa_extended"] # ["gpqa_diamond", "gpqa_extended", "gpqa_main", "gpqa_experts"]
|
|
31
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question"] + get_n_letters(4)
|
|
32
|
+
LANGUAGE = Language.ENG
|
|
33
|
+
|
|
34
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
35
|
+
super().__init__(num_fewshot)
|
|
36
|
+
self.stop_sequences = ["Question:"]
|
|
37
|
+
self.keys = get_n_letters(4)
|
|
38
|
+
self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
|
|
39
|
+
self.rnd_choice_shuffle = random.Random(RANDOM_SEED)
|
|
40
|
+
|
|
41
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
42
|
+
name = subject if subject != NO_SUBJECT else None
|
|
43
|
+
|
|
44
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
|
|
45
|
+
self.dataset = {}
|
|
46
|
+
|
|
47
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
48
|
+
|
|
49
|
+
for split, data in hf_dataset.items():
|
|
50
|
+
data_list = list(data)
|
|
51
|
+
|
|
52
|
+
if split == self.SAMPLE_SPLIT:
|
|
53
|
+
self.rnd.shuffle(data_list)
|
|
54
|
+
|
|
55
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
56
|
+
# exclude in the GPQA dataset one of the sample that has an too long prompt (DNA sequence)
|
|
57
|
+
data_list_filtered = [
|
|
58
|
+
item
|
|
59
|
+
for item in data_list
|
|
60
|
+
if item["Question"]
|
|
61
|
+
!= "Hello, you are embarking on a new project. You need to produce the HP1alpha protein in E. coli. Which of these plasmids will you choose?" # noqa: E501
|
|
62
|
+
]
|
|
63
|
+
if len(data_list) - len(data_list_filtered) > 0:
|
|
64
|
+
logger.info(f"Excluded {len(data_list) - len(data_list_filtered)} samples from {split} split.")
|
|
65
|
+
assert len(data_list) - len(data_list_filtered) < 2, "we expect to remove max one item"
|
|
66
|
+
|
|
67
|
+
self.dataset[split] = data_list_filtered
|
|
68
|
+
|
|
69
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
70
|
+
system_prompt_text = (
|
|
71
|
+
"Here are some example questions from experts. "
|
|
72
|
+
"An explanation is given before the final answer. "
|
|
73
|
+
"Answer the final question yourself, giving your reasoning beforehand."
|
|
74
|
+
)
|
|
75
|
+
return system_prompt_text
|
|
76
|
+
|
|
77
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
78
|
+
choices, _ = self._get_possible_completions_marked(item)
|
|
79
|
+
prompt = f"Question: {item['Question'].strip()}\n"
|
|
80
|
+
prompt += "\n".join(choices) + "\n"
|
|
81
|
+
return prompt
|
|
82
|
+
|
|
83
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
84
|
+
ground_truth = self._get_ground_truth(item)
|
|
85
|
+
assert ground_truth is not None
|
|
86
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
87
|
+
|
|
88
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
89
|
+
return "Answer:"
|
|
90
|
+
|
|
91
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
92
|
+
choices, correct_answer_position = self._get_possible_completions_marked(item)
|
|
93
|
+
answer_key = choices[correct_answer_position][:3]
|
|
94
|
+
return f" {answer_key}"
|
|
95
|
+
|
|
96
|
+
def _get_possible_completions_marked(self, item: dict[str, Any]) -> tuple[list[str], int]:
|
|
97
|
+
choices = [self._preprocess(item[f"Incorrect Answer {x}"]) for x in range(1, 4)]
|
|
98
|
+
correct_answer = self._preprocess(item["Correct Answer"])
|
|
99
|
+
# we want to be random, but always the same for the same input
|
|
100
|
+
# so we hash the string, which always give you the same seed
|
|
101
|
+
hash_object = hashlib.sha256(f"{choices} {correct_answer}".encode())
|
|
102
|
+
self.rnd_choice_shuffle.seed(int(hash_object.hexdigest(), 16))
|
|
103
|
+
self.rnd_choice_shuffle.shuffle(choices)
|
|
104
|
+
correct_answer_position = self.rnd_choice_shuffle.randint(0, 3)
|
|
105
|
+
choices.insert(correct_answer_position, correct_answer)
|
|
106
|
+
choices = [f"({self.keys[i]}) {choice}" for i, choice in enumerate(choices)]
|
|
107
|
+
return choices, correct_answer_position
|
|
108
|
+
|
|
109
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
110
|
+
return [f" ({x})" for x in self.keys]
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _preprocess(text: str | None) -> str:
|
|
114
|
+
if text is None:
|
|
115
|
+
return " "
|
|
116
|
+
text = text.strip()
|
|
117
|
+
text = text.replace(" [title]", ". ")
|
|
118
|
+
text = re.sub("\\[.*?\\]", "", text)
|
|
119
|
+
text = text.replace(" ", " ")
|
|
120
|
+
return text
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class GPQA_IDK(GPQA):
|
|
124
|
+
NAME = "GPQA_IDK"
|
|
125
|
+
METRICS = [
|
|
126
|
+
AccuracyLoglikelihood,
|
|
127
|
+
AccuracyNormLoglikelihood,
|
|
128
|
+
ConfidenceWeightedAccuracy,
|
|
129
|
+
DistributionalCorrectnessScore,
|
|
130
|
+
TernaryScore,
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
134
|
+
return (
|
|
135
|
+
"Answer only if you are confident, since mistakes may be penalised, while correct answers receive points. "
|
|
136
|
+
"It is acceptable to answer with '?' if you are unsure, and you will receive 0 points."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
140
|
+
completions = super()._get_possible_completions(item)
|
|
141
|
+
return (completions or []) + [" (?)"]
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class GPQA_COT(GPQA):
|
|
145
|
+
NAME = "GPQA_COT"
|
|
146
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
147
|
+
METRICS = [AccuracyCompletion]
|
|
148
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Therefore", "the", "answer", "is", "ANSWER_LETTER"] + get_n_letters(
|
|
149
|
+
4
|
|
150
|
+
)
|
|
151
|
+
ANS_RE = re.compile(r"Therefore, the answer is \(([ABCDEFGHIJ])\)")
|
|
152
|
+
|
|
153
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
154
|
+
assert num_fewshot == 0, "Fewshot is not supported for GPQA_COT"
|
|
155
|
+
super().__init__(num_fewshot)
|
|
156
|
+
self.stop_sequences: list[str] = ["Question:"]
|
|
157
|
+
self.keys = get_n_letters(4)
|
|
158
|
+
self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
|
|
159
|
+
self.rnd_choice_shuffle = random.Random(RANDOM_SEED)
|
|
160
|
+
|
|
161
|
+
def _extract_answer(self, completion: str) -> str:
|
|
162
|
+
match = self.ANS_RE.search(completion)
|
|
163
|
+
if match:
|
|
164
|
+
match_str = match.group(1)
|
|
165
|
+
return match_str
|
|
166
|
+
else:
|
|
167
|
+
return "[invalid]"
|
|
168
|
+
|
|
169
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
170
|
+
for stop_sequence in self.stop_sequences:
|
|
171
|
+
if stop_sequence in completion_text:
|
|
172
|
+
completion_text = completion_text.split(stop_sequence)[0]
|
|
173
|
+
return self._extract_answer(completion_text)
|
|
174
|
+
|
|
175
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
176
|
+
return ""
|
|
177
|
+
|
|
178
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
179
|
+
# using the reasoning prompt from "Figure 44 of Tülu 3 paper: https://arxiv.org/pdf/2411.15124"
|
|
180
|
+
choices, _ = self._get_possible_completions_marked(item)
|
|
181
|
+
instruction_text = (
|
|
182
|
+
"Answer the following multiple-choice question by giving the correct answer letter in parentheses. "
|
|
183
|
+
"Provide CONCISE reasoning for the answer, and make sure to finish the response with "
|
|
184
|
+
'"Therefore, the answer is (ANSWER_LETTER)" where (ANSWER_LETTER) is one of (A), (B), (C), (D), (E), etc.'
|
|
185
|
+
)
|
|
186
|
+
instruction_text += f"\n\nQuestion: {item['Question'].strip()}\n"
|
|
187
|
+
instruction_text += "\n".join(choices)
|
|
188
|
+
instruction_text += (
|
|
189
|
+
"\n\nAnswer the above question and REMEMBER to finish your response with the exact phrase "
|
|
190
|
+
'"Therefore, the answer is (ANSWER_LETTER)" where (ANSWER_LETTER) is one of (A), (B), (C), (D), (E), etc.'
|
|
191
|
+
)
|
|
192
|
+
return instruction_text
|
|
193
|
+
|
|
194
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
195
|
+
return ""
|
|
196
|
+
|
|
197
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
198
|
+
choices, correct_answer_position = self._get_possible_completions_marked(item)
|
|
199
|
+
# index 1 selects the letter
|
|
200
|
+
answer_key = choices[correct_answer_position][1]
|
|
201
|
+
return answer_key
|