evalscope 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. evalscope/backend/rag_eval/clip_benchmark/utils/webdataset_convert.py +230 -0
  2. evalscope/backend/rag_eval/clip_benchmark/utils/webdatasets.txt +43 -0
  3. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +87 -0
  4. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +36 -0
  5. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +26 -0
  6. evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +41 -0
  7. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +60 -0
  8. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +36 -0
  9. evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +22 -0
  10. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +35 -0
  11. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  12. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  13. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  14. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  15. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +34 -0
  16. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +36 -0
  17. evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +25 -0
  18. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  19. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  20. evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +16 -0
  21. evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +24 -0
  22. evalscope/backend/rag_eval/ragas/prompts/persona_prompt.py +18 -0
  23. evalscope/backend/vlm_eval_kit/backend_manager.py +23 -21
  24. evalscope/benchmarks/ceval/samples.jsonl +1 -0
  25. evalscope/benchmarks/cmmlu/samples.jsonl +5 -0
  26. evalscope/benchmarks/mmlu/samples.jsonl +5 -0
  27. evalscope/benchmarks/race/samples.jsonl +5 -0
  28. evalscope/benchmarks/trivia_qa/samples.jsonl +5 -0
  29. evalscope/cli/start_perf.py +8 -11
  30. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +58485 -0
  31. evalscope/metrics/resources/gpt2-zhcn3-v4.json +1 -0
  32. evalscope/metrics/rouge_metric.py +30 -15
  33. evalscope/perf/arguments.py +179 -0
  34. evalscope/perf/benchmark.py +245 -0
  35. evalscope/perf/http_client.py +127 -711
  36. evalscope/perf/main.py +35 -0
  37. evalscope/perf/plugin/__init__.py +2 -0
  38. evalscope/perf/plugin/api/__init__.py +3 -0
  39. evalscope/perf/{api_plugin_base.py → plugin/api/base.py} +17 -18
  40. evalscope/perf/{custom_api.py → plugin/api/custom_api.py} +25 -19
  41. evalscope/perf/{dashscope_api.py → plugin/api/dashscope_api.py} +28 -14
  42. evalscope/perf/{openai_api.py → plugin/api/openai_api.py} +51 -27
  43. evalscope/perf/plugin/datasets/__init__.py +6 -0
  44. evalscope/perf/{dataset_plugin_base.py → plugin/datasets/base.py} +13 -10
  45. evalscope/perf/plugin/datasets/custom.py +21 -0
  46. evalscope/perf/plugin/datasets/flickr8k.py +51 -0
  47. evalscope/perf/{datasets → plugin/datasets}/line_by_line.py +9 -5
  48. evalscope/perf/plugin/datasets/longalpaca.py +28 -0
  49. evalscope/perf/plugin/datasets/openqa.py +38 -0
  50. evalscope/perf/plugin/datasets/speed_benchmark.py +50 -0
  51. evalscope/perf/plugin/registry.py +54 -0
  52. evalscope/perf/{how_to_analysis_result.py → utils/analysis_result.py} +11 -5
  53. evalscope/perf/utils/benchmark_util.py +135 -0
  54. evalscope/perf/utils/chat_service.py +252 -0
  55. evalscope/perf/utils/db_util.py +200 -0
  56. evalscope/perf/utils/handler.py +46 -0
  57. evalscope/perf/utils/local_server.py +139 -0
  58. evalscope/registry/config/cfg_arena.yaml +77 -0
  59. evalscope/registry/config/cfg_arena_zhihu.yaml +63 -0
  60. evalscope/registry/config/cfg_pairwise_baseline.yaml +83 -0
  61. evalscope/registry/config/cfg_single.yaml +78 -0
  62. evalscope/registry/data/prompt_template/lmsys_v2.jsonl +8 -0
  63. evalscope/registry/data/prompt_template/prompt_templates.jsonl +8 -0
  64. evalscope/registry/data/qa_browser/battle.jsonl +634 -0
  65. evalscope/registry/data/qa_browser/category_mapping.yaml +10 -0
  66. evalscope/registry/data/question.jsonl +80 -0
  67. evalscope/third_party/longbench_write/README.md +118 -0
  68. evalscope/third_party/longbench_write/default_task.json +27 -0
  69. evalscope/third_party/longbench_write/default_task.yaml +24 -0
  70. evalscope/third_party/toolbench_static/README.md +118 -0
  71. evalscope/third_party/toolbench_static/config_default.json +15 -0
  72. evalscope/third_party/toolbench_static/config_default.yaml +12 -0
  73. evalscope/third_party/toolbench_static/requirements.txt +2 -0
  74. evalscope/utils/logger.py +18 -20
  75. evalscope/utils/utils.py +41 -42
  76. evalscope/version.py +2 -2
  77. evalscope-0.7.0.dist-info/LICENSE +203 -0
  78. {evalscope-0.6.1.dist-info → evalscope-0.7.0.dist-info}/METADATA +91 -33
  79. {evalscope-0.6.1.dist-info → evalscope-0.7.0.dist-info}/RECORD +99 -29
  80. {evalscope-0.6.1.dist-info → evalscope-0.7.0.dist-info}/WHEEL +1 -1
  81. {evalscope-0.6.1.dist-info → evalscope-0.7.0.dist-info}/top_level.txt +1 -0
  82. tests/cli/__init__.py +1 -0
  83. tests/cli/test_run.py +76 -0
  84. tests/perf/__init__.py +1 -0
  85. tests/perf/test_perf.py +96 -0
  86. tests/rag/test_clip_benchmark.py +85 -0
  87. tests/rag/test_mteb.py +136 -0
  88. tests/rag/test_ragas.py +120 -0
  89. tests/swift/__init__.py +1 -0
  90. tests/swift/test_run_swift_eval.py +146 -0
  91. tests/swift/test_run_swift_vlm_eval.py +128 -0
  92. tests/swift/test_run_swift_vlm_jugde_eval.py +157 -0
  93. tests/test_run_all.py +12 -0
  94. tests/vlm/__init__.py +1 -0
  95. tests/vlm/test_vlmeval.py +59 -0
  96. evalscope/perf/_logging.py +0 -32
  97. evalscope/perf/datasets/longalpaca_12k.py +0 -20
  98. evalscope/perf/datasets/openqa.py +0 -22
  99. evalscope/perf/plugin_registry.py +0 -35
  100. evalscope/perf/query_parameters.py +0 -42
  101. evalscope/perf/server_sent_event.py +0 -43
  102. evalscope/preprocess/tokenizers/gpt2_tokenizer.py +0 -221
  103. /evalscope/perf/{datasets → utils}/__init__.py +0 -0
  104. {evalscope-0.6.1.dist-info → evalscope-0.7.0.dist-info}/entry_points.txt +0 -0
  105. {evalscope/preprocess → tests}/__init__.py +0 -0
  106. {evalscope/preprocess/tokenizers → tests/rag}/__init__.py +0 -0
