evalscope 0.17.0__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (66) hide show
  1. evalscope/benchmarks/bfcl/bfcl_adapter.py +1 -1
  2. evalscope/benchmarks/data_adapter.py +9 -4
  3. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -1
  4. evalscope/benchmarks/general_qa/general_qa_adapter.py +2 -1
  5. evalscope/benchmarks/hle/__init__.py +0 -0
  6. evalscope/benchmarks/hle/hle_adapter.py +118 -0
  7. evalscope/benchmarks/humaneval/humaneval_adapter.py +5 -21
  8. evalscope/benchmarks/mmlu/mmlu_adapter.py +1 -1
  9. evalscope/benchmarks/tau_bench/__init__.py +0 -0
  10. evalscope/benchmarks/tau_bench/tau_bench_adapter.py +110 -0
  11. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +7 -1
  12. evalscope/benchmarks/utils.py +1 -0
  13. evalscope/constants.py +5 -21
  14. evalscope/evaluator/__init__.py +1 -1
  15. evalscope/evaluator/evaluator.py +5 -3
  16. evalscope/metrics/__init__.py +3 -1
  17. evalscope/metrics/completion_parsers.py +7 -0
  18. evalscope/metrics/llm_judge.py +6 -5
  19. evalscope/metrics/metrics.py +19 -7
  20. evalscope/models/__init__.py +4 -8
  21. evalscope/models/adapters/__init__.py +4 -9
  22. evalscope/models/adapters/base_adapter.py +4 -0
  23. evalscope/models/adapters/bfcl_adapter.py +2 -0
  24. evalscope/models/adapters/chat_adapter.py +3 -0
  25. evalscope/models/adapters/choice_adapter.py +4 -0
  26. evalscope/models/adapters/custom_adapter.py +7 -3
  27. evalscope/models/adapters/server_adapter.py +2 -0
  28. evalscope/models/adapters/t2i_adapter.py +3 -0
  29. evalscope/models/adapters/tau_bench_adapter.py +189 -0
  30. evalscope/models/register.py +0 -14
  31. evalscope/perf/arguments.py +13 -0
  32. evalscope/perf/benchmark.py +38 -39
  33. evalscope/perf/http_client.py +30 -86
  34. evalscope/perf/main.py +2 -2
  35. evalscope/perf/plugin/__init__.py +3 -2
  36. evalscope/perf/plugin/api/__init__.py +4 -3
  37. evalscope/perf/plugin/api/base.py +22 -4
  38. evalscope/perf/plugin/api/custom_api.py +212 -55
  39. evalscope/perf/plugin/api/dashscope_api.py +4 -10
  40. evalscope/perf/plugin/api/default_api.py +105 -0
  41. evalscope/perf/plugin/api/openai_api.py +17 -19
  42. evalscope/perf/plugin/datasets/__init__.py +10 -7
  43. evalscope/perf/plugin/datasets/base.py +22 -1
  44. evalscope/perf/plugin/datasets/custom.py +2 -1
  45. evalscope/perf/plugin/datasets/flickr8k.py +4 -27
  46. evalscope/perf/plugin/datasets/kontext_bench.py +28 -0
  47. evalscope/perf/plugin/datasets/line_by_line.py +2 -1
  48. evalscope/perf/plugin/datasets/longalpaca.py +2 -1
  49. evalscope/perf/plugin/datasets/openqa.py +2 -1
  50. evalscope/perf/plugin/datasets/random_dataset.py +15 -4
  51. evalscope/perf/plugin/datasets/random_vl_dataset.py +80 -0
  52. evalscope/perf/plugin/registry.py +36 -16
  53. evalscope/perf/utils/benchmark_util.py +14 -20
  54. evalscope/perf/utils/db_util.py +79 -61
  55. evalscope/utils/io_utils.py +10 -0
  56. evalscope/version.py +2 -2
  57. {evalscope-0.17.0.dist-info → evalscope-0.17.1.dist-info}/METADATA +54 -34
  58. {evalscope-0.17.0.dist-info → evalscope-0.17.1.dist-info}/RECORD +65 -58
  59. tests/cli/test_all.py +18 -2
  60. tests/cli/test_run.py +25 -37
  61. tests/perf/test_perf.py +29 -2
  62. evalscope/models/model.py +0 -189
  63. {evalscope-0.17.0.dist-info → evalscope-0.17.1.dist-info}/LICENSE +0 -0
  64. {evalscope-0.17.0.dist-info → evalscope-0.17.1.dist-info}/WHEEL +0 -0
  65. {evalscope-0.17.0.dist-info → evalscope-0.17.1.dist-info}/entry_points.txt +0 -0
  66. {evalscope-0.17.0.dist-info → evalscope-0.17.1.dist-info}/top_level.txt +0 -0
