evalscope 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (25) hide show
  1. evalscope/backend/opencompass/backend_manager.py +2 -0
  2. evalscope/backend/vlm_eval_kit/backend_manager.py +1 -1
  3. evalscope/benchmarks/benchmark.py +1 -1
  4. evalscope/evaluator/evaluator.py +3 -3
  5. evalscope/models/api/__init__.py +3 -0
  6. evalscope/models/api/openai_api.py +228 -0
  7. evalscope/perf/http_client.py +5 -5
  8. evalscope/third_party/longbench_write/__init__.py +3 -0
  9. evalscope/third_party/longbench_write/eval.py +284 -0
  10. evalscope/third_party/longbench_write/infer.py +217 -0
  11. evalscope/third_party/longbench_write/longbench_write.py +88 -0
  12. evalscope/third_party/longbench_write/resources/__init__.py +1 -0
  13. evalscope/third_party/longbench_write/resources/judge.txt +31 -0
  14. evalscope/third_party/longbench_write/resources/longbench_write.jsonl +120 -0
  15. evalscope/third_party/longbench_write/resources/longbench_write_en.jsonl +60 -0
  16. evalscope/third_party/longbench_write/resources/longwrite_ruler.jsonl +48 -0
  17. evalscope/third_party/longbench_write/tools/__init__.py +1 -0
  18. evalscope/third_party/longbench_write/tools/data_etl.py +155 -0
  19. evalscope/third_party/longbench_write/utils.py +37 -0
  20. evalscope/version.py +2 -2
  21. {evalscope-0.5.3.dist-info → evalscope-0.5.4.dist-info}/METADATA +24 -32
  22. {evalscope-0.5.3.dist-info → evalscope-0.5.4.dist-info}/RECORD +25 -11
  23. {evalscope-0.5.3.dist-info → evalscope-0.5.4.dist-info}/WHEEL +0 -0
  24. {evalscope-0.5.3.dist-info → evalscope-0.5.4.dist-info}/entry_points.txt +0 -0
  25. {evalscope-0.5.3.dist-info → evalscope-0.5.4.dist-info}/top_level.txt +0 -0
@@ -242,4 +242,6 @@ if __name__ == '__main__':
242
242
  'limit': 5
243
243
  }
244
244
  )
245
+ all_datasets = OpenCompassBackendManager.list_datasets()
246
+ print(f'all_datasets: {all_datasets}')
245
247
  oc_backend_manager.run()
@@ -1,5 +1,5 @@
1
1
  from typing import Optional, Union
2
- from evalscope.utils import is_module_installed, get_module_path, get_valid_list, yaml_to_dict, json_to_dict
2
+ from evalscope.utils import is_module_installed, get_valid_list
3
3
  from evalscope.backend.base import BackendManager
4
4
  from evalscope.utils.logger import get_logger
5
5
  from functools import partial
@@ -46,7 +46,7 @@ class Benchmark(object):
46
46
 
47
47
  dataset.dataset_name = dataset_name.split('/')[-1]
48
48
  dataset.subset_name = subset
49
- dataset.split = split
49
+ # dataset.split = split
50
50
  return dataset
51
51
  elif hub == 'HuggingFace':
52
52
  # TODO: implement this by xingjun.wxj@alibaba-inc.com
@@ -244,8 +244,8 @@ class Evaluator(object):
244
244
  answer_d[AnswerKeys.ORIGIN_PROMPT] = input_prompt
245
245
 
246
246
  if debug:
247
- logger.debug(f'**input_prompt: {json.dumps(input_prompt, ensure_ascii=False)} \n')
248
- logger.debug(f'**predicted ans: {json.dumps(answer_d, ensure_ascii=False)} \n')
247
+ logger.info(f'**input_prompt: {json.dumps(input_prompt, ensure_ascii=False)} \n')
248
+ logger.info(f'**predicted ans: {json.dumps(answer_d, ensure_ascii=False)} \n')
249
249
 
250
250
  answers_list.append(answer_d)
251
251
 
@@ -349,7 +349,7 @@ class Evaluator(object):
349
349
  review_d = self._get_review(answer_d=answer_d, review_id=review_id, reviewer_spec=reviewer_spec)
350
350
 
351
351
  if debug:
352
- logger.debug(review_d)
352
+ logger.info(review_d)
353
353
 
354
354
  reviews_list.append(review_d)
