evalscope 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/__init__.py +3 -0
- evalscope/backend/__init__.py +3 -0
- evalscope/backend/base.py +27 -0
- evalscope/backend/opencompass/__init__.py +3 -0
- evalscope/backend/opencompass/api_meta_template.py +64 -0
- evalscope/backend/opencompass/backend_manager.py +247 -0
- evalscope/backend/opencompass/tasks/__init__.py +1 -0
- evalscope/backend/opencompass/tasks/eval_api.py +30 -0
- evalscope/backend/opencompass/tasks/eval_datasets.py +71 -0
- evalscope/backend/vlm_eval_kit/__init__.py +1 -0
- evalscope/backend/vlm_eval_kit/backend_manager.py +153 -0
- evalscope/benchmarks/__init__.py +4 -0
- evalscope/benchmarks/arc/__init__.py +5 -0
- evalscope/benchmarks/arc/ai2_arc.py +148 -0
- evalscope/benchmarks/arc/arc_adapter.py +231 -0
- evalscope/benchmarks/bbh/__init__.py +6 -0
- evalscope/benchmarks/bbh/bbh_adapter.py +308 -0
- evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +23 -0
- evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +33 -0
- evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +72 -0
- evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +78 -0
- evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +42 -0
- evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/navigate.txt +43 -0
- evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +41 -0
- evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +63 -0
- evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/snarks.txt +30 -0
- evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +10 -0
- evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +77 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +17 -0
- evalscope/benchmarks/benchmark.py +65 -0
- evalscope/benchmarks/ceval/__init__.py +5 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +340 -0
- evalscope/benchmarks/ceval/ceval_exam.py +159 -0
- evalscope/benchmarks/cmmlu/__init__.py +5 -0
- evalscope/benchmarks/cmmlu/cmmlu.py +166 -0
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +369 -0
- evalscope/benchmarks/competition_math/__init__.py +5 -0
- evalscope/benchmarks/competition_math/competition_math.py +88 -0
- evalscope/benchmarks/competition_math/competition_math_adapter.py +470 -0
- evalscope/benchmarks/data_adapter.py +263 -0
- evalscope/benchmarks/general_qa/__init__.py +5 -0
- evalscope/benchmarks/general_qa/general_qa_adapter.py +186 -0
- evalscope/benchmarks/gsm8k/__init__.py +5 -0
- evalscope/benchmarks/gsm8k/gsm8k.py +127 -0
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +236 -0
- evalscope/benchmarks/hellaswag/__init__.py +5 -0
- evalscope/benchmarks/hellaswag/hellaswag.py +116 -0
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +222 -0
- evalscope/benchmarks/humaneval/__init__.py +5 -0
- evalscope/benchmarks/humaneval/humaneval.py +82 -0
- evalscope/benchmarks/humaneval/humaneval_adapter.py +21 -0
- evalscope/benchmarks/mmlu/__init__.py +5 -0
- evalscope/benchmarks/mmlu/mmlu.py +174 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +375 -0
- evalscope/benchmarks/race/__init__.py +5 -0
- evalscope/benchmarks/race/race.py +118 -0
- evalscope/benchmarks/race/race_adapter.py +229 -0
- evalscope/benchmarks/trivia_qa/__init__.py +5 -0
- evalscope/benchmarks/trivia_qa/trivia_qa.py +104 -0
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +207 -0
- evalscope/benchmarks/truthful_qa/__init__.py +5 -0
- evalscope/benchmarks/truthful_qa/truthful_qa.py +167 -0
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +351 -0
- evalscope/cache.py +98 -0
- evalscope/cli/__init__.py +1 -0
- evalscope/cli/base.py +20 -0
- evalscope/cli/cli.py +26 -0
- evalscope/cli/start_perf.py +37 -0
- evalscope/cli/start_server.py +138 -0
- evalscope/config.py +165 -0
- evalscope/constants.py +150 -0
- evalscope/evaluator/__init__.py +3 -0
- evalscope/evaluator/evaluator.py +689 -0
- evalscope/evaluator/rating_eval.py +178 -0
- evalscope/evaluator/reviewer/__init__.py +1 -0
- evalscope/evaluator/reviewer/auto_reviewer.py +411 -0
- evalscope/metrics/__init__.py +1 -0
- evalscope/metrics/bundled_rouge_score/__init__.py +14 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +342 -0
- evalscope/metrics/code_metric.py +104 -0
- evalscope/metrics/math_accuracy.py +60 -0
- evalscope/metrics/metrics.py +405 -0
- evalscope/metrics/rouge_metric.py +129 -0
- evalscope/models/__init__.py +4 -0
- evalscope/models/custom/__init__.py +4 -0
- evalscope/models/custom/custom_model.py +53 -0
- evalscope/models/dummy_chat_model.py +50 -0
- evalscope/models/model.py +88 -0
- evalscope/models/model_adapter.py +586 -0
- evalscope/models/openai_model.py +103 -0
- evalscope/models/template.py +1446 -0
- evalscope/perf/__init__.py +0 -0
- evalscope/perf/_logging.py +32 -0
- evalscope/perf/api_plugin_base.py +60 -0
- evalscope/perf/custom_api.py +87 -0
- evalscope/perf/dashscope_api.py +84 -0
- evalscope/perf/dataset_plugin_base.py +64 -0
- evalscope/perf/datasets/__init__.py +0 -0
- evalscope/perf/datasets/line_by_line.py +18 -0
- evalscope/perf/datasets/longalpaca_12k.py +20 -0
- evalscope/perf/datasets/openqa.py +22 -0
- evalscope/perf/how_to_analysis_result.py +24 -0
- evalscope/perf/http_client.py +756 -0
- evalscope/perf/openai_api.py +130 -0
- evalscope/perf/plugin_registry.py +35 -0
- evalscope/perf/query_parameters.py +42 -0
- evalscope/perf/server_sent_event.py +43 -0
- evalscope/preprocess/__init__.py +1 -0
- evalscope/preprocess/tokenizers/__init__.py +0 -0
- evalscope/preprocess/tokenizers/gpt2_tokenizer.py +221 -0
- evalscope/registry/__init__.py +1 -0
- evalscope/registry/tasks/arc.yaml +29 -0
- evalscope/registry/tasks/bbh.yaml +27 -0
- evalscope/registry/tasks/bbh_mini.yaml +27 -0
- evalscope/registry/tasks/ceval.yaml +27 -0
- evalscope/registry/tasks/ceval_mini.yaml +27 -0
- evalscope/registry/tasks/cmmlu.yaml +27 -0
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +28 -0
- evalscope/registry/tasks/general_qa.yaml +27 -0
- evalscope/registry/tasks/gsm8k.yaml +29 -0
- evalscope/registry/tasks/mmlu.yaml +29 -0
- evalscope/registry/tasks/mmlu_mini.yaml +27 -0
- evalscope/run.py +404 -0
- evalscope/run_arena.py +204 -0
- evalscope/run_ms.py +140 -0
- evalscope/summarizer.py +144 -0
- evalscope/third_party/__init__.py +1 -0
- evalscope/third_party/toolbench_static/__init__.py +3 -0
- evalscope/third_party/toolbench_static/eval.py +219 -0
- evalscope/third_party/toolbench_static/infer.py +278 -0
- evalscope/third_party/toolbench_static/llm/__init__.py +1 -0
- evalscope/third_party/toolbench_static/llm/swift_infer.py +45 -0
- evalscope/third_party/toolbench_static/toolbench_static.py +50 -0
- evalscope/tools/__init__.py +1 -0
- evalscope/tools/combine_reports.py +140 -0
- evalscope/tools/gen_mmlu_subject_mapping.py +90 -0
- evalscope/tools/rewrite_eval_results.py +95 -0
- evalscope/utils/__init__.py +4 -0
- evalscope/utils/arena_utils.py +247 -0
- evalscope/utils/completion_parsers.py +87 -0
- evalscope/utils/logger.py +64 -0
- evalscope/utils/task_cfg_parser.py +10 -0
- evalscope/utils/task_utils.py +19 -0
- evalscope/utils/utils.py +625 -0
- evalscope/version.py +4 -0
- evalscope-0.5.0.dist-info/METADATA +566 -0
- evalscope-0.5.0.dist-info/RECORD +165 -0
- evalscope-0.5.0.dist-info/WHEEL +5 -0
- evalscope-0.5.0.dist-info/entry_points.txt +3 -0
- evalscope-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,756 @@
|
|
|
1
|
+
"""LLM performance benchmark client.
|
|
2
|
+
"""
|
|
3
|
+
import argparse
|
|
4
|
+
import asyncio
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
import functools
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import platform
|
|
10
|
+
import signal
|
|
11
|
+
import sqlite3
|
|
12
|
+
import os
|
|
13
|
+
import time
|
|
14
|
+
import base64
|
|
15
|
+
import pickle
|
|
16
|
+
import importlib.util
|
|
17
|
+
import sys
|
|
18
|
+
import platform
|
|
19
|
+
from typing import List, Dict, Optional
|
|
20
|
+
from datetime import datetime, timezone
|
|
21
|
+
import aiohttp
|
|
22
|
+
from http import HTTPStatus
|
|
23
|
+
import aiohttp
|
|
24
|
+
import numpy as np
|
|
25
|
+
from evalscope.perf.plugin_registry import api_registry, dataset_registry
|
|
26
|
+
from evalscope.perf.query_parameters import QueryParameters
|
|
27
|
+
from evalscope.perf.server_sent_event import ServerSentEvent
|
|
28
|
+
# for plugin registry
|
|
29
|
+
from evalscope.perf.dashscope_api import DashScopeApiPlugin
|
|
30
|
+
from evalscope.perf.openai_api import OpenaiPlugin
|
|
31
|
+
from evalscope.perf.datasets.line_by_line import LineByLineDatasetPlugin
|
|
32
|
+
from evalscope.perf.datasets.longalpaca_12k import LongAlpacaDatasetPlugin
|
|
33
|
+
from evalscope.perf.datasets.openqa import OpenqaDatasetPlugin
|
|
34
|
+
from evalscope.perf.custom_api import CustomPlugin
|
|
35
|
+
from evalscope.perf._logging import logger
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
DashScopeApiPlugin,
|
|
39
|
+
OpenaiPlugin,
|
|
40
|
+
CustomPlugin,
|
|
41
|
+
LineByLineDatasetPlugin,
|
|
42
|
+
LongAlpacaDatasetPlugin,
|
|
43
|
+
OpenqaDatasetPlugin,
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
_query_send_completed = False
|
|
47
|
+
_data_process_completed = False
|
|
48
|
+
_table_name = "result"
|
|
49
|
+
|
|
50
|
+
UNLIMITED_RATE = -1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def on_request_start(session, context, params):
|
|
54
|
+
logger.debug(f'Starting request: <{params}>')
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async def on_request_chunk_sent(session, context, params):
|
|
58
|
+
logger.debug(f'Request body: {params}')
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
async def on_response_chunk_received(session, context, params):
|
|
62
|
+
logger.debug(f'Response info: <{params}>')
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class AioHttpClient:
|
|
66
|
+
def __init__(self,
|
|
67
|
+
url: str,
|
|
68
|
+
conn_timeout: int = 120,
|
|
69
|
+
read_timeout: int = 120,
|
|
70
|
+
headers: Dict = None,
|
|
71
|
+
debug: bool = False):
|
|
72
|
+
# one client only has one connection
|
|
73
|
+
client_timeout = aiohttp.ClientTimeout(total=read_timeout + conn_timeout,
|
|
74
|
+
connect=conn_timeout,
|
|
75
|
+
sock_read=read_timeout)
|
|
76
|
+
self.debug = debug
|
|
77
|
+
if debug:
|
|
78
|
+
logger.setLevel(level=logging.DEBUG)
|
|
79
|
+
trace_config = aiohttp.TraceConfig()
|
|
80
|
+
trace_config.on_request_start.append(on_request_start)
|
|
81
|
+
trace_config.on_request_chunk_sent.append(on_request_chunk_sent)
|
|
82
|
+
# not support server sent event(stream=true)
|
|
83
|
+
trace_config.on_response_chunk_received.append(on_response_chunk_received)
|
|
84
|
+
self.client = aiohttp.ClientSession(trace_configs=[trace_config] if debug else [],
|
|
85
|
+
connector=aiohttp.TCPConnector(limit=1),
|
|
86
|
+
timeout=client_timeout)
|
|
87
|
+
ua = "modelscope_bench"
|
|
88
|
+
self.headers = {"user-agent": ua}
|
|
89
|
+
if headers:
|
|
90
|
+
self.headers.update(headers)
|
|
91
|
+
self.url = url
|
|
92
|
+
|
|
93
|
+
async def __aenter__(self):
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
97
|
+
await self.client.close()
|
|
98
|
+
|
|
99
|
+
async def aio_call(self):
|
|
100
|
+
response = self._handle_request()
|
|
101
|
+
if self.stream:
|
|
102
|
+
return (item async for item in response)
|
|
103
|
+
else:
|
|
104
|
+
result = await response.__anext__()
|
|
105
|
+
try:
|
|
106
|
+
await response.__anext__()
|
|
107
|
+
except StopAsyncIteration:
|
|
108
|
+
pass
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
async def _handle_stream(self, response):
|
|
112
|
+
is_error = False
|
|
113
|
+
status_code = response.status
|
|
114
|
+
async for line in response.content:
|
|
115
|
+
if line:
|
|
116
|
+
line = line.decode("utf8")
|
|
117
|
+
line = line.rstrip("\n").rstrip("\r")
|
|
118
|
+
if self.debug:
|
|
119
|
+
logger.debug(line)
|
|
120
|
+
sse_msg = ServerSentEvent.decode(line)
|
|
121
|
+
if not sse_msg:
|
|
122
|
+
continue
|
|
123
|
+
if sse_msg.event and sse_msg.event == "error": # dashscope error
|
|
124
|
+
is_error = True
|
|
125
|
+
|
|
126
|
+
if sse_msg.data:
|
|
127
|
+
if sse_msg.data.startswith("[DONE]"): # openai api completed
|
|
128
|
+
break
|
|
129
|
+
yield (is_error, status_code, sse_msg.data)
|
|
130
|
+
# yield data
|
|
131
|
+
|
|
132
|
+
async def _handle_response(self, response: aiohttp.ClientResponse):
|
|
133
|
+
if (response.status == HTTPStatus.OK and "text/event-stream" in response.content_type):
|
|
134
|
+
async for is_error, status_code, data in self._handle_stream(response):
|
|
135
|
+
yield (is_error, status_code, data)
|
|
136
|
+
elif response.status == HTTPStatus.OK and "application/json" in response.content_type:
|
|
137
|
+
content = await response.json()
|
|
138
|
+
if 'object' in content and content['object'] == 'error':
|
|
139
|
+
yield(True, content['code'], content['message'])
|
|
140
|
+
else:
|
|
141
|
+
yield (False, HTTPStatus.OK, json.dumps(content))
|
|
142
|
+
elif response.status == HTTPStatus.OK:
|
|
143
|
+
content = await response.read()
|
|
144
|
+
yield (False, HTTPStatus.OK, content)
|
|
145
|
+
else:
|
|
146
|
+
if "application/json" in response.content_type:
|
|
147
|
+
error = await response.json()
|
|
148
|
+
yield (True, response.status, json.dumps(error))
|
|
149
|
+
elif "text/event-stream" in response.content_type:
|
|
150
|
+
async for _, _, data in self._handle_stream(response):
|
|
151
|
+
error = json.loads(data)
|
|
152
|
+
yield (True, response.status, error)
|
|
153
|
+
else:
|
|
154
|
+
msg = await response.read()
|
|
155
|
+
yield (True, response.status, msg.decode('utf-8'))
|
|
156
|
+
|
|
157
|
+
async def post(self, body):
|
|
158
|
+
try:
|
|
159
|
+
headers = {"Content-Type": "application/json", **self.headers}
|
|
160
|
+
response = await self.client.request("POST",
|
|
161
|
+
url=self.url,
|
|
162
|
+
json=body,
|
|
163
|
+
headers=headers)
|
|
164
|
+
async with response:
|
|
165
|
+
async for rsp in self._handle_response(response):
|
|
166
|
+
yield rsp
|
|
167
|
+
except aiohttp.ClientConnectorError as e:
|
|
168
|
+
logger.error(e)
|
|
169
|
+
raise e
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(e)
|
|
172
|
+
raise e
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def dynamic_import_module(dynamic_module_file_path: str):
|
|
176
|
+
"""Dynamic import input output process python file.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
dynamic_module_file_path (str): The absolute path of the
|
|
180
|
+
input output process python path, or name of the format,
|
|
181
|
+
system support openai, dashscope format.
|
|
182
|
+
"""
|
|
183
|
+
module_name = 'module_request_response_parser'
|
|
184
|
+
|
|
185
|
+
dynamic_module_spec = importlib.util.spec_from_file_location(module_name, dynamic_module_file_path)
|
|
186
|
+
dynamic_module = importlib.util.module_from_spec(dynamic_module_spec)
|
|
187
|
+
sys.modules[module_name] = dynamic_module
|
|
188
|
+
dynamic_module_spec.loader.exec_module(dynamic_module)
|
|
189
|
+
return dynamic_module
|
|
190
|
+
|
|
191
|
+
def get_query_template(args):
|
|
192
|
+
if args.query_template.startswith('@'):
|
|
193
|
+
# read from file
|
|
194
|
+
with open(args.query_template[1:], 'r') as f:
|
|
195
|
+
content = f.read()
|
|
196
|
+
return content.strip()
|
|
197
|
+
return args.query_template.strip()
|
|
198
|
+
|
|
199
|
+
async def dispatch_requests_worker(request_queue: asyncio.Queue, args):
|
|
200
|
+
query_generator_class = api_registry(args.api)
|
|
201
|
+
if not query_generator_class:
|
|
202
|
+
print('Can not find query generator: %s'%args.api)
|
|
203
|
+
query_generator = query_generator_class(args.tokenizer_path)
|
|
204
|
+
total_query_counter = 0
|
|
205
|
+
query_parameters = QueryParameters(args)
|
|
206
|
+
if args.prompt is not None:
|
|
207
|
+
if args.prompt.startswith("@"): # read local as prompt, same as curl --data
|
|
208
|
+
with open(args.prompt, 'r', encoding='utf-8') as f:
|
|
209
|
+
prompt = f.read()
|
|
210
|
+
else:
|
|
211
|
+
prompt = args.prompt
|
|
212
|
+
messages = {'role': 'user', 'content': prompt}
|
|
213
|
+
request = query_generator.build_request(messages, query_parameters)
|
|
214
|
+
if args.number is None:
|
|
215
|
+
await request_queue.put(request)
|
|
216
|
+
else:
|
|
217
|
+
for i in range(args.number):
|
|
218
|
+
if args.rate == UNLIMITED_RATE:
|
|
219
|
+
await request_queue.put(request)
|
|
220
|
+
else:
|
|
221
|
+
interval = np.random.exponential(1.0 / args.rate)
|
|
222
|
+
# The next request will be sent after the interval.
|
|
223
|
+
await asyncio.sleep(interval)
|
|
224
|
+
await request_queue.put(request)
|
|
225
|
+
elif args.dataset_path is not None:
|
|
226
|
+
# Ensure sufficient quantity of queries.
|
|
227
|
+
while True:
|
|
228
|
+
message_generator_class = dataset_registry.get_class(args.dataset)
|
|
229
|
+
if not message_generator_class:
|
|
230
|
+
print('Can not find dataset: %s plugin.'%(args.dataset))
|
|
231
|
+
sys.exit(1)
|
|
232
|
+
message_generator = message_generator_class(query_parameters)
|
|
233
|
+
for messages in message_generator:
|
|
234
|
+
request = query_generator.build_request(messages, query_parameters)
|
|
235
|
+
if request is None:
|
|
236
|
+
continue
|
|
237
|
+
await request_queue.put(request)
|
|
238
|
+
total_query_counter += 1
|
|
239
|
+
if args.number is not None:
|
|
240
|
+
if total_query_counter >= args.number:
|
|
241
|
+
break
|
|
242
|
+
if args.rate == UNLIMITED_RATE: # on rate limit
|
|
243
|
+
continue
|
|
244
|
+
# Sample the request interval from the exponential distribution.
|
|
245
|
+
# from vllm
|
|
246
|
+
interval = np.random.exponential(1.0 / args.rate)
|
|
247
|
+
# The next request will be sent after the interval.
|
|
248
|
+
await asyncio.sleep(interval)
|
|
249
|
+
if args.number is None:
|
|
250
|
+
break
|
|
251
|
+
elif total_query_counter >= args.number:
|
|
252
|
+
break
|
|
253
|
+
else:
|
|
254
|
+
raise Exception("Prompt or dataset is required!")
|
|
255
|
+
return total_query_counter
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class BenchmarkData(dict):
|
|
259
|
+
"""Benchmark info, two parts
|
|
260
|
+
1. query info.
|
|
261
|
+
prompt length
|
|
262
|
+
2. response info
|
|
263
|
+
start send time
|
|
264
|
+
list of package_receive_time
|
|
265
|
+
package info.
|
|
266
|
+
response complete time
|
|
267
|
+
total response info(response tokens)
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def __init__(self):
|
|
271
|
+
pass
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def calculate_query_stream_metric(benchmark_data):
|
|
275
|
+
first_chunk_latency = benchmark_data["chunk_times"][0] - benchmark_data["start_time"] # the first chunk latency
|
|
276
|
+
n_chunks = len(benchmark_data["chunk_times"]) - 2 # minus first and last chunk.
|
|
277
|
+
n_chunks_time = benchmark_data["chunk_times"][-2] - benchmark_data["chunk_times"][0] # -2 to last chunk
|
|
278
|
+
return (first_chunk_latency, n_chunks, n_chunks_time)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
async def statistic_benchmark_metric_worker(benchmark_data_queue: asyncio.Queue, args):
|
|
282
|
+
"""Statistics of performance metrics based on performance data
|
|
283
|
+
"""
|
|
284
|
+
n_succeed_queries = 0
|
|
285
|
+
n_failed_queries = 0
|
|
286
|
+
total_first_chunk_latency = 0
|
|
287
|
+
total_latency = 0.0
|
|
288
|
+
n_total_chunks = 0
|
|
289
|
+
n_total_prompt_tokens = 0
|
|
290
|
+
n_total_completion_tokens = 0
|
|
291
|
+
qps = 0
|
|
292
|
+
concurrency = args.parallel
|
|
293
|
+
start_time = None
|
|
294
|
+
total_chunks_time = 0.0
|
|
295
|
+
avg_latency = -1
|
|
296
|
+
avg_first_chunk_latency = -1
|
|
297
|
+
avg_token_per_seconds = -1
|
|
298
|
+
avg_time_per_token = -1
|
|
299
|
+
n_avg_chunks = -1
|
|
300
|
+
avg_chunk_time = -1
|
|
301
|
+
avg_prompt_tokens = -1
|
|
302
|
+
avg_completion_tokens = -1
|
|
303
|
+
total_time = 1 # avoid divide by zero
|
|
304
|
+
n_total_queries = 0
|
|
305
|
+
# avg generate tps generated tokens / time
|
|
306
|
+
# avg chunk time, first latency - avg_chunk_time == first latency, 去掉第一个和最后一个,第一个和prefill合并了,最后一个生成token可能比较短
|
|
307
|
+
# avg prefill tps
|
|
308
|
+
# prefill time = 首包时间-avg_chunk_time
|
|
309
|
+
# n-tokens-per-trunk
|
|
310
|
+
n_benchmark_result = 0
|
|
311
|
+
api_plugin_class = api_registry(args.api)
|
|
312
|
+
if not api_plugin_class:
|
|
313
|
+
print('Can not find query generator: %s'%args.api)
|
|
314
|
+
api_plugin = api_plugin_class(args.tokenizer_path)
|
|
315
|
+
utc_dt = datetime.now(timezone.utc)
|
|
316
|
+
current_time = utc_dt.astimezone().strftime("%Y_%m_%d_%H_%M_%S_%f")
|
|
317
|
+
if args.name:
|
|
318
|
+
result_db_path = os.path.join('./', args.name)
|
|
319
|
+
else:
|
|
320
|
+
result_db_path = os.path.join('./', "%s_benchmark_%s.db" % (args.model, current_time))
|
|
321
|
+
result_db_path_split = result_db_path.split('/')[1:]
|
|
322
|
+
if len(result_db_path_split) > 2:
|
|
323
|
+
result_db_path_split = result_db_path_split[-2:]
|
|
324
|
+
result_db_path = os.path.join(os.getcwd(), "/".join(result_db_path_split))
|
|
325
|
+
result_db_dir = os.path.split(result_db_path)[0]
|
|
326
|
+
if not os.path.exists(result_db_dir):
|
|
327
|
+
os.makedirs(result_db_dir, exist_ok=True)
|
|
328
|
+
print('Save the result to : %s'%result_db_path)
|
|
329
|
+
if os.path.exists(result_db_path):
|
|
330
|
+
print('The db file exist, delete it and start again!.')
|
|
331
|
+
sys.exit(1)
|
|
332
|
+
|
|
333
|
+
con = sqlite3.connect(result_db_path)
|
|
334
|
+
|
|
335
|
+
db_cur = con.cursor()
|
|
336
|
+
# create table
|
|
337
|
+
# TPS output tokens per second
|
|
338
|
+
# tpot Time per ooutput token
|
|
339
|
+
db_cur.execute("CREATE TABLE %s(request, start_time, chunk_times, success, \
|
|
340
|
+
response_messages, completed_time, latency, first_chunk_latency, \
|
|
341
|
+
n_chunks, chunk_time, prompt_tokens, completion_tokens)" % _table_name)
|
|
342
|
+
if args.wandb_api_key is not None:
|
|
343
|
+
import wandb
|
|
344
|
+
name = args.name if args.name is not None else '%s_%s' % (args.model, current_time)
|
|
345
|
+
wandb.init(
|
|
346
|
+
project="perf_benchmark",
|
|
347
|
+
name=name,
|
|
348
|
+
# track run metadata
|
|
349
|
+
config={
|
|
350
|
+
"model": args.model,
|
|
351
|
+
"time": current_time
|
|
352
|
+
})
|
|
353
|
+
os.environ["WANDB_SILENT"] = "true"
|
|
354
|
+
|
|
355
|
+
while True:
|
|
356
|
+
try:
|
|
357
|
+
benchmark_data = benchmark_data_queue.get_nowait()
|
|
358
|
+
benchmark_data_queue.task_done()
|
|
359
|
+
n_benchmark_result += 1
|
|
360
|
+
except asyncio.QueueEmpty as e:
|
|
361
|
+
if _data_process_completed:
|
|
362
|
+
break
|
|
363
|
+
await asyncio.sleep(1)
|
|
364
|
+
continue
|
|
365
|
+
if start_time is None:
|
|
366
|
+
start_time = benchmark_data["start_time"] # start time with first request start time
|
|
367
|
+
# total requests
|
|
368
|
+
total_time = time.perf_counter() - start_time
|
|
369
|
+
|
|
370
|
+
if benchmark_data["success"]:
|
|
371
|
+
n_succeed_queries += 1
|
|
372
|
+
n_query_trunks = len(benchmark_data["chunk_times"])
|
|
373
|
+
query_latency = benchmark_data["completed_time"] - benchmark_data["start_time"]
|
|
374
|
+
if n_query_trunks > 1:
|
|
375
|
+
query_first_chunk_latency, query_n_chunks, query_n_chunks_time = calculate_query_stream_metric(
|
|
376
|
+
benchmark_data)
|
|
377
|
+
else:
|
|
378
|
+
query_first_chunk_latency = query_latency # not stream mode, query latency is equal total latency
|
|
379
|
+
query_n_chunks = 1
|
|
380
|
+
query_n_chunks_time = query_latency
|
|
381
|
+
|
|
382
|
+
n_query_prompt_tokens, n_query_completion_tokens = api_plugin.parse_responses(
|
|
383
|
+
benchmark_data["response_messages"],
|
|
384
|
+
request=benchmark_data["request"])
|
|
385
|
+
n_total_prompt_tokens += n_query_prompt_tokens
|
|
386
|
+
n_total_completion_tokens += n_query_completion_tokens
|
|
387
|
+
|
|
388
|
+
total_first_chunk_latency += query_first_chunk_latency
|
|
389
|
+
total_latency += query_latency
|
|
390
|
+
n_total_chunks += query_n_chunks
|
|
391
|
+
total_chunks_time += query_n_chunks_time
|
|
392
|
+
|
|
393
|
+
# calc average
|
|
394
|
+
avg_first_chunk_latency = total_first_chunk_latency / n_succeed_queries
|
|
395
|
+
# average latency
|
|
396
|
+
avg_latency = total_latency / n_succeed_queries
|
|
397
|
+
# average generate chunks
|
|
398
|
+
if n_query_trunks > 1:
|
|
399
|
+
n_avg_chunks = n_total_chunks / n_succeed_queries + 2 # we remove the frist and last chunk.
|
|
400
|
+
else:
|
|
401
|
+
n_avg_chunks = n_total_chunks / n_succeed_queries
|
|
402
|
+
avg_chunk_time = total_chunks_time / n_total_chunks
|
|
403
|
+
avg_prompt_tokens = n_total_prompt_tokens / n_succeed_queries
|
|
404
|
+
avg_completion_tokens = n_total_completion_tokens / n_succeed_queries
|
|
405
|
+
# avg generate tps generated tokens / time
|
|
406
|
+
avg_token_per_seconds = n_total_completion_tokens / total_time
|
|
407
|
+
avg_time_per_token = total_time / n_total_completion_tokens
|
|
408
|
+
# save the benchmark data to database.
|
|
409
|
+
# save data to dist.
|
|
410
|
+
insert_sql = "INSERT INTO %s VALUES('%s', %s, '%s', '%s', '%s', %s, %s, %s, %s, %s, %s, %s)" % (
|
|
411
|
+
_table_name,
|
|
412
|
+
base64.b64encode(pickle.dumps(benchmark_data["request"])).decode("ascii"),
|
|
413
|
+
benchmark_data["start_time"],
|
|
414
|
+
json.dumps(benchmark_data["chunk_times"]),
|
|
415
|
+
benchmark_data["success"],
|
|
416
|
+
base64.b64encode(pickle.dumps(benchmark_data["response_messages"])).decode("ascii"),
|
|
417
|
+
benchmark_data["completed_time"],
|
|
418
|
+
query_latency,
|
|
419
|
+
query_first_chunk_latency,
|
|
420
|
+
query_n_chunks,
|
|
421
|
+
query_n_chunks_time,
|
|
422
|
+
n_query_prompt_tokens,
|
|
423
|
+
n_query_completion_tokens
|
|
424
|
+
)
|
|
425
|
+
else:
|
|
426
|
+
n_failed_queries += 1
|
|
427
|
+
# save the benchmark data to database.
|
|
428
|
+
# save data to dist.
|
|
429
|
+
insert_sql = "INSERT INTO %s(request, start_time, chunk_times, success, response_messages, completed_time)\
|
|
430
|
+
VALUES('%s', %s, '%s', '%s', '%s', %s)" % (
|
|
431
|
+
_table_name,
|
|
432
|
+
base64.b64encode(pickle.dumps(benchmark_data["request"])).decode("ascii"),
|
|
433
|
+
benchmark_data["start_time"],
|
|
434
|
+
json.dumps(benchmark_data["chunk_times"]),
|
|
435
|
+
benchmark_data["success"],
|
|
436
|
+
base64.b64encode(pickle.dumps(benchmark_data["response_messages"])).decode("ascii"),
|
|
437
|
+
benchmark_data["completed_time"]
|
|
438
|
+
)
|
|
439
|
+
n_total_queries = float(n_succeed_queries + n_failed_queries) # float for calc
|
|
440
|
+
qps = n_succeed_queries / total_time
|
|
441
|
+
db_cur.execute(insert_sql)
|
|
442
|
+
con.commit()
|
|
443
|
+
default_ndigits = 3
|
|
444
|
+
message = {"Time": round(total_time, default_ndigits),
|
|
445
|
+
"concurrency": concurrency,
|
|
446
|
+
"completed": int(n_total_queries),
|
|
447
|
+
"succeed": n_succeed_queries,
|
|
448
|
+
"failed": n_failed_queries,
|
|
449
|
+
"qps": round(qps, default_ndigits),
|
|
450
|
+
"latency": round(avg_latency, default_ndigits),
|
|
451
|
+
"time to first token": round(avg_first_chunk_latency, default_ndigits),
|
|
452
|
+
"throughput(output tokens per second)": round(avg_token_per_seconds, default_ndigits),
|
|
453
|
+
"time per output token": round(avg_time_per_token, 5),
|
|
454
|
+
"package per request": round(n_avg_chunks, default_ndigits),
|
|
455
|
+
"time per package": round(avg_chunk_time, default_ndigits),
|
|
456
|
+
"input tokens per request": round(avg_prompt_tokens, default_ndigits),
|
|
457
|
+
"output tokens per request": round(avg_completion_tokens, default_ndigits)}
|
|
458
|
+
if args.wandb_api_key is not None:
|
|
459
|
+
wandb.log(message)
|
|
460
|
+
if int(n_total_queries) % args.log_every_n_query == 0:
|
|
461
|
+
msg = json.dumps(message)
|
|
462
|
+
msg = msg[1:-1].replace('"', '')
|
|
463
|
+
logger.info(msg)
|
|
464
|
+
con.commit()
|
|
465
|
+
con.close()
|
|
466
|
+
return (total_time, n_total_queries,
|
|
467
|
+
n_succeed_queries, n_failed_queries,
|
|
468
|
+
qps, avg_latency, avg_first_chunk_latency,
|
|
469
|
+
n_avg_chunks, avg_chunk_time,
|
|
470
|
+
avg_prompt_tokens, avg_completion_tokens,
|
|
471
|
+
avg_token_per_seconds, avg_time_per_token,
|
|
472
|
+
result_db_path)
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def summary_result(expected_number_of_queries,
|
|
476
|
+
total_time,
|
|
477
|
+
n_total_queries,
|
|
478
|
+
n_succeed_queries,
|
|
479
|
+
n_failed_queries,
|
|
480
|
+
qps,
|
|
481
|
+
avg_latency,
|
|
482
|
+
avg_first_chunk_latency,
|
|
483
|
+
n_avg_chunks,
|
|
484
|
+
avg_chunk_time,
|
|
485
|
+
avg_prompt_tokens,
|
|
486
|
+
avg_completion_tokens,
|
|
487
|
+
avg_token_per_seconds,
|
|
488
|
+
avg_time_per_token,
|
|
489
|
+
result_db_path, args):
|
|
490
|
+
|
|
491
|
+
print("Benchmarking summary: ")
|
|
492
|
+
print(" Time taken for tests: %.3f seconds" % total_time)
|
|
493
|
+
print(" Expected number of requests: %s" % expected_number_of_queries)
|
|
494
|
+
print(" Number of concurrency: %d" % args.parallel)
|
|
495
|
+
print(" Total requests: %d" % n_total_queries)
|
|
496
|
+
print(" Succeed requests: %d" % n_succeed_queries)
|
|
497
|
+
print(" Failed requests: %d" % n_failed_queries)
|
|
498
|
+
print(" Average QPS: %.3f" % qps)
|
|
499
|
+
print(" Average latency: %.3f" % avg_latency)
|
|
500
|
+
print(" Throughput(average output tokens per second): %.3f" % avg_token_per_seconds)
|
|
501
|
+
print(" Average time to first token: %.3f" % avg_first_chunk_latency)
|
|
502
|
+
print(" Average input tokens per request: %.3f" % avg_prompt_tokens)
|
|
503
|
+
print(" Average output tokens per request: %.3f" % avg_completion_tokens)
|
|
504
|
+
print(" Average time per output token: %.5f" % avg_time_per_token)
|
|
505
|
+
print(" Average package per request: %.3f" % n_avg_chunks)
|
|
506
|
+
print(" Average package latency: %.3f" % avg_chunk_time)
|
|
507
|
+
|
|
508
|
+
con = sqlite3.connect(result_db_path)
|
|
509
|
+
query_sql = "SELECT start_time, chunk_times, success, \
|
|
510
|
+
completed_time, latency, first_chunk_latency, \
|
|
511
|
+
n_chunks, chunk_time, prompt_tokens, completion_tokens \
|
|
512
|
+
FROM %s WHERE success='True' ORDER BY first_chunk_latency ASC" % _table_name
|
|
513
|
+
|
|
514
|
+
percentiles = [50, 66, 75, 80, 90, 95, 98, 99]
|
|
515
|
+
with con:
|
|
516
|
+
rows = con.execute(query_sql).fetchall()
|
|
517
|
+
n_success_queries = len(rows)
|
|
518
|
+
if len(rows) > len(percentiles):
|
|
519
|
+
print(" Percentile of time to first token: ")
|
|
520
|
+
for percentile in percentiles:
|
|
521
|
+
idx = (int)(n_success_queries * percentile / 100)
|
|
522
|
+
row = rows[idx]
|
|
523
|
+
print(" p%s: %.4f" % (percentile, row[5] if row[5] is not None else float("inf")))
|
|
524
|
+
# print(row)
|
|
525
|
+
print(" Percentile of request latency: ")
|
|
526
|
+
latency_index = 4
|
|
527
|
+
rows.sort(key=lambda x: x[latency_index])
|
|
528
|
+
for percentile in percentiles:
|
|
529
|
+
idx = (int)(n_success_queries * percentile / 100)
|
|
530
|
+
row = rows[idx]
|
|
531
|
+
print(" p%s: %.4f" % (percentile, row[latency_index]
|
|
532
|
+
if row[latency_index] is not None else float("inf")))
|
|
533
|
+
else:
|
|
534
|
+
print(" Too little data to calculate quantiles!")
|
|
535
|
+
con.close()
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
async def send_requests_worker(task_id, request_queue: asyncio.Queue, benchmark_data_queue: asyncio.Queue, args):
|
|
539
|
+
client = AioHttpClient(args.url,
|
|
540
|
+
conn_timeout=args.connect_timeout,
|
|
541
|
+
read_timeout=args.read_timeout,
|
|
542
|
+
headers=args.headers,
|
|
543
|
+
debug=args.debug)
|
|
544
|
+
async with client:
|
|
545
|
+
while True:
|
|
546
|
+
# Get a request out of the queue.
|
|
547
|
+
try:
|
|
548
|
+
request = request_queue.get_nowait()
|
|
549
|
+
request_queue.task_done()
|
|
550
|
+
except asyncio.QueueEmpty as e:
|
|
551
|
+
if _query_send_completed:
|
|
552
|
+
break
|
|
553
|
+
await asyncio.sleep(0.01)
|
|
554
|
+
continue # keep polling querys
|
|
555
|
+
benchmark_data = BenchmarkData()
|
|
556
|
+
benchmark_data["request"] = request
|
|
557
|
+
benchmark_data["start_time"] = time.perf_counter()
|
|
558
|
+
benchmark_data["chunk_times"] = []
|
|
559
|
+
benchmark_data["success"] = False
|
|
560
|
+
collected_messages = []
|
|
561
|
+
try:
|
|
562
|
+
async for (is_error, state_code, response_data) in client.post(request):
|
|
563
|
+
if is_error or state_code != HTTPStatus.OK:
|
|
564
|
+
logger.error("Request: %s failed, state_code: %s, data: %s" %
|
|
565
|
+
(request, state_code, response_data))
|
|
566
|
+
break
|
|
567
|
+
else:
|
|
568
|
+
if response_data:
|
|
569
|
+
collected_messages.append(response_data) # save the message
|
|
570
|
+
logger.debug(response_data)
|
|
571
|
+
benchmark_data["chunk_times"].append(time.perf_counter())
|
|
572
|
+
|
|
573
|
+
benchmark_data["response_messages"] = collected_messages
|
|
574
|
+
benchmark_data["completed_time"] = time.perf_counter()
|
|
575
|
+
benchmark_data["success"] = not is_error
|
|
576
|
+
await benchmark_data_queue.put(benchmark_data)
|
|
577
|
+
except BaseException as e:
|
|
578
|
+
if response_data:
|
|
579
|
+
collected_messages.append(response_data) # save the message
|
|
580
|
+
benchmark_data["response_messages"] = collected_messages
|
|
581
|
+
benchmark_data["completed_time"] = time.perf_counter()
|
|
582
|
+
await benchmark_data_queue.put(benchmark_data)
|
|
583
|
+
logger.error("Request query: %s exception, response: %s" % (request, response_data))
|
|
584
|
+
logger.exception(e)
|
|
585
|
+
|
|
586
|
+
def signal_handler(signal_name, loop):
|
|
587
|
+
print("Got signal %s: exit" % signal_name)
|
|
588
|
+
loop.stop()
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
async def benchmark(args) -> None:
|
|
592
|
+
# Check if the current platform is Windows
|
|
593
|
+
if platform.system() != 'Windows':
|
|
594
|
+
# add SIGINT and SIGTERM handler
|
|
595
|
+
loop = asyncio.get_running_loop()
|
|
596
|
+
for signal_name in {'SIGINT', 'SIGTERM'}:
|
|
597
|
+
loop.add_signal_handler(
|
|
598
|
+
getattr(signal, signal_name),
|
|
599
|
+
functools.partial(signal_handler, signal_name, loop))
|
|
600
|
+
|
|
601
|
+
request_tasks: List[asyncio.Task] = []
|
|
602
|
+
# Queues can be used to distribute workload between several concurrent tasks
|
|
603
|
+
# Create a queue that we will use to store our "workload".
|
|
604
|
+
request_queue = asyncio.Queue()
|
|
605
|
+
benchmark_data_queue = asyncio.Queue()
|
|
606
|
+
dispatch_task = asyncio.create_task(dispatch_requests_worker(request_queue, args))
|
|
607
|
+
statistic_benchmark_metric_task = asyncio.create_task(statistic_benchmark_metric_worker(benchmark_data_queue, args))
|
|
608
|
+
for idx, task in enumerate(range(args.parallel)):
|
|
609
|
+
task = asyncio.create_task(send_requests_worker(idx, request_queue, benchmark_data_queue, args))
|
|
610
|
+
request_tasks.append(task)
|
|
611
|
+
|
|
612
|
+
expected_number_of_queries = await dispatch_task # wait for dispatch task complete
|
|
613
|
+
await request_queue.join()
|
|
614
|
+
global _query_send_completed
|
|
615
|
+
_query_send_completed = True
|
|
616
|
+
await asyncio.gather(*request_tasks, return_exceptions=True)
|
|
617
|
+
await benchmark_data_queue.join() # wait for all query is processed
|
|
618
|
+
global _data_process_completed
|
|
619
|
+
_data_process_completed = True
|
|
620
|
+
(total_time, n_total_queries,
|
|
621
|
+
n_succeed_queries, n_failed_queries,
|
|
622
|
+
qps, avg_latency,
|
|
623
|
+
avg_first_chunk_latency, n_avg_chunks,
|
|
624
|
+
avg_chunk_time, avg_prompt_tokens,
|
|
625
|
+
avg_completion_tokens, avg_token_per_seconds,
|
|
626
|
+
avg_time_per_token, result_db_path) = await statistic_benchmark_metric_task
|
|
627
|
+
|
|
628
|
+
summary_result(expected_number_of_queries, total_time, n_total_queries, n_succeed_queries,
|
|
629
|
+
n_failed_queries, qps, avg_latency, avg_first_chunk_latency,
|
|
630
|
+
n_avg_chunks, avg_chunk_time,
|
|
631
|
+
avg_prompt_tokens, avg_completion_tokens,
|
|
632
|
+
avg_token_per_seconds, avg_time_per_token,
|
|
633
|
+
result_db_path, args)
|
|
634
|
+
await asyncio.sleep(0.250)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def process_number(input):
|
|
638
|
+
try:
|
|
639
|
+
return int(input)
|
|
640
|
+
except ValueError:
|
|
641
|
+
try:
|
|
642
|
+
return float(input)
|
|
643
|
+
except ValueError:
|
|
644
|
+
return input
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
# from: https://gist.github.com/vadimkantorov/37518ff88808af840884355c845049ea
|
|
648
|
+
class ParseKVAction(argparse.Action):
|
|
649
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
650
|
+
setattr(namespace, self.dest, dict())
|
|
651
|
+
for each in values:
|
|
652
|
+
try:
|
|
653
|
+
key, value = each.split("=")
|
|
654
|
+
if value.lower() == 'bool_true':
|
|
655
|
+
value = True
|
|
656
|
+
if value.lower() == 'bool_false':
|
|
657
|
+
value = False
|
|
658
|
+
|
|
659
|
+
value = process_number(value)
|
|
660
|
+
getattr(namespace, self.dest)[key] = value
|
|
661
|
+
except ValueError as ex:
|
|
662
|
+
message = "\nTraceback: {}".format(ex)
|
|
663
|
+
message += "\nError on '{}' || It should be 'key=value'".format(
|
|
664
|
+
each)
|
|
665
|
+
raise argparse.ArgumentError(self, str(message))
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def run_perf_benchmark(args):
|
|
669
|
+
asyncio.run(benchmark(args))
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def add_argument(parser: argparse.ArgumentParser):
|
|
673
|
+
parser.add_argument("--model", type=str, required=True,
|
|
674
|
+
help="The test model name.")
|
|
675
|
+
parser.add_argument("--url", type=str, default="localhost")
|
|
676
|
+
parser.add_argument("--connect-timeout", type=int, default=120,
|
|
677
|
+
help="The network connection timeout")
|
|
678
|
+
parser.add_argument("--read-timeout", type=int, default=120,
|
|
679
|
+
help="The network read timeout")
|
|
680
|
+
parser.add_argument("-n", "--number", type=int, default=None,
|
|
681
|
+
help="How many requests to be made, if None, "
|
|
682
|
+
"will will send request base dataset or prompt.")
|
|
683
|
+
parser.add_argument("--parallel", type=int, default=1,
|
|
684
|
+
help="Set number of concurrency request, default 1")
|
|
685
|
+
parser.add_argument("--rate", type=int, default=UNLIMITED_RATE,
|
|
686
|
+
help="Number of requests per second. default None, if it set to -1,"
|
|
687
|
+
"then all the requests are sent at time 0. "
|
|
688
|
+
"Otherwise, we use Poisson process to synthesize "
|
|
689
|
+
"the request arrival times. Mutual exclusion with parallel")
|
|
690
|
+
parser.add_argument("--log-every-n-query", type=int, default=10,
|
|
691
|
+
help="Logging every n query.")
|
|
692
|
+
parser.add_argument("--headers", nargs="+", dest="headers",
|
|
693
|
+
action=ParseKVAction,
|
|
694
|
+
help="Extra http headers accepts by key1=value1 key2=value2. "
|
|
695
|
+
"The headers will be use for each query."
|
|
696
|
+
"You can use this parameter to specify http authorization and other header.",
|
|
697
|
+
metavar="KEY1=VALUE1")
|
|
698
|
+
parser.add_argument("--wandb-api-key", type=str, default=None,
|
|
699
|
+
help="The wandb api key, if set the metric will be saved to wandb.")
|
|
700
|
+
parser.add_argument("--name", type=str,
|
|
701
|
+
help="The wandb db result name and result db name, default: {model_name}_{current_time}")
|
|
702
|
+
parser.add_argument("--debug", action='store_true', default=False,
|
|
703
|
+
help='Debug request send.')
|
|
704
|
+
parser.add_argument("--tokenizer-path", type=str, required=False, default=None,
|
|
705
|
+
help="Specify the tokenizer weight path, used to calculate the number of input and output tokens,"
|
|
706
|
+
"usually in the same directory as the model weight. If service return usage will use usage info.")
|
|
707
|
+
parser.add_argument("--api",
|
|
708
|
+
type=str,
|
|
709
|
+
default="openai",
|
|
710
|
+
help="Specify the service api, current support [openai|dashscope]"
|
|
711
|
+
"you can define your custom parser with python, and specify the python file path, "
|
|
712
|
+
"reference api_plugin_base.py,")
|
|
713
|
+
parser.add_argument("--max-prompt-length", type=int, default=sys.maxsize,
|
|
714
|
+
help="Maximum input prompt length")
|
|
715
|
+
parser.add_argument("--min-prompt-length", type=int, default=0,
|
|
716
|
+
help="Minimum input prompt length.")
|
|
717
|
+
parser.add_argument("--prompt", type=str, required=False, default=None,
|
|
718
|
+
help="Specified the request prompt, all the query will use this prompt, "
|
|
719
|
+
"You can specify local file via @file_path, the prompt will be "
|
|
720
|
+
"the file content.")
|
|
721
|
+
parser.add_argument("--query-template",
|
|
722
|
+
type=str,
|
|
723
|
+
default=None,
|
|
724
|
+
help="Specify the query template, should be a json string, or local file,"
|
|
725
|
+
"with local file, specified with @local_file_path,"
|
|
726
|
+
"will will replace model and prompt in the template.")
|
|
727
|
+
parser.add_argument("--dataset",
|
|
728
|
+
type=str,
|
|
729
|
+
default='line_by_line',
|
|
730
|
+
help="Specify the dataset [openqa|longalpaca|line_by_line]"
|
|
731
|
+
"you can define your custom dataset parser with python, and specify the python file path, "
|
|
732
|
+
"reference dataset_plugin_base.py,")
|
|
733
|
+
parser.add_argument("--dataset-path", type=str, required=False,
|
|
734
|
+
help="Path to the dataset file, Used in conjunction with dataset. "
|
|
735
|
+
"If dataset is None, each line defaults to a prompt.")
|
|
736
|
+
|
|
737
|
+
parser.add_argument("--frequency-penalty", type=float, help="The frequency_penalty value.", default= None)
|
|
738
|
+
parser.add_argument("--logprobs", action='store_true', help="The logprobs.", default=None)
|
|
739
|
+
parser.add_argument("--max-tokens", type=int, help="The maximum number of tokens can be generated.", default=None)
|
|
740
|
+
parser.add_argument("--n-choices", type=int, help="How may chmpletion choices to generate.", default=None)
|
|
741
|
+
parser.add_argument("--seed", type=int, help="Rhe random seed.", default=None)
|
|
742
|
+
parser.add_argument("--stop", nargs='*', help="The stop tokens.", default=None)
|
|
743
|
+
parser.add_argument("--stop-token-ids", nargs='*', help="Set the stop token ids.", default=None)
|
|
744
|
+
parser.add_argument("--stream", action='store_true', help="Stream output with SSE, Automatically add stream_option.include_usage with openai interface.", default=None)
|
|
745
|
+
parser.add_argument("--temperature", type=float, help="The sample temperature.", default=None)
|
|
746
|
+
parser.add_argument("--top-p", type=float, help="Sampling top p.", default=None)
|
|
747
|
+
|
|
748
|
+
if __name__ == "__main__":
|
|
749
|
+
# for windows raise RuntimeError: Event loop is closed
|
|
750
|
+
if platform.system() == 'Windows':
|
|
751
|
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
752
|
+
parser = argparse.ArgumentParser(
|
|
753
|
+
description="Benchmark LLM service performance.")
|
|
754
|
+
add_argument(parser)
|
|
755
|
+
args = parser.parse_args()
|
|
756
|
+
run_perf_benchmark(args)
|