xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (95) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +132 -0
  3. xinference/api/restful_api.py +282 -78
  4. xinference/client/handlers.py +3 -0
  5. xinference/client/restful/restful_client.py +108 -75
  6. xinference/constants.py +14 -4
  7. xinference/core/cache_tracker.py +102 -0
  8. xinference/core/chat_interface.py +10 -4
  9. xinference/core/event.py +56 -0
  10. xinference/core/model.py +44 -0
  11. xinference/core/resource.py +19 -12
  12. xinference/core/status_guard.py +4 -0
  13. xinference/core/supervisor.py +278 -87
  14. xinference/core/utils.py +68 -3
  15. xinference/core/worker.py +98 -8
  16. xinference/deploy/cmdline.py +6 -3
  17. xinference/deploy/local.py +2 -2
  18. xinference/deploy/supervisor.py +2 -2
  19. xinference/model/audio/__init__.py +27 -0
  20. xinference/model/audio/core.py +161 -0
  21. xinference/model/audio/model_spec.json +79 -0
  22. xinference/model/audio/utils.py +18 -0
  23. xinference/model/audio/whisper.py +132 -0
  24. xinference/model/core.py +18 -13
  25. xinference/model/embedding/__init__.py +27 -2
  26. xinference/model/embedding/core.py +43 -3
  27. xinference/model/embedding/model_spec.json +24 -0
  28. xinference/model/embedding/model_spec_modelscope.json +24 -0
  29. xinference/model/embedding/utils.py +18 -0
  30. xinference/model/image/__init__.py +12 -1
  31. xinference/model/image/core.py +63 -9
  32. xinference/model/image/utils.py +26 -0
  33. xinference/model/llm/__init__.py +20 -1
  34. xinference/model/llm/core.py +43 -2
  35. xinference/model/llm/ggml/chatglm.py +15 -6
  36. xinference/model/llm/llm_family.json +197 -6
  37. xinference/model/llm/llm_family.py +9 -7
  38. xinference/model/llm/llm_family_modelscope.json +189 -4
  39. xinference/model/llm/pytorch/chatglm.py +3 -3
  40. xinference/model/llm/pytorch/core.py +4 -2
  41. xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
  42. xinference/model/llm/pytorch/utils.py +21 -9
  43. xinference/model/llm/pytorch/yi_vl.py +246 -0
  44. xinference/model/llm/utils.py +57 -4
  45. xinference/model/llm/vllm/core.py +5 -4
  46. xinference/model/rerank/__init__.py +25 -2
  47. xinference/model/rerank/core.py +51 -9
  48. xinference/model/rerank/model_spec.json +6 -0
  49. xinference/model/rerank/model_spec_modelscope.json +7 -0
  50. xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
  51. xinference/model/utils.py +5 -3
  52. xinference/thirdparty/__init__.py +0 -0
  53. xinference/thirdparty/llava/__init__.py +1 -0
  54. xinference/thirdparty/llava/conversation.py +205 -0
  55. xinference/thirdparty/llava/mm_utils.py +122 -0
  56. xinference/thirdparty/llava/model/__init__.py +1 -0
  57. xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
  58. xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
  59. xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
  60. xinference/thirdparty/llava/model/constants.py +6 -0
  61. xinference/thirdparty/llava/model/llava_arch.py +385 -0
  62. xinference/thirdparty/llava/model/llava_llama.py +163 -0
  63. xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
  64. xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
  65. xinference/types.py +1 -1
  66. xinference/web/ui/build/asset-manifest.json +3 -3
  67. xinference/web/ui/build/index.html +1 -1
  68. xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
  69. xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
  75. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
  76. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
  77. xinference/api/oauth2/core.py +0 -93
  78. xinference/model/multimodal/__init__.py +0 -52
  79. xinference/model/multimodal/core.py +0 -467
  80. xinference/model/multimodal/model_spec.json +0 -43
  81. xinference/model/multimodal/model_spec_modelscope.json +0 -45
  82. xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
  83. xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
  92. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
  93. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
  94. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
  95. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,385 @@
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ from abc import ABC, abstractmethod
18
+
19
+ import torch
20
+ from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, key_info
21
+
22
+ from .clip_encoder.builder import build_vision_tower
23
+ from .multimodal_projector.builder import build_vision_projector
24
+
25
+
26
+ class LlavaMetaModel:
27
+ def __init__(self, config):
28
+ super(LlavaMetaModel, self).__init__(config)
29
+
30
+ if hasattr(config, "mm_vision_tower"):
31
+ config.mm_vision_tower = os.path.join(
32
+ key_info["model_path"], config.mm_vision_tower.replace("./", "")
33
+ )
34
+ self.vision_tower = build_vision_tower(config, delay_load=True)
35
+ self.mm_projector = build_vision_projector(config)
36
+
37
+ def get_vision_tower(self):
38
+ vision_tower = getattr(self, "vision_tower", None)
39
+ if type(vision_tower) is list:
40
+ vision_tower = vision_tower[0]
41
+ return vision_tower
42
+
43
+ def initialize_vision_modules(self, model_args):
44
+ vision_tower = model_args.vision_tower
45
+ mm_vision_select_layer = model_args.mm_vision_select_layer
46
+ mm_vision_select_feature = model_args.mm_vision_select_feature
47
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
48
+
49
+ self.config.mm_vision_tower = vision_tower
50
+
51
+ if self.get_vision_tower() is None:
52
+ vision_tower = build_vision_tower(model_args)
53
+ self.vision_tower = vision_tower
54
+ else:
55
+ vision_tower = self.vision_tower
56
+ if not vision_tower.is_loaded:
57
+ vision_tower.load_model()
58
+
59
+ self.config.use_mm_proj = True
60
+ self.config.mm_projector_type = getattr(
61
+ model_args, "mm_projector_type", "linear"
62
+ )
63
+ self.config.mm_hidden_size = vision_tower.hidden_size
64
+ self.config.mm_vision_select_layer = mm_vision_select_layer
65
+ self.config.mm_vision_select_feature = mm_vision_select_feature
66
+
67
+ if getattr(self, "mm_projector", None) is None:
68
+ self.mm_projector = build_vision_projector(self.config)
69
+
70
+ if pretrain_mm_mlp_adapter is not None:
71
+ mm_projector_weights = torch.load(
72
+ pretrain_mm_mlp_adapter, map_location="cpu"
73
+ )
74
+
75
+ def get_w(weights, keyword):
76
+ return {
77
+ k.split(keyword + ".")[1]: v
78
+ for k, v in weights.items()
79
+ if keyword in k
80
+ }
81
+
82
+ self.mm_projector.load_state_dict(
83
+ get_w(mm_projector_weights, "mm_projector")
84
+ )
85
+
86
+
87
+ class LlavaMetaForCausalLM(ABC):
88
+ @abstractmethod
89
+ def get_model(self):
90
+ pass
91
+
92
+ def get_vision_tower(self):
93
+ return self.get_model().get_vision_tower()
94
+
95
+ def encode_images(self, images):
96
+ image_features = self.get_model().get_vision_tower()(images)
97
+ image_features = self.get_model().mm_projector(image_features)
98
+ return image_features
99
+
100
+ def prepare_inputs_labels_for_multimodal(
101
+ self, input_ids, attention_mask, past_key_values, labels, images
102
+ ):
103
+ vision_tower = self.get_vision_tower()
104
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
105
+ if (
106
+ past_key_values is not None
107
+ and vision_tower is not None
108
+ and images is not None
109
+ and input_ids.shape[1] == 1
110
+ ):
111
+ new_mask = torch.ones(
112
+ (
113
+ attention_mask.shape[0],
114
+ past_key_values[-1][-1].shape[-2] + 1 - attention_mask.shape[1],
115
+ ),
116
+ dtype=attention_mask.dtype,
117
+ device=attention_mask.device,
118
+ )
119
+ attention_mask = torch.cat([attention_mask, new_mask], dim=1)
120
+ return input_ids, attention_mask, past_key_values, None, labels
121
+
122
+ if type(images) is list or images.ndim == 5:
123
+ concat_images = torch.cat([image for image in images], dim=0)
124
+ image_features = self.encode_images(concat_images)
125
+ split_sizes = [image.shape[0] for image in images]
126
+ image_features = torch.split(image_features, split_sizes, dim=0)
127
+ image_features = [x.flatten(0, 1) for x in image_features]
128
+ else:
129
+ image_features = self.encode_images(images)
130
+
131
+ new_input_embeds = []
132
+ new_labels = [] if labels is not None else None
133
+ cur_image_idx = 0
134
+ for batch_idx, cur_input_ids in enumerate(input_ids):
135
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
136
+ # multimodal LLM, but the current sample is not multimodal
137
+ # cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
138
+ # cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
139
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
140
+ half_len = cur_input_ids.shape[0] // 2
141
+ cur_image_features = image_features[cur_image_idx]
142
+ cur_input_embeds_1 = self.get_model().embed_tokens(
143
+ cur_input_ids[:half_len]
144
+ )
145
+ cur_input_embeds_2 = self.get_model().embed_tokens(
146
+ cur_input_ids[half_len:]
147
+ )
148
+ cur_input_embeds = torch.cat(
149
+ [cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2],
150
+ dim=0,
151
+ )
152
+ new_input_embeds.append(cur_input_embeds)
153
+ if labels is not None:
154
+ new_labels.append(labels[batch_idx])
155
+ cur_image_idx += 1
156
+ continue
157
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
158
+ cur_new_input_embeds = []
159
+ if labels is not None:
160
+ cur_labels = labels[batch_idx]
161
+ cur_new_labels = []
162
+ assert cur_labels.shape == cur_input_ids.shape
163
+ while image_token_indices.numel() > 0:
164
+ cur_image_features = image_features[cur_image_idx]
165
+ image_token_start = image_token_indices[0]
166
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
167
+ self.config, "mm_use_im_start_end", False
168
+ ):
169
+ cur_new_input_embeds.append(
170
+ self.get_model()
171
+ .embed_tokens(cur_input_ids[: image_token_start - 1])
172
+ .detach()
173
+ )
174
+ cur_new_input_embeds.append(
175
+ self.get_model().embed_tokens(
176
+ cur_input_ids[image_token_start - 1 : image_token_start]
177
+ )
178
+ )
179
+ cur_new_input_embeds.append(cur_image_features)
180
+ cur_new_input_embeds.append(
181
+ self.get_model().embed_tokens(
182
+ cur_input_ids[image_token_start + 1 : image_token_start + 2]
183
+ )
184
+ )
185
+ if labels is not None:
186
+ cur_new_labels.append(cur_labels[:image_token_start])
187
+ cur_new_labels.append(
188
+ torch.full(
189
+ (cur_image_features.shape[0],),
190
+ IGNORE_INDEX,
191
+ device=labels.device,
192
+ dtype=labels.dtype,
193
+ )
194
+ )
195
+ cur_new_labels.append(
196
+ cur_labels[image_token_start : image_token_start + 1]
197
+ )
198
+ cur_labels = cur_labels[image_token_start + 2 :]
199
+ else:
200
+ cur_new_input_embeds.append(
201
+ self.get_model().embed_tokens(cur_input_ids[:image_token_start])
202
+ )
203
+ cur_new_input_embeds.append(cur_image_features)
204
+ if labels is not None:
205
+ cur_new_labels.append(cur_labels[:image_token_start])
206
+ cur_new_labels.append(
207
+ torch.full(
208
+ (cur_image_features.shape[0],),
209
+ IGNORE_INDEX,
210
+ device=labels.device,
211
+ dtype=labels.dtype,
212
+ )
213
+ )
214
+ cur_labels = cur_labels[image_token_start + 1 :]
215
+ cur_image_idx += 1
216
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
217
+ self.config, "mm_use_im_start_end", False
218
+ ):
219
+ cur_input_ids = cur_input_ids[image_token_start + 2 :]
220
+ else:
221
+ cur_input_ids = cur_input_ids[image_token_start + 1 :]
222
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
223
+ if cur_input_ids.numel() > 0:
224
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
225
+ self.config, "mm_use_im_start_end", False
226
+ ):
227
+ cur_new_input_embeds.append(
228
+ self.get_model().embed_tokens(cur_input_ids).detach()
229
+ )
230
+ else:
231
+ cur_new_input_embeds.append(
232
+ self.get_model().embed_tokens(cur_input_ids)
233
+ )
234
+ if labels is not None:
235
+ cur_new_labels.append(cur_labels)
236
+ cur_new_input_embeds = [
237
+ x.to(device=self.device) for x in cur_new_input_embeds
238
+ ]
239
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
240
+ new_input_embeds.append(cur_new_input_embeds)
241
+ if labels is not None:
242
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
243
+ new_labels.append(cur_new_labels)
244
+
245
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
246
+ max_len = max(x.shape[0] for x in new_input_embeds)
247
+
248
+ new_input_embeds_align = []
249
+ for cur_new_embed in new_input_embeds:
250
+ cur_new_embed = torch.cat(
251
+ (
252
+ cur_new_embed,
253
+ torch.zeros(
254
+ (max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
255
+ dtype=cur_new_embed.dtype,
256
+ device=cur_new_embed.device,
257
+ ),
258
+ ),
259
+ dim=0,
260
+ )
261
+ new_input_embeds_align.append(cur_new_embed)
262
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
263
+
264
+ if labels is not None:
265
+ new_labels_align = []
266
+ _new_labels = new_labels
267
+ for cur_new_label in new_labels:
268
+ cur_new_label = torch.cat(
269
+ (
270
+ cur_new_label,
271
+ torch.full(
272
+ (max_len - cur_new_label.shape[0],),
273
+ IGNORE_INDEX,
274
+ dtype=cur_new_label.dtype,
275
+ device=cur_new_label.device,
276
+ ),
277
+ ),
278
+ dim=0,
279
+ )
280
+ new_labels_align.append(cur_new_label)
281
+ new_labels = torch.stack(new_labels_align, dim=0)
282
+
283
+ if attention_mask is not None:
284
+ new_attention_mask = []
285
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(
286
+ attention_mask, _new_labels, new_labels
287
+ ):
288
+ new_attn_mask_pad_left = torch.full(
289
+ (cur_new_labels.shape[0] - labels.shape[1],),
290
+ True,
291
+ dtype=attention_mask.dtype,
292
+ device=attention_mask.device,
293
+ )
294
+ new_attn_mask_pad_right = torch.full(
295
+ (cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
296
+ False,
297
+ dtype=attention_mask.dtype,
298
+ device=attention_mask.device,
299
+ )
300
+ cur_new_attention_mask = torch.cat(
301
+ (
302
+ new_attn_mask_pad_left,
303
+ cur_attention_mask,
304
+ new_attn_mask_pad_right,
305
+ ),
306
+ dim=0,
307
+ )
308
+ new_attention_mask.append(cur_new_attention_mask)
309
+ attention_mask = torch.stack(new_attention_mask, dim=0)
310
+ assert attention_mask.shape == new_labels.shape
311
+ else:
312
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
313
+ if labels is not None:
314
+ new_labels = torch.stack(new_labels, dim=0)
315
+
316
+ if attention_mask is not None:
317
+ new_attn_mask_pad_right = torch.full(
318
+ (
319
+ attention_mask.shape[0],
320
+ new_input_embeds.shape[1] - input_ids.shape[1],
321
+ ),
322
+ True,
323
+ dtype=attention_mask.dtype,
324
+ device=attention_mask.device,
325
+ )
326
+ attention_mask = torch.cat(
327
+ (attention_mask, new_attn_mask_pad_right), dim=1
328
+ )
329
+ assert attention_mask.shape == new_input_embeds.shape[:2]
330
+
331
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
332
+
333
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
334
+ if model_args.mm_use_im_patch_token:
335
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
336
+ self.resize_token_embeddings(len(tokenizer))
337
+
338
+ if model_args.mm_use_im_start_end:
339
+ num_new_tokens = tokenizer.add_tokens(
340
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
341
+ )
342
+ self.resize_token_embeddings(len(tokenizer))
343
+
344
+ if num_new_tokens > 0:
345
+ input_embeddings = self.get_input_embeddings().weight.data
346
+ output_embeddings = self.get_output_embeddings().weight.data
347
+
348
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
349
+ dim=0, keepdim=True
350
+ )
351
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
352
+ dim=0, keepdim=True
353
+ )
354
+
355
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
356
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
357
+
358
+ if model_args.tune_mm_mlp_adapter:
359
+ for p in self.get_input_embeddings().parameters():
360
+ p.requires_grad = True
361
+ for p in self.get_output_embeddings().parameters():
362
+ p.requires_grad = False
363
+
364
+ if model_args.pretrain_mm_mlp_adapter:
365
+ mm_projector_weights = torch.load(
366
+ model_args.pretrain_mm_mlp_adapter, map_location="cpu"
367
+ )
368
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
369
+ assert num_new_tokens == 2
370
+ if input_embeddings.shape == embed_tokens_weight.shape:
371
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
372
+ -num_new_tokens:
373
+ ]
374
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
375
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
376
+ else:
377
+ raise ValueError(
378
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Number of new tokens: {num_new_tokens}."
379
+ )
380
+ elif model_args.mm_use_im_patch_token:
381
+ if model_args.tune_mm_mlp_adapter:
382
+ for p in self.get_input_embeddings().parameters():
383
+ p.requires_grad = False
384
+ for p in self.get_output_embeddings().parameters():
385
+ p.requires_grad = False
@@ -0,0 +1,163 @@
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
22
+ from transformers.modeling_outputs import CausalLMOutputWithPast
23
+
24
+ from .llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
25
+
26
+
27
+ class LlavaConfig(LlamaConfig):
28
+ model_type = "llava"
29
+
30
+
31
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
32
+ config_class = LlavaConfig
33
+
34
+ def __init__(self, config: LlamaConfig):
35
+ config._flash_attn_2_enabled = True ######set flash attention2!!!!!!
36
+ super(LlavaLlamaModel, self).__init__(config)
37
+
38
+
39
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
40
+ config_class = LlavaConfig
41
+
42
+ def __init__(self, config):
43
+ super(LlamaForCausalLM, self).__init__(config)
44
+ self.model = LlavaLlamaModel(config)
45
+
46
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
47
+
48
+ # Initialize weights and apply final processing
49
+ self.post_init()
50
+
51
+ def get_model(self):
52
+ return self.model
53
+
54
+ def forward(
55
+ self,
56
+ input_ids: torch.LongTensor = None,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
59
+ inputs_embeds: Optional[torch.FloatTensor] = None,
60
+ labels: Optional[torch.LongTensor] = None,
61
+ use_cache: Optional[bool] = None,
62
+ output_attentions: Optional[bool] = None,
63
+ output_hidden_states: Optional[bool] = None,
64
+ images: Optional[torch.FloatTensor] = None,
65
+ return_dict: Optional[bool] = None,
66
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
67
+ output_attentions = (
68
+ output_attentions
69
+ if output_attentions is not None
70
+ else self.config.output_attentions
71
+ )
72
+ output_hidden_states = (
73
+ output_hidden_states
74
+ if output_hidden_states is not None
75
+ else self.config.output_hidden_states
76
+ )
77
+ return_dict = (
78
+ return_dict if return_dict is not None else self.config.use_return_dict
79
+ )
80
+
81
+ (
82
+ input_ids,
83
+ attention_mask,
84
+ past_key_values,
85
+ inputs_embeds,
86
+ labels,
87
+ ) = self.prepare_inputs_labels_for_multimodal(
88
+ input_ids, attention_mask, past_key_values, labels, images
89
+ )
90
+
91
+ position_ids = attention_mask.long().cumsum(-1) - 1
92
+ position_ids.masked_fill_(attention_mask == 0, 1)
93
+ if past_key_values:
94
+ position_ids = position_ids[:, -1].unsqueeze(-1)
95
+
96
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
97
+ outputs = self.model(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ position_ids=position_ids,
101
+ past_key_values=past_key_values,
102
+ inputs_embeds=inputs_embeds,
103
+ use_cache=use_cache,
104
+ output_attentions=output_attentions,
105
+ output_hidden_states=output_hidden_states,
106
+ return_dict=return_dict,
107
+ )
108
+
109
+ hidden_states = outputs[0]
110
+ logits = self.lm_head(hidden_states)
111
+
112
+ loss = None
113
+ if labels is not None:
114
+ # Shift so that tokens < n predict n
115
+ shift_logits = logits[..., :-1, :].contiguous()
116
+ shift_labels = labels[..., 1:].contiguous()
117
+ # Flatten the tokens
118
+ loss_fct = CrossEntropyLoss()
119
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
120
+ shift_labels = shift_labels.view(-1)
121
+ # Enable model/pipeline parallelism
122
+ shift_labels = shift_labels.to(shift_logits.device)
123
+ loss = loss_fct(shift_logits, shift_labels)
124
+
125
+ if not return_dict:
126
+ output = (logits,) + outputs[1:]
127
+ return (loss,) + output if loss is not None else output
128
+
129
+ return CausalLMOutputWithPast(
130
+ loss=loss,
131
+ logits=logits,
132
+ past_key_values=outputs.past_key_values,
133
+ hidden_states=outputs.hidden_states,
134
+ attentions=outputs.attentions,
135
+ )
136
+
137
+ def prepare_inputs_for_generation(
138
+ self,
139
+ input_ids,
140
+ past_key_values=None,
141
+ attention_mask=None,
142
+ inputs_embeds=None,
143
+ **kwargs
144
+ ):
145
+ if past_key_values:
146
+ input_ids = input_ids[:, -1:]
147
+
148
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
149
+ if inputs_embeds is not None and past_key_values is None:
150
+ model_inputs = {"inputs_embeds": inputs_embeds}
151
+ else:
152
+ model_inputs = {"input_ids": input_ids}
153
+
154
+ model_inputs.update(
155
+ {
156
+ # "position_ids": position_ids,
157
+ "past_key_values": past_key_values,
158
+ "use_cache": kwargs.get("use_cache"),
159
+ "attention_mask": attention_mask,
160
+ "images": kwargs.get("images", None),
161
+ }
162
+ )
163
+ return model_inputs
@@ -0,0 +1,64 @@
1
+ import re
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": "identity"}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, "mm_projector_type", "linear")
34
+
35
+ if projector_type == "linear":
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ use_norm = False
39
+ if "_Norm" in projector_type:
40
+ use_norm = True
41
+ projector_type = projector_type.replace("_Norm", "")
42
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
43
+ if mlp_gelu_match:
44
+ mlp_depth = int(mlp_gelu_match.group(1))
45
+ if use_norm:
46
+ modules = [
47
+ nn.Linear(config.mm_hidden_size, config.hidden_size),
48
+ nn.LayerNorm(config.hidden_size),
49
+ ]
50
+ else:
51
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
52
+ for _ in range(1, mlp_depth):
53
+ modules.append(nn.GELU())
54
+ if use_norm:
55
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
56
+ modules.append(nn.LayerNorm(config.hidden_size))
57
+ else:
58
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
59
+ return nn.Sequential(*modules)
60
+
61
+ if projector_type == "identity":
62
+ return IdentityMap()
63
+
64
+ raise ValueError(f"Unknown projector type: {projector_type}")
xinference/types.py CHANGED
@@ -329,7 +329,7 @@ class ModelAndPrompt(BaseModel):
329
329
 
330
330
  class CreateCompletionTorch(BaseModel):
331
331
  echo: bool = echo_field
332
- max_tokens: int = max_tokens_field
332
+ max_tokens: Optional[int] = max_tokens_field
333
333
  repetition_penalty: float = repeat_penalty_field
334
334
  stop: Optional[Union[str, List[str]]] = stop_field
335
335
  stop_token_ids: Optional[Union[int, List[int]]] = none_field
@@ -1,11 +1,11 @@
1
1
  {
2
2
  "files": {
3
- "main.js": "./static/js/main.b83095c2.js",
3
+ "main.js": "./static/js/main.15822aeb.js",
4
4
  "static/media/icon.webp": "./static/media/icon.4603d52c63041e5dfbfd.webp",
5
5
  "index.html": "./index.html",
6
- "main.b83095c2.js.map": "./static/js/main.b83095c2.js.map"
6
+ "main.15822aeb.js.map": "./static/js/main.15822aeb.js.map"
7
7
  },
8
8
  "entrypoints": [
9
- "static/js/main.b83095c2.js"
9
+ "static/js/main.15822aeb.js"
10
10
  ]
11
11
  }
@@ -1 +1 @@
1
- <!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.b83095c2.js"></script></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>
1
+ <!doctype html><html lang="en"><head><meta charset="utf-8"/><link rel="icon" href="./favicon.svg"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="theme-color" content="#000000"/><meta name="description" content="Web site created using create-react-app"/><link rel="apple-touch-icon" href="./logo192.png"/><link rel="manifest" href="./manifest.json"/><title>Xinference</title><script defer="defer" src="./static/js/main.15822aeb.js"></script></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>