@@ -9,16 +9,24 @@ from tqdm import tqdm
9
9
 
10
10
  from evalscope.constants import MetricsConstant
11
11
  from evalscope.metrics.bundled_rouge_score import rouge_scorer
12
- from evalscope.preprocess.tokenizers.gpt2_tokenizer import DummyTokenizer
12
+
13
13
  from rouge_chinese import Rouge
14
14
  import jieba
15
15
 
16
+
17
+ class DummyTokenizer:
18
+
19
+ def tokenize(self, text: str):
20
+ return text.split()
21
+
22
+
16
23
  HERE = Path(__file__).absolute().parent
17
24
 
18
25
  logger = logging.getLogger(__name__)
19
26
 
20
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'],
21
- tokenizer=DummyTokenizer())
27
+ scorer = rouge_scorer.RougeScorer(
28
+ ['rouge1', 'rouge2', 'rougeL'], tokenizer=DummyTokenizer()
29
+ )
22
30
  zh_scorer = Rouge()
23
31
 
24
32
 
@@ -28,6 +36,7 @@ def is_contains_chinese(strs):
28
36
  return True
29
37
  return False
30
38
 
39
+
31
40
  def compute_rouge_score(predict_l, reference_l):
32
41
  assert len(predict_l) == len(reference_l)
