xinference 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (87) hide show
  1. xinference/_version.py +3 -3
  2. xinference/client/restful/restful_client.py +1 -1
  3. xinference/conftest.py +0 -7
  4. xinference/core/media_interface.py +9 -8
  5. xinference/core/model.py +13 -6
  6. xinference/core/scheduler.py +1 -10
  7. xinference/core/worker.py +0 -10
  8. xinference/model/audio/model_spec.json +53 -1
  9. xinference/model/audio/model_spec_modelscope.json +57 -1
  10. xinference/model/embedding/core.py +19 -11
  11. xinference/model/image/model_spec.json +10 -1
  12. xinference/model/image/model_spec_modelscope.json +20 -0
  13. xinference/model/llm/__init__.py +6 -54
  14. xinference/model/llm/core.py +19 -5
  15. xinference/model/llm/llama_cpp/core.py +59 -3
  16. xinference/model/llm/llama_cpp/memory.py +455 -0
  17. xinference/model/llm/llm_family.json +185 -397
  18. xinference/model/llm/llm_family.py +88 -16
  19. xinference/model/llm/llm_family_modelscope.json +199 -421
  20. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  21. xinference/model/llm/sglang/core.py +4 -0
  22. xinference/model/llm/transformers/__init__.py +27 -6
  23. xinference/model/llm/transformers/chatglm.py +4 -2
  24. xinference/model/llm/transformers/core.py +49 -28
  25. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  26. xinference/model/llm/transformers/gemma3.py +119 -164
  27. xinference/{thirdparty/omnilmm/train → model/llm/transformers/multimodal}/__init__.py +1 -1
  28. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  29. xinference/model/llm/transformers/multimodal/core.py +205 -0
  30. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  31. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  32. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  33. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  34. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  35. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  36. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  37. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  38. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  39. xinference/model/llm/transformers/opt.py +4 -2
  40. xinference/model/llm/transformers/utils.py +6 -37
  41. xinference/model/llm/vllm/core.py +4 -0
  42. xinference/model/rerank/core.py +7 -1
  43. xinference/model/rerank/utils.py +17 -0
  44. xinference/web/ui/build/asset-manifest.json +3 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/js/main.ddf9eaee.js +3 -0
  47. xinference/web/ui/build/static/js/main.ddf9eaee.js.map +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +1 -0
  52. xinference/web/ui/src/locales/en.json +3 -1
  53. xinference/web/ui/src/locales/zh.json +3 -1
  54. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/METADATA +16 -14
  55. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/RECORD +60 -76
  56. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/WHEEL +1 -1
  57. xinference/model/llm/transformers/cogvlm2.py +0 -442
  58. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  59. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  60. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  61. xinference/model/llm/transformers/intern_vl.py +0 -526
  62. xinference/model/llm/transformers/internlm2.py +0 -94
  63. xinference/model/llm/transformers/minicpmv25.py +0 -193
  64. xinference/model/llm/transformers/omnilmm.py +0 -132
  65. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  66. xinference/model/llm/transformers/qwen_vl.py +0 -360
  67. xinference/thirdparty/omnilmm/LICENSE +0 -201
  68. xinference/thirdparty/omnilmm/__init__.py +0 -0
  69. xinference/thirdparty/omnilmm/chat.py +0 -218
  70. xinference/thirdparty/omnilmm/constants.py +0 -4
  71. xinference/thirdparty/omnilmm/conversation.py +0 -332
  72. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  73. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  74. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  75. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  76. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  77. xinference/thirdparty/omnilmm/utils.py +0 -134
  78. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  79. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  84. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.ddf9eaee.js.LICENSE.txt} +0 -0
  85. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/entry_points.txt +0 -0
  86. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/licenses/LICENSE +0 -0
  87. {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,595 +0,0 @@
1
- import gc
2
- import math
3
- from typing import List, Optional, Tuple, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
- from torch import Tensor
8
- from torch.nn import CrossEntropyLoss
9
- from transformers import (
10
- AutoConfig,
11
- AutoModelForCausalLM,
12
- MistralConfig,
13
- MistralForCausalLM,
14
- MistralModel,
15
- )
16
- from transformers.modeling_outputs import (
17
- BaseModelOutputWithPast,
18
- CausalLMOutputWithPast,
19
- )
20
-
21
- from ..model.resampler import Resampler
22
- from ..model.utils import build_transform
23
-
24
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
25
- DEFAULT_IM_START_TOKEN = "<im_start>"
26
- DEFAULT_IM_END_TOKEN = "<im_end>"
27
-
28
-
29
- class OmniLMMConfig(MistralConfig):
30
- model_type = "omnilmm"
31
-
32
-
33
- class Identity(torch.nn.Identity):
34
- def forward(self, input: Tensor, **kwargs) -> Tensor:
35
- return super().forward(input)
36
-
37
-
38
- def create_vision_module(config):
39
- import timm
40
-
41
- vision_tower = timm.create_model(
42
- "eva02_enormous_patch14_clip_224.laion2b_plus",
43
- pretrained=False,
44
- num_classes=0,
45
- dynamic_img_size=True,
46
- dynamic_img_pad=True,
47
- )
48
-
49
- if isinstance(vision_tower, timm.models.VisionTransformer):
50
- if vision_tower.attn_pool is not None:
51
- vision_tower.attn_pool = Identity()
52
-
53
- # use 2nd last layer's output
54
- vision_tower.blocks[-1] = Identity()
55
-
56
- embed_dim = config.hidden_size
57
- resampler = Resampler(
58
- grid_size=int(math.sqrt(config.num_query)),
59
- embed_dim=embed_dim,
60
- num_heads=embed_dim // 128,
61
- kv_dim=vision_tower.embed_dim,
62
- )
63
- return vision_tower, resampler
64
-
65
-
66
- class OmniLMMModel(MistralModel):
67
- config_class = OmniLMMConfig
68
-
69
- def __init__(
70
- self,
71
- config: OmniLMMConfig,
72
- mm_vision_tower=None,
73
- mm_hidden_size=None,
74
- tune_clip=True,
75
- ):
76
- super(OmniLMMModel, self).__init__(config)
77
-
78
- if hasattr(config, "mm_vision_tower"):
79
- vision_tower, resampler = create_vision_module(config)
80
-
81
- # print(__file__, 'skip loading vision tower weights')
82
-
83
- # HACK: for FSDP
84
- self.vision_tower = [vision_tower]
85
- self.resampler = resampler
86
- if tune_clip:
87
- self.vision_tower = self.vision_tower[0]
88
-
89
- self.vision_config = lambda x: None
90
-
91
- def initialize_vision_modules(
92
- self, vision_tower, no_randaug, num_query, image_size, tune_clip=False
93
- ):
94
- self.config.mm_vision_tower = vision_tower
95
- self.config.use_mm_proj = True
96
- self.config.num_query = num_query
97
- self.config.image_size = image_size
98
-
99
- if not hasattr(self, "vision_tower"):
100
- vision_tower, resampler = create_vision_module(self.config)
101
- state_dict = torch.load(
102
- "/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt"
103
- )
104
- vision_tower.load_state_dict(state_dict, strict=False)
105
- del state_dict
106
- gc.collect()
107
- else:
108
- if isinstance(self.vision_tower, list):
109
- vision_tower = self.vision_tower[0]
110
- else:
111
- vision_tower = self.vision_tower
112
- resampler = self.resampler
113
- self.vision_tower = vision_tower if tune_clip else [vision_tower]
114
- self.resampler = resampler
115
-
116
- train_img_transform = build_transform(
117
- is_train=True,
118
- randaug=not no_randaug,
119
- input_size=self.config.image_size,
120
- std_mode="OPENAI_CLIP",
121
- )
122
- eval_img_transform = build_transform(
123
- is_train=False, input_size=self.config.image_size, std_mode="OPENAI_CLIP"
124
- )
125
-
126
- return dict(
127
- image_processor=(train_img_transform, eval_img_transform),
128
- image_token_len=num_query,
129
- vision_config=self.vision_config,
130
- )
131
-
132
- def get_vision_embedding(self, pixel_values):
133
- if isinstance(self.vision_tower, list):
134
- vision_tower = self.vision_tower[0] # HACK: for FSDP
135
- else:
136
- vision_tower = self.vision_tower
137
-
138
- dtype = vision_tower.pos_embed.data.dtype
139
- vision_embedding = vision_tower.forward_features(pixel_values.type(dtype))
140
- if (
141
- hasattr(vision_tower, "num_prefix_tokens")
142
- and vision_tower.num_prefix_tokens > 0
143
- ):
144
- vision_embedding = vision_embedding[:, vision_tower.num_prefix_tokens :]
145
- res = self.resampler(vision_embedding)
146
- return res
147
-
148
- def get_vllm_embedding(self, data):
149
- if "vision_hidden_states" not in data:
150
- pixel_values_list = data["pixel_values"]
151
- vision_hidden_states = []
152
- for pixel_values in pixel_values_list:
153
- if len(pixel_values) > 0:
154
- vision_hidden_states.append(
155
- self.get_vision_embedding(pixel_values.unsqueeze(0))[0]
156
- )
157
- else:
158
- vision_hidden_states.append([])
159
- else:
160
- vision_hidden_states = data["vision_hidden_states"]
161
-
162
- # vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
163
- inputs_embeds = self.embed_tokens(data["input_ids"])
164
- vision_hidden_states = [
165
- i.type(inputs_embeds.dtype) if isinstance(i, torch.Tensor) else i
166
- for i in vision_hidden_states
167
- ]
168
-
169
- # HACK: replace back original embeddings for LLaVA pretraining
170
- orig_embeds_params = getattr(self, "orig_embeds_params", None)
171
-
172
- new_input_embeds = []
173
- cur_image_idx = 0
174
- for cur_input_ids, cur_input_embeds in zip(data["input_ids"], inputs_embeds):
175
- if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
176
- # multimodal LLM, but the current sample is not multimodal
177
- cur_input_embeds = cur_input_embeds + (0.0 * dummy_image_features).sum()
178
- new_input_embeds.append(cur_input_embeds)
179
- continue
180
-
181
- if self.vision_config.use_im_start_end:
182
- cur_image_features = vision_hidden_states[cur_image_idx]
183
- num_patches = cur_image_features.shape[0]
184
- if (cur_input_ids == self.vision_config.im_start_token).sum() != (
185
- cur_input_ids == self.vision_config.im_end_token
186
- ).sum():
187
- raise ValueError(
188
- "The number of image start tokens and image end tokens should be the same."
189
- )
190
- image_start_tokens = torch.where(
191
- cur_input_ids == self.vision_config.im_start_token
192
- )[0]
193
- for image_start_token_pos in image_start_tokens:
194
- cur_image_features = vision_hidden_states[cur_image_idx].to(
195
- device=cur_input_embeds.device
196
- )
197
- num_patches = cur_image_features.shape[0]
198
- if (
199
- cur_input_ids[image_start_token_pos + num_patches + 1]
200
- != self.vision_config.im_end_token
201
- ):
202
- raise ValueError(
203
- "The image end token should follow the image start token."
204
- )
205
- if orig_embeds_params is not None:
206
- cur_new_input_embeds = torch.cat(
207
- (
208
- cur_input_embeds[:image_start_token_pos].detach(),
209
- cur_input_embeds[
210
- image_start_token_pos : image_start_token_pos + 1
211
- ],
212
- cur_image_features,
213
- cur_input_embeds[
214
- image_start_token_pos
215
- + num_patches
216
- + 1 : image_start_token_pos
217
- + num_patches
218
- + 2
219
- ],
220
- cur_input_embeds[
221
- image_start_token_pos + num_patches + 2 :
222
- ].detach(),
223
- ),
224
- dim=0,
225
- )
226
- else:
227
- cur_new_input_embeds = torch.cat(
228
- (
229
- cur_input_embeds[: image_start_token_pos + 1],
230
- cur_image_features,
231
- cur_input_embeds[
232
- image_start_token_pos + num_patches + 1 :
233
- ],
234
- ),
235
- dim=0,
236
- )
237
- cur_image_idx += 1
238
- new_input_embeds.append(cur_new_input_embeds)
239
- else:
240
- raise NotImplementedError
241
- inputs_embeds = torch.stack(new_input_embeds, dim=0)
242
-
243
- return inputs_embeds, vision_hidden_states
244
-
245
- def forward(
246
- self,
247
- input_ids: torch.LongTensor = None,
248
- attention_mask: Optional[torch.Tensor] = None,
249
- past_key_values: Optional[List[torch.FloatTensor]] = None,
250
- inputs_embeds: Optional[torch.FloatTensor] = None,
251
- use_cache: Optional[bool] = None,
252
- output_attentions: Optional[bool] = None,
253
- output_hidden_states: Optional[bool] = None,
254
- images: Optional[torch.FloatTensor] = None,
255
- return_dict: Optional[bool] = None,
256
- **kwargs,
257
- ) -> Union[Tuple, BaseModelOutputWithPast]:
258
- # HACK: replace back original embeddings for LLaVA pretraining
259
- orig_embeds_params = getattr(self, "orig_embeds_params", None)
260
-
261
- if inputs_embeds is None and past_key_values is None:
262
- inputs_embeds = self.embed_tokens(input_ids)
263
-
264
- vision_tower = getattr(self, "vision_tower", None)
265
- if (
266
- vision_tower is not None
267
- and (input_ids.shape[1] != 1 or self.training)
268
- and images is not None
269
- ):
270
- if type(images) is list:
271
- image_features = []
272
- for image in images:
273
- image_forward_out = self.get_vision_embedding(
274
- image.unsqueeze(0)
275
- )[0]
276
- image_features.append(image_forward_out)
277
- else:
278
- image_features = self.get_vision_embedding(images)
279
-
280
- dummy_image_features = torch.zeros(
281
- self.config.num_query,
282
- self.config.hidden_size,
283
- device=inputs_embeds.device,
284
- dtype=inputs_embeds.dtype,
285
- )
286
-
287
- new_input_embeds = []
288
- cur_image_idx = 0
289
- for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
290
- if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
291
- # multimodal LLM, but the current sample is not multimodal
292
- cur_input_embeds = (
293
- cur_input_embeds + (0.0 * dummy_image_features).sum()
294
- )
295
- new_input_embeds.append(cur_input_embeds)
296
- continue
297
-
298
- if self.vision_config.use_im_start_end:
299
- cur_image_features = image_features[cur_image_idx]
300
- num_patches = cur_image_features.shape[0]
301
- if (
302
- cur_input_ids == self.vision_config.im_start_token
303
- ).sum() != (
304
- cur_input_ids == self.vision_config.im_end_token
305
- ).sum():
306
- raise ValueError(
307
- "The number of image start tokens and image end tokens should be the same."
308
- )
309
- image_start_tokens = torch.where(
310
- cur_input_ids == self.vision_config.im_start_token
311
- )[0]
312
- for image_start_token_pos in image_start_tokens:
313
- cur_image_features = image_features[cur_image_idx].to(
314
- device=cur_input_embeds.device
315
- )
316
- num_patches = cur_image_features.shape[0]
317
- if (
318
- cur_input_ids[image_start_token_pos + num_patches + 1]
319
- != self.vision_config.im_end_token
320
- ):
321
- raise ValueError(
322
- "The image end token should follow the image start token."
323
- )
324
- if orig_embeds_params is not None:
325
- cur_new_input_embeds = torch.cat(
326
- (
327
- cur_input_embeds[
328
- :image_start_token_pos
329
- ].detach(),
330
- cur_input_embeds[
331
- image_start_token_pos : image_start_token_pos
332
- + 1
333
- ],
334
- cur_image_features,
335
- cur_input_embeds[
336
- image_start_token_pos
337
- + num_patches
338
- + 1 : image_start_token_pos
339
- + num_patches
340
- + 2
341
- ],
342
- cur_input_embeds[
343
- image_start_token_pos + num_patches + 2 :
344
- ].detach(),
345
- ),
346
- dim=0,
347
- )
348
- else:
349
- cur_new_input_embeds = torch.cat(
350
- (
351
- cur_input_embeds[: image_start_token_pos + 1],
352
- cur_image_features,
353
- cur_input_embeds[
354
- image_start_token_pos + num_patches + 1 :
355
- ],
356
- ),
357
- dim=0,
358
- )
359
- cur_image_idx += 1
360
- new_input_embeds.append(cur_new_input_embeds)
361
- else:
362
- raise NotImplementedError
363
- inputs_embeds = torch.stack(new_input_embeds, dim=0)
364
- input_ids = None
365
-
366
- return super(OmniLMMModel, self).forward(
367
- input_ids=input_ids,
368
- attention_mask=attention_mask,
369
- past_key_values=past_key_values,
370
- inputs_embeds=inputs_embeds,
371
- use_cache=use_cache,
372
- output_attentions=output_attentions,
373
- output_hidden_states=output_hidden_states,
374
- return_dict=return_dict,
375
- **kwargs,
376
- )
377
-
378
-
379
- class OmniLMMForCausalLM(MistralForCausalLM):
380
- config_class = OmniLMMConfig
381
-
382
- def __init__(self, config, mm_vision_tower=None, tune_clip=True):
383
- super(MistralForCausalLM, self).__init__(config)
384
- self.model = OmniLMMModel(
385
- config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip
386
- )
387
-
388
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
389
-
390
- # Initialize weights and apply final processing
391
- self.post_init()
392
-
393
- def forward(
394
- self,
395
- input_ids: torch.LongTensor = None,
396
- attention_mask: Optional[torch.Tensor] = None,
397
- past_key_values: Optional[List[torch.FloatTensor]] = None,
398
- inputs_embeds: Optional[torch.FloatTensor] = None,
399
- labels: Optional[torch.LongTensor] = None,
400
- use_cache: Optional[bool] = None,
401
- output_attentions: Optional[bool] = None,
402
- output_hidden_states: Optional[bool] = None,
403
- images: Optional[torch.FloatTensor] = None,
404
- return_dict: Optional[bool] = None,
405
- **kwargs,
406
- ) -> Union[Tuple, CausalLMOutputWithPast]:
407
- output_attentions = (
408
- output_attentions
409
- if output_attentions is not None
410
- else self.config.output_attentions
411
- )
412
- output_hidden_states = (
413
- output_hidden_states
414
- if output_hidden_states is not None
415
- else self.config.output_hidden_states
416
- )
417
- return_dict = (
418
- return_dict if return_dict is not None else self.config.use_return_dict
419
- )
420
-
421
- # print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
422
- # print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
423
- # print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
424
-
425
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
426
- outputs = self.model(
427
- input_ids=input_ids,
428
- attention_mask=attention_mask,
429
- past_key_values=past_key_values,
430
- inputs_embeds=inputs_embeds,
431
- use_cache=use_cache,
432
- output_attentions=output_attentions,
433
- output_hidden_states=output_hidden_states,
434
- return_dict=return_dict,
435
- images=images,
436
- **kwargs,
437
- )
438
-
439
- hidden_states = outputs[0]
440
- logits = self.lm_head(hidden_states)
441
-
442
- loss = None
443
- if labels is not None:
444
- # Shift so that tokens < n predict n
445
- shift_logits = logits[..., :-1, :].contiguous()
446
- shift_labels = labels[..., 1:].contiguous()
447
- # Flatten the tokens
448
- loss_fct = CrossEntropyLoss()
449
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
450
- shift_labels = shift_labels.view(-1)
451
- # Enable model/pipeline parallelism
452
- shift_labels = shift_labels.to(shift_logits.device)
453
- loss = loss_fct(shift_logits, shift_labels)
454
-
455
- if not return_dict:
456
- output = (logits,) + outputs[1:]
457
- return (loss,) + output if loss is not None else output
458
-
459
- return CausalLMOutputWithPast(
460
- loss=loss,
461
- logits=logits,
462
- past_key_values=outputs.past_key_values,
463
- hidden_states=outputs.hidden_states,
464
- attentions=outputs.attentions,
465
- )
466
-
467
- # TODO could be removed for generate_vllm()
468
- def prepare_inputs_for_generation(
469
- self,
470
- input_ids,
471
- past_key_values=None,
472
- attention_mask=None,
473
- inputs_embeds=None,
474
- **kwargs,
475
- ):
476
- if past_key_values:
477
- input_ids = input_ids[:, -1:]
478
-
479
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
480
- if inputs_embeds is not None and past_key_values is None:
481
- model_inputs = {"inputs_embeds": inputs_embeds}
482
- else:
483
- model_inputs = {"input_ids": input_ids}
484
-
485
- model_inputs.update(
486
- {
487
- "past_key_values": past_key_values,
488
- "use_cache": kwargs.get("use_cache"),
489
- "attention_mask": attention_mask,
490
- "images": kwargs.get("images", None),
491
- }
492
- )
493
- return model_inputs
494
-
495
- def generate_vllm(
496
- self,
497
- input_ids: torch.LongTensor = None,
498
- images: Optional[torch.FloatTensor] = None,
499
- vision_hidden_states=None,
500
- return_vision_hidden_states=False,
501
- **kwargs,
502
- ):
503
- model_inputs = {"input_ids": input_ids}
504
- if vision_hidden_states is None:
505
- model_inputs["pixel_values"] = images
506
- else:
507
- model_inputs["vision_hidden_states"] = vision_hidden_states
508
-
509
- with torch.inference_mode():
510
- inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(
511
- model_inputs
512
- )
513
-
514
- result = self.generate(inputs_embeds=inputs_embeds, **kwargs)
515
-
516
- if return_vision_hidden_states:
517
- return result, vision_hidden_states
518
-
519
- return result
520
-
521
- def initialize_vision_tokenizer(
522
- self, mm_use_im_start_end, tokenizer, device, tune_mm_mlp_adapter=False
523
- ):
524
- self.model.vision_config.use_im_start_end = mm_use_im_start_end
525
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
526
- self.resize_token_embeddings(len(tokenizer))
527
-
528
- if mm_use_im_start_end:
529
- num_new_tokens = tokenizer.add_tokens(
530
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
531
- )
532
- self.resize_token_embeddings(len(tokenizer))
533
- (
534
- self.model.vision_config.im_start_token,
535
- self.model.vision_config.im_end_token,
536
- ) = tokenizer.convert_tokens_to_ids(
537
- [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
538
- )
539
-
540
- if num_new_tokens > 0:
541
- input_embeddings = self.get_input_embeddings().weight.data
542
- output_embeddings = self.get_output_embeddings().weight.data
543
-
544
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
545
- dim=0, keepdim=True
546
- )
547
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
548
- dim=0, keepdim=True
549
- )
550
-
551
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
552
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
553
-
554
- # for new sft data
555
- num_new_tokens = tokenizer.add_tokens(
556
- ["<box>", "</box>", "<ref>", "</ref>", "<quad>", "</quad>"],
557
- special_tokens=True,
558
- )
559
- self.resize_token_embeddings(len(tokenizer))
560
-
561
- if num_new_tokens > 0:
562
- input_embeddings = self.get_input_embeddings().weight.data
563
- output_embeddings = self.get_output_embeddings().weight.data
564
-
565
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
566
- dim=0, keepdim=True
567
- )
568
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
569
- dim=0, keepdim=True
570
- )
571
-
572
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
573
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
574
-
575
- if tune_mm_mlp_adapter:
576
- self.model.orig_embeds_params = [
577
- self.get_input_embeddings().weight.data.clone().to(device=device)
578
- ]
579
- for p in self.get_input_embeddings().parameters():
580
- p.requires_grad = True
581
- for p in self.get_output_embeddings().parameters():
582
- p.requires_grad = False
583
-
584
- self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
585
- [DEFAULT_IMAGE_PATCH_TOKEN]
586
- )[0]
587
- print(
588
- f"Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}",
589
- flush=True,
590
- )
591
- # exit()
592
-
593
-
594
- AutoConfig.register("omnilmm", OmniLMMConfig)
595
- AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)