evalscope 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (105) hide show
  1. evalscope/__init__.py +2 -0
  2. evalscope/arguments.py +10 -3
  3. evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_classification.py +0 -1
  4. evalscope/backend/rag_eval/utils/llm.py +1 -1
  5. evalscope/benchmarks/__init__.py +20 -1
  6. evalscope/benchmarks/arc/__init__.py +0 -5
  7. evalscope/benchmarks/arc/arc_adapter.py +23 -99
  8. evalscope/benchmarks/bbh/__init__.py +0 -4
  9. evalscope/benchmarks/bbh/bbh_adapter.py +19 -89
  10. evalscope/benchmarks/benchmark.py +70 -59
  11. evalscope/benchmarks/ceval/__init__.py +0 -5
  12. evalscope/benchmarks/ceval/ceval_adapter.py +22 -46
  13. evalscope/benchmarks/cmmlu/__init__.py +0 -5
  14. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +20 -41
  15. evalscope/benchmarks/competition_math/__init__.py +0 -5
  16. evalscope/benchmarks/competition_math/competition_math_adapter.py +29 -371
  17. evalscope/benchmarks/data_adapter.py +114 -85
  18. evalscope/benchmarks/general_qa/__init__.py +0 -5
  19. evalscope/benchmarks/general_qa/general_qa_adapter.py +16 -19
  20. evalscope/benchmarks/gsm8k/__init__.py +0 -4
  21. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +19 -98
  22. evalscope/benchmarks/hellaswag/__init__.py +0 -5
  23. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +23 -96
  24. evalscope/benchmarks/humaneval/__init__.py +0 -4
  25. evalscope/benchmarks/humaneval/humaneval_adapter.py +16 -117
  26. evalscope/benchmarks/mmlu/__init__.py +0 -5
  27. evalscope/benchmarks/mmlu/mmlu_adapter.py +26 -48
  28. evalscope/benchmarks/mmlu_pro/__init__.py +0 -0
  29. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +110 -0
  30. evalscope/benchmarks/race/__init__.py +0 -5
  31. evalscope/benchmarks/race/race_adapter.py +25 -53
  32. evalscope/benchmarks/trivia_qa/__init__.py +0 -5
  33. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +24 -97
  34. evalscope/benchmarks/truthful_qa/__init__.py +0 -5
  35. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +23 -33
  36. evalscope/collections/__init__.py +3 -0
  37. evalscope/collections/evaluator.py +178 -0
  38. evalscope/collections/sampler.py +132 -0
  39. evalscope/collections/schema.py +122 -0
  40. evalscope/config.py +10 -6
  41. evalscope/constants.py +7 -28
  42. evalscope/evaluator/evaluator.py +66 -108
  43. evalscope/evaluator/reviewer/auto_reviewer.py +12 -4
  44. evalscope/metrics/__init__.py +6 -0
  45. evalscope/metrics/bundled_rouge_score/rouge_scorer.py +1 -1
  46. evalscope/metrics/math_accuracy.py +193 -50
  47. evalscope/metrics/metrics.py +7 -4
  48. evalscope/metrics/rouge_metric.py +13 -8
  49. evalscope/models/__init__.py +14 -1
  50. evalscope/models/base_adapter.py +52 -0
  51. evalscope/models/chat_adapter.py +138 -0
  52. evalscope/models/choice_adapter.py +211 -0
  53. evalscope/models/custom_adapter.py +67 -0
  54. evalscope/models/local_model.py +74 -0
  55. evalscope/models/model.py +141 -0
  56. evalscope/models/server_adapter.py +104 -0
  57. evalscope/perf/arguments.py +1 -0
  58. evalscope/perf/benchmark.py +1 -1
  59. evalscope/perf/main.py +3 -1
  60. evalscope/perf/plugin/api/openai_api.py +51 -47
  61. evalscope/perf/utils/local_server.py +1 -0
  62. evalscope/run.py +37 -66
  63. evalscope/run_arena.py +1 -1
  64. evalscope/utils/__init__.py +1 -1
  65. evalscope/utils/chat_service.py +4 -3
  66. evalscope/utils/io_utils.py +8 -0
  67. evalscope/utils/logger.py +4 -0
  68. evalscope/utils/model_utils.py +10 -0
  69. evalscope/utils/utils.py +3 -25
  70. evalscope/version.py +2 -2
  71. {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/METADATA +46 -17
  72. {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/RECORD +81 -92
  73. tests/cli/test_collection.py +53 -0
  74. tests/cli/test_run.py +43 -1
  75. tests/perf/test_perf.py +3 -3
  76. tests/rag/test_mteb.py +3 -2
  77. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +0 -87
  78. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +0 -36
  79. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +0 -26
  80. evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +0 -41
  81. evalscope/backend/rag_eval/ragas/prompts/chinese/CustomNodeFilter/scoring_prompt_chinese.json +0 -7
  82. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +0 -60
  83. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +0 -36
  84. evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +0 -24
  85. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +0 -35
  86. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +0 -30
  87. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +0 -39
  88. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +0 -30
  89. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +0 -39
  90. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +0 -34
  91. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +0 -36
  92. evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +0 -25
  93. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +0 -24
  94. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +0 -39
  95. evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +0 -16
  96. evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +0 -24
  97. evalscope/models/api/__init__.py +0 -3
  98. evalscope/models/dummy_chat_model.py +0 -49
  99. evalscope/models/model_adapter.py +0 -525
  100. evalscope/models/openai_model.py +0 -103
  101. /evalscope/{models/api → third_party/longbench_write/tools}/openai_api.py +0 -0
  102. {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/LICENSE +0 -0
  103. {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/WHEEL +0 -0
  104. {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/entry_points.txt +0 -0
  105. {evalscope-0.8.1.dist-info → evalscope-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,57 +1,200 @@
1
1
  # Copyright (c) Alibaba, Inc. and its affiliates.
2
2
 
3
- import re
4
- from collections import defaultdict
5
- from tqdm import tqdm
6
3
 
7
- from evalscope.constants import MetricsConstant
4
+ # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
5
+ def is_equiv(str1, str2, verbose=False):
6
+ if str1 is None and str2 is None:
7
+ print('WARNING: Both None')
8
+ return True
9
+ if str1 is None or str2 is None:
10
+ return False
8
11
 
12
+ try:
13
+ ss1 = strip_string(str1)
14
+ ss2 = strip_string(str2)
15
+ if verbose:
16
+ print(ss1, ss2)
17
+ return ss1 == ss2
18
+ except Exception:
19
+ return str1 == str2
9
20
 
10
- def get_last_number(s):
11
- match = re.search(r'[-+]?\d*\.\d+|\d+', s[::-1])
12
- if match:
13
- last_digit = match.group()[::-1]
21
+
22
+ def remove_boxed(s):
23
+ if '\\boxed ' in s:
24
+ left = '\\boxed '
25
+ assert s[:len(left)] == left
26
+ return s[len(left):]
27
+
28
+ left = '\\boxed{'
29
+
30
+ assert s[:len(left)] == left
31
+ assert s[-1] == '}'
32
+
33
+ return s[len(left):-1]
34
+
35
+
36
+ def last_boxed_only_string(string):
37
+ idx = string.rfind('\\boxed')
38
+ if '\\boxed ' in string:
39
+ return '\\boxed ' + string.split('\\boxed ')[-1].split('$')[0]
40
+ if idx < 0:
41
+ idx = string.rfind('\\fbox')
42
+ if idx < 0:
43
+ return None
44
+
45
+ i = idx
46
+ right_brace_idx = None
47
+ num_left_braces_open = 0
48
+ while i < len(string):
49
+ if string[i] == '{':
50
+ num_left_braces_open += 1
51
+ if string[i] == '}':
52
+ num_left_braces_open -= 1
53
+ if num_left_braces_open == 0:
54
+ right_brace_idx = i
55
+ break
56
+ i += 1
57
+
58
+ if right_brace_idx is None:
59
+ retval = None
14
60
  else:
15
- last_digit = -100000
16
- return float(last_digit)
17
-
18
-
19
- def compute_math_accuracy_one_sample(predict, reference):
20
- if isinstance(predict, list):
21
- predict = predict[0]
22
- if isinstance(reference, list):
23
- reference = reference[0]
24
- predict_number = get_last_number(predict)
25
- reference_number = get_last_number(reference)
26
- if abs(predict_number - reference_number) <= MetricsConstant.EPSILON:
27
- return 1
61
+ retval = string[idx:right_brace_idx + 1]
62
+
63
+ return retval
64
+
65
+
66
+ def fix_fracs(string):
67
+ substrs = string.split('\\frac')
68
+ new_str = substrs[0]
69
+ if len(substrs) > 1:
70
+ substrs = substrs[1:]
71
+ for substr in substrs:
72
+ new_str += '\\frac'
73
+ if substr[0] == '{':
74
+ new_str += substr
75
+ else:
76
+ try:
77
+ assert len(substr) >= 2
78
+ except AssertionError:
79
+ return string
80
+ a = substr[0]
81
+ b = substr[1]
82
+ if b != '{':
83
+ if len(substr) > 2:
84
+ post_substr = substr[2:]
85
+ new_str += '{' + a + '}{' + b + '}' + post_substr
86
+ else:
87
+ new_str += '{' + a + '}{' + b + '}'
88
+ else:
89
+ if len(substr) > 2:
90
+ post_substr = substr[2:]
91
+ new_str += '{' + a + '}' + b + post_substr
92
+ else:
93
+ new_str += '{' + a + '}' + b
94
+ string = new_str
95
+ return string
96
+
97
+
98
+ def fix_a_slash_b(string):
99
+ if len(string.split('/')) != 2:
100
+ return string
101
+ a = string.split('/')[0]
102
+ b = string.split('/')[1]
103
+ try:
104
+ a = int(a)
105
+ b = int(b)
106
+ assert string == '{}/{}'.format(a, b)
107
+ new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
108
+ return new_string
109
+ except AssertionError:
110
+ return string
111
+
112
+
113
+ def remove_right_units(string):
114
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
115
+ if '\\text{ ' in string:
116
+ splits = string.split('\\text{ ')
117
+ assert len(splits) == 2
118
+ return splits[0]
28
119
  else:
29
- return 0
30
-
31
-
32
- def compute_math_accuracy(predict_l, reference_l):
33
- assert len(predict_l) == len(reference_l)
34
- if len(predict_l) == 0:
35
- return 0
36
- total_cnt = len(predict_l)
37
- correct_cnt = 0
38
- for predict, reference in zip(predict_l, reference_l):
39
- correct_cnt += compute_math_accuracy_one_sample(predict, reference)
40
- return {'math accuracy': correct_cnt / total_cnt}
41
-
42
-
43
- def run_math_eval(data_l, md_level=2):
44
- print(f"{'#' * md_level} Math Eval(math accuracy)")
45
- for data in tqdm(data_l):
46
- data['math_accuracy'] = compute_math_accuracy_one_sample(data['gen'], data['target'])
47
- task_data_d = defaultdict(list)
48
- for data in data_l:
49
- for task in data['task_tags']:
50
- task_data_d[task].append(data)
51
- correct_cnt = sum([data['math_accuracy'] for data in data_l])
52
- print(f'[total], count: {len(data_l)}, math accuracy: '
53
- f'{correct_cnt / len(data_l) * 100:0.2f}%')
54
- for task in task_data_d.keys():
55
- correct_cnt = sum([data['math_accuracy'] for data in task_data_d[task]])
56
- print(f'[{task}], count: {len(task_data_d[task])}, math accuracy: '
57
- f'{correct_cnt / len(task_data_d[task]) * 100:0.2f}%')
120
+ return string
121
+
122
+
123
+ def fix_sqrt(string):
124
+ if '\\sqrt' not in string:
125
+ return string
126
+ splits = string.split('\\sqrt')
127
+ new_string = splits[0]
128
+ for split in splits[1:]:
129
+ if split[0] != '{':
130
+ a = split[0]
131
+ new_substr = '\\sqrt{' + a + '}' + split[1:]
132
+ else:
133
+ new_substr = '\\sqrt' + split
134
+ new_string += new_substr
135
+ return new_string
136
+
137
+
138
+ def strip_string(string):
139
+ # linebreaks
140
+ string = string.replace('\n', '')
141
+
142
+ # remove inverse spaces
143
+ string = string.replace('\\!', '')
144
+
145
+ # replace \\ with \
146
+ string = string.replace('\\\\', '\\')
147
+
148
+ # replace tfrac and dfrac with frac
149
+ string = string.replace('tfrac', 'frac')
150
+ string = string.replace('dfrac', 'frac')
151
+
152
+ # remove \left and \right
153
+ string = string.replace('\\left', '')
154
+ string = string.replace('\\right', '')
155
+
156
+ # Remove circ (degrees)
157
+ string = string.replace('^{\\circ}', '')
158
+ string = string.replace('^\\circ', '')
159
+
160
+ # remove dollar signs
161
+ string = string.replace('\\$', '')
162
+
163
+ # remove units (on the right)
164
+ string = remove_right_units(string)
165
+
166
+ # remove percentage
167
+ string = string.replace('\\%', '')
168
+ string = string.replace('\%', '') # noqa: W605
169
+
170
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
171
+ string = string.replace(' .', ' 0.')
172
+ string = string.replace('{.', '{0.')
173
+ # if empty, return empty string
174
+ if len(string) == 0:
175
+ return string
176
+ if string[0] == '.':
177
+ string = '0' + string
178
+
179
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
180
+ if len(string.split('=')) == 2:
181
+ if len(string.split('=')[0]) <= 2:
182
+ string = string.split('=')[1]
183
+
184
+ # fix sqrt3 --> sqrt{3}
185
+ string = fix_sqrt(string)
186
+
187
+ # remove spaces
188
+ string = string.replace(' ', '')
189
+
190
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} # noqa: E501
191
+ string = fix_fracs(string)
192
+
193
+ # manually change 0.5 --> \frac{1}{2}
194
+ if string == '0.5':
195
+ string = '\\frac{1}{2}'
196
+
197
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
198
+ string = fix_a_slash_b(string)
199
+
200
+ return string
@@ -2,16 +2,12 @@
2
2
  # Copyright (c) EleutherAI. and its affiliates.
3
3
  # Copyright (c) OpenAI. and its affiliates.
4
4
  import itertools
5
- import jieba
6
5
  import math
7
6
  import numpy as np
8
7
  import random
9
8
  import sacrebleu
10
- import sklearn.metrics
11
9
  from collections import defaultdict
12
10
  from collections.abc import Iterable
13
- from nltk import word_tokenize
14
- from nltk.translate.bleu_score import sentence_bleu
15
11
  from typing import Dict, List, Union
16
12
 
17
13
 
@@ -38,6 +34,8 @@ def median(arr):
38
34
 
39
35
 
40
36
  def matthews_corrcoef(items):
37
+ import sklearn.metrics
38
+
41
39
  unzipped_list = list(zip(*items))
42
40
  golds = unzipped_list[0]
43
41
  preds = unzipped_list[1]
@@ -45,6 +43,8 @@ def matthews_corrcoef(items):
45
43
 
46
44
 
47
45
  def f1_score(items):
46
+ import sklearn.metrics
47
+
48
48
  unzipped_list = list(zip(*items))
49
49
  golds = unzipped_list[0]
50
50
  preds = unzipped_list[1]
@@ -150,6 +150,9 @@ def bleu_ngram_one_sample(predict, reference):
150
150
  }
151
151
 
152
152
  """
153
+ import jieba
154
+ from nltk import word_tokenize
155
+ from nltk.translate.bleu_score import sentence_bleu
153
156
 
154
157
  def is_contains_chinese(strs):
155
158
  for _char in strs:
@@ -1,15 +1,16 @@
1
1
  # Copyright (c) Alibaba, Inc. and its affiliates.
2
2
 
3
3
  import jieba
4
- import logging
5
4
  from collections import defaultdict
6
- from pathlib import Path
7
5
  from rouge_chinese import Rouge
8
6
  from statistics import mean
9
7
  from tqdm import tqdm
10
8
 
11
9
  from evalscope.constants import MetricsConstant
12
10
  from evalscope.metrics.bundled_rouge_score import rouge_scorer
11
+ from evalscope.utils.logger import get_logger
12
+
13
+ logger = get_logger()
13
14
 
14
15
 
15
16
  class DummyTokenizer:
@@ -18,10 +19,6 @@ class DummyTokenizer:
18
19
  return text.split()
19
20
 
20
21
 
21
- HERE = Path(__file__).absolute().parent
22
-
23
- logger = logging.getLogger(__name__)
24
-
25
22
  scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], tokenizer=DummyTokenizer())
26
23
  zh_scorer = Rouge()
27
24
 
@@ -58,7 +55,11 @@ def compute_rouge_score_one_sample_zh(predict, reference):
58
55
  p = ' '.join(jieba.cut(p)) if is_contains_chinese(p) else p
59
56
  r = ' '.join(jieba.cut(r)) if is_contains_chinese(r) else r
60
57
 
61
- score = zh_scorer.get_scores(p, r)[0]
58
+ try:
59
+ score = zh_scorer.get_scores(p, r, ignore_empty=True)[0]
60
+ except Exception as e:
61
+ logger.warning(f'rouge score error: {p} {r} {e}')
62
+ continue
62
63
  result['rouge-1-r'] = score['rouge-1']['r']
63
64
  result['rouge-1-p'] = score['rouge-1']['p']
64
65
  result['rouge-1-f'] = score['rouge-1']['f']
@@ -75,7 +76,11 @@ def compute_rouge_score_one_sample_zh(predict, reference):
75
76
  def compute_rouge_score_one_sample(predict, reference):
76
77
  result = dict()
77
78
  for p, r in zip(predict, reference):
78
- score = scorer.score(p, r)
79
+ try:
80
+ score = scorer.score(p, r)
81
+ except Exception as e:
82
+ logger.warning(f'rouge score error: {p} {r} {e}')
83
+ continue
79
84
  result['rouge-1-r'] = score['rouge1'].recall
80
85
  result['rouge-1-p'] = score['rouge1'].precision
81
86
  result['rouge-1-f'] = score['rouge1'].fmeasure
@@ -1,3 +1,16 @@
1
1
  # Copyright (c) Alibaba, Inc. and its affiliates.
2
2
 
3
- from evalscope.models.model import BaseModel, ChatBaseModel
3
+ from evalscope.models.base_adapter import BaseModelAdapter, initialize_model_adapter
4
+ from evalscope.models.chat_adapter import ChatGenerationModelAdapter
5
+ from evalscope.models.choice_adapter import ContinuationLogitsModelAdapter, MultiChoiceModelAdapter
6
+ from evalscope.models.custom import CustomModel
7
+ from evalscope.models.custom_adapter import CustomModelAdapter
8
+ from evalscope.models.local_model import LocalModel, get_local_model
9
+ from evalscope.models.model import BaseModel, ChatBaseModel, OpenAIModel
10
+ from evalscope.models.server_adapter import ServerModelAdapter
11
+
12
+ __all__ = [
13
+ 'CustomModel', 'BaseModel', 'ChatBaseModel', 'OpenAIModel', 'BaseModelAdapter', 'ChatGenerationModelAdapter',
14
+ 'MultiChoiceModelAdapter', 'ContinuationLogitsModelAdapter', 'CustomModelAdapter', 'ServerModelAdapter',
15
+ 'LocalModel', 'get_local_model', 'initialize_model_adapter'
16
+ ]
@@ -0,0 +1,52 @@
1
+ import torch
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING, Any, Optional, Union
4
+
5
+ from evalscope.constants import EvalType
6
+ from evalscope.models.custom import CustomModel
7
+ from evalscope.models.local_model import LocalModel
8
+
9
+ if TYPE_CHECKING:
10
+ from evalscope.config import TaskConfig
11
+
12
+
13
+ class BaseModelAdapter(ABC):
14
+
15
+ def __init__(self, model: Optional[Union[LocalModel, CustomModel]], **kwargs):
16
+ if model is None:
17
+ self.model_cfg = kwargs.get('model_cfg', None)
18
+ elif isinstance(model, LocalModel):
19
+ self.model = model.model
20
+ self.model_id = model.model_id
21
+ self.model_revision = model.model_revision
22
+ self.device = model.device
23
+ self.tokenizer = model.tokenizer
24
+ self.model_cfg = model.model_cfg
25
+ elif isinstance(model, CustomModel):
26
+ self.model_cfg = model.config
27
+ else:
28
+ raise ValueError(f'Unsupported model type: {type(model)}')
29
+
30
+ @abstractmethod
31
+ @torch.no_grad()
32
+ def predict(self, *args, **kwargs) -> Any:
33
+ raise NotImplementedError
34
+
35
+
36
+ def initialize_model_adapter(task_cfg: 'TaskConfig', model_adapter_cls: 'BaseModelAdapter', base_model: 'LocalModel'):
37
+ """Initialize the model adapter based on the task configuration."""
38
+ if task_cfg.dry_run:
39
+ from evalscope.models.model import DummyChatModel
40
+ return DummyChatModel(model_cfg=dict())
41
+ elif task_cfg.eval_type == EvalType.CUSTOM:
42
+ if not isinstance(task_cfg.model, CustomModel):
43
+ raise ValueError(f'Expected evalscope.models.custom.CustomModel, but got {type(task_cfg.model)}.')
44
+ from evalscope.models import CustomModelAdapter
45
+ return CustomModelAdapter(custom_model=task_cfg.model)
46
+ elif task_cfg.eval_type == EvalType.SERVICE:
47
+ from evalscope.models import ServerModelAdapter
48
+ return ServerModelAdapter(
49
+ api_url=task_cfg.api_url, model_id=task_cfg.model, api_key=task_cfg.api_key, seed=task_cfg.seed)
50
+ else:
51
+ return model_adapter_cls(
52
+ model=base_model, generation_config=task_cfg.generation_config, chat_template=task_cfg.chat_template)
@@ -0,0 +1,138 @@
1
+ import os
2
+ import time
3
+ import torch
4
+ from typing import Union
5
+
6
+ from evalscope.models.base_adapter import BaseModelAdapter
7
+ from evalscope.models.local_model import LocalModel
8
+ from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage
9
+ from evalscope.utils.logger import get_logger
10
+ from evalscope.utils.model_utils import fix_do_sample_warning
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ class ChatGenerationModelAdapter(BaseModelAdapter):
16
+ """
17
+ Chat generation model adapter.
18
+ """
19
+
20
+ def __init__(self, model: LocalModel, **kwargs):
21
+ super().__init__(model)
22
+
23
+ self.generation_config = self._parse_generation_config(self.tokenizer, self.model)
24
+
25
+ custom_generation_config = kwargs.pop('generation_config', None)
26
+ custom_chat_template = kwargs.pop('chat_template', None)
27
+
28
+ if custom_generation_config:
29
+ logger.info('Updating generation config ...')
30
+ self.generation_config.update(**custom_generation_config)
31
+
32
+ if custom_chat_template:
33
+ self.tokenizer.chat_template = custom_chat_template
34
+ logger.info(f'Using custom chat template: {custom_chat_template}')
35
+
36
+ def _parse_generation_config(self, tokenizer, model):
37
+ from modelscope import GenerationConfig
38
+
39
+ generation_config = getattr(model, 'generation_config', GenerationConfig(do_sample=False))
40
+
41
+ try:
42
+ remote_config = GenerationConfig.from_pretrained(
43
+ self.model_id, revision=self.model_revision, trust_remote_code=True)
44
+ generation_config.update(**remote_config.to_dict())
45
+ except Exception:
46
+ logger.warning(f'Failed to get generation config of {self.model_id} from model hub, use default.')
47
+
48
+ if isinstance(self.model_id, str) and os.path.exists(self.model_id):
49
+ logger.warning(f'Got local model dir: {self.model_id}')
50
+
51
+ if tokenizer.eos_token_id is not None:
52
+ generation_config.eos_token_id = tokenizer.eos_token_id
53
+ if tokenizer.pad_token_id is not None:
54
+ generation_config.pad_token_id = tokenizer.pad_token_id
55
+ if generation_config.max_new_tokens is None:
56
+ generation_config.max_new_tokens = 2048
57
+
58
+ return generation_config
59
+
60
+ def _model_generate(self, query: str, system_prompt: str = None, infer_cfg: dict = {}) -> str:
61
+ """
62
+ Args:
63
+ query: The input query.
64
+ system_prompt: The system prompt.
65
+ infer_cfg: The inference configuration.
66
+ Returns:
67
+ The prediction result.
68
+ """
69
+ # For chat model, use the chat template to format the input
70
+ if self.tokenizer.chat_template is not None:
71
+ messages = [ChatMessage(role='user', content=query)]
72
+ if system_prompt:
73
+ messages = [ChatMessage(role='system', content=system_prompt)] + messages
74
+ formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75
+ else:
76
+ # For base model, use the query as the input
77
+ formatted_prompt = query
78
+
79
+ inputs = self.tokenizer(formatted_prompt, return_tensors='pt', padding=True).to(self.device)
80
+ input_ids = inputs['input_ids']
81
+
82
+ # Process infer_cfg
83
+ if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1:
84
+ infer_cfg['do_sample'] = True
85
+
86
+ # stop settings
87
+ stop = infer_cfg.get('stop', None)
88
+ eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \
89
+ if stop else self.tokenizer.eos_token_id
90
+
91
+ if eos_token_id is not None:
92
+ infer_cfg['eos_token_id'] = eos_token_id
93
+ infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
94
+
95
+ self.generation_config.update(**infer_cfg)
96
+ fix_do_sample_warning(self.generation_config)
97
+
98
+ # Run inference
99
+ output_ids = self.model.generate(**inputs, generation_config=self.generation_config)
100
+
101
+ response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], skip_special_tokens=True)
102
+ return response
103
+
104
+ @torch.no_grad()
105
+ def predict(self, inputs: Union[str, dict, list], infer_cfg: dict = {}) -> dict:
106
+ """
107
+ Args:
108
+ inputs: The input data.
109
+ infer_cfg: The inference configuration.
110
+ Returns:
111
+ The prediction result.
112
+ """
113
+
114
+ # Process inputs
115
+ if isinstance(inputs, str):
116
+ query = inputs
117
+ system_prompt = None
118
+ elif isinstance(inputs, dict):
119
+ query = inputs['data'][0]
120
+ system_prompt = inputs.get('system_prompt', None)
121
+ elif isinstance(inputs, list):
122
+ query = '\n'.join(inputs)
123
+ system_prompt = None
124
+ else:
125
+ raise TypeError(f'Unsupported inputs type: {type(inputs)}')
126
+
127
+ response = self._model_generate(query, system_prompt, infer_cfg)
128
+
129
+ choices_list = [
130
+ ChatCompletionResponseChoice(
131
+ index=0, message=ChatMessage(content=response, role='assistant'), finish_reason='stop')
132
+ ]
133
+
134
+ res_d = ChatCompletionResponse(
135
+ model=self.model_id, choices=choices_list, object='chat.completion', created=int(time.time()),
136
+ usage=None).model_dump(exclude_unset=True)
137
+
138
+ return res_d