eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,580 @@
1
+ import random
2
+ import re
3
+ from typing import Any
4
+
5
+ from eval_framework.metrics.completion.accuracy_completion import AccuracyCompletion
6
+ from eval_framework.metrics.completion.language_checker import LanguageRawConsistencyChecker
7
+ from eval_framework.metrics.completion.math_reasoning_completion import MathReasoningCompletion
8
+ from eval_framework.tasks.base import NO_SUBJECT, RANDOM_SEED, BaseTask, Language, ResponseType, Sample, SubjectType
9
+
10
+
11
+ class MATHReasoning(BaseTask[str]):
12
+ """AIME 2024 dataset: https://huggingface.co/datasets/HuggingFaceH4/aime_2024
13
+
14
+ This dataset contains a single train split of 30 questions.
15
+ Data contains
16
+ ID | Problem | Solution | Answer
17
+
18
+ pass@1 evaluation
19
+ """
20
+
21
+ RESPONSE_TYPE = ResponseType.COMPLETION
22
+ METRICS = [MathReasoningCompletion]
23
+ SUBJECTS = [NO_SUBJECT]
24
+ ANSWER_PATTERN = r"(?i)Answer\s*:\s*(.*)"
25
+ LANGUAGE = Language.ENG
26
+
27
+ def __init__(self, num_fewshot: int = 0) -> None:
28
+ super().__init__(num_fewshot)
29
+ # Max tokens are going to be determined by the model.
30
+ # however GPT paper and results used 1024 tokens, s1 used 2048
31
+
32
+ def _extract_answer(
33
+ self, string: str, extract_from_boxed: bool = True, extract_regex: str = ANSWER_PATTERN
34
+ ) -> str | None:
35
+ """Extract Answer String from \\boxed expression or based on regex"""
36
+ if not extract_from_boxed:
37
+ match = re.search(extract_regex, string)
38
+ if match:
39
+ return match.group(1)
40
+ return None
41
+
42
+ if "\\boxed" not in string and "\\fbox" not in string:
43
+ return None
44
+
45
+ idx_boxed = string.rfind("\\boxed")
46
+ idx_fbox = string.rfind("\\fbox")
47
+ idx = max(idx_boxed, idx_fbox)
48
+
49
+ i = idx
50
+ right_brace_idx = None
51
+ num_left_braces_open = 0
52
+ while i < len(string):
53
+ if string[i] == "{":
54
+ num_left_braces_open += 1
55
+ elif string[i] == "}":
56
+ num_left_braces_open -= 1
57
+ if num_left_braces_open == 0:
58
+ right_brace_idx = i
59
+ break
60
+ i += 1
61
+
62
+ if right_brace_idx is None:
63
+ retval = None
64
+ else:
65
+ retval = string[idx : right_brace_idx + 1]
66
+
67
+ if retval:
68
+ left = "\\boxed{"
69
+ try:
70
+ assert retval[: len(left)] == left
71
+ assert retval[-1] == "}"
72
+ return retval[len(left) : -1]
73
+ except AssertionError:
74
+ return None
75
+
76
+ return None
77
+
78
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
79
+ assert isinstance(completion_text, str)
80
+ extracted_answer = self._extract_answer(completion_text)
81
+ if extracted_answer is None:
82
+ normalized_answer = "[no_answer]"
83
+ else:
84
+ normalized_answer = self._strip_string(extracted_answer)
85
+ return normalized_answer
86
+
87
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
88
+ raise NotImplementedError("This method should be implemented in subclasses")
89
+
90
+ # The following code is coming from the Eleuther AI lm-evaluation-harness repository
91
+ # Subject to MIT License
92
+
93
+ # This needs a major refactoring but is kept as is for consistency with the original code
94
+
95
+ def _find_closing_bracket(self, string: str, start_index: int) -> int:
96
+ """
97
+ Finds the index of the closing '}' for a '{' at the given start index.
98
+
99
+ :param string: The input string containing '{' and '}' brackets.
100
+ :param start_index: The index where the opening '{' is located.
101
+ :return: The index of the corresponding closing '}' or -1 if not found.
102
+ """
103
+ if start_index < 0 or start_index >= len(string) or string[start_index] != "{":
104
+ raise ValueError("The start_index must point to a '{' character.")
105
+
106
+ depth = 0 # Track the nesting level of brackets
107
+ for i in range(start_index, len(string)):
108
+ if string[i] == "{":
109
+ depth += 1 # Increase depth for each opening bracket
110
+ elif string[i] == "}":
111
+ depth -= 1 # Decrease depth for each closing bracket
112
+ if depth == 0:
113
+ return i # Found the matching closing bracket
114
+
115
+ return -1 # No matching '}' found
116
+
117
+ def _split_text_command(self, string: str, search: str = r"\text{") -> tuple[str, str, str]:
118
+ """
119
+ Extracts the content inside a LaTeX \text{...} command and returns three parts:
120
+
121
+ 1. Everything before `\text{`
122
+ 2. The content inside `\text{...}`
123
+ 3. Everything after the closing `}`
124
+
125
+ :param string: The input LaTeX string.
126
+ :param search: The command to search for (default: `\text{`).
127
+ :return: Tuple (before_text, inside_text, after_text).
128
+ If no `\text{}` is found, returns (string, "", "").
129
+ If no closing bracket `}` is found, returns (before_text, remaining_string, "").
130
+ """
131
+ search_len = len(search)
132
+ search_start = string.find(search)
133
+
134
+ # If \text{ is not found, return the entire string in `before_text`
135
+ if search_start == -1:
136
+ return string, "", ""
137
+
138
+ # Ensure `{` follows the search term
139
+ content_start = search_start + search_len - 1
140
+ if content_start >= len(string) or string[content_start] != "{":
141
+ return string, "", ""
142
+
143
+ # Find the corresponding closing bracket
144
+ closing_index = self._find_closing_bracket(string, start_index=content_start)
145
+
146
+ # If no closing bracket is found, return remaining string as "inside_text"
147
+ if closing_index == -1:
148
+ return string[:search_start], string[content_start + 1 :], ""
149
+
150
+ before_text = string[:search_start] # Everything before `\text{`
151
+ inside_text = string[content_start + 1 : closing_index] # Content inside `\text{...}`
152
+ after_text = string[closing_index + 1 :] # Everything after the closing `}`
153
+
154
+ return before_text, inside_text, after_text
155
+
156
+ # https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py#L144
157
+ def _remove_right_units(self, string: str) -> str:
158
+ # "\text{ " only ever occurs (at least in the val set) when describing units
159
+ count = string.count(r"\text{")
160
+ if count == 0:
161
+ return string
162
+ elif count > 1:
163
+ content, *_ = string.split(r"\text{", maxsplit=1)
164
+ return content
165
+ elif count == 1:
166
+ before, inside, after = self._split_text_command(string)
167
+ if before.strip():
168
+ return before.strip()
169
+ elif after.strip():
170
+ return after.strip()
171
+ else:
172
+ return inside.strip()
173
+ else:
174
+ raise ValueError("Unexpected count of units in string")
175
+
176
+ # Based on https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py#L154
177
+ def _fix_sqrt(self, string: str) -> str:
178
+ if "\\sqrt" not in string:
179
+ return string
180
+ parts = string.split("\\sqrt")
181
+ new_string = parts[0]
182
+ for part in parts[1:]:
183
+ new_string += "\\sqrt{"
184
+ if part[0] != "{":
185
+ new_string += part[0] + "}"
186
+ new_string += part[1:]
187
+ return new_string
188
+
189
+ # Based on https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py#L97
190
+ def _fix_fracs(self, string: str) -> str:
191
+ parts = string.split("\\frac")
192
+ if len(parts) <= 1:
193
+ return string
194
+ new_str = parts[0]
195
+ for part in parts[1:]:
196
+ new_str += "\\frac"
197
+ if not part:
198
+ continue
199
+ if part[0] == "{":
200
+ new_str += part
201
+ else:
202
+ try:
203
+ assert len(part) >= 2
204
+ except AssertionError:
205
+ return string
206
+ a = part[0]
207
+ b = part[1]
208
+ new_str += "{" + a + "}{"
209
+ if b != "{":
210
+ new_str += b + "}"
211
+
212
+ if len(part) > 2:
213
+ new_str += part[2:]
214
+ return new_str
215
+
216
+ # Based on https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py#L129
217
+ def _fix_a_slash_b(self, string: str) -> str:
218
+ if len(string.split("/")) != 2:
219
+ return string
220
+ a, b = string.split("/")
221
+ try:
222
+ a_int = int(a)
223
+ b_int = int(b)
224
+ assert string == f"{a_int}/{b_int}"
225
+ new_string = "\\frac{" + str(a_int) + "}{" + str(b_int) + "}"
226
+ return new_string
227
+ except AssertionError:
228
+ return string
229
+ except ValueError:
230
+ return string
231
+
232
+ # Based on https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py#L169
233
+ def _strip_string(self, string: str) -> str:
234
+ replacements = [
235
+ (r"\n", ""), # linebreaks
236
+ (r"\\!", ""), # remove inverse spaces
237
+ (r"\\\\", "\\"), # replace \\ with \
238
+ (r"tfrac", "frac"), # replace tfrac with frac
239
+ (r"dfrac", "frac"), # replace dfrac with frac
240
+ (r"\\left", ""), # remove \left
241
+ (r"\\right", ""), # remove \right
242
+ (r"^{\\circ}", ""), # remove circ
243
+ (r"^\\circ", ""), # remove circ
244
+ (r"\\$", ""), # remove $
245
+ ]
246
+ for pattern, replacement in replacements:
247
+ string = string.replace(pattern, replacement)
248
+
249
+ # remove units (on the right)
250
+ string = self._remove_right_units(string)
251
+
252
+ replacements = [
253
+ (r"\\%", ""), # remove percentage
254
+ (r"\%", ""), # noqa: W605 # remove percentage
255
+ (r" .", " 0."), # " 0." equivalent to " ."
256
+ (r"{.", "{0."), # "{0." equivalent to "{."
257
+ ]
258
+ for pattern, replacement in replacements:
259
+ string = string.replace(pattern, replacement)
260
+
261
+ # if empty, return empty string
262
+ if len(string) == 0:
263
+ return string
264
+ # Add "0" if "." is the start of the string
265
+ if string[0] == ".":
266
+ string = "0" + string
267
+ # Get rid of e.g. "k = " or "x = y = " at beginning
268
+ parts = [s.strip() for s in string.split("=")]
269
+ if len(parts) == 2 and len(parts[0]) <= 2:
270
+ string = parts[1]
271
+ elif len(parts) > 2:
272
+ if all(len(part) <= 2 and re.match(r"^[a-zA-Z]\w*$", part) for part in parts[:-1]): # noqa: W605
273
+ string = parts[-1]
274
+
275
+ # fix sqrt3 --> sqrt{3}
276
+ string = self._fix_sqrt(string)
277
+
278
+ # remove spaces
279
+ string = string.replace(r" ", "")
280
+
281
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2},
282
+ # etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
283
+ string = self._fix_fracs(string)
284
+
285
+ # manually change 0.5 --> \frac{1}{2}
286
+ if string == "0.5":
287
+ string = "\\frac{1}{2}"
288
+
289
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
290
+ string = self._fix_a_slash_b(string)
291
+
292
+ def strip_leading_zero(s: str) -> str:
293
+ """strip leading zeros, but keep the first zero if it is a decimal"""
294
+ return re.sub(r"\b0(?=\d)", "", s)
295
+
296
+ # remove leading zeros
297
+ string = strip_leading_zero(string)
298
+
299
+ return string
300
+
301
+
302
+ class AIME2024(MATHReasoning):
303
+ """AIME 2024 dataset: https://huggingface.co/datasets/HuggingFaceH4/aime_2024
304
+
305
+ This dataset contains a single train split of 30 questions.
306
+ Data contains
307
+ ID | Problem | Solution | Answer
308
+
309
+ pass@1 evaluation
310
+ """
311
+
312
+ NAME = "AIME2024"
313
+ DATASET_PATH = "HuggingFaceH4/aime_2024"
314
+ SAMPLE_SPLIT = "train"
315
+ FEWSHOT_SPLIT = "train"
316
+ RESPONSE_TYPE = ResponseType.COMPLETION
317
+ METRICS = [MathReasoningCompletion, LanguageRawConsistencyChecker]
318
+ SUBJECTS = [NO_SUBJECT]
319
+ LANGUAGE = Language.ENG
320
+
321
+ # https://github.com/NVIDIA/NeMo-Skills/blob/main/nemo_skills/prompt/config/llama3-instruct/math.yaml
322
+ QUERY_TEMPLATE = """Solve the following math problem efficiently and clearly:
323
+
324
+ - For simple problems (2 steps or fewer):
325
+ Provide a concise solution with minimal explanation.
326
+
327
+ - For complex problems (3 steps or more):
328
+ Use this step-by-step format:
329
+
330
+ ## Step 1: [Concise description]
331
+ [Brief explanation and calculations]
332
+
333
+ ## Step 2: [Concise description]
334
+ [Brief explanation and calculations]
335
+
336
+ ...
337
+
338
+ Regardless of the approach, always conclude with:
339
+
340
+ Therefore, the final answer is: $\\boxed{{answer}}$. I hope it is correct.
341
+
342
+ Where [answer] is just the final number or expression that solves the problem.
343
+
344
+ Problem: {Question}""" # noqa: E501
345
+ ANSWER_PATTERN = r"Therefore, the final answer is:(.*?). I hope it is correct."
346
+
347
+ def __init__(self, num_fewshot: int = 0) -> None:
348
+ assert num_fewshot == 0, "AIME evaluation does not include few shot"
349
+ super().__init__(num_fewshot)
350
+
351
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
352
+ return self.QUERY_TEMPLATE.format(Question=item["problem"])
353
+
354
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
355
+ return item["answer"].lstrip("0") # valid answers in this dataset range from 0-999 and have leading zeros
356
+
357
+
358
+ class MATH500(MATHReasoning):
359
+ """MATH500 dataset: https://huggingface.co/datasets/HuggingFaceH4/MATH-500
360
+
361
+ This dataset contains a single test split of 500 questions.
362
+ Data contains
363
+
364
+ ID | Problem | Solution | Answer
365
+
366
+ pass@1 evaluation
367
+ """
368
+
369
+ NAME = "MATH500"
370
+ DATASET_PATH = "HuggingFaceH4/MATH-500"
371
+ SAMPLE_SPLIT = "test"
372
+ FEWSHOT_SPLIT = "test"
373
+ RESPONSE_TYPE = ResponseType.COMPLETION
374
+ METRICS = [MathReasoningCompletion, LanguageRawConsistencyChecker]
375
+ SUBJECTS = [NO_SUBJECT]
376
+ LANGUAGE = Language.ENG
377
+
378
+ # Adapted from OpenAI's math_eval.py (c) 2024 OpenAI – MIT License – https://github.com/openai/simple-evals/blob/main/math_eval.py
379
+ QUERY_TEMPLATE = """
380
+ Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
381
+
382
+ {Question}
383
+
384
+ Remember to put your answer in $\\boxed{{answer}}$
385
+
386
+ where [answer] is just the final number or expression that solves the problem.
387
+ """.strip() # noqa: E501
388
+
389
+ def __init__(self, num_fewshot: int = 0) -> None:
390
+ assert num_fewshot == 0, "MATH-500 evaluation does not include few shot"
391
+ super().__init__(num_fewshot)
392
+
393
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
394
+ extracted_answer_boxed = self._extract_answer(completion_text)
395
+ extracted_answer_unboxed = self._extract_answer(
396
+ completion_text, extract_from_boxed=False, extract_regex=self.ANSWER_PATTERN
397
+ )
398
+ # if there is no "boxed" answer but there is an "Answer: " answer, use the latter
399
+ extracted_answer = extracted_answer_boxed if extracted_answer_boxed is not None else extracted_answer_unboxed
400
+ if extracted_answer is None:
401
+ normalized_answer = "[no_answer]"
402
+ else:
403
+ normalized_answer = self._strip_string(extracted_answer)
404
+ return normalized_answer
405
+
406
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
407
+ return self.QUERY_TEMPLATE.format(Question=item["problem"])
408
+
409
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
410
+ return item["answer"]
411
+
412
+
413
+ class MATH(MATHReasoning):
414
+ """MATH dataset: https://huggingface.co/datasets/EleutherAI/hendrycks_math"""
415
+
416
+ NAME = "Math"
417
+ DATASET_PATH = "EleutherAI/hendrycks_math"
418
+ SAMPLE_SPLIT = "test"
419
+ FEWSHOT_SPLIT = "train"
420
+ RESPONSE_TYPE = ResponseType.COMPLETION
421
+ METRICS = [MathReasoningCompletion, LanguageRawConsistencyChecker]
422
+ SUBJECTS = [
423
+ "algebra",
424
+ "counting_and_probability",
425
+ "geometry",
426
+ "intermediate_algebra",
427
+ "number_theory",
428
+ "prealgebra",
429
+ "precalculus",
430
+ ]
431
+ LANGUAGE = Language.ENG
432
+
433
+ # Adapted from OpenAI's math_eval.py (c) 2024 OpenAI – MIT License – https://github.com/openai/simple-evals/blob/main/math_eval.py
434
+ QUERY_TEMPLATE = """
435
+ Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
436
+
437
+ {Question}
438
+
439
+ Remember to put your answer in $\\boxed{{answer}}$
440
+
441
+ where [answer] is just the final number or expression that solves the problem.
442
+ """.strip() # noqa: E501
443
+
444
+ def __init__(self, num_fewshot: int = 0) -> None:
445
+ super().__init__(num_fewshot)
446
+ self.stop_sequences = ["\nProblem:", "\nProblem", "\n\nProblem:", "\n\nProblem"]
447
+
448
+ def extract_last_two_dollar_text(self, s: str) -> str:
449
+ """
450
+ extract_last_two_dollar_text finds text between the last two dollar signs in a string
451
+ :param s: the string to extract text from
452
+ :returns: the extracted text
453
+ """
454
+ finds = re.findall(r"\$(.*?)\$", s)
455
+ match = "" if len(finds) == 0 else finds[-1]
456
+ return match
457
+
458
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
459
+ """
460
+ post_process_generated_completion extracts via flex extraction/matching.
461
+ if there is a boxed answer, then this gets used first
462
+ if there is no boxed answer, and latex math symbols ("$") then this will be extracted and used
463
+ if there is an answer text ("Answer:") then this will be used last
464
+
465
+ """
466
+ extracted_answer_boxed = self._extract_answer(completion_text)
467
+ extracted_answer_latex_math_symb = self._extract_answer(self.extract_last_two_dollar_text(completion_text))
468
+ extracted_answer_unboxed = self._extract_answer(
469
+ completion_text, extract_from_boxed=False, extract_regex=self.ANSWER_PATTERN
470
+ )
471
+ # if there is no "boxed" answer but there is an "Answer: " answer, use the latter
472
+ if extracted_answer_boxed:
473
+ normalized_answer = self._strip_string(extracted_answer_boxed)
474
+ elif extracted_answer_latex_math_symb:
475
+ normalized_answer = self._strip_string(extracted_answer_latex_math_symb)
476
+ elif extracted_answer_unboxed:
477
+ normalized_answer = self._strip_string(extracted_answer_unboxed)
478
+ else:
479
+ normalized_answer = "[no_answer]"
480
+ return normalized_answer
481
+
482
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
483
+ return self.QUERY_TEMPLATE.format(Question=item["problem"]) + "\n"
484
+
485
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
486
+ return f"Answer: {item['solution']}"
487
+
488
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
489
+ return self._extract_answer(item["solution"])
490
+
491
+
492
+ class MATHLvl5(MATH):
493
+ NAME = "Math Lvl 5"
494
+
495
+ def _load_dataset(self, subject: SubjectType) -> None:
496
+ name = subject if subject != NO_SUBJECT else None
497
+
498
+ hf_dataset = self._load_hf_dataset(path=self.DATASET_PATH, name=name)
499
+ self.dataset = {}
500
+
501
+ self.rnd = random.Random(RANDOM_SEED)
502
+
503
+ for split, data in hf_dataset.items():
504
+ data_list = list(data)
505
+
506
+ if split == self.SAMPLE_SPLIT:
507
+ self.rnd.shuffle(data_list)
508
+
509
+ if split in [self.SAMPLE_SPLIT, self.FEWSHOT_SPLIT]:
510
+ self.dataset[split] = [item for item in data_list if item["level"] == "Level 5"]
511
+
512
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None | list[str]:
513
+ return self._extract_answer(item["solution"])
514
+
515
+
516
+ class GSM8KReasoning(MATHReasoning):
517
+ """GSM8K dataset with reasoning prompt: https://huggingface.co/datasets/openai/gsm8k
518
+
519
+ Zero-shot reasoning version that expects answers in boxed format.
520
+ """
521
+
522
+ NAME = "GSM8KReasoning"
523
+ DATASET_PATH = "gsm8k"
524
+ SAMPLE_SPLIT = "test"
525
+ FEWSHOT_SPLIT = "train"
526
+ RESPONSE_TYPE = ResponseType.COMPLETION
527
+ METRICS = [AccuracyCompletion, LanguageRawConsistencyChecker]
528
+ SUBJECTS = ["main"]
529
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Question", "Answer"]
530
+ LANGUAGE = Language.ENG
531
+
532
+ # Reasoning prompt template that encourages step-by-step thinking with boxed answers
533
+ QUERY_TEMPLATE = """\
534
+ Solve the following math problem step by step. Think through the problem carefully and show your reasoning.
535
+
536
+ Please provide your answer in the format: $\\boxed{{answer}}$ where answer is the final numerical result.
537
+
538
+ Question: {question}
539
+
540
+ Answer:"""
541
+
542
+ def __init__(self, num_fewshot: int = 0) -> None:
543
+ assert num_fewshot == 0, "GSM8K Reasoning is designed for zero-shot evaluation only"
544
+ super().__init__(num_fewshot)
545
+ self.stop_sequences: list[str] = []
546
+
547
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
548
+ for stop_sequence in self.stop_sequences:
549
+ if stop_sequence in completion_text:
550
+ completion_text = completion_text.split(stop_sequence)[0]
551
+ return self._extract_answer_with_fallback(completion_text)
552
+
553
+ def _extract_answer_fallback(self, completion: str) -> str:
554
+ """Fallback answer extraction using #### pattern for compatibility"""
555
+ ans_re = re.compile(r"#### (\-?[0-9\.\,]+)")
556
+ match = ans_re.search(completion)
557
+ if match:
558
+ match_str = match.group(1).strip()
559
+ match_str = match_str.replace(",", "")
560
+ return match_str
561
+ else:
562
+ return "[invalid]"
563
+
564
+ def _extract_answer_with_fallback(self, completion: str) -> str:
565
+ """Extract answer from completion, trying boxed format first, then fallback"""
566
+ # Try boxed format first
567
+ boxed_answer = self._extract_answer(completion)
568
+ if boxed_answer is not None:
569
+ # Clean the answer by removing commas and whitespace
570
+ cleaned_answer = boxed_answer.replace(",", "").strip()
571
+ return cleaned_answer
572
+
573
+ # Fallback to #### pattern
574
+ return self._extract_answer_fallback(completion)
575
+
576
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
577
+ return self.QUERY_TEMPLATE.format(question=item["question"])
578
+
579
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
580
+ return self._extract_answer_fallback(item["answer"])