evalscope 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (45) hide show
  1. evalscope/arguments.py +1 -1
  2. evalscope/backend/rag_eval/utils/llm.py +4 -5
  3. evalscope/benchmarks/alpaca_eval/__init__.py +0 -0
  4. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +109 -0
  5. evalscope/benchmarks/arena_hard/__init__.py +0 -0
  6. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +120 -0
  7. evalscope/benchmarks/arena_hard/utils.py +162 -0
  8. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +2 -5
  9. evalscope/benchmarks/competition_math/competition_math_adapter.py +0 -1
  10. evalscope/benchmarks/data_adapter.py +26 -2
  11. evalscope/benchmarks/data_collection/data_collection_adapter.py +0 -1
  12. evalscope/benchmarks/general_qa/general_qa_adapter.py +5 -11
  13. evalscope/benchmarks/ifeval/ifeval_adapter.py +2 -5
  14. evalscope/benchmarks/live_code_bench/testing_util.py +3 -3
  15. evalscope/benchmarks/mmlu_redux/__init__.py +0 -0
  16. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +182 -0
  17. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +2 -5
  18. evalscope/collections/evaluator.py +1 -1
  19. evalscope/config.py +6 -3
  20. evalscope/constants.py +1 -0
  21. evalscope/evaluator/evaluator.py +5 -4
  22. evalscope/metrics/llm_judge.py +1 -1
  23. evalscope/models/chat_adapter.py +32 -11
  24. evalscope/models/custom_adapter.py +1 -1
  25. evalscope/perf/arguments.py +19 -46
  26. evalscope/perf/benchmark.py +64 -90
  27. evalscope/perf/main.py +1 -1
  28. evalscope/perf/plugin/api/openai_api.py +4 -2
  29. evalscope/perf/plugin/datasets/__init__.py +1 -0
  30. evalscope/perf/plugin/datasets/openqa.py +6 -11
  31. evalscope/perf/plugin/datasets/random_dataset.py +51 -0
  32. evalscope/perf/plugin/datasets/speed_benchmark.py +11 -0
  33. evalscope/perf/utils/db_util.py +5 -2
  34. evalscope/run.py +14 -2
  35. evalscope/version.py +2 -2
  36. {evalscope-0.13.0.dist-info → evalscope-0.13.2.dist-info}/METADATA +42 -78
  37. {evalscope-0.13.0.dist-info → evalscope-0.13.2.dist-info}/RECORD +45 -37
  38. tests/cli/test_all.py +33 -24
  39. tests/cli/test_run.py +69 -22
  40. tests/perf/test_perf.py +23 -0
  41. tests/rag/test_ragas.py +4 -1
  42. {evalscope-0.13.0.dist-info → evalscope-0.13.2.dist-info}/LICENSE +0 -0
  43. {evalscope-0.13.0.dist-info → evalscope-0.13.2.dist-info}/WHEEL +0 -0
  44. {evalscope-0.13.0.dist-info → evalscope-0.13.2.dist-info}/entry_points.txt +0 -0
  45. {evalscope-0.13.0.dist-info → evalscope-0.13.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ from collections import defaultdict
2
+ from typing import Any, Dict
3
+
4
+ from evalscope.benchmarks import Benchmark, DataAdapter
5
+ from evalscope.constants import EvalType, OutputType
6
+ from evalscope.metrics import exact_match
7
+ from evalscope.utils.logger import get_logger
8
+ from evalscope.utils.utils import ResponseParser
9
+
10
+ logger = get_logger()
11
+
12
+ SUBSET_LIST = [
13
+ 'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology',
14
+ 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics',
15
+ 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics',
16
+ 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
17
+ 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
18
+ 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics',
19
+ 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history',
20
+ 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning',
21
+ 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition',
22
+ 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine',
23
+ 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology',
24
+ 'world_religions'
25
+ ]
26
+
27
+ SUBJECT_MAPPING = {
28
+ 'abstract_algebra': ['Abstract Algebra', 'math', 'STEM'],
29
+ 'anatomy': ['Anatomy', 'health', 'Other'],
30
+ 'astronomy': ['Astronomy', 'physics', 'STEM'],
31
+ 'business_ethics': ['Business Ethics', 'business', 'Other'],
32
+ 'clinical_knowledge': ['Clinical Knowledge', 'health', 'Other'],
33
+ 'college_biology': ['College Biology', 'biology', 'STEM'],
34
+ 'college_chemistry': ['College Chemistry', 'chemistry', 'STEM'],
35
+ 'college_computer_science': ['College Computer Science', 'computer science', 'STEM'],
36
+ 'college_mathematics': ['College Mathematics', 'math', 'STEM'],
37
+ 'college_medicine': ['College Medicine', 'health', 'Other'],
38
+ 'college_physics': ['College Physics', 'physics', 'STEM'],
39
+ 'computer_security': ['Computer Security', 'computer science', 'STEM'],
40
+ 'conceptual_physics': ['Conceptual Physics', 'physics', 'STEM'],
41
+ 'econometrics': ['Econometrics', 'economics', 'Social Science'],
42
+ 'electrical_engineering': ['Electrical Engineering', 'engineering', 'STEM'],
43
+ 'elementary_mathematics': ['Elementary Mathematics', 'math', 'STEM'],
44
+ 'formal_logic': ['Formal Logic', 'philosophy', 'Humanities'],
45
+ 'global_facts': ['Global Facts', 'other', 'Other'],
46
+ 'high_school_biology': ['High School Biology', 'biology', 'STEM'],
47
+ 'high_school_chemistry': ['High School Chemistry', 'chemistry', 'STEM'],
48
+ 'high_school_computer_science': ['High School Computer Science', 'computer science', 'STEM'],
49
+ 'high_school_european_history': ['High School European History', 'history', 'Humanities'],
50
+ 'high_school_geography': ['High School Geography', 'geography', 'Social Science'],
51
+ 'high_school_government_and_politics': ['High School Government And Politics', 'politics', 'Social Science'],
52
+ 'high_school_macroeconomics': ['High School Macroeconomics', 'economics', 'Social Science'],
53
+ 'high_school_mathematics': ['High School Mathematics', 'math', 'STEM'],
54
+ 'high_school_microeconomics': ['High School Microeconomics', 'economics', 'Social Science'],
55
+ 'high_school_physics': ['High School Physics', 'physics', 'STEM'],
56
+ 'high_school_psychology': ['High School Psychology', 'psychology', 'Social Science'],
57
+ 'high_school_statistics': ['High School Statistics', 'math', 'STEM'],
58
+ 'high_school_us_history': ['High School Us History', 'history', 'Humanities'],
59
+ 'high_school_world_history': ['High School World History', 'history', 'Humanities'],
60
+ 'human_aging': ['Human Aging', 'health', 'Other'],
61
+ 'human_sexuality': ['Human Sexuality', 'culture', 'Social Science'],
62
+ 'international_law': ['International Law', 'law', 'Humanities'],
63
+ 'jurisprudence': ['Jurisprudence', 'law', 'Humanities'],
64
+ 'logical_fallacies': ['Logical Fallacies', 'philosophy', 'Humanities'],
65
+ 'machine_learning': ['Machine Learning', 'computer science', 'STEM'],
66
+ 'management': ['Management', 'business', 'Other'],
67
+ 'marketing': ['Marketing', 'business', 'Other'],
68
+ 'medical_genetics': ['Medical Genetics', 'health', 'Other'],
69
+ 'miscellaneous': ['Miscellaneous', 'other', 'Other'],
70
+ 'moral_disputes': ['Moral Disputes', 'philosophy', 'Humanities'],
71
+ 'moral_scenarios': ['Moral Scenarios', 'philosophy', 'Humanities'],
72
+ 'nutrition': ['Nutrition', 'health', 'Other'],
73
+ 'philosophy': ['Philosophy', 'philosophy', 'Humanities'],
74
+ 'prehistory': ['Prehistory', 'history', 'Humanities'],
75
+ 'professional_accounting': ['Professional Accounting', 'other', 'Other'],
76
+ 'professional_law': ['Professional Law', 'law', 'Humanities'],
77
+ 'professional_medicine': ['Professional Medicine', 'health', 'Other'],
78
+ 'professional_psychology': ['Professional Psychology', 'psychology', 'Social Science'],
79
+ 'public_relations': ['Public Relations', 'politics', 'Social Science'],
80
+ 'security_studies': ['Security Studies', 'politics', 'Social Science'],
81
+ 'sociology': ['Sociology', 'culture', 'Social Science'],
82
+ 'us_foreign_policy': ['Us Foreign Policy', 'politics', 'Social Science'],
83
+ 'virology': ['Virology', 'health', 'Other'],
84
+ 'world_religions': ['World Religions', 'philosophy', 'Humanities'],
85
+ }
86
+
87
+
88
+ @Benchmark.register(
89
+ name='mmlu_redux',
90
+ pretty_name='MMLU-Redux',
91
+ dataset_id='AI-ModelScope/mmlu-redux-2.0',
92
+ model_adapter=OutputType.GENERATION,
93
+ output_types=[OutputType.MULTIPLE_CHOICE, OutputType.GENERATION],
94
+ subset_list=SUBSET_LIST,
95
+ metric_list=['AverageAccuracy'],
96
+ few_shot_num=0,
97
+ train_split=None,
98
+ eval_split='test',
99
+ prompt_template=
100
+ 'The following are multiple choice questions (with answers) about {subset_name}. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n{query}', # noqa: E501
101
+ )
102
+ class MMLUReduxAdapter(DataAdapter):
103
+
104
+ def __init__(self, **kwargs):
105
+ super().__init__(**kwargs)
106
+
107
+ if self.few_shot_num > 0:
108
+ self.few_shot_num = 0
109
+ logger.warning('Few-shot examples are not supported for MMLU-Redux dataset. Setting few_shot_num to 0.')
110
+
111
+ self.choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
112
+ self.category_map = {k: v[-1] for k, v in SUBJECT_MAPPING.items()}
113
+
114
+ def gen_prompt(self, input_d: Dict, subset_name: str, few_shot_list: list, **kwargs) -> Any:
115
+ if self.few_shot_num > 0:
116
+ prefix = self.format_fewshot_examples(few_shot_list)
117
+ else:
118
+ prefix = ''
119
+ query = prefix + 'Q: ' + input_d['question'] + '\n' + \
120
+ self.__form_options(input_d['choices']) + '\n'
121
+
122
+ full_prompt = self.prompt_template.format(subset_name=subset_name, query=query)
123
+ return self.gen_prompt_data(full_prompt)
124
+
125
+ def format_fewshot_examples(self, few_shot_list):
126
+ # load few-shot prompts for each category
127
+ prompts = ''
128
+ for index, d in enumerate(few_shot_list):
129
+ prompts += 'Q: ' + d['question'] + '\n' + \
130
+ self.__form_options(d['choices']) + '\n'
131
+ return prompts
132
+
133
+ def __form_options(self, options: list):
134
+ option_str = 'Options are:\n'
135
+ for opt, choice in zip(options, self.choices):
136
+ option_str += f'({choice}): {opt}' + '\n'
137
+ return option_str
138
+
139
+ def get_gold_answer(self, input_d: dict) -> str:
140
+ """
141
+ Parse the raw input labels (gold).
142
+
143
+ Args:
144
+ input_d: input raw data. Depending on the dataset.
145
+
146
+ Returns:
147
+ The parsed input. e.g. gold answer ... Depending on the dataset.
148
+ """
149
+ answer_index = int(input_d['answer'])
150
+ return self.choices[answer_index]
151
+
152
+ def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: str = EvalType.CHECKPOINT) -> str:
153
+ """
154
+ Parse the predicted result and extract proper answer.
155
+
156
+ Args:
157
+ result: Predicted answer from the model. Usually a string for chat.
158
+ raw_input_d: The raw input. Depending on the dataset.
159
+ eval_type: 'checkpoint' or 'service' or `custom`, default: 'checkpoint'
160
+
161
+ Returns:
162
+ The parsed answer. Depending on the dataset. Usually a string for chat.
163
+ """
164
+ if self.model_adapter == OutputType.MULTIPLE_CHOICE:
165
+ return result
166
+ else:
167
+ return ResponseParser.parse_first_option(result)
168
+
169
+ def match(self, gold: str, pred: str) -> float:
170
+ """
171
+ Match the gold answer and the predicted answer.
172
+
173
+ Args:
174
+ gold (Any): The golden answer. Usually a string for chat/multiple-choice-questions.
175
+ e.g. 'A', extracted from get_gold_answer method.
176
+ pred (Any): The predicted answer. Usually a string for chat/multiple-choice-questions.
177
+ e.g. 'B', extracted from parse_pred_result method.
178
+
179
+ Returns:
180
+ The match result. Usually a score (float) for chat/multiple-choice-questions.
181
+ """
182
+ return exact_match(gold=gold, pred=pred)
@@ -126,7 +126,7 @@ class SimpleQAAdapter(DataAdapter):
126
126
 