355
355
 
@@ -0,0 +1,3 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from evalscope.models.api.openai_api import OpenaiApi
@@ -0,0 +1,228 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import json
4
+ import threading
5
+ import time
6
+ from asyncio import Queue
7
+
8
+ import requests
9
+ from typing import Union, List, Optional, Dict
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from modelscope.utils.logger import get_logger
12
+
13
+ logger = get_logger()
14
+
15
+
16
+ class OpenaiApi:
17
+
18
+ def __init__(self,
19
+ model: str,
20
+ openai_api_key,
21
+ openai_api_base,
22
+ logprobs: Optional[bool] = False,
23
+ top_logprobs: Optional[int] = None,
24
+ max_new_tokens: int = 4096,
25
+ temperature: Optional[float] = 0.0,
26
+ repetition_penalty: Optional[float] = 1.0,
27
+ is_chat: bool = True,
28
+ verbose: bool = True,
29
+ retry: int = 3,
30
+ query_per_second: int = 10, # TODO
31
+ **kwargs):
32
+
33
+ self.temperature = temperature
34
+ self.repetition_penalty = repetition_penalty
35
+ self.max_tokens = max_new_tokens
36
+ self.logprobs = logprobs
37
+ self.top_logprobs = top_logprobs
38
+
39
+ self.openai_api_key = openai_api_key
40
+ self.url = openai_api_base
41
+ self.model = model
42
+ self.is_chat = is_chat
43
+ self.retry = retry
44
+ self.verbose = verbose
45
+
46
+ self.token_bucket = TokenBucket(query_per_second, verbose)
47
+
48
+ def generate_simple(self, inputs: Union[List[str]]):
49
+
50
+ def process_one(in_data: str):
51
+
52
+ if self.is_chat:
53
+ data = dict(
54
+ model=self.model,
55
+ messages=[{'role': 'user', 'content': in_data}],
56
+ max_tokens=self.max_tokens,
57
+ n=1,
58
+ logprobs=self.logprobs,
59
+ top_logprobs=self.top_logprobs,
60
+ stop=None,
61
+ temperature=self.temperature,
62
+ repetition_penalty=self.repetition_penalty,
63
+ )
64
+ else:
65
+ data = dict(
66
+ model=self.model,
67
+ prompt=in_data,
68
+ max_tokens=self.max_tokens,
69
+ temperature=self.temperature,
70
+ repetition_penalty=self.repetition_penalty,
71
+ )
72
+
73
+ # todo
74
+ openai_api_key = self.openai_api_key or ''
75
+ header = {'Authorization': f'Bearer ', 'content-type': 'application/json', }
76
+ data = json.dumps(data, ensure_ascii=False)
77
+
78
+ if self.verbose:
79
+ print(f'>>data in generate_simple: {data}')
80
+
81
+ resp = requests.post(self.url, headers=header, data=data)
82
+ resp = resp.json()
83
+ if self.verbose:
84
+ print(f'>>resp in generate_simple: {resp}')
85
+
86
+ if self.logprobs:
87
+ return resp['choices']
88
+ else:
89
+ if self.is_chat:
90
+ return resp['choices'][0]['message']['content'].strip()
91
+ else:
92
+ return resp['choices'][0]['text'].strip()
93
+
94
+ with ThreadPoolExecutor() as executor:
95
+ results = list(executor.map(process_one, inputs))
96
+
97
+ return results
98
+
99
+ def generate(self,
100
+ inputs: Union[List[str], List[List]],
101
+ **kwargs) -> List[str]:
102
+ """
103
+ Generate responses from OpenAI API.
104
+
105
+ Args:
106
+ inputs: The input messages for the model. It can be a string or a list of messages.
107
+ e.g. ['who are you ?', 'what is your name ?']
108
+ e.g. [[{'role': 'user', 'content': 'who are you ?'}], ...]
109
+ kwargs: The optional arguments for the model.
110
+ """
111
+ results = []
112
+ # with ThreadPoolExecutor() as executor:
113
+ # results = list(executor.map(self._generate, inputs))
114
+
115
+ for input in inputs:
116
+ results.append(self._generate(input))
117
+
118
+ return results
119
+
120
+ def _generate(self, messages: Union[str, List[Dict]]) -> str:
121
+
122
+ if isinstance(messages, str):
123
+ messages = [{'role': 'user', 'content': messages}]
124
+
125
+ max_num_retries = 0
126
+ while max_num_retries < self.retry:
127
+ # self.wait()
128
+
129
+ header = {
130
+ 'Authorization': f'Bearer {self.openai_api_key}',
131
+ 'content-type': 'application/json',
132
+ }
133
+
134
+ try:
135
+ if self.is_chat:
136
+ data = dict(
137
+ model=self.model,
138
+ messages=messages,
139
+ max_tokens=self.max_tokens,
140
+ n=1,
141
+ logprobs=self.logprobs,
142
+ top_logprobs=self.top_logprobs,
143
+ stop=None,
144
+ temperature=self.temperature,
145
+ repetition_penalty=self.repetition_penalty,
146
+ )
147
+ else:
148
+ # TODO: This is a temporary solution for non-chat models.
149
+ input_prompts = []
150
+ for msg in messages:
151
+ input_prompts.append(msg['content'])
152
+
153
+ data = dict(
154
+ model=self.model,
155
+ prompt='\n'.join(input_prompts),
156
+ max_tokens=self.max_tokens,
157
+ temperature=self.temperature,
158
+ repetition_penalty=self.repetition_penalty,
159
+ )
160
+
161
+ def remove_none_val(input_d: dict):
162
+ return {k: v for k, v in input_d.items() if v is not None}
163
+ data = remove_none_val(data)
164
+
165
+ if self.verbose:
166
+ logger.info(f'>> Post data: {json.dumps(data, ensure_ascii=False)}')
167
+ raw_response = requests.post(self.url,
168
+ headers=header,
169
+ data=json.dumps(data, ensure_ascii=False))
170
+
171
+ response = raw_response.json()
172
+ if self.verbose:
173
+ logger.info(f'>> response: {response}')
174
+
175
+ if self.logprobs:
176
+ return response['choices']
177
+ else:
178
+ if self.is_chat:
179
+ return response['choices'][0]['message']['content'].strip()
180
+ else:
181
+ return response['choices'][0]['text'].strip()
182
+
183
+ except Exception as e:
184
+ logger.error(f'Error occurs: {str(e)}')
185
+ max_num_retries += 1
186
+ continue
187
+
188
+ def wait(self):
189
+ return self.token_bucket.get_token()
190
+
191
+
192
+ class TokenBucket:
193
+ """A token bucket for rate limiting.
194
+
195
+ Args:
196
+ query_per_second (float): The rate of the token bucket.
197
+ """
198
+
199
+ def __init__(self, rate, verbose=False):
200
+ self._rate = rate
201
+ self._tokens = threading.Semaphore(0)
202
+ self.started = False
203
+ self._request_queue = Queue()
204
+ self.logger = get_logger()
205
+ self.verbose = verbose
206
+
207
+ def _add_tokens(self):
208
+ """Add tokens to the bucket."""
209
+ while True:
210
+ if self._tokens._value < self._rate:
211
+ self._tokens.release()
212
+ time.sleep(1 / self._rate)
213
+
214
+ def get_token(self):
215
+ """Get a token from the bucket."""
216
+ if not self.started:
217
+ self.started = True
218
+ threading.Thread(target=self._add_tokens, daemon=True).start()
219
+ self._tokens.acquire()
220
+ if self.verbose:
221
+ cur_time = time.time()
222
+ while not self._request_queue.empty():
223
+ if cur_time - self._request_queue.queue[0] > 60:
224
+ self._request_queue.get()
225
+ else:
226
+ break
227
+ self._request_queue.put(cur_time)
228
+ self.logger.info(f'Current RPM {self._request_queue.qsize()}.')
@@ -51,15 +51,15 @@ UNLIMITED_RATE = -1
51
51
 
