sglang 0.4.3__py3-none-any.whl → 0.4.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/backend/openai.py +5 -0
- sglang/lang/chat_template.py +22 -7
- sglang/lang/ir.py +1 -0
- sglang/srt/configs/__init__.py +6 -3
- sglang/srt/configs/model_config.py +2 -0
- sglang/srt/configs/qwen2_5_vl_config.py +1003 -0
- sglang/srt/entrypoints/engine.py +16 -1
- sglang/srt/hf_transformers_utils.py +2 -3
- sglang/srt/managers/image_processor.py +217 -122
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/models/deepseek_nextn.py +295 -0
- sglang/srt/models/deepseek_v2.py +4 -1
- sglang/srt/models/llava.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +722 -0
- sglang/srt/models/qwen2_vl.py +2 -1
- sglang/srt/openai_api/adapter.py +17 -3
- sglang/srt/server_args.py +6 -3
- sglang/srt/speculative/eagle_worker.py +7 -2
- sglang/srt/speculative/spec_info.py +11 -1
- sglang/utils.py +99 -19
- sglang/version.py +1 -1
- {sglang-0.4.3.dist-info → sglang-0.4.3.post1.dist-info}/METADATA +2 -2
- {sglang-0.4.3.dist-info → sglang-0.4.3.post1.dist-info}/RECORD +26 -24
- sglang/srt/configs/qwen2vl.py +0 -130
- {sglang-0.4.3.dist-info → sglang-0.4.3.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.dist-info → sglang-0.4.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.3.dist-info → sglang-0.4.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,722 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from
|
3
|
+
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
4
|
+
# Copyright 2024 The Qwen team.
|
5
|
+
# Copyright 2023 The vLLM team.
|
6
|
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
7
|
+
#
|
8
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
9
|
+
# and OPT implementations in this library. It has been modified from its
|
10
|
+
# original forms to accommodate minor architectural differences compared
|
11
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
12
|
+
#
|
13
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
14
|
+
# you may not use this file except in compliance with the License.
|
15
|
+
# You may obtain a copy of the License at
|
16
|
+
#
|
17
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
18
|
+
#
|
19
|
+
# Unless required by applicable law or agreed to in writing, software
|
20
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
21
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
22
|
+
# See the License for the specific language governing permissions and
|
23
|
+
# limitations under the License.
|
24
|
+
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
25
|
+
import logging
|
26
|
+
from functools import lru_cache, partial
|
27
|
+
from typing import Iterable, List, Optional, Tuple, Type
|
28
|
+
|
29
|
+
import numpy as np
|
30
|
+
import torch
|
31
|
+
import torch.nn as nn
|
32
|
+
import torch.nn.functional as F
|
33
|
+
from einops import rearrange
|
34
|
+
from transformers import AutoModel, Qwen2VLConfig
|
35
|
+
from transformers.activations import ACT2FN
|
36
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
37
|
+
|
38
|
+
from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
|
39
|
+
from sglang.srt.distributed import (
|
40
|
+
get_tensor_model_parallel_rank,
|
41
|
+
get_tensor_model_parallel_world_size,
|
42
|
+
)
|
43
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
44
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
45
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
46
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
48
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
49
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
50
|
+
from sglang.srt.managers.schedule_batch import ImageInputs
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
53
|
+
from sglang.srt.models.qwen2 import Qwen2Model
|
54
|
+
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
|
55
|
+
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
|
58
|
+
|
59
|
+
class Qwen2_5_VLMLP(nn.Module):
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
in_features: int,
|
64
|
+
hidden_features: int = None,
|
65
|
+
bias: bool = True,
|
66
|
+
hidden_act="silu",
|
67
|
+
quant_config: Optional[QuantizationConfig] = None,
|
68
|
+
):
|
69
|
+
super().__init__()
|
70
|
+
self.gate_proj = ColumnParallelLinear(
|
71
|
+
in_features, hidden_features, bias=bias, quant_config=quant_config
|
72
|
+
)
|
73
|
+
self.up_proj = ColumnParallelLinear(
|
74
|
+
in_features, hidden_features, bias=bias, quant_config=quant_config
|
75
|
+
)
|
76
|
+
self.down_proj = RowParallelLinear(
|
77
|
+
hidden_features, in_features, bias=bias, quant_config=quant_config
|
78
|
+
)
|
79
|
+
self.act = ACT2FN[hidden_act]
|
80
|
+
|
81
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
82
|
+
x_parallel_gate, _ = self.gate_proj(x)
|
83
|
+
x_parallel_gate = self.act(x_parallel_gate)
|
84
|
+
x_parallel_up, _ = self.up_proj(x)
|
85
|
+
x_parallel = x_parallel_gate * x_parallel_up
|
86
|
+
x, _ = self.down_proj(x_parallel)
|
87
|
+
return x
|
88
|
+
|
89
|
+
|
90
|
+
class Qwen2_5_VisionBlock(nn.Module):
|
91
|
+
|
92
|
+
def __init__(
|
93
|
+
self,
|
94
|
+
dim: int,
|
95
|
+
intermediate_dim: int,
|
96
|
+
num_heads: int,
|
97
|
+
hidden_act="silu",
|
98
|
+
norm_layer: Type[nn.Module] = None,
|
99
|
+
attn_implementation: Optional[str] = "sdpa",
|
100
|
+
quant_config: Optional[QuantizationConfig] = None,
|
101
|
+
) -> None:
|
102
|
+
super().__init__()
|
103
|
+
if norm_layer is None:
|
104
|
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
105
|
+
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
106
|
+
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
107
|
+
if attn_implementation == "sdpa":
|
108
|
+
use_context_forward = False
|
109
|
+
use_full_precision_softmax = False
|
110
|
+
elif attn_implementation == "flash_attention_2":
|
111
|
+
use_full_precision_softmax = False
|
112
|
+
use_context_forward = True
|
113
|
+
elif attn_implementation == "eager":
|
114
|
+
use_full_precision_softmax = True
|
115
|
+
use_context_forward = False
|
116
|
+
|
117
|
+
self.attn = VisionAttention(
|
118
|
+
embed_dim=dim,
|
119
|
+
num_heads=num_heads,
|
120
|
+
projection_size=dim,
|
121
|
+
use_qkv_parallel=False,
|
122
|
+
use_context_forward=use_context_forward,
|
123
|
+
use_full_precision_softmax=use_full_precision_softmax,
|
124
|
+
flatten_batch=True,
|
125
|
+
quant_config=quant_config,
|
126
|
+
)
|
127
|
+
self.mlp = Qwen2_5_VLMLP(
|
128
|
+
dim, intermediate_dim, hidden_act=hidden_act, quant_config=quant_config
|
129
|
+
)
|
130
|
+
|
131
|
+
def forward(
|
132
|
+
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
133
|
+
) -> torch.Tensor:
|
134
|
+
hidden_states = self.norm1(x)
|
135
|
+
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
136
|
+
attn = self.attn(
|
137
|
+
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
138
|
+
)
|
139
|
+
attn = rearrange(attn, "b s ... -> s b ...")
|
140
|
+
x = x + attn
|
141
|
+
norm2 = self.norm2(x)
|
142
|
+
mlp = self.mlp(norm2)
|
143
|
+
x = x + mlp
|
144
|
+
return x
|
145
|
+
|
146
|
+
|
147
|
+
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
148
|
+
|
149
|
+
def __init__(
|
150
|
+
self,
|
151
|
+
patch_size: int = 14,
|
152
|
+
temporal_patch_size: int = 2,
|
153
|
+
in_chans: int = 3,
|
154
|
+
embed_dim: int = 1152,
|
155
|
+
) -> None:
|
156
|
+
super().__init__()
|
157
|
+
self.patch_size = patch_size
|
158
|
+
self.temporal_patch_size = temporal_patch_size
|
159
|
+
self.embed_dim = embed_dim
|
160
|
+
|
161
|
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
162
|
+
self.proj = nn.Conv3d(
|
163
|
+
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
164
|
+
)
|
165
|
+
|
166
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
167
|
+
L, C = x.shape
|
168
|
+
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
169
|
+
x = self.proj(x).view(L, self.embed_dim)
|
170
|
+
return x
|
171
|
+
|
172
|
+
|
173
|
+
class Qwen2_5_VisionPatchMerger(nn.Module):
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
dim: int,
|
178
|
+
context_dim: int,
|
179
|
+
spatial_merge_size: int = 2,
|
180
|
+
quant_config: Optional[QuantizationConfig] = None,
|
181
|
+
) -> None:
|
182
|
+
super().__init__()
|
183
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
184
|
+
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
|
185
|
+
self.mlp = nn.ModuleList(
|
186
|
+
[
|
187
|
+
ColumnParallelLinear(
|
188
|
+
self.hidden_size,
|
189
|
+
self.hidden_size,
|
190
|
+
bias=True,
|
191
|
+
quant_config=quant_config,
|
192
|
+
),
|
193
|
+
nn.GELU(),
|
194
|
+
RowParallelLinear(
|
195
|
+
self.hidden_size, dim, bias=True, quant_config=quant_config
|
196
|
+
),
|
197
|
+
]
|
198
|
+
)
|
199
|
+
|
200
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
201
|
+
x = self.ln_q(x)
|
202
|
+
x = x.view(-1, self.hidden_size)
|
203
|
+
|
204
|
+
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
205
|
+
x_parallel, _ = mlp_fc1(x)
|
206
|
+
x_parallel = mlp_act(x_parallel)
|
207
|
+
out, _ = mlp_fc2(x_parallel)
|
208
|
+
return out
|
209
|
+
|
210
|
+
|
211
|
+
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
212
|
+
|
213
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
214
|
+
super().__init__()
|
215
|
+
self.dim = dim
|
216
|
+
self.theta = theta
|
217
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
218
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
219
|
+
self._seq_len_cached = 0
|
220
|
+
self._freqs_cached = None
|
221
|
+
|
222
|
+
def update_freqs_cache(self, seqlen: int) -> None:
|
223
|
+
if seqlen > self._seq_len_cached:
|
224
|
+
seqlen *= 2
|
225
|
+
self._seq_len_cached = seqlen
|
226
|
+
self.inv_freq = 1.0 / (
|
227
|
+
self.theta
|
228
|
+
** (
|
229
|
+
torch.arange(
|
230
|
+
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
|
231
|
+
)
|
232
|
+
/ self.dim
|
233
|
+
)
|
234
|
+
)
|
235
|
+
seq = torch.arange(
|
236
|
+
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
237
|
+
)
|
238
|
+
freqs = torch.outer(seq, self.inv_freq)
|
239
|
+
self._freqs_cached = freqs
|
240
|
+
|
241
|
+
def forward(self, seqlen: int) -> torch.Tensor:
|
242
|
+
self.update_freqs_cache(seqlen)
|
243
|
+
return self._freqs_cached[:seqlen]
|
244
|
+
|
245
|
+
|
246
|
+
class Qwen2_5_VisionTransformer(nn.Module):
|
247
|
+
|
248
|
+
def __init__(
|
249
|
+
self,
|
250
|
+
vision_config: Qwen2_5_VLVisionConfig,
|
251
|
+
norm_eps: float = 1e-6,
|
252
|
+
quant_config: Optional[QuantizationConfig] = None,
|
253
|
+
) -> None:
|
254
|
+
super().__init__()
|
255
|
+
|
256
|
+
patch_size: int = vision_config.patch_size
|
257
|
+
temporal_patch_size: int = vision_config.temporal_patch_size
|
258
|
+
spatial_merge_size: int = vision_config.spatial_merge_size
|
259
|
+
self.spatial_merge_size = spatial_merge_size
|
260
|
+
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
261
|
+
in_chans: int = vision_config.in_chans
|
262
|
+
hidden_size: int = vision_config.hidden_size
|
263
|
+
depth: int = vision_config.depth
|
264
|
+
num_heads: int = vision_config.num_heads
|
265
|
+
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
266
|
+
self.window_size = vision_config.window_size
|
267
|
+
self.patch_size = vision_config.patch_size
|
268
|
+
mlp_hidden_size: int = vision_config.intermediate_size
|
269
|
+
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
270
|
+
patch_size=patch_size,
|
271
|
+
temporal_patch_size=temporal_patch_size,
|
272
|
+
in_chans=in_chans,
|
273
|
+
embed_dim=hidden_size,
|
274
|
+
)
|
275
|
+
|
276
|
+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
277
|
+
head_dim = hidden_size // num_heads
|
278
|
+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
279
|
+
self.blocks = nn.ModuleList(
|
280
|
+
[
|
281
|
+
Qwen2_5_VisionBlock(
|
282
|
+
dim=hidden_size,
|
283
|
+
intermediate_dim=mlp_hidden_size,
|
284
|
+
num_heads=num_heads,
|
285
|
+
hidden_act=vision_config.hidden_act,
|
286
|
+
norm_layer=norm_layer,
|
287
|
+
attn_implementation="sdpa",
|
288
|
+
quant_config=quant_config,
|
289
|
+
)
|
290
|
+
for _ in range(depth)
|
291
|
+
]
|
292
|
+
)
|
293
|
+
self.merger = Qwen2_5_VisionPatchMerger(
|
294
|
+
dim=vision_config.out_hidden_size,
|
295
|
+
context_dim=hidden_size,
|
296
|
+
spatial_merge_size=spatial_merge_size,
|
297
|
+
quant_config=quant_config,
|
298
|
+
)
|
299
|
+
|
300
|
+
def get_window_index(self, grid_thw):
|
301
|
+
window_index: list = []
|
302
|
+
cu_window_seqlens: list = [0]
|
303
|
+
window_index_id = 0
|
304
|
+
vit_merger_window_size = (
|
305
|
+
self.window_size // self.spatial_merge_size // self.patch_size
|
306
|
+
)
|
307
|
+
|
308
|
+
for grid_t, grid_h, grid_w in grid_thw:
|
309
|
+
llm_grid_h, llm_grid_w = (
|
310
|
+
grid_h // self.spatial_merge_size,
|
311
|
+
grid_w // self.spatial_merge_size,
|
312
|
+
)
|
313
|
+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
314
|
+
grid_t, llm_grid_h, llm_grid_w
|
315
|
+
)
|
316
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
317
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
318
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
319
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
320
|
+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
321
|
+
index_padded = index_padded.reshape(
|
322
|
+
grid_t,
|
323
|
+
num_windows_h,
|
324
|
+
vit_merger_window_size,
|
325
|
+
num_windows_w,
|
326
|
+
vit_merger_window_size,
|
327
|
+
)
|
328
|
+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
329
|
+
grid_t,
|
330
|
+
num_windows_h * num_windows_w,
|
331
|
+
vit_merger_window_size,
|
332
|
+
vit_merger_window_size,
|
333
|
+
)
|
334
|
+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
335
|
+
index_padded = index_padded.reshape(-1)
|
336
|
+
index_new = index_padded[index_padded != -100]
|
337
|
+
window_index.append(index_new + window_index_id)
|
338
|
+
cu_seqlens_tmp = (
|
339
|
+
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
340
|
+
)
|
341
|
+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
342
|
+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
343
|
+
window_index = torch.cat(window_index, dim=0)
|
344
|
+
|
345
|
+
return window_index, cu_window_seqlens
|
346
|
+
|
347
|
+
@property
|
348
|
+
def dtype(self) -> torch.dtype:
|
349
|
+
return self.blocks[0].mlp.gate_proj.weight.dtype
|
350
|
+
|
351
|
+
@property
|
352
|
+
def device(self) -> torch.device:
|
353
|
+
return self.blocks[0].mlp.gate_proj.weight.device
|
354
|
+
|
355
|
+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
356
|
+
pos_ids = []
|
357
|
+
for t, h, w in grid_thw:
|
358
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
359
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
360
|
+
hpos_ids = (
|
361
|
+
hpos_ids.reshape(
|
362
|
+
h // self.spatial_merge_size,
|
363
|
+
self.spatial_merge_size,
|
364
|
+
w // self.spatial_merge_size,
|
365
|
+
self.spatial_merge_size,
|
366
|
+
)
|
367
|
+
.permute(0, 2, 1, 3)
|
368
|
+
.flatten()
|
369
|
+
)
|
370
|
+
wpos_ids = (
|
371
|
+
wpos_ids.reshape(
|
372
|
+
h // self.spatial_merge_size,
|
373
|
+
self.spatial_merge_size,
|
374
|
+
w // self.spatial_merge_size,
|
375
|
+
self.spatial_merge_size,
|
376
|
+
)
|
377
|
+
.permute(0, 2, 1, 3)
|
378
|
+
.flatten()
|
379
|
+
)
|
380
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
381
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
382
|
+
max_grid_size = grid_thw[:, 1:].max()
|
383
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
384
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
385
|
+
return rotary_pos_emb
|
386
|
+
|
387
|
+
def forward(
|
388
|
+
self,
|
389
|
+
x: torch.Tensor,
|
390
|
+
grid_thw: torch.Tensor,
|
391
|
+
) -> torch.Tensor:
|
392
|
+
# patchify
|
393
|
+
x = x.to(device=self.device, dtype=self.dtype)
|
394
|
+
x = self.patch_embed(x)
|
395
|
+
|
396
|
+
# compute position embedding
|
397
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
398
|
+
|
399
|
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
400
|
+
cu_window_seqlens = torch.tensor(
|
401
|
+
cu_window_seqlens,
|
402
|
+
device=x.device,
|
403
|
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
404
|
+
)
|
405
|
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
406
|
+
|
407
|
+
seq_len, _ = x.size()
|
408
|
+
|
409
|
+
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
410
|
+
x = x[window_index, :, :]
|
411
|
+
x = x.reshape(seq_len, -1)
|
412
|
+
rotary_pos_emb = rotary_pos_emb.reshape(
|
413
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
414
|
+
)
|
415
|
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
416
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
417
|
+
|
418
|
+
# compute cu_seqlens
|
419
|
+
cu_seqlens = torch.repeat_interleave(
|
420
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
421
|
+
).cumsum(dim=0, dtype=torch.int32)
|
422
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
423
|
+
|
424
|
+
# transformers
|
425
|
+
x = x.unsqueeze(1)
|
426
|
+
for layer_num, blk in enumerate(self.blocks):
|
427
|
+
if layer_num in self.fullatt_block_indexes:
|
428
|
+
cu_seqlens_now = cu_seqlens
|
429
|
+
else:
|
430
|
+
cu_seqlens_now = cu_window_seqlens
|
431
|
+
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
|
432
|
+
|
433
|
+
# adapter
|
434
|
+
x = self.merger(x)
|
435
|
+
|
436
|
+
reverse_indices = torch.argsort(window_index)
|
437
|
+
x = x[reverse_indices, :]
|
438
|
+
|
439
|
+
return x
|
440
|
+
|
441
|
+
|
442
|
+
cached_get_processor = lru_cache(get_processor)
|
443
|
+
|
444
|
+
|
445
|
+
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
446
|
+
def __init__(
|
447
|
+
self,
|
448
|
+
config: Qwen2VLConfig,
|
449
|
+
quant_config: Optional[QuantizationConfig] = None,
|
450
|
+
) -> None:
|
451
|
+
super().__init__()
|
452
|
+
|
453
|
+
self.config = config
|
454
|
+
self.visual = Qwen2_5_VisionTransformer(
|
455
|
+
config.vision_config,
|
456
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
457
|
+
# NOTE: Qwen2-VL vision encoder does not support any
|
458
|
+
# quantization method now.
|
459
|
+
quant_config=None,
|
460
|
+
)
|
461
|
+
|
462
|
+
self.model = Qwen2Model(config, quant_config)
|
463
|
+
|
464
|
+
if config.tie_word_embeddings:
|
465
|
+
self.lm_head = self.model.embed_tokens
|
466
|
+
else:
|
467
|
+
self.lm_head = ParallelLMHead(
|
468
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
469
|
+
)
|
470
|
+
|
471
|
+
self.logits_processor = LogitsProcessor(config)
|
472
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
473
|
+
|
474
|
+
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
475
|
+
processor = cached_get_processor(self.config._name_or_path)
|
476
|
+
grid_t, grid_h, grid_w = image_grid_thw
|
477
|
+
num_image_tokens = (
|
478
|
+
grid_t
|
479
|
+
* grid_h
|
480
|
+
* grid_w
|
481
|
+
// processor.image_processor.merge_size
|
482
|
+
// processor.image_processor.merge_size
|
483
|
+
)
|
484
|
+
return num_image_tokens
|
485
|
+
|
486
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
487
|
+
new_input_ids = []
|
488
|
+
last_idx = 0
|
489
|
+
image_idx = -1
|
490
|
+
image_inputs.image_offsets = []
|
491
|
+
|
492
|
+
# Get all special token IDs
|
493
|
+
im_start_id = image_inputs.im_start_id
|
494
|
+
im_end_id = image_inputs.im_end_id
|
495
|
+
|
496
|
+
# Find all start and end positions for both types
|
497
|
+
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
|
498
|
+
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
|
499
|
+
|
500
|
+
if len(start_indices) != len(end_indices):
|
501
|
+
return input_ids
|
502
|
+
# Process each region (both image and slice)
|
503
|
+
for start_idx, end_idx in zip(start_indices, end_indices):
|
504
|
+
# Add non-image tokens before this region
|
505
|
+
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
|
506
|
+
|
507
|
+
is_image_start = input_ids[start_idx] == im_start_id
|
508
|
+
|
509
|
+
if is_image_start:
|
510
|
+
image_inputs.image_offsets += [start_idx]
|
511
|
+
image_idx += 1
|
512
|
+
|
513
|
+
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
514
|
+
|
515
|
+
# Generate pad_ids
|
516
|
+
pad_values = [image_inputs.pad_values[image_idx]]
|
517
|
+
|
518
|
+
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
519
|
+
pad_ids = pad_ids[:num_tokens]
|
520
|
+
|
521
|
+
# Add pad_ids
|
522
|
+
new_input_ids.extend(pad_ids)
|
523
|
+
|
524
|
+
# Update last_idx to after end token
|
525
|
+
last_idx = end_idx
|
526
|
+
|
527
|
+
# Add remaining tokens after last region
|
528
|
+
new_input_ids.extend(input_ids[last_idx:])
|
529
|
+
assert len(input_ids) == len(new_input_ids)
|
530
|
+
return new_input_ids
|
531
|
+
|
532
|
+
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
533
|
+
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
534
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
535
|
+
return image_embeds
|
536
|
+
|
537
|
+
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
538
|
+
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
|
539
|
+
video_embeds = self.visual(
|
540
|
+
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
|
541
|
+
)
|
542
|
+
return video_embeds
|
543
|
+
|
544
|
+
def forward(
|
545
|
+
self,
|
546
|
+
input_ids: torch.Tensor,
|
547
|
+
positions: torch.Tensor,
|
548
|
+
forward_batch: ForwardBatch,
|
549
|
+
get_embedding: bool = False,
|
550
|
+
):
|
551
|
+
"""Run forward pass for Qwen2_5-VL.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
555
|
+
batch.
|
556
|
+
positions: Flattened (concatenated) position ids corresponding to a
|
557
|
+
batch.
|
558
|
+
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
559
|
+
opensource models), the shape will be `(3, seq_len)`,
|
560
|
+
otherwise it will be `(seq_len,).
|
561
|
+
(Use input_metadata.mrope_positions to replace it)
|
562
|
+
"""
|
563
|
+
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
564
|
+
positions = forward_batch.mrope_positions
|
565
|
+
|
566
|
+
image_inputs = None
|
567
|
+
if forward_batch.image_inputs is not None:
|
568
|
+
image_inputs = [
|
569
|
+
img for img in forward_batch.image_inputs if img is not None
|
570
|
+
]
|
571
|
+
|
572
|
+
if (
|
573
|
+
forward_batch.forward_mode.is_decode()
|
574
|
+
or image_inputs is None
|
575
|
+
or len(image_inputs) == 0
|
576
|
+
):
|
577
|
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
578
|
+
else:
|
579
|
+
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
580
|
+
assert positions.ndim == 2 and positions.size(0) == 3, (
|
581
|
+
"multimodal section rotary embedding requires "
|
582
|
+
f"(3, seq_len) positions, but got {positions.size()}"
|
583
|
+
)
|
584
|
+
|
585
|
+
# Clamp input ids. This is because the input_ids for the image tokens are
|
586
|
+
# filled with the hash values of the image for the prefix matching in the radix attention.
|
587
|
+
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
588
|
+
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
589
|
+
# [B, s, hidden_size]
|
590
|
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
591
|
+
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
592
|
+
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
593
|
+
for i, image in enumerate(forward_batch.image_inputs):
|
594
|
+
if image is None:
|
595
|
+
continue
|
596
|
+
start_idx = extend_start_loc_cpu[i]
|
597
|
+
prefix_len = prefix_lens_cpu[i]
|
598
|
+
|
599
|
+
pixel_values = image.pixel_values.clone().detach().requires_grad_(False)
|
600
|
+
image_grid_thws = torch.tensor(
|
601
|
+
np.array(image.image_grid_thws), device="cuda"
|
602
|
+
)
|
603
|
+
image_offsets = image.image_offsets
|
604
|
+
image_input = Qwen2VLImageInputs(
|
605
|
+
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
606
|
+
)
|
607
|
+
image_embeds = self._process_image_input(image_input)
|
608
|
+
|
609
|
+
image_embeds_offset = 0
|
610
|
+
for idx, image_offset in enumerate(image_offsets):
|
611
|
+
if image_offset < prefix_len:
|
612
|
+
continue
|
613
|
+
num_image_tokens = self.calculate_num_image_tokens(
|
614
|
+
image_grid_thws[idx]
|
615
|
+
)
|
616
|
+
|
617
|
+
left_idx = start_idx + (image_offset - prefix_len)
|
618
|
+
right_idx = left_idx + num_image_tokens
|
619
|
+
|
620
|
+
tp_size = get_tensor_model_parallel_world_size()
|
621
|
+
|
622
|
+
hidden_size = image_embeds.shape[-1]
|
623
|
+
|
624
|
+
if hidden_size % tp_size != 0:
|
625
|
+
padding_size = tp_size - (hidden_size % tp_size)
|
626
|
+
image_embeds = F.pad(image_embeds, (0, padding_size))
|
627
|
+
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
|
628
|
+
|
629
|
+
hidden_chunk_size = image_embeds.shape[-1] // tp_size
|
630
|
+
rank = get_tensor_model_parallel_rank()
|
631
|
+
start_dim = rank * hidden_chunk_size
|
632
|
+
end_dim = (rank + 1) * hidden_chunk_size
|
633
|
+
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
|
634
|
+
image_embeds[
|
635
|
+
image_embeds_offset : image_embeds_offset
|
636
|
+
+ num_image_tokens,
|
637
|
+
...,
|
638
|
+
start_dim:end_dim,
|
639
|
+
]
|
640
|
+
)
|
641
|
+
image_embeds_offset += num_image_tokens
|
642
|
+
|
643
|
+
input_ids = None
|
644
|
+
hidden_states = self.model(
|
645
|
+
input_ids=input_ids,
|
646
|
+
positions=positions,
|
647
|
+
forward_batch=forward_batch,
|
648
|
+
input_embeds=inputs_embeds,
|
649
|
+
)
|
650
|
+
|
651
|
+
if not get_embedding:
|
652
|
+
return self.logits_processor(
|
653
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
654
|
+
)
|
655
|
+
else:
|
656
|
+
return self.pooler(hidden_states, forward_batch)
|
657
|
+
|
658
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
659
|
+
stacked_params_mapping = [
|
660
|
+
# (param_name, shard_name, shard_id)
|
661
|
+
("qkv_proj", "q_proj", "q"),
|
662
|
+
("qkv_proj", "k_proj", "k"),
|
663
|
+
("qkv_proj", "v_proj", "v"),
|
664
|
+
("gate_up_proj", "up_proj", 1),
|
665
|
+
("gate_up_proj", "gate_proj", 0),
|
666
|
+
]
|
667
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
668
|
+
for name, loaded_weight in weights:
|
669
|
+
if "rotary_emb.inv_freq" in name:
|
670
|
+
continue
|
671
|
+
|
672
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
673
|
+
if weight_name not in name:
|
674
|
+
continue
|
675
|
+
if "visual" in name:
|
676
|
+
continue
|
677
|
+
name = name.replace(weight_name, param_name)
|
678
|
+
|
679
|
+
# Skip loading extra bias for GPTQ models.
|
680
|
+
if name.endswith(".bias") and name not in params_dict:
|
681
|
+
continue
|
682
|
+
param = params_dict[name]
|
683
|
+
weight_loader = param.weight_loader
|
684
|
+
weight_loader(param, loaded_weight, shard_id)
|
685
|
+
break
|
686
|
+
else:
|
687
|
+
if "visual" in name and "qkv.weight" in name:
|
688
|
+
visual_num_heads = self.config.vision_config.num_heads
|
689
|
+
visual_embed_dim = self.config.vision_config.hidden_size
|
690
|
+
head_size = visual_embed_dim // visual_num_heads
|
691
|
+
loaded_weight = loaded_weight.view(
|
692
|
+
3, visual_num_heads, head_size, visual_embed_dim
|
693
|
+
)
|
694
|
+
loaded_weight = loaded_weight.transpose(0, 1)
|
695
|
+
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
696
|
+
elif "visual" in name and "qkv.bias" in name:
|
697
|
+
visual_num_heads = self.config.vision_config.num_heads
|
698
|
+
visual_embed_dim = self.config.vision_config.hidden_size
|
699
|
+
head_size = visual_embed_dim // visual_num_heads
|
700
|
+
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
701
|
+
loaded_weight = loaded_weight.transpose(0, 1)
|
702
|
+
loaded_weight = loaded_weight.reshape(-1)
|
703
|
+
|
704
|
+
if "visual" in name:
|
705
|
+
# adapt to VisionAttention
|
706
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
707
|
+
|
708
|
+
try:
|
709
|
+
# Skip loading extra bias for GPTQ models.
|
710
|
+
if name.endswith(".bias") and name not in params_dict:
|
711
|
+
continue
|
712
|
+
param = params_dict[name]
|
713
|
+
except KeyError:
|
714
|
+
print(params_dict.keys())
|
715
|
+
raise
|
716
|
+
|
717
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
718
|
+
weight_loader(param, loaded_weight)
|
719
|
+
|
720
|
+
|
721
|
+
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
722
|
+
AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
|