127
127
  def match(self, gold: str, pred: str) -> float:
128
128
  # simple match
129
- logger.warning(f'Please use LLMJudge to match the result for SimpleQA')
129
+ logger.warning(f'Please use LLMJudge to match the result for {self.name}')
130
130
  is_correct = 1 if gold.lower().strip() == pred.lower().strip() else 0
131
131
  is_incorrect = not is_correct
132
132
  is_not_attempted = 0
@@ -159,9 +159,6 @@ class SimpleQAAdapter(DataAdapter):
159
159
  review_res_list: [{'is_correct': 1, 'is_incorrect': 0, 'is_not_attempted': 0}, ...]
160
160
  """
161
161
  # zip dict answers
162
- res_dict = defaultdict(list)
163
- for res in review_res_list:
164
- for key, value in res.items():
165
- res_dict[key].append(value)
162
+ res_dict = super().compute_dict_metric(review_res_list, **kwargs)
166
163
 
167
164
  return super().compute_metric(res_dict, **kwargs)
@@ -181,7 +181,7 @@ class EvaluatorCollection:
181
181
  answers_list = jsonl_to_list(pred_file_path)
182
182
  indices = set()
183
183
  for answer in answers_list:
184
- index = answer[AnswerKeys.ORIGIN_PROMPT].get('index')
184
+ index = answer.get(AnswerKeys.INDEX)
185
185
  answer_dict[index] = answer
186
186
  indices.add(index)
187
187
  data = []
evalscope/config.py CHANGED
@@ -75,13 +75,13 @@ class TaskConfig:
75
75
 
76
76
  # LLMJudge arguments
77
77
  judge_strategy: str = JudgeStrategy.AUTO
78
- judge_worker_num: int = 8
78
+ judge_worker_num: int = 1
79
79
  judge_model_args: Optional[Dict] = field(default_factory=lambda: {})
80
80
 
81
81
  def __post_init__(self):
82
82
  if (not self.model_id) and self.model:
83
83
  if isinstance(self.model, CustomModel):
84
- self.model_id = type(self.model).__name__
84
+ self.model_id = self.model.config.get('model_id', 'custom_model')
85
85
  else:
86
86
  self.model_id = os.path.basename(self.model).rstrip(os.sep)
87
87
  # fix path error, see http://github.com/modelscope/evalscope/issues/377
@@ -92,7 +92,10 @@ class TaskConfig:
92
92
  self.eval_batch_size = 8 if self.eval_type == EvalType.SERVICE else 1
93
93
 
94
94
  def to_dict(self):
95
- return self.__dict__
95
+ result = self.__dict__.copy()
96
+ if isinstance(self.model, CustomModel):
97
+ result['model'] = self.model.__class__.__name__
98
+ return result
96
99
 
97
100
  def __str__(self):
98
101
  return json.dumps(self.to_dict(), indent=4, default=str, ensure_ascii=False)
evalscope/constants.py CHANGED
@@ -77,6 +77,7 @@ class ArenaMode:
77
77
 
78
78
 
79
79
  class AnswerKeys:
80
+ INDEX = 'index'
80
81
  ANSWER_ID = 'answer_id'
81
82
  RAW_INPUT = 'raw_input'
82
83
  ORIGIN_PROMPT = 'origin_prompt'
@@ -81,7 +81,7 @@ class Evaluator(object):
81
81
  for subset_name, prompts_list in prompts.items():
82
82
  limit = self.task_cfg.limit or len(prompts_list)
83
83
  for index, prompt in enumerate(prompts_list[:limit]):
84
- prompt['index'] = index
84
+ prompt[AnswerKeys.INDEX] = index
85
85
  limited_prompts[subset_name].append(prompt)
86
86
 
87
87
  return limited_prompts
@@ -97,7 +97,8 @@ class Evaluator(object):
97
97
  answer_d[AnswerKeys.ANSWER_ID] = answer_id
98
98
  answer_d[AnswerKeys.SUBSET_NAME] = subset_name
99
99
  answer_d[AnswerKeys.RAW_INPUT] = input_d[AnswerKeys.RAW_INPUT]
100
- answer_d[AnswerKeys.ORIGIN_PROMPT] = input_d
100
+ # answer_d[AnswerKeys.ORIGIN_PROMPT] = input_d
101
+ answer_d[AnswerKeys.INDEX] = input_d[AnswerKeys.INDEX]
101
102
  return answer_d
102
103
 
103
104
  def _get_answer(self, input_prompts, subset_name, infer_cfg) -> List[dict]:
@@ -117,7 +118,7 @@ class Evaluator(object):
117
118
  return answers_list, prompts_list
118
119
 
119
120
  def get_answered_indices(answers_list: List[Dict]) -> List[int]:
120
- indices = [answer[AnswerKeys.ORIGIN_PROMPT].get('index') for answer in answers_list]
121
+ indices = [answer.get(AnswerKeys.INDEX) for answer in answers_list]
121
122
 
122
123
  if all(index is None for index in indices):
123
124
  return list(range(len(answers_list)))
@@ -238,7 +239,7 @@ class Evaluator(object):
238
239
  pred = pred_content
239
240
 
240
241
  choice[ReviewKeys.REVIEW] = {
241
- ReviewKeys.GOLD: gold_content,
242
+ ReviewKeys.GOLD: gold_content if gold_content != raw_input_d else '*Same as Input*',
242
243
  ReviewKeys.PRED: pred,
243
244
  ReviewKeys.RESULT: review_result
244
245
  }
@@ -49,7 +49,7 @@ class LLMJudge:
49
49
  """
