eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import signal
|
|
3
|
+
from collections.abc import Callable, Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sympy import Basic, S, SympifyError, factor, simplify
|
|
7
|
+
from sympy.parsing.latex import parse_latex
|
|
8
|
+
from sympy.parsing.latex.errors import LaTeXParsingError
|
|
9
|
+
|
|
10
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
11
|
+
from eval_framework.shared.types import Completion
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def timeout_handler(signum: Any, frame: Any) -> None:
|
|
15
|
+
raise TimeoutError()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MathReasoningCompletion(BaseMetric[Completion]):
|
|
19
|
+
#
|
|
20
|
+
# Math Reasoning Completion (symbolic)
|
|
21
|
+
#
|
|
22
|
+
# This metric evaluates the correctness of the completion of a math reasoning task without
|
|
23
|
+
# correcting LaTeX expressions. Normalization occurs on the strings, only to remove formatting
|
|
24
|
+
# and units.
|
|
25
|
+
#
|
|
26
|
+
# The metric is designed to evaluate the correctness of the completion of a math reasoning task
|
|
27
|
+
# without correcting LaTeX expressions.
|
|
28
|
+
#
|
|
29
|
+
|
|
30
|
+
NAME = "Math Reasoning Completion (symbolic)"
|
|
31
|
+
|
|
32
|
+
# Substitutions to apply to the final answer
|
|
33
|
+
SUBSTITUTIONS = [
|
|
34
|
+
(r"\ban\b(?!\w)", ""), # Remove "an" if not part of a word
|
|
35
|
+
(r"\ba\b(?!\w)", ""), # Remove "a" if not part of a word
|
|
36
|
+
(r"\.\$", "$"), # Replace ".$" with "$"
|
|
37
|
+
(r"\\\$", ""), # Remove "\$"
|
|
38
|
+
(r"\\ ", ""), # Remove "\ " (escaped space)
|
|
39
|
+
(r"\s+", ""), # Remove all spaces
|
|
40
|
+
(r"\\mbox", "text"), # Replace "\mbox" with "text"
|
|
41
|
+
(r",\\text\{and\}", ","), # Replace ",\text{and}" with ","
|
|
42
|
+
(r"\\text\{and\}", ","), # Replace "\text{and}" with ","
|
|
43
|
+
(r"\\text\{m\}", "\\text{}"), # Replace "\text{m}" with "\text{}"
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
# Expressions to remove from the final answer
|
|
47
|
+
# Most of these expressions omit units and formatting
|
|
48
|
+
# which the ground truth does not have
|
|
49
|
+
REMOVED_EXPRESSIONS_UNITS = [
|
|
50
|
+
"square",
|
|
51
|
+
"ways",
|
|
52
|
+
"integers",
|
|
53
|
+
"dollars",
|
|
54
|
+
"mph",
|
|
55
|
+
"inches",
|
|
56
|
+
"ft",
|
|
57
|
+
"hours",
|
|
58
|
+
"km",
|
|
59
|
+
"units",
|
|
60
|
+
"\\ldots",
|
|
61
|
+
"sue",
|
|
62
|
+
"points",
|
|
63
|
+
"feet",
|
|
64
|
+
"minutes",
|
|
65
|
+
"digits",
|
|
66
|
+
"cents",
|
|
67
|
+
"degrees",
|
|
68
|
+
"cm",
|
|
69
|
+
"gm",
|
|
70
|
+
"pounds",
|
|
71
|
+
"meters",
|
|
72
|
+
"meals",
|
|
73
|
+
"edges",
|
|
74
|
+
"students",
|
|
75
|
+
"childrentickets",
|
|
76
|
+
"multiples",
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
REMOVED_EXPRESSIONS_FORMAT = [
|
|
80
|
+
"\\text{s}",
|
|
81
|
+
"\\text{.}",
|
|
82
|
+
"\\text{\ns}",
|
|
83
|
+
"\\text{}^2",
|
|
84
|
+
"\\text{}^3",
|
|
85
|
+
"\\text{\n}",
|
|
86
|
+
"\\text{}",
|
|
87
|
+
r"\mathrm{th}",
|
|
88
|
+
r"^\circ",
|
|
89
|
+
r"^{\circ}",
|
|
90
|
+
r"\;",
|
|
91
|
+
r",\!",
|
|
92
|
+
"{,}",
|
|
93
|
+
'"',
|
|
94
|
+
"\\dots",
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
def normalize_expression(self, final_answer: str) -> str:
|
|
98
|
+
"""
|
|
99
|
+
Function to normalize LaTeX expressions
|
|
100
|
+
:param final_answer: raw LaTeX expression
|
|
101
|
+
:return: normalized LaTeX expression
|
|
102
|
+
NOTE: Changed logic, because before the substitution randomly replaced characters in the string,
|
|
103
|
+
i.e., turned "infty" into "iny" by removing "ft"
|
|
104
|
+
"""
|
|
105
|
+
for before, after in self.SUBSTITUTIONS:
|
|
106
|
+
final_answer = re.sub(before, after, final_answer)
|
|
107
|
+
for expr in self.REMOVED_EXPRESSIONS_UNITS:
|
|
108
|
+
# Safely remove units at the end, allowing optional space before the unit
|
|
109
|
+
final_answer = re.sub(rf"(.*?)\s*({re.escape(expr)})$", r"\1", final_answer)
|
|
110
|
+
for expr in self.REMOVED_EXPRESSIONS_FORMAT:
|
|
111
|
+
# Safely remove formatting expressions
|
|
112
|
+
final_answer = final_answer.replace(expr, "")
|
|
113
|
+
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", r"$\3$", final_answer)
|
|
114
|
+
final_answer = re.sub(r"(\\text\{)(.*?)(\})", r"\2", final_answer)
|
|
115
|
+
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", r"\2", final_answer)
|
|
116
|
+
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", r"\2", final_answer)
|
|
117
|
+
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", r"\2", final_answer)
|
|
118
|
+
final_answer = re.sub(r"(frac)([^{])(.)", r"frac{\2}{\3}", final_answer)
|
|
119
|
+
final_answer = re.sub(r"(sqrt)([^{])", r"sqrt{\2}", final_answer)
|
|
120
|
+
final_answer = final_answer.replace("$", "")
|
|
121
|
+
# Only strip commas if it's a single numeric value with optional commas (like "1,000")
|
|
122
|
+
if re.fullmatch(r"\d{1,3}(,\d{3})*", final_answer):
|
|
123
|
+
final_answer = final_answer.replace(",", "")
|
|
124
|
+
return final_answer
|
|
125
|
+
|
|
126
|
+
def check_for_equation(self, final_answer: str) -> list:
|
|
127
|
+
"""
|
|
128
|
+
Check if the final answer is an equation and split it into left hand side and right hand side
|
|
129
|
+
:param final_answer: the expression to evaluate
|
|
130
|
+
:return: list of left hand side and right hand side of the equation
|
|
131
|
+
"""
|
|
132
|
+
if isinstance(final_answer, str) and "=" in final_answer:
|
|
133
|
+
return final_answer.split("=")
|
|
134
|
+
else:
|
|
135
|
+
return [final_answer]
|
|
136
|
+
|
|
137
|
+
def _safe_simplify_expression(self, expression: Basic, timeout: int = 10) -> Basic:
|
|
138
|
+
"""
|
|
139
|
+
Simplify an expression with a timeout and catch recursion depth exception
|
|
140
|
+
:param expression: SymPy expression
|
|
141
|
+
:param timeout: Time limit in seconds (default: 10 seconds).
|
|
142
|
+
:return: simplified expressions
|
|
143
|
+
"""
|
|
144
|
+
signal.signal(signal.SIGALRM, timeout_handler) # Set timeout signal
|
|
145
|
+
signal.alarm(timeout) # Set timeout duration
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
factored = factor(expression)
|
|
149
|
+
simplified = simplify(factored)
|
|
150
|
+
signal.alarm(0) # Disable timeout
|
|
151
|
+
return simplified
|
|
152
|
+
except (SympifyError, TimeoutError):
|
|
153
|
+
return S.NaN
|
|
154
|
+
|
|
155
|
+
def _any_symb_correct(self, response_list: Iterable[Basic], ground_truth_list: Iterable[Basic]) -> bool:
|
|
156
|
+
"""
|
|
157
|
+
Check if any of the responses are correct and return true at first match
|
|
158
|
+
:param response_list: list of responses
|
|
159
|
+
:param ground_truth_list: list of ground truths
|
|
160
|
+
:return: True if any response is correct
|
|
161
|
+
"""
|
|
162
|
+
for answer in response_list:
|
|
163
|
+
for ground_truth in ground_truth_list:
|
|
164
|
+
try:
|
|
165
|
+
unsimplified_difference = answer - ground_truth
|
|
166
|
+
# check if the difference is close to zero with numpy
|
|
167
|
+
difference = self._safe_simplify_expression(unsimplified_difference)
|
|
168
|
+
tolerance = 1e-12
|
|
169
|
+
if abs(difference) < tolerance:
|
|
170
|
+
return True
|
|
171
|
+
except ValueError:
|
|
172
|
+
# equations cannot be evaluated against each other
|
|
173
|
+
return False
|
|
174
|
+
return False
|
|
175
|
+
|
|
176
|
+
def _apply_safely(self, func: Callable[[Basic], Basic], list_of_expressions: list[Basic]) -> None:
|
|
177
|
+
"""
|
|
178
|
+
apply safely to a list of expressions and replace the original expressions
|
|
179
|
+
:param list_of_expressions: list of sympy expressions
|
|
180
|
+
"""
|
|
181
|
+
for i, expression in enumerate(list_of_expressions):
|
|
182
|
+
try:
|
|
183
|
+
list_of_expressions[i] = func(expression)
|
|
184
|
+
except RecursionError:
|
|
185
|
+
list_of_expressions[i] = S.NaN
|
|
186
|
+
|
|
187
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
188
|
+
"""
|
|
189
|
+
Calculate the accuracy of the completion
|
|
190
|
+
|
|
191
|
+
performs several verification and simplification steps
|
|
192
|
+
to ensure that the completion is correct
|
|
193
|
+
|
|
194
|
+
the completion may either be a latex or string response
|
|
195
|
+
which sympy will parse, factor, and simplify
|
|
196
|
+
|
|
197
|
+
:param response: Completion object
|
|
198
|
+
:return: list of MetricResult
|
|
199
|
+
"""
|
|
200
|
+
ground_truths = []
|
|
201
|
+
INVALID_ANSWER = S.NaN
|
|
202
|
+
timeout = 10
|
|
203
|
+
# latex parse all ingested ground truth values for math reasoning
|
|
204
|
+
for gt in response.ground_truth_list:
|
|
205
|
+
signal.signal(signal.SIGALRM, timeout_handler) # Set timeout signal
|
|
206
|
+
signal.alarm(timeout) # Set timeout duration
|
|
207
|
+
try:
|
|
208
|
+
gt_parsed = parse_latex(gt) # NOTE: parses f(x)=0,\quadf(x)=x-1,\quadf(x)=-x+1 to Eq(f(x), 0) ONLY
|
|
209
|
+
ground_truths.append(gt_parsed)
|
|
210
|
+
signal.alarm(0)
|
|
211
|
+
except Exception:
|
|
212
|
+
ground_truths.append(gt)
|
|
213
|
+
normalized_response = self.normalize_expression(response.completion)
|
|
214
|
+
response_list = self.check_for_equation(normalized_response)
|
|
215
|
+
try:
|
|
216
|
+
symb_is_correct = self._is_symbolically_equiv(response_list, ground_truths, INVALID_ANSWER)
|
|
217
|
+
except Exception:
|
|
218
|
+
symb_is_correct = False
|
|
219
|
+
|
|
220
|
+
# check if already correct symbolically
|
|
221
|
+
if symb_is_correct:
|
|
222
|
+
return [
|
|
223
|
+
MetricResult(
|
|
224
|
+
metric_name=self.NAME, value=float(symb_is_correct), higher_is_better=True, error=response.error
|
|
225
|
+
)
|
|
226
|
+
]
|
|
227
|
+
else:
|
|
228
|
+
# fall back to string comparison
|
|
229
|
+
# ground truth can be list or str, we have str comparisons
|
|
230
|
+
assert isinstance(response.ground_truth, str)
|
|
231
|
+
str_is_correct = self._is_str_correct(normalized_response, response.ground_truth)
|
|
232
|
+
return [
|
|
233
|
+
MetricResult(
|
|
234
|
+
metric_name=self.NAME, value=float(str_is_correct), higher_is_better=True, error=response.error
|
|
235
|
+
)
|
|
236
|
+
]
|
|
237
|
+
|
|
238
|
+
def _any_str_correct(self, response_list: list, ground_truths: list) -> bool:
|
|
239
|
+
"""
|
|
240
|
+
Check if any of the responses are correct and return true at first match
|
|
241
|
+
:param response_list: list of responses
|
|
242
|
+
:param ground_truths: list of ground truths
|
|
243
|
+
:return: True if any response is correct
|
|
244
|
+
"""
|
|
245
|
+
for response in response_list:
|
|
246
|
+
for ground_truth in ground_truths:
|
|
247
|
+
if self._is_str_correct(response, ground_truth):
|
|
248
|
+
return True
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
def _is_str_correct(self, str1: str, str2: str) -> bool:
|
|
252
|
+
"""
|
|
253
|
+
Check if two strings are equal after stripping
|
|
254
|
+
:param str1: first string
|
|
255
|
+
:param str2: second string
|
|
256
|
+
:param verbose: print the stripped strings
|
|
257
|
+
:return: True if the strings are equal
|
|
258
|
+
"""
|
|
259
|
+
# if multiple equal signs in ground truth (str2)
|
|
260
|
+
# slide the response (str1) over the ground truth (str2)
|
|
261
|
+
# at the interval of every equal sign in the ground truth
|
|
262
|
+
# and check if any of the responses match
|
|
263
|
+
# this accounts for generations such as b = 1 with ground truth as x = b = 1
|
|
264
|
+
if str1.count("=") < str2.count("="):
|
|
265
|
+
return self._is_str_correct(str1, str2[str2.index("=") + 1 :])
|
|
266
|
+
if str1.count("=") > str2.count("="):
|
|
267
|
+
return self._is_str_correct(str1[str1.index("=") + 1 :], str2)
|
|
268
|
+
if str1 is None and str2 is None:
|
|
269
|
+
return True
|
|
270
|
+
if str1 is None or str2 is None:
|
|
271
|
+
return False
|
|
272
|
+
try:
|
|
273
|
+
return str1 == str2
|
|
274
|
+
except Exception:
|
|
275
|
+
return str1 == str2
|
|
276
|
+
|
|
277
|
+
def _is_symbolically_equiv(
|
|
278
|
+
self, response_list: list[str], ground_truths: list, default_invalid: Basic = S.NaN
|
|
279
|
+
) -> bool:
|
|
280
|
+
"""
|
|
281
|
+
Check if any of the responses are correct and return true at first match
|
|
282
|
+
:param response_list: list of responses
|
|
283
|
+
:param ground_truths: list of ground truths
|
|
284
|
+
:param default_invalid: default value for invalid expressions
|
|
285
|
+
:return: True if any response
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
self._apply_safely(parse_latex, response_list)
|
|
290
|
+
except (LaTeXParsingError, SympifyError, TypeError):
|
|
291
|
+
response_list = [default_invalid] # this can not occur as an answer.
|
|
292
|
+
return False
|
|
293
|
+
|
|
294
|
+
# map objects dont catch errors, so we use safe apply here
|
|
295
|
+
self._apply_safely(self._safe_simplify_expression, ground_truths)
|
|
296
|
+
self._apply_safely(self._safe_simplify_expression, response_list)
|
|
297
|
+
|
|
298
|
+
# check if any of the simplified responses match any of the simplified ground truths
|
|
299
|
+
try:
|
|
300
|
+
is_correct = self._any_symb_correct(response_list, ground_truths)
|
|
301
|
+
return is_correct
|
|
302
|
+
except ValueError:
|
|
303
|
+
return False
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import unicodedata
|
|
3
|
+
|
|
4
|
+
from eval_framework.metrics.base import (
|
|
5
|
+
BaseMetric,
|
|
6
|
+
MetricResult,
|
|
7
|
+
)
|
|
8
|
+
from eval_framework.shared.types import Completion, Error, LanguageMetricContext, extract_context_metric
|
|
9
|
+
|
|
10
|
+
# Dictionary of "none" words in different languages
|
|
11
|
+
NONE_DICT = {
|
|
12
|
+
"en": ["none"],
|
|
13
|
+
"ko": ["없음"],
|
|
14
|
+
"pl": ["brak"],
|
|
15
|
+
"zh": ["无"],
|
|
16
|
+
"vi": ["Không có"],
|
|
17
|
+
"ja": ["なし", "数字はありません"],
|
|
18
|
+
"ta": ["ஏதுமில்லை"],
|
|
19
|
+
"hu": ["nincs"],
|
|
20
|
+
"fr": ["aucun"],
|
|
21
|
+
"no": ["ingen"],
|
|
22
|
+
"uk": ["немає", "Нема"],
|
|
23
|
+
"ru": ["нет"],
|
|
24
|
+
"de": ["Keine vorhanden"],
|
|
25
|
+
"es": ["ninguno"],
|
|
26
|
+
"sv": ["inga"],
|
|
27
|
+
"fi": ["ei mikään"],
|
|
28
|
+
"cs": ["žádné", "žádná"],
|
|
29
|
+
"sr": ["nema"],
|
|
30
|
+
"pt": ["nenhum"],
|
|
31
|
+
"it": ["nessuno"],
|
|
32
|
+
"fa": ["هیچ کدام"],
|
|
33
|
+
"sw": ["hakuna"],
|
|
34
|
+
"nl": ["geen"],
|
|
35
|
+
"st": ["ha ho letho"],
|
|
36
|
+
"hi": ["कोई नहीं"],
|
|
37
|
+
"da": ["ingen"],
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def clean_text(text: str) -> str:
|
|
42
|
+
"""Clean text by removing spaces and normalizing"""
|
|
43
|
+
return text.strip().lower().replace("\u200c", "").replace(" ", "")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class NIAHAccuracy(BaseMetric[Completion]):
|
|
47
|
+
"""Metric for Needle in a Haystack tasks"""
|
|
48
|
+
|
|
49
|
+
NAME = "NIAHAccuracy"
|
|
50
|
+
|
|
51
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
52
|
+
if response.error is not None:
|
|
53
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
54
|
+
|
|
55
|
+
context = extract_context_metric(response, LanguageMetricContext)
|
|
56
|
+
|
|
57
|
+
ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
# Extract task and language from metadata
|
|
61
|
+
assert response.context is not None
|
|
62
|
+
language = context.language
|
|
63
|
+
|
|
64
|
+
# Get model's answer
|
|
65
|
+
model_answer = response.completion
|
|
66
|
+
|
|
67
|
+
# Determine which comparison function to use based on the task
|
|
68
|
+
none_values = set(v for values in NONE_DICT.values() for v in values)
|
|
69
|
+
if ground_truths[0] in none_values:
|
|
70
|
+
is_correct = self._compare_none(language, model_answer)
|
|
71
|
+
else:
|
|
72
|
+
is_correct = self._compare_numbers(language, ground_truths, model_answer)
|
|
73
|
+
|
|
74
|
+
return [
|
|
75
|
+
MetricResult(
|
|
76
|
+
metric_name=self.NAME, value=float(is_correct), higher_is_better=True, error=response.error
|
|
77
|
+
)
|
|
78
|
+
]
|
|
79
|
+
except Exception as e:
|
|
80
|
+
error = Error(error_class=e.__class__.__name__, message=str(e), traceback="")
|
|
81
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
|
|
82
|
+
|
|
83
|
+
def _compare_numbers(self, lang: str, correct_answer: list[str], model_answer: str) -> bool:
|
|
84
|
+
"""Compare numbers for regular NIAH tasks"""
|
|
85
|
+
if "-" in lang:
|
|
86
|
+
inst_lang = lang.split("-")[1]
|
|
87
|
+
else:
|
|
88
|
+
inst_lang = lang
|
|
89
|
+
|
|
90
|
+
if not model_answer:
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
processed_model_answer = unicodedata.normalize("NFKC", model_answer)
|
|
94
|
+
|
|
95
|
+
none_words = NONE_DICT.get(inst_lang, ["none"])
|
|
96
|
+
# Check if any word in none_words is present in the processed answer; if yes, auto-fail
|
|
97
|
+
for word in none_words:
|
|
98
|
+
if word in processed_model_answer or clean_text(word) in processed_model_answer:
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
# Extract all numeric substrings from the processed answer
|
|
102
|
+
numeric_strings = re.findall(r"\d+", processed_model_answer)
|
|
103
|
+
|
|
104
|
+
# Remove numbers that consist of a single digit
|
|
105
|
+
numeric_strings = [num for num in numeric_strings if len(num) > 1]
|
|
106
|
+
|
|
107
|
+
# Remove duplicates while preserving the original order
|
|
108
|
+
numeric_strings = list(dict.fromkeys(numeric_strings))
|
|
109
|
+
|
|
110
|
+
# If no numerics are found after processing, return False
|
|
111
|
+
if not numeric_strings:
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
# Convert the extracted number strings to integers
|
|
115
|
+
try:
|
|
116
|
+
extracted_numbers = [int(num) for num in numeric_strings]
|
|
117
|
+
except Exception:
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
# Convert correct_answers elements to integers to ensure numeric comparison
|
|
121
|
+
try:
|
|
122
|
+
correct_converted = [int(item) for item in correct_answer]
|
|
123
|
+
except Exception:
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
# Check that the number of extracted numbers matches the length of correct_answers
|
|
127
|
+
if len(extracted_numbers) != len(correct_converted):
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
# Compare the extracted numbers with the correct answers
|
|
131
|
+
if set(extracted_numbers) == set(correct_converted):
|
|
132
|
+
return True
|
|
133
|
+
else:
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
def _compare_none(self, lang: str, model_answer: str) -> bool:
|
|
137
|
+
"""Compare for NIAH none tasks"""
|
|
138
|
+
# Lower-case all inputs for consistent, case-insensitive processing
|
|
139
|
+
if "-" in lang:
|
|
140
|
+
inst_lang = lang.split("-")[1]
|
|
141
|
+
else:
|
|
142
|
+
inst_lang = lang
|
|
143
|
+
|
|
144
|
+
processed_model_answer = clean_text(unicodedata.normalize("NFKC", model_answer))
|
|
145
|
+
none_words = [clean_text(word) for word in NONE_DICT[inst_lang]]
|
|
146
|
+
|
|
147
|
+
# Remove single digit numbers from the processed answer
|
|
148
|
+
processed_model_answer = re.sub(r"\b\d\b", "", processed_model_answer)
|
|
149
|
+
|
|
150
|
+
# Extract all multi-digit numeric substrings from the processed answer
|
|
151
|
+
numeric_strings = re.findall(r"\d\d+", processed_model_answer)
|
|
152
|
+
|
|
153
|
+
# If any multi-digit numbers are found, return False
|
|
154
|
+
if numeric_strings:
|
|
155
|
+
return False
|
|
156
|
+
|
|
157
|
+
# Check if any of the words in none_words are present
|
|
158
|
+
for word in none_words:
|
|
159
|
+
if word in processed_model_answer:
|
|
160
|
+
return True
|
|
161
|
+
|
|
162
|
+
# If none of the none_words are found, return False
|
|
163
|
+
return False
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
4
|
+
from eval_framework.shared.types import BaseMetricContext, Completion, extract_context_metric
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PlaceholderCheckerMetricContext(BaseMetricContext):
|
|
8
|
+
num_placeholders: int
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PlaceholderChecker(BaseMetric[Completion]):
|
|
12
|
+
NAME = "Placeholder Check"
|
|
13
|
+
|
|
14
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
15
|
+
if response.error is not None:
|
|
16
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
17
|
+
|
|
18
|
+
context = extract_context_metric(response, PlaceholderCheckerMetricContext)
|
|
19
|
+
|
|
20
|
+
assert context.num_placeholders is not None, "Expected 'num_placeholders' in context"
|
|
21
|
+
assert isinstance(context.num_placeholders, int), (
|
|
22
|
+
f"'num_placeholders' has incorrect type: {type(context.num_placeholders)}"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
placeholders = re.findall(r"\[.*?\]", response.completion)
|
|
26
|
+
value = float(len(placeholders) >= context.num_placeholders)
|
|
27
|
+
return [MetricResult(metric_name=self.NAME, value=value, higher_is_better=True, error=response.error)]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Final
|
|
5
|
+
|
|
6
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
7
|
+
from eval_framework.shared.types import Completion
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class WordRepetition(BaseMetric[Completion]):
|
|
11
|
+
"""Word Repetition Metric
|
|
12
|
+
|
|
13
|
+
This metric checks for repetitions of words in the completion text for a
|
|
14
|
+
given window size and repetition threshold. The window size defines the
|
|
15
|
+
consecutive word count to consider a repetition, and min_repetitions
|
|
16
|
+
specifies the minimum repetition count that triggers the metric. This metric
|
|
17
|
+
returns 0.0 if no repetitions are found, and 1.0 if a sufficient number of
|
|
18
|
+
repetitions are found. For example, if the completion contains a two-word
|
|
19
|
+
sequence that repeats once (such as "hello world hello world"), this metric
|
|
20
|
+
would trigger with a window size of 2 and min_repetitions set to 1.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
NAME = "WordRepetition"
|
|
24
|
+
HIGHER_IS_BETTER: Final[bool] = False
|
|
25
|
+
|
|
26
|
+
def __init__(self, window_size: int = 128, min_repetitions: int = 1) -> None:
|
|
27
|
+
"""
|
|
28
|
+
Initialize the WordRepetition metric.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
window_size (int): The number of consecutive words to consider as a
|
|
32
|
+
sequence.
|
|
33
|
+
min_repetitions (int): The minimum number of times a sequence must
|
|
34
|
+
repeat to be considered a repetition. Set to 1 to catch any
|
|
35
|
+
repetition.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.window_size = window_size
|
|
39
|
+
self.min_repetitions = min_repetitions
|
|
40
|
+
|
|
41
|
+
if self.min_repetitions < 1:
|
|
42
|
+
raise ValueError("min_repetitions must be at least 1")
|
|
43
|
+
|
|
44
|
+
if self.window_size < 1:
|
|
45
|
+
raise ValueError("window_size must be at least 1")
|
|
46
|
+
|
|
47
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
48
|
+
if response.error is not None:
|
|
49
|
+
return [
|
|
50
|
+
MetricResult(
|
|
51
|
+
metric_name=self.NAME,
|
|
52
|
+
value=None,
|
|
53
|
+
higher_is_better=self.HIGHER_IS_BETTER,
|
|
54
|
+
error=response.error,
|
|
55
|
+
)
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
has_repetition = _has_repetition(
|
|
59
|
+
text=response.completion,
|
|
60
|
+
window_size=self.window_size,
|
|
61
|
+
min_repetitions=self.min_repetitions,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return [
|
|
65
|
+
MetricResult(
|
|
66
|
+
metric_name=self.NAME,
|
|
67
|
+
value=float(has_repetition),
|
|
68
|
+
higher_is_better=self.HIGHER_IS_BETTER,
|
|
69
|
+
error=response.error,
|
|
70
|
+
)
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _has_repetition(text: str, window_size: int, min_repetitions: int) -> bool:
|
|
75
|
+
"""Check if the text contains any word sequences of a given size that repeat"""
|
|
76
|
+
sequences = _word_sequences(_to_words(text), window_size)
|
|
77
|
+
counts = Counter(sequences)
|
|
78
|
+
return any([count > min_repetitions for count in counts.values()])
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _to_words(text: str) -> Sequence[str]:
|
|
82
|
+
"""A somewhat crude function to tokenize a string into words."""
|
|
83
|
+
return re.findall(r"\b\w+\b", text, re.UNICODE)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _word_sequences(text_words: Sequence[str], window_size: int) -> Sequence[Sequence[str]]:
|
|
87
|
+
"""Get all contiguous sub-sequences of a given size from a word sequence."""
|
|
88
|
+
return [tuple(text_words[i : i + window_size]) for i in range(len(text_words) - window_size + 1)]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from eval_framework.exceptions import LogicError
|
|
2
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
3
|
+
from eval_framework.metrics.completion.f1 import calculate_f1
|
|
4
|
+
from eval_framework.shared.types import Completion
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ROUGE_1(BaseMetric[Completion]):
|
|
8
|
+
"""ROUGE-1"""
|
|
9
|
+
|
|
10
|
+
NAME = "ROUGE-1"
|
|
11
|
+
|
|
12
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
13
|
+
if response.error is not None:
|
|
14
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
15
|
+
|
|
16
|
+
if response.completion == "":
|
|
17
|
+
return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
|
|
18
|
+
if None in response.ground_truth_list:
|
|
19
|
+
raise LogicError("When calculating ROUGE-1 ground_truth cannot be None.")
|
|
20
|
+
|
|
21
|
+
# ROUGE-1 captures word sequence similarity by focusing on unigrams
|
|
22
|
+
rouge = max([_calculate_rouge_1(response.completion, gt) for gt in response.ground_truth_list]) # type: ignore[arg-type]
|
|
23
|
+
return [MetricResult(metric_name=self.NAME, value=float(rouge), higher_is_better=True, error=response.error)]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _calculate_rouge_1(candidate: str, reference: str) -> float:
|
|
27
|
+
"""
|
|
28
|
+
Calculate ROUGE-1 precision, recall, and F1 score between candidate and reference texts.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# Tokenize the candidate and reference summaries
|
|
32
|
+
candidate_tokens = candidate.split()
|
|
33
|
+
reference_tokens = reference.split()
|
|
34
|
+
|
|
35
|
+
return calculate_f1(reference_tokens, candidate_tokens)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from eval_framework.exceptions import LogicError
|
|
2
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
3
|
+
from eval_framework.metrics.completion.f1 import calculate_f1
|
|
4
|
+
from eval_framework.shared.types import Completion
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ROUGE_2(BaseMetric[Completion]):
|
|
8
|
+
"""ROUGE-2"""
|
|
9
|
+
|
|
10
|
+
NAME = "ROUGE-2"
|
|
11
|
+
|
|
12
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
13
|
+
if response.error is not None:
|
|
14
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
15
|
+
|
|
16
|
+
if response.completion == "":
|
|
17
|
+
return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
|
|
18
|
+
if None in response.ground_truth_list:
|
|
19
|
+
raise LogicError("When calculating ROUGE-2 ground_truth cannot be None.")
|
|
20
|
+
|
|
21
|
+
# ROUGE-2 captures word sequence similarity by focusing on bigrams,
|
|
22
|
+
# which makes it sensitive to the order and co-occurrence of words to some extent.
|
|
23
|
+
rouge = max([_calculate_rouge_2(response.completion, gt) for gt in response.ground_truth_list]) # type: ignore[arg-type]
|
|
24
|
+
return [MetricResult(metric_name=self.NAME, value=float(rouge), higher_is_better=True, error=response.error)]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _generate_bigrams(tokens: list[str]) -> list[tuple[str, str]]:
|
|
28
|
+
"""Generate bigrams from a list of tokens."""
|
|
29
|
+
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _calculate_rouge_2(completion: str, ground_truth: str) -> float:
|
|
33
|
+
"""
|
|
34
|
+
Calculate ROUGE-2 precision, recall, and F1 score between candidate and reference texts.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# Tokenize the candidate and reference summaries
|
|
38
|
+
candidate_tokens = completion.split()
|
|
39
|
+
reference_tokens = ground_truth.split()
|
|
40
|
+
|
|
41
|
+
# Generate bigrams for candidate and reference
|
|
42
|
+
candidate_bigrams = _generate_bigrams(candidate_tokens)
|
|
43
|
+
reference_bigrams = _generate_bigrams(reference_tokens)
|
|
44
|
+
|
|
45
|
+
return calculate_f1(reference_bigrams, candidate_bigrams)
|