xinference 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (64) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +47 -18
  3. xinference/api/oauth2/types.py +1 -0
  4. xinference/api/restful_api.py +16 -11
  5. xinference/client/restful/restful_client.py +12 -2
  6. xinference/conftest.py +13 -2
  7. xinference/constants.py +2 -0
  8. xinference/core/supervisor.py +32 -1
  9. xinference/core/worker.py +139 -20
  10. xinference/deploy/cmdline.py +119 -20
  11. xinference/model/llm/__init__.py +6 -0
  12. xinference/model/llm/llm_family.json +711 -10
  13. xinference/model/llm/llm_family_modelscope.json +557 -7
  14. xinference/model/llm/pytorch/chatglm.py +2 -1
  15. xinference/model/llm/pytorch/core.py +2 -0
  16. xinference/model/llm/pytorch/deepseek_vl.py +232 -0
  17. xinference/model/llm/pytorch/internlm2.py +2 -1
  18. xinference/model/llm/pytorch/omnilmm.py +153 -0
  19. xinference/model/llm/sglang/__init__.py +13 -0
  20. xinference/model/llm/sglang/core.py +365 -0
  21. xinference/model/llm/utils.py +46 -13
  22. xinference/model/llm/vllm/core.py +10 -0
  23. xinference/thirdparty/deepseek_vl/__init__.py +31 -0
  24. xinference/thirdparty/deepseek_vl/models/__init__.py +28 -0
  25. xinference/thirdparty/deepseek_vl/models/clip_encoder.py +242 -0
  26. xinference/thirdparty/deepseek_vl/models/image_processing_vlm.py +208 -0
  27. xinference/thirdparty/deepseek_vl/models/modeling_vlm.py +170 -0
  28. xinference/thirdparty/deepseek_vl/models/processing_vlm.py +390 -0
  29. xinference/thirdparty/deepseek_vl/models/projector.py +100 -0
  30. xinference/thirdparty/deepseek_vl/models/sam.py +593 -0
  31. xinference/thirdparty/deepseek_vl/models/siglip_vit.py +681 -0
  32. xinference/thirdparty/deepseek_vl/utils/__init__.py +18 -0
  33. xinference/thirdparty/deepseek_vl/utils/conversation.py +348 -0
  34. xinference/thirdparty/deepseek_vl/utils/io.py +78 -0
  35. xinference/thirdparty/omnilmm/__init__.py +0 -0
  36. xinference/thirdparty/omnilmm/chat.py +216 -0
  37. xinference/thirdparty/omnilmm/constants.py +4 -0
  38. xinference/thirdparty/omnilmm/conversation.py +332 -0
  39. xinference/thirdparty/omnilmm/model/__init__.py +1 -0
  40. xinference/thirdparty/omnilmm/model/omnilmm.py +594 -0
  41. xinference/thirdparty/omnilmm/model/resampler.py +166 -0
  42. xinference/thirdparty/omnilmm/model/utils.py +563 -0
  43. xinference/thirdparty/omnilmm/train/__init__.py +13 -0
  44. xinference/thirdparty/omnilmm/train/train_utils.py +150 -0
  45. xinference/thirdparty/omnilmm/utils.py +134 -0
  46. xinference/web/ui/build/asset-manifest.json +3 -3
  47. xinference/web/ui/build/index.html +1 -1
  48. xinference/web/ui/build/static/js/main.98516614.js +3 -0
  49. xinference/web/ui/build/static/js/main.98516614.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/139969fd25258eb7decc9505f30b779089bba50c402bb5c663008477c7bff73b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/3f357ab57b8e7fade54c667f0e0ebf2787566f72bfdca0fea14e395b5c203753.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/9d7c49815d97539207e5aab2fb967591b5fed7791218a0762539efc9491f36af.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/d0d0b591d9adaf42b83ad6633f8b7c118541a4b80ea957c303d3bf9b86fbad0a.json +1 -0
  54. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/METADATA +21 -5
  55. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/RECORD +60 -31
  56. xinference/web/ui/build/static/js/main.66b1c4fb.js +0 -3
  57. xinference/web/ui/build/static/js/main.66b1c4fb.js.map +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/c2124cfe036b26befcbd386d1d17743b1a58d0b7a041a17bb67f9924400d63c3.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +0 -1
  60. /xinference/web/ui/build/static/js/{main.66b1c4fb.js.LICENSE.txt → main.98516614.js.LICENSE.txt} +0 -0
  61. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/LICENSE +0 -0
  62. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/WHEEL +0 -0
  63. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/entry_points.txt +0 -0
  64. {xinference-0.9.3.dist-info → xinference-0.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,594 @@
1
+ import gc
2
+ import math
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import timm
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch import Tensor
9
+ from torch.nn import CrossEntropyLoss
10
+ from transformers import (
11
+ AutoConfig,
12
+ AutoModelForCausalLM,
13
+ MistralConfig,
14
+ MistralForCausalLM,
15
+ MistralModel,
16
+ )
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ )
21
+
22
+ from ..model.resampler import Resampler
23
+ from ..model.utils import build_transform
24
+
25
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
26
+ DEFAULT_IM_START_TOKEN = "<im_start>"
27
+ DEFAULT_IM_END_TOKEN = "<im_end>"
28
+
29
+
30
+ class OmniLMMConfig(MistralConfig):
31
+ model_type = "omnilmm"
32
+
33
+
34
+ class Identity(torch.nn.Identity):
35
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
36
+ return super().forward(input)
37
+
38
+
39
+ def create_vision_module(config):
40
+ vision_tower = timm.create_model(
41
+ "eva02_enormous_patch14_clip_224.laion2b_plus",
42
+ pretrained=False,
43
+ num_classes=0,
44
+ dynamic_img_size=True,
45
+ dynamic_img_pad=True,
46
+ )
47
+
48
+ if isinstance(vision_tower, timm.models.VisionTransformer):
49
+ if vision_tower.attn_pool is not None:
50
+ vision_tower.attn_pool = Identity()
51
+
52
+ # use 2nd last layer's output
53
+ vision_tower.blocks[-1] = Identity()
54
+
55
+ embed_dim = config.hidden_size
56
+ resampler = Resampler(
57
+ grid_size=int(math.sqrt(config.num_query)),
58
+ embed_dim=embed_dim,
59
+ num_heads=embed_dim // 128,
60
+ kv_dim=vision_tower.embed_dim,
61
+ )
62
+ return vision_tower, resampler
63
+
64
+
65
+ class OmniLMMModel(MistralModel):
66
+ config_class = OmniLMMConfig
67
+
68
+ def __init__(
69
+ self,
70
+ config: OmniLMMConfig,
71
+ mm_vision_tower=None,
72
+ mm_hidden_size=None,
73
+ tune_clip=True,
74
+ ):
75
+ super(OmniLMMModel, self).__init__(config)
76
+
77
+ if hasattr(config, "mm_vision_tower"):
78
+ vision_tower, resampler = create_vision_module(config)
79
+
80
+ # print(__file__, 'skip loading vision tower weights')
81
+
82
+ # HACK: for FSDP
83
+ self.vision_tower = [vision_tower]
84
+ self.resampler = resampler
85
+ if tune_clip:
86
+ self.vision_tower = self.vision_tower[0]
87
+
88
+ self.vision_config = lambda x: None
89
+
90
+ def initialize_vision_modules(
91
+ self, vision_tower, no_randaug, num_query, image_size, tune_clip=False
92
+ ):
93
+ self.config.mm_vision_tower = vision_tower
94
+ self.config.use_mm_proj = True
95
+ self.config.num_query = num_query
96
+ self.config.image_size = image_size
97
+
98
+ if not hasattr(self, "vision_tower"):
99
+ vision_tower, resampler = create_vision_module(self.config)
100
+ state_dict = torch.load(
101
+ "/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt"
102
+ )
103
+ vision_tower.load_state_dict(state_dict, strict=False)
104
+ del state_dict
105
+ gc.collect()
106
+ else:
107
+ if isinstance(self.vision_tower, list):
108
+ vision_tower = self.vision_tower[0]
109
+ else:
110
+ vision_tower = self.vision_tower
111
+ resampler = self.resampler
112
+ self.vision_tower = vision_tower if tune_clip else [vision_tower]
113
+ self.resampler = resampler
114
+
115
+ train_img_transform = build_transform(
116
+ is_train=True,
117
+ randaug=not no_randaug,
118
+ input_size=self.config.image_size,
119
+ std_mode="OPENAI_CLIP",
120
+ )
121
+ eval_img_transform = build_transform(
122
+ is_train=False, input_size=self.config.image_size, std_mode="OPENAI_CLIP"
123
+ )
124
+
125
+ return dict(
126
+ image_processor=(train_img_transform, eval_img_transform),
127
+ image_token_len=num_query,
128
+ vision_config=self.vision_config,
129
+ )
130
+
131
+ def get_vision_embedding(self, pixel_values):
132
+ if isinstance(self.vision_tower, list):
133
+ vision_tower = self.vision_tower[0] # HACK: for FSDP
134
+ else:
135
+ vision_tower = self.vision_tower
136
+
137
+ dtype = vision_tower.pos_embed.data.dtype
138
+ vision_embedding = vision_tower.forward_features(pixel_values.type(dtype))
139
+ if (
140
+ hasattr(vision_tower, "num_prefix_tokens")
141
+ and vision_tower.num_prefix_tokens > 0
142
+ ):
143
+ vision_embedding = vision_embedding[:, vision_tower.num_prefix_tokens :]
144
+ res = self.resampler(vision_embedding)
145
+ return res
146
+
147
+ def get_vllm_embedding(self, data):
148
+ if "vision_hidden_states" not in data:
149
+ pixel_values_list = data["pixel_values"]
150
+ vision_hidden_states = []
151
+ for pixel_values in pixel_values_list:
152
+ if len(pixel_values) > 0:
153
+ vision_hidden_states.append(
154
+ self.get_vision_embedding(pixel_values.unsqueeze(0))[0]
155
+ )
156
+ else:
157
+ vision_hidden_states.append([])
158
+ else:
159
+ vision_hidden_states = data["vision_hidden_states"]
160
+
161
+ # vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
162
+ inputs_embeds = self.embed_tokens(data["input_ids"])
163
+ vision_hidden_states = [
164
+ i.type(inputs_embeds.dtype) if isinstance(i, torch.Tensor) else i
165
+ for i in vision_hidden_states
166
+ ]
167
+
168
+ # HACK: replace back original embeddings for LLaVA pretraining
169
+ orig_embeds_params = getattr(self, "orig_embeds_params", None)
170
+
171
+ new_input_embeds = []
172
+ cur_image_idx = 0
173
+ for cur_input_ids, cur_input_embeds in zip(data["input_ids"], inputs_embeds):
174
+ if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
175
+ # multimodal LLM, but the current sample is not multimodal
176
+ cur_input_embeds = cur_input_embeds + (0.0 * dummy_image_features).sum()
177
+ new_input_embeds.append(cur_input_embeds)
178
+ continue
179
+
180
+ if self.vision_config.use_im_start_end:
181
+ cur_image_features = vision_hidden_states[cur_image_idx]
182
+ num_patches = cur_image_features.shape[0]
183
+ if (cur_input_ids == self.vision_config.im_start_token).sum() != (
184
+ cur_input_ids == self.vision_config.im_end_token
185
+ ).sum():
186
+ raise ValueError(
187
+ "The number of image start tokens and image end tokens should be the same."
188
+ )
189
+ image_start_tokens = torch.where(
190
+ cur_input_ids == self.vision_config.im_start_token
191
+ )[0]
192
+ for image_start_token_pos in image_start_tokens:
193
+ cur_image_features = vision_hidden_states[cur_image_idx].to(
194
+ device=cur_input_embeds.device
195
+ )
196
+ num_patches = cur_image_features.shape[0]
197
+ if (
198
+ cur_input_ids[image_start_token_pos + num_patches + 1]
199
+ != self.vision_config.im_end_token
200
+ ):
201
+ raise ValueError(
202
+ "The image end token should follow the image start token."
203
+ )
204
+ if orig_embeds_params is not None:
205
+ cur_new_input_embeds = torch.cat(
206
+ (
207
+ cur_input_embeds[:image_start_token_pos].detach(),
208
+ cur_input_embeds[
209
+ image_start_token_pos : image_start_token_pos + 1
210
+ ],
211
+ cur_image_features,
212
+ cur_input_embeds[
213
+ image_start_token_pos
214
+ + num_patches
215
+ + 1 : image_start_token_pos
216
+ + num_patches
217
+ + 2
218
+ ],
219
+ cur_input_embeds[
220
+ image_start_token_pos + num_patches + 2 :
221
+ ].detach(),
222
+ ),
223
+ dim=0,
224
+ )
225
+ else:
226
+ cur_new_input_embeds = torch.cat(
227
+ (
228
+ cur_input_embeds[: image_start_token_pos + 1],
229
+ cur_image_features,
230
+ cur_input_embeds[
231
+ image_start_token_pos + num_patches + 1 :
232
+ ],
233
+ ),
234
+ dim=0,
235
+ )
236
+ cur_image_idx += 1
237
+ new_input_embeds.append(cur_new_input_embeds)
238
+ else:
239
+ raise NotImplementedError
240
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
241
+
242
+ return inputs_embeds, vision_hidden_states
243
+
244
+ def forward(
245
+ self,
246
+ input_ids: torch.LongTensor = None,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
249
+ inputs_embeds: Optional[torch.FloatTensor] = None,
250
+ use_cache: Optional[bool] = None,
251
+ output_attentions: Optional[bool] = None,
252
+ output_hidden_states: Optional[bool] = None,
253
+ images: Optional[torch.FloatTensor] = None,
254
+ return_dict: Optional[bool] = None,
255
+ **kwargs,
256
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
257
+ # HACK: replace back original embeddings for LLaVA pretraining
258
+ orig_embeds_params = getattr(self, "orig_embeds_params", None)
259
+
260
+ if inputs_embeds is None and past_key_values is None:
261
+ inputs_embeds = self.embed_tokens(input_ids)
262
+
263
+ vision_tower = getattr(self, "vision_tower", None)
264
+ if (
265
+ vision_tower is not None
266
+ and (input_ids.shape[1] != 1 or self.training)
267
+ and images is not None
268
+ ):
269
+ if type(images) is list:
270
+ image_features = []
271
+ for image in images:
272
+ image_forward_out = self.get_vision_embedding(
273
+ image.unsqueeze(0)
274
+ )[0]
275
+ image_features.append(image_forward_out)
276
+ else:
277
+ image_features = self.get_vision_embedding(images)
278
+
279
+ dummy_image_features = torch.zeros(
280
+ self.config.num_query,
281
+ self.config.hidden_size,
282
+ device=inputs_embeds.device,
283
+ dtype=inputs_embeds.dtype,
284
+ )
285
+
286
+ new_input_embeds = []
287
+ cur_image_idx = 0
288
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
289
+ if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
290
+ # multimodal LLM, but the current sample is not multimodal
291
+ cur_input_embeds = (
292
+ cur_input_embeds + (0.0 * dummy_image_features).sum()
293
+ )
294
+ new_input_embeds.append(cur_input_embeds)
295
+ continue
296
+
297
+ if self.vision_config.use_im_start_end:
298
+ cur_image_features = image_features[cur_image_idx]
299
+ num_patches = cur_image_features.shape[0]
300
+ if (
301
+ cur_input_ids == self.vision_config.im_start_token
302
+ ).sum() != (
303
+ cur_input_ids == self.vision_config.im_end_token
304
+ ).sum():
305
+ raise ValueError(
306
+ "The number of image start tokens and image end tokens should be the same."
307
+ )
308
+ image_start_tokens = torch.where(
309
+ cur_input_ids == self.vision_config.im_start_token
310
+ )[0]
311
+ for image_start_token_pos in image_start_tokens:
312
+ cur_image_features = image_features[cur_image_idx].to(
313
+ device=cur_input_embeds.device
314
+ )
315
+ num_patches = cur_image_features.shape[0]
316
+ if (
317
+ cur_input_ids[image_start_token_pos + num_patches + 1]
318
+ != self.vision_config.im_end_token
319
+ ):
320
+ raise ValueError(
321
+ "The image end token should follow the image start token."
322
+ )
323
+ if orig_embeds_params is not None:
324
+ cur_new_input_embeds = torch.cat(
325
+ (
326
+ cur_input_embeds[
327
+ :image_start_token_pos
328
+ ].detach(),
329
+ cur_input_embeds[
330
+ image_start_token_pos : image_start_token_pos
331
+ + 1
332
+ ],
333
+ cur_image_features,
334
+ cur_input_embeds[
335
+ image_start_token_pos
336
+ + num_patches
337
+ + 1 : image_start_token_pos
338
+ + num_patches
339
+ + 2
340
+ ],
341
+ cur_input_embeds[
342
+ image_start_token_pos + num_patches + 2 :
343
+ ].detach(),
344
+ ),
345
+ dim=0,
346
+ )
347
+ else:
348
+ cur_new_input_embeds = torch.cat(
349
+ (
350
+ cur_input_embeds[: image_start_token_pos + 1],
351
+ cur_image_features,
352
+ cur_input_embeds[
353
+ image_start_token_pos + num_patches + 1 :
354
+ ],
355
+ ),
356
+ dim=0,
357
+ )
358
+ cur_image_idx += 1
359
+ new_input_embeds.append(cur_new_input_embeds)
360
+ else:
361
+ raise NotImplementedError
362
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
363
+ input_ids = None
364
+
365
+ return super(OmniLMMModel, self).forward(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ past_key_values=past_key_values,
369
+ inputs_embeds=inputs_embeds,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs,
375
+ )
376
+
377
+
378
+ class OmniLMMForCausalLM(MistralForCausalLM):
379
+ config_class = OmniLMMConfig
380
+
381
+ def __init__(self, config, mm_vision_tower=None, tune_clip=True):
382
+ super(MistralForCausalLM, self).__init__(config)
383
+ self.model = OmniLMMModel(
384
+ config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip
385
+ )
386
+
387
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
388
+
389
+ # Initialize weights and apply final processing
390
+ self.post_init()
391
+
392
+ def forward(
393
+ self,
394
+ input_ids: torch.LongTensor = None,
395
+ attention_mask: Optional[torch.Tensor] = None,
396
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
397
+ inputs_embeds: Optional[torch.FloatTensor] = None,
398
+ labels: Optional[torch.LongTensor] = None,
399
+ use_cache: Optional[bool] = None,
400
+ output_attentions: Optional[bool] = None,
401
+ output_hidden_states: Optional[bool] = None,
402
+ images: Optional[torch.FloatTensor] = None,
403
+ return_dict: Optional[bool] = None,
404
+ **kwargs,
405
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
406
+ output_attentions = (
407
+ output_attentions
408
+ if output_attentions is not None
409
+ else self.config.output_attentions
410
+ )
411
+ output_hidden_states = (
412
+ output_hidden_states
413
+ if output_hidden_states is not None
414
+ else self.config.output_hidden_states
415
+ )
416
+ return_dict = (
417
+ return_dict if return_dict is not None else self.config.use_return_dict
418
+ )
419
+
420
+ # print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
421
+ # print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
422
+ # print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
423
+
424
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
425
+ outputs = self.model(
426
+ input_ids=input_ids,
427
+ attention_mask=attention_mask,
428
+ past_key_values=past_key_values,
429
+ inputs_embeds=inputs_embeds,
430
+ use_cache=use_cache,
431
+ output_attentions=output_attentions,
432
+ output_hidden_states=output_hidden_states,
433
+ return_dict=return_dict,
434
+ images=images,
435
+ **kwargs,
436
+ )
437
+
438
+ hidden_states = outputs[0]
439
+ logits = self.lm_head(hidden_states)
440
+
441
+ loss = None
442
+ if labels is not None:
443
+ # Shift so that tokens < n predict n
444
+ shift_logits = logits[..., :-1, :].contiguous()
445
+ shift_labels = labels[..., 1:].contiguous()
446
+ # Flatten the tokens
447
+ loss_fct = CrossEntropyLoss()
448
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
449
+ shift_labels = shift_labels.view(-1)
450
+ # Enable model/pipeline parallelism
451
+ shift_labels = shift_labels.to(shift_logits.device)
452
+ loss = loss_fct(shift_logits, shift_labels)
453
+
454
+ if not return_dict:
455
+ output = (logits,) + outputs[1:]
456
+ return (loss,) + output if loss is not None else output
457
+
458
+ return CausalLMOutputWithPast(
459
+ loss=loss,
460
+ logits=logits,
461
+ past_key_values=outputs.past_key_values,
462
+ hidden_states=outputs.hidden_states,
463
+ attentions=outputs.attentions,
464
+ )
465
+
466
+ # TODO could be removed for generate_vllm()
467
+ def prepare_inputs_for_generation(
468
+ self,
469
+ input_ids,
470
+ past_key_values=None,
471
+ attention_mask=None,
472
+ inputs_embeds=None,
473
+ **kwargs,
474
+ ):
475
+ if past_key_values:
476
+ input_ids = input_ids[:, -1:]
477
+
478
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
479
+ if inputs_embeds is not None and past_key_values is None:
480
+ model_inputs = {"inputs_embeds": inputs_embeds}
481
+ else:
482
+ model_inputs = {"input_ids": input_ids}
483
+
484
+ model_inputs.update(
485
+ {
486
+ "past_key_values": past_key_values,
487
+ "use_cache": kwargs.get("use_cache"),
488
+ "attention_mask": attention_mask,
489
+ "images": kwargs.get("images", None),
490
+ }
491
+ )
492
+ return model_inputs
493
+
494
+ def generate_vllm(
495
+ self,
496
+ input_ids: torch.LongTensor = None,
497
+ images: Optional[torch.FloatTensor] = None,
498
+ vision_hidden_states=None,
499
+ return_vision_hidden_states=False,
500
+ **kwargs,
501
+ ):
502
+ model_inputs = {"input_ids": input_ids}
503
+ if vision_hidden_states is None:
504
+ model_inputs["pixel_values"] = images
505
+ else:
506
+ model_inputs["vision_hidden_states"] = vision_hidden_states
507
+
508
+ with torch.inference_mode():
509
+ inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(
510
+ model_inputs
511
+ )
512
+
513
+ result = self.generate(inputs_embeds=inputs_embeds, **kwargs)
514
+
515
+ if return_vision_hidden_states:
516
+ return result, vision_hidden_states
517
+
518
+ return result
519
+
520
+ def initialize_vision_tokenizer(
521
+ self, mm_use_im_start_end, tokenizer, device, tune_mm_mlp_adapter=False
522
+ ):
523
+ self.model.vision_config.use_im_start_end = mm_use_im_start_end
524
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
525
+ self.resize_token_embeddings(len(tokenizer))
526
+
527
+ if mm_use_im_start_end:
528
+ num_new_tokens = tokenizer.add_tokens(
529
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
530
+ )
531
+ self.resize_token_embeddings(len(tokenizer))
532
+ (
533
+ self.model.vision_config.im_start_token,
534
+ self.model.vision_config.im_end_token,
535
+ ) = tokenizer.convert_tokens_to_ids(
536
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
537
+ )
538
+
539
+ if num_new_tokens > 0:
540
+ input_embeddings = self.get_input_embeddings().weight.data
541
+ output_embeddings = self.get_output_embeddings().weight.data
542
+
543
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
544
+ dim=0, keepdim=True
545
+ )
546
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
547
+ dim=0, keepdim=True
548
+ )
549
+
550
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
551
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
552
+
553
+ # for new sft data
554
+ num_new_tokens = tokenizer.add_tokens(
555
+ ["<box>", "</box>", "<ref>", "</ref>", "<quad>", "</quad>"],
556
+ special_tokens=True,
557
+ )
558
+ self.resize_token_embeddings(len(tokenizer))
559
+
560
+ if num_new_tokens > 0:
561
+ input_embeddings = self.get_input_embeddings().weight.data
562
+ output_embeddings = self.get_output_embeddings().weight.data
563
+
564
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
565
+ dim=0, keepdim=True
566
+ )
567
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
568
+ dim=0, keepdim=True
569
+ )
570
+
571
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
572
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
573
+
574
+ if tune_mm_mlp_adapter:
575
+ self.model.orig_embeds_params = [
576
+ self.get_input_embeddings().weight.data.clone().to(device=device)
577
+ ]
578
+ for p in self.get_input_embeddings().parameters():
579
+ p.requires_grad = True
580
+ for p in self.get_output_embeddings().parameters():
581
+ p.requires_grad = False
582
+
583
+ self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
584
+ [DEFAULT_IMAGE_PATCH_TOKEN]
585
+ )[0]
586
+ print(
587
+ f"Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}",
588
+ flush=True,
589
+ )
590
+ # exit()
591
+
592
+
593
+ AutoConfig.register("omnilmm", OmniLMMConfig)
594
+ AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)