evalscope 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/__init__.py +3 -0
- evalscope/backend/__init__.py +3 -0
- evalscope/backend/base.py +27 -0
- evalscope/backend/opencompass/__init__.py +3 -0
- evalscope/backend/opencompass/api_meta_template.py +64 -0
- evalscope/backend/opencompass/backend_manager.py +247 -0
- evalscope/backend/opencompass/tasks/__init__.py +1 -0
- evalscope/backend/opencompass/tasks/eval_api.py +30 -0
- evalscope/backend/opencompass/tasks/eval_datasets.py +71 -0
- evalscope/backend/vlm_eval_kit/__init__.py +1 -0
- evalscope/backend/vlm_eval_kit/backend_manager.py +153 -0
- evalscope/benchmarks/__init__.py +4 -0
- evalscope/benchmarks/arc/__init__.py +5 -0
- evalscope/benchmarks/arc/ai2_arc.py +148 -0
- evalscope/benchmarks/arc/arc_adapter.py +231 -0
- evalscope/benchmarks/bbh/__init__.py +6 -0
- evalscope/benchmarks/bbh/bbh_adapter.py +308 -0
- evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +23 -0
- evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +33 -0
- evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +72 -0
- evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +78 -0
- evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +42 -0
- evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/navigate.txt +43 -0
- evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +41 -0
- evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +63 -0
- evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/snarks.txt +30 -0
- evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +10 -0
- evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +77 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +17 -0
- evalscope/benchmarks/benchmark.py +65 -0
- evalscope/benchmarks/ceval/__init__.py +5 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +340 -0
- evalscope/benchmarks/ceval/ceval_exam.py +159 -0
- evalscope/benchmarks/cmmlu/__init__.py +5 -0
- evalscope/benchmarks/cmmlu/cmmlu.py +166 -0
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +369 -0
- evalscope/benchmarks/competition_math/__init__.py +5 -0
- evalscope/benchmarks/competition_math/competition_math.py +88 -0
- evalscope/benchmarks/competition_math/competition_math_adapter.py +470 -0
- evalscope/benchmarks/data_adapter.py +263 -0
- evalscope/benchmarks/general_qa/__init__.py +5 -0
- evalscope/benchmarks/general_qa/general_qa_adapter.py +186 -0
- evalscope/benchmarks/gsm8k/__init__.py +5 -0
- evalscope/benchmarks/gsm8k/gsm8k.py +127 -0
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +236 -0
- evalscope/benchmarks/hellaswag/__init__.py +5 -0
- evalscope/benchmarks/hellaswag/hellaswag.py +116 -0
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +222 -0
- evalscope/benchmarks/humaneval/__init__.py +5 -0
- evalscope/benchmarks/humaneval/humaneval.py +82 -0
- evalscope/benchmarks/humaneval/humaneval_adapter.py +21 -0
- evalscope/benchmarks/mmlu/__init__.py +5 -0
- evalscope/benchmarks/mmlu/mmlu.py +174 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +375 -0
- evalscope/benchmarks/race/__init__.py +5 -0
- evalscope/benchmarks/race/race.py +118 -0
- evalscope/benchmarks/race/race_adapter.py +229 -0
- evalscope/benchmarks/trivia_qa/__init__.py +5 -0
- evalscope/benchmarks/trivia_qa/trivia_qa.py +104 -0
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +207 -0
- evalscope/benchmarks/truthful_qa/__init__.py +5 -0
- evalscope/benchmarks/truthful_qa/truthful_qa.py +167 -0
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +351 -0
- evalscope/cache.py +98 -0
- evalscope/cli/__init__.py +1 -0
- evalscope/cli/base.py +20 -0
- evalscope/cli/cli.py +26 -0
- evalscope/cli/start_perf.py +37 -0
- evalscope/cli/start_server.py +138 -0
- evalscope/config.py +165 -0
- evalscope/constants.py +150 -0
- evalscope/evaluator/__init__.py +3 -0
- evalscope/evaluator/evaluator.py +689 -0
- evalscope/evaluator/rating_eval.py +178 -0
- evalscope/evaluator/reviewer/__init__.py +1 -0
- evalscope/evaluator/reviewer/auto_reviewer.py +411 -0
- evalscope/metrics/__init__.py +1 -0
- evalscope/metrics/bundled_rouge_score/__init__.py +14 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +342 -0
- evalscope/metrics/code_metric.py +104 -0
- evalscope/metrics/math_accuracy.py +60 -0
- evalscope/metrics/metrics.py +405 -0
- evalscope/metrics/rouge_metric.py +129 -0
- evalscope/models/__init__.py +4 -0
- evalscope/models/custom/__init__.py +4 -0
- evalscope/models/custom/custom_model.py +53 -0
- evalscope/models/dummy_chat_model.py +50 -0
- evalscope/models/model.py +88 -0
- evalscope/models/model_adapter.py +586 -0
- evalscope/models/openai_model.py +103 -0
- evalscope/models/template.py +1446 -0
- evalscope/perf/__init__.py +0 -0
- evalscope/perf/_logging.py +32 -0
- evalscope/perf/api_plugin_base.py +60 -0
- evalscope/perf/custom_api.py +87 -0
- evalscope/perf/dashscope_api.py +84 -0
- evalscope/perf/dataset_plugin_base.py +64 -0
- evalscope/perf/datasets/__init__.py +0 -0
- evalscope/perf/datasets/line_by_line.py +18 -0
- evalscope/perf/datasets/longalpaca_12k.py +20 -0
- evalscope/perf/datasets/openqa.py +22 -0
- evalscope/perf/how_to_analysis_result.py +24 -0
- evalscope/perf/http_client.py +756 -0
- evalscope/perf/openai_api.py +130 -0
- evalscope/perf/plugin_registry.py +35 -0
- evalscope/perf/query_parameters.py +42 -0
- evalscope/perf/server_sent_event.py +43 -0
- evalscope/preprocess/__init__.py +1 -0
- evalscope/preprocess/tokenizers/__init__.py +0 -0
- evalscope/preprocess/tokenizers/gpt2_tokenizer.py +221 -0
- evalscope/registry/__init__.py +1 -0
- evalscope/registry/tasks/arc.yaml +29 -0
- evalscope/registry/tasks/bbh.yaml +27 -0
- evalscope/registry/tasks/bbh_mini.yaml +27 -0
- evalscope/registry/tasks/ceval.yaml +27 -0
- evalscope/registry/tasks/ceval_mini.yaml +27 -0
- evalscope/registry/tasks/cmmlu.yaml +27 -0
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +28 -0
- evalscope/registry/tasks/general_qa.yaml +27 -0
- evalscope/registry/tasks/gsm8k.yaml +29 -0
- evalscope/registry/tasks/mmlu.yaml +29 -0
- evalscope/registry/tasks/mmlu_mini.yaml +27 -0
- evalscope/run.py +404 -0
- evalscope/run_arena.py +204 -0
- evalscope/run_ms.py +140 -0
- evalscope/summarizer.py +144 -0
- evalscope/third_party/__init__.py +1 -0
- evalscope/third_party/toolbench_static/__init__.py +3 -0
- evalscope/third_party/toolbench_static/eval.py +219 -0
- evalscope/third_party/toolbench_static/infer.py +278 -0
- evalscope/third_party/toolbench_static/llm/__init__.py +1 -0
- evalscope/third_party/toolbench_static/llm/swift_infer.py +45 -0
- evalscope/third_party/toolbench_static/toolbench_static.py +50 -0
- evalscope/tools/__init__.py +1 -0
- evalscope/tools/combine_reports.py +140 -0
- evalscope/tools/gen_mmlu_subject_mapping.py +90 -0
- evalscope/tools/rewrite_eval_results.py +95 -0
- evalscope/utils/__init__.py +4 -0
- evalscope/utils/arena_utils.py +247 -0
- evalscope/utils/completion_parsers.py +87 -0
- evalscope/utils/logger.py +64 -0
- evalscope/utils/task_cfg_parser.py +10 -0
- evalscope/utils/task_utils.py +19 -0
- evalscope/utils/utils.py +625 -0
- evalscope/version.py +4 -0
- evalscope-0.5.0.dist-info/METADATA +566 -0
- evalscope-0.5.0.dist-info/RECORD +165 -0
- evalscope-0.5.0.dist-info/WHEEL +5 -0
- evalscope-0.5.0.dist-info/entry_points.txt +3 -0
- evalscope-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
from rouge import Rouge
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class EvalArgs:
|
|
12
|
+
input_path: str
|
|
13
|
+
output_path: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_eval(args: EvalArgs):
|
|
17
|
+
print(f'*** Start evaluation with eval args: {args}\n')
|
|
18
|
+
|
|
19
|
+
args.input_path = os.path.join(args.input_path, 'predictions.json')
|
|
20
|
+
args.output_path = os.path.join(args.output_path, 'metrics.json')
|
|
21
|
+
|
|
22
|
+
def evaluate_rougel(cand_list: list, ref_list: list):
|
|
23
|
+
if len(ref_list) == 0:
|
|
24
|
+
return 0
|
|
25
|
+
rouge = Rouge()
|
|
26
|
+
rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True)
|
|
27
|
+
rougel = rouge_score["rouge-l"]["f"]
|
|
28
|
+
return rougel
|
|
29
|
+
|
|
30
|
+
def evaluate_action_em(cand_list: list, ref_list: list):
|
|
31
|
+
if len(ref_list) == 0:
|
|
32
|
+
return 0
|
|
33
|
+
em = 0
|
|
34
|
+
for cand, ref in zip(cand_list, ref_list):
|
|
35
|
+
em += (1 if cand == ref else 0)
|
|
36
|
+
return em / len(cand_list)
|
|
37
|
+
|
|
38
|
+
def evaluate_action_input_f1(action_pred: list, action_ref: list, cand_list: list, ref_list: list):
|
|
39
|
+
easy_f1 = []
|
|
40
|
+
hard_f1 = []
|
|
41
|
+
f1 = []
|
|
42
|
+
for i in range(len(action_pred)):
|
|
43
|
+
ref_action = action_ref[i]
|
|
44
|
+
pred_action = action_pred[i]
|
|
45
|
+
|
|
46
|
+
ref_input = ref_list[i]
|
|
47
|
+
cand_input = cand_list[i]
|
|
48
|
+
|
|
49
|
+
if ref_action != pred_action:
|
|
50
|
+
easy_f1.append(0)
|
|
51
|
+
hard_f1.append(0)
|
|
52
|
+
f1.append(0)
|
|
53
|
+
else:
|
|
54
|
+
try:
|
|
55
|
+
ref_input_json = json.loads(ref_input)
|
|
56
|
+
try:
|
|
57
|
+
cand_input_json = json.loads(cand_input)
|
|
58
|
+
half_match = 0
|
|
59
|
+
full_match = 0
|
|
60
|
+
if ref_input_json == {}:
|
|
61
|
+
if cand_input_json == {}:
|
|
62
|
+
easy_f1.append(1)
|
|
63
|
+
f1.append(1)
|
|
64
|
+
else:
|
|
65
|
+
easy_f1.append(0)
|
|
66
|
+
f1.append(0)
|
|
67
|
+
else:
|
|
68
|
+
for k, v in ref_input_json.items():
|
|
69
|
+
if k in cand_input_json.keys():
|
|
70
|
+
if cand_input_json[k] == v:
|
|
71
|
+
full_match += 1
|
|
72
|
+
else:
|
|
73
|
+
half_match += 1
|
|
74
|
+
|
|
75
|
+
recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30)
|
|
76
|
+
precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30)
|
|
77
|
+
hard_f1.append((2 * recall * precision) / (recall + precision))
|
|
78
|
+
f1.append((2 * recall * precision) / (recall + precision))
|
|
79
|
+
except:
|
|
80
|
+
# cand_input = cand_input.replace("\n","").replace("\"","")
|
|
81
|
+
# ref_input = cand_input.replace("\n","").replace("\"","")
|
|
82
|
+
# rouge = Rouge()
|
|
83
|
+
# rouge_score = rouge.get_scores(hyps=[cand_input], refs=[ref_input], avg=True)
|
|
84
|
+
if ref_input_json == {}:
|
|
85
|
+
easy_f1.append(0)
|
|
86
|
+
else:
|
|
87
|
+
hard_f1.append(0)
|
|
88
|
+
# hard_f1.append(rouge_score["rouge-l"]["f"])
|
|
89
|
+
# f1.append(rouge_score["rouge-l"]["f"])
|
|
90
|
+
f1.append(0)
|
|
91
|
+
except:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
return sum(easy_f1) / len(easy_f1) + 1e-30, sum(hard_f1) / len(hard_f1) + 1e-30, sum(f1) / len(f1) + 1e-30
|
|
95
|
+
|
|
96
|
+
with open(args.input_path, encoding='utf-8') as f:
|
|
97
|
+
data = json.load(f)
|
|
98
|
+
|
|
99
|
+
def parse_action(text):
|
|
100
|
+
action = "None"
|
|
101
|
+
action_input = "{}"
|
|
102
|
+
if 'Action Input:' in text:
|
|
103
|
+
input_idx = text.rindex('Action Input:')
|
|
104
|
+
action_input = text[input_idx + len('Action Input:'):].strip()
|
|
105
|
+
else:
|
|
106
|
+
action_input = '{}'
|
|
107
|
+
|
|
108
|
+
if 'Action:' in text:
|
|
109
|
+
action_idx = text.rindex('Action:')
|
|
110
|
+
action = text[action_idx + len('Action:'):].strip()
|
|
111
|
+
if 'Action Input:' in action:
|
|
112
|
+
input_idx = action.index('Action Input:')
|
|
113
|
+
action = action[:input_idx].strip()
|
|
114
|
+
else:
|
|
115
|
+
action = 'none'
|
|
116
|
+
return action, action_input
|
|
117
|
+
|
|
118
|
+
def parse_output(text):
|
|
119
|
+
action, action_input = parse_action(text)
|
|
120
|
+
if action == "Finish":
|
|
121
|
+
try:
|
|
122
|
+
action_input = json.loads(action_input)
|
|
123
|
+
# print(action_input)
|
|
124
|
+
# print(json.dumps(action_input,indent=2))
|
|
125
|
+
return_type = action_input["return_type"]
|
|
126
|
+
if return_type == "give_answer":
|
|
127
|
+
if "final_answer" in action_input.keys():
|
|
128
|
+
answer = str(action_input['final_answer'])
|
|
129
|
+
if answer.strip() in ['', '.', ',']:
|
|
130
|
+
answer = "None"
|
|
131
|
+
else:
|
|
132
|
+
answer = "None"
|
|
133
|
+
return "finish", action, action_input, answer
|
|
134
|
+
else:
|
|
135
|
+
return "give up", None, None, None
|
|
136
|
+
except:
|
|
137
|
+
return "give up", None, None, None
|
|
138
|
+
else:
|
|
139
|
+
plan = 'call'
|
|
140
|
+
answer = None
|
|
141
|
+
return plan, action, action_input, answer
|
|
142
|
+
|
|
143
|
+
plan_ref = []
|
|
144
|
+
plan_pred = []
|
|
145
|
+
hallu_cases = []
|
|
146
|
+
error_cases = []
|
|
147
|
+
new_data = []
|
|
148
|
+
answer_ref = []
|
|
149
|
+
action_ref = []
|
|
150
|
+
action_input_ref = []
|
|
151
|
+
hallu_ref = 0
|
|
152
|
+
answer_pred = []
|
|
153
|
+
action_pred = []
|
|
154
|
+
action_input_pred = []
|
|
155
|
+
hallu_pred = 0
|
|
156
|
+
for d in data:
|
|
157
|
+
reference = d['target']
|
|
158
|
+
prediction = d['predictions']
|
|
159
|
+
ref_plan, ref_action, ref_input, ref_ans = parse_output(reference)
|
|
160
|
+
# ref_plan: call
|
|
161
|
+
# ref_action: spott
|
|
162
|
+
# ref_input: {"is_id": "city center" }
|
|
163
|
+
# ref_ans: None
|
|
164
|
+
|
|
165
|
+
pred_plan, pred_action, pred_input, pred_ans = parse_output(prediction)
|
|
166
|
+
if ref_action is not None and ref_action == "invalid_hallucination_function_name":
|
|
167
|
+
continue
|
|
168
|
+
if pred_action is not None and ref_action != 'none' and ref_action not in [t['name'] for t in d['tools']]:
|
|
169
|
+
continue
|
|
170
|
+
|
|
171
|
+
if pred_action is not None and pred_action != 'none' and pred_action not in [t['name'] for t in d['tools']]:
|
|
172
|
+
hallu_pred += 1
|
|
173
|
+
hallu_cases.append(d)
|
|
174
|
+
|
|
175
|
+
plan_ref.append(ref_plan)
|
|
176
|
+
plan_pred.append(pred_plan)
|
|
177
|
+
if ref_plan == 'give up':
|
|
178
|
+
pass
|
|
179
|
+
elif ref_plan == 'finish':
|
|
180
|
+
answer_ref.append(ref_ans)
|
|
181
|
+
if pred_ans is None:
|
|
182
|
+
answer_pred.append('none')
|
|
183
|
+
else:
|
|
184
|
+
answer_pred.append(pred_ans)
|
|
185
|
+
else:
|
|
186
|
+
action_ref.append(ref_action)
|
|
187
|
+
action_input_ref.append(ref_input)
|
|
188
|
+
if pred_action is None:
|
|
189
|
+
action_pred.append('none')
|
|
190
|
+
else:
|
|
191
|
+
action_pred.append(pred_action)
|
|
192
|
+
|
|
193
|
+
if pred_input is None:
|
|
194
|
+
action_input_pred.append('{}')
|
|
195
|
+
else:
|
|
196
|
+
action_input_pred.append(pred_input)
|
|
197
|
+
|
|
198
|
+
metric = {}
|
|
199
|
+
rouge = evaluate_rougel(answer_pred, answer_ref)
|
|
200
|
+
plan_em = evaluate_action_em(cand_list=plan_pred, ref_list=plan_ref)
|
|
201
|
+
action_em = evaluate_action_em(cand_list=action_pred, ref_list=action_ref)
|
|
202
|
+
easy_f1, hard_f1, f1 = evaluate_action_input_f1(action_pred, action_ref, action_input_pred, action_input_ref)
|
|
203
|
+
hallu_rate = hallu_pred / len(data)
|
|
204
|
+
metric['rouge'] = rouge
|
|
205
|
+
metric['plan_em'] = plan_em
|
|
206
|
+
metric['action_em'] = action_em
|
|
207
|
+
metric['easy_f1'] = easy_f1
|
|
208
|
+
metric['hard_f1'] = hard_f1
|
|
209
|
+
metric['f1'] = f1
|
|
210
|
+
metric['hallu_rate'] = hallu_rate
|
|
211
|
+
|
|
212
|
+
if not os.path.exists(os.path.dirname(args.output_path)):
|
|
213
|
+
os.makedirs(os.path.dirname(args.output_path))
|
|
214
|
+
print(metric)
|
|
215
|
+
with open(args.output_path, 'w', encoding='utf-8') as f:
|
|
216
|
+
json.dump(metric, f, indent=2)
|
|
217
|
+
|
|
218
|
+
with open(args.output_path.replace('metrics.json', 'hallu_cases.json'), 'w', encoding='utf-8') as f:
|
|
219
|
+
json.dump(hallu_cases, f, indent=2)
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
|
3
|
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
|
18
|
+
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
import json
|
|
21
|
+
import os
|
|
22
|
+
from rouge import Rouge
|
|
23
|
+
import time
|
|
24
|
+
from urllib3.exceptions import MaxRetryError, NewConnectionError
|
|
25
|
+
import requests
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def evaluate_rouge_l(cand_list: list, ref_list: list):
|
|
29
|
+
if len(ref_list) == 0:
|
|
30
|
+
return 0
|
|
31
|
+
rouge = Rouge()
|
|
32
|
+
rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True)
|
|
33
|
+
rougel = rouge_score["rouge-l"]["f"]
|
|
34
|
+
return rougel
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def nested_load_test_data(data_path):
|
|
38
|
+
test_raw_data = []
|
|
39
|
+
if os.path.isdir(data_path):
|
|
40
|
+
for f in os.listdir(data_path):
|
|
41
|
+
temp_test = nested_load_test_data(os.path.join(data_path, f))
|
|
42
|
+
test_raw_data += temp_test
|
|
43
|
+
return test_raw_data
|
|
44
|
+
elif os.path.isfile(data_path) and data_path.endswith('.json'):
|
|
45
|
+
print("Load data from", data_path)
|
|
46
|
+
temp_data = json.load(open(data_path, "r"))
|
|
47
|
+
test_raw_data = temp_data
|
|
48
|
+
return test_raw_data
|
|
49
|
+
else:
|
|
50
|
+
return []
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def baichuan_call(context: list, system: str):
|
|
54
|
+
url = "https://api.baichuan-ai.com/v1/chat/completions"
|
|
55
|
+
api_key = "sk-xxx"
|
|
56
|
+
|
|
57
|
+
new_msg = []
|
|
58
|
+
new_msg.append({
|
|
59
|
+
"role": 'system',
|
|
60
|
+
'content': system})
|
|
61
|
+
for m in context:
|
|
62
|
+
if m['role'] == "user":
|
|
63
|
+
new_msg.append({
|
|
64
|
+
'role': 'user', 'content': m['content']
|
|
65
|
+
})
|
|
66
|
+
elif m['role'] == "function":
|
|
67
|
+
new_msg.append({
|
|
68
|
+
'role': 'user', 'content': m['content']
|
|
69
|
+
})
|
|
70
|
+
elif m['role'] == 'assistant':
|
|
71
|
+
new_msg.append({
|
|
72
|
+
'role': 'assistant', 'content': m['content']
|
|
73
|
+
})
|
|
74
|
+
# print(json.dumps(new_msg, indent=2))
|
|
75
|
+
data = {
|
|
76
|
+
"model": "Baichuan2-Turbo",
|
|
77
|
+
"messages": new_msg,
|
|
78
|
+
"stream": False
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
json_data = json.dumps(data)
|
|
82
|
+
|
|
83
|
+
headers = {
|
|
84
|
+
"Content-Type": "application/json",
|
|
85
|
+
"Authorization": "Bearer " + api_key
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
for i in range(5):
|
|
89
|
+
res = None
|
|
90
|
+
try:
|
|
91
|
+
res = requests.post(url, data=json_data, headers=headers, timeout=60)
|
|
92
|
+
res = res._content.decode('utf-8')
|
|
93
|
+
res = json.loads(res)
|
|
94
|
+
return res["choices"][0]["message"]["content"]
|
|
95
|
+
except KeyError:
|
|
96
|
+
print(res)
|
|
97
|
+
time.sleep(1)
|
|
98
|
+
continue
|
|
99
|
+
except ConnectionError:
|
|
100
|
+
time.sleep(5)
|
|
101
|
+
continue
|
|
102
|
+
except MaxRetryError:
|
|
103
|
+
time.sleep(5)
|
|
104
|
+
continue
|
|
105
|
+
except NewConnectionError:
|
|
106
|
+
time.sleep(5)
|
|
107
|
+
continue
|
|
108
|
+
return ""
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def minimax_call(context: list, system: str):
|
|
112
|
+
group_id = "your-id"
|
|
113
|
+
api_key = "your-xxx"
|
|
114
|
+
|
|
115
|
+
url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}"
|
|
116
|
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
117
|
+
|
|
118
|
+
# construct message
|
|
119
|
+
system_prompt = "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。" \
|
|
120
|
+
"MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。"
|
|
121
|
+
system_prompt += ('\n' + system)
|
|
122
|
+
|
|
123
|
+
new_msg = []
|
|
124
|
+
for m in context:
|
|
125
|
+
if m['role'] == "user":
|
|
126
|
+
new_msg.append({
|
|
127
|
+
'sender_type': 'USER', 'sender_name': 'user', 'text': m['content']
|
|
128
|
+
})
|
|
129
|
+
elif m['role'] == "function":
|
|
130
|
+
new_msg.append({
|
|
131
|
+
'sender_type': 'USER', 'sender_name': 'funtion', 'text': m['content']
|
|
132
|
+
})
|
|
133
|
+
elif m['role'] == 'assistant':
|
|
134
|
+
new_msg.append({
|
|
135
|
+
'sender_type': 'BOT', 'sender_name': 'MM智能助理', 'text': m['content']
|
|
136
|
+
})
|
|
137
|
+
|
|
138
|
+
request_body = {
|
|
139
|
+
"model": "abab6-chat",
|
|
140
|
+
# "model": "abab5.5s-chat",
|
|
141
|
+
"tokens_to_generate": 8192,
|
|
142
|
+
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
|
|
143
|
+
"messages": new_msg,
|
|
144
|
+
"bot_setting": [
|
|
145
|
+
{
|
|
146
|
+
"bot_name": "MM智能助理",
|
|
147
|
+
"content": system_prompt,
|
|
148
|
+
}
|
|
149
|
+
],
|
|
150
|
+
}
|
|
151
|
+
response = requests.post(url, headers=headers, json=request_body)
|
|
152
|
+
status_code = response.status_code
|
|
153
|
+
for i in range(5):
|
|
154
|
+
try:
|
|
155
|
+
if status_code == 200:
|
|
156
|
+
reply = response.json()["reply"]
|
|
157
|
+
if len(reply) == 0:
|
|
158
|
+
print("limit rate")
|
|
159
|
+
time.sleep(8)
|
|
160
|
+
continue
|
|
161
|
+
print(f'>>return: {reply}')
|
|
162
|
+
return reply
|
|
163
|
+
else:
|
|
164
|
+
print(response._content)
|
|
165
|
+
time.sleep(5)
|
|
166
|
+
except KeyError:
|
|
167
|
+
print(response)
|
|
168
|
+
time.sleep(5)
|
|
169
|
+
continue
|
|
170
|
+
return ""
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def swift_call(context: list, system: str, swift_infer_obj):
|
|
174
|
+
query_d: dict = context[-1]
|
|
175
|
+
history_list = context[: -1]
|
|
176
|
+
|
|
177
|
+
query: str = query_d['content']
|
|
178
|
+
history_msg = []
|
|
179
|
+
|
|
180
|
+
tmp_list = []
|
|
181
|
+
for idx, item in enumerate(history_list):
|
|
182
|
+
|
|
183
|
+
if idx % 2 == 0:
|
|
184
|
+
tmp_list.append(item['content'])
|
|
185
|
+
else:
|
|
186
|
+
tmp_list.append(item['content'])
|
|
187
|
+
history_msg.append(tuple(tmp_list))
|
|
188
|
+
tmp_list = []
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
resp_str: str = swift_infer_obj.predict(system=system, query=query, history=history_msg)
|
|
192
|
+
except Exception as e:
|
|
193
|
+
print(e)
|
|
194
|
+
resp_str = ''
|
|
195
|
+
|
|
196
|
+
return resp_str
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class InferArgs:
|
|
201
|
+
model_name_or_path: str
|
|
202
|
+
model_type: str
|
|
203
|
+
data_path: str
|
|
204
|
+
output_dir: str
|
|
205
|
+
deploy_type: str
|
|
206
|
+
max_new_tokens: int = 2048
|
|
207
|
+
num_infer_samples: int = None
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def run_infer(args: InferArgs):
|
|
211
|
+
|
|
212
|
+
if args.deploy_type == 'swift':
|
|
213
|
+
from evalscope.third_party.toolbench_static.llm.swift_infer import SwiftInfer, SwiftInferArgs
|
|
214
|
+
swift_infer_args = SwiftInferArgs(model_id_or_path=args.model_name_or_path,
|
|
215
|
+
model_type=args.model_type,
|
|
216
|
+
max_new_tokens=args.max_new_tokens)
|
|
217
|
+
swift_infer = SwiftInfer(args=swift_infer_args)
|
|
218
|
+
else:
|
|
219
|
+
swift_infer = None
|
|
220
|
+
|
|
221
|
+
# load data
|
|
222
|
+
infer_samples = nested_load_test_data(args.data_path)
|
|
223
|
+
if args.num_infer_samples is not None:
|
|
224
|
+
infer_samples = infer_samples[:args.num_infer_samples]
|
|
225
|
+
|
|
226
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
227
|
+
if os.path.exists(os.path.join(args.output_dir, 'predictions.json')):
|
|
228
|
+
with open(os.path.join(args.output_dir, 'predictions.json')) as f:
|
|
229
|
+
processed_samples = json.load(f)
|
|
230
|
+
else:
|
|
231
|
+
processed_samples = []
|
|
232
|
+
preds = []
|
|
233
|
+
refs = []
|
|
234
|
+
for i, o in enumerate(infer_samples):
|
|
235
|
+
if i < len(processed_samples) and "predictions" in processed_samples[i].keys():
|
|
236
|
+
infer_samples[i]['predictions'] = processed_samples[i]['predictions']
|
|
237
|
+
refs.append(processed_samples[i]['target'])
|
|
238
|
+
preds.append(processed_samples[i]['predictions'])
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
system = o['messages'][0]['content']
|
|
242
|
+
new_msg = o['messages'][1:]
|
|
243
|
+
|
|
244
|
+
print('================================')
|
|
245
|
+
print('case', str(i))
|
|
246
|
+
|
|
247
|
+
if args.deploy_type == 'minimax':
|
|
248
|
+
response_text = minimax_call(new_msg, system)
|
|
249
|
+
# elif model_args.model_type == 'xingchen':
|
|
250
|
+
# response_text = spark_call(new_msg, system)
|
|
251
|
+
# elif model_args.model_type == 'xingchen_v2':
|
|
252
|
+
# response_text = spark_call_v2(new_msg, system, model_args)
|
|
253
|
+
elif args.deploy_type == 'baichuan':
|
|
254
|
+
response_text = baichuan_call(new_msg, system)
|
|
255
|
+
elif args.deploy_type == 'swift':
|
|
256
|
+
assert swift_infer is not None, 'ModelScope Swift infer process is not initialized.'
|
|
257
|
+
response_text = swift_call(new_msg, system, swift_infer)
|
|
258
|
+
else:
|
|
259
|
+
raise NotImplementedError
|
|
260
|
+
|
|
261
|
+
candidate = response_text
|
|
262
|
+
print(candidate)
|
|
263
|
+
if candidate.startswith(': '):
|
|
264
|
+
candidate = candidate[2:]
|
|
265
|
+
if candidate.strip() in ['', '.', ',']:
|
|
266
|
+
candidate = 'none'
|
|
267
|
+
reference = infer_samples[i]['target']
|
|
268
|
+
infer_samples[i]['predictions'] = candidate
|
|
269
|
+
if reference.strip() in ['', '.', ',']:
|
|
270
|
+
reference = "none"
|
|
271
|
+
refs.append(reference)
|
|
272
|
+
preds.append(candidate)
|
|
273
|
+
|
|
274
|
+
with open(os.path.join(args.output_dir, 'predictions.json'), 'w') as f:
|
|
275
|
+
json.dump(infer_samples[:i + 1], f, indent=4)
|
|
276
|
+
|
|
277
|
+
rouge_l = round(evaluate_rouge_l(preds, refs), 2)
|
|
278
|
+
print('\n*** Overall rouge:', rouge_l)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from swift.llm import (
|
|
5
|
+
get_model_tokenizer, get_template, inference, get_default_template_type,
|
|
6
|
+
)
|
|
7
|
+
from swift.utils import seed_everything
|
|
8
|
+
|
|
9
|
+
# TODO: Support custom model for swift infer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SwiftInferArgs:
|
|
14
|
+
model_id_or_path: str
|
|
15
|
+
model_type: str
|
|
16
|
+
max_new_tokens: int = 2048
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SwiftInfer:
|
|
20
|
+
|
|
21
|
+
def __init__(self, args: SwiftInferArgs):
|
|
22
|
+
model_type = args.model_type
|
|
23
|
+
template_type = get_default_template_type(model_type)
|
|
24
|
+
model, tokenizer = get_model_tokenizer(model_type,
|
|
25
|
+
model_id_or_path=args.model_id_or_path,
|
|
26
|
+
model_kwargs={'device_map': 'auto'})
|
|
27
|
+
model.generation_config.max_new_tokens = args.max_new_tokens
|
|
28
|
+
print(f'** Generation config: {model.generation_config}')
|
|
29
|
+
|
|
30
|
+
template = get_template(template_type, tokenizer)
|
|
31
|
+
seed_everything(42)
|
|
32
|
+
|
|
33
|
+
self.tokenizer = tokenizer
|
|
34
|
+
self.model = model
|
|
35
|
+
self.template = template
|
|
36
|
+
|
|
37
|
+
def predict(self, system: str, query: str, history: list):
|
|
38
|
+
|
|
39
|
+
response, history = inference(self.model,
|
|
40
|
+
self.template,
|
|
41
|
+
query=query,
|
|
42
|
+
system=system,
|
|
43
|
+
history=history)
|
|
44
|
+
|
|
45
|
+
return response
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import os
|
|
3
|
+
from typing import Union
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
|
|
6
|
+
from evalscope.third_party.toolbench_static.infer import InferArgs, run_infer
|
|
7
|
+
from evalscope.third_party.toolbench_static.eval import EvalArgs, run_eval
|
|
8
|
+
from evalscope.utils import yaml_to_dict, get_logger, json_to_dict
|
|
9
|
+
|
|
10
|
+
logger = get_logger()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def run_task(task_cfg: Union[str, dict]):
|
|
14
|
+
|
|
15
|
+
if isinstance(task_cfg, str):
|
|
16
|
+
if task_cfg.endswith('.yaml'):
|
|
17
|
+
task_cfg: dict = yaml_to_dict(task_cfg)
|
|
18
|
+
elif task_cfg.endswith('.json'):
|
|
19
|
+
task_cfg: dict = json_to_dict(task_cfg)
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError(f'Unsupported file format: {task_cfg}, should be yaml or json file.')
|
|
22
|
+
|
|
23
|
+
# Run inference for each domain
|
|
24
|
+
infer_args: dict = task_cfg['infer_args']
|
|
25
|
+
for domain in ['in_domain', 'out_of_domain']:
|
|
26
|
+
domain_infer_args = deepcopy(infer_args)
|
|
27
|
+
domain_infer_args.update({'data_path': os.path.join(infer_args['data_path'], f'{domain}.json')})
|
|
28
|
+
domain_infer_args.update({'output_dir': os.path.join(infer_args['output_dir'], domain)})
|
|
29
|
+
|
|
30
|
+
task_infer_args = InferArgs(**domain_infer_args)
|
|
31
|
+
print(f'**Run infer config: {task_infer_args}')
|
|
32
|
+
run_infer(task_infer_args)
|
|
33
|
+
|
|
34
|
+
# Run evaluation for each domain
|
|
35
|
+
eval_args: dict = task_cfg['eval_args']
|
|
36
|
+
for domain in ['in_domain', 'out_of_domain']:
|
|
37
|
+
domain_eval_args = deepcopy(eval_args)
|
|
38
|
+
domain_eval_args.update({'input_path': os.path.join(eval_args['input_path'], domain)})
|
|
39
|
+
domain_eval_args.update({'output_path': os.path.join(eval_args['output_path'], domain)})
|
|
40
|
+
|
|
41
|
+
task_eval_args = EvalArgs(**domain_eval_args)
|
|
42
|
+
print(f'**Run eval config: {task_eval_args}')
|
|
43
|
+
run_eval(task_eval_args)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
if __name__ == '__main__':
|
|
47
|
+
# task_cfg_file = 'config_default.yaml'
|
|
48
|
+
task_cfg_file = 'config_default.json'
|
|
49
|
+
|
|
50
|
+
run_task(task_cfg=task_cfg_file)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|