50
50
  self.api_key = api_key or os.environ.get('OPENAI_API_KEY', 'EMPTY')
51
51
  self.api_url = api_url or os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1')
52
- self.model_id = model_id or os.environ.get('LOCAL_LLM', 'gpt-3.5-turbo')
52
+ self.model_id = model_id or os.environ.get('LOCAL_LLM', 'gpt-4')
53
53
  self.system_prompt = system_prompt or os.environ.get('JUDGE_SYSTEM_PROMPT', None)
54
54
  self.prompt_template = prompt_template or os.environ.get('JUDGE_PROMPT_TEMPLATE', DEFAULT_PROMPT_TEMPLATE)
55
55
  self.generation_config = generation_config
@@ -1,13 +1,13 @@
1
1
  import os
2
2
  import time
3
3
  import torch
4
- from typing import List, Union
4
+ from typing import Any, Dict, List, Tuple, Union
5
5
 
6
6
  from evalscope.constants import OutputType
7
7
  from evalscope.models.base_adapter import BaseModelAdapter
8
8
  from evalscope.models.local_model import LocalModel
9
9
  from evalscope.models.register import register_model_adapter
10
- from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage
10
+ from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, Usage
11
11
  from evalscope.utils.logger import get_logger
12
12
  from evalscope.utils.model_utils import fix_do_sample_warning