33
42
  if len(predict_l) == 0:
@@ -43,10 +52,14 @@ def compute_rouge_score(predict_l, reference_l):
43
52
  result[rouge_key].append(one_sample[rouge_key])
44
53
  rlt = {}
45
54
  for rouge_key in MetricsConstant.ROUGE_KEYS:
46
- rlt[rouge_key] = mean(result[rouge_key]) * 100 if rouge_key in result \
55
+ rlt[rouge_key] = (
56
+ mean(result[rouge_key]) * 100
57
+ if rouge_key in result
47
58
  else MetricsConstant.INVALID_VALUE
59
+ )
48
60
  return rlt
49
61
 
62
+
50
63
  def compute_rouge_score_one_sample_zh(predict, reference):
51
64
  result = dict()
52
65
  for p, r in zip(predict, reference):
@@ -63,9 +76,10 @@ def compute_rouge_score_one_sample_zh(predict, reference):
63
76
  result['rouge-l-r'] = score['rouge-l']['r']
64
77
  result['rouge-l-p'] = score['rouge-l']['p']
65
78
  result['rouge-l-f'] = score['rouge-l']['f']
66
-
79
+
67
80
  return result
68
81
 
82
+
69
83
  def compute_rouge_score_one_sample(predict, reference):
70
84
  result = dict()
71
85
  for p, r in zip(predict, reference):
@@ -97,11 +111,9 @@ def _to_table(final_result) -> str:
97
111
  if not task:
98
112
  continue
99
113
  elif task == 'total':
100
- row.append(
101
- f'{final_result["total"]["rouge"][rouge_key] :0.2f}')
114
+ row.append(f'{final_result["total"]["rouge"][rouge_key] :0.2f}')
102
115
  else:
103
- row.append(
104
- f'{final_result["tasks"][task]["rouge"][rouge_key] :0.2f}')
116
+ row.append(f'{final_result["tasks"][task]["rouge"][rouge_key] :0.2f}')
105
117
  table.append('\t'.join(row))
106
118
 
107
119
  return '\n'.join(table)
@@ -111,19 +123,22 @@ def run_rouge_eval(data_l, md_level=2, report_metric_key='rouge-l-f'):
111
123
  print(f"{'#' * md_level} Rouge Eval")
112
124
  for data in tqdm(data_l):
113
125
  data['rouge'] = compute_rouge_score_one_sample(
114
- data['gen_tok_str'], data['reference_tok_str'])
126
+ data['gen_tok_str'], data['reference_tok_str']
127
+ )
115
128
  task_data_d = defaultdict(list)
116
129
  for data in data_l:
117
130
  for task in data['task_tags']:
118
131
  task_data_d[task].append(data)
119
132
 
120
133
  total_rouge = mean([data['rouge'][report_metric_key] for data in data_l])
121
- print(f'[total], count: {len(data_l)}, {report_metric_key}: '
122
- f'{total_rouge * 100:0.2f}%')
134
+ print(
135
+ f'[total], count: {len(data_l)}, {report_metric_key}: '
136
+ f'{total_rouge * 100:0.2f}%'
137
+ )
123
138
 
124
139
  for task, task_data in task_data_d.items():
