evalscope 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +1 -0
- evalscope/benchmarks/aime24/__init__.py +0 -0
- evalscope/benchmarks/aime24/aime24_adapter.py +49 -0
- evalscope/benchmarks/arc/arc_adapter.py +5 -7
- evalscope/benchmarks/bbh/bbh_adapter.py +17 -9
- evalscope/benchmarks/benchmark.py +2 -2
- evalscope/benchmarks/ceval/ceval_adapter.py +9 -9
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +9 -11
- evalscope/benchmarks/competition_math/competition_math_adapter.py +34 -23
- evalscope/benchmarks/data_adapter.py +18 -12
- evalscope/benchmarks/data_collection/__init__.py +0 -0
- evalscope/benchmarks/data_collection/data_collection_adapter.py +71 -0
- evalscope/benchmarks/general_mcq/__init__.py +0 -0
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +129 -0
- evalscope/benchmarks/general_qa/general_qa_adapter.py +6 -6
- evalscope/benchmarks/gpqa/__init__.py +0 -0
- evalscope/benchmarks/gpqa/chain_of_thought.txt +81 -0
- evalscope/benchmarks/gpqa/gpqa_adapter.py +121 -0
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +8 -13
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +3 -7
- evalscope/benchmarks/humaneval/humaneval_adapter.py +5 -6
- evalscope/benchmarks/ifeval/ifeval_adapter.py +14 -14
- evalscope/benchmarks/ifeval/instructions.py +3 -4
- evalscope/benchmarks/iquiz/iquiz_adapter.py +5 -5
- evalscope/benchmarks/math_500/__init__.py +0 -0
- evalscope/benchmarks/math_500/math_500_adapter.py +49 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +7 -11
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +27 -15
- evalscope/benchmarks/race/race_adapter.py +3 -3
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +1 -2
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +8 -8
- evalscope/cli/start_app.py +3 -2
- evalscope/collections/evaluator.py +103 -39
- evalscope/collections/sampler.py +2 -1
- evalscope/collections/schema.py +1 -2
- evalscope/config.py +1 -0
- evalscope/evaluator/evaluator.py +78 -64
- evalscope/metrics/math_parser.py +526 -0
- evalscope/metrics/metrics.py +16 -1
- evalscope/metrics/named_metrics.py +31 -7
- evalscope/models/chat_adapter.py +69 -47
- evalscope/models/choice_adapter.py +52 -45
- evalscope/models/custom_adapter.py +2 -2
- evalscope/models/local_model.py +4 -0
- evalscope/models/server_adapter.py +28 -34
- evalscope/report/app.py +298 -96
- evalscope/run.py +10 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/io_utils.py +1 -1
- evalscope/version.py +2 -2
- {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/METADATA +20 -11
- {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/RECORD +57 -47
- tests/cli/test_run.py +93 -16
- evalscope/benchmarks/ceval/samples.jsonl +0 -1
- evalscope/metrics/math_accuracy.py +0 -200
- {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/LICENSE +0 -0
- {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/WHEEL +0 -0
- {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math:
|
|
3
|
+
"""
|
|
4
|
+
# flake8: noqa
|
|
5
|
+
import re
|
|
6
|
+
import regex
|
|
7
|
+
from latex2sympy2 import latex2sympy
|
|
8
|
+
from math import isclose
|
|
9
|
+
from sympy import N, simplify
|
|
10
|
+
from sympy.parsing.latex import parse_latex
|
|
11
|
+
from sympy.parsing.sympy_parser import parse_expr
|
|
12
|
+
from word2number import w2n
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def convert_word_number(text: str) -> str:
|
|
16
|
+
try:
|
|
17
|
+
text = str(w2n.word_to_num(text))
|
|
18
|
+
except Exception:
|
|
19
|
+
pass
|
|
20
|
+
return text
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _fix_fracs(string):
|
|
24
|
+
substrs = string.split('\\frac')
|
|
25
|
+
new_str = substrs[0]
|
|
26
|
+
if len(substrs) > 1:
|
|
27
|
+
substrs = substrs[1:]
|
|
28
|
+
for substr in substrs:
|
|
29
|
+
new_str += '\\frac'
|
|
30
|
+
if len(substr) > 0 and substr[0] == '{':
|
|
31
|
+
new_str += substr
|
|
32
|
+
else:
|
|
33
|
+
try:
|
|
34
|
+
assert len(substr) >= 2
|
|
35
|
+
except Exception:
|
|
36
|
+
return string
|
|
37
|
+
a = substr[0]
|
|
38
|
+
b = substr[1]
|
|
39
|
+
if b != '{':
|
|
40
|
+
if len(substr) > 2:
|
|
41
|
+
post_substr = substr[2:]
|
|
42
|
+
new_str += '{' + a + '}{' + b + '}' + post_substr
|
|
43
|
+
else:
|
|
44
|
+
new_str += '{' + a + '}{' + b + '}'
|
|
45
|
+
else:
|
|
46
|
+
if len(substr) > 2:
|
|
47
|
+
post_substr = substr[2:]
|
|
48
|
+
new_str += '{' + a + '}' + b + post_substr
|
|
49
|
+
else:
|
|
50
|
+
new_str += '{' + a + '}' + b
|
|
51
|
+
string = new_str
|
|
52
|
+
return string
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _fix_a_slash_b(string):
|
|
56
|
+
if len(string.split('/')) != 2:
|
|
57
|
+
return string
|
|
58
|
+
a = string.split('/')[0]
|
|
59
|
+
b = string.split('/')[1]
|
|
60
|
+
try:
|
|
61
|
+
if 'sqrt' not in a:
|
|
62
|
+
a = int(a)
|
|
63
|
+
if 'sqrt' not in b:
|
|
64
|
+
b = int(b)
|
|
65
|
+
assert string == '{}/{}'.format(a, b)
|
|
66
|
+
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
|
67
|
+
return new_string
|
|
68
|
+
except Exception:
|
|
69
|
+
return string
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _fix_sqrt(string):
|
|
73
|
+
_string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
|
|
74
|
+
return _string
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def strip_answer_string(string):
|
|
78
|
+
string = str(string).strip()
|
|
79
|
+
# linebreaks
|
|
80
|
+
string = string.replace('\n', '')
|
|
81
|
+
|
|
82
|
+
# right "."
|
|
83
|
+
string = string.rstrip('.')
|
|
84
|
+
|
|
85
|
+
# remove inverse spaces
|
|
86
|
+
# replace \\ with \
|
|
87
|
+
string = string.replace('\\!', '')
|
|
88
|
+
# string = string.replace("\\ ", "")
|
|
89
|
+
# string = string.replace("\\\\", "\\")
|
|
90
|
+
|
|
91
|
+
# matrix
|
|
92
|
+
string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string)
|
|
93
|
+
string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string)
|
|
94
|
+
string = string.replace('bmatrix', 'pmatrix')
|
|
95
|
+
|
|
96
|
+
# replace tfrac and dfrac with frac
|
|
97
|
+
string = string.replace('tfrac', 'frac')
|
|
98
|
+
string = string.replace('dfrac', 'frac')
|
|
99
|
+
string = (string.replace('\\neq', '\\ne').replace('\\leq', '\\le').replace('\\geq', '\\ge'))
|
|
100
|
+
|
|
101
|
+
# remove \left and \right
|
|
102
|
+
string = string.replace('\\left', '')
|
|
103
|
+
string = string.replace('\\right', '')
|
|
104
|
+
string = string.replace('\\{', '{')
|
|
105
|
+
string = string.replace('\\}', '}')
|
|
106
|
+
|
|
107
|
+
# Function to replace number words with corresponding digits
|
|
108
|
+
def replace_match(match):
|
|
109
|
+
word = match.group(1).lower()
|
|
110
|
+
if convert_word_number(word) == word:
|
|
111
|
+
return match.group(0)
|
|
112
|
+
else:
|
|
113
|
+
return convert_word_number(word)
|
|
114
|
+
|
|
115
|
+
string = re.sub(r'\\text\{([a-zA-Z]+)\}', replace_match, string)
|
|
116
|
+
|
|
117
|
+
# Before removing unit, check if the unit is squared (for surface area)
|
|
118
|
+
string = re.sub(r'(cm|inches)\}\^2', r'\1}', string)
|
|
119
|
+
|
|
120
|
+
# Remove unit: miles, dollars if after is not none
|
|
121
|
+
_string = re.sub(r'\\text{.*?}$', '', string).strip()
|
|
122
|
+
if _string != '' and _string != string:
|
|
123
|
+
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
|
124
|
+
string = _string
|
|
125
|
+
|
|
126
|
+
# Remove circ (degrees)
|
|
127
|
+
string = string.replace('^{\\circ}', '')
|
|
128
|
+
string = string.replace('^\\circ', '')
|
|
129
|
+
|
|
130
|
+
# remove dollar signs
|
|
131
|
+
string = string.replace('\\$', '')
|
|
132
|
+
string = string.replace('$', '')
|
|
133
|
+
string = string.replace('\\(', '').replace('\\)', '')
|
|
134
|
+
|
|
135
|
+
# convert word number to digit
|
|
136
|
+
string = convert_word_number(string)
|
|
137
|
+
|
|
138
|
+
# replace "\\text{...}" to "..."
|
|
139
|
+
string = re.sub(r'\\text\{(.*?)\}', r'\1', string)
|
|
140
|
+
for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']:
|
|
141
|
+
string = string.replace(key, '')
|
|
142
|
+
string = string.replace('\\emptyset', r'{}')
|
|
143
|
+
string = string.replace('(-\\infty,\\infty)', '\\mathbb{R}')
|
|
144
|
+
|
|
145
|
+
# remove percentage
|
|
146
|
+
string = string.replace('\\%', '')
|
|
147
|
+
string = string.replace('\%', '')
|
|
148
|
+
string = string.replace('%', '')
|
|
149
|
+
|
|
150
|
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
151
|
+
string = string.replace(' .', ' 0.')
|
|
152
|
+
string = string.replace('{.', '{0.')
|
|
153
|
+
|
|
154
|
+
# cdot
|
|
155
|
+
# string = string.replace("\\cdot", "")
|
|
156
|
+
if (string.startswith('{') and string.endswith('}') and string.isalnum()
|
|
157
|
+
or string.startswith('(') and string.endswith(')') and string.isalnum()
|
|
158
|
+
or string.startswith('[') and string.endswith(']') and string.isalnum()):
|
|
159
|
+
string = string[1:-1]
|
|
160
|
+
|
|
161
|
+
# inf
|
|
162
|
+
string = string.replace('infinity', '\\infty')
|
|
163
|
+
if '\\infty' not in string:
|
|
164
|
+
string = string.replace('inf', '\\infty')
|
|
165
|
+
string = string.replace('+\\inity', '\\infty')
|
|
166
|
+
|
|
167
|
+
# and
|
|
168
|
+
string = string.replace('and', '')
|
|
169
|
+
string = string.replace('\\mathbf', '')
|
|
170
|
+
|
|
171
|
+
# use regex to remove \mbox{...}
|
|
172
|
+
string = re.sub(r'\\mbox{.*?}', '', string)
|
|
173
|
+
|
|
174
|
+
# quote
|
|
175
|
+
string.replace("'", '')
|
|
176
|
+
string.replace('"', '')
|
|
177
|
+
|
|
178
|
+
# i, j
|
|
179
|
+
if 'j' in string and 'i' not in string:
|
|
180
|
+
string = string.replace('j', 'i')
|
|
181
|
+
|
|
182
|
+
# replace a.000b where b is not number or b is end, with ab, use regex
|
|
183
|
+
string = re.sub(r'(\d+)\.0*([^\d])', r'\1\2', string)
|
|
184
|
+
string = re.sub(r'(\d+)\.0*$', r'\1', string)
|
|
185
|
+
|
|
186
|
+
# if empty, return empty string
|
|
187
|
+
if len(string) == 0:
|
|
188
|
+
return string
|
|
189
|
+
if string[0] == '.':
|
|
190
|
+
string = '0' + string
|
|
191
|
+
|
|
192
|
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
193
|
+
if len(string.split('=')) == 2:
|
|
194
|
+
if len(string.split('=')[0]) <= 2:
|
|
195
|
+
string = string.split('=')[1]
|
|
196
|
+
|
|
197
|
+
string = _fix_sqrt(string)
|
|
198
|
+
string = string.replace(' ', '')
|
|
199
|
+
|
|
200
|
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
201
|
+
string = _fix_fracs(string)
|
|
202
|
+
|
|
203
|
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
204
|
+
string = _fix_a_slash_b(string)
|
|
205
|
+
|
|
206
|
+
# Remove unnecessary '\' before integers
|
|
207
|
+
string = re.sub(r'\\(?=\-?\d+(\\|\)|,|\]|$))', '', string)
|
|
208
|
+
|
|
209
|
+
# Remove grade level (e.g., 12th grade) and just maintain the integer
|
|
210
|
+
string = re.sub(r'thgrade$', '', string)
|
|
211
|
+
|
|
212
|
+
# If the answer is a list of integers (without parenthesis), sort them
|
|
213
|
+
if re.fullmatch(r'(\s*-?\d+\s*,)*\s*-?\d+\s*', string):
|
|
214
|
+
# Split the string into a list of integers
|
|
215
|
+
try:
|
|
216
|
+
integer_list = list(map(int, string.split(',')))
|
|
217
|
+
except Exception:
|
|
218
|
+
integer_list = list(map(int, '-1,-1'.split(',')))
|
|
219
|
+
|
|
220
|
+
# Sort the list in ascending order
|
|
221
|
+
sorted_list = sorted(integer_list)
|
|
222
|
+
|
|
223
|
+
# Join the sorted list back into a comma-separated string
|
|
224
|
+
string = ','.join(map(str, sorted_list))
|
|
225
|
+
|
|
226
|
+
return string
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def extract_answer(pred_str, use_last_number=True):
|
|
230
|
+
pred_str = pred_str.replace('\u043a\u0438', '')
|
|
231
|
+
if 'final answer is $' in pred_str and '$. I hope' in pred_str:
|
|
232
|
+
# minerva_math
|
|
233
|
+
tmp = pred_str.split('final answer is $', 1)[1]
|
|
234
|
+
pred = tmp.split('$. I hope', 1)[0].strip()
|
|
235
|
+
elif 'boxed' in pred_str:
|
|
236
|
+
ans = pred_str.split('boxed')[-1]
|
|
237
|
+
if len(ans) == 0:
|
|
238
|
+
return ''
|
|
239
|
+
elif ans[0] == '{':
|
|
240
|
+
stack = 1
|
|
241
|
+
a = ''
|
|
242
|
+
for c in ans[1:]:
|
|
243
|
+
if c == '{':
|
|
244
|
+
stack += 1
|
|
245
|
+
a += c
|
|
246
|
+
elif c == '}':
|
|
247
|
+
stack -= 1
|
|
248
|
+
if stack == 0:
|
|
249
|
+
break
|
|
250
|
+
a += c
|
|
251
|
+
else:
|
|
252
|
+
a += c
|
|
253
|
+
else:
|
|
254
|
+
a = ans.split('$')[0].strip()
|
|
255
|
+
pred = a
|
|
256
|
+
elif 'he answer is' in pred_str:
|
|
257
|
+
pred = pred_str.split('he answer is')[-1].strip()
|
|
258
|
+
elif 'final answer is' in pred_str:
|
|
259
|
+
pred = pred_str.split('final answer is')[-1].strip()
|
|
260
|
+
elif '答案是' in pred_str:
|
|
261
|
+
# Handle Chinese few-shot multiple choice problem answer extraction
|
|
262
|
+
pred = pred_str.split('答案是')[1].strip().split('\n\n')[0].strip()
|
|
263
|
+
else: # use the last number
|
|
264
|
+
if use_last_number:
|
|
265
|
+
pattern = '-?\d*\.?\d+'
|
|
266
|
+
pred = re.findall(pattern, pred_str.replace(',', ''))
|
|
267
|
+
if len(pred) >= 1:
|
|
268
|
+
pred = pred[-1]
|
|
269
|
+
else:
|
|
270
|
+
pred = ''
|
|
271
|
+
else:
|
|
272
|
+
pred = ''
|
|
273
|
+
|
|
274
|
+
# multiple line
|
|
275
|
+
# pred = pred.split("\n")[0]
|
|
276
|
+
pred = re.sub(r'\n\s*', '', pred)
|
|
277
|
+
if pred != '' and pred[0] == ':':
|
|
278
|
+
pred = pred[1:]
|
|
279
|
+
if pred != '' and pred[-1] == '.':
|
|
280
|
+
pred = pred[:-1]
|
|
281
|
+
if pred != '' and pred[-1] == '/':
|
|
282
|
+
pred = pred[:-1]
|
|
283
|
+
pred = strip_answer_string(pred)
|
|
284
|
+
return pred
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def choice_answer_clean(pred: str):
|
|
288
|
+
pred = pred.strip('\n').rstrip('.').rstrip('/').strip(' ').lstrip(':')
|
|
289
|
+
# Clean the answer based on the dataset
|
|
290
|
+
tmp = re.findall(r'\b(A|B|C|D|E)\b', pred.upper())
|
|
291
|
+
if tmp:
|
|
292
|
+
pred = tmp
|
|
293
|
+
else:
|
|
294
|
+
pred = [pred.strip().strip('.')]
|
|
295
|
+
pred = pred[-1]
|
|
296
|
+
# Remove the period at the end, again!
|
|
297
|
+
pred = pred.rstrip('.').rstrip('/')
|
|
298
|
+
return pred
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def parse_digits(num):
|
|
302
|
+
num = regex.sub(',', '', str(num))
|
|
303
|
+
try:
|
|
304
|
+
return float(num)
|
|
305
|
+
except Exception:
|
|
306
|
+
if num.endswith('%'):
|
|
307
|
+
num = num[:-1]
|
|
308
|
+
if num.endswith('\\'):
|
|
309
|
+
num = num[:-1]
|
|
310
|
+
try:
|
|
311
|
+
return float(num) / 100
|
|
312
|
+
except Exception:
|
|
313
|
+
pass
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def is_digit(num):
|
|
318
|
+
# paired with parse_digits
|
|
319
|
+
return parse_digits(num) is not None
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def str_to_pmatrix(input_str):
|
|
323
|
+
input_str = input_str.strip()
|
|
324
|
+
matrix_str = re.findall(r'\{.*,.*\}', input_str)
|
|
325
|
+
pmatrix_list = []
|
|
326
|
+
|
|
327
|
+
for m in matrix_str:
|
|
328
|
+
m = m.strip('{}')
|
|
329
|
+
pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}'
|
|
330
|
+
pmatrix_list.append(pmatrix)
|
|
331
|
+
|
|
332
|
+
return ', '.join(pmatrix_list)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def math_equal(
|
|
336
|
+
prediction,
|
|
337
|
+
reference,
|
|
338
|
+
include_percentage: bool = True,
|
|
339
|
+
is_close: bool = True,
|
|
340
|
+
timeout: bool = False,
|
|
341
|
+
) -> bool:
|
|
342
|
+
"""
|
|
343
|
+
Exact match of math if and only if:
|
|
344
|
+
1. numerical equal: both can convert to float and are equal
|
|
345
|
+
2. symbolic equal: both can convert to sympy expression and are equal
|
|
346
|
+
"""
|
|
347
|
+
if prediction is None or reference is None:
|
|
348
|
+
return False
|
|
349
|
+
if str(prediction.strip().lower()) == str(reference.strip().lower()):
|
|
350
|
+
return True
|
|
351
|
+
if (reference in ['A', 'B', 'C', 'D', 'E'] and choice_answer_clean(prediction) == reference):
|
|
352
|
+
return True
|
|
353
|
+
|
|
354
|
+
try: # 1. numerical equal
|
|
355
|
+
if is_digit(prediction) and is_digit(reference):
|
|
356
|
+
prediction = parse_digits(prediction)
|
|
357
|
+
reference = parse_digits(reference)
|
|
358
|
+
# number questions
|
|
359
|
+
if include_percentage:
|
|
360
|
+
gt_result = [reference / 100, reference, reference * 100]
|
|
361
|
+
else:
|
|
362
|
+
gt_result = [reference]
|
|
363
|
+
for item in gt_result:
|
|
364
|
+
try:
|
|
365
|
+
if is_close:
|
|
366
|
+
if numeric_equal(prediction, item):
|
|
367
|
+
return True
|
|
368
|
+
else:
|
|
369
|
+
if item == prediction:
|
|
370
|
+
return True
|
|
371
|
+
except Exception:
|
|
372
|
+
continue
|
|
373
|
+
return False
|
|
374
|
+
except Exception:
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
if not prediction and prediction not in [0, False]:
|
|
378
|
+
return False
|
|
379
|
+
|
|
380
|
+
# 2. symbolic equal
|
|
381
|
+
reference = str(reference).strip()
|
|
382
|
+
prediction = str(prediction).strip()
|
|
383
|
+
|
|
384
|
+
## pmatrix (amps)
|
|
385
|
+
if 'pmatrix' in prediction and 'pmatrix' not in reference:
|
|
386
|
+
reference = str_to_pmatrix(reference)
|
|
387
|
+
|
|
388
|
+
## deal with [], (), {}
|
|
389
|
+
pred_str, ref_str = prediction, reference
|
|
390
|
+
if (prediction.startswith('[') and prediction.endswith(']')
|
|
391
|
+
and not reference.startswith('(')) or (prediction.startswith('(') and prediction.endswith(')')
|
|
392
|
+
and not reference.startswith('[')):
|
|
393
|
+
pred_str = pred_str.strip('[]()')
|
|
394
|
+
ref_str = ref_str.strip('[]()')
|
|
395
|
+
for s in ['{', '}', '(', ')']:
|
|
396
|
+
ref_str = ref_str.replace(s, '')
|
|
397
|
+
pred_str = pred_str.replace(s, '')
|
|
398
|
+
if pred_str.lower() == ref_str.lower():
|
|
399
|
+
return True
|
|
400
|
+
|
|
401
|
+
## [a, b] vs. [c, d], return a==c and b==d
|
|
402
|
+
if (regex.match(r'(\(|\[).+(\)|\])', prediction) is not None
|
|
403
|
+
and regex.match(r'(\(|\[).+(\)|\])', reference) is not None):
|
|
404
|
+
pred_parts = prediction[1:-1].split(',')
|
|
405
|
+
ref_parts = reference[1:-1].split(',')
|
|
406
|
+
if len(pred_parts) == len(ref_parts):
|
|
407
|
+
if all(
|
|
408
|
+
[math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
|
|
409
|
+
for i in range(len(pred_parts))]):
|
|
410
|
+
return True
|
|
411
|
+
if ((prediction.startswith('\\begin{pmatrix}') or prediction.startswith('\\begin{bmatrix}'))
|
|
412
|
+
and (prediction.endswith('\\end{pmatrix}') or prediction.endswith('\\end{bmatrix}'))
|
|
413
|
+
and (reference.startswith('\\begin{pmatrix}') or reference.startswith('\\begin{bmatrix}'))
|
|
414
|
+
and (reference.endswith('\\end{pmatrix}') or reference.endswith('\\end{bmatrix}'))):
|
|
415
|
+
pred_lines = [
|
|
416
|
+
line.strip() for line in prediction[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
|
|
417
|
+
if line.strip()
|
|
418
|
+
]
|
|
419
|
+
ref_lines = [
|
|
420
|
+
line.strip() for line in reference[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
|
|
421
|
+
if line.strip()
|
|
422
|
+
]
|
|
423
|
+
matched = True
|
|
424
|
+
if len(pred_lines) == len(ref_lines):
|
|
425
|
+
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
|
426
|
+
pred_parts = pred_line.split('&')
|
|
427
|
+
ref_parts = ref_line.split('&')
|
|
428
|
+
if len(pred_parts) == len(ref_parts):
|
|
429
|
+
if not all([
|
|
430
|
+
math_equal(
|
|
431
|
+
pred_parts[i],
|
|
432
|
+
ref_parts[i],
|
|
433
|
+
include_percentage,
|
|
434
|
+
is_close,
|
|
435
|
+
) for i in range(len(pred_parts))
|
|
436
|
+
]):
|
|
437
|
+
matched = False
|
|
438
|
+
break
|
|
439
|
+
else:
|
|
440
|
+
matched = False
|
|
441
|
+
if not matched:
|
|
442
|
+
break
|
|
443
|
+
else:
|
|
444
|
+
matched = False
|
|
445
|
+
if matched:
|
|
446
|
+
return True
|
|
447
|
+
|
|
448
|
+
if prediction.count('=') == 1 and reference.count('=') == 1:
|
|
449
|
+
pred = prediction.split('=')
|
|
450
|
+
pred = f'{pred[0].strip()} - ({pred[1].strip()})'
|
|
451
|
+
ref = reference.split('=')
|
|
452
|
+
ref = f'{ref[0].strip()} - ({ref[1].strip()})'
|
|
453
|
+
if symbolic_equal(pred, ref) or symbolic_equal(f'-({pred})', ref):
|
|
454
|
+
return True
|
|
455
|
+
elif (prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference):
|
|
456
|
+
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
|
|
457
|
+
return True
|
|
458
|
+
elif (reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction):
|
|
459
|
+
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
|
|
460
|
+
return True
|
|
461
|
+
|
|
462
|
+
if symbolic_equal(prediction, reference):
|
|
463
|
+
return True
|
|
464
|
+
|
|
465
|
+
return False
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def numeric_equal(prediction: float, reference: float):
|
|
469
|
+
return isclose(reference, prediction, rel_tol=1e-4)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def symbolic_equal(a, b):
|
|
473
|
+
|
|
474
|
+
def _parse(s):
|
|
475
|
+
for f in [parse_latex, parse_expr, latex2sympy]:
|
|
476
|
+
try:
|
|
477
|
+
return f(s.replace('\\\\', '\\'))
|
|
478
|
+
except Exception:
|
|
479
|
+
try:
|
|
480
|
+
return f(s)
|
|
481
|
+
except Exception:
|
|
482
|
+
pass
|
|
483
|
+
return s
|
|
484
|
+
|
|
485
|
+
a = _parse(a)
|
|
486
|
+
b = _parse(b)
|
|
487
|
+
|
|
488
|
+
# direct equal
|
|
489
|
+
try:
|
|
490
|
+
if str(a) == str(b) or a == b:
|
|
491
|
+
return True
|
|
492
|
+
except Exception:
|
|
493
|
+
pass
|
|
494
|
+
|
|
495
|
+
# simplify equal
|
|
496
|
+
try:
|
|
497
|
+
if a.equals(b) or simplify(a - b) == 0:
|
|
498
|
+
return True
|
|
499
|
+
except Exception:
|
|
500
|
+
pass
|
|
501
|
+
|
|
502
|
+
# equation equal
|
|
503
|
+
try:
|
|
504
|
+
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
|
505
|
+
return True
|
|
506
|
+
except Exception:
|
|
507
|
+
pass
|
|
508
|
+
|
|
509
|
+
try:
|
|
510
|
+
if numeric_equal(float(N(a)), float(N(b))):
|
|
511
|
+
return True
|
|
512
|
+
except Exception:
|
|
513
|
+
pass
|
|
514
|
+
|
|
515
|
+
# matrix
|
|
516
|
+
try:
|
|
517
|
+
# if a and b are matrix
|
|
518
|
+
if a.shape == b.shape:
|
|
519
|
+
_a = a.applyfunc(lambda x: round(x, 3))
|
|
520
|
+
_b = b.applyfunc(lambda x: round(x, 3))
|
|
521
|
+
if _a.equals(_b):
|
|
522
|
+
return True
|
|
523
|
+
except Exception:
|
|
524
|
+
pass
|
|
525
|
+
|
|
526
|
+
return False
|
evalscope/metrics/metrics.py
CHANGED
|
@@ -12,10 +12,25 @@ from collections.abc import Iterable
|
|
|
12
12
|
from typing import TYPE_CHECKING, Dict, List, Union
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def mean(arr):
|
|
15
|
+
def mean(arr: list):
|
|
16
|
+
if isinstance(arr[0], list):
|
|
17
|
+
arr = [item for sublist in arr for item in sublist]
|
|
16
18
|
return sum(arr) / len(arr)
|
|
17
19
|
|
|
18
20
|
|
|
21
|
+
def pass_at_k(arr: Union[List[int], List[List[int]]], k: int = 1) -> float:
|
|
22
|
+
if not arr:
|
|
23
|
+
return 0.0
|
|
24
|
+
|
|
25
|
+
def sub_pass_at_k(sub_arr: List[int]) -> float:
|
|
26
|
+
return 1.0 if any(sub_arr[:k]) else 0.0
|
|
27
|
+
|
|
28
|
+
if isinstance(arr[0], list):
|
|
29
|
+
return sum(sub_pass_at_k(sub_arr) for sub_arr in arr) / len(arr)
|
|
30
|
+
else:
|
|
31
|
+
return sum(arr) / len(arr)
|
|
32
|
+
|
|
33
|
+
|
|
19
34
|
def pop_stddev(arr):
|
|
20
35
|
mu = mean(arr)
|
|
21
36
|
return math.sqrt(sum([(x - mu)**2 for x in arr]) / len(arr))
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dataclasses import dataclass, field
|
|
2
|
-
from
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Callable, Dict
|
|
3
4
|
|
|
4
|
-
from evalscope.metrics.metrics import mean, weighted_mean
|
|
5
|
+
from evalscope.metrics.metrics import mean, pass_at_k, weighted_mean
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
@dataclass
|
|
@@ -10,8 +11,31 @@ class Metric:
|
|
|
10
11
|
object: Callable = field(default_factory=lambda: mean)
|
|
11
12
|
|
|
12
13
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
14
|
+
class MetricRegistry:
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self.metrics: Dict[str, Metric] = {}
|
|
18
|
+
|
|
19
|
+
def register(self, metric: Metric):
|
|
20
|
+
self.metrics[metric.name] = metric
|
|
21
|
+
|
|
22
|
+
def get(self, name: str) -> Metric:
|
|
23
|
+
try:
|
|
24
|
+
return self.metrics[name]
|
|
25
|
+
except KeyError:
|
|
26
|
+
raise KeyError(f'Metric {name} not found in the registry. Available metrics: {self.list_metrics()}')
|
|
27
|
+
|
|
28
|
+
def list_metrics(self):
|
|
29
|
+
return list(self.metrics.keys())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
metric_registry = MetricRegistry()
|
|
33
|
+
|
|
34
|
+
# Register metrics
|
|
35
|
+
metric_registry.register(Metric(name='AverageAccuracy', object=mean))
|
|
36
|
+
metric_registry.register(Metric(name='WeightedAverageAccuracy', object=weighted_mean))
|
|
37
|
+
metric_registry.register(Metric(name='AverageBLEU', object=mean))
|
|
38
|
+
metric_registry.register(Metric(name='WeightedAverageBLEU', object=weighted_mean))
|
|
39
|
+
metric_registry.register(Metric(name='AveragePass@1', object=mean))
|
|
40
|
+
for k in range(1, 17):
|
|
41
|
+
metric_registry.register(Metric(name=f'Pass@{k}', object=partial(pass_at_k, k=k)))
|