xinference 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (59) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +47 -18
  3. xinference/api/oauth2/types.py +1 -0
  4. xinference/api/restful_api.py +9 -1
  5. xinference/client/restful/restful_client.py +12 -2
  6. xinference/conftest.py +13 -2
  7. xinference/core/supervisor.py +32 -1
  8. xinference/core/worker.py +139 -20
  9. xinference/deploy/cmdline.py +119 -20
  10. xinference/model/llm/__init__.py +4 -0
  11. xinference/model/llm/llm_family.json +627 -0
  12. xinference/model/llm/llm_family_modelscope.json +471 -0
  13. xinference/model/llm/pytorch/core.py +2 -0
  14. xinference/model/llm/pytorch/deepseek_vl.py +232 -0
  15. xinference/model/llm/pytorch/omnilmm.py +153 -0
  16. xinference/model/llm/utils.py +11 -1
  17. xinference/model/llm/vllm/core.py +3 -0
  18. xinference/thirdparty/deepseek_vl/__init__.py +31 -0
  19. xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
  20. xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
  21. xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
  22. xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
  23. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
  24. xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
  25. xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
  26. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
  27. xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
  28. xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
  29. xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
  30. xinference/thirdparty/omnilmm/__init__.py +0 -0
  31. xinference/thirdparty/omnilmm/chat.py +216 -0
  32. xinference/thirdparty/omnilmm/constants.py +4 -0
  33. xinference/thirdparty/omnilmm/conversation.py +332 -0
  34. xinference/thirdparty/omnilmm/model/__init__.py +1 -0
  35. xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
  36. xinference/thirdparty/omnilmm/model/resampler.py +166 -0
  37. xinference/thirdparty/omnilmm/model/utils.py +563 -0
  38. xinference/thirdparty/omnilmm/train/__init__.py +13 -0
  39. xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
  40. xinference/thirdparty/omnilmm/utils.py +134 -0
  41. xinference/web/ui/build/asset-manifest.json +3 -3
  42. xinference/web/ui/build/index.html +1 -1
  43. xinference/web/ui/build/static/js/main.98516614.js +3 -0
  44. xinference/web/ui/build/static/js/main.98516614.js.map +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/139969fd25258eb7decc9505f30b779089bba50c402bb5c663008477c7bff73b.json +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/3f357ab57b8e7fade54c667f0e0ebf2787566f72bfdca0fea14e395b5c203753.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/9d7c49815d97539207e5aab2fb967591b5fed7791218a0762539efc9491f36af.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/d0d0b591d9adaf42b83ad6633f8b7c118541a4b80ea957c303d3bf9b86fbad0a.json +1 -0
  49. {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/METADATA +18 -5
  50. {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/RECORD +55 -28
  51. xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
  52. xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
  53. xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
  54. xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
  55. /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.98516614.js.LICENSE.txt} +0 -0
  56. {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/LICENSE +0 -0
  57. {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/WHEEL +0 -0
  58. {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/entry_points.txt +0 -0
  59. {xinference-0.9.4.dist-info → xinference-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ """
21
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
22
+ """
23
+
24
+ import dataclasses
25
+ from enum import IntEnum, auto
26
+ from typing import Dict, List
27
+
28
+
29
+ class SeparatorStyle(IntEnum):
30
+ """Separator styles."""
31
+
32
+ ADD_COLON_SINGLE = auto()
33
+ ADD_COLON_TWO = auto()
34
+ ADD_COLON_SPACE_SINGLE = auto()
35
+ NO_COLON_SINGLE = auto()
36
+ NO_COLON_TWO = auto()
37
+ ADD_NEW_LINE_SINGLE = auto()
38
+ LLAMA2 = auto()
39
+ CHATGLM = auto()
40
+ CHATML = auto()
41
+ CHATINTERN = auto()
42
+ DOLLY = auto()
43
+ RWKV = auto()
44
+ PHOENIX = auto()
45
+ ROBIN = auto()
46
+ DeepSeek = auto()
47
+ PLAIN = auto()
48
+ ALIGNMENT = auto()
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class Conversation:
53
+ """A class that manages prompt templates and keeps all conversation history."""
54
+
55
+ # The name of this template
56
+ name: str
57
+ # The template of the system prompt
58
+ system_template: str = "{system_message}"
59
+ # The system message
60
+ system_message: str = ""
61
+ # The names of two roles
62
+ roles: List[str] = (("USER", "ASSISTANT"),)
63
+ # All messages. Each item is (role, message).
64
+ messages: List[List[str]] = ()
65
+ # The number of few shot examples
66
+ offset: int = 0
67
+ # The separator style and configurations
68
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
69
+ sep: str = "\n"
70
+ sep2: str = None
71
+ # Stop criteria (the default one is EOS token)
72
+ stop_str: str = None
73
+ # Stops generation if meeting any token in this list
74
+ stop_token_ids: List[int] = None
75
+
76
+ def get_prompt(self) -> str:
77
+ """Get the prompt for generation."""
78
+ system_prompt = self.system_template.format(system_message=self.system_message)
79
+
80
+ if self.sep_style == SeparatorStyle.DeepSeek:
81
+ seps = [self.sep, self.sep2]
82
+ if system_prompt == "" or system_prompt is None:
83
+ ret = ""
84
+ else:
85
+ ret = system_prompt + seps[0]
86
+ for i, (role, message) in enumerate(self.messages):
87
+ if message:
88
+ ret += role + ": " + message + seps[i % 2]
89
+ else:
90
+ ret += role + ":"
91
+ return ret
92
+ elif self.sep_style == SeparatorStyle.LLAMA2:
93
+ seps = [self.sep, self.sep2]
94
+ if self.system_message:
95
+ ret = system_prompt
96
+ else:
97
+ ret = "[INST] "
98
+ for i, (role, message) in enumerate(self.messages):
99
+ tag = self.roles[i % 2]
100
+ if message:
101
+ if type(message) is tuple: # multimodal message
102
+ message, _ = message
103
+ if i == 0:
104
+ ret += message + " "
105
+ else:
106
+ ret += tag + " " + message + seps[i % 2]
107
+ else:
108
+ ret += tag
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.PLAIN:
111
+ seps = [self.sep, self.sep2]
112
+ ret = ""
113
+ for i, (role, message) in enumerate(self.messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ if i % 2 == 0:
118
+ ret += message + seps[i % 2]
119
+ else:
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ return ret
124
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
125
+ seps = [self.sep, self.sep2]
126
+ ret = ""
127
+ for i, (role, message) in enumerate(self.messages):
128
+ if message:
129
+ if type(message) is tuple:
130
+ message, _, _ = message
131
+ if i % 2 == 0:
132
+ ret += "<image>\n" + seps[i % 2]
133
+ else:
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ return ret
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ def get_prompt_for_current_round(self, content=None):
142
+ """Get current round formatted question prompt during sft training"""
143
+ if self.sep_style == SeparatorStyle.PLAIN:
144
+ formatted_question = "<image>\n"
145
+ elif self.sep_style == SeparatorStyle.DeepSeek:
146
+ formatted_question = (
147
+ f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported sep_style: {self.sep_style}")
151
+ return formatted_question
152
+
153
+ def set_system_message(self, system_message: str):
154
+ """Set the system message."""
155
+ self.system_message = system_message
156
+
157
+ def append_message(self, role: str, message: str):
158
+ """Append a new message."""
159
+ self.messages.append([role, message])
160
+
161
+ def reset_message(self):
162
+ """Reset a new message."""
163
+ self.messages = []
164
+
165
+ def update_last_message(self, message: str):
166
+ """Update the last output.
167
+
168
+ The last message is typically set to be None when constructing the prompt,
169
+ so we need to update it in-place after getting the response from a model.
170
+ """
171
+ self.messages[-1][1] = message
172
+
173
+ def to_gradio_chatbot(self):
174
+ """Convert the conversation to gradio chatbot format."""
175
+ ret = []
176
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
177
+ if i % 2 == 0:
178
+ ret.append([msg, None])
179
+ else:
180
+ ret[-1][-1] = msg
181
+ return ret
182
+
183
+ def to_openai_api_messages(self):
184
+ """Convert the conversation to OpenAI chat completion format."""
185
+ system_prompt = self.system_template.format(system_message=self.system_message)
186
+ ret = [{"role": "system", "content": system_prompt}]
187
+
188
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
189
+ if i % 2 == 0:
190
+ ret.append({"role": "user", "content": msg})
191
+ else:
192
+ if msg is not None:
193
+ ret.append({"role": "assistant", "content": msg})
194
+ return ret
195
+
196
+ def copy(self):
197
+ return Conversation(
198
+ name=self.name,
199
+ system_template=self.system_template,
200
+ system_message=self.system_message,
201
+ roles=self.roles,
202
+ messages=[[x, y] for x, y in self.messages],
203
+ offset=self.offset,
204
+ sep_style=self.sep_style,
205
+ sep=self.sep,
206
+ sep2=self.sep2,
207
+ stop_str=self.stop_str,
208
+ stop_token_ids=self.stop_token_ids,
209
+ )
210
+
211
+ def dict(self):
212
+ return {
213
+ "template_name": self.name,
214
+ "system_message": self.system_message,
215
+ "roles": self.roles,
216
+ "messages": self.messages,
217
+ "offset": self.offset,
218
+ }
219
+
220
+
221
+ # A global registry for all conversation templates
222
+ conv_templates: Dict[str, Conversation] = {}
223
+
224
+
225
+ def register_conv_template(template: Conversation, override: bool = False):
226
+ """Register a new conversation template."""
227
+ if not override:
228
+ assert (
229
+ template.name not in conv_templates
230
+ ), f"{template.name} has been registered."
231
+
232
+ conv_templates[template.name] = template
233
+
234
+
235
+ def get_conv_template(name: str) -> Conversation:
236
+ """Get a conversation template."""
237
+ return conv_templates[name].copy()
238
+
239
+
240
+ # llava_llama2 template
241
+ register_conv_template(
242
+ Conversation(
243
+ name="llava_llama2",
244
+ system_message="You are a helpful language and vision assistant. "
245
+ "You are able to understand the visual content that the user provides, "
246
+ "and assist the user with a variety of tasks using natural language.",
247
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
248
+ roles=("[INST]", "[/INST]"),
249
+ messages=(),
250
+ offset=0,
251
+ sep_style=SeparatorStyle.LLAMA2,
252
+ sep=" ",
253
+ sep2=" </s><s>",
254
+ stop_token_ids=[2],
255
+ )
256
+ )
257
+
258
+ # llama2 template
259
+ # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
260
+ register_conv_template(
261
+ Conversation(
262
+ name="llama-2",
263
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
264
+ roles=("[INST]", "[/INST]"),
265
+ messages=(),
266
+ offset=0,
267
+ sep_style=SeparatorStyle.LLAMA2,
268
+ sep=" ",
269
+ sep2=" </s><s>",
270
+ stop_token_ids=[2],
271
+ )
272
+ )
273
+
274
+
275
+ # deepseek template
276
+ register_conv_template(
277
+ Conversation(
278
+ name="deepseek",
279
+ system_template="{system_message}",
280
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
281
+ # "thinking step by step to be sure you get the right answer.",
282
+ system_message="",
283
+ roles=("User", "Assistant"),
284
+ messages=(),
285
+ offset=0,
286
+ sep_style=SeparatorStyle.DeepSeek,
287
+ sep="\n\n",
288
+ sep2="<|end▁of▁sentence|>",
289
+ stop_token_ids=[100001],
290
+ stop_str=["User:", "<|end▁of▁sentence|>"],
291
+ )
292
+ )
293
+
294
+ register_conv_template(
295
+ Conversation(
296
+ name="plain",
297
+ system_template="",
298
+ system_message="",
299
+ roles=("", ""),
300
+ messages=(),
301
+ offset=0,
302
+ sep_style=SeparatorStyle.PLAIN,
303
+ sep="",
304
+ sep2="",
305
+ stop_token_ids=[2],
306
+ stop_str=["</s>"],
307
+ )
308
+ )
309
+
310
+
311
+ register_conv_template(
312
+ Conversation(
313
+ name="alignment",
314
+ system_template="",
315
+ system_message="",
316
+ roles=("", ""),
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.ALIGNMENT,
320
+ sep="",
321
+ sep2="",
322
+ stop_token_ids=[2],
323
+ stop_str=["</s>"],
324
+ )
325
+ )
326
+
327
+
328
+ if __name__ == "__main__":
329
+ # print("Llama-2 template:")
330
+ # conv = get_conv_template("llama-2")
331
+ # conv.set_system_message("You are a helpful, respectful and honest assistant.")
332
+ # conv.append_message(conv.roles[0], "Hello!")
333
+ # conv.append_message(conv.roles[1], "Hi!")
334
+ # conv.append_message(conv.roles[0], "How are you?")
335
+ # conv.append_message(conv.roles[1], None)
336
+ # print(conv.get_prompt())
337
+
338
+ # print("\n")
339
+
340
+ print("deepseek template:")
341
+ conv = get_conv_template("deepseek")
342
+ conv.append_message(conv.roles[0], "Hello!")
343
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
344
+ conv.append_message(conv.roles[0], "Who are you?")
345
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
346
+ conv.append_message(conv.roles[0], "How are you?")
347
+ conv.append_message(conv.roles[1], None)
348
+ print(conv.get_prompt())
@@ -0,0 +1,78 @@
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import json
21
+ from typing import Dict, List
22
+
23
+ import PIL.Image
24
+ import torch
25
+ from transformers import AutoModelForCausalLM
26
+
27
+ from ..models import MultiModalityCausalLM, VLChatProcessor
28
+
29
+
30
+ def load_pretrained_model(model_path: str):
31
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
32
+ tokenizer = vl_chat_processor.tokenizer
33
+
34
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
35
+ model_path, trust_remote_code=True
36
+ )
37
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
38
+
39
+ return tokenizer, vl_chat_processor, vl_gpt
40
+
41
+
42
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
43
+ """
44
+
45
+ Args:
46
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
47
+ [
48
+ {
49
+ "role": "User",
50
+ "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
51
+ "images": ["./examples/table_datasets.png"]
52
+ },
53
+ {"role": "Assistant", "content": ""},
54
+ ]
55
+
56
+ Returns:
57
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
58
+
59
+ """
60
+
61
+ pil_images = []
62
+
63
+ for message in conversations:
64
+ if "images" not in message:
65
+ continue
66
+
67
+ for image_path in message["images"]:
68
+ pil_img = PIL.Image.open(image_path)
69
+ pil_img = pil_img.convert("RGB")
70
+ pil_images.append(pil_img)
71
+
72
+ return pil_images
73
+
74
+
75
+ def load_json(filepath):
76
+ with open(filepath, "r") as f:
77
+ data = json.load(f)
78
+ return data
File without changes
@@ -0,0 +1,216 @@
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+
6
+ import torch
7
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
8
+ from PIL import Image
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ from .model.omnilmm import OmniLMMForCausalLM
12
+ from .model.utils import build_transform
13
+ from .train.train_utils import omni_preprocess
14
+ from .utils import disable_torch_init
15
+
16
+ DEFAULT_IMAGE_TOKEN = "<image>"
17
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
18
+ DEFAULT_IM_START_TOKEN = "<im_start>"
19
+ DEFAULT_IM_END_TOKEN = "<im_end>"
20
+
21
+
22
+ def init_omni_lmm(model_path, device_map):
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ disable_torch_init()
25
+ model_name = os.path.expanduser(model_path)
26
+ print(f"Load omni_lmm model and tokenizer from {model_name}")
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=2048)
28
+
29
+ if False:
30
+ # model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
31
+ with init_empty_weights():
32
+ model = OmniLMMForCausalLM.from_pretrained(
33
+ model_name, tune_clip=True, torch_dtype=torch.bfloat16
34
+ )
35
+ model = load_checkpoint_and_dispatch(
36
+ model,
37
+ model_name,
38
+ dtype=torch.bfloat16,
39
+ device_map="auto",
40
+ no_split_module_classes=[
41
+ "Eva",
42
+ "MistralDecoderLayer",
43
+ "ModuleList",
44
+ "Resampler",
45
+ ],
46
+ )
47
+ else:
48
+ model = OmniLMMForCausalLM.from_pretrained(
49
+ model_name,
50
+ tune_clip=True,
51
+ torch_dtype=torch.bfloat16,
52
+ device_map=device_map,
53
+ ).to(dtype=torch.bfloat16)
54
+
55
+ image_processor = build_transform(
56
+ is_train=False, input_size=model.model.config.image_size, std_mode="OPENAI_CLIP"
57
+ )
58
+
59
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
60
+ assert mm_use_im_start_end
61
+
62
+ tokenizer.add_tokens(
63
+ [DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN],
64
+ special_tokens=True,
65
+ )
66
+
67
+ vision_config = model.model.vision_config
68
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
69
+ [DEFAULT_IMAGE_PATCH_TOKEN]
70
+ )[0]
71
+ vision_config.use_im_start_end = mm_use_im_start_end
72
+ (
73
+ vision_config.im_start_token,
74
+ vision_config.im_end_token,
75
+ ) = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
76
+ image_token_len = model.model.config.num_query
77
+
78
+ return model, image_processor, image_token_len, tokenizer
79
+
80
+
81
+ def expand_question_into_multimodal(
82
+ question_text, image_token_len, im_st_token, im_ed_token, im_patch_token
83
+ ):
84
+ if "<image>" in question_text[0]["content"]:
85
+ question_text[0]["content"] = question_text[0]["content"].replace(
86
+ "<image>", im_st_token + im_patch_token * image_token_len + im_ed_token
87
+ )
88
+ else:
89
+ question_text[0]["content"] = (
90
+ im_st_token
91
+ + im_patch_token * image_token_len
92
+ + im_ed_token
93
+ + "\n"
94
+ + question_text[0]["content"]
95
+ )
96
+ return question_text
97
+
98
+
99
+ def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
100
+ question = expand_question_into_multimodal(
101
+ question,
102
+ image_token_len,
103
+ DEFAULT_IM_START_TOKEN,
104
+ DEFAULT_IM_END_TOKEN,
105
+ DEFAULT_IMAGE_PATCH_TOKEN,
106
+ )
107
+
108
+ conversation = question
109
+ data_dict = omni_preprocess(
110
+ sources=[conversation], tokenizer=tokenizer, generation=True
111
+ )
112
+
113
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
114
+ return data_dict
115
+
116
+
117
+ class OmniLMM12B:
118
+ def __init__(self, model_path, device_map) -> None:
119
+ model, img_processor, image_token_len, tokenizer = init_omni_lmm(
120
+ model_path, device_map
121
+ )
122
+ self.model = model
123
+ self.image_token_len = image_token_len
124
+ self.image_transform = img_processor
125
+ self.tokenizer = tokenizer
126
+ self.model.eval()
127
+
128
+ def decode(self, image, input_ids):
129
+ with torch.inference_mode():
130
+ output = self.model.generate_vllm(
131
+ input_ids=input_ids.unsqueeze(0).cuda(),
132
+ images=image.unsqueeze(0).half().cuda(),
133
+ temperature=0.6,
134
+ max_new_tokens=1024,
135
+ # num_beams=num_beams,
136
+ do_sample=True,
137
+ output_scores=True,
138
+ return_dict_in_generate=True,
139
+ repetition_penalty=1.1,
140
+ top_k=30,
141
+ top_p=0.9,
142
+ )
143
+
144
+ response = self.tokenizer.decode(
145
+ output.sequences[0], skip_special_tokens=True
146
+ )
147
+ response = response.strip()
148
+ return response
149
+
150
+ def chat(self, input):
151
+ try:
152
+ image = Image.open(io.BytesIO(base64.b64decode(input["image"]))).convert(
153
+ "RGB"
154
+ )
155
+ except Exception as e:
156
+ return f"Image decode error: {e}"
157
+
158
+ msgs = json.loads(input["question"])
159
+ input_ids = wrap_question_for_omni_lmm(
160
+ msgs, self.image_token_len, self.tokenizer
161
+ )["input_ids"]
162
+ input_ids = torch.as_tensor(input_ids)
163
+ # print('input_ids', input_ids)
164
+ image = self.image_transform(image)
165
+
166
+ out = self.decode(image, input_ids)
167
+
168
+ return out
169
+
170
+
171
+ def img2base64(file_name):
172
+ with open(file_name, "rb") as f:
173
+ encoded_string = base64.b64encode(f.read())
174
+ return encoded_string
175
+
176
+
177
+ class OmniLMM3B:
178
+ def __init__(self, model_path, device_map) -> None:
179
+ self.model = AutoModel.from_pretrained(
180
+ model_path, trust_remote_code=True, device_map=device_map
181
+ ).to(dtype=torch.bfloat16)
182
+ self.tokenizer = AutoTokenizer.from_pretrained(
183
+ model_path, trust_remote_code=True
184
+ )
185
+ self.model.eval().cuda()
186
+
187
+ def chat(self, input):
188
+ try:
189
+ image = Image.open(io.BytesIO(base64.b64decode(input["image"]))).convert(
190
+ "RGB"
191
+ )
192
+ except Exception as e:
193
+ return f"Image decode error: {e}"
194
+
195
+ msgs = json.loads(input["question"])
196
+
197
+ answer, context, _ = self.model.chat(
198
+ image=image,
199
+ msgs=msgs,
200
+ context=None,
201
+ tokenizer=self.tokenizer,
202
+ sampling=True,
203
+ temperature=0.7,
204
+ )
205
+ return answer
206
+
207
+
208
+ class OmniLMMChat:
209
+ def __init__(self, model_path, device_map) -> None:
210
+ if "12B" in model_path:
211
+ self.model = OmniLMM12B(model_path, device_map)
212
+ else:
213
+ self.model = OmniLMM3B(model_path, device_map)
214
+
215
+ def chat(self, input):
216
+ return self.model.chat(input)
@@ -0,0 +1,4 @@
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."