evalscope 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (59) hide show
  1. evalscope/arguments.py +1 -0
  2. evalscope/benchmarks/aime24/__init__.py +0 -0
  3. evalscope/benchmarks/aime24/aime24_adapter.py +49 -0
  4. evalscope/benchmarks/arc/arc_adapter.py +5 -7
  5. evalscope/benchmarks/bbh/bbh_adapter.py +17 -9
  6. evalscope/benchmarks/benchmark.py +2 -2
  7. evalscope/benchmarks/ceval/ceval_adapter.py +9 -9
  8. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +9 -11
  9. evalscope/benchmarks/competition_math/competition_math_adapter.py +34 -23
  10. evalscope/benchmarks/data_adapter.py +18 -12
  11. evalscope/benchmarks/data_collection/__init__.py +0 -0
  12. evalscope/benchmarks/data_collection/data_collection_adapter.py +71 -0
  13. evalscope/benchmarks/general_mcq/__init__.py +0 -0
  14. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +129 -0
  15. evalscope/benchmarks/general_qa/general_qa_adapter.py +6 -6
  16. evalscope/benchmarks/gpqa/__init__.py +0 -0
  17. evalscope/benchmarks/gpqa/chain_of_thought.txt +81 -0
  18. evalscope/benchmarks/gpqa/gpqa_adapter.py +121 -0
  19. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +8 -13
  20. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +3 -7
  21. evalscope/benchmarks/humaneval/humaneval_adapter.py +5 -6
  22. evalscope/benchmarks/ifeval/ifeval_adapter.py +14 -14
  23. evalscope/benchmarks/ifeval/instructions.py +3 -4
  24. evalscope/benchmarks/iquiz/iquiz_adapter.py +5 -5
  25. evalscope/benchmarks/math_500/__init__.py +0 -0
  26. evalscope/benchmarks/math_500/math_500_adapter.py +49 -0
  27. evalscope/benchmarks/mmlu/mmlu_adapter.py +7 -11
  28. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +27 -15
  29. evalscope/benchmarks/race/race_adapter.py +3 -3
  30. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +1 -2
  31. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +8 -8
  32. evalscope/cli/start_app.py +3 -2
  33. evalscope/collections/evaluator.py +103 -39
  34. evalscope/collections/sampler.py +2 -1
  35. evalscope/collections/schema.py +1 -2
  36. evalscope/config.py +1 -0
  37. evalscope/evaluator/evaluator.py +78 -64
  38. evalscope/metrics/math_parser.py +526 -0
  39. evalscope/metrics/metrics.py +16 -1
  40. evalscope/metrics/named_metrics.py +31 -7
  41. evalscope/models/chat_adapter.py +69 -47
  42. evalscope/models/choice_adapter.py +52 -45
  43. evalscope/models/custom_adapter.py +2 -2
  44. evalscope/models/local_model.py +4 -0
  45. evalscope/models/server_adapter.py +28 -34
  46. evalscope/report/app.py +298 -96
  47. evalscope/run.py +10 -7
  48. evalscope/utils/chat_service.py +2 -2
  49. evalscope/utils/io_utils.py +1 -1
  50. evalscope/version.py +2 -2
  51. {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/METADATA +20 -11
  52. {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/RECORD +57 -47
  53. tests/cli/test_run.py +93 -16
  54. evalscope/benchmarks/ceval/samples.jsonl +0 -1
  55. evalscope/metrics/math_accuracy.py +0 -200
  56. {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/LICENSE +0 -0
  57. {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/WHEEL +0 -0
  58. {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/entry_points.txt +0 -0
  59. {evalscope-0.10.0.dist-info → evalscope-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,526 @@
1
+ """
2
+ The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math:
3
+ """
4
+ # flake8: noqa
5
+ import re
6
+ import regex
7
+ from latex2sympy2 import latex2sympy
8
+ from math import isclose
9
+ from sympy import N, simplify
10
+ from sympy.parsing.latex import parse_latex
11
+ from sympy.parsing.sympy_parser import parse_expr
12
+ from word2number import w2n
13
+
14
+
15
+ def convert_word_number(text: str) -> str:
16
+ try:
17
+ text = str(w2n.word_to_num(text))
18
+ except Exception:
19
+ pass
20
+ return text
21
+
22
+
23
+ def _fix_fracs(string):
24
+ substrs = string.split('\\frac')
25
+ new_str = substrs[0]
26
+ if len(substrs) > 1:
27
+ substrs = substrs[1:]
28
+ for substr in substrs:
29
+ new_str += '\\frac'
30
+ if len(substr) > 0 and substr[0] == '{':
31
+ new_str += substr
32
+ else:
33
+ try:
34
+ assert len(substr) >= 2
35
+ except Exception:
36
+ return string
37
+ a = substr[0]
38
+ b = substr[1]
39
+ if b != '{':
40
+ if len(substr) > 2:
41
+ post_substr = substr[2:]
42
+ new_str += '{' + a + '}{' + b + '}' + post_substr
43
+ else:
44
+ new_str += '{' + a + '}{' + b + '}'
45
+ else:
46
+ if len(substr) > 2:
47
+ post_substr = substr[2:]
48
+ new_str += '{' + a + '}' + b + post_substr
49
+ else:
50
+ new_str += '{' + a + '}' + b
51
+ string = new_str
52
+ return string
53
+
54
+
55
+ def _fix_a_slash_b(string):
56
+ if len(string.split('/')) != 2:
57
+ return string
58
+ a = string.split('/')[0]
59
+ b = string.split('/')[1]
60
+ try:
61
+ if 'sqrt' not in a:
62
+ a = int(a)
63
+ if 'sqrt' not in b:
64
+ b = int(b)
65
+ assert string == '{}/{}'.format(a, b)
66
+ new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
67
+ return new_string
68
+ except Exception:
69
+ return string
70
+
71
+
72
+ def _fix_sqrt(string):
73
+ _string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
74
+ return _string
75
+
76
+
77
+ def strip_answer_string(string):
78
+ string = str(string).strip()
79
+ # linebreaks
80
+ string = string.replace('\n', '')
81
+
82
+ # right "."
83
+ string = string.rstrip('.')
84
+
85
+ # remove inverse spaces
86
+ # replace \\ with \
87
+ string = string.replace('\\!', '')
88
+ # string = string.replace("\\ ", "")
89
+ # string = string.replace("\\\\", "\\")
90
+
91
+ # matrix
92
+ string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string)
93
+ string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string)
94
+ string = string.replace('bmatrix', 'pmatrix')
95
+
96
+ # replace tfrac and dfrac with frac
97
+ string = string.replace('tfrac', 'frac')
98
+ string = string.replace('dfrac', 'frac')
99
+ string = (string.replace('\\neq', '\\ne').replace('\\leq', '\\le').replace('\\geq', '\\ge'))
100
+
101
+ # remove \left and \right
102
+ string = string.replace('\\left', '')
103
+ string = string.replace('\\right', '')
104
+ string = string.replace('\\{', '{')
105
+ string = string.replace('\\}', '}')
106
+
107
+ # Function to replace number words with corresponding digits
108
+ def replace_match(match):
109
+ word = match.group(1).lower()
110
+ if convert_word_number(word) == word:
111
+ return match.group(0)
112
+ else:
113
+ return convert_word_number(word)
114
+
115
+ string = re.sub(r'\\text\{([a-zA-Z]+)\}', replace_match, string)
116
+
117
+ # Before removing unit, check if the unit is squared (for surface area)
118
+ string = re.sub(r'(cm|inches)\}\^2', r'\1}', string)
119
+
120
+ # Remove unit: miles, dollars if after is not none
121
+ _string = re.sub(r'\\text{.*?}$', '', string).strip()
122
+ if _string != '' and _string != string:
123
+ # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
124
+ string = _string
125
+
126
+ # Remove circ (degrees)
127
+ string = string.replace('^{\\circ}', '')
128
+ string = string.replace('^\\circ', '')
129
+
130
+ # remove dollar signs
131
+ string = string.replace('\\$', '')
132
+ string = string.replace('$', '')
133
+ string = string.replace('\\(', '').replace('\\)', '')
134
+
135
+ # convert word number to digit
136
+ string = convert_word_number(string)
137
+
138
+ # replace "\\text{...}" to "..."
139
+ string = re.sub(r'\\text\{(.*?)\}', r'\1', string)
140
+ for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']:
141
+ string = string.replace(key, '')
142
+ string = string.replace('\\emptyset', r'{}')
143
+ string = string.replace('(-\\infty,\\infty)', '\\mathbb{R}')
144
+
145
+ # remove percentage
146
+ string = string.replace('\\%', '')
147
+ string = string.replace('\%', '')
148
+ string = string.replace('%', '')
149
+
150
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
151
+ string = string.replace(' .', ' 0.')
152
+ string = string.replace('{.', '{0.')
153
+
154
+ # cdot
155
+ # string = string.replace("\\cdot", "")
156
+ if (string.startswith('{') and string.endswith('}') and string.isalnum()
157
+ or string.startswith('(') and string.endswith(')') and string.isalnum()
158
+ or string.startswith('[') and string.endswith(']') and string.isalnum()):
159
+ string = string[1:-1]
160
+
161
+ # inf
162
+ string = string.replace('infinity', '\\infty')
163
+ if '\\infty' not in string:
164
+ string = string.replace('inf', '\\infty')
165
+ string = string.replace('+\\inity', '\\infty')
166
+
167
+ # and
168
+ string = string.replace('and', '')
169
+ string = string.replace('\\mathbf', '')
170
+
171
+ # use regex to remove \mbox{...}
172
+ string = re.sub(r'\\mbox{.*?}', '', string)
173
+
174
+ # quote
175
+ string.replace("'", '')
176
+ string.replace('"', '')
177
+
178
+ # i, j
179
+ if 'j' in string and 'i' not in string:
180
+ string = string.replace('j', 'i')
181
+
182
+ # replace a.000b where b is not number or b is end, with ab, use regex
183
+ string = re.sub(r'(\d+)\.0*([^\d])', r'\1\2', string)
184
+ string = re.sub(r'(\d+)\.0*$', r'\1', string)
185
+
186
+ # if empty, return empty string
187
+ if len(string) == 0:
188
+ return string
189
+ if string[0] == '.':
190
+ string = '0' + string
191
+
192
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
193
+ if len(string.split('=')) == 2:
194
+ if len(string.split('=')[0]) <= 2:
195
+ string = string.split('=')[1]
196
+
197
+ string = _fix_sqrt(string)
198
+ string = string.replace(' ', '')
199
+
200
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
201
+ string = _fix_fracs(string)
202
+
203
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
204
+ string = _fix_a_slash_b(string)
205
+
206
+ # Remove unnecessary '\' before integers
207
+ string = re.sub(r'\\(?=\-?\d+(\\|\)|,|\]|$))', '', string)
208
+
209
+ # Remove grade level (e.g., 12th grade) and just maintain the integer
210
+ string = re.sub(r'thgrade$', '', string)
211
+
212
+ # If the answer is a list of integers (without parenthesis), sort them
213
+ if re.fullmatch(r'(\s*-?\d+\s*,)*\s*-?\d+\s*', string):
214
+ # Split the string into a list of integers
215
+ try:
216
+ integer_list = list(map(int, string.split(',')))
217
+ except Exception:
218
+ integer_list = list(map(int, '-1,-1'.split(',')))
219
+
220
+ # Sort the list in ascending order
221
+ sorted_list = sorted(integer_list)
222
+
223
+ # Join the sorted list back into a comma-separated string
224
+ string = ','.join(map(str, sorted_list))
225
+
226
+ return string
227
+
228
+
229
+ def extract_answer(pred_str, use_last_number=True):
230
+ pred_str = pred_str.replace('\u043a\u0438', '')
231
+ if 'final answer is $' in pred_str and '$. I hope' in pred_str:
232
+ # minerva_math
233
+ tmp = pred_str.split('final answer is $', 1)[1]
234
+ pred = tmp.split('$. I hope', 1)[0].strip()
235
+ elif 'boxed' in pred_str:
236
+ ans = pred_str.split('boxed')[-1]
237
+ if len(ans) == 0:
238
+ return ''
239
+ elif ans[0] == '{':
240
+ stack = 1
241
+ a = ''
242
+ for c in ans[1:]:
243
+ if c == '{':
244
+ stack += 1
245
+ a += c
246
+ elif c == '}':
247
+ stack -= 1
248
+ if stack == 0:
249
+ break
250
+ a += c
251
+ else:
252
+ a += c
253
+ else:
254
+ a = ans.split('$')[0].strip()
255
+ pred = a
256
+ elif 'he answer is' in pred_str:
257
+ pred = pred_str.split('he answer is')[-1].strip()
258
+ elif 'final answer is' in pred_str:
259
+ pred = pred_str.split('final answer is')[-1].strip()
260
+ elif '答案是' in pred_str:
261
+ # Handle Chinese few-shot multiple choice problem answer extraction
262
+ pred = pred_str.split('答案是')[1].strip().split('\n\n')[0].strip()
263
+ else: # use the last number
264
+ if use_last_number:
265
+ pattern = '-?\d*\.?\d+'
266
+ pred = re.findall(pattern, pred_str.replace(',', ''))
267
+ if len(pred) >= 1:
268
+ pred = pred[-1]
269
+ else:
270
+ pred = ''
271
+ else:
272
+ pred = ''
273
+
274
+ # multiple line
275
+ # pred = pred.split("\n")[0]
276
+ pred = re.sub(r'\n\s*', '', pred)
277
+ if pred != '' and pred[0] == ':':
278
+ pred = pred[1:]
279
+ if pred != '' and pred[-1] == '.':
280
+ pred = pred[:-1]
281
+ if pred != '' and pred[-1] == '/':
282
+ pred = pred[:-1]
283
+ pred = strip_answer_string(pred)
284
+ return pred
285
+
286
+
287
+ def choice_answer_clean(pred: str):
288
+ pred = pred.strip('\n').rstrip('.').rstrip('/').strip(' ').lstrip(':')
289
+ # Clean the answer based on the dataset
290
+ tmp = re.findall(r'\b(A|B|C|D|E)\b', pred.upper())
291
+ if tmp:
292
+ pred = tmp
293
+ else:
294
+ pred = [pred.strip().strip('.')]
295
+ pred = pred[-1]
296
+ # Remove the period at the end, again!
297
+ pred = pred.rstrip('.').rstrip('/')
298
+ return pred
299
+
300
+
301
+ def parse_digits(num):
302
+ num = regex.sub(',', '', str(num))
303
+ try:
304
+ return float(num)
305
+ except Exception:
306
+ if num.endswith('%'):
307
+ num = num[:-1]
308
+ if num.endswith('\\'):
309
+ num = num[:-1]
310
+ try:
311
+ return float(num) / 100
312
+ except Exception:
313
+ pass
314
+ return None
315
+
316
+
317
+ def is_digit(num):
318
+ # paired with parse_digits
319
+ return parse_digits(num) is not None
320
+
321
+
322
+ def str_to_pmatrix(input_str):
323
+ input_str = input_str.strip()
324
+ matrix_str = re.findall(r'\{.*,.*\}', input_str)
325
+ pmatrix_list = []
326
+
327
+ for m in matrix_str:
328
+ m = m.strip('{}')
329
+ pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}'
330
+ pmatrix_list.append(pmatrix)
331
+
332
+ return ', '.join(pmatrix_list)
333
+
334
+
335
+ def math_equal(
336
+ prediction,
337
+ reference,
338
+ include_percentage: bool = True,
339
+ is_close: bool = True,
340
+ timeout: bool = False,
341
+ ) -> bool:
342
+ """
343
+ Exact match of math if and only if:
344
+ 1. numerical equal: both can convert to float and are equal
345
+ 2. symbolic equal: both can convert to sympy expression and are equal
346
+ """
347
+ if prediction is None or reference is None:
348
+ return False
349
+ if str(prediction.strip().lower()) == str(reference.strip().lower()):
350
+ return True
351
+ if (reference in ['A', 'B', 'C', 'D', 'E'] and choice_answer_clean(prediction) == reference):
352
+ return True
353
+
354
+ try: # 1. numerical equal
355
+ if is_digit(prediction) and is_digit(reference):
356
+ prediction = parse_digits(prediction)
357
+ reference = parse_digits(reference)
358
+ # number questions
359
+ if include_percentage:
360
+ gt_result = [reference / 100, reference, reference * 100]
361
+ else:
362
+ gt_result = [reference]
363
+ for item in gt_result:
364
+ try:
365
+ if is_close:
366
+ if numeric_equal(prediction, item):
367
+ return True
368
+ else:
369
+ if item == prediction:
370
+ return True
371
+ except Exception:
372
+ continue
373
+ return False
374
+ except Exception:
375
+ pass
376
+
377
+ if not prediction and prediction not in [0, False]:
378
+ return False
379
+
380
+ # 2. symbolic equal
381
+ reference = str(reference).strip()
382
+ prediction = str(prediction).strip()
383
+
384
+ ## pmatrix (amps)
385
+ if 'pmatrix' in prediction and 'pmatrix' not in reference:
386
+ reference = str_to_pmatrix(reference)
387
+
388
+ ## deal with [], (), {}
389
+ pred_str, ref_str = prediction, reference
390
+ if (prediction.startswith('[') and prediction.endswith(']')
391
+ and not reference.startswith('(')) or (prediction.startswith('(') and prediction.endswith(')')
392
+ and not reference.startswith('[')):
393
+ pred_str = pred_str.strip('[]()')
394
+ ref_str = ref_str.strip('[]()')
395
+ for s in ['{', '}', '(', ')']:
396
+ ref_str = ref_str.replace(s, '')
397
+ pred_str = pred_str.replace(s, '')
398
+ if pred_str.lower() == ref_str.lower():
399
+ return True
400
+
401
+ ## [a, b] vs. [c, d], return a==c and b==d
402
+ if (regex.match(r'(\(|\[).+(\)|\])', prediction) is not None
403
+ and regex.match(r'(\(|\[).+(\)|\])', reference) is not None):
404
+ pred_parts = prediction[1:-1].split(',')
405
+ ref_parts = reference[1:-1].split(',')
406
+ if len(pred_parts) == len(ref_parts):
407
+ if all(
408
+ [math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
409
+ for i in range(len(pred_parts))]):
410
+ return True
411
+ if ((prediction.startswith('\\begin{pmatrix}') or prediction.startswith('\\begin{bmatrix}'))
412
+ and (prediction.endswith('\\end{pmatrix}') or prediction.endswith('\\end{bmatrix}'))
413
+ and (reference.startswith('\\begin{pmatrix}') or reference.startswith('\\begin{bmatrix}'))
414
+ and (reference.endswith('\\end{pmatrix}') or reference.endswith('\\end{bmatrix}'))):
415
+ pred_lines = [
416
+ line.strip() for line in prediction[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
417
+ if line.strip()
418
+ ]
419
+ ref_lines = [
420
+ line.strip() for line in reference[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
421
+ if line.strip()
422
+ ]
423
+ matched = True
424
+ if len(pred_lines) == len(ref_lines):
425
+ for pred_line, ref_line in zip(pred_lines, ref_lines):
426
+ pred_parts = pred_line.split('&')
427
+ ref_parts = ref_line.split('&')
428
+ if len(pred_parts) == len(ref_parts):
429
+ if not all([
430
+ math_equal(
431
+ pred_parts[i],
432
+ ref_parts[i],
433
+ include_percentage,
434
+ is_close,
435
+ ) for i in range(len(pred_parts))
436
+ ]):
437
+ matched = False
438
+ break
439
+ else:
440
+ matched = False
441
+ if not matched:
442
+ break
443
+ else:
444
+ matched = False
445
+ if matched:
446
+ return True
447
+
448
+ if prediction.count('=') == 1 and reference.count('=') == 1:
449
+ pred = prediction.split('=')
450
+ pred = f'{pred[0].strip()} - ({pred[1].strip()})'
451
+ ref = reference.split('=')
452
+ ref = f'{ref[0].strip()} - ({ref[1].strip()})'
453
+ if symbolic_equal(pred, ref) or symbolic_equal(f'-({pred})', ref):
454
+ return True
455
+ elif (prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference):
456
+ if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
457
+ return True
458
+ elif (reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction):
459
+ if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
460
+ return True
461
+
462
+ if symbolic_equal(prediction, reference):
463
+ return True
464
+
465
+ return False
466
+
467
+
468
+ def numeric_equal(prediction: float, reference: float):
469
+ return isclose(reference, prediction, rel_tol=1e-4)
470
+
471
+
472
+ def symbolic_equal(a, b):
473
+
474
+ def _parse(s):
475
+ for f in [parse_latex, parse_expr, latex2sympy]:
476
+ try:
477
+ return f(s.replace('\\\\', '\\'))
478
+ except Exception:
479
+ try:
480
+ return f(s)
481
+ except Exception:
482
+ pass
483
+ return s
484
+
485
+ a = _parse(a)
486
+ b = _parse(b)
487
+
488
+ # direct equal
489
+ try:
490
+ if str(a) == str(b) or a == b:
491
+ return True
492
+ except Exception:
493
+ pass
494
+
495
+ # simplify equal
496
+ try:
497
+ if a.equals(b) or simplify(a - b) == 0:
498
+ return True
499
+ except Exception:
500
+ pass
501
+
502
+ # equation equal
503
+ try:
504
+ if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
505
+ return True
506
+ except Exception:
507
+ pass
508
+
509
+ try:
510
+ if numeric_equal(float(N(a)), float(N(b))):
511
+ return True
512
+ except Exception:
513
+ pass
514
+
515
+ # matrix
516
+ try:
517
+ # if a and b are matrix
518
+ if a.shape == b.shape:
519
+ _a = a.applyfunc(lambda x: round(x, 3))
520
+ _b = b.applyfunc(lambda x: round(x, 3))
521
+ if _a.equals(_b):
522
+ return True
523
+ except Exception:
524
+ pass
525
+
526
+ return False
@@ -12,10 +12,25 @@ from collections.abc import Iterable
12
12
  from typing import TYPE_CHECKING, Dict, List, Union
13
13
 
14
14
 
15
- def mean(arr):
15
+ def mean(arr: list):
16
+ if isinstance(arr[0], list):
17
+ arr = [item for sublist in arr for item in sublist]
16
18
  return sum(arr) / len(arr)
17
19
 
18
20
 
21
+ def pass_at_k(arr: Union[List[int], List[List[int]]], k: int = 1) -> float:
22
+ if not arr:
23
+ return 0.0
24
+
25
+ def sub_pass_at_k(sub_arr: List[int]) -> float:
26
+ return 1.0 if any(sub_arr[:k]) else 0.0
27
+
28
+ if isinstance(arr[0], list):
29
+ return sum(sub_pass_at_k(sub_arr) for sub_arr in arr) / len(arr)
30
+ else:
31
+ return sum(arr) / len(arr)
32
+
33
+
19
34
  def pop_stddev(arr):
20
35
  mu = mean(arr)
21
36
  return math.sqrt(sum([(x - mu)**2 for x in arr]) / len(arr))
@@ -1,7 +1,8 @@
1
1
  from dataclasses import dataclass, field
2
- from typing import Callable
2
+ from functools import partial
3
+ from typing import Callable, Dict
3
4
 
4
- from evalscope.metrics.metrics import mean, weighted_mean
5
+ from evalscope.metrics.metrics import mean, pass_at_k, weighted_mean
5
6
 
6
7
 
7
8
  @dataclass
@@ -10,8 +11,31 @@ class Metric:
10
11
  object: Callable = field(default_factory=lambda: mean)
11
12
 
12
13
 
13
- AverageAccuracy = Metric(name='AverageAccuracy', object=mean)
14
- WeightedAverageAccuracy = Metric(name='WeightedAverageAccuracy', object=weighted_mean)
15
- AverageBLEU = Metric(name='AverageBLEU', object=mean)
16
- WeightedAverageBLEU = Metric(name='WeightedAverageBLEU', object=weighted_mean)
17
- Pass1 = Metric(name='Pass@1', object=mean)
14
+ class MetricRegistry:
15
+
16
+ def __init__(self):
17
+ self.metrics: Dict[str, Metric] = {}
18
+
19
+ def register(self, metric: Metric):
20
+ self.metrics[metric.name] = metric
21
+
22
+ def get(self, name: str) -> Metric:
23
+ try:
24
+ return self.metrics[name]
25
+ except KeyError:
26
+ raise KeyError(f'Metric {name} not found in the registry. Available metrics: {self.list_metrics()}')
27
+
28
+ def list_metrics(self):
29
+ return list(self.metrics.keys())
30
+
31
+
32
+ metric_registry = MetricRegistry()
33
+
34
+ # Register metrics
35
+ metric_registry.register(Metric(name='AverageAccuracy', object=mean))
36
+ metric_registry.register(Metric(name='WeightedAverageAccuracy', object=weighted_mean))
37
+ metric_registry.register(Metric(name='AverageBLEU', object=mean))
38
+ metric_registry.register(Metric(name='WeightedAverageBLEU', object=weighted_mean))
39
+ metric_registry.register(Metric(name='AveragePass@1', object=mean))
40
+ for k in range(1, 17):
41
+ metric_registry.register(Metric(name=f'Pass@{k}', object=partial(pass_at_k, k=k)))