evalscope 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +2 -0
- evalscope/arguments.py +10 -3
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_classification.py +0 -1
- evalscope/backend/rag_eval/utils/llm.py +1 -1
- evalscope/benchmarks/__init__.py +20 -1
- evalscope/benchmarks/arc/__init__.py +0 -5
- evalscope/benchmarks/arc/arc_adapter.py +23 -99
- evalscope/benchmarks/bbh/__init__.py +0 -4
- evalscope/benchmarks/bbh/bbh_adapter.py +19 -89
- evalscope/benchmarks/benchmark.py +70 -59
- evalscope/benchmarks/ceval/__init__.py +0 -5
- evalscope/benchmarks/ceval/ceval_adapter.py +22 -46
- evalscope/benchmarks/cmmlu/__init__.py +0 -5
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +20 -41
- evalscope/benchmarks/competition_math/__init__.py +0 -5
- evalscope/benchmarks/competition_math/competition_math_adapter.py +29 -371
- evalscope/benchmarks/data_adapter.py +114 -85
- evalscope/benchmarks/general_qa/__init__.py +0 -5
- evalscope/benchmarks/general_qa/general_qa_adapter.py +16 -19
- evalscope/benchmarks/gsm8k/__init__.py +0 -4
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +19 -98
- evalscope/benchmarks/hellaswag/__init__.py +0 -5
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +23 -96
- evalscope/benchmarks/humaneval/__init__.py +0 -4
- evalscope/benchmarks/humaneval/humaneval_adapter.py +16 -117
- evalscope/benchmarks/mmlu/__init__.py +0 -5
- evalscope/benchmarks/mmlu/mmlu_adapter.py +26 -48
- evalscope/benchmarks/mmlu_pro/__init__.py +0 -0
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +110 -0
- evalscope/benchmarks/race/__init__.py +0 -5
- evalscope/benchmarks/race/race_adapter.py +25 -53
- evalscope/benchmarks/trivia_qa/__init__.py +0 -5
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +24 -97
- evalscope/benchmarks/truthful_qa/__init__.py +0 -5
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +23 -33
- evalscope/collections/__init__.py +3 -0
- evalscope/collections/evaluator.py +178 -0
- evalscope/collections/sampler.py +132 -0
- evalscope/collections/schema.py +122 -0
- evalscope/config.py +10 -6
- evalscope/constants.py +7 -28
- evalscope/evaluator/evaluator.py +66 -108
- evalscope/evaluator/reviewer/auto_reviewer.py +12 -4
- evalscope/metrics/__init__.py +6 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +1 -1
- evalscope/metrics/math_accuracy.py +193 -50
- evalscope/metrics/metrics.py +7 -4
- evalscope/metrics/rouge_metric.py +13 -8
- evalscope/models/__init__.py +14 -1
- evalscope/models/base_adapter.py +52 -0
- evalscope/models/chat_adapter.py +138 -0
- evalscope/models/choice_adapter.py +211 -0
- evalscope/models/custom_adapter.py +67 -0
- evalscope/models/local_model.py +74 -0
- evalscope/models/model.py +141 -0
- evalscope/models/server_adapter.py +104 -0
- evalscope/perf/arguments.py +1 -0
- evalscope/perf/benchmark.py +1 -1
- evalscope/perf/main.py +3 -1
- evalscope/perf/plugin/api/openai_api.py +51 -47
- evalscope/perf/utils/local_server.py +1 -0
- evalscope/run.py +37 -66
- evalscope/run_arena.py +1 -1
- evalscope/utils/__init__.py +1 -1
- evalscope/utils/chat_service.py +4 -3
- evalscope/utils/io_utils.py +8 -0
- evalscope/utils/logger.py +4 -0
- evalscope/utils/model_utils.py +10 -0
- evalscope/utils/utils.py +3 -25
- evalscope/version.py +2 -2
- {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/METADATA +46 -17
- {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/RECORD +81 -92
- tests/cli/test_collection.py +53 -0
- tests/cli/test_run.py +43 -1
- tests/perf/test_perf.py +3 -3
- tests/rag/test_mteb.py +3 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +0 -87
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +0 -36
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +0 -26
- evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +0 -41
- evalscope/backend/rag_eval/ragas/prompts/chinese/CustomNodeFilter/scoring_prompt_chinese.json +0 -7
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +0 -60
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +0 -36
- evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +0 -24
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +0 -35
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +0 -30
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +0 -39
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +0 -30
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +0 -39
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +0 -34
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +0 -36
- evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +0 -25
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +0 -24
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +0 -39
- evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +0 -16
- evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +0 -24
- evalscope/models/api/__init__.py +0 -3
- evalscope/models/dummy_chat_model.py +0 -49
- evalscope/models/model_adapter.py +0 -525
- evalscope/models/openai_model.py +0 -103
- /evalscope/{models/api → third_party/longbench_write/tools}/openai_api.py +0 -0
- {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/LICENSE +0 -0
- {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/WHEEL +0 -0
- {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,57 +1,200 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
|
|
3
|
-
import re
|
|
4
|
-
from collections import defaultdict
|
|
5
|
-
from tqdm import tqdm
|
|
6
3
|
|
|
7
|
-
from
|
|
4
|
+
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
|
|
5
|
+
def is_equiv(str1, str2, verbose=False):
|
|
6
|
+
if str1 is None and str2 is None:
|
|
7
|
+
print('WARNING: Both None')
|
|
8
|
+
return True
|
|
9
|
+
if str1 is None or str2 is None:
|
|
10
|
+
return False
|
|
8
11
|
|
|
12
|
+
try:
|
|
13
|
+
ss1 = strip_string(str1)
|
|
14
|
+
ss2 = strip_string(str2)
|
|
15
|
+
if verbose:
|
|
16
|
+
print(ss1, ss2)
|
|
17
|
+
return ss1 == ss2
|
|
18
|
+
except Exception:
|
|
19
|
+
return str1 == str2
|
|
9
20
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
if
|
|
13
|
-
|
|
21
|
+
|
|
22
|
+
def remove_boxed(s):
|
|
23
|
+
if '\\boxed ' in s:
|
|
24
|
+
left = '\\boxed '
|
|
25
|
+
assert s[:len(left)] == left
|
|
26
|
+
return s[len(left):]
|
|
27
|
+
|
|
28
|
+
left = '\\boxed{'
|
|
29
|
+
|
|
30
|
+
assert s[:len(left)] == left
|
|
31
|
+
assert s[-1] == '}'
|
|
32
|
+
|
|
33
|
+
return s[len(left):-1]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def last_boxed_only_string(string):
|
|
37
|
+
idx = string.rfind('\\boxed')
|
|
38
|
+
if '\\boxed ' in string:
|
|
39
|
+
return '\\boxed ' + string.split('\\boxed ')[-1].split('$')[0]
|
|
40
|
+
if idx < 0:
|
|
41
|
+
idx = string.rfind('\\fbox')
|
|
42
|
+
if idx < 0:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
i = idx
|
|
46
|
+
right_brace_idx = None
|
|
47
|
+
num_left_braces_open = 0
|
|
48
|
+
while i < len(string):
|
|
49
|
+
if string[i] == '{':
|
|
50
|
+
num_left_braces_open += 1
|
|
51
|
+
if string[i] == '}':
|
|
52
|
+
num_left_braces_open -= 1
|
|
53
|
+
if num_left_braces_open == 0:
|
|
54
|
+
right_brace_idx = i
|
|
55
|
+
break
|
|
56
|
+
i += 1
|
|
57
|
+
|
|
58
|
+
if right_brace_idx is None:
|
|
59
|
+
retval = None
|
|
14
60
|
else:
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
61
|
+
retval = string[idx:right_brace_idx + 1]
|
|
62
|
+
|
|
63
|
+
return retval
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def fix_fracs(string):
|
|
67
|
+
substrs = string.split('\\frac')
|
|
68
|
+
new_str = substrs[0]
|
|
69
|
+
if len(substrs) > 1:
|
|
70
|
+
substrs = substrs[1:]
|
|
71
|
+
for substr in substrs:
|
|
72
|
+
new_str += '\\frac'
|
|
73
|
+
if substr[0] == '{':
|
|
74
|
+
new_str += substr
|
|
75
|
+
else:
|
|
76
|
+
try:
|
|
77
|
+
assert len(substr) >= 2
|
|
78
|
+
except AssertionError:
|
|
79
|
+
return string
|
|
80
|
+
a = substr[0]
|
|
81
|
+
b = substr[1]
|
|
82
|
+
if b != '{':
|
|
83
|
+
if len(substr) > 2:
|
|
84
|
+
post_substr = substr[2:]
|
|
85
|
+
new_str += '{' + a + '}{' + b + '}' + post_substr
|
|
86
|
+
else:
|
|
87
|
+
new_str += '{' + a + '}{' + b + '}'
|
|
88
|
+
else:
|
|
89
|
+
if len(substr) > 2:
|
|
90
|
+
post_substr = substr[2:]
|
|
91
|
+
new_str += '{' + a + '}' + b + post_substr
|
|
92
|
+
else:
|
|
93
|
+
new_str += '{' + a + '}' + b
|
|
94
|
+
string = new_str
|
|
95
|
+
return string
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def fix_a_slash_b(string):
|
|
99
|
+
if len(string.split('/')) != 2:
|
|
100
|
+
return string
|
|
101
|
+
a = string.split('/')[0]
|
|
102
|
+
b = string.split('/')[1]
|
|
103
|
+
try:
|
|
104
|
+
a = int(a)
|
|
105
|
+
b = int(b)
|
|
106
|
+
assert string == '{}/{}'.format(a, b)
|
|
107
|
+
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
|
108
|
+
return new_string
|
|
109
|
+
except AssertionError:
|
|
110
|
+
return string
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def remove_right_units(string):
|
|
114
|
+
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
|
115
|
+
if '\\text{ ' in string:
|
|
116
|
+
splits = string.split('\\text{ ')
|
|
117
|
+
assert len(splits) == 2
|
|
118
|
+
return splits[0]
|
|
28
119
|
else:
|
|
29
|
-
return
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
120
|
+
return string
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def fix_sqrt(string):
|
|
124
|
+
if '\\sqrt' not in string:
|
|
125
|
+
return string
|
|
126
|
+
splits = string.split('\\sqrt')
|
|
127
|
+
new_string = splits[0]
|
|
128
|
+
for split in splits[1:]:
|
|
129
|
+
if split[0] != '{':
|
|
130
|
+
a = split[0]
|
|
131
|
+
new_substr = '\\sqrt{' + a + '}' + split[1:]
|
|
132
|
+
else:
|
|
133
|
+
new_substr = '\\sqrt' + split
|
|
134
|
+
new_string += new_substr
|
|
135
|
+
return new_string
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def strip_string(string):
|
|
139
|
+
# linebreaks
|
|
140
|
+
string = string.replace('\n', '')
|
|
141
|
+
|
|
142
|
+
# remove inverse spaces
|
|
143
|
+
string = string.replace('\\!', '')
|
|
144
|
+
|
|
145
|
+
# replace \\ with \
|
|
146
|
+
string = string.replace('\\\\', '\\')
|
|
147
|
+
|
|
148
|
+
# replace tfrac and dfrac with frac
|
|
149
|
+
string = string.replace('tfrac', 'frac')
|
|
150
|
+
string = string.replace('dfrac', 'frac')
|
|
151
|
+
|
|
152
|
+
# remove \left and \right
|
|
153
|
+
string = string.replace('\\left', '')
|
|
154
|
+
string = string.replace('\\right', '')
|
|
155
|
+
|
|
156
|
+
# Remove circ (degrees)
|
|
157
|
+
string = string.replace('^{\\circ}', '')
|
|
158
|
+
string = string.replace('^\\circ', '')
|
|
159
|
+
|
|
160
|
+
# remove dollar signs
|
|
161
|
+
string = string.replace('\\$', '')
|
|
162
|
+
|
|
163
|
+
# remove units (on the right)
|
|
164
|
+
string = remove_right_units(string)
|
|
165
|
+
|
|
166
|
+
# remove percentage
|
|
167
|
+
string = string.replace('\\%', '')
|
|
168
|
+
string = string.replace('\%', '') # noqa: W605
|
|
169
|
+
|
|
170
|
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
171
|
+
string = string.replace(' .', ' 0.')
|
|
172
|
+
string = string.replace('{.', '{0.')
|
|
173
|
+
# if empty, return empty string
|
|
174
|
+
if len(string) == 0:
|
|
175
|
+
return string
|
|
176
|
+
if string[0] == '.':
|
|
177
|
+
string = '0' + string
|
|
178
|
+
|
|
179
|
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
180
|
+
if len(string.split('=')) == 2:
|
|
181
|
+
if len(string.split('=')[0]) <= 2:
|
|
182
|
+
string = string.split('=')[1]
|
|
183
|
+
|
|
184
|
+
# fix sqrt3 --> sqrt{3}
|
|
185
|
+
string = fix_sqrt(string)
|
|
186
|
+
|
|
187
|
+
# remove spaces
|
|
188
|
+
string = string.replace(' ', '')
|
|
189
|
+
|
|
190
|
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} # noqa: E501
|
|
191
|
+
string = fix_fracs(string)
|
|
192
|
+
|
|
193
|
+
# manually change 0.5 --> \frac{1}{2}
|
|
194
|
+
if string == '0.5':
|
|
195
|
+
string = '\\frac{1}{2}'
|
|
196
|
+
|
|
197
|
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
198
|
+
string = fix_a_slash_b(string)
|
|
199
|
+
|
|
200
|
+
return string
|
evalscope/metrics/metrics.py
CHANGED
|
@@ -2,16 +2,12 @@
|
|
|
2
2
|
# Copyright (c) EleutherAI. and its affiliates.
|
|
3
3
|
# Copyright (c) OpenAI. and its affiliates.
|
|
4
4
|
import itertools
|
|
5
|
-
import jieba
|
|
6
5
|
import math
|
|
7
6
|
import numpy as np
|
|
8
7
|
import random
|
|
9
8
|
import sacrebleu
|
|
10
|
-
import sklearn.metrics
|
|
11
9
|
from collections import defaultdict
|
|
12
10
|
from collections.abc import Iterable
|
|
13
|
-
from nltk import word_tokenize
|
|
14
|
-
from nltk.translate.bleu_score import sentence_bleu
|
|
15
11
|
from typing import Dict, List, Union
|
|
16
12
|
|
|
17
13
|
|
|
@@ -38,6 +34,8 @@ def median(arr):
|
|
|
38
34
|
|
|
39
35
|
|
|
40
36
|
def matthews_corrcoef(items):
|
|
37
|
+
import sklearn.metrics
|
|
38
|
+
|
|
41
39
|
unzipped_list = list(zip(*items))
|
|
42
40
|
golds = unzipped_list[0]
|
|
43
41
|
preds = unzipped_list[1]
|
|
@@ -45,6 +43,8 @@ def matthews_corrcoef(items):
|
|
|
45
43
|
|
|
46
44
|
|
|
47
45
|
def f1_score(items):
|
|
46
|
+
import sklearn.metrics
|
|
47
|
+
|
|
48
48
|
unzipped_list = list(zip(*items))
|
|
49
49
|
golds = unzipped_list[0]
|
|
50
50
|
preds = unzipped_list[1]
|
|
@@ -150,6 +150,9 @@ def bleu_ngram_one_sample(predict, reference):
|
|
|
150
150
|
}
|
|
151
151
|
|
|
152
152
|
"""
|
|
153
|
+
import jieba
|
|
154
|
+
from nltk import word_tokenize
|
|
155
|
+
from nltk.translate.bleu_score import sentence_bleu
|
|
153
156
|
|
|
154
157
|
def is_contains_chinese(strs):
|
|
155
158
|
for _char in strs:
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
|
|
3
3
|
import jieba
|
|
4
|
-
import logging
|
|
5
4
|
from collections import defaultdict
|
|
6
|
-
from pathlib import Path
|
|
7
5
|
from rouge_chinese import Rouge
|
|
8
6
|
from statistics import mean
|
|
9
7
|
from tqdm import tqdm
|
|
10
8
|
|
|
11
9
|
from evalscope.constants import MetricsConstant
|
|
12
10
|
from evalscope.metrics.bundled_rouge_score import rouge_scorer
|
|
11
|
+
from evalscope.utils.logger import get_logger
|
|
12
|
+
|
|
13
|
+
logger = get_logger()
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class DummyTokenizer:
|
|
@@ -18,10 +19,6 @@ class DummyTokenizer:
|
|
|
18
19
|
return text.split()
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
HERE = Path(__file__).absolute().parent
|
|
22
|
-
|
|
23
|
-
logger = logging.getLogger(__name__)
|
|
24
|
-
|
|
25
22
|
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], tokenizer=DummyTokenizer())
|
|
26
23
|
zh_scorer = Rouge()
|
|
27
24
|
|
|
@@ -58,7 +55,11 @@ def compute_rouge_score_one_sample_zh(predict, reference):
|
|
|
58
55
|
p = ' '.join(jieba.cut(p)) if is_contains_chinese(p) else p
|
|
59
56
|
r = ' '.join(jieba.cut(r)) if is_contains_chinese(r) else r
|
|
60
57
|
|
|
61
|
-
|
|
58
|
+
try:
|
|
59
|
+
score = zh_scorer.get_scores(p, r, ignore_empty=True)[0]
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.warning(f'rouge score error: {p} {r} {e}')
|
|
62
|
+
continue
|
|
62
63
|
result['rouge-1-r'] = score['rouge-1']['r']
|
|
63
64
|
result['rouge-1-p'] = score['rouge-1']['p']
|
|
64
65
|
result['rouge-1-f'] = score['rouge-1']['f']
|
|
@@ -75,7 +76,11 @@ def compute_rouge_score_one_sample_zh(predict, reference):
|
|
|
75
76
|
def compute_rouge_score_one_sample(predict, reference):
|
|
76
77
|
result = dict()
|
|
77
78
|
for p, r in zip(predict, reference):
|
|
78
|
-
|
|
79
|
+
try:
|
|
80
|
+
score = scorer.score(p, r)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.warning(f'rouge score error: {p} {r} {e}')
|
|
83
|
+
continue
|
|
79
84
|
result['rouge-1-r'] = score['rouge1'].recall
|
|
80
85
|
result['rouge-1-p'] = score['rouge1'].precision
|
|
81
86
|
result['rouge-1-f'] = score['rouge1'].fmeasure
|
evalscope/models/__init__.py
CHANGED
|
@@ -1,3 +1,16 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
|
|
3
|
-
from evalscope.models.
|
|
3
|
+
from evalscope.models.base_adapter import BaseModelAdapter, initialize_model_adapter
|
|
4
|
+
from evalscope.models.chat_adapter import ChatGenerationModelAdapter
|
|
5
|
+
from evalscope.models.choice_adapter import ContinuationLogitsModelAdapter, MultiChoiceModelAdapter
|
|
6
|
+
from evalscope.models.custom import CustomModel
|
|
7
|
+
from evalscope.models.custom_adapter import CustomModelAdapter
|
|
8
|
+
from evalscope.models.local_model import LocalModel, get_local_model
|
|
9
|
+
from evalscope.models.model import BaseModel, ChatBaseModel, OpenAIModel
|
|
10
|
+
from evalscope.models.server_adapter import ServerModelAdapter
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
'CustomModel', 'BaseModel', 'ChatBaseModel', 'OpenAIModel', 'BaseModelAdapter', 'ChatGenerationModelAdapter',
|
|
14
|
+
'MultiChoiceModelAdapter', 'ContinuationLogitsModelAdapter', 'CustomModelAdapter', 'ServerModelAdapter',
|
|
15
|
+
'LocalModel', 'get_local_model', 'initialize_model_adapter'
|
|
16
|
+
]
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
4
|
+
|
|
5
|
+
from evalscope.constants import EvalType
|
|
6
|
+
from evalscope.models.custom import CustomModel
|
|
7
|
+
from evalscope.models.local_model import LocalModel
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from evalscope.config import TaskConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseModelAdapter(ABC):
|
|
14
|
+
|
|
15
|
+
def __init__(self, model: Optional[Union[LocalModel, CustomModel]], **kwargs):
|
|
16
|
+
if model is None:
|
|
17
|
+
self.model_cfg = kwargs.get('model_cfg', None)
|
|
18
|
+
elif isinstance(model, LocalModel):
|
|
19
|
+
self.model = model.model
|
|
20
|
+
self.model_id = model.model_id
|
|
21
|
+
self.model_revision = model.model_revision
|
|
22
|
+
self.device = model.device
|
|
23
|
+
self.tokenizer = model.tokenizer
|
|
24
|
+
self.model_cfg = model.model_cfg
|
|
25
|
+
elif isinstance(model, CustomModel):
|
|
26
|
+
self.model_cfg = model.config
|
|
27
|
+
else:
|
|
28
|
+
raise ValueError(f'Unsupported model type: {type(model)}')
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
@torch.no_grad()
|
|
32
|
+
def predict(self, *args, **kwargs) -> Any:
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def initialize_model_adapter(task_cfg: 'TaskConfig', model_adapter_cls: 'BaseModelAdapter', base_model: 'LocalModel'):
|
|
37
|
+
"""Initialize the model adapter based on the task configuration."""
|
|
38
|
+
if task_cfg.dry_run:
|
|
39
|
+
from evalscope.models.model import DummyChatModel
|
|
40
|
+
return DummyChatModel(model_cfg=dict())
|
|
41
|
+
elif task_cfg.eval_type == EvalType.CUSTOM:
|
|
42
|
+
if not isinstance(task_cfg.model, CustomModel):
|
|
43
|
+
raise ValueError(f'Expected evalscope.models.custom.CustomModel, but got {type(task_cfg.model)}.')
|
|
44
|
+
from evalscope.models import CustomModelAdapter
|
|
45
|
+
return CustomModelAdapter(custom_model=task_cfg.model)
|
|
46
|
+
elif task_cfg.eval_type == EvalType.SERVICE:
|
|
47
|
+
from evalscope.models import ServerModelAdapter
|
|
48
|
+
return ServerModelAdapter(
|
|
49
|
+
api_url=task_cfg.api_url, model_id=task_cfg.model, api_key=task_cfg.api_key, seed=task_cfg.seed)
|
|
50
|
+
else:
|
|
51
|
+
return model_adapter_cls(
|
|
52
|
+
model=base_model, generation_config=task_cfg.generation_config, chat_template=task_cfg.chat_template)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from evalscope.models.base_adapter import BaseModelAdapter
|
|
7
|
+
from evalscope.models.local_model import LocalModel
|
|
8
|
+
from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage
|
|
9
|
+
from evalscope.utils.logger import get_logger
|
|
10
|
+
from evalscope.utils.model_utils import fix_do_sample_warning
|
|
11
|
+
|
|
12
|
+
logger = get_logger()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
16
|
+
"""
|
|
17
|
+
Chat generation model adapter.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, model: LocalModel, **kwargs):
|
|
21
|
+
super().__init__(model)
|
|
22
|
+
|
|
23
|
+
self.generation_config = self._parse_generation_config(self.tokenizer, self.model)
|
|
24
|
+
|
|
25
|
+
custom_generation_config = kwargs.pop('generation_config', None)
|
|
26
|
+
custom_chat_template = kwargs.pop('chat_template', None)
|
|
27
|
+
|
|
28
|
+
if custom_generation_config:
|
|
29
|
+
logger.info('Updating generation config ...')
|
|
30
|
+
self.generation_config.update(**custom_generation_config)
|
|
31
|
+
|
|
32
|
+
if custom_chat_template:
|
|
33
|
+
self.tokenizer.chat_template = custom_chat_template
|
|
34
|
+
logger.info(f'Using custom chat template: {custom_chat_template}')
|
|
35
|
+
|
|
36
|
+
def _parse_generation_config(self, tokenizer, model):
|
|
37
|
+
from modelscope import GenerationConfig
|
|
38
|
+
|
|
39
|
+
generation_config = getattr(model, 'generation_config', GenerationConfig(do_sample=False))
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
remote_config = GenerationConfig.from_pretrained(
|
|
43
|
+
self.model_id, revision=self.model_revision, trust_remote_code=True)
|
|
44
|
+
generation_config.update(**remote_config.to_dict())
|
|
45
|
+
except Exception:
|
|
46
|
+
logger.warning(f'Failed to get generation config of {self.model_id} from model hub, use default.')
|
|
47
|
+
|
|
48
|
+
if isinstance(self.model_id, str) and os.path.exists(self.model_id):
|
|
49
|
+
logger.warning(f'Got local model dir: {self.model_id}')
|
|
50
|
+
|
|
51
|
+
if tokenizer.eos_token_id is not None:
|
|
52
|
+
generation_config.eos_token_id = tokenizer.eos_token_id
|
|
53
|
+
if tokenizer.pad_token_id is not None:
|
|
54
|
+
generation_config.pad_token_id = tokenizer.pad_token_id
|
|
55
|
+
if generation_config.max_new_tokens is None:
|
|
56
|
+
generation_config.max_new_tokens = 2048
|
|
57
|
+
|
|
58
|
+
return generation_config
|
|
59
|
+
|
|
60
|
+
def _model_generate(self, query: str, system_prompt: str = None, infer_cfg: dict = {}) -> str:
|
|
61
|
+
"""
|
|
62
|
+
Args:
|
|
63
|
+
query: The input query.
|
|
64
|
+
system_prompt: The system prompt.
|
|
65
|
+
infer_cfg: The inference configuration.
|
|
66
|
+
Returns:
|
|
67
|
+
The prediction result.
|
|
68
|
+
"""
|
|
69
|
+
# For chat model, use the chat template to format the input
|
|
70
|
+
if self.tokenizer.chat_template is not None:
|
|
71
|
+
messages = [ChatMessage(role='user', content=query)]
|
|
72
|
+
if system_prompt:
|
|
73
|
+
messages = [ChatMessage(role='system', content=system_prompt)] + messages
|
|
74
|
+
formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
75
|
+
else:
|
|
76
|
+
# For base model, use the query as the input
|
|
77
|
+
formatted_prompt = query
|
|
78
|
+
|
|
79
|
+
inputs = self.tokenizer(formatted_prompt, return_tensors='pt', padding=True).to(self.device)
|
|
80
|
+
input_ids = inputs['input_ids']
|
|
81
|
+
|
|
82
|
+
# Process infer_cfg
|
|
83
|
+
if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1:
|
|
84
|
+
infer_cfg['do_sample'] = True
|
|
85
|
+
|
|
86
|
+
# stop settings
|
|
87
|
+
stop = infer_cfg.get('stop', None)
|
|
88
|
+
eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \
|
|
89
|
+
if stop else self.tokenizer.eos_token_id
|
|
90
|
+
|
|
91
|
+
if eos_token_id is not None:
|
|
92
|
+
infer_cfg['eos_token_id'] = eos_token_id
|
|
93
|
+
infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
|
|
94
|
+
|
|
95
|
+
self.generation_config.update(**infer_cfg)
|
|
96
|
+
fix_do_sample_warning(self.generation_config)
|
|
97
|
+
|
|
98
|
+
# Run inference
|
|
99
|
+
output_ids = self.model.generate(**inputs, generation_config=self.generation_config)
|
|
100
|
+
|
|
101
|
+
response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], skip_special_tokens=True)
|
|
102
|
+
return response
|
|
103
|
+
|
|
104
|
+
@torch.no_grad()
|
|
105
|
+
def predict(self, inputs: Union[str, dict, list], infer_cfg: dict = {}) -> dict:
|
|
106
|
+
"""
|
|
107
|
+
Args:
|
|
108
|
+
inputs: The input data.
|
|
109
|
+
infer_cfg: The inference configuration.
|
|
110
|
+
Returns:
|
|
111
|
+
The prediction result.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
# Process inputs
|
|
115
|
+
if isinstance(inputs, str):
|
|
116
|
+
query = inputs
|
|
117
|
+
system_prompt = None
|
|
118
|
+
elif isinstance(inputs, dict):
|
|
119
|
+
query = inputs['data'][0]
|
|
120
|
+
system_prompt = inputs.get('system_prompt', None)
|
|
121
|
+
elif isinstance(inputs, list):
|
|
122
|
+
query = '\n'.join(inputs)
|
|
123
|
+
system_prompt = None
|
|
124
|
+
else:
|
|
125
|
+
raise TypeError(f'Unsupported inputs type: {type(inputs)}')
|
|
126
|
+
|
|
127
|
+
response = self._model_generate(query, system_prompt, infer_cfg)
|
|
128
|
+
|
|
129
|
+
choices_list = [
|
|
130
|
+
ChatCompletionResponseChoice(
|
|
131
|
+
index=0, message=ChatMessage(content=response, role='assistant'), finish_reason='stop')
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
res_d = ChatCompletionResponse(
|
|
135
|
+
model=self.model_id, choices=choices_list, object='chat.completion', created=int(time.time()),
|
|
136
|
+
usage=None).model_dump(exclude_unset=True)
|
|
137
|
+
|
|
138
|
+
return res_d
|