evalscope 0.7.2__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +1 -1
- evalscope/arguments.py +73 -0
- evalscope/backend/base.py +6 -2
- evalscope/backend/opencompass/api_meta_template.py +8 -14
- evalscope/backend/opencompass/backend_manager.py +24 -15
- evalscope/backend/opencompass/tasks/eval_api.py +1 -6
- evalscope/backend/opencompass/tasks/eval_datasets.py +26 -28
- evalscope/backend/rag_eval/__init__.py +3 -3
- evalscope/backend/rag_eval/backend_manager.py +21 -25
- evalscope/backend/rag_eval/clip_benchmark/__init__.py +1 -1
- evalscope/backend/rag_eval/clip_benchmark/arguments.py +6 -6
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +62 -79
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +29 -43
- evalscope/backend/rag_eval/clip_benchmark/tasks/image_caption.py +20 -22
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_classification.py +16 -23
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_retrieval.py +14 -35
- evalscope/backend/rag_eval/clip_benchmark/utils/webdataset_convert.py +69 -90
- evalscope/backend/rag_eval/cmteb/__init__.py +3 -3
- evalscope/backend/rag_eval/cmteb/arguments.py +25 -27
- evalscope/backend/rag_eval/cmteb/base.py +22 -23
- evalscope/backend/rag_eval/cmteb/task_template.py +15 -17
- evalscope/backend/rag_eval/cmteb/tasks/Classification.py +98 -79
- evalscope/backend/rag_eval/cmteb/tasks/Clustering.py +17 -22
- evalscope/backend/rag_eval/cmteb/tasks/CustomTask.py +17 -19
- evalscope/backend/rag_eval/cmteb/tasks/PairClassification.py +35 -29
- evalscope/backend/rag_eval/cmteb/tasks/Reranking.py +18 -5
- evalscope/backend/rag_eval/cmteb/tasks/Retrieval.py +163 -163
- evalscope/backend/rag_eval/cmteb/tasks/STS.py +126 -104
- evalscope/backend/rag_eval/cmteb/tasks/__init__.py +33 -34
- evalscope/backend/rag_eval/ragas/__init__.py +2 -2
- evalscope/backend/rag_eval/ragas/arguments.py +3 -8
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +9 -9
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/CustomNodeFilter/scoring_prompt_chinese.json +7 -0
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +8 -8
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +7 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +27 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +27 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +21 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +4 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/persona_prompt.py +0 -1
- evalscope/backend/rag_eval/ragas/task_template.py +10 -15
- evalscope/backend/rag_eval/ragas/tasks/__init__.py +1 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +45 -0
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +135 -0
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +17 -133
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +8 -18
- evalscope/backend/rag_eval/utils/clip.py +47 -51
- evalscope/backend/rag_eval/utils/embedding.py +13 -12
- evalscope/backend/rag_eval/utils/llm.py +8 -6
- evalscope/backend/rag_eval/utils/tools.py +12 -11
- evalscope/backend/vlm_eval_kit/__init__.py +1 -1
- evalscope/backend/vlm_eval_kit/custom_dataset.py +7 -8
- evalscope/benchmarks/arc/__init__.py +3 -2
- evalscope/benchmarks/arc/ai2_arc.py +19 -16
- evalscope/benchmarks/arc/arc_adapter.py +32 -24
- evalscope/benchmarks/bbh/__init__.py +1 -2
- evalscope/benchmarks/bbh/bbh_adapter.py +28 -25
- evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/navigate.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/snarks.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +1 -1
- evalscope/benchmarks/benchmark.py +16 -16
- evalscope/benchmarks/ceval/__init__.py +3 -2
- evalscope/benchmarks/ceval/ceval_adapter.py +80 -69
- evalscope/benchmarks/ceval/ceval_exam.py +18 -31
- evalscope/benchmarks/cmmlu/__init__.py +3 -2
- evalscope/benchmarks/cmmlu/cmmlu.py +87 -92
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +109 -155
- evalscope/benchmarks/cmmlu/samples.jsonl +1 -1
- evalscope/benchmarks/competition_math/__init__.py +3 -2
- evalscope/benchmarks/competition_math/competition_math.py +7 -16
- evalscope/benchmarks/competition_math/competition_math_adapter.py +32 -34
- evalscope/benchmarks/data_adapter.py +24 -24
- evalscope/benchmarks/general_qa/__init__.py +3 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +35 -39
- evalscope/benchmarks/gsm8k/__init__.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k.py +6 -12
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +27 -24
- evalscope/benchmarks/hellaswag/__init__.py +3 -2
- evalscope/benchmarks/hellaswag/hellaswag.py +15 -19
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +28 -23
- evalscope/benchmarks/humaneval/__init__.py +1 -1
- evalscope/benchmarks/humaneval/humaneval.py +15 -18
- evalscope/benchmarks/humaneval/humaneval_adapter.py +192 -7
- evalscope/benchmarks/mmlu/__init__.py +3 -2
- evalscope/benchmarks/mmlu/mmlu.py +15 -29
- evalscope/benchmarks/mmlu/mmlu_adapter.py +85 -77
- evalscope/benchmarks/race/__init__.py +3 -2
- evalscope/benchmarks/race/race.py +21 -35
- evalscope/benchmarks/race/race_adapter.py +33 -29
- evalscope/benchmarks/race/samples.jsonl +1 -1
- evalscope/benchmarks/trivia_qa/__init__.py +3 -2
- evalscope/benchmarks/trivia_qa/samples.jsonl +1 -1
- evalscope/benchmarks/trivia_qa/trivia_qa.py +19 -34
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +27 -22
- evalscope/benchmarks/truthful_qa/__init__.py +3 -2
- evalscope/benchmarks/truthful_qa/truthful_qa.py +25 -29
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +36 -37
- evalscope/cli/cli.py +6 -5
- evalscope/cli/start_eval.py +31 -0
- evalscope/cli/start_perf.py +0 -3
- evalscope/cli/start_server.py +27 -41
- evalscope/config.py +154 -96
- evalscope/constants.py +50 -32
- evalscope/evaluator/evaluator.py +97 -377
- evalscope/evaluator/rating_eval.py +12 -33
- evalscope/evaluator/reviewer/auto_reviewer.py +48 -76
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +10 -20
- evalscope/metrics/code_metric.py +3 -9
- evalscope/metrics/math_accuracy.py +3 -6
- evalscope/metrics/metrics.py +21 -21
- evalscope/metrics/rouge_metric.py +11 -25
- evalscope/models/__init__.py +1 -2
- evalscope/models/api/openai_api.py +40 -29
- evalscope/models/custom/__init__.py +0 -1
- evalscope/models/custom/custom_model.py +3 -3
- evalscope/models/dummy_chat_model.py +7 -8
- evalscope/models/model_adapter.py +89 -156
- evalscope/models/openai_model.py +20 -20
- evalscope/perf/arguments.py +16 -3
- evalscope/perf/benchmark.py +9 -11
- evalscope/perf/http_client.py +3 -8
- evalscope/perf/main.py +8 -1
- evalscope/perf/plugin/api/custom_api.py +1 -2
- evalscope/perf/plugin/api/dashscope_api.py +1 -2
- evalscope/perf/plugin/api/openai_api.py +3 -4
- evalscope/perf/plugin/datasets/base.py +1 -2
- evalscope/perf/plugin/datasets/flickr8k.py +1 -2
- evalscope/perf/plugin/datasets/longalpaca.py +1 -2
- evalscope/perf/plugin/datasets/openqa.py +1 -2
- evalscope/perf/plugin/registry.py +3 -3
- evalscope/perf/utils/analysis_result.py +1 -2
- evalscope/perf/utils/benchmark_util.py +5 -6
- evalscope/perf/utils/db_util.py +77 -30
- evalscope/perf/utils/local_server.py +21 -13
- evalscope/registry/config/cfg_arena_zhihu.yaml +1 -1
- evalscope/registry/tasks/arc.yaml +2 -3
- evalscope/registry/tasks/bbh.yaml +3 -4
- evalscope/registry/tasks/bbh_mini.yaml +3 -4
- evalscope/registry/tasks/ceval.yaml +3 -3
- evalscope/registry/tasks/ceval_mini.yaml +3 -4
- evalscope/registry/tasks/cmmlu.yaml +3 -3
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +1 -1
- evalscope/registry/tasks/general_qa.yaml +1 -1
- evalscope/registry/tasks/gsm8k.yaml +2 -2
- evalscope/registry/tasks/mmlu.yaml +3 -3
- evalscope/registry/tasks/mmlu_mini.yaml +3 -3
- evalscope/run.py +153 -381
- evalscope/run_arena.py +21 -25
- evalscope/summarizer.py +27 -40
- evalscope/third_party/longbench_write/README.md +99 -42
- evalscope/third_party/longbench_write/default_task.json +1 -1
- evalscope/third_party/longbench_write/default_task.yaml +8 -7
- evalscope/third_party/longbench_write/eval.py +29 -27
- evalscope/third_party/longbench_write/infer.py +16 -104
- evalscope/third_party/longbench_write/longbench_write.py +5 -4
- evalscope/third_party/longbench_write/resources/judge.txt +1 -1
- evalscope/third_party/longbench_write/tools/data_etl.py +5 -6
- evalscope/third_party/longbench_write/utils.py +0 -1
- evalscope/third_party/toolbench_static/eval.py +14 -15
- evalscope/third_party/toolbench_static/infer.py +48 -69
- evalscope/third_party/toolbench_static/llm/swift_infer.py +4 -12
- evalscope/third_party/toolbench_static/requirements.txt +1 -1
- evalscope/third_party/toolbench_static/toolbench_static.py +4 -3
- evalscope/tools/combine_reports.py +27 -34
- evalscope/tools/rewrite_eval_results.py +15 -47
- evalscope/utils/__init__.py +1 -1
- evalscope/utils/arena_utils.py +18 -48
- evalscope/{perf/utils → utils}/chat_service.py +4 -5
- evalscope/utils/completion_parsers.py +3 -8
- evalscope/utils/io_utils.py +162 -0
- evalscope/utils/logger.py +17 -7
- evalscope/utils/model_utils.py +11 -0
- evalscope/utils/utils.py +5 -306
- evalscope/version.py +2 -2
- {evalscope-0.7.2.dist-info → evalscope-0.8.1.dist-info}/METADATA +123 -118
- evalscope-0.8.1.dist-info/RECORD +285 -0
- tests/cli/test_run.py +53 -15
- tests/perf/test_perf.py +6 -1
- tests/rag/test_clip_benchmark.py +38 -38
- tests/rag/test_mteb.py +3 -2
- tests/rag/test_ragas.py +5 -5
- tests/swift/test_run_swift_eval.py +2 -3
- tests/swift/test_run_swift_vlm_eval.py +2 -3
- tests/swift/test_run_swift_vlm_jugde_eval.py +2 -3
- tests/vlm/test_vlmeval.py +3 -2
- evalscope/backend/rag_eval/ragas/metrics/__init__.py +0 -2
- evalscope/backend/rag_eval/ragas/metrics/multi_modal_faithfulness.py +0 -91
- evalscope/backend/rag_eval/ragas/metrics/multi_modal_relevance.py +0 -99
- evalscope/cache.py +0 -98
- evalscope/models/template.py +0 -1446
- evalscope/run_ms.py +0 -140
- evalscope/utils/task_cfg_parser.py +0 -10
- evalscope/utils/task_utils.py +0 -22
- evalscope-0.7.2.dist-info/RECORD +0 -286
- {evalscope-0.7.2.dist-info → evalscope-0.8.1.dist-info}/LICENSE +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.1.dist-info}/WHEEL +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.1.dist-info}/entry_points.txt +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -1,34 +1,36 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import requests
|
|
4
5
|
import threading
|
|
5
6
|
import time
|
|
6
7
|
from asyncio import Queue
|
|
8
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
from typing import Dict, List, Optional, Union
|
|
7
11
|
|
|
8
|
-
import
|
|
9
|
-
from typing import Union, List, Optional, Dict
|
|
10
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
11
|
-
from modelscope.utils.logger import get_logger
|
|
12
|
+
from evalscope.utils.logger import get_logger
|
|
12
13
|
|
|
13
14
|
logger = get_logger()
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class OpenaiApi:
|
|
17
18
|
|
|
18
|
-
def __init__(
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model: str,
|
|
22
|
+
openai_api_key,
|
|
23
|
+
openai_api_base,
|
|
24
|
+
logprobs: Optional[bool] = False,
|
|
25
|
+
top_logprobs: Optional[int] = None,
|
|
26
|
+
max_new_tokens: int = 4096,
|
|
27
|
+
temperature: Optional[float] = 0.0,
|
|
28
|
+
repetition_penalty: Optional[float] = 1.0,
|
|
29
|
+
is_chat: bool = True,
|
|
30
|
+
verbose: bool = True,
|
|
31
|
+
retry: int = 3,
|
|
32
|
+
query_per_second: int = 10, # TODO
|
|
33
|
+
**kwargs):
|
|
32
34
|
|
|
33
35
|
self.temperature = temperature
|
|
34
36
|
self.repetition_penalty = repetition_penalty
|
|
@@ -45,14 +47,17 @@ class OpenaiApi:
|
|
|
45
47
|
|
|
46
48
|
self.token_bucket = TokenBucket(query_per_second, verbose)
|
|
47
49
|
|
|
48
|
-
def generate_simple(self, inputs: Union[List[str]]):
|
|
50
|
+
def generate_simple(self, inputs: Union[List[str]], num_proc: int = 8):
|
|
49
51
|
|
|
50
52
|
def process_one(in_data: str):
|
|
51
53
|
|
|
52
54
|
if self.is_chat:
|
|
53
55
|
data = dict(
|
|
54
56
|
model=self.model,
|
|
55
|
-
messages=[{
|
|
57
|
+
messages=[{
|
|
58
|
+
'role': 'user',
|
|
59
|
+
'content': in_data
|
|
60
|
+
}],
|
|
56
61
|
max_tokens=self.max_tokens,
|
|
57
62
|
n=1,
|
|
58
63
|
logprobs=self.logprobs,
|
|
@@ -72,7 +77,10 @@ class OpenaiApi:
|
|
|
72
77
|
|
|
73
78
|
# todo
|
|
74
79
|
openai_api_key = self.openai_api_key or ''
|
|
75
|
-
header = {
|
|
80
|
+
header = {
|
|
81
|
+
'Authorization': f'Bearer {openai_api_key}',
|
|
82
|
+
'content-type': 'application/json',
|
|
83
|
+
}
|
|
76
84
|
data = json.dumps(data, ensure_ascii=False)
|
|
77
85
|
|
|
78
86
|
if self.verbose:
|
|
@@ -91,14 +99,18 @@ class OpenaiApi:
|
|
|
91
99
|
else:
|
|
92
100
|
return resp['choices'][0]['text'].strip()
|
|
93
101
|
|
|
94
|
-
|
|
95
|
-
|
|
102
|
+
results = []
|
|
103
|
+
with ThreadPoolExecutor(max_workers=num_proc) as executor:
|
|
104
|
+
# Submit all tasks
|
|
105
|
+
future_to_task = {executor.submit(process_one, input_one): input_one for input_one in inputs}
|
|
106
|
+
|
|
107
|
+
# Show progress bar
|
|
108
|
+
for future in tqdm(as_completed(future_to_task), total=len(inputs)):
|
|
109
|
+
results.append(future.result())
|
|
96
110
|
|
|
97
111
|
return results
|
|
98
112
|
|
|
99
|
-
def generate(self,
|
|
100
|
-
inputs: Union[List[str], List[List]],
|
|
101
|
-
**kwargs) -> List[str]:
|
|
113
|
+
def generate(self, inputs: Union[List[str], List[List]], **kwargs) -> List[str]:
|
|
102
114
|
"""
|
|
103
115
|
Generate responses from OpenAI API.
|
|
104
116
|
|
|
@@ -160,13 +172,12 @@ class OpenaiApi:
|
|
|
160
172
|
|
|
161
173
|
def remove_none_val(input_d: dict):
|
|
162
174
|
return {k: v for k, v in input_d.items() if v is not None}
|
|
175
|
+
|
|
163
176
|
data = remove_none_val(data)
|
|
164
177
|
|
|
165
178
|
if self.verbose:
|
|
166
179
|
logger.info(f'>> Post data: {json.dumps(data, ensure_ascii=False)}')
|
|
167
|
-
raw_response = requests.post(self.url,
|
|
168
|
-
headers=header,
|
|
169
|
-
data=json.dumps(data, ensure_ascii=False))
|
|
180
|
+
raw_response = requests.post(self.url, headers=header, data=json.dumps(data, ensure_ascii=False))
|
|
170
181
|
|
|
171
182
|
response = raw_response.json()
|
|
172
183
|
if self.verbose:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
from typing import Any, Union, Dict, List
|
|
4
2
|
import torch
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Dict, List, Union
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class CustomModel(ABC):
|
|
@@ -11,7 +11,7 @@ class CustomModel(ABC):
|
|
|
11
11
|
self.kwargs = kwargs
|
|
12
12
|
|
|
13
13
|
if config.get('model_id', None) is None:
|
|
14
|
-
raise ValueError(f
|
|
14
|
+
raise ValueError(f'**Error: model_id is required in config for CustomModel. Got config: {config}')
|
|
15
15
|
|
|
16
16
|
@abstractmethod
|
|
17
17
|
@torch.no_grad()
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import random
|
|
4
4
|
import time
|
|
5
|
+
|
|
5
6
|
from evalscope.models import ChatBaseModel
|
|
6
7
|
from evalscope.utils.logger import get_logger
|
|
7
8
|
|
|
@@ -32,15 +33,13 @@ class DummyChatModel(ChatBaseModel):
|
|
|
32
33
|
|
|
33
34
|
# Build response
|
|
34
35
|
res = {
|
|
35
|
-
'choices': [
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
'
|
|
39
|
-
|
|
40
|
-
'role': 'assistant'
|
|
41
|
-
}
|
|
36
|
+
'choices': [{
|
|
37
|
+
'index': 0,
|
|
38
|
+
'message': {
|
|
39
|
+
'content': choice,
|
|
40
|
+
'role': 'assistant'
|
|
42
41
|
}
|
|
43
|
-
],
|
|
42
|
+
}],
|
|
44
43
|
'created': time.time(),
|
|
45
44
|
'model': self.MODEL_ID + '-' + self.REVISION,
|
|
46
45
|
'object': 'chat.completion',
|
|
@@ -1,35 +1,25 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
# Copyright (c) EleutherAI, Inc. and its affiliates.
|
|
3
3
|
# flake8: noqa
|
|
4
|
+
import numpy as np
|
|
4
5
|
import os
|
|
5
6
|
import sys
|
|
6
|
-
from typing import List, Any, Union, Dict
|
|
7
|
-
import numpy as np
|
|
8
7
|
import time
|
|
8
|
+
import torch
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
10
|
from copy import deepcopy
|
|
11
|
-
|
|
12
|
-
import torch
|
|
11
|
+
from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
13
12
|
from torch import dtype
|
|
13
|
+
from typing import Any, Dict, List, Union
|
|
14
14
|
|
|
15
|
-
from evalscope.constants import
|
|
15
|
+
from evalscope.constants import DEFAULT_MODEL_CACHE_DIR
|
|
16
16
|
from evalscope.models.custom import CustomModel
|
|
17
|
-
from evalscope.
|
|
17
|
+
from evalscope.utils.chat_service import ChatMessage
|
|
18
18
|
from evalscope.utils.logger import get_logger
|
|
19
|
-
from
|
|
19
|
+
from evalscope.utils.model_utils import fix_do_sample_warning
|
|
20
20
|
|
|
21
21
|
logger = get_logger()
|
|
22
22
|
|
|
23
|
-
# Notes:
|
|
24
|
-
# - modelscope>=1.9.5
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def get_model_cache_dir(root_cache_dir: str):
|
|
28
|
-
model_cache_dir = os.path.join(root_cache_dir, 'models')
|
|
29
|
-
model_cache_dir = os.path.expanduser(model_cache_dir)
|
|
30
|
-
os.makedirs(model_cache_dir, exist_ok=True)
|
|
31
|
-
return model_cache_dir
|
|
32
|
-
|
|
33
23
|
|
|
34
24
|
class BaseModelAdapter(ABC):
|
|
35
25
|
"""
|
|
@@ -69,7 +59,7 @@ class MultiChoiceModelAdapter(BaseModelAdapter):
|
|
|
69
59
|
torch_dtype: dtype = torch.bfloat16,
|
|
70
60
|
model_revision: str = None,
|
|
71
61
|
max_length: int = None,
|
|
72
|
-
cache_dir: str =
|
|
62
|
+
cache_dir: str = None,
|
|
73
63
|
**kwargs):
|
|
74
64
|
"""
|
|
75
65
|
Args:
|
|
@@ -80,11 +70,11 @@ class MultiChoiceModelAdapter(BaseModelAdapter):
|
|
|
80
70
|
max_length: The max length of input sequence. Default: None.
|
|
81
71
|
**kwargs: Other args.
|
|
82
72
|
"""
|
|
83
|
-
model_cache_dir =
|
|
73
|
+
model_cache_dir = cache_dir or DEFAULT_MODEL_CACHE_DIR
|
|
84
74
|
|
|
85
75
|
self.model_id: str = model_id
|
|
86
76
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
87
|
-
logger.warning(f'
|
|
77
|
+
logger.warning(f'Device: {self.device}')
|
|
88
78
|
|
|
89
79
|
torch_dtype = torch_dtype if torch_dtype is not None else 'auto'
|
|
90
80
|
|
|
@@ -93,31 +83,21 @@ class MultiChoiceModelAdapter(BaseModelAdapter):
|
|
|
93
83
|
model_cfg['device_map'] = device_map
|
|
94
84
|
model_cfg['torch_dtype'] = str(torch_dtype)
|
|
95
85
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
cache_dir=model_cache_dir,)
|
|
112
|
-
|
|
113
|
-
model = AutoModelForCausalLM.from_pretrained(self.model_id, # self.model_id
|
|
114
|
-
revision=model_revision,
|
|
115
|
-
device_map=device_map,
|
|
116
|
-
trust_remote_code=True,
|
|
117
|
-
torch_dtype=torch_dtype,
|
|
118
|
-
cache_dir=model_cache_dir,)
|
|
119
|
-
|
|
120
|
-
# model.generation_config = GenerationConfig.from_pretrained(model_id, trust_remote_code=True)
|
|
86
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
87
|
+
self.model_id, # self.model_id
|
|
88
|
+
revision=model_revision,
|
|
89
|
+
trust_remote_code=True,
|
|
90
|
+
cache_dir=model_cache_dir,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
94
|
+
self.model_id, # self.model_id
|
|
95
|
+
revision=model_revision,
|
|
96
|
+
device_map=device_map,
|
|
97
|
+
trust_remote_code=True,
|
|
98
|
+
torch_dtype=torch_dtype,
|
|
99
|
+
cache_dir=model_cache_dir,
|
|
100
|
+
)
|
|
121
101
|
|
|
122
102
|
super().__init__(model=model, tokenizer=tokenizer, model_cfg=model_cfg)
|
|
123
103
|
|
|
@@ -187,18 +167,16 @@ class MultiChoiceModelAdapter(BaseModelAdapter):
|
|
|
187
167
|
if softval.dtype in {torch.bfloat16, torch.float16}:
|
|
188
168
|
softval = softval.to(dtype=torch.float32)
|
|
189
169
|
probs = softval.detach().cpu().numpy()
|
|
190
|
-
pred: str = multi_choices[int(np.argmax(probs))]
|
|
170
|
+
pred: str = multi_choices[int(np.argmax(probs))] # Format: A or B or C or D
|
|
191
171
|
|
|
192
172
|
res_d = {
|
|
193
|
-
'choices': [
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
'
|
|
197
|
-
|
|
198
|
-
'role': 'assistant'
|
|
199
|
-
}
|
|
173
|
+
'choices': [{
|
|
174
|
+
'index': 0,
|
|
175
|
+
'message': {
|
|
176
|
+
'content': pred,
|
|
177
|
+
'role': 'assistant'
|
|
200
178
|
}
|
|
201
|
-
],
|
|
179
|
+
}],
|
|
202
180
|
'created': time.time(),
|
|
203
181
|
'model': self.model_id,
|
|
204
182
|
'object': 'chat.completion',
|
|
@@ -226,7 +204,7 @@ class ContinuationLogitsModelAdapter(MultiChoiceModelAdapter):
|
|
|
226
204
|
device_map: str = 'auto',
|
|
227
205
|
torch_dtype: dtype = torch.bfloat16,
|
|
228
206
|
model_revision: str = None,
|
|
229
|
-
cache_dir: str =
|
|
207
|
+
cache_dir: str = None,
|
|
230
208
|
**kwargs):
|
|
231
209
|
"""
|
|
232
210
|
Continuation-logits model adapter.
|
|
@@ -239,12 +217,13 @@ class ContinuationLogitsModelAdapter(MultiChoiceModelAdapter):
|
|
|
239
217
|
**kwargs: Other args.
|
|
240
218
|
"""
|
|
241
219
|
|
|
242
|
-
super().__init__(
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
220
|
+
super().__init__(
|
|
221
|
+
model_id=model_id,
|
|
222
|
+
device_map=device_map,
|
|
223
|
+
torch_dtype=torch_dtype,
|
|
224
|
+
model_revision=model_revision,
|
|
225
|
+
cache_dir=cache_dir,
|
|
226
|
+
**kwargs)
|
|
248
227
|
|
|
249
228
|
@torch.no_grad()
|
|
250
229
|
def predict(self, inputs: dict, infer_cfg: dict = None) -> dict:
|
|
@@ -282,15 +261,13 @@ class ContinuationLogitsModelAdapter(MultiChoiceModelAdapter):
|
|
|
282
261
|
pred_list: list = self.loglikelihood(inputs=inputs['data'], infer_cfg=infer_cfg)
|
|
283
262
|
|
|
284
263
|
res_d = {
|
|
285
|
-
'choices': [
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
'
|
|
289
|
-
|
|
290
|
-
'role': 'assistant'
|
|
291
|
-
}
|
|
264
|
+
'choices': [{
|
|
265
|
+
'index': 0,
|
|
266
|
+
'message': {
|
|
267
|
+
'content': pred_list,
|
|
268
|
+
'role': 'assistant'
|
|
292
269
|
}
|
|
293
|
-
],
|
|
270
|
+
}],
|
|
294
271
|
'created': time.time(),
|
|
295
272
|
'model': self.model_id,
|
|
296
273
|
'object': 'chat.completion',
|
|
@@ -347,10 +324,10 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
|
347
324
|
|
|
348
325
|
def __init__(self,
|
|
349
326
|
model_id: str,
|
|
350
|
-
model_revision: str,
|
|
327
|
+
model_revision: str = 'master',
|
|
351
328
|
device_map: str = 'auto',
|
|
352
|
-
torch_dtype: dtype =
|
|
353
|
-
cache_dir: str =
|
|
329
|
+
torch_dtype: dtype = 'auto',
|
|
330
|
+
cache_dir: str = None,
|
|
354
331
|
**kwargs):
|
|
355
332
|
"""
|
|
356
333
|
Chat completion model adapter. Tasks of chat and generation are supported.
|
|
@@ -359,17 +336,18 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
|
359
336
|
model_id: The model id on ModelScope, or local model_dir.
|
|
360
337
|
model_revision: The model revision on ModelScope. Default: None.
|
|
361
338
|
device_map: The device map for model inference.
|
|
362
|
-
torch_dtype: The torch dtype for model inference. Default:
|
|
339
|
+
torch_dtype: The torch dtype for model inference. Default: 'auto'.
|
|
363
340
|
**kwargs: Other args.
|
|
364
341
|
"""
|
|
365
342
|
|
|
366
343
|
custom_generation_config = kwargs.pop('generation_config', None)
|
|
367
|
-
|
|
344
|
+
custom_chat_template = kwargs.pop('chat_template', None)
|
|
345
|
+
model_cache_dir = cache_dir or DEFAULT_MODEL_CACHE_DIR
|
|
368
346
|
|
|
369
347
|
self.model_id: str = model_id
|
|
370
348
|
self.model_revision: str = model_revision
|
|
371
349
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
372
|
-
logger.warning(f'
|
|
350
|
+
logger.warning(f'Device: {self.device}')
|
|
373
351
|
|
|
374
352
|
torch_dtype = torch_dtype if torch_dtype is not None else 'auto'
|
|
375
353
|
|
|
@@ -378,72 +356,47 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
|
378
356
|
model_cfg['device_map'] = device_map
|
|
379
357
|
model_cfg['torch_dtype'] = str(torch_dtype)
|
|
380
358
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
# except:
|
|
399
|
-
# model_dir = snapshot_download(self.model_id,
|
|
400
|
-
# revision=model_revision,
|
|
401
|
-
# cache_dir=model_cache_dir, )
|
|
402
|
-
# logger.warning('**Load model from ModelScope hub **')
|
|
403
|
-
|
|
404
|
-
tokenizer = AutoTokenizer.from_pretrained(self.model_id,
|
|
405
|
-
revision=model_revision,
|
|
406
|
-
trust_remote_code=True,
|
|
407
|
-
cache_dir=model_cache_dir,)
|
|
408
|
-
|
|
409
|
-
model = AutoModelForCausalLM.from_pretrained(self.model_id,
|
|
410
|
-
revision=model_revision,
|
|
411
|
-
device_map=device_map,
|
|
412
|
-
trust_remote_code=True,
|
|
413
|
-
torch_dtype=torch_dtype,
|
|
414
|
-
cache_dir=model_cache_dir,)
|
|
415
|
-
|
|
416
|
-
self.origin_tokenizer = deepcopy(tokenizer)
|
|
417
|
-
|
|
418
|
-
self.generation_config, self.generation_template = self._parse_generation_config(tokenizer, model)
|
|
359
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
360
|
+
self.model_id,
|
|
361
|
+
revision=model_revision,
|
|
362
|
+
trust_remote_code=True,
|
|
363
|
+
cache_dir=model_cache_dir,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
367
|
+
self.model_id,
|
|
368
|
+
revision=model_revision,
|
|
369
|
+
device_map=device_map,
|
|
370
|
+
trust_remote_code=True,
|
|
371
|
+
torch_dtype=torch_dtype,
|
|
372
|
+
cache_dir=model_cache_dir,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
self.generation_config = self._parse_generation_config(tokenizer, model)
|
|
419
376
|
|
|
420
377
|
if custom_generation_config:
|
|
421
|
-
logger.info('
|
|
422
|
-
self.generation_config.update(**custom_generation_config
|
|
423
|
-
logger.info(f'**Generation config init: {self.generation_config.to_dict()}')
|
|
378
|
+
logger.info('Updating generation config ...')
|
|
379
|
+
self.generation_config.update(**custom_generation_config)
|
|
424
380
|
|
|
425
|
-
|
|
381
|
+
if custom_chat_template:
|
|
382
|
+
tokenizer.chat_template = custom_chat_template
|
|
383
|
+
logger.info(f'Using custom chat template: {custom_chat_template}')
|
|
426
384
|
|
|
427
|
-
|
|
428
|
-
from modelscope.utils.hf_util import GenerationConfig
|
|
385
|
+
super().__init__(model=model, tokenizer=tokenizer, model_cfg=model_cfg)
|
|
429
386
|
|
|
430
|
-
|
|
387
|
+
def _parse_generation_config(self, tokenizer, model):
|
|
388
|
+
generation_config = getattr(model, 'generation_config', GenerationConfig(do_sample=False))
|
|
431
389
|
|
|
432
390
|
try:
|
|
433
391
|
remote_config = GenerationConfig.from_pretrained(
|
|
434
|
-
self.model_id,
|
|
435
|
-
revision=self.model_revision,
|
|
436
|
-
trust_remote_code=True)
|
|
392
|
+
self.model_id, revision=self.model_revision, trust_remote_code=True)
|
|
437
393
|
generation_config.update(**remote_config.to_dict())
|
|
438
394
|
except:
|
|
439
395
|
logger.warning(f'Failed to get generation config of {self.model_id} from model hub, use default.')
|
|
440
396
|
|
|
441
|
-
# Parse templates for chat-completion
|
|
442
397
|
if isinstance(self.model_id, str) and os.path.exists(self.model_id):
|
|
443
398
|
logger.warning(f'Got local model dir: {self.model_id}')
|
|
444
399
|
|
|
445
|
-
generation_template = get_template(template_type=self.template_type, tokenizer=tokenizer)
|
|
446
|
-
|
|
447
400
|
if tokenizer.eos_token_id is not None:
|
|
448
401
|
generation_config.eos_token_id = tokenizer.eos_token_id
|
|
449
402
|
if tokenizer.pad_token_id is not None:
|
|
@@ -451,24 +404,19 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
|
451
404
|
if generation_config.max_new_tokens is None:
|
|
452
405
|
generation_config.max_new_tokens = 2048
|
|
453
406
|
|
|
454
|
-
return generation_config
|
|
407
|
+
return generation_config
|
|
455
408
|
|
|
456
409
|
def _model_generate(self, query: str, infer_cfg: dict) -> str:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
inputs, _ = self.generation_template.encode(example)
|
|
410
|
+
messages = [ChatMessage(role='user', content=query)]
|
|
411
|
+
formatted_prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
412
|
+
inputs = self.tokenizer(formatted_prompt, return_tensors='pt', padding=True).to(self.device)
|
|
462
413
|
input_ids = inputs['input_ids']
|
|
463
|
-
input_ids = torch.tensor(input_ids)[None].to(self.device)
|
|
464
|
-
attention_mask = torch.ones_like(input_ids).to(self.device)
|
|
465
414
|
|
|
466
415
|
# Process infer_cfg
|
|
467
|
-
infer_cfg = infer_cfg or {}
|
|
468
416
|
if isinstance(infer_cfg.get('num_return_sequences'), int) and infer_cfg['num_return_sequences'] > 1:
|
|
469
417
|
infer_cfg['do_sample'] = True
|
|
470
418
|
|
|
471
|
-
#
|
|
419
|
+
# stop settings
|
|
472
420
|
stop = infer_cfg.get('stop', None)
|
|
473
421
|
eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0] \
|
|
474
422
|
if stop else self.tokenizer.eos_token_id
|
|
@@ -478,25 +426,16 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
|
478
426
|
infer_cfg['pad_token_id'] = eos_token_id # setting eos_token_id as pad token
|
|
479
427
|
|
|
480
428
|
self.generation_config.update(**infer_cfg)
|
|
481
|
-
|
|
482
|
-
# stopping
|
|
483
|
-
stop_words = [self.generation_template.suffix[-1]]
|
|
484
|
-
decode_kwargs = {}
|
|
485
|
-
stopping_criteria = StoppingCriteriaList(
|
|
486
|
-
[StopWordsCriteria(self.tokenizer, stop_words, **decode_kwargs)])
|
|
429
|
+
fix_do_sample_warning(self.generation_config)
|
|
487
430
|
|
|
488
431
|
# Run inference
|
|
489
|
-
output_ids = self.model.generate(
|
|
490
|
-
input_ids=input_ids,
|
|
491
|
-
attention_mask=attention_mask,
|
|
492
|
-
generation_config=self.generation_config,
|
|
493
|
-
stopping_criteria=stopping_criteria, )
|
|
432
|
+
output_ids = self.model.generate(**inputs, generation_config=self.generation_config)
|
|
494
433
|
|
|
495
|
-
response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], True
|
|
434
|
+
response = self.tokenizer.decode(output_ids[0, len(input_ids[0]):], skip_special_tokens=True)
|
|
496
435
|
return response
|
|
497
436
|
|
|
498
437
|
@torch.no_grad()
|
|
499
|
-
def predict(self, inputs: Union[str, dict, list], infer_cfg: dict =
|
|
438
|
+
def predict(self, inputs: Union[str, dict, list], infer_cfg: dict = {}) -> dict:
|
|
500
439
|
|
|
501
440
|
# Process inputs
|
|
502
441
|
if isinstance(inputs, str):
|
|
@@ -510,12 +449,7 @@ class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
|
510
449
|
|
|
511
450
|
response = self._model_generate(query, infer_cfg)
|
|
512
451
|
|
|
513
|
-
choices_list = [
|
|
514
|
-
{'index': 0,
|
|
515
|
-
'message': {'content': response,
|
|
516
|
-
'role': 'assistant'}
|
|
517
|
-
}
|
|
518
|
-
]
|
|
452
|
+
choices_list = [{'index': 0, 'message': {'content': response, 'role': 'assistant'}}]
|
|
519
453
|
|
|
520
454
|
res_d = {
|
|
521
455
|
'choices': choices_list,
|
|
@@ -589,4 +523,3 @@ class CustomModelAdapter(BaseModelAdapter):
|
|
|
589
523
|
raise TypeError(f'Unsupported inputs type: {type(input_prompt)}')
|
|
590
524
|
|
|
591
525
|
return self.custom_model.predict(prompts=in_prompts, **kwargs)
|
|
592
|
-
|
evalscope/models/openai_model.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
|
|
3
|
+
import openai
|
|
3
4
|
import os
|
|
4
5
|
import time
|
|
5
6
|
|
|
6
|
-
import openai
|
|
7
|
-
|
|
8
7
|
from evalscope.models import ChatBaseModel
|
|
9
8
|
from evalscope.utils.logger import get_logger
|
|
10
9
|
|
|
@@ -43,22 +42,25 @@ class OpenAIModel(ChatBaseModel):
|
|
|
43
42
|
|
|
44
43
|
logger.info(f'Using OpenAI model_id: {model_id}')
|
|
45
44
|
|
|
46
|
-
res = self._predict(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
45
|
+
res = self._predict(
|
|
46
|
+
model_id=model_id,
|
|
47
|
+
sys_prompt=sys_prompt,
|
|
48
|
+
user_prompt=user_prompt,
|
|
49
|
+
temperature=temperature,
|
|
50
|
+
max_tokens=max_tokens,
|
|
51
|
+
mode=mode)
|
|
52
52
|
|
|
53
53
|
return res
|
|
54
54
|
|
|
55
|
-
def _predict(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
55
|
+
def _predict(
|
|
56
|
+
self,
|
|
57
|
+
model_id,
|
|
58
|
+
sys_prompt,
|
|
59
|
+
user_prompt,
|
|
60
|
+
temperature,
|
|
61
|
+
max_tokens,
|
|
62
|
+
mode: str = 'chat.completion',
|
|
63
|
+
) -> dict:
|
|
62
64
|
|
|
63
65
|
res = {}
|
|
64
66
|
openai.api_key = self.api_key
|
|
@@ -82,9 +84,8 @@ class OpenAIModel(ChatBaseModel):
|
|
|
82
84
|
ans_text = resp['choices'][0]['message']['content']
|
|
83
85
|
model_id = resp['model']
|
|
84
86
|
else:
|
|
85
|
-
logger.warning(
|
|
86
|
-
|
|
87
|
-
f'for input {sys_prompt} {user_prompt}')
|
|
87
|
+
logger.warning(f'OpenAI GPT API call failed: got empty response '
|
|
88
|
+
f'for input {sys_prompt} {user_prompt}')
|
|
88
89
|
ans_text = ''
|
|
89
90
|
model_id = ''
|
|
90
91
|
|
|
@@ -98,6 +99,5 @@ class OpenAIModel(ChatBaseModel):
|
|
|
98
99
|
except Exception as e:
|
|
99
100
|
logger.warning(f'OpenAI API call failed: {e}')
|
|
100
101
|
time.sleep(3)
|
|
101
|
-
logger.error(
|
|
102
|
-
f'OpenAI API call failed after {self.MAX_RETRIES} retries')
|
|
102
|
+
logger.error(f'OpenAI API call failed after {self.MAX_RETRIES} retries')
|
|
103
103
|
return res
|