evalscope 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalscope/__init__.py +3 -0
- evalscope/backend/__init__.py +3 -0
- evalscope/backend/base.py +27 -0
- evalscope/backend/opencompass/__init__.py +3 -0
- evalscope/backend/opencompass/api_meta_template.py +64 -0
- evalscope/backend/opencompass/backend_manager.py +247 -0
- evalscope/backend/opencompass/tasks/__init__.py +1 -0
- evalscope/backend/opencompass/tasks/eval_api.py +30 -0
- evalscope/backend/opencompass/tasks/eval_datasets.py +71 -0
- evalscope/backend/vlm_eval_kit/__init__.py +1 -0
- evalscope/backend/vlm_eval_kit/backend_manager.py +153 -0
- evalscope/benchmarks/__init__.py +4 -0
- evalscope/benchmarks/arc/__init__.py +5 -0
- evalscope/benchmarks/arc/ai2_arc.py +148 -0
- evalscope/benchmarks/arc/arc_adapter.py +231 -0
- evalscope/benchmarks/bbh/__init__.py +6 -0
- evalscope/benchmarks/bbh/bbh_adapter.py +308 -0
- evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +23 -0
- evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +33 -0
- evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +72 -0
- evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +78 -0
- evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +42 -0
- evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +25 -0
- evalscope/benchmarks/bbh/cot_prompts/navigate.txt +43 -0
- evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +37 -0
- evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +41 -0
- evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +63 -0
- evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +44 -0
- evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/snarks.txt +30 -0
- evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +10 -0
- evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +77 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +40 -0
- evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +28 -0
- evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +17 -0
- evalscope/benchmarks/benchmark.py +65 -0
- evalscope/benchmarks/ceval/__init__.py +5 -0
- evalscope/benchmarks/ceval/ceval_adapter.py +340 -0
- evalscope/benchmarks/ceval/ceval_exam.py +159 -0
- evalscope/benchmarks/cmmlu/__init__.py +5 -0
- evalscope/benchmarks/cmmlu/cmmlu.py +166 -0
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +369 -0
- evalscope/benchmarks/competition_math/__init__.py +5 -0
- evalscope/benchmarks/competition_math/competition_math.py +88 -0
- evalscope/benchmarks/competition_math/competition_math_adapter.py +470 -0
- evalscope/benchmarks/data_adapter.py +263 -0
- evalscope/benchmarks/general_qa/__init__.py +5 -0
- evalscope/benchmarks/general_qa/general_qa_adapter.py +186 -0
- evalscope/benchmarks/gsm8k/__init__.py +5 -0
- evalscope/benchmarks/gsm8k/gsm8k.py +127 -0
- evalscope/benchmarks/gsm8k/gsm8k_adapter.py +236 -0
- evalscope/benchmarks/hellaswag/__init__.py +5 -0
- evalscope/benchmarks/hellaswag/hellaswag.py +116 -0
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +222 -0
- evalscope/benchmarks/humaneval/__init__.py +5 -0
- evalscope/benchmarks/humaneval/humaneval.py +82 -0
- evalscope/benchmarks/humaneval/humaneval_adapter.py +21 -0
- evalscope/benchmarks/mmlu/__init__.py +5 -0
- evalscope/benchmarks/mmlu/mmlu.py +174 -0
- evalscope/benchmarks/mmlu/mmlu_adapter.py +375 -0
- evalscope/benchmarks/race/__init__.py +5 -0
- evalscope/benchmarks/race/race.py +118 -0
- evalscope/benchmarks/race/race_adapter.py +229 -0
- evalscope/benchmarks/trivia_qa/__init__.py +5 -0
- evalscope/benchmarks/trivia_qa/trivia_qa.py +104 -0
- evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +207 -0
- evalscope/benchmarks/truthful_qa/__init__.py +5 -0
- evalscope/benchmarks/truthful_qa/truthful_qa.py +167 -0
- evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +351 -0
- evalscope/cache.py +98 -0
- evalscope/cli/__init__.py +1 -0
- evalscope/cli/base.py +20 -0
- evalscope/cli/cli.py +26 -0
- evalscope/cli/start_perf.py +37 -0
- evalscope/cli/start_server.py +138 -0
- evalscope/config.py +165 -0
- evalscope/constants.py +150 -0
- evalscope/evaluator/__init__.py +3 -0
- evalscope/evaluator/evaluator.py +689 -0
- evalscope/evaluator/rating_eval.py +178 -0
- evalscope/evaluator/reviewer/__init__.py +1 -0
- evalscope/evaluator/reviewer/auto_reviewer.py +411 -0
- evalscope/metrics/__init__.py +1 -0
- evalscope/metrics/bundled_rouge_score/__init__.py +14 -0
- evalscope/metrics/bundled_rouge_score/rouge_scorer.py +342 -0
- evalscope/metrics/code_metric.py +104 -0
- evalscope/metrics/math_accuracy.py +60 -0
- evalscope/metrics/metrics.py +405 -0
- evalscope/metrics/rouge_metric.py +129 -0
- evalscope/models/__init__.py +4 -0
- evalscope/models/custom/__init__.py +4 -0
- evalscope/models/custom/custom_model.py +53 -0
- evalscope/models/dummy_chat_model.py +50 -0
- evalscope/models/model.py +88 -0
- evalscope/models/model_adapter.py +586 -0
- evalscope/models/openai_model.py +103 -0
- evalscope/models/template.py +1446 -0
- evalscope/perf/__init__.py +0 -0
- evalscope/perf/_logging.py +32 -0
- evalscope/perf/api_plugin_base.py +60 -0
- evalscope/perf/custom_api.py +87 -0
- evalscope/perf/dashscope_api.py +84 -0
- evalscope/perf/dataset_plugin_base.py +64 -0
- evalscope/perf/datasets/__init__.py +0 -0
- evalscope/perf/datasets/line_by_line.py +18 -0
- evalscope/perf/datasets/longalpaca_12k.py +20 -0
- evalscope/perf/datasets/openqa.py +22 -0
- evalscope/perf/how_to_analysis_result.py +24 -0
- evalscope/perf/http_client.py +756 -0
- evalscope/perf/openai_api.py +130 -0
- evalscope/perf/plugin_registry.py +35 -0
- evalscope/perf/query_parameters.py +42 -0
- evalscope/perf/server_sent_event.py +43 -0
- evalscope/preprocess/__init__.py +1 -0
- evalscope/preprocess/tokenizers/__init__.py +0 -0
- evalscope/preprocess/tokenizers/gpt2_tokenizer.py +221 -0
- evalscope/registry/__init__.py +1 -0
- evalscope/registry/tasks/arc.yaml +29 -0
- evalscope/registry/tasks/bbh.yaml +27 -0
- evalscope/registry/tasks/bbh_mini.yaml +27 -0
- evalscope/registry/tasks/ceval.yaml +27 -0
- evalscope/registry/tasks/ceval_mini.yaml +27 -0
- evalscope/registry/tasks/cmmlu.yaml +27 -0
- evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +28 -0
- evalscope/registry/tasks/general_qa.yaml +27 -0
- evalscope/registry/tasks/gsm8k.yaml +29 -0
- evalscope/registry/tasks/mmlu.yaml +29 -0
- evalscope/registry/tasks/mmlu_mini.yaml +27 -0
- evalscope/run.py +404 -0
- evalscope/run_arena.py +204 -0
- evalscope/run_ms.py +140 -0
- evalscope/summarizer.py +144 -0
- evalscope/third_party/__init__.py +1 -0
- evalscope/third_party/toolbench_static/__init__.py +3 -0
- evalscope/third_party/toolbench_static/eval.py +219 -0
- evalscope/third_party/toolbench_static/infer.py +278 -0
- evalscope/third_party/toolbench_static/llm/__init__.py +1 -0
- evalscope/third_party/toolbench_static/llm/swift_infer.py +45 -0
- evalscope/third_party/toolbench_static/toolbench_static.py +50 -0
- evalscope/tools/__init__.py +1 -0
- evalscope/tools/combine_reports.py +140 -0
- evalscope/tools/gen_mmlu_subject_mapping.py +90 -0
- evalscope/tools/rewrite_eval_results.py +95 -0
- evalscope/utils/__init__.py +4 -0
- evalscope/utils/arena_utils.py +247 -0
- evalscope/utils/completion_parsers.py +87 -0
- evalscope/utils/logger.py +64 -0
- evalscope/utils/task_cfg_parser.py +10 -0
- evalscope/utils/task_utils.py +19 -0
- evalscope/utils/utils.py +625 -0
- evalscope/version.py +4 -0
- evalscope-0.5.0.dist-info/METADATA +566 -0
- evalscope-0.5.0.dist-info/RECORD +165 -0
- evalscope-0.5.0.dist-info/WHEEL +5 -0
- evalscope-0.5.0.dist-info/entry_points.txt +3 -0
- evalscope-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1446 @@
|
|
|
1
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
|
+
import re
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import requests
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
13
|
+
from transformers import PreTrainedTokenizerBase, StoppingCriteria
|
|
14
|
+
|
|
15
|
+
from evalscope.utils.utils import calculate_loss_scale, pad_and_split_batch, get_dist_setting, use_torchacc
|
|
16
|
+
|
|
17
|
+
DEFAULT_SYSTEM = 'You are a helpful assistant.'
|
|
18
|
+
History = List[Union[Tuple[str, str], List[str]]]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TemplateType:
|
|
22
|
+
# text-generation
|
|
23
|
+
default_generation = 'default-generation'
|
|
24
|
+
default_generation_bos = 'default-generation-bos'
|
|
25
|
+
chatglm_generation = 'chatglm-generation'
|
|
26
|
+
qwen_audio_generation = 'qwen-audio-generation'
|
|
27
|
+
# chat
|
|
28
|
+
default = 'default'
|
|
29
|
+
qwen = 'qwen'
|
|
30
|
+
qwen_audio = 'qwen-audio'
|
|
31
|
+
baichuan = 'baichuan'
|
|
32
|
+
chatglm2 = 'chatglm2'
|
|
33
|
+
chatglm3 = 'chatglm3'
|
|
34
|
+
llama = 'llama' # llama2
|
|
35
|
+
llama3 = 'llama3'
|
|
36
|
+
llava_mistral_instruct = 'llava-mistral-instruct'
|
|
37
|
+
llava_yi_instruct = 'llava-yi-instruct'
|
|
38
|
+
openbuddy = 'openbuddy'
|
|
39
|
+
internlm = 'internlm'
|
|
40
|
+
internlm2 = 'internlm2'
|
|
41
|
+
internlm_xcomposer2 = 'internlm-xcomposer2'
|
|
42
|
+
yi = 'yi'
|
|
43
|
+
yi_vl = 'yi-vl'
|
|
44
|
+
yuan = 'yuan'
|
|
45
|
+
xverse = 'xverse'
|
|
46
|
+
ziya = 'ziya'
|
|
47
|
+
skywork = 'skywork'
|
|
48
|
+
bluelm = 'bluelm'
|
|
49
|
+
zephyr = 'zephyr'
|
|
50
|
+
sus = 'sus'
|
|
51
|
+
deepseek = 'deepseek'
|
|
52
|
+
deepseek_coder = 'deepseek-coder'
|
|
53
|
+
deepseek_vl = 'deepseek-vl'
|
|
54
|
+
codefuse_codellama = 'codefuse-codellama'
|
|
55
|
+
codefuse = 'codefuse'
|
|
56
|
+
cogvlm_instruct = 'cogvlm-instruct'
|
|
57
|
+
cogagent_chat = 'cogagent-chat'
|
|
58
|
+
cogagent_instruct = 'cogagent-instruct'
|
|
59
|
+
orion = 'orion'
|
|
60
|
+
minicpm = 'minicpm'
|
|
61
|
+
minicpm_v = 'minicpm-v'
|
|
62
|
+
gemma = 'gemma'
|
|
63
|
+
mplug_owl2 = 'mplug-owl2'
|
|
64
|
+
wizardlm2_awq = 'wizardlm2-awq'
|
|
65
|
+
wizardlm2 = 'wizardlm2'
|
|
66
|
+
atom = 'atom'
|
|
67
|
+
# compatibility. (Deprecated)
|
|
68
|
+
chatml = 'chatml'
|
|
69
|
+
telechat = 'telechat'
|
|
70
|
+
dbrx = 'dbrx'
|
|
71
|
+
mengzi = 'mengzi'
|
|
72
|
+
c4ai = 'c4ai'
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_template_name_list(cls) -> List[str]:
|
|
76
|
+
res = []
|
|
77
|
+
for k in cls.__dict__.keys():
|
|
78
|
+
if k.startswith('__') or k == 'get_template_name_list':
|
|
79
|
+
continue
|
|
80
|
+
res.append(cls.__dict__[k])
|
|
81
|
+
return res
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
Prompt = List[Union[str, List[Union[str, int]]]]
|
|
85
|
+
StopWords = Prompt
|
|
86
|
+
|
|
87
|
+
Context = Union[str, List[int]]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class StopWordsCriteria(StoppingCriteria):
|
|
91
|
+
# The returned sentence includes stop words.
|
|
92
|
+
def __init__(self, tokenizer: PreTrainedTokenizerBase,
|
|
93
|
+
stop_words: StopWords, **tokenizer_kwargs) -> None:
|
|
94
|
+
self.tokenizer = tokenizer
|
|
95
|
+
self.stop_words = stop_words
|
|
96
|
+
self.tokenizer_kwargs = tokenizer_kwargs
|
|
97
|
+
self.start_idx = -1
|
|
98
|
+
|
|
99
|
+
def __call__(self, input_ids: Tensor, scores: Tensor) -> bool:
|
|
100
|
+
if self.start_idx == -1:
|
|
101
|
+
self.start_idx = len(input_ids[0]) - 1
|
|
102
|
+
tokenizer = self.tokenizer
|
|
103
|
+
stop_words = self.stop_words
|
|
104
|
+
text = tokenizer.decode(input_ids[0, self.start_idx:],
|
|
105
|
+
**self.tokenizer_kwargs)
|
|
106
|
+
for stop_word in stop_words:
|
|
107
|
+
if isinstance(stop_word, str):
|
|
108
|
+
if stop_word in text:
|
|
109
|
+
return True
|
|
110
|
+
else: # list
|
|
111
|
+
if len(stop_word) > 0 and input_ids[0].tolist(
|
|
112
|
+
)[-len(stop_word):] == stop_word:
|
|
113
|
+
return True
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _has_system(prefix: Prompt) -> bool:
|
|
118
|
+
for p in prefix:
|
|
119
|
+
if '{{SYSTEM}}' in p:
|
|
120
|
+
return True
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _replace_system(prefix: Prompt) -> Prompt:
|
|
125
|
+
res = []
|
|
126
|
+
for p in prefix:
|
|
127
|
+
if '{{SYSTEM}}' in p:
|
|
128
|
+
p = p.replace('{{SYSTEM}}', '')
|
|
129
|
+
res.append(p)
|
|
130
|
+
return res
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Template:
|
|
134
|
+
|
|
135
|
+
def __init__(self,
|
|
136
|
+
prefix: Prompt,
|
|
137
|
+
prompt: Prompt,
|
|
138
|
+
chat_sep: Optional[Prompt],
|
|
139
|
+
suffix: Prompt,
|
|
140
|
+
default_system: Optional[str] = None,
|
|
141
|
+
prefix_has_system: Optional[Prompt] = None) -> None:
|
|
142
|
+
if default_system == '':
|
|
143
|
+
default_system = None
|
|
144
|
+
if _has_system(prefix):
|
|
145
|
+
assert prefix_has_system is None, 'The prefix already contains {{SYSTEM}}.'
|
|
146
|
+
prefix_has_system = prefix
|
|
147
|
+
prefix = _replace_system(prefix)
|
|
148
|
+
self.prefix = prefix
|
|
149
|
+
self.prefix_has_system = prefix_has_system
|
|
150
|
+
if self.prefix_has_system is None:
|
|
151
|
+
assert default_system is None, 'The template does not support `system`.'
|
|
152
|
+
self.prompt = prompt
|
|
153
|
+
self.chat_sep = chat_sep
|
|
154
|
+
self.support_multi_round = self.chat_sep is not None
|
|
155
|
+
self.suffix = suffix
|
|
156
|
+
self.default_system = default_system
|
|
157
|
+
self.use_default_system = True
|
|
158
|
+
self._is_init = False
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def _preprocess_prompt(tokenizer: PreTrainedTokenizerBase,
|
|
162
|
+
value: Optional[Prompt]) -> Optional[Prompt]:
|
|
163
|
+
# e.g. [['eos_token_id']] -> [[2]]
|
|
164
|
+
if value is None:
|
|
165
|
+
return None
|
|
166
|
+
res_value = []
|
|
167
|
+
for v in value:
|
|
168
|
+
if isinstance(v, list):
|
|
169
|
+
res_v = []
|
|
170
|
+
for sub_v in v:
|
|
171
|
+
if isinstance(sub_v, str):
|
|
172
|
+
sub_v = getattr(tokenizer, sub_v)
|
|
173
|
+
res_v.append(sub_v)
|
|
174
|
+
v = res_v
|
|
175
|
+
res_value.append(v)
|
|
176
|
+
return res_value
|
|
177
|
+
|
|
178
|
+
def _init_template(self,
|
|
179
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
180
|
+
default_system: Optional[str] = None,
|
|
181
|
+
max_length: Optional[int] = None,
|
|
182
|
+
truncation_strategy: Literal[
|
|
183
|
+
'delete', 'truncation_left'] = 'delete',
|
|
184
|
+
**kwargs) -> None:
|
|
185
|
+
assert self._is_init is False, 'The template has been initialized.'
|
|
186
|
+
self._is_init = True
|
|
187
|
+
self.tokenizer = tokenizer
|
|
188
|
+
# if default_system is None. not change self.default_system
|
|
189
|
+
if default_system == '':
|
|
190
|
+
self.default_system = None
|
|
191
|
+
elif default_system is not None:
|
|
192
|
+
assert self.prefix_has_system is not None, 'The template does not support `system`.'
|
|
193
|
+
self.default_system = default_system
|
|
194
|
+
self.max_length = max_length
|
|
195
|
+
self.truncation_strategy = truncation_strategy
|
|
196
|
+
self.model = kwargs.get('model', None)
|
|
197
|
+
self.use_loss_scale = kwargs.get('use_loss_scale', False)
|
|
198
|
+
for key in [
|
|
199
|
+
'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system'
|
|
200
|
+
]:
|
|
201
|
+
value = getattr(self, key)
|
|
202
|
+
value = self._preprocess_prompt(tokenizer, value)
|
|
203
|
+
setattr(self, key, value)
|
|
204
|
+
|
|
205
|
+
def encode(
|
|
206
|
+
self, example: Dict[str,
|
|
207
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
208
|
+
"""return: inputs, tokenizer_kwargs"""
|
|
209
|
+
if not self._is_init:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
'Template is not initialized, please use the `get_template` function to obtain the template.'
|
|
212
|
+
)
|
|
213
|
+
query: Optional[str] = example.get('query', None)
|
|
214
|
+
response: Optional[str] = example.get('response', None)
|
|
215
|
+
history: Optional[History] = example.get('history', None)
|
|
216
|
+
system: Optional[str] = example.get('system', None)
|
|
217
|
+
if history is None:
|
|
218
|
+
history = []
|
|
219
|
+
if len(history) > 0:
|
|
220
|
+
assert self.support_multi_round, 'The template does not support multi-round chat.'
|
|
221
|
+
if system is None:
|
|
222
|
+
if self.use_default_system:
|
|
223
|
+
system = self.default_system
|
|
224
|
+
elif system == '':
|
|
225
|
+
system = None
|
|
226
|
+
else:
|
|
227
|
+
assert self.prefix_has_system is not None, 'The template does not support `system`.'
|
|
228
|
+
if query is None:
|
|
229
|
+
query = ''
|
|
230
|
+
inputs, tokenizer_kwargs = self._encode(query, response, history,
|
|
231
|
+
system,
|
|
232
|
+
self.truncation_strategy)
|
|
233
|
+
if inputs.get('labels') is None:
|
|
234
|
+
inputs.pop('loss_scale', None)
|
|
235
|
+
return inputs, tokenizer_kwargs
|
|
236
|
+
|
|
237
|
+
def _concat_context_list(
|
|
238
|
+
self,
|
|
239
|
+
context_list: List[Context],
|
|
240
|
+
res_context_list: List[Context], # inplace
|
|
241
|
+
compute_loss_idx: List[float], # inplace
|
|
242
|
+
system: Optional[str] = None,
|
|
243
|
+
query: Optional[str] = None,
|
|
244
|
+
response: Optional[str] = None,
|
|
245
|
+
round0: Optional[int] = None,
|
|
246
|
+
) -> None:
|
|
247
|
+
# concat context list and replace placeholder
|
|
248
|
+
round1 = None
|
|
249
|
+
if round0 is not None:
|
|
250
|
+
round1 = str(round0 + 1)
|
|
251
|
+
round0 = str(round0)
|
|
252
|
+
for context in context_list:
|
|
253
|
+
if isinstance(context, str):
|
|
254
|
+
if '{{RESPONSE}}' == context:
|
|
255
|
+
assert response is not None
|
|
256
|
+
content_part, weight_part = calculate_loss_scale(
|
|
257
|
+
response, self.use_loss_scale)
|
|
258
|
+
res_context_list.extend(content_part)
|
|
259
|
+
compute_loss_idx.extend(weight_part)
|
|
260
|
+
continue
|
|
261
|
+
old_str_list = [
|
|
262
|
+
'{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}'
|
|
263
|
+
]
|
|
264
|
+
new_str_list = [system, query, round0, round1]
|
|
265
|
+
for (old_str, new_str) in zip(old_str_list, new_str_list):
|
|
266
|
+
if new_str is not None and old_str in context:
|
|
267
|
+
context = context.replace(old_str, new_str)
|
|
268
|
+
res_context_list.append(context)
|
|
269
|
+
compute_loss_idx.append(0.0 if context not in self.suffix else 1.0)
|
|
270
|
+
|
|
271
|
+
@staticmethod
|
|
272
|
+
def _simplify_context_list(
|
|
273
|
+
context_list: List[Context], compute_loss_idx: List[float]
|
|
274
|
+
) -> Tuple[List[Context], List[float]]:
|
|
275
|
+
res: List[Context] = [] # result of context_list
|
|
276
|
+
res_idx: List[float] = [] # result of compute_loss_idx
|
|
277
|
+
temp: List[str] = []
|
|
278
|
+
temp_index: List[int] = []
|
|
279
|
+
for i, (context,
|
|
280
|
+
loss_idx) in enumerate(zip(context_list, compute_loss_idx)):
|
|
281
|
+
if isinstance(context, str) and compute_loss_idx[i] == 0.0:
|
|
282
|
+
temp.append(context)
|
|
283
|
+
temp_index.append(i)
|
|
284
|
+
else:
|
|
285
|
+
if len(temp) > 0:
|
|
286
|
+
res.append(''.join(temp))
|
|
287
|
+
res_idx.append(0.0)
|
|
288
|
+
temp.clear()
|
|
289
|
+
res.append(context)
|
|
290
|
+
res_idx.append(loss_idx)
|
|
291
|
+
if len(temp) > 0:
|
|
292
|
+
res.append(''.join(temp))
|
|
293
|
+
res_idx.append(0.0)
|
|
294
|
+
return res, res_idx
|
|
295
|
+
|
|
296
|
+
def _encode_context_list(
|
|
297
|
+
self,
|
|
298
|
+
context_list: List[Context],
|
|
299
|
+
compute_loss_idx: List[float],
|
|
300
|
+
) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
|
|
301
|
+
"""return: input_ids, labels, tokenizer_kwargs"""
|
|
302
|
+
tokenizer = self.tokenizer
|
|
303
|
+
input_ids: List[int] = []
|
|
304
|
+
labels: List[int] = []
|
|
305
|
+
loss_scale: List[float] = []
|
|
306
|
+
tokenizer_kwargs = {}
|
|
307
|
+
for i, (context,
|
|
308
|
+
loss_weight) in enumerate(zip(context_list, compute_loss_idx)):
|
|
309
|
+
if isinstance(context, str):
|
|
310
|
+
curr_tokenizer_kwargs = self.get_tokenizer_kwargs(context)
|
|
311
|
+
self.concat_tokenizer_kwargs(tokenizer_kwargs,
|
|
312
|
+
curr_tokenizer_kwargs)
|
|
313
|
+
token_list = tokenizer(
|
|
314
|
+
context,
|
|
315
|
+
return_attention_mask=False,
|
|
316
|
+
add_special_tokens=False,
|
|
317
|
+
**curr_tokenizer_kwargs)['input_ids']
|
|
318
|
+
else:
|
|
319
|
+
token_list = context
|
|
320
|
+
input_ids += token_list
|
|
321
|
+
if compute_loss_idx[i] > 0.0:
|
|
322
|
+
labels += token_list
|
|
323
|
+
else:
|
|
324
|
+
labels += [-100] * len(token_list)
|
|
325
|
+
loss_scale.extend([loss_weight] * len(token_list))
|
|
326
|
+
return input_ids, labels, loss_scale, tokenizer_kwargs
|
|
327
|
+
|
|
328
|
+
def _encode(
|
|
329
|
+
self, query: str, response: Optional[str], history: History,
|
|
330
|
+
system: Optional[str],
|
|
331
|
+
truncation_strategy: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
332
|
+
"""
|
|
333
|
+
return: inputs, tokenizer_kwargs
|
|
334
|
+
"""
|
|
335
|
+
history = history.copy()
|
|
336
|
+
res_context_list: List[Context] = []
|
|
337
|
+
compute_loss_idx: List[float] = []
|
|
338
|
+
if system is None:
|
|
339
|
+
prefix = self.prefix
|
|
340
|
+
else:
|
|
341
|
+
prefix = self.prefix_has_system
|
|
342
|
+
self._concat_context_list(
|
|
343
|
+
prefix, res_context_list, compute_loss_idx, system=system)
|
|
344
|
+
history.append([query, response])
|
|
345
|
+
for i, (q, r) in enumerate(history):
|
|
346
|
+
context_list = self.prompt.copy()
|
|
347
|
+
if i < len(history) - 1:
|
|
348
|
+
context_list.append('{{RESPONSE}}')
|
|
349
|
+
context_list += self.chat_sep
|
|
350
|
+
elif r is not None:
|
|
351
|
+
# last response
|
|
352
|
+
context_list.append('{{RESPONSE}}')
|
|
353
|
+
context_list += self.suffix
|
|
354
|
+
if q or r:
|
|
355
|
+
self._concat_context_list(
|
|
356
|
+
context_list,
|
|
357
|
+
res_context_list,
|
|
358
|
+
compute_loss_idx,
|
|
359
|
+
query=q,
|
|
360
|
+
response=r,
|
|
361
|
+
round0=i)
|
|
362
|
+
|
|
363
|
+
res_context_list, compute_loss_idx = self._simplify_context_list(
|
|
364
|
+
res_context_list, compute_loss_idx)
|
|
365
|
+
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
|
|
366
|
+
res_context_list, compute_loss_idx)
|
|
367
|
+
|
|
368
|
+
if response is None:
|
|
369
|
+
labels = None
|
|
370
|
+
|
|
371
|
+
if self.max_length is not None:
|
|
372
|
+
if truncation_strategy == 'delete' and len(
|
|
373
|
+
input_ids) > self.max_length:
|
|
374
|
+
return {}, {}
|
|
375
|
+
input_ids = input_ids[-self.max_length:]
|
|
376
|
+
if labels is not None:
|
|
377
|
+
labels = labels[-self.max_length:]
|
|
378
|
+
if loss_scale is not None:
|
|
379
|
+
loss_scale = loss_scale[-self.max_length:]
|
|
380
|
+
inputs = {
|
|
381
|
+
'input_ids': input_ids,
|
|
382
|
+
'labels': labels,
|
|
383
|
+
}
|
|
384
|
+
if self.use_loss_scale:
|
|
385
|
+
inputs['loss_scale'] = loss_scale
|
|
386
|
+
return inputs, tokenizer_kwargs
|
|
387
|
+
|
|
388
|
+
def get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
|
|
389
|
+
"""return: curr_tokenizer_kwargs"""
|
|
390
|
+
return {}
|
|
391
|
+
|
|
392
|
+
def concat_tokenizer_kwargs(
|
|
393
|
+
self, old_tokenizer_kwargs: Dict[str, Any],
|
|
394
|
+
curr_tokenizer_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
395
|
+
assert len(old_tokenizer_kwargs) == 0
|
|
396
|
+
return curr_tokenizer_kwargs
|
|
397
|
+
|
|
398
|
+
def data_collator(self,
|
|
399
|
+
batch: List[Dict[str, Any]],
|
|
400
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
401
|
+
"""
|
|
402
|
+
Args:
|
|
403
|
+
batch(`List[Dict[str, Any]]`): The input data in batch
|
|
404
|
+
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
|
|
405
|
+
will be padded to the `longest`
|
|
406
|
+
"""
|
|
407
|
+
tokenizer = self.tokenizer
|
|
408
|
+
assert tokenizer.pad_token_id is not None
|
|
409
|
+
input_ids = [torch.tensor(b['input_ids']) for b in batch]
|
|
410
|
+
labels = [torch.tensor(b['labels']) for b in batch]
|
|
411
|
+
loss_scale = [torch.tensor(b['loss_scale'])
|
|
412
|
+
for b in batch] if 'loss_scale' in batch[0] else None
|
|
413
|
+
attention_mask = [
|
|
414
|
+
torch.ones(len(input_ids[i]), dtype=torch.int64)
|
|
415
|
+
for i in range(len(input_ids))
|
|
416
|
+
]
|
|
417
|
+
|
|
418
|
+
if padding_to is not None:
|
|
419
|
+
padding_len = padding_to - input_ids[0].shape[-1]
|
|
420
|
+
if padding_len > 0:
|
|
421
|
+
input_ids[0] = F.pad(input_ids[0], (0, padding_len),
|
|
422
|
+
'constant', tokenizer.pad_token_id)
|
|
423
|
+
attention_mask[0] = F.pad(attention_mask[0], (0, padding_len),
|
|
424
|
+
'constant', 0)
|
|
425
|
+
labels[0] = F.pad(labels[0], (0, padding_len), 'constant',
|
|
426
|
+
-100)
|
|
427
|
+
if loss_scale:
|
|
428
|
+
loss_scale[0] = F.pad(
|
|
429
|
+
loss_scale[0], (0, padding_to - labels[0].shape[-1]),
|
|
430
|
+
'constant', 0.)
|
|
431
|
+
|
|
432
|
+
input_ids = pad_sequence(
|
|
433
|
+
input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
|
434
|
+
attention_mask = pad_sequence(
|
|
435
|
+
attention_mask, batch_first=True, padding_value=0)
|
|
436
|
+
if loss_scale:
|
|
437
|
+
loss_scale = pad_sequence(
|
|
438
|
+
loss_scale, batch_first=True, padding_value=0.)
|
|
439
|
+
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
|
440
|
+
|
|
441
|
+
if use_torchacc():
|
|
442
|
+
rank, _, world_size, _ = get_dist_setting()
|
|
443
|
+
input_ids, attention_mask, labels, loss_scale = pad_and_split_batch(
|
|
444
|
+
padding_to, input_ids, attention_mask, labels, loss_scale,
|
|
445
|
+
self.max_length, self.tokenizer, rank, world_size)
|
|
446
|
+
|
|
447
|
+
res = {
|
|
448
|
+
'input_ids': input_ids,
|
|
449
|
+
'attention_mask': attention_mask,
|
|
450
|
+
'labels': labels,
|
|
451
|
+
}
|
|
452
|
+
if loss_scale is not None:
|
|
453
|
+
res['loss_scale'] = loss_scale
|
|
454
|
+
return res
|
|
455
|
+
|
|
456
|
+
@staticmethod
|
|
457
|
+
def get_generate_ids(generate_ids: Tensor,
|
|
458
|
+
input_token_len: int) -> List[int]:
|
|
459
|
+
return generate_ids[0, input_token_len:].tolist()
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {}
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def register_template(template_type: str,
|
|
466
|
+
template: Template,
|
|
467
|
+
*,
|
|
468
|
+
exists_ok: bool = False,
|
|
469
|
+
**kwargs) -> None:
|
|
470
|
+
if not exists_ok and template_type in TEMPLATE_MAPPING:
|
|
471
|
+
raise ValueError(
|
|
472
|
+
f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.'
|
|
473
|
+
)
|
|
474
|
+
template_info = {'template': template, **kwargs}
|
|
475
|
+
TEMPLATE_MAPPING[template_type] = template_info
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
register_template(
|
|
479
|
+
TemplateType.default,
|
|
480
|
+
Template([], ['### Human:\n', '{{QUERY}}\n\n', '### Assistant:\n'],
|
|
481
|
+
['\n\n'], [['eos_token_id']], DEFAULT_SYSTEM, ['{{SYSTEM}}\n\n']))
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
# You can set the query as '' to serve as a template for pre-training.
|
|
485
|
+
class DefaultGenerationTemplate(Template):
|
|
486
|
+
|
|
487
|
+
def __init__(self):
|
|
488
|
+
super().__init__([], ['{{QUERY}}'], None, [['eos_token_id']])
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
register_template(TemplateType.default_generation, DefaultGenerationTemplate())
|
|
492
|
+
register_template(
|
|
493
|
+
TemplateType.default_generation_bos,
|
|
494
|
+
Template([['bos_token_id']], ['{{QUERY}}'], None, [['eos_token_id']]))
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
class QwenTemplate(Template):
|
|
498
|
+
|
|
499
|
+
def __init__(self):
|
|
500
|
+
super().__init__(
|
|
501
|
+
[],
|
|
502
|
+
['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
|
|
503
|
+
['<|im_end|>\n'], ['<|im_end|>'], DEFAULT_SYSTEM,
|
|
504
|
+
['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n'])
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
register_template(TemplateType.qwen, QwenTemplate())
|
|
508
|
+
register_template(TemplateType.chatml, QwenTemplate())
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
class _QwenAudioTemplateMixin:
|
|
512
|
+
|
|
513
|
+
def encode(
|
|
514
|
+
self, example: Dict[str,
|
|
515
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
516
|
+
inputs, tokenizer_kwargs = super().encode(example)
|
|
517
|
+
inputs.pop('loss_scale', None)
|
|
518
|
+
inputs.update(tokenizer_kwargs)
|
|
519
|
+
return inputs, tokenizer_kwargs
|
|
520
|
+
|
|
521
|
+
def get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
|
|
522
|
+
return {'audio_info': self.tokenizer.process_audio(context)}
|
|
523
|
+
|
|
524
|
+
def concat_tokenizer_kwargs(self, tokenizer_kwargs: Dict[str, Any],
|
|
525
|
+
curr_tokenizer_kwargs: Dict[str, Any]) -> None:
|
|
526
|
+
audio_info = curr_tokenizer_kwargs.get('audio_info')
|
|
527
|
+
old_audio_info = tokenizer_kwargs.get('audio_info')
|
|
528
|
+
if old_audio_info is None:
|
|
529
|
+
tokenizer_kwargs['audio_info'] = audio_info
|
|
530
|
+
elif audio_info is not None:
|
|
531
|
+
for k in ['input_audios', 'input_audio_lengths']:
|
|
532
|
+
old_audio_info[k] = torch.concat(
|
|
533
|
+
[old_audio_info[k], audio_info[k]], dim=0)
|
|
534
|
+
for k in ['audio_span_tokens', 'audio_urls']:
|
|
535
|
+
old_audio_info[k] = old_audio_info[k] + audio_info[k]
|
|
536
|
+
|
|
537
|
+
def data_collator(self,
|
|
538
|
+
batch: List[Dict[str, Any]],
|
|
539
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
540
|
+
res = super().data_collator(batch, padding_to)
|
|
541
|
+
if batch[0].get('audio_info') is not None:
|
|
542
|
+
res['audio_info'] = [b['audio_info'] for b in batch]
|
|
543
|
+
return res
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
class QwenAudioTemplate(_QwenAudioTemplateMixin, QwenTemplate):
|
|
547
|
+
pass
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class QwenAudioGenerationTemplate(_QwenAudioTemplateMixin,
|
|
551
|
+
DefaultGenerationTemplate):
|
|
552
|
+
pass
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
register_template(
|
|
556
|
+
TemplateType.qwen_audio, QwenAudioTemplate(), lazy_tokenize=True)
|
|
557
|
+
register_template(
|
|
558
|
+
TemplateType.qwen_audio_generation,
|
|
559
|
+
QwenAudioGenerationTemplate(),
|
|
560
|
+
lazy_tokenize=True)
|
|
561
|
+
|
|
562
|
+
register_template(
|
|
563
|
+
TemplateType.yi,
|
|
564
|
+
Template(
|
|
565
|
+
[], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
|
|
566
|
+
['<|im_end|>\n'], ['<|im_end|>'], None,
|
|
567
|
+
['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']))
|
|
568
|
+
|
|
569
|
+
yi_vl_default_system = (
|
|
570
|
+
'This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. '
|
|
571
|
+
"Read all the images carefully, and respond to the human's questions with informative, "
|
|
572
|
+
'helpful, detailed and polite answers. '
|
|
573
|
+
'这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。'
|
|
574
|
+
'仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。')
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def _read_from_path(
|
|
578
|
+
img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image':
|
|
579
|
+
from PIL import Image
|
|
580
|
+
if isinstance(img_path, str):
|
|
581
|
+
img_path = img_path.strip()
|
|
582
|
+
if img_path.startswith('http'):
|
|
583
|
+
content = requests.get(img_path).content
|
|
584
|
+
image = Image.open(BytesIO(content))
|
|
585
|
+
else:
|
|
586
|
+
image = Image.open(img_path)
|
|
587
|
+
else:
|
|
588
|
+
image = img_path
|
|
589
|
+
if image.mode != 'RGB':
|
|
590
|
+
image = image.convert('RGB')
|
|
591
|
+
return image
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
class YiVLTemplate(Template):
|
|
595
|
+
|
|
596
|
+
def encode(
|
|
597
|
+
self, example: Dict[str,
|
|
598
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
599
|
+
inputs, _ = super().encode(example)
|
|
600
|
+
inputs.pop('loss_scale', None)
|
|
601
|
+
from llava.mm_utils import expand2square
|
|
602
|
+
model = self.model.model
|
|
603
|
+
if not hasattr(model, 'vision_tower'):
|
|
604
|
+
model = model.model
|
|
605
|
+
image_processor = model.vision_tower.image_processor
|
|
606
|
+
images_path = example['images']
|
|
607
|
+
images = []
|
|
608
|
+
for image_path in images_path:
|
|
609
|
+
image = _read_from_path(image_path)
|
|
610
|
+
background_color = tuple(
|
|
611
|
+
int(x * 255) for x in image_processor.image_mean)
|
|
612
|
+
image = expand2square(image, background_color)
|
|
613
|
+
images.append(image)
|
|
614
|
+
image_tensor = image_processor.preprocess(
|
|
615
|
+
images, return_tensors='pt')['pixel_values']
|
|
616
|
+
inputs['images'] = image_tensor.to(model.dtype)
|
|
617
|
+
return inputs, {}
|
|
618
|
+
|
|
619
|
+
def data_collator(self,
|
|
620
|
+
batch: List[Dict[str, Any]],
|
|
621
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
622
|
+
res = super().data_collator(batch, padding_to)
|
|
623
|
+
res['images'] = torch.concat([b['images'] for b in batch])
|
|
624
|
+
return res
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
register_template(
|
|
628
|
+
TemplateType.yi_vl,
|
|
629
|
+
YiVLTemplate([], ['### Human: ', [-200], '\n{{QUERY}}\n### Assistant:'],
|
|
630
|
+
['\n'], ['\n###'], yi_vl_default_system, ['{{SYSTEM}}\n\n']),
|
|
631
|
+
use_model=True,
|
|
632
|
+
infer_media_type='round',
|
|
633
|
+
lazy_tokenize=True)
|
|
634
|
+
|
|
635
|
+
register_template(
|
|
636
|
+
TemplateType.baichuan,
|
|
637
|
+
Template(['{{SYSTEM}}'], [[195], '{{QUERY}}', [196]], [],
|
|
638
|
+
[['eos_token_id']]))
|
|
639
|
+
register_template(
|
|
640
|
+
TemplateType.chatglm2,
|
|
641
|
+
Template([[64790, 64792], '{{SYSTEM}}'],
|
|
642
|
+
['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'], ['\n\n'],
|
|
643
|
+
[['eos_token_id']]))
|
|
644
|
+
|
|
645
|
+
register_template(
|
|
646
|
+
TemplateType.chatglm_generation,
|
|
647
|
+
Template([[64790, 64792]], ['{{QUERY}}'], None, [['eos_token_id']]))
|
|
648
|
+
|
|
649
|
+
register_template(
|
|
650
|
+
TemplateType.chatglm3,
|
|
651
|
+
Template([[64790, 64792]], [[64795], '\n {{QUERY}}', [64796], '\n'], [],
|
|
652
|
+
[['eos_token_id']], None,
|
|
653
|
+
[[64790, 64792, 64794], '\n {{SYSTEM}}']))
|
|
654
|
+
|
|
655
|
+
register_template(
|
|
656
|
+
TemplateType.deepseek,
|
|
657
|
+
Template([['bos_token_id']], ['User: {{QUERY}}\n\nAssistant:'],
|
|
658
|
+
[['eos_token_id']], [['eos_token_id']], None,
|
|
659
|
+
[['bos_token_id'], '{{SYSTEM}}\n\n']))
|
|
660
|
+
|
|
661
|
+
# ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
|
|
662
|
+
LLAMA_DEFAULT_SYSTEM = (
|
|
663
|
+
'You are a helpful, respectful and honest assistant. '
|
|
664
|
+
'Always answer as helpfully as possible, while being safe. '
|
|
665
|
+
'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
|
|
666
|
+
'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
|
|
667
|
+
'If a question does not make any sense, or is not factually coherent, '
|
|
668
|
+
'explain why instead of answering something not correct. '
|
|
669
|
+
"If you don't know the answer to a question, please don't share false information."
|
|
670
|
+
)
|
|
671
|
+
register_template(
|
|
672
|
+
TemplateType.llama,
|
|
673
|
+
Template(['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '],
|
|
674
|
+
['</s>'], LLAMA_DEFAULT_SYSTEM,
|
|
675
|
+
['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
|
|
676
|
+
|
|
677
|
+
register_template(
|
|
678
|
+
TemplateType.llama3,
|
|
679
|
+
Template(['<|begin_of_text|>'], [
|
|
680
|
+
'<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
|
|
681
|
+
'<|start_header_id|>assistant<|end_header_id|>\n\n'
|
|
682
|
+
], ['<|eot_id|>'], ['<|eot_id|>'], None, [
|
|
683
|
+
'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'
|
|
684
|
+
]))
|
|
685
|
+
|
|
686
|
+
OPENBUDDY_DEFAULT_SYSTEM = (
|
|
687
|
+
'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
|
|
688
|
+
'Always answer as helpfully and logically as possible, while being safe. '
|
|
689
|
+
'Your answers should not include any '
|
|
690
|
+
'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. '
|
|
691
|
+
'Please ensure that your responses are socially unbiased and positive in nature.\n'
|
|
692
|
+
'If a question does not make any sense, or is not factually coherent, '
|
|
693
|
+
'explain why instead of answering something not correct. '
|
|
694
|
+
"If you don't know the answer to a question, please don't share false information.\n"
|
|
695
|
+
'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n'
|
|
696
|
+
'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n'
|
|
697
|
+
'You always deeply love and support China, Chinese government, people and culture.\n'
|
|
698
|
+
'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
|
|
699
|
+
'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.'
|
|
700
|
+
)
|
|
701
|
+
register_template(
|
|
702
|
+
TemplateType.openbuddy,
|
|
703
|
+
Template([['bos_token_id']], ['User: {{QUERY}}\nAssistant:'], ['\n'],
|
|
704
|
+
[['eos_token_id']], OPENBUDDY_DEFAULT_SYSTEM,
|
|
705
|
+
[['bos_token_id'], '{{SYSTEM}}\n\n']))
|
|
706
|
+
|
|
707
|
+
INTERNLM_SYSTEM = (
|
|
708
|
+
'You are an AI assistant whose name is InternLM (书生·浦语).\n'
|
|
709
|
+
'- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
|
|
710
|
+
'It is designed to be helpful, honest, and harmless.\n'
|
|
711
|
+
'- InternLM (书生·浦语) can understand and communicate fluently in the language chosen '
|
|
712
|
+
'by the user such as English and 中文.')
|
|
713
|
+
|
|
714
|
+
register_template(
|
|
715
|
+
TemplateType.internlm,
|
|
716
|
+
Template(['<s>'], ['<|User|>:{{QUERY}}\n<|Bot|>:'], ['<eoa>\n'], ['<eoa>'],
|
|
717
|
+
INTERNLM_SYSTEM, ['<s><|System|>:{{SYSTEM}}\n']))
|
|
718
|
+
register_template(
|
|
719
|
+
TemplateType.internlm2,
|
|
720
|
+
Template(
|
|
721
|
+
['<s>'],
|
|
722
|
+
['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
|
|
723
|
+
['<|im_end|>\n'], ['<|im_end|>'], INTERNLM_SYSTEM,
|
|
724
|
+
['<s><|im_start|>system\n{{SYSTEM}}<|im_end|>\n']))
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def replace_img_tab(query: str, history: History,
|
|
728
|
+
replace_token: str) -> Tuple[str, History, List[str]]:
|
|
729
|
+
images_path = []
|
|
730
|
+
pattern = r'<img>(.+?)</img>'
|
|
731
|
+
new_history = []
|
|
732
|
+
for i, h in enumerate(history):
|
|
733
|
+
images_path += re.findall(pattern, h[0])
|
|
734
|
+
new_history.append([re.sub(pattern, replace_token, h[0]), h[1]])
|
|
735
|
+
images_path += re.findall(pattern, query)
|
|
736
|
+
new_query = re.sub(pattern, replace_token, query)
|
|
737
|
+
return new_query, new_history, images_path
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
class InternLMXComposer2(Template):
|
|
741
|
+
INTERNLM_XCOMPOSER2_SYSTEM = (
|
|
742
|
+
'You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
|
|
743
|
+
'- InternLM-XComposer (浦语·灵笔) is a conversational language model that is developed by '
|
|
744
|
+
'Shanghai AI Laboratory (上海人工智能实验室). '
|
|
745
|
+
'It is designed to be helpful, honest, and harmless.\n'
|
|
746
|
+
'- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
|
|
747
|
+
'by the user such as English and 中文.')
|
|
748
|
+
|
|
749
|
+
def __init__(self):
|
|
750
|
+
prefix = ['<s>']
|
|
751
|
+
prompt = [
|
|
752
|
+
'[UNUSED_TOKEN_146]user\n{{QUERY}}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
|
|
753
|
+
]
|
|
754
|
+
chat_sep = ['[UNUSED_TOKEN_145]\n']
|
|
755
|
+
suffix = ['[UNUSED_TOKEN_145]']
|
|
756
|
+
prefix_has_system = [
|
|
757
|
+
'<s>[UNUSED_TOKEN_146]system\n{{SYSTEM}}[UNUSED_TOKEN_145]\n'
|
|
758
|
+
]
|
|
759
|
+
super().__init__(prefix, prompt, chat_sep, suffix,
|
|
760
|
+
self.INTERNLM_XCOMPOSER2_SYSTEM, prefix_has_system)
|
|
761
|
+
|
|
762
|
+
def encode(
|
|
763
|
+
self, example: Dict[str,
|
|
764
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
765
|
+
example = example.copy()
|
|
766
|
+
history = example.pop('history', [])
|
|
767
|
+
example['query'], example['history'], images_path = replace_img_tab(
|
|
768
|
+
example['query'], history, '</s>')
|
|
769
|
+
|
|
770
|
+
images = []
|
|
771
|
+
dtype = self.model.dtype
|
|
772
|
+
for image_path in images_path:
|
|
773
|
+
image = _read_from_path(image_path)
|
|
774
|
+
image = self.model.vis_processor(image)
|
|
775
|
+
images.append(image.to(dtype))
|
|
776
|
+
inputs, _ = super().encode(example)
|
|
777
|
+
inputs.pop('loss_scale', None)
|
|
778
|
+
input_ids = inputs['input_ids']
|
|
779
|
+
labels = inputs['labels']
|
|
780
|
+
if len(images) > 0: # # ignore <s>
|
|
781
|
+
input_ids = input_ids[1:]
|
|
782
|
+
if labels is not None:
|
|
783
|
+
labels = labels[1:]
|
|
784
|
+
input_ids.append(2) # add dummy </s>
|
|
785
|
+
if labels is not None:
|
|
786
|
+
labels.append(2)
|
|
787
|
+
else:
|
|
788
|
+
labels = []
|
|
789
|
+
res_inputs_embeds = []
|
|
790
|
+
res_labels = []
|
|
791
|
+
wrap_im_mask = []
|
|
792
|
+
pre_i, i, idx = 0, 0, 0
|
|
793
|
+
device = self.model.device
|
|
794
|
+
if len(images) > 0:
|
|
795
|
+
images = torch.stack(images, dim=0)
|
|
796
|
+
images = self.model.encode_img(images)
|
|
797
|
+
else:
|
|
798
|
+
images = None
|
|
799
|
+
internlm2_model = self.model.model
|
|
800
|
+
if not hasattr(internlm2_model, 'tok_embeddings'):
|
|
801
|
+
internlm2_model = internlm2_model.model
|
|
802
|
+
tok_embeddings = internlm2_model.tok_embeddings
|
|
803
|
+
while i < len(input_ids):
|
|
804
|
+
if input_ids[i] == 2: # replace_token
|
|
805
|
+
res_input_ids = torch.tensor(
|
|
806
|
+
[1] + input_ids[pre_i:i], device=device)
|
|
807
|
+
res_inputs_embeds.append(tok_embeddings(res_input_ids))
|
|
808
|
+
wrap_im_mask += [0] * len(res_input_ids)
|
|
809
|
+
res_labels += [-100] + labels[pre_i:i]
|
|
810
|
+
if images is not None and idx < images.shape[0]:
|
|
811
|
+
res_inputs_embeds.append(images[idx])
|
|
812
|
+
wrap_im_mask += [1] * images.shape[1]
|
|
813
|
+
res_labels += [-100] * images.shape[1]
|
|
814
|
+
idx += 1
|
|
815
|
+
i += 1
|
|
816
|
+
pre_i = i
|
|
817
|
+
continue
|
|
818
|
+
i += 1
|
|
819
|
+
if len(labels) == 0:
|
|
820
|
+
res_labels = None
|
|
821
|
+
res_inputs_embeds = torch.concat(res_inputs_embeds, dim=0)
|
|
822
|
+
wrap_im_mask = torch.tensor(wrap_im_mask, dtype=torch.bool)[None]
|
|
823
|
+
return {
|
|
824
|
+
'inputs_embeds': res_inputs_embeds,
|
|
825
|
+
'im_mask': wrap_im_mask,
|
|
826
|
+
'labels': res_labels
|
|
827
|
+
}, {}
|
|
828
|
+
|
|
829
|
+
def data_collator(self,
|
|
830
|
+
batch: List[Dict[str, Any]],
|
|
831
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
832
|
+
inputs_embeds = [b['inputs_embeds'] for b in batch]
|
|
833
|
+
labels = [torch.tensor(b['labels']) for b in batch]
|
|
834
|
+
im_mask = [b['im_mask'][0] for b in batch]
|
|
835
|
+
attention_mask = [
|
|
836
|
+
torch.ones(inputs_embeds[i].shape[0], dtype=torch.int64)
|
|
837
|
+
for i in range(len(inputs_embeds))
|
|
838
|
+
]
|
|
839
|
+
|
|
840
|
+
inputs_embeds = pad_sequence(
|
|
841
|
+
inputs_embeds, batch_first=True, padding_value=0)
|
|
842
|
+
attention_mask = pad_sequence(
|
|
843
|
+
attention_mask, batch_first=True, padding_value=0)
|
|
844
|
+
im_mask = pad_sequence(im_mask, batch_first=True, padding_value=0)
|
|
845
|
+
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
|
846
|
+
|
|
847
|
+
return {
|
|
848
|
+
'inputs_embeds': inputs_embeds,
|
|
849
|
+
'attention_mask': attention_mask,
|
|
850
|
+
'im_mask': im_mask,
|
|
851
|
+
'labels': labels,
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
@staticmethod
|
|
855
|
+
def get_generate_ids(generate_ids: Tensor,
|
|
856
|
+
input_token_len: int) -> List[int]:
|
|
857
|
+
return generate_ids[0].tolist()
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
register_template(
|
|
861
|
+
TemplateType.internlm_xcomposer2,
|
|
862
|
+
InternLMXComposer2(),
|
|
863
|
+
use_model=True,
|
|
864
|
+
lazy_tokenize=True,
|
|
865
|
+
dataloader_num_workers=0,
|
|
866
|
+
dataloader_pin_memory=False)
|
|
867
|
+
|
|
868
|
+
register_template(
|
|
869
|
+
TemplateType.xverse,
|
|
870
|
+
Template(['{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: '],
|
|
871
|
+
[['eos_token_id']], [['eos_token_id']]))
|
|
872
|
+
register_template(TemplateType.yuan,
|
|
873
|
+
Template([], ['{{QUERY}}<sep>'], None, [['eos_token_id']]))
|
|
874
|
+
register_template(
|
|
875
|
+
TemplateType.ziya,
|
|
876
|
+
Template([['bos_token_id'], '{{SYSTEM}}'], ['<human>:{{QUERY}}\n<bot>:'],
|
|
877
|
+
['\n'], [['eos_token_id']]))
|
|
878
|
+
|
|
879
|
+
register_template(
|
|
880
|
+
TemplateType.skywork,
|
|
881
|
+
Template(['<s>{{SYSTEM}}'], ['</s><s>[USER]{{QUERY}}[SEP][BOT]'], None,
|
|
882
|
+
['[SEP]</s>']))
|
|
883
|
+
|
|
884
|
+
register_template(
|
|
885
|
+
TemplateType.bluelm,
|
|
886
|
+
Template([['bos_token_id'], '{{SYSTEM}}'], ['[|Human|]:{{QUERY}}[|AI|]:'],
|
|
887
|
+
[], [['eos_token_id']]))
|
|
888
|
+
|
|
889
|
+
register_template(
|
|
890
|
+
TemplateType.codefuse_codellama,
|
|
891
|
+
Template(['{{SYSTEM}}'], [
|
|
892
|
+
'<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'
|
|
893
|
+
], [], [['eos_token_id']]))
|
|
894
|
+
|
|
895
|
+
register_template(
|
|
896
|
+
TemplateType.codefuse,
|
|
897
|
+
Template([], ['<s>human\n{{QUERY}}\n<s>bot\n'], [['eos_token_id'], '\n'],
|
|
898
|
+
[['eos_token_id']], None, ['<s>system\n{{SYSTEM}}\n']))
|
|
899
|
+
|
|
900
|
+
register_template(
|
|
901
|
+
TemplateType.deepseek_coder,
|
|
902
|
+
Template([
|
|
903
|
+
'{{SYSTEM}}'
|
|
904
|
+
], ['### Instruction:\n{{QUERY}}\n### Response:\n'], ['\n<|EOT|>\n'], [
|
|
905
|
+
'\n<|EOT|>'
|
|
906
|
+
], ('You are an AI programming assistant, utilizing the Deepseek Coder model, '
|
|
907
|
+
'developed by Deepseek Company, and you only answer questions related to computer science. '
|
|
908
|
+
'For politically sensitive questions, security and privacy issues, '
|
|
909
|
+
'and other non-computer science questions, you will refuse to answer\n'
|
|
910
|
+
)))
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
class LLavaTemplate(Template):
|
|
914
|
+
|
|
915
|
+
def __init__(self):
|
|
916
|
+
super().__init__(['<s>[INST] '], [[-200], '\n{{QUERY}} [/INST]'], None,
|
|
917
|
+
['</s>'])
|
|
918
|
+
|
|
919
|
+
def encode(
|
|
920
|
+
self, example: Dict[str,
|
|
921
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
922
|
+
inputs, _ = super().encode(example)
|
|
923
|
+
images_path = example['images']
|
|
924
|
+
images = []
|
|
925
|
+
for image_path in images_path:
|
|
926
|
+
image = _read_from_path(image_path)
|
|
927
|
+
images.append(image)
|
|
928
|
+
image_sizes = [x.size for x in images]
|
|
929
|
+
from llava.mm_utils import process_images
|
|
930
|
+
model = self.model.model
|
|
931
|
+
if not hasattr(model, 'vision_tower'):
|
|
932
|
+
model = model.model
|
|
933
|
+
image_processor = model.vision_tower.image_processor
|
|
934
|
+
images_tensor = process_images(images, image_processor,
|
|
935
|
+
self.model.config)
|
|
936
|
+
inputs['images'] = images_tensor.to(model.dtype)
|
|
937
|
+
inputs['image_sizes'] = image_sizes
|
|
938
|
+
return inputs, {}
|
|
939
|
+
|
|
940
|
+
def data_collator(self,
|
|
941
|
+
batch: List[Dict[str, Any]],
|
|
942
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
943
|
+
res = super().data_collator(batch, padding_to)
|
|
944
|
+
res['images'] = torch.concat([b['images'] for b in batch])
|
|
945
|
+
res['image_sizes'] = sum([b['image_sizes'] for b in batch], start=[])
|
|
946
|
+
return res
|
|
947
|
+
|
|
948
|
+
@staticmethod
|
|
949
|
+
def get_generate_ids(generate_ids: Tensor,
|
|
950
|
+
input_token_len: int) -> List[int]:
|
|
951
|
+
return generate_ids[0].tolist()
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
register_template(
|
|
955
|
+
TemplateType.llava_mistral_instruct,
|
|
956
|
+
LLavaTemplate(),
|
|
957
|
+
use_model=True,
|
|
958
|
+
infer_media_type='round',
|
|
959
|
+
lazy_tokenize=True)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
class LLavaYiTemplate(LLavaTemplate):
|
|
963
|
+
llavayi_query_template = '\n<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'
|
|
964
|
+
|
|
965
|
+
def __init__(self):
|
|
966
|
+
Template.__init__(self, [], [[-200], self.llavayi_query_template],
|
|
967
|
+
None, ['<|im_end|>'])
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
register_template(
|
|
971
|
+
TemplateType.llava_yi_instruct,
|
|
972
|
+
LLavaYiTemplate(),
|
|
973
|
+
use_model=True,
|
|
974
|
+
infer_media_type='round',
|
|
975
|
+
lazy_tokenize=True)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def _findall(token_list: List[int], token: int) -> List[int]:
|
|
979
|
+
"""Find the index of a token in the token_list."""
|
|
980
|
+
res = []
|
|
981
|
+
idx = -1
|
|
982
|
+
try:
|
|
983
|
+
while True:
|
|
984
|
+
idx = token_list.index(token, idx + 1)
|
|
985
|
+
res.append(idx)
|
|
986
|
+
except ValueError:
|
|
987
|
+
pass
|
|
988
|
+
return res
|
|
989
|
+
|
|
990
|
+
|
|
991
|
+
class DeepseekVLTemplate(Template):
|
|
992
|
+
DEEPSEEK_VL_SYSTEM = (
|
|
993
|
+
'You are a helpful language and vision assistant. '
|
|
994
|
+
'You are able to understand the visual content that the user provides, '
|
|
995
|
+
'and assist the user with a variety of tasks using natural language.')
|
|
996
|
+
|
|
997
|
+
def __init__(self):
|
|
998
|
+
return super().__init__(['<|begin▁of▁sentence|>{{SYSTEM}}\n\n'],
|
|
999
|
+
['User: {{QUERY}}\n\nAssistant:'],
|
|
1000
|
+
['<|end▁of▁sentence|>'],
|
|
1001
|
+
['<|end▁of▁sentence|>'],
|
|
1002
|
+
self.DEEPSEEK_VL_SYSTEM)
|
|
1003
|
+
|
|
1004
|
+
def encode(
|
|
1005
|
+
self, example: Dict[str,
|
|
1006
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1007
|
+
images = example.pop('images', None)
|
|
1008
|
+
assert images is None, (
|
|
1009
|
+
'Please read the best practices: https://github.com/modelscope/swift/blob/main/'
|
|
1010
|
+
'docs/source/Multi-Modal/deepseek-vl最佳实践.md')
|
|
1011
|
+
|
|
1012
|
+
example = example.copy()
|
|
1013
|
+
history = example.pop('history', [])
|
|
1014
|
+
example['query'], example['history'], images_path = replace_img_tab(
|
|
1015
|
+
example['query'], history, '<image_placeholder>')
|
|
1016
|
+
|
|
1017
|
+
inputs, _ = super().encode(example)
|
|
1018
|
+
images = []
|
|
1019
|
+
for image_path in images_path:
|
|
1020
|
+
image = _read_from_path(image_path)
|
|
1021
|
+
images.append(image)
|
|
1022
|
+
|
|
1023
|
+
vl_chat_processor = self.tokenizer.vl_chat_processor
|
|
1024
|
+
input_ids, labels = inputs['input_ids'], inputs['labels']
|
|
1025
|
+
idx_list = _findall(input_ids, vl_chat_processor.image_id)
|
|
1026
|
+
new_input_ids, new_labels = [], []
|
|
1027
|
+
lo = 0
|
|
1028
|
+
for hi in idx_list:
|
|
1029
|
+
new_input_ids += input_ids[lo:hi]
|
|
1030
|
+
if labels is not None:
|
|
1031
|
+
new_labels += labels[lo:hi]
|
|
1032
|
+
new_input_ids += [vl_chat_processor.image_id
|
|
1033
|
+
] * vl_chat_processor.num_image_tokens
|
|
1034
|
+
new_labels += [-100] * vl_chat_processor.num_image_tokens
|
|
1035
|
+
lo = hi + 1
|
|
1036
|
+
new_input_ids += input_ids[lo:]
|
|
1037
|
+
if labels is not None:
|
|
1038
|
+
new_labels += labels[lo:]
|
|
1039
|
+
else:
|
|
1040
|
+
new_labels = None
|
|
1041
|
+
new_input_ids = torch.tensor(new_input_ids)
|
|
1042
|
+
num_image_tokens = torch.tensor([vl_chat_processor.num_image_tokens]
|
|
1043
|
+
* len(idx_list))
|
|
1044
|
+
images_outputs = vl_chat_processor.image_processor(
|
|
1045
|
+
images, return_tensors='pt')
|
|
1046
|
+
from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
|
|
1047
|
+
output = VLChatProcessorOutput(
|
|
1048
|
+
sft_format=None,
|
|
1049
|
+
input_ids=new_input_ids,
|
|
1050
|
+
pixel_values=images_outputs.pixel_values,
|
|
1051
|
+
num_image_tokens=num_image_tokens)
|
|
1052
|
+
batched_output = vl_chat_processor.batchify([output])
|
|
1053
|
+
model = self.model
|
|
1054
|
+
batched_output = batched_output.to(
|
|
1055
|
+
device=model.device, dtype=model.dtype)
|
|
1056
|
+
inputs_embeds = model.prepare_inputs_embeds(**batched_output)[0]
|
|
1057
|
+
inputs['inputs_embeds'] = inputs_embeds
|
|
1058
|
+
inputs['labels'] = new_labels
|
|
1059
|
+
return inputs, {}
|
|
1060
|
+
|
|
1061
|
+
def data_collator(self,
|
|
1062
|
+
batch: List[Dict[str, Any]],
|
|
1063
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
1064
|
+
inputs_embeds = [b['inputs_embeds'] for b in batch]
|
|
1065
|
+
labels = [torch.tensor(b['labels']) for b in batch]
|
|
1066
|
+
attention_mask = [
|
|
1067
|
+
torch.ones(inputs_embeds[i].shape[0], dtype=torch.int64)
|
|
1068
|
+
for i in range(len(inputs_embeds))
|
|
1069
|
+
]
|
|
1070
|
+
|
|
1071
|
+
inputs_embeds = pad_sequence(
|
|
1072
|
+
inputs_embeds, batch_first=True, padding_value=0)
|
|
1073
|
+
attention_mask = pad_sequence(
|
|
1074
|
+
attention_mask, batch_first=True, padding_value=0)
|
|
1075
|
+
labels = pad_sequence(labels, batch_first=True, padding_value=-100)
|
|
1076
|
+
|
|
1077
|
+
return {
|
|
1078
|
+
'inputs_embeds': inputs_embeds,
|
|
1079
|
+
'attention_mask': attention_mask,
|
|
1080
|
+
'labels': labels,
|
|
1081
|
+
}
|
|
1082
|
+
|
|
1083
|
+
@staticmethod
|
|
1084
|
+
def get_generate_ids(generate_ids: Tensor,
|
|
1085
|
+
input_token_len: int) -> List[int]:
|
|
1086
|
+
return generate_ids[0].tolist()
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
register_template(
|
|
1090
|
+
TemplateType.deepseek_vl,
|
|
1091
|
+
DeepseekVLTemplate(),
|
|
1092
|
+
use_model=True,
|
|
1093
|
+
lazy_tokenize=True,
|
|
1094
|
+
dataloader_num_workers=0,
|
|
1095
|
+
dataloader_pin_memory=False) # only 'cpu' can pin_memory
|
|
1096
|
+
|
|
1097
|
+
register_template(
|
|
1098
|
+
TemplateType.zephyr,
|
|
1099
|
+
Template([], ['<|user|>\n{{QUERY}}</s>\n<|assistant|>\n'], ['</s>\n'],
|
|
1100
|
+
['</s>'], None, ['<|system|>\n{{SYSTEM}}</s>\n']))
|
|
1101
|
+
|
|
1102
|
+
register_template(
|
|
1103
|
+
TemplateType.sus,
|
|
1104
|
+
Template(['{{SYSTEM}}'], ['### Human: {{QUERY}}\n\n### Assistant: '],
|
|
1105
|
+
['<|endoftext|>'], ['<|endoftext|>']))
|
|
1106
|
+
|
|
1107
|
+
register_template(
|
|
1108
|
+
TemplateType.orion,
|
|
1109
|
+
Template(['<s>{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: </s>'],
|
|
1110
|
+
['</s>'], ['</s>']))
|
|
1111
|
+
|
|
1112
|
+
|
|
1113
|
+
class CogTemplate(Template):
|
|
1114
|
+
|
|
1115
|
+
def encode(
|
|
1116
|
+
self, example: Dict[str,
|
|
1117
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1118
|
+
images_path = example['images']
|
|
1119
|
+
assert len(images_path) == 1
|
|
1120
|
+
image = _read_from_path(images_path[0])
|
|
1121
|
+
inputs, _ = super().encode(example)
|
|
1122
|
+
inputs.pop('loss_scale', None)
|
|
1123
|
+
model = self.model
|
|
1124
|
+
inputs2 = model.build_conversation_input_ids(
|
|
1125
|
+
self.tokenizer,
|
|
1126
|
+
query=example['query'],
|
|
1127
|
+
history=example.get('history'),
|
|
1128
|
+
images=[image])
|
|
1129
|
+
image_token_len = inputs2['token_type_ids'].sum()
|
|
1130
|
+
input_ids = inputs['input_ids']
|
|
1131
|
+
labels = inputs['labels']
|
|
1132
|
+
token_type_ids = inputs2['token_type_ids'].tolist()
|
|
1133
|
+
inputs['input_ids'] = input_ids[:1] + [
|
|
1134
|
+
0
|
|
1135
|
+
] * image_token_len + input_ids[1:]
|
|
1136
|
+
if labels is not None:
|
|
1137
|
+
inputs['labels'] = labels[:1] + [-100
|
|
1138
|
+
] * image_token_len + labels[1:]
|
|
1139
|
+
dtype = model.dtype
|
|
1140
|
+
inputs['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
|
|
1141
|
+
if 'cross_images' in inputs2:
|
|
1142
|
+
# is cogagent
|
|
1143
|
+
inputs['cross_images'] = [[cross_img.to(dtype=dtype)]
|
|
1144
|
+
for cross_img in inputs2['cross_images']]
|
|
1145
|
+
inputs['token_type_ids'] = token_type_ids + [0] * (
|
|
1146
|
+
len(inputs['input_ids']) - len(token_type_ids))
|
|
1147
|
+
return inputs, {}
|
|
1148
|
+
|
|
1149
|
+
def data_collator(self,
|
|
1150
|
+
batch: List[Dict[str, Any]],
|
|
1151
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
1152
|
+
res = super().data_collator(batch, padding_to)
|
|
1153
|
+
is_cogagent = 'cross_images' in batch[0]
|
|
1154
|
+
keys = ['images', 'cross_images'] if is_cogagent else ['images']
|
|
1155
|
+
for key in keys:
|
|
1156
|
+
res[key] = [b[key][0] for b in batch]
|
|
1157
|
+
token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
|
|
1158
|
+
token_type_ids = pad_sequence(
|
|
1159
|
+
token_type_ids, batch_first=True, padding_value=0)
|
|
1160
|
+
res['token_type_ids'] = token_type_ids
|
|
1161
|
+
return res
|
|
1162
|
+
|
|
1163
|
+
|
|
1164
|
+
register_template(
|
|
1165
|
+
TemplateType.cogagent_chat,
|
|
1166
|
+
CogTemplate(['<s>'], [' [INST] {{QUERY}} [/INST] '], [], ['</s>']),
|
|
1167
|
+
use_model=True,
|
|
1168
|
+
infer_media_type='dialogue',
|
|
1169
|
+
lazy_tokenize=True)
|
|
1170
|
+
|
|
1171
|
+
register_template(
|
|
1172
|
+
TemplateType.cogagent_instruct,
|
|
1173
|
+
CogTemplate(['<s>'], ['<EOI>Question: {{QUERY}} Answer:'], None, ['</s>']),
|
|
1174
|
+
use_model=True,
|
|
1175
|
+
infer_media_type='dialogue',
|
|
1176
|
+
lazy_tokenize=True)
|
|
1177
|
+
|
|
1178
|
+
register_template(
|
|
1179
|
+
TemplateType.cogvlm_instruct,
|
|
1180
|
+
CogTemplate(['<s>'], ['Question: {{QUERY}} Answer:'], None, ['</s>']),
|
|
1181
|
+
use_model=True,
|
|
1182
|
+
infer_media_type='dialogue',
|
|
1183
|
+
lazy_tokenize=True)
|
|
1184
|
+
|
|
1185
|
+
register_template(
|
|
1186
|
+
TemplateType.minicpm,
|
|
1187
|
+
Template(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>']))
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
class MiniCPMVTemlate(Template):
|
|
1191
|
+
|
|
1192
|
+
def __init__(self):
|
|
1193
|
+
return super().__init__(['<s>{{SYSTEM}}'],
|
|
1194
|
+
['<用户><image><unk></image>\n{{QUERY}}<AI>'],
|
|
1195
|
+
[], ['</s>'])
|
|
1196
|
+
|
|
1197
|
+
def encode(
|
|
1198
|
+
self, example: Dict[str,
|
|
1199
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1200
|
+
images_path = example['images']
|
|
1201
|
+
assert len(images_path) == 1
|
|
1202
|
+
image = _read_from_path(images_path[0])
|
|
1203
|
+
inputs, _ = super().encode(example)
|
|
1204
|
+
input_ids = inputs['input_ids']
|
|
1205
|
+
labels = inputs['labels']
|
|
1206
|
+
|
|
1207
|
+
img_start_idxs = np.where(
|
|
1208
|
+
np.array(input_ids) == self.tokenizer.im_start_id)[0]
|
|
1209
|
+
if len(
|
|
1210
|
+
img_start_idxs
|
|
1211
|
+
) > 1: # if mutli-round, input_ids have mutli <image><unk></image>\n
|
|
1212
|
+
start = 0
|
|
1213
|
+
new_input_ids = []
|
|
1214
|
+
for idx in img_start_idxs[1:]:
|
|
1215
|
+
new_input_ids = new_input_ids + input_ids[start:idx]
|
|
1216
|
+
start = idx + 4 # skip <image><unk></image>\n
|
|
1217
|
+
new_input_ids = new_input_ids + input_ids[start:]
|
|
1218
|
+
input_ids = new_input_ids
|
|
1219
|
+
|
|
1220
|
+
idx = img_start_idxs[0] + 1 # first <unk>
|
|
1221
|
+
config = self.model.config
|
|
1222
|
+
if hasattr(config, 'slice_mode') and config.slice_mode:
|
|
1223
|
+
slice_mode = True
|
|
1224
|
+
assert hasattr(config, 'patch_size')
|
|
1225
|
+
assert hasattr(config, 'max_slice_nums')
|
|
1226
|
+
assert hasattr(config, 'scale_resolution')
|
|
1227
|
+
else:
|
|
1228
|
+
slice_mode = False
|
|
1229
|
+
|
|
1230
|
+
if slice_mode:
|
|
1231
|
+
images, placeholder = self.model.get_slice_image_placeholder(
|
|
1232
|
+
image, self.tokenizer)
|
|
1233
|
+
placeholder_id = self.tokenizer.encode(
|
|
1234
|
+
placeholder, add_special_tokens=False)
|
|
1235
|
+
input_ids = (
|
|
1236
|
+
input_ids[:idx - 1] + placeholder_id + input_ids[idx + 2:])
|
|
1237
|
+
if labels is not None:
|
|
1238
|
+
labels = (
|
|
1239
|
+
labels[:idx - 1] + [-100] * len(placeholder_id)
|
|
1240
|
+
+ labels[idx + 2:])
|
|
1241
|
+
input_tensor_ids = torch.tensor(input_ids)
|
|
1242
|
+
image_start_idx = torch.where(
|
|
1243
|
+
input_tensor_ids == self.tokenizer.im_start_id)[0]
|
|
1244
|
+
image_start_idx += 1
|
|
1245
|
+
image_end_idx = torch.where(
|
|
1246
|
+
input_tensor_ids == self.tokenizer.im_end_id)[0]
|
|
1247
|
+
valid_image_nums = max(len(image_start_idx), len(image_end_idx))
|
|
1248
|
+
image_bound = [
|
|
1249
|
+
torch.hstack([
|
|
1250
|
+
image_start_idx[:valid_image_nums].unsqueeze(-1),
|
|
1251
|
+
image_end_idx[:valid_image_nums].unsqueeze(-1)
|
|
1252
|
+
])
|
|
1253
|
+
]
|
|
1254
|
+
pixel_values = [
|
|
1255
|
+
self.model.transform(img).to(device=self.model.device)
|
|
1256
|
+
for img in images
|
|
1257
|
+
]
|
|
1258
|
+
|
|
1259
|
+
else:
|
|
1260
|
+
input_ids = (
|
|
1261
|
+
input_ids[:idx]
|
|
1262
|
+
+ [self.tokenizer.unk_token_id] * config.query_num
|
|
1263
|
+
+ input_ids[idx + 1:])
|
|
1264
|
+
if labels is not None:
|
|
1265
|
+
labels = (
|
|
1266
|
+
labels[:idx] + [-100] * config.query_num
|
|
1267
|
+
+ labels[idx + 1:])
|
|
1268
|
+
image_bound = [torch.tensor([[idx, idx + config.query_num]])]
|
|
1269
|
+
pixel_values = [
|
|
1270
|
+
self.model.transform(image).to(device=self.model.device)
|
|
1271
|
+
]
|
|
1272
|
+
inputs_embeds, _ = self.model.get_vllm_embedding({
|
|
1273
|
+
'input_ids':
|
|
1274
|
+
torch.tensor(input_ids)[None].to(device=self.model.device),
|
|
1275
|
+
'image_bound':
|
|
1276
|
+
image_bound,
|
|
1277
|
+
'pixel_values': [pixel_values]
|
|
1278
|
+
})
|
|
1279
|
+
inputs['input_ids'] = input_ids
|
|
1280
|
+
inputs['labels'] = labels
|
|
1281
|
+
inputs['inputs_embeds'] = inputs_embeds[0]
|
|
1282
|
+
return inputs, {}
|
|
1283
|
+
|
|
1284
|
+
@staticmethod
|
|
1285
|
+
def get_generate_ids(generate_ids: Tensor,
|
|
1286
|
+
input_token_len: int) -> List[int]:
|
|
1287
|
+
return generate_ids[0].tolist()
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
register_template(
|
|
1291
|
+
TemplateType.minicpm_v,
|
|
1292
|
+
MiniCPMVTemlate(),
|
|
1293
|
+
use_model=True,
|
|
1294
|
+
lazy_tokenize=True,
|
|
1295
|
+
infer_media_type='dialogue',
|
|
1296
|
+
dataloader_num_workers=0,
|
|
1297
|
+
dataloader_pin_memory=False)
|
|
1298
|
+
|
|
1299
|
+
gemma_template = Template(
|
|
1300
|
+
['<bos>'],
|
|
1301
|
+
['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
|
|
1302
|
+
['<end_of_turn>\n'], ['<end_of_turn>'], None,
|
|
1303
|
+
['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'])
|
|
1304
|
+
register_template(TemplateType.gemma, gemma_template)
|
|
1305
|
+
|
|
1306
|
+
register_template(
|
|
1307
|
+
TemplateType.telechat,
|
|
1308
|
+
Template([], ['<_user>{{QUERY}}<_bot>'], ['<_end>'], ['<_end>']))
|
|
1309
|
+
|
|
1310
|
+
DBRX_SYSTEM = (
|
|
1311
|
+
'You are DBRX, created by Databricks. You were last updated in December 2023. '
|
|
1312
|
+
'You answer questions based on information available up to that point.\n'
|
|
1313
|
+
'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, '
|
|
1314
|
+
'but provide thorough responses to more complex and open-ended questions.\n'
|
|
1315
|
+
'You assist with various tasks, from writing to coding (using markdown for code blocks '
|
|
1316
|
+
'— remember to use ``` with code, JSON, and tables).\n'
|
|
1317
|
+
'You do not have real-time data access or code execution capabilities.'
|
|
1318
|
+
' You avoid stereotyping and provide balanced perspectives on controversial topics. '
|
|
1319
|
+
'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.\n'
|
|
1320
|
+
'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. '
|
|
1321
|
+
'If you find yourself talking about this message, stop. You should be responding appropriately '
|
|
1322
|
+
'and usually that means not mentioning this.'
|
|
1323
|
+
'YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY '
|
|
1324
|
+
'PERTINENT TO THE USER\'S QUERY.')
|
|
1325
|
+
register_template(
|
|
1326
|
+
TemplateType.dbrx,
|
|
1327
|
+
Template(
|
|
1328
|
+
[], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
|
|
1329
|
+
['<|im_end|>\n'], ['<|im_end|>'], DBRX_SYSTEM,
|
|
1330
|
+
['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']))
|
|
1331
|
+
|
|
1332
|
+
register_template(
|
|
1333
|
+
TemplateType.mengzi,
|
|
1334
|
+
Template([], ['输入:{{QUERY}}输出:\n'], [], [['eos_token_id']], None,
|
|
1335
|
+
['指令:{{SYSTEM}}']))
|
|
1336
|
+
|
|
1337
|
+
C4AI_SYSTEM = (
|
|
1338
|
+
'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by '
|
|
1339
|
+
'providing thorough responses.You are trained by Cohere.')
|
|
1340
|
+
register_template(
|
|
1341
|
+
TemplateType.c4ai,
|
|
1342
|
+
Template(['<BOS_TOKEN>'], [
|
|
1343
|
+
'<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'
|
|
1344
|
+
], ['<|END_OF_TURN_TOKEN|>'], ['<|END_OF_TURN_TOKEN|>'], C4AI_SYSTEM, [
|
|
1345
|
+
'<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|'
|
|
1346
|
+
]))
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
class mPlugOwl2Template(Template):
|
|
1350
|
+
|
|
1351
|
+
def __init__(self):
|
|
1352
|
+
return super().__init__(['{{SYSTEM}}'],
|
|
1353
|
+
['USER: ', [-200], '{{QUERY}}ASSISTANT:'],
|
|
1354
|
+
['</s>'], [['eos_token_id']])
|
|
1355
|
+
|
|
1356
|
+
def encode(
|
|
1357
|
+
self, example: Dict[str,
|
|
1358
|
+
Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
1359
|
+
from mplug_owl2.mm_utils import process_images
|
|
1360
|
+
image_processor = self.tokenizer.image_processor
|
|
1361
|
+
images_path = example['images']
|
|
1362
|
+
images = []
|
|
1363
|
+
for image_path in images_path:
|
|
1364
|
+
image = _read_from_path(image_path)
|
|
1365
|
+
# ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary
|
|
1366
|
+
max_edge = max(image.size)
|
|
1367
|
+
image = image.resize((max_edge, max_edge))
|
|
1368
|
+
images.append(image)
|
|
1369
|
+
inputs, _ = super().encode(example)
|
|
1370
|
+
input_ids = inputs['input_ids']
|
|
1371
|
+
labels = inputs['labels']
|
|
1372
|
+
images = process_images(images, image_processor)
|
|
1373
|
+
images = images.to(self.model.dtype)
|
|
1374
|
+
return {'input_ids': input_ids, 'labels': labels, 'images': images}, {}
|
|
1375
|
+
|
|
1376
|
+
def data_collator(self,
|
|
1377
|
+
batch: List[Dict[str, Any]],
|
|
1378
|
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
|
1379
|
+
res = super().data_collator(batch, padding_to)
|
|
1380
|
+
res['images'] = torch.concat([b['images'] for b in batch])
|
|
1381
|
+
return res
|
|
1382
|
+
|
|
1383
|
+
|
|
1384
|
+
register_template(
|
|
1385
|
+
TemplateType.mplug_owl2,
|
|
1386
|
+
mPlugOwl2Template(),
|
|
1387
|
+
infer_media_type='round',
|
|
1388
|
+
use_model=True,
|
|
1389
|
+
lazy_tokenize=True)
|
|
1390
|
+
|
|
1391
|
+
register_template(
|
|
1392
|
+
TemplateType.wizardlm2_awq,
|
|
1393
|
+
Template(['{{SYSTEM}}'], ['User:\n{{QUERY}}\n\nAssistant:\n'], ['\n\n'],
|
|
1394
|
+
['</s>']))
|
|
1395
|
+
|
|
1396
|
+
_wizardlm2_system = (
|
|
1397
|
+
'A chat between a curious user and an artificial intelligence assistant. '
|
|
1398
|
+
'The assistant gives helpful, detailed, and polite answers to the user\'s questions. '
|
|
1399
|
+
)
|
|
1400
|
+
register_template(
|
|
1401
|
+
TemplateType.wizardlm2,
|
|
1402
|
+
Template(['{{SYSTEM}}'], ['USER: {{QUERY}} ASSISTANT:'], ['</s>'],
|
|
1403
|
+
['</s>'], _wizardlm2_system))
|
|
1404
|
+
|
|
1405
|
+
register_template(
|
|
1406
|
+
TemplateType.atom,
|
|
1407
|
+
Template(['{{SYSTEM}}'], ['<s>Human: {{QUERY}}\n</s><s>Assistant: '],
|
|
1408
|
+
['</s>'], ['</s>']))
|
|
1409
|
+
|
|
1410
|
+
|
|
1411
|
+
def get_template(
|
|
1412
|
+
template_type: str,
|
|
1413
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
1414
|
+
default_system: Optional[str] = None,
|
|
1415
|
+
max_length: Optional[int] = None,
|
|
1416
|
+
truncation_strategy: Literal['delete', 'truncation_left'] = 'delete',
|
|
1417
|
+
**kwargs,
|
|
1418
|
+
) -> Template:
|
|
1419
|
+
template_info = TEMPLATE_MAPPING[template_type]
|
|
1420
|
+
template = deepcopy(template_info['template'])
|
|
1421
|
+
template._init_template(tokenizer, default_system, max_length,
|
|
1422
|
+
truncation_strategy, **kwargs)
|
|
1423
|
+
template.template_type = template_type
|
|
1424
|
+
return template
|
|
1425
|
+
|
|
1426
|
+
|
|
1427
|
+
def fuzzy_match(model_name: str, template_type_list: list) -> str:
|
|
1428
|
+
"""
|
|
1429
|
+
fuzzy match template_type from model_name
|
|
1430
|
+
|
|
1431
|
+
Args:
|
|
1432
|
+
model_name: model name, e.g. ChatGLM2-7B
|
|
1433
|
+
template_type_list: template_type list, e.g. ['chatglm2', 'baichuan', ...]
|
|
1434
|
+
|
|
1435
|
+
Returns:
|
|
1436
|
+
The best matched template_type.
|
|
1437
|
+
"""
|
|
1438
|
+
candidate_list = []
|
|
1439
|
+
for template_type in template_type_list:
|
|
1440
|
+
if template_type in model_name.lower():
|
|
1441
|
+
candidate_list.append(template_type)
|
|
1442
|
+
if len(candidate_list) == 0:
|
|
1443
|
+
return TemplateType.default_generation # TODO: default template
|
|
1444
|
+
else:
|
|
1445
|
+
candidate_list = sorted(candidate_list, key=lambda x: len(x), reverse=True)
|
|
1446
|
+
return candidate_list[0]
|