13
13
 
@@ -60,7 +60,10 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
60
60
 
61
61
  return generation_config
62
62
 
63
- def _model_generate(self, queries: List[str], system_prompts: List[str] = None, infer_cfg: dict = {}) -> List[str]:
63
+ def _model_generate(self,
64
+ queries: List[str],
65
+ system_prompts: List[str] = None,
66
+ infer_cfg: Dict[str, Any] = None) -> Tuple[List[List[str]], List[int]]:
64
67
  """
65
68
  Args:
66
69
  queries: The input queries.
@@ -69,6 +72,11 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
69
72
  Returns:
70
73
  The prediction results.
71
74
  """
75
+ if system_prompts is None:
76
+ system_prompts = []
77
+ if infer_cfg is None:
78
+ infer_cfg = {}
79
+
72
80
  # Process infer_cfg
73
81
  num_return_sequences = infer_cfg.get('num_return_sequences', 1)
74
82
  if num_return_sequences > 1:
@@ -111,7 +119,9 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
111
119
  # Run inference
112
120
  output_ids = self.model.generate(**inputs, generation_config=self.generation_config)
113
121
 
122
+ # Decode output
114
123
  responses = []
124
+ input_lengths = [len(self.tokenizer.encode(prompt)) for prompt in formatted_prompts]
115
125
  for i in range(0, len(output_ids), num_return_sequences):
