eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.completion.aidanbench import AidanBenchMetric
|
|
5
|
+
from eval_framework.metrics.llm.graders.coherence_grader import CoherenceGrader
|
|
6
|
+
from eval_framework.metrics.llm.graders.language import Language as LLMLanguage
|
|
7
|
+
from eval_framework.shared.types import Completion
|
|
8
|
+
from eval_framework.tasks.base import NO_SUBJECT, BaseTask, ResponseType, Sample
|
|
9
|
+
from eval_framework.tasks.base import Language as TaskLanguage
|
|
10
|
+
from eval_framework.utils.helpers import pairwise_cosine_similarity
|
|
11
|
+
from template_formatting.formatter import Message, Role
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from eval_framework.llm.base import BaseLLM
|
|
15
|
+
from eval_framework.shared.types import Error
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
COHERENCE_THRESHOLD = 15
|
|
19
|
+
NOVELTY_THRESHOLD = 0.15
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AidanBenchOriginal(BaseTask[str]):
|
|
23
|
+
"""AidanBench (https://openreview.net/pdf?id=fz969ahcvJ)."""
|
|
24
|
+
|
|
25
|
+
NAME = "AidanBench"
|
|
26
|
+
DATASET_PATH = "Aleph-Alpha-Research/aidanbench"
|
|
27
|
+
SAMPLE_SPLIT = "train"
|
|
28
|
+
FEWSHOT_SPLIT = "train"
|
|
29
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
30
|
+
METRICS = [AidanBenchMetric]
|
|
31
|
+
SUBJECTS = [NO_SUBJECT]
|
|
32
|
+
LANGUAGE = {NO_SUBJECT: TaskLanguage.ENG}
|
|
33
|
+
|
|
34
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
35
|
+
from eval_framework.llm.openai import OpenAIEmbeddingModel, OpenAIModel
|
|
36
|
+
|
|
37
|
+
super().__init__(num_fewshot)
|
|
38
|
+
assert num_fewshot == 0, "AidanBench does not support few-shot prompting."
|
|
39
|
+
self._coherence_grader = CoherenceGrader(grading_model=OpenAIModel(model_name="gpt-4o-mini"))
|
|
40
|
+
self._embedding_model = OpenAIEmbeddingModel()
|
|
41
|
+
|
|
42
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
43
|
+
item_prompt = item["prompt"]
|
|
44
|
+
# note the extra dot after colon. We take this from the original AidanBench code:
|
|
45
|
+
# https://github.com/aidanmclaughlin/AidanBench/blob/a6bb3253ff630c82e7adbc81ce7bc7184c5bd881/benchmark/prompts.py#L7 # noqa: E501
|
|
46
|
+
base_prompt = (
|
|
47
|
+
"Answer the following question:.\n"
|
|
48
|
+
"<question>" + item_prompt + "</question>\n"
|
|
49
|
+
"Provide your answer in <answer></answer> XML tags.\n"
|
|
50
|
+
)
|
|
51
|
+
base_prompt += (
|
|
52
|
+
"Your response should be one direct answer. "
|
|
53
|
+
"Only provide one answer. DO NOT list multiple answers. Please try to be concise.\n"
|
|
54
|
+
)
|
|
55
|
+
return base_prompt
|
|
56
|
+
|
|
57
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
def _calculate_novelty_score(self, messages: list[Message]) -> float:
|
|
61
|
+
assert messages[0].role == Role.USER
|
|
62
|
+
assert all(msg.role != Role.USER for msg in messages[1:]), "Only the first message should be from USER"
|
|
63
|
+
messages_without_instruction_ = messages[1:]
|
|
64
|
+
messages_without_instruction: list[Sequence[Message]] = [
|
|
65
|
+
[m] for m in messages_without_instruction_
|
|
66
|
+
] # input format for embedding model
|
|
67
|
+
if len(messages_without_instruction) == 1:
|
|
68
|
+
return 1.0 # if there's only one response, it's by definition novel
|
|
69
|
+
all_embeddings = self._embedding_model.generate_embeddings(messages_without_instruction)
|
|
70
|
+
new_embedding = all_embeddings[-1]
|
|
71
|
+
previous_embeddings = all_embeddings[:-1]
|
|
72
|
+
similarities = pairwise_cosine_similarity([new_embedding], previous_embeddings)
|
|
73
|
+
assert len(similarities) == 1
|
|
74
|
+
similarities_squeezed = similarities[0] # "squeeze"
|
|
75
|
+
assert len(similarities_squeezed) == len(previous_embeddings)
|
|
76
|
+
return 1 - max(similarities_squeezed)
|
|
77
|
+
|
|
78
|
+
def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
|
|
79
|
+
return []
|
|
80
|
+
|
|
81
|
+
def _fuse_messages(self, messages: list[Message]) -> list[Message]:
|
|
82
|
+
"""
|
|
83
|
+
Takes a list of messages and fuses them into a single message:
|
|
84
|
+
A USER message that also contains all previous model responses, wrapped for the next iterative generation step.
|
|
85
|
+
"""
|
|
86
|
+
assert len(messages) >= 2, "There must be at least one USER and one ASSISTANT message"
|
|
87
|
+
assert messages[0].role == Role.USER
|
|
88
|
+
assert all(msg.role == Role.ASSISTANT for msg in messages[1:]), "Only the first message should be from USER"
|
|
89
|
+
|
|
90
|
+
instruction_message = messages[0].content
|
|
91
|
+
previous_answers = [msg.content for msg in messages[1:]]
|
|
92
|
+
|
|
93
|
+
previous_answers_str = "\n\n".join(
|
|
94
|
+
[
|
|
95
|
+
f"<previous_answer id='{i + 1}'>\n{answer}\n</previous_answer>"
|
|
96
|
+
for i, answer in enumerate(previous_answers)
|
|
97
|
+
]
|
|
98
|
+
)
|
|
99
|
+
instruction_message += (
|
|
100
|
+
"IMPORTANT: Provide an answer you *HAVE NOT* given previously.\n"
|
|
101
|
+
"Your previous answers are inside of <previous_answers></previous_answers> XML tags.\n"
|
|
102
|
+
"<previous_answers>\n" + previous_answers_str + "\n</previous_answers>"
|
|
103
|
+
)
|
|
104
|
+
return [Message(role=Role.USER, content=instruction_message)]
|
|
105
|
+
|
|
106
|
+
def _generation_loop(
|
|
107
|
+
self, llm: "BaseLLM", stop_sequences: list[str] | None, max_tokens: int | None, initial_samples: list[Sample]
|
|
108
|
+
) -> tuple[list[list[Message]], list[Union["Error", None]]]:
|
|
109
|
+
initial_messages = [s.messages for s in initial_samples]
|
|
110
|
+
samples = [(s, False) for s in initial_samples] # (sample, is_done)
|
|
111
|
+
message_history = [msg for msg in initial_messages] # to keep track of all iterative model responses
|
|
112
|
+
errors: list[Error | None] = [None for _ in message_history]
|
|
113
|
+
while not all(is_done for _, is_done in samples):
|
|
114
|
+
# iterative generation loop
|
|
115
|
+
not_done_idx = [i for i, (_, is_done) in enumerate(samples) if not is_done]
|
|
116
|
+
new_completions = super().generate_completions(
|
|
117
|
+
llm,
|
|
118
|
+
[samples[i][0] for i in not_done_idx],
|
|
119
|
+
stop_sequences=stop_sequences,
|
|
120
|
+
max_tokens=max_tokens,
|
|
121
|
+
)
|
|
122
|
+
new_completion_messages: list[list[Message] | None] = [c.messages for c in new_completions]
|
|
123
|
+
new_errors = [c.error for c in new_completions]
|
|
124
|
+
|
|
125
|
+
new_samples = [s for s in samples]
|
|
126
|
+
for idx, completion_msgs, error in zip(not_done_idx, new_completion_messages, new_errors):
|
|
127
|
+
old_sample = samples[idx][0]
|
|
128
|
+
if completion_msgs is not None:
|
|
129
|
+
message_history[idx].append(completion_msgs[-1]) # add latest model response to history
|
|
130
|
+
errors[idx] = error
|
|
131
|
+
|
|
132
|
+
assert completion_msgs[0].role == Role.USER and completion_msgs[-1].role == Role.ASSISTANT
|
|
133
|
+
coherence_score = self._coherence_grader.grade(
|
|
134
|
+
instruction=old_sample.messages[0].content, # only pass initial instruction
|
|
135
|
+
completion=completion_msgs[-1].content,
|
|
136
|
+
language=LLMLanguage(iso_639_1="en"),
|
|
137
|
+
).coherence_score
|
|
138
|
+
else:
|
|
139
|
+
coherence_score = 0 # if no completion, treat as non-coherent
|
|
140
|
+
|
|
141
|
+
novelty_score = self._calculate_novelty_score(message_history[idx])
|
|
142
|
+
|
|
143
|
+
fused_message = self._fuse_messages(message_history[idx])
|
|
144
|
+
new_sample = Sample(
|
|
145
|
+
id=old_sample.id,
|
|
146
|
+
subject=old_sample.subject,
|
|
147
|
+
ground_truth=old_sample.ground_truth,
|
|
148
|
+
messages=fused_message,
|
|
149
|
+
context=old_sample.context,
|
|
150
|
+
possible_completions=old_sample.possible_completions,
|
|
151
|
+
)
|
|
152
|
+
if coherence_score < COHERENCE_THRESHOLD or novelty_score < NOVELTY_THRESHOLD:
|
|
153
|
+
# Fail! Stop generating
|
|
154
|
+
new_samples[idx] = (new_sample, True)
|
|
155
|
+
else:
|
|
156
|
+
# Continue generating
|
|
157
|
+
new_samples[idx] = (new_sample, False)
|
|
158
|
+
samples = new_samples
|
|
159
|
+
return message_history, errors
|
|
160
|
+
|
|
161
|
+
def generate_completions(
|
|
162
|
+
self,
|
|
163
|
+
llm: "BaseLLM",
|
|
164
|
+
samples: list[Sample],
|
|
165
|
+
stop_sequences: list[str] | None = None,
|
|
166
|
+
max_tokens: int | None = None,
|
|
167
|
+
) -> list[Completion]:
|
|
168
|
+
assert all(len(s.messages) == 1 and s.messages[0].role == Role.USER for s in samples), (
|
|
169
|
+
"Each sample must have exactly one USER message."
|
|
170
|
+
)
|
|
171
|
+
all_message_histories, errors = self._generation_loop(llm, stop_sequences, max_tokens, samples)
|
|
172
|
+
|
|
173
|
+
completion_list = []
|
|
174
|
+
for idx, sample in enumerate(samples):
|
|
175
|
+
messages = all_message_histories[idx]
|
|
176
|
+
error = errors[idx]
|
|
177
|
+
|
|
178
|
+
completion_list.append(
|
|
179
|
+
Completion(
|
|
180
|
+
id=sample.id,
|
|
181
|
+
subject=sample.subject,
|
|
182
|
+
ground_truth=sample.ground_truth,
|
|
183
|
+
prompt=sample.messages[0].content,
|
|
184
|
+
prompt_sequence_positions=None,
|
|
185
|
+
concat_compression=None,
|
|
186
|
+
messages=messages,
|
|
187
|
+
completion="".join([msg.content for msg in messages if msg.role == Role.ASSISTANT]),
|
|
188
|
+
raw_completion="".join([msg.content for msg in messages if msg.role == Role.ASSISTANT]),
|
|
189
|
+
raw_completion_sequence_positions=None,
|
|
190
|
+
context=sample.context,
|
|
191
|
+
error=error,
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
return completion_list
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class AidanBench(AidanBenchOriginal):
|
|
199
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
200
|
+
item_prompt = item["prompt"]
|
|
201
|
+
# We correct the prompt here by removing the extra dot after the colon.
|
|
202
|
+
base_prompt = (
|
|
203
|
+
"Answer the following question:\n"
|
|
204
|
+
"<question>" + item_prompt + "</question>\n"
|
|
205
|
+
"Provide your answer in <answer></answer> XML tags.\n"
|
|
206
|
+
)
|
|
207
|
+
base_prompt += (
|
|
208
|
+
"Your response should be one direct answer. "
|
|
209
|
+
"Only provide one answer. DO NOT list multiple answers. Please try to be concise.\n"
|
|
210
|
+
)
|
|
211
|
+
return base_prompt
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
4
|
+
AccuracyLoglikelihood,
|
|
5
|
+
AccuracyNormLoglikelihood,
|
|
6
|
+
)
|
|
7
|
+
from eval_framework.metrics.loglikelihood.confidence_weighted_accuracy import ConfidenceWeightedAccuracy
|
|
8
|
+
from eval_framework.metrics.loglikelihood.dcs import DistributionalCorrectnessScore
|
|
9
|
+
from eval_framework.metrics.loglikelihood.ternary import TernaryScore
|
|
10
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType
|
|
11
|
+
from eval_framework.tasks.utils import get_n_letters
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ARC(BaseTask[str]):
|
|
15
|
+
"""ARC dataset: https://huggingface.co/datasets/allenai/ai2_arc"""
|
|
16
|
+
|
|
17
|
+
NAME = "ARC"
|
|
18
|
+
DATASET_PATH = "ai2_arc"
|
|
19
|
+
SAMPLE_SPLIT = "test"
|
|
20
|
+
FEWSHOT_SPLIT = "train"
|
|
21
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
22
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
23
|
+
SUBJECTS = ["ARC-Easy", "ARC-Challenge"]
|
|
24
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question"] + get_n_letters(5)
|
|
25
|
+
LANGUAGE = Language.ENG
|
|
26
|
+
|
|
27
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
28
|
+
super().__init__(num_fewshot)
|
|
29
|
+
|
|
30
|
+
self.keys = get_n_letters(5) # needs to be 5 because there is one sample with 5 answer possibilities
|
|
31
|
+
self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
|
|
32
|
+
|
|
33
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
34
|
+
return f"Question: {item['question']}\n"
|
|
35
|
+
|
|
36
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
37
|
+
ground_truth = self._get_ground_truth(item)
|
|
38
|
+
assert ground_truth is not None
|
|
39
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
40
|
+
|
|
41
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
42
|
+
return "Answer:"
|
|
43
|
+
|
|
44
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
45
|
+
answer_key = self.num_to_letter.get(item["answerKey"], item["answerKey"])
|
|
46
|
+
return f" {item['choices']['text'][self.keys.index(answer_key)]}"
|
|
47
|
+
|
|
48
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
49
|
+
return [f" {choice}" for choice in item["choices"]["text"]]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ARC_IDK(ARC):
|
|
53
|
+
NAME = "ARC_IDK"
|
|
54
|
+
METRICS = [
|
|
55
|
+
AccuracyLoglikelihood,
|
|
56
|
+
AccuracyNormLoglikelihood,
|
|
57
|
+
ConfidenceWeightedAccuracy,
|
|
58
|
+
DistributionalCorrectnessScore,
|
|
59
|
+
TernaryScore,
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
63
|
+
return (
|
|
64
|
+
"Answer only if you are confident, since mistakes may be penalised, while correct answers receive points. "
|
|
65
|
+
"It is acceptable to answer with 'I do not know' if you are unsure, and you will receive 0 points."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
69
|
+
completions = super()._get_possible_completions(item)
|
|
70
|
+
return (completions or []) + [" I do not know."]
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
4
|
+
AccuracyLoglikelihood,
|
|
5
|
+
AccuracyNormLoglikelihood,
|
|
6
|
+
)
|
|
7
|
+
from eval_framework.tasks.base import NO_SUBJECT, BaseTask, Language, ResponseType
|
|
8
|
+
from eval_framework.tasks.utils import get_n_letters
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ARC_DE(BaseTask[str]):
|
|
12
|
+
"""ARC-DE dataset: https://huggingface.co/datasets/LeoLM/ArcChallenge_de"""
|
|
13
|
+
|
|
14
|
+
NAME = "ARC German"
|
|
15
|
+
DATASET_PATH = "LeoLM/ArcChallenge_de"
|
|
16
|
+
SAMPLE_SPLIT = "test"
|
|
17
|
+
FEWSHOT_SPLIT = "validation"
|
|
18
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
19
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
20
|
+
SUBJECTS = [NO_SUBJECT]
|
|
21
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Frage"] + get_n_letters(5)
|
|
22
|
+
LANGUAGE = Language.DEU
|
|
23
|
+
|
|
24
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
25
|
+
super().__init__(num_fewshot)
|
|
26
|
+
|
|
27
|
+
self.keys = get_n_letters(5) # needs to be 5 because there is one sample with 5 answer possibilities
|
|
28
|
+
self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
|
|
29
|
+
|
|
30
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
31
|
+
return f"Frage: {item['question_de']}\n"
|
|
32
|
+
|
|
33
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
34
|
+
ground_truth = self._get_ground_truth(item)
|
|
35
|
+
assert ground_truth is not None
|
|
36
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
37
|
+
|
|
38
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
39
|
+
return "Antwort:"
|
|
40
|
+
|
|
41
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
42
|
+
answer_key = self.num_to_letter.get(item["answerKey"], item["answerKey"])
|
|
43
|
+
return f" {item['choices_de']['text'][self.keys.index(answer_key)]}"
|
|
44
|
+
|
|
45
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
46
|
+
return [f" {choice}" for choice in item["choices_de"]["text"]]
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
4
|
+
AccuracyLoglikelihood,
|
|
5
|
+
AccuracyNormLoglikelihood,
|
|
6
|
+
)
|
|
7
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType
|
|
8
|
+
from eval_framework.tasks.utils import get_n_letters
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ARC_FI(BaseTask[str]):
|
|
12
|
+
"""ARC-FI dataset: https://huggingface.co/datasets/LumiOpen/arc_challenge_mt"""
|
|
13
|
+
|
|
14
|
+
NAME = "ARC Finnish"
|
|
15
|
+
DATASET_PATH = "LumiOpen/arc_challenge_mt"
|
|
16
|
+
SAMPLE_SPLIT = "test"
|
|
17
|
+
FEWSHOT_SPLIT = "validation"
|
|
18
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
19
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
20
|
+
SUBJECTS = ["fi"]
|
|
21
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question"] + get_n_letters(5)
|
|
22
|
+
LANGUAGE = Language.FIN
|
|
23
|
+
|
|
24
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
25
|
+
super().__init__(num_fewshot)
|
|
26
|
+
|
|
27
|
+
self.keys = get_n_letters(5) # needs to be 5 because there is one sample with 5 answer possibilities
|
|
28
|
+
self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
|
|
29
|
+
|
|
30
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
31
|
+
return f"Question: {item['question']}\n"
|
|
32
|
+
|
|
33
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
34
|
+
ground_truth = self._get_ground_truth(item)
|
|
35
|
+
assert ground_truth is not None
|
|
36
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
37
|
+
|
|
38
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
39
|
+
return "Answer:"
|
|
40
|
+
|
|
41
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
42
|
+
answer_key = self.num_to_letter.get(item["answerKey"], item["answerKey"])
|
|
43
|
+
return f" {item['choices']['text'][self.keys.index(answer_key)]}"
|
|
44
|
+
|
|
45
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
46
|
+
return [f" {choice}" for choice in item["choices"]["text"]]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
4
|
+
AccuracyLoglikelihood,
|
|
5
|
+
AccuracyNormLoglikelihood,
|
|
6
|
+
)
|
|
7
|
+
from eval_framework.tasks.base import BaseTask, Language, ResponseType
|
|
8
|
+
from eval_framework.tasks.utils import get_n_letters
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BELEBELE(BaseTask[str]):
|
|
12
|
+
"""BELEBELE dataset: https://huggingface.co/datasets/facebook/belebele"""
|
|
13
|
+
|
|
14
|
+
NAME = "BELEBELE"
|
|
15
|
+
DATASET_PATH = "facebook/belebele"
|
|
16
|
+
SAMPLE_SPLIT = "test"
|
|
17
|
+
FEWSHOT_SPLIT = "test"
|
|
18
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
19
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
20
|
+
SUBJECTS = [
|
|
21
|
+
"eng_Latn",
|
|
22
|
+
]
|
|
23
|
+
PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Answer"] + get_n_letters(4)
|
|
24
|
+
LANGUAGE = Language.ENG
|
|
25
|
+
|
|
26
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
27
|
+
super().__init__(num_fewshot)
|
|
28
|
+
|
|
29
|
+
self.keys = get_n_letters(4)
|
|
30
|
+
self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
|
|
31
|
+
|
|
32
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
33
|
+
return "The following are multiple choice questions (with answers)."
|
|
34
|
+
|
|
35
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
36
|
+
context = item["flores_passage"].strip()
|
|
37
|
+
question = item["question"].strip()
|
|
38
|
+
choices = "".join(
|
|
39
|
+
[
|
|
40
|
+
f"{key}. {choice}\n"
|
|
41
|
+
for key, choice in zip(
|
|
42
|
+
self.keys, [item["mc_answer1"], item["mc_answer2"], item["mc_answer3"], item["mc_answer4"]]
|
|
43
|
+
)
|
|
44
|
+
]
|
|
45
|
+
)
|
|
46
|
+
return f"{context}\n\nQuestion: {question}\n{choices}"
|
|
47
|
+
|
|
48
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
49
|
+
ground_truth = self._get_ground_truth(item)
|
|
50
|
+
assert ground_truth is not None
|
|
51
|
+
return f"{self._get_cue_text(item)}{ground_truth}"
|
|
52
|
+
|
|
53
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
54
|
+
return "Answer:"
|
|
55
|
+
|
|
56
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
57
|
+
return f" {self.keys[int(item['correct_answer_num']) - 1]}"
|
|
58
|
+
|
|
59
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
60
|
+
return [f" {key}" for key in self.keys]
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from eval_framework.metrics.completion.code_execution_pass_at_one import (
|
|
6
|
+
CodeExecutionPassAtOne,
|
|
7
|
+
CodeExecutionPassAtOneContext,
|
|
8
|
+
)
|
|
9
|
+
from eval_framework.tasks.base import (
|
|
10
|
+
RANDOM_SEED,
|
|
11
|
+
BaseTask,
|
|
12
|
+
Language,
|
|
13
|
+
ResponseType,
|
|
14
|
+
Sample,
|
|
15
|
+
SubjectType,
|
|
16
|
+
)
|
|
17
|
+
from eval_framework.tasks.utils import (
|
|
18
|
+
BIG_CODE_BENCH_PACKAGE_MAPPING,
|
|
19
|
+
CallableSerializer,
|
|
20
|
+
_parse_unittest_output,
|
|
21
|
+
unittest_merge_snippets,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
PROMPT_INSTRUCTION = (
|
|
25
|
+
"Please provide a self-contained Python script, without tests or example usage, that solves the following "
|
|
26
|
+
"problem in a markdown code block:\n"
|
|
27
|
+
) # from https://arxiv.org/pdf/2406.15877 - Figure 14
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
RESPONSE_PREFIX = (
|
|
31
|
+
"Below is a Python script with a self-contained function that solves the problem and passes "
|
|
32
|
+
"corresponding tests:\n"
|
|
33
|
+
) # from https://github.com/bigcode-project/bigcodebench/blob/main/bigcodebench/generate.py#L149
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BigCodeBench(BaseTask[str]):
|
|
37
|
+
"""BigCodeBench dataset: https://huggingface.co/datasets/bigcode/bigcodebench"""
|
|
38
|
+
|
|
39
|
+
NAME = "BigCodeBench"
|
|
40
|
+
DATASET_PATH = "bigcode/bigcodebench"
|
|
41
|
+
SAMPLE_SPLIT = "v0.1.4"
|
|
42
|
+
FEWSHOT_SPLIT = "v0.1.4" # (there is no dedicated split, few-shot is not expected for this dataset)
|
|
43
|
+
RESPONSE_TYPE = ResponseType.COMPLETION
|
|
44
|
+
METRICS = [CodeExecutionPassAtOne]
|
|
45
|
+
SUBJECTS = ["original", "calibrated"]
|
|
46
|
+
LANGUAGE = Language.ENG
|
|
47
|
+
|
|
48
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
49
|
+
assert num_fewshot == 0, "Fewshot is not supported for BigCodeBench"
|
|
50
|
+
# NOTE : this serializer should be the same class as initialized in the metric
|
|
51
|
+
self.serializer = CallableSerializer()
|
|
52
|
+
super().__init__(num_fewshot)
|
|
53
|
+
|
|
54
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
55
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=None)
|
|
56
|
+
self.dataset = {}
|
|
57
|
+
|
|
58
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
59
|
+
|
|
60
|
+
for split, data in hf_dataset.items():
|
|
61
|
+
data_list = list(data)
|
|
62
|
+
|
|
63
|
+
if split == self.SAMPLE_SPLIT:
|
|
64
|
+
self.rnd.shuffle(data_list)
|
|
65
|
+
|
|
66
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
67
|
+
self.dataset[split] = data_list
|
|
68
|
+
|
|
69
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
70
|
+
return PROMPT_INSTRUCTION + item["complete_prompt"]
|
|
71
|
+
|
|
72
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
73
|
+
return RESPONSE_PREFIX + (item["code_prompt"] if item["subject"] == "calibrated" else "")
|
|
74
|
+
|
|
75
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
76
|
+
return item["canonical_solution"] # Not needed for evaluation, as it is test based given the generated code
|
|
77
|
+
|
|
78
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
def _get_context(self, item: dict[str, Any]) -> CodeExecutionPassAtOneContext:
|
|
82
|
+
return CodeExecutionPassAtOneContext(
|
|
83
|
+
run_env="python:3.12", # os.environ.get("DOCKER_CODE_EXECUTION"),
|
|
84
|
+
code_prompt=item["code_prompt"],
|
|
85
|
+
test_code=item["test"],
|
|
86
|
+
snippet_merge_fn=self.serializer.encode(unittest_merge_snippets),
|
|
87
|
+
output_parse_fn=self.serializer.encode(_parse_unittest_output),
|
|
88
|
+
package_downloads=BIG_CODE_BENCH_PACKAGE_MAPPING,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
92
|
+
if sample is not None and sample.context is not None and sample.subject == "calibrated":
|
|
93
|
+
assert isinstance(sample.context, CodeExecutionPassAtOneContext), "Expected CodeExecutionPassAtOneContext"
|
|
94
|
+
processed_text = (sample.context.code_prompt if sample.context is not None else "") + completion_text
|
|
95
|
+
else:
|
|
96
|
+
processed_text = extract_executable_code(completion_text)
|
|
97
|
+
|
|
98
|
+
return processed_text
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class BigCodeBenchInstruct(BigCodeBench):
|
|
102
|
+
"""BigCodeBench dataset: https://huggingface.co/datasets/bigcode/bigcodebench"""
|
|
103
|
+
|
|
104
|
+
NAME = "BigCodeBenchInstruct"
|
|
105
|
+
|
|
106
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
107
|
+
return PROMPT_INSTRUCTION + item["instruct_prompt"]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class BigCodeBenchHard(BigCodeBench):
|
|
111
|
+
"""BigCodeBench dataset: https://huggingface.co/datasets/bigcode/bigcodebench-hard"""
|
|
112
|
+
|
|
113
|
+
NAME = "BigCodeBenchHard"
|
|
114
|
+
DATASET_PATH = "bigcode/bigcodebench-hard"
|
|
115
|
+
|
|
116
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
117
|
+
return PROMPT_INSTRUCTION + item["complete_prompt"]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class BigCodeBenchHardInstruct(BigCodeBenchHard):
|
|
121
|
+
"""BigCodeBench dataset: https://huggingface.co/datasets/bigcode/bigcodebench-hard"""
|
|
122
|
+
|
|
123
|
+
NAME = "BigCodeBenchHardInstruct"
|
|
124
|
+
|
|
125
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
126
|
+
return PROMPT_INSTRUCTION + item["instruct_prompt"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def extract_executable_code(llm_response: str) -> str:
|
|
130
|
+
# Look for nested markdown+python pattern
|
|
131
|
+
nested_pattern = r"```markdown.*?```python\s*(.*?)\s*```"
|
|
132
|
+
nested_matches = re.findall(nested_pattern, llm_response, re.DOTALL)
|
|
133
|
+
if nested_matches:
|
|
134
|
+
return nested_matches[0].strip()
|
|
135
|
+
|
|
136
|
+
# Look for python code blocks
|
|
137
|
+
python_pattern = r"```python\s*(.*?)\s*```"
|
|
138
|
+
python_matches = re.findall(python_pattern, llm_response, re.DOTALL)
|
|
139
|
+
if python_matches:
|
|
140
|
+
return python_matches[0].strip()
|
|
141
|
+
|
|
142
|
+
# Look for markdown-only code blocks
|
|
143
|
+
markdown_pattern = r"```markdown\s*(.*?)\s*```"
|
|
144
|
+
markdown_matches = re.findall(markdown_pattern, llm_response, re.DOTALL)
|
|
145
|
+
if markdown_matches:
|
|
146
|
+
return markdown_matches[0].strip()
|
|
147
|
+
|
|
148
|
+
# Look for generic code blocks as fallback
|
|
149
|
+
generic_pattern = r"```\s*(.*?)\s*```"
|
|
150
|
+
generic_matches = re.findall(generic_pattern, llm_response, re.DOTALL)
|
|
151
|
+
if generic_matches:
|
|
152
|
+
return generic_matches[0].strip()
|
|
153
|
+
|
|
154
|
+
# If no code blocks found, return original response
|
|
155
|
+
return llm_response
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
|
|
5
|
+
AccuracyLoglikelihood,
|
|
6
|
+
AccuracyNormLoglikelihood,
|
|
7
|
+
)
|
|
8
|
+
from eval_framework.tasks.base import NO_SUBJECT, RANDOM_SEED, BaseTask, Language, ResponseType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CASEHOLD(BaseTask[str]):
|
|
12
|
+
NAME = "CaseHold"
|
|
13
|
+
DATASET_PATH = "lex_glue"
|
|
14
|
+
SAMPLE_SPLIT = "test"
|
|
15
|
+
FEWSHOT_SPLIT = "train"
|
|
16
|
+
RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
|
|
17
|
+
METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
|
|
18
|
+
SUBJECTS = ["case_hold"]
|
|
19
|
+
LANGUAGE = Language.ENG
|
|
20
|
+
|
|
21
|
+
def _load_dataset(self, subject: str) -> None:
|
|
22
|
+
name = subject if subject != NO_SUBJECT else None
|
|
23
|
+
|
|
24
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
|
|
25
|
+
self.dataset = {}
|
|
26
|
+
|
|
27
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
28
|
+
|
|
29
|
+
for split, data in hf_dataset.items():
|
|
30
|
+
data_list = list(data)
|
|
31
|
+
|
|
32
|
+
if split == self.SAMPLE_SPLIT:
|
|
33
|
+
self.rnd.shuffle(data_list)
|
|
34
|
+
|
|
35
|
+
if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
36
|
+
self.dataset[split] = [i for i in data_list if i["context"].count("(<HOLDING>)") == 1]
|
|
37
|
+
|
|
38
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
39
|
+
return item["context"].split("(<HOLDING>)", maxsplit=1)[0]
|
|
40
|
+
|
|
41
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
|
|
42
|
+
right = item["context"].split("(<HOLDING>)", maxsplit=1)[1]
|
|
43
|
+
return f"{item['endings'][item['label']]}{right}"
|
|
44
|
+
|
|
45
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
46
|
+
right = item["context"].split("(<HOLDING>)", maxsplit=1)[1]
|
|
47
|
+
return [f"{ending}{right}" for ending in item["endings"]]
|