evalscope 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/__init__.py +1 -1
- evalscope/arguments.py +73 -0
- evalscope/backend/base.py +5 -1
- evalscope/backend/opencompass/api_meta_template.py +8 -14
- evalscope/backend/opencompass/backend_manager.py +24 -15
- evalscope/backend/opencompass/tasks/eval_api.py +1 -6
- evalscope/backend/opencompass/tasks/eval_datasets.py +26 -28
- evalscope/backend/rag_eval/__init__.py +3 -3
- evalscope/backend/rag_eval/backend_manager.py +21 -25
- evalscope/backend/rag_eval/clip_benchmark/__init__.py +1 -1
- evalscope/backend/rag_eval/clip_benchmark/arguments.py +6 -6
- evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +62 -79
- evalscope/backend/rag_eval/clip_benchmark/task_template.py +29 -43
- evalscope/backend/rag_eval/clip_benchmark/tasks/image_caption.py +20 -22
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_classification.py +16 -23
- evalscope/backend/rag_eval/clip_benchmark/tasks/zeroshot_retrieval.py +14 -35
- evalscope/backend/rag_eval/clip_benchmark/utils/webdataset_convert.py +69 -90
- evalscope/backend/rag_eval/cmteb/__init__.py +3 -3
- evalscope/backend/rag_eval/cmteb/arguments.py +25 -27
- evalscope/backend/rag_eval/cmteb/base.py +22 -23
- evalscope/backend/rag_eval/cmteb/task_template.py +15 -17
- evalscope/backend/rag_eval/cmteb/tasks/Classification.py +98 -79
- evalscope/backend/rag_eval/cmteb/tasks/Clustering.py +17 -22
- evalscope/backend/rag_eval/cmteb/tasks/CustomTask.py +17 -19
- evalscope/backend/rag_eval/cmteb/tasks/PairClassification.py +35 -29
- evalscope/backend/rag_eval/cmteb/tasks/Reranking.py +18 -5
- evalscope/backend/rag_eval/cmteb/tasks/Retrieval.py +163 -163
- evalscope/backend/rag_eval/cmteb/tasks/STS.py +126 -104
- evalscope/backend/rag_eval/cmteb/tasks/__init__.py +33 -34
- evalscope/backend/rag_eval/ragas/__init__.py +2 -2
- evalscope/backend/rag_eval/ragas/arguments.py +3 -8
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +9 -9
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/CustomNodeFilter/scoring_prompt_chinese.json +7 -0
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +8 -8
- evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +7 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +27 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +27 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +5 -5
- evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +21 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +3 -3
- evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +4 -4
- evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +2 -2
- evalscope/backend/rag_eval/ragas/prompts/persona_prompt.py +0 -1
- evalscope/backend/rag_eval/ragas/task_template.py +10 -15
- evalscope/backend/rag_eval/ragas/tasks/__init__.py +1 -1
- evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +45 -0
- evalscope/backend/rag_eval/ragas/tasks/build_transform.py +135 -0
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +17 -133
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +8 -18
- evalscope/backend/rag_eval/utils/clip.py +46 -50
- evalscope/backend/rag_eval/utils/embedding.py +12 -11
- evalscope/backend/rag_eval/utils/llm.py +8 -6
- evalscope/backend/rag_eval/utils/tools.py +12 -11
- evalscope/backend/vlm_eval_kit/__init__.py +1 -1
- evalscope/backend/vlm_eval_kit/custom_dataset.py +7 -8
- evalscope/benchmarks/arc/__init__.py +3 -2
- evalscope/benchmarks/arc/ai2_arc.py +19 -16
- evalscope/benchmarks/arc/arc_adapter.py +32 -24
- evalscope/benchmarks/bbh/__init__.py +1 -2
- evalscope/benchmarks/bbh/bbh_adapter.py +28 -25
- evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/navigate.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/snarks.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +1 -1
- evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +1 -1
- evalscope/benchmarks/benchmark.py +16 -16
- evalscope/benchmarks/ceval/__init__.py +3 -2
- evalscope/benchmarks/ceval/ceval_adapter.py +80 -69
- evalscope/benchmarks/ceval/ceval_exam.py +18 -31
- evalscope/benchmarks/cmmlu/__init__.py +3 -2
- evalscope/benchmarks/cmmlu/cmmlu.py +87 -92
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +109 -155
- evalscope/benchmarks/cmmlu/samples.jsonl +1 -1
- evalscope/benchmarks/competition_math/__init__.py +3 -2
- evalscope/benchmarks/competition_math/competition_math.py +7 -16
- evalscope/benchmarks/competition_math/competition_math_adapter.py +32 -34
- evalscope/benchmarks/data_adapter.py +24 -24
- evalscope/benchmarks/general_qa/__init__.py +3 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +34 -38
- evalscope/benchmarks/gsm8k/__init__.py +1 -1
- evalscope/benchmarks/gsm8k/gsm8k.py +6 -12
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +26 -24
- evalscope/benchmarks/hellaswag/__init__.py +3 -2
- evalscope/benchmarks/hellaswag/hellaswag.py +15 -19
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +27 -23
- evalscope/benchmarks/humaneval/__init__.py +1 -1
- evalscope/benchmarks/humaneval/humaneval.py +15 -18
- evalscope/benchmarks/humaneval/humaneval_adapter.py +0 -1
- evalscope/benchmarks/mmlu/__init__.py +3 -2
- evalscope/benchmarks/mmlu/mmlu.py +15 -29
- evalscope/benchmarks/mmlu/mmlu_adapter.py +85 -77
- evalscope/benchmarks/race/__init__.py +3 -2
- evalscope/benchmarks/race/race.py +21 -35
- evalscope/benchmarks/race/race_adapter.py +32 -29
- evalscope/benchmarks/race/samples.jsonl +1 -1
- evalscope/benchmarks/trivia_qa/__init__.py +3 -2
- evalscope/benchmarks/trivia_qa/samples.jsonl +1 -1
- evalscope/benchmarks/trivia_qa/trivia_qa.py +19 -34
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +27 -22
- evalscope/benchmarks/truthful_qa/__init__.py +3 -2
- evalscope/benchmarks/truthful_qa/truthful_qa.py +25 -29
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +36 -37
- evalscope/cli/cli.py +6 -5
- evalscope/cli/start_eval.py +31 -0
- evalscope/cli/start_perf.py +0 -3
- evalscope/cli/start_server.py +27 -41
- evalscope/config.py +119 -95
- evalscope/constants.py +61 -29
- evalscope/evaluator/__init__.py +1 -0
- evalscope/evaluator/evaluator.py +96 -377
- evalscope/evaluator/humaneval_evaluator.py +158 -0
- evalscope/evaluator/rating_eval.py +12 -33
- evalscope/evaluator/reviewer/auto_reviewer.py +47 -76
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +10 -20
- evalscope/metrics/code_metric.py +3 -9
- evalscope/metrics/math_accuracy.py +3 -6
- evalscope/metrics/metrics.py +21 -21
- evalscope/metrics/rouge_metric.py +11 -25
- evalscope/models/__init__.py +1 -2
- evalscope/models/api/openai_api.py +40 -29
- evalscope/models/custom/__init__.py +0 -1
- evalscope/models/custom/custom_model.py +3 -3
- evalscope/models/dummy_chat_model.py +7 -8
- evalscope/models/model_adapter.py +89 -156
- evalscope/models/openai_model.py +20 -20
- evalscope/perf/arguments.py +15 -3
- evalscope/perf/benchmark.py +7 -9
- evalscope/perf/http_client.py +3 -8
- evalscope/perf/main.py +10 -0
- evalscope/perf/plugin/api/custom_api.py +1 -2
- evalscope/perf/plugin/api/dashscope_api.py +1 -2
- evalscope/perf/plugin/api/openai_api.py +2 -3
- evalscope/perf/plugin/datasets/base.py +1 -2
- evalscope/perf/plugin/datasets/flickr8k.py +1 -2
- evalscope/perf/plugin/datasets/longalpaca.py +1 -2
- evalscope/perf/plugin/datasets/openqa.py +1 -2
- evalscope/perf/utils/analysis_result.py +1 -2
- evalscope/perf/utils/benchmark_util.py +1 -2
- evalscope/perf/utils/db_util.py +11 -8
- evalscope/perf/utils/local_server.py +19 -13
- evalscope/registry/config/cfg_arena_zhihu.yaml +1 -1
- evalscope/registry/tasks/arc.yaml +2 -3
- evalscope/registry/tasks/bbh.yaml +3 -4
- evalscope/registry/tasks/bbh_mini.yaml +3 -4
- evalscope/registry/tasks/ceval.yaml +3 -3
- evalscope/registry/tasks/ceval_mini.yaml +3 -4
- evalscope/registry/tasks/cmmlu.yaml +3 -3
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +1 -1
- evalscope/registry/tasks/general_qa.yaml +1 -1
- evalscope/registry/tasks/gsm8k.yaml +2 -2
- evalscope/registry/tasks/mmlu.yaml +3 -3
- evalscope/registry/tasks/mmlu_mini.yaml +3 -3
- evalscope/run.py +184 -375
- evalscope/run_arena.py +20 -25
- evalscope/summarizer.py +16 -17
- evalscope/third_party/longbench_write/README.md +99 -42
- evalscope/third_party/longbench_write/default_task.json +1 -1
- evalscope/third_party/longbench_write/default_task.yaml +8 -7
- evalscope/third_party/longbench_write/eval.py +29 -28
- evalscope/third_party/longbench_write/infer.py +16 -104
- evalscope/third_party/longbench_write/longbench_write.py +5 -5
- evalscope/third_party/longbench_write/resources/judge.txt +1 -1
- evalscope/third_party/longbench_write/tools/data_etl.py +4 -5
- evalscope/third_party/longbench_write/utils.py +0 -1
- evalscope/third_party/toolbench_static/eval.py +14 -15
- evalscope/third_party/toolbench_static/infer.py +48 -69
- evalscope/third_party/toolbench_static/llm/swift_infer.py +4 -12
- evalscope/third_party/toolbench_static/requirements.txt +1 -1
- evalscope/third_party/toolbench_static/toolbench_static.py +3 -3
- evalscope/tools/combine_reports.py +25 -30
- evalscope/tools/rewrite_eval_results.py +14 -46
- evalscope/utils/__init__.py +0 -1
- evalscope/utils/arena_utils.py +18 -48
- evalscope/{perf/utils → utils}/chat_service.py +3 -4
- evalscope/utils/completion_parsers.py +3 -8
- evalscope/utils/logger.py +9 -7
- evalscope/utils/model_utils.py +11 -0
- evalscope/utils/utils.py +12 -138
- evalscope/version.py +2 -2
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/METADATA +123 -118
- evalscope-0.8.0.dist-info/RECORD +285 -0
- tests/cli/test_run.py +54 -15
- tests/perf/test_perf.py +4 -0
- tests/rag/test_clip_benchmark.py +38 -38
- tests/rag/test_mteb.py +3 -2
- tests/rag/test_ragas.py +5 -5
- tests/swift/test_run_swift_eval.py +2 -3
- tests/swift/test_run_swift_vlm_eval.py +2 -3
- tests/swift/test_run_swift_vlm_jugde_eval.py +2 -3
- evalscope/backend/rag_eval/ragas/metrics/__init__.py +0 -2
- evalscope/backend/rag_eval/ragas/metrics/multi_modal_faithfulness.py +0 -91
- evalscope/backend/rag_eval/ragas/metrics/multi_modal_relevance.py +0 -99
- evalscope/cache.py +0 -98
- evalscope/models/template.py +0 -1446
- evalscope/run_ms.py +0 -140
- evalscope/utils/task_cfg_parser.py +0 -10
- evalscope/utils/task_utils.py +0 -22
- evalscope-0.7.2.dist-info/RECORD +0 -286
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/LICENSE +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/WHEEL +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.7.2.dist-info → evalscope-0.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,15 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
# Copyright (c) ZhipuAI, Inc. and its affiliates.
|
|
3
3
|
|
|
4
|
-
import os
|
|
5
4
|
import json
|
|
6
|
-
from typing import List
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
5
|
import numpy as np
|
|
6
|
+
import os
|
|
10
7
|
import random
|
|
11
|
-
|
|
12
|
-
from
|
|
8
|
+
import torch
|
|
9
|
+
from typing import List
|
|
13
10
|
|
|
14
|
-
from evalscope.third_party.longbench_write.utils import count_words
|
|
15
11
|
from evalscope.models.api import OpenaiApi
|
|
12
|
+
from evalscope.third_party.longbench_write.utils import count_words
|
|
16
13
|
from evalscope.utils import get_logger
|
|
17
14
|
|
|
18
15
|
logger = get_logger()
|
|
@@ -25,39 +22,6 @@ Refer to https://github.com/THUDM/LongWriter for more details.
|
|
|
25
22
|
"""
|
|
26
23
|
|
|
27
24
|
|
|
28
|
-
def get_pred(rank, world_size, data, path, max_new_tokens, temperature, tokenizer, fout):
|
|
29
|
-
device = torch.device(f'cuda:{rank}')
|
|
30
|
-
model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
|
31
|
-
model = model.eval()
|
|
32
|
-
|
|
33
|
-
for dt in tqdm(data, total=len(data), desc=f'Infer on rank-{rank}: '):
|
|
34
|
-
prompt = dt['prompt']
|
|
35
|
-
if "llama" in path.lower():
|
|
36
|
-
prompt = f"[INST]{prompt}[/INST]"
|
|
37
|
-
input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
|
|
38
|
-
context_length = input.input_ids.shape[-1]
|
|
39
|
-
output = model.generate(
|
|
40
|
-
**input,
|
|
41
|
-
max_new_tokens=max_new_tokens,
|
|
42
|
-
num_beams=1,
|
|
43
|
-
do_sample=True,
|
|
44
|
-
temperature=temperature,
|
|
45
|
-
)[0]
|
|
46
|
-
response = tokenizer.decode(output[context_length:], skip_special_tokens=True)
|
|
47
|
-
else:
|
|
48
|
-
response, history = model.chat(tokenizer, prompt, history=[], max_new_tokens=max_new_tokens,
|
|
49
|
-
temperature=temperature)
|
|
50
|
-
dt["response_length"], _ = count_words(response)
|
|
51
|
-
dt["response"] = response
|
|
52
|
-
|
|
53
|
-
logger.info(dt)
|
|
54
|
-
|
|
55
|
-
fout.write(json.dumps(dt, ensure_ascii=False) + '\n')
|
|
56
|
-
fout.flush()
|
|
57
|
-
|
|
58
|
-
logger.info(f'Successfully generated predictions for {len(data)} samples.')
|
|
59
|
-
|
|
60
|
-
|
|
61
25
|
def seed_everything(seed):
|
|
62
26
|
torch.manual_seed(seed)
|
|
63
27
|
torch.cuda.manual_seed(seed)
|
|
@@ -68,69 +32,13 @@ def seed_everything(seed):
|
|
|
68
32
|
torch.cuda.manual_seed_all(seed)
|
|
69
33
|
|
|
70
34
|
|
|
71
|
-
# def run_infer(model: str,
|
|
72
|
-
# data_path: str,
|
|
73
|
-
# output_dir: str,
|
|
74
|
-
# generation_kwargs: dict = None,
|
|
75
|
-
# enable: bool = True, ):
|
|
76
|
-
# """
|
|
77
|
-
# Process inference for LongWriter model.
|
|
78
|
-
#
|
|
79
|
-
# Args:
|
|
80
|
-
# model: The model id of the LongWriter model on ModelScope, or local model path.
|
|
81
|
-
# data_path: The path to the data file.
|
|
82
|
-
# output_dir: The output directory for the predictions.
|
|
83
|
-
# generation_kwargs: The generation arguments for the model.
|
|
84
|
-
# Attributes: `max_new_tokens`: The maximum number of tokens to generate. `temperature`: The temperature
|
|
85
|
-
# enable: Whether to run infer process.
|
|
86
|
-
# """
|
|
87
|
-
# model_id_path: str = os.path.join(output_dir, model.strip(os.sep).replace(os.sep, '__'))
|
|
88
|
-
#
|
|
89
|
-
# if not enable:
|
|
90
|
-
# logger.warning('*** Skip `infer` stage ***')
|
|
91
|
-
# return f'{model_id_path}/pred.jsonl'
|
|
92
|
-
#
|
|
93
|
-
# seed_everything(42)
|
|
94
|
-
#
|
|
95
|
-
# os.makedirs(model_id_path, exist_ok=True)
|
|
96
|
-
# fout = open(f'{model_id_path}/pred.jsonl', 'w', encoding='utf-8')
|
|
97
|
-
#
|
|
98
|
-
# if generation_kwargs is None:
|
|
99
|
-
# generation_kwargs = dict({
|
|
100
|
-
# 'max_new_tokens': 32768,
|
|
101
|
-
# 'temperature': 0.5
|
|
102
|
-
# })
|
|
103
|
-
#
|
|
104
|
-
# tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
|
105
|
-
# world_size = torch.cuda.device_count()
|
|
106
|
-
#
|
|
107
|
-
# logger.info(f'>>Input data path: {data_path}')
|
|
108
|
-
# with open(data_path, encoding='utf-8') as f:
|
|
109
|
-
# data = [json.loads(line) for line in f]
|
|
110
|
-
#
|
|
111
|
-
# data_subsets = [data[i::world_size] for i in range(world_size)]
|
|
112
|
-
# processes = []
|
|
113
|
-
# for rank in range(world_size):
|
|
114
|
-
# p = mp.Process(target=get_pred,
|
|
115
|
-
# args=(rank, world_size, data_subsets[rank], model, generation_kwargs.get('max_new_tokens'), generation_kwargs.get('temperature'), tokenizer, fout))
|
|
116
|
-
# p.start()
|
|
117
|
-
# processes.append(p)
|
|
118
|
-
#
|
|
119
|
-
# for p in processes:
|
|
120
|
-
# p.join()
|
|
121
|
-
#
|
|
122
|
-
# logger.info(f'Finish generating predictions for {model}.')
|
|
123
|
-
# logger.info(f'Predictions are saved in {model_id_path}/pred.jsonl.')
|
|
124
|
-
#
|
|
125
|
-
# return f'{model_id_path}/pred.jsonl'
|
|
126
|
-
|
|
127
|
-
|
|
128
35
|
def run_infer(model: str,
|
|
129
36
|
data_path: str,
|
|
130
37
|
output_dir: str,
|
|
131
38
|
api_config: dict,
|
|
132
39
|
generation_kwargs: dict = None,
|
|
133
|
-
enable: bool = True,
|
|
40
|
+
enable: bool = True,
|
|
41
|
+
proc_num: int = DEFAULT_PROC_NUM):
|
|
134
42
|
"""
|
|
135
43
|
Process inference for LongWriter model.
|
|
136
44
|
|
|
@@ -147,6 +55,7 @@ def run_infer(model: str,
|
|
|
147
55
|
generation_kwargs: The generation arguments for the model.
|
|
148
56
|
Attributes: `max_new_tokens`: The maximum number of tokens to generate. `temperature`: The temperature
|
|
149
57
|
enable: Whether to run infer process.
|
|
58
|
+
proc_num: calling OpenAI api service with proc_num
|
|
150
59
|
"""
|
|
151
60
|
model_id_path: str = os.path.join(output_dir, model.strip(os.sep).replace(os.sep, '__'))
|
|
152
61
|
|
|
@@ -173,7 +82,8 @@ def run_infer(model: str,
|
|
|
173
82
|
|
|
174
83
|
api_client = OpenaiApi(model=model,
|
|
175
84
|
openai_api_key=None,
|
|
176
|
-
openai_api_base=api_config.get('openai_api_base',
|
|
85
|
+
openai_api_base=api_config.get('openai_api_base',
|
|
86
|
+
'http://127.0.0.1:8000/v1/chat/completions'),
|
|
177
87
|
max_new_tokens=generation_kwargs.get('max_new_tokens', 4096),
|
|
178
88
|
temperature=generation_kwargs.get('temperature', 0.0),
|
|
179
89
|
repetition_penalty=generation_kwargs.get('repetition_penalty', 1.0),
|
|
@@ -181,9 +91,11 @@ def run_infer(model: str,
|
|
|
181
91
|
verbose=api_config.get('verbose', False),
|
|
182
92
|
)
|
|
183
93
|
|
|
184
|
-
# TODO:
|
|
185
|
-
results: List[str] = api_client.generate_simple(inputs=[example['prompt'] for example in data_list]
|
|
186
|
-
|
|
94
|
+
# TODO: refine generate_simple
|
|
95
|
+
results: List[str] = api_client.generate_simple(inputs=[example['prompt'] for example in data_list],
|
|
96
|
+
num_proc=proc_num)
|
|
97
|
+
assert len(results) == len(data_list), \
|
|
98
|
+
f'Error: The number of predictions {len(results)} is not equal to the number of inputs {len(data_list)}.'
|
|
187
99
|
logger.info(f'Finish generating predictions with {len(data_list)} samples for {model}')
|
|
188
100
|
|
|
189
101
|
# Outputs
|
|
@@ -191,8 +103,8 @@ def run_infer(model: str,
|
|
|
191
103
|
output_pred_file: str = f'{model_id_path}/pred.jsonl'
|
|
192
104
|
with open(output_pred_file, 'w', encoding='utf-8') as f:
|
|
193
105
|
for dt, res in zip(data_list, results):
|
|
194
|
-
dt[
|
|
195
|
-
dt[
|
|
106
|
+
dt['response_length'], _ = count_words(res)
|
|
107
|
+
dt['response'] = res
|
|
196
108
|
f.write(json.dumps(dt, ensure_ascii=False) + '\n')
|
|
197
109
|
|
|
198
110
|
logger.info(f'Predictions are saved in {output_pred_file}')
|
|
@@ -2,10 +2,9 @@
|
|
|
2
2
|
import os
|
|
3
3
|
from typing import Union
|
|
4
4
|
|
|
5
|
-
from evalscope.third_party.longbench_write.infer import run_infer
|
|
6
5
|
from evalscope.third_party.longbench_write.eval import run_eval
|
|
7
|
-
from evalscope.
|
|
8
|
-
from evalscope.utils import get_logger
|
|
6
|
+
from evalscope.third_party.longbench_write.infer import run_infer
|
|
7
|
+
from evalscope.utils import get_logger, json_to_dict, yaml_to_dict
|
|
9
8
|
|
|
10
9
|
logger = get_logger()
|
|
11
10
|
|
|
@@ -45,7 +44,8 @@ def run_task(task_cfg: Union[str, dict]):
|
|
|
45
44
|
verbose=infer_config.get('verbose', False),
|
|
46
45
|
),
|
|
47
46
|
generation_kwargs=infer_config.get('generation_kwargs'),
|
|
48
|
-
enable='infer' in stage
|
|
47
|
+
enable='infer' in stage,
|
|
48
|
+
proc_num=infer_config.get('proc_num', 16))
|
|
49
49
|
|
|
50
50
|
# Run eval process
|
|
51
51
|
run_eval(model=model,
|
|
@@ -77,7 +77,7 @@ if __name__ == '__main__':
|
|
|
77
77
|
},
|
|
78
78
|
|
|
79
79
|
eval_config={
|
|
80
|
-
'openai_api_key':
|
|
80
|
+
'openai_api_key': None,
|
|
81
81
|
'openai_api_base': 'https://api.openai.com/v1/chat/completions',
|
|
82
82
|
'openai_gpt_model': 'gpt-4o-2024-05-13',
|
|
83
83
|
'generation_kwargs': {'max_new_tokens': 1024, 'temperature': 0.5, 'stop': None},
|
|
@@ -28,4 +28,4 @@ $RESPONSE$
|
|
|
28
28
|
|
|
29
29
|
</Response>
|
|
30
30
|
|
|
31
|
-
Please evaluate the quality of the response. You must first provide a brief analysis of its quality, then give a comprehensive analysis with scores for each dimension. The output must strictly follow the JSON format: {"Analysis": ..., "Relevance": ..., "Accuracy": ..., "Coherence": ..., "Clarity": ..., "Breadth and Depth": ..., "Reading Experience": ...}. You do not need to consider whether the response meets the user's length requirements in your evaluation. Ensure that only one integer between 1 and 5 is output for each dimension score.
|
|
31
|
+
Please evaluate the quality of the response. You must first provide a brief analysis of its quality, then give a comprehensive analysis with scores for each dimension. The output must strictly follow the JSON format: {"Analysis": ..., "Relevance": ..., "Accuracy": ..., "Coherence": ..., "Clarity": ..., "Breadth and Depth": ..., "Reading Experience": ...}. You do not need to consider whether the response meets the user's length requirements in your evaluation. Ensure that only one integer between 1 and 5 is output for each dimension score.
|
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import json
|
|
2
3
|
import os.path
|
|
3
|
-
from typing import List
|
|
4
4
|
import re
|
|
5
|
-
import
|
|
5
|
+
from typing import List
|
|
6
6
|
|
|
7
7
|
from evalscope.third_party.longbench_write.eval import EvalLength
|
|
8
|
-
from evalscope.third_party.longbench_write.utils import
|
|
8
|
+
from evalscope.third_party.longbench_write.utils import chinese_to_arabic, count_words
|
|
9
9
|
from evalscope.utils import jsonl_to_list
|
|
10
10
|
from evalscope.utils.logger import get_logger
|
|
11
11
|
|
|
12
12
|
logger = get_logger()
|
|
13
|
-
|
|
14
13
|
"""
|
|
15
14
|
This script is used to preprocess the dataset for the LongWriter.
|
|
16
15
|
"""
|
|
@@ -141,7 +140,7 @@ class DataETL:
|
|
|
141
140
|
return out_file
|
|
142
141
|
|
|
143
142
|
|
|
144
|
-
if __name__ ==
|
|
143
|
+
if __name__ == '__main__':
|
|
145
144
|
# run `no_required_length`: got 1748 exampels left
|
|
146
145
|
|
|
147
146
|
# Refer to: https://modelscope.cn/datasets/ZhipuAI/LongWriter-6k/files
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import os
|
|
4
5
|
from dataclasses import dataclass
|
|
5
|
-
|
|
6
6
|
from rouge import Rouge
|
|
7
|
-
import os
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
@dataclass
|
|
@@ -24,7 +23,7 @@ def run_eval(args: EvalArgs):
|
|
|
24
23
|
return 0
|
|
25
24
|
rouge = Rouge()
|
|
26
25
|
rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True)
|
|
27
|
-
rougel = rouge_score[
|
|
26
|
+
rougel = rouge_score['rouge-l']['f']
|
|
28
27
|
return rougel
|
|
29
28
|
|
|
30
29
|
def evaluate_action_em(cand_list: list, ref_list: list):
|
|
@@ -97,8 +96,8 @@ def run_eval(args: EvalArgs):
|
|
|
97
96
|
data = json.load(f)
|
|
98
97
|
|
|
99
98
|
def parse_action(text):
|
|
100
|
-
action =
|
|
101
|
-
action_input =
|
|
99
|
+
action = 'None'
|
|
100
|
+
action_input = '{}'
|
|
102
101
|
if 'Action Input:' in text:
|
|
103
102
|
input_idx = text.rindex('Action Input:')
|
|
104
103
|
action_input = text[input_idx + len('Action Input:'):].strip()
|
|
@@ -117,24 +116,24 @@ def run_eval(args: EvalArgs):
|
|
|
117
116
|
|
|
118
117
|
def parse_output(text):
|
|
119
118
|
action, action_input = parse_action(text)
|
|
120
|
-
if action ==
|
|
119
|
+
if action == 'Finish':
|
|
121
120
|
try:
|
|
122
121
|
action_input = json.loads(action_input)
|
|
123
122
|
# print(action_input)
|
|
124
123
|
# print(json.dumps(action_input,indent=2))
|
|
125
|
-
return_type = action_input[
|
|
126
|
-
if return_type ==
|
|
127
|
-
if
|
|
124
|
+
return_type = action_input['return_type']
|
|
125
|
+
if return_type == 'give_answer':
|
|
126
|
+
if 'final_answer' in action_input.keys():
|
|
128
127
|
answer = str(action_input['final_answer'])
|
|
129
128
|
if answer.strip() in ['', '.', ',']:
|
|
130
|
-
answer =
|
|
129
|
+
answer = 'None'
|
|
131
130
|
else:
|
|
132
|
-
answer =
|
|
133
|
-
return
|
|
131
|
+
answer = 'None'
|
|
132
|
+
return 'finish', action, action_input, answer
|
|
134
133
|
else:
|
|
135
|
-
return
|
|
134
|
+
return 'give up', None, None, None
|
|
136
135
|
except:
|
|
137
|
-
return
|
|
136
|
+
return 'give up', None, None, None
|
|
138
137
|
else:
|
|
139
138
|
plan = 'call'
|
|
140
139
|
answer = None
|
|
@@ -163,7 +162,7 @@ def run_eval(args: EvalArgs):
|
|
|
163
162
|
# ref_ans: None
|
|
164
163
|
|
|
165
164
|
pred_plan, pred_action, pred_input, pred_ans = parse_output(prediction)
|
|
166
|
-
if ref_action is not None and ref_action ==
|
|
165
|
+
if ref_action is not None and ref_action == 'invalid_hallucination_function_name':
|
|
167
166
|
continue
|
|
168
167
|
if pred_action is not None and ref_action != 'none' and ref_action not in [t['name'] for t in d['tools']]:
|
|
169
168
|
continue
|
|
@@ -16,13 +16,13 @@
|
|
|
16
16
|
|
|
17
17
|
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
|
18
18
|
|
|
19
|
-
from dataclasses import dataclass, field
|
|
20
19
|
import json
|
|
21
20
|
import os
|
|
22
|
-
|
|
21
|
+
import requests
|
|
23
22
|
import time
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from rouge import Rouge
|
|
24
25
|
from urllib3.exceptions import MaxRetryError, NewConnectionError
|
|
25
|
-
import requests
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def evaluate_rouge_l(cand_list: list, ref_list: list):
|
|
@@ -30,7 +30,7 @@ def evaluate_rouge_l(cand_list: list, ref_list: list):
|
|
|
30
30
|
return 0
|
|
31
31
|
rouge = Rouge()
|
|
32
32
|
rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True)
|
|
33
|
-
rougel = rouge_score[
|
|
33
|
+
rougel = rouge_score['rouge-l']['f']
|
|
34
34
|
return rougel
|
|
35
35
|
|
|
36
36
|
|
|
@@ -42,8 +42,8 @@ def nested_load_test_data(data_path):
|
|
|
42
42
|
test_raw_data += temp_test
|
|
43
43
|
return test_raw_data
|
|
44
44
|
elif os.path.isfile(data_path) and data_path.endswith('.json'):
|
|
45
|
-
print(
|
|
46
|
-
temp_data = json.load(open(data_path,
|
|
45
|
+
print('Load data from', data_path)
|
|
46
|
+
temp_data = json.load(open(data_path, 'r'))
|
|
47
47
|
test_raw_data = temp_data
|
|
48
48
|
return test_raw_data
|
|
49
49
|
else:
|
|
@@ -51,39 +51,24 @@ def nested_load_test_data(data_path):
|
|
|
51
51
|
|
|
52
52
|
|
|
53
53
|
def baichuan_call(context: list, system: str):
|
|
54
|
-
url =
|
|
55
|
-
api_key =
|
|
54
|
+
url = 'https://api.baichuan-ai.com/v1/chat/completions'
|
|
55
|
+
api_key = 'sk-xxx'
|
|
56
56
|
|
|
57
57
|
new_msg = []
|
|
58
|
-
new_msg.append({
|
|
59
|
-
"role": 'system',
|
|
60
|
-
'content': system})
|
|
58
|
+
new_msg.append({'role': 'system', 'content': system})
|
|
61
59
|
for m in context:
|
|
62
|
-
if m['role'] ==
|
|
63
|
-
new_msg.append({
|
|
64
|
-
|
|
65
|
-
})
|
|
66
|
-
elif m['role'] == "function":
|
|
67
|
-
new_msg.append({
|
|
68
|
-
'role': 'user', 'content': m['content']
|
|
69
|
-
})
|
|
60
|
+
if m['role'] == 'user':
|
|
61
|
+
new_msg.append({'role': 'user', 'content': m['content']})
|
|
62
|
+
elif m['role'] == 'function':
|
|
63
|
+
new_msg.append({'role': 'user', 'content': m['content']})
|
|
70
64
|
elif m['role'] == 'assistant':
|
|
71
|
-
new_msg.append({
|
|
72
|
-
'role': 'assistant', 'content': m['content']
|
|
73
|
-
})
|
|
65
|
+
new_msg.append({'role': 'assistant', 'content': m['content']})
|
|
74
66
|
# print(json.dumps(new_msg, indent=2))
|
|
75
|
-
data = {
|
|
76
|
-
"model": "Baichuan2-Turbo",
|
|
77
|
-
"messages": new_msg,
|
|
78
|
-
"stream": False
|
|
79
|
-
}
|
|
67
|
+
data = {'model': 'Baichuan2-Turbo', 'messages': new_msg, 'stream': False}
|
|
80
68
|
|
|
81
69
|
json_data = json.dumps(data)
|
|
82
70
|
|
|
83
|
-
headers = {
|
|
84
|
-
"Content-Type": "application/json",
|
|
85
|
-
"Authorization": "Bearer " + api_key
|
|
86
|
-
}
|
|
71
|
+
headers = {'Content-Type': 'application/json', 'Authorization': 'Bearer ' + api_key}
|
|
87
72
|
|
|
88
73
|
for i in range(5):
|
|
89
74
|
res = None
|
|
@@ -91,7 +76,7 @@ def baichuan_call(context: list, system: str):
|
|
|
91
76
|
res = requests.post(url, data=json_data, headers=headers, timeout=60)
|
|
92
77
|
res = res._content.decode('utf-8')
|
|
93
78
|
res = json.loads(res)
|
|
94
|
-
return res[
|
|
79
|
+
return res['choices'][0]['message']['content']
|
|
95
80
|
except KeyError:
|
|
96
81
|
print(res)
|
|
97
82
|
time.sleep(1)
|
|
@@ -105,57 +90,52 @@ def baichuan_call(context: list, system: str):
|
|
|
105
90
|
except NewConnectionError:
|
|
106
91
|
time.sleep(5)
|
|
107
92
|
continue
|
|
108
|
-
return
|
|
93
|
+
return ''
|
|
109
94
|
|
|
110
95
|
|
|
111
96
|
def minimax_call(context: list, system: str):
|
|
112
|
-
group_id =
|
|
113
|
-
api_key =
|
|
97
|
+
group_id = 'your-id'
|
|
98
|
+
api_key = 'your-xxx'
|
|
114
99
|
|
|
115
|
-
url = f
|
|
116
|
-
headers = {
|
|
100
|
+
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
|
|
101
|
+
headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
|
|
117
102
|
|
|
118
103
|
# construct message
|
|
119
|
-
system_prompt =
|
|
120
|
-
|
|
104
|
+
system_prompt = 'MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。' \
|
|
105
|
+
'MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。'
|
|
121
106
|
system_prompt += ('\n' + system)
|
|
122
107
|
|
|
123
108
|
new_msg = []
|
|
124
109
|
for m in context:
|
|
125
|
-
if m['role'] ==
|
|
126
|
-
new_msg.append({
|
|
127
|
-
|
|
128
|
-
})
|
|
129
|
-
elif m['role'] == "function":
|
|
130
|
-
new_msg.append({
|
|
131
|
-
'sender_type': 'USER', 'sender_name': 'funtion', 'text': m['content']
|
|
132
|
-
})
|
|
110
|
+
if m['role'] == 'user':
|
|
111
|
+
new_msg.append({'sender_type': 'USER', 'sender_name': 'user', 'text': m['content']})
|
|
112
|
+
elif m['role'] == 'function':
|
|
113
|
+
new_msg.append({'sender_type': 'USER', 'sender_name': 'funtion', 'text': m['content']})
|
|
133
114
|
elif m['role'] == 'assistant':
|
|
134
|
-
new_msg.append({
|
|
135
|
-
'sender_type': 'BOT', 'sender_name': 'MM智能助理', 'text': m['content']
|
|
136
|
-
})
|
|
115
|
+
new_msg.append({'sender_type': 'BOT', 'sender_name': 'MM智能助理', 'text': m['content']})
|
|
137
116
|
|
|
138
117
|
request_body = {
|
|
139
|
-
|
|
118
|
+
'model': 'abab6-chat',
|
|
140
119
|
# "model": "abab5.5s-chat",
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
120
|
+
'tokens_to_generate': 8192,
|
|
121
|
+
'reply_constraints': {
|
|
122
|
+
'sender_type': 'BOT',
|
|
123
|
+
'sender_name': 'MM智能助理'
|
|
124
|
+
},
|
|
125
|
+
'messages': new_msg,
|
|
126
|
+
'bot_setting': [{
|
|
127
|
+
'bot_name': 'MM智能助理',
|
|
128
|
+
'content': system_prompt,
|
|
129
|
+
}],
|
|
150
130
|
}
|
|
151
131
|
response = requests.post(url, headers=headers, json=request_body)
|
|
152
132
|
status_code = response.status_code
|
|
153
133
|
for i in range(5):
|
|
154
134
|
try:
|
|
155
135
|
if status_code == 200:
|
|
156
|
-
reply = response.json()[
|
|
136
|
+
reply = response.json()['reply']
|
|
157
137
|
if len(reply) == 0:
|
|
158
|
-
print(
|
|
138
|
+
print('limit rate')
|
|
159
139
|
time.sleep(8)
|
|
160
140
|
continue
|
|
161
141
|
print(f'>>return: {reply}')
|
|
@@ -167,12 +147,12 @@ def minimax_call(context: list, system: str):
|
|
|
167
147
|
print(response)
|
|
168
148
|
time.sleep(5)
|
|
169
149
|
continue
|
|
170
|
-
return
|
|
150
|
+
return ''
|
|
171
151
|
|
|
172
152
|
|
|
173
153
|
def swift_call(context: list, system: str, swift_infer_obj):
|
|
174
154
|
query_d: dict = context[-1]
|
|
175
|
-
history_list = context[
|
|
155
|
+
history_list = context[:-1]
|
|
176
156
|
|
|
177
157
|
query: str = query_d['content']
|
|
178
158
|
history_msg = []
|
|
@@ -211,9 +191,8 @@ def run_infer(args: InferArgs):
|
|
|
211
191
|
|
|
212
192
|
if args.deploy_type == 'swift':
|
|
213
193
|
from evalscope.third_party.toolbench_static.llm.swift_infer import SwiftInfer, SwiftInferArgs
|
|
214
|
-
swift_infer_args = SwiftInferArgs(
|
|
215
|
-
|
|
216
|
-
max_new_tokens=args.max_new_tokens)
|
|
194
|
+
swift_infer_args = SwiftInferArgs(
|
|
195
|
+
model_id_or_path=args.model_name_or_path, model_type=args.model_type, max_new_tokens=args.max_new_tokens)
|
|
217
196
|
swift_infer = SwiftInfer(args=swift_infer_args)
|
|
218
197
|
else:
|
|
219
198
|
swift_infer = None
|
|
@@ -232,7 +211,7 @@ def run_infer(args: InferArgs):
|
|
|
232
211
|
preds = []
|
|
233
212
|
refs = []
|
|
234
213
|
for i, o in enumerate(infer_samples):
|
|
235
|
-
if i < len(processed_samples) and
|
|
214
|
+
if i < len(processed_samples) and 'predictions' in processed_samples[i].keys():
|
|
236
215
|
infer_samples[i]['predictions'] = processed_samples[i]['predictions']
|
|
237
216
|
refs.append(processed_samples[i]['target'])
|
|
238
217
|
preds.append(processed_samples[i]['predictions'])
|
|
@@ -267,7 +246,7 @@ def run_infer(args: InferArgs):
|
|
|
267
246
|
reference = infer_samples[i]['target']
|
|
268
247
|
infer_samples[i]['predictions'] = candidate
|
|
269
248
|
if reference.strip() in ['', '.', ',']:
|
|
270
|
-
reference =
|
|
249
|
+
reference = 'none'
|
|
271
250
|
refs.append(reference)
|
|
272
251
|
preds.append(candidate)
|
|
273
252
|
|
|
@@ -1,9 +1,6 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
|
|
4
|
-
from swift.llm import (
|
|
5
|
-
get_model_tokenizer, get_template, inference, get_default_template_type,
|
|
6
|
-
)
|
|
3
|
+
from swift.llm import get_default_template_type, get_model_tokenizer, get_template, inference
|
|
7
4
|
from swift.utils import seed_everything
|
|
8
5
|
|
|
9
6
|
# TODO: Support custom model for swift infer
|
|
@@ -21,9 +18,8 @@ class SwiftInfer:
|
|
|
21
18
|
def __init__(self, args: SwiftInferArgs):
|
|
22
19
|
model_type = args.model_type
|
|
23
20
|
template_type = get_default_template_type(model_type)
|
|
24
|
-
model, tokenizer = get_model_tokenizer(
|
|
25
|
-
|
|
26
|
-
model_kwargs={'device_map': 'auto'})
|
|
21
|
+
model, tokenizer = get_model_tokenizer(
|
|
22
|
+
model_type, model_id_or_path=args.model_id_or_path, model_kwargs={'device_map': 'auto'})
|
|
27
23
|
model.generation_config.max_new_tokens = args.max_new_tokens
|
|
28
24
|
print(f'** Generation config: {model.generation_config}')
|
|
29
25
|
|
|
@@ -36,10 +32,6 @@ class SwiftInfer:
|
|
|
36
32
|
|
|
37
33
|
def predict(self, system: str, query: str, history: list):
|
|
38
34
|
|
|
39
|
-
response, history = inference(self.model,
|
|
40
|
-
self.template,
|
|
41
|
-
query=query,
|
|
42
|
-
system=system,
|
|
43
|
-
history=history)
|
|
35
|
+
response, history = inference(self.model, self.template, query=query, system=system, history=history)
|
|
44
36
|
|
|
45
37
|
return response
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
ms-swift>=2.1.0
|
|
2
|
-
rouge
|
|
2
|
+
rouge
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
import os
|
|
3
|
-
from typing import Union
|
|
4
3
|
from copy import deepcopy
|
|
4
|
+
from typing import Union
|
|
5
5
|
|
|
6
|
-
from evalscope.third_party.toolbench_static.infer import InferArgs, run_infer
|
|
7
6
|
from evalscope.third_party.toolbench_static.eval import EvalArgs, run_eval
|
|
8
|
-
from evalscope.
|
|
7
|
+
from evalscope.third_party.toolbench_static.infer import InferArgs, run_infer
|
|
8
|
+
from evalscope.utils import get_logger, json_to_dict, yaml_to_dict
|
|
9
9
|
|
|
10
10
|
logger = get_logger()
|
|
11
11
|
|