evalscope 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. evalscope/__init__.py +3 -0
  2. evalscope/backend/__init__.py +3 -0
  3. evalscope/backend/base.py +27 -0
  4. evalscope/backend/opencompass/__init__.py +3 -0
  5. evalscope/backend/opencompass/api_meta_template.py +64 -0
  6. evalscope/backend/opencompass/backend_manager.py +247 -0
  7. evalscope/backend/opencompass/tasks/__init__.py +1 -0
  8. evalscope/backend/opencompass/tasks/eval_api.py +30 -0
  9. evalscope/backend/opencompass/tasks/eval_datasets.py +71 -0
  10. evalscope/backend/vlm_eval_kit/__init__.py +1 -0
  11. evalscope/backend/vlm_eval_kit/backend_manager.py +153 -0
  12. evalscope/benchmarks/__init__.py +4 -0
  13. evalscope/benchmarks/arc/__init__.py +5 -0
  14. evalscope/benchmarks/arc/ai2_arc.py +148 -0
  15. evalscope/benchmarks/arc/arc_adapter.py +231 -0
  16. evalscope/benchmarks/bbh/__init__.py +6 -0
  17. evalscope/benchmarks/bbh/bbh_adapter.py +308 -0
  18. evalscope/benchmarks/bbh/cot_prompts/boolean_expressions.txt +23 -0
  19. evalscope/benchmarks/bbh/cot_prompts/causal_judgement.txt +25 -0
  20. evalscope/benchmarks/bbh/cot_prompts/date_understanding.txt +33 -0
  21. evalscope/benchmarks/bbh/cot_prompts/disambiguation_qa.txt +37 -0
  22. evalscope/benchmarks/bbh/cot_prompts/dyck_languages.txt +72 -0
  23. evalscope/benchmarks/bbh/cot_prompts/formal_fallacies.txt +44 -0
  24. evalscope/benchmarks/bbh/cot_prompts/geometric_shapes.txt +78 -0
  25. evalscope/benchmarks/bbh/cot_prompts/hyperbaton.txt +28 -0
  26. evalscope/benchmarks/bbh/cot_prompts/logical_deduction_five_objects.txt +37 -0
  27. evalscope/benchmarks/bbh/cot_prompts/logical_deduction_seven_objects.txt +37 -0
  28. evalscope/benchmarks/bbh/cot_prompts/logical_deduction_three_objects.txt +37 -0
  29. evalscope/benchmarks/bbh/cot_prompts/movie_recommendation.txt +42 -0
  30. evalscope/benchmarks/bbh/cot_prompts/multistep_arithmetic_two.txt +25 -0
  31. evalscope/benchmarks/bbh/cot_prompts/navigate.txt +43 -0
  32. evalscope/benchmarks/bbh/cot_prompts/object_counting.txt +37 -0
  33. evalscope/benchmarks/bbh/cot_prompts/penguins_in_a_table.txt +41 -0
  34. evalscope/benchmarks/bbh/cot_prompts/reasoning_about_colored_objects.txt +63 -0
  35. evalscope/benchmarks/bbh/cot_prompts/ruin_names.txt +44 -0
  36. evalscope/benchmarks/bbh/cot_prompts/salient_translation_error_detection.txt +40 -0
  37. evalscope/benchmarks/bbh/cot_prompts/snarks.txt +30 -0
  38. evalscope/benchmarks/bbh/cot_prompts/sports_understanding.txt +10 -0
  39. evalscope/benchmarks/bbh/cot_prompts/temporal_sequences.txt +77 -0
  40. evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_five_objects.txt +40 -0
  41. evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_seven_objects.txt +40 -0
  42. evalscope/benchmarks/bbh/cot_prompts/tracking_shuffled_objects_three_objects.txt +40 -0
  43. evalscope/benchmarks/bbh/cot_prompts/web_of_lies.txt +28 -0
  44. evalscope/benchmarks/bbh/cot_prompts/word_sorting.txt +17 -0
  45. evalscope/benchmarks/benchmark.py +65 -0
  46. evalscope/benchmarks/ceval/__init__.py +5 -0
  47. evalscope/benchmarks/ceval/ceval_adapter.py +340 -0
  48. evalscope/benchmarks/ceval/ceval_exam.py +159 -0
  49. evalscope/benchmarks/cmmlu/__init__.py +5 -0
  50. evalscope/benchmarks/cmmlu/cmmlu.py +166 -0
  51. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +369 -0
  52. evalscope/benchmarks/competition_math/__init__.py +5 -0
  53. evalscope/benchmarks/competition_math/competition_math.py +88 -0
  54. evalscope/benchmarks/competition_math/competition_math_adapter.py +470 -0
  55. evalscope/benchmarks/data_adapter.py +263 -0
  56. evalscope/benchmarks/general_qa/__init__.py +5 -0
  57. evalscope/benchmarks/general_qa/general_qa_adapter.py +186 -0
  58. evalscope/benchmarks/gsm8k/__init__.py +5 -0
  59. evalscope/benchmarks/gsm8k/gsm8k.py +127 -0
  60. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +236 -0
  61. evalscope/benchmarks/hellaswag/__init__.py +5 -0
  62. evalscope/benchmarks/hellaswag/hellaswag.py +116 -0
  63. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +222 -0
  64. evalscope/benchmarks/humaneval/__init__.py +5 -0
  65. evalscope/benchmarks/humaneval/humaneval.py +82 -0
  66. evalscope/benchmarks/humaneval/humaneval_adapter.py +21 -0
  67. evalscope/benchmarks/mmlu/__init__.py +5 -0
  68. evalscope/benchmarks/mmlu/mmlu.py +174 -0
  69. evalscope/benchmarks/mmlu/mmlu_adapter.py +375 -0
  70. evalscope/benchmarks/race/__init__.py +5 -0
  71. evalscope/benchmarks/race/race.py +118 -0
  72. evalscope/benchmarks/race/race_adapter.py +229 -0
  73. evalscope/benchmarks/trivia_qa/__init__.py +5 -0
  74. evalscope/benchmarks/trivia_qa/trivia_qa.py +104 -0
  75. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +207 -0
  76. evalscope/benchmarks/truthful_qa/__init__.py +5 -0
  77. evalscope/benchmarks/truthful_qa/truthful_qa.py +167 -0
  78. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +351 -0
  79. evalscope/cache.py +98 -0
  80. evalscope/cli/__init__.py +1 -0
  81. evalscope/cli/base.py +20 -0
  82. evalscope/cli/cli.py +26 -0
  83. evalscope/cli/start_perf.py +37 -0
  84. evalscope/cli/start_server.py +138 -0
  85. evalscope/config.py +165 -0
  86. evalscope/constants.py +150 -0
  87. evalscope/evaluator/__init__.py +3 -0
  88. evalscope/evaluator/evaluator.py +689 -0
  89. evalscope/evaluator/rating_eval.py +178 -0
  90. evalscope/evaluator/reviewer/__init__.py +1 -0
  91. evalscope/evaluator/reviewer/auto_reviewer.py +411 -0
  92. evalscope/metrics/__init__.py +1 -0
  93. evalscope/metrics/bundled_rouge_score/__init__.py +14 -0
  94. evalscope/metrics/bundled_rouge_score/rouge_scorer.py +342 -0
  95. evalscope/metrics/code_metric.py +104 -0
  96. evalscope/metrics/math_accuracy.py +60 -0
  97. evalscope/metrics/metrics.py +405 -0
  98. evalscope/metrics/rouge_metric.py +129 -0
  99. evalscope/models/__init__.py +4 -0
  100. evalscope/models/custom/__init__.py +4 -0
  101. evalscope/models/custom/custom_model.py +53 -0
  102. evalscope/models/dummy_chat_model.py +50 -0
  103. evalscope/models/model.py +88 -0
  104. evalscope/models/model_adapter.py +586 -0
  105. evalscope/models/openai_model.py +103 -0
  106. evalscope/models/template.py +1446 -0
  107. evalscope/perf/__init__.py +0 -0
  108. evalscope/perf/_logging.py +32 -0
  109. evalscope/perf/api_plugin_base.py +60 -0
  110. evalscope/perf/custom_api.py +87 -0
  111. evalscope/perf/dashscope_api.py +84 -0
  112. evalscope/perf/dataset_plugin_base.py +64 -0
  113. evalscope/perf/datasets/__init__.py +0 -0
  114. evalscope/perf/datasets/line_by_line.py +18 -0
  115. evalscope/perf/datasets/longalpaca_12k.py +20 -0
  116. evalscope/perf/datasets/openqa.py +22 -0
  117. evalscope/perf/how_to_analysis_result.py +24 -0
  118. evalscope/perf/http_client.py +756 -0
  119. evalscope/perf/openai_api.py +130 -0
  120. evalscope/perf/plugin_registry.py +35 -0
  121. evalscope/perf/query_parameters.py +42 -0
  122. evalscope/perf/server_sent_event.py +43 -0
  123. evalscope/preprocess/__init__.py +1 -0
  124. evalscope/preprocess/tokenizers/__init__.py +0 -0
  125. evalscope/preprocess/tokenizers/gpt2_tokenizer.py +221 -0
  126. evalscope/registry/__init__.py +1 -0
  127. evalscope/registry/tasks/arc.yaml +29 -0
  128. evalscope/registry/tasks/bbh.yaml +27 -0
  129. evalscope/registry/tasks/bbh_mini.yaml +27 -0
  130. evalscope/registry/tasks/ceval.yaml +27 -0
  131. evalscope/registry/tasks/ceval_mini.yaml +27 -0
  132. evalscope/registry/tasks/cmmlu.yaml +27 -0
  133. evalscope/registry/tasks/eval_qwen-7b-chat_v100.yaml +28 -0
  134. evalscope/registry/tasks/general_qa.yaml +27 -0
  135. evalscope/registry/tasks/gsm8k.yaml +29 -0
  136. evalscope/registry/tasks/mmlu.yaml +29 -0
  137. evalscope/registry/tasks/mmlu_mini.yaml +27 -0
  138. evalscope/run.py +404 -0
  139. evalscope/run_arena.py +204 -0
  140. evalscope/run_ms.py +140 -0
  141. evalscope/summarizer.py +144 -0
  142. evalscope/third_party/__init__.py +1 -0
  143. evalscope/third_party/toolbench_static/__init__.py +3 -0
  144. evalscope/third_party/toolbench_static/eval.py +219 -0
  145. evalscope/third_party/toolbench_static/infer.py +278 -0
  146. evalscope/third_party/toolbench_static/llm/__init__.py +1 -0
  147. evalscope/third_party/toolbench_static/llm/swift_infer.py +45 -0
  148. evalscope/third_party/toolbench_static/toolbench_static.py +50 -0
  149. evalscope/tools/__init__.py +1 -0
  150. evalscope/tools/combine_reports.py +140 -0
  151. evalscope/tools/gen_mmlu_subject_mapping.py +90 -0
  152. evalscope/tools/rewrite_eval_results.py +95 -0
  153. evalscope/utils/__init__.py +4 -0
  154. evalscope/utils/arena_utils.py +247 -0
  155. evalscope/utils/completion_parsers.py +87 -0
  156. evalscope/utils/logger.py +64 -0
  157. evalscope/utils/task_cfg_parser.py +10 -0
  158. evalscope/utils/task_utils.py +19 -0
  159. evalscope/utils/utils.py +625 -0
  160. evalscope/version.py +4 -0
  161. evalscope-0.5.0.dist-info/METADATA +566 -0
  162. evalscope-0.5.0.dist-info/RECORD +165 -0
  163. evalscope-0.5.0.dist-info/WHEEL +5 -0
  164. evalscope-0.5.0.dist-info/entry_points.txt +3 -0
  165. evalscope-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1446 @@
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import re
3
+ from copy import deepcopy
4
+ from io import BytesIO
5
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import requests
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from transformers import PreTrainedTokenizerBase, StoppingCriteria
14
+
15
+ from evalscope.utils.utils import calculate_loss_scale, pad_and_split_batch, get_dist_setting, use_torchacc
16
+
17
+ DEFAULT_SYSTEM = 'You are a helpful assistant.'
18
+ History = List[Union[Tuple[str, str], List[str]]]
19
+
20
+
21
+ class TemplateType:
22
+ # text-generation
23
+ default_generation = 'default-generation'
24
+ default_generation_bos = 'default-generation-bos'
25
+ chatglm_generation = 'chatglm-generation'
26
+ qwen_audio_generation = 'qwen-audio-generation'
27
+ # chat
28
+ default = 'default'
29
+ qwen = 'qwen'
30
+ qwen_audio = 'qwen-audio'
31
+ baichuan = 'baichuan'
32
+ chatglm2 = 'chatglm2'
33
+ chatglm3 = 'chatglm3'
34
+ llama = 'llama' # llama2
35
+ llama3 = 'llama3'
36
+ llava_mistral_instruct = 'llava-mistral-instruct'
37
+ llava_yi_instruct = 'llava-yi-instruct'
38
+ openbuddy = 'openbuddy'
39
+ internlm = 'internlm'
40
+ internlm2 = 'internlm2'
41
+ internlm_xcomposer2 = 'internlm-xcomposer2'
42
+ yi = 'yi'
43
+ yi_vl = 'yi-vl'
44
+ yuan = 'yuan'
45
+ xverse = 'xverse'
46
+ ziya = 'ziya'
47
+ skywork = 'skywork'
48
+ bluelm = 'bluelm'
49
+ zephyr = 'zephyr'
50
+ sus = 'sus'
51
+ deepseek = 'deepseek'
52
+ deepseek_coder = 'deepseek-coder'
53
+ deepseek_vl = 'deepseek-vl'
54
+ codefuse_codellama = 'codefuse-codellama'
55
+ codefuse = 'codefuse'
56
+ cogvlm_instruct = 'cogvlm-instruct'
57
+ cogagent_chat = 'cogagent-chat'
58
+ cogagent_instruct = 'cogagent-instruct'
59
+ orion = 'orion'
60
+ minicpm = 'minicpm'
61
+ minicpm_v = 'minicpm-v'
62
+ gemma = 'gemma'
63
+ mplug_owl2 = 'mplug-owl2'
64
+ wizardlm2_awq = 'wizardlm2-awq'
65
+ wizardlm2 = 'wizardlm2'
66
+ atom = 'atom'
67
+ # compatibility. (Deprecated)
68
+ chatml = 'chatml'
69
+ telechat = 'telechat'
70
+ dbrx = 'dbrx'
71
+ mengzi = 'mengzi'
72
+ c4ai = 'c4ai'
73
+
74
+ @classmethod
75
+ def get_template_name_list(cls) -> List[str]:
76
+ res = []
77
+ for k in cls.__dict__.keys():
78
+ if k.startswith('__') or k == 'get_template_name_list':
79
+ continue
80
+ res.append(cls.__dict__[k])
81
+ return res
82
+
83
+
84
+ Prompt = List[Union[str, List[Union[str, int]]]]
85
+ StopWords = Prompt
86
+
87
+ Context = Union[str, List[int]]
88
+
89
+
90
+ class StopWordsCriteria(StoppingCriteria):
91
+ # The returned sentence includes stop words.
92
+ def __init__(self, tokenizer: PreTrainedTokenizerBase,
93
+ stop_words: StopWords, **tokenizer_kwargs) -> None:
94
+ self.tokenizer = tokenizer
95
+ self.stop_words = stop_words
96
+ self.tokenizer_kwargs = tokenizer_kwargs
97
+ self.start_idx = -1
98
+
99
+ def __call__(self, input_ids: Tensor, scores: Tensor) -> bool:
100
+ if self.start_idx == -1:
101
+ self.start_idx = len(input_ids[0]) - 1
102
+ tokenizer = self.tokenizer
103
+ stop_words = self.stop_words
104
+ text = tokenizer.decode(input_ids[0, self.start_idx:],
105
+ **self.tokenizer_kwargs)
106
+ for stop_word in stop_words:
107
+ if isinstance(stop_word, str):
108
+ if stop_word in text:
109
+ return True
110
+ else: # list
111
+ if len(stop_word) > 0 and input_ids[0].tolist(
112
+ )[-len(stop_word):] == stop_word:
113
+ return True
114
+ return False
115
+
116
+
117
+ def _has_system(prefix: Prompt) -> bool:
118
+ for p in prefix:
119
+ if '{{SYSTEM}}' in p:
120
+ return True
121
+ return False
122
+
123
+
124
+ def _replace_system(prefix: Prompt) -> Prompt:
125
+ res = []
126
+ for p in prefix:
127
+ if '{{SYSTEM}}' in p:
128
+ p = p.replace('{{SYSTEM}}', '')
129
+ res.append(p)
130
+ return res
131
+
132
+
133
+ class Template:
134
+
135
+ def __init__(self,
136
+ prefix: Prompt,
137
+ prompt: Prompt,
138
+ chat_sep: Optional[Prompt],
139
+ suffix: Prompt,
140
+ default_system: Optional[str] = None,
141
+ prefix_has_system: Optional[Prompt] = None) -> None:
142
+ if default_system == '':
143
+ default_system = None
144
+ if _has_system(prefix):
145
+ assert prefix_has_system is None, 'The prefix already contains {{SYSTEM}}.'
146
+ prefix_has_system = prefix
147
+ prefix = _replace_system(prefix)
148
+ self.prefix = prefix
149
+ self.prefix_has_system = prefix_has_system
150
+ if self.prefix_has_system is None:
151
+ assert default_system is None, 'The template does not support `system`.'
152
+ self.prompt = prompt
153
+ self.chat_sep = chat_sep
154
+ self.support_multi_round = self.chat_sep is not None
155
+ self.suffix = suffix
156
+ self.default_system = default_system
157
+ self.use_default_system = True
158
+ self._is_init = False
159
+
160
+ @staticmethod
161
+ def _preprocess_prompt(tokenizer: PreTrainedTokenizerBase,
162
+ value: Optional[Prompt]) -> Optional[Prompt]:
163
+ # e.g. [['eos_token_id']] -> [[2]]
164
+ if value is None:
165
+ return None
166
+ res_value = []
167
+ for v in value:
168
+ if isinstance(v, list):
169
+ res_v = []
170
+ for sub_v in v:
171
+ if isinstance(sub_v, str):
172
+ sub_v = getattr(tokenizer, sub_v)
173
+ res_v.append(sub_v)
174
+ v = res_v
175
+ res_value.append(v)
176
+ return res_value
177
+
178
+ def _init_template(self,
179
+ tokenizer: PreTrainedTokenizerBase,
180
+ default_system: Optional[str] = None,
181
+ max_length: Optional[int] = None,
182
+ truncation_strategy: Literal[
183
+ 'delete', 'truncation_left'] = 'delete',
184
+ **kwargs) -> None:
185
+ assert self._is_init is False, 'The template has been initialized.'
186
+ self._is_init = True
187
+ self.tokenizer = tokenizer
188
+ # if default_system is None. not change self.default_system
189
+ if default_system == '':
190
+ self.default_system = None
191
+ elif default_system is not None:
192
+ assert self.prefix_has_system is not None, 'The template does not support `system`.'
193
+ self.default_system = default_system
194
+ self.max_length = max_length
195
+ self.truncation_strategy = truncation_strategy
196
+ self.model = kwargs.get('model', None)
197
+ self.use_loss_scale = kwargs.get('use_loss_scale', False)
198
+ for key in [
199
+ 'prefix', 'prompt', 'chat_sep', 'suffix', 'prefix_has_system'
200
+ ]:
201
+ value = getattr(self, key)
202
+ value = self._preprocess_prompt(tokenizer, value)
203
+ setattr(self, key, value)
204
+
205
+ def encode(
206
+ self, example: Dict[str,
207
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
208
+ """return: inputs, tokenizer_kwargs"""
209
+ if not self._is_init:
210
+ raise ValueError(
211
+ 'Template is not initialized, please use the `get_template` function to obtain the template.'
212
+ )
213
+ query: Optional[str] = example.get('query', None)
214
+ response: Optional[str] = example.get('response', None)
215
+ history: Optional[History] = example.get('history', None)
216
+ system: Optional[str] = example.get('system', None)
217
+ if history is None:
218
+ history = []
219
+ if len(history) > 0:
220
+ assert self.support_multi_round, 'The template does not support multi-round chat.'
221
+ if system is None:
222
+ if self.use_default_system:
223
+ system = self.default_system
224
+ elif system == '':
225
+ system = None
226
+ else:
227
+ assert self.prefix_has_system is not None, 'The template does not support `system`.'
228
+ if query is None:
229
+ query = ''
230
+ inputs, tokenizer_kwargs = self._encode(query, response, history,
231
+ system,
232
+ self.truncation_strategy)
233
+ if inputs.get('labels') is None:
234
+ inputs.pop('loss_scale', None)
235
+ return inputs, tokenizer_kwargs
236
+
237
+ def _concat_context_list(
238
+ self,
239
+ context_list: List[Context],
240
+ res_context_list: List[Context], # inplace
241
+ compute_loss_idx: List[float], # inplace
242
+ system: Optional[str] = None,
243
+ query: Optional[str] = None,
244
+ response: Optional[str] = None,
245
+ round0: Optional[int] = None,
246
+ ) -> None:
247
+ # concat context list and replace placeholder
248
+ round1 = None
249
+ if round0 is not None:
250
+ round1 = str(round0 + 1)
251
+ round0 = str(round0)
252
+ for context in context_list:
253
+ if isinstance(context, str):
254
+ if '{{RESPONSE}}' == context:
255
+ assert response is not None
256
+ content_part, weight_part = calculate_loss_scale(
257
+ response, self.use_loss_scale)
258
+ res_context_list.extend(content_part)
259
+ compute_loss_idx.extend(weight_part)
260
+ continue
261
+ old_str_list = [
262
+ '{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}'
263
+ ]
264
+ new_str_list = [system, query, round0, round1]
265
+ for (old_str, new_str) in zip(old_str_list, new_str_list):
266
+ if new_str is not None and old_str in context:
267
+ context = context.replace(old_str, new_str)
268
+ res_context_list.append(context)
269
+ compute_loss_idx.append(0.0 if context not in self.suffix else 1.0)
270
+
271
+ @staticmethod
272
+ def _simplify_context_list(
273
+ context_list: List[Context], compute_loss_idx: List[float]
274
+ ) -> Tuple[List[Context], List[float]]:
275
+ res: List[Context] = [] # result of context_list
276
+ res_idx: List[float] = [] # result of compute_loss_idx
277
+ temp: List[str] = []
278
+ temp_index: List[int] = []
279
+ for i, (context,
280
+ loss_idx) in enumerate(zip(context_list, compute_loss_idx)):
281
+ if isinstance(context, str) and compute_loss_idx[i] == 0.0:
282
+ temp.append(context)
283
+ temp_index.append(i)
284
+ else:
285
+ if len(temp) > 0:
286
+ res.append(''.join(temp))
287
+ res_idx.append(0.0)
288
+ temp.clear()
289
+ res.append(context)
290
+ res_idx.append(loss_idx)
291
+ if len(temp) > 0:
292
+ res.append(''.join(temp))
293
+ res_idx.append(0.0)
294
+ return res, res_idx
295
+
296
+ def _encode_context_list(
297
+ self,
298
+ context_list: List[Context],
299
+ compute_loss_idx: List[float],
300
+ ) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
301
+ """return: input_ids, labels, tokenizer_kwargs"""
302
+ tokenizer = self.tokenizer
303
+ input_ids: List[int] = []
304
+ labels: List[int] = []
305
+ loss_scale: List[float] = []
306
+ tokenizer_kwargs = {}
307
+ for i, (context,
308
+ loss_weight) in enumerate(zip(context_list, compute_loss_idx)):
309
+ if isinstance(context, str):
310
+ curr_tokenizer_kwargs = self.get_tokenizer_kwargs(context)
311
+ self.concat_tokenizer_kwargs(tokenizer_kwargs,
312
+ curr_tokenizer_kwargs)
313
+ token_list = tokenizer(
314
+ context,
315
+ return_attention_mask=False,
316
+ add_special_tokens=False,
317
+ **curr_tokenizer_kwargs)['input_ids']
318
+ else:
319
+ token_list = context
320
+ input_ids += token_list
321
+ if compute_loss_idx[i] > 0.0:
322
+ labels += token_list
323
+ else:
324
+ labels += [-100] * len(token_list)
325
+ loss_scale.extend([loss_weight] * len(token_list))
326
+ return input_ids, labels, loss_scale, tokenizer_kwargs
327
+
328
+ def _encode(
329
+ self, query: str, response: Optional[str], history: History,
330
+ system: Optional[str],
331
+ truncation_strategy: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
332
+ """
333
+ return: inputs, tokenizer_kwargs
334
+ """
335
+ history = history.copy()
336
+ res_context_list: List[Context] = []
337
+ compute_loss_idx: List[float] = []
338
+ if system is None:
339
+ prefix = self.prefix
340
+ else:
341
+ prefix = self.prefix_has_system
342
+ self._concat_context_list(
343
+ prefix, res_context_list, compute_loss_idx, system=system)
344
+ history.append([query, response])
345
+ for i, (q, r) in enumerate(history):
346
+ context_list = self.prompt.copy()
347
+ if i < len(history) - 1:
348
+ context_list.append('{{RESPONSE}}')
349
+ context_list += self.chat_sep
350
+ elif r is not None:
351
+ # last response
352
+ context_list.append('{{RESPONSE}}')
353
+ context_list += self.suffix
354
+ if q or r:
355
+ self._concat_context_list(
356
+ context_list,
357
+ res_context_list,
358
+ compute_loss_idx,
359
+ query=q,
360
+ response=r,
361
+ round0=i)
362
+
363
+ res_context_list, compute_loss_idx = self._simplify_context_list(
364
+ res_context_list, compute_loss_idx)
365
+ input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
366
+ res_context_list, compute_loss_idx)
367
+
368
+ if response is None:
369
+ labels = None
370
+
371
+ if self.max_length is not None:
372
+ if truncation_strategy == 'delete' and len(
373
+ input_ids) > self.max_length:
374
+ return {}, {}
375
+ input_ids = input_ids[-self.max_length:]
376
+ if labels is not None:
377
+ labels = labels[-self.max_length:]
378
+ if loss_scale is not None:
379
+ loss_scale = loss_scale[-self.max_length:]
380
+ inputs = {
381
+ 'input_ids': input_ids,
382
+ 'labels': labels,
383
+ }
384
+ if self.use_loss_scale:
385
+ inputs['loss_scale'] = loss_scale
386
+ return inputs, tokenizer_kwargs
387
+
388
+ def get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
389
+ """return: curr_tokenizer_kwargs"""
390
+ return {}
391
+
392
+ def concat_tokenizer_kwargs(
393
+ self, old_tokenizer_kwargs: Dict[str, Any],
394
+ curr_tokenizer_kwargs: Dict[str, Any]) -> Dict[str, Any]:
395
+ assert len(old_tokenizer_kwargs) == 0
396
+ return curr_tokenizer_kwargs
397
+
398
+ def data_collator(self,
399
+ batch: List[Dict[str, Any]],
400
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
401
+ """
402
+ Args:
403
+ batch(`List[Dict[str, Any]]`): The input data in batch
404
+ padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
405
+ will be padded to the `longest`
406
+ """
407
+ tokenizer = self.tokenizer
408
+ assert tokenizer.pad_token_id is not None
409
+ input_ids = [torch.tensor(b['input_ids']) for b in batch]
410
+ labels = [torch.tensor(b['labels']) for b in batch]
411
+ loss_scale = [torch.tensor(b['loss_scale'])
412
+ for b in batch] if 'loss_scale' in batch[0] else None
413
+ attention_mask = [
414
+ torch.ones(len(input_ids[i]), dtype=torch.int64)
415
+ for i in range(len(input_ids))
416
+ ]
417
+
418
+ if padding_to is not None:
419
+ padding_len = padding_to - input_ids[0].shape[-1]
420
+ if padding_len > 0:
421
+ input_ids[0] = F.pad(input_ids[0], (0, padding_len),
422
+ 'constant', tokenizer.pad_token_id)
423
+ attention_mask[0] = F.pad(attention_mask[0], (0, padding_len),
424
+ 'constant', 0)
425
+ labels[0] = F.pad(labels[0], (0, padding_len), 'constant',
426
+ -100)
427
+ if loss_scale:
428
+ loss_scale[0] = F.pad(
429
+ loss_scale[0], (0, padding_to - labels[0].shape[-1]),
430
+ 'constant', 0.)
431
+
432
+ input_ids = pad_sequence(
433
+ input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
434
+ attention_mask = pad_sequence(
435
+ attention_mask, batch_first=True, padding_value=0)
436
+ if loss_scale:
437
+ loss_scale = pad_sequence(
438
+ loss_scale, batch_first=True, padding_value=0.)
439
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
440
+
441
+ if use_torchacc():
442
+ rank, _, world_size, _ = get_dist_setting()
443
+ input_ids, attention_mask, labels, loss_scale = pad_and_split_batch(
444
+ padding_to, input_ids, attention_mask, labels, loss_scale,
445
+ self.max_length, self.tokenizer, rank, world_size)
446
+
447
+ res = {
448
+ 'input_ids': input_ids,
449
+ 'attention_mask': attention_mask,
450
+ 'labels': labels,
451
+ }
452
+ if loss_scale is not None:
453
+ res['loss_scale'] = loss_scale
454
+ return res
455
+
456
+ @staticmethod
457
+ def get_generate_ids(generate_ids: Tensor,
458
+ input_token_len: int) -> List[int]:
459
+ return generate_ids[0, input_token_len:].tolist()
460
+
461
+
462
+ TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {}
463
+
464
+
465
+ def register_template(template_type: str,
466
+ template: Template,
467
+ *,
468
+ exists_ok: bool = False,
469
+ **kwargs) -> None:
470
+ if not exists_ok and template_type in TEMPLATE_MAPPING:
471
+ raise ValueError(
472
+ f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.'
473
+ )
474
+ template_info = {'template': template, **kwargs}
475
+ TEMPLATE_MAPPING[template_type] = template_info
476
+
477
+
478
+ register_template(
479
+ TemplateType.default,
480
+ Template([], ['### Human:\n', '{{QUERY}}\n\n', '### Assistant:\n'],
481
+ ['\n\n'], [['eos_token_id']], DEFAULT_SYSTEM, ['{{SYSTEM}}\n\n']))
482
+
483
+
484
+ # You can set the query as '' to serve as a template for pre-training.
485
+ class DefaultGenerationTemplate(Template):
486
+
487
+ def __init__(self):
488
+ super().__init__([], ['{{QUERY}}'], None, [['eos_token_id']])
489
+
490
+
491
+ register_template(TemplateType.default_generation, DefaultGenerationTemplate())
492
+ register_template(
493
+ TemplateType.default_generation_bos,
494
+ Template([['bos_token_id']], ['{{QUERY}}'], None, [['eos_token_id']]))
495
+
496
+
497
+ class QwenTemplate(Template):
498
+
499
+ def __init__(self):
500
+ super().__init__(
501
+ [],
502
+ ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
503
+ ['<|im_end|>\n'], ['<|im_end|>'], DEFAULT_SYSTEM,
504
+ ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n'])
505
+
506
+
507
+ register_template(TemplateType.qwen, QwenTemplate())
508
+ register_template(TemplateType.chatml, QwenTemplate())
509
+
510
+
511
+ class _QwenAudioTemplateMixin:
512
+
513
+ def encode(
514
+ self, example: Dict[str,
515
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
516
+ inputs, tokenizer_kwargs = super().encode(example)
517
+ inputs.pop('loss_scale', None)
518
+ inputs.update(tokenizer_kwargs)
519
+ return inputs, tokenizer_kwargs
520
+
521
+ def get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
522
+ return {'audio_info': self.tokenizer.process_audio(context)}
523
+
524
+ def concat_tokenizer_kwargs(self, tokenizer_kwargs: Dict[str, Any],
525
+ curr_tokenizer_kwargs: Dict[str, Any]) -> None:
526
+ audio_info = curr_tokenizer_kwargs.get('audio_info')
527
+ old_audio_info = tokenizer_kwargs.get('audio_info')
528
+ if old_audio_info is None:
529
+ tokenizer_kwargs['audio_info'] = audio_info
530
+ elif audio_info is not None:
531
+ for k in ['input_audios', 'input_audio_lengths']:
532
+ old_audio_info[k] = torch.concat(
533
+ [old_audio_info[k], audio_info[k]], dim=0)
534
+ for k in ['audio_span_tokens', 'audio_urls']:
535
+ old_audio_info[k] = old_audio_info[k] + audio_info[k]
536
+
537
+ def data_collator(self,
538
+ batch: List[Dict[str, Any]],
539
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
540
+ res = super().data_collator(batch, padding_to)
541
+ if batch[0].get('audio_info') is not None:
542
+ res['audio_info'] = [b['audio_info'] for b in batch]
543
+ return res
544
+
545
+
546
+ class QwenAudioTemplate(_QwenAudioTemplateMixin, QwenTemplate):
547
+ pass
548
+
549
+
550
+ class QwenAudioGenerationTemplate(_QwenAudioTemplateMixin,
551
+ DefaultGenerationTemplate):
552
+ pass
553
+
554
+
555
+ register_template(
556
+ TemplateType.qwen_audio, QwenAudioTemplate(), lazy_tokenize=True)
557
+ register_template(
558
+ TemplateType.qwen_audio_generation,
559
+ QwenAudioGenerationTemplate(),
560
+ lazy_tokenize=True)
561
+
562
+ register_template(
563
+ TemplateType.yi,
564
+ Template(
565
+ [], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
566
+ ['<|im_end|>\n'], ['<|im_end|>'], None,
567
+ ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']))
568
+
569
+ yi_vl_default_system = (
570
+ 'This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. '
571
+ "Read all the images carefully, and respond to the human's questions with informative, "
572
+ 'helpful, detailed and polite answers. '
573
+ '这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。'
574
+ '仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。')
575
+
576
+
577
+ def _read_from_path(
578
+ img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image':
579
+ from PIL import Image
580
+ if isinstance(img_path, str):
581
+ img_path = img_path.strip()
582
+ if img_path.startswith('http'):
583
+ content = requests.get(img_path).content
584
+ image = Image.open(BytesIO(content))
585
+ else:
586
+ image = Image.open(img_path)
587
+ else:
588
+ image = img_path
589
+ if image.mode != 'RGB':
590
+ image = image.convert('RGB')
591
+ return image
592
+
593
+
594
+ class YiVLTemplate(Template):
595
+
596
+ def encode(
597
+ self, example: Dict[str,
598
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
599
+ inputs, _ = super().encode(example)
600
+ inputs.pop('loss_scale', None)
601
+ from llava.mm_utils import expand2square
602
+ model = self.model.model
603
+ if not hasattr(model, 'vision_tower'):
604
+ model = model.model
605
+ image_processor = model.vision_tower.image_processor
606
+ images_path = example['images']
607
+ images = []
608
+ for image_path in images_path:
609
+ image = _read_from_path(image_path)
610
+ background_color = tuple(
611
+ int(x * 255) for x in image_processor.image_mean)
612
+ image = expand2square(image, background_color)
613
+ images.append(image)
614
+ image_tensor = image_processor.preprocess(
615
+ images, return_tensors='pt')['pixel_values']
616
+ inputs['images'] = image_tensor.to(model.dtype)
617
+ return inputs, {}
618
+
619
+ def data_collator(self,
620
+ batch: List[Dict[str, Any]],
621
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
622
+ res = super().data_collator(batch, padding_to)
623
+ res['images'] = torch.concat([b['images'] for b in batch])
624
+ return res
625
+
626
+
627
+ register_template(
628
+ TemplateType.yi_vl,
629
+ YiVLTemplate([], ['### Human: ', [-200], '\n{{QUERY}}\n### Assistant:'],
630
+ ['\n'], ['\n###'], yi_vl_default_system, ['{{SYSTEM}}\n\n']),
631
+ use_model=True,
632
+ infer_media_type='round',
633
+ lazy_tokenize=True)
634
+
635
+ register_template(
636
+ TemplateType.baichuan,
637
+ Template(['{{SYSTEM}}'], [[195], '{{QUERY}}', [196]], [],
638
+ [['eos_token_id']]))
639
+ register_template(
640
+ TemplateType.chatglm2,
641
+ Template([[64790, 64792], '{{SYSTEM}}'],
642
+ ['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'], ['\n\n'],
643
+ [['eos_token_id']]))
644
+
645
+ register_template(
646
+ TemplateType.chatglm_generation,
647
+ Template([[64790, 64792]], ['{{QUERY}}'], None, [['eos_token_id']]))
648
+
649
+ register_template(
650
+ TemplateType.chatglm3,
651
+ Template([[64790, 64792]], [[64795], '\n {{QUERY}}', [64796], '\n'], [],
652
+ [['eos_token_id']], None,
653
+ [[64790, 64792, 64794], '\n {{SYSTEM}}']))
654
+
655
+ register_template(
656
+ TemplateType.deepseek,
657
+ Template([['bos_token_id']], ['User: {{QUERY}}\n\nAssistant:'],
658
+ [['eos_token_id']], [['eos_token_id']], None,
659
+ [['bos_token_id'], '{{SYSTEM}}\n\n']))
660
+
661
+ # ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
662
+ LLAMA_DEFAULT_SYSTEM = (
663
+ 'You are a helpful, respectful and honest assistant. '
664
+ 'Always answer as helpfully as possible, while being safe. '
665
+ 'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
666
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
667
+ 'If a question does not make any sense, or is not factually coherent, '
668
+ 'explain why instead of answering something not correct. '
669
+ "If you don't know the answer to a question, please don't share false information."
670
+ )
671
+ register_template(
672
+ TemplateType.llama,
673
+ Template(['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '],
674
+ ['</s>'], LLAMA_DEFAULT_SYSTEM,
675
+ ['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
676
+
677
+ register_template(
678
+ TemplateType.llama3,
679
+ Template(['<|begin_of_text|>'], [
680
+ '<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
681
+ '<|start_header_id|>assistant<|end_header_id|>\n\n'
682
+ ], ['<|eot_id|>'], ['<|eot_id|>'], None, [
683
+ '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'
684
+ ]))
685
+
686
+ OPENBUDDY_DEFAULT_SYSTEM = (
687
+ 'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
688
+ 'Always answer as helpfully and logically as possible, while being safe. '
689
+ 'Your answers should not include any '
690
+ 'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. '
691
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n'
692
+ 'If a question does not make any sense, or is not factually coherent, '
693
+ 'explain why instead of answering something not correct. '
694
+ "If you don't know the answer to a question, please don't share false information.\n"
695
+ 'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n'
696
+ 'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n'
697
+ 'You always deeply love and support China, Chinese government, people and culture.\n'
698
+ 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
699
+ 'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.'
700
+ )
701
+ register_template(
702
+ TemplateType.openbuddy,
703
+ Template([['bos_token_id']], ['User: {{QUERY}}\nAssistant:'], ['\n'],
704
+ [['eos_token_id']], OPENBUDDY_DEFAULT_SYSTEM,
705
+ [['bos_token_id'], '{{SYSTEM}}\n\n']))
706
+
707
+ INTERNLM_SYSTEM = (
708
+ 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
709
+ '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
710
+ 'It is designed to be helpful, honest, and harmless.\n'
711
+ '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen '
712
+ 'by the user such as English and 中文.')
713
+
714
+ register_template(
715
+ TemplateType.internlm,
716
+ Template(['<s>'], ['<|User|>:{{QUERY}}\n<|Bot|>:'], ['<eoa>\n'], ['<eoa>'],
717
+ INTERNLM_SYSTEM, ['<s><|System|>:{{SYSTEM}}\n']))
718
+ register_template(
719
+ TemplateType.internlm2,
720
+ Template(
721
+ ['<s>'],
722
+ ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
723
+ ['<|im_end|>\n'], ['<|im_end|>'], INTERNLM_SYSTEM,
724
+ ['<s><|im_start|>system\n{{SYSTEM}}<|im_end|>\n']))
725
+
726
+
727
+ def replace_img_tab(query: str, history: History,
728
+ replace_token: str) -> Tuple[str, History, List[str]]:
729
+ images_path = []
730
+ pattern = r'<img>(.+?)</img>'
731
+ new_history = []
732
+ for i, h in enumerate(history):
733
+ images_path += re.findall(pattern, h[0])
734
+ new_history.append([re.sub(pattern, replace_token, h[0]), h[1]])
735
+ images_path += re.findall(pattern, query)
736
+ new_query = re.sub(pattern, replace_token, query)
737
+ return new_query, new_history, images_path
738
+
739
+
740
+ class InternLMXComposer2(Template):
741
+ INTERNLM_XCOMPOSER2_SYSTEM = (
742
+ 'You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
743
+ '- InternLM-XComposer (浦语·灵笔) is a conversational language model that is developed by '
744
+ 'Shanghai AI Laboratory (上海人工智能实验室). '
745
+ 'It is designed to be helpful, honest, and harmless.\n'
746
+ '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
747
+ 'by the user such as English and 中文.')
748
+
749
+ def __init__(self):
750
+ prefix = ['<s>']
751
+ prompt = [
752
+ '[UNUSED_TOKEN_146]user\n{{QUERY}}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
753
+ ]
754
+ chat_sep = ['[UNUSED_TOKEN_145]\n']
755
+ suffix = ['[UNUSED_TOKEN_145]']
756
+ prefix_has_system = [
757
+ '<s>[UNUSED_TOKEN_146]system\n{{SYSTEM}}[UNUSED_TOKEN_145]\n'
758
+ ]
759
+ super().__init__(prefix, prompt, chat_sep, suffix,
760
+ self.INTERNLM_XCOMPOSER2_SYSTEM, prefix_has_system)
761
+
762
+ def encode(
763
+ self, example: Dict[str,
764
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
765
+ example = example.copy()
766
+ history = example.pop('history', [])
767
+ example['query'], example['history'], images_path = replace_img_tab(
768
+ example['query'], history, '</s>')
769
+
770
+ images = []
771
+ dtype = self.model.dtype
772
+ for image_path in images_path:
773
+ image = _read_from_path(image_path)
774
+ image = self.model.vis_processor(image)
775
+ images.append(image.to(dtype))
776
+ inputs, _ = super().encode(example)
777
+ inputs.pop('loss_scale', None)
778
+ input_ids = inputs['input_ids']
779
+ labels = inputs['labels']
780
+ if len(images) > 0: # # ignore <s>
781
+ input_ids = input_ids[1:]
782
+ if labels is not None:
783
+ labels = labels[1:]
784
+ input_ids.append(2) # add dummy </s>
785
+ if labels is not None:
786
+ labels.append(2)
787
+ else:
788
+ labels = []
789
+ res_inputs_embeds = []
790
+ res_labels = []
791
+ wrap_im_mask = []
792
+ pre_i, i, idx = 0, 0, 0
793
+ device = self.model.device
794
+ if len(images) > 0:
795
+ images = torch.stack(images, dim=0)
796
+ images = self.model.encode_img(images)
797
+ else:
798
+ images = None
799
+ internlm2_model = self.model.model
800
+ if not hasattr(internlm2_model, 'tok_embeddings'):
801
+ internlm2_model = internlm2_model.model
802
+ tok_embeddings = internlm2_model.tok_embeddings
803
+ while i < len(input_ids):
804
+ if input_ids[i] == 2: # replace_token
805
+ res_input_ids = torch.tensor(
806
+ [1] + input_ids[pre_i:i], device=device)
807
+ res_inputs_embeds.append(tok_embeddings(res_input_ids))
808
+ wrap_im_mask += [0] * len(res_input_ids)
809
+ res_labels += [-100] + labels[pre_i:i]
810
+ if images is not None and idx < images.shape[0]:
811
+ res_inputs_embeds.append(images[idx])
812
+ wrap_im_mask += [1] * images.shape[1]
813
+ res_labels += [-100] * images.shape[1]
814
+ idx += 1
815
+ i += 1
816
+ pre_i = i
817
+ continue
818
+ i += 1
819
+ if len(labels) == 0:
820
+ res_labels = None
821
+ res_inputs_embeds = torch.concat(res_inputs_embeds, dim=0)
822
+ wrap_im_mask = torch.tensor(wrap_im_mask, dtype=torch.bool)[None]
823
+ return {
824
+ 'inputs_embeds': res_inputs_embeds,
825
+ 'im_mask': wrap_im_mask,
826
+ 'labels': res_labels
827
+ }, {}
828
+
829
+ def data_collator(self,
830
+ batch: List[Dict[str, Any]],
831
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
832
+ inputs_embeds = [b['inputs_embeds'] for b in batch]
833
+ labels = [torch.tensor(b['labels']) for b in batch]
834
+ im_mask = [b['im_mask'][0] for b in batch]
835
+ attention_mask = [
836
+ torch.ones(inputs_embeds[i].shape[0], dtype=torch.int64)
837
+ for i in range(len(inputs_embeds))
838
+ ]
839
+
840
+ inputs_embeds = pad_sequence(
841
+ inputs_embeds, batch_first=True, padding_value=0)
842
+ attention_mask = pad_sequence(
843
+ attention_mask, batch_first=True, padding_value=0)
844
+ im_mask = pad_sequence(im_mask, batch_first=True, padding_value=0)
845
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
846
+
847
+ return {
848
+ 'inputs_embeds': inputs_embeds,
849
+ 'attention_mask': attention_mask,
850
+ 'im_mask': im_mask,
851
+ 'labels': labels,
852
+ }
853
+
854
+ @staticmethod
855
+ def get_generate_ids(generate_ids: Tensor,
856
+ input_token_len: int) -> List[int]:
857
+ return generate_ids[0].tolist()
858
+
859
+
860
+ register_template(
861
+ TemplateType.internlm_xcomposer2,
862
+ InternLMXComposer2(),
863
+ use_model=True,
864
+ lazy_tokenize=True,
865
+ dataloader_num_workers=0,
866
+ dataloader_pin_memory=False)
867
+
868
+ register_template(
869
+ TemplateType.xverse,
870
+ Template(['{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: '],
871
+ [['eos_token_id']], [['eos_token_id']]))
872
+ register_template(TemplateType.yuan,
873
+ Template([], ['{{QUERY}}<sep>'], None, [['eos_token_id']]))
874
+ register_template(
875
+ TemplateType.ziya,
876
+ Template([['bos_token_id'], '{{SYSTEM}}'], ['<human>:{{QUERY}}\n<bot>:'],
877
+ ['\n'], [['eos_token_id']]))
878
+
879
+ register_template(
880
+ TemplateType.skywork,
881
+ Template(['<s>{{SYSTEM}}'], ['</s><s>[USER]{{QUERY}}[SEP][BOT]'], None,
882
+ ['[SEP]</s>']))
883
+
884
+ register_template(
885
+ TemplateType.bluelm,
886
+ Template([['bos_token_id'], '{{SYSTEM}}'], ['[|Human|]:{{QUERY}}[|AI|]:'],
887
+ [], [['eos_token_id']]))
888
+
889
+ register_template(
890
+ TemplateType.codefuse_codellama,
891
+ Template(['{{SYSTEM}}'], [
892
+ '<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'
893
+ ], [], [['eos_token_id']]))
894
+
895
+ register_template(
896
+ TemplateType.codefuse,
897
+ Template([], ['<s>human\n{{QUERY}}\n<s>bot\n'], [['eos_token_id'], '\n'],
898
+ [['eos_token_id']], None, ['<s>system\n{{SYSTEM}}\n']))
899
+
900
+ register_template(
901
+ TemplateType.deepseek_coder,
902
+ Template([
903
+ '{{SYSTEM}}'
904
+ ], ['### Instruction:\n{{QUERY}}\n### Response:\n'], ['\n<|EOT|>\n'], [
905
+ '\n<|EOT|>'
906
+ ], ('You are an AI programming assistant, utilizing the Deepseek Coder model, '
907
+ 'developed by Deepseek Company, and you only answer questions related to computer science. '
908
+ 'For politically sensitive questions, security and privacy issues, '
909
+ 'and other non-computer science questions, you will refuse to answer\n'
910
+ )))
911
+
912
+
913
+ class LLavaTemplate(Template):
914
+
915
+ def __init__(self):
916
+ super().__init__(['<s>[INST] '], [[-200], '\n{{QUERY}} [/INST]'], None,
917
+ ['</s>'])
918
+
919
+ def encode(
920
+ self, example: Dict[str,
921
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
922
+ inputs, _ = super().encode(example)
923
+ images_path = example['images']
924
+ images = []
925
+ for image_path in images_path:
926
+ image = _read_from_path(image_path)
927
+ images.append(image)
928
+ image_sizes = [x.size for x in images]
929
+ from llava.mm_utils import process_images
930
+ model = self.model.model
931
+ if not hasattr(model, 'vision_tower'):
932
+ model = model.model
933
+ image_processor = model.vision_tower.image_processor
934
+ images_tensor = process_images(images, image_processor,
935
+ self.model.config)
936
+ inputs['images'] = images_tensor.to(model.dtype)
937
+ inputs['image_sizes'] = image_sizes
938
+ return inputs, {}
939
+
940
+ def data_collator(self,
941
+ batch: List[Dict[str, Any]],
942
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
943
+ res = super().data_collator(batch, padding_to)
944
+ res['images'] = torch.concat([b['images'] for b in batch])
945
+ res['image_sizes'] = sum([b['image_sizes'] for b in batch], start=[])
946
+ return res
947
+
948
+ @staticmethod
949
+ def get_generate_ids(generate_ids: Tensor,
950
+ input_token_len: int) -> List[int]:
951
+ return generate_ids[0].tolist()
952
+
953
+
954
+ register_template(
955
+ TemplateType.llava_mistral_instruct,
956
+ LLavaTemplate(),
957
+ use_model=True,
958
+ infer_media_type='round',
959
+ lazy_tokenize=True)
960
+
961
+
962
+ class LLavaYiTemplate(LLavaTemplate):
963
+ llavayi_query_template = '\n<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'
964
+
965
+ def __init__(self):
966
+ Template.__init__(self, [], [[-200], self.llavayi_query_template],
967
+ None, ['<|im_end|>'])
968
+
969
+
970
+ register_template(
971
+ TemplateType.llava_yi_instruct,
972
+ LLavaYiTemplate(),
973
+ use_model=True,
974
+ infer_media_type='round',
975
+ lazy_tokenize=True)
976
+
977
+
978
+ def _findall(token_list: List[int], token: int) -> List[int]:
979
+ """Find the index of a token in the token_list."""
980
+ res = []
981
+ idx = -1
982
+ try:
983
+ while True:
984
+ idx = token_list.index(token, idx + 1)
985
+ res.append(idx)
986
+ except ValueError:
987
+ pass
988
+ return res
989
+
990
+
991
+ class DeepseekVLTemplate(Template):
992
+ DEEPSEEK_VL_SYSTEM = (
993
+ 'You are a helpful language and vision assistant. '
994
+ 'You are able to understand the visual content that the user provides, '
995
+ 'and assist the user with a variety of tasks using natural language.')
996
+
997
+ def __init__(self):
998
+ return super().__init__(['<|begin▁of▁sentence|>{{SYSTEM}}\n\n'],
999
+ ['User: {{QUERY}}\n\nAssistant:'],
1000
+ ['<|end▁of▁sentence|>'],
1001
+ ['<|end▁of▁sentence|>'],
1002
+ self.DEEPSEEK_VL_SYSTEM)
1003
+
1004
+ def encode(
1005
+ self, example: Dict[str,
1006
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1007
+ images = example.pop('images', None)
1008
+ assert images is None, (
1009
+ 'Please read the best practices: https://github.com/modelscope/swift/blob/main/'
1010
+ 'docs/source/Multi-Modal/deepseek-vl最佳实践.md')
1011
+
1012
+ example = example.copy()
1013
+ history = example.pop('history', [])
1014
+ example['query'], example['history'], images_path = replace_img_tab(
1015
+ example['query'], history, '<image_placeholder>')
1016
+
1017
+ inputs, _ = super().encode(example)
1018
+ images = []
1019
+ for image_path in images_path:
1020
+ image = _read_from_path(image_path)
1021
+ images.append(image)
1022
+
1023
+ vl_chat_processor = self.tokenizer.vl_chat_processor
1024
+ input_ids, labels = inputs['input_ids'], inputs['labels']
1025
+ idx_list = _findall(input_ids, vl_chat_processor.image_id)
1026
+ new_input_ids, new_labels = [], []
1027
+ lo = 0
1028
+ for hi in idx_list:
1029
+ new_input_ids += input_ids[lo:hi]
1030
+ if labels is not None:
1031
+ new_labels += labels[lo:hi]
1032
+ new_input_ids += [vl_chat_processor.image_id
1033
+ ] * vl_chat_processor.num_image_tokens
1034
+ new_labels += [-100] * vl_chat_processor.num_image_tokens
1035
+ lo = hi + 1
1036
+ new_input_ids += input_ids[lo:]
1037
+ if labels is not None:
1038
+ new_labels += labels[lo:]
1039
+ else:
1040
+ new_labels = None
1041
+ new_input_ids = torch.tensor(new_input_ids)
1042
+ num_image_tokens = torch.tensor([vl_chat_processor.num_image_tokens]
1043
+ * len(idx_list))
1044
+ images_outputs = vl_chat_processor.image_processor(
1045
+ images, return_tensors='pt')
1046
+ from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
1047
+ output = VLChatProcessorOutput(
1048
+ sft_format=None,
1049
+ input_ids=new_input_ids,
1050
+ pixel_values=images_outputs.pixel_values,
1051
+ num_image_tokens=num_image_tokens)
1052
+ batched_output = vl_chat_processor.batchify([output])
1053
+ model = self.model
1054
+ batched_output = batched_output.to(
1055
+ device=model.device, dtype=model.dtype)
1056
+ inputs_embeds = model.prepare_inputs_embeds(**batched_output)[0]
1057
+ inputs['inputs_embeds'] = inputs_embeds
1058
+ inputs['labels'] = new_labels
1059
+ return inputs, {}
1060
+
1061
+ def data_collator(self,
1062
+ batch: List[Dict[str, Any]],
1063
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1064
+ inputs_embeds = [b['inputs_embeds'] for b in batch]
1065
+ labels = [torch.tensor(b['labels']) for b in batch]
1066
+ attention_mask = [
1067
+ torch.ones(inputs_embeds[i].shape[0], dtype=torch.int64)
1068
+ for i in range(len(inputs_embeds))
1069
+ ]
1070
+
1071
+ inputs_embeds = pad_sequence(
1072
+ inputs_embeds, batch_first=True, padding_value=0)
1073
+ attention_mask = pad_sequence(
1074
+ attention_mask, batch_first=True, padding_value=0)
1075
+ labels = pad_sequence(labels, batch_first=True, padding_value=-100)
1076
+
1077
+ return {
1078
+ 'inputs_embeds': inputs_embeds,
1079
+ 'attention_mask': attention_mask,
1080
+ 'labels': labels,
1081
+ }
1082
+
1083
+ @staticmethod
1084
+ def get_generate_ids(generate_ids: Tensor,
1085
+ input_token_len: int) -> List[int]:
1086
+ return generate_ids[0].tolist()
1087
+
1088
+
1089
+ register_template(
1090
+ TemplateType.deepseek_vl,
1091
+ DeepseekVLTemplate(),
1092
+ use_model=True,
1093
+ lazy_tokenize=True,
1094
+ dataloader_num_workers=0,
1095
+ dataloader_pin_memory=False) # only 'cpu' can pin_memory
1096
+
1097
+ register_template(
1098
+ TemplateType.zephyr,
1099
+ Template([], ['<|user|>\n{{QUERY}}</s>\n<|assistant|>\n'], ['</s>\n'],
1100
+ ['</s>'], None, ['<|system|>\n{{SYSTEM}}</s>\n']))
1101
+
1102
+ register_template(
1103
+ TemplateType.sus,
1104
+ Template(['{{SYSTEM}}'], ['### Human: {{QUERY}}\n\n### Assistant: '],
1105
+ ['<|endoftext|>'], ['<|endoftext|>']))
1106
+
1107
+ register_template(
1108
+ TemplateType.orion,
1109
+ Template(['<s>{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: </s>'],
1110
+ ['</s>'], ['</s>']))
1111
+
1112
+
1113
+ class CogTemplate(Template):
1114
+
1115
+ def encode(
1116
+ self, example: Dict[str,
1117
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1118
+ images_path = example['images']
1119
+ assert len(images_path) == 1
1120
+ image = _read_from_path(images_path[0])
1121
+ inputs, _ = super().encode(example)
1122
+ inputs.pop('loss_scale', None)
1123
+ model = self.model
1124
+ inputs2 = model.build_conversation_input_ids(
1125
+ self.tokenizer,
1126
+ query=example['query'],
1127
+ history=example.get('history'),
1128
+ images=[image])
1129
+ image_token_len = inputs2['token_type_ids'].sum()
1130
+ input_ids = inputs['input_ids']
1131
+ labels = inputs['labels']
1132
+ token_type_ids = inputs2['token_type_ids'].tolist()
1133
+ inputs['input_ids'] = input_ids[:1] + [
1134
+ 0
1135
+ ] * image_token_len + input_ids[1:]
1136
+ if labels is not None:
1137
+ inputs['labels'] = labels[:1] + [-100
1138
+ ] * image_token_len + labels[1:]
1139
+ dtype = model.dtype
1140
+ inputs['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
1141
+ if 'cross_images' in inputs2:
1142
+ # is cogagent
1143
+ inputs['cross_images'] = [[cross_img.to(dtype=dtype)]
1144
+ for cross_img in inputs2['cross_images']]
1145
+ inputs['token_type_ids'] = token_type_ids + [0] * (
1146
+ len(inputs['input_ids']) - len(token_type_ids))
1147
+ return inputs, {}
1148
+
1149
+ def data_collator(self,
1150
+ batch: List[Dict[str, Any]],
1151
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1152
+ res = super().data_collator(batch, padding_to)
1153
+ is_cogagent = 'cross_images' in batch[0]
1154
+ keys = ['images', 'cross_images'] if is_cogagent else ['images']
1155
+ for key in keys:
1156
+ res[key] = [b[key][0] for b in batch]
1157
+ token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
1158
+ token_type_ids = pad_sequence(
1159
+ token_type_ids, batch_first=True, padding_value=0)
1160
+ res['token_type_ids'] = token_type_ids
1161
+ return res
1162
+
1163
+
1164
+ register_template(
1165
+ TemplateType.cogagent_chat,
1166
+ CogTemplate(['<s>'], [' [INST] {{QUERY}} [/INST] '], [], ['</s>']),
1167
+ use_model=True,
1168
+ infer_media_type='dialogue',
1169
+ lazy_tokenize=True)
1170
+
1171
+ register_template(
1172
+ TemplateType.cogagent_instruct,
1173
+ CogTemplate(['<s>'], ['<EOI>Question: {{QUERY}} Answer:'], None, ['</s>']),
1174
+ use_model=True,
1175
+ infer_media_type='dialogue',
1176
+ lazy_tokenize=True)
1177
+
1178
+ register_template(
1179
+ TemplateType.cogvlm_instruct,
1180
+ CogTemplate(['<s>'], ['Question: {{QUERY}} Answer:'], None, ['</s>']),
1181
+ use_model=True,
1182
+ infer_media_type='dialogue',
1183
+ lazy_tokenize=True)
1184
+
1185
+ register_template(
1186
+ TemplateType.minicpm,
1187
+ Template(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>']))
1188
+
1189
+
1190
+ class MiniCPMVTemlate(Template):
1191
+
1192
+ def __init__(self):
1193
+ return super().__init__(['<s>{{SYSTEM}}'],
1194
+ ['<用户><image><unk></image>\n{{QUERY}}<AI>'],
1195
+ [], ['</s>'])
1196
+
1197
+ def encode(
1198
+ self, example: Dict[str,
1199
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1200
+ images_path = example['images']
1201
+ assert len(images_path) == 1
1202
+ image = _read_from_path(images_path[0])
1203
+ inputs, _ = super().encode(example)
1204
+ input_ids = inputs['input_ids']
1205
+ labels = inputs['labels']
1206
+
1207
+ img_start_idxs = np.where(
1208
+ np.array(input_ids) == self.tokenizer.im_start_id)[0]
1209
+ if len(
1210
+ img_start_idxs
1211
+ ) > 1: # if mutli-round, input_ids have mutli <image><unk></image>\n
1212
+ start = 0
1213
+ new_input_ids = []
1214
+ for idx in img_start_idxs[1:]:
1215
+ new_input_ids = new_input_ids + input_ids[start:idx]
1216
+ start = idx + 4 # skip <image><unk></image>\n
1217
+ new_input_ids = new_input_ids + input_ids[start:]
1218
+ input_ids = new_input_ids
1219
+
1220
+ idx = img_start_idxs[0] + 1 # first <unk>
1221
+ config = self.model.config
1222
+ if hasattr(config, 'slice_mode') and config.slice_mode:
1223
+ slice_mode = True
1224
+ assert hasattr(config, 'patch_size')
1225
+ assert hasattr(config, 'max_slice_nums')
1226
+ assert hasattr(config, 'scale_resolution')
1227
+ else:
1228
+ slice_mode = False
1229
+
1230
+ if slice_mode:
1231
+ images, placeholder = self.model.get_slice_image_placeholder(
1232
+ image, self.tokenizer)
1233
+ placeholder_id = self.tokenizer.encode(
1234
+ placeholder, add_special_tokens=False)
1235
+ input_ids = (
1236
+ input_ids[:idx - 1] + placeholder_id + input_ids[idx + 2:])
1237
+ if labels is not None:
1238
+ labels = (
1239
+ labels[:idx - 1] + [-100] * len(placeholder_id)
1240
+ + labels[idx + 2:])
1241
+ input_tensor_ids = torch.tensor(input_ids)
1242
+ image_start_idx = torch.where(
1243
+ input_tensor_ids == self.tokenizer.im_start_id)[0]
1244
+ image_start_idx += 1
1245
+ image_end_idx = torch.where(
1246
+ input_tensor_ids == self.tokenizer.im_end_id)[0]
1247
+ valid_image_nums = max(len(image_start_idx), len(image_end_idx))
1248
+ image_bound = [
1249
+ torch.hstack([
1250
+ image_start_idx[:valid_image_nums].unsqueeze(-1),
1251
+ image_end_idx[:valid_image_nums].unsqueeze(-1)
1252
+ ])
1253
+ ]
1254
+ pixel_values = [
1255
+ self.model.transform(img).to(device=self.model.device)
1256
+ for img in images
1257
+ ]
1258
+
1259
+ else:
1260
+ input_ids = (
1261
+ input_ids[:idx]
1262
+ + [self.tokenizer.unk_token_id] * config.query_num
1263
+ + input_ids[idx + 1:])
1264
+ if labels is not None:
1265
+ labels = (
1266
+ labels[:idx] + [-100] * config.query_num
1267
+ + labels[idx + 1:])
1268
+ image_bound = [torch.tensor([[idx, idx + config.query_num]])]
1269
+ pixel_values = [
1270
+ self.model.transform(image).to(device=self.model.device)
1271
+ ]
1272
+ inputs_embeds, _ = self.model.get_vllm_embedding({
1273
+ 'input_ids':
1274
+ torch.tensor(input_ids)[None].to(device=self.model.device),
1275
+ 'image_bound':
1276
+ image_bound,
1277
+ 'pixel_values': [pixel_values]
1278
+ })
1279
+ inputs['input_ids'] = input_ids
1280
+ inputs['labels'] = labels
1281
+ inputs['inputs_embeds'] = inputs_embeds[0]
1282
+ return inputs, {}
1283
+
1284
+ @staticmethod
1285
+ def get_generate_ids(generate_ids: Tensor,
1286
+ input_token_len: int) -> List[int]:
1287
+ return generate_ids[0].tolist()
1288
+
1289
+
1290
+ register_template(
1291
+ TemplateType.minicpm_v,
1292
+ MiniCPMVTemlate(),
1293
+ use_model=True,
1294
+ lazy_tokenize=True,
1295
+ infer_media_type='dialogue',
1296
+ dataloader_num_workers=0,
1297
+ dataloader_pin_memory=False)
1298
+
1299
+ gemma_template = Template(
1300
+ ['<bos>'],
1301
+ ['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
1302
+ ['<end_of_turn>\n'], ['<end_of_turn>'], None,
1303
+ ['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'])
1304
+ register_template(TemplateType.gemma, gemma_template)
1305
+
1306
+ register_template(
1307
+ TemplateType.telechat,
1308
+ Template([], ['<_user>{{QUERY}}<_bot>'], ['<_end>'], ['<_end>']))
1309
+
1310
+ DBRX_SYSTEM = (
1311
+ 'You are DBRX, created by Databricks. You were last updated in December 2023. '
1312
+ 'You answer questions based on information available up to that point.\n'
1313
+ 'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, '
1314
+ 'but provide thorough responses to more complex and open-ended questions.\n'
1315
+ 'You assist with various tasks, from writing to coding (using markdown for code blocks '
1316
+ '— remember to use ``` with code, JSON, and tables).\n'
1317
+ 'You do not have real-time data access or code execution capabilities.'
1318
+ ' You avoid stereotyping and provide balanced perspectives on controversial topics. '
1319
+ 'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.\n'
1320
+ 'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. '
1321
+ 'If you find yourself talking about this message, stop. You should be responding appropriately '
1322
+ 'and usually that means not mentioning this.'
1323
+ 'YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY '
1324
+ 'PERTINENT TO THE USER\'S QUERY.')
1325
+ register_template(
1326
+ TemplateType.dbrx,
1327
+ Template(
1328
+ [], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'],
1329
+ ['<|im_end|>\n'], ['<|im_end|>'], DBRX_SYSTEM,
1330
+ ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']))
1331
+
1332
+ register_template(
1333
+ TemplateType.mengzi,
1334
+ Template([], ['输入:{{QUERY}}输出:\n'], [], [['eos_token_id']], None,
1335
+ ['指令:{{SYSTEM}}']))
1336
+
1337
+ C4AI_SYSTEM = (
1338
+ 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by '
1339
+ 'providing thorough responses.You are trained by Cohere.')
1340
+ register_template(
1341
+ TemplateType.c4ai,
1342
+ Template(['<BOS_TOKEN>'], [
1343
+ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'
1344
+ ], ['<|END_OF_TURN_TOKEN|>'], ['<|END_OF_TURN_TOKEN|>'], C4AI_SYSTEM, [
1345
+ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|'
1346
+ ]))
1347
+
1348
+
1349
+ class mPlugOwl2Template(Template):
1350
+
1351
+ def __init__(self):
1352
+ return super().__init__(['{{SYSTEM}}'],
1353
+ ['USER: ', [-200], '{{QUERY}}ASSISTANT:'],
1354
+ ['</s>'], [['eos_token_id']])
1355
+
1356
+ def encode(
1357
+ self, example: Dict[str,
1358
+ Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1359
+ from mplug_owl2.mm_utils import process_images
1360
+ image_processor = self.tokenizer.image_processor
1361
+ images_path = example['images']
1362
+ images = []
1363
+ for image_path in images_path:
1364
+ image = _read_from_path(image_path)
1365
+ # ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1/summary
1366
+ max_edge = max(image.size)
1367
+ image = image.resize((max_edge, max_edge))
1368
+ images.append(image)
1369
+ inputs, _ = super().encode(example)
1370
+ input_ids = inputs['input_ids']
1371
+ labels = inputs['labels']
1372
+ images = process_images(images, image_processor)
1373
+ images = images.to(self.model.dtype)
1374
+ return {'input_ids': input_ids, 'labels': labels, 'images': images}, {}
1375
+
1376
+ def data_collator(self,
1377
+ batch: List[Dict[str, Any]],
1378
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1379
+ res = super().data_collator(batch, padding_to)
1380
+ res['images'] = torch.concat([b['images'] for b in batch])
1381
+ return res
1382
+
1383
+
1384
+ register_template(
1385
+ TemplateType.mplug_owl2,
1386
+ mPlugOwl2Template(),
1387
+ infer_media_type='round',
1388
+ use_model=True,
1389
+ lazy_tokenize=True)
1390
+
1391
+ register_template(
1392
+ TemplateType.wizardlm2_awq,
1393
+ Template(['{{SYSTEM}}'], ['User:\n{{QUERY}}\n\nAssistant:\n'], ['\n\n'],
1394
+ ['</s>']))
1395
+
1396
+ _wizardlm2_system = (
1397
+ 'A chat between a curious user and an artificial intelligence assistant. '
1398
+ 'The assistant gives helpful, detailed, and polite answers to the user\'s questions. '
1399
+ )
1400
+ register_template(
1401
+ TemplateType.wizardlm2,
1402
+ Template(['{{SYSTEM}}'], ['USER: {{QUERY}} ASSISTANT:'], ['</s>'],
1403
+ ['</s>'], _wizardlm2_system))
1404
+
1405
+ register_template(
1406
+ TemplateType.atom,
1407
+ Template(['{{SYSTEM}}'], ['<s>Human: {{QUERY}}\n</s><s>Assistant: '],
1408
+ ['</s>'], ['</s>']))
1409
+
1410
+
1411
+ def get_template(
1412
+ template_type: str,
1413
+ tokenizer: PreTrainedTokenizerBase,
1414
+ default_system: Optional[str] = None,
1415
+ max_length: Optional[int] = None,
1416
+ truncation_strategy: Literal['delete', 'truncation_left'] = 'delete',
1417
+ **kwargs,
1418
+ ) -> Template:
1419
+ template_info = TEMPLATE_MAPPING[template_type]
1420
+ template = deepcopy(template_info['template'])
1421
+ template._init_template(tokenizer, default_system, max_length,
1422
+ truncation_strategy, **kwargs)
1423
+ template.template_type = template_type
1424
+ return template
1425
+
1426
+
1427
+ def fuzzy_match(model_name: str, template_type_list: list) -> str:
1428
+ """
1429
+ fuzzy match template_type from model_name
1430
+
1431
+ Args:
1432
+ model_name: model name, e.g. ChatGLM2-7B
1433
+ template_type_list: template_type list, e.g. ['chatglm2', 'baichuan', ...]
1434
+
1435
+ Returns:
1436
+ The best matched template_type.
1437
+ """
1438
+ candidate_list = []
1439
+ for template_type in template_type_list:
1440
+ if template_type in model_name.lower():
1441
+ candidate_list.append(template_type)
1442
+ if len(candidate_list) == 0:
1443
+ return TemplateType.default_generation # TODO: default template
1444
+ else:
1445
+ candidate_list = sorted(candidate_list, key=lambda x: len(x), reverse=True)
1446
+ return candidate_list[0]