xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/oauth2/auth_service.py +132 -0
- xinference/api/restful_api.py +282 -78
- xinference/client/handlers.py +3 -0
- xinference/client/restful/restful_client.py +108 -75
- xinference/constants.py +14 -4
- xinference/core/cache_tracker.py +102 -0
- xinference/core/chat_interface.py +10 -4
- xinference/core/event.py +56 -0
- xinference/core/model.py +44 -0
- xinference/core/resource.py +19 -12
- xinference/core/status_guard.py +4 -0
- xinference/core/supervisor.py +278 -87
- xinference/core/utils.py +68 -3
- xinference/core/worker.py +98 -8
- xinference/deploy/cmdline.py +6 -3
- xinference/deploy/local.py +2 -2
- xinference/deploy/supervisor.py +2 -2
- xinference/model/audio/__init__.py +27 -0
- xinference/model/audio/core.py +161 -0
- xinference/model/audio/model_spec.json +79 -0
- xinference/model/audio/utils.py +18 -0
- xinference/model/audio/whisper.py +132 -0
- xinference/model/core.py +18 -13
- xinference/model/embedding/__init__.py +27 -2
- xinference/model/embedding/core.py +43 -3
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/utils.py +18 -0
- xinference/model/image/__init__.py +12 -1
- xinference/model/image/core.py +63 -9
- xinference/model/image/utils.py +26 -0
- xinference/model/llm/__init__.py +20 -1
- xinference/model/llm/core.py +43 -2
- xinference/model/llm/ggml/chatglm.py +15 -6
- xinference/model/llm/llm_family.json +197 -6
- xinference/model/llm/llm_family.py +9 -7
- xinference/model/llm/llm_family_modelscope.json +189 -4
- xinference/model/llm/pytorch/chatglm.py +3 -3
- xinference/model/llm/pytorch/core.py +4 -2
- xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
- xinference/model/llm/pytorch/utils.py +21 -9
- xinference/model/llm/pytorch/yi_vl.py +246 -0
- xinference/model/llm/utils.py +57 -4
- xinference/model/llm/vllm/core.py +5 -4
- xinference/model/rerank/__init__.py +25 -2
- xinference/model/rerank/core.py +51 -9
- xinference/model/rerank/model_spec.json +6 -0
- xinference/model/rerank/model_spec_modelscope.json +7 -0
- xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
- xinference/model/utils.py +5 -3
- xinference/thirdparty/__init__.py +0 -0
- xinference/thirdparty/llava/__init__.py +1 -0
- xinference/thirdparty/llava/conversation.py +205 -0
- xinference/thirdparty/llava/mm_utils.py +122 -0
- xinference/thirdparty/llava/model/__init__.py +1 -0
- xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
- xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
- xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
- xinference/thirdparty/llava/model/constants.py +6 -0
- xinference/thirdparty/llava/model/llava_arch.py +385 -0
- xinference/thirdparty/llava/model/llava_llama.py +163 -0
- xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
- xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
- xinference/types.py +1 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
- xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
- xinference/api/oauth2/core.py +0 -93
- xinference/model/multimodal/__init__.py +0 -52
- xinference/model/multimodal/core.py +0 -467
- xinference/model/multimodal/model_spec.json +0 -43
- xinference/model/multimodal/model_spec_modelscope.json +0 -45
- xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
- xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
- /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
# Copyright 2023 Haotian Liu
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, key_info
|
|
21
|
+
|
|
22
|
+
from .clip_encoder.builder import build_vision_tower
|
|
23
|
+
from .multimodal_projector.builder import build_vision_projector
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LlavaMetaModel:
|
|
27
|
+
def __init__(self, config):
|
|
28
|
+
super(LlavaMetaModel, self).__init__(config)
|
|
29
|
+
|
|
30
|
+
if hasattr(config, "mm_vision_tower"):
|
|
31
|
+
config.mm_vision_tower = os.path.join(
|
|
32
|
+
key_info["model_path"], config.mm_vision_tower.replace("./", "")
|
|
33
|
+
)
|
|
34
|
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
|
35
|
+
self.mm_projector = build_vision_projector(config)
|
|
36
|
+
|
|
37
|
+
def get_vision_tower(self):
|
|
38
|
+
vision_tower = getattr(self, "vision_tower", None)
|
|
39
|
+
if type(vision_tower) is list:
|
|
40
|
+
vision_tower = vision_tower[0]
|
|
41
|
+
return vision_tower
|
|
42
|
+
|
|
43
|
+
def initialize_vision_modules(self, model_args):
|
|
44
|
+
vision_tower = model_args.vision_tower
|
|
45
|
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
|
46
|
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
|
47
|
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
|
48
|
+
|
|
49
|
+
self.config.mm_vision_tower = vision_tower
|
|
50
|
+
|
|
51
|
+
if self.get_vision_tower() is None:
|
|
52
|
+
vision_tower = build_vision_tower(model_args)
|
|
53
|
+
self.vision_tower = vision_tower
|
|
54
|
+
else:
|
|
55
|
+
vision_tower = self.vision_tower
|
|
56
|
+
if not vision_tower.is_loaded:
|
|
57
|
+
vision_tower.load_model()
|
|
58
|
+
|
|
59
|
+
self.config.use_mm_proj = True
|
|
60
|
+
self.config.mm_projector_type = getattr(
|
|
61
|
+
model_args, "mm_projector_type", "linear"
|
|
62
|
+
)
|
|
63
|
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
|
64
|
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
|
65
|
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
|
66
|
+
|
|
67
|
+
if getattr(self, "mm_projector", None) is None:
|
|
68
|
+
self.mm_projector = build_vision_projector(self.config)
|
|
69
|
+
|
|
70
|
+
if pretrain_mm_mlp_adapter is not None:
|
|
71
|
+
mm_projector_weights = torch.load(
|
|
72
|
+
pretrain_mm_mlp_adapter, map_location="cpu"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def get_w(weights, keyword):
|
|
76
|
+
return {
|
|
77
|
+
k.split(keyword + ".")[1]: v
|
|
78
|
+
for k, v in weights.items()
|
|
79
|
+
if keyword in k
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
self.mm_projector.load_state_dict(
|
|
83
|
+
get_w(mm_projector_weights, "mm_projector")
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class LlavaMetaForCausalLM(ABC):
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def get_model(self):
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
def get_vision_tower(self):
|
|
93
|
+
return self.get_model().get_vision_tower()
|
|
94
|
+
|
|
95
|
+
def encode_images(self, images):
|
|
96
|
+
image_features = self.get_model().get_vision_tower()(images)
|
|
97
|
+
image_features = self.get_model().mm_projector(image_features)
|
|
98
|
+
return image_features
|
|
99
|
+
|
|
100
|
+
def prepare_inputs_labels_for_multimodal(
|
|
101
|
+
self, input_ids, attention_mask, past_key_values, labels, images
|
|
102
|
+
):
|
|
103
|
+
vision_tower = self.get_vision_tower()
|
|
104
|
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
|
105
|
+
if (
|
|
106
|
+
past_key_values is not None
|
|
107
|
+
and vision_tower is not None
|
|
108
|
+
and images is not None
|
|
109
|
+
and input_ids.shape[1] == 1
|
|
110
|
+
):
|
|
111
|
+
new_mask = torch.ones(
|
|
112
|
+
(
|
|
113
|
+
attention_mask.shape[0],
|
|
114
|
+
past_key_values[-1][-1].shape[-2] + 1 - attention_mask.shape[1],
|
|
115
|
+
),
|
|
116
|
+
dtype=attention_mask.dtype,
|
|
117
|
+
device=attention_mask.device,
|
|
118
|
+
)
|
|
119
|
+
attention_mask = torch.cat([attention_mask, new_mask], dim=1)
|
|
120
|
+
return input_ids, attention_mask, past_key_values, None, labels
|
|
121
|
+
|
|
122
|
+
if type(images) is list or images.ndim == 5:
|
|
123
|
+
concat_images = torch.cat([image for image in images], dim=0)
|
|
124
|
+
image_features = self.encode_images(concat_images)
|
|
125
|
+
split_sizes = [image.shape[0] for image in images]
|
|
126
|
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
|
127
|
+
image_features = [x.flatten(0, 1) for x in image_features]
|
|
128
|
+
else:
|
|
129
|
+
image_features = self.encode_images(images)
|
|
130
|
+
|
|
131
|
+
new_input_embeds = []
|
|
132
|
+
new_labels = [] if labels is not None else None
|
|
133
|
+
cur_image_idx = 0
|
|
134
|
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
|
135
|
+
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
|
136
|
+
# multimodal LLM, but the current sample is not multimodal
|
|
137
|
+
# cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
|
|
138
|
+
# cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
|
|
139
|
+
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
|
140
|
+
half_len = cur_input_ids.shape[0] // 2
|
|
141
|
+
cur_image_features = image_features[cur_image_idx]
|
|
142
|
+
cur_input_embeds_1 = self.get_model().embed_tokens(
|
|
143
|
+
cur_input_ids[:half_len]
|
|
144
|
+
)
|
|
145
|
+
cur_input_embeds_2 = self.get_model().embed_tokens(
|
|
146
|
+
cur_input_ids[half_len:]
|
|
147
|
+
)
|
|
148
|
+
cur_input_embeds = torch.cat(
|
|
149
|
+
[cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2],
|
|
150
|
+
dim=0,
|
|
151
|
+
)
|
|
152
|
+
new_input_embeds.append(cur_input_embeds)
|
|
153
|
+
if labels is not None:
|
|
154
|
+
new_labels.append(labels[batch_idx])
|
|
155
|
+
cur_image_idx += 1
|
|
156
|
+
continue
|
|
157
|
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
|
158
|
+
cur_new_input_embeds = []
|
|
159
|
+
if labels is not None:
|
|
160
|
+
cur_labels = labels[batch_idx]
|
|
161
|
+
cur_new_labels = []
|
|
162
|
+
assert cur_labels.shape == cur_input_ids.shape
|
|
163
|
+
while image_token_indices.numel() > 0:
|
|
164
|
+
cur_image_features = image_features[cur_image_idx]
|
|
165
|
+
image_token_start = image_token_indices[0]
|
|
166
|
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
|
167
|
+
self.config, "mm_use_im_start_end", False
|
|
168
|
+
):
|
|
169
|
+
cur_new_input_embeds.append(
|
|
170
|
+
self.get_model()
|
|
171
|
+
.embed_tokens(cur_input_ids[: image_token_start - 1])
|
|
172
|
+
.detach()
|
|
173
|
+
)
|
|
174
|
+
cur_new_input_embeds.append(
|
|
175
|
+
self.get_model().embed_tokens(
|
|
176
|
+
cur_input_ids[image_token_start - 1 : image_token_start]
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
cur_new_input_embeds.append(cur_image_features)
|
|
180
|
+
cur_new_input_embeds.append(
|
|
181
|
+
self.get_model().embed_tokens(
|
|
182
|
+
cur_input_ids[image_token_start + 1 : image_token_start + 2]
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
if labels is not None:
|
|
186
|
+
cur_new_labels.append(cur_labels[:image_token_start])
|
|
187
|
+
cur_new_labels.append(
|
|
188
|
+
torch.full(
|
|
189
|
+
(cur_image_features.shape[0],),
|
|
190
|
+
IGNORE_INDEX,
|
|
191
|
+
device=labels.device,
|
|
192
|
+
dtype=labels.dtype,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
cur_new_labels.append(
|
|
196
|
+
cur_labels[image_token_start : image_token_start + 1]
|
|
197
|
+
)
|
|
198
|
+
cur_labels = cur_labels[image_token_start + 2 :]
|
|
199
|
+
else:
|
|
200
|
+
cur_new_input_embeds.append(
|
|
201
|
+
self.get_model().embed_tokens(cur_input_ids[:image_token_start])
|
|
202
|
+
)
|
|
203
|
+
cur_new_input_embeds.append(cur_image_features)
|
|
204
|
+
if labels is not None:
|
|
205
|
+
cur_new_labels.append(cur_labels[:image_token_start])
|
|
206
|
+
cur_new_labels.append(
|
|
207
|
+
torch.full(
|
|
208
|
+
(cur_image_features.shape[0],),
|
|
209
|
+
IGNORE_INDEX,
|
|
210
|
+
device=labels.device,
|
|
211
|
+
dtype=labels.dtype,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
cur_labels = cur_labels[image_token_start + 1 :]
|
|
215
|
+
cur_image_idx += 1
|
|
216
|
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
|
217
|
+
self.config, "mm_use_im_start_end", False
|
|
218
|
+
):
|
|
219
|
+
cur_input_ids = cur_input_ids[image_token_start + 2 :]
|
|
220
|
+
else:
|
|
221
|
+
cur_input_ids = cur_input_ids[image_token_start + 1 :]
|
|
222
|
+
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
|
223
|
+
if cur_input_ids.numel() > 0:
|
|
224
|
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
|
225
|
+
self.config, "mm_use_im_start_end", False
|
|
226
|
+
):
|
|
227
|
+
cur_new_input_embeds.append(
|
|
228
|
+
self.get_model().embed_tokens(cur_input_ids).detach()
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
cur_new_input_embeds.append(
|
|
232
|
+
self.get_model().embed_tokens(cur_input_ids)
|
|
233
|
+
)
|
|
234
|
+
if labels is not None:
|
|
235
|
+
cur_new_labels.append(cur_labels)
|
|
236
|
+
cur_new_input_embeds = [
|
|
237
|
+
x.to(device=self.device) for x in cur_new_input_embeds
|
|
238
|
+
]
|
|
239
|
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
|
240
|
+
new_input_embeds.append(cur_new_input_embeds)
|
|
241
|
+
if labels is not None:
|
|
242
|
+
cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
|
243
|
+
new_labels.append(cur_new_labels)
|
|
244
|
+
|
|
245
|
+
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
|
246
|
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
|
247
|
+
|
|
248
|
+
new_input_embeds_align = []
|
|
249
|
+
for cur_new_embed in new_input_embeds:
|
|
250
|
+
cur_new_embed = torch.cat(
|
|
251
|
+
(
|
|
252
|
+
cur_new_embed,
|
|
253
|
+
torch.zeros(
|
|
254
|
+
(max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
|
255
|
+
dtype=cur_new_embed.dtype,
|
|
256
|
+
device=cur_new_embed.device,
|
|
257
|
+
),
|
|
258
|
+
),
|
|
259
|
+
dim=0,
|
|
260
|
+
)
|
|
261
|
+
new_input_embeds_align.append(cur_new_embed)
|
|
262
|
+
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
|
263
|
+
|
|
264
|
+
if labels is not None:
|
|
265
|
+
new_labels_align = []
|
|
266
|
+
_new_labels = new_labels
|
|
267
|
+
for cur_new_label in new_labels:
|
|
268
|
+
cur_new_label = torch.cat(
|
|
269
|
+
(
|
|
270
|
+
cur_new_label,
|
|
271
|
+
torch.full(
|
|
272
|
+
(max_len - cur_new_label.shape[0],),
|
|
273
|
+
IGNORE_INDEX,
|
|
274
|
+
dtype=cur_new_label.dtype,
|
|
275
|
+
device=cur_new_label.device,
|
|
276
|
+
),
|
|
277
|
+
),
|
|
278
|
+
dim=0,
|
|
279
|
+
)
|
|
280
|
+
new_labels_align.append(cur_new_label)
|
|
281
|
+
new_labels = torch.stack(new_labels_align, dim=0)
|
|
282
|
+
|
|
283
|
+
if attention_mask is not None:
|
|
284
|
+
new_attention_mask = []
|
|
285
|
+
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(
|
|
286
|
+
attention_mask, _new_labels, new_labels
|
|
287
|
+
):
|
|
288
|
+
new_attn_mask_pad_left = torch.full(
|
|
289
|
+
(cur_new_labels.shape[0] - labels.shape[1],),
|
|
290
|
+
True,
|
|
291
|
+
dtype=attention_mask.dtype,
|
|
292
|
+
device=attention_mask.device,
|
|
293
|
+
)
|
|
294
|
+
new_attn_mask_pad_right = torch.full(
|
|
295
|
+
(cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
|
|
296
|
+
False,
|
|
297
|
+
dtype=attention_mask.dtype,
|
|
298
|
+
device=attention_mask.device,
|
|
299
|
+
)
|
|
300
|
+
cur_new_attention_mask = torch.cat(
|
|
301
|
+
(
|
|
302
|
+
new_attn_mask_pad_left,
|
|
303
|
+
cur_attention_mask,
|
|
304
|
+
new_attn_mask_pad_right,
|
|
305
|
+
),
|
|
306
|
+
dim=0,
|
|
307
|
+
)
|
|
308
|
+
new_attention_mask.append(cur_new_attention_mask)
|
|
309
|
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
|
310
|
+
assert attention_mask.shape == new_labels.shape
|
|
311
|
+
else:
|
|
312
|
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
|
313
|
+
if labels is not None:
|
|
314
|
+
new_labels = torch.stack(new_labels, dim=0)
|
|
315
|
+
|
|
316
|
+
if attention_mask is not None:
|
|
317
|
+
new_attn_mask_pad_right = torch.full(
|
|
318
|
+
(
|
|
319
|
+
attention_mask.shape[0],
|
|
320
|
+
new_input_embeds.shape[1] - input_ids.shape[1],
|
|
321
|
+
),
|
|
322
|
+
True,
|
|
323
|
+
dtype=attention_mask.dtype,
|
|
324
|
+
device=attention_mask.device,
|
|
325
|
+
)
|
|
326
|
+
attention_mask = torch.cat(
|
|
327
|
+
(attention_mask, new_attn_mask_pad_right), dim=1
|
|
328
|
+
)
|
|
329
|
+
assert attention_mask.shape == new_input_embeds.shape[:2]
|
|
330
|
+
|
|
331
|
+
return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
|
332
|
+
|
|
333
|
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
|
334
|
+
if model_args.mm_use_im_patch_token:
|
|
335
|
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
|
336
|
+
self.resize_token_embeddings(len(tokenizer))
|
|
337
|
+
|
|
338
|
+
if model_args.mm_use_im_start_end:
|
|
339
|
+
num_new_tokens = tokenizer.add_tokens(
|
|
340
|
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
|
341
|
+
)
|
|
342
|
+
self.resize_token_embeddings(len(tokenizer))
|
|
343
|
+
|
|
344
|
+
if num_new_tokens > 0:
|
|
345
|
+
input_embeddings = self.get_input_embeddings().weight.data
|
|
346
|
+
output_embeddings = self.get_output_embeddings().weight.data
|
|
347
|
+
|
|
348
|
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
|
349
|
+
dim=0, keepdim=True
|
|
350
|
+
)
|
|
351
|
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
|
352
|
+
dim=0, keepdim=True
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
356
|
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
357
|
+
|
|
358
|
+
if model_args.tune_mm_mlp_adapter:
|
|
359
|
+
for p in self.get_input_embeddings().parameters():
|
|
360
|
+
p.requires_grad = True
|
|
361
|
+
for p in self.get_output_embeddings().parameters():
|
|
362
|
+
p.requires_grad = False
|
|
363
|
+
|
|
364
|
+
if model_args.pretrain_mm_mlp_adapter:
|
|
365
|
+
mm_projector_weights = torch.load(
|
|
366
|
+
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
|
|
367
|
+
)
|
|
368
|
+
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
|
369
|
+
assert num_new_tokens == 2
|
|
370
|
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
|
371
|
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
|
|
372
|
+
-num_new_tokens:
|
|
373
|
+
]
|
|
374
|
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
|
375
|
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
|
376
|
+
else:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Number of new tokens: {num_new_tokens}."
|
|
379
|
+
)
|
|
380
|
+
elif model_args.mm_use_im_patch_token:
|
|
381
|
+
if model_args.tune_mm_mlp_adapter:
|
|
382
|
+
for p in self.get_input_embeddings().parameters():
|
|
383
|
+
p.requires_grad = False
|
|
384
|
+
for p in self.get_output_embeddings().parameters():
|
|
385
|
+
p.requires_grad = False
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright 2023 Haotian Liu
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from typing import List, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
from torch.nn import CrossEntropyLoss
|
|
21
|
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
|
22
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
23
|
+
|
|
24
|
+
from .llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LlavaConfig(LlamaConfig):
|
|
28
|
+
model_type = "llava"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
|
32
|
+
config_class = LlavaConfig
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: LlamaConfig):
|
|
35
|
+
config._flash_attn_2_enabled = True ######set flash attention2!!!!!!
|
|
36
|
+
super(LlavaLlamaModel, self).__init__(config)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
|
40
|
+
config_class = LlavaConfig
|
|
41
|
+
|
|
42
|
+
def __init__(self, config):
|
|
43
|
+
super(LlamaForCausalLM, self).__init__(config)
|
|
44
|
+
self.model = LlavaLlamaModel(config)
|
|
45
|
+
|
|
46
|
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
47
|
+
|
|
48
|
+
# Initialize weights and apply final processing
|
|
49
|
+
self.post_init()
|
|
50
|
+
|
|
51
|
+
def get_model(self):
|
|
52
|
+
return self.model
|
|
53
|
+
|
|
54
|
+
def forward(
|
|
55
|
+
self,
|
|
56
|
+
input_ids: torch.LongTensor = None,
|
|
57
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
58
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
59
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
60
|
+
labels: Optional[torch.LongTensor] = None,
|
|
61
|
+
use_cache: Optional[bool] = None,
|
|
62
|
+
output_attentions: Optional[bool] = None,
|
|
63
|
+
output_hidden_states: Optional[bool] = None,
|
|
64
|
+
images: Optional[torch.FloatTensor] = None,
|
|
65
|
+
return_dict: Optional[bool] = None,
|
|
66
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
67
|
+
output_attentions = (
|
|
68
|
+
output_attentions
|
|
69
|
+
if output_attentions is not None
|
|
70
|
+
else self.config.output_attentions
|
|
71
|
+
)
|
|
72
|
+
output_hidden_states = (
|
|
73
|
+
output_hidden_states
|
|
74
|
+
if output_hidden_states is not None
|
|
75
|
+
else self.config.output_hidden_states
|
|
76
|
+
)
|
|
77
|
+
return_dict = (
|
|
78
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
(
|
|
82
|
+
input_ids,
|
|
83
|
+
attention_mask,
|
|
84
|
+
past_key_values,
|
|
85
|
+
inputs_embeds,
|
|
86
|
+
labels,
|
|
87
|
+
) = self.prepare_inputs_labels_for_multimodal(
|
|
88
|
+
input_ids, attention_mask, past_key_values, labels, images
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
92
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
93
|
+
if past_key_values:
|
|
94
|
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
95
|
+
|
|
96
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
97
|
+
outputs = self.model(
|
|
98
|
+
input_ids=input_ids,
|
|
99
|
+
attention_mask=attention_mask,
|
|
100
|
+
position_ids=position_ids,
|
|
101
|
+
past_key_values=past_key_values,
|
|
102
|
+
inputs_embeds=inputs_embeds,
|
|
103
|
+
use_cache=use_cache,
|
|
104
|
+
output_attentions=output_attentions,
|
|
105
|
+
output_hidden_states=output_hidden_states,
|
|
106
|
+
return_dict=return_dict,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
hidden_states = outputs[0]
|
|
110
|
+
logits = self.lm_head(hidden_states)
|
|
111
|
+
|
|
112
|
+
loss = None
|
|
113
|
+
if labels is not None:
|
|
114
|
+
# Shift so that tokens < n predict n
|
|
115
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
116
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
117
|
+
# Flatten the tokens
|
|
118
|
+
loss_fct = CrossEntropyLoss()
|
|
119
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
120
|
+
shift_labels = shift_labels.view(-1)
|
|
121
|
+
# Enable model/pipeline parallelism
|
|
122
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
123
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
124
|
+
|
|
125
|
+
if not return_dict:
|
|
126
|
+
output = (logits,) + outputs[1:]
|
|
127
|
+
return (loss,) + output if loss is not None else output
|
|
128
|
+
|
|
129
|
+
return CausalLMOutputWithPast(
|
|
130
|
+
loss=loss,
|
|
131
|
+
logits=logits,
|
|
132
|
+
past_key_values=outputs.past_key_values,
|
|
133
|
+
hidden_states=outputs.hidden_states,
|
|
134
|
+
attentions=outputs.attentions,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def prepare_inputs_for_generation(
|
|
138
|
+
self,
|
|
139
|
+
input_ids,
|
|
140
|
+
past_key_values=None,
|
|
141
|
+
attention_mask=None,
|
|
142
|
+
inputs_embeds=None,
|
|
143
|
+
**kwargs
|
|
144
|
+
):
|
|
145
|
+
if past_key_values:
|
|
146
|
+
input_ids = input_ids[:, -1:]
|
|
147
|
+
|
|
148
|
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
149
|
+
if inputs_embeds is not None and past_key_values is None:
|
|
150
|
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
151
|
+
else:
|
|
152
|
+
model_inputs = {"input_ids": input_ids}
|
|
153
|
+
|
|
154
|
+
model_inputs.update(
|
|
155
|
+
{
|
|
156
|
+
# "position_ids": position_ids,
|
|
157
|
+
"past_key_values": past_key_values,
|
|
158
|
+
"use_cache": kwargs.get("use_cache"),
|
|
159
|
+
"attention_mask": attention_mask,
|
|
160
|
+
"images": kwargs.get("images", None),
|
|
161
|
+
}
|
|
162
|
+
)
|
|
163
|
+
return model_inputs
|
|
File without changes
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class IdentityMap(nn.Module):
|
|
7
|
+
def __init__(self):
|
|
8
|
+
super().__init__()
|
|
9
|
+
|
|
10
|
+
def forward(self, x, *args, **kwargs):
|
|
11
|
+
return x
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def config(self):
|
|
15
|
+
return {"mm_projector_type": "identity"}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SimpleResBlock(nn.Module):
|
|
19
|
+
def __init__(self, channels):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.pre_norm = nn.LayerNorm(channels)
|
|
22
|
+
|
|
23
|
+
self.proj = nn.Sequential(
|
|
24
|
+
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def forward(self, x):
|
|
28
|
+
x = self.pre_norm(x)
|
|
29
|
+
return x + self.proj(x)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
|
33
|
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
|
34
|
+
|
|
35
|
+
if projector_type == "linear":
|
|
36
|
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
|
37
|
+
|
|
38
|
+
use_norm = False
|
|
39
|
+
if "_Norm" in projector_type:
|
|
40
|
+
use_norm = True
|
|
41
|
+
projector_type = projector_type.replace("_Norm", "")
|
|
42
|
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
|
43
|
+
if mlp_gelu_match:
|
|
44
|
+
mlp_depth = int(mlp_gelu_match.group(1))
|
|
45
|
+
if use_norm:
|
|
46
|
+
modules = [
|
|
47
|
+
nn.Linear(config.mm_hidden_size, config.hidden_size),
|
|
48
|
+
nn.LayerNorm(config.hidden_size),
|
|
49
|
+
]
|
|
50
|
+
else:
|
|
51
|
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
|
52
|
+
for _ in range(1, mlp_depth):
|
|
53
|
+
modules.append(nn.GELU())
|
|
54
|
+
if use_norm:
|
|
55
|
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
|
56
|
+
modules.append(nn.LayerNorm(config.hidden_size))
|
|
57
|
+
else:
|
|
58
|
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
|
59
|
+
return nn.Sequential(*modules)
|
|
60
|
+
|
|
61
|
+
if projector_type == "identity":
|
|
62
|
+
return IdentityMap()
|
|
63
|
+
|
|
64
|
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
xinference/types.py
CHANGED
|
@@ -329,7 +329,7 @@ class ModelAndPrompt(BaseModel):
|
|
|
329
329
|
|
|
330
330
|
class CreateCompletionTorch(BaseModel):
|
|
331
331
|
echo: bool = echo_field
|
|
332
|
-
max_tokens: int = max_tokens_field
|
|
332
|
+
max_tokens: Optional[int] = max_tokens_field
|
|
333
333
|
repetition_penalty: float = repeat_penalty_field
|
|
334
334
|
stop: Optional[Union[str, List[str]]] = stop_field
|
|
335
335
|
stop_token_ids: Optional[Union[int, List[int]]] = none_field
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
{
|
|
2
2
|
"files": {
|
|
3
|
-
"main.js": "./static/js/main.
|
|
3
|
+
"main.js": "./static/js/main.15822aeb.js",
|
|
4
4
|
"static/media/icon.webp": "./static/media/icon.4603d52c63041e5dfbfd.webp",
|
|
5
5
|
"index.html": "./index.html",
|
|
6
|
-
"main.
|
|
6
|
+
"main.15822aeb.js.map": "./static/js/main.15822aeb.js.map"
|
|
7
7
|
},
|
|
8
8
|
"entrypoints": [
|
|
9
|
-
"static/js/main.
|
|
9
|
+
"static/js/main.15822aeb.js"
|
|
10
10
|
]
|
|
11
11
|
}
|
|
@@ -1 +1 @@
|
|
|
1
|
-
<!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.
|
|
1
|
+
<!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.15822aeb.js"></script></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
|