xinference 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +47 -18
- xinference/api/oauth2/types.py +1 -0
- xinference/api/restful_api.py +9 -1
- xinference/client/restful/restful_client.py +12 -2
- xinference/conftest.py +13 -2
- xinference/core/supervisor.py +32 -1
- xinference/core/worker.py +139 -20
- xinference/deploy/cmdline.py +119 -20
- xinference/model/llm/__init__.py +4 -0
- xinference/model/llm/llm_family.json +627 -0
- xinference/model/llm/llm_family_modelscope.json +471 -0
- xinference/model/llm/pytorch/core.py +2 -0
- xinference/model/llm/pytorch/deepseek_vl.py +232 -0
- xinference/model/llm/pytorch/omnilmm.py +153 -0
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +3 -0
- xinference/thirdparty/deepseek_vl/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
- xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
- xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
- xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
- xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
- xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
- xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
- xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
- xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
- xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
- xinference/thirdparty/omnilmm/__init__.py +0 -0
- xinference/thirdparty/omnilmm/chat.py +216 -0
- xinference/thirdparty/omnilmm/constants.py +4 -0
- xinference/thirdparty/omnilmm/conversation.py +332 -0
- xinference/thirdparty/omnilmm/model/__init__.py +1 -0
- xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
- xinference/thirdparty/omnilmm/model/resampler.py +166 -0
- xinference/thirdparty/omnilmm/model/utils.py +563 -0
- xinference/thirdparty/omnilmm/train/__init__.py +13 -0
- xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
- xinference/thirdparty/omnilmm/utils.py +134 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.98516614.js +3 -0
- xinference/web/ui/build/static/js/main.98516614.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/139969fd25258eb7decc9505f30b779089bba50c402bb5c663008477c7bff73b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3f357ab57b8e7fade54c667f0e0ebf2787566f72bfdca0fea14e395b5c203753.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9d7c49815d97539207e5aab2fb967591b5fed7791218a0762539efc9491f36af.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d0d0b591d9adaf42b83ad6633f8b7c118541a4b80ea957c303d3bf9b86fbad0a.json +1 -0
- {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/METADATA +18 -5
- {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/RECORD +55 -28
- xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
- xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
- /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.98516614.js.LICENSE.txt} +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/LICENSE +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/WHEEL +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
# Copyright (c) 2023-2024 DeepSeek.
|
|
2
|
+
#
|
|
3
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
4
|
+
# this software and associated documentation files (the "Software"), to deal in
|
|
5
|
+
# the Software without restriction, including without limitation the rights to
|
|
6
|
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
7
|
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
8
|
+
# subject to the following conditions:
|
|
9
|
+
#
|
|
10
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
# copies or substantial portions of the Software.
|
|
12
|
+
#
|
|
13
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
15
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
16
|
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
17
|
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
18
|
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import dataclasses
|
|
25
|
+
from enum import IntEnum, auto
|
|
26
|
+
from typing import Dict, List
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SeparatorStyle(IntEnum):
|
|
30
|
+
"""Separator styles."""
|
|
31
|
+
|
|
32
|
+
ADD_COLON_SINGLE = auto()
|
|
33
|
+
ADD_COLON_TWO = auto()
|
|
34
|
+
ADD_COLON_SPACE_SINGLE = auto()
|
|
35
|
+
NO_COLON_SINGLE = auto()
|
|
36
|
+
NO_COLON_TWO = auto()
|
|
37
|
+
ADD_NEW_LINE_SINGLE = auto()
|
|
38
|
+
LLAMA2 = auto()
|
|
39
|
+
CHATGLM = auto()
|
|
40
|
+
CHATML = auto()
|
|
41
|
+
CHATINTERN = auto()
|
|
42
|
+
DOLLY = auto()
|
|
43
|
+
RWKV = auto()
|
|
44
|
+
PHOENIX = auto()
|
|
45
|
+
ROBIN = auto()
|
|
46
|
+
DeepSeek = auto()
|
|
47
|
+
PLAIN = auto()
|
|
48
|
+
ALIGNMENT = auto()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclasses.dataclass
|
|
52
|
+
class Conversation:
|
|
53
|
+
"""A class that manages prompt templates and keeps all conversation history."""
|
|
54
|
+
|
|
55
|
+
# The name of this template
|
|
56
|
+
name: str
|
|
57
|
+
# The template of the system prompt
|
|
58
|
+
system_template: str = "{system_message}"
|
|
59
|
+
# The system message
|
|
60
|
+
system_message: str = ""
|
|
61
|
+
# The names of two roles
|
|
62
|
+
roles: List[str] = (("USER", "ASSISTANT"),)
|
|
63
|
+
# All messages. Each item is (role, message).
|
|
64
|
+
messages: List[List[str]] = ()
|
|
65
|
+
# The number of few shot examples
|
|
66
|
+
offset: int = 0
|
|
67
|
+
# The separator style and configurations
|
|
68
|
+
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
|
69
|
+
sep: str = "\n"
|
|
70
|
+
sep2: str = None
|
|
71
|
+
# Stop criteria (the default one is EOS token)
|
|
72
|
+
stop_str: str = None
|
|
73
|
+
# Stops generation if meeting any token in this list
|
|
74
|
+
stop_token_ids: List[int] = None
|
|
75
|
+
|
|
76
|
+
def get_prompt(self) -> str:
|
|
77
|
+
"""Get the prompt for generation."""
|
|
78
|
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
|
79
|
+
|
|
80
|
+
if self.sep_style == SeparatorStyle.DeepSeek:
|
|
81
|
+
seps = [self.sep, self.sep2]
|
|
82
|
+
if system_prompt == "" or system_prompt is None:
|
|
83
|
+
ret = ""
|
|
84
|
+
else:
|
|
85
|
+
ret = system_prompt + seps[0]
|
|
86
|
+
for i, (role, message) in enumerate(self.messages):
|
|
87
|
+
if message:
|
|
88
|
+
ret += role + ": " + message + seps[i % 2]
|
|
89
|
+
else:
|
|
90
|
+
ret += role + ":"
|
|
91
|
+
return ret
|
|
92
|
+
elif self.sep_style == SeparatorStyle.LLAMA2:
|
|
93
|
+
seps = [self.sep, self.sep2]
|
|
94
|
+
if self.system_message:
|
|
95
|
+
ret = system_prompt
|
|
96
|
+
else:
|
|
97
|
+
ret = "[INST] "
|
|
98
|
+
for i, (role, message) in enumerate(self.messages):
|
|
99
|
+
tag = self.roles[i % 2]
|
|
100
|
+
if message:
|
|
101
|
+
if type(message) is tuple: # multimodal message
|
|
102
|
+
message, _ = message
|
|
103
|
+
if i == 0:
|
|
104
|
+
ret += message + " "
|
|
105
|
+
else:
|
|
106
|
+
ret += tag + " " + message + seps[i % 2]
|
|
107
|
+
else:
|
|
108
|
+
ret += tag
|
|
109
|
+
return ret
|
|
110
|
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
|
111
|
+
seps = [self.sep, self.sep2]
|
|
112
|
+
ret = ""
|
|
113
|
+
for i, (role, message) in enumerate(self.messages):
|
|
114
|
+
if message:
|
|
115
|
+
if type(message) is tuple:
|
|
116
|
+
message, _, _ = message
|
|
117
|
+
if i % 2 == 0:
|
|
118
|
+
ret += message + seps[i % 2]
|
|
119
|
+
else:
|
|
120
|
+
ret += message + seps[i % 2]
|
|
121
|
+
else:
|
|
122
|
+
ret += ""
|
|
123
|
+
return ret
|
|
124
|
+
elif self.sep_style == SeparatorStyle.ALIGNMENT:
|
|
125
|
+
seps = [self.sep, self.sep2]
|
|
126
|
+
ret = ""
|
|
127
|
+
for i, (role, message) in enumerate(self.messages):
|
|
128
|
+
if message:
|
|
129
|
+
if type(message) is tuple:
|
|
130
|
+
message, _, _ = message
|
|
131
|
+
if i % 2 == 0:
|
|
132
|
+
ret += "<image>\n" + seps[i % 2]
|
|
133
|
+
else:
|
|
134
|
+
ret += message + seps[i % 2]
|
|
135
|
+
else:
|
|
136
|
+
ret += ""
|
|
137
|
+
return ret
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
|
140
|
+
|
|
141
|
+
def get_prompt_for_current_round(self, content=None):
|
|
142
|
+
"""Get current round formatted question prompt during sft training"""
|
|
143
|
+
if self.sep_style == SeparatorStyle.PLAIN:
|
|
144
|
+
formatted_question = "<image>\n"
|
|
145
|
+
elif self.sep_style == SeparatorStyle.DeepSeek:
|
|
146
|
+
formatted_question = (
|
|
147
|
+
f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Unsupported sep_style: {self.sep_style}")
|
|
151
|
+
return formatted_question
|
|
152
|
+
|
|
153
|
+
def set_system_message(self, system_message: str):
|
|
154
|
+
"""Set the system message."""
|
|
155
|
+
self.system_message = system_message
|
|
156
|
+
|
|
157
|
+
def append_message(self, role: str, message: str):
|
|
158
|
+
"""Append a new message."""
|
|
159
|
+
self.messages.append([role, message])
|
|
160
|
+
|
|
161
|
+
def reset_message(self):
|
|
162
|
+
"""Reset a new message."""
|
|
163
|
+
self.messages = []
|
|
164
|
+
|
|
165
|
+
def update_last_message(self, message: str):
|
|
166
|
+
"""Update the last output.
|
|
167
|
+
|
|
168
|
+
The last message is typically set to be None when constructing the prompt,
|
|
169
|
+
so we need to update it in-place after getting the response from a model.
|
|
170
|
+
"""
|
|
171
|
+
self.messages[-1][1] = message
|
|
172
|
+
|
|
173
|
+
def to_gradio_chatbot(self):
|
|
174
|
+
"""Convert the conversation to gradio chatbot format."""
|
|
175
|
+
ret = []
|
|
176
|
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
|
177
|
+
if i % 2 == 0:
|
|
178
|
+
ret.append([msg, None])
|
|
179
|
+
else:
|
|
180
|
+
ret[-1][-1] = msg
|
|
181
|
+
return ret
|
|
182
|
+
|
|
183
|
+
def to_openai_api_messages(self):
|
|
184
|
+
"""Convert the conversation to OpenAI chat completion format."""
|
|
185
|
+
system_prompt = self.system_template.format(system_message=self.system_message)
|
|
186
|
+
ret = [{"role": "system", "content": system_prompt}]
|
|
187
|
+
|
|
188
|
+
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
|
189
|
+
if i % 2 == 0:
|
|
190
|
+
ret.append({"role": "user", "content": msg})
|
|
191
|
+
else:
|
|
192
|
+
if msg is not None:
|
|
193
|
+
ret.append({"role": "assistant", "content": msg})
|
|
194
|
+
return ret
|
|
195
|
+
|
|
196
|
+
def copy(self):
|
|
197
|
+
return Conversation(
|
|
198
|
+
name=self.name,
|
|
199
|
+
system_template=self.system_template,
|
|
200
|
+
system_message=self.system_message,
|
|
201
|
+
roles=self.roles,
|
|
202
|
+
messages=[[x, y] for x, y in self.messages],
|
|
203
|
+
offset=self.offset,
|
|
204
|
+
sep_style=self.sep_style,
|
|
205
|
+
sep=self.sep,
|
|
206
|
+
sep2=self.sep2,
|
|
207
|
+
stop_str=self.stop_str,
|
|
208
|
+
stop_token_ids=self.stop_token_ids,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def dict(self):
|
|
212
|
+
return {
|
|
213
|
+
"template_name": self.name,
|
|
214
|
+
"system_message": self.system_message,
|
|
215
|
+
"roles": self.roles,
|
|
216
|
+
"messages": self.messages,
|
|
217
|
+
"offset": self.offset,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# A global registry for all conversation templates
|
|
222
|
+
conv_templates: Dict[str, Conversation] = {}
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def register_conv_template(template: Conversation, override: bool = False):
|
|
226
|
+
"""Register a new conversation template."""
|
|
227
|
+
if not override:
|
|
228
|
+
assert (
|
|
229
|
+
template.name not in conv_templates
|
|
230
|
+
), f"{template.name} has been registered."
|
|
231
|
+
|
|
232
|
+
conv_templates[template.name] = template
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def get_conv_template(name: str) -> Conversation:
|
|
236
|
+
"""Get a conversation template."""
|
|
237
|
+
return conv_templates[name].copy()
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# llava_llama2 template
|
|
241
|
+
register_conv_template(
|
|
242
|
+
Conversation(
|
|
243
|
+
name="llava_llama2",
|
|
244
|
+
system_message="You are a helpful language and vision assistant. "
|
|
245
|
+
"You are able to understand the visual content that the user provides, "
|
|
246
|
+
"and assist the user with a variety of tasks using natural language.",
|
|
247
|
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
|
248
|
+
roles=("[INST]", "[/INST]"),
|
|
249
|
+
messages=(),
|
|
250
|
+
offset=0,
|
|
251
|
+
sep_style=SeparatorStyle.LLAMA2,
|
|
252
|
+
sep=" ",
|
|
253
|
+
sep2=" </s><s>",
|
|
254
|
+
stop_token_ids=[2],
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# llama2 template
|
|
259
|
+
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
|
260
|
+
register_conv_template(
|
|
261
|
+
Conversation(
|
|
262
|
+
name="llama-2",
|
|
263
|
+
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
|
264
|
+
roles=("[INST]", "[/INST]"),
|
|
265
|
+
messages=(),
|
|
266
|
+
offset=0,
|
|
267
|
+
sep_style=SeparatorStyle.LLAMA2,
|
|
268
|
+
sep=" ",
|
|
269
|
+
sep2=" </s><s>",
|
|
270
|
+
stop_token_ids=[2],
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# deepseek template
|
|
276
|
+
register_conv_template(
|
|
277
|
+
Conversation(
|
|
278
|
+
name="deepseek",
|
|
279
|
+
system_template="{system_message}",
|
|
280
|
+
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
|
281
|
+
# "thinking step by step to be sure you get the right answer.",
|
|
282
|
+
system_message="",
|
|
283
|
+
roles=("User", "Assistant"),
|
|
284
|
+
messages=(),
|
|
285
|
+
offset=0,
|
|
286
|
+
sep_style=SeparatorStyle.DeepSeek,
|
|
287
|
+
sep="\n\n",
|
|
288
|
+
sep2="<|end▁of▁sentence|>",
|
|
289
|
+
stop_token_ids=[100001],
|
|
290
|
+
stop_str=["User:", "<|end▁of▁sentence|>"],
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
register_conv_template(
|
|
295
|
+
Conversation(
|
|
296
|
+
name="plain",
|
|
297
|
+
system_template="",
|
|
298
|
+
system_message="",
|
|
299
|
+
roles=("", ""),
|
|
300
|
+
messages=(),
|
|
301
|
+
offset=0,
|
|
302
|
+
sep_style=SeparatorStyle.PLAIN,
|
|
303
|
+
sep="",
|
|
304
|
+
sep2="",
|
|
305
|
+
stop_token_ids=[2],
|
|
306
|
+
stop_str=["</s>"],
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
register_conv_template(
|
|
312
|
+
Conversation(
|
|
313
|
+
name="alignment",
|
|
314
|
+
system_template="",
|
|
315
|
+
system_message="",
|
|
316
|
+
roles=("", ""),
|
|
317
|
+
messages=(),
|
|
318
|
+
offset=0,
|
|
319
|
+
sep_style=SeparatorStyle.ALIGNMENT,
|
|
320
|
+
sep="",
|
|
321
|
+
sep2="",
|
|
322
|
+
stop_token_ids=[2],
|
|
323
|
+
stop_str=["</s>"],
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
if __name__ == "__main__":
|
|
329
|
+
# print("Llama-2 template:")
|
|
330
|
+
# conv = get_conv_template("llama-2")
|
|
331
|
+
# conv.set_system_message("You are a helpful, respectful and honest assistant.")
|
|
332
|
+
# conv.append_message(conv.roles[0], "Hello!")
|
|
333
|
+
# conv.append_message(conv.roles[1], "Hi!")
|
|
334
|
+
# conv.append_message(conv.roles[0], "How are you?")
|
|
335
|
+
# conv.append_message(conv.roles[1], None)
|
|
336
|
+
# print(conv.get_prompt())
|
|
337
|
+
|
|
338
|
+
# print("\n")
|
|
339
|
+
|
|
340
|
+
print("deepseek template:")
|
|
341
|
+
conv = get_conv_template("deepseek")
|
|
342
|
+
conv.append_message(conv.roles[0], "Hello!")
|
|
343
|
+
conv.append_message(conv.roles[1], "Hi! This is Tony.")
|
|
344
|
+
conv.append_message(conv.roles[0], "Who are you?")
|
|
345
|
+
conv.append_message(conv.roles[1], "I am a helpful assistant.")
|
|
346
|
+
conv.append_message(conv.roles[0], "How are you?")
|
|
347
|
+
conv.append_message(conv.roles[1], None)
|
|
348
|
+
print(conv.get_prompt())
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright (c) 2023-2024 DeepSeek.
|
|
2
|
+
#
|
|
3
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
4
|
+
# this software and associated documentation files (the "Software"), to deal in
|
|
5
|
+
# the Software without restriction, including without limitation the rights to
|
|
6
|
+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
7
|
+
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
8
|
+
# subject to the following conditions:
|
|
9
|
+
#
|
|
10
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
# copies or substantial portions of the Software.
|
|
12
|
+
#
|
|
13
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
15
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
16
|
+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
17
|
+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
18
|
+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
19
|
+
|
|
20
|
+
import json
|
|
21
|
+
from typing import Dict, List
|
|
22
|
+
|
|
23
|
+
import PIL.Image
|
|
24
|
+
import torch
|
|
25
|
+
from transformers import AutoModelForCausalLM
|
|
26
|
+
|
|
27
|
+
from ..models import MultiModalityCausalLM, VLChatProcessor
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def load_pretrained_model(model_path: str):
|
|
31
|
+
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
|
32
|
+
tokenizer = vl_chat_processor.tokenizer
|
|
33
|
+
|
|
34
|
+
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
|
35
|
+
model_path, trust_remote_code=True
|
|
36
|
+
)
|
|
37
|
+
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
|
38
|
+
|
|
39
|
+
return tokenizer, vl_chat_processor, vl_gpt
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
|
|
47
|
+
[
|
|
48
|
+
{
|
|
49
|
+
"role": "User",
|
|
50
|
+
"content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
|
|
51
|
+
"images": ["./examples/table_datasets.png"]
|
|
52
|
+
},
|
|
53
|
+
{"role": "Assistant", "content": ""},
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
pil_images (List[PIL.Image.Image]): the list of PIL images.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
pil_images = []
|
|
62
|
+
|
|
63
|
+
for message in conversations:
|
|
64
|
+
if "images" not in message:
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
for image_path in message["images"]:
|
|
68
|
+
pil_img = PIL.Image.open(image_path)
|
|
69
|
+
pil_img = pil_img.convert("RGB")
|
|
70
|
+
pil_images.append(pil_img)
|
|
71
|
+
|
|
72
|
+
return pil_images
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def load_json(filepath):
|
|
76
|
+
with open(filepath, "r") as f:
|
|
77
|
+
data = json.load(f)
|
|
78
|
+
return data
|
|
File without changes
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import io
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from transformers import AutoModel, AutoTokenizer
|
|
10
|
+
|
|
11
|
+
from .model.omnilmm import OmniLMMForCausalLM
|
|
12
|
+
from .model.utils import build_transform
|
|
13
|
+
from .train.train_utils import omni_preprocess
|
|
14
|
+
from .utils import disable_torch_init
|
|
15
|
+
|
|
16
|
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
|
17
|
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
18
|
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
|
19
|
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def init_omni_lmm(model_path, device_map):
|
|
23
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
24
|
+
disable_torch_init()
|
|
25
|
+
model_name = os.path.expanduser(model_path)
|
|
26
|
+
print(f"Load omni_lmm model and tokenizer from {model_name}")
|
|
27
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=2048)
|
|
28
|
+
|
|
29
|
+
if False:
|
|
30
|
+
# model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
|
|
31
|
+
with init_empty_weights():
|
|
32
|
+
model = OmniLMMForCausalLM.from_pretrained(
|
|
33
|
+
model_name, tune_clip=True, torch_dtype=torch.bfloat16
|
|
34
|
+
)
|
|
35
|
+
model = load_checkpoint_and_dispatch(
|
|
36
|
+
model,
|
|
37
|
+
model_name,
|
|
38
|
+
dtype=torch.bfloat16,
|
|
39
|
+
device_map="auto",
|
|
40
|
+
no_split_module_classes=[
|
|
41
|
+
"Eva",
|
|
42
|
+
"MistralDecoderLayer",
|
|
43
|
+
"ModuleList",
|
|
44
|
+
"Resampler",
|
|
45
|
+
],
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
model = OmniLMMForCausalLM.from_pretrained(
|
|
49
|
+
model_name,
|
|
50
|
+
tune_clip=True,
|
|
51
|
+
torch_dtype=torch.bfloat16,
|
|
52
|
+
device_map=device_map,
|
|
53
|
+
).to(dtype=torch.bfloat16)
|
|
54
|
+
|
|
55
|
+
image_processor = build_transform(
|
|
56
|
+
is_train=False, input_size=model.model.config.image_size, std_mode="OPENAI_CLIP"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
|
60
|
+
assert mm_use_im_start_end
|
|
61
|
+
|
|
62
|
+
tokenizer.add_tokens(
|
|
63
|
+
[DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN],
|
|
64
|
+
special_tokens=True,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
vision_config = model.model.vision_config
|
|
68
|
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
69
|
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
|
70
|
+
)[0]
|
|
71
|
+
vision_config.use_im_start_end = mm_use_im_start_end
|
|
72
|
+
(
|
|
73
|
+
vision_config.im_start_token,
|
|
74
|
+
vision_config.im_end_token,
|
|
75
|
+
) = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
|
76
|
+
image_token_len = model.model.config.num_query
|
|
77
|
+
|
|
78
|
+
return model, image_processor, image_token_len, tokenizer
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def expand_question_into_multimodal(
|
|
82
|
+
question_text, image_token_len, im_st_token, im_ed_token, im_patch_token
|
|
83
|
+
):
|
|
84
|
+
if "<image>" in question_text[0]["content"]:
|
|
85
|
+
question_text[0]["content"] = question_text[0]["content"].replace(
|
|
86
|
+
"<image>", im_st_token + im_patch_token * image_token_len + im_ed_token
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
question_text[0]["content"] = (
|
|
90
|
+
im_st_token
|
|
91
|
+
+ im_patch_token * image_token_len
|
|
92
|
+
+ im_ed_token
|
|
93
|
+
+ "\n"
|
|
94
|
+
+ question_text[0]["content"]
|
|
95
|
+
)
|
|
96
|
+
return question_text
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
|
|
100
|
+
question = expand_question_into_multimodal(
|
|
101
|
+
question,
|
|
102
|
+
image_token_len,
|
|
103
|
+
DEFAULT_IM_START_TOKEN,
|
|
104
|
+
DEFAULT_IM_END_TOKEN,
|
|
105
|
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
conversation = question
|
|
109
|
+
data_dict = omni_preprocess(
|
|
110
|
+
sources=[conversation], tokenizer=tokenizer, generation=True
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
|
|
114
|
+
return data_dict
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class OmniLMM12B:
|
|
118
|
+
def __init__(self, model_path, device_map) -> None:
|
|
119
|
+
model, img_processor, image_token_len, tokenizer = init_omni_lmm(
|
|
120
|
+
model_path, device_map
|
|
121
|
+
)
|
|
122
|
+
self.model = model
|
|
123
|
+
self.image_token_len = image_token_len
|
|
124
|
+
self.image_transform = img_processor
|
|
125
|
+
self.tokenizer = tokenizer
|
|
126
|
+
self.model.eval()
|
|
127
|
+
|
|
128
|
+
def decode(self, image, input_ids):
|
|
129
|
+
with torch.inference_mode():
|
|
130
|
+
output = self.model.generate_vllm(
|
|
131
|
+
input_ids=input_ids.unsqueeze(0).cuda(),
|
|
132
|
+
images=image.unsqueeze(0).half().cuda(),
|
|
133
|
+
temperature=0.6,
|
|
134
|
+
max_new_tokens=1024,
|
|
135
|
+
# num_beams=num_beams,
|
|
136
|
+
do_sample=True,
|
|
137
|
+
output_scores=True,
|
|
138
|
+
return_dict_in_generate=True,
|
|
139
|
+
repetition_penalty=1.1,
|
|
140
|
+
top_k=30,
|
|
141
|
+
top_p=0.9,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
response = self.tokenizer.decode(
|
|
145
|
+
output.sequences[0], skip_special_tokens=True
|
|
146
|
+
)
|
|
147
|
+
response = response.strip()
|
|
148
|
+
return response
|
|
149
|
+
|
|
150
|
+
def chat(self, input):
|
|
151
|
+
try:
|
|
152
|
+
image = Image.open(io.BytesIO(base64.b64decode(input["image"]))).convert(
|
|
153
|
+
"RGB"
|
|
154
|
+
)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
return f"Image decode error: {e}"
|
|
157
|
+
|
|
158
|
+
msgs = json.loads(input["question"])
|
|
159
|
+
input_ids = wrap_question_for_omni_lmm(
|
|
160
|
+
msgs, self.image_token_len, self.tokenizer
|
|
161
|
+
)["input_ids"]
|
|
162
|
+
input_ids = torch.as_tensor(input_ids)
|
|
163
|
+
# print('input_ids', input_ids)
|
|
164
|
+
image = self.image_transform(image)
|
|
165
|
+
|
|
166
|
+
out = self.decode(image, input_ids)
|
|
167
|
+
|
|
168
|
+
return out
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def img2base64(file_name):
|
|
172
|
+
with open(file_name, "rb") as f:
|
|
173
|
+
encoded_string = base64.b64encode(f.read())
|
|
174
|
+
return encoded_string
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class OmniLMM3B:
|
|
178
|
+
def __init__(self, model_path, device_map) -> None:
|
|
179
|
+
self.model = AutoModel.from_pretrained(
|
|
180
|
+
model_path, trust_remote_code=True, device_map=device_map
|
|
181
|
+
).to(dtype=torch.bfloat16)
|
|
182
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
183
|
+
model_path, trust_remote_code=True
|
|
184
|
+
)
|
|
185
|
+
self.model.eval().cuda()
|
|
186
|
+
|
|
187
|
+
def chat(self, input):
|
|
188
|
+
try:
|
|
189
|
+
image = Image.open(io.BytesIO(base64.b64decode(input["image"]))).convert(
|
|
190
|
+
"RGB"
|
|
191
|
+
)
|
|
192
|
+
except Exception as e:
|
|
193
|
+
return f"Image decode error: {e}"
|
|
194
|
+
|
|
195
|
+
msgs = json.loads(input["question"])
|
|
196
|
+
|
|
197
|
+
answer, context, _ = self.model.chat(
|
|
198
|
+
image=image,
|
|
199
|
+
msgs=msgs,
|
|
200
|
+
context=None,
|
|
201
|
+
tokenizer=self.tokenizer,
|
|
202
|
+
sampling=True,
|
|
203
|
+
temperature=0.7,
|
|
204
|
+
)
|
|
205
|
+
return answer
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class OmniLMMChat:
|
|
209
|
+
def __init__(self, model_path, device_map) -> None:
|
|
210
|
+
if "12B" in model_path:
|
|
211
|
+
self.model = OmniLMM12B(model_path, device_map)
|
|
212
|
+
else:
|
|
213
|
+
self.model = OmniLMM3B(model_path, device_map)
|
|
214
|
+
|
|
215
|
+
def chat(self, input):
|
|
216
|
+
return self.model.chat(input)
|