xinference 1.4.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (59) hide show
  1. xinference/_compat.py +1 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +4 -0
  4. xinference/core/model.py +23 -3
  5. xinference/core/supervisor.py +6 -0
  6. xinference/core/worker.py +54 -11
  7. xinference/model/llm/__init__.py +4 -2
  8. xinference/model/llm/core.py +1 -0
  9. xinference/model/llm/llama_cpp/core.py +6 -1
  10. xinference/model/llm/llm_family.json +117 -1
  11. xinference/model/llm/llm_family_modelscope.json +125 -1
  12. xinference/model/llm/reasoning_parser.py +3 -3
  13. xinference/model/llm/sglang/core.py +111 -13
  14. xinference/model/llm/transformers/core.py +1 -0
  15. xinference/model/llm/transformers/deepseek_vl.py +1 -1
  16. xinference/model/llm/transformers/deepseek_vl2.py +287 -0
  17. xinference/model/llm/utils.py +26 -14
  18. xinference/model/llm/vllm/core.py +149 -8
  19. xinference/model/llm/vllm/distributed_executor.py +314 -0
  20. xinference/model/rerank/core.py +16 -11
  21. xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
  22. xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
  23. xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
  24. xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
  25. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
  26. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
  27. xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
  28. xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
  29. xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
  30. xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
  31. xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
  32. xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
  33. xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
  34. xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
  35. xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
  36. xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
  37. xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
  38. xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
  39. xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
  40. xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
  41. xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
  42. xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
  43. xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
  44. xinference/web/ui/build/asset-manifest.json +3 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/js/{main.3cea968e.js → main.5ca4eea1.js} +3 -3
  47. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +1 -0
  50. {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/METADATA +4 -4
  51. {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/RECORD +56 -31
  52. xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
  53. xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
  54. xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
  55. /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.5ca4eea1.js.LICENSE.txt} +0 -0
  56. {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/LICENSE +0 -0
  57. {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/WHEEL +0 -0
  58. {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/entry_points.txt +0 -0
  59. {xinference-1.4.0.dist-info → xinference-1.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,697 @@
1
+ from attrdict import AttrDict
2
+ from dataclasses import dataclass
3
+ import logging
4
+ import gc
5
+
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, List, Tuple, Callable, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.utils import (
14
+ add_start_docstrings,
15
+ add_start_docstrings_to_model_forward,
16
+ )
17
+ from transformers.modeling_outputs import ModelOutput
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers import (
20
+ AutoConfig,
21
+ AutoModelForCausalLM,
22
+ PreTrainedModel
23
+ )
24
+ from transformers.utils import logging
25
+
26
+ from .siglip_vit import VisionTransformer
27
+ from .configuration_deepseek import DeepseekV2Config
28
+ from .modeling_deepseek import DeepseekV2ForCausalLM
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class MlpProjector(nn.Module):
35
+
36
+ def __init__(self, cfg):
37
+
38
+ super().__init__()
39
+
40
+ self.cfg = cfg
41
+
42
+ if cfg.projector_type == "identity":
43
+ modules = nn.Identity()
44
+
45
+ elif cfg.projector_type == "linear":
46
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
47
+
48
+ elif cfg.projector_type == "mlp_gelu":
49
+ mlp_depth = cfg.depth
50
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
51
+ for _ in range(1, mlp_depth):
52
+ modules.append(nn.GELU())
53
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
54
+ modules = nn.Sequential(*modules)
55
+
56
+ elif cfg.projector_type == "downsample_mlp_gelu":
57
+ mlp_depth = cfg.depth
58
+ mlp_ratio = cfg.mlp_ratio
59
+ modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
60
+ for _ in range(1, mlp_depth - 1):
61
+ modules.append(nn.GELU())
62
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
63
+ modules.append(nn.GELU())
64
+ modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
65
+ modules = nn.Sequential(*modules)
66
+
67
+ else:
68
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
69
+
70
+ if cfg.token_pooling:
71
+ self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
72
+
73
+ self.layers = modules
74
+
75
+ def forward(self, x):
76
+ if self.cfg.token_pooling:
77
+ batch_size, wxh, channels = x.shape
78
+ w = h = int(wxh ** 0.5)
79
+ x = x.view(batch_size, w, h, channels)
80
+ x = x.permute(0, 3, 1, 2)
81
+ # import ipdb; ipdb.set_trace()
82
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
83
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
84
+ # 在通道维度上拼接
85
+ patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
86
+
87
+ # 通过线性层
88
+ patches = patches.permute(0, 2, 1, 3).contiguous()
89
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
90
+
91
+ x = self.token_pooling_layer(patches)
92
+
93
+ elif self.cfg.projector_type == 'downsample_mlp_gelu':
94
+ bs, hw, input_dim = x.shape
95
+ h = w = int((hw) ** 0.5)
96
+
97
+ """compute padding"""
98
+ if h % self.cfg.downsample_ratio:
99
+ pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
100
+ else:
101
+ pad = 0
102
+ x = x.reshape(bs, h, w, input_dim)
103
+ if pad > 0:
104
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
105
+
106
+ """4 to 1 concat"""
107
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
108
+ x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio,
109
+ padding=0) # B, C*4, HW // 4
110
+ x = x.permute(0, 2, 1)
111
+
112
+ return self.layers(x)
113
+
114
+
115
+ class VisionEncoderConfig(PretrainedConfig):
116
+ model_type: str = "vision"
117
+
118
+ model_name: str = "siglip_large_patch16_384"
119
+ image_size: int = 384
120
+ patch_size: int = 16
121
+ width: int = 1024
122
+ layers: int = 24
123
+ heads: int = 16
124
+ mlp_ratio: int = 4
125
+ global_pool: str = "map"
126
+ ignore_head: bool = True
127
+ class_token: bool = False
128
+ num_classes: int = 0
129
+ use_checkpoint: bool = False
130
+ weight_init: str = "skip"
131
+ deterministic: bool = False
132
+ num_recomputing_layers: int = 0
133
+
134
+ def __init__(
135
+ self,
136
+ model_name: str = "siglip_large_patch16_384",
137
+ image_size: int = 384,
138
+ patch_size: int = 16,
139
+ width: int = 1024,
140
+ layers: int = 24,
141
+ heads: int = 16,
142
+ mlp_ratio: int = 4,
143
+ global_pool: str = "map",
144
+ ignore_head: bool = True,
145
+ class_token: bool = False,
146
+ num_classes: int = 0,
147
+ use_checkpoint: bool = False,
148
+ **kwargs
149
+ ):
150
+ self.model_name = model_name
151
+ self.image_size = image_size
152
+ self.patch_size = patch_size
153
+ self.width = width
154
+ self.layers = layers
155
+ self.heads = heads
156
+ self.mlp_ratio = mlp_ratio
157
+ self.global_pool = global_pool
158
+ self.ignore_head = ignore_head
159
+ self.class_token = class_token
160
+ self.num_classes = num_classes
161
+ self.use_checkpoint = use_checkpoint
162
+
163
+ super().__init__(**kwargs)
164
+
165
+
166
+ class MlpProjectorConfig(PretrainedConfig):
167
+ model_type = "mlp_projector"
168
+ projector_type: str = "downsample_mlp_gelu"
169
+ input_dim: int = 1152
170
+ n_embed: int = 2048
171
+ depth: int = 2
172
+ mlp_ratio: int = 1
173
+ downsample_ratio: int = 2
174
+ token_pooling: bool = False
175
+
176
+ def __init__(
177
+ self,
178
+ projector_type: str = "downsample_mlp_gelu",
179
+ input_dim: int = 1152,
180
+ n_embed: int = 2048,
181
+ depth: int = 2,
182
+ mlp_ratio: int = 1,
183
+ downsample_ratio: int = 2,
184
+ **kwargs
185
+ ):
186
+ self.projector_type = projector_type
187
+ self.input_dim = input_dim
188
+ self.n_embed = n_embed
189
+ self.depth = depth
190
+ self.mlp_ratio = mlp_ratio
191
+ self.downsample_ratio = downsample_ratio
192
+
193
+ super().__init__(**kwargs)
194
+
195
+
196
+ @dataclass
197
+ class DeepSeekVLV2CausalLMOutputWithPast(ModelOutput):
198
+ """
199
+ Base class for DeepSeek-VL2 causal language model (or autoregressive) outputs.
200
+
201
+ Args:
202
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
203
+ Language modeling loss (for next-token prediction).
204
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
205
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
206
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
207
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
208
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
209
+
210
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
211
+ `past_key_values` input) to speed up sequential decoding.
212
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
213
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
214
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
215
+
216
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
217
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
218
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
219
+ sequence_length)`.
220
+
221
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
222
+ heads.
223
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
224
+ The rope index difference between sequence length and multimodal rope.
225
+ """
226
+
227
+ loss: Optional[torch.FloatTensor] = None
228
+ logits: torch.FloatTensor = None
229
+ past_key_values: Optional[List[torch.FloatTensor]] = None
230
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
231
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
232
+ rope_deltas: Optional[torch.LongTensor] = None
233
+
234
+
235
+ class DeepseekVLV2Config(PretrainedConfig):
236
+ model_type = "deepseek_vl_v2"
237
+ vision_config: VisionEncoderConfig
238
+ projector_config: MlpProjectorConfig
239
+ language_config: DeepseekV2Config
240
+
241
+ tile_tag: str = "2D"
242
+ global_view_pos: str = "head"
243
+ candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
244
+
245
+ def __init__(
246
+ self,
247
+ tile_tag: str = "tile_tag",
248
+ global_view_pos: str = "head",
249
+ candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
250
+ **kwargs
251
+ ):
252
+ super().__init__(**kwargs)
253
+
254
+ vision_config = kwargs.get("vision_config", {})
255
+ self.vision_config = VisionEncoderConfig(**vision_config)
256
+
257
+ projector_config = kwargs.get("projector_config", {})
258
+ self.projector_config = MlpProjectorConfig(**projector_config)
259
+
260
+ language_config = kwargs.get("language_config", {})
261
+ if isinstance(language_config, DeepseekV2Config):
262
+ self.language_config = language_config
263
+ else:
264
+ self.language_config = DeepseekV2Config(**language_config)
265
+
266
+ self.tile_tag = tile_tag
267
+ self.global_view_pos = global_view_pos
268
+ self.candidate_resolutions = candidate_resolutions
269
+
270
+
271
+ class DeepseekVLV2PreTrainedModel(PreTrainedModel):
272
+ config_class = DeepseekVLV2Config
273
+ base_model_prefix = "deepseek_vl_v2"
274
+ _no_split_modules = []
275
+ _skip_keys_device_placement = "past_key_values"
276
+
277
+
278
+ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
279
+
280
+ def __init__(self, config: DeepseekVLV2Config):
281
+ super().__init__(config)
282
+
283
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
284
+
285
+ # ----------- vision encoder ------------
286
+ vision_config = config.vision_config
287
+ self.vision = VisionTransformer(
288
+ img_size=vision_config.image_size,
289
+ patch_size=vision_config.patch_size,
290
+ embed_dim=vision_config.width,
291
+ depth=vision_config.layers,
292
+ num_heads=vision_config.heads,
293
+ mlp_ratio=vision_config.mlp_ratio,
294
+ class_token=vision_config.class_token,
295
+ global_pool=vision_config.global_pool,
296
+ ignore_head=vision_config.ignore_head,
297
+ weight_init=vision_config.weight_init,
298
+ num_classes=0,
299
+ deterministic=vision_config.deterministic,
300
+ num_recomputing_layers=vision_config.num_recomputing_layers
301
+ )
302
+
303
+ # ----------- vl projector ------------
304
+ projector_config = config.projector_config
305
+ self.projector = MlpProjector(projector_config)
306
+
307
+ # image token format 形式
308
+ # FIXME 目前tile tag & global_view_pos的默认取值都是之前的实验策略;后续应当去掉默认取值,改为没有取值就raise error
309
+ self.tile_tag = config.tile_tag
310
+ self.global_view_pos = config.global_view_pos
311
+
312
+ # 用于format image token sequence的特殊token
313
+ embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32))
314
+ if self.tile_tag == "2D":
315
+ # <|view_separator|>, <|\n|>
316
+ self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
317
+ # fix the typo: view_seperater
318
+ self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
319
+ elif self.tile_tag == "1D":
320
+ # <|tile_x|>, <|tile_global|>
321
+ candidate_resolutions = config.candidate_resolutions
322
+ if len(candidate_resolutions) == 0:
323
+ raise ValueError(
324
+ f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}")
325
+ tile_variants_num = len(candidate_resolutions)
326
+ self.tile_indicators = nn.Parameter(
327
+ torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std
328
+ )
329
+ else:
330
+ raise ValueError(f"tile tag should be either 1D or 2D, but got {self.tile_tag}")
331
+
332
+ # ----------- language model ------------
333
+ language_config = config.language_config
334
+ self.language = DeepseekV2ForCausalLM(language_config)
335
+
336
+ def prepare_inputs_embeds(
337
+ self,
338
+ input_ids: torch.LongTensor,
339
+ images: Optional[torch.FloatTensor] = None,
340
+ images_seq_mask: Optional[torch.LongTensor] = None,
341
+ images_spatial_crop: Optional[torch.LongTensor] = None,
342
+ **ignore_kwargs
343
+ ):
344
+ """
345
+
346
+ Args:
347
+ input_ids (torch.LongTensor): [b, T]
348
+ images (torch.FloatTensor): [b, max_n_images, 3, height, width]
349
+ images_seq_mask (torch.BoolTensor): [b, T]
350
+ images_spatial_crop (torch.LongTensor): [b, max_n_images, 2]
351
+
352
+ Returns:
353
+ input_embeds (torch.Tensor): [b, T, D]
354
+ """
355
+
356
+ if images is None or images_spatial_crop.sum() == 0:
357
+ return self.language.get_input_embeddings()(input_ids)
358
+
359
+ bs, max_n_images, _ = images_spatial_crop.shape
360
+ batch_num_tiles = [0 for _ in range(bs)]
361
+ total_tiles = []
362
+ for idx in range(bs):
363
+ for jdx in range(max_n_images):
364
+ num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
365
+ if num_width_tiles == 0 or num_height_tiles == 0:
366
+ break
367
+ batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles)
368
+
369
+ total_tiles.append(images[idx, :batch_num_tiles[idx]])
370
+
371
+ # [batch_all_tiles, 3, height, width]
372
+ total_tiles = torch.cat(total_tiles, dim=0)
373
+ assert total_tiles.shape[0] == sum(batch_num_tiles)
374
+ if total_tiles.shape[0] == 0:
375
+ return self.language.get_input_embeddings()(input_ids)
376
+
377
+ # [batch_all_tiles, vit_seq_len, c]
378
+ images_feature = self.vision(total_tiles)
379
+
380
+ # [batch_all_tiles, hw, D]
381
+ images_embeds = self.projector(images_feature)
382
+ _, hw, n_dim = images_embeds.shape
383
+ h = w = int(hw ** 0.5)
384
+
385
+ # put image tokens into the input_embeds, [b, T, D]
386
+ input_embeds = self.language.get_input_embeddings()(input_ids)
387
+
388
+ # 根据self.tile_tag & self.global_view_pos填充image token sequence
389
+ tile_index = 0
390
+ for idx in range(images_spatial_crop.shape[0]):
391
+ images_in_this_batch = []
392
+ for jdx in range(images_spatial_crop.shape[1]):
393
+
394
+ # extra global & local features
395
+ num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
396
+ if num_width_tiles == 0 or num_height_tiles == 0:
397
+ break
398
+
399
+ num_tiles_in_image = num_width_tiles * num_height_tiles
400
+
401
+ # [hw, D]
402
+ global_features = images_embeds[tile_index]
403
+
404
+ # [num_height_tiles * num_width_tiles, hw, D]
405
+ local_features = images_embeds[tile_index + 1: tile_index + 1 + num_tiles_in_image]
406
+
407
+ tile_index += num_tiles_in_image + 1
408
+
409
+ # format global and local features
410
+ if self.tile_tag == "2D":
411
+
412
+ # ----------------- global view add newline -----------------
413
+ # [hw, D] -> [h, w, D]
414
+ global_features = global_features.view(h, w, n_dim)
415
+ # [D] -> [h, 1, D]
416
+ new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
417
+ # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
418
+ global_features = torch.cat([global_features, new_lines_in_global], dim=1)
419
+ # [h, w + 1, D] -> [h * (w + 1), D]
420
+ global_features = global_features.view(-1, n_dim)
421
+
422
+ # ----------------- local view add newline -----------------
423
+ # [num_height_tiles * num_width_tiles, h * w, D] -> [num_height_tiles * h, num_width_tiles * w, D]
424
+ local_features = rearrange(
425
+ local_features,
426
+ "(th tw) (h w) d -> (th h) (tw w) d",
427
+ th=num_height_tiles,
428
+ tw=num_width_tiles,
429
+ h=h,
430
+ w=w
431
+ )
432
+
433
+ # [D] -> [num_height_tiles * h, 1, D]
434
+ new_lines_in_local = repeat(
435
+ self.image_newline,
436
+ "d -> (th h) 1 d",
437
+ th=num_height_tiles,
438
+ h=h
439
+ )
440
+
441
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
442
+ local_features = torch.cat([local_features, new_lines_in_local], dim=1)
443
+
444
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
445
+ # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
446
+ local_features = local_features.view(-1, n_dim)
447
+
448
+ # ----------------- merge global and local tiles -----------------
449
+ if self.global_view_pos == "head":
450
+ global_local_features = torch.cat(
451
+ [global_features, self.view_seperator[None, :], local_features], dim=0)
452
+ else:
453
+ global_local_features = torch.cat(
454
+ [local_features, self.view_seperator[None, :], global_features], dim=0)
455
+
456
+ else:
457
+ # abandoned,实际上不会走这个逻辑
458
+ global_features = torch.cat(
459
+ [self.tile_indicators[0:1], global_features], dim=0
460
+ )
461
+ local_features = torch.cat(
462
+ [self.tile_indicators[1:num_tiles_in_image + 1].unsqueeze(1), local_features], dim=1
463
+ )
464
+ local_features = rearrange(local_features, 'crop_num hw d -> (crop_num hw) d')
465
+
466
+ if self.global_view_pos == "head":
467
+ global_local_features = torch.cat([global_features, local_features], dim=0)
468
+ else:
469
+ global_local_features = torch.cat([local_features, global_features], dim=0)
470
+
471
+ images_in_this_batch.append(global_local_features)
472
+
473
+ if len(images_in_this_batch) > 0:
474
+ images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
475
+ input_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch)
476
+
477
+ return input_embeds
478
+
479
+ @torch.no_grad()
480
+ def incremental_prefilling(
481
+ self,
482
+ input_ids: Optional[torch.LongTensor] = None,
483
+ attention_mask: Optional[torch.Tensor] = None,
484
+ inputs_embeds: Optional[torch.FloatTensor] = None,
485
+
486
+ images: Optional[torch.FloatTensor] = None,
487
+ images_seq_mask: Optional[torch.LongTensor] = None,
488
+ images_spatial_crop: Optional[torch.LongTensor] = None,
489
+ chunk_size: int = 1024
490
+ ):
491
+ if inputs_embeds is None:
492
+ inputs_embeds = self.prepare_inputs_embeds(
493
+ input_ids=input_ids,
494
+ images=images,
495
+ images_seq_mask=images_seq_mask,
496
+ images_spatial_crop=images_spatial_crop,
497
+ )
498
+
499
+ del images
500
+ del images_seq_mask
501
+ del images_spatial_crop
502
+
503
+ if attention_mask is not None:
504
+ attention_mask = attention_mask.to(inputs_embeds.device)
505
+
506
+ self._clear_cuda_cache()
507
+
508
+ bzs, seq_len, _ = inputs_embeds.shape
509
+ past_key_values = None
510
+
511
+ # remain the last token for the next forward
512
+ prefilling_len = seq_len - 1
513
+ for i in range(0, prefilling_len, chunk_size):
514
+ chunk_start = i
515
+ chunk_end = min(i + chunk_size, prefilling_len)
516
+ chunk_inputs_embeds = inputs_embeds[:, chunk_start: chunk_end]
517
+ chunk_attention_mask = attention_mask[:, 0: chunk_end]
518
+ # print(f"start = {chunk_start}, end = {chunk_end}, prefilling_len = {prefilling_len}, seq_len = {seq_len}")
519
+
520
+ # compute position_ids
521
+ if past_key_values is not None:
522
+ position_ids = torch.arange(
523
+ chunk_start,
524
+ chunk_end,
525
+ dtype=torch.long,
526
+ device=inputs_embeds.device
527
+ ).unsqueeze(0)
528
+ past_key_values = self._move_past_key_values_to_gpu(past_key_values, inputs_embeds.device)
529
+ else:
530
+ position_ids = None
531
+
532
+ # chunk-forward
533
+ with torch.no_grad():
534
+ outputs = self.forward(
535
+ inputs_embeds=chunk_inputs_embeds,
536
+ attention_mask=chunk_attention_mask,
537
+ past_key_values=past_key_values,
538
+ position_ids=position_ids,
539
+ use_cache=True,
540
+ )
541
+ # update past_key_values
542
+ past_key_values = outputs.past_key_values
543
+ past_key_values = self._move_past_key_values_to_cpu(past_key_values)
544
+
545
+ del outputs, position_ids
546
+ self._clear_cuda_cache()
547
+
548
+ prefilling_key_values = []
549
+ for layer_past in past_key_values:
550
+ prefilling_key_values.append(
551
+ (
552
+ layer_past[0][:, :, 0: prefilling_len, ...].to(inputs_embeds.device),
553
+ layer_past[1][:, :, 0: prefilling_len, ...].to(inputs_embeds.device),
554
+ )
555
+ )
556
+
557
+ return inputs_embeds, prefilling_key_values
558
+
559
+ def forward(
560
+ self,
561
+ input_ids: Optional[torch.LongTensor] = None,
562
+
563
+ attention_mask: Optional[torch.Tensor] = None,
564
+ position_ids: Optional[torch.LongTensor] = None,
565
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
566
+ inputs_embeds: Optional[torch.FloatTensor] = None,
567
+
568
+ images: Optional[torch.FloatTensor] = None,
569
+ images_seq_mask: Optional[torch.LongTensor] = None,
570
+ images_spatial_crop: Optional[torch.LongTensor] = None,
571
+
572
+ labels: Optional[torch.LongTensor] = None,
573
+ use_cache: Optional[bool] = None,
574
+ output_attentions: Optional[bool] = None,
575
+ output_hidden_states: Optional[bool] = None,
576
+ return_dict: Optional[bool] = None,
577
+ cache_position: Optional[torch.LongTensor] = None,
578
+ ):
579
+
580
+ output_attentions = (
581
+ output_attentions
582
+ if output_attentions is not None
583
+ else self.config.output_attentions
584
+ )
585
+ output_hidden_states = (
586
+ output_hidden_states
587
+ if output_hidden_states is not None
588
+ else self.config.output_hidden_states
589
+ )
590
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
591
+
592
+ return_dict = (
593
+ return_dict if return_dict is not None else self.config.use_return_dict
594
+ )
595
+ if inputs_embeds is None:
596
+ inputs_embeds = self.prepare_inputs_embeds(
597
+ input_ids=input_ids,
598
+ images=images,
599
+ images_seq_mask=images_seq_mask,
600
+ images_spatial_crop=images_spatial_crop,
601
+ )
602
+
603
+ if attention_mask is not None:
604
+ attention_mask = attention_mask.to(inputs_embeds.device)
605
+
606
+ # print(inputs_embeds.shape)
607
+ outputs = self.language.forward(
608
+ input_ids=None,
609
+ attention_mask=attention_mask,
610
+ position_ids=position_ids,
611
+ past_key_values=past_key_values,
612
+ inputs_embeds=inputs_embeds,
613
+ labels=labels,
614
+ use_cache=use_cache,
615
+ output_attentions=output_attentions,
616
+ output_hidden_states=output_hidden_states,
617
+ return_dict=return_dict,
618
+ cache_position=cache_position
619
+ )
620
+
621
+ return outputs
622
+
623
+ def _clear_cuda_cache(self):
624
+ """clear CUDA memory cache"""
625
+ gc.collect()
626
+ if torch.cuda.is_available():
627
+ torch.cuda.empty_cache()
628
+ torch.cuda.synchronize()
629
+
630
+ def _move_past_key_values_to_cpu(self, past_key_values):
631
+ # print(f"past_key_values -> cpu")
632
+ if past_key_values is None:
633
+ return None
634
+ return tuple(tuple(t.cpu() for t in layer) for layer in past_key_values)
635
+
636
+ def _move_past_key_values_to_gpu(self, past_key_values, device="cuda:0"):
637
+ # print(f"past_key_values -> gpu")
638
+ if past_key_values is None:
639
+ return None
640
+ return tuple(tuple(t.to(device) for t in layer) for layer in past_key_values)
641
+
642
+ def prepare_inputs_for_generation(
643
+ self,
644
+ input_ids,
645
+ past_key_values=None,
646
+ inputs_embeds=None,
647
+
648
+ images: Optional[torch.FloatTensor] = None,
649
+ images_seq_mask: Optional[torch.LongTensor] = None,
650
+ images_spatial_crop: Optional[torch.LongTensor] = None,
651
+
652
+ attention_mask=None,
653
+ cache_position=None,
654
+
655
+ pixel_values=None,
656
+ image_sizes=None,
657
+ num_logits_to_keep=None,
658
+ **kwargs,
659
+ ):
660
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
661
+ model_inputs = self.language.prepare_inputs_for_generation(
662
+ input_ids,
663
+ past_key_values=past_key_values,
664
+ inputs_embeds=inputs_embeds,
665
+ attention_mask=attention_mask,
666
+ cache_position=cache_position,
667
+ num_logits_to_keep=num_logits_to_keep,
668
+ **kwargs,
669
+ )
670
+
671
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
672
+ # Otherwise we need pixel values to be passed to model
673
+ cache_position = model_inputs["cache_position"]
674
+ if cache_position[0] == 0:
675
+ model_inputs["images"] = images
676
+ model_inputs["images_seq_mask"] = images_seq_mask
677
+ model_inputs["images_spatial_crop"] = images_spatial_crop
678
+
679
+ return model_inputs
680
+
681
+ @staticmethod
682
+ def _reorder_cache(past_key_values, beam_idx):
683
+ reordered_past = ()
684
+ for layer_past in past_key_values:
685
+ reordered_past += (
686
+ tuple(
687
+ past_state.index_select(0, beam_idx.to(past_state.device))
688
+ for past_state in layer_past
689
+ ),
690
+ )
691
+ return reordered_past
692
+
693
+
694
+ AutoConfig.register("vision", VisionEncoderConfig)
695
+ AutoConfig.register("mlp_projector", MlpProjectorConfig)
696
+ AutoConfig.register("deepseek_vl_v2", DeepseekVLV2Config)
697
+ AutoModelForCausalLM.register(DeepseekVLV2Config, DeepseekVLV2ForCausalLM)