eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,227 @@
1
+ import re
2
+ from collections.abc import Callable, Sequence
3
+ from typing import Annotated, NamedTuple, Self, TypeVar, cast
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ from eval_framework.metrics.llm.graders.language import detect_language_of
8
+ from eval_framework.utils.helpers import count_bytes
9
+ from template_formatting.formatter import ConcatFormatter, Message, Role
10
+
11
+
12
+ class ConcatCompression(NamedTuple):
13
+ """Helper class for storing compression info for the concat formatter.
14
+
15
+ The concat formatter is used to avoid bias towards special tokens.
16
+ """
17
+
18
+ num_bytes: int
19
+ num_tokens: int
20
+
21
+ @classmethod
22
+ def calculate(
23
+ cls,
24
+ messages: Sequence[Message],
25
+ count_tokens: Callable[[str], int],
26
+ choices: list[str] | None = None,
27
+ completion: str | None = None,
28
+ ) -> Self | None:
29
+ """Calculate the compression info for the given messages and token counting function."""
30
+ if (choices is None) == (completion is None):
31
+ raise ValueError("Either possible_completions or completion must be provided, but not both.")
32
+ concat_str = ConcatFormatter().format(messages, output_mode="string")
33
+
34
+ if choices is not None:
35
+ if any(c is None for c in choices):
36
+ return None
37
+ num_bytes = count_bytes(concat_str) + sum(count_bytes(c) for c in choices)
38
+ num_tokens = count_tokens(concat_str) + sum(count_tokens(c) for c in choices)
39
+ else:
40
+ if completion is None:
41
+ return None
42
+ concat_str = f"{concat_str}{completion}"
43
+ num_bytes = count_bytes(concat_str)
44
+ num_tokens = count_tokens(concat_str)
45
+
46
+ res = cls(num_bytes=num_bytes, num_tokens=num_tokens)
47
+ if res.num_bytes > 0 and res.num_tokens > 0:
48
+ return res
49
+ else:
50
+ return None
51
+
52
+
53
+ class BaseMetricContext(BaseModel):
54
+ """Base class for metric context"""
55
+
56
+ model_config = ConfigDict(extra="allow")
57
+
58
+
59
+ class LanguageMetricContext(BaseMetricContext):
60
+ language: str
61
+
62
+
63
+ class UntemplatedPrompt(BaseMetricContext):
64
+ untemplated_prompt: str
65
+
66
+
67
+ class Error(BaseModel):
68
+ model_config = ConfigDict(extra="forbid")
69
+ error_class: str
70
+ message: str
71
+ traceback: str
72
+
73
+
74
+ class PromptTooLongException(Exception):
75
+ pass
76
+
77
+
78
+ class BaseCompletion(BaseModel):
79
+ model_config = ConfigDict(extra="forbid")
80
+ prompt: Annotated[str, "prompt as passed to the llm"]
81
+ prompt_sequence_positions: Annotated[
82
+ int | None,
83
+ "number of sequence positions that the prompt occupies in the llm architecture (e.g. token count) "
84
+ "or None if the info is not available",
85
+ ]
86
+ completion: Annotated[str, "completion as generated by the llm"]
87
+ concat_compression: Annotated[ConcatCompression | None, "Compression info for the concat formatter."] = None
88
+
89
+
90
+ class RawCompletion(BaseCompletion):
91
+ completion_sequence_positions: Annotated[
92
+ int | None,
93
+ "number of sequence positions that the completion occupies in the llm architecture "
94
+ "(e.g. token count) or None if the info is not available",
95
+ ]
96
+ raw_completion_error: Error | None = None
97
+
98
+
99
+ class Completion(BaseCompletion):
100
+ id: int
101
+ subject: str
102
+ ground_truth: str | None | list[str]
103
+ messages: list[Message] | None # needed for LLM as a judge
104
+ raw_completion: Annotated[str, "raw completion as generated by the llm"]
105
+ raw_completion_sequence_positions: Annotated[
106
+ int | None,
107
+ "number of sequence positions that the completion occupies in the llm architecture or None "
108
+ "if the info is not available",
109
+ ]
110
+ context: list[BaseMetricContext] | BaseMetricContext | None = None
111
+ error: Error | None = None
112
+
113
+ @property
114
+ def ground_truth_list(self) -> list[str] | list[None]:
115
+ if isinstance(self.ground_truth, list):
116
+ return self.ground_truth
117
+
118
+ return [self.ground_truth] # type: ignore[return-value]
119
+
120
+ # Use just the raw messages for instructions to LLM judges, not the original prompt with its special formatting.
121
+ # (see https://x.com/karpathy/status/1823418177197646104 for a motivation).
122
+ @property
123
+ def system_user_instruction(self) -> str:
124
+ assert self.messages is not None
125
+ return "\n\n".join([m.content for m in self.messages if m.role in (Role.SYSTEM, Role.USER)])
126
+
127
+ @property
128
+ def user_instruction(self) -> str:
129
+ assert self.messages is not None
130
+ return "\n\n".join([m.content for m in self.messages if m.role == Role.USER])
131
+
132
+ @property
133
+ def first_user_instruction(self) -> str:
134
+ assert self.messages is not None
135
+ user_messages = [m.content for m in self.messages if m.role == Role.USER]
136
+ return user_messages[0] if user_messages else ""
137
+
138
+ @property
139
+ def all_but_first_user_instruction(self) -> str:
140
+ assert self.messages is not None
141
+ user_messages = [m.content for m in self.messages if m.role == Role.USER]
142
+ return "\n\n".join(user_messages[1:]) if len(user_messages) > 1 else ""
143
+
144
+ @property
145
+ def last_user_instruction(self) -> str:
146
+ assert self.messages is not None
147
+ user_messages = [m.content for m in self.messages if m.role == Role.USER]
148
+ return user_messages[-1] if user_messages else ""
149
+
150
+ @property
151
+ def sanitized_completion(self) -> str:
152
+ # Make sure the completion doesn't contain any obvious special chars either by "breaking" any <|xyz|> pattern.
153
+ return re.sub(r"<\|(\S+)\|>", r"<| \1 |>", self.completion)
154
+
155
+ def get_completion_language(self) -> str:
156
+ detected_language = ""
157
+ if self.context and isinstance(self.context, LanguageMetricContext):
158
+ detected_language = self.context.language
159
+ else:
160
+ detected_language_object = detect_language_of(self.completion)
161
+ detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
162
+ return detected_language
163
+
164
+ def get_raw_completion_language(self) -> str:
165
+ detected_language = ""
166
+ if self.context and isinstance(self.context, LanguageMetricContext):
167
+ detected_language = self.context.language
168
+ else:
169
+ detected_language_object = detect_language_of(self.raw_completion)
170
+ detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
171
+ return detected_language
172
+
173
+ def get_instruction_language(self) -> str:
174
+ detected_language = ""
175
+ if self.context and isinstance(self.context, LanguageMetricContext):
176
+ detected_language = self.context.language
177
+ else:
178
+ detected_language_object = detect_language_of(self.user_instruction)
179
+ detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
180
+ return detected_language
181
+
182
+
183
+ class BaseLoglikelihood(BaseModel):
184
+ model_config = ConfigDict(extra="forbid")
185
+ prompt: str
186
+ prompt_sequence_positions: int | None
187
+ loglikelihoods: dict[str, float]
188
+ loglikelihoods_sequence_positions: dict[str, int] # Is empty if the model does not provide sequence positions
189
+ concat_compression: Annotated[ConcatCompression | None, "Compression info for the concat formatter"] = None
190
+
191
+
192
+ class RawLoglikelihood(BaseLoglikelihood):
193
+ raw_loglikelihood_error: Error | None = None
194
+
195
+
196
+ class Loglikelihood(BaseLoglikelihood):
197
+ id: int
198
+ subject: str
199
+ ground_truth: str | list[str]
200
+ error: Error | None = None
201
+
202
+ @property
203
+ def ground_truth_list(self) -> list[str] | list[None]:
204
+ if isinstance(self.ground_truth, list):
205
+ return self.ground_truth
206
+ return [self.ground_truth] # type: ignore[return-value]
207
+
208
+
209
+ MetricContext = TypeVar("MetricContext", bound=BaseMetricContext)
210
+
211
+
212
+ def extract_context_metric[MetricContext: BaseMetricContext](
213
+ response: Completion, metric_context_class: type[MetricContext]
214
+ ) -> MetricContext:
215
+ assert response.context is not None, "Expected context to be provided in the response"
216
+ if not isinstance(response.context, list):
217
+ assert isinstance(response.context, metric_context_class) or isinstance(response.context, BaseMetricContext), (
218
+ f"Expected context to be of type {metric_context_class.__name__}, got {type(response.context).__name__}"
219
+ )
220
+ return cast(MetricContext, response.context)
221
+ else:
222
+ assert len(response.context) > 0, "Expected context to be provided in the response"
223
+ context = [
224
+ metric_context for metric_context in response.context if isinstance(metric_context, metric_context_class)
225
+ ][0]
226
+ assert context is not None, f"Expected {metric_context_class.__name__} to be provided in the response context"
227
+ return cast(MetricContext, context)
@@ -0,0 +1,6 @@
1
+ # Register all tasks on import
2
+ from .task_names import register_all_tasks
3
+
4
+ register_all_tasks()
5
+
6
+ del register_all_tasks
@@ -0,0 +1,392 @@
1
+ import logging
2
+ import os
3
+ import random
4
+ import traceback
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Iterable
7
+ from enum import Enum
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Any, Self, TypeVar
10
+
11
+ import iso639
12
+ from datasets import DatasetDict, DownloadConfig, load_dataset
13
+ from huggingface_hub import HfApi
14
+ from huggingface_hub.errors import RevisionNotFoundError
15
+ from pydantic import BaseModel, ConfigDict
16
+
17
+ from eval_framework.shared.types import BaseMetricContext, Completion, Error, RawCompletion
18
+ from eval_framework.tasks.utils import raise_errors
19
+ from template_formatting.formatter import Message, Role
20
+
21
+ if TYPE_CHECKING:
22
+ from eval_framework.llm.base import BaseLLM
23
+ from eval_framework.metrics.base import BaseMetric
24
+
25
+ RANDOM_SEED = 42
26
+ NO_SUBJECT = "no_subject"
27
+
28
+
29
+ class ResponseType(Enum):
30
+ COMPLETION = "completion"
31
+ LOGLIKELIHOODS = "loglikelihoods"
32
+
33
+
34
+ class Language(Enum):
35
+ ENG = "English"
36
+ DEU = "German"
37
+ FRA = "French"
38
+ ITA = "Italian"
39
+ SPA = "Spanish"
40
+ POR = "Portuguese"
41
+ NLD = "Dutch"
42
+ FIN = "Finnish"
43
+ SWE = "Swedish"
44
+ ARB = "Arabic"
45
+ POL = "Polish"
46
+ RUS = "Russian"
47
+ UKR = "Ukrainian"
48
+ HRV = "Croatian"
49
+ SRP = "Serbian"
50
+
51
+ @classmethod
52
+ def add_members(cls, new_members: dict[str, Any]) -> type["Language"]:
53
+ members = {member.name: member.value for member in cls}
54
+ for name, value in new_members.items():
55
+ if name not in members:
56
+ members[name] = value
57
+ return Enum(cls.__name__, members) # type: ignore[return-value]
58
+
59
+
60
+ languages: dict[str, str] = {}
61
+ for language in iso639.ALL_LANGUAGES:
62
+ enum_name = language.part3.upper()
63
+ languages[enum_name] = language.name
64
+
65
+ Language: type[Enum] = Language.add_members(languages) # type: ignore[no-redef]
66
+
67
+
68
+ class Sample(BaseModel):
69
+ model_config = ConfigDict(extra="forbid")
70
+ id: int
71
+ subject: str
72
+ messages: list[Message]
73
+ ground_truth: str | list[str] | None
74
+ possible_completions: list[str] | None
75
+ context: BaseMetricContext | list[BaseMetricContext] | None = None
76
+
77
+
78
+ SubjectType = TypeVar("SubjectType")
79
+
80
+ logger = logging.getLogger(__name__)
81
+
82
+
83
+ class BaseTask[SubjectType](ABC):
84
+ NAME: str
85
+ DATASET_PATH: str
86
+ SAMPLE_SPLIT: str
87
+ FEWSHOT_SPLIT: str
88
+ RESPONSE_TYPE: ResponseType
89
+ METRICS: list[type["BaseMetric"]]
90
+ SUBJECTS: list[SubjectType]
91
+ HF_REVISION: str | None = None # tag name, or branch name, or commit hash to ensure reproducibility
92
+
93
+ # Words in _get_instruction_text() not to be perturbed. List of words is case insensitive. No special characters
94
+ # or whitespace should be included.
95
+ PERTURBATION_UNMODIFIABLE_WORDS: list[str] | None
96
+
97
+ # The language (or languages) tested by the benchmark. Accepts a single string, a dictionary specifying
98
+ # language by subtopic, or `None` (for tasks not specific to a single language).
99
+ LANGUAGE: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None
100
+
101
+ def __init__(self, num_fewshot: int = 0) -> None:
102
+ self.num_fewshot = num_fewshot
103
+ self.stop_sequences: list[str] | None = None
104
+ self.max_tokens: int | None = None
105
+
106
+ @classmethod
107
+ def with_overwrite(
108
+ cls, num_fewshot: int, *, custom_subjects: list[str] | None, custom_hf_revision: str | None
109
+ ) -> Self:
110
+ instance = cls(num_fewshot=num_fewshot)
111
+
112
+ # If custom subjects were provided during initialization, they take precedence over the class-level SUBJECTS.
113
+ filtered_subjects = instance._filter_task_subjects(custom_subjects=custom_subjects)
114
+ if filtered_subjects:
115
+ logger.info(f"Setting SUBJECTS to `{filtered_subjects}` for the task {instance.__class__.__name__}")
116
+ instance.SUBJECTS = filtered_subjects # type: ignore[assignment]
117
+
118
+ # If a custom revision was provided during initialization, it takes precedence over the class-level HF_REVISION.
119
+ if custom_hf_revision:
120
+ logger.info(f"Setting HF revision to `{custom_hf_revision}` for the task {instance.__class__.__name__}")
121
+ instance.HF_REVISION = custom_hf_revision
122
+
123
+ return instance
124
+
125
+ def _filter_task_subjects(self, custom_subjects: list[str] | None) -> list[str] | list[tuple] | None:
126
+ """Process custom subjects passed from EvalConfig. Check and returns restricted task subjects if specified."""
127
+ if not custom_subjects:
128
+ return None
129
+
130
+ assert hasattr(self, "SUBJECTS") and len(self.SUBJECTS) > 0
131
+ if isinstance(self.SUBJECTS[0], tuple):
132
+ # subjects are specified as strings but we need tuples
133
+ filters = [tuple(item.strip() for item in subject.split(",")) for subject in custom_subjects]
134
+
135
+ # check if all parts of custom subjects exists (* is a wildcard)
136
+ num_items = len(self.SUBJECTS[0])
137
+ legal_values = [
138
+ set([s[i] for s in self.SUBJECTS if isinstance(s, tuple)] + ["*"]) for i in range(num_items)
139
+ ]
140
+
141
+ for tpl in filters:
142
+ for i, v in enumerate(tpl):
143
+ assert v in legal_values[i], f"Subject part {v} not found in task {self.__class__.__name__}"
144
+
145
+ # filter task subjects. * is a supported wildcard for a specific item in a tuple, e.g. "DE_DE, *"
146
+ chosen_subjects = []
147
+ for subject in self.SUBJECTS:
148
+ subject_tuple = subject if isinstance(subject, tuple) else tuple(str(subject).split(","))
149
+ for filter in filters:
150
+ if all(filter[i] == "*" or filter[i] == subject_tuple[i] for i in range(num_items)):
151
+ chosen_subjects.append(subject_tuple)
152
+ break
153
+ return chosen_subjects # type: ignore[return-value]
154
+ else:
155
+ for cs in custom_subjects:
156
+ assert cs in self.SUBJECTS, f"Subject {cs} not found in task {self.__class__.__name__}"
157
+ return custom_subjects # type: ignore[return-value]
158
+
159
+ def _load_hf_dataset(self, **kwargs: Any) -> Any:
160
+ # Check if the HF_REVISION is valid before loading the dataset
161
+ if self.HF_REVISION:
162
+ try:
163
+ _ = HfApi().dataset_info(repo_id=kwargs["path"], revision=self.HF_REVISION, timeout=100.0)
164
+ except Exception as e:
165
+ if isinstance(e, RevisionNotFoundError):
166
+ raise e
167
+
168
+ cache_dir: str = os.environ.get("HF_DATASET_CACHE_DIR", f"{Path.home()}/.cache/huggingface/datasets")
169
+ download_config = DownloadConfig(cache_dir=cache_dir, max_retries=5)
170
+ try:
171
+ return load_dataset(
172
+ **kwargs,
173
+ revision=self.HF_REVISION,
174
+ trust_remote_code=True,
175
+ cache_dir=cache_dir,
176
+ download_config=download_config,
177
+ )
178
+ except Exception:
179
+ return load_dataset(
180
+ **kwargs,
181
+ revision=self.HF_REVISION,
182
+ trust_remote_code=True,
183
+ cache_dir=f"{Path.home()}/.cache/eval-framework",
184
+ )
185
+
186
+ def _shuffle_splits(self, hf_dataset: DatasetDict) -> dict[str, Any]:
187
+ dataset = {}
188
+ self.rnd = random.Random(RANDOM_SEED)
189
+
190
+ for split, data in hf_dataset.items():
191
+ if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
192
+ continue
193
+
194
+ data_list = list(data)
195
+
196
+ if split == self.SAMPLE_SPLIT:
197
+ self.rnd.shuffle(data_list)
198
+
199
+ dataset[split] = data_list
200
+
201
+ return dataset
202
+
203
+ def _load_dataset(self, subject: SubjectType) -> None:
204
+ name = subject if subject != NO_SUBJECT else None
205
+ hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
206
+ self.dataset = self._shuffle_splits(hf_dataset=hf_dataset)
207
+
208
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
209
+ return completion_text
210
+
211
+ def _get_example_messages(self, item: dict[str, Any]) -> list[Message]:
212
+ fewshot_examples = self._sample_fewshot_examples(item) if self.num_fewshot > 0 else []
213
+
214
+ example_messages = []
215
+ for fewshot_example in fewshot_examples:
216
+ fewshot_example["subject"] = item["subject"]
217
+ example_messages.extend(self._get_instruction_messages(fewshot_example))
218
+ example_messages.append(
219
+ Message(role=Role.ASSISTANT, content=self._get_fewshot_target_text(fewshot_example))
220
+ )
221
+ return example_messages
222
+
223
+ def _get_messages(self, item: dict[str, Any]) -> list[Message]:
224
+ example_messages = self._get_example_messages(item)
225
+ instruction_message = self._get_instruction_messages(item)
226
+ cue_text = self._get_cue_text(item)
227
+ cue_message = [Message(role=Role.ASSISTANT, content=cue_text)] if cue_text else []
228
+ messages = example_messages + instruction_message + cue_message
229
+ if initial_prompt_text := self._get_initial_prompt_text(item):
230
+ first_message = messages[0]
231
+ assert first_message.role == Role.USER
232
+ first_message.content = f"{initial_prompt_text}\n\n{first_message.content}"
233
+
234
+ if system_prompt_text := self._get_system_prompt_text(item):
235
+ return [Message(role=Role.SYSTEM, content=system_prompt_text)] + messages
236
+ return messages
237
+
238
+ def _get_instruction_messages(self, item: dict[str, Any]) -> list[Message]:
239
+ return [Message(role=Role.USER, content=self._get_instruction_text(item))]
240
+
241
+ def iterate_samples(self, num_samples: int | None = None) -> Iterable[Sample]:
242
+ for subject in self.SUBJECTS:
243
+ self._load_dataset(subject)
244
+ assert len(self.dataset[self.SAMPLE_SPLIT]) > 0
245
+ done = False
246
+ index = 0
247
+ for item in self.dataset[self.SAMPLE_SPLIT]:
248
+ if done:
249
+ break
250
+ item["subject"] = subject
251
+ for sample in self._create_samples(item, index, str(subject)):
252
+ yield sample
253
+ index += 1
254
+ if index == num_samples:
255
+ done = True
256
+ break
257
+
258
+ def _create_samples(self, item: dict[str, Any], index: int, subject: str) -> list[Sample]:
259
+ """Creates one or more samples from a single dataset item. Default implementation returns single sample."""
260
+ return [
261
+ Sample(
262
+ id=index,
263
+ subject=str(subject),
264
+ messages=self._get_messages(item),
265
+ ground_truth=self._get_ground_truth(item),
266
+ possible_completions=self._get_possible_completions(item),
267
+ context=self._get_context(item),
268
+ )
269
+ ]
270
+
271
+ def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
272
+ return ""
273
+
274
+ def _get_system_prompt_text(self, item: dict[str, Any]) -> str | None:
275
+ return None
276
+
277
+ @abstractmethod
278
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
279
+ raise NotImplementedError
280
+
281
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
282
+ target = self._get_ground_truth(item)
283
+ assert target is not None
284
+ assert isinstance(target, str)
285
+ return target
286
+
287
+ @abstractmethod
288
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
289
+ raise NotImplementedError
290
+
291
+ def _get_cue_text(self, item: dict[str, Any]) -> str:
292
+ return ""
293
+
294
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
295
+ return None
296
+
297
+ def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
298
+ if self.FEWSHOT_SPLIT == self.SAMPLE_SPLIT:
299
+ fewshot_examples = self.rnd.sample(self.dataset[self.FEWSHOT_SPLIT], self.num_fewshot + 1)
300
+ fewshot_examples = [example for example in fewshot_examples if example != item]
301
+ fewshot_examples = fewshot_examples[: self.num_fewshot]
302
+ return fewshot_examples
303
+ else:
304
+ return self.rnd.sample(self.dataset[self.FEWSHOT_SPLIT], self.num_fewshot)
305
+
306
+ def _get_context(self, item: dict[str, Any]) -> BaseMetricContext | list[BaseMetricContext] | None:
307
+ return None
308
+
309
+ def get_metadata(self) -> dict[str, str | list[str]]:
310
+ return {
311
+ "dataset_path": self.DATASET_PATH,
312
+ "sample_split": self.SAMPLE_SPLIT,
313
+ "fewshot_split": self.FEWSHOT_SPLIT,
314
+ "response_type": self.RESPONSE_TYPE.value,
315
+ "metrics": [m.NAME for m in self.METRICS],
316
+ "subjects": [str(s) for s in self.SUBJECTS],
317
+ }
318
+
319
+ def generate_completions(
320
+ self,
321
+ llm: "BaseLLM",
322
+ samples: list[Sample],
323
+ stop_sequences: list[str] | None = None,
324
+ max_tokens: int | None = None,
325
+ ) -> list[Completion]:
326
+ """
327
+ Generates completions for the sample.
328
+ :param sample: sample to generate completions for
329
+ :param stop_sequences: stop sequences to use in completion generation
330
+ :param max_tokens: maximum tokens to use in completion generation
331
+ :return: completion
332
+ """
333
+ if stop_sequences is None:
334
+ stop_sequences = []
335
+
336
+ raw_completions: list[RawCompletion]
337
+ try:
338
+ raw_completions = llm.generate(samples=samples, stop_sequences=stop_sequences, max_tokens=max_tokens)
339
+ except Exception as e:
340
+ if raise_errors():
341
+ raise e
342
+ logger.info(f"Error: {e.__class__.__name__} {e}")
343
+ assert len(samples) == 1, "LLMs not handling errors are not supported in batch mode"
344
+ raw_completions = [
345
+ RawCompletion(
346
+ prompt="",
347
+ prompt_sequence_positions=0,
348
+ completion="",
349
+ completion_sequence_positions=0,
350
+ raw_completion_error=Error(
351
+ error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
352
+ ),
353
+ )
354
+ for _ in range(len(samples))
355
+ ]
356
+
357
+ completion_list = []
358
+ for idx, sample in enumerate(samples):
359
+ raw_completion = raw_completions[idx]
360
+
361
+ if sample.messages and sample.messages[-1].role == Role.ASSISTANT:
362
+ messages = sample.messages[:-1] + [
363
+ Message(role=Role.ASSISTANT, content=sample.messages[-1].content + raw_completion.completion)
364
+ ]
365
+ else:
366
+ messages = sample.messages + [Message(role=Role.ASSISTANT, content=raw_completion.completion)]
367
+
368
+ try:
369
+ error = None
370
+ model_post_processed_completion = llm.post_process_completion(raw_completion.completion, sample)
371
+ completion = self.post_process_generated_completion(model_post_processed_completion, sample)
372
+ except Exception as e:
373
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
374
+ completion = ""
375
+
376
+ completion_list.append(
377
+ Completion(
378
+ id=sample.id,
379
+ subject=sample.subject,
380
+ ground_truth=sample.ground_truth,
381
+ prompt=raw_completion.prompt,
382
+ prompt_sequence_positions=raw_completion.prompt_sequence_positions,
383
+ concat_compression=raw_completion.concat_compression,
384
+ messages=messages,
385
+ completion=completion,
386
+ raw_completion=raw_completion.completion,
387
+ raw_completion_sequence_positions=raw_completion.completion_sequence_positions,
388
+ context=sample.context,
389
+ error=raw_completion.raw_completion_error or error,
390
+ )
391
+ )
392
+ return completion_list
File without changes