eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
from typing import Annotated, NamedTuple, Self, TypeVar, cast
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
from eval_framework.metrics.llm.graders.language import detect_language_of
|
|
8
|
+
from eval_framework.utils.helpers import count_bytes
|
|
9
|
+
from template_formatting.formatter import ConcatFormatter, Message, Role
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConcatCompression(NamedTuple):
|
|
13
|
+
"""Helper class for storing compression info for the concat formatter.
|
|
14
|
+
|
|
15
|
+
The concat formatter is used to avoid bias towards special tokens.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
num_bytes: int
|
|
19
|
+
num_tokens: int
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def calculate(
|
|
23
|
+
cls,
|
|
24
|
+
messages: Sequence[Message],
|
|
25
|
+
count_tokens: Callable[[str], int],
|
|
26
|
+
choices: list[str] | None = None,
|
|
27
|
+
completion: str | None = None,
|
|
28
|
+
) -> Self | None:
|
|
29
|
+
"""Calculate the compression info for the given messages and token counting function."""
|
|
30
|
+
if (choices is None) == (completion is None):
|
|
31
|
+
raise ValueError("Either possible_completions or completion must be provided, but not both.")
|
|
32
|
+
concat_str = ConcatFormatter().format(messages, output_mode="string")
|
|
33
|
+
|
|
34
|
+
if choices is not None:
|
|
35
|
+
if any(c is None for c in choices):
|
|
36
|
+
return None
|
|
37
|
+
num_bytes = count_bytes(concat_str) + sum(count_bytes(c) for c in choices)
|
|
38
|
+
num_tokens = count_tokens(concat_str) + sum(count_tokens(c) for c in choices)
|
|
39
|
+
else:
|
|
40
|
+
if completion is None:
|
|
41
|
+
return None
|
|
42
|
+
concat_str = f"{concat_str}{completion}"
|
|
43
|
+
num_bytes = count_bytes(concat_str)
|
|
44
|
+
num_tokens = count_tokens(concat_str)
|
|
45
|
+
|
|
46
|
+
res = cls(num_bytes=num_bytes, num_tokens=num_tokens)
|
|
47
|
+
if res.num_bytes > 0 and res.num_tokens > 0:
|
|
48
|
+
return res
|
|
49
|
+
else:
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class BaseMetricContext(BaseModel):
|
|
54
|
+
"""Base class for metric context"""
|
|
55
|
+
|
|
56
|
+
model_config = ConfigDict(extra="allow")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class LanguageMetricContext(BaseMetricContext):
|
|
60
|
+
language: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class UntemplatedPrompt(BaseMetricContext):
|
|
64
|
+
untemplated_prompt: str
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Error(BaseModel):
|
|
68
|
+
model_config = ConfigDict(extra="forbid")
|
|
69
|
+
error_class: str
|
|
70
|
+
message: str
|
|
71
|
+
traceback: str
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class PromptTooLongException(Exception):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BaseCompletion(BaseModel):
|
|
79
|
+
model_config = ConfigDict(extra="forbid")
|
|
80
|
+
prompt: Annotated[str, "prompt as passed to the llm"]
|
|
81
|
+
prompt_sequence_positions: Annotated[
|
|
82
|
+
int | None,
|
|
83
|
+
"number of sequence positions that the prompt occupies in the llm architecture (e.g. token count) "
|
|
84
|
+
"or None if the info is not available",
|
|
85
|
+
]
|
|
86
|
+
completion: Annotated[str, "completion as generated by the llm"]
|
|
87
|
+
concat_compression: Annotated[ConcatCompression | None, "Compression info for the concat formatter."] = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class RawCompletion(BaseCompletion):
|
|
91
|
+
completion_sequence_positions: Annotated[
|
|
92
|
+
int | None,
|
|
93
|
+
"number of sequence positions that the completion occupies in the llm architecture "
|
|
94
|
+
"(e.g. token count) or None if the info is not available",
|
|
95
|
+
]
|
|
96
|
+
raw_completion_error: Error | None = None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Completion(BaseCompletion):
|
|
100
|
+
id: int
|
|
101
|
+
subject: str
|
|
102
|
+
ground_truth: str | None | list[str]
|
|
103
|
+
messages: list[Message] | None # needed for LLM as a judge
|
|
104
|
+
raw_completion: Annotated[str, "raw completion as generated by the llm"]
|
|
105
|
+
raw_completion_sequence_positions: Annotated[
|
|
106
|
+
int | None,
|
|
107
|
+
"number of sequence positions that the completion occupies in the llm architecture or None "
|
|
108
|
+
"if the info is not available",
|
|
109
|
+
]
|
|
110
|
+
context: list[BaseMetricContext] | BaseMetricContext | None = None
|
|
111
|
+
error: Error | None = None
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def ground_truth_list(self) -> list[str] | list[None]:
|
|
115
|
+
if isinstance(self.ground_truth, list):
|
|
116
|
+
return self.ground_truth
|
|
117
|
+
|
|
118
|
+
return [self.ground_truth] # type: ignore[return-value]
|
|
119
|
+
|
|
120
|
+
# Use just the raw messages for instructions to LLM judges, not the original prompt with its special formatting.
|
|
121
|
+
# (see https://x.com/karpathy/status/1823418177197646104 for a motivation).
|
|
122
|
+
@property
|
|
123
|
+
def system_user_instruction(self) -> str:
|
|
124
|
+
assert self.messages is not None
|
|
125
|
+
return "\n\n".join([m.content for m in self.messages if m.role in (Role.SYSTEM, Role.USER)])
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def user_instruction(self) -> str:
|
|
129
|
+
assert self.messages is not None
|
|
130
|
+
return "\n\n".join([m.content for m in self.messages if m.role == Role.USER])
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def first_user_instruction(self) -> str:
|
|
134
|
+
assert self.messages is not None
|
|
135
|
+
user_messages = [m.content for m in self.messages if m.role == Role.USER]
|
|
136
|
+
return user_messages[0] if user_messages else ""
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def all_but_first_user_instruction(self) -> str:
|
|
140
|
+
assert self.messages is not None
|
|
141
|
+
user_messages = [m.content for m in self.messages if m.role == Role.USER]
|
|
142
|
+
return "\n\n".join(user_messages[1:]) if len(user_messages) > 1 else ""
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def last_user_instruction(self) -> str:
|
|
146
|
+
assert self.messages is not None
|
|
147
|
+
user_messages = [m.content for m in self.messages if m.role == Role.USER]
|
|
148
|
+
return user_messages[-1] if user_messages else ""
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def sanitized_completion(self) -> str:
|
|
152
|
+
# Make sure the completion doesn't contain any obvious special chars either by "breaking" any <|xyz|> pattern.
|
|
153
|
+
return re.sub(r"<\|(\S+)\|>", r"<| \1 |>", self.completion)
|
|
154
|
+
|
|
155
|
+
def get_completion_language(self) -> str:
|
|
156
|
+
detected_language = ""
|
|
157
|
+
if self.context and isinstance(self.context, LanguageMetricContext):
|
|
158
|
+
detected_language = self.context.language
|
|
159
|
+
else:
|
|
160
|
+
detected_language_object = detect_language_of(self.completion)
|
|
161
|
+
detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
|
|
162
|
+
return detected_language
|
|
163
|
+
|
|
164
|
+
def get_raw_completion_language(self) -> str:
|
|
165
|
+
detected_language = ""
|
|
166
|
+
if self.context and isinstance(self.context, LanguageMetricContext):
|
|
167
|
+
detected_language = self.context.language
|
|
168
|
+
else:
|
|
169
|
+
detected_language_object = detect_language_of(self.raw_completion)
|
|
170
|
+
detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
|
|
171
|
+
return detected_language
|
|
172
|
+
|
|
173
|
+
def get_instruction_language(self) -> str:
|
|
174
|
+
detected_language = ""
|
|
175
|
+
if self.context and isinstance(self.context, LanguageMetricContext):
|
|
176
|
+
detected_language = self.context.language
|
|
177
|
+
else:
|
|
178
|
+
detected_language_object = detect_language_of(self.user_instruction)
|
|
179
|
+
detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
|
|
180
|
+
return detected_language
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class BaseLoglikelihood(BaseModel):
|
|
184
|
+
model_config = ConfigDict(extra="forbid")
|
|
185
|
+
prompt: str
|
|
186
|
+
prompt_sequence_positions: int | None
|
|
187
|
+
loglikelihoods: dict[str, float]
|
|
188
|
+
loglikelihoods_sequence_positions: dict[str, int] # Is empty if the model does not provide sequence positions
|
|
189
|
+
concat_compression: Annotated[ConcatCompression | None, "Compression info for the concat formatter"] = None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class RawLoglikelihood(BaseLoglikelihood):
|
|
193
|
+
raw_loglikelihood_error: Error | None = None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class Loglikelihood(BaseLoglikelihood):
|
|
197
|
+
id: int
|
|
198
|
+
subject: str
|
|
199
|
+
ground_truth: str | list[str]
|
|
200
|
+
error: Error | None = None
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def ground_truth_list(self) -> list[str] | list[None]:
|
|
204
|
+
if isinstance(self.ground_truth, list):
|
|
205
|
+
return self.ground_truth
|
|
206
|
+
return [self.ground_truth] # type: ignore[return-value]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
MetricContext = TypeVar("MetricContext", bound=BaseMetricContext)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def extract_context_metric[MetricContext: BaseMetricContext](
|
|
213
|
+
response: Completion, metric_context_class: type[MetricContext]
|
|
214
|
+
) -> MetricContext:
|
|
215
|
+
assert response.context is not None, "Expected context to be provided in the response"
|
|
216
|
+
if not isinstance(response.context, list):
|
|
217
|
+
assert isinstance(response.context, metric_context_class) or isinstance(response.context, BaseMetricContext), (
|
|
218
|
+
f"Expected context to be of type {metric_context_class.__name__}, got {type(response.context).__name__}"
|
|
219
|
+
)
|
|
220
|
+
return cast(MetricContext, response.context)
|
|
221
|
+
else:
|
|
222
|
+
assert len(response.context) > 0, "Expected context to be provided in the response"
|
|
223
|
+
context = [
|
|
224
|
+
metric_context for metric_context in response.context if isinstance(metric_context, metric_context_class)
|
|
225
|
+
][0]
|
|
226
|
+
assert context is not None, f"Expected {metric_context_class.__name__} to be provided in the response context"
|
|
227
|
+
return cast(MetricContext, context)
|
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
import traceback
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Self, TypeVar
|
|
10
|
+
|
|
11
|
+
import iso639
|
|
12
|
+
from datasets import DatasetDict, DownloadConfig, load_dataset
|
|
13
|
+
from huggingface_hub import HfApi
|
|
14
|
+
from huggingface_hub.errors import RevisionNotFoundError
|
|
15
|
+
from pydantic import BaseModel, ConfigDict
|
|
16
|
+
|
|
17
|
+
from eval_framework.shared.types import BaseMetricContext, Completion, Error, RawCompletion
|
|
18
|
+
from eval_framework.tasks.utils import raise_errors
|
|
19
|
+
from template_formatting.formatter import Message, Role
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from eval_framework.llm.base import BaseLLM
|
|
23
|
+
from eval_framework.metrics.base import BaseMetric
|
|
24
|
+
|
|
25
|
+
RANDOM_SEED = 42
|
|
26
|
+
NO_SUBJECT = "no_subject"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ResponseType(Enum):
|
|
30
|
+
COMPLETION = "completion"
|
|
31
|
+
LOGLIKELIHOODS = "loglikelihoods"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Language(Enum):
|
|
35
|
+
ENG = "English"
|
|
36
|
+
DEU = "German"
|
|
37
|
+
FRA = "French"
|
|
38
|
+
ITA = "Italian"
|
|
39
|
+
SPA = "Spanish"
|
|
40
|
+
POR = "Portuguese"
|
|
41
|
+
NLD = "Dutch"
|
|
42
|
+
FIN = "Finnish"
|
|
43
|
+
SWE = "Swedish"
|
|
44
|
+
ARB = "Arabic"
|
|
45
|
+
POL = "Polish"
|
|
46
|
+
RUS = "Russian"
|
|
47
|
+
UKR = "Ukrainian"
|
|
48
|
+
HRV = "Croatian"
|
|
49
|
+
SRP = "Serbian"
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def add_members(cls, new_members: dict[str, Any]) -> type["Language"]:
|
|
53
|
+
members = {member.name: member.value for member in cls}
|
|
54
|
+
for name, value in new_members.items():
|
|
55
|
+
if name not in members:
|
|
56
|
+
members[name] = value
|
|
57
|
+
return Enum(cls.__name__, members) # type: ignore[return-value]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
languages: dict[str, str] = {}
|
|
61
|
+
for language in iso639.ALL_LANGUAGES:
|
|
62
|
+
enum_name = language.part3.upper()
|
|
63
|
+
languages[enum_name] = language.name
|
|
64
|
+
|
|
65
|
+
Language: type[Enum] = Language.add_members(languages) # type: ignore[no-redef]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Sample(BaseModel):
|
|
69
|
+
model_config = ConfigDict(extra="forbid")
|
|
70
|
+
id: int
|
|
71
|
+
subject: str
|
|
72
|
+
messages: list[Message]
|
|
73
|
+
ground_truth: str | list[str] | None
|
|
74
|
+
possible_completions: list[str] | None
|
|
75
|
+
context: BaseMetricContext | list[BaseMetricContext] | None = None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
SubjectType = TypeVar("SubjectType")
|
|
79
|
+
|
|
80
|
+
logger = logging.getLogger(__name__)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class BaseTask[SubjectType](ABC):
|
|
84
|
+
NAME: str
|
|
85
|
+
DATASET_PATH: str
|
|
86
|
+
SAMPLE_SPLIT: str
|
|
87
|
+
FEWSHOT_SPLIT: str
|
|
88
|
+
RESPONSE_TYPE: ResponseType
|
|
89
|
+
METRICS: list[type["BaseMetric"]]
|
|
90
|
+
SUBJECTS: list[SubjectType]
|
|
91
|
+
HF_REVISION: str | None = None # tag name, or branch name, or commit hash to ensure reproducibility
|
|
92
|
+
|
|
93
|
+
# Words in _get_instruction_text() not to be perturbed. List of words is case insensitive. No special characters
|
|
94
|
+
# or whitespace should be included.
|
|
95
|
+
PERTURBATION_UNMODIFIABLE_WORDS: list[str] | None
|
|
96
|
+
|
|
97
|
+
# The language (or languages) tested by the benchmark. Accepts a single string, a dictionary specifying
|
|
98
|
+
# language by subtopic, or `None` (for tasks not specific to a single language).
|
|
99
|
+
LANGUAGE: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None
|
|
100
|
+
|
|
101
|
+
def __init__(self, num_fewshot: int = 0) -> None:
|
|
102
|
+
self.num_fewshot = num_fewshot
|
|
103
|
+
self.stop_sequences: list[str] | None = None
|
|
104
|
+
self.max_tokens: int | None = None
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def with_overwrite(
|
|
108
|
+
cls, num_fewshot: int, *, custom_subjects: list[str] | None, custom_hf_revision: str | None
|
|
109
|
+
) -> Self:
|
|
110
|
+
instance = cls(num_fewshot=num_fewshot)
|
|
111
|
+
|
|
112
|
+
# If custom subjects were provided during initialization, they take precedence over the class-level SUBJECTS.
|
|
113
|
+
filtered_subjects = instance._filter_task_subjects(custom_subjects=custom_subjects)
|
|
114
|
+
if filtered_subjects:
|
|
115
|
+
logger.info(f"Setting SUBJECTS to `{filtered_subjects}` for the task {instance.__class__.__name__}")
|
|
116
|
+
instance.SUBJECTS = filtered_subjects # type: ignore[assignment]
|
|
117
|
+
|
|
118
|
+
# If a custom revision was provided during initialization, it takes precedence over the class-level HF_REVISION.
|
|
119
|
+
if custom_hf_revision:
|
|
120
|
+
logger.info(f"Setting HF revision to `{custom_hf_revision}` for the task {instance.__class__.__name__}")
|
|
121
|
+
instance.HF_REVISION = custom_hf_revision
|
|
122
|
+
|
|
123
|
+
return instance
|
|
124
|
+
|
|
125
|
+
def _filter_task_subjects(self, custom_subjects: list[str] | None) -> list[str] | list[tuple] | None:
|
|
126
|
+
"""Process custom subjects passed from EvalConfig. Check and returns restricted task subjects if specified."""
|
|
127
|
+
if not custom_subjects:
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
assert hasattr(self, "SUBJECTS") and len(self.SUBJECTS) > 0
|
|
131
|
+
if isinstance(self.SUBJECTS[0], tuple):
|
|
132
|
+
# subjects are specified as strings but we need tuples
|
|
133
|
+
filters = [tuple(item.strip() for item in subject.split(",")) for subject in custom_subjects]
|
|
134
|
+
|
|
135
|
+
# check if all parts of custom subjects exists (* is a wildcard)
|
|
136
|
+
num_items = len(self.SUBJECTS[0])
|
|
137
|
+
legal_values = [
|
|
138
|
+
set([s[i] for s in self.SUBJECTS if isinstance(s, tuple)] + ["*"]) for i in range(num_items)
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
for tpl in filters:
|
|
142
|
+
for i, v in enumerate(tpl):
|
|
143
|
+
assert v in legal_values[i], f"Subject part {v} not found in task {self.__class__.__name__}"
|
|
144
|
+
|
|
145
|
+
# filter task subjects. * is a supported wildcard for a specific item in a tuple, e.g. "DE_DE, *"
|
|
146
|
+
chosen_subjects = []
|
|
147
|
+
for subject in self.SUBJECTS:
|
|
148
|
+
subject_tuple = subject if isinstance(subject, tuple) else tuple(str(subject).split(","))
|
|
149
|
+
for filter in filters:
|
|
150
|
+
if all(filter[i] == "*" or filter[i] == subject_tuple[i] for i in range(num_items)):
|
|
151
|
+
chosen_subjects.append(subject_tuple)
|
|
152
|
+
break
|
|
153
|
+
return chosen_subjects # type: ignore[return-value]
|
|
154
|
+
else:
|
|
155
|
+
for cs in custom_subjects:
|
|
156
|
+
assert cs in self.SUBJECTS, f"Subject {cs} not found in task {self.__class__.__name__}"
|
|
157
|
+
return custom_subjects # type: ignore[return-value]
|
|
158
|
+
|
|
159
|
+
def _load_hf_dataset(self, **kwargs: Any) -> Any:
|
|
160
|
+
# Check if the HF_REVISION is valid before loading the dataset
|
|
161
|
+
if self.HF_REVISION:
|
|
162
|
+
try:
|
|
163
|
+
_ = HfApi().dataset_info(repo_id=kwargs["path"], revision=self.HF_REVISION, timeout=100.0)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
if isinstance(e, RevisionNotFoundError):
|
|
166
|
+
raise e
|
|
167
|
+
|
|
168
|
+
cache_dir: str = os.environ.get("HF_DATASET_CACHE_DIR", f"{Path.home()}/.cache/huggingface/datasets")
|
|
169
|
+
download_config = DownloadConfig(cache_dir=cache_dir, max_retries=5)
|
|
170
|
+
try:
|
|
171
|
+
return load_dataset(
|
|
172
|
+
**kwargs,
|
|
173
|
+
revision=self.HF_REVISION,
|
|
174
|
+
trust_remote_code=True,
|
|
175
|
+
cache_dir=cache_dir,
|
|
176
|
+
download_config=download_config,
|
|
177
|
+
)
|
|
178
|
+
except Exception:
|
|
179
|
+
return load_dataset(
|
|
180
|
+
**kwargs,
|
|
181
|
+
revision=self.HF_REVISION,
|
|
182
|
+
trust_remote_code=True,
|
|
183
|
+
cache_dir=f"{Path.home()}/.cache/eval-framework",
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def _shuffle_splits(self, hf_dataset: DatasetDict) -> dict[str, Any]:
|
|
187
|
+
dataset = {}
|
|
188
|
+
self.rnd = random.Random(RANDOM_SEED)
|
|
189
|
+
|
|
190
|
+
for split, data in hf_dataset.items():
|
|
191
|
+
if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
data_list = list(data)
|
|
195
|
+
|
|
196
|
+
if split == self.SAMPLE_SPLIT:
|
|
197
|
+
self.rnd.shuffle(data_list)
|
|
198
|
+
|
|
199
|
+
dataset[split] = data_list
|
|
200
|
+
|
|
201
|
+
return dataset
|
|
202
|
+
|
|
203
|
+
def _load_dataset(self, subject: SubjectType) -> None:
|
|
204
|
+
name = subject if subject != NO_SUBJECT else None
|
|
205
|
+
hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
|
|
206
|
+
self.dataset = self._shuffle_splits(hf_dataset=hf_dataset)
|
|
207
|
+
|
|
208
|
+
def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
|
|
209
|
+
return completion_text
|
|
210
|
+
|
|
211
|
+
def _get_example_messages(self, item: dict[str, Any]) -> list[Message]:
|
|
212
|
+
fewshot_examples = self._sample_fewshot_examples(item) if self.num_fewshot > 0 else []
|
|
213
|
+
|
|
214
|
+
example_messages = []
|
|
215
|
+
for fewshot_example in fewshot_examples:
|
|
216
|
+
fewshot_example["subject"] = item["subject"]
|
|
217
|
+
example_messages.extend(self._get_instruction_messages(fewshot_example))
|
|
218
|
+
example_messages.append(
|
|
219
|
+
Message(role=Role.ASSISTANT, content=self._get_fewshot_target_text(fewshot_example))
|
|
220
|
+
)
|
|
221
|
+
return example_messages
|
|
222
|
+
|
|
223
|
+
def _get_messages(self, item: dict[str, Any]) -> list[Message]:
|
|
224
|
+
example_messages = self._get_example_messages(item)
|
|
225
|
+
instruction_message = self._get_instruction_messages(item)
|
|
226
|
+
cue_text = self._get_cue_text(item)
|
|
227
|
+
cue_message = [Message(role=Role.ASSISTANT, content=cue_text)] if cue_text else []
|
|
228
|
+
messages = example_messages + instruction_message + cue_message
|
|
229
|
+
if initial_prompt_text := self._get_initial_prompt_text(item):
|
|
230
|
+
first_message = messages[0]
|
|
231
|
+
assert first_message.role == Role.USER
|
|
232
|
+
first_message.content = f"{initial_prompt_text}\n\n{first_message.content}"
|
|
233
|
+
|
|
234
|
+
if system_prompt_text := self._get_system_prompt_text(item):
|
|
235
|
+
return [Message(role=Role.SYSTEM, content=system_prompt_text)] + messages
|
|
236
|
+
return messages
|
|
237
|
+
|
|
238
|
+
def _get_instruction_messages(self, item: dict[str, Any]) -> list[Message]:
|
|
239
|
+
return [Message(role=Role.USER, content=self._get_instruction_text(item))]
|
|
240
|
+
|
|
241
|
+
def iterate_samples(self, num_samples: int | None = None) -> Iterable[Sample]:
|
|
242
|
+
for subject in self.SUBJECTS:
|
|
243
|
+
self._load_dataset(subject)
|
|
244
|
+
assert len(self.dataset[self.SAMPLE_SPLIT]) > 0
|
|
245
|
+
done = False
|
|
246
|
+
index = 0
|
|
247
|
+
for item in self.dataset[self.SAMPLE_SPLIT]:
|
|
248
|
+
if done:
|
|
249
|
+
break
|
|
250
|
+
item["subject"] = subject
|
|
251
|
+
for sample in self._create_samples(item, index, str(subject)):
|
|
252
|
+
yield sample
|
|
253
|
+
index += 1
|
|
254
|
+
if index == num_samples:
|
|
255
|
+
done = True
|
|
256
|
+
break
|
|
257
|
+
|
|
258
|
+
def _create_samples(self, item: dict[str, Any], index: int, subject: str) -> list[Sample]:
|
|
259
|
+
"""Creates one or more samples from a single dataset item. Default implementation returns single sample."""
|
|
260
|
+
return [
|
|
261
|
+
Sample(
|
|
262
|
+
id=index,
|
|
263
|
+
subject=str(subject),
|
|
264
|
+
messages=self._get_messages(item),
|
|
265
|
+
ground_truth=self._get_ground_truth(item),
|
|
266
|
+
possible_completions=self._get_possible_completions(item),
|
|
267
|
+
context=self._get_context(item),
|
|
268
|
+
)
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
|
|
272
|
+
return ""
|
|
273
|
+
|
|
274
|
+
def _get_system_prompt_text(self, item: dict[str, Any]) -> str | None:
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
@abstractmethod
|
|
278
|
+
def _get_instruction_text(self, item: dict[str, Any]) -> str:
|
|
279
|
+
raise NotImplementedError
|
|
280
|
+
|
|
281
|
+
def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
|
|
282
|
+
target = self._get_ground_truth(item)
|
|
283
|
+
assert target is not None
|
|
284
|
+
assert isinstance(target, str)
|
|
285
|
+
return target
|
|
286
|
+
|
|
287
|
+
@abstractmethod
|
|
288
|
+
def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
|
|
289
|
+
raise NotImplementedError
|
|
290
|
+
|
|
291
|
+
def _get_cue_text(self, item: dict[str, Any]) -> str:
|
|
292
|
+
return ""
|
|
293
|
+
|
|
294
|
+
def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
|
|
298
|
+
if self.FEWSHOT_SPLIT == self.SAMPLE_SPLIT:
|
|
299
|
+
fewshot_examples = self.rnd.sample(self.dataset[self.FEWSHOT_SPLIT], self.num_fewshot + 1)
|
|
300
|
+
fewshot_examples = [example for example in fewshot_examples if example != item]
|
|
301
|
+
fewshot_examples = fewshot_examples[: self.num_fewshot]
|
|
302
|
+
return fewshot_examples
|
|
303
|
+
else:
|
|
304
|
+
return self.rnd.sample(self.dataset[self.FEWSHOT_SPLIT], self.num_fewshot)
|
|
305
|
+
|
|
306
|
+
def _get_context(self, item: dict[str, Any]) -> BaseMetricContext | list[BaseMetricContext] | None:
|
|
307
|
+
return None
|
|
308
|
+
|
|
309
|
+
def get_metadata(self) -> dict[str, str | list[str]]:
|
|
310
|
+
return {
|
|
311
|
+
"dataset_path": self.DATASET_PATH,
|
|
312
|
+
"sample_split": self.SAMPLE_SPLIT,
|
|
313
|
+
"fewshot_split": self.FEWSHOT_SPLIT,
|
|
314
|
+
"response_type": self.RESPONSE_TYPE.value,
|
|
315
|
+
"metrics": [m.NAME for m in self.METRICS],
|
|
316
|
+
"subjects": [str(s) for s in self.SUBJECTS],
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
def generate_completions(
|
|
320
|
+
self,
|
|
321
|
+
llm: "BaseLLM",
|
|
322
|
+
samples: list[Sample],
|
|
323
|
+
stop_sequences: list[str] | None = None,
|
|
324
|
+
max_tokens: int | None = None,
|
|
325
|
+
) -> list[Completion]:
|
|
326
|
+
"""
|
|
327
|
+
Generates completions for the sample.
|
|
328
|
+
:param sample: sample to generate completions for
|
|
329
|
+
:param stop_sequences: stop sequences to use in completion generation
|
|
330
|
+
:param max_tokens: maximum tokens to use in completion generation
|
|
331
|
+
:return: completion
|
|
332
|
+
"""
|
|
333
|
+
if stop_sequences is None:
|
|
334
|
+
stop_sequences = []
|
|
335
|
+
|
|
336
|
+
raw_completions: list[RawCompletion]
|
|
337
|
+
try:
|
|
338
|
+
raw_completions = llm.generate(samples=samples, stop_sequences=stop_sequences, max_tokens=max_tokens)
|
|
339
|
+
except Exception as e:
|
|
340
|
+
if raise_errors():
|
|
341
|
+
raise e
|
|
342
|
+
logger.info(f"Error: {e.__class__.__name__} {e}")
|
|
343
|
+
assert len(samples) == 1, "LLMs not handling errors are not supported in batch mode"
|
|
344
|
+
raw_completions = [
|
|
345
|
+
RawCompletion(
|
|
346
|
+
prompt="",
|
|
347
|
+
prompt_sequence_positions=0,
|
|
348
|
+
completion="",
|
|
349
|
+
completion_sequence_positions=0,
|
|
350
|
+
raw_completion_error=Error(
|
|
351
|
+
error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
|
|
352
|
+
),
|
|
353
|
+
)
|
|
354
|
+
for _ in range(len(samples))
|
|
355
|
+
]
|
|
356
|
+
|
|
357
|
+
completion_list = []
|
|
358
|
+
for idx, sample in enumerate(samples):
|
|
359
|
+
raw_completion = raw_completions[idx]
|
|
360
|
+
|
|
361
|
+
if sample.messages and sample.messages[-1].role == Role.ASSISTANT:
|
|
362
|
+
messages = sample.messages[:-1] + [
|
|
363
|
+
Message(role=Role.ASSISTANT, content=sample.messages[-1].content + raw_completion.completion)
|
|
364
|
+
]
|
|
365
|
+
else:
|
|
366
|
+
messages = sample.messages + [Message(role=Role.ASSISTANT, content=raw_completion.completion)]
|
|
367
|
+
|
|
368
|
+
try:
|
|
369
|
+
error = None
|
|
370
|
+
model_post_processed_completion = llm.post_process_completion(raw_completion.completion, sample)
|
|
371
|
+
completion = self.post_process_generated_completion(model_post_processed_completion, sample)
|
|
372
|
+
except Exception as e:
|
|
373
|
+
error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
|
|
374
|
+
completion = ""
|
|
375
|
+
|
|
376
|
+
completion_list.append(
|
|
377
|
+
Completion(
|
|
378
|
+
id=sample.id,
|
|
379
|
+
subject=sample.subject,
|
|
380
|
+
ground_truth=sample.ground_truth,
|
|
381
|
+
prompt=raw_completion.prompt,
|
|
382
|
+
prompt_sequence_positions=raw_completion.prompt_sequence_positions,
|
|
383
|
+
concat_compression=raw_completion.concat_compression,
|
|
384
|
+
messages=messages,
|
|
385
|
+
completion=completion,
|
|
386
|
+
raw_completion=raw_completion.completion,
|
|
387
|
+
raw_completion_sequence_positions=raw_completion.completion_sequence_positions,
|
|
388
|
+
context=sample.context,
|
|
389
|
+
error=raw_completion.raw_completion_error or error,
|
|
390
|
+
)
|
|
391
|
+
)
|
|
392
|
+
return completion_list
|
|
File without changes
|