125
- task_rouge = mean(
126
- [data['rouge'][report_metric_key] for data in task_data])
140
+ task_rouge = mean([data['rouge'][report_metric_key] for data in task_data])
127
141
  print(
128
142
  f'[{task}], count: {len(task_data_d[task])}, {report_metric_key}: '
129
- f'{task_rouge * 100:0.2f}%')
143
+ f'{task_rouge * 100:0.2f}%'
144
+ )
@@ -0,0 +1,179 @@
1
+ import argparse
2
+ import sys
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import json
7
+
8
+
9
+ @dataclass
10
+ class Arguments:
11
+ # Model and API
12
+ model: str # Model identifier
13
+ attn_implementation: Optional[str] = None # Attention implementaion, only for local inference
14
+ api: str = 'openai' # API to be used (default: 'openai')
15
+ tokenizer_path: Optional[str] = None # Path to the tokenizer
16
+
17
+ # Connection settings
18
+ url: str = 'http://127.0.0.1:8877/v1/chat/completions' # URL for the API connection
19
+ headers: Dict[str, Any] = field(default_factory=dict) # Custom headers
20
+ connect_timeout: int = 120 # Connection timeout in seconds
21
+ read_timeout: int = 120 # Read timeout in seconds
22
+ api_key: str = 'EMPTY'
23
+
24
+ # Performance and parallelism
25
+ number: Optional[int] = None # Number of requests to be made
26
+ parallel: int = 1 # Number of parallel requests
27
+ rate: int = -1 # Rate limit for requests (default: -1, no limit)
28
+
29
+ # Logging and debugging
30
+ log_every_n_query: int = 10 # Log every N queries
31
+ debug: bool = False # Debug mode
32
+ wandb_api_key: Optional[str] = None # WandB API key for logging
33
+ name: Optional[str] = None # Name for the run
34
+
35
+ # Prompt settings
36
+ max_prompt_length: int = sys.maxsize # Maximum length of the prompt
37
+ min_prompt_length: int = 0 # Minimum length of the prompt
38
+ prompt: Optional[str] = None # The prompt text
39
+ query_template: Optional[str] = None # Template for the query
40
+
41
+ # Dataset settings
42
+ dataset: str = 'openqa' # Dataset type (default: 'line_by_line')
43
+ dataset_path: Optional[str] = None # Path to the dataset
44
+
45
+ # Response settings
46
+ frequency_penalty: Optional[float] = None # Frequency penalty for the response
47
+ logprobs: Optional[bool] = None # Whether to log probabilities
48
+ max_tokens: Optional[int] = 2048 # Maximum number of tokens in the response
49
+ min_tokens: Optional[int] = None # Minimum number of tokens in the response
50
+ n_choices: Optional[int] = None # Number of response choices
51
+ seed: Optional[int] = 42 # Random seed for reproducibility
52
+ stop: Optional[List[str]] = field(default_factory=list) # Stop sequences for the response
53
+ stop_token_ids: Optional[List[str]] = field(default_factory=list) # Stop token IDs for the response
54
+ stream: Optional[bool] = None # Whether to stream the response
55
+ temperature: Optional[float] = None # Temperature setting for the response
56
+ top_p: Optional[float] = None # Top-p (nucleus) sampling setting for the response
57
+
58
+ @staticmethod
59
+ def from_args(args):
60
+
61
+ return Arguments(
62
+ model=args.model,
63
+ attn_implementation=args.attn_implementation,
64
+ url=args.url,
65
+ api_key=args.api_key,
66
+ connect_timeout=args.connect_timeout,
67
+ read_timeout=args.read_timeout,
68
+ number=args.number,
69
+ parallel=args.parallel,
70
+ rate=args.rate,
71
+ log_every_n_query=args.log_every_n_query,
72
+ headers=args.headers,
73
+ wandb_api_key=args.wandb_api_key,
74
+ name=args.name,
75
+ debug=args.debug,
76
+ tokenizer_path=args.tokenizer_path,
77
+ api=args.api,
78
+ max_prompt_length=args.max_prompt_length,
79
+ min_prompt_length=args.min_prompt_length,
80
+ prompt=args.prompt,
81
+ query_template=args.query_template,
82
+ dataset=args.dataset,
83
+ dataset_path=args.dataset_path,
84
+ frequency_penalty=args.frequency_penalty,
85
+ logprobs=args.logprobs,
86
+ max_tokens=args.max_tokens,
87
+ min_tokens=args.min_tokens,
88
+ n_choices=args.n_choices,
89
+ seed=args.seed,
90
+ stop=args.stop,
91
+ stop_token_ids=args.stop_token_ids,
92
+ stream=args.stream,
93
+ temperature=args.temperature,
94
+ top_p=args.top_p)
95
+
96
+ def __post_init__(self):
97
+ self.headers = self.headers or {} # Default to empty dictionary
98
+ if self.api_key:
99
+ # Assuming the API key is used as a Bearer token
100
+ self.headers['Authorization'] = f'Bearer {self.api_key}'
101
+
102
+ def __str__(self):
103
+ return json.dumps(self.to_dict(), indent=4, default=str, ensure_ascii=False)
104
+
105
+ def to_dict(self) -> Dict[str, Any]:
106
+ return self.__dict__
107
+
108
+
109
+ class ParseKVAction(argparse.Action):
110
+
111
+ def __call__(self, parser, namespace, values, option_string=None):
112
+ if not values:
113
+ setattr(namespace, self.dest, {})
114
+ else:
115
+ try:
116
+ kv_dict = dict(kv.split('=') for kv in values)
117
+ setattr(namespace, self.dest, kv_dict)
118
+ except ValueError as e:
119
+ parser.error(f'Error parsing key-value pairs: {e}')
120
+
121
+
122
+ def add_argument(parser: argparse.ArgumentParser):
123
+ # yapf: disable
124
+ # Model and API
125
+ parser.add_argument('--model', type=str, required=True, help='The test model name.')
126
+ parser.add_argument('--attn-implementation', required=False, default=None, help='Attention implementaion')
127
+ parser.add_argument('--api', type=str, default='openai', help='Specify the service API')
128
+ parser.add_argument(
129
+ '--tokenizer-path', type=str, required=False, default=None, help='Specify the tokenizer weight path')
130
+
131
+ # Connection settings
132
+ parser.add_argument('--url', type=str, default='http://127.0.0.1:8877/v1/chat/completions')
133
+ parser.add_argument('--headers', nargs='+', dest='headers', action=ParseKVAction, help='Extra HTTP headers')
134
+ parser.add_argument('--api-key', type=str, required=False, default='EMPTY', help='The API key for authentication')
135
+ parser.add_argument('--connect-timeout', type=int, default=120, help='The network connection timeout')
136
+ parser.add_argument('--read-timeout', type=int, default=120, help='The network read timeout')
137
+
138
+ # Performance and parallelism
139
+ parser.add_argument('-n', '--number', type=int, default=None, help='How many requests to be made')
140
+ parser.add_argument('--parallel', type=int, default=1, help='Set number of concurrency requests, default 1')
141
+ parser.add_argument('--rate', type=int, default=-1, help='Number of requests per second. default None')
142
+
143
+ # Logging and debugging
144
+ parser.add_argument('--log-every-n-query', type=int, default=10, help='Logging every n query')
145
+ parser.add_argument('--debug', action='store_true', default=False, help='Debug request send')
146
+ parser.add_argument('--wandb-api-key', type=str, default=None, help='The wandb API key')
147
+ parser.add_argument('--name', type=str, help='The wandb db result name and result db name')
148
+
149
+ # Prompt settings
150
+ parser.add_argument('--max-prompt-length', type=int, default=sys.maxsize, help='Maximum input prompt length')
151
+ parser.add_argument('--min-prompt-length', type=int, default=0, help='Minimum input prompt length')
152
+ parser.add_argument('--prompt', type=str, required=False, default=None, help='Specified the request prompt')
153
+ parser.add_argument('--query-template', type=str, default=None, help='Specify the query template')
154
+
155
+ # Dataset settings
156
+ parser.add_argument('--dataset', type=str, default='openqa', help='Specify the dataset')
157
+ parser.add_argument('--dataset-path', type=str, required=False, help='Path to the dataset file')
158
+
159
+ # Response settings
160
+ parser.add_argument('--frequency-penalty', type=float, help='The frequency_penalty value', default=None)
161
+ parser.add_argument('--logprobs', action='store_true', help='The logprobs', default=None)
162
+ parser.add_argument(
163
+ '--max-tokens', type=int, help='The maximum number of tokens that can be generated', default=2048)
164
+ parser.add_argument(
165
+ '--min-tokens', type=int, help='The minimum number of tokens that can be generated', default=None)
166
+ parser.add_argument('--n-choices', type=int, help='How many completion choices to generate', default=None)
167
+ parser.add_argument('--seed', type=int, help='The random seed', default=42)
168
+ parser.add_argument('--stop', nargs='*', help='The stop tokens', default=None)
169
+ parser.add_argument('--stop-token-ids', nargs='*', help='Set the stop token IDs', default=None)
170
+ parser.add_argument('--stream', action='store_true', help='Stream output with SSE', default=None)
171
+ parser.add_argument('--temperature', type=float, help='The sample temperature', default=None)
172
+ parser.add_argument('--top-p', type=float, help='Sampling top p', default=None)
173
+ # yapf: enable
174
+
175
+
176
+ def parse_args():
177
+ parser = argparse.ArgumentParser(description='Benchmark LLM service performance.')
178
+ add_argument(parser)
179
+ return parser.parse_args()
@@ -0,0 +1,245 @@
1
+ import asyncio
2
+ import copy
3
+ import os
4
+ import platform
5
+ import sqlite3
6
+ import threading
7
+ import time
8
+ from http import HTTPStatus
9
+ from typing import List
10
+
11
+ import json
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ from evalscope.perf.arguments import Arguments
16
+ from evalscope.perf.http_client import AioHttpClient, test_connection
17
+ from evalscope.perf.plugin.registry import ApiRegistry, DatasetRegistry
18
+ from evalscope.perf.utils.benchmark_util import BenchmarkData, BenchmarkMetrics
19
+ from evalscope.perf.utils.db_util import create_result_table, get_result_db_path, insert_benchmark_data, summary_result
20
+ from evalscope.perf.utils.handler import add_signal_handlers, exception_handler
21
+ from evalscope.perf.utils.local_server import start_app
22
+ from evalscope.utils.logger import get_logger
23
+
24
+ logger = get_logger()
25
+ query_send_completed_event = asyncio.Event()
26
+ data_process_completed_event = asyncio.Event()
27
+
28
+
29
+ @exception_handler
30
+ async def dispatch_requests_worker(request_queue: asyncio.Queue, args: Arguments):
31
+ query_generator_class = ApiRegistry(args.api)
32
+ query_generator = query_generator_class(args.tokenizer_path)
33
+
34
+ def load_prompt(prompt_path_or_text):
35
+ """Load the prompt from a file or directly from the input text."""
36
+ if prompt_path_or_text.startswith('@'):
37
+ with open(prompt_path_or_text[1:], 'r', encoding='utf-8') as file:
38
+ return file.read()
39
+ return prompt_path_or_text
40
+
41
+ async def dispatch_request(request):
42
+ """Dispatch a single request with optional rate limiting."""
43
+ await request_queue.put(request)
44
+ if args.rate != -1:
45
+ interval = np.random.exponential(1.0 / args.rate)
46
+ await asyncio.sleep(interval)
47
+
48
+ async def dispatch_requests_from_prompt(messages):
49
+ """Generate and dispatch requests based on the given prompt."""
50
+ request = query_generator.build_request(messages, args)
51
+ if args.number is None:
52
+ await dispatch_request(request)
53
+ return 1
54
+ for _ in range(args.number):
55
+ await dispatch_request(request)
56
+ return args.number
57
+
58
+ async def dispatch_requests_from_dataset():
59
+ """Generate and dispatch requests based on the dataset."""
60
+ total_query_count = 0
61
+ message_generator_class = DatasetRegistry(args.dataset)
62
+ message_generator = message_generator_class(args)
63
+
64
+ for messages in message_generator:
65
+ request = query_generator.build_request(messages, args)
66
+ if request is None:
67
+ continue
68
+ await dispatch_request(request)
69
+ total_query_count += 1
70
+ if args.number and total_query_count >= args.number:
71
+ break
72
+
73
+ return total_query_count
74
+
75
+ # Load prompt or dataset and dispatch requests accordingly
76
+ if args.prompt:
77
+ prompt = load_prompt(args.prompt)
78
+ messages = [{'role': 'user', 'content': prompt}]
79
+ total_queries = await dispatch_requests_from_prompt(messages)
80
+ elif args.dataset:
81
+ total_queries = await dispatch_requests_from_dataset()
82
+ else:
83
+ raise Exception('Either prompt or dataset is required!')
84
+
85
+ return total_queries
86
+
87
+
88
+ @exception_handler
89
+ async def send_requests_worker(
90
+ task_id,
91
+ request_queue: asyncio.Queue,
92
+ benchmark_data_queue: asyncio.Queue,
93
+ args: Arguments,
94
+ ):
95
+ client = AioHttpClient(args)
96
+ async with client:
97
+ while not (query_send_completed_event.is_set() and request_queue.empty()):
98
+ try:
99
+ # Attempt to get a request from the queue with a timeout
100
+ request = await asyncio.wait_for(request_queue.get(), timeout=0.0001)
101
+ request_queue.task_done()
102
+ except asyncio.TimeoutError:
103
+ # If timeout, continue to the next iteration
104
+ continue
105
+
106
+ # Initialize benchmark data for the current request
107
+ benchmark_data = BenchmarkData(request=request)
108
+ collected_messages = []
109
+ try:
110
+ # Send the request and process the response
111
+ async for is_error, state_code, response_data in client.post(request):
112
+ if is_error or state_code != HTTPStatus.OK:
113
+ logger.error(f'Request: {request} failed, state_code: {state_code}, data: {response_data}')
114
+ benchmark_data.success = False
115
+ break
116
+ if response_data:
117
+ collected_messages.append(response_data)
118
+ benchmark_data.chunk_times.append(time.perf_counter())
119
+ benchmark_data.success = True
120
+ benchmark_data.update_gpu_usage()
121
+ except Exception as e:
122
+ if response_data:
123
+ collected_messages.append(response_data)
124
+ benchmark_data.success = False
125
+ logger.exception(e)
126
+ logger.error(f'Request query: {request} exception')
127
+ finally:
128
+ # Record completion time and collected messages
129
+ benchmark_data.completed_time = time.perf_counter()
130
+ benchmark_data.response_messages = collected_messages
131
+ await benchmark_data_queue.put(benchmark_data)
132
+
133
+
134
+ @exception_handler
135
+ async def statistic_benchmark_metric_worker(benchmark_data_queue: asyncio.Queue, args: Arguments):
136
+ metrics = BenchmarkMetrics(concurrency=args.parallel)
137
+
138
+ api_plugin_class = ApiRegistry(args.api)
139
+ api_plugin = api_plugin_class(args.tokenizer_path)
140
+
141
+ result_db_path = get_result_db_path(args.name, args.model)
142
+ # Initialize wandb
143
+ if args.wandb_api_key:
144
+ import wandb
145
+ import datetime
146
+ os.environ['WANDB_SILENT'] = 'true'
147
+ os.environ['WANDB_DIR'] = './outputs'
148
+
149
+ wandb.login(key=args.wandb_api_key)
150
+ current_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
151
+ name = args.name if args.name else f'{args.model}_{current_time}'
152
+ wandb.init(project='perf_benchmark', name=name, config=args.to_dict())
153
+
154
+ with sqlite3.connect(result_db_path) as con:
155
+ cursor = con.cursor()
156
+ create_result_table(cursor)
157
+ with tqdm(desc='Processing') as pbar:
158
+ while not (data_process_completed_event.is_set() and benchmark_data_queue.empty()):
159
+ try:
160
+ # Attempt to get benchmark data from the queue with a timeout
161
+ benchmark_data = await asyncio.wait_for(benchmark_data_queue.get(), timeout=1)
162
+ benchmark_data_queue.task_done()
163
+ except asyncio.TimeoutError:
164
+ # If timeout, continue to the next iteration
165
+ continue
166
+
167
+ # Update metrics based on the benchmark data
168
+ metrics.update_metrics(benchmark_data, api_plugin)
169
+
170
+ # Insert benchmark data into the database and commit the transaction
171
+ insert_benchmark_data(cursor, benchmark_data)
172
+ con.commit()
173
+
174
+ # Create a message with the updated metrics
175
+ message = metrics.create_message()
176
+
177
+ # Log the message to wandb if the api key is provided
178
+ if args.wandb_api_key:
179
+ wandb.log(message)
180
+
181
+ # Log the message to the logger every n queries
182
+ if int(metrics.n_total_queries) % args.log_every_n_query == 0:
183
+ msg = json.dumps(message, ensure_ascii=False, indent=2)
184
+ logger.info(msg)
185
+
186
+ pbar.update(1) # Update the progress bar
187
+
188
+ return metrics, result_db_path
189
+
190
+
191
+ @exception_handler
192
+ async def start_server(args: Arguments) -> bool:
193
+ if args.api.startswith('local'):
194
+ # start local server
195
+ server = threading.Thread(target=start_app, args=(copy.deepcopy(args), ), daemon=True)
196
+ server.start()
197
+
198
+ if args.dataset.startswith('speed_benchmark'):
199
+ args.url = 'http://127.0.0.1:8877/v1/completions'
200
+ else:
201
+ args.url = 'http://127.0.0.1:8877/v1/chat/completions'
202
+ args.model = os.path.basename(args.model)
203
+
204
+ if not await test_connection(args):
205
+ raise TimeoutError('Test connection failed')
206
+
207
+
208
+ @exception_handler
209
+ async def benchmark(args: Arguments) -> None:
210
+ if platform.system() != 'Windows':
211
+ loop = asyncio.get_running_loop()
212
+ add_signal_handlers(loop)
213
+
214
+ request_queue = asyncio.Queue()
215
+ benchmark_data_queue = asyncio.Queue()
216
+
217
+ async def create_send_request_tasks():
218
+ tasks: List[asyncio.Task] = []
219
+ for idx in range(args.parallel):
220
+ task = asyncio.create_task(send_requests_worker(idx, request_queue, benchmark_data_queue, args))
221
+ tasks.append(task)
222
+ return tasks
223
+
224
+ async def run_tasks():
225
+ await start_server(args)
226
+
227
+ dispatch_task = asyncio.create_task(dispatch_requests_worker(request_queue, args))
228
+ statistic_benchmark_metric_task = asyncio.create_task(
229
+ statistic_benchmark_metric_worker(benchmark_data_queue, args))
230
+ send_request_tasks = await create_send_request_tasks()
231
+
232
+ expected_number_of_queries = await dispatch_task
233
+ await request_queue.join()
234
+ query_send_completed_event.set()
235
+
236
+ await asyncio.gather(*send_request_tasks, return_exceptions=True)
237
+ await benchmark_data_queue.join()
238
+ data_process_completed_event.set()
239
+
240
+ metrics, result_db_path = await statistic_benchmark_metric_task
241
+ summary_result(args, metrics, expected_number_of_queries, result_db_path)
242
+
243
+ await asyncio.sleep(0.250)
244
+
245
+ await run_tasks()