116
126
  query_responses = []
117
127
  for j in range(num_return_sequences):
@@ -121,7 +131,7 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
121
131
  query_responses.append(response)
122
132
  responses.append(query_responses)
123
133
 
124
- return responses
134
+ return responses, input_lengths
125
135
 
126
136
  @torch.no_grad()
127
137
  def predict(self, inputs: List[dict], infer_cfg: dict = {}) -> List[dict]:
@@ -141,22 +151,33 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
141
151
  queries.append(input_item['data'][0])
142
152
  system_prompts.append(input_item.get('system_prompt', None))
143
153
 
144
- responses = self._model_generate(queries, system_prompts, infer_cfg)
154
+ # Run inference
155
+ responses, input_lengths = self._model_generate(queries, system_prompts, infer_cfg)
145
156
 
157
+ # Process outputs
146
158
  results = []
147
- for response in responses:
148
- choices_list = [
149
- ChatCompletionResponseChoice(
159
+ for response, input_length in zip(responses, input_lengths):
160
+ choices_list = []
161
+ completion_tokens = 0
162
+
163
+ for index, one_response in enumerate(response):
164
+ choice = ChatCompletionResponseChoice(
150
165
  index=index, message=ChatMessage(content=one_response, role='assistant'), finish_reason='stop')
151
- for index, one_response in enumerate(response)
152
- ]
166
+ choices_list.append(choice)
167
+
168
+ completion_tokens += len(self.tokenizer.encode(one_response))
169
+
170
+ usage = Usage(
171
+ prompt_tokens=input_length,
172
+ completion_tokens=completion_tokens,
173
+ total_tokens=input_length + completion_tokens)
153
174
 
154
175
  res_d = ChatCompletionResponse(
155
176
  model=self.model_id,
156
177
  choices=choices_list,
157
178
  object='chat.completion',
158
179
  created=int(time.time()),
159
- usage=None).model_dump(exclude_unset=True)
180
+ usage=usage).model_dump(exclude_unset=True)
160
181
 