52
52
 
53
53
  async def on_request_start(session, context, params):
54
- logger.debug(f'Starting request: <{params}>')
54
+ logger.info(f'Starting request: <{params}>')
55
55
 
56
56
 
57
57
  async def on_request_chunk_sent(session, context, params):
58
- logger.debug(f'Request body: {params}')
58
+ logger.info(f'Request body: {params}')
59
59
 
60
60
 
61
61
  async def on_response_chunk_received(session, context, params):
62
- logger.debug(f'Response info: <{params}>')
62
+ logger.info(f'Response info: <{params}>')
63
63
 
64
64
 
65
65
  class AioHttpClient:
@@ -116,7 +116,7 @@ class AioHttpClient:
116
116
  line = line.decode("utf8")
117
117
  line = line.rstrip("\n").rstrip("\r")
118
118
  if self.debug:
119
- logger.debug(line)
119
+ logger.info(line)
120
120
  sse_msg = ServerSentEvent.decode(line)
121
121
  if not sse_msg:
122
122
  continue
@@ -567,7 +567,7 @@ async def send_requests_worker(task_id, request_queue: asyncio.Queue, benchmark_
567
567
  else:
568
568
  if response_data:
569
569
  collected_messages.append(response_data) # save the message
570
- logger.debug(response_data)
570
+ logger.info(response_data)
571
571
  benchmark_data["chunk_times"].append(time.perf_counter())
572
572
 
573
573
  benchmark_data["response_messages"] = collected_messages
@@ -0,0 +1,3 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from evalscope.third_party.longbench_write.longbench_write import run_task
@@ -0,0 +1,284 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Copyright (c) ZhipuAI, Inc. and its affiliates.
3
+ import multiprocessing
4
+ import os
5
+ import json
6
+ import random
7
+ import re
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import requests
13
+ from tqdm import tqdm
14
+
15
+ from evalscope.utils import jsonl_to_list
16
+ from evalscope.utils import get_logger
17
+
18
+ logger = get_logger()
19
+
20
+ """
21
+ This script is used to evaluate results of predictions for the LongWriter model.
22
+ Refer to https://github.com/THUDM/LongWriter for more details.
23
+
24
+ EvalLength:
25
+ Evaluate the length of the generated responses.
26
+ Metrics:
27
+ score_l: The average score of the length evaluation.
28
+
29
+ EvalQuality:
30
+ Evaluate the quality of the generated responses by using Judge Model.
31
+ Metrics:
32
+ score_q: The average score of the quality evaluation.
33
+ """
34
+
35
+
36
+ class EvalLength:
37
+
38
+ EVAL_L = 'eval_length'
39
+
40
+ def __init__(self, model: str, pred_path: str, output_dir: str):
41
+ self.model = model
42
+ self.pred_path = pred_path
43
+ self.output_dir = output_dir
44
+
45
+ self.model_id_path = self.model.strip(os.sep).replace(os.sep, '__')
46
+
47
+ @staticmethod
48
+ def score(x, y):
49
+ if y > x:
50
+ return 100 * max(0, 1. - (y / x - 1) / 3)
51
+ else:
52
+ return 100 * max(0, 1. - (x / y - 1) / 2)
53
+
54
+ def eval(self, dump_res: bool = True):
55
+ # example = {"prompt": "Write an outline for a short 100-word blog post about xxx", "type": "Community Forum", "length": 100, "response_length": 103, "response": "I. Introduction A. xxx"}
56
+ predictions = [json.loads(line) for line in open(self.pred_path, encoding='utf-8')]
57
+ x, y, scores = [], [], []
58
+
59
+ for pred in tqdm(predictions, total=len(predictions), desc=f'Process of eval_l: '):
60
+ x.append(pred["length"])
61
+ y.append(pred["response_length"])
62
+ scores.append(self.score(pred["length"], pred["response_length"]))
63
+
64
+ avg_score_l = np.mean(scores)
65
+ logger.info(f'Average score of length evaluation: {avg_score_l:.2f}')
66
+
67
+ # Dump to output file
68
+ if dump_res:
69
+ output_res_path = f'{self.output_dir}/{self.model_id_path}/{self.EVAL_L}.jsonl'
70
+ with open(output_res_path, 'w') as f:
71
+ f.write(json.dumps({'score_l': avg_score_l, 'scores': scores}, ensure_ascii=False) + '\n')
72
+ logger.info(f'Successfully dumped evaluation results to {output_res_path}')
73
+
74
+ return x, y, scores
75
+
76
+ def plot(self, x: list, y: list):
77
+ plt = self.plot_img(x, y)
78
+ output_pic_path = f'{self.output_dir}/{self.model_id_path}/eval_length_scatter.png'
79
+ plt.savefig(output_pic_path)
80
+ logger.info(f'Successfully saved scatter plot to {output_pic_path}')
81
+
82
+ @staticmethod
83
+ def plot_img(x: list, y: list):
84
+ # set plt size 6x6
85
+ plt.figure(figsize=(6, 6))
86
+ lmt = 25000
87
+ # plot x, y
88
+ plt.scatter(x, y, s=100, c='r', alpha=0.3)
89
+ # plot x=y
90
+ plt.plot([0, lmt], [0, lmt], 'k--')
91
+ plt.xscale('log')
92
+ plt.yscale('log')
93
+ plt.xlim(50, lmt)
94
+ plt.ylim(50, lmt)
95
+ plt.xlabel('Required Length', fontsize=20, fontweight='bold')
96
+ plt.ylabel('Output Length', fontsize=20, fontweight='bold')
97
+ plt.xticks(fontsize=24)
98
+ plt.yticks(fontsize=24)
99
+ plt.tight_layout()
100
+
101
+ return plt
102
+
103
+
104
+ class EvalQuality:
105
+
106
+ EVAL_Q = 'eval_quality'
107
+ OPENAI_BASE_URL = 'https://api.openai.com/v1/chat/completions'
108
+ DIMS = ["Relevance", "Accuracy", "Coherence", "Clarity", "Breadth and Depth", "Reading Experience"]
109
+
110
+ def __init__(self,
111
+ model: str,
112
+ pred_path: str,
113
+ output_dir: str,
114
+ prompt_template_path: str,
115
+ openai_api_key: str = None,
116
+ openai_api_base: str = OPENAI_BASE_URL,
117
+ openai_gpt_model: str = 'gpt-4o-2024-05-13',
118
+ generation_kwargs: dict = None,
119
+ proc_num: int = 8):
120
+
121
+ self.model = model
122
+ self.openai_api_base = openai_api_base
123
+ self.pred_path = pred_path
124
+ self.output_dir = output_dir
125
+ self.proc_num = proc_num
126
+ self.eval_scores = []
127
+
128
+ assert os.path.exists(self.pred_path), f'Prediction file not found: {self.pred_path}'
129
+
130
+ # Default: temperature=0.5, max_new_tokens=1024, stop=None
131
+ if generation_kwargs is None:
132
+ self.generation_kwargs = dict({
133
+ 'max_new_tokens': 1024,
134
+ 'temperature': 0.5,
135
+ 'stop': None,
136
+ })
137
+ else:
138
+ self.generation_kwargs = generation_kwargs
139
+
140
+ self.prompt_template: str = open(prompt_template_path, 'r', encoding='utf-8').read()
141
+
142
+ self.model_id_path = self.model.strip(os.sep).replace(os.sep, '__')
143
+ self.output_res_path = f'{self.output_dir}/{self.model_id_path}/{self.EVAL_Q}.jsonl'
144
+
145
+ self.openai_api_key: str = openai_api_key
146
+ self.openai_gpt_model = openai_gpt_model
147
+ assert self.openai_api_key, 'Please set `OPENAI_API_KEY` in environment variables.'
148
+
149
+ def get_response_gpt4(self, prompt, temperature=0.5, max_new_tokens=1024, stop=None):
150
+ tries = 0
151
+ while tries < 1:
152
+ tries += 1
153
+ try:
154
+ headers = {
155
+ 'Authorization': "Bearer {}".format(self.openai_api_key),
156
+ }
157
+ messages = [
158
+ {'role': 'user', 'content': prompt},
159
+ ]
160
+ resp = requests.post(self.openai_api_base, json={
161
+ "model": self.openai_gpt_model,
162
+ "messages": messages,
163
+ "temperature": temperature,
164
+ "max_tokens": max_new_tokens,
165
+ "stop": stop,
166
+ }, headers=headers, timeout=600)
167
+ if resp.status_code != 200:
168
+ raise Exception(resp.text)
169
+ resp = resp.json()
170
+ logger.info(f'>>gpt resp: {resp}')
171
+ break
172
+ except KeyboardInterrupt as e:
173
+ raise e
174
+ except Exception as e:
175
+ if "maximum context length" in str(e):
176
+ raise e
177
+ elif "triggering" in str(e):
178
+ return 'Trigger OpenAI\'s content management policy'
179
+ logger.error("Error Occurs: \"%s\" Retry ..." % (str(e)))
180
+ else:
181
+ logger.error("Max tries. Failed.")
182
+ return "Max tries. Failed."
183
+ try:
184
+ return resp["choices"][0]["message"]["content"]
185
+ except:
186
+ return ''
187
+
188
+ @staticmethod
189
+ def extract_info(pattern, text):
190
+ match = re.search(pattern, text, re.DOTALL)
191
+ if match:
192
+ return match.group(1)
193
+ else:
194
+ return None
195
+
196
+ def process_data(self, item):
197
+ # for item in tqdm(items, total=len(items), desc=f'Process of eval_q: '):
198
+ prompt = self.prompt_template.replace('$INST$', item['prompt']).replace('$RESPONSE$', item["response"])
199
+ scores = None
200
+ output = self.get_response_gpt4(prompt, **self.generation_kwargs)
201
+ try:
202
+ if '```json' in output:
203
+ output = self.extract_info(r'```json\n(.*?)\n```', output)
204
+ output = output.replace('\n', '')
205
+ scores = json.loads(output)
206
+ for dim in self.DIMS:
207
+ if dim not in scores:
208
+ logger.warning(f'Cannot find score for dimension: {dim} in scores {scores}.')
209
+ scores = None
210
+ except Exception as e:
211
+ logger.error(f'Error occurs during process data: {str(e)}')
212
+
213
+ if scores is None:
214
+ logger.error(f'Failed to extract scores for item: {item}')
215
+ else:
216
+ logger.info(f'>>scores: {scores}')
217
+ item['scores'] = scores
218
+
219
+ return item
220
+
221
+ def eval(self):
222
+
223
+ data_all = jsonl_to_list(self.pred_path)
224
+ total = len(data_all)
225
+ assert total > 0, f'No data found in prediction file: {self.pred_path}'
226
+
227
+ random.shuffle(data_all)
228
+
229
+ with ThreadPoolExecutor() as executor:
230
+ self.eval_scores = list(executor.map(self.process_data, data_all))
231
+
232
+ # self.process_data(items=data)
233
+ logger.info(f'>>self.eval_scores: {self.eval_scores}')
234
+
235
+ total_score = dict()
236
+ for dim in self.DIMS:
237
+ # scores = [float(score[dim]) if dim in score else 3 for score in self.eval_scores]
238
+ scores = [float(item['scores'][dim]) if 'scores' in item and dim in item['scores'] else 3 for item in self.eval_scores]
239
+ total_score[dim] = ((sum(scores) / len(scores)) - 1) * 25
240
+ total_score['total'] = sum(total_score.values()) / len(total_score)
241
+ logger.info(f'Total score of quality evaluation: {total_score["total"]:.2f}')
242
+
243
+ output_res_path: str = f'{self.output_dir}/{self.model_id_path}/{self.EVAL_Q}.jsonl'
244
+ with open(output_res_path, 'w', encoding='utf-8') as fout:
245
+ fout.write(json.dumps(total_score, ensure_ascii=False) + '\n')
246
+
247
+
248
+ def run_eval(model: str,
249
+ pred_path: str,
250
+ output_dir: str,
251
+ prompt_template_path: str,
252
+ openai_api_key: str,
253
+ openai_api_base: str,
254
+ openai_gpt_model: str,
255
+ generation_kwargs: dict,
256
+ proc_num: int,
257
+ stage: list,
258
+ ):
259
+ logger.info(f'Got eval stages: {stage}')
260
+
261
+ if 'eval_l' in stage:
262
+ logger.info(f'Processing evaluation of length for model: {model}')
263
+ eval_length = EvalLength(model=model,
264
+ pred_path=pred_path,
265
+ output_dir=output_dir)
266
+ x, y, _ = eval_length.eval()
267
+ eval_length.plot(x, y)
268
+ else:
269
+ logger.warning(f'*** Skip `eval_l` stage ***')
270
+
271
+ if 'eval_q' in stage:
272
+ logger.info(f'Processing evaluation of quality for model: {model}')
273
+ eval_quality = EvalQuality(model=model,
274
+ pred_path=pred_path,
275
+ output_dir=output_dir,
276
+ prompt_template_path=prompt_template_path,
277
+ openai_api_key=openai_api_key,
278
+ openai_api_base=openai_api_base,
279
+ openai_gpt_model=openai_gpt_model,
280
+ generation_kwargs=generation_kwargs,
281
+ proc_num=proc_num)
282
+ eval_quality.eval()
283
+ else:
284
+ logger.warning('*** Skip `eval_q` stage ***')