eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,227 @@
1
+ import re
2
+ from collections.abc import Callable, Sequence
3
+ from typing import Annotated, NamedTuple, Self, TypeVar, cast
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ from eval_framework.metrics.llm.graders.language import detect_language_of
8
+ from eval_framework.utils.helpers import count_bytes
9
+ from template_formatting.formatter import ConcatFormatter, Message, Role
10
+
11
+
12
+ class ConcatCompression(NamedTuple):
13
+ """Helper class for storing compression info for the concat formatter.
14
+
15
+ The concat formatter is used to avoid bias towards special tokens.
16
+ """
17
+
18
+ num_bytes: int
19
+ num_tokens: int
20
+
21
+ @classmethod
22
+ def calculate(
23
+ cls,
24
+ messages: Sequence[Message],
25
+ count_tokens: Callable[[str], int],
26
+ choices: list[str] | None = None,
27
+ completion: str | None = None,
28
+ ) -> Self | None:
29
+ """Calculate the compression info for the given messages and token counting function."""
30
+ if (choices is None) == (completion is None):
31
+ raise ValueError("Either possible_completions or completion must be provided, but not both.")
32
+ concat_str = ConcatFormatter().format(messages, output_mode="string")
33
+
34
+ if choices is not None:
35
+ if any(c is None for c in choices):
36
+ return None
37
+ num_bytes = count_bytes(concat_str) + sum(count_bytes(c) for c in choices)
38
+ num_tokens = count_tokens(concat_str) + sum(count_tokens(c) for c in choices)
39
+ else:
40
+ if completion is None:
41
+ return None
42
+ concat_str = f"{concat_str}{completion}"
43
+ num_bytes = count_bytes(concat_str)
44
+ num_tokens = count_tokens(concat_str)
45
+
46
+ res = cls(num_bytes=num_bytes, num_tokens=num_tokens)
47
+ if res.num_bytes > 0 and res.num_tokens > 0:
48
+ return res
49
+ else:
50
+ return None
51
+
52
+
53
+ class BaseMetricContext(BaseModel):
54
+ """Base class for metric context"""
55
+
56
+ model_config = ConfigDict(extra="allow")
57
+
58
+
59
+ class LanguageMetricContext(BaseMetricContext):
60
+ language: str
61
+
62
+
63
+ class UntemplatedPrompt(BaseMetricContext):
64
+ untemplated_prompt: str
65
+
66
+
67
+ class Error(BaseModel):
68
+ model_config = ConfigDict(extra="forbid")
69
+ error_class: str
70
+ message: str
71
+ traceback: str
72
+
73
+
74
+ class PromptTooLongException(Exception):
75
+ pass
76
+
77
+
78
+ class BaseCompletion(BaseModel):
79
+ model_config = ConfigDict(extra="forbid")
80
+ prompt: Annotated[str, "prompt as passed to the llm"]
81
+ prompt_sequence_positions: Annotated[
82
+ int | None,
83
+ "number of sequence positions that the prompt occupies in the llm architecture (e.g. token count) "
84
+ "or None if the info is not available",
85
+ ]
86
+ completion: Annotated[str, "completion as generated by the llm"]
87
+ concat_compression: Annotated[ConcatCompression | None, "Compression info for the concat formatter."] = None
88
+
89
+
90
+ class RawCompletion(BaseCompletion):
91
+ completion_sequence_positions: Annotated[
92
+ int | None,
93
+ "number of sequence positions that the completion occupies in the llm architecture "
94
+ "(e.g. token count) or None if the info is not available",
95
+ ]
96
+ raw_completion_error: Error | None = None
97
+
98
+
99
+ class Completion(BaseCompletion):
100
+ id: int
101
+ subject: str
102
+ ground_truth: str | None | list[str]
103
+ messages: list[Message] | None # needed for LLM as a judge
104
+ raw_completion: Annotated[str, "raw completion as generated by the llm"]
105
+ raw_completion_sequence_positions: Annotated[
106
+ int | None,
107
+ "number of sequence positions that the completion occupies in the llm architecture or None "
108
+ "if the info is not available",
109
+ ]
110
+ context: list[BaseMetricContext] | BaseMetricContext | None = None
111
+ error: Error | None = None
112
+
113
+ @property
114
+ def ground_truth_list(self) -> list[str] | list[None]:
115
+ if isinstance(self.ground_truth, list):
116
+ return self.ground_truth
117
+
118
+ return [self.ground_truth] # type: ignore[return-value]
119
+
120
+ # Use just the raw messages for instructions to LLM judges, not the original prompt with its special formatting.
121
+ # (see https://x.com/karpathy/status/1823418177197646104 for a motivation).
122
+ @property
123
+ def system_user_instruction(self) -> str:
124
+ assert self.messages is not None
125
+ return "\n\n".join([m.content for m in self.messages if m.role in (Role.SYSTEM, Role.USER)])
126
+
127
+ @property
128
+ def user_instruction(self) -> str:
129
+ assert self.messages is not None
130
+ return "\n\n".join([m.content for m in self.messages if m.role == Role.USER])
131
+
132
+ @property
133
+ def first_user_instruction(self) -> str:
134
+ assert self.messages is not None
135
+ user_messages = [m.content for m in self.messages if m.role == Role.USER]
136
+ return user_messages[0] if user_messages else ""
137
+
138
+ @property
139
+ def all_but_first_user_instruction(self) -> str:
140
+ assert self.messages is not None
141
+ user_messages = [m.content for m in self.messages if m.role == Role.USER]
142
+ return "\n\n".join(user_messages[1:]) if len(user_messages) > 1 else ""
143
+
144
+ @property
145
+ def last_user_instruction(self) -> str:
146
+ assert self.messages is not None
147
+ user_messages = [m.content for m in self.messages if m.role == Role.USER]
148
+ return user_messages[-1] if user_messages else ""
149
+
150
+ @property
151
+ def sanitized_completion(self) -> str:
152
+ # Make sure the completion doesn't contain any obvious special chars either by "breaking" any <|xyz|> pattern.
153
+ return re.sub(r"<\|(\S+)\|>", r"<| \1 |>", self.completion)
154
+
155
+ def get_completion_language(self) -> str:
156
+ detected_language = ""
157
+ if self.context and isinstance(self.context, LanguageMetricContext):
158
+ detected_language = self.context.language
159
+ else:
160
+ detected_language_object = detect_language_of(self.completion)
161
+ detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
162
+ return detected_language
163
+
164
+ def get_raw_completion_language(self) -> str:
165
+ detected_language = ""
166
+ if self.context and isinstance(self.context, LanguageMetricContext):
167
+ detected_language = self.context.language
168
+ else:
169
+ detected_language_object = detect_language_of(self.raw_completion)
170
+ detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
171
+ return detected_language
172
+
173
+ def get_instruction_language(self) -> str:
174
+ detected_language = ""
175
+ if self.context and isinstance(self.context, LanguageMetricContext):
176
+ detected_language = self.context.language
177
+ else:
178
+ detected_language_object = detect_language_of(self.user_instruction)
179
+ detected_language = detected_language_object.iso_code_639_1.name.lower() if detected_language_object else ""
180
+ return detected_language
181
+
182
+
183
+ class BaseLoglikelihood(BaseModel):
184
+ model_config = ConfigDict(extra="forbid")
185
+ prompt: str
186
+ prompt_sequence_positions: int | None
187
+ loglikelihoods: dict[str, float]
188
+ loglikelihoods_sequence_positions: dict[str, int] # Is empty if the model does not provide sequence positions
189
+ concat_compression: Annotated[ConcatCompression | None, "Compression info for the concat formatter"] = None
190
+
191
+
192
+ class RawLoglikelihood(BaseLoglikelihood):
193
+ raw_loglikelihood_error: Error | None = None
194
+
195
+
196
+ class Loglikelihood(BaseLoglikelihood):
197
+ id: int
198
+ subject: str
199
+ ground_truth: str | list[str]
200
+ error: Error | None = None
201
+
202
+ @property
203
+ def ground_truth_list(self) -> list[str] | list[None]:
204
+ if isinstance(self.ground_truth, list):
205
+ return self.ground_truth
206
+ return [self.ground_truth] # type: ignore[return-value]
207
+
208
+
209
+ MetricContext = TypeVar("MetricContext", bound=BaseMetricContext)
210
+
211
+
212
+ def extract_context_metric[MetricContext: BaseMetricContext](
213
+ response: Completion, metric_context_class: type[MetricContext]
214
+ ) -> MetricContext:
215
+ assert response.context is not None, "Expected context to be provided in the response"
216
+ if not isinstance(response.context, list):
217
+ assert isinstance(response.context, metric_context_class) or isinstance(response.context, BaseMetricContext), (
218
+ f"Expected context to be of type {metric_context_class.__name__}, got {type(response.context).__name__}"
219
+ )
220
+ return cast(MetricContext, response.context)
221
+ else:
222
+ assert len(response.context) > 0, "Expected context to be provided in the response"
223
+ context = [
224
+ metric_context for metric_context in response.context if isinstance(metric_context, metric_context_class)
225
+ ][0]
226
+ assert context is not None, f"Expected {metric_context_class.__name__} to be provided in the response context"
227
+ return cast(MetricContext, context)
@@ -0,0 +1,6 @@
1
+ # Register all tasks on import
2
+ from .task_names import register_all_tasks
3
+
4
+ register_all_tasks()
5
+
6
+ del register_all_tasks
@@ -0,0 +1,314 @@
1
+ import logging
2
+ import os
3
+ import random
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Iterable
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Self, TypeVar
9
+
10
+ import iso639
11
+ from datasets import DatasetDict, DownloadConfig, load_dataset
12
+ from huggingface_hub import HfApi
13
+ from huggingface_hub.errors import RevisionNotFoundError
14
+ from pydantic import BaseModel, ConfigDict
15
+
16
+ from eval_framework.shared.types import BaseMetricContext
17
+ from template_formatting.formatter import Message, Role
18
+
19
+ if TYPE_CHECKING:
20
+ from eval_framework.metrics.base import BaseMetric
21
+
22
+ RANDOM_SEED = 42
23
+ NO_SUBJECT = "no_subject"
24
+
25
+
26
+ class ResponseType(Enum):
27
+ COMPLETION = "completion"
28
+ LOGLIKELIHOODS = "loglikelihoods"
29
+
30
+
31
+ class Language(Enum):
32
+ ENG = "English"
33
+ DEU = "German"
34
+ FRA = "French"
35
+ ITA = "Italian"
36
+ SPA = "Spanish"
37
+ POR = "Portuguese"
38
+ NLD = "Dutch"
39
+ FIN = "Finnish"
40
+ SWE = "Swedish"
41
+ ARB = "Arabic"
42
+ POL = "Polish"
43
+ RUS = "Russian"
44
+ UKR = "Ukrainian"
45
+ HRV = "Croatian"
46
+ SRP = "Serbian"
47
+
48
+ @classmethod
49
+ def add_members(cls, new_members: dict[str, Any]) -> type["Language"]:
50
+ members = {member.name: member.value for member in cls}
51
+ for name, value in new_members.items():
52
+ if name not in members:
53
+ members[name] = value
54
+ return Enum(cls.__name__, members) # type: ignore[return-value]
55
+
56
+
57
+ languages: dict[str, str] = {}
58
+ for language in iso639.ALL_LANGUAGES:
59
+ enum_name = language.part3.upper()
60
+ languages[enum_name] = language.name
61
+
62
+ Language: type[Enum] = Language.add_members(languages) # type: ignore[no-redef]
63
+
64
+
65
+ class Sample(BaseModel):
66
+ model_config = ConfigDict(extra="forbid")
67
+ id: int
68
+ subject: str
69
+ messages: list[Message]
70
+ ground_truth: str | list[str] | None
71
+ possible_completions: list[str] | None
72
+ context: BaseMetricContext | list[BaseMetricContext] | None = None
73
+
74
+
75
+ SubjectType = TypeVar("SubjectType")
76
+
77
+ logger = logging.getLogger(__name__)
78
+
79
+
80
+ class BaseTask[SubjectType](ABC):
81
+ NAME: str
82
+ DATASET_PATH: str
83
+ SAMPLE_SPLIT: str
84
+ FEWSHOT_SPLIT: str
85
+ RESPONSE_TYPE: ResponseType
86
+ METRICS: list[type["BaseMetric"]]
87
+ SUBJECTS: list[SubjectType]
88
+ HF_REVISION: str | None = None # tag name, or branch name, or commit hash to ensure reproducibility
89
+
90
+ # Words in _get_instruction_text() not to be perturbed. List of words is case insensitive. No special characters
91
+ # or whitespace should be included.
92
+ PERTURBATION_UNMODIFIABLE_WORDS: list[str] | None
93
+
94
+ # The language (or languages) tested by the benchmark. Accepts a single string, a dictionary specifying
95
+ # language by subtopic, or `None` (for tasks not specific to a single language).
96
+ LANGUAGE: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None
97
+
98
+ def __init__(self, num_fewshot: int = 0) -> None:
99
+ self.num_fewshot = num_fewshot
100
+ self.stop_sequences: list[str] | None = None
101
+ self.max_tokens: int | None = None
102
+
103
+ @classmethod
104
+ def with_overwrite(
105
+ cls, num_fewshot: int, *, custom_subjects: list[str] | None, custom_hf_revision: str | None
106
+ ) -> Self:
107
+ instance = cls(num_fewshot=num_fewshot)
108
+
109
+ # If custom subjects were provided during initialization, they take precedence over the class-level SUBJECTS.
110
+ filtered_subjects = instance._filter_task_subjects(custom_subjects=custom_subjects)
111
+ if filtered_subjects:
112
+ logger.info(f"Setting SUBJECTS to `{filtered_subjects}` for the task {instance.__class__.__name__}")
113
+ instance.SUBJECTS = filtered_subjects # type: ignore[assignment]
114
+
115
+ # If a custom revision was provided during initialization, it takes precedence over the class-level HF_REVISION.
116
+ if custom_hf_revision:
117
+ logger.info(f"Setting HF revision to `{custom_hf_revision}` for the task {instance.__class__.__name__}")
118
+ instance.HF_REVISION = custom_hf_revision
119
+
120
+ return instance
121
+
122
+ def _filter_task_subjects(self, custom_subjects: list[str] | None) -> list[str] | list[tuple] | None:
123
+ """Process custom subjects passed from EvalConfig. Check and returns restricted task subjects if specified."""
124
+ if not custom_subjects:
125
+ return None
126
+
127
+ assert hasattr(self, "SUBJECTS") and len(self.SUBJECTS) > 0
128
+ if isinstance(self.SUBJECTS[0], tuple):
129
+ # subjects are specified as strings but we need tuples
130
+ filters = [tuple(item.strip() for item in subject.split(",")) for subject in custom_subjects]
131
+
132
+ # check if all parts of custom subjects exists (* is a wildcard)
133
+ num_items = len(self.SUBJECTS[0])
134
+ legal_values = [
135
+ set([s[i] for s in self.SUBJECTS if isinstance(s, tuple)] + ["*"]) for i in range(num_items)
136
+ ]
137
+
138
+ for tpl in filters:
139
+ for i, v in enumerate(tpl):
140
+ assert v in legal_values[i], f"Subject part {v} not found in task {self.__class__.__name__}"
141
+
142
+ # filter task subjects. * is a supported wildcard for a specific item in a tuple, e.g. "DE_DE, *"
143
+ chosen_subjects = []
144
+ for subject in self.SUBJECTS:
145
+ subject_tuple = subject if isinstance(subject, tuple) else tuple(str(subject).split(","))
146
+ for filter in filters:
147
+ if all(filter[i] == "*" or filter[i] == subject_tuple[i] for i in range(num_items)):
148
+ chosen_subjects.append(subject_tuple)
149
+ break
150
+ return chosen_subjects # type: ignore[return-value]
151
+ else:
152
+ for cs in custom_subjects:
153
+ assert cs in self.SUBJECTS, f"Subject {cs} not found in task {self.__class__.__name__}"
154
+ return custom_subjects # type: ignore[return-value]
155
+
156
+ def _load_hf_dataset(self, **kwargs: Any) -> Any:
157
+ # Check if the HF_REVISION is valid before loading the dataset
158
+ if self.HF_REVISION:
159
+ try:
160
+ _ = HfApi().dataset_info(repo_id=kwargs["path"], revision=self.HF_REVISION, timeout=100.0)
161
+ except Exception as e:
162
+ if isinstance(e, RevisionNotFoundError):
163
+ raise e
164
+
165
+ cache_dir: str = os.environ.get("HF_DATASET_CACHE_DIR", f"{Path.home()}/.cache/huggingface/datasets")
166
+ download_config = DownloadConfig(cache_dir=cache_dir, max_retries=5)
167
+ try:
168
+ return load_dataset(
169
+ **kwargs,
170
+ revision=self.HF_REVISION,
171
+ trust_remote_code=True,
172
+ cache_dir=cache_dir,
173
+ download_config=download_config,
174
+ )
175
+ except Exception:
176
+ return load_dataset(
177
+ **kwargs,
178
+ revision=self.HF_REVISION,
179
+ trust_remote_code=True,
180
+ cache_dir=f"{Path.home()}/.cache/eval-framework",
181
+ )
182
+
183
+ def _shuffle_splits(self, hf_dataset: DatasetDict) -> dict[str, list[dict[str, Any]]]:
184
+ dataset = {}
185
+ self.rnd = random.Random(RANDOM_SEED)
186
+
187
+ for split, data in hf_dataset.items():
188
+ if split not in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
189
+ continue
190
+
191
+ data_list = list(data)
192
+
193
+ if split == self.SAMPLE_SPLIT:
194
+ self.rnd.shuffle(data_list)
195
+
196
+ dataset[split] = data_list
197
+
198
+ return dataset
199
+
200
+ def _load_dataset(self, subject: SubjectType) -> None:
201
+ name = subject if subject != NO_SUBJECT else None
202
+ hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
203
+ self.dataset = self._shuffle_splits(hf_dataset=hf_dataset)
204
+
205
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
206
+ return completion_text
207
+
208
+ def _get_example_messages(self, item: dict[str, Any]) -> list[Message]:
209
+ fewshot_examples = self._sample_fewshot_examples(item) if self.num_fewshot > 0 else []
210
+
211
+ example_messages = []
212
+ for fewshot_example in fewshot_examples:
213
+ fewshot_example["subject"] = item["subject"]
214
+ example_messages.extend(self._get_instruction_messages(fewshot_example))
215
+ example_messages.append(
216
+ Message(role=Role.ASSISTANT, content=self._get_fewshot_target_text(fewshot_example))
217
+ )
218
+ return example_messages
219
+
220
+ def _get_messages(self, item: dict[str, Any]) -> list[Message]:
221
+ example_messages = self._get_example_messages(item)
222
+ instruction_message = self._get_instruction_messages(item)
223
+ cue_text = self._get_cue_text(item)
224
+ cue_message = [Message(role=Role.ASSISTANT, content=cue_text)] if cue_text else []
225
+ messages = example_messages + instruction_message + cue_message
226
+ if initial_prompt_text := self._get_initial_prompt_text(item):
227
+ first_message = messages[0]
228
+ assert first_message.role == Role.USER
229
+ first_message.content = f"{initial_prompt_text}\n\n{first_message.content}"
230
+
231
+ if system_prompt_text := self._get_system_prompt_text(item):
232
+ return [Message(role=Role.SYSTEM, content=system_prompt_text)] + messages
233
+ return messages
234
+
235
+ def _get_instruction_messages(self, item: dict[str, Any]) -> list[Message]:
236
+ return [Message(role=Role.USER, content=self._get_instruction_text(item))]
237
+
238
+ def iterate_samples(self, num_samples: int | None = None) -> Iterable[Sample]:
239
+ for subject in self.SUBJECTS:
240
+ self._load_dataset(subject)
241
+ assert len(self.dataset[self.SAMPLE_SPLIT]) > 0
242
+ done = False
243
+ index = 0
244
+ for item in self.dataset[self.SAMPLE_SPLIT]:
245
+ if done:
246
+ break
247
+ item["subject"] = subject
248
+ for sample in self._create_samples(item, index, str(subject)):
249
+ yield sample
250
+ index += 1
251
+ if index == num_samples:
252
+ done = True
253
+ break
254
+
255
+ def _create_samples(self, item: dict[str, Any], index: int, subject: str) -> list[Sample]:
256
+ """Creates one or more samples from a single dataset item. Default implementation returns single sample."""
257
+ return [
258
+ Sample(
259
+ id=index,
260
+ subject=str(subject),
261
+ messages=self._get_messages(item),
262
+ ground_truth=self._get_ground_truth(item),
263
+ possible_completions=self._get_possible_completions(item),
264
+ context=self._get_context(item),
265
+ )
266
+ ]
267
+
268
+ def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
269
+ return ""
270
+
271
+ def _get_system_prompt_text(self, item: dict[str, Any]) -> str | None:
272
+ return None
273
+
274
+ @abstractmethod
275
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
276
+ raise NotImplementedError
277
+
278
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
279
+ target = self._get_ground_truth(item)
280
+ assert target is not None
281
+ assert isinstance(target, str)
282
+ return target
283
+
284
+ @abstractmethod
285
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
286
+ raise NotImplementedError
287
+
288
+ def _get_cue_text(self, item: dict[str, Any]) -> str:
289
+ return ""
290
+
291
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
292
+ return None
293
+
294
+ def _sample_fewshot_examples(self, item: dict[str, Any]) -> list[dict]:
295
+ if self.FEWSHOT_SPLIT == self.SAMPLE_SPLIT:
296
+ fewshot_examples = self.rnd.sample(self.dataset[self.FEWSHOT_SPLIT], self.num_fewshot + 1)
297
+ fewshot_examples = [example for example in fewshot_examples if example != item]
298
+ fewshot_examples = fewshot_examples[: self.num_fewshot]
299
+ return fewshot_examples
300
+ else:
301
+ return self.rnd.sample(self.dataset[self.FEWSHOT_SPLIT], self.num_fewshot)
302
+
303
+ def _get_context(self, item: dict[str, Any]) -> BaseMetricContext | list[BaseMetricContext] | None:
304
+ return None
305
+
306
+ def get_metadata(self) -> dict[str, str | list[str]]:
307
+ return {
308
+ "dataset_path": self.DATASET_PATH,
309
+ "sample_split": self.SAMPLE_SPLIT,
310
+ "fewshot_split": self.FEWSHOT_SPLIT,
311
+ "response_type": self.RESPONSE_TYPE.value,
312
+ "metrics": [m.NAME for m in self.METRICS],
313
+ "subjects": [str(s) for s in self.SUBJECTS],
314
+ }
File without changes
@@ -0,0 +1,46 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
4
+ AccuracyLoglikelihood,
5
+ AccuracyNormLoglikelihood,
6
+ )
7
+ from eval_framework.tasks.base import BaseTask, Language, ResponseType
8
+ from eval_framework.tasks.utils import get_n_letters
9
+
10
+
11
+ class ARC(BaseTask[str]):
12
+ """ARC dataset: https://huggingface.co/datasets/allenai/ai2_arc"""
13
+
14
+ NAME = "ARC"
15
+ DATASET_PATH = "ai2_arc"
16
+ SAMPLE_SPLIT = "test"
17
+ FEWSHOT_SPLIT = "train"
18
+ RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
19
+ METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
20
+ SUBJECTS = ["ARC-Easy", "ARC-Challenge"]
21
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Question"] + get_n_letters(5)
22
+ LANGUAGE = Language.ENG
23
+
24
+ def __init__(self, num_fewshot: int = 0) -> None:
25
+ super().__init__(num_fewshot)
26
+
27
+ self.keys = get_n_letters(5) # needs to be 5 because there is one sample with 5 answer possibilities
28
+ self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
29
+
30
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
31
+ return f"Question: {item['question']}\n"
32
+
33
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
34
+ ground_truth = self._get_ground_truth(item)
35
+ assert ground_truth is not None
36
+ return f"{self._get_cue_text(item)}{ground_truth}"
37
+
38
+ def _get_cue_text(self, item: dict[str, Any]) -> str:
39
+ return "Answer:"
40
+
41
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
42
+ answer_key = self.num_to_letter.get(item["answerKey"], item["answerKey"])
43
+ return f" {item['choices']['text'][self.keys.index(answer_key)]}"
44
+
45
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
46
+ return [f" {choice}" for choice in item["choices"]["text"]]
@@ -0,0 +1,46 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
4
+ AccuracyLoglikelihood,
5
+ AccuracyNormLoglikelihood,
6
+ )
7
+ from eval_framework.tasks.base import NO_SUBJECT, BaseTask, Language, ResponseType
8
+ from eval_framework.tasks.utils import get_n_letters
9
+
10
+
11
+ class ARC_DE(BaseTask[str]):
12
+ """ARC-DE dataset: https://huggingface.co/datasets/LeoLM/ArcChallenge_de"""
13
+
14
+ NAME = "ARC German"
15
+ DATASET_PATH = "LeoLM/ArcChallenge_de"
16
+ SAMPLE_SPLIT = "test"
17
+ FEWSHOT_SPLIT = "validation"
18
+ RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
19
+ METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
20
+ SUBJECTS = [NO_SUBJECT]
21
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Frage"] + get_n_letters(5)
22
+ LANGUAGE = Language.DEU
23
+
24
+ def __init__(self, num_fewshot: int = 0) -> None:
25
+ super().__init__(num_fewshot)
26
+
27
+ self.keys = get_n_letters(5) # needs to be 5 because there is one sample with 5 answer possibilities
28
+ self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
29
+
30
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
31
+ return f"Frage: {item['question_de']}\n"
32
+
33
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
34
+ ground_truth = self._get_ground_truth(item)
35
+ assert ground_truth is not None
36
+ return f"{self._get_cue_text(item)}{ground_truth}"
37
+
38
+ def _get_cue_text(self, item: dict[str, Any]) -> str:
39
+ return "Antwort:"
40
+
41
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
42
+ answer_key = self.num_to_letter.get(item["answerKey"], item["answerKey"])
43
+ return f" {item['choices_de']['text'][self.keys.index(answer_key)]}"
44
+
45
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
46
+ return [f" {choice}" for choice in item["choices_de"]["text"]]
@@ -0,0 +1,46 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
4
+ AccuracyLoglikelihood,
5
+ AccuracyNormLoglikelihood,
6
+ )
7
+ from eval_framework.tasks.base import BaseTask, Language, ResponseType
8
+ from eval_framework.tasks.utils import get_n_letters
9
+
10
+
11
+ class ARC_FI(BaseTask[str]):
12
+ """ARC-FI dataset: https://huggingface.co/datasets/LumiOpen/arc_challenge_mt"""
13
+
14
+ NAME = "ARC Finnish"
15
+ DATASET_PATH = "LumiOpen/arc_challenge_mt"
16
+ SAMPLE_SPLIT = "test"
17
+ FEWSHOT_SPLIT = "validation"
18
+ RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
19
+ METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
20
+ SUBJECTS = ["fi"]
21
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Question"] + get_n_letters(5)
22
+ LANGUAGE = Language.FIN
23
+
24
+ def __init__(self, num_fewshot: int = 0) -> None:
25
+ super().__init__(num_fewshot)
26
+
27
+ self.keys = get_n_letters(5) # needs to be 5 because there is one sample with 5 answer possibilities
28
+ self.num_to_letter = {str(i): letter for i, letter in enumerate(self.keys, start=1)}
29
+
30
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
31
+ return f"Question: {item['question']}\n"
32
+
33
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
34
+ ground_truth = self._get_ground_truth(item)
35
+ assert ground_truth is not None
36
+ return f"{self._get_cue_text(item)}{ground_truth}"
37
+
38
+ def _get_cue_text(self, item: dict[str, Any]) -> str:
39
+ return "Answer:"
40
+
41
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
42
+ answer_key = self.num_to_letter.get(item["answerKey"], item["answerKey"])
43
+ return f" {item['choices']['text'][self.keys.index(answer_key)]}"
44
+
45
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
46
+ return [f" {choice}" for choice in item["choices"]["text"]]