161
182
  results.append(res_d)
162
183
 
@@ -66,4 +66,4 @@ class CustomModelAdapter(BaseModelAdapter):
66
66
  else:
67
67
  raise TypeError(f'Unsupported inputs type: {type(input_prompt)}')
68
68
 
69
- return self.custom_model.predict(prompts=in_prompts, **kwargs)
69
+ return self.custom_model.predict(prompts=in_prompts, origin_inputs=inputs, **kwargs)
@@ -24,9 +24,10 @@ class Arguments:
24
24
  connect_timeout: int = 600 # Connection timeout in seconds
25
25
  read_timeout: int = 600 # Read timeout in seconds
26
26
  api_key: Optional[str] = None
27
+ no_test_connection: bool = False # Test the connection before starting the benchmark
27
28
 
28
29
  # Performance and parallelism
29
- number: Optional[int] = None # Number of requests to be made
30
+ number: int = 1000 # Number of requests to be made
30
31
  parallel: int = 1 # Number of parallel requests
31
32
  rate: int = -1 # Rate limit for requests (default: -1, no limit)
32
33
 
@@ -40,8 +41,9 @@ class Arguments:
40
41
  outputs_dir: str = DEFAULT_WORK_DIR
41
42
 
42
43
  # Prompt settings
43
- max_prompt_length: int = sys.maxsize # Maximum length of the prompt
44
+ max_prompt_length: int = 131072 # Maximum length of the prompt
44
45
  min_prompt_length: int = 0 # Minimum length of the prompt
46
+ prefix_length: int = 0 # Length of the prefix, only for random dataset
45
47
  prompt: Optional[str] = None # The prompt text
46
48
  query_template: Optional[str] = None # Template for the query
47
49
 
@@ -58,51 +60,20 @@ class Arguments:
58
60
  seed: Optional[int] = 42 # Random seed for reproducibility
59
61
  stop: Optional[List[str]] = field(default_factory=list) # Stop sequences for the response
60
62
  stop_token_ids: Optional[List[str]] = field(default_factory=list) # Stop token IDs for the response
61
- stream: Optional[bool] = None # Whether to stream the response
62
- temperature: Optional[float] = None # Temperature setting for the response
63
+ stream: Optional[bool] = False # Whether to stream the response
64
+ temperature: float = 0.0 # Temperature setting for the response
63
65
  top_p: Optional[float] = None # Top-p (nucleus) sampling setting for the response
64
66
  top_k: Optional[int] = None # Top-k sampling setting for the response
67
+ extra_args: Optional[Dict[str, Any]] = None # Extra arguments
65
68
 
66
69
  @staticmethod
67
70
  def from_args(args):