@@ -53,7 +53,10 @@ def initialize_model_adapter(task_cfg: 'TaskConfig', benchmark: 'DataAdapter', b
53
53
  if task_cfg.eval_type == EvalType.SERVICE or task_cfg.api_url is not None:
54
54
 
55
55
  if 'server' not in model_adapter_cls_str:
56
+ logger.warning(f'Output type {model_adapter_cls_str} is not supported for service evaluation. '
57
+ f'Using server model adapter instead.')
56
58
  model_adapter_cls_str = 'server'
59
+ benchmark.model_adapter = model_adapter_cls_str
57
60
 
58
61
  # init server model adapter
59
62
  model_adapter_cls = get_model_adapter(model_adapter_cls_str)
@@ -71,6 +74,7 @@ def initialize_model_adapter(task_cfg: 'TaskConfig', benchmark: 'DataAdapter', b
71
74
  logger.warning(f'Output type {model_adapter_cls_str} is not supported for benchmark {benchmark.name}.'
72
75
  f'Using {benchmark.output_types[0]} instead.')
73
76
  model_adapter_cls_str = benchmark.output_types[0]
77
+ benchmark.model_adapter = model_adapter_cls_str
74
78
 
75
79
  model_adapter_cls = get_model_adapter(model_adapter_cls_str)
76
80
  return model_adapter_cls(
@@ -4,11 +4,13 @@ import uuid
4
4
  from typing import Any, List, Optional, Union
5
5
 
6
6
  from evalscope.utils.logger import get_logger
7
+ from ..register import register_model_adapter
7
8
  from .server_adapter import ServerModelAdapter
8
9
 
9
10
  logger = get_logger()
10
11
 
11
12
 
13
+ @register_model_adapter(name='bfcl_server')
12
14
  class BFCLAdapter(ServerModelAdapter):
13
15
  """
14
16
  BFCL model adapter to request remote API model and generate results for BFCL evaluation.
@@ -3,15 +3,18 @@ import time
3
3
  import torch
4
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
5
 
6
+ from evalscope.constants import OutputType
6
7
  from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, Usage
7
8
  from evalscope.utils.logger import get_logger
8
9
  from evalscope.utils.model_utils import fix_do_sample_warning
9
10
  from ..local_model import LocalModel
11
+ from ..register import register_model_adapter
10
12
  from .base_adapter import BaseModelAdapter
11
13
 
12
14
  logger = get_logger()
13
15
 
14
16
 
17
+ @register_model_adapter(name=OutputType.GENERATION)
15
18
  class ChatGenerationModelAdapter(BaseModelAdapter):
16
19
  """
17
20
  Chat generation model adapter.
@@ -3,11 +3,14 @@ import time
3
3
  import torch
4
4
  from typing import List
5
5
 
6
+ from evalscope.constants import OutputType
6
7
  from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage
7
8
  from ..local_model import LocalModel
9
+ from ..register import register_model_adapter
8
10
  from .base_adapter import BaseModelAdapter
9
11
 
10
12
 
13
+ @register_model_adapter(name=OutputType.MULTIPLE_CHOICE)
11
14
  class MultiChoiceModelAdapter(BaseModelAdapter):
12
15
  """ The multi-choice model adapter. """
13
16
 
@@ -110,6 +113,7 @@ class MultiChoiceModelAdapter(BaseModelAdapter):
110
113
  return log_probs, {'tokens': tokens}
111
114
 
112
115
 
116
+ @register_model_adapter(name=OutputType.CONTINUOUS)
113
117
  class ContinuationLogitsModelAdapter(MultiChoiceModelAdapter):
114
118
  """
115
119
  Continuation-logits model adapter.
@@ -1,12 +1,16 @@
1
- from typing import Any, Dict, List, Union
1
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
2
2
 
3
- from ..custom import CustomModel
3
+ from ..register import register_model_adapter
4
4
  from .base_adapter import BaseModelAdapter
5
5
 
6
+ if TYPE_CHECKING:
7
+ from ..custom import CustomModel
6
8
 
9
+
10
+ @register_model_adapter(name='custom')
7
11
  class CustomModelAdapter(BaseModelAdapter):
8
12
 
9
- def __init__(self, custom_model: CustomModel, **kwargs):
13
+ def __init__(self, custom_model: 'CustomModel', **kwargs):
10
14
  """
11
15
  Custom model adapter.
12
16
 
@@ -7,11 +7,13 @@ from typing import List, Optional, Union
7
7
 
8
8
  from evalscope.utils.argument_utils import get_supported_params
9
9
  from evalscope.utils.logger import get_logger
10
+ from ..register import register_model_adapter
10
11
  from .base_adapter import BaseModelAdapter
11
12
 
12
13
  logger = get_logger()
13
14
 
14
15
 
16
+ @register_model_adapter(name='server')
15
17
  class ServerModelAdapter(BaseModelAdapter):
16
18
  """
17
19
  Server model adapter to request remote API model and generate results.
@@ -3,15 +3,18 @@ import time
3
3
  import torch
4
4
  from typing import Any, Dict, List, Optional, Tuple, Union
5
5
 
6
+ from evalscope.constants import OutputType
6
7
  from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage
7
8
  from evalscope.utils.io_utils import OutputsStructure
8
9
  from evalscope.utils.logger import get_logger
9
10
  from ..local_model import LocalModel
11
+ from ..register import register_model_adapter
10
12
  from .base_adapter import BaseModelAdapter
11
13
 
12
14
  logger = get_logger()
13
15
 
14
16
 
17
+ @register_model_adapter(name=OutputType.IMAGE_GENERATION)
15
18
  class T2IModelAdapter(BaseModelAdapter):
16
19
  """
17
20
  Text to image model adapter.
@@ -0,0 +1,189 @@
1
+ import json
2
+ import time
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ from evalscope.utils.logger import get_logger
6
+ from ..register import register_model_adapter
7
+ from .server_adapter import ServerModelAdapter
8
+
9
+ logger = get_logger()
10
+
11
+
12
+ @register_model_adapter(name='tau_bench_server')
13
+ class TauBenchAdapter(ServerModelAdapter):
14
+ """
15
+ TauBench model adapter to request remote API model and generate results for TauBench evaluation.
16
+ Support multi-turn and single-turn function calling tasks.
17
+ """
18
+
19
+ def __init__(self, api_url: str, model_id: str, api_key: str = 'EMPTY', **kwargs):
20
+ """
21
+ Args:
22
+ api_url: The URL of the remote API model.
23
+ model_id: The ID of the remote API model.
24
+ api_key: The API key of the remote API model.
25
+ """
26
+ super().__init__(api_url=api_url, model_id=model_id, api_key=api_key, **kwargs)
27
+
28
+ self._patch_agent_solve()
29
+
30
+ def predict(self, inputs: List[dict], infer_cfg: Optional[dict] = None) -> List[dict]:
31
+ """
32
+ Model prediction func. For multi-turn evals, we pass a list[list[message]] to the model
33
+ where each list is a follow up turn in the conversation
34
+ each turn is a List[List[Message]]
35
+
36
+ Args:
37
+ inputs (List[dict]): The input data.
38
+ infer_cfg (dict): Inference configuration.
39
+
40
+ Returns:
41
+ res (List[dict]): The model prediction results.
42
+ """
43
+ infer_cfg = infer_cfg or {}
44
+ results = []
45
+
46
+ for input_item in inputs:
47
+ raw_input = input_item.get('raw_input')
48
+
49
+ res_d = self.solve(env_name=raw_input['env_name'], task_index=raw_input['task_index'], infer_cfg=infer_cfg)
50
+
51
+ wrapper_res = {
52
+ 'choices': [{
53
+ 'index': 0,
54
+ 'message': {
55
+ 'content': json.dumps(res_d, ensure_ascii=False),
56
+ 'role': 'assistant'
57
+ }
58
+ }],
59
+ 'created':
60
+ time.time(),
61
+ 'model':
62
+ self.model_id,
63
+ 'object':
64
+ 'chat.completion',
65
+ 'usage': {
66
+ 'completion_tokens': 0,
67
+ 'prompt_tokens': 0,
68
+ 'total_tokens': 0
69
+ }
70
+ }
71
+
72
+ results.append(wrapper_res)
73
+
74
+ return results
75
+
76
+ def _patch_agent_solve(self):
77
+ """Patch ToolCallingAgent.solve method to use custom model configuration"""
78
+ from tau_bench.agents.tool_calling_agent import ToolCallingAgent, message_to_action
79
+ from tau_bench.envs.base import Env
80
+ from tau_bench.types import RESPOND_ACTION_NAME, SolveResult
81
+ from typing import List, Optional
82
+
83
+ def patched_solve(self,
84
+ env: Env,
85
+ task_index: Optional[int] = None,
86
+ max_num_steps: int = 30,
87
+ infer_cfg: Optional[dict] = {}) -> SolveResult:
88
+ env_reset_res = env.reset(task_index=task_index)
89
+ obs = env_reset_res.observation
90
+ info = env_reset_res.info.model_dump()
91
+ reward = 0.0
92
+ messages: List[Dict[str, Any]] = [
93
+ {
94
+ 'role': 'system',
95
+ 'content': self.wiki
96
+ },
97
+ {
98
+ 'role': 'user',
99
+ 'content': obs
100
+ },
101
+ ]
102
+
103
+ for step_index in range(max_num_steps):
104
+ # Use adapter's model configuration instead of agent's
105
+ request_json = adapter_instance.make_request(
106
+ input_item={
107
+ 'messages': messages,
108
+ 'tools': self.tools_info
109
+ }, infer_cfg=infer_cfg)
110
+ res = adapter_instance.send_request(request_json)
111
+
112
+ next_message = res['choices'][0]['message']
113
+ action = message_to_action(next_message)
114
+ env_response = env.step(action)
115
+ reward = env_response.reward
116
+ info = {**info, **env_response.info.model_dump()}
117
+
118
+ if action.name != RESPOND_ACTION_NAME:
119
+ next_message['tool_calls'] = next_message['tool_calls'][:1]
120
+ messages.extend([
121
+ next_message,
122
+ {
123
+ 'role': 'tool',
124
+ 'tool_call_id': next_message['tool_calls'][0]['id'],
125
+ 'name': next_message['tool_calls'][0]['function']['name'],
126
+ 'content': env_response.observation,
127
+ },
128
+ ])
129
+ else:
130
+ messages.extend([
131
+ next_message,
132
+ {
133
+ 'role': 'user',
134
+ 'content': env_response.observation
135
+ },
136
+ ])
137
+ logger.debug(f'Task: {task_index} Step: {step_index} finished')
138
+
139
+ if env_response.done:
140
+ break
141
+
142
+ return SolveResult(
143
+ reward=reward,
144
+ info=info,
145
+ messages=messages,
146
+ total_cost=0,
147
+ )
148
+
149
+ adapter_instance = self
150
+
151
+ ToolCallingAgent.solve = patched_solve
152
+
153
+ return 'ToolCallingAgent.solve patched successfully'
154
+
155
+ def solve(self, env_name, task_index, infer_cfg, **kwargs):
156
+ """
157
+ Solve a specific task in the TauBench environment.
158
+
159
+ Args:
160
+ env_name (str): The name of the TauBench environment.
161
+ task_index (int): The index of the task to solve.
162
+ **kwargs: Additional arguments for the task.
163
+
164
+ Returns:
165
+ dict: The result of the task.
166
+ """
167
+ from tau_bench.agents.tool_calling_agent import ToolCallingAgent
168
+ from tau_bench.envs import get_env
169
+
170
+ # This method can be implemented to solve specific tasks in the TauBench environment
171
+ isolated_env = get_env(
172
+ env_name=env_name,
173
+ user_strategy='llm',
174
+ user_model='dummy', # Use dummy model to prevent errors
175
+ user_provider='openai', # Use dummy provider to prevent errors
176
+ task_split='test',
177
+ task_index=task_index,
178
+ )
179
+ agent = ToolCallingAgent(
180
+ tools_info=isolated_env.tools_info,
181
+ wiki=isolated_env.wiki,
182
+ model='dummy', # Use dummy model to prevent errors
183
+ provider='dummy', # Use dummy provider to prevent errors
184
+ temperature=0, # dummy temperature to prevent errors
185
+ )
186
+
187
+ res = agent.solve(env=isolated_env, task_index=task_index, infer_cfg=infer_cfg)
188
+
189
+ return res.model_dump()
@@ -1,6 +1,3 @@
1
- from evalscope.constants import OutputType
2
- from .adapters import *
3
-
4
1
  MODEL_ADAPTERS = {}
5
2
 
6
3
 
@@ -42,14 +39,3 @@ def register_model_adapter_class(cls, name=None):
42
39
  if name in MODEL_ADAPTERS:
43
40
  raise ValueError(f"Model adapter class '{name}' is already registered.")
44
41
  MODEL_ADAPTERS[name] = cls
45
-
46
-
47
- # register all model adapters
48
- register_model_adapter_class(BaseModelAdapter, name='base')
49
- register_model_adapter_class(ChatGenerationModelAdapter, name=OutputType.GENERATION)
50
- register_model_adapter_class(ContinuationLogitsModelAdapter, name=OutputType.CONTINUOUS)
51
- register_model_adapter_class(MultiChoiceModelAdapter, name=OutputType.MULTIPLE_CHOICE)
52
- register_model_adapter_class(CustomModelAdapter, name='custom')
53
- register_model_adapter_class(ServerModelAdapter, name='server')
54
- register_model_adapter_class(BFCLAdapter, name='bfcl_server')
55
- register_model_adapter_class(T2IModelAdapter, name=OutputType.IMAGE_GENERATION)
@@ -31,6 +31,7 @@ class Arguments(BaseArgument):
31
31
  number: Union[int, List[int]] = 1000 # Number of requests to be made
32
32
  parallel: Union[int, List[int]] = 1 # Number of parallel requests
33
33
  rate: int = -1 # Rate limit for requests (default: -1, no limit)
34
+ sleep_interval: int = 5 # Sleep interval between performance runs, in seconds
34
35
 
35
36
  # Logging and debugging
36
37
  log_every_n_query: int = 10 # Log every N queries
@@ -49,6 +50,11 @@ class Arguments(BaseArgument):
49
50
  prompt: Optional[str] = None # The prompt text
50
51
  query_template: Optional[str] = None # Template for the query
51
52
  apply_chat_template: Optional[bool] = None # Whether to apply chat template
53
+ # random vl settings
54
+ image_width: int = 224 # Width of the image for random VL dataset
55
+ image_height: int = 224 # Height of the image for random VL dataset
56
+ image_format: str = 'RGB' # Image format for random VL dataset
57
+ image_num: int = 1 # Number of images for random VL dataset
52
58
 
53
59
  # Dataset settings
54
60
  dataset: str = 'openqa' # Dataset type (default: 'line_by_line')
@@ -142,6 +148,8 @@ def add_argument(parser: argparse.ArgumentParser):
142
148
  parser.add_argument('-n', '--number', type=int, default=1000, nargs='+', help='How many requests to be made')
143
149
  parser.add_argument('--parallel', type=int, default=1, nargs='+', help='Set number of concurrency requests, default 1') # noqa: E501
144
150
  parser.add_argument('--rate', type=int, default=-1, help='Number of requests per second. default None')
151
+ parser.add_argument(
152
+ '--sleep-interval', type=int, default=5, help='Sleep interval between performance runs, in seconds. Default 5') # noqa: E501
145
153
 
146
154
  # Logging and debugging
147
155
  parser.add_argument('--log-every-n-query', type=int, default=10, help='Logging every n query')
@@ -158,6 +166,11 @@ def add_argument(parser: argparse.ArgumentParser):
158
166
  parser.add_argument('--query-template', type=str, default=None, help='Specify the query template')
159
167
  parser.add_argument(
160
168
  '--apply-chat-template', type=argparse.BooleanOptionalAction, default=None, help='Apply chat template to the prompt') # noqa: E501
169
+ # random vl settings
170
+ parser.add_argument('--image-width', type=int, default=224, help='Width of the image for random VL dataset')
171
+ parser.add_argument('--image-height', type=int, default=224, help='Height of the image for random VL dataset')
172
+ parser.add_argument('--image-format', type=str, default='RGB', help='Image format for random VL dataset')
173
+ parser.add_argument('--image-num', type=int, default=1, help='Number of images for random VL dataset')
161
174
 
162
175
  # Output settings
163
176
  parser.add_argument('--outputs-dir', help='Outputs dir.', default='outputs')
@@ -6,15 +6,18 @@ import sqlite3
6
6
  import time
7
7
  from http import HTTPStatus
8
8
  from tqdm import tqdm
9
- from typing import AsyncGenerator, Dict, List, Tuple
10
-
11
- from evalscope.perf.arguments import Arguments
12
- from evalscope.perf.http_client import AioHttpClient, test_connection
13
- from evalscope.perf.plugin.registry import ApiRegistry, DatasetRegistry
14
- from evalscope.perf.utils.benchmark_util import BenchmarkData, BenchmarkMetrics
15
- from evalscope.perf.utils.db_util import create_result_table, get_result_db_path, insert_benchmark_data, summary_result
16
- from evalscope.perf.utils.handler import add_signal_handlers, exception_handler
9
+ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Tuple
10
+
17
11
  from evalscope.utils.logger import get_logger
12
+ from .arguments import Arguments
13
+ from .http_client import AioHttpClient, test_connection
14
+ from .plugin import ApiRegistry, DatasetRegistry
15
+ from .utils.benchmark_util import BenchmarkData, BenchmarkMetrics
16
+ from .utils.db_util import create_result_table, get_result_db_path, insert_benchmark_data, load_prompt, summary_result
17
+ from .utils.handler import add_signal_handlers, exception_handler
18
+
19
+ if TYPE_CHECKING:
20
+ from .plugin import ApiPluginBase, DatasetPluginBase
18
21
 
19
22
  logger = get_logger()
20
23
 
@@ -22,28 +25,22 @@ data_process_completed_event = asyncio.Event()
22
25
 
23
26
 
24
27
  @exception_handler
25
- async def get_requests(args: Arguments) -> AsyncGenerator[dict, None]:
26
- query_generator_class = ApiRegistry(args.api)
27
- query_generator = query_generator_class(args.tokenizer_path)
28
-
29
- def load_prompt(prompt_path_or_text):
30
- if prompt_path_or_text.startswith('@'):
31
- with open(prompt_path_or_text[1:], 'r', encoding='utf-8') as file:
32
- return file.read()
33
- return prompt_path_or_text
34
-
35
- async def generate_requests_from_prompt(messages):
36
- request = query_generator.build_request(messages, args)
28
+ async def get_requests(args: Arguments, api_plugin: 'ApiPluginBase') -> AsyncGenerator[dict, None]:
29
+
30
+ async def generate_requests_from_prompt():
31
+ prompt = load_prompt(args.prompt)
32
+ messages = [{'role': 'user', 'content': prompt}] if args.apply_chat_template else prompt
33
+ request = api_plugin.build_request(messages)
37
34
  for _ in range(args.number):
38
35
  yield request
39
36
 
40
37
  async def generate_requests_from_dataset():
41
- message_generator_class = DatasetRegistry(args.dataset)
38
+ message_generator_class = DatasetRegistry.get_class(args.dataset)
42
39
  message_generator = message_generator_class(args)
43
40
 
44
41
  dataset_messages = []
45
42
  try:
46
- for messages in message_generator:
43
+ for messages in message_generator.build_messages():
47
44
  dataset_messages.append(messages)
48
45
  except StopIteration:
49
46
  pass
@@ -56,7 +53,7 @@ async def get_requests(args: Arguments) -> AsyncGenerator[dict, None]:
56
53
 
57
54
  while count < args.number:
58
55
  messages = dataset_messages[dataset_index]
59
- request = query_generator.build_request(messages, args)
56
+ request = api_plugin.build_request(messages)
60
57
  if request is not None:
61
58
  yield request
62
59
  count += 1
@@ -64,13 +61,11 @@ async def get_requests(args: Arguments) -> AsyncGenerator[dict, None]:
64
61
  dataset_index = (dataset_index + 1) % len(dataset_messages)
65
62
 
66
63
  if args.prompt:
67
- prompt = load_prompt(args.prompt)
68
- messages = [{'role': 'user', 'content': prompt}] if args.apply_chat_template else prompt
69
- generator = generate_requests_from_prompt(messages)
64
+ generator = generate_requests_from_prompt()
70
65
  elif args.dataset:
71
66
  generator = generate_requests_from_dataset()
72
67
  else:
73
- raise Exception('Either prompt or dataset is required!')
68
+ raise ValueError('Either prompt or dataset is required!')
74
69
 
75
70
  async for request in generator:
76
71
  yield request
@@ -85,9 +80,10 @@ async def send_request(
85
80
  request: dict,
86
81
  benchmark_data_queue: asyncio.Queue,
87
82
  args: Arguments,
83
+ api_plugin: 'ApiPluginBase',
88
84
  ):
89
85
  async with semaphore:
90
- client = AioHttpClient(args)
86
+ client = AioHttpClient(args, api_plugin)
91
87
  async with client:
92
88
  benchmark_data = BenchmarkData(request=request)
93
89
  benchmark_data.start_time = time.perf_counter()
@@ -95,7 +91,8 @@ async def send_request(
95
91
  try:
96
92
  async for is_error, state_code, response_data in client.post(request):
97
93
  if is_error or state_code != HTTPStatus.OK:
98
- logger.error(f'Request: {request} failed, state_code: {state_code}, data: {response_data}')
94
+ error_msg = str(response_data) if response_data else 'Unknown error'
95
+ logger.error(f'Request: {request} failed, state_code: {state_code}, data: {error_msg}')
99
96
  benchmark_data.success = False
100
97
  break
101
98
  if response_data:
@@ -116,12 +113,9 @@ async def send_request(
116
113
 
117
114
 
118
115
  @exception_handler
119
- async def statistic_benchmark_metric(benchmark_data_queue: asyncio.Queue, args: Arguments):
116
+ async def statistic_benchmark_metric(benchmark_data_queue: asyncio.Queue, args: Arguments, api_plugin: 'ApiPluginBase'):
120
117
  metrics = BenchmarkMetrics(concurrency=args.parallel)
121
118
 
122
- api_plugin_class = ApiRegistry(args.api)
123
- api_plugin = api_plugin_class(args.tokenizer_path)
124
-
125
119
  result_db_path = get_result_db_path(args)
126
120
 
127
121
  collected_benchmark_data = []
@@ -172,8 +166,8 @@ async def statistic_benchmark_metric(benchmark_data_queue: asyncio.Queue, args:
172
166
 
173
167
 
174
168
  @exception_handler
175
- async def connect_test(args: Arguments) -> bool:
176
- if (not args.no_test_connection) and (not await test_connection(args)):
169
+ async def connect_test(args: Arguments, api_plugin) -> bool:
170
+ if (not args.no_test_connection) and (not await test_connection(args, api_plugin)):
177
171
  raise TimeoutError('Test connection failed')
178
172
 
179
173
 
@@ -183,19 +177,24 @@ async def benchmark(args: Arguments) -> Tuple[Dict, Dict]:
183
177
  loop = asyncio.get_running_loop()
184
178
  add_signal_handlers(loop)
185
179
 
180
+ # Create API plugin instance for request/response processing
181
+ api_plugin_class = ApiRegistry.get_class(args.api)
182
+ api_plugin = api_plugin_class(args)
183
+
186
184
  # init queue
187
185
  benchmark_data_queue = asyncio.Queue()
188
186
  # reset event
189
187
  data_process_completed_event.clear()
190
188
  # test connection
191
- await connect_test(args)
189
+ await connect_test(args, api_plugin)
192
190
  # start statistic benchmark metric
193
- statistic_benchmark_metric_task = asyncio.create_task(statistic_benchmark_metric(benchmark_data_queue, args))
191
+ statistic_benchmark_metric_task = asyncio.create_task(
192
+ statistic_benchmark_metric(benchmark_data_queue, args, api_plugin))
194
193
  # start send request
195
194
  semaphore = asyncio.Semaphore(args.parallel)
196
195
  send_request_tasks: List[asyncio.Task] = []
197
- async for request in get_requests(args):
198
- task = asyncio.create_task(send_request(semaphore, request, benchmark_data_queue, args))
196
+ async for request in get_requests(args, api_plugin):
197
+ task = asyncio.create_task(send_request(semaphore, request, benchmark_data_queue, args, api_plugin))
199
198
  send_request_tasks.append(task)
200
199
 
201
200
  await asyncio.gather(*send_request_tasks, return_exceptions=True)