xinference 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/client/restful/restful_client.py +1 -1
- xinference/conftest.py +0 -7
- xinference/core/media_interface.py +9 -8
- xinference/core/model.py +13 -6
- xinference/core/scheduler.py +1 -10
- xinference/core/worker.py +0 -10
- xinference/model/audio/model_spec.json +53 -1
- xinference/model/audio/model_spec_modelscope.json +57 -1
- xinference/model/embedding/core.py +19 -11
- xinference/model/image/model_spec.json +10 -1
- xinference/model/image/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +6 -54
- xinference/model/llm/core.py +19 -5
- xinference/model/llm/llama_cpp/core.py +59 -3
- xinference/model/llm/llama_cpp/memory.py +455 -0
- xinference/model/llm/llm_family.json +185 -397
- xinference/model/llm/llm_family.py +88 -16
- xinference/model/llm/llm_family_modelscope.json +199 -421
- xinference/model/llm/llm_family_openmind_hub.json +0 -34
- xinference/model/llm/sglang/core.py +4 -0
- xinference/model/llm/transformers/__init__.py +27 -6
- xinference/model/llm/transformers/chatglm.py +4 -2
- xinference/model/llm/transformers/core.py +49 -28
- xinference/model/llm/transformers/deepseek_v2.py +6 -49
- xinference/model/llm/transformers/gemma3.py +119 -164
- xinference/{thirdparty/omnilmm/train → model/llm/transformers/multimodal}/__init__.py +1 -1
- xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
- xinference/model/llm/transformers/multimodal/core.py +205 -0
- xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
- xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
- xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
- xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
- xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
- xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
- xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
- xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
- xinference/model/llm/transformers/opt.py +4 -2
- xinference/model/llm/transformers/utils.py +6 -37
- xinference/model/llm/vllm/core.py +4 -0
- xinference/model/rerank/core.py +7 -1
- xinference/model/rerank/utils.py +17 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.ddf9eaee.js +3 -0
- xinference/web/ui/build/static/js/main.ddf9eaee.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +1 -0
- xinference/web/ui/src/locales/en.json +3 -1
- xinference/web/ui/src/locales/zh.json +3 -1
- {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/METADATA +16 -14
- {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/RECORD +60 -76
- {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/WHEEL +1 -1
- xinference/model/llm/transformers/cogvlm2.py +0 -442
- xinference/model/llm/transformers/cogvlm2_video.py +0 -333
- xinference/model/llm/transformers/deepseek_vl.py +0 -280
- xinference/model/llm/transformers/glm_edge_v.py +0 -213
- xinference/model/llm/transformers/intern_vl.py +0 -526
- xinference/model/llm/transformers/internlm2.py +0 -94
- xinference/model/llm/transformers/minicpmv25.py +0 -193
- xinference/model/llm/transformers/omnilmm.py +0 -132
- xinference/model/llm/transformers/qwen2_audio.py +0 -179
- xinference/model/llm/transformers/qwen_vl.py +0 -360
- xinference/thirdparty/omnilmm/LICENSE +0 -201
- xinference/thirdparty/omnilmm/__init__.py +0 -0
- xinference/thirdparty/omnilmm/chat.py +0 -218
- xinference/thirdparty/omnilmm/constants.py +0 -4
- xinference/thirdparty/omnilmm/conversation.py +0 -332
- xinference/thirdparty/omnilmm/model/__init__.py +0 -1
- xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
- xinference/thirdparty/omnilmm/model/resampler.py +0 -166
- xinference/thirdparty/omnilmm/model/utils.py +0 -578
- xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
- xinference/thirdparty/omnilmm/utils.py +0 -134
- xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
- xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
- /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.ddf9eaee.js.LICENSE.txt} +0 -0
- {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.6.0.dist-info → xinference-1.6.1.dist-info}/top_level.txt +0 -0
|
@@ -1,595 +0,0 @@
|
|
|
1
|
-
import gc
|
|
2
|
-
import math
|
|
3
|
-
from typing import List, Optional, Tuple, Union
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
from torch import Tensor
|
|
8
|
-
from torch.nn import CrossEntropyLoss
|
|
9
|
-
from transformers import (
|
|
10
|
-
AutoConfig,
|
|
11
|
-
AutoModelForCausalLM,
|
|
12
|
-
MistralConfig,
|
|
13
|
-
MistralForCausalLM,
|
|
14
|
-
MistralModel,
|
|
15
|
-
)
|
|
16
|
-
from transformers.modeling_outputs import (
|
|
17
|
-
BaseModelOutputWithPast,
|
|
18
|
-
CausalLMOutputWithPast,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
from ..model.resampler import Resampler
|
|
22
|
-
from ..model.utils import build_transform
|
|
23
|
-
|
|
24
|
-
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
25
|
-
DEFAULT_IM_START_TOKEN = "<im_start>"
|
|
26
|
-
DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class OmniLMMConfig(MistralConfig):
|
|
30
|
-
model_type = "omnilmm"
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class Identity(torch.nn.Identity):
|
|
34
|
-
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
|
35
|
-
return super().forward(input)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def create_vision_module(config):
|
|
39
|
-
import timm
|
|
40
|
-
|
|
41
|
-
vision_tower = timm.create_model(
|
|
42
|
-
"eva02_enormous_patch14_clip_224.laion2b_plus",
|
|
43
|
-
pretrained=False,
|
|
44
|
-
num_classes=0,
|
|
45
|
-
dynamic_img_size=True,
|
|
46
|
-
dynamic_img_pad=True,
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
if isinstance(vision_tower, timm.models.VisionTransformer):
|
|
50
|
-
if vision_tower.attn_pool is not None:
|
|
51
|
-
vision_tower.attn_pool = Identity()
|
|
52
|
-
|
|
53
|
-
# use 2nd last layer's output
|
|
54
|
-
vision_tower.blocks[-1] = Identity()
|
|
55
|
-
|
|
56
|
-
embed_dim = config.hidden_size
|
|
57
|
-
resampler = Resampler(
|
|
58
|
-
grid_size=int(math.sqrt(config.num_query)),
|
|
59
|
-
embed_dim=embed_dim,
|
|
60
|
-
num_heads=embed_dim // 128,
|
|
61
|
-
kv_dim=vision_tower.embed_dim,
|
|
62
|
-
)
|
|
63
|
-
return vision_tower, resampler
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class OmniLMMModel(MistralModel):
|
|
67
|
-
config_class = OmniLMMConfig
|
|
68
|
-
|
|
69
|
-
def __init__(
|
|
70
|
-
self,
|
|
71
|
-
config: OmniLMMConfig,
|
|
72
|
-
mm_vision_tower=None,
|
|
73
|
-
mm_hidden_size=None,
|
|
74
|
-
tune_clip=True,
|
|
75
|
-
):
|
|
76
|
-
super(OmniLMMModel, self).__init__(config)
|
|
77
|
-
|
|
78
|
-
if hasattr(config, "mm_vision_tower"):
|
|
79
|
-
vision_tower, resampler = create_vision_module(config)
|
|
80
|
-
|
|
81
|
-
# print(__file__, 'skip loading vision tower weights')
|
|
82
|
-
|
|
83
|
-
# HACK: for FSDP
|
|
84
|
-
self.vision_tower = [vision_tower]
|
|
85
|
-
self.resampler = resampler
|
|
86
|
-
if tune_clip:
|
|
87
|
-
self.vision_tower = self.vision_tower[0]
|
|
88
|
-
|
|
89
|
-
self.vision_config = lambda x: None
|
|
90
|
-
|
|
91
|
-
def initialize_vision_modules(
|
|
92
|
-
self, vision_tower, no_randaug, num_query, image_size, tune_clip=False
|
|
93
|
-
):
|
|
94
|
-
self.config.mm_vision_tower = vision_tower
|
|
95
|
-
self.config.use_mm_proj = True
|
|
96
|
-
self.config.num_query = num_query
|
|
97
|
-
self.config.image_size = image_size
|
|
98
|
-
|
|
99
|
-
if not hasattr(self, "vision_tower"):
|
|
100
|
-
vision_tower, resampler = create_vision_module(self.config)
|
|
101
|
-
state_dict = torch.load(
|
|
102
|
-
"/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt"
|
|
103
|
-
)
|
|
104
|
-
vision_tower.load_state_dict(state_dict, strict=False)
|
|
105
|
-
del state_dict
|
|
106
|
-
gc.collect()
|
|
107
|
-
else:
|
|
108
|
-
if isinstance(self.vision_tower, list):
|
|
109
|
-
vision_tower = self.vision_tower[0]
|
|
110
|
-
else:
|
|
111
|
-
vision_tower = self.vision_tower
|
|
112
|
-
resampler = self.resampler
|
|
113
|
-
self.vision_tower = vision_tower if tune_clip else [vision_tower]
|
|
114
|
-
self.resampler = resampler
|
|
115
|
-
|
|
116
|
-
train_img_transform = build_transform(
|
|
117
|
-
is_train=True,
|
|
118
|
-
randaug=not no_randaug,
|
|
119
|
-
input_size=self.config.image_size,
|
|
120
|
-
std_mode="OPENAI_CLIP",
|
|
121
|
-
)
|
|
122
|
-
eval_img_transform = build_transform(
|
|
123
|
-
is_train=False, input_size=self.config.image_size, std_mode="OPENAI_CLIP"
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
return dict(
|
|
127
|
-
image_processor=(train_img_transform, eval_img_transform),
|
|
128
|
-
image_token_len=num_query,
|
|
129
|
-
vision_config=self.vision_config,
|
|
130
|
-
)
|
|
131
|
-
|
|
132
|
-
def get_vision_embedding(self, pixel_values):
|
|
133
|
-
if isinstance(self.vision_tower, list):
|
|
134
|
-
vision_tower = self.vision_tower[0] # HACK: for FSDP
|
|
135
|
-
else:
|
|
136
|
-
vision_tower = self.vision_tower
|
|
137
|
-
|
|
138
|
-
dtype = vision_tower.pos_embed.data.dtype
|
|
139
|
-
vision_embedding = vision_tower.forward_features(pixel_values.type(dtype))
|
|
140
|
-
if (
|
|
141
|
-
hasattr(vision_tower, "num_prefix_tokens")
|
|
142
|
-
and vision_tower.num_prefix_tokens > 0
|
|
143
|
-
):
|
|
144
|
-
vision_embedding = vision_embedding[:, vision_tower.num_prefix_tokens :]
|
|
145
|
-
res = self.resampler(vision_embedding)
|
|
146
|
-
return res
|
|
147
|
-
|
|
148
|
-
def get_vllm_embedding(self, data):
|
|
149
|
-
if "vision_hidden_states" not in data:
|
|
150
|
-
pixel_values_list = data["pixel_values"]
|
|
151
|
-
vision_hidden_states = []
|
|
152
|
-
for pixel_values in pixel_values_list:
|
|
153
|
-
if len(pixel_values) > 0:
|
|
154
|
-
vision_hidden_states.append(
|
|
155
|
-
self.get_vision_embedding(pixel_values.unsqueeze(0))[0]
|
|
156
|
-
)
|
|
157
|
-
else:
|
|
158
|
-
vision_hidden_states.append([])
|
|
159
|
-
else:
|
|
160
|
-
vision_hidden_states = data["vision_hidden_states"]
|
|
161
|
-
|
|
162
|
-
# vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
|
163
|
-
inputs_embeds = self.embed_tokens(data["input_ids"])
|
|
164
|
-
vision_hidden_states = [
|
|
165
|
-
i.type(inputs_embeds.dtype) if isinstance(i, torch.Tensor) else i
|
|
166
|
-
for i in vision_hidden_states
|
|
167
|
-
]
|
|
168
|
-
|
|
169
|
-
# HACK: replace back original embeddings for LLaVA pretraining
|
|
170
|
-
orig_embeds_params = getattr(self, "orig_embeds_params", None)
|
|
171
|
-
|
|
172
|
-
new_input_embeds = []
|
|
173
|
-
cur_image_idx = 0
|
|
174
|
-
for cur_input_ids, cur_input_embeds in zip(data["input_ids"], inputs_embeds):
|
|
175
|
-
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
|
176
|
-
# multimodal LLM, but the current sample is not multimodal
|
|
177
|
-
cur_input_embeds = cur_input_embeds + (0.0 * dummy_image_features).sum()
|
|
178
|
-
new_input_embeds.append(cur_input_embeds)
|
|
179
|
-
continue
|
|
180
|
-
|
|
181
|
-
if self.vision_config.use_im_start_end:
|
|
182
|
-
cur_image_features = vision_hidden_states[cur_image_idx]
|
|
183
|
-
num_patches = cur_image_features.shape[0]
|
|
184
|
-
if (cur_input_ids == self.vision_config.im_start_token).sum() != (
|
|
185
|
-
cur_input_ids == self.vision_config.im_end_token
|
|
186
|
-
).sum():
|
|
187
|
-
raise ValueError(
|
|
188
|
-
"The number of image start tokens and image end tokens should be the same."
|
|
189
|
-
)
|
|
190
|
-
image_start_tokens = torch.where(
|
|
191
|
-
cur_input_ids == self.vision_config.im_start_token
|
|
192
|
-
)[0]
|
|
193
|
-
for image_start_token_pos in image_start_tokens:
|
|
194
|
-
cur_image_features = vision_hidden_states[cur_image_idx].to(
|
|
195
|
-
device=cur_input_embeds.device
|
|
196
|
-
)
|
|
197
|
-
num_patches = cur_image_features.shape[0]
|
|
198
|
-
if (
|
|
199
|
-
cur_input_ids[image_start_token_pos + num_patches + 1]
|
|
200
|
-
!= self.vision_config.im_end_token
|
|
201
|
-
):
|
|
202
|
-
raise ValueError(
|
|
203
|
-
"The image end token should follow the image start token."
|
|
204
|
-
)
|
|
205
|
-
if orig_embeds_params is not None:
|
|
206
|
-
cur_new_input_embeds = torch.cat(
|
|
207
|
-
(
|
|
208
|
-
cur_input_embeds[:image_start_token_pos].detach(),
|
|
209
|
-
cur_input_embeds[
|
|
210
|
-
image_start_token_pos : image_start_token_pos + 1
|
|
211
|
-
],
|
|
212
|
-
cur_image_features,
|
|
213
|
-
cur_input_embeds[
|
|
214
|
-
image_start_token_pos
|
|
215
|
-
+ num_patches
|
|
216
|
-
+ 1 : image_start_token_pos
|
|
217
|
-
+ num_patches
|
|
218
|
-
+ 2
|
|
219
|
-
],
|
|
220
|
-
cur_input_embeds[
|
|
221
|
-
image_start_token_pos + num_patches + 2 :
|
|
222
|
-
].detach(),
|
|
223
|
-
),
|
|
224
|
-
dim=0,
|
|
225
|
-
)
|
|
226
|
-
else:
|
|
227
|
-
cur_new_input_embeds = torch.cat(
|
|
228
|
-
(
|
|
229
|
-
cur_input_embeds[: image_start_token_pos + 1],
|
|
230
|
-
cur_image_features,
|
|
231
|
-
cur_input_embeds[
|
|
232
|
-
image_start_token_pos + num_patches + 1 :
|
|
233
|
-
],
|
|
234
|
-
),
|
|
235
|
-
dim=0,
|
|
236
|
-
)
|
|
237
|
-
cur_image_idx += 1
|
|
238
|
-
new_input_embeds.append(cur_new_input_embeds)
|
|
239
|
-
else:
|
|
240
|
-
raise NotImplementedError
|
|
241
|
-
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
|
242
|
-
|
|
243
|
-
return inputs_embeds, vision_hidden_states
|
|
244
|
-
|
|
245
|
-
def forward(
|
|
246
|
-
self,
|
|
247
|
-
input_ids: torch.LongTensor = None,
|
|
248
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
249
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
250
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
251
|
-
use_cache: Optional[bool] = None,
|
|
252
|
-
output_attentions: Optional[bool] = None,
|
|
253
|
-
output_hidden_states: Optional[bool] = None,
|
|
254
|
-
images: Optional[torch.FloatTensor] = None,
|
|
255
|
-
return_dict: Optional[bool] = None,
|
|
256
|
-
**kwargs,
|
|
257
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
258
|
-
# HACK: replace back original embeddings for LLaVA pretraining
|
|
259
|
-
orig_embeds_params = getattr(self, "orig_embeds_params", None)
|
|
260
|
-
|
|
261
|
-
if inputs_embeds is None and past_key_values is None:
|
|
262
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
|
263
|
-
|
|
264
|
-
vision_tower = getattr(self, "vision_tower", None)
|
|
265
|
-
if (
|
|
266
|
-
vision_tower is not None
|
|
267
|
-
and (input_ids.shape[1] != 1 or self.training)
|
|
268
|
-
and images is not None
|
|
269
|
-
):
|
|
270
|
-
if type(images) is list:
|
|
271
|
-
image_features = []
|
|
272
|
-
for image in images:
|
|
273
|
-
image_forward_out = self.get_vision_embedding(
|
|
274
|
-
image.unsqueeze(0)
|
|
275
|
-
)[0]
|
|
276
|
-
image_features.append(image_forward_out)
|
|
277
|
-
else:
|
|
278
|
-
image_features = self.get_vision_embedding(images)
|
|
279
|
-
|
|
280
|
-
dummy_image_features = torch.zeros(
|
|
281
|
-
self.config.num_query,
|
|
282
|
-
self.config.hidden_size,
|
|
283
|
-
device=inputs_embeds.device,
|
|
284
|
-
dtype=inputs_embeds.dtype,
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
new_input_embeds = []
|
|
288
|
-
cur_image_idx = 0
|
|
289
|
-
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
|
290
|
-
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
|
291
|
-
# multimodal LLM, but the current sample is not multimodal
|
|
292
|
-
cur_input_embeds = (
|
|
293
|
-
cur_input_embeds + (0.0 * dummy_image_features).sum()
|
|
294
|
-
)
|
|
295
|
-
new_input_embeds.append(cur_input_embeds)
|
|
296
|
-
continue
|
|
297
|
-
|
|
298
|
-
if self.vision_config.use_im_start_end:
|
|
299
|
-
cur_image_features = image_features[cur_image_idx]
|
|
300
|
-
num_patches = cur_image_features.shape[0]
|
|
301
|
-
if (
|
|
302
|
-
cur_input_ids == self.vision_config.im_start_token
|
|
303
|
-
).sum() != (
|
|
304
|
-
cur_input_ids == self.vision_config.im_end_token
|
|
305
|
-
).sum():
|
|
306
|
-
raise ValueError(
|
|
307
|
-
"The number of image start tokens and image end tokens should be the same."
|
|
308
|
-
)
|
|
309
|
-
image_start_tokens = torch.where(
|
|
310
|
-
cur_input_ids == self.vision_config.im_start_token
|
|
311
|
-
)[0]
|
|
312
|
-
for image_start_token_pos in image_start_tokens:
|
|
313
|
-
cur_image_features = image_features[cur_image_idx].to(
|
|
314
|
-
device=cur_input_embeds.device
|
|
315
|
-
)
|
|
316
|
-
num_patches = cur_image_features.shape[0]
|
|
317
|
-
if (
|
|
318
|
-
cur_input_ids[image_start_token_pos + num_patches + 1]
|
|
319
|
-
!= self.vision_config.im_end_token
|
|
320
|
-
):
|
|
321
|
-
raise ValueError(
|
|
322
|
-
"The image end token should follow the image start token."
|
|
323
|
-
)
|
|
324
|
-
if orig_embeds_params is not None:
|
|
325
|
-
cur_new_input_embeds = torch.cat(
|
|
326
|
-
(
|
|
327
|
-
cur_input_embeds[
|
|
328
|
-
:image_start_token_pos
|
|
329
|
-
].detach(),
|
|
330
|
-
cur_input_embeds[
|
|
331
|
-
image_start_token_pos : image_start_token_pos
|
|
332
|
-
+ 1
|
|
333
|
-
],
|
|
334
|
-
cur_image_features,
|
|
335
|
-
cur_input_embeds[
|
|
336
|
-
image_start_token_pos
|
|
337
|
-
+ num_patches
|
|
338
|
-
+ 1 : image_start_token_pos
|
|
339
|
-
+ num_patches
|
|
340
|
-
+ 2
|
|
341
|
-
],
|
|
342
|
-
cur_input_embeds[
|
|
343
|
-
image_start_token_pos + num_patches + 2 :
|
|
344
|
-
].detach(),
|
|
345
|
-
),
|
|
346
|
-
dim=0,
|
|
347
|
-
)
|
|
348
|
-
else:
|
|
349
|
-
cur_new_input_embeds = torch.cat(
|
|
350
|
-
(
|
|
351
|
-
cur_input_embeds[: image_start_token_pos + 1],
|
|
352
|
-
cur_image_features,
|
|
353
|
-
cur_input_embeds[
|
|
354
|
-
image_start_token_pos + num_patches + 1 :
|
|
355
|
-
],
|
|
356
|
-
),
|
|
357
|
-
dim=0,
|
|
358
|
-
)
|
|
359
|
-
cur_image_idx += 1
|
|
360
|
-
new_input_embeds.append(cur_new_input_embeds)
|
|
361
|
-
else:
|
|
362
|
-
raise NotImplementedError
|
|
363
|
-
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
|
364
|
-
input_ids = None
|
|
365
|
-
|
|
366
|
-
return super(OmniLMMModel, self).forward(
|
|
367
|
-
input_ids=input_ids,
|
|
368
|
-
attention_mask=attention_mask,
|
|
369
|
-
past_key_values=past_key_values,
|
|
370
|
-
inputs_embeds=inputs_embeds,
|
|
371
|
-
use_cache=use_cache,
|
|
372
|
-
output_attentions=output_attentions,
|
|
373
|
-
output_hidden_states=output_hidden_states,
|
|
374
|
-
return_dict=return_dict,
|
|
375
|
-
**kwargs,
|
|
376
|
-
)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
class OmniLMMForCausalLM(MistralForCausalLM):
|
|
380
|
-
config_class = OmniLMMConfig
|
|
381
|
-
|
|
382
|
-
def __init__(self, config, mm_vision_tower=None, tune_clip=True):
|
|
383
|
-
super(MistralForCausalLM, self).__init__(config)
|
|
384
|
-
self.model = OmniLMMModel(
|
|
385
|
-
config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip
|
|
386
|
-
)
|
|
387
|
-
|
|
388
|
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
389
|
-
|
|
390
|
-
# Initialize weights and apply final processing
|
|
391
|
-
self.post_init()
|
|
392
|
-
|
|
393
|
-
def forward(
|
|
394
|
-
self,
|
|
395
|
-
input_ids: torch.LongTensor = None,
|
|
396
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
397
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
398
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
399
|
-
labels: Optional[torch.LongTensor] = None,
|
|
400
|
-
use_cache: Optional[bool] = None,
|
|
401
|
-
output_attentions: Optional[bool] = None,
|
|
402
|
-
output_hidden_states: Optional[bool] = None,
|
|
403
|
-
images: Optional[torch.FloatTensor] = None,
|
|
404
|
-
return_dict: Optional[bool] = None,
|
|
405
|
-
**kwargs,
|
|
406
|
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
407
|
-
output_attentions = (
|
|
408
|
-
output_attentions
|
|
409
|
-
if output_attentions is not None
|
|
410
|
-
else self.config.output_attentions
|
|
411
|
-
)
|
|
412
|
-
output_hidden_states = (
|
|
413
|
-
output_hidden_states
|
|
414
|
-
if output_hidden_states is not None
|
|
415
|
-
else self.config.output_hidden_states
|
|
416
|
-
)
|
|
417
|
-
return_dict = (
|
|
418
|
-
return_dict if return_dict is not None else self.config.use_return_dict
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
# print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
|
|
422
|
-
# print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
|
|
423
|
-
# print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
|
|
424
|
-
|
|
425
|
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
426
|
-
outputs = self.model(
|
|
427
|
-
input_ids=input_ids,
|
|
428
|
-
attention_mask=attention_mask,
|
|
429
|
-
past_key_values=past_key_values,
|
|
430
|
-
inputs_embeds=inputs_embeds,
|
|
431
|
-
use_cache=use_cache,
|
|
432
|
-
output_attentions=output_attentions,
|
|
433
|
-
output_hidden_states=output_hidden_states,
|
|
434
|
-
return_dict=return_dict,
|
|
435
|
-
images=images,
|
|
436
|
-
**kwargs,
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
hidden_states = outputs[0]
|
|
440
|
-
logits = self.lm_head(hidden_states)
|
|
441
|
-
|
|
442
|
-
loss = None
|
|
443
|
-
if labels is not None:
|
|
444
|
-
# Shift so that tokens < n predict n
|
|
445
|
-
shift_logits = logits[..., :-1, :].contiguous()
|
|
446
|
-
shift_labels = labels[..., 1:].contiguous()
|
|
447
|
-
# Flatten the tokens
|
|
448
|
-
loss_fct = CrossEntropyLoss()
|
|
449
|
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
450
|
-
shift_labels = shift_labels.view(-1)
|
|
451
|
-
# Enable model/pipeline parallelism
|
|
452
|
-
shift_labels = shift_labels.to(shift_logits.device)
|
|
453
|
-
loss = loss_fct(shift_logits, shift_labels)
|
|
454
|
-
|
|
455
|
-
if not return_dict:
|
|
456
|
-
output = (logits,) + outputs[1:]
|
|
457
|
-
return (loss,) + output if loss is not None else output
|
|
458
|
-
|
|
459
|
-
return CausalLMOutputWithPast(
|
|
460
|
-
loss=loss,
|
|
461
|
-
logits=logits,
|
|
462
|
-
past_key_values=outputs.past_key_values,
|
|
463
|
-
hidden_states=outputs.hidden_states,
|
|
464
|
-
attentions=outputs.attentions,
|
|
465
|
-
)
|
|
466
|
-
|
|
467
|
-
# TODO could be removed for generate_vllm()
|
|
468
|
-
def prepare_inputs_for_generation(
|
|
469
|
-
self,
|
|
470
|
-
input_ids,
|
|
471
|
-
past_key_values=None,
|
|
472
|
-
attention_mask=None,
|
|
473
|
-
inputs_embeds=None,
|
|
474
|
-
**kwargs,
|
|
475
|
-
):
|
|
476
|
-
if past_key_values:
|
|
477
|
-
input_ids = input_ids[:, -1:]
|
|
478
|
-
|
|
479
|
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
480
|
-
if inputs_embeds is not None and past_key_values is None:
|
|
481
|
-
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
482
|
-
else:
|
|
483
|
-
model_inputs = {"input_ids": input_ids}
|
|
484
|
-
|
|
485
|
-
model_inputs.update(
|
|
486
|
-
{
|
|
487
|
-
"past_key_values": past_key_values,
|
|
488
|
-
"use_cache": kwargs.get("use_cache"),
|
|
489
|
-
"attention_mask": attention_mask,
|
|
490
|
-
"images": kwargs.get("images", None),
|
|
491
|
-
}
|
|
492
|
-
)
|
|
493
|
-
return model_inputs
|
|
494
|
-
|
|
495
|
-
def generate_vllm(
|
|
496
|
-
self,
|
|
497
|
-
input_ids: torch.LongTensor = None,
|
|
498
|
-
images: Optional[torch.FloatTensor] = None,
|
|
499
|
-
vision_hidden_states=None,
|
|
500
|
-
return_vision_hidden_states=False,
|
|
501
|
-
**kwargs,
|
|
502
|
-
):
|
|
503
|
-
model_inputs = {"input_ids": input_ids}
|
|
504
|
-
if vision_hidden_states is None:
|
|
505
|
-
model_inputs["pixel_values"] = images
|
|
506
|
-
else:
|
|
507
|
-
model_inputs["vision_hidden_states"] = vision_hidden_states
|
|
508
|
-
|
|
509
|
-
with torch.inference_mode():
|
|
510
|
-
inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(
|
|
511
|
-
model_inputs
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
result = self.generate(inputs_embeds=inputs_embeds, **kwargs)
|
|
515
|
-
|
|
516
|
-
if return_vision_hidden_states:
|
|
517
|
-
return result, vision_hidden_states
|
|
518
|
-
|
|
519
|
-
return result
|
|
520
|
-
|
|
521
|
-
def initialize_vision_tokenizer(
|
|
522
|
-
self, mm_use_im_start_end, tokenizer, device, tune_mm_mlp_adapter=False
|
|
523
|
-
):
|
|
524
|
-
self.model.vision_config.use_im_start_end = mm_use_im_start_end
|
|
525
|
-
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
|
526
|
-
self.resize_token_embeddings(len(tokenizer))
|
|
527
|
-
|
|
528
|
-
if mm_use_im_start_end:
|
|
529
|
-
num_new_tokens = tokenizer.add_tokens(
|
|
530
|
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
|
531
|
-
)
|
|
532
|
-
self.resize_token_embeddings(len(tokenizer))
|
|
533
|
-
(
|
|
534
|
-
self.model.vision_config.im_start_token,
|
|
535
|
-
self.model.vision_config.im_end_token,
|
|
536
|
-
) = tokenizer.convert_tokens_to_ids(
|
|
537
|
-
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
|
538
|
-
)
|
|
539
|
-
|
|
540
|
-
if num_new_tokens > 0:
|
|
541
|
-
input_embeddings = self.get_input_embeddings().weight.data
|
|
542
|
-
output_embeddings = self.get_output_embeddings().weight.data
|
|
543
|
-
|
|
544
|
-
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
|
545
|
-
dim=0, keepdim=True
|
|
546
|
-
)
|
|
547
|
-
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
|
548
|
-
dim=0, keepdim=True
|
|
549
|
-
)
|
|
550
|
-
|
|
551
|
-
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
552
|
-
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
553
|
-
|
|
554
|
-
# for new sft data
|
|
555
|
-
num_new_tokens = tokenizer.add_tokens(
|
|
556
|
-
["<box>", "</box>", "<ref>", "</ref>", "<quad>", "</quad>"],
|
|
557
|
-
special_tokens=True,
|
|
558
|
-
)
|
|
559
|
-
self.resize_token_embeddings(len(tokenizer))
|
|
560
|
-
|
|
561
|
-
if num_new_tokens > 0:
|
|
562
|
-
input_embeddings = self.get_input_embeddings().weight.data
|
|
563
|
-
output_embeddings = self.get_output_embeddings().weight.data
|
|
564
|
-
|
|
565
|
-
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
|
566
|
-
dim=0, keepdim=True
|
|
567
|
-
)
|
|
568
|
-
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
|
569
|
-
dim=0, keepdim=True
|
|
570
|
-
)
|
|
571
|
-
|
|
572
|
-
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
573
|
-
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
574
|
-
|
|
575
|
-
if tune_mm_mlp_adapter:
|
|
576
|
-
self.model.orig_embeds_params = [
|
|
577
|
-
self.get_input_embeddings().weight.data.clone().to(device=device)
|
|
578
|
-
]
|
|
579
|
-
for p in self.get_input_embeddings().parameters():
|
|
580
|
-
p.requires_grad = True
|
|
581
|
-
for p in self.get_output_embeddings().parameters():
|
|
582
|
-
p.requires_grad = False
|
|
583
|
-
|
|
584
|
-
self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
585
|
-
[DEFAULT_IMAGE_PATCH_TOKEN]
|
|
586
|
-
)[0]
|
|
587
|
-
print(
|
|
588
|
-
f"Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}",
|
|
589
|
-
flush=True,
|
|
590
|
-
)
|
|
591
|
-
# exit()
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
AutoConfig.register("omnilmm", OmniLMMConfig)
|
|
595
|
-
AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
|