68
- return Arguments(
69
- model=args.model,
70
- attn_implementation=args.attn_implementation,
71
- url=args.url,
72
- port=args.port,
73
- api_key=args.api_key,
74
- connect_timeout=args.connect_timeout,
75
- read_timeout=args.read_timeout,
76
- number=args.number,
77
- parallel=args.parallel,
78
- rate=args.rate,
79
- log_every_n_query=args.log_every_n_query,
80
- headers=args.headers,
81
- wandb_api_key=args.wandb_api_key,
82
- name=args.name,
83
- outputs_dir=args.outputs_dir,
84
- debug=args.debug,
85
- tokenizer_path=args.tokenizer_path,
86
- api=args.api,
87
- max_prompt_length=args.max_prompt_length,
88
- min_prompt_length=args.min_prompt_length,
89
- prompt=args.prompt,
90
- query_template=args.query_template,
91
- dataset=args.dataset,
92
- dataset_path=args.dataset_path,
93
- frequency_penalty=args.frequency_penalty,
94
- logprobs=args.logprobs,
95
- max_tokens=args.max_tokens,
96
- min_tokens=args.min_tokens,
97
- n_choices=args.n_choices,
98
- seed=args.seed,
99
- stop=args.stop,
100
- stop_token_ids=args.stop_token_ids,
101
- stream=args.stream,
102
- temperature=args.temperature,
103
- top_p=args.top_p,
104
- top_k=args.top_k,
105
- )
71
+ # Convert Namespace to a dictionary and filter out None values
72
+ args_dict = {k: v for k, v in vars(args).items() if v is not None}
73
+
74
+ if 'func' in args_dict:
75
+ del args_dict['func'] # Note: compat CLI arguments
76
+ return Arguments(**args_dict)
106
77
 
107
78
  def __post_init__(self):
108
79
  self.headers = self.headers or {} # Default to empty dictionary
@@ -153,9 +124,10 @@ def add_argument(parser: argparse.ArgumentParser):
153
124
  parser.add_argument('--api-key', type=str, required=False, default=None, help='The API key for authentication')
154
125
  parser.add_argument('--connect-timeout', type=int, default=600, help='The network connection timeout')
155
126
  parser.add_argument('--read-timeout', type=int, default=600, help='The network read timeout')
127
+ parser.add_argument('--no-test-connection', action='store_false', default=False, help='Do not test the connection before starting the benchmark') # noqa: E501
156
128
 
157
129
  # Performance and parallelism
158
- parser.add_argument('-n', '--number', type=int, default=None, help='How many requests to be made')
130
+ parser.add_argument('-n', '--number', type=int, default=1000, help='How many requests to be made')
159
131
  parser.add_argument('--parallel', type=int, default=1, help='Set number of concurrency requests, default 1')
160
132
  parser.add_argument('--rate', type=int, default=-1, help='Number of requests per second. default None')
161
133
 
@@ -168,6 +140,7 @@ def add_argument(parser: argparse.ArgumentParser):
168
140
  # Prompt settings
169
141
  parser.add_argument('--max-prompt-length', type=int, default=sys.maxsize, help='Maximum input prompt length')
170
142
  parser.add_argument('--min-prompt-length', type=int, default=0, help='Minimum input prompt length')
143
+ parser.add_argument('--prefix-length', type=int, default=0, help='The prefix length')
171
144
  parser.add_argument('--prompt', type=str, required=False, default=None, help='Specified the request prompt')
172
145
  parser.add_argument('--query-template', type=str, default=None, help='Specify the query template')
173
146
 
@@ -189,11 +162,11 @@ def add_argument(parser: argparse.ArgumentParser):
189
162
  parser.add_argument('--seed', type=int, help='The random seed', default=42)
190
163
  parser.add_argument('--stop', nargs='*', help='The stop tokens', default=None)
191
164
  parser.add_argument('--stop-token-ids', nargs='*', help='Set the stop token IDs', default=None)
192
- parser.add_argument('--stream', action='store_true', help='Stream output with SSE', default=None)
193
- parser.add_argument('--temperature', type=float, help='The sample temperature', default=None)
165
+ parser.add_argument('--stream', action='store_true', help='Stream output with SSE', default=False)
166
+ parser.add_argument('--temperature', type=float, help='The sample temperature', default=0.0)
194
167
  parser.add_argument('--top-p', type=float, help='Sampling top p', default=None)
195
168
  parser.add_argument('--top-k', type=int, help='Sampling top k', default=None)
196
-
169
+ parser.add_argument('--extra-args', type=json.loads, default='{}', help='Extra arguments, should in JSON format',)
197
170
  # yapf: enable
198
171
 
199
172