sglang 0.4.3__py3-none-any.whl → 0.4.3.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. sglang/lang/backend/openai.py +5 -0
  2. sglang/lang/chat_template.py +22 -7
  3. sglang/lang/ir.py +1 -0
  4. sglang/srt/configs/__init__.py +6 -3
  5. sglang/srt/configs/model_config.py +2 -0
  6. sglang/srt/configs/qwen2_5_vl_config.py +1003 -0
  7. sglang/srt/entrypoints/engine.py +17 -2
  8. sglang/srt/hf_transformers_utils.py +2 -3
  9. sglang/srt/layers/attention/flashinfer_backend.py +101 -30
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  12. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  13. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  14. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  15. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  16. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  17. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  18. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  19. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  20. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  21. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  22. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  23. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  24. sglang/srt/managers/image_processor.py +217 -122
  25. sglang/srt/managers/schedule_batch.py +1 -0
  26. sglang/srt/model_executor/forward_batch_info.py +4 -1
  27. sglang/srt/model_executor/model_runner.py +1 -0
  28. sglang/srt/models/deepseek_nextn.py +295 -0
  29. sglang/srt/models/deepseek_v2.py +9 -3
  30. sglang/srt/models/llava.py +2 -1
  31. sglang/srt/models/qwen2_5_vl.py +722 -0
  32. sglang/srt/models/qwen2_vl.py +2 -1
  33. sglang/srt/openai_api/adapter.py +17 -3
  34. sglang/srt/server_args.py +6 -3
  35. sglang/srt/speculative/eagle_worker.py +7 -2
  36. sglang/srt/speculative/spec_info.py +11 -1
  37. sglang/utils.py +99 -19
  38. sglang/version.py +1 -1
  39. {sglang-0.4.3.dist-info → sglang-0.4.3.post2.dist-info}/METADATA +3 -3
  40. {sglang-0.4.3.dist-info → sglang-0.4.3.post2.dist-info}/RECORD +43 -27
  41. sglang/srt/configs/qwen2vl.py +0 -130
  42. {sglang-0.4.3.dist-info → sglang-0.4.3.post2.dist-info}/LICENSE +0 -0
  43. {sglang-0.4.3.dist-info → sglang-0.4.3.post2.dist-info}/WHEEL +0 -0
  44. {sglang-0.4.3.dist-info → sglang-0.4.3.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,722 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
4
+ # Copyright 2024 The Qwen team.
5
+ # Copyright 2023 The vLLM team.
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
25
+ import logging
26
+ from functools import lru_cache, partial
27
+ from typing import Iterable, List, Optional, Tuple, Type
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from einops import rearrange
34
+ from transformers import AutoModel, Qwen2VLConfig
35
+ from transformers.activations import ACT2FN
36
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
37
+
38
+ from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
39
+ from sglang.srt.distributed import (
40
+ get_tensor_model_parallel_rank,
41
+ get_tensor_model_parallel_world_size,
42
+ )
43
+ from sglang.srt.hf_transformers_utils import get_processor
44
+ from sglang.srt.layers.attention.vision import VisionAttention
45
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
46
+ from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.pooler import Pooler, PoolingType
48
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
+ from sglang.srt.managers.schedule_batch import ImageInputs
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
53
+ from sglang.srt.models.qwen2 import Qwen2Model
54
+ from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ class Qwen2_5_VLMLP(nn.Module):
60
+
61
+ def __init__(
62
+ self,
63
+ in_features: int,
64
+ hidden_features: int = None,
65
+ bias: bool = True,
66
+ hidden_act="silu",
67
+ quant_config: Optional[QuantizationConfig] = None,
68
+ ):
69
+ super().__init__()
70
+ self.gate_proj = ColumnParallelLinear(
71
+ in_features, hidden_features, bias=bias, quant_config=quant_config
72
+ )
73
+ self.up_proj = ColumnParallelLinear(
74
+ in_features, hidden_features, bias=bias, quant_config=quant_config
75
+ )
76
+ self.down_proj = RowParallelLinear(
77
+ hidden_features, in_features, bias=bias, quant_config=quant_config
78
+ )
79
+ self.act = ACT2FN[hidden_act]
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ x_parallel_gate, _ = self.gate_proj(x)
83
+ x_parallel_gate = self.act(x_parallel_gate)
84
+ x_parallel_up, _ = self.up_proj(x)
85
+ x_parallel = x_parallel_gate * x_parallel_up
86
+ x, _ = self.down_proj(x_parallel)
87
+ return x
88
+
89
+
90
+ class Qwen2_5_VisionBlock(nn.Module):
91
+
92
+ def __init__(
93
+ self,
94
+ dim: int,
95
+ intermediate_dim: int,
96
+ num_heads: int,
97
+ hidden_act="silu",
98
+ norm_layer: Type[nn.Module] = None,
99
+ attn_implementation: Optional[str] = "sdpa",
100
+ quant_config: Optional[QuantizationConfig] = None,
101
+ ) -> None:
102
+ super().__init__()
103
+ if norm_layer is None:
104
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
105
+ self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
106
+ self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
107
+ if attn_implementation == "sdpa":
108
+ use_context_forward = False
109
+ use_full_precision_softmax = False
110
+ elif attn_implementation == "flash_attention_2":
111
+ use_full_precision_softmax = False
112
+ use_context_forward = True
113
+ elif attn_implementation == "eager":
114
+ use_full_precision_softmax = True
115
+ use_context_forward = False
116
+
117
+ self.attn = VisionAttention(
118
+ embed_dim=dim,
119
+ num_heads=num_heads,
120
+ projection_size=dim,
121
+ use_qkv_parallel=False,
122
+ use_context_forward=use_context_forward,
123
+ use_full_precision_softmax=use_full_precision_softmax,
124
+ flatten_batch=True,
125
+ quant_config=quant_config,
126
+ )
127
+ self.mlp = Qwen2_5_VLMLP(
128
+ dim, intermediate_dim, hidden_act=hidden_act, quant_config=quant_config
129
+ )
130
+
131
+ def forward(
132
+ self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
133
+ ) -> torch.Tensor:
134
+ hidden_states = self.norm1(x)
135
+ hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
136
+ attn = self.attn(
137
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
138
+ )
139
+ attn = rearrange(attn, "b s ... -> s b ...")
140
+ x = x + attn
141
+ norm2 = self.norm2(x)
142
+ mlp = self.mlp(norm2)
143
+ x = x + mlp
144
+ return x
145
+
146
+
147
+ class Qwen2_5_VisionPatchEmbed(nn.Module):
148
+
149
+ def __init__(
150
+ self,
151
+ patch_size: int = 14,
152
+ temporal_patch_size: int = 2,
153
+ in_chans: int = 3,
154
+ embed_dim: int = 1152,
155
+ ) -> None:
156
+ super().__init__()
157
+ self.patch_size = patch_size
158
+ self.temporal_patch_size = temporal_patch_size
159
+ self.embed_dim = embed_dim
160
+
161
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
162
+ self.proj = nn.Conv3d(
163
+ in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
164
+ )
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ L, C = x.shape
168
+ x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
169
+ x = self.proj(x).view(L, self.embed_dim)
170
+ return x
171
+
172
+
173
+ class Qwen2_5_VisionPatchMerger(nn.Module):
174
+
175
+ def __init__(
176
+ self,
177
+ dim: int,
178
+ context_dim: int,
179
+ spatial_merge_size: int = 2,
180
+ quant_config: Optional[QuantizationConfig] = None,
181
+ ) -> None:
182
+ super().__init__()
183
+ self.hidden_size = context_dim * (spatial_merge_size**2)
184
+ self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
185
+ self.mlp = nn.ModuleList(
186
+ [
187
+ ColumnParallelLinear(
188
+ self.hidden_size,
189
+ self.hidden_size,
190
+ bias=True,
191
+ quant_config=quant_config,
192
+ ),
193
+ nn.GELU(),
194
+ RowParallelLinear(
195
+ self.hidden_size, dim, bias=True, quant_config=quant_config
196
+ ),
197
+ ]
198
+ )
199
+
200
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
201
+ x = self.ln_q(x)
202
+ x = x.view(-1, self.hidden_size)
203
+
204
+ mlp_fc1, mlp_act, mlp_fc2 = self.mlp
205
+ x_parallel, _ = mlp_fc1(x)
206
+ x_parallel = mlp_act(x_parallel)
207
+ out, _ = mlp_fc2(x_parallel)
208
+ return out
209
+
210
+
211
+ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
212
+
213
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
214
+ super().__init__()
215
+ self.dim = dim
216
+ self.theta = theta
217
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
218
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
219
+ self._seq_len_cached = 0
220
+ self._freqs_cached = None
221
+
222
+ def update_freqs_cache(self, seqlen: int) -> None:
223
+ if seqlen > self._seq_len_cached:
224
+ seqlen *= 2
225
+ self._seq_len_cached = seqlen
226
+ self.inv_freq = 1.0 / (
227
+ self.theta
228
+ ** (
229
+ torch.arange(
230
+ 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
231
+ )
232
+ / self.dim
233
+ )
234
+ )
235
+ seq = torch.arange(
236
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
237
+ )
238
+ freqs = torch.outer(seq, self.inv_freq)
239
+ self._freqs_cached = freqs
240
+
241
+ def forward(self, seqlen: int) -> torch.Tensor:
242
+ self.update_freqs_cache(seqlen)
243
+ return self._freqs_cached[:seqlen]
244
+
245
+
246
+ class Qwen2_5_VisionTransformer(nn.Module):
247
+
248
+ def __init__(
249
+ self,
250
+ vision_config: Qwen2_5_VLVisionConfig,
251
+ norm_eps: float = 1e-6,
252
+ quant_config: Optional[QuantizationConfig] = None,
253
+ ) -> None:
254
+ super().__init__()
255
+
256
+ patch_size: int = vision_config.patch_size
257
+ temporal_patch_size: int = vision_config.temporal_patch_size
258
+ spatial_merge_size: int = vision_config.spatial_merge_size
259
+ self.spatial_merge_size = spatial_merge_size
260
+ self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
261
+ in_chans: int = vision_config.in_chans
262
+ hidden_size: int = vision_config.hidden_size
263
+ depth: int = vision_config.depth
264
+ num_heads: int = vision_config.num_heads
265
+ self.fullatt_block_indexes = vision_config.fullatt_block_indexes
266
+ self.window_size = vision_config.window_size
267
+ self.patch_size = vision_config.patch_size
268
+ mlp_hidden_size: int = vision_config.intermediate_size
269
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
270
+ patch_size=patch_size,
271
+ temporal_patch_size=temporal_patch_size,
272
+ in_chans=in_chans,
273
+ embed_dim=hidden_size,
274
+ )
275
+
276
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
277
+ head_dim = hidden_size // num_heads
278
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
279
+ self.blocks = nn.ModuleList(
280
+ [
281
+ Qwen2_5_VisionBlock(
282
+ dim=hidden_size,
283
+ intermediate_dim=mlp_hidden_size,
284
+ num_heads=num_heads,
285
+ hidden_act=vision_config.hidden_act,
286
+ norm_layer=norm_layer,
287
+ attn_implementation="sdpa",
288
+ quant_config=quant_config,
289
+ )
290
+ for _ in range(depth)
291
+ ]
292
+ )
293
+ self.merger = Qwen2_5_VisionPatchMerger(
294
+ dim=vision_config.out_hidden_size,
295
+ context_dim=hidden_size,
296
+ spatial_merge_size=spatial_merge_size,
297
+ quant_config=quant_config,
298
+ )
299
+
300
+ def get_window_index(self, grid_thw):
301
+ window_index: list = []
302
+ cu_window_seqlens: list = [0]
303
+ window_index_id = 0
304
+ vit_merger_window_size = (
305
+ self.window_size // self.spatial_merge_size // self.patch_size
306
+ )
307
+
308
+ for grid_t, grid_h, grid_w in grid_thw:
309
+ llm_grid_h, llm_grid_w = (
310
+ grid_h // self.spatial_merge_size,
311
+ grid_w // self.spatial_merge_size,
312
+ )
313
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
314
+ grid_t, llm_grid_h, llm_grid_w
315
+ )
316
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
317
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
318
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
319
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
320
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
321
+ index_padded = index_padded.reshape(
322
+ grid_t,
323
+ num_windows_h,
324
+ vit_merger_window_size,
325
+ num_windows_w,
326
+ vit_merger_window_size,
327
+ )
328
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
329
+ grid_t,
330
+ num_windows_h * num_windows_w,
331
+ vit_merger_window_size,
332
+ vit_merger_window_size,
333
+ )
334
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
335
+ index_padded = index_padded.reshape(-1)
336
+ index_new = index_padded[index_padded != -100]
337
+ window_index.append(index_new + window_index_id)
338
+ cu_seqlens_tmp = (
339
+ seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
340
+ )
341
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
342
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
343
+ window_index = torch.cat(window_index, dim=0)
344
+
345
+ return window_index, cu_window_seqlens
346
+
347
+ @property
348
+ def dtype(self) -> torch.dtype:
349
+ return self.blocks[0].mlp.gate_proj.weight.dtype
350
+
351
+ @property
352
+ def device(self) -> torch.device:
353
+ return self.blocks[0].mlp.gate_proj.weight.device
354
+
355
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
356
+ pos_ids = []
357
+ for t, h, w in grid_thw:
358
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
359
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
360
+ hpos_ids = (
361
+ hpos_ids.reshape(
362
+ h // self.spatial_merge_size,
363
+ self.spatial_merge_size,
364
+ w // self.spatial_merge_size,
365
+ self.spatial_merge_size,
366
+ )
367
+ .permute(0, 2, 1, 3)
368
+ .flatten()
369
+ )
370
+ wpos_ids = (
371
+ wpos_ids.reshape(
372
+ h // self.spatial_merge_size,
373
+ self.spatial_merge_size,
374
+ w // self.spatial_merge_size,
375
+ self.spatial_merge_size,
376
+ )
377
+ .permute(0, 2, 1, 3)
378
+ .flatten()
379
+ )
380
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
381
+ pos_ids = torch.cat(pos_ids, dim=0)
382
+ max_grid_size = grid_thw[:, 1:].max()
383
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
384
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
385
+ return rotary_pos_emb
386
+
387
+ def forward(
388
+ self,
389
+ x: torch.Tensor,
390
+ grid_thw: torch.Tensor,
391
+ ) -> torch.Tensor:
392
+ # patchify
393
+ x = x.to(device=self.device, dtype=self.dtype)
394
+ x = self.patch_embed(x)
395
+
396
+ # compute position embedding
397
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
398
+
399
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
400
+ cu_window_seqlens = torch.tensor(
401
+ cu_window_seqlens,
402
+ device=x.device,
403
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
404
+ )
405
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
406
+
407
+ seq_len, _ = x.size()
408
+
409
+ x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
410
+ x = x[window_index, :, :]
411
+ x = x.reshape(seq_len, -1)
412
+ rotary_pos_emb = rotary_pos_emb.reshape(
413
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
414
+ )
415
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
416
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
417
+
418
+ # compute cu_seqlens
419
+ cu_seqlens = torch.repeat_interleave(
420
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
421
+ ).cumsum(dim=0, dtype=torch.int32)
422
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
423
+
424
+ # transformers
425
+ x = x.unsqueeze(1)
426
+ for layer_num, blk in enumerate(self.blocks):
427
+ if layer_num in self.fullatt_block_indexes:
428
+ cu_seqlens_now = cu_seqlens
429
+ else:
430
+ cu_seqlens_now = cu_window_seqlens
431
+ x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
432
+
433
+ # adapter
434
+ x = self.merger(x)
435
+
436
+ reverse_indices = torch.argsort(window_index)
437
+ x = x[reverse_indices, :]
438
+
439
+ return x
440
+
441
+
442
+ cached_get_processor = lru_cache(get_processor)
443
+
444
+
445
+ class Qwen2_5_VLForConditionalGeneration(nn.Module):
446
+ def __init__(
447
+ self,
448
+ config: Qwen2VLConfig,
449
+ quant_config: Optional[QuantizationConfig] = None,
450
+ ) -> None:
451
+ super().__init__()
452
+
453
+ self.config = config
454
+ self.visual = Qwen2_5_VisionTransformer(
455
+ config.vision_config,
456
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
457
+ # NOTE: Qwen2-VL vision encoder does not support any
458
+ # quantization method now.
459
+ quant_config=None,
460
+ )
461
+
462
+ self.model = Qwen2Model(config, quant_config)
463
+
464
+ if config.tie_word_embeddings:
465
+ self.lm_head = self.model.embed_tokens
466
+ else:
467
+ self.lm_head = ParallelLMHead(
468
+ config.vocab_size, config.hidden_size, quant_config=quant_config
469
+ )
470
+
471
+ self.logits_processor = LogitsProcessor(config)
472
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
473
+
474
+ def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
475
+ processor = cached_get_processor(self.config._name_or_path)
476
+ grid_t, grid_h, grid_w = image_grid_thw
477
+ num_image_tokens = (
478
+ grid_t
479
+ * grid_h
480
+ * grid_w
481
+ // processor.image_processor.merge_size
482
+ // processor.image_processor.merge_size
483
+ )
484
+ return num_image_tokens
485
+
486
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
487
+ new_input_ids = []
488
+ last_idx = 0
489
+ image_idx = -1
490
+ image_inputs.image_offsets = []
491
+
492
+ # Get all special token IDs
493
+ im_start_id = image_inputs.im_start_id
494
+ im_end_id = image_inputs.im_end_id
495
+
496
+ # Find all start and end positions for both types
497
+ start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
498
+ end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
499
+
500
+ if len(start_indices) != len(end_indices):
501
+ return input_ids
502
+ # Process each region (both image and slice)
503
+ for start_idx, end_idx in zip(start_indices, end_indices):
504
+ # Add non-image tokens before this region
505
+ new_input_ids.extend(input_ids[last_idx : start_idx + 1])
506
+
507
+ is_image_start = input_ids[start_idx] == im_start_id
508
+
509
+ if is_image_start:
510
+ image_inputs.image_offsets += [start_idx]
511
+ image_idx += 1
512
+
513
+ num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
514
+
515
+ # Generate pad_ids
516
+ pad_values = [image_inputs.pad_values[image_idx]]
517
+
518
+ pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
519
+ pad_ids = pad_ids[:num_tokens]
520
+
521
+ # Add pad_ids
522
+ new_input_ids.extend(pad_ids)
523
+
524
+ # Update last_idx to after end token
525
+ last_idx = end_idx
526
+
527
+ # Add remaining tokens after last region
528
+ new_input_ids.extend(input_ids[last_idx:])
529
+ assert len(input_ids) == len(new_input_ids)
530
+ return new_input_ids
531
+
532
+ def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
533
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
534
+ image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
535
+ return image_embeds
536
+
537
+ def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
538
+ pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
539
+ video_embeds = self.visual(
540
+ pixel_values_videos, grid_thw=video_input["video_grid_thw"]
541
+ )
542
+ return video_embeds
543
+
544
+ def forward(
545
+ self,
546
+ input_ids: torch.Tensor,
547
+ positions: torch.Tensor,
548
+ forward_batch: ForwardBatch,
549
+ get_embedding: bool = False,
550
+ ):
551
+ """Run forward pass for Qwen2_5-VL.
552
+
553
+ Args:
554
+ input_ids: Flattened (concatenated) input_ids corresponding to a
555
+ batch.
556
+ positions: Flattened (concatenated) position ids corresponding to a
557
+ batch.
558
+ **NOTE**: If mrope is enabled (default setting for Qwen2-VL
559
+ opensource models), the shape will be `(3, seq_len)`,
560
+ otherwise it will be `(seq_len,).
561
+ (Use input_metadata.mrope_positions to replace it)
562
+ """
563
+ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
564
+ positions = forward_batch.mrope_positions
565
+
566
+ image_inputs = None
567
+ if forward_batch.image_inputs is not None:
568
+ image_inputs = [
569
+ img for img in forward_batch.image_inputs if img is not None
570
+ ]
571
+
572
+ if (
573
+ forward_batch.forward_mode.is_decode()
574
+ or image_inputs is None
575
+ or len(image_inputs) == 0
576
+ ):
577
+ inputs_embeds = self.model.embed_tokens(input_ids)
578
+ else:
579
+ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
580
+ assert positions.ndim == 2 and positions.size(0) == 3, (
581
+ "multimodal section rotary embedding requires "
582
+ f"(3, seq_len) positions, but got {positions.size()}"
583
+ )
584
+
585
+ # Clamp input ids. This is because the input_ids for the image tokens are
586
+ # filled with the hash values of the image for the prefix matching in the radix attention.
587
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
588
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
589
+ # [B, s, hidden_size]
590
+ inputs_embeds = self.model.embed_tokens(input_ids)
591
+ extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
592
+ prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
593
+ for i, image in enumerate(forward_batch.image_inputs):
594
+ if image is None:
595
+ continue
596
+ start_idx = extend_start_loc_cpu[i]
597
+ prefix_len = prefix_lens_cpu[i]
598
+
599
+ pixel_values = image.pixel_values.clone().detach().requires_grad_(False)
600
+ image_grid_thws = torch.tensor(
601
+ np.array(image.image_grid_thws), device="cuda"
602
+ )
603
+ image_offsets = image.image_offsets
604
+ image_input = Qwen2VLImageInputs(
605
+ pixel_values=pixel_values, image_grid_thw=image_grid_thws
606
+ )
607
+ image_embeds = self._process_image_input(image_input)
608
+
609
+ image_embeds_offset = 0
610
+ for idx, image_offset in enumerate(image_offsets):
611
+ if image_offset < prefix_len:
612
+ continue
613
+ num_image_tokens = self.calculate_num_image_tokens(
614
+ image_grid_thws[idx]
615
+ )
616
+
617
+ left_idx = start_idx + (image_offset - prefix_len)
618
+ right_idx = left_idx + num_image_tokens
619
+
620
+ tp_size = get_tensor_model_parallel_world_size()
621
+
622
+ hidden_size = image_embeds.shape[-1]
623
+
624
+ if hidden_size % tp_size != 0:
625
+ padding_size = tp_size - (hidden_size % tp_size)
626
+ image_embeds = F.pad(image_embeds, (0, padding_size))
627
+ inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
628
+
629
+ hidden_chunk_size = image_embeds.shape[-1] // tp_size
630
+ rank = get_tensor_model_parallel_rank()
631
+ start_dim = rank * hidden_chunk_size
632
+ end_dim = (rank + 1) * hidden_chunk_size
633
+ inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
634
+ image_embeds[
635
+ image_embeds_offset : image_embeds_offset
636
+ + num_image_tokens,
637
+ ...,
638
+ start_dim:end_dim,
639
+ ]
640
+ )
641
+ image_embeds_offset += num_image_tokens
642
+
643
+ input_ids = None
644
+ hidden_states = self.model(
645
+ input_ids=input_ids,
646
+ positions=positions,
647
+ forward_batch=forward_batch,
648
+ input_embeds=inputs_embeds,
649
+ )
650
+
651
+ if not get_embedding:
652
+ return self.logits_processor(
653
+ input_ids, hidden_states, self.lm_head, forward_batch
654
+ )
655
+ else:
656
+ return self.pooler(hidden_states, forward_batch)
657
+
658
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
659
+ stacked_params_mapping = [
660
+ # (param_name, shard_name, shard_id)
661
+ ("qkv_proj", "q_proj", "q"),
662
+ ("qkv_proj", "k_proj", "k"),
663
+ ("qkv_proj", "v_proj", "v"),
664
+ ("gate_up_proj", "up_proj", 1),
665
+ ("gate_up_proj", "gate_proj", 0),
666
+ ]
667
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
668
+ for name, loaded_weight in weights:
669
+ if "rotary_emb.inv_freq" in name:
670
+ continue
671
+
672
+ for param_name, weight_name, shard_id in stacked_params_mapping:
673
+ if weight_name not in name:
674
+ continue
675
+ if "visual" in name:
676
+ continue
677
+ name = name.replace(weight_name, param_name)
678
+
679
+ # Skip loading extra bias for GPTQ models.
680
+ if name.endswith(".bias") and name not in params_dict:
681
+ continue
682
+ param = params_dict[name]
683
+ weight_loader = param.weight_loader
684
+ weight_loader(param, loaded_weight, shard_id)
685
+ break
686
+ else:
687
+ if "visual" in name and "qkv.weight" in name:
688
+ visual_num_heads = self.config.vision_config.num_heads
689
+ visual_embed_dim = self.config.vision_config.hidden_size
690
+ head_size = visual_embed_dim // visual_num_heads
691
+ loaded_weight = loaded_weight.view(
692
+ 3, visual_num_heads, head_size, visual_embed_dim
693
+ )
694
+ loaded_weight = loaded_weight.transpose(0, 1)
695
+ loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
696
+ elif "visual" in name and "qkv.bias" in name:
697
+ visual_num_heads = self.config.vision_config.num_heads
698
+ visual_embed_dim = self.config.vision_config.hidden_size
699
+ head_size = visual_embed_dim // visual_num_heads
700
+ loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
701
+ loaded_weight = loaded_weight.transpose(0, 1)
702
+ loaded_weight = loaded_weight.reshape(-1)
703
+
704
+ if "visual" in name:
705
+ # adapt to VisionAttention
706
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
707
+
708
+ try:
709
+ # Skip loading extra bias for GPTQ models.
710
+ if name.endswith(".bias") and name not in params_dict:
711
+ continue
712
+ param = params_dict[name]
713
+ except KeyError:
714
+ print(params_dict.keys())
715
+ raise
716
+
717
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
718
+ weight_loader(param, loaded_weight)
719
+
720
+
721
+ EntryClass = [Qwen2_5_VLForConditionalGeneration]